from loguru import logger
from pycaret import regression
from scripts.constants.app_configuration import MlFlow, PycaretParams


class PycaretUtil:
    def __init__(self):
        self.model_list = PycaretParams.model_list.split(",")
        self.selected_metric = PycaretParams.selected_metric
        self.hyperparameter_tuning_method = PycaretParams.hyperparameter_tuning_method

    def get_best_model(self, df, target):
        try:
            regression.setup(data=df, target=target)
            best_model = regression.compare_models(include=self.model_list, sort=self.selected_metric,
                                                   n_select=1)
            tuned_model = regression.tune_model(best_model, optimize=self.selected_metric,
                                                search_library=self.hyperparameter_tuning_method)
            results = regression.pull()
            results.sort_values(self.selected_metric, ascending=False, inplace=True)
            results.reset_index(drop=True, inplace=True)
            get_best_model_row = results.iloc[0]
            best_metrics = get_best_model_row.to_dict()
            best_metrics.pop('Model', None)
            return tuned_model, best_metrics
        except Exception as e:
            logger.exception(f'Exception - {e}')

    @staticmethod
    def get_model_name(model):
        try:
            model_name = str(model).split('(')[0]
            return model_name
        except Exception as e:
            logger.info(f"Unable to get the model name - {e}")

    def get_auto_ml_model(self, df, target):
        try:
            model, metrics = self.get_best_model(df=df, target=target)
            model_name = self.get_model_name(model)
            hyper_params = model.get_params()
            return model, model_name, metrics, hyper_params
        except Exception as e:
            logger.info(f"Unable to get the model name - {e}")
