import cv2
from edge_engine.common.logsetup import logger
from edge_engine.ai.model.modelwraper import ModelWrapper
import json
import os
import time
from scripts.utils.openvino_utils import OpenVinoDetector
from scripts.utils.infocenter import InfoCenter
try:
    import urlparse
except ImportError:
    import urllib.parse as urlparse


class Welspun_Stitch_Detection(ModelWrapper):

    def __init__(self, config, pubs, device_id):
        super().__init__()
        self.config = config["config"]
        self.type = config['inputConf']['sourceType']
        with open('assests/welspun.json', 'r') as f:
            self.dets = json.loads(f.read())
        self.rtp = pubs.rtp_write
        self.ic = InfoCenter(device_id=device_id)
        logger.info("Loading model to the device")
        self.base_model_path = 'assests/'
        print("[INFO] loading Stitch Detection Model")
        # open-vino
        self.model_detector_pth = os.path.join(self.base_model_path, "resnet34.xml")
        self.model_bin = os.path.splitext(self.model_detector_pth)[0] + ".bin"
        self.ov = OpenVinoDetector(model_detector_pth=self.model_detector_pth, model_bin=self.model_bin)
        self.skip_timer = None
    def _pre_process(self, x):
        """
        Do preprocessing here, if any
        :param x: payload
        :return: payload
        """
        return x

    def _post_process(self, x):
        """
        Apply post processing here, if any
        :param x: payload
        :return: payload
        """
        self.rtp.publish(x)  # video stream
        return x

    def _predict(self, x):
        """Implement core mask_model inference code here"""
        try:
            time.sleep(0.05)
            frame = x['frame']
            f_copy = frame.copy()
            frame_id = x['frameId']
            prob = self._type(self.type, frame, frame_id)
            if float(prob) > 0.90 and frame_id % 40 == 0:
                cv2.putText(f_copy, text="Stitch Detected with Probability :" + str(prob[0]), org=(20, 20),
                            color=(0, 0, 255),
                            thickness=5,
                            fontScale=5, fontFace=cv2.LINE_AA)
                self.ic.send_payload(frame=f_copy,
                                     message="Stich Detected",
                                     alert_sound="sound_1")
            x['frame'] = cv2.resize(f_copy, (self.config.get('FRAME_WIDTH'), self.config.get('FRAME_HEIGHT')))
            return x
        except Exception as e:
            logger.error(f"Error: {e}", exc_info=True)
            x['frame'] = cv2.resize(x['frame'], (self.config.get('FRAME_WIDTH'), self.config.get('FRAME_HEIGHT')))
            return x

    def _type(self, video_type, frame, frame_id):
        if video_type not in ['videofile']:
            return self.ov.stitch_detector(frame)
        return self.dets[int(frame_id)][str(frame_id)]['detections']
