import socket
import threading
import logging
import time

logger = logging.getLogger('elbus')

GREETINGS = 0xEB

PROTOCOL_VERSION = 1

OP_NOP = 0
OP_PUBLISH = 1
OP_SUBSCRIBE = 2
OP_UNSUBSCRIBE = 3
OP_MESSAGE = 0x12
OP_BROADCAST = 0x13
OP_ACK = 0xFE

RESPONSE_OK = 0x01
ERR_CLIENT_NOT_REGISTERED = 0x71
ERR_DATA = 0x72
ERR_IO = 0x73
ERR_OTHER = 0x74
ERR_NOT_SUPPORTED = 0x75
ERR_BUSY = 0x76
ERR_NOT_DELIVERED = 0x77

PING_FRAME = b'\x00' * 9


def on_message_default(message):
    pass
    # print(message.type)
    # print(message.sender)
    # print(message.payload)


class ElbusClient:

    def __init__(self, path, name):
        self.path = path
        self.socket = None
        self.buf_size = 8192
        self.name = name
        self.frame_id = 0
        self.ping_interval = 1
        self.on_message = on_message_default
        self.socket_lock = threading.Lock()
        self.mgmt_lock = threading.Lock()
        self.connected = False
        self.frames = {}
        self.timeout = 5

    def connect(self):
        with self.mgmt_lock:
            if self.path.endswith('.sock') or self.path.endswith(
                    '.socket') or self.path.startswith('/'):
                self.socket = socket.socket(socket.AF_UNIX)
                path = self.path
            else:
                self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY,
                                       1)
                path = self.path.rsplit(':', maxsplit=2)
                path[1] = int(path[1])
                path = tuple(path)
            self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF,
                                   self.buf_size)
            self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF,
                                   self.buf_size)
            self.socket.connect(path)
            self.socket.settimeout(self.timeout)
            frame = self.read_exact(3)
            if frame[0] != GREETINGS:
                raise RuntimeError('Unsupported protocol')
            if int.from_bytes(frame[1:3], 'little') != PROTOCOL_VERSION:
                raise RuntimeError('Unsupported protocol version')
            self.socket.sendall(frame)
            frame = self.socket.recv(1)
            if frame[0] != RESPONSE_OK:
                raise RuntimeError(f'Server response: {hex(frame[0])}')
            name = self.name.encode()
            self.socket.sendall(len(name).to_bytes(2, 'little') + name)
            frame = self.socket.recv(1)
            if frame[0] != RESPONSE_OK:
                raise RuntimeError(f'Server response: {hex(frame[0])}')
            self.connected = True
            threading.Thread(target=self._t_reader, daemon=True).start()
            threading.Thread(target=self._t_pinger, daemon=True).start()

    def handle_daemon_exception(self):
        with self.mgmt_lock:
            if self.connected:
                self.socket.close()
                self.connected = False
                raise
            else:
                pass

    def _t_pinger(self):
        try:
            while True:
                time.sleep(self.ping_interval)
                with self.socket_lock:
                    self.socket.sendall(PING_FRAME)
        except:
            self.handle_daemon_exception()

    def _t_reader(self):
        try:
            while True:
                frame = self.read_exact(6)
                if frame[0] == OP_NOP:
                    continue
                elif frame[0] == OP_ACK:
                    op_id = int.from_bytes(frame[1:5], 'little')
                    with self.socket_lock:
                        try:
                            o = self.frames.pop(op_id)
                            o.result = frame[5]
                            o.completed.set()
                        except KeyError:
                            pass
                else:
                    data_len = int.from_bytes(frame[1:5], 'little')
                    # do not use read_exact for max zero-copy
                    data = b''
                    while len(data) < data_len:
                        buf_size = data_len - len(data)
                        data += self.socket.recv(buf_size if buf_size < self.
                                                 buf_size else self.buf_size)
                    message = Message()
                    try:
                        message.type = frame[0]
                        if frame[0] == OP_PUBLISH:
                            sender, topic, message.payload = data.split(
                                b'\x00', maxsplit=2)
                            message.topic = topic.decode()
                        else:
                            sender, message.payload = data.split(b'\x00',
                                                                 maxsplit=1)
                            message.topic = None
                        message.sender = sender.decode()
                    except Exception as e:
                        logger.error(f'Invalid message from the server: {e}')
                    try:
                        self.on_message(message)
                    except:
                        import traceback
                        traceback.print_exc()
        except:
            self.handle_daemon_exception()

    def disconnect(self):
        with self.mgmt_lock:
            self.socket.close()
            self.connected = False

    def read_exact(self, data_len):
        data = b''
        while len(data) < data_len:
            buf_size = data_len - len(data)
            try:
                data += self.socket.recv(
                    buf_size if buf_size < self.buf_size else self.buf_size)
            except socket.timeout:
                if not self.connected:
                    break
        return data

    def send_message(self, target, message):
        with self.socket_lock:
            self.frame_id += 1
            if self.frame_id > 0xffff_ffff:
                self.frame_id = 1
            flags = message.type | message.qos << 6
            if message.type == OP_SUBSCRIBE or message.type == OP_UNSUBSCRIBE:
                topics = message.topic if isinstance(message.topic,
                                                     list) else [message.topic]
                payload = b'\x00'.join(t.encode() for t in topics)
                self.socket.sendall(
                    self.frame_id.to_bytes(4, 'little') +
                    flags.to_bytes(1, 'little') +
                    len(payload).to_bytes(4, 'little') + payload)
            else:
                frame_len = len(target) + len(message.payload) + 1
                if frame_len > 0xffff_ffff:
                    raise ValueError('frame too large')
                self.socket.sendall(
                    self.frame_id.to_bytes(4, 'little') +
                    flags.to_bytes(1, 'little') +
                    frame_len.to_bytes(4, 'little') + target.encode() +
                    b'\x00' + (message.payload.encode(
                    ) if isinstance(message.payload, str) else message.payload))
            o = OutgingFrame(message.qos)
            if message.qos > 0:
                self.frames[self.frame_id] = o
            return o

    def subscribe(self, topics):
        message = Message(tp=OP_SUBSCRIBE)
        message.topic = topics
        return self.send_message(None, message)

    def unsubscribe(self, topics):
        message = Message(tp=OP_UNSUBSCRIBE)
        message.topic = topics
        return self.send_message(None, message)


class OutgingFrame:

    def __init__(self, qos):
        self.qos = qos
        self.result = 0
        if qos > 0:
            self.completed = threading.Event()

    def is_completed(self):
        if self.qos > 0:
            return self.completed.is_set()
        else:
            return True

    def wait_completed(self, *args, **kwargs):
        self.completed.wait(*args, **kwargs)
        return self.result


class Message:

    def __init__(self, payload=None, tp=OP_MESSAGE, qos=1):
        if payload is not None:
            self.payload = payload
        self.type = tp
        self.qos = qos
