import traceback
from scripts.common.config_parser import config
from scripts.common.constants import ModelObjectConstants, ComponentExceptions
from scripts.common.logsetup import logger
from sklearn.ensemble import RandomForestRegressor
import pickle
import os


class RandomForest:
    def __init__(self, model_object, component_out_dir, query):
        self.model_object = model_object
        self.component_out_dir = component_out_dir
        self.query = query
        self.n_estimators = self.query['n_estimators']
        self.min_samples_split = self.query['min_samples_split']
        self.min_samples_leaf = self.query['min_samples_leaf']
        self.max_features = self.query['max_features']
        self.max_depth = self.query['max_depth']
        self.bootstrap = self.query['bootstrap']
        self.criterion = self.query['criterion']
        self.parameters = {
            "n_estimators": self.n_estimators,
            "min_samples_split": self.min_samples_split,
            "min_samples_leaf": self.min_samples_leaf,
            "max_features": self.max_features,
            "max_depth": self.max_depth,
            "bootstrap": self.bootstrap,
            "criterion": self.criterion
        }

    def random_forest(self):
        try:
            logger.info("Creating Model object for the model " + self.model_object)
            logger.info("These are the parameters used " + str(self.parameters))
            rf = RandomForestRegressor(n_estimators=self.n_estimators, min_samples_split=self.min_samples_split,
                                       min_samples_leaf=self.min_samples_leaf, max_features=self.max_features,
                                       max_depth=self.max_depth, bootstrap=self.bootstrap,criterion=self.criterion)
            filename = 'random_forest.pkl'
            logger.info("Pickling the model.....")
            pickle.dump(rf, open(os.path.join(self.component_out_dir, filename), 'wb'))
            return True
        except Exception as e:
            logger.info(traceback.format_exc())
            raise ValueError(e)


if __name__ == '__main__':
    try:
        obj = RandomForest(config['model_name'], config['component_output_dir'], config)
        if config['model_name'] == ModelObjectConstants.RANDOMFOREST:
            val = obj.random_forest()
            if val:
                if len(os.listdir(config['component_output_dir'])) > 0:
                    logger.info("File created Successfully")
                else:
                    logger.info("The output directory is empty")
            else:
                logger.info("Random Forest Component Failed")
                logger.info(traceback.format_exc())
    except Exception as e:
        logger.info("Random Forest Component Failed")
        logger.info(traceback.format_exc())
