import cv2
import numpy as np
from openvino.inference_engine import IECore


class OpenVinoDetector:
    def __init__(self, confidence, model_detector_pth, model_bin):
        self.confidence = confidence
        self.sink_layer = {'0': 'conv2d_58/BiasAdd/Add', '1': 'conv2d_66/BiasAdd/Add', '2': 'conv2d_74/BiasAdd/Add'}
        self.model_bin = model_bin
        self.model_detector_pth = model_detector_pth
        self.ie = IECore()
        self.net1 = self.ie.read_network(model=self.model_detector_pth, weights=self.model_bin)
        self.input_blob = next(iter(self.net1.inputs))
        self.n, self.c, self.h, self.w = self.net1.inputs[self.input_blob].shape
        self.out_blob = next(iter(self.net1.outputs))
        self.update_net(self.net1)
        self.exec_net = self.ie.load_network(network=self.net1, device_name="CPU")

    @staticmethod
    def update_net(net1):
        for input_key in net1.inputs:
            if len(net1.inputs[input_key].layout) == 4:
                input_name = input_key
                # logger.info("Batch size is {}".format(net1.batch_size))
                # net.inputs[input_key].precision = 'FP16'
            elif len(net1.inputs[input_key].layout) == 2:
                input_info_name = input_key
                net1.inputs[input_key].precision = 'FP32'
                if net1.inputs[input_key].shape[1] != 3 and net1.inputs[input_key].shape[1] != 6 or \
                        net1.inputs[input_key].shape[
                            0] != 1:
                    pass
                # logger.error('Invalid input info. Should be 3 or 6 values length.')

        # --------------------------- Prepare output blobs ----------------------------------------------------
        # logger.info('Preparing output blobs')
        output_info = net1.outputs[next(iter(net1.outputs.keys()))]
        output_info.precision = "FP32"
        for output_key in net1.outputs:
            output_name, output_info = output_key, net1.outputs[output_key]
            # print("output")
            # print(output_name, output_info)
        # -----------------------------------------------------------------------------------------------------

    def person_detector(self, frame):
        vino_frame = frame.copy()

        images = np.ndarray(shape=(self.n, self.c, self.h, self.w))
        images_hw = []
        for i in range(self.n):
            image = vino_frame
            ih, iw = image.shape[:-1]
            images_hw.append((ih, iw))

            if (ih, iw) != (self.h, self.w):
                image = cv2.resize(image, (self.w, self.h))
                # log.warning("Image {} is resized from {} to {}".format(args.input[i], image.shape[:-1], (h, w)))
            image = image.transpose((2, 0, 1))  # Change data layout from HWC to CHW
            images[i] = image
        res = self.exec_net.infer(inputs={self.input_blob: images})
        res = res[self.out_blob]
        boxes, classes = {}, {}
        data = res[0][0]
        boxes2 = []
        for number, proposal in enumerate(data):
            if proposal[2] > 0:
                imid = np.int(proposal[0])

                ih, iw = images_hw[imid]
                label = np.int(proposal[1])
                # print(label)
                confidence = proposal[2]
                xmin = np.int(iw * proposal[3])
                ymin = np.int(ih * proposal[4])
                xmax = np.int(iw * proposal[5])
                ymax = np.int(ih * proposal[6])
                # print("[{},{}] element, prob = {:.6}    ({},{})-({},{}) batch id : {}" \
                #       .format(number, label, confidence, xmin, ymin, xmax, ymax, imid), end="")
                if proposal[2] > self.confidence:
                    # print(" WILL BE PRINTED!")
                    if not imid in boxes.keys():
                        boxes[imid] = []
                    boxes[imid].append([xmin, ymin, xmax, ymax])
                    boxes2.append([ymin, xmin, ymax, xmax])
                    if not imid in classes.keys():
                        classes[imid] = []
                    classes[imid].append(label)
        boxes2 = self.non_max_suppression_fast(np.asarray(boxes2, dtype=np.float32), 0.4)
        boxes2 = np.array(boxes2)
        return boxes2, vino_frame

    @staticmethod
    def non_max_suppression_fast(boxes, overlap_thresh):
        if len(boxes) == 0:
            return []
        if boxes.dtype.kind == "i":
            boxes = boxes.astype("float")
        pick = []
        x1 = boxes[:, 0]
        y1 = boxes[:, 1]
        x2 = boxes[:, 2]
        y2 = boxes[:, 3]
        area = (x2 - x1 + 1) * (y2 - y1 + 1)
        idxs = np.argsort(y2)
        while len(idxs) > 0:
            last = len(idxs) - 1
            i = idxs[last]
            pick.append(i)
            xx1 = np.maximum(x1[i], x1[idxs[:last]])
            yy1 = np.maximum(y1[i], y1[idxs[:last]])
            xx2 = np.minimum(x2[i], x2[idxs[:last]])
            yy2 = np.minimum(y2[i], y2[idxs[:last]])
            w = np.maximum(0, xx2 - xx1 + 1)
            h = np.maximum(0, yy2 - yy1 + 1)
            overlap = (w * h) / area[idxs[:last]]
            idxs = np.delete(idxs, np.concatenate(([last], np.where(overlap > overlap_thresh)[0])))
        return boxes[pick].astype("int")
