import os
import shutil

import cv2
from loguru import logger


class TrainTestSplit:
    """
    Split the dataset into three.
        train
        test
        valid
    """

    def __init__(self, dataset_path, job_id, validation_size=0.20, test_size=0.10):
        self.path = dataset_path
        self.folder_train_name = os.path.join(job_id, "train")
        self.folder_valid_name = os.path.join(job_id, "valid")
        self.folder_test_name = os.path.join(job_id, "test")
        self.validation_size = validation_size
        self.test_size = test_size
        if not os.path.exists(self.folder_train_name):
            os.mkdir(self.folder_train_name)
        if not os.path.exists(self.folder_valid_name):
            os.mkdir(self.folder_valid_name)
        if not os.path.exists(self.folder_test_name):
            os.mkdir(self.folder_test_name)


    def train_test_split(self):
        main_dir = os.listdir(self.path)
        list_image = []
        for each in os.listdir(self.path):
            for img in os.listdir(os.path.join(self.path, each)):
                list_image.append(cv2.imread(os.path.join(self.path, each, img)))
        count = len(list_image)

        for each_dir in main_dir:
            os.mkdir(os.path.join(self.folder_train_name, each_dir))
            os.mkdir(os.path.join(self.folder_test_name, each_dir))
            os.mkdir(os.path.join(self.folder_valid_name, each_dir))
            files = os.listdir(os.path.join(self.path, each_dir))

            train_per = round(len(files) * (1 - (self.validation_size + self.test_size)))
            valid_per = round(len(files) * self.validation_size)
            test_per = round(len(files) * self.test_size)

            for every_file in files[:train_per]:
                shutil.copyfile(os.path.join(self.path, each_dir, every_file),
                                os.path.join(self.folder_train_name, each_dir, every_file))
            for every_file in files[train_per:train_per + valid_per]:
                shutil.copyfile(os.path.join(self.path, each_dir, every_file),
                                os.path.join(self.folder_valid_name, each_dir, every_file))
            for every_file in files[train_per + valid_per:]:
                shutil.copyfile(os.path.join(self.path, each_dir, every_file),
                                os.path.join(self.folder_test_name, each_dir, every_file))

        return self.folder_train_name, self.folder_valid_name, self.folder_test_name, count
