import warnings

import mlflow
import numpy as np
import pandas as pd
from sklearn import metrics
from sklearn.ensemble import ExtraTreesRegressor
from sklearn.metrics import *
from sklearn.model_selection import train_test_split

from scripts.core.model_loader import ModelLoader
from scripts.section_utils.mlflow_util import ModelLoaderSaver

warnings.filterwarnings("ignore")


def model_trainer():
    sheet_df = pd.read_csv(r'D:\kalypso\bsj-model-inference\sheet-agg.csv')
    mixer_df = pd.read_csv(r'D:\kalypso\bsj-model-inference\mixer-agg.csv')
    extruder_df = pd.read_csv(r'D:\kalypso\bsj-model-inference\extruder-agg.csv')
    bof_df = pd.read_csv(r'D:\kalypso\bsj-model-inference\bof-agg.csv')
    pickup_df = pd.read_csv(r'D:\kalypso\bsj-model-inference\pickup-agg.csv')
    viscosity_df = pd.read_csv(r'D:\kalypso\bsj-model-inference\viscosity-agg.csv')
    sheet_df.describe().to_csv("sheet_desc.csv")
    mixer_df.describe().to_csv("mixer_desc.csv")
    extruder_df.describe().to_csv("extru_desc.csv")
    bof_df.describe().to_csv("bof_desc.csv")
    pickup_df.describe().to_csv("pickup_desc.csv")
    viscosity_df.describe().to_csv("visc_desc.csv")
    # viscosity_df = viscosity_df[['batch-date', 'viscosity']]
    merged_df = pd.merge(sheet_df, mixer_df, on='batch-date', how='left')
    merged_df = pd.merge(merged_df, extruder_df, on='batch-date', how='left')
    merged_df = pd.merge(merged_df, bof_df, on='batch-date', how='left')
    merged_df = pd.merge(merged_df, pickup_df, on='batch-date', how='left')
    df_grouped = pd.merge(merged_df, viscosity_df, on='batch-date', how='left')
    # print(df_grouped.columns)
    selected_cols = df_grouped.columns
    df_grouped = df_grouped[selected_cols]

    viscosity_rubber_cols = ['Weight_type1', 'Weight_type2',
                             'Weighted_PO_type', 'Weighted_DIRT_type', 'Weighted_ASH_type',
                             'Weighted_VM_type', 'Weighted_PRI_type', 'Weighted_NITROGEN_type',
                             'Weighted_Temperature during transportation_type[℃]',
                             'Weighted_Humidity during transportation__type[%]', 'Weighted Sum',
                             'viscosity']
    # Replace 0 values with NaN
    for col in viscosity_rubber_cols:
        df_grouped[col] = df_grouped[col].replace(0, np.nan)
        df_grouped[col] = df_grouped[col].fillna(df_grouped[col].mean())

    # Extract batch number and date
    batch_number = df_grouped['batch-date'].str.extract(r'Batch_(\d+\.\d+)_')[0].astype(float)
    date = pd.to_datetime(df_grouped['batch-date'].str.extract(r'_(\d{4}-\d{2}-\d{2})$')[0])

    # Add extracted data as separate columns
    df_grouped['Batch Number'] = batch_number
    df_grouped['Date'] = date

    # Sort by 'Batch Number' and 'Date'
    df_grouped = df_grouped.sort_values(by=['Date', 'Batch Number'])
    # df_grouped = pd.read_csv(r"D:\kalypso\bsj-model-inference\test-agg-data.csv")
    df_grouped = round(df_grouped, 6)
    df_grouped.to_csv('grouped.csv')
    cols_x = ['temperature_ws_side_std', 'calender_roll_upper_side_inlet_side_cooling_water_temperature_mean', '_calendar_current_mean', 'electric_energy_mean', 'seat_temperature_immediately_after_bof_mean', 'Weighted_NITROGEN_type', 'ram_pressure_mean', 'surface_temperature_center_std', 'drilled_side_left_exit_side_cooling_water_temperature_mean', 'Weighted_VM_type', 'screw_operation_side_outlet_side_cooling_water_flow_rate_std', 'Weighted_DIRT_type', 'screw_opposite_operation_side_outlet_side_cooling_water_temperature_std', 'residence_time_max', 'calender_roll_lower_side_inlet_side_cooling_water_flow_rate_mean', 'Weighted_ASH_type', 'Weighted_PO_type', 'drilled_side_right_exit_side_cooling_water_flow_rate_std']
    cols_y = "viscosity"
    req_cols = cols_x + ['viscosity']
    # df_grouped = round(df_grouped, 2)
    features = df_grouped[cols_x]
    # print(features.info())
    # print(features.describe().to_csv('feature.csv'))
    # print(df_grouped[req_cols].isnull().sum())
    # df_grouped = round(df_grouped,2)
    # df_grouped = pd.read_csv(r'D:\kalypso\bsj-model-inference\final-fy676a.csv')
    labels = df_grouped[cols_y]
    # print(df_grouped[cols_y].describe())
    df_grouped[req_cols].to_csv('final.csv')
    # Split the data into training and testing sets
    x_train, x_test, y_train, y_test = train_test_split(features, labels, random_state=42, test_size=0.25)

    print(f'x_train shape - {x_train.shape}')
    print(f'x_test shape - {x_test.shape}')
    print(f'y_train shape - {y_train.shape}')
    print(f'y_test shape - {y_test.shape}')
    params = {'bootstrap': False,
              'ccp_alpha': 0.0,
              'criterion': 'squared_error',
              'max_depth': None,
              'max_features': 1.0,
              'max_leaf_nodes': None,
              'max_samples': None,
              'min_impurity_decrease': 0.0,
              'min_samples_leaf': 1,
              'min_samples_split': 2,
              'min_weight_fraction_leaf': 0.0,
              'n_estimators': 100,
              'n_jobs': -1,
              'oob_score': False,
              'random_state': 123,
              'verbose': 0,
              'warm_start': False}
    model = ExtraTreesRegressor(**params)
    model.fit(x_train, y_train)
    y_pred = model.predict(x_test)
    predictions = [round(value, 2) for value in y_pred]

    metric_dictionary = dict()
    mae = metrics.mean_absolute_error(y_test, predictions)
    mse = metrics.mean_squared_error(y_test, predictions)
    mape = metrics.mean_absolute_percentage_error(y_test, predictions)
    explained_variance_score = metrics.explained_variance_score(y_test, predictions)
    max_error = metrics.max_error(y_test, predictions)
    r2_score = metrics.r2_score(y_test, predictions)
    median_absolute_error = metrics.median_absolute_error(y_test, predictions)
    mean_poisson_deviance = metrics.mean_poisson_deviance(y_test, predictions)
    mean_gamma_deviance = metrics.mean_gamma_deviance(y_test, predictions)

    metric_dictionary["Mean Absolute Error (MAE)"] = mae
    metric_dictionary["Mean Squared Error (MSE)"] = mse
    metric_dictionary["Root Mean Squared Error (RMSE)"] = np.sqrt(mse)
    metric_dictionary["Mean Absolute Percentage Error (MAPE)"] = mape
    metric_dictionary["Explained Variance Score"] = explained_variance_score
    metric_dictionary["Max Error"] = max_error
    metric_dictionary["Median Absolute Error"] = median_absolute_error
    metric_dictionary["R2 Score"] = r2_score
    metric_dictionary["Mean Gamma Deviance"] = mean_gamma_deviance
    metric_dictionary["Mean Poisson Deviance"] = mean_poisson_deviance

    print(metric_dictionary)

    experiment_name = "BSJ-Models"
    parent_run_name = model_save_name = model_type = "fy676a"
    list_of_models = ['rf', 'xgboost', 'lr']
    obj = ModelLoaderSaver(None, metric_dictionary, params, experiment_name, parent_run_name, model_save_name, model_type)
    new_model = obj.get_latest_model()

    y_pred = new_model.predict(x_test)
    predictions = [round(value, 2) for value in y_pred]

    metric_dictionary = dict()
    mae = metrics.mean_absolute_error(y_test, predictions)
    mse = metrics.mean_squared_error(y_test, predictions)
    mape = metrics.mean_absolute_percentage_error(y_test, predictions)
    explained_variance_score = metrics.explained_variance_score(y_test, predictions)
    max_error = metrics.max_error(y_test, predictions)
    r2_score = metrics.r2_score(y_test, predictions)
    median_absolute_error = metrics.median_absolute_error(y_test, predictions)
    mean_poisson_deviance = metrics.mean_poisson_deviance(y_test, predictions)
    mean_gamma_deviance = metrics.mean_gamma_deviance(y_test, predictions)

    metric_dictionary["Mean Absolute Error (MAE)"] = mae
    metric_dictionary["Mean Squared Error (MSE)"] = mse
    metric_dictionary["Root Mean Squared Error (RMSE)"] = np.sqrt(mse)
    metric_dictionary["Mean Absolute Percentage Error (MAPE)"] = mape
    metric_dictionary["Explained Variance Score"] = explained_variance_score
    metric_dictionary["Max Error"] = max_error
    metric_dictionary["Median Absolute Error"] = median_absolute_error
    metric_dictionary["R2 Score"] = r2_score
    metric_dictionary["Mean Gamma Deviance"] = mean_gamma_deviance
    metric_dictionary["Mean Poisson Deviance"] = mean_poisson_deviance

    print(metric_dictionary)
    # mlflow.sklearn.save_model(new_model, "models/fy676a")
    saved_model = ModelLoader({
        "type": "mlflow.sklearn",
        "path": "models/fy676a"
    }).load_model()
    y_pred = saved_model.predict(x_test)
    predictions = [round(value, 2) for value in y_pred]

    metric_dictionary = dict()
    mae = metrics.mean_absolute_error(y_test, predictions)
    mse = metrics.mean_squared_error(y_test, predictions)
    mape = metrics.mean_absolute_percentage_error(y_test, predictions)
    explained_variance_score = metrics.explained_variance_score(y_test, predictions)
    max_error = metrics.max_error(y_test, predictions)
    r2_score = metrics.r2_score(y_test, predictions)
    median_absolute_error = metrics.median_absolute_error(y_test, predictions)
    mean_poisson_deviance = metrics.mean_poisson_deviance(y_test, predictions)
    mean_gamma_deviance = metrics.mean_gamma_deviance(y_test, predictions)

    metric_dictionary["Mean Absolute Error (MAE)"] = mae
    metric_dictionary["Mean Squared Error (MSE)"] = mse
    metric_dictionary["Root Mean Squared Error (RMSE)"] = np.sqrt(mse)
    metric_dictionary["Mean Absolute Percentage Error (MAPE)"] = mape
    metric_dictionary["Explained Variance Score"] = explained_variance_score
    metric_dictionary["Max Error"] = max_error
    metric_dictionary["Median Absolute Error"] = median_absolute_error
    metric_dictionary["R2 Score"] = r2_score
    metric_dictionary["Mean Gamma Deviance"] = mean_gamma_deviance
    metric_dictionary["Mean Poisson Deviance"] = mean_poisson_deviance

    print(metric_dictionary)
model_trainer()


# {'Mean Absolute Error (MAE)': 1.4711585365853663, 'Mean Squared Error (MSE)': 3.193666768292685, 'Root Mean Squared Error (RMSE)': 1.7870833131929482,
#  'Mean Absolute Percentage Error (MAPE)': 0.015400607504235945, 'Explained Variance Score': 0.5937040328784624, 'Max Error': 4.709999999999994,
#  'Median Absolute Error': 1.4399999999999977, 'R2 Score': 0.5936331791226861, 'Mean Gamma Deviance': 0.0003503027256495745,
#  'Mean Poisson Deviance': 0.03343612041755939}
