import cv2
import base64
import numpy as np
from scipy.spatial import distance
from expiringdict import ExpiringDict

from edge_engine.common.logsetup import logger
from scripts.utils.infocenter import MongoLogger
from yolov5processor.infer import ExecuteInference
from scripts.utils.edge_utils import get_extra_fields
from edge_engine.ai.model.modelwraper import ModelWrapper
from scripts.utils.centroidtracker import CentroidTracker
from scripts.common.constants import JanusDeploymentConstants
from scripts.utils.image_utils import draw_circles_on_frame, resize_to_64_64
from scripts.utils.edge_utils import Utilities

from collections import deque
from scripts.utils.tracker import Tracker
from scripts.utils.helpers import box_iou2
from sklearn.utils.linear_assignment_ import linear_assignment


class CementBagCounter(ModelWrapper):

    def __init__(self, config, model_config, pubs, device_id):
        super().__init__()
        """
        init function
        """
        self.config = config["config"]
        self.device_id = device_id
        self.rtp = pubs.rtp_write
        self.mongo_logger = MongoLogger()
        self.frame_skip = self.config.get('frame_skip', False)
        model = "data/ACC_v3.pt"
        self.yp = ExecuteInference(weight=model,
                                   gpu=model_config.get("gpu", False),
                                   agnostic_nms=model_config.get("agnostic_nms", True),
                                   iou=model_config.get("iou", 0.2),
                                   confidence=model_config.get("confidence", 0.4))
        self.print_eu_dist = model_config.get('print_eu_dist', 200)
        self.ct1 = CentroidTracker(maxDisappeared=5)
        self.ct2 = CentroidTracker(maxDisappeared=5)
        self.frame_skipping = {
            "skip_current_frame": True,
            "detection_value": None
        }
        self.count = 0
        self.cement_bag = 0
        self.count_suraksha = 0
        self.count_whitecem = 0
        self.count_gold = 0
        self.mrp_counter = 0
        self.initial_object_position = Utilities.get_direction(self.device_id)
        self.tracker_list = []
        self.max_age = 15
        self.min_hits = 10
        self.track_id_list = deque([str(i) for i in range(1, 50)])
        self.prev_annotation = []

        self.initial_object_position = None

        self.uncounted_objects = ExpiringDict(max_len=model_config.get("uncounted_obj_length", 50),
                                              max_age_seconds=model_config.get("uncounted_obj_age", 60))
        self.janus_metadata = ExpiringDict(max_age_seconds=120, max_len=1)

    def _pre_process(self, x):
        """
        Do preprocessing here, if any
        :param x: payload
        :return: payload
        """
        return x

    def _post_process(self, x):
        """
        Apply post processing here, if any
        :param x: payload
        :return: payload
        """
        self.rtp.publish(x)  # video stream
        return x

    def send_payload(self, frame, label='CementBagDetected', bg_color="#474520", font_color="#FFFF00", alert_sound=None,
                     message="Cement Bag Detected!"):
        """
        Insert event to Mongo
        :param message:
        :param frame:
        :param label:
        :param bg_color:
        :param font_color:
        :param alert_sound:
        :return: None
        """

        payload = {"deviceId": self.device_id, "message": message,
                   "frame": 'data:image/jpeg;base64,' + base64.b64encode(
                       cv2.imencode('.jpg', frame)[1].tostring()).decode("utf-8"), "activity": label,
                   "bg_color": bg_color, "font_color": font_color, "alert_sound": alert_sound}

        self.mongo_logger.insert_attendance_event_to_mongo(payload)

    def track_bags(self, tracker_obj, dets, im0, filter_name, centroid_color=(255, 0, 0)):
        """
        Track the bags using Centroid based tracking
        :param dets: prediction output
        :param tracker_obj: prediction output
        :param filter_name: prediction output
        :param im0: raw frame
        :param centroid_color: color given to the centroid marking
        :return: centroid points, frame
        """

        bags = list()
        classes = list()
        for det in dets:
            if (det["class"] in filter_name):
                bags.append(np.array(det['points']).astype("int"))
                classes.append(det["class"])


        objects = tracker_obj.update(bags)
        objects.pop("frame", None)

        if centroid_color is not False:
            for (objectID, centroid) in objects.items():

                if centroid['has_print']:
                    centroid_color = (0, 255, 0)

                cv2.putText(im0, str(objectID), (centroid['centroid'][0] - 10, centroid['centroid'][1] - 10),
                            cv2.FONT_HERSHEY_SIMPLEX,
                            1, centroid_color, 2, cv2.LINE_AA)
                cv2.circle(im0, (centroid['centroid'][0], centroid['centroid'][1]), 8, centroid_color, -1)
        return objects, classes, im0

    def kalman_tracker(
            self,
            bboxs,
            img,
    ):

        z_box = bboxs
        x_box = []

        if len(self.tracker_list) > 0:
            for trk in self.tracker_list:
                x_box.append(trk.box)

        matched, unmatched_dets, unmatched_trks = self.assign_detections_to_trackers(x_box, z_box, iou_thrd=0.01)

        # Deal with matched detections
        if matched.size > 0:
            for trk_idx, det_idx in matched:
                z = z_box[det_idx]
                z = np.expand_dims(z, axis=0).T
                tmp_trk = self.tracker_list[trk_idx]
                tmp_trk.kalman_filter(z)
                xx = tmp_trk.x_state.T[0].tolist()
                xx = [xx[0], xx[2], xx[4], xx[6]]
                x_box[trk_idx] = xx
                tmp_trk.box = xx
                tmp_trk.hits += 1

        # Deal with unmatched detections
        if len(unmatched_dets) > 0:
            for idx in unmatched_dets:
                z = z_box[idx]
                z = np.expand_dims(z, axis=0).T
                tmp_trk = Tracker()  # Create a new tracker
                x = np.array([[z[0], 0, z[1], 0, z[2], 0, z[3], 0]]).T
                tmp_trk.x_state = x
                tmp_trk.predict_only()
                xx = tmp_trk.x_state
                xx = xx.T[0].tolist()
                xx = [xx[0], xx[2], xx[4], xx[6]]
                tmp_trk.box = xx
                tmp_trk.id = self.track_id_list.popleft()  # assign an ID for the tracker

                self.tracker_list.append(tmp_trk)
                x_box.append(xx)

        # Deal with unmatched tracks
        if len(unmatched_trks) > 0:
            for trk_idx in unmatched_trks:
                tmp_trk = self.tracker_list[trk_idx]
                tmp_trk.no_losses += 1
                tmp_trk.predict_only()
                xx = tmp_trk.x_state
                xx = xx.T[0].tolist()
                xx = [xx[0], xx[2], xx[4], xx[6]]
                tmp_trk.box = xx
                x_box[trk_idx] = xx

        # The list of tracks to be annotated
        good_tracker_list = []
        objects = []
        boxs = []
        for trk in self.tracker_list:
            if (trk.hits >= self.min_hits) and (trk.no_losses <= self.max_age):
                good_tracker_list.append(trk)
                x_cv2 = trk.box
                left, top, right, bottom = x_cv2[1], x_cv2[0], x_cv2[3], x_cv2[2]
                centroid = [int(left + ((right - left) / 2)), bottom]
                objects.append([int(trk.id), centroid])
                boxs.append(x_cv2)

        deleted_tracks = filter(lambda _x: _x.no_losses > self.max_age, self.tracker_list)

        for trk in deleted_tracks:
            self.track_id_list.append(trk.id)

        self.tracker_list = [x for x in self.tracker_list if x.no_losses <= self.max_age]
        print("object is ", str(objects))
        return img, objects, boxs

    @staticmethod
    def assign_detections_to_trackers(
            trackers,
            detections,
            iou_thrd=0.3,
    ):
        """
        From current list of trackers and new detections, output matched detections,
        un matched trackers, unmatched detections.
        """
        iou_mat = np.zeros((len(trackers), len(detections)), dtype=np.float32)
        for t, trk in enumerate(trackers):
            for d, det in enumerate(detections):
                iou_mat[t, d] = box_iou2(trk, det)

        matched_idx = linear_assignment(-iou_mat)

        unmatched_trackers, unmatched_detections = [], []
        for t, trk in enumerate(trackers):
            if t not in matched_idx[:, 0]:
                unmatched_trackers.append(t)

        for d, det in enumerate(detections):
            if d not in matched_idx[:, 1]:
                unmatched_detections.append(d)

        matches = []

        for m in matched_idx:
            if iou_mat[m[0], m[1]] < iou_thrd:
                unmatched_trackers.append(m[0])
                unmatched_detections.append(m[1])
            else:
                matches.append(m.reshape(1, 2))

        if len(matches) == 0:
            matches = np.empty((0, 2), dtype=int)
        else:
            matches = np.concatenate(matches, axis=0)

        return matches, np.array(unmatched_detections), np.array(unmatched_trackers)




    def get_line_coordinates(self):
        """
        Get the line coordinates from the deployment JSON
        """
        if not self.janus_metadata.get('metadata'):
            self.janus_metadata['metadata'] = get_extra_fields(self.device_id)

        _coordinates = [self.janus_metadata['metadata'].get(coordinate_key) for coordinate_key in
                        JanusDeploymentConstants.LINE_COORDINATES]
        _alignment = self.janus_metadata['metadata'].get(JanusDeploymentConstants.ALIGNMENT_KEY)
        return _alignment, _coordinates

    def line_point_position(self, point):
        """
        Get the position of point w.r.t. the line
        :param point: point to be compared
        :return: boolean
        """
        _alignment, line_coordinates = self.get_line_coordinates()

        assert len(line_coordinates) == 4, "Line coordinates variable is invalid"
        assert len(point) == 2, "Point variable is invalid"

        _slope = (line_coordinates[3] - line_coordinates[1]) / (line_coordinates[2] - line_coordinates[0])
        _point_equation_value = point[1] - line_coordinates[1] - _slope * (point[0] - line_coordinates[0])
        if _point_equation_value > 0:
            return True
        else:
            return False

    def validate_point_position(self, point):
        """
        Validate the position of the point w.r.t. the line
        :param point: centroid
        :return: bool
        """
        _alignment, line_coordinates = self.get_line_coordinates()
        assert _alignment in [JanusDeploymentConstants.VERTICAL, JanusDeploymentConstants.HORIZONTAL], \
            "Invalid alignment variable"
        if _alignment == JanusDeploymentConstants.VERTICAL:
            line_y2 = line_coordinates[3]
            line_y1 = line_coordinates[1]
            if line_y1 < point[1] < line_y2 or line_y2 < point[1] < line_y1:
                return True
            else:
                return False
        else:
            line_x2 = line_coordinates[2]
            line_x1 = line_coordinates[0]
            if line_x1 < point[0] < line_x2 or line_x2 < point[0] < line_x1:
                return True
            else:
                return False

    def update_bag_count(self, frame, detection_objects, classes):
        """
        Maintains the bag counts
        :param frame: image
        :param detection_objects: detection object having object id and centroids
        """
        for class_name, (objectID, centroid) in zip(classes, detection_objects.items()):

            if self.validate_point_position(centroid['centroid']):
                logger.debug("centroid detected")

                if not isinstance(self.initial_object_position, bool):
                    logger.debug("Initializing the initial object position")
                    self.initial_object_position = self.line_point_position(point=centroid['centroid'])
                    Utilities.set_direction(self.device_id, self.initial_object_position)
                    #self.initial_object_position = True
                    logger.debug(self.initial_object_position)

                _point_position = self.line_point_position(point=centroid['centroid'])
                logger.debug("object ID is : ", str(objectID))
                logger.debug(self.uncounted_objects)

                # Check point in the same side as the initial object
                if _point_position == self.initial_object_position:
                    logger.debug("same side only")
                    # Check the object is not already counted
                    if objectID not in self.uncounted_objects:
                        self.uncounted_objects[objectID] = centroid['centroid']
                    frame = draw_circles_on_frame(frame, centroid['centroid'], radius=10, color=(0, 0, 255),
                                                  thickness=-1)

                elif objectID in self.uncounted_objects:
                    logger.debug("different side")
                    self.uncounted_objects.pop(objectID, None)
                    if (class_name == "acc_gold"):
                        self.count_gold += 1
                        logger.debug(self.count_gold)
                    elif (class_name == "acc_suraksha_plus"):
                        self.count_suraksha += 1
                        logger.debug(self.count_suraksha)
                    elif (class_name == "ambuja_whitecem"):
                        self.count_whitecem += 1
                        logger.debug(self.count_whitecem)
                    frame = draw_circles_on_frame(frame, centroid['centroid'], radius=10, color=(0, 255, 0),
                                                  thickness=-1)
                    if centroid['has_print']:
                        self.send_payload(resize_to_64_64(frame=frame), message='Print Detected!')
                        logger.info(f"Count: {self.count}, Print Found: True")
                    else:
                        self.send_payload(resize_to_64_64(frame=frame), message='Print Missing!')
                        logger.info(f"Count: {self.count}, Print Found: False")
                else:
                    frame = draw_circles_on_frame(frame, centroid['centroid'], radius=10, color=(0, 255, 0),
                                                  thickness=-1)
        count_text_gold = "ACC_GOLD: " + str(self.count_gold)
        count_text_suraksha = "ACC_SURAKSHA_P_PLUS: " + str(self.count_suraksha)
        count_text_whitecem = "PPC_WHITE: " + str(self.count_whitecem)
        cv2.putText(frame, count_text_gold, (1300, 200), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 3,
                    cv2.LINE_AA)
        cv2.putText(frame, count_text_suraksha, (1300, 400), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 3,
                    cv2.LINE_AA)
        cv2.putText(frame, count_text_whitecem, (1300, 600), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 3,
                    cv2.LINE_AA)
        return frame

    def draw_line_over_image(self, frame, color=(255, 255, 255)):
        """
        Draws line over the counting line
        :param frame: frame for
        :param color:
        :return:
        """
        _alignment, line_coordinates = self.get_line_coordinates()
        assert len(line_coordinates) == 4, "Line coordinates variable is invalid"

        # return cv2.line(frame, (line_coordinates[0], line_coordinates[1]), (line_coordinates[2], line_coordinates[3]),
        #                 color, 3)

        self.drawline(frame, (line_coordinates[0], line_coordinates[1]), (line_coordinates[2],
                                                                          line_coordinates[3]), color, thickness=3)
        return frame

    @staticmethod
    def drawline(img, pt1, pt2, color, thickness=1, style='dotted', gap=20):
        dist = ((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2) ** .5
        pts = []
        for i in np.arange(0, dist, gap):
            r = i / dist
            x = int((pt1[0] * (1 - r) + pt2[0] * r) + .5)
            y = int((pt1[1] * (1 - r) + pt2[1] * r) + .5)
            p = (x, y)
            pts.append(p)

        if style == 'dotted':
            for p in pts:
                cv2.circle(img, p, thickness, color, -1)
        else:
            s = pts[0]
            e = pts[0]
            i = 0
            for p in pts:
                s = e
                e = p
                if i % 2 == 1:
                    cv2.line(img, s, e, color, thickness)
                i += 1

    def distances(self, objs1, objs2):
        for key1, val1 in objs1.items():
            for key2, val2 in objs2.items():
                dst = distance.euclidean(val1['centroid'], val2['centroid'])
                if objs1[key1]['has_print']:
                    self.mrp_counter += 1
                    if(self.mrp_counter >= 5):
                        #STOP THE RELAY
                        pass
                    continue
                elif dst < self.print_eu_dist:
                    objs1[key1]['has_print'] = True
                    self.mrp_counter = 0

    # def inference(
    #         self,
    #         frame,
    #         classes,
    #
    # ):
    #     dets = self.yp.predict(frame)
    #     class_name = list()
    #     bboxs = []
    #
    #     if dets:
    #         for i in dets:
    #             if(i["class"] in classes):
    #                 class_name.append(i["class"])
    #             bboxs.append([i["points"][1], i["points"][0], i["points"][3], i["points"][2]])
    #
    #     print("#######")
    #     print(bboxs)
    #     #frame = cv2.rectangle(frame, (bboxs[0][0], bboxs[0][1]), (bboxs[0][2], bboxs[0][3]),(255, 255, 0) , 2)
    #     return bboxs, frame, dets, class_name


    def inference(
            self,
            frame,
    ):
        dets = self.yp.predict(frame)
        bboxs = []
        if dets:
            for i in dets:
                bboxs.append([i["points"][1], i["points"][0], i["points"][3], i["points"][2]])
        return bboxs, frame, dets

    def _predict(self, obj):
        class_list = ["acc_gold", "acc_suraksha_plus", "ambuja_buildcem"]
        try:
            frame = obj['frame']

            if self.frame_skip:
                if not self.frame_skipping["skip_current_frame"]:
                    dets = self.yp.predict(frame)
                    self.frame_skipping["detection_value"] = dets
                    self.frame_skipping["skip_current_frame"] = True
                else:
                    dets = self.frame_skipping["detection_value"]
                    self.frame_skipping["skip_current_frame"] = False
            else:

                dets, frame, _dets = self.inference(frame)
                print("PRINTING INFERENCE FUNCTION OUTPUT")
                print(dets)
                print(_dets)
                #print(class_name)
                #if dets:
                frame, objects, boxs = self.kalman_tracker(dets, frame)
                print("PRINTING KALMAN OUTPUTS")
                print(objects)
                print(boxs)

                dets = self.yp.predict(frame)

            frame = self.draw_line_over_image(frame)
            # if [True for e in dets if e['class'] == 'cement_bag']:
            #class_list = ["acc_gold", "acc_suraksha_plus", "ambuja_whitecem"]
            mrp = ["mrp"]
            objects,classes_cement, frame = self.track_bags(self.ct1, dets, frame, class_list)
            _,classes, frame = self.track_bags(self.ct2, dets, frame, mrp)
            frame = self.update_bag_count(frame=frame, detection_objects=objects, classes = classes_cement)
            cv2.imshow("output is ", cv2.resize(frame, (900, 600)))
            cv2.waitKey(1)
            self.distances(objects, _)
            logger.debug("self.uncounted_objects --> {}".format(self.uncounted_objects))
            # for each in dets:
            #     color = (255, 255, 0)
            #     class_n = "Cement Bag"
            #
            #     if each['class'] == 'label':
            #         color = (0, 255, 0)
            #         class_n = "Printing Detected!"
            #     cv2.rectangle(frame, (each['points'][0], each['points'][1]), (each['points'][2], each['points'][3]),
            #                   color, 2)
            #     cv2.putText(frame, class_n, (each['points'][2], each['points'][1]), cv2.FONT_HERSHEY_SIMPLEX,
            #                 1, color, 2, cv2.LINE_AA)

            obj['frame'] = cv2.resize(frame, (self.config.get('FRAME_WIDTH'), self.config.
                                              get('FRAME_HEIGHT')))
        except Exception as e:
            logger.exception(f"Error: {e}", exc_info=True)
            obj['frame'] = cv2.resize(obj['frame'], (self.config.get('FRAME_WIDTH'), self.config.get('FRAME_HEIGHT')))

        return obj
