import ctypes
import time
import cv2
import numpy as np
import pycuda.autoinit
import pycuda.driver as cuda
import tensorrt as trt

PLUGIN_LIBRARY = "data/libmyplugins.so"

ctypes.CDLL(PLUGIN_LIBRARY)


class YoloV5TRT(object):
    """
    description: A YOLOv5 class that warps TensorRT ops, preprocess and postprocess ops.
    """

    def __init__(self, engine_file_path, conf_thresh, iou_thresh):
        self.CONF_THRESH = conf_thresh
        self.IOU_THRESHOLD = iou_thresh
        self.ctx = cuda.Device(0).make_context()
        stream = cuda.Stream()
        TRT_LOGGER = trt.Logger(trt.Logger.INFO)
        runtime = trt.Runtime(TRT_LOGGER)

        with open(engine_file_path, "rb") as f:
            engine = runtime.deserialize_cuda_engine(f.read())
        context = engine.create_execution_context()

        host_inputs = []
        cuda_inputs = []
        host_outputs = []
        cuda_outputs = []
        bindings = []

        for binding in engine:
            size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
            dtype = trt.nptype(engine.get_binding_dtype(binding))

            host_mem = cuda.pagelocked_empty(size, dtype)
            cuda_mem = cuda.mem_alloc(host_mem.nbytes)

            bindings.append(int(cuda_mem))

            if engine.binding_is_input(binding):
                self.input_w = engine.get_binding_shape(binding)[-1]
                self.input_h = engine.get_binding_shape(binding)[-2]
                host_inputs.append(host_mem)
                cuda_inputs.append(cuda_mem)
            else:
                host_outputs.append(host_mem)
                cuda_outputs.append(cuda_mem)

        self.stream = stream
        self.context = context
        self.engine = engine
        self.host_inputs = host_inputs
        self.cuda_inputs = cuda_inputs
        self.host_outputs = host_outputs
        self.cuda_outputs = cuda_outputs
        self.bindings = bindings
        self.batch_size = engine.max_batch_size

    def infer(self, frame):

        self.ctx.push()

        stream = self.stream
        context = self.context
        engine = self.engine
        host_inputs = self.host_inputs
        cuda_inputs = self.cuda_inputs
        host_outputs = self.host_outputs
        cuda_outputs = self.cuda_outputs
        bindings = self.bindings

        input_image, image_raw, origin_h, origin_w = self.preprocess_image(frame)

        np.copyto(host_inputs[0], input_image.ravel())
        start = time.time()

        cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream)

        context.execute_async(batch_size=self.batch_size, bindings=bindings, stream_handle=stream.handle)

        cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream)

        stream.synchronize()

        self.ctx.pop()

        output = host_outputs[0]

        result_boxes, result_scores, result_classid = [], [], []
        for i in range(self.batch_size):
            result_boxes, result_scores, result_classid = self.post_process(
                output[i * 6001: (i + 1) * 6001], frame.shape[0], frame.shape[1])

            # if len(result_boxes) > 0:
            #     for i in range(len(result_boxes)):
            #         print("Result -->", result_boxes[i])
            #         print("Scores -->", result_scores[i])
            #         print("Class ID -->", result_classid[i])
        # result_boxes = list(map(int, result_boxes))
        return result_boxes, result_scores, result_classid

    def destroy(self):
        self.ctx.pop()

    def preprocess_image(self, raw_bgr_image):
        """
        description: Convert BGR image to RGB,
                  resize and pad it to target size, normalize to [0,1],
                  transform to NCHW format.
        param:
           input_image_path: str, image path
        return:
           image:  the processed image
           image_raw: the original image
           h: original height
           w: original width
        """
        image_raw = raw_bgr_image
        h, w, c = image_raw.shape
        image = cv2.cvtColor(image_raw, cv2.COLOR_BGR2RGB)

        r_w = self.input_w / w
        r_h = self.input_h / h
        if r_h > r_w:
            tw = self.input_w
            th = int(r_w * h)
            tx1 = tx2 = 0
            ty1 = int((self.input_h - th) / 2)
            ty2 = self.input_h - th - ty1
        else:
            tw = int(r_h * w)
            th = self.input_h
            tx1 = int((self.input_w - tw) / 2)
            tx2 = self.input_w - tw - tx1
            ty1 = ty2 = 0

        image = cv2.resize(image, (tw, th))

        image = cv2.copyMakeBorder(
            image, ty1, ty2, tx1, tx2, cv2.BORDER_CONSTANT, (128, 128, 128)
        )
        image = image.astype(np.float32)

        image /= 255.0

        image = np.transpose(image, [2, 0, 1])

        image = np.expand_dims(image, axis=0)

        image = np.ascontiguousarray(image)
        return image, image_raw, h, w

    def xywh2xyxy(self, origin_h, origin_w, x):
        """
        description:    Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
        param:
           origin_h:   height of original image
           origin_w:   width of original image
           x:          A boxes numpy, each row is a box [center_x, center_y, w, h]
        return:
           y:          A boxes numpy, each row is a box [x1, y1, x2, y2]
        """
        y = np.zeros_like(x)
        r_w = self.input_w / origin_w
        r_h = self.input_h / origin_h
        if r_h > r_w:
            y[:, 0] = x[:, 0] - x[:, 2] / 2
            y[:, 2] = x[:, 0] + x[:, 2] / 2
            y[:, 1] = x[:, 1] - x[:, 3] / 2 - (self.input_h - r_w * origin_h) / 2
            y[:, 3] = x[:, 1] + x[:, 3] / 2 - (self.input_h - r_w * origin_h) / 2
            y /= r_w
        else:
            y[:, 0] = x[:, 0] - x[:, 2] / 2 - (self.input_w - r_h * origin_w) / 2
            y[:, 2] = x[:, 0] + x[:, 2] / 2 - (self.input_w - r_h * origin_w) / 2
            y[:, 1] = x[:, 1] - x[:, 3] / 2
            y[:, 3] = x[:, 1] + x[:, 3] / 2
            y /= r_h

        return y

    def post_process(self, output, origin_h, origin_w):
        """
        description: postprocess the prediction
        param:
           output:     A numpy likes [num_boxes,cx,cy,w,h,conf,cls_id, cx,cy,w,h,conf,cls_id, ...]
           origin_h:   height of original image
           origin_w:   width of original image
        return:
           result_boxes: finally boxes, a boxes numpy, each row is a box [x1, y1, x2, y2]
           result_scores: finally scores, a numpy, each element is the score correspoing to box
           result_classid: finally classid, a numpy, each element is the classid correspoing to box
        """

        num = int(output[0])

        pred = np.reshape(output[1:], (-1, 6))[:num, :]

        boxes = self.non_max_suppression(pred, origin_h, origin_w, conf_thres=self.CONF_THRESH,
                                         nms_thres=self.IOU_THRESHOLD)
        result_boxes = boxes[:, :4].astype(int) if len(boxes) else np.array([])
        result_scores = boxes[:, 4] if len(boxes) else np.array([])
        result_classid = boxes[:, 5].astype(int) if len(boxes) else np.array([])
        return result_boxes, result_scores, result_classid

    def bbox_iou(self, box1, box2, x1y1x2y2=True):
        """
        description: compute the IoU of two bounding boxes
        param:
           box1: A box coordinate (can be (x1, y1, x2, y2) or (x, y, w, h))
           box2: A box coordinate (can be (x1, y1, x2, y2) or (x, y, w, h))
           x1y1x2y2: select the coordinate format
        return:
           iou: computed iou
        """
        if not x1y1x2y2:

            b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
            b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
            b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
            b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
        else:

            b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
            b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]

        inter_rect_x1 = np.maximum(b1_x1, b2_x1)
        inter_rect_y1 = np.maximum(b1_y1, b2_y1)
        inter_rect_x2 = np.minimum(b1_x2, b2_x2)
        inter_rect_y2 = np.minimum(b1_y2, b2_y2)

        inter_area = np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, None) * \
                     np.clip(inter_rect_y2 - inter_rect_y1 + 1, 0, None)

        b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
        b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)

        iou = inter_area / (b1_area + b2_area - inter_area + 1e-16)

        return iou

    def non_max_suppression(self, prediction, origin_h, origin_w, conf_thres=0.5, nms_thres=0.4):
        """
        description: Removes detections with lower object confidence score than 'conf_thres' and performs
        Non-Maximum Suppression to further filter detections.
        param:
           prediction: detections, (x1, y1, x2, y2, conf, cls_id)
           origin_h: original image height
           origin_w: original image width
           conf_thres: a confidence threshold to filter detections
           nms_thres: a iou threshold to filter detections
        return:
           boxes: output after nms with the shape (x1, y1, x2, y2, conf, cls_id)
        """

        boxes = prediction[prediction[:, 4] >= conf_thres]

        boxes[:, :4] = self.xywh2xyxy(origin_h, origin_w, boxes[:, :4])

        boxes[:, 0] = np.clip(boxes[:, 0], 0, origin_w - 1)
        boxes[:, 2] = np.clip(boxes[:, 2], 0, origin_w - 1)
        boxes[:, 1] = np.clip(boxes[:, 1], 0, origin_h - 1)
        boxes[:, 3] = np.clip(boxes[:, 3], 0, origin_h - 1)

        confs = boxes[:, 4]

        boxes = boxes[np.argsort(-confs)]

        keep_boxes = []
        while boxes.shape[0]:
            large_overlap = self.bbox_iou(np.expand_dims(boxes[0, :4], 0), boxes[:, :4]) > nms_thres
            label_match = boxes[0, -1] == boxes[:, -1]

            invalid = large_overlap & label_match
            keep_boxes += [boxes[0]]
            boxes = boxes[~invalid]
        boxes = np.stack(keep_boxes, 0) if len(keep_boxes) else np.array([])
        return boxes


class inferThread():
    def __init__(self, yolov5_wrapper):
        self.yolov5_wrapper = yolov5_wrapper

    def run(self):
        cap = cv2.VideoCapture("test2.mp4")
        ret, frame = cap.read()
        while ret:
            result_boxes, result_scores, result_classid = self.yolov5_wrapper.infer(frame)
            out = [{"points": list(points), "conf": conf, "class": class_id} for points, conf, class_id in
                   zip(result_boxes, result_scores, result_classid)]
            print(out)
            ret, frame = cap.read()


if __name__ == "__main__":

    engine_file_path = "build/yolov5.engine"

    categories = ["cement_bag"]

    yolov5_wrapper = YoloV5TRT(engine_file_path)
    try:
        inf = inferThread(yolov5_wrapper)
        inf.run()
    finally:
        yolov5_wrapper.destroy()
