import json
from abc import abstractmethod

import paho.mqtt.client as mqtt
from loguru import logger


class BaseMQTTUtil(object):
    def __init__(
        self,
        client_id="",
        host="localhost",
        port=1883,
        username=None,
        password=None,
        ssl=False,
        certificate_type=None,
        connection_type="tcp",
    ) -> None:
        super().__init__()
        self.client_id = client_id
        self.host = host
        self.port = port
        self.username = username
        self.password = password
        self.ssl = ssl
        self.certificate_type = certificate_type
        self.connection_type = connection_type
        self.ca_cert = None
        self.certfile = None
        self.keyfile = None
        self.connection_error_codes = {1: 1, 2: 2, 3: 3, 4: 4, 5: 5}
        self._client = self.create_client()
        self.topiclist = []
        self.func = None
        self._client.connect(host=self.host, port=self.port, keepalive=60)

    def create_client(self):
        client = mqtt.Client(
            client_id=self.client_id, clean_session=True, transport=self.connection_type
        )
        if self.ssl:
            self.prepare_tls_set_args()
            client.tls_set()
            # raise NotImplementedError(
            #     "ssl based MQTT connection is not enabled")
        if self.username is not None and self.password is not None:
            logger.info("Configuring Credentials for MQTT")
            client.username_pw_set(username=self.username, password=self.password)
        return client

    def executor_function(self, f):
        """
        Function to be executed for the data after subscribing to topics
        - This function should always be called before subscribing to a topic
        """
        self.func = f

    @property
    def client(self):
        return self._client

    def disconnect(self):
        self._client.disconnect()

    def f(self):
        ...

    def prepare_tls_set_args(self):
        ...

    @abstractmethod
    def on_message(self, client, userdata, msg):
        ...

    @abstractmethod
    def on_connect(self, client, userdata, flags, rc):
        ...

    @abstractmethod
    def on_disconnect(self, client, userdata, rc=0):
        ...

    @abstractmethod
    def publish(self, topic, payload=None, qos=0, retain=False):
        ...

    @abstractmethod
    def subscribe(self, topic, qos=0):
        ...


class MQTTUtil(BaseMQTTUtil):
    """
    ### Usage:
    ----------

    #### Subscribing to a topic
        >>> mqtt_obj = MQTTUtil(host='localhost', port=1883)
        >>> mqtt_obj.executor_function(print)
        >>> mqtt_obj.subscribe(topic="mqtt/topic", qos=0, return_type = "payload")

    #### Publishing to a topic
        >>> mqtt_obj = MQTTUtil(host='localhost', port=1883)
        >>> mqtt_obj.publish(topic="mqtt/topic", payload="data", qos=0, retain = False)

    """

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

    def on_message(self, client, userdata, msg):
        logger.trace("Message received on high priority channel")
        if self.return_type == "payload":
            self.func(json.loads(msg.payload.decode("utf-8")))
        elif self.return_type == "all":
            self.func(
                {
                    "data": msg.payload.decode("utf-8"),
                    "topic": msg.topic,
                    "qos": msg.qos,
                    "msg": msg,
                }
            )
        else:
            raise TypeError("Unsupported return type for the executor function")

    def on_connect(self, client, userdata, flags, rc):
        logger.info("Successfully connected to (MQTT)")
        client.subscribe(self.topiclist)
        logger.debug(
            "Agent has subscribed to the MQTT topic '{}'".format(self.topiclist)
        )

    def on_disconnect(self, client, userdata, rc=0):
        logger.warning(
            "MQTT lost connection: {}".format(self.connection_error_codes[rc])
        )
        print(self.connection_error_codes[rc])
        self._client.reconnect()

    def publish(self, topic, payload=None, qos=0, retain=False):
        if not self._client.is_connected():
            self._client = self.create_client()
            logger.info("client Not connected\nConnecting client to Host...")
            self._client.connect(host=self.host, port=self.port, keepalive=60)
            logger.info("client connected")
        return self._client.publish(topic, payload=payload, qos=qos, retain=retain)

    def subscribe(self, topic, qos=0, return_type="payload"):
        """
        :param -> return_type = "payload" | "all"
            payload: returns decoded subscribed mqtt message
            all: returns a dict of all keys of mqtt message (with a mqtt msg object)
        """
        self.return_type = return_type
        if self.func is None:
            raise ModuleNotFoundError(
                "Executor Function is not set.\ncall executor function and then"
                " pass the function to be executed for subscribed topic"
            )
        self.topiclist.append((topic, qos))
        if self._client.is_connected():
            self._client.disconnect()
            self._client.reinitialise()
        self._client.on_connect = self.on_connect
        self._client.on_disconnect = self.on_disconnect
        self._client.on_message = self.on_message
        if not self._client.is_connected():
            logger.info("client Not connected\nConnecting client to Host...")
            self._client.connect(host=self.host, port=self.port, keepalive=60)
            logger.info("client connected")
        self._client.loop_start()

    def prepare_tls_set_args(self):
        ...
