import math
import warnings
from datetime import datetime

import numpy as np
import pandas as pd
from loguru import logger

from scripts.constants.constants import ExtruderConstants

warnings.filterwarnings("ignore")


def mixer_section_start_end_time(raw_df, index_no):
    try:
        mixer_cols = ['Time Stamp',
                      'Size No (INDEX No).3',
                      'Size name',
                      'Mixing batch number',
                      'idle time between batches',
                      ]
        mixer_df = raw_df[mixer_cols]
        mixer_df['Time Stamp'] = pd.to_datetime(mixer_df['Time Stamp'])
        mixer_df = mixer_df.sort_values(by='Time Stamp')
        numeric_cols = mixer_df.select_dtypes(include=['int', 'float']).columns

        # Convert numeric columns to float
        mixer_df[numeric_cols] = mixer_df[numeric_cols].astype(float)
        mixer_df['day'] = mixer_df['Time Stamp'].dt.date
        mixer_df = mixer_df[mixer_df["Size No (INDEX No).3"] == index_no]
        mixer_df = mixer_df[mixer_df["Mixing batch number"] != 0]
        mixer_df['time_min'] = mixer_df['Time Stamp']
        mixer_df['time_max'] = mixer_df['Time Stamp']
        aggregation_dict = {
            'time_min': 'min',
            'time_max': 'max',
        }
        group_by = ['day', 'Mixing batch number']
        df_mixer_grouped = mixer_df.groupby(group_by).agg(aggregation_dict).reset_index()
        df_mixer_grouped['mixer_section_time_diff_second'] = df_mixer_grouped['time_max'] - df_mixer_grouped['time_min']
        df_mixer_grouped['mixer_section_time_diff_second'] = df_mixer_grouped[
            'mixer_section_time_diff_second'].dt.total_seconds()
        df_mixer_grouped['batch-date'] = 'Batch_' + df_mixer_grouped['Mixing batch number'].astype(str) + '_' + \
                                         df_mixer_grouped['day'].astype(str)
        date_dict = {}
        batch_lis = list(df_mixer_grouped['batch-date'].unique())
        for each_bt in batch_lis:
            df_nw = df_mixer_grouped[df_mixer_grouped['batch-date'] == each_bt]
            date_dict[each_bt] = {"start_time": str(list(df_nw['time_min'])[0]),
                                  'end_time': str(list(df_nw['time_max'])[0])}
        return date_dict
    except Exception as err:
        logger.error(f'Exception in extruder mixer time fetch {str(err)}')
        raise Exception(str(err))


def return_batch_no_df(raw_df, viscosity_df, date_dict, bof_cols, additional_cols, index_no):
    try:
        raw_df = raw_df.sort_values(by='Time Stamp')
        raw_df['Time Stamp'] = pd.to_datetime(raw_df['Time Stamp'])
        raw_df["day"] = raw_df["Time Stamp"].dt.date
        raw_df["day"] = raw_df["day"].astype("str")

        raw_df["Mixing batch number"] = raw_df["Mixing batch number"].astype("float")
        raw_df["batch-date"] = (
                "Batch_"
                + raw_df["Mixing batch number"].astype("str")
                + "_"
                + raw_df["day"].astype("str")
        )

        bof_add_cols = bof_cols + additional_cols
        bof_df = raw_df[bof_add_cols]

        sorted_bof_df = bof_df.sort_values(by="Time Stamp", ascending=True)
        sorted_bof_df = sorted_bof_df[sorted_bof_df["Size No (INDEX No).4"] == index_no]
        dt_list = list(sorted_bof_df["day"].unique())

        day_length_dic = {}
        for each_day in dt_list:
            day_df = sorted_bof_df[sorted_bof_df["day"] == each_day]
            if day_df["discharge length"].max() - day_df["discharge length"].min() <= 0:
                raise Exception(f"Discharge length in extruder section for the day {each_day} is 0")
            else:
                value = day_df["discharge length"].max() - day_df["discharge length"].min()
            day_length_dic[each_day] = value

        # print(day_length_dic)

        sorted_viscosity_df = viscosity_df.sort_values(by="Mixing date", ascending=True)
        sorted_viscosity_df["day"] = sorted_viscosity_df["Mixing date"].dt.date
        sorted_viscosity_df["day"] = sorted_viscosity_df["day"].astype("str")

        extrud_visc_df = sorted_viscosity_df[
            ["Batch No.", "Input rubber weight(0.1kg)", "day", "Mixing date"]
        ]
        extrud_visc_df["length_from_extruder"] = extrud_visc_df["day"].map(day_length_dic)
        extrud_visc_df["length_from_extruder"] = extrud_visc_df[
            "length_from_extruder"
        ].fillna(0)
        daily_sum_weight = (
                extrud_visc_df.groupby("day")["Input rubber weight(0.1kg)"].sum() / 10
        )
        # Add a new column 'm/kg' by dividing 'length_from_extruder' by the sum for each day
        extrud_visc_df["m/kg"] = extrud_visc_df.apply(
            lambda row: row["length_from_extruder"] / daily_sum_weight[row["day"]], axis=1
        )
        extrud_visc_df["batch_length"] = extrud_visc_df.apply(
            lambda row: row["m/kg"] * row["Input rubber weight(0.1kg)"] / 10, axis=1
        ).astype("float64")

        extrud_visc_df["batch_length"] = extrud_visc_df["batch_length"].apply(math.ceil)
        extrud_visc_df["cumulative_length"] = extrud_visc_df.groupby("day")[
            "batch_length"
        ].cumsum()

        discharge_dict = (
            extrud_visc_df.groupby("day")
            .apply(
                lambda group: group.set_index("Batch No.").to_dict()["cumulative_length"]
            )
            .to_dict()
        )

        test_sorted_extr_df = sorted_bof_df
        test_df = test_sorted_extr_df

        # Initialize an empty list to store batch numbers
        batch_numbers = []

        # Iterate through each row in the DataFrame
        for index, row in test_df.iterrows():
            day = row["day"]
            discharge_length = row["discharge length"]
            if discharge_length == 0:
                batch_numbers.append(0)
            else:
                # Check if the day is in the dictionary
                if day in discharge_dict:
                    # Check if discharge length is less than or equal to the corresponding batch length
                    batch_length_dict = discharge_dict[day]
                    for batch_no, batch_length in batch_length_dict.items():
                        if discharge_length <= batch_length:
                            batch_numbers.append(batch_no)
                            break
                    else:
                        # If no match is found in the dictionary, assign NaN to batch number
                        batch_numbers.append(batch_numbers[-1])
                else:
                    # If day is not in the dictionary, assign NaN to batch number
                    batch_numbers.append(np.nan)

        # Add the 'batch_no' column to the DataFrame
        test_df["batch_no"] = batch_numbers

        batch_number = 0
        batch_list = []

        started_with_one = False
        current_day = None

        for value, day in zip(list(test_df["lower door open"]), list(test_df["day"])):

            if current_day != day:
                current_day = day
                batch_number = 0

            if value == 1:
                if not started_with_one:
                    batch_number += 1
                    started_with_one = True
                batch_list.append(batch_number)
            else:
                batch_list.append(batch_number)
                started_with_one = False

        test_df["batch_no"] = test_df["batch_no"].astype("float")
        test_df["extruder_batch_date"] = (
                "Batch_"
                + test_df["batch_no"].astype("str")
                + "_"
                + test_df["day"].astype("str")
        )
        extruder_flag_list = []
        extrud_flg_vms = []
        for i, value in test_df.iterrows():
            if value["batch_no"] == 0.0:
                extruder_flag_list.append("false")
                extrud_flg_vms.append(0)
            else:
                start_time = date_dict.get(value["extruder_batch_date"]).get("start_time")

                end_time = date_dict.get(value["extruder_batch_date"]).get("end_time")
                if (datetime.strptime(str(value["Time Stamp"]).split('+')[0], '%Y-%m-%d %H:%M:%S') > datetime.strptime(
                        start_time.split('+')[0], '%Y-%m-%d %H:%M:%S')) & \
                        (datetime.strptime(str(value["Time Stamp"]).split('+')[0], '%Y-%m-%d %H:%M:%S') < datetime.strptime(
                            end_time.split('+')[0], '%Y-%m-%d %H:%M:%S')):
                    extruder_flag_list.append("true")
                    extrud_flg_vms.append(1)
                else:
                    extruder_flag_list.append("false")
                    extrud_flg_vms.append(0)

        test_df["extruder_flag"] = extruder_flag_list
        test_df["extruder_batch_diff"] = extrud_flg_vms
        test_df["updtaed_bt_list"] = batch_list

        test_df["extruder_batch_number"] = test_df["batch_no"] - test_df[
            "extruder_batch_diff"
        ].astype("float")
        test_df["batch-date"] = (
                "Batch_"
                + test_df["extruder_batch_number"].astype("str")
                + "_"
                + test_df["day"].astype("str")
        )
        return test_df
    except Exception as err:
        logger.error(f'Exception in generating extruder batch {str(err)}')
        raise Exception(str(err))


def preprocess_extruder_section(df, index_number, vis_df):
    try:
        extruder_cols = ExtruderConstants.extruder_cols
        additional_columns = ['Time Stamp']
        df_extruder = df[extruder_cols + additional_columns]
        df_extruder['Time Stamp'] = pd.to_datetime(df_extruder['Time Stamp'])
        df_extruder = df_extruder.sort_values(by='Time Stamp')
        df_extruder['day'] = df_extruder['Time Stamp'].dt.date
        df_extruder['day'] = df_extruder['day'].astype('str')
        sorted_extrud_df = df_extruder.sort_values(by="Time Stamp", ascending=True)
        sorted_extrud_df = sorted_extrud_df[sorted_extrud_df['Size No (INDEX No).4'] == index_number]
        drop_col = ['spare.19',
                    'spare.20',
                    'spare.21',
                    'spare.22',
                    'spare.23',
                    'spare.24',
                    'spare.25', 'Hopper bank upper limit',
                    'middle of hopper bank',
                    'Hopper bank lower limit',
                    'Hopper bank below lower limit']

        sorted_extrud_df.drop(columns=drop_col, inplace=True)
        date_dict = mixer_section_start_end_time(df, index_number)
        additional_cols = ['day', 'Time Stamp', 'lower door open']
        # adding date col to the viscosity df
        vis_df = vis_df.sort_values(by='Mixing date')
        vis_df['date'] = vis_df['Mixing date'].dt.date
        vis_df['batch-date'] = 'Batch_' + vis_df['Batch No.'].astype('float').astype(str) + '_' + vis_df[
            'date'].astype(str)
        vis_df = vis_df[vis_df['Index No'] == index_number]
        extruder_merged_df_final = return_batch_no_df(df, vis_df, date_dict, extruder_cols, additional_cols,
                                                      index_number)
        extruder_merged_df_final = extruder_merged_df_final[extruder_merged_df_final['extruder_batch_number'] != 0]
        grouped_cols = ['batch-date']
        aggregate_dict = ExtruderConstants.aggregate_dict
        df_extruder_grouped = extruder_merged_df_final.groupby(grouped_cols).agg(aggregate_dict).reset_index()
        col_renamer = {}
        for col, col_agg in aggregate_dict.items():
            if col not in ['viscosity', 'time_min', 'time_max', 'Mixing Weight (Integrated Value)_diff', 'max_rpm_count']:
                renamed_col = f'{col.replace("(", "").replace(")", "").replace(" ", "_")}_{col_agg}'.lower()
                col_renamer[col] = renamed_col
            else:
                col_renamer[col] = col
        df_extruder_grouped = df_extruder_grouped.rename(columns=col_renamer)
        df_extruder_grouped = df_extruder_grouped.fillna(df_extruder_grouped.mean())
        df_extruder_grouped = round(df_extruder_grouped, 6)
        return df_extruder_grouped
    except Exception as err:
        logger.error(f"Exception in extruder preprocess {str(err)}")
        raise Exception(str(err))
