from edge_engine.ai.model.modelwraper import ModelWrapper
from scripts.utils.infocenter import MongoLogger
from edge_engine.common.logsetup import logger
from scripts.utils.notify_infocenter import NotificationFilter
import cv2
import base64
import numpy as np
from imutils.video import FPS
import datetime
import traceback
import psutil
# from scripts.utils.yolo_params import YoloParams
from yolov5processor.infer import ExecuteInference
import os
from math import exp as exp
import time

try:
    import urlparse
except ImportError:
    import urllib.parse as urlparse


class SRF_Cans(ModelWrapper):

    def __init__(self, config, model_config, pubs, device_id):
        super().__init__()
        """
        init function
        """
        self.config = config["config"]
        self.device_id = device_id
        self.rtp = pubs.rtp_write
        logger.info("Loading model to the device")

        self.mongologger = MongoLogger()
        self.notify_bool = NotificationFilter(ttl_value=model_config.get('notification_ttl_value_sec'))

        self.base_model_path = 'scripts/model/'
        print("[INFO] loading yolov5-Cans Detection Model")

        self.model_detector_pth = os.path.join(self.base_model_path, "cans_yolov5s.pt")
        self.yp = ExecuteInference(weight=self.model_detector_pth)
        self.all_pred = []
        self.count = 0
        self.max_val = 0
        self.length_list = []
        self.frame_count = 0
        self.pre_time = time.time()

    def _pre_process(self, x):
        """
        Do preprocessing here, if any
        :param x: payload
        :return: payload
        """
        return x

    def _post_process(self, x):
        """
        Apply post processing here, if any
        :param x: payload
        :return: payload
        """
        self.rtp.publish(x)  # video stream
        return x

    def send_payload(self, message, frame, label, bg_color, font_color, alert_sound):
        """
        For selective sending of notification to the infocenter
        :param _emp_name: employee name
        :param _emp_temp: employee temperature recorded
        :param _emp_id: employee id
        :param message: message tp be shown
        :param temp_exceedence_check: boolean temperature excedence or not
        :param class_idx: str class mapping for mask detection
        :param temperature_flag: risky or safe str
        :param croped_face: cropped frame of face
        :return: None
        """

        payload = {"deviceId": self.device_id, "message": message,
                   "frame": 'data:image/jpeg;base64,' + base64.b64encode(
                       cv2.imencode('.jpg', frame)[1].tostring()).decode("utf-8"), "activity": label,
                   "bg_color": bg_color, "font_color": font_color, "alert_sound": alert_sound}

        self.mongologger.insert_attendance_event_to_mongo(payload)

    def _predict(self, obj):
        try:
            frame = obj['frame']

            orig_image = cv2.resize(frame, (640, 400))

            frame = self.process_frame(orig_image)
            curr_time = time.time()
            logger.info("Time for 1 frame: {}".format(curr_time - self.pre_time))
            self.pre_time = curr_time
            obj['frame'] = cv2.resize(frame, (self.config.get('FRAME_WIDTH'), self.config.get('FRAME_HEIGHT')))
            timestamp = datetime.datetime.now().replace(microsecond=0).isoformat()
            obj["timestamp"] = timestamp
            #logger.info("cpu percent--> ", psutil.cpu_percent(), "cpu count--> ", str(psutil.cpu_count()))
            #logger.info("virtual memory--> ", psutil.virtual_memory())

        except Exception as e:
            logger.exception(f"Error: {e}")
            obj['frame'] = cv2.resize(obj['frame'], (self.config.get('FRAME_WIDTH'), self.config.get('FRAME_HEIGHT')))
            traceback.print_exc()
            obj["error"] = "{}".format(e)
            obj["message"] = "{}".format("error processing frame")
            obj["status"] = False
            obj["timestamp"] = datetime.datetime.now().replace(microsecond=0).isoformat()
        return obj

    # @staticmethod
    # def update_net(net1):
    #     for input_key in net1.inputs:
    #         if len(net1.inputs[input_key].layout) == 4:
    #             input_name = input_key
    #             # logger.info("Batch size is {}".format(net1.batch_size))
    #             # net.inputs[input_key].precision = 'FP16'
    #         elif len(net1.inputs[input_key].layout) == 2:
    #             input_info_name = input_key
    #             net1.inputs[input_key].precision = 'FP32'
    #             if net1.inputs[input_key].shape[1] != 3 and net1.inputs[input_key].shape[1] != 6 or \
    #                     net1.inputs[input_key].shape[
    #                         0] != 1:
    #                 pass
    #             # logger.error('Invalid input info. Should be 3 or 6 values length.')

    #     # --------------------------- Prepare output blobs ----------------------------------------------------
    #     # logger.info('Preparing output blobs')
    #     output_info = net1.outputs[next(iter(net1.outputs.keys()))]
    #     output_info.precision = "FP32"
    #     for output_key in net1.outputs:
    #         output_name, output_info = output_key, net1.outputs[output_key]
    #         # print("output")
    #         # print(output_name, output_info)
    #     # -----------------------------------------------------------------------------------------------------

    # def check_zone(self, point, frame_size):
    #     frame_size = [frame_size[1], frame_size[0]]
    #     if point[0] <= frame_size[0] / 2 and point[1] <= frame_size[1] / 2:
    #         return 'Vilolation in Zone 1'
    #     elif (point[0] > frame_size[0] / 2 and point[1] < frame_size[1] / 2):
    #         return 'Vilolation in Zone 2'
    #     elif (point[0] > frame_size[0] / 2 and point[1] > frame_size[1] / 2):
    #         return 'Vilolation in Zone 3'
    #     else:
    #         return 'Vilolation in Zone 4'

    def send_payload(self, message, frame, label, bg_color, font_color, alert_sound):
        """
        For selective sending of notification to the infocenter
        :param _emp_name: employee name
        :param _emp_temp: employee temperature recorded
        :param _emp_id: employee id
        :param message: message tp be shown
        :param temp_exceedence_check: boolean temperature excedence or not
        :param class_idx: str class mapping for mask detection
        :param temperature_flag: risky or safe str
        :param croped_face: cropped frame of face
        :return: None
        """
        frame = cv2.resize(frame, (64, 64))
        payload = {"deviceId": self.device_id, "message": message,
                   "frame": 'data:image/jpeg;base64,' + base64.b64encode(
                       cv2.imencode('.jpg', frame)[1].tostring()).decode("utf-8"), "activity": label,
                   "bg_color": bg_color, "font_color": font_color, "alert_sound": alert_sound}

        self.mongologger.insert_attendance_event_to_mongo(payload)

    def process_frame(self, frame):

        #self.frame_count = self.frame_count + 1
        pred = []
        #if self.frame_count % 5 == 0:
        lis = self.yp.predict(frame)
        # count = 0
        # all_pred = []
        # print(lis)
        for i in lis:
            conf = "{:.2f}".format(int(i['conf'] * 100))
            if int(i['conf'] * 100) > 60:
                label = int(i['class'].item())
                det = i['points']
                if int(det[1]) > 208 and label == 0:
                    # print(i)
                    pred.append(i)
        if len(pred) == 0 and len(self.all_pred) > 0:
            self.all_pred.append(pred)
            self.length_list.append(0)
        for det in pred:
            conf = "{:.2f}".format(int(det['conf'] * 100))
            if int(det['conf'] * 100) > 60:
                label = int(det['class'].item())
                det = det['points']

                if len(pred) >= 1:
                    self.all_pred.append(pred)
                    self.length_list.append(int(len(pred)))
                    if label == 1:
                        cv2.rectangle(frame, (int(det[0]), int(det[1])), (int(det[2]), int(det[3])), (255, 255, 0), 2)
                        # cv2.putText(img=image, text="Box " + conf, org=(int(det[0]), int(det[1]) - 10), color=(0, 0, 255),
                        #             thickness=3, fontScale=1, fontFace=cv2.LINE_AA)
                    elif label == 0:
                        cv2.rectangle(frame, (int(det[0]), int(det[1])), (int(det[2]), int(det[3])), (0, 0, 255), 2)
                        # cv2.putText(img=image, text="Can " + conf, org=(int(det[0]), int(det[1]) - 10), color=(0, 0, 255),
                        #             thickness=3, fontScale=1, fontFace=cv2.LINE_AA)

        length = len(self.all_pred)
        # if length == 1:
        #     print("---->")
        #     #max_val = len(all_pred[0])
        #     pass
        # elif length > 1:
        #     try:
        #         #print(len(all_pred[0]))
        #         for i in range(0,len(all_pred)):
        #             #print("length of all_pred--> ",len(all_pred))
        #             print("each pred in all_pred-->",len(all_pred[i]))
        #             #print(all_pred[i])
        #             max_val = max(max_val,len(all_pred[i]))
        #             print("max value--> ",max_val)
        #             # val = 0
        #             # for v in range(0,10):
        #             #     val = val + len(all_pred[length - v])
        #             #     print(val)

        #     except:
        #         pass
        for i in range(0, len(self.all_pred)):
            # print("length of all_pred--> ",len(all_pred))
            #print("each pred in all_pred-->", len(all_pred[i]))

            # print(all_pred[i])
            self.max_val = max(self.max_val, len(self.all_pred[i]))
            print("max value--> ", self.max_val)
            # print(length_list)
            # temp = groupby(length_list)
            # res = max(temp,key = lambda sub: len(list(sub[1])))
            # print("max continuos--> ",res[0])
            if (length > 51):
                res = sum(self.length_list[-20:])
                if (res == 0):
                    self.count = self.count + self.max_val
                    self.all_pred = []
                    self.max_val = 0
                    self.send_payload("Cans Packed: " + str(self.count), frame, "Cans Packed", "#12e6cd", "#000d0b",
                                      "sound_1")
                    break
            # print(all_pred[1])
            # print(all_pred[2])
            # print(all_pred[3])
            # print(all_pred[4])
            # print(all_pred[length - 1])
            # if ((len(all_pred[length-6]) >= 9 ) and (len(all_pred[length-5]) < 1) and (len(all_pred[length-4]) < 1) and (len(all_pred[length-3]) < 1) and (len(all_pred[length-2]) < 1) and (len(all_pred[length-1]) < 1)):
            #     count = count + len(all_pred[0])
            #     all_pred = []
        #cv2.putText(img=frame, text="packed:- " + str(count), org=(50, 50), color=(0, 0, 255), thickness=3,
                    #fontScale=1, fontFace=cv2.LINE_AA)

        # cv2.putText(img=frame, text="Packed Cans:- " + str(self.count), org=(50, 50), color=(0, 0, 255),
        # thickness=3,
        # fontScale=1, fontFace=cv2.LINE_AA)

        return frame

