import cv2
import json
import base64
import numpy as np
from scipy.spatial import distance
from expiringdict import ExpiringDict
import time
from edge_engine.common.logsetup import logger
from scripts.utils.infocenter import MongoLogger
# from yolov5processor.infer import ExecuteInference
from scripts.utils.edge_utils import get_extra_fields
from edge_engine.ai.model.modelwraper import ModelWrapper
from scripts.utils.centroidtracker import CentroidTracker
from scripts.common.constants import JanusDeploymentConstants
from scripts.utils.image_utils import draw_circles_on_frame, resize_to_64_64
from pymongo import MongoClient
from scripts.common.config import MONGO_URI
from uuid import uuid4
import cv2
import base64
import datetime
import numpy as np
import imutils
from collections import deque
from expiringdict import ExpiringDict
from sklearn.utils.linear_assignment_ import linear_assignment

from edge_engine.common.logsetup import logger
from edge_engine.ai.model.modelwraper import ModelWrapper

from scripts.utils.tracker import Tracker
from scripts.utils.helpers import box_iou2
from scripts.utils.edge_utils import Utilities
from scripts.utils.infocenter import MongoLogger
from scripts.utils.model_tracker import ModelCountTracker
from scripts.common.constants import JanusDeploymentConstants

# TRT Additions start
# from yolov5processor.infer import ExecuteInference
# from scripts.utils.yolov5_trt import YoloV5TRT
# TRT Additions stop

from scripts.utils.relay_util import RelayHandler
from paddleocr import PaddleOCR


class CementBagCounter(ModelWrapper):

    def __init__(self, config, model_config, pubs, device_id):
        super().__init__()
        """
        init function
        """
        self.config = config["config"]
        self.device_id = device_id
        self.rtp = pubs.rtp_write
        self.mongo_logger = MongoLogger()
        self.frame_skip = self.config.get('frame_skip', False)

        # TRT Additions start

        # model = "data/acc_v6.pt"
        # self.yp = ExecuteInference(weight=model,
        #                            gpu=model_config.get("gpu", False),
        #                            agnostic_nms=model_config.get("agnostic_nms", True),
        #                            iou=model_config.get("iou", 0.2),
        #                            confidence=model_config.get("confidence", 0.4))
        engine_file_path = "data/acc_v14.engine"
        # with open("/home/ilens/container_weights/classes.json", 'r') as f:
        #     self.classes = json.loads(f.read())
        #     self.classes = {int(k): v for k, v in self.classes.items()}
        self.classes = {0: 'ambuja_plus', 1: 'acc_gold', 2: 'acc_suraksha_power_plus', 3: 'ambuja_buildcem', 4: 'mrp', 5: 'acc_suraksha_power', 6: 'acc_nfr', 7: 'acc_concrete_plus'}
        # self.yolo_v5_wrapper = YoloV5TRT(engine_file_path, model_config.get('conf_thresh', 0.5),
        #                                                  model_config.get('iou_thresh', 0.4))

        # TRT Additions stop

        # self.print_eu_dist = model_config.get('print_eu_dist', 200)
        self.ct1 = CentroidTracker(maxDisappeared=5)
        self.ct2 = CentroidTracker(maxDisappeared=5)

        self.count = 0
        self.cement_bag = 0
        self.count_suraksha = 0
        self.count_whitecem = 0
        self.count_gold = 0
        self.tracker_list = []
        self.max_age = 3
        self.min_hits = 0
        self.track_id_list = deque([str(i) for i in range(1, 50)])
        self.prev_annotation = []
        self.mrp_counter = 0
        self.count_nfr = 0
        self.count_suraksha_power = 0
        self.count_concrete_plus = 0
        self.count_ambuja_plus = 0
        # self.prev_class_name = None

        self.initial_object_position = None

        self.uncounted_objects = ExpiringDict(max_len=model_config.get("uncounted_obj_length", 50),
                                              max_age_seconds=model_config.get("uncounted_obj_age", 60))
        self.janus_metadata = ExpiringDict(max_age_seconds=120, max_len=1)
        self.mongo_alarm_coll = MongoClient(MONGO_URI)['ilens_events']['triggered_alarms']
        self.camera_details = self.mongo_logger.get_camera_details(self.device_id)
        self.black_white_ratio_dict = {'ambuja_plus': 1.2, 'acc_gold': 1.2, 'acc_suraksha_power_plus': 1.2,
                                  'ambuja_buildcem': 1.2, 'acc_suraksha_power': 1.2, 'acc_nfr': 1.2,
                                  'acc_concrete_plus': 1.2}
    def paddle_ocr_load_model(self):
        ocr = PaddleOCR(
            lang="en",
            # det_db_thresh=0.1,
            # det_db_box_thresh=0.1,
            # use_mp=True,
            # total_process_num=process_count,
            use_angle_cls=True,
            cls_model_dir="paddleocr/model/ch_ppocr_mobile_v2.0_cls_infer",
            rec_model_dir="paddleocr/model/ch_PP-OCRv3_rec_infer",
            det_model_dir="paddleocr/model/en_PP-OCRv3_det_infer")
        return ocr
    def paddle_ocr_prediction(self,img_path, ocr):
        result = ocr.ocr(img_path, cls=False, det=True, rec=True)[0]
        txts = [line[1][0] for line in result]
        return txts

    def check_character(self,character):
        if len(character) < 17:
            return False
        else:
            return character

    def fixed_character(self,character):
        if character[len(fixed) + len(internal)] not in year_list:
            return False

        if fixed != character[0:len(fixed)]:
            return False
        if internal != character[len(fixed):len(fixed) + len(internal)]:
            return False
        if plant_name != character[len(fixed) + len(internal) + 1]:
            return False
        if character[len(fixed) + len(internal) + 2] not in month_list:
            return False
        return character
    def craft_character_find(self):
        # read image
        image = read_image(image)

        # load models
        refine_net = load_refinenet_model(cuda=False)
        craft_net = load_craftnet_model(cuda=False)

        # perform prediction
        prediction_result = get_prediction(
            image=image,
            craft_net=craft_net,
            refine_net=refine_net,
            text_threshold=0.7,
            link_threshold=9999999999999999999,
            low_text=0.4,
            cuda=False,
            long_size=1280
        )

        # export detected text regions
        exported_file_paths = export_detected_regions(
            image=image,
            regions=prediction_result["boxes"],
            output_dir=output_dir,
            rectify=True
        )

        # export heatmap, detection points, box visualization
        export_extra_results(
            image=image,
            regions=prediction_result["boxes"],
            heatmaps=prediction_result["heatmaps"],
            output_dir=output_dir
        )

        # unload models from gpu
        empty_cuda_cache()
        print("find")
    def classify_last_eight_chacter(self):
        print("8")
    def verify_mont_year_plant(self):
        print("month")
    def compare_to_get_final_output(self):
        print("output")
    def final_output(self):
        prediction="success"
        print(prediction)

    def _pre_process(self, x):
        """
        Do preprocessing here, if any
        :param x: payload
        :return: payload
        """
        return x

    def _post_process(self, x):
        """
        Apply post processing here, if any
        :param x: payload
        :return: payload
        """
        self.rtp.publish(x)  # video stream
        return x

    def send_payload(self, frame, bag_type='', label='CementBagDetected', bg_color="#474520", font_color="#FFFF00",
                     alert_sound=None,
                     message="Cement Bag Detected!", mrp_frmae='', mrp_roi = '', mrp = ''):
        """
        Insert event to Mongo
        :param message:
        :param frame:
        :param label:
        :param bg_color:
        :param font_color:
        :param alert_sound:
        :return: None
        """

        payload = {"deviceId": self.device_id, "message": message,
                   "frame": 'data:image/jpeg;base64,' + base64.b64encode(
                       cv2.imencode('.jpg', frame)[1].tostring()).decode("utf-8"), "activity": label,
                   "bg_color": bg_color, "font_color": font_color, "alert_sound": alert_sound, "bag_type": bag_type,
                   "mrp_frmae": mrp_frmae, "mrp_roi" : mrp_roi, "mrp" : mrp}

        self.mongo_logger.insert_attendance_event_to_mongo(payload)

    def kalman_tracker(
            self,
            bboxs,
            img,
    ):

        z_box = bboxs
        x_box = []

        if len(self.tracker_list) > 0:
            for trk in self.tracker_list:
                x_box.append(trk.box)

        matched, unmatched_dets, unmatched_trks = self.assign_detections_to_trackers(x_box, z_box, iou_thrd=0.03)

        # Deal with matched detections
        if matched.size > 0:
            for trk_idx, det_idx in matched:
                z = z_box[det_idx]
                z = np.expand_dims(z, axis=0).T
                tmp_trk = self.tracker_list[trk_idx]
                tmp_trk.kalman_filter(z)
                xx = tmp_trk.x_state.T[0].tolist()
                xx = [xx[0], xx[2], xx[4], xx[6]]
                x_box[trk_idx] = xx
                tmp_trk.box = xx
                tmp_trk.hits += 1

        # Deal with unmatched detections
        if len(unmatched_dets) > 0:
            for idx in unmatched_dets:
                z = z_box[idx]
                z = np.expand_dims(z, axis=0).T
                tmp_trk = Tracker()  # Create a new tracker
                x = np.array([[z[0], 0, z[1], 0, z[2], 0, z[3], 0]]).T
                tmp_trk.x_state = x
                tmp_trk.predict_only()
                xx = tmp_trk.x_state
                xx = xx.T[0].tolist()
                xx = [xx[0], xx[2], xx[4], xx[6]]
                tmp_trk.box = xx
                tmp_trk.id = self.track_id_list.popleft()  # assign an ID for the tracker

                self.tracker_list.append(tmp_trk)
                x_box.append(xx)

        # Deal with unmatched tracks
        if len(unmatched_trks) > 0:
            for trk_idx in unmatched_trks:
                tmp_trk = self.tracker_list[trk_idx]
                tmp_trk.no_losses += 1
                tmp_trk.predict_only()
                xx = tmp_trk.x_state
                xx = xx.T[0].tolist()
                xx = [xx[0], xx[2], xx[4], xx[6]]
                tmp_trk.box = xx
                x_box[trk_idx] = xx

        # The list of tracks to be annotated
        good_tracker_list = []
        objects = []
        boxs = []
        for trk in self.tracker_list:
            if (trk.hits >= self.min_hits) and (trk.no_losses <= self.max_age):
                good_tracker_list.append(trk)
                x_cv2 = trk.box
                left, top, right, bottom = x_cv2[1], x_cv2[0], x_cv2[3], x_cv2[2]
                centroid = [int(left + ((right - left) / 2)), bottom]
                objects.append([int(trk.id), centroid])
                boxs.append(x_cv2)

        deleted_tracks = filter(lambda _x: _x.no_losses > self.max_age, self.tracker_list)

        for trk in deleted_tracks:
            self.track_id_list.append(trk.id)

        self.tracker_list = [x for x in self.tracker_list if x.no_losses <= self.max_age]
        # print("object is ", str(objects))

        return img, objects, boxs

    @staticmethod
    def assign_detections_to_trackers(
            trackers,
            detections,
            iou_thrd=0.3,
    ):
        """
        From current list of trackers and new detections, output matched detections,
        un matched trackers, unmatched detections.
        """
        iou_mat = np.zeros((len(trackers), len(detections)), dtype=np.float32)
        for t, trk in enumerate(trackers):
            for d, det in enumerate(detections):
                iou_mat[t, d] = box_iou2(trk, det)

        matched_idx = linear_assignment(-iou_mat)

        unmatched_trackers, unmatched_detections = [], []
        for t, trk in enumerate(trackers):
            if t not in matched_idx[:, 0]:
                unmatched_trackers.append(t)

        for d, det in enumerate(detections):
            if d not in matched_idx[:, 1]:
                unmatched_detections.append(d)

        matches = []

        for m in matched_idx:
            if iou_mat[m[0], m[1]] < iou_thrd:
                unmatched_trackers.append(m[0])
                unmatched_detections.append(m[1])
            else:
                matches.append(m.reshape(1, 2))

        if len(matches) == 0:
            matches = np.empty((0, 2), dtype=int)
        else:
            matches = np.concatenate(matches, axis=0)

        return matches, np.array(unmatched_detections), np.array(unmatched_trackers)

    def get_line_coordinates(self):
        """
        Get the line coordinates from the deployment JSON
        """
        if not self.janus_metadata.get('metadata'):
            self.janus_metadata['metadata'] = get_extra_fields(self.device_id)

        _coordinates = [self.janus_metadata['metadata'].get(coordinate_key) for coordinate_key in
                        JanusDeploymentConstants.LINE_COORDINATES]
        _alignment = self.janus_metadata['metadata'].get(JanusDeploymentConstants.ALIGNMENT_KEY) \
            # _coordinates = [550, 200, 555, 1100]
        #
        # _alignment = "vertical"
        return _alignment, _coordinates

    def line_point_position(self, point):
        """
        Get the position of point w.r.t. the line
        :param point: point to be compared
        :return: boolean
        """
        _alignment, line_coordinates = self.get_line_coordinates()

        assert len(line_coordinates) == 4, "Line coordinates variable is invalid"
        assert len(point) == 2, "Point variable is invalid"

        _slope = (line_coordinates[3] - line_coordinates[1]) / (line_coordinates[2] - line_coordinates[0])
        _point_equation_value = point[1] - line_coordinates[1] - _slope * (point[0] - line_coordinates[0])
        if _point_equation_value > 0:
            return True
        else:
            return False

    def insert_alarm_event(self, message="MRP Missing", asset_hierarchy=""):
        data = {
            "device_instance_id": asset_hierarchy,
            "device_instance_ids": asset_hierarchy,
            "alarm_id": "alarm_configuration_212",
            "triggered_devices": [
                asset_hierarchy
            ],
            "tag_value": 5.0,
            "start_time": None,
            "end_time": "",
            "current_level": 0,
            "id": "alarm_event_1567101",
            "trigger_time": [
                {
                    "start_time": None,
                    "counter": 0
                }
            ],
            "trigger_levels": [
                {
                    "timestamp": None,
                    "notificaton_profile": [
                        {
                            "usersOrUserGroup": [
                                {
                                    "value": "access_group_100",
                                    "type": "access_group",
                                    "label": "ACC Admin"
                                }
                            ],
                            "notificationProfile": [
                                "alarm_notify_type_4"
                            ],
                            "emailIds": [],
                            "phoneNumbers": [],
                            "notificationTone": "",
                            "isNotificationToneShow": True,
                            "triggers": [
                                {
                                    "device_instance_id": None,
                                    "tags": None,
                                    "customValueType": None,
                                    "customValue": None,
                                    "counter": 1
                                }
                            ],
                            "counter": 1,
                            "enable_custom": True
                        }
                    ]
                }
            ],
            "priority": "alarm_priority_type_109",
            "tag_id": "",
            "template": "MRP Missing",
            "acknowledge": True,
            "alarmName": message,
            "alarmType": "Alarm",
            "created_by": "user_100",
            "project_id": "project_101",
            "product_encrypted": False,
            "start_time_in_epoch": None,
            "show_data_viz": True,
            "alarm_condition": message,
            "tag_id_list": [
                asset_hierarchy
            ],
            "tag_value_json": {
                asset_hierarchy: 5.0
            },
            "alarm_tag_list": [
                asset_hierarchy
            ],
            "acknowledged_at": "2022-01-06 20:51:22",
            "acknowledged_by": None
        }
        epoch = int(time.time()) * 1000
        time_string = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
        data['start_time_in_epoch'] = epoch
        data['trigger_time'][0]['start_time'] = epoch
        data['start_time'] = time_string
        data['trigger_levels'][0]['timestamp'] = time_string
        data['acknowledged_at'] = time_string
        data['id'] = f"alarm_event_{str(uuid4()).split('-')[0]}"
        self.mongo_alarm_coll.insert_one(data)

    def validate_point_position(self, point):
        """
        Validate the position of the point w.r.t. the line
        :param point: centroid
        :return: bool
        """
        _alignment, line_coordinates = self.get_line_coordinates()
        assert _alignment in [JanusDeploymentConstants.VERTICAL, JanusDeploymentConstants.HORIZONTAL], \
            "Invalid alignment variable"
        if _alignment == JanusDeploymentConstants.VERTICAL:
            # _alignment, line_coordinates = self.get_line_coordinates()
            # assert _alignment in ["horizontal", "vertical"], \
            #     "Invalid alignment variable"
            # print(point)
            # if _alignment == "vertical":
            line_y2 = line_coordinates[3]
            line_y1 = line_coordinates[1]
            if line_y1 < point[1] < line_y2 or line_y2 < point[1] < line_y1:
                return True
            else:
                return False
        else:
            line_x2 = line_coordinates[2]
            line_x1 = line_coordinates[0]
            if line_x1 < point[0] < line_x2 or line_x2 < point[0] < line_x1:

                return True
            else:
                return False

    # def verifying_cement_bag_type(self, previous_class, current_class):
    #     if(previous_class == current_class):
    #         print("bag changed")

    def update_bag_count(self, frame, detection_objects, class_name, detections):
        """
        Maintains the bag counts
        :param frame: image
        :param detection_objects: detection object having object id and centroids
        """
        # for class_name, (objectID, centroid) in zip(classes, detection_objects):
        for (object_id, det, class_detected) in zip(
                detection_objects, detections, class_name
        ):
            centroid = object_id[1]
            object_id = object_id[0]
            logger.debug(detections)
            # print(object_id)
            frame = draw_circles_on_frame(
                frame, centroid, radius=10, color=(0, 0, 255), thickness=-1
            )
            if self.validate_point_position(centroid):
                logger.debug("centroid detected")

                if not isinstance(self.initial_object_position, bool):
                    logger.debug("Initializing the initial object position")
                    # self.initial_object_position = self.line_point_position(point=centroid)
                    self.initial_object_position = True
                    logger.debug(self.initial_object_position)

                _point_position = self.line_point_position(point=centroid)
                # print("object ID is : ", str(objectID))
                logger.debug(self.uncounted_objects)

                # Check point in the same side as the initial object
                if _point_position == self.initial_object_position:
                    logger.debug("same side only")
                    # print(class_name)
                    # Check the object is not already counted
                    if object_id not in self.uncounted_objects:
                        self.uncounted_objects[object_id] = centroid

                elif object_id in self.uncounted_objects:
                    # print("************")
                    # print(class_detected)
                    self.uncounted_objects.pop(object_id, None)

                    if class_detected == "acc_gold":
                        gold_flag = self.yellow_thresholding(
                            detections, frame, class_detected
                        )
                        if gold_flag:
                            self.count_gold += 1
                        else:
                            self.count_suraksha += 1
                            # self.verifying_cement_bag_type(self.prev_class_name, "acc_gold")
                            # self.prev_class_name = "acc_gold"
                            class_detected = "acc_suraksha_power_plus"
                        (
                            mrp_result,
                            mrp_frame,
                            mrp_roi,
                        ) = self.distances(detections, frame, class_detected)
                        if mrp_result:
                            if gold_flag:
                                self.send_payload(
                                    resize_to_64_64(frame=frame),
                                    bag_type="acc_gold",
                                    message="ACC GOLD, MRP:YES",
                                    mrp_frmae=mrp_frame,
                                    mrp_roi=mrp_roi,
                                    mrp="PASS",
                                )
                                logger.info(
                                    f"Count: {self.count_gold}, Print Found: True"
                                )
                            else:
                                self.send_payload(
                                    resize_to_64_64(frame=frame),
                                    bag_type="acc_suraksha_power_plus",
                                    message="ACC SURAKSHA PP, MRP:YES",
                                    mrp_frmae=mrp_frame,
                                    mrp_roi=mrp_roi,
                                    mrp="PASS",
                                )
                                logger.info(
                                    f"Count: {self.count_suraksha}, Print Found: True"
                                )

                        else:
                            if gold_flag:
                                self.send_payload(
                                    resize_to_64_64(frame=frame),
                                    bag_type="acc_gold",
                                    message="ACC GOLD, MRP:NO",
                                    mrp_frmae=mrp_frame,
                                    mrp_roi=mrp_roi,
                                )
                                logger.info(
                                    f"Count: {self.count_gold}, Print Found: False"
                                )

                            else:
                                self.send_payload(
                                    resize_to_64_64(frame=frame),
                                    bag_type="acc_suraksha_power_plus",
                                    message="ACC SURAKSHA PP, MRP:NO",
                                    mrp_frmae=mrp_frame,
                                    mrp_roi=mrp_roi,
                                )
                                logger.info(
                                    f"Count: {self.count_suraksha}, Print Found: False"
                                )

                    elif class_detected == "acc_suraksha_power_plus":
                        self.count_suraksha += 1
                        # self.verifying_cement_bag_type(self.prev_class_name, "acc_suraksha")
                        # self.prev_class_name = "acc_suraksha"
                        logger.debug(self.count_suraksha)
                        (
                            mrp_result,
                            mrp_frame,
                            mrp_roi
                        ) = self.distances(detections, frame, class_detected)
                        if mrp_result:
                            self.send_payload(
                                resize_to_64_64(frame=frame),
                                bag_type="acc_suraksha_power_plus",
                                message="ACC SURAKSHA PP, MRP:YES",
                                mrp_frmae=mrp_frame,
                                mrp_roi=mrp_roi,
                                mrp="PASS",
                            )
                            logger.info(
                                f"Count: {self.count_suraksha}, Print Found: True"
                            )
                        else:
                            self.send_payload(
                                resize_to_64_64(frame=frame),
                                bag_type="acc_suraksha_power_plus",
                                message="ACC SURAKSHA PP, MRP:NO",
                                mrp_frmae=mrp_frame,
                                mrp_roi=mrp_roi,
                            )
                            logger.info(
                                f"Count: {self.count_suraksha}, Print Found: False"
                            )

                    elif class_detected == "ambuja_buildcem":
                        self.count_whitecem += 1
                        # self.verifying_cement_bag_type(self.prev_class_name, "ambuja_buildcem")
                        # self.prev_class_name = "ambuja_buildcem"
                        (
                            mrp_result,
                            mrp_frame,
                            mrp_roi
                        ) = self.distances(detections, frame, class_detected)
                        if mrp_result:
                            self.send_payload(
                                resize_to_64_64(frame=frame),
                                bag_type="ambuja_buildcem",
                                message="AMBUJA BUILDCEM, MRP:YES",
                                mrp_frmae=mrp_frame,
                                mrp_roi=mrp_roi,
                                mrp="PASS",
                            )
                            logger.info(
                                f"Count: {self.count_whitecem}, Print Found: True"
                            )
                        else:
                            self.send_payload(
                                resize_to_64_64(frame=frame),
                                bag_type="ambuja_buildcem",
                                message="AMBUJA BUILDCEM, MRP:NO",
                                mrp_frmae=mrp_frame,
                                mrp_roi=mrp_roi,
                            )
                            logger.info(
                                f"Count: {self.count_whitecem}, Print Found: False"
                            )

                    elif class_detected == "acc_nfr":
                        self.count_nfr += 1
                        # self.verifying_cement_bag_type(self.prev_class_name, "acc_nfr")
                        # self.prev_class_name = "acc_nfr"
                        logger.debug(self.count_nfr)
                        (
                            mrp_result,
                            mrp_frame,
                            mrp_roi,
                        ) = self.distances(detections, frame, class_detected)
                        if mrp_result:
                            self.send_payload(
                                resize_to_64_64(frame=frame),
                                bag_type="acc_nfr",
                                message="ACC NFR, MRP:YES",
                                mrp_frmae=mrp_frame,
                                mrp_roi=mrp_roi,
                                mrp="PASS",
                            )
                            logger.info(f"Count: {self.count_nfr}, Print Found: True")
                        else:
                            self.send_payload(
                                resize_to_64_64(frame=frame),
                                bag_type="acc_nfr",
                                message="ACC NFR, MRP:NO",
                                mrp_frmae=mrp_frame,
                                mrp_roi=mrp_roi,
                            )
                            logger.info(f"Count: {self.count_nfr}, Print Found: False")

                    elif class_detected == "acc_suraksha_power":
                        self.count_suraksha_power += 1
                        # self.verifying_cement_bag_type(self.prev_class_name, "acc_suraksha_power")
                        # self.prev_class_name = "acc_suraksha_power"
                        logger.debug(self.count_suraksha_power)
                        (
                            mrp_result,
                            mrp_frame,
                            mrp_roi,
                        ) = self.distances(detections, frame, class_detected)
                        if mrp_result:
                            self.send_payload(
                                resize_to_64_64(frame=frame),
                                bag_type="acc_suraksha_power",
                                message="ACC SURAKSHA POWER, MRP:YES",
                                mrp_frmae=mrp_frame,
                                mrp_roi=mrp_roi,
                                mrp="PASS",
                            )
                            logger.info(
                                f"Count: {self.count_suraksha_power}, Print Found: True"
                            )
                        else:
                            self.send_payload(
                                resize_to_64_64(frame=frame),
                                bag_type="acc_suraksha_power",
                                message="ACC SURAKSHA POWER, MRP:NO",
                                mrp_frmae=mrp_frame,
                                mrp_roi=mrp_roi,
                            )
                            logger.info(
                                f"Count: {self.count_suraksha_power}, Print Found: False"
                            )

                    elif class_detected == "acc_concrete_plus":
                        self.count_concrete_plus += 1
                        # self.verifying_cement_bag_type(self.prev_class_name, "acc_nfr")
                        # self.prev_class_name = "acc_nfr"
                        logger.debug(self.count_concrete_plus)
                        (
                            mrp_result,
                            mrp_frame,
                            mrp_roi,
                        ) = self.distances(detections, frame, class_detected)
                        if mrp_result:
                            self.send_payload(
                                resize_to_64_64(frame=frame),
                                bag_type="acc_concrete_plus",
                                message="ACC CONCRETE PLUS, MRP:YES",
                                mrp_frmae=mrp_frame,
                                mrp_roi=mrp_roi,
                                mrp="PASS",
                            )
                            logger.info(
                                f"Count: {self.count_concrete_plus}, Print Found: True"
                            )
                        else:
                            self.send_payload(
                                resize_to_64_64(frame=frame),
                                bag_type="acc_concrete_plus",
                                message="ACC CONCRETE PLUS, MRP:NO",
                                mrp_frmae=mrp_frame,
                                mrp_roi=mrp_roi,
                            )
                            logger.info(
                                f"Count: {self.count_concrete_plus}, Print Found: False"
                            )

                    elif class_detected == "ambuja_plus":
                        self.count_ambuja_plus += 1
                        # self.verifying_cement_bag_type(self.prev_class_name, "acc_nfr")
                        # self.prev_class_name = "acc_nfr"
                        logger.debug(self.count_ambuja_plus)
                        (
                            mrp_result,
                            mrp_frame,
                            mrp_roi,
                        ) = self.distances(detections, frame, class_detected)
                        self.text = "COUNT : {ambuja_count}/10".format(
                            ambuja_count=self.count_ambuja_plus
                        )

                        if mrp_result:
                            self.send_payload(
                                resize_to_64_64(frame=frame),
                                bag_type="ambuja_plus",
                                message="Ambuja PLUS, MRP:YES",
                                mrp_frmae=mrp_frame,
                                mrp_roi=mrp_roi,
                                mrp="PASS",
                            )
                            logger.info(
                                f"Count: {self.count_ambuja_plus}, Print Found: True"
                            )
                        else:
                            self.send_payload(
                                resize_to_64_64(frame=frame),
                                bag_type="ambuja_plus",
                                message="Ambuja PLUS, MRP:NO",
                                mrp_frmae=mrp_frame,
                                mrp_roi=mrp_roi,
                            )
                            logger.info(
                                f"Count: {self.count_ambuja_plus}, Print Found: False"
                            )

                    frame = draw_circles_on_frame(
                        frame, centroid, radius=10, color=(0, 255, 0), thickness=-1
                    )
                    # cv2.waitKey(0)
                    # if centroid['has_print']:
                    #     self.send_payload(resize_to_64_64(frame=frame), message='Print Detected!')
                    #     logger.info(f"Count: {self.count}, Print Found: True")
                    # else:
                    #     self.send_payload(resize_to_64_64(frame=frame), message='Print Missing!')
                    #     logger.info(f"Count: {self.count}, Print Found: False")
                else:
                    frame = draw_circles_on_frame(
                        frame, centroid, radius=10, color=(0, 255, 0), thickness=-1
                    )

        # count_text_gold = "ACC_GOLD: " + str(self.count_gold)
        # count_text_suraksha = "ACC_SURAKSHA_PLUS: " + str(self.count_suraksha)
        # count_text_whitecem = "ACC_WHITE_CEM: " + str(self.count_whitecem)
        # count_text_suraksha_power = "ACC_SURAKSHA_POWER: " + str(self.count_suraksha_power)
        # count_text_nfr = "ACC_NFR: " + str(self.count_nfr)
        # cv2.putText(frame, count_text_gold, (1300, 200), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 3,
        #             cv2.LINE_AA)
        # cv2.putText(frame, count_text_suraksha, (1300, 400), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 3,
        #             cv2.LINE_AA)
        # cv2.putText(frame, count_text_whitecem, (1300, 600), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 3,
        #             cv2.LINE_AA)
        # cv2.putText(frame, count_text_suraksha_power, (1300, 800), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 3,
        #             cv2.LINE_AA)
        # cv2.putText(frame, count_text_nfr, (1300, 1000), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 3,
        #             cv2.LINE_AA)
        # cv2.putText(frame, self.prev_class_name, (1000, 800), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 3,
        #             cv2.LINE_AA)
        return frame

    def yellow_thresholding(self, detections, frame, class_detected):
        gold_flag = False
        mrp_cord = []
        cem_bag_cord = []
        add_mrp = ""
        for det in detections:
            if det["class"] == "mrp":
                mrp_cord.append(det["points"])
            else:
                cem_bag_cord.append(det["points"])

        for c_cord in cem_bag_cord:

            bag_width = c_cord[2] - c_cord[0]
            if bag_width > 500:
                roi = frame[c_cord[1] + 40: c_cord[3], c_cord[0]: c_cord[2] - 80]
        original = roi.copy()
        # cv2.imshow("roi", roi)
        # cv2.waitKey(0)

        image = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV)

        lower = np.array([10, 70, 0], dtype="uint8")
        upper = np.array([45, 255, 255], dtype="uint8")
        mask = cv2.inRange(image, lower, upper)

        cnts = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cnts = cnts[0] if len(cnts) == 2 else cnts[1]
        # print("contours")
        # print(len(cnts))

        for c in cnts:
            x, y, w, h = cv2.boundingRect(c)
            # cv2.imshow("bag with yellow region:  {w}".format(w=w), original)
            # print(x, y, w, h)
            if w > 100:
                gold_flag = True
                cv2.rectangle(original, (x, y), (x + w, y + h), (36, 255, 12), 2)
                #
                break
        # if gold_flag:
        #     print("Final Bag: ACC Gold")
        # else:
        #     print("Final Bag: ACC SURAKSHA PP")

        return gold_flag

    def draw_line_over_image(self, frame, color=(255, 255, 255)):
        """
        Draws line over the counting line
        :param frame: frame for
        :param color:
        :return:
        """
        _alignment, line_coordinates = self.get_line_coordinates()
        assert len(line_coordinates) == 4, "Line coordinates variable is invalid"

        # return cv2.line(frame, (line_coordinates[0], line_coordinates[1]), (line_coordinates[2], line_coordinates[3]),
        #                 color, 3)

        self.drawline(
            frame,
            (line_coordinates[0], line_coordinates[1]),
            (line_coordinates[2], line_coordinates[3]),
            color,
            thickness=3,
        )
        return frame

    @staticmethod
    def drawline(img, pt1, pt2, color, thickness=1, style="dotted", gap=20):
        dist = ((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2) ** 0.5
        pts = []
        for i in np.arange(0, dist, gap):
            r = i / dist
            x = int((pt1[0] * (1 - r) + pt2[0] * r) + 0.5)
            y = int((pt1[1] * (1 - r) + pt2[1] * r) + 0.5)
            p = (x, y)
            pts.append(p)

        if style == "dotted":
            for p in pts:
                cv2.circle(img, p, thickness, color, -1)
        else:
            s = pts[0]
            e = pts[0]
            i = 0
            for p in pts:
                s = e
                e = p
                if i % 2 == 1:
                    cv2.line(img, s, e, color, thickness)
                i += 1

    def crop_polygon(self, cement_bag_img):
        pts = np.array([[100, 180], [530, 110], [555, 190], [100, 260]])

        ## (1) Crop the bounding rect
        rect = cv2.boundingRect(pts)
        x, y, w, h = rect
        croped = cement_bag_img[y: y + h, x: x + w].copy()

        ## (2) make mask
        pts = pts - pts.min(axis=0)

        mask = np.zeros(croped.shape[:2], np.uint8)
        cv2.drawContours(mask, [pts], -1, (255, 255, 255), -1, cv2.LINE_AA)

        ## (3) do bit-op
        dst = cv2.bitwise_and(croped, croped, mask=mask)

        ## (4) add the white background
        bg = np.ones_like(croped, np.uint8) * 255
        cv2.bitwise_not(bg, bg, mask=mask)
        mrp_tag = bg + dst
        # import random
        #
        # cv2.imwrite(
        #     "E:\\acc_new\\mrp_region_only\\{random_number}.jpg".format(
        #         random_number=random.randint(1, 100000)
        #     ),
        #     mrp_tag,
        # )
        return mrp_tag

    def mrp_digit_count(self, img, detected_class):
        digit_num = {"ambuja_plus": 17, "acc_gold": 13, "acc_suraksha_power_plus": 17}
        height, width = img.shape[:2]
        # print(height, width)
        # img[45: 95, 290:379] = [255, 255, 255]
        # img[25: 95, 290:500] = [255, 255, 255]
        bag_type = detected_class

        blank_image = np.zeros((300, 800, 3), np.uint8)

        blank_image[:, :] = (255, 255, 255)

        l_img = blank_image.copy()  # (600, 900, 3)

        x_offset = y_offset = 20

        l_img[y_offset: y_offset + height, x_offset: x_offset + width] = img.copy()

        # cv2.imshow("l_img", l_img)

        img = l_img

        if bag_type == "ambuja_plus":
            hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)

            h, s, v = cv2.split(hsv)

            th, threshed = cv2.threshold(v, 220, 255, cv2.THRESH_BINARY_INV)

            image = cv2.bitwise_and(img, img, mask=threshed)
            # cv2.imshow("thresholded_image", image)

            # pre-process the image by resizing it, converting it to

            # graycale, blurring it, and computing an edge map

            image = imutils.resize(image, height=150)

            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

            blurred = cv2.GaussianBlur(gray, (5, 5), 0)

            edged = cv2.Canny(blurred, 50, 200, 255)
            # cv2.imshow("edged", edged)
            thresh = cv2.threshold(
                blurred, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU
            )[1]

            # print("After Threshold")

            # cv2.imshow("thresh", thresh)

            kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))

            thresh = cv2.morphologyEx(
                thresh, cv2.MORPH_OPEN, kernel, np.ones((5, 5), np.uint8), iterations=2
            )

            thresh = cv2.morphologyEx(
                thresh, cv2.MORPH_CLOSE, kernel, None, None, 1, cv2.BORDER_REFLECT101
            )

            # cv2.imshow("thresh2", thresh)
            cnts = cv2.findContours(
                thresh.copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
            )

            cnts = imutils.grab_contours(cnts)

            cv2.drawContours(
                thresh,
                cnts,
                -1,
                (0, 0, 255),
                1,
            )

            x = self.find_chars(cnts, thresh)

        if bag_type == "acc_suraksha_power_plus":
            hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)

            h, s, v = cv2.split(hsv)

            th, threshed = cv2.threshold(v, 220, 255, cv2.THRESH_BINARY_INV)

            image = cv2.bitwise_and(img, img, mask=threshed)

            # pre-process the image by resizing it, converting it to

            # graycale, blurring it, and computing an edge map

            image = imutils.resize(image, height=150)

            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

            blurred = cv2.GaussianBlur(gray, (5, 5), 0)

            edged = cv2.Canny(blurred, 50, 200, 255)

            thresh = cv2.threshold(
                blurred, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU
            )[1]

            # print("After Threshold")

            # cv2.imshow("thresh4", thresh)

            kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))

            thresh = cv2.morphologyEx(
                thresh, cv2.MORPH_OPEN, kernel, np.ones((5, 5), np.uint8), iterations=2
            )

            thresh = cv2.morphologyEx(
                thresh, cv2.MORPH_CLOSE, kernel, None, None, 1, cv2.BORDER_REFLECT101
            )

            # print("After Morphology")

            # cv2.imshow("thresh5", thresh)

            # find contours in the thresholded image, then initialize the

            # digit contours lists

            cnts = cv2.findContours(thresh.copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)

            cnts = imutils.grab_contours(cnts)

            cv2.drawContours(
                thresh,
                cnts,
                -1,
                (0, 255, 0),
                1,
            )

            # cv2.imshow("thresh6", thresh)

            x = self.find_chars(cnts, thresh)

        if bag_type == "acc_gold":
            hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)

            h, s, v = cv2.split(hsv)

            # cv2.imshow("img", img)

            th, threshed = cv2.threshold(v, 220, 255, cv2.THRESH_BINARY_INV)

            image = cv2.bitwise_and(img, img, mask=threshed)
            # cv2.imshow("bitwise and", image)

            # pre-process the image by resizing it, converting it to

            # graycale, blurring it, and computing an edge map

            image = imutils.resize(image, height=150)

            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

            blurred = cv2.GaussianBlur(gray, (5, 5), 0)

            edged = cv2.Canny(blurred, 50, 200, 255)

            thresh = cv2.threshold(
                blurred, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU
            )[1]

            # cv2.imshow("thresh7", thresh)

            kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))

            thresh = cv2.morphologyEx(
                thresh, cv2.MORPH_OPEN, kernel, np.ones((5, 5), np.uint8), iterations=2
            )

            thresh = cv2.morphologyEx(
                thresh, cv2.MORPH_CLOSE, kernel, None, None, 1, cv2.BORDER_REFLECT101
            )

            # cv2.imshow("thresh9", thresh)

            # find contours in the thresholded image, then initialize the

            # digit contours lists

            cnts = cv2.findContours(thresh.copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)

            cnts = imutils.grab_contours(cnts)

            cv2.drawContours(
                thresh,
                cnts,
                -1,
                (0, 255, 0),
                1,
            )

            # cv2.imshow("thresh10", thresh)

            x = self.find_chars(cnts, thresh)

        # cv2.imshow("thresh11", thresh)

        if x >= (digit_num[detected_class] - 2):
            return x, True
        else:
            return x, False

    def distances(self, detections, frame, class_detected):

        mrp_cord = []
        cem_bag_cord = []
        add_mrp = ""
        mrp_roi = ""
        for det in detections:
            if det["class"] == "mrp":
                mrp_cord.append(det["points"])
            else:
                cem_bag_cord.append(det["points"])

        # if mrp_cord == []:
        #     return False, add_mrp, mrp_roi

        for c_cord in cem_bag_cord:

            bag_width = c_cord[2] - c_cord[0]
            if bag_width > 500:
                roi = frame[c_cord[1]: c_cord[3], c_cord[0]: c_cord[2]]
                # cv2.imwrite("E:\\acc_new\\masked_bag_ambuja_\\{count}.jpg".format(count = self.count), roi)
                # self.count+= 1
                h, w, _ = roi.shape
                # cv2.imshow("cement_bag", roi)
                roi_half_h = roi[int((h / 2) - 0): h, 0:w]

                roi_half_v = roi_half_h[50: int((h / 2)), 0: w - 150]
                # cv2.imshow("roi", roi)
                mrp_roi = "data:image/jpeg;base64," + base64.b64encode(
                    cv2.imencode(".jpg", roi)[1].tostring()
                ).decode("utf-8")

                extra_values = get_extra_fields(self.device_id)
                mrp_detect = extra_values.get(JanusDeploymentConstants.MRP_DETECT_KEY)
                if mrp_detect is not None and mrp_detect.lower() == "yes":
                    mrp_add = self.mrp_image(roi_half_v, class_detected)
                    # cv2.imshow("mrp_add", mrp_add)
                    if mrp_add is not None:
                    #     mrp_region = self.crop_polygon(roi)
                    #     # cv2.imshow("mrp_region_", mrp_region)
                    #     mrp_digits, mrp_status = self.mrp_digit_count(
                    #         mrp_region, class_detected
                    #     )
                    #     # cv2.waitKey(0)
                    #     if mrp_status:
                    #         mrp_check = "PASS"
                    #     else:
                    #         mrp_check = "FAIL"

                        # cv2.imshow("mrp yes", cv2.resize(mrp_add, (400, 300)))
                        self.mrp_counter = 0
                        add_mrp = "data:image/jpeg;base64," + base64.b64encode(
                            cv2.imencode(".jpg", mrp_add)[1].tostring()
                        ).decode("utf-8")
                        return True, add_mrp, mrp_roi

                    else:
                        # cv2.imshow("mrp no", mrp_add)
                        self.mrp_counter += 1
                        if self.mrp_counter >= 5:
                            self.mrp_counter = 0
                            logger.debug("activate relay")
                            self.insert_alarm_event(
                                asset_hierarchy=self.camera_details.get(
                                    "asset_hierarchy", ""
                                ),
                                message=self.camera_details.get("asset_name", "")
                                        + " - "
                                        + class_detected
                                        + " : MRP missed",
                            )
                            RelayHandler().update_relay_status(
                                self.camera_details.get("belt_relay_ep", ""),
                                dict(triggerStatus="stop"),
                            )
                            logger.debug(
                                "Stopped the relay because of 5 consecutive MRP misses"
                            )

                        return False, add_mrp, mrp_roi




    def inference(
            self,
            frame,
            classes,

    ):
        # TRT Additions start

        # dets = self.yp.predict(frame)
        result_boxes, result_scores, result_classid = self.yolo_v5_wrapper.infer(frame)
        dets = [{"points": list(points), "conf": conf, "class": self.classes.get(class_id)} for points, conf, class_id
                in
                zip(result_boxes, result_scores, result_classid)]
        # TRT Additions stop

        class_name = list()
        bboxs = []

        if dets:
            for i in dets:
                if (i["class"] in classes):
                    class_name.append(i["class"])
                    # cv2.rectangle(frame, (i["points"][0], i["points"][1]), (i["points"][2], i["points"][3]), (255, 255, 0), 2)
                    bboxs.append([i["points"][1], i["points"][0], i["points"][3], i["points"][2]])

        #        frame = cv2.rectangle(frame, (bboxs[0][0], bboxs[0][1]), (bboxs[0][2], bboxs[0][3]),(255, 255, 0) , 2)
        return bboxs, frame, dets, class_name

    def _predict(self, obj):
        class_list = ["acc_gold", "acc_suraksha_power_plus", "ambuja_buildcem", "acc_nfr", "acc_suraksha_power",
                      "acc_concrete_plus", "ambuja_plus"]
        mrp = ["mrp"]
        try:
            frame = obj['frame']

            dets, frame, _dets, class_name = self.inference(frame, class_list)

            frame = self.draw_line_over_image(frame)

            frame, objects, boxs = self.kalman_tracker(dets, frame)

            frame = self.update_bag_count(frame=frame, detection_objects=objects, class_name=class_name,
                                          detections=_dets)

            logger.debug("self.uncounted_objects --> {}".format(self.uncounted_objects))

            obj['frame'] = cv2.resize(frame, (self.config.get('FRAME_WIDTH'), self.config.
                                              get('FRAME_HEIGHT')))
        except Exception as e:
            logger.exception(f"Error: {e}", exc_info=True)
            obj['frame'] = cv2.resize(obj['frame'], (self.config.get('FRAME_WIDTH'), self.config.get('FRAME_HEIGHT')))

        return obj


    def detect_mrp(self, img, class_detected):

        # print("Finding Pixels")
        black_pixel = 0
        white_pixel = 100
        # img_bgr = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
        pixels = img.reshape(-1, 3)
        for pixel in pixels:
            if pixel[0] == pixel[1] == pixel[2] == 0:
                black_pixel += 1
            else:
                white_pixel += 1

        black_to_white = (white_pixel / black_pixel) * 100

        black_white_ratio = self.black_white_ratio_dict[
            "{class_name}".format(class_name=class_detected)
        ]
        # print("black_white_ratio")
        # print(black_white_ratio)
        # cv2.imshow("black white ratio {black_white_ratio1}.jpg".format(black_white_ratio1=str(black_to_white)),
        #            img)
        # if(black_to_white == 14.99204665959703):
        #     cv2.imwrite("black white ratio {black_white_ratio1}".format(black_white_ratio1=str(black_to_white)), img)
        if int(black_to_white) < int(black_white_ratio):
            flag = False
        else:
            # print("MRP present, finding MRP region..")
            flag = True
        return flag

    def mrp_image(self, img, class_detected):
        ROI = None
        try:
            # cv2.imshow("roi image", img)
            img_cp = img.copy()
            rgb_planes = cv2.split(img_cp)
            kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])

            result_planes = []
            for plane in rgb_planes:
                dilated_img = cv2.dilate(plane, np.ones((7, 7), np.uint8))
                bg_img = cv2.medianBlur(dilated_img, 21)
                diff_img = 255 - cv2.absdiff(plane, bg_img)
                result_planes.append(diff_img)

            result = cv2.merge(result_planes)
            # print("Original Image but cropped,and shadow removed")
            # cv2.imshow("shadow_removed", result)_predict
            shadow_removed = result.copy()

            result = cv2.filter2D(src=result, ddepth=-1, kernel=kernel)
            # cv2_imshow(result)

            img_cp = result
            # Convert BGR to HSV
            hsv = cv2.cvtColor(img_cp, cv2.COLOR_BGR2HSV)

            # define range of black color in HSV
            lower_val = np.array([0, 0, 160])
            upper_val = np.array([180, 230, 200])

            # Threshold the HSV image to get only black colors
            mask = cv2.inRange(hsv, lower_val, upper_val)

            # Bitwise-AND mask and original image
            res = cv2.bitwise_and(img_cp, img_cp, mask=mask)

            mrp_flag = self.detect_mrp(res, class_detected)
            if not mrp_flag:
                return ROI
            else:
                # cv2.imshow("res", res)
                return res

        except Exception as e:
            logger.exception(f"Error: {e}", exc_info=True)
            return ROI