import os

import mlflow
from azure.storage.blob import BlobServiceClient
from loguru import logger

from scripts.constants.app_configuration import MlflowMetaData
from scripts.constants.app_constants import MODEL_NAME

mlflow_tracking_uri = MlflowMetaData.MLFLOW_TRACKING_URI
AZURE_STORAGE_CONNECTION_STRING = MlflowMetaData.AZURE_STORAGE_CONNECTION_STRING
AZURE_STORAGE_ACCESS_KEY = MlflowMetaData.AZURE_STORAGE_ACCESS_KEY

os.environ["MLFLOW_TRACKING_USERNAME"] = MlflowMetaData.MLFLOW_TRACKING_USERNAME
os.environ["MLFLOW_TRACKING_PASSWORD"] = MlflowMetaData.MLFLOW_TRACKING_PASSWORD
os.environ["AZURE_STORAGE_CONNECTION_STRING"] = AZURE_STORAGE_CONNECTION_STRING
os.environ["AZURE_STORAGE_ACCESS_KEY"] = AZURE_STORAGE_ACCESS_KEY
mlflow.set_tracking_uri(mlflow_tracking_uri)
mlflow.set_registry_uri(mlflow_tracking_uri)
client = mlflow.tracking.MlflowClient()


class MlFlowUtil:
    @staticmethod
    def delete_artifact(run_id, parent_run_name, artifact_uri, file_path, model_name):
        logger.info(f"Deleting artifact for {run_id} under {parent_run_name}")
        container_name = artifact_uri.split("//")[-1].split('@')[0]
        mlflow_name = artifact_uri.split("//")[-1].split('@')[-1].split('/')[1]
        mlflow_id = artifact_uri.split("//")[-1].split('@')[-1].split('/')[2]
        path = f'{mlflow_name}/{mlflow_id}/{run_id}/artifacts/{file_path}/{model_name}'
        logger.info(f'Deleting artifact from path: {path}')
        blob_service_client = BlobServiceClient.from_connection_string(MlflowMetaData.AZURE_STORAGE_CONNECTION_STRING)
        container_client = blob_service_client.get_container_client(container_name)
        blob_client = container_client.get_blob_client(path)
        blob_exists = blob_client.exists()
        if blob_exists:
            logger.info(f"The blob {path} exists, so deleting it")
            blob_client.delete_blob()
        else:
            logger.info(f"The blob {path} does not exist, which means its already deleted")


class MlflowCleanUp:
    def __init__(self, experiment_name, parent_run_name, model_save_name):
        self.experiment_name = experiment_name
        self.parent_run_name = parent_run_name
        self.model_save_name = model_save_name
        self._mfu_ = MlFlowUtil()
        self.total_models_needed = int(MlflowMetaData.TOTAL_MODELS_NEEDED)
        self.model_name = MODEL_NAME
        self.model_history_key = 'tags.mlflow.log-model.history'
        self.model_parent_run_id_key = 'tags.mlflow.parentRunId'

    def check_experiment(self):
        experiment_info = mlflow.get_experiment_by_name(self.experiment_name)
        if experiment_info is None:
            logger.info(f"No experiment found with name {self.experiment_name}")
            return None
        else:
            logger.info(f"Proceeding with existing Experiment {self.experiment_name}")
        mlflow.set_experiment(experiment_name=self.experiment_name)
        experiment_info = mlflow.get_experiment_by_name(self.experiment_name)
        experiment_id = experiment_info.experiment_id
        return experiment_id

    @staticmethod
    def check_runs_data(experiment_id):
        runs_df = mlflow.search_runs(experiment_id)
        # parent_runs_df.to_csv('all-runs.csv', index=False)
        if not runs_df.empty:
            return runs_df
        else:
            logger.info('No runs found for the experiment...')
            return None

    def delete_run_model_data(self, df, run_name_mapping):
        cols = ['run_id', 'artifact_uri', 'start_time', 'end_time', 'tags.mlflow.parentRunId',
                'tags.mlflow.log-model.history']
        parent_runs = list(set(list(df[self.model_parent_run_id_key])))
        for parent_run_id in parent_runs:
            logger.info(f'Checking for Run Name {run_name_mapping[parent_run_id]}')
            temp_df = df[df[self.model_parent_run_id_key] == parent_run_id]
            temp_df = temp_df[temp_df[self.model_history_key].notna()]
            if not temp_df.empty:
                temp_df = temp_df[cols]
                total_models_present = len(temp_df)
                logger.info(f'Total models present are {total_models_present}')
                temp_df = temp_df.iloc[self.total_models_needed:]
                logger.info(f'Total models to cleanup are {len(temp_df)}')
                all_records = temp_df.to_dict('records')
                if len(all_records) > 0:
                    for record in all_records:
                        run_id = record['run_id']
                        artifact_uri = record['artifact_uri']
                        self._mfu_.delete_artifact(run_id, run_name_mapping[parent_run_id], artifact_uri,
                                                   self.model_save_name, self.model_name)
                else:
                    logger.info(f'No records to cleanup for Run {run_name_mapping[parent_run_id]}')
            else:
                logger.info(f'Nothing to cleanup for Run {run_name_mapping[parent_run_id]}')

    def start_cleanup(self):
        experiment_id = self.check_experiment()
        if experiment_id is not None:
            runs_df = self.check_runs_data(experiment_id)
            if self.model_parent_run_id_key in runs_df.columns:
                if runs_df is not None:
                    run_id_list = list(runs_df['run_id'])
                    run_name_list = list(runs_df['tags.mlflow.runName'])
                    run_name_mapping = {}
                    for i in range(len(run_id_list)):
                        run_name_mapping[run_id_list[i]] = run_name_list[i]
                    # getting runs who have a parent-id
                    df = runs_df[runs_df[self.model_parent_run_id_key].notna()]
                    # getting runs who have a model logged
                    f_df = df[df[self.model_history_key].notna()]
                    self.delete_run_model_data(f_df, run_name_mapping)
                else:
                    logger.info('No runs found for experiment, so no cleanup')
            else:
                logger.info('No parent runs found for experiment, so no cleanup')
        else:
            logger.info("Not a valid experiment...")
