import os
import re
import tracemalloc

import mlflow
from loguru import logger
from scripts.constants.app_configuration import MlFlow, job
from scripts.core.vgg16_training import VGG16Training

mlflow_tracking_uri = MlFlow.mlflow_tracking_uri

# os.environ["MLFLOW_TRACKING_USERNAME"] = MlFlow.mlflow_tracking_username
# os.environ["MLFLOW_TRACKING_PASSWORD"] = MlFlow.mlflow_tracking_password

os.environ["AZURE_STORAGE_CONNECTION_STRING"] = MlFlow.azure_storage_connection_string
os.environ["AZURE_STORAGE_ACCESS_KEY"] = MlFlow.azure_storage_access_key
mlflow.set_tracking_uri(mlflow_tracking_uri)
mlflow.set_registry_uri(mlflow_tracking_uri)
client = mlflow.tracking.MlflowClient()


class MlFlowUtil:

    @staticmethod
    def log_model(model, model_name):
        try:
            mlflow.sklearn.log_model(model, model_name)
            logger.info("logged the model")
            return True
        except Exception as e:
            logger.exception(str(e))

    @staticmethod
    def log_metrics(metrics):
        try:
            updated_metric = {}
            for key, value in metrics.items():
                key = re.sub(r"[([{})\]]", "", key)
                updated_metric[key] = value
            mlflow.log_metrics(updated_metric)
            return True
        except Exception as e:
            logger.exception(str(e))

    @staticmethod
    def log_hyper_param(hyper_params):
        try:
            mlflow.log_params(hyper_params)
            return True
        except Exception as e:
            logger.exception(str(e))

    @staticmethod
    def set_tag(child_run_id, key, value):
        try:
            client.set_tag(run_id=child_run_id, key=key, value=value)
        except Exception as e:
            logger.exception(f"Exception while setting the tag - {e}")


class ModelReTrainer:
    def __init__(self, experiment_name, parent_run_name, line, camera, training_path, validation_path, job_id,
                 master_config):
        self.experiment_name = experiment_name
        self.parent_run_name = parent_run_name
        self.line = f'Line_{line}'
        self.camera = f'Camera_{camera}'
        self.training_path = training_path
        self.validation_path = validation_path
        self.job_id = job_id
        self.master_config = master_config
        self._mfu_ = MlFlowUtil()
        self.current_run_name = job.job_id

    def check_create_experiment(self):
        """
        check if experiment exists, if not creates a new experiment
        :return: experiment_id of the experiment
        """
        experiment_info = mlflow.get_experiment_by_name(self.experiment_name)
        if experiment_info is None:
            logger.info(f"No experiment found with name {self.experiment_name}, So creating one")
            mlflow.create_experiment(self.experiment_name)
        else:
            logger.info(f"Proceeding with existing Experiment {self.experiment_name}")
        mlflow.set_experiment(experiment_name=self.experiment_name)
        experiment_info = mlflow.get_experiment_by_name(self.experiment_name)
        experiment_id = experiment_info.experiment_id
        return experiment_id

    def check_create_parent_run(self, experiment_id):
        """
        check if a parent run exists in the experiment, if not create it with the mentioned parent run name
        :param experiment_id: Experiment id
        :return: returns the parent run id
        """
        parent_runs_df = mlflow.search_runs(experiment_id)
        run_key = 'tags.mlflow.runName'
        if run_key in parent_runs_df.columns:
            parent_runs_df = parent_runs_df[parent_runs_df[run_key] == self.parent_run_name]
        else:
            parent_runs_df = parent_runs_df.iloc[:0]
        if not parent_runs_df.empty:
            logger.info(f"Proceeding with existing Parent Run {self.parent_run_name}")
            return list(parent_runs_df['run_id'])[0]
        # no parent run found
        logger.info(f"No Parent Run present {self.parent_run_name}")
        with mlflow.start_run(experiment_id=experiment_id, run_name=self.parent_run_name) as run:
            logger.info(f"Creating the parent Run {self.parent_run_name} with Parent Run Id {run.info.run_id}")
            return run.info.run_id

    def check_create_child_run(self, experiment_id, parent_run_id):
        """
        check if a child run exists in the experiment id under the parent run id
        if exists take the child run id which has the model saved and validate when was it lastly trained.
        Based on the lastly trained see if you have to retrain or not. if retrain create a new child run
        else if no child run exists under the parent run id of experiment id, create a new child run
        :param experiment_id: experiment id
        :param parent_run_id: parent run id
        :return: child run id, retrain flag
        """
        child_runs_df = mlflow.search_runs(experiment_id, filter_string=f"tags.mlflow.parentRunId='{parent_run_id}'")
        if not child_runs_df.empty:
            logger.info(f"Already Child runs are present for Parent Run Id {parent_run_id}")
            child_runs_df = child_runs_df[child_runs_df['tags.mlflow.runName'] == str(self.line)]
            # child_run_id, retrain = self.get_latest_child_run(experiment_id, parent_run_id, child_runs_df)
            if child_runs_df.empty:
                with mlflow.start_run(experiment_id=experiment_id, run_id=parent_run_id, nested=True):
                    with mlflow.start_run(experiment_id=experiment_id, nested=True, run_name=self.line) as child_run:
                        return child_run.info.run_id
            return list(child_runs_df['run_id'])[0]
        else:
            logger.info(f"Child runs are not present for Parent Run Id {parent_run_id}")
            with mlflow.start_run(experiment_id=experiment_id, run_id=parent_run_id, nested=True):
                with mlflow.start_run(experiment_id=experiment_id, nested=True, run_name=self.line) as child_run:
                    return child_run.info.run_id

    def create_camera_run(self, experiment_id, city_run_id, line_run_id):
        camera_child_runs_df = mlflow.search_runs(experiment_id,
                                                  filter_string=f"tags.mlflow.parentRunId='{line_run_id}'")
        if not camera_child_runs_df.empty:
            child_runs_df = camera_child_runs_df[camera_child_runs_df['tags.mlflow.runName'] == str(self.camera)]
            if child_runs_df.empty:
                with mlflow.start_run(experiment_id=experiment_id, run_id=city_run_id, nested=True):
                    with mlflow.start_run(experiment_id=experiment_id, nested=True, run_id=line_run_id):
                        with mlflow.start_run(experiment_id=experiment_id, nested=True,
                                              run_name=self.camera) as child_run:
                            return child_run.info.run_id
            return list(child_runs_df['run_id'])[0]

        with mlflow.start_run(experiment_id=experiment_id, run_id=city_run_id, nested=True):
            with mlflow.start_run(experiment_id=experiment_id, nested=True, run_id=line_run_id):
                with mlflow.start_run(experiment_id=experiment_id, nested=True, run_name=self.camera) as child_run:
                    return child_run.info.run_id

    def get_current_run(self, experiment_id, city_run_id, line_run_id, camera_run_id):
        current_child_runs_df = mlflow.search_runs(experiment_id,
                                                   filter_string=f"tags.mlflow.parentRunId='{camera_run_id}'")
        if not current_child_runs_df.empty:
            child_runs_df = current_child_runs_df[current_child_runs_df['tags.mlflow.runName'] == self.current_run_name]
            if child_runs_df.empty:
                with mlflow.start_run(experiment_id=experiment_id, run_id=city_run_id, nested=True):
                    with mlflow.start_run(experiment_id=experiment_id, nested=True, run_id=line_run_id):
                        with mlflow.start_run(experiment_id=experiment_id, nested=True, run_id=camera_run_id):
                            with mlflow.start_run(experiment_id=experiment_id, nested=True,
                                                  run_name=self.current_run_name) as child_run:
                                return child_run.info.run_id
            return list(child_runs_df['run_id'])[0]
        else:
            with mlflow.start_run(experiment_id=experiment_id, run_id=city_run_id, nested=True):
                with mlflow.start_run(experiment_id=experiment_id, nested=True, run_id=line_run_id):
                    with mlflow.start_run(experiment_id=experiment_id, nested=True, run_id=camera_run_id):
                        with mlflow.start_run(experiment_id=experiment_id, nested=True,
                                              run_name=self.current_run_name) as child_run:
                            return child_run.info.run_id

    @staticmethod
    def flatten_dict(dd, separator='_', prefix=''):
        stack = [(dd, prefix)]
        flat_dict = {}

        while stack:
            cur_dict, cur_prefix = stack.pop()
            for key, val in cur_dict.items():
                new_key = cur_prefix + separator + key if cur_prefix else key
                if isinstance(val, dict):
                    stack.append((val, new_key))
                else:
                    flat_dict[new_key] = val

        return flat_dict

    def start_training(self):
        """
        This is the Main function which will return the latest model
        :return:
        """
        experiment_id = self.check_create_experiment()
        parent_run_id = self.check_create_parent_run(experiment_id)
        child_run_id = self.check_create_child_run(experiment_id, parent_run_id)
        camera_run_id = self.create_camera_run(experiment_id=experiment_id, city_run_id=parent_run_id,
                                               line_run_id=child_run_id)
        current_run_id = self.get_current_run(experiment_id=experiment_id, city_run_id=parent_run_id,
                                              line_run_id=child_run_id, camera_run_id=camera_run_id)
        with mlflow.start_run(run_id=current_run_id):
            logger.info('Creating the new model !')

            vgg16 = VGG16Training(self.master_config, self.training_path, self.validation_path, self.job_id)
            metrics, results_directory = vgg16.train_model()

            metrics = metrics[-2]

            metrics = {'avg_loss_train': float(metrics.get('avg_loss_train')),
                       'avg_acc_train': float(metrics.get('avg_acc_train')),
                       'avg_loss_val': float(metrics.get('avg_loss_val')),
                       'avg_acc_val': float(metrics.get('avg_acc_val'))}

            logger.info(f'metrics - {metrics}')
            self.log_metrics(metrics=metrics)

            for each in os.listdir(results_directory):
                self.log_model(model_path=os.path.join(results_directory, each))

            tracemalloc.clear_traces()
            tracemalloc.get_traced_memory()
        # with mlflow.start_run(run_id=camera_run_id):
        #     self.log_metrics(metrics=metrics)
        logger.info(f"Loading the model from the child run id {camera_run_id}")
        return results_directory

    @staticmethod
    def log_model(model_path):
        """
        Function is to log the model
        :param model_path : model Path
        :return: Boolean Value
        """
        try:
            mlflow.log_artifact(model_path)
            logger.info("logged the model")
            return True
        except Exception as e:
            logger.exception(str(e))

    @staticmethod
    def log_metrics(metrics):
        """
        Function is to log the metrics
        :param metrics: dict of metrics
        :return: Boolen value
        """
        try:
            updated_metric = {}
            for key, value in metrics.items():
                key = re.sub("[\([{})\]]", "", key)
                updated_metric[key] = value
            mlflow.log_metrics(updated_metric)
            logger.info('logged the model metric')
            return True
        except Exception as e:
            logger.exception(str(e))

    @staticmethod
    def log_hyper_param(hyperparameters):
        """
        Function is to log the hyper params
        :param hyperparameters: dict of hyperparameters
        :return: Boolen value
        """
        try:
            mlflow.log_params(hyperparameters)
            logger.debug('logged model hyper parameters')
            return True
        except Exception as e:
            logger.exception(str(e))

    @staticmethod
    def set_tag(run_id, params):
        """
        Function is to set the tag for a particular run
        :param run_id: Run id in which the tags need to be added
        :param key: Name of the key
        :param value: what needs to tagged in the value
        """
        try:
            for i, k in zip(params.keys(), params.values()):
                client.set_tag(run_id=run_id, key=i, value=k)
            logger.debug('set the tag for the model')
        except Exception as e:
            logger.exception(f"Exception while setting the tag - {e}")
