if __name__ == "__main__":
    from dotenv import load_dotenv

    load_dotenv(dotenv_path='config.env')

import os.path
import shutil
import sys
import traceback
import warnings
from pathlib import Path
from pymongo import MongoClient
from loguru import logger

from scripts.constants.app_configuration import Mongo, job, StatusMessage
from utils.other.blob_util import BlobUtil
from utils.other.mlflow_vam_util import ModelReTrainer
from utils.preprocessing.dataset_augmentation import DataAugmentation
from utils.preprocessing.dataset_extraction import ExtractDataset
from utils.preprocessing.train_test_split_classifier import TrainTestSplit

warnings.filterwarnings("ignore")
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]  # YOLOv5 root directory
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))  # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))  # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))


def main():
    try:
        mmongo = MongoClient(Mongo.mongo_uri)
        mongo_client = mmongo[Mongo.mongo_db]
        job_id = job.job_id
        job_collection = job.job_collection
        master_conf = mongo_client[job_collection].find_one({"job_id": job_id})
        os.mkdir(job_id)
        root_data_path = os.path.join(job_id, 'temp')
        raw_dataset_path = os.path.join(root_data_path, "raw_dataset")
        extracted_dataset_path = os.path.join(root_data_path, "exatracted_dataset")
        unaugmented_dataset_path = os.path.join(root_data_path, "unaugmented_dataset")
        augmented_dataset_path = os.path.join(root_data_path, "augmented_dataset")
        split_dataset_path = os.path.join(root_data_path, "split_dataset")
        if not os.path.exists(root_data_path):
            os.mkdir(root_data_path)
        if not os.path.exists(raw_dataset_path):
            os.mkdir(raw_dataset_path)
        if not os.path.exists(extracted_dataset_path):
            os.mkdir(extracted_dataset_path)
        if not os.path.exists(unaugmented_dataset_path):
            os.mkdir(unaugmented_dataset_path)
        if not os.path.exists(augmented_dataset_path):
            os.mkdir(augmented_dataset_path)
        if not os.path.exists(split_dataset_path):
            os.mkdir(split_dataset_path)
        logger.info("Starting Vision Data Accusation Pipeline")
        blob_list = master_conf['blob_path']
        logger.info(f'Starting accusation of Stream {blob_list}')
        blob_util = BlobUtil(master_conf['project_id'], master_conf['Site_id'],
                             master_conf['line_id'], master_conf['camera_id'])
        mongo_client[job.job_collection].update_one({'job_id': job_id}, {
            '$set': {"job_status": StatusMessage.tr_data_download_started, 'progress': 5.0}}, upsert=False)
        if blob_download_status := blob_util.download(master_conf['blob_path'], raw_dataset_path):
            logger.info("Data Downloaded Successfully!")
            mongo_client[job.job_collection].update_one({'job_id': job_id}, {
                '$set':
                    {"job_status": StatusMessage.tr_data_downloaded,
                     'progress': 10.0}},
                                                        upsert=False)
        extract = ExtractDataset()
        #raw_dataset_path = Path("data_serin")
        #extract.extract_ds(os.path.join(raw_dataset_path, "serin_dataset.zip"), extracted_dataset_path)
        extract.extract_ds(os.path.join(raw_dataset_path, "dataset.zip"), extracted_dataset_path)
        extract.move_files(extracted_dataset_path, unaugmented_dataset_path)
        class_names = os.listdir(unaugmented_dataset_path)

        logger.info('Initiating Augmentations of Dataset!!')
        general_augmentation_list = master_conf['general_augmentations'].keys()
        augmentation_functions = []
        for augmentation in general_augmentation_list:
            augmentation_value =master_conf['general_augmentations'][augmentation]
            augmentation_functions.extend(
                {"property": augmentation, "value": val} for val in augmentation_value)
            logger.info(
                f'Starting Data augmentation with type {augmentation} and Value {augmentation_value}')


        classification_data_aug_manager = DataAugmentation(augmentation_functions)

        for cls in class_names:
            classwise_unaugment_dataset_path = os.path.join(unaugmented_dataset_path, cls)
            os.mkdir(os.path.join(augmented_dataset_path, cls))
            classification_data_aug_manager.process(annotation_directory=classwise_unaugment_dataset_path,
                                                    post_process_directory=os.path.join(augmented_dataset_path,
                                                                                        cls))

        mongo_client[job.job_collection].update_one({'job_id': job_id}, {
            '$set':
                {"job_status": StatusMessage.tr_augmentation_completed,
                 'progress': 20.0}},
                                                    upsert=False)
        classification_data_aug_manager.combine_dataset(augmented_dataset_path=augmented_dataset_path,
                                                        unaugmented_dataset_path=unaugmented_dataset_path, cls=cls)

        train_test_split_obj = TrainTestSplit(augmented_dataset_path, job_id,
                                              validation_size=0.20,
                                              test_size=0.10)
        training_path, validation_path, test_path, dataset_size = train_test_split_obj.train_test_split()


        shutil.rmtree(f'./{root_data_path}')
        master_conf['Dataset_size'] = dataset_size
        mongo_client[job.job_collection].update_one({'job_id': job_id}, {
            '$set':
                {"job_status": StatusMessage.tr_init_training,
                 'progress': 35.0}},
                                                    upsert=False)

        model_trainer = ModelReTrainer(master_conf['project_id'],
                                       master_conf['Site_id'],
                                       master_conf['line_id'],
                                       master_conf['camera_id'], training_path,
                                       validation_path, job_id, master_conf)
        result_directory = model_trainer.start_training()
        mongo_client[job.job_collection].update_one({'job_id': job_id}, {
            '$set': {"job_status": StatusMessage.tr_data_upload_started, 'progress': 90.0}}, upsert=False)

        blob_util = BlobUtil(master_conf['project_id'], master_conf['Site_id'],
                             master_conf['line_id'], master_conf['camera_id'])

        logger.info("Uploading Files!!!")
        for each_file in os.listdir(result_directory):
            if not blob_util.upload(result_directory, f'{each_file}'):
                logger.info("Data Upload Unsuccessfull!")
        logger.info("Data Uploaded Successfully!")
        mongo_client[job.job_collection].update_one({'job_id': job_id}, {
            '$set':
                {"job_status": StatusMessage.tr_data_uploaded,
                 'progress': 100.0}},
                                                    upsert=False)

        shutil.rmtree(job_id)
        shutil.rmtree(result_directory)

    except Exception as e:

        traceback.print_exc()
        shutil.rmtree(job_id)
        if os.path.exists(result_directory):
            shutil.rmtree(result_directory)
        mongo_client[job.job_collection].update_one({'job_id': job_id}, {
            '$set':
                {"job_status": StatusMessage.tr_failed,
                 'progress': 35.0}},
                                                    upsert=False)
        mmongo.close()


if __name__ == '__main__':
    logger.info("Attempting to Start Training Pipeline!")
    try:
        main()
    except Exception as e:
        traceback.print_exc()
        logger.error(f"Failed to Start Training Pipeline! : {str(e)}")
