from __future__ import print_function, division

import json
import shutil
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from loguru import logger
from pymongo import MongoClient
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torchvision import datasets, models, transforms
from torch.utils.data.dataloader import DataLoader
import time
import os
import copy

from scripts.constants.app_configuration import Mongo, job, StatusMessage


class VGG16Training:
    """
    Entire code for vgg16
    """

    def __init__(self, master_config, training_path, validation_path, job_id):
        self.map_50_90 = []
        self.map_50 = []
        self.recall = []
        self.precision = []
        self.mongo_client = MongoClient(Mongo.mongo_uri)[Mongo.mongo_db]
        self.clsloss = []
        self.objloss = []
        self.bbloss = []
        self.master_config = master_config
        self.pbar = 0
        self.mloss = 0
        self.current_epoch = 0
        self.status = 0
        self.classes = {'nc': len(master_config['classes'].split(',')), 'names': master_config['classes'].split(','),
                        'train': 'train/', 'val': 'valid/'}
        self.epoch = 0
        self.avg_loss = []
        self.avg_acc = []
        self.avg_loss_val = []
        self.avg_acc_val = []
        self.training_path = training_path
        self.validation_path = validation_path
        self.data_dir = job_id
        self.image_datasets = self.storing_image_datasets()
        self.dataset_sizes = self.get_dataset_size(self.image_datasets)
        self.inputs, self.classes = next(iter(self.dataloaders("train")))
        self.class_names = self.image_datasets['train'].classes
        self.vgg16 = models.vgg16(pretrained=True)
        for param in self.vgg16.features.parameters():
            param.require_grad = False
        self.use_gpu = torch.cuda.is_available()
        if self.use_gpu:
            logger.info("Using CUDA")
            self.vgg16.cuda()  # .cuda() will move everything to the GPU side
        else:
            logger.info("Not Using CUDA")
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer_ft = optim.SGD(self.vgg16.parameters(), lr=0.001, momentum=0.9)
        self.exp_lr_scheduler = lr_scheduler.StepLR(self.optimizer_ft, step_size=7, gamma=0.1)
        self.num_epochs = self.master_config['training_params'].get("epochs")



    def update_epoch(self, epoch_result, epochs_total):
        """
        Update each epoch after its finished
        """
        logger.info("Updating Epoch!!!")
        self.epoch = int(epoch_result.get('Epoch'))
        self.avg_loss.append(epoch_result.get('avg_loss_train'))
        self.avg_acc.append(epoch_result.get('avg_acc_train'))
        self.avg_loss_val.append(epoch_result.get('avg_loss_val'))
        self.avg_acc_val.append(epoch_result.get('avg_acc_val'))
        self.mongo_client[job.job_collection].update_one({'job_id': job.job_id}, {
            '$set':
                {"job_status": StatusMessage.tr_in_progress,
                 'progress': float(
                     35.0 + ((55 / epochs_total) * int(self.epoch)))
                    , 'avg_loss_train': self.avg_loss, 'avg_acc_train': self.avg_acc,
                 'avg_loss_val': self.avg_loss_val,
                 'avg_acc_val': self.avg_acc_val}},
                                                         upsert=False)

    def data_augmentation_transforms(self, name):
        """
        Returns a composed data augmentation transformation pipeline.

        Returns:
            A torchvision.transforms.Compose object representing the data augmentation pipeline.
        """
        if name == "train":
            return transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ])
        if name == "valid":
            return transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
            ])
        # if name == self.test:
        #     return transforms.Compose([
        #         transforms.Resize(256),
        #         transforms.CenterCrop(224),
        #         transforms.ToTensor(),
        #     ])

    def storing_image_datasets(self):
        image_datasets = {}  # Initialize an empty dictionary to store the datasets
        for each in os.listdir(self.data_dir):
            dataset = datasets.ImageFolder(os.path.join(self.data_dir, each),
                                           transform=self.data_augmentation_transforms(each))
            image_datasets[each] = dataset  # Store the dataset in the dictionary using 'train' or 'valid' as keys
        return image_datasets

    def dataloaders(self, x, batch_size=8, shuffle=True, num_workers=4):
        """
        Create data loaders for training and validation datasets.

        """
        dataloaders = {}
        dataloaders[x] = DataLoader(
            self.image_datasets[x], batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
        )
        return dataloaders[x]

    def get_dataset_size(self, image_datasets):
        dataset_sizes = {x: len(image_datasets[x]) for x in os.listdir(self.data_dir)}
        return dataset_sizes

    # def imshow(self, inp, title=None):
    #     inp = inp.numpy().transpose((1, 2, 0))
    #     plt.axis('off')
    #     plt.imshow(inp)
    #     if title is not None:
    #         plt.title(title)
    #     plt.pause(0)
    #
    # def show_databatch(self, inputs, classes):
    #     out = torchvision.utils.make_grid(inputs)
    #     self.imshow(out, title=[self.class_names[x] for x in classes])

    def train_model(self):
        """
        The code for training a model
        """
        logger.info('Initiating Training!!')
        since = time.time()
        best_model_wts = copy.deepcopy(self.vgg16.state_dict())
        best_acc = 0.0

        avg_loss = 0
        avg_acc = 0
        avg_loss_val = 0
        avg_acc_val = 0

        train_batches = len(self.dataloaders("train"))
        val_batches = len(self.dataloaders("valid"))
        results_list = []

        for epoch in range(self.num_epochs):
            logger.info("Epoch {}/{}".format(epoch, self.num_epochs))
            logger.info('-' * 10)

            loss_train = 0
            loss_val = 0
            acc_train = 0
            acc_val = 0

            self.vgg16.train(True)

            for i, data in enumerate(self.dataloaders("train")):
                if i % 100 == 0:
                    logger.info("Training batch {}/{}".format(i, train_batches / 2), end='', flush=True)

                # Use half training dataset
                if i >= train_batches / 2:
                    break

                inputs, labels = data

                if self.use_gpu:
                    inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
                else:
                    inputs, labels = Variable(inputs), Variable(labels)

                self.optimizer_ft.zero_grad()

                outputs = self.vgg16(inputs)

                _, preds = torch.max(outputs.data, 1)
                loss = self.criterion(outputs, labels)

                loss.backward()
                self.optimizer_ft.step()

                loss_train += loss.data
                acc_train += torch.sum(preds == labels.data)

                del inputs, labels, outputs, preds
                torch.cuda.empty_cache()

            # * 2 as we only used half of the dataset
            avg_loss = loss_train * 2 / self.dataset_sizes['train']
            avg_acc = acc_train * 2 / self.dataset_sizes['train']

            self.vgg16.train(False)
            self.vgg16.eval()

            for i, data in enumerate(self.dataloaders("valid")):
                if i % 100 == 0:
                    logger.info("\rValidation batch {}/{}".format(i, val_batches), end='', flush=True)

                inputs, labels = data

                if self.use_gpu:
                    inputs, labels = Variable(inputs.cuda(), volatile=True), Variable(labels.cuda(), volatile=True)
                else:
                    inputs, labels = Variable(inputs, volatile=True), Variable(labels, volatile=True)

                self.optimizer_ft.zero_grad()

                outputs = self.vgg16(inputs)

                _, preds = torch.max(outputs.data, 1)
                loss = self.criterion(outputs, labels)

                loss_val += loss.data
                acc_val += torch.sum(preds == labels.data)

                del inputs, labels, outputs, preds
                torch.cuda.empty_cache()

            avg_loss_val = loss_val / self.dataset_sizes['valid']
            avg_acc_val = acc_val / self.dataset_sizes['valid']

            logger.info("Epoch {} result: ".format(epoch))
            logger.info("avg_loss_train: {:.4f}".format(avg_loss))
            logger.info("avg_acc_train: {:.4f}".format(avg_acc))
            logger.info("avg_loss_val: {:.4f}".format(avg_loss_val))
            logger.info("avg_acc_val: {:.4f}".format(avg_acc_val))
            logger.info('-' * 10)

            # Initialize an empty dictionary to store the results for each epoch

            epoch_result = {
                "Epoch": format(epoch),
                "avg_loss_train": format(avg_loss),
                "avg_acc_train": format(avg_acc),
                "avg_loss_val": format(avg_loss_val),
                "avg_acc_val": format(avg_acc_val)
            }

            self.update_epoch(epoch_result, self.num_epochs)

            results_list.append(epoch_result)

            # Print the final dictionary

            if avg_acc_val > best_acc:
                best_acc = avg_acc_val
                best_model_wts = copy.deepcopy(self.vgg16.state_dict())

        elapsed_time = time.time() - since

        logger.info("Training completed in {:.0f}m {:.0f}s".format(elapsed_time // 60, elapsed_time % 60))
        logger.info("Best acc: {:.4f}".format(best_acc))

        training_result = {"Training completed in": "{:.0f}m {:.0f}s".format(elapsed_time // 60, elapsed_time % 60),
                           "Best acc": "{:.4f}".format(best_acc)}

        self.mongo_client[job.job_collection].update_one({'job_id': job.job_id}, {
            '$set': {"job_status": StatusMessage.tr_completed, 'progress': 85.0}}, upsert=False)

        results_list.append(training_result)



        self.vgg16.load_state_dict(best_model_wts)
        if os.path.exists("vgg_results"):
            shutil.rmtree("vgg_results")
            logger.info('Deleting existing directory!!')

        result_directory = Path("vgg_results")
        if not os.path.exists(result_directory):
            os.mkdir(result_directory)
        torch.save(self.vgg16, 'vgg_results/VGG16_model_complete.pt')
        torch.save(self.vgg16.state_dict(), 'vgg_results/VGG16_model_State_dict.pt')
        results_json_file = os.path.join(result_directory, 'output_data.json')


        with open(results_json_file, 'w') as file:
            json.dump(results_list, file)

        return results_list, result_directory
