import mlflow
from loguru import logger
import pandas as pd
import re
import os
import pytz
from datetime import datetime
from scripts.constants.app_configuration import MlFlow, ReqTimeZone
from scripts.utils.pycaret_util import PycaretUtil

mlflow_tracking_uri = MlFlow.mlflow_tracking_uri

os.environ["MLFLOW_TRACKING_USERNAME"] = MlFlow.mlflow_tracking_username
os.environ["MLFLOW_TRACKING_PASSWORD"] = MlFlow.mlflow_tracking_password
os.environ["AZURE_STORAGE_CONNECTION_STRING"] = MlFlow.azure_storage_connection_string
os.environ["AZURE_STORAGE_ACCESS_KEY"] = MlFlow.azure_storage_access_key
mlflow.set_tracking_uri(mlflow_tracking_uri)
mlflow.set_registry_uri(mlflow_tracking_uri)
client = mlflow.tracking.MlflowClient()


class ModelLoad(object):

    def model_manager(self, df, target, inv_mppt_id, city):
        try:
            experiment_id = self.create_experiment(experiment_name=MlFlow.experiment_name)
            days, latest_run_id = self.fetch_latest_model(experiment_id=experiment_id,
                                                          run_name=MlFlow.run_name + '_' + inv_mppt_id)
            if days < int(MlFlow.model_check_param):
                logger.debug(f'Using the pretrained model !')
                energy_model = self.load_model_pyfunc(
                    model_path=self.forming_loading_path(latest_run_id=latest_run_id))
            else:
                run_id = self.creating_run(experiment_id=experiment_id,
                                           run_name=city)
                with mlflow.start_run(run_id=run_id):
                    run_id = self.creating_new_nested_run(experiment_id=experiment_id, run_id=run_id,
                                                          run_name=MlFlow.run_name + '_' + inv_mppt_id,
                                                          nested=True)
                    nested_run_id = self.creating_new_nested_run(experiment_id=experiment_id,
                                                                 run_id=run_id,
                                                                 nested=True)
                    with mlflow.start_run(run_id=nested_run_id, nested=True):
                        logger.debug(f'Creating the new model !')
                        energy_model, model_name, metrics, hyper_params = \
                            PycaretUtil().get_auto_ml_model(df=df, target=target)
                        self.log_model(model=energy_model, model_name=MlFlow.model_name)
                        self.log_metrics(metrics=metrics)
                        self.log_hyper_param(hyperparameters=hyper_params)
                        self.set_tag(run_id=nested_run_id, key="algorithm", value=model_name)
        except Exception as e:
            logger.exception(str(e))

    @staticmethod
    def create_experiment(experiment_name):
        """
        Function is to create an experiment by passing experiment name
        :param experiment_name: Name of the experiment
        :return: Experiment id, Run id if any parent run is existing
        """
        try:
            experiment = mlflow.get_experiment_by_name(experiment_name)
            if experiment:
                exp_id = experiment.experiment_id
            else:
                mlflow.set_experiment(experiment_name)
                experiment = mlflow.get_experiment_by_name(experiment_name)
                exp_id = experiment.experiment_id
            return exp_id
        except Exception as e:
            logger.exception(str(e))

    @staticmethod
    def creating_run(experiment_id, run_id=None, run_name=None, nested=False):
        try:
            latest_run_id = None
            if run_id:
                df = mlflow.search_runs([experiment_id])
                run_id_list = list(df["run_id"])
                if run_id in run_id_list:
                    return run_id
                else:
                    run = client.create_run(experiment_id)
                    with mlflow.start_run(
                            experiment_id=experiment_id, run_name=run_name, run_id=run.info.run_id,
                            nested=nested) as run:
                        return run.info.run_id
            elif run_name:
                df = mlflow.search_runs([experiment_id])
                if df.empty:
                    run = client.create_run(experiment_id=experiment_id, tags={"mlflow.runName": run_name,
                                                                               "mlflow.user": MlFlow.user})
                    with mlflow.start_run(
                            experiment_id=experiment_id, run_id=run.info.run_id, run_name=run_name,
                            nested=nested) as run:
                        return run.info.run_id
                else:
                    for index, row in df.iterrows():
                        if run_name == row.get("tags.mlflow.runName", ""):
                            latest_run_id = row.get("run_id")
                    if latest_run_id:
                        return latest_run_id
                    else:
                        run = client.create_run(experiment_id=experiment_id, tags={"mlflow.runName": run_name,
                                                                                   "mlflow.user": MlFlow.user})
                        with mlflow.start_run(
                                experiment_id=experiment_id, run_id=run.info.run_id, run_name=run_name,
                                nested=nested) as run:
                            return run.info.run_id
        except Exception as e:
            logger.exception(str(e))

    @staticmethod
    def creating_new_nested_run(experiment_id, run_id=None, run_name=None ,nested=False):
        """
        Function is to create a nested run
        :param experiment_id: Experiment Id
        :param run_id: run id
        :param nested: nested Run
        :return : return nested run id
        """
        try:
            with mlflow.start_run(experiment_id=experiment_id, run_id=run_id, nested=nested):
                with mlflow.start_run(experiment_id=experiment_id, nested=True, run_name=run_name) as run:
                    return run.info.run_id
        except Exception as e:
            logger.exception(str(e))

    @staticmethod
    def log_model(model, model_name):
        """
        Function is to log the model
        :param model : model
        :param model_name : model_name
        :return: Boolean Value
        """
        try:
            mlflow.sklearn.log_model(model, model_name)
            logger.info("logged the model")
            return True
        except Exception as e:
            logger.exception(str(e))

    @staticmethod
    def log_metrics(metrics):
        """
        Function is to log the metrics
        :param metrics: dict of metrics
        :return: Boolen value
        """
        try:
            updated_metric = dict()
            for key, value in metrics.items():
                key = re.sub(r"[\([{})\]]", "", key)
                updated_metric[key] = value
            mlflow.log_metrics(updated_metric)
            logger.debug(f'logged the model metric')
            return True
        except Exception as e:
            logger.exception(str(e))

    @staticmethod
    def log_hyper_param(hyperparameters):
        """
        Function is to log the hyper params
        :param hyperparameters: dict of hyperparameters
        :return: Boolen value
        """
        try:
            mlflow.log_params(hyperparameters)
            logger.debug(f'logged model hyper parameters')
            return True
        except Exception as e:
            logger.exception(str(e))

    def fetch_latest_model(self, experiment_id, run_name):
        """
        Function is to fetch the latest run
        :param experiment_id: Experiment Id
        :return: return the difference in the days/Hours/Minutes of current and run time, latest run id
        """
        try:
            days = int(MlFlow.model_check_param) + 1
            model_history = ""
            latest_run_id = ""
            if experiment_id:
                run_id = self.get_parent_run_id(experiment_id, run_name)
                run_info = mlflow.search_runs([experiment_id],
                                              filter_string="tags.mlflow.parentRunId='{run_id}'".format(
                                                  run_id=run_id))

                if not run_info.empty:
                    for ind in run_info.index:
                        model_history, days, latest_run_id = self.check_model_existing(run_info=run_info,
                                                                                       index=ind)
                        if model_history is not None:
                            break

                    if model_history is None:
                        days = int(MlFlow.model_check_param) + 1
                        logger.info("No Model is existing with this experiment")
            return days, latest_run_id
        except Exception as e:
            logger.exception(f"Exception while fetching the latest model  - {e}")

    @staticmethod
    def get_parent_run_id(experiment_id, run_name):
        """
        Function is to fetch latest parent run id if available else latest run id
        :param experiment_id: Experiment Id
        :param run_name: Name of the run
        :return: latest parent run id
        """
        try:
            result_run_id = None
            df = mlflow.search_runs([experiment_id])
            for index, row in df.iterrows():
                parent_run_name = row.get("tags.mlflow.runName")
                if parent_run_name == run_name:
                    result_run_id = row.get("run_id")
                else:
                    logger.info(f"No Run is existing with this Experiment id - {experiment_id}")
            return result_run_id
        except Exception as e:
            logger.exception(f"Exception while fetching the latest run_id  - {e}")

    def check_model_existing(self, run_info, index):
        """
        Function is to check if model is existing or not
        :param run_info: Dataframe of run details
        :param index: index of which run from the dataframe
        :return:
        """
        try:
            model_history = None
            date_param = MlFlow.check_param
            # Difference between the current date and latest available model date
            days = self.format_mlflow_time(run_info=run_info, index=index, date_param=date_param)
            latest_run_id = run_info.loc[index, 'run_id']
            if 'tags.mlflow.log-model.history' in run_info:
                model_history = run_info['tags.mlflow.log-model.history'][index]
                if model_history:
                    model_history_list = model_history.split(":")
                    model_history = model_history_list[2].split(",")[0]
                else:
                    logger.info("No Model is existing")
            return model_history, days, latest_run_id
        except Exception as e:
            logger.exception(f"Exception while checking the model name  - {e}")

    @staticmethod
    def forming_loading_path(latest_run_id):
        """
        Function is to form the loading path
        :param latest_run_id: Run id
        :return : Return the loading path
        """
        try:
            model_name = MlFlow.model_name
            model_path = f"runs:/{latest_run_id}/{model_name}"
            return model_path
        except Exception as e:
            logger.exception(f"Exception while forming loading path  - {e}")

    @staticmethod
    def format_mlflow_time(run_info, index, date_param):
        """
        Formatting mlflow time
        :param run_info: details of the runs
        :param index: index of the run in the dataframe
        :param: What type of the date param
        :return: calculate the time difference between the mlflow time and the current time zone
        """
        try:
            df_time = run_info.copy()
            df_time['end_time'] = pd.to_datetime(df_time['end_time']).dt.tz_convert(ReqTimeZone.required_tz)
            df_time["days"] = df_time['end_time'].dt.date
            df_time["hours"] = df_time['end_time'].dt.hour
            df_required = df_time.iloc[index:index + 1:, :]
            df_required.reset_index(drop=True, inplace=True)
            last_model_time = df_required['end_time'][0].to_pydatetime()
            central_current = datetime.now(pytz.utc).astimezone(pytz.timezone(ReqTimeZone.required_tz))
            time_diff = central_current - last_model_time
            if date_param.lower() == "days":
                days_diff = int(time_diff.days)
                return days_diff
            elif date_param.lower() == "hours":
                hours_diff = int(time_diff.total_seconds() // 3600)
                return hours_diff
            elif date_param.lower() == "minutes":
                minutes_diff = int(time_diff.total_seconds() // 60)
                return minutes_diff
            else:
                logger.info("No Valid Date format was given")
        except Exception as e:
            logger.exception(f"Exception while Loading the model - {e}")

    @staticmethod
    def set_tag(run_id, key, value):
        """
        Function is to set the tag for a particular run
        :param run_id: Run id in which the tags need to be added
        :param key: Name of the key
        :param value: what needs to tagged in the value
        """
        try:
            client.set_tag(run_id=run_id, key=key, value=value)
            logger.debug(f'set the tag for the model')
        except Exception as e:
            logger.exception(f"Exception while setting the tag - {e}")

    @staticmethod
    def load_model_pyfunc(model_path):
        """
        Function is load the sklearn model
        :param model_path: path of the model
        :return: boolen value
        """
        try:
            model = mlflow.pyfunc.load_model(model_path)
            logger.info("loading the model")
            return model
        except Exception as e:
            logger.exception(str(e))
