import os
import uuid
from scripts.utils.edge_utils import get_extra_fields
from scripts.utils.centroidtracker import CentroidTracker
from scripts.utils.image_utils import draw_circles_on_frame, resize_to_64_64

import cv2
import base64
import numpy as np

from collections import deque
from expiringdict import ExpiringDict
from sklearn.utils.linear_assignment_ import linear_assignment
# from scipy.optimize import linear_sum_assignment as linear_assignment

from edge_engine.common.logsetup import logger
from edge_engine.ai.model.modelwraper import ModelWrapper

from scripts.utils.tracker import Tracker
from scripts.utils.helpers import box_iou2
from scripts.utils.infocenter import MongoLogger
from scripts.common.constants import JanusDeploymentConstants
import time
import json



class Ppe(ModelWrapper):

    def __init__(self, config, model_config, pubs, device_id):
        super().__init__()
        """
        init function
        """
        self.type = config['inputConf']['sourceType']
        if config['inputConf']['sourceType'] == 'videofile':
            f = open('aarti.json', "r")
            self.dets = json.loads(f.read())
            f.close()
        self.config = config["config"]
        self.device_id = device_id
        self.rtp = pubs.rtp_write
        self.mongo_logger = MongoLogger()
        self.frame_skip = self.config.get('frame_skip', False)
        model = "data/aarti_v3.pt"
        self.print_eu_dist = model_config.get('print_eu_dist', 200)
        self.ct1 = CentroidTracker(maxDisappeared=5)
        self.ct2 = CentroidTracker(maxDisappeared=5)
        self.frame_skipping = {
            "skip_current_frame": True,
            "detection_value": None
        }
        self.count = 0
        self.tracker_list = []
        self.max_age = 3
        self.min_hits = 0
        self.track_id_list = deque([str(i) for i in range(1, 50)])
        self.prev_annotation = []


        self.initial_object_position = None

        self.uncounted_objects = ExpiringDict(max_len=model_config.get("uncounted_obj_length", 50),
                                              max_age_seconds=model_config.get("uncounted_obj_age", 60))
        self.janus_metadata = ExpiringDict(max_age_seconds=120, max_len=1)
        self.safety_equip = ExpiringDict(max_age_seconds=30, max_len=10)
        self.polygon = np.array([[[400, 700], [1600, 500], [1900, 650], [1900, 1000], [700, 1000]]])
        self.tracking_people = {}
        self.final_ppe_result = {}
        self.skip_frame_bool = False
        self.skip_frame_count = 0
        self.frame_id = 0
        self.reported_violation_ids = {}
        self.violation_count = {"Air Breathing Mask": [] , "Safety helmet": [], "Hand gloves": [], "coverall suit": []}


        self.payload_classes = {"Air Breathing Mask": "air_breathing_mask_violation", "Safety helmet": "helmet_violation", "Hand gloves": "glove_violation", "coverall suit": "coverall_suit_violation"}

        self.active_rec = {}


    def _pre_process(self, x):
        """
        Do preprocessing here, if any
        :param x: payload
        :return: payload
        """
        return x

    def _post_process(self, x):
        """
        Apply post processing here, if any
        :param x: payload
        :return: payload
        """
        self.rtp.publish(x)  # video stream
        return x

    def send_payload(self, frame, label='PPEFrame', bg_color="#474520", font_color="#FFFF00", alert_sound=None,
                     message=[], event = "", frame_id = None):
        """
        Insert event to Mongox
        :param message:
        :param frame:
        :param label:
        :param bg_color:
        :param font_color:
        :param alert_sound:
        :return: None
        """

        payload = {"deviceId": self.device_id, "message": message,
                   "frame": 'data:image/jpeg;base64,' + base64.b64encode(
                       cv2.imencode('.jpg', frame)[1].tostring()).decode("utf-8"), "activity": label,
                   "bg_color": bg_color, "font_color": font_color, "alert_sound": alert_sound, "frame_id": frame_id, "event_type": event}

        self.mongo_logger.insert_attendance_event_to_mongo(payload)

    def kalman_tracker(
            self,
            bboxs,
            img,
    ):

        z_box = bboxs
        x_box = []

        if len(self.tracker_list) > 0:
            for trk in self.tracker_list:
                x_box.append(trk.box)

        matched, unmatched_dets, unmatched_trks = self.assign_detections_to_trackers(x_box, z_box, iou_thrd=0.03)

        # Deal with matched detections
        if matched.size > 0:
            for trk_idx, det_idx in matched:
                z = z_box[det_idx]
                z = np.expand_dims(z, axis=0).T
                tmp_trk = self.tracker_list[trk_idx]
                tmp_trk.kalman_filter(z)
                xx = tmp_trk.x_state.T[0].tolist()
                xx = [xx[0], xx[2], xx[4], xx[6]]
                x_box[trk_idx] = xx
                tmp_trk.box = xx
                tmp_trk.hits += 1

        # Deal with unmatched detections
        if len(unmatched_dets) > 0:
            for idx in unmatched_dets:
                z = z_box[idx]
                z = np.expand_dims(z, axis=0).T
                tmp_trk = Tracker()  # Create a new tracker
                x = np.array([[z[0], 0, z[1], 0, z[2], 0, z[3], 0]]).T
                tmp_trk.x_state = x
                tmp_trk.predict_only()
                xx = tmp_trk.x_state
                xx = xx.T[0].tolist()
                xx = [xx[0], xx[2], xx[4], xx[6]]
                tmp_trk.box = xx
                tmp_trk.id = self.track_id_list.popleft()  # assign an ID for the tracker

                self.tracker_list.append(tmp_trk)
                x_box.append(xx)

        # Deal with unmatched tracks
        if len(unmatched_trks) > 0:
            for trk_idx in unmatched_trks:
                tmp_trk = self.tracker_list[trk_idx]
                tmp_trk.no_losses += 1
                tmp_trk.predict_only()
                xx = tmp_trk.x_state
                xx = xx.T[0].tolist()
                xx = [xx[0], xx[2], xx[4], xx[6]]
                tmp_trk.box = xx
                x_box[trk_idx] = xx

        # The list of tracks to be annotated
        good_tracker_list = []
        objects = []
        boxs = []
        for trk in self.tracker_list:
            if (trk.hits >= self.min_hits) and (trk.no_losses <= self.max_age):
                good_tracker_list.append(trk)
                x_cv2 = trk.box
                left, top, right, bottom = x_cv2[1], x_cv2[0], x_cv2[3], x_cv2[2]
                centroid = [int(left + ((right - left) / 2)), bottom]
                objects.append([int(trk.id), centroid])
                boxs.append(x_cv2)

        deleted_tracks = filter(lambda _x: _x.no_losses > self.max_age, self.tracker_list)

        for trk in deleted_tracks:
            self.track_id_list.append(trk.id)

        self.tracker_list = [x for x in self.tracker_list if x.no_losses <= self.max_age]
        # print("object is ", str(objects))

        return img, objects, boxs

    @staticmethod
    def assign_detections_to_trackers(
            trackers,
            detections,
            iou_thrd=0.3,
    ):
        """
        From current list of trackers and new detections, output matched detections,
        un matched trackers, unmatched detections.
        """
        iou_mat = np.zeros((len(trackers), len(detections)), dtype=np.float32)
        for t, trk in enumerate(trackers):
            for d, det in enumerate(detections):
                iou_mat[t, d] = box_iou2(trk, det)

        matched_idx = linear_assignment(-iou_mat)
        # row, col = matched_idx
        # matched_idx = np.concatenate((row.reshape(-1, 1), col.reshape(-1, 1)), axis=1)
        unmatched_trackers, unmatched_detections = [], []
        for t, trk in enumerate(trackers):
            if t not in matched_idx[:, 0]:
                unmatched_trackers.append(t)

        for d, det in enumerate(detections):

            if d not in matched_idx[:, 1]:
                unmatched_detections.append(d)

        matches = []

        for m in matched_idx:
            if iou_mat[m[0], m[1]] < iou_thrd:
                unmatched_trackers.append(m[0])
                unmatched_detections.append(m[1])
            else:
                matches.append(m.reshape(1, 2))

        if len(matches) == 0:
            matches = np.empty((0, 2), dtype=int)
        else:
            matches = np.concatenate(matches, axis=0)

        return matches, np.array(unmatched_detections), np.array(unmatched_trackers)

    def get_line_coordinates(self):
        """
        Get the line coordinates from the deployment JSON
        """
        if not self.janus_metadata.get('metadata'):
            self.janus_metadata['metadata'] = get_extra_fields(self.device_id)

        _coordinates = [self.janus_metadata['metadata'].get(coordinate_key) for coordinate_key in
                        JanusDeploymentConstants.LINE_COORDINATES]
        _alignment = self.janus_metadata['metadata'].get(JanusDeploymentConstants.ALIGNMENT_KEY) \
            # _coordinates = [550, 200, 555, 1100]
        #
        # _alignment = "vertical"
        return _alignment, _coordinates



    def get_line_coordinates(self):
        """
        Get the line coordinates from the deployment JSON
        """
        if not self.janus_metadata.get('metadata'):
            self.janus_metadata['metadata'] = get_extra_fields(self.device_id)

        _coordinates = [self.janus_metadata['metadata'].get(coordinate_key) for coordinate_key in
                        JanusDeploymentConstants.LINE_COORDINATES]
        _alignment = self.janus_metadata['metadata'].get(JanusDeploymentConstants.ALIGNMENT_KEY)\
        # _coordinates = [550, 200, 555, 1100]
        #
        # _alignment = "vertical"
        return _alignment, _coordinates

    def line_point_position(self, point):
        """
        Get the position of point w.r.t. the line
        :param point: point to be compared
        :return: boolean
        """
        _alignment, line_coordinates = self.get_line_coordinates()

        assert len(line_coordinates) == 4, "Line coordinates variable is invalid"
        assert len(point) == 2, "Point variable is invalid"

        _slope = (line_coordinates[3] - line_coordinates[1]) / (line_coordinates[2] - line_coordinates[0])
        _point_equation_value = point[1] - line_coordinates[1] - _slope * (point[0] - line_coordinates[0])
        if _point_equation_value > 0:
            return True
        else:
            return False

    def validate_point_position(self, point):
        """
        Validate the position of the point w.r.t. the line
        :param point: centroid
        :return: bool
        """
        _alignment, line_coordinates = self.get_line_coordinates()
        assert _alignment in [JanusDeploymentConstants.VERTICAL, JanusDeploymentConstants.HORIZONTAL], \
            "Invalid alignment variable"
        if _alignment == JanusDeploymentConstants.VERTICAL:
            line_y2 = line_coordinates[3]
            line_y1 = line_coordinates[1]
            if line_y1 < point[1] < line_y2 or line_y2 < point[1] < line_y1:
                return True
            else:
                return False
        else:
            line_x2 = line_coordinates[2]
            line_x1 = line_coordinates[0]
            if line_x1 < point[0] < line_x2 or line_x2 < point[0] < line_x1:

                return True
            else:
                return False


    def ppe_detection(self, frame, bbox, detection_objects, class_name, other_class_name, other_centroid):
        """
        Maintains the bag counts
        :param frame: image
        :param detection_objects: detection object having object id and centroids
        """
        detected_objects = []
        needed_objects_with_mask = {"Air Breathing Mask", "Hand gloves", "coverall suit"}
        needed_objects_with_mask_helmet = {"Air Breathing Mask", "Hand gloves", "coverall suit", "Safety helmet"}
        needed_objects_with_helmet = {"Safety helmet", "Hand gloves", "coverall suit"}
        person_with_safety_kit = 0
        person_without_safety_kit = 0
        total_person_safety_kits = {"Air Breathing Mask"}

        #logger.info("Detections: " + str(detections))
        for (object_id, class_detected, person_bb) in zip(
                detection_objects, class_name, bbox
        ):
            centroid = object_id[1]
            # print(object_id)
            object_id = object_id[0]
            cv2.circle(frame, (centroid[1], centroid[0]), 2, (0, 255, 0), thickness=1, lineType=8, shift=0)

        # for person_bb in bbox:
            person_safety_status = set()
            for (safety_object, object_bb) in zip(other_class_name, other_centroid):
                if (person_bb[0] < object_bb[0] and person_bb[2] > object_bb[0]):
                    if(person_bb[1] < object_bb[1] and person_bb[3] > object_bb[1]):
                        person_safety_status.add(safety_object)
                        print("SAFETY OBJECTSSSS")
                        print(safety_object)

                        # print(safety_object, object_bb)
                        if("object_id" in self.safety_equip):
                            temp_list = set()
                            temp_list = self.safety_equip["object_id"]
                            temp_list.add(safety_object)
                            # temp_list = self.safety_equip["object_id"]
                            self.safety_equip["object_id"] = temp_list
                            # print("object id present")
                        else:
                            temp_list = set()
                            temp_list.add(safety_object)
                            # print("object id not present")
                            self.safety_equip["object_id"] = temp_list


            if("Air Breathing Mask" in self.safety_equip["object_id"] and "Safety helmet" in self.safety_equip["object_id"]):
                # print("with both helmet and air breathing mask if")
                if (self.safety_equip["object_id"] == needed_objects_with_mask_helmet):
                    # print("SAFE--------------------------------------------")
                    person_with_safety_kit += 1
                    cv2.rectangle(frame, (person_bb[0], person_bb[1]), (person_bb[2], person_bb[3]), (0, 255, 0), 2)
                    # cv2.waitKey(1)

                    if (object_id in self.active_rec):
                        rec_inf = self.active_rec[object_id]
                        # rec_inf[1].release()
                        temp_v_list = rec_inf[3]
                        print(temp_v_list)
                        # cv2.waitKey(1)
                        for v in temp_v_list:

                            self.send_payload(frame=resize_to_64_64(frame), message=temp_v_list,
                                              event=v, frame_id=rec_inf[2])

                        del self.active_rec[object_id]


                else:
                    violations = needed_objects_with_mask_helmet.difference(self.safety_equip["object_id"])
                    # print("violations")
                    # print(violations)
                    temp_violation_list = []
                    for v in violations:
                        temp_violation_list.append(self.payload_classes[v])

                    # violated_items = ', '.join(list(map(str, temp_violation_list)))
                    temp_violation_list.sort()
                    violated_items_2 = ', '.join(list(map(str, violations)))

                    for elem in violations:
                        self.violation_count[elem].append(elem)

                    if (object_id not in self.reported_violation_ids):
                        print("sending to mongo")

                        self.send_payload(frame=resize_to_64_64(frame), message=temp_violation_list, event = violated_items,frame_id=self.frame_id)
                        #
                        self.reported_violation_ids[object_id] = time.time()

                    else:
                        time_diff = time.time() - self.reported_violation_ids[object_id]
                        if(time_diff > 30):
                            del self.reported_violation_ids[object_id]

                            if (object_id in self.active_rec):
                                rec_inf = self.active_rec[object_id]
                                # rec_inf[1].release()
                                temp_v_list = rec_inf[3]
                                print(temp_v_list)
                                cv2.waitKey(1)
                                for v in temp_v_list:

                                    self.send_payload(frame=resize_to_64_64(frame), message=temp_v_list,
                                                      event=v, frame_id=rec_inf[2])

                                del self.active_rec[object_id]



                    person_without_safety_kit += 1
                    cv2.rectangle(frame, (person_bb[0], person_bb[1]), (person_bb[2], person_bb[3]), (0, 0, 255), 2)

            elif ("Air Breathing Mask" in self.safety_equip["object_id"]):
                # print("air breathing mask if")
                print(self.safety_equip["object_id"])
                if (self.safety_equip["object_id"] == needed_objects_with_mask):
                    # print("SAFE--------------------------------------------")
                    person_with_safety_kit += 1
                    cv2.rectangle(frame, (person_bb[0], person_bb[1]), (person_bb[2], person_bb[3]), (0, 255, 0), 2)
                    # cv2.waitKey(1)
                    if (object_id in self.active_rec):
                        rec_inf = self.active_rec[object_id]
                        # rec_inf[1].release()
                        temp_v_list = rec_inf[3]
                        print(temp_v_list)
                        # cv2.waitKey(1)
                        for v in temp_v_list:

                            self.send_payload(frame=resize_to_64_64(frame), message=temp_v_list,
                                              event=v, frame_id=rec_inf[2])
                        # with open("output/recorded_video_list.txt", "a") as f:
                        #     f.write(rec_inf[2] + ".webm")
                        #     f.write("\n")
                        del self.active_rec[object_id]




                else:
                    violations = needed_objects_with_mask.difference(self.safety_equip["object_id"])

                    for elem in violations:
                        self.violation_count[elem].append(elem)

                    temp_violation_list = []
                    for v in violations:
                        temp_violation_list.append(self.payload_classes[v])
                    temp_violation_list.sort()
                    violated_items = ', '.join(list(map(str, temp_violation_list)))
                    violated_items_2 = ', '.join(list(map(str, violations)))
                    print("violated items")
                    print(violated_items)
                    # print(msg)
                    # print("VIOLATION LIST")
                    print(violated_items)
                    if (object_id not in self.reported_violation_ids):
                        print("sending to mongo")

                        self.send_payload(frame=resize_to_64_64(frame), message=temp_violation_list , event = violated_items, frame_id=self.frame_id)
                        self.reported_violation_ids[object_id] = time.time()

                    else:
                        time_diff = time.time() - self.reported_violation_ids[object_id]
                        if(time_diff > 30):
                            del self.reported_violation_ids[object_id]

                            if (object_id in self.active_rec):
                                rec_inf = self.active_rec[object_id]
                                # rec_inf[1].release()
                                temp_v_list = rec_inf[3]
                                print(temp_v_list)
                                # cv2.waitKey(1)
                                for v in temp_v_list:

                                    self.send_payload(frame=resize_to_64_64(frame), message=temp_v_list,
                                                      event=v, frame_id=rec_inf[2])

                                del self.active_rec[object_id]

                    person_without_safety_kit += 1
                    cv2.rectangle(frame, (person_bb[0], person_bb[1]), (person_bb[2], person_bb[3]), (0, 0, 255), 2)
                    # cv2.waitKey(1)

            else:
                print("without air breathing mask if")
                if (self.safety_equip["object_id"] == needed_objects_with_helmet):
                    # print("SAFE--------------------------------------------")
                    person_with_safety_kit += 1
                    cv2.rectangle(frame, (person_bb[0], person_bb[1]), (person_bb[2], person_bb[3]), (0, 255, 0), 2)
                    # cv2.waitKey(1)
                    if(object_id in self.active_rec):
                        rec_inf = self.active_rec[object_id]
                        # rec_inf[1].release()
                        temp_v_list = rec_inf[3]
                        print(temp_v_list)
                        cv2.waitKey(1)
                        for v in temp_v_list:

                            self.send_payload(frame=resize_to_64_64(frame), message=temp_v_list,
                                              event=v, frame_id=rec_inf[2])

                        del self.active_rec[object_id]


                else:
                    violations = needed_objects_with_helmet.difference(self.safety_equip["object_id"])
                    # print("violations")
                    # print(violations)
                    temp_violation_list = []
                    for v in violations:
                        temp_violation_list.append(self.payload_classes[v])

                    # temp_violation_list.append("air_breathing_mask_violation")
                    temp_violation_list.sort()
                    violated_items = ', '.join(list(map(str, temp_violation_list)))
                    violated_items_2 = ', '.join(list(map(str, violations)))
                    print("violated items")
                    print(violated_items)
                    violated_items = violated_items


                    for elem in violations:
                        self.violation_count[elem].append(elem)

                    if (object_id not in self.reported_violation_ids):
                        print("sending to mongo")

                        self.send_payload(frame=resize_to_64_64(frame), message=temp_violation_list, event = violated_items, frame_id=self.frame_id)
                        self.reported_violation_ids[object_id] = time.time()

                    else:
                        time_diff = time.time() - self.reported_violation_ids[object_id]
                        if(time_diff > 30):

                            del self.reported_violation_ids[object_id]

                            if (object_id in self.active_rec):
                                rec_inf = self.active_rec[object_id]
                                # rec_inf[1].release()
                                temp_v_list = rec_inf[3]
                                print(temp_v_list)
                                cv2.waitKey(1)
                                for v in temp_v_list:

                                    self.send_payload(frame=resize_to_64_64(frame), message=temp_v_list,
                                                      event=v, frame_id=rec_inf[2])

                    person_without_safety_kit += 1
                    cv2.rectangle(frame, (person_bb[0], person_bb[1]), (person_bb[2], person_bb[3]), (0, 0, 255), 2)

        person_without_kit_text = "PERSON WITHOUT AIR BREATHING MASK : " + str(person_without_safety_kit)
        person_with_kit_text = "PERSON WITH AIR BREATHING MASK : " + str(person_with_safety_kit)

        cv2.rectangle(frame, (870, 0), (1400, 50), (0, 0, 255), -1)
        cv2.rectangle(frame, (870, 50), (1400, 120), (50, 50, 50), -1)
        cv2.putText(frame, "Air Breathing Mask Violation", (880, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 4, 2)
        cv2.rectangle(frame, (1400, 300), (1900, 350), (0, 255, 0), -1)
        cv2.rectangle(frame, (1400, 350), (1900, 420), (50, 50, 50), -1)
        cv2.putText(frame, "People following compliance", (1410, 330), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 4, 2)
        cv2.rectangle(frame, (870, 300), (1400, 350), (0, 0, 255), -1)
        cv2.rectangle(frame, (870, 350), (1400, 420), (50, 50, 50), -1)
        cv2.putText(frame, "People not following compliance", (880, 330), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 4, 2)
        cv2.rectangle(frame, (1400, 0), (1900, 50), (0, 0, 255), -1)
        cv2.rectangle(frame, (1400, 50), (1900, 120), (50, 50, 50), -1)
        cv2.putText(frame, "Coverall Suit Violation", (1410, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 4, 2)
        cv2.rectangle(frame, (870, 150), (1400, 200), (0, 0, 255), -1)
        cv2.rectangle(frame, (870, 200), (1400, 270), (50, 50, 50), -1)
        cv2.putText(frame, "Helmet Violation", (880, 180), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 4, 2)
        cv2.rectangle(frame, (1400, 150), (1900, 200), (0, 0, 255), -1)
        cv2.rectangle(frame, (1400, 200), (1900, 270), (50, 50, 50), -1)
        cv2.putText(frame, "Glove Violation", (1410, 180), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 4, 2)
        # print("dict")
        # print(self.violation_count)
        if (person_without_safety_kit > 0):
            cv2.putText(frame, str(person_without_safety_kit), (1100, 400), cv2.FONT_HERSHEY_SIMPLEX, 1,
                        (255, 255, 255), 4, 2)

        if (len(self.violation_count["Air Breathing Mask"]) != 0):
            cv2.putText(frame, str(len(self.violation_count[""])), (1100, 100), cv2.FONT_HERSHEY_SIMPLEX, 1,
                        (255, 255, 255), 4, 2)

        if (person_with_safety_kit > 0):
            cv2.putText(frame, str(person_with_safety_kit), (1600, 400), cv2.FONT_HERSHEY_SIMPLEX, 1,
                        (255, 255, 255), 4, 2)

        if (len(self.violation_count["coverall suit"]) != 0):
            cv2.putText(frame, str(len(self.violation_count["coverall suit"])), (1600, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 4, 2)

        if (len(self.violation_count["Safety helmet"]) != 0):
            cv2.putText(frame, str(len(self.violation_count["Safety helmet"])), (1100, 250), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 4, 2)

        if (len(self.violation_count["Hand gloves"]) != 0):
            cv2.putText(frame, str(len(self.violation_count["Hand gloves"])), (1600, 250), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 4, 2)
        self.violation_count = {"Air Breathing Mask": [], "Safety helmet": [], "Hand gloves": [], "coverall suit": []}
        # print("expiring dict")
        # print(self.safety_equip)
        # if(len(self.safety_equip) == 0):
        #     cv2.waitKey(1)
        return frame

    def draw_line_over_image(self, frame, color=(255, 255, 255)):
        """
        Draws line over the counting line
        :param frame: frame for
        :param color:
        :return:
        """
        _alignment, line_coordinates = self.get_line_coordinates()
        assert len(line_coordinates) == 4, "Line coordinates variable is invalid"

        # return cv2.line(frame, (line_coordinates[0], line_coordinates[1]), (line_coordinates[2], line_coordinates[3]),
        #                 color, 3)

        self.drawline(frame, (line_coordinates[0], line_coordinates[1]), (line_coordinates[2],
                                                                          line_coordinates[3]), color, thickness=3)
        return frame

    @staticmethod
    def drawline(img, pt1, pt2, color, thickness=1, style='dotted', gap=20):
        dist = ((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2) ** .5
        pts = []
        for i in np.arange(0, dist, gap):
            r = i / dist
            x = int((pt1[0] * (1 - r) + pt2[0] * r) + .5)
            y = int((pt1[1] * (1 - r) + pt2[1] * r) + .5)
            p = (x, y)
            pts.append(p)

        if style == 'dotted':
            for p in pts:
                cv2.circle(img, p, thickness, color, -1)
        else:
            s = pts[0]
            e = pts[0]
            i = 0
            for p in pts:
                s = e
                e = p
                if i % 2 == 1:
                    cv2.line(img, s, e, color, thickness)
                i += 1

    def inference(
            self,
            dets,
            classes,
            frame

    ):
        # dets = run(self.yolo_model, frame)
        class_name = list()
        bboxs = []
        other_centroid = []
        other_class_name = list()
        logger.info(dets)
        print("PREDICTIONS ")
        print(dets)
        if dets:
            for i in dets:

                if(i["class"] in classes):
                    # print("detections are there")

                    # c = (int(i["points"][0] + (i["points"][2] - i["points"][0])/2), int(i["points"][1] + (i["points"][3] - i["points"][1])/2))

                    # if(cv2.pointPolygonTest(self.polygon, c, False) == 1.0):
                        #logger.info("Detections inside polygon (i): " + str(i))
                    class_name.append(i["class"])
                    frame = cv2.rectangle(frame, (i["points"][0], i["points"][1]), (i["points"][2], i["points"][3]), (255, 255, 0), 2)
                    bboxs.append([i["points"][0], i["points"][1], i["points"][2], i["points"][3]])
                    # bboxs.append([i["points"][1], i["points"][0], i["points"][3], i["points"][2]])
                    # bboxs.append(
                    #     [i["points"][1], i["points"][0], i["points"][3], i["points"][2]]
                    # )

                    #logger.info("BBOX inside polygon: " + str(bboxs))
                else:
                    c = tuple(i["centroid"])
                    other_class_name.append(i["class"])
                    cv2.circle(frame, c, 1, 1, thickness=8, lineType=8, shift=0)
                    # cv2.rectangle(frame, (i["points"][0], i["points"][1]), (i["points"][2], i["points"][3]),
                    #               (255, 0, 0), 2)
                    other_centroid.append([i["centroid"][0], i["centroid"][1]])
        return bboxs, frame, class_name, other_class_name, other_centroid

    def _predict(self, obj):
        self.count+= 1

        class_list = ["person"]

        ###################
        try:
            time.sleep(0.05)
            frame = obj['frame']
            # frame = cv2.resize(frame, (480, 270))
            id = int(obj['frameId'])
            # if self.frame_skipping["to_skip"]:
            #     if not self.frame_skipping["skip_current_frame"]:
            #         if self.type == 'videofile':
            #             dets = self.dets[id][str(id)]['detections']
            #         else:
            #             # dets = self.yp.predict(frame)
            #             print("**************")
            #         # dets = self.yp.predict(frame)
            #         self.frame_skipping["detection_value"] = dets
            #         self.frame_skipping["skip_current_frame"] = True
            #     else:
            #         dets = self.frame_skipping["detection_value"]
            #         self.frame_skipping["skip_current_frame"] = False
            # else:
            if self.type == 'videofile':
                dets = self.dets[id][str(id)]['detections']
            else:
                # dets = self.yp.predict(frame)
                print("******************")

        ##################

            bbox, frame, class_name, other_class_name, other_centroid = self.inference(dets, class_list, frame)
            print("bounding box")
            print(bbox)
            frame, objects, boxs = self.kalman_tracker(bbox, frame)
            print("kalman outputs")
            print(objects)
        #logger.info("PRINTING KALMAN OUTPUT")
        #logger.info(objects)
        #logger.info(boxs)

            frame = self.ppe_detection(frame=frame, bbox=bbox, detection_objects = objects, class_name = class_name, other_class_name = other_class_name, other_centroid = other_centroid)
            #logger.info("Final PPE tracking people result: " + str(self.final_ppe_result))

            # call send payload to update final_ppe_result, empty final ppe result dict

            logger.debug("self.uncounted_objects --> {}".format(self.uncounted_objects))
            obj['frame'] = cv2.resize(frame, (self.config.get('FRAME_WIDTH'), self.config.
                                              get('FRAME_HEIGHT')))

            cv2.putText(frame, str(self.frame_id), (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0),
                        1, cv2.LINE_AA)


            cv2.imshow("output is ", cv2.resize(frame, (1000, 800)))
            cv2.waitKey(1)
        except Exception as e:
            logger.exception(f"Error: {e}", exc_info=True)
            obj['frame'] = cv2.resize(obj['frame'], (self.config.get('FRAME_WIDTH'), self.config.get('FRAME_HEIGHT')))

        return obj
