import os
from pathlib import Path

import cv2
from uuid import uuid4
from multiprocessing.pool import ThreadPool as Pool

from albumentations import Compose, BboxParams
from albumentations.augmentations.geometric.rotate import Rotate
from albumentations.augmentations.geometric.transforms import Affine
from albumentations.augmentations.transforms import (
    HorizontalFlip, VerticalFlip, ToGray, HueSaturationValue, RandomBrightnessContrast, Blur, GaussNoise)


class AugmentImage:
    def __init__(self, funs):

        self.transformations = list()
        if 'vertical_flip' in funs:
            assert type(funs['vertical_flip']), list
            for each_prop in funs['vertical_flip']:
                if each_prop['property'] == 'vertical_flip':
                    self.transformations.append(VerticalFlip(p=1))

        if 'horizontal_flip' in funs:

            assert type(funs['horizontal_flip']), list
            for each_prop in funs['horizontal_flip']:

                if each_prop['property'] == 'horizontal_flip':
                    self.transformations.append(HorizontalFlip(p=1))

        if 'rotation' in funs:
            assert type(funs['rotation']), list
            for each_prop in funs['rotation']:
                if each_prop['property'] == 'rotation':
                    self.transformations.append(Rotate(
                        limit=int(45 * each_prop['value']),
                        interpolation=1,
                        border_mode=0,
                        value=None,
                        mask_value=None,
                        always_apply=False,
                        p=0.7))

        if 'grayscale' in funs:
            assert type(funs['grayscale']), list
            for each_prop in funs['grayscale']:
                if each_prop['property'] == 'probability':
                    self.transformations.append(
                        ToGray(p=each_prop['value']))

        if 'hue' in funs:
            assert type(funs['hue']), list
            for each_prop in funs['hue']:
                if each_prop['property'] == 'hue':
                    self.transformations.append(
                        HueSaturationValue(
                            hue_shift_limit=int(180 * each_prop['value']),
                            sat_shift_limit=0,
                            val_shift_limit=0,
                            always_apply=False,
                            p=0.7))

        if 'saturation' in funs:
            assert type(funs['saturation']), list
            for each_prop in funs['saturation']:
                if each_prop['property'] == 'saturation':
                    self.transformations.append(
                        HueSaturationValue(
                            hue_shift_limit=0,
                            sat_shift_limit=int(255 * each_prop['value']),
                            val_shift_limit=0,
                            always_apply=False,
                            p=0.7))

        if 'brightness' in funs:
            assert type(funs['brightness']), list
            for each_prop in funs['brightness']:
                if each_prop['property'] == 'brightness':
                    self.transformations.append(
                        RandomBrightnessContrast(
                            brightness_limit=int(30 * each_prop['value']),
                            contrast_limit=0,
                            brightness_by_max=True,
                            always_apply=True,
                            p=0.7))

        if 'exposure' in funs:
            assert type(funs['exposure']), list
            for each_prop in funs['exposure']:
                if each_prop['property'] == 'exposure':
                    self.transformations.append(
                        HueSaturationValue(
                            hue_shift_limit=0,
                            sat_shift_limit=0,
                            val_shift_limit=int(255 * each_prop['value']),
                            always_apply=False,
                            p=0.7))

        if 'blur' in funs:
            assert type(funs['blur']), list
            for each_prop in funs['blur']:
                if each_prop['property'] == 'blur':
                    self.transformations.append(
                        Blur(blur_limit=int(150 * each_prop['value']),
                             always_apply=False,
                             p=1))

        if 'noise' in funs:
            assert type(funs['noise']), list
            for each_prop in funs['noise']:
                if each_prop['property'] == 'noise':
                    self.transformations.append(
                        GaussNoise(
                            var_limit=int(50000 * each_prop['value']),
                            mean=0,
                            per_channel=True,
                            always_apply=True,
                            p=0.7))

        if 'horizontal_shear' in funs:
            assert type(funs['horizontal_shear']), list
            for each_prop in funs['horizontal_shear']:
                if each_prop['property'] == 'horizontal_shear':
                    self.transformations.append(
                        Affine(
                            scale=None,
                            translate_percent=None,
                            translate_px=None,
                            rotate=None,
                            shear={"x": int(-45 * each_prop['value']), "y": 0},
                            interpolation=0,
                            mask_interpolation=0,
                            cval=0,
                            cval_mask=0,
                            mode=0,
                            fit_output=False,
                            always_apply=False,
                            p=0.9))

                    self.transformations.append(
                        Affine(
                            scale=None,
                            translate_percent=None,
                            translate_px=None,
                            rotate=None,
                            shear={"x": int(45 * each_prop['value']), "y": 0},
                            interpolation=0,
                            mask_interpolation=0,
                            cval=0,
                            cval_mask=0,
                            mode=0,
                            fit_output=False,
                            always_apply=False,
                            p=0.9))

        if 'vertical_shear' in funs:
            assert type(funs['vertical_shear']), list
            for each_prop in funs['vertical_shear']:
                if each_prop['property'] == 'vertical_shear':
                    self.transformations.append(
                        Affine(
                            scale=None,
                            translate_percent=None,
                            translate_px=None,
                            rotate=None,
                            shear={"x": 0, "y": int(-45 * each_prop['value'])},
                            interpolation=0,
                            mask_interpolation=0,
                            cval=0,
                            cval_mask=0,
                            mode=0,
                            fit_output=False,
                            always_apply=False,
                            p=0.9))

                    self.transformations.append(
                        Affine(
                            scale=None,
                            translate_percent=None,
                            translate_px=None,
                            rotate=None,
                            shear={"x": 0, "y": int(45 * each_prop['value'])},
                            interpolation=0,
                            mask_interpolation=0,
                            cval=0,
                            cval_mask=0,
                            mode=0,
                            fit_output=False,
                            always_apply=False,
                            p=0.9))

        self.transformer = Compose(self.transformations,
                                   bbox_params=BboxParams(
                                       format='yolo',
                                       label_fields=['category_ids']))

    def __call__(self, image, bounding_boxes, category_ids):

        if self.transformations:
            transformed = self.transformer(
                image=image,
                bboxes=bounding_boxes,
                category_ids=category_ids)
            # immmgg = self.visualize(
            #     transformed['image'],
            #     self.conv_2_coco(transformed['image'], transformed['bboxes']),
            #     transformed['category_ids'],
            #     {0: 'cement_bag'},
            # )
            # return [immmgg, transformed['bboxes'], category_ids]
            return [transformed['image'], transformed['bboxes'], category_ids]
        return image

    @staticmethod
    def conv_2_coco(img, *args):
        import numpy as np
        x_cen, y_cen, w, h = args[0][0]
        ih, iw, _ = img.shape
        x_min = (x_cen - w / 2) * iw
        y_min = (y_cen - h / 2) * ih
        wid = w * iw
        hei = h * ih
        return np.array([(x_min, y_min, wid, hei)])

    BOX_COLOR = (255, 0, 0)

    @staticmethod
    def visualize_bbox(img, bbox, class_name, color=BOX_COLOR, thickness=2):
        BOX_COLOR = (255, 0, 0)  # Red
        TEXT_COLOR = (255, 255, 255)  # White
        """Visualizes a single bounding box on the image"""
        x_min, y_min, w, h = bbox
        x_min, x_max, y_min, y_max = int(x_min), int(x_min + w), int(y_min), int(y_min + h)

        cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness)

        ((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1)
        cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), BOX_COLOR, -1)
        cv2.putText(
            img,
            text=class_name,
            org=(x_min, y_min - int(0.3 * text_height)),
            fontFace=cv2.FONT_HERSHEY_SIMPLEX,
            fontScale=0.35,
            color=TEXT_COLOR,
            lineType=cv2.LINE_AA,
        )
        return img

    def visualize(self, image, bboxes, category_ids, category_id_to_name):
        img = image.copy()
        for bbox, category_id in zip(bboxes, category_ids):
            class_name = category_id_to_name[category_id]
            img = self.visualize_bbox(img, bbox, class_name)
        return img


class ChainAugmentations:
    def __init__(self, functions):
        self.transformers = [AugmentImage({each_prop: functions[each_prop]}) for each_prop in functions]

    def __call__(self, image, bounding_boxes, category_ids):
        return [each_transformers(image, bounding_boxes, category_ids) for each_transformers in self.transformers]


class AugmentationManager:
    def __init__(self, functions=None):
        if functions is None:
            functions = {'blur': [{'property': 'blur','value': 0.3}],'noise': [{'property': 'noise','value': 0.1}]}
        self.augment = ChainAugmentations(functions=functions)
        self.pool = Pool(12)

    def run_augmentations(self, annotation_directory, post_process_directory, filename, each_file):
        print("inside augmentations")
        with open(os.path.join(filename + ".txt"), 'r') as f:
            annotations = [e for e in f.read().split('\n') if e]
            bounding_boxes = list()
            category_ids = list()
            for each_annotations in annotations:
                print("inside spliting")
                split_annotations = each_annotations.split(' ')
                category_ids.append(split_annotations[0])

                bounding_boxes.append([float(e) for e in split_annotations[1:]])
        print("after splitting")
        image = cv2.imread(os.path.join(annotation_directory, each_file))
        multi_images = self.augment(image, bounding_boxes, category_ids)
        for each_element in multi_images:
            image, bounding_boxes, category_id = each_element
            _file_name = Path(each_file).stem + str(uuid4())

            cv2.imwrite(os.path.join(post_process_directory, _file_name + ".jpg"), image)

            new_annotations = list()
            for _x in zip(category_id, bounding_boxes):
                lis = list(_x[1:][0])
                lis.insert(0, _x[0])
                lis = [str(e) for e in lis]
                print("lis")
                print(lis)
                new_annotations.append(' '.join(lis))
                new_annotations.append('\n')
                print("new ann list")
                print(new_annotations)

            with open(os.path.join(post_process_directory, _file_name + '.txt'), 'w') as _f:
                _f.writelines(new_annotations)
                print("new ann")
                print(new_annotations)

    def process(self, annotation_directory, post_process_directory):
        print("annotation dir")
        print(annotation_directory)
        print("post process")
        print(post_process_directory)
        assert os.path.exists(annotation_directory)
        if not os.path.exists(post_process_directory):
            os.mkdir(post_process_directory)
            print(f"Path: {post_process_directory} does not exist, creating one now!")

        for each_file in os.listdir(annotation_directory):

            filename, file_extension = os.path.splitext(os.path.join(annotation_directory, each_file))

            if file_extension in ['.jpg', '.jpeg', '.png']:
                print(os.path.join(filename + ".txt"))
                if os.path.isfile(os.path.join(filename + ".txt")):
                    print("txt file exists")
                    self.pool.apply_async(self.run_augmentations,
                                          (annotation_directory, post_process_directory, filename, each_file))
        self.pool.close()
        self.pool.join()
#
#
# if __name__ == '__main__':
#     augment_manager = AugmentationManager(functions={
#         "blur": [
#             {
#                 "property": "blur",
#                 "value": 0.3
#             }
#         ],
#         "noise": [
#             {
#                 "property": "noise",
#                 "value": 0.1
#             }
#         ],
#         "horizontal_flip": [
#             {
#                 "property": "horizontal_flip",
#             }
#         ],
#         "grayscale": [
#             {
#                 "property": "grayscale",
#                 "value": 0.3
#             }
#         ],
#         "hue": [
#             {
#                 "property": "hue",
#                 "value": 1
#             }
#         ],
#         "saturation": [
#             {
#                 "property": "saturation",
#                 "value": 1
#             }
#         ],
#         "brightness": [
#             {
#                 "property": "brightness",
#                 "value": 0.8
#             }
#         ],
#         "exposure": [
#             {
#                 "property": "exposure",
#                 "value": 0.9
#             }
#         ],
#         "vertical_flip": [
#             {
#                 "property": "vertical_flip",
#                 "value": 0
#             }
#         ]
#
#
#     })
#     augment_manager.process(
#         annotation_directory=r"C:\Users\sikhin.vc\PycharmProjects\training_pipeline\jk_data\unaugmented_dataset",
#         post_process_directory=r"C:\Users\sikhin.vc\PycharmProjects\training_pipeline\jk_data\augmented_dataset")
