from pycaret import regression
from loguru import logger


class GetBestModel:
    def __init__(self, df, target_col_list, list_of_models, no_of_models=1):
        self.df = df
        self.target_col_list = target_col_list
        self.no_of_models = no_of_models
        self.list_of_models = list_of_models

    def compare_get_best_model(self, fine_tune_tech, comparison_metric):
        """
        Train and compares the model based on the finetune tech and comparison metric
        :param fine_tune_tech: search library for fine-tuning of the selected model
        :param comparison_metric: metrics to select the best model
        :return: the best model, model name, metrics and parameters
        """
        try:
            logger.info("Using Pycaret to train mentioned models")
            regression.setup(data=self.df, target=self.target_col_list)
            logger.info(f"Selecting the best model using the metric {comparison_metric}")
            best_model = regression.compare_models(include=self.list_of_models, sort=comparison_metric,
                                                   n_select=self.no_of_models)
            logger.info("Tuning the Model")
            tuned_model = regression.tune_model(best_model, optimize=comparison_metric, search_library=fine_tune_tech)
            results = regression.pull()
            get_best_model_row = results.iloc[0]
            best_metrics = get_best_model_row.to_dict()
            best_metrics.pop('Model', None)
            model_name = str(tuned_model).split('(')[0]
            hyper_params = tuned_model.get_params()
            logger.info("Model Training Completed")
            return tuned_model, model_name, best_metrics, hyper_params
        except Exception as e:
            logger.info(f"Unable to select the best model - {e}")
