from edge_engine.ai.model.modelwraper import ModelWrapper
from scripts.utils.infocenter import MongoLogger
from edge_engine.common.logsetup import logger
from scripts.utils.notify_infocenter import NotificationFilter
import cv2
import base64
import numpy as np
from imutils.video import FPS
import datetime
import traceback
# from scripts.utils.yolo_params import YoloParams
from yolov5processor.infer import ExecuteInference
import os
from math import exp as exp
import time

try:
    import urlparse
except ImportError:
    import urllib.parse as urlparse


class fruits_model(ModelWrapper):

    def __init__(self, config, model_config, pubs, device_id):
        super().__init__()
        """
        init function
        """
        self.config = config["config"]
        self.device_id = device_id
        self.rtp = pubs.rtp_write
        logger.info("Loading model to the device")

        self.mongologger = MongoLogger()
        self.notify_bool = NotificationFilter(ttl_value=model_config.get('notification_ttl_value_sec'))

        self.base_model_path = 'scripts/model/'
        print("[INFO] loading yolov5-Fruits Detection Model")

        self.model_detector_pth = os.path.join(self.base_model_path, "fruits_300.pt")
        self.yp = ExecuteInference(weight=self.model_detector_pth)
        # self.all_pred = []
        # self.count = 0
        # self.max_val = 0
        # self.length_list = []
        # self.frame_count = 0
        # self.pre_time = time.time()

    def _pre_process(self, x):
        """
        Do preprocessing here, if any
        :param x: payload
        :return: payload
        """
        return x

    def _post_process(self, x):
        """
        Apply post processing here, if any
        :param x: payload
        :return: payload
        """
        self.rtp.publish(x)  # video stream
        return x

    def send_payload(self, message, frame, label, bg_color, font_color, alert_sound):
        """
        For selective sending of notification to the infocenter
        :param _emp_name: employee name
        :param _emp_temp: employee temperature recorded
        :param _emp_id: employee id
        :param message: message tp be shown
        :param temp_exceedence_check: boolean temperature excedence or not
        :param class_idx: str class mapping for mask detection
        :param temperature_flag: risky or safe str
        :param croped_face: cropped frame of face
        :return: None
        """

        payload = {"deviceId": self.device_id, "message": message,
                   "frame": 'data:image/jpeg;base64,' + base64.b64encode(
                       cv2.imencode('.jpg', frame)[1].tostring()).decode("utf-8"), "activity": label,
                   "bg_color": bg_color, "font_color": font_color, "alert_sound": alert_sound}

        self.mongologger.insert_attendance_event_to_mongo(payload)

    def _predict(self, obj):
        try:
            frame = obj['frame']

            orig_image = cv2.resize(frame, (640, 400))

            frame = self.process_frame(orig_image)
            curr_time = time.time()
            #logger.info("Time for 1 frame: {}".format(curr_time - self.pre_time))
            #self.pre_time = curr_time
            obj['frame'] = cv2.resize(frame, (self.config.get('FRAME_WIDTH'), self.config.get('FRAME_HEIGHT')))
            #timestamp = datetime.datetime.now().replace(microsecond=0).isoformat()
            #obj["timestamp"] = timestamp
            #logger.info("cpu percent--> ", psutil.cpu_percent(), "cpu count--> ", str(psutil.cpu_count()))
            #logger.info("virtual memory--> ", psutil.virtual_memory())

        except Exception as e:
            logger.exception(f"Error: {e}")
            obj['frame'] = cv2.resize(obj['frame'], (self.config.get('FRAME_WIDTH'), self.config.get('FRAME_HEIGHT')))
            traceback.print_exc()
            obj["error"] = "{}".format(e)
            obj["message"] = "{}".format("error processing frame")
            obj["status"] = False
            obj["timestamp"] = datetime.datetime.now().replace(microsecond=0).isoformat()
        return obj

    # @staticmethod
    # def update_net(net1):
    #     for input_key in net1.inputs:
    #         if len(net1.inputs[input_key].layout) == 4:
    #             input_name = input_key
    #             # logger.info("Batch size is {}".format(net1.batch_size))
    #             # net.inputs[input_key].precision = 'FP16'
    #         elif len(net1.inputs[input_key].layout) == 2:
    #             input_info_name = input_key
    #             net1.inputs[input_key].precision = 'FP32'
    #             if net1.inputs[input_key].shape[1] != 3 and net1.inputs[input_key].shape[1] != 6 or \
    #                     net1.inputs[input_key].shape[
    #                         0] != 1:
    #                 pass
    #             # logger.error('Invalid input info. Should be 3 or 6 values length.')

    #     # --------------------------- Prepare output blobs ----------------------------------------------------
    #     # logger.info('Preparing output blobs')
    #     output_info = net1.outputs[next(iter(net1.outputs.keys()))]
    #     output_info.precision = "FP32"
    #     for output_key in net1.outputs:
    #         output_name, output_info = output_key, net1.outputs[output_key]
    #         # print("output")
    #         # print(output_name, output_info)
    #     # -----------------------------------------------------------------------------------------------------

    # def check_zone(self, point, frame_size):
    #     frame_size = [frame_size[1], frame_size[0]]
    #     if point[0] <= frame_size[0] / 2 and point[1] <= frame_size[1] / 2:
    #         return 'Vilolation in Zone 1'
    #     elif (point[0] > frame_size[0] / 2 and point[1] < frame_size[1] / 2):
    #         return 'Vilolation in Zone 2'
    #     elif (point[0] > frame_size[0] / 2 and point[1] > frame_size[1] / 2):
    #         return 'Vilolation in Zone 3'
    #     else:
    #         return 'Vilolation in Zone 4'

    def send_payload(self, message, frame, label, bg_color, font_color, alert_sound):
        """
        For selective sending of notification to the infocenter
        :param _emp_name: employee name
        :param _emp_temp: employee temperature recorded
        :param _emp_id: employee id
        :param message: message tp be shown
        :param temp_exceedence_check: boolean temperature excedence or not
        :param class_idx: str class mapping for mask detection
        :param temperature_flag: risky or safe str
        :param croped_face: cropped frame of face
        :return: None
        """
        frame = cv2.resize(frame, (64, 64))
        payload = {"deviceId": self.device_id, "message": message,
                   "frame": 'data:image/jpeg;base64,' + base64.b64encode(
                       cv2.imencode('.jpg', frame)[1].tostring()).decode("utf-8"), "activity": label,
                   "bg_color": bg_color, "font_color": font_color, "alert_sound": alert_sound}

        self.mongologger.insert_attendance_event_to_mongo(payload)

    def process_frame(self, frame):
        pred = self.yp.predict(frame)
        for det in pred:
            conf = "{:.2f}".format(int(det['conf'] * 100))
            if (int(det['conf'] * 100) > 60):
                label = int(det['class'].item())
                # print(label.item())
                det = det['points']
                if (label == 0):
                    cv2.rectangle(frame, (int(det[0]), int(det[1])), (int(det[2]), int(det[3])), (0, 0, 255), 2)
                    cv2.putText(img=frame, text="Apple: Red Delicious " + conf, org=(int(det[0]), int(det[1]) - 10),
                                color=(0, 0, 255), thickness=2,
                                fontScale=0.5, fontFace=cv2.LINE_AA)
                elif (label == 1):
                    cv2.rectangle(frame, (int(det[0]), int(det[1])), (int(det[2]), int(det[3])), (255, 0, 255), 2)
                    cv2.putText(img=frame, text="Apple: Gaya " + conf, org=(int(det[0]), int(det[1]) - 10),
                                color=(255, 0, 255),
                                thickness=2,
                                fontScale=0.5, fontFace=cv2.LINE_AA)
                elif (label == 2):
                    cv2.rectangle(frame, (int(det[0]), int(det[1])), (int(det[2]), int(det[3])), (0, 255, 0), 2)
                    cv2.putText(img=frame, text="Apple: Granny Smith " + conf, org=(int(det[0]), int(det[1]) - 10),
                                color=(0, 255, 0),
                                thickness=2,
                                fontScale=0.5, fontFace=cv2.LINE_AA)
                elif (label == 3):
                    cv2.rectangle(frame, (int(det[0]), int(det[1])), (int(det[2]), int(det[3])), (51, 87, 255), 2)
                    cv2.putText(img=frame, text="Orange " + conf, org=(int(det[0]), int(det[1]) - 10),
                                color=(51, 87, 225),
                                thickness=2,
                                fontScale=0.5, fontFace=cv2.LINE_AA)
                elif (label == 4):
                    cv2.rectangle(frame, (int(det[0]), int(det[1])), (int(det[2]), int(det[3])), (0, 195, 255), 2)
                    cv2.putText(img=frame, text="Mango " + conf, org=(int(det[0]), int(det[1]) - 10),
                                color=(0, 195, 255),
                                thickness=2,
                                fontScale=0.5, fontFace=cv2.LINE_AA)
                elif (label == 5):
                    cv2.rectangle(frame, (int(det[0]), int(det[1])), (int(det[2]), int(det[3])), (0, 0, 128), 2)
                    cv2.putText(img=frame, text="Kiwi Actinidia Deliciosa " + conf, org=(int(det[0]), int(det[1]) - 10),
                                color=(0, 0, 128),
                                thickness=2,
                                fontScale=0.5, fontFace=cv2.LINE_AA)
                elif (label == 6):
                    cv2.rectangle(frame, (int(det[0]), int(det[1])), (int(det[2]), int(det[3])), (0, 128, 0), 2)
                    cv2.putText(img=frame, text="Kiwi Hardy" + conf, org=(int(det[0]), int(det[1]) - 10),
                                color=(0, 128, 0),
                                thickness=2,
                                fontScale=0.5, fontFace=cv2.LINE_AA)
        return frame

