import warnings

import numpy as np
import pandas as pd
from loguru import logger
from sklearn import metrics
from sklearn.model_selection import train_test_split

from scripts.constants.constants import RawConstants
from scripts.core.model_loader import ModelLoader
from scripts.section_utils.bof_section import preprocess_bof_section
from scripts.section_utils.extruder_section import preprocess_extruder_section
from scripts.section_utils.material_section import preprocess_viscosity_section
from scripts.section_utils.mixer_section import preprocess_mixer_section
from scripts.section_utils.pickup_section import preprocess_pickup_section
from scripts.section_utils.sheet_supply_section import preprocess_sheet_section

warnings.filterwarnings("ignore")


def model_trainer(df_grouped, index_no):
    cols_x, cols_y, saved_model = None, None, None
    if index_no == 1250:
        cols_x = ['temperature_ws_side_std', 'calender_roll_upper_side_inlet_side_cooling_water_temperature_mean',
                  '_calendar_current_mean', 'electric_energy_mean', 'seat_temperature_immediately_after_bof_mean',
                  'Weighted_NITROGEN_type', 'ram_pressure_mean', 'surface_temperature_center_std',
                  'drilled_side_left_exit_side_cooling_water_temperature_mean', 'Weighted_VM_type',
                  'screw_operation_side_outlet_side_cooling_water_flow_rate_std', 'Weighted_DIRT_type',
                  'screw_opposite_operation_side_outlet_side_cooling_water_temperature_std', 'residence_time_max',
                  'calender_roll_lower_side_inlet_side_cooling_water_flow_rate_mean', 'Weighted_ASH_type',
                  'Weighted_PO_type', 'drilled_side_right_exit_side_cooling_water_flow_rate_std']
        cols_y = "viscosity"
        saved_model = ModelLoader({
            "type": "mlflow.sklearn",
            "path": "models/fy676a"
        }).load_model()

    elif index_no == 3294:
        cols_x = ['Weighted_ASH_type', 'Weighted_NITROGEN_type', 'electric_energy_mean',
                  'drilled_side_left_inlet_side_cooling_water_temperature_std',
                  'seat_temperature_immediately_after_bof_mean',
                  'mixer_rotor_left_outlet_side_cooling_water_flow_rate_mean', 'humidity_mean',
                  'drilled_side_left_exit_side_cooling_water_flow_rate_mean',
                  'calender_roll_lower_side_inlet_side_cooling_water_flow_rate_mean', 'calendar_bank_load_max',
                  'drilled_side_right_inlet_side_cooling_water_flow_rate_mean', 'Weighted_PRI_type',
                  'mixer_rotor_right_inlet_side_cooling_water_flow_rate_mean', 'temperature_ws_side_std',
                  'dust_cv\nspeed_std', 'mixer_rotor_right_inlet_side_coolant_temperature_mean', 'ram_position_std',
                  'drilled_side_right_exit_side_cooling_water_temperature_std',
                  'calender_roll_upper_side__outlet__side_cooling_water_temperature_std',
                  'Weighted_Temperature during transportation_type[℃]']
        cols_y = "viscosity"
        saved_model = ModelLoader({
            "type": "mlflow.sklearn",
            "path": "models/fy664g"
        }).load_model()
    req_cols = cols_x + ['viscosity']
    features = df_grouped[cols_x]
    labels = df_grouped[cols_y]
    # df_grouped[req_cols].to_csv('final.csv')
    # Split the data into training and testing sets
    x_train, x_test, y_train, y_test = train_test_split(features, labels, random_state=42, test_size=0.25)
    print(f'x_train shape - {x_train.shape}')
    print(f'x_test shape - {x_test.shape}')
    print(f'y_train shape - {y_train.shape}')
    print(f'y_test shape - {y_test.shape}')
    y_pred = saved_model.predict(x_test)
    predictions = [round(value, 2) for value in y_pred]

    metric_dictionary = dict()
    mae = metrics.mean_absolute_error(y_test, predictions)
    mse = metrics.mean_squared_error(y_test, predictions)
    mape = metrics.mean_absolute_percentage_error(y_test, predictions)
    explained_variance_score = metrics.explained_variance_score(y_test, predictions)
    max_error = metrics.max_error(y_test, predictions)
    r2_score = metrics.r2_score(y_test, predictions)
    median_absolute_error = metrics.median_absolute_error(y_test, predictions)
    mean_poisson_deviance = metrics.mean_poisson_deviance(y_test, predictions)
    mean_gamma_deviance = metrics.mean_gamma_deviance(y_test, predictions)

    metric_dictionary["Mean Absolute Error (MAE)"] = mae
    metric_dictionary["Mean Squared Error (MSE)"] = mse
    metric_dictionary["Root Mean Squared Error (RMSE)"] = np.sqrt(mse)
    metric_dictionary["Mean Absolute Percentage Error (MAPE)"] = mape
    metric_dictionary["Explained Variance Score"] = explained_variance_score
    metric_dictionary["Max Error"] = max_error
    metric_dictionary["Median Absolute Error"] = median_absolute_error
    metric_dictionary["R2 Score"] = r2_score
    metric_dictionary["Mean Gamma Deviance"] = mean_gamma_deviance
    metric_dictionary["Mean Poisson Deviance"] = mean_poisson_deviance

    print(metric_dictionary)


def read_raw_data(raw_path, raw_skip_rows):
    try:
        df = pd.read_excel(raw_path, skiprows=raw_skip_rows)
    except Exception as e:
        df = pd.read_csv(raw_path)
    if len(df.columns) == len(RawConstants.columns):
        logger.info(f"Total cols are {len(RawConstants.columns)} and are same as the df cols length")
        df.columns = RawConstants.columns
    else:
        missed_cols = RawConstants.columns[len(df.columns):]
        logger.info(f"missed cols are {missed_cols}")
        for col in missed_cols:
            df[col] = float('nan')
        df.columns = RawConstants.columns
    logger.info(f"Shape of df is {df.shape}")
    return df


def merged_all_sections(sheet_df, mixer_df, extruder_df, bof_df, pickup_df, viscosity_df):
    merged_df = pd.merge(sheet_df, mixer_df, on='batch-date', how='left')
    merged_df = pd.merge(merged_df, extruder_df, on='batch-date', how='left')
    merged_df = pd.merge(merged_df, bof_df, on='batch-date', how='left')
    merged_df = pd.merge(merged_df, pickup_df, on='batch-date', how='left')
    df_grouped = pd.merge(merged_df, viscosity_df, on='batch-date', how='left')
    selected_cols = df_grouped.columns
    df_grouped = df_grouped[selected_cols]

    viscosity_rubber_cols = ['Weight_type1', 'Weight_type2',
                             'Weighted_PO_type', 'Weighted_DIRT_type', 'Weighted_ASH_type',
                             'Weighted_VM_type', 'Weighted_PRI_type', 'Weighted_NITROGEN_type',
                             'Weighted_Temperature during transportation_type[℃]',
                             'Weighted_Humidity during transportation__type[%]', 'Weighted Sum',
                             'viscosity']
    # Replace 0 values with NaN
    for col in viscosity_rubber_cols:
        df_grouped[col] = df_grouped[col].replace(0, np.nan)
        df_grouped[col] = df_grouped[col].fillna(df_grouped[col].mean())

    # Extract batch number and date
    batch_number = df_grouped['batch-date'].str.extract(r'Batch_(\d+\.\d+)_')[0].astype(float)
    date = pd.to_datetime(df_grouped['batch-date'].str.extract(r'_(\d{4}-\d{2}-\d{2})$')[0])

    # Add extracted data as separate columns
    df_grouped['Batch Number'] = batch_number
    df_grouped['Date'] = date

    # Sort by 'Batch Number' and 'Date'
    df_grouped = df_grouped.sort_values(by=['Date', 'Batch Number'])
    df_grouped = round(df_grouped, 6)
    return df_grouped


def load_and_predict(df_grouped, index_no):
    if index_no == 1250:
        logger.info(f"Loading model for {index_no}")
        saved_model = ModelLoader({
            "type": "mlflow.sklearn",
            "path": "models/fy676a"
        }).load_model()
        cols_x = ['temperature_ws_side_std', 'calender_roll_upper_side_inlet_side_cooling_water_temperature_mean',
                  '_calendar_current_mean', 'electric_energy_mean', 'seat_temperature_immediately_after_bof_mean',
                  'Weighted_NITROGEN_type', 'ram_pressure_mean', 'surface_temperature_center_std',
                  'drilled_side_left_exit_side_cooling_water_temperature_mean', 'Weighted_VM_type',
                  'screw_operation_side_outlet_side_cooling_water_flow_rate_std', 'Weighted_DIRT_type',
                  'screw_opposite_operation_side_outlet_side_cooling_water_temperature_std', 'residence_time_max',
                  'calender_roll_lower_side_inlet_side_cooling_water_flow_rate_mean', 'Weighted_ASH_type',
                  'Weighted_PO_type', 'drilled_side_right_exit_side_cooling_water_flow_rate_std']
        cols_y = "viscosity"
        features = df_grouped[cols_x]
        labels = df_grouped[cols_y]
        y_pred_full = saved_model.predict(features)
        df_grouped['predicted_viscosity'] = y_pred_full
        final_df = df_grouped[['Date', 'Batch Number', 'predicted_viscosity']]
        final_df.to_csv(f'{index_no}_final_predicted_viscosity.csv')
    elif index_no == 3294:
        logger.info(f"Loading model for {index_no}")
        saved_model = ModelLoader({
            "type": "mlflow.sklearn",
            "path": "models/fy664g"
        }).load_model()
        cols_x = ['Weighted_ASH_type', 'Weighted_NITROGEN_type', 'electric_energy_mean',
                  'drilled_side_left_inlet_side_cooling_water_temperature_std',
                  'seat_temperature_immediately_after_bof_mean',
                  'mixer_rotor_left_outlet_side_cooling_water_flow_rate_mean', 'humidity_mean',
                  'drilled_side_left_exit_side_cooling_water_flow_rate_mean',
                  'calender_roll_lower_side_inlet_side_cooling_water_flow_rate_mean', 'calendar_bank_load_max',
                  'drilled_side_right_inlet_side_cooling_water_flow_rate_mean', 'Weighted_PRI_type',
                  'mixer_rotor_right_inlet_side_cooling_water_flow_rate_mean', 'temperature_ws_side_std',
                  'dust_cv\nspeed_std', 'mixer_rotor_right_inlet_side_coolant_temperature_mean',
                  'ram_position_std',
                  'drilled_side_right_exit_side_cooling_water_temperature_std',
                  'calender_roll_upper_side__outlet__side_cooling_water_temperature_std',
                  'Weighted_Temperature during transportation_type[℃]']
        cols_y = "viscosity"
        features = df_grouped[cols_x]
        labels = df_grouped[cols_y]
        y_pred_full = saved_model.predict(features)
        df_grouped['predicted_viscosity'] = y_pred_full
        final_df = df_grouped[['Date', 'Batch Number', 'predicted_viscosity']]
        final_df.to_csv(f'{index_no}_final_predicted_viscosity.csv')


def start_prediction(raw_path, viscosity_path, index_no, raw_skip_rows, viscosity_skip_rows):
    logger.info(f"Starting prediction for {index_no}")
    logger.info("Reading raw file data")
    df = read_raw_data(raw_path, raw_skip_rows)
    logger.info(f"Shape of raw df is {df.shape}")

    logger.info("Starting preprocessing material section")
    visc_df = pd.read_excel(viscosity_path, skiprows=viscosity_skip_rows)
    viscosity_df, raw_viscosity_df = preprocess_viscosity_section(visc_df, index_no)
    # viscosity_df.to_csv('viscosity-agg.csv')
    logger.info(f"The shape of the viscosity df is {viscosity_df.shape}")
    logger.info("Completed material section preprocessing")

    logger.info("Starting preprocessing sheet section")
    df_sheet_grouped = preprocess_sheet_section(df, index_no)
    logger.info(f"The shape of the Sheet df is {df_sheet_grouped.shape}")
    logger.info("Completed sheet section preprocessing")
    # df_sheet_grouped.to_csv('sheet-agg.csv')

    logger.info("Starting preprocessing mixer section")
    df_mixer_grouped = preprocess_mixer_section(df, index_no)
    logger.info(f"The shape of the Mixer df is {df_mixer_grouped.shape}")
    logger.info("Completed mixer section preprocessing")
    # df_mixer_grouped.to_csv('mixer-agg.csv')

    logger.info("Starting preprocessing extruder section")
    df_extruder_grouped = preprocess_extruder_section(df, index_no, raw_viscosity_df)
    logger.info(f"The shape of the Extruder df is {df_extruder_grouped.shape}")
    logger.info("Completed extruder section preprocessing")
    # df_extruder_grouped.to_csv('extruder-agg.csv')

    logger.info("Starting preprocessing bof section")
    df_bof_grouped = preprocess_bof_section(df, index_no, raw_viscosity_df)
    logger.info(f"The shape of the BOF df is {df_bof_grouped.shape}")
    logger.info("Completed bof section preprocessing")
    # df_bof_grouped.to_csv('bof-agg.csv')
    # bof_desc = df_bof_grouped.describe()
    # bof_desc.to_csv('bof-describe.csv')

    logger.info("Starting preprocessing pickup section")
    df_pickup_grouped = preprocess_pickup_section(df, index_no, raw_viscosity_df)
    logger.info(f"The shape of the Extruder df is {df_pickup_grouped.shape}")
    logger.info("Completed pickup section preprocessing")
    # df_pickup_grouped.to_csv('pickup-agg.csv')
    # df = pd.read_csv('pickup-agg.csv')
    # print(df.describe())
    df_grouped = merged_all_sections(df_sheet_grouped, df_mixer_grouped, df_extruder_grouped, df_bof_grouped,
                                     df_pickup_grouped, viscosity_df)

    load_and_predict(df_grouped, index_no)
    # model_trainer(df_grouped, index_no)


if __name__ == "__main__":
    try:
        logger.info("Starting the model")
        index_number = 1250
        raw_file_path = 'FY676-A-WO_Visc.xlsx'
        raw_file_skip_rows = 0
        viscosity_file_path = 'viscosity_natural_rubber_data.xlsx'
        viscosity_file_skip_rows = 3
        start_prediction(raw_file_path, viscosity_file_path, index_number, raw_file_skip_rows, viscosity_file_skip_rows)
        index_number = 3294
        raw_file_path = 'fy664g_raw.csv'
        raw_file_skip_rows = 0
        viscosity_file_path = 'fy664g-viscosity.xlsx'
        viscosity_file_skip_rows = 2
        start_prediction(raw_file_path, viscosity_file_path, index_number, raw_file_skip_rows, viscosity_file_skip_rows)
    except Exception as e:
        logger.exception(f"Module failed because of error {e}")
