import os
import shutil

import cv2
import numpy as np
from loguru import logger
from pathlib import Path
import random


class DataAugmentation:
    """
    Handles with various augmentations for dataset.
    """

    def __init__(self, functions):
        self.functions = functions


        # [AugmentImage({each_prop['property']: each_prop}) for each_prop in functions]
        # for item in functions:
        #     self.values = tuple([item['value']])
        #     print(self.values)


    def fill(self, img, h, w):
        img = cv2.resize(img, (h, w), cv2.INTER_CUBIC)
        return img

    def horizontal_shift(self, img):
        for augmentation in self.functions:
            if "horizontal_shift" == augmentation['property']:
                ratio = augmentation['value']
        if ratio > 1 or ratio < 0:
            print('Value should be less than 1 and greater than 0')
            return img
        ratio = random.uniform(-ratio, ratio)
        h, w = img.shape[:2]
        to_shift = w * ratio
        if ratio > 0:
            img = img[:, :int(w - to_shift), :]
        if ratio < 0:
            img = img[:, int(-1 * to_shift):, :]
        img = self.fill(img, h, w)
        return img

    def vertical_shift(self, img):
        for augmentation in self.functions:
            if "vertical_shift" == augmentation['property']:
                ratio = augmentation['value']
        if ratio > 1 or ratio < 0:
            print('Value should be less than 1 and greater than 0')
            return img
        ratio = random.uniform(-ratio, ratio)
        h, w = img.shape[:2]
        to_shift = h * ratio
        if ratio > 0:
            img = img[:int(h - to_shift), :, :]
        if ratio < 0:
            img = img[int(-1 * to_shift):, :, :]
        img = self.fill(img, h, w)
        return img

    def brightness(self, img):
        for augmentation in self.functions:
            if "brightness" == augmentation['property']:
                low = augmentation['value']
        value = random.uniform(low, low + 2.5)
        hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
        hsv = np.array(hsv, dtype=np.float64)
        hsv[:, :, 1] = hsv[:, :, 1] * value
        hsv[:, :, 1][hsv[:, :, 1] > 255] = 255
        hsv[:, :, 2] = hsv[:, :, 2] * value
        hsv[:, :, 2][hsv[:, :, 2] > 255] = 255
        hsv = np.array(hsv, dtype=np.uint8)
        img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
        return img

    def channel_shift(self, img):
        for augmentation in self.functions:
            if "channel_shift" == augmentation['property']:
                value = augmentation['value']
        value = int(random.uniform(-value, value))
        img = img + value
        img[:, :, :][img[:, :, :] > 255] = 255
        img[:, :, :][img[:, :, :] < 0] = 0
        img = img.astype(np.uint8)
        return img

    def zoom(self, img):
        for augmentation in self.functions:
            if "zoom" == augmentation['property']:
                value = augmentation['value']
        if value > 1 or value < 0:
            print('Value for zoom should be less than 1 and greater than 0')
            return img
        value = random.uniform(value, 1)
        h, w = img.shape[:2]
        h_taken = int(value * h)
        w_taken = int(value * w)
        h_start = random.randint(0, h - h_taken)
        w_start = random.randint(0, w - w_taken)
        img = img[h_start:h_start + h_taken, w_start:w_start + w_taken, :]
        img = self.fill(img, h, w)
        return img

    def horizontal_flip(self, img):
        for augmentation in self.functions:
            if "horizontal_flip" == augmentation['property']:
                flag = augmentation['value']
        if flag:
            return cv2.flip(img, 1)
        else:
            return img

    def vertical_flip(self, img):
        for augmentation in self.functions:
            if "vertical_flip" == augmentation['property']:
                flag = augmentation['value']
        if flag:
            return cv2.flip(img, 0)
        else:
            return img

    def rotation(self, img):
        for augmentation in self.functions:
            if "rotation" == augmentation['property']:
                angle = augmentation['value']
        angle = int(random.uniform(-angle, angle))
        h, w = img.shape[:2]
        M = cv2.getRotationMatrix2D((int(w / 2), int(h / 2)), angle, 1)
        img = cv2.warpAffine(img, M, (w, h))
        return img

    def process(self, annotation_directory, post_process_directory):
        assert os.path.exists(annotation_directory)
        if not os.path.exists(post_process_directory):
            os.mkdir(post_process_directory)
            logger.info(f"Path: {post_process_directory} does not exist, creating one now!")
        for each_file in os.listdir(annotation_directory):
            filename, file_extension = os.path.splitext(os.path.join(annotation_directory, each_file))

            if file_extension in ['.jpg', '.jpeg', '.png']:
                image = cv2.imread(os.path.join(annotation_directory, each_file))
                multi_images = (
                    self.horizontal_shift(image), self.vertical_shift(image), self.brightness(image),
                    self.zoom(image), self.channel_shift(image), self.horizontal_flip(image),
                    self.vertical_flip(image), self.rotation(image))

                _file_name = 0
                for each_element in multi_images:
                    image = each_element

                    cv2.imwrite(
                        os.path.join(post_process_directory, f"{each_file[:-4]}" + "_" + f"{_file_name}" + ".jpg"),
                        image)
                    _file_name = _file_name + 1

    def combine_dataset(self, augmented_dataset_path, unaugmented_dataset_path, cls):
        for each_file in os.listdir(os.path.join(unaugmented_dataset_path, cls)):
            shutil.copy(os.path.join(unaugmented_dataset_path, cls, each_file),
                        os.path.join(augmented_dataset_path, cls))
