import os
import re
from datetime import datetime

import mlflow
import pandas as pd
import pytz
from dateutil import tz
from loguru import logger
from azure.storage.blob import BlobServiceClient

from scripts.constants.app_configuration import REQUIRED_TZ, MlflowMetaData

mlflow_tracking_uri = MlflowMetaData.MLFLOW_TRACKING_URI
AZURE_STORAGE_CONNECTION_STRING = MlflowMetaData.AZURE_STORAGE_CONNECTION_STRING
AZURE_STORAGE_ACCESS_KEY = MlflowMetaData.AZURE_STORAGE_ACCESS_KEY

os.environ["MLFLOW_TRACKING_USERNAME"] = MlflowMetaData.MLFLOW_TRACKING_USERNAME
os.environ["MLFLOW_TRACKING_PASSWORD"] = MlflowMetaData.MLFLOW_TRACKING_PASSWORD
os.environ["AZURE_STORAGE_CONNECTION_STRING"] = AZURE_STORAGE_CONNECTION_STRING
os.environ["AZURE_STORAGE_ACCESS_KEY"] = AZURE_STORAGE_ACCESS_KEY
mlflow.set_tracking_uri(mlflow_tracking_uri)
mlflow.set_registry_uri(mlflow_tracking_uri)
client = mlflow.tracking.MlflowClient()


class MlFlowUtil:
    @staticmethod
    def get_last_run_time_diff(run_info):
        try:
            logger.info(f"Checking the time difference in days")
            df_time = run_info.copy()
            df_time['end_time'] = pd.to_datetime(df_time['end_time']).dt.tz_convert(REQUIRED_TZ)
            to_zone = tz.gettz(REQUIRED_TZ)
            df_time["days"] = df_time['end_time'].dt.date
            df_time["hours"] = df_time['end_time'].dt.hour
            last_model_time = list(df_time['end_time'])[0].to_pydatetime()
            today = datetime.now(pytz.utc)
            central_current = today.astimezone(to_zone)
            time_diff = central_current - last_model_time
            return int(time_diff.days)
        except Exception as e:
            logger.warning(f"Exception while checking the last run time of the model - {e}")
            return 0

    @staticmethod
    def log_model(model, model_name):
        try:
            mlflow.sklearn.log_model(model, model_name)
            logger.info("logged the model")
            return True
        except Exception as e:
            logger.exception(str(e))

    @staticmethod
    def log_metrics(metrics):
        try:
            updated_metric = {}
            for key, value in metrics.items():
                key = re.sub(r"[([{})\]]", "", key)
                updated_metric[key] = value
            mlflow.log_metrics(updated_metric)
            return True
        except Exception as e:
            logger.exception(str(e))

    @staticmethod
    def log_hyper_param(hyper_params):
        try:
            mlflow.log_params(hyper_params)
            return True
        except Exception as e:
            logger.exception(str(e))

    @staticmethod
    def set_tag(child_run_id, key, value):
        try:
            client.set_tag(run_id=child_run_id, key=key, value=value)
        except Exception as e:
            logger.exception(f"Exception while setting the tag - {e}")

    @staticmethod
    def remove_file_if_exists(path):
        if os.path.exists(path):
            os.remove(path)

    @staticmethod
    def delete_artifact(run_id, artifact_uri, file_path):
        container_name = artifact_uri.split("//")[-1].split('@')[0]
        mlflow_name = artifact_uri.split("//")[-1].split('@')[-1].split('/')[1]
        mlflow_id = artifact_uri.split("//")[-1].split('@')[-1].split('/')[2]
        path = f'{mlflow_name}/{mlflow_id}/{run_id}/artifacts/{file_path}/requirements.txt'
        logger.info(f'Deleting artifact from path: {path}')
        blob_service_client = BlobServiceClient.from_connection_string(MlflowMetaData.AZURE_STORAGE_CONNECTION_STRING)
        container_client = blob_service_client.get_container_client(container_name)
        blob_client = container_client.get_blob_client(path)
        blob_client.delete_blob()


class MlflowCleanUp:
    def __init__(self, experiment_name, parent_run_name, model_save_name):
        self.experiment_name = experiment_name
        self.parent_run_name = parent_run_name
        self.model_save_name = model_save_name
        self._mfu_ = MlFlowUtil()
        self.model_age = int(MlflowMetaData.MODEL_AGE_IN_DAYS)

    def check_experiment(self):
        experiment_info = mlflow.get_experiment_by_name(self.experiment_name)
        if experiment_info is None:
            logger.info(f"No experiment found with name {self.experiment_name}")
            return None
        else:
            logger.info(f"Proceeding with existing Experiment {self.experiment_name}")
        mlflow.set_experiment(experiment_name=self.experiment_name)
        experiment_info = mlflow.get_experiment_by_name(self.experiment_name)
        experiment_id = experiment_info.experiment_id
        return experiment_id

    @staticmethod
    def check_parent_run(experiment_id):
        parent_runs_df = mlflow.search_runs(experiment_id)
        parent_runs_df.to_csv('all-runs.csv',index=False)
        all_parent_runs = list(parent_runs_df['tags.mlflow.parentRunId'])
        print(all_parent_runs)
        if not parent_runs_df.empty:
            parent_key = 'tags.mlflow.parentRunId'
            parent_runs_df[parent_key].fillna('parent', inplace=True)
            df = parent_runs_df[parent_runs_df[parent_key] == 'parent']
            if not df.empty:
                run_key = 'run_id'
                logger.info('Parent runs found for the experiment')
                parent_runs = list(df[run_key])
                return {'parent_runs': parent_runs, 'df': parent_runs_df}
            else:
                logger.info('No parent runs found for the experiment')
                return None
        else:
            logger.info('No runs found for the experiment...')
            return None

    def get_nested_runs(self, exp_id, parent_run_id):
        """
        Recursively iterate through all child runs of the specified parent run and return a nested dictionary
        with the parent run ID as the key and a dictionary of child run IDs and their nested child runs as the value.
        """

        # Recursively get all child runs
        child_runs_dict = {}
        for child_run_id in mlflow.search_runs([exp_id], f"tags.mlflow.parentRunId = '{parent_run_id}'")["run_id"]:
            nested_child_runs = self.get_nested_runs(exp_id, child_run_id)
            child_runs_dict[child_run_id] = nested_child_runs if nested_child_runs else None

        # Construct the nested dictionary with parent run ID as key and dictionary of child runs as value
        return child_runs_dict

    def check_under_parent_run(self, experiment_id, parent_run_id, df):
        logger.info(f"Getting all runs under parent run {parent_run_id}")
        child_runs = self.get_nested_runs(experiment_id, parent_run_id)
        print(child_runs)
        print(df.columns)
        cols = ['run_id', 'artifact_uri', 'start_time', 'end_time', 'tags.mlflow.parentRunId',
                'tags.mlflow.log-model.history']
        df = df[cols]
        print(df.columns)
        # for child_run in child_runs:
        #     temp_df = df[df['run_id'] == child_run]
        #     artifact_uri = list(temp_df['artifact_uri'])[0]
        #     self._mfu_.delete_artifact(child_run, artifact_uri, self.model_save_name)

    def start_cleanup(self):
        experiment_id = self.check_experiment()
        if experiment_id is not None:
            parent_runs_dict = self.check_parent_run(experiment_id)
            if parent_runs_dict is not None:
                parent_runs = parent_runs_dict['parent_runs']
                df = parent_runs_dict['df']
                logger.info(f'Total parent runs found are {len(parent_runs)}')
                for run in parent_runs:
                    self.check_under_parent_run(experiment_id, run, df)
            else:
                logger.info('No parent runs found for experiment, so no cleanup')
                return False
        else:
            logger.info("Not a valid experiment...")
            return False

    def check_existing_model_retrain(self, latest_child_run_id, child_run_info, retrain):
        """
        If retrain is True, it returns true as retraining is required.
        If retrain is False, it checks the time difference between the last child run and the current time and returns
        true or false depending on the time difference
        :param latest_child_run_id: last child run id
        :param child_run_info: last child run info
        :param retrain: retrain flag
        :return: final retrain flag
        """
        if retrain:
            logger.info("Retraining Needed...")
            return True
        else:
            logger.info(f"Already trained model is present, checking the age of the existing model of run id "
                        f"{latest_child_run_id}")
            time_diff = self._mfu_.get_last_run_time_diff(child_run_info)
            return False

    def forming_loading_path(self, latest_run_id):
        """
        Creates the path from the child run id
        :param latest_run_id: latest child run id
        :return: the path to the model
        """
        try:
            model_name = self.model_save_name
            return f"runs:/{latest_run_id}/{model_name}"
        except Exception as e:
            logger.exception(f"Exception while forming loading path  - {e}")
