import cv2
import base64
import datetime
import traceback
import numpy as np
from expiringdict import ExpiringDict
import time
from edge_engine.common.logsetup import logger
from scripts.utils.infocenter import MongoLogger
#from yolov5processor.infer import ExecuteInference
from scripts.utils.edge_utils import get_extra_fields
from edge_engine.ai.model.modelwraper import ModelWrapper
from scripts.utils.centroidtracker import CentroidTracker
from scripts.common.constants import JanusDeploymentConstants
from scripts.utils.image_utils import draw_circles_on_frame, resize_to_64_64
import json


class CementBagCounter(ModelWrapper):

    def __init__(self, config, model_config, pubs, device_id):
        super().__init__()
        """
        init function
        """
        self.type = config['inputConf']['sourceType']
        if config['inputConf']['sourceType'] == 'videofile':
            f = open('cement_bag.json', "r")
            self.dets = json.loads(f.read())
            f.close()
        self.config = config["config"]
        self.device_id = device_id
        self.rtp = pubs.rtp_write
        self.mongo_logger = MongoLogger()
        #model = get_extra_fields(self.device_id).get(JanusDeploymentConstants.MODEL_KEY,
        #                                             "/app/data/yolov5_s_v2.pt")
        #self.yp = ExecuteInference(weight=model,
        #                           gpu=model_config.get("gpu", False),
        #                           agnostic_nms=model_config.get("agnostic_nms", False),
        #                           iou=model_config.get("iou", 0.5),
        #                           confidence=model_config.get("confidence", 0.6))
        self.ct = CentroidTracker(maxDisappeared=model_config.get("ct_frames", 1))
        self.classes = {0: 'Cement Bag'}
        self.count = None
        self.initial_object_position = None
        self.uncounted_objects = ExpiringDict(max_len=model_config.get("uncounted_obj_length", 50),
                                              max_age_seconds=model_config.get("uncounted_obj_age", 60))
        self.frame_skipping = {
            "to_skip": model_config.get("skip_alternative_frames", False),
            "skip_current_frame": False,
            "detection_value": None
        }

    def _pre_process(self, x):
        """
        Do preprocessing here, if any
        :param x: payload
        :return: payload
        """
        return x

    def _post_process(self, x):
        """
        Apply post processing here, if any
        :param x: payload
        :return: payload
        """
        self.rtp.publish(x)  # video stream
        return x

    def track_bags(self, dets, im0, centroid_color=(255, 0, 0)):
        """
        Track the bags using Centroid based tracking
        :param dets: prediction output
        :param im0: raw frame
        :param centroid_color: color given to the centroid marking
        :return: centroid points, frame
        """
        bags = list()
        for det in dets:
            dete = np.array([det['points'][0], det['points'][1], det['points'][2], det['points'][3]])
            bags.append(dete.astype("int"))
        objects = self.ct.update(bags)
        if centroid_color is not False:
            for (objectID, centroid) in objects.items():
                cv2.circle(im0, (tuple(centroid)), 4, centroid_color, -1)
        return objects, im0

    def get_line_coordinates(self):
        """
        Get the line coordinates from the deployment JSON
        """
        _janus_deployment = get_extra_fields(self.device_id)
        _coordinates = [_janus_deployment.get(coordinate_key) for coordinate_key in
                        JanusDeploymentConstants.LINE_COORDINATES]
        _alignment = _janus_deployment.get(JanusDeploymentConstants.ALIGNMENT_KEY)
        return _alignment, _coordinates

    def line_point_position(self, point):
        """
        Get the position of point w.r.t. the line
        :param point: point to be compared
        :return: boolean
        """
        _alignment, line_coordinates = self.get_line_coordinates()

        assert len(line_coordinates) == 4, "Line coordinates variable is invalid"
        assert len(point) == 2, "Point variable is invalid"

        if (line_coordinates[2] - line_coordinates[0]) == 0 and _alignment == "vertical":
            if point[0] > line_coordinates[2]:
                return True
            else:
                return False
        else:

            _slope = (line_coordinates[3] - line_coordinates[1]) / (line_coordinates[2] - line_coordinates[0])

        _point_equation_value = point[1] - line_coordinates[1] - _slope * (point[0] - line_coordinates[0])
        if _point_equation_value > 0:
            return True
        else:
            return False

    def draw_line_over_image(self, frame, color=(255, 0, 0)):
        """
        Draws line over the counting line
        :param frame: frame for
        :param color:
        :return:
        """
        _alignment, line_coordinates = self.get_line_coordinates()
        assert len(line_coordinates) == 4, "Line coordinates variable is invalid"
        return cv2.line(frame, (line_coordinates[0], line_coordinates[1]), (line_coordinates[2], line_coordinates[3]),
                        color, 3)

    def validate_point_position(self, point):
        """
        Validate the position of the point w.r.t. the line
        :param point: centroid
        :return: bool
        """
        _alignment, line_coordinates = self.get_line_coordinates()
        assert _alignment in [JanusDeploymentConstants.VERTICAL, JanusDeploymentConstants.HORIZONTAL], \
            "Invalid alignment variable"
        if _alignment == JanusDeploymentConstants.VERTICAL:
            line_y2 = line_coordinates[3]
            line_y1 = line_coordinates[1]
            if line_y1 < point[1] < line_y2 or line_y2 < point[1] < line_y1:
                return True
            else:
                return False
        else:
            line_x2 = line_coordinates[2]
            line_x1 = line_coordinates[0]
            if line_x1 < point[0] < line_x2 or line_x2 < point[0] < line_x1:
                return True
            else:
                return False

    def send_payload(self, frame, label='CementBagDetected', bg_color="#474520", font_color="#FFFF00", alert_sound=None,
                     message="Cement Bag Detected!"):
        """
        Insert event to Mongo
        :param message:
        :param frame:
        :param label:
        :param bg_color:
        :param font_color:
        :param alert_sound:
        :return: None
        """

        payload = {"deviceId": self.device_id, "message": message,
                   "frame": 'data:image/jpeg;base64,' + base64.b64encode(
                       cv2.imencode('.jpg', frame)[1].tostring()).decode("utf-8"), "activity": label,
                   "bg_color": bg_color, "font_color": font_color, "alert_sound": alert_sound, "app": "cement"}

        self.mongo_logger.insert_attendance_event_to_mongo(payload)

    def update_bag_count(self, frame, detection_objects):
        """
        Maintains the bag counts
        :param frame: image
        :param detection_objects: detection object having object id and centroids
        """
        for (objectID, centroid) in detection_objects.items():
            if self.validate_point_position(centroid):
                if not isinstance(self.count, int):
                    logger.debug("Initializing the count variable")
                    # Initializing the bag count
                    self.count = 0

                if not isinstance(self.initial_object_position, bool):
                    logger.debug("Initializing the initial object position")
                    self.initial_object_position = self.line_point_position(point=centroid)

                _point_position = self.line_point_position(point=centroid)

                # Check point in the same side as the initial object
                if _point_position == self.initial_object_position:
                    # Check the object is not already counted
                    if objectID not in self.uncounted_objects:
                        self.uncounted_objects[objectID] = centroid
                    frame = draw_circles_on_frame(frame, centroid, radius=10, color=(0, 0, 255), thickness=-1)

                elif objectID in self.uncounted_objects:
                    self.uncounted_objects.pop(objectID, None)
                    self.count += 1
                    frame = draw_circles_on_frame(frame, centroid, radius=10, color=(0, 255, 0), thickness=-1)
                    self.send_payload(resize_to_64_64(frame=frame))
                else:
                    frame = draw_circles_on_frame(frame, centroid, radius=10, color=(0, 255, 0), thickness=-1)

        return frame

    def _predict(self, obj):
        try:
            time.sleep(0.05)
            frame = obj['frame']
            id = int(obj['frameId'])
            if self.frame_skipping["to_skip"]:
                if not self.frame_skipping["skip_current_frame"]:
                    if self.type == 'videofile':
                        dets = self.dets[id][str(id)]['detections']
                    else:
                        #dets = self.yp.predict(frame)
                        print("**************")
                    # dets = self.yp.predict(frame)
                    self.frame_skipping["detection_value"] = dets
                    self.frame_skipping["skip_current_frame"] = True
                else:
                    dets = self.frame_skipping["detection_value"]
                    self.frame_skipping["skip_current_frame"] = False
            else:
                if self.type == 'videofile':
                    dets = self.dets[id][str(id)]['detections']
                else:
                    #dets = self.yp.predict(frame)
                    print("******************")
            objects, frame = self.track_bags(dets, frame)
            frame = self.update_bag_count(frame=frame, detection_objects=objects)
            logger.debug("Counts: {}".format(self.count))

            obj['frame'] = cv2.resize(self.draw_line_over_image(frame), (self.config.get('FRAME_WIDTH'), self.config.
                                                                         get('FRAME_HEIGHT')))
            timestamp = datetime.datetime.now().replace(microsecond=0).isoformat()
            obj["timestamp"] = timestamp
        except Exception as e:
            logger.exception(f"Error: {e}")
            obj['frame'] = cv2.resize(obj['frame'], (self.config.get('FRAME_WIDTH'), self.config.get('FRAME_HEIGHT')))
            logger.exception(traceback.format_exc())
            obj["error"] = "{}".format(e)
            obj["message"] = "{}".format("error processing frame")
            obj["status"] = False
            obj["timestamp"] = datetime.datetime.now().replace(microsecond=0).isoformat()
        return obj
