import pandas as pd
import numpy as np
from loguru import logger
from scripts.utils.pycaret_util import PycaretUtil
from scripts.utils.preprocessing import DataPreprocessing
from scripts.utils.mlflow_util import ModelLoad


class TrainingInference:
    def __init__(self, df, df_train, df_test):
        self.df = df
        self.df_train = df_train
        self.df_test = df_test

    def data_training(self, inv_id, mppt_id):
        try:
            data_preprocessing = DataPreprocessing()
            df_train_inv = self.df_train[self.df_train['inv_id'] == inv_id]
            df_train_mppt = df_train_inv[df_train_inv['mppt_id'] == mppt_id]
            x_train = df_train_mppt[['datetime', 'inv_id', 'mppt_id', 'hour', 'tilt_irradiance', 'voltage_mppt']]
            y_train = df_train_mppt[['current_mppt']]

            x_train_std, scaler_x = data_preprocessing.get_standardized_data(df=x_train,
                                                                             param_list=['datetime', 'inv_id',
                                                                                         'mppt_id'])
            y_train_std, scaler_y = data_preprocessing.get_standardized_data(df=y_train)
            df_std = pd.concat([x_train_std, y_train_std], axis=1)
            df_std.dropna(axis=0, inplace=True)
            df_std.reset_index(drop=True, inplace=True)
            inv_mppt_id = f'{inv_id}_{mppt_id}'
            model, pre_trained = ModelLoad().model_manager(df=df_std, target='current_mppt',
                                                           inv_mppt_id=inv_mppt_id)

            return model, scaler_x, scaler_y
        except Exception as e:
            logger.exception(f'Exception - {e}')

    def data_inference(self, scaler_x, scaler_y, model, inv_id, mppt_id):
        try:
            df_test_inv = self.df_test[self.df_test['inv_id'] == inv_id]
            df_test_mppt = df_test_inv[df_test_inv['mppt_id'] == mppt_id]
            df_test_mppt.reset_index(drop=True, inplace=True)
            x_test = df_test_mppt[['datetime', 'inv_id', 'mppt_id', 'hour', 'tilt_irradiance', 'voltage_mppt']]
            y_test = df_test_mppt[['current_mppt']]
            data_preprocessing = DataPreprocessing()
            x_test_std = data_preprocessing.get_transform_std_data(df=x_test,
                                                                   param_list=['datetime', 'inv_id', 'mppt_id'],
                                                                   scaler=scaler_x)
            y_test_std = data_preprocessing.get_transform_std_data(df=y_test, scaler=scaler_y)

            predictions = model.predict(x_test_std).reshape(1, -1)
            predictions = np.array(scaler_y.inverse_transform(predictions)).reshape(-1, 1)
            y_test = scaler_y.inverse_transform(y_test_std)
            return x_test, y_test, predictions
        except Exception as e:
            logger.exception(f'Exception - {e}')
