from datetime import datetime

from scripts.constants.app_configuration import KAIROS_DB_HOST, METADATA
from scripts.core.data.data_import import DataPuller
from loguru import logger
import pandas as pd


class Compliance:
    def __init__(self, payload, column_rename, tags_data):
        self.payload = payload
        self.column_rename = column_rename
        self.tags_data = tags_data
        self._dp_ = DataPuller(db_host=KAIROS_DB_HOST, payload=self.payload, column_rename=self.column_rename)

    @staticmethod
    def add_compliance_column(df, column, live_col, upper_col, lower_col):
        try:
            df[f'{column}_compliance'] = (df[live_col] > df[lower_col]) & (df[live_col] < df[upper_col])
        except Exception as e:
            logger.warning(f"Error adding compliance column - {e}")
        return df

    @staticmethod
    def create_compliance_sheet(final_data_dict, compliance_cols, total_columns_criteria):
        logger.info("Calculating overall compliance...")
        c_df = pd.DataFrame()
        try:
            c_df['client'] = [METADATA['client']] * len(final_data_dict)
            c_df['site'] = [METADATA['site']] * len(final_data_dict)
            c_df['project_name'] = [METADATA['project_name']] * len(final_data_dict)
            final_compliance_list = []
            timestamp_data = []
            for idx, value in final_data_dict.items():
                data_list = []
                for k, v in value.items():
                    if k in compliance_cols:
                        data_list.append(v)
                    else:
                        timestamp_data.append(v)
                count = data_list.count(True)
                if count >= total_columns_criteria:
                    final_compliance_list.append(1)
                else:
                    final_compliance_list.append(0)
            c_df['time'] = timestamp_data
            c_df['compliance'] = final_compliance_list
        except Exception as e:
            logger.warning(f'Error - {e}')
        return c_df

    def start_calculation(self, all_timestamps):
        all_dfs = []
        parameter_wise_dfs = []
        for i in all_timestamps:
            start_timestamp = i['start']
            end_timestamp = i['end']
            start_time = datetime.fromtimestamp(start_timestamp//1000)
            end_time = datetime.fromtimestamp(end_timestamp//1000)
            logger.info(f"Calculating for {start_time} to {end_time}")
            df = self._dp_.get_data(start_timestamp, end_timestamp)
            total_cols = len(df.columns) - 2
            required_total_cols = len(self.tags_data) * 3
            if total_cols != required_total_cols:
                logger.warning(f"No Data for {start_time} to {end_time}")
            else:
                df.dropna(inplace=True)
                compliance_cols = []
                for column, column_data in self.tags_data.items():
                    df = self.add_compliance_column(df, column, f'{column}_live', f'{column}_upper', f'{column}_lower')
                    compliance_cols.append(f'{column}_compliance')
                # df.to_csv('r5-parameter-wise-compliance.csv', index=False)
                parameter_wise_dfs.append(df)
                total_columns_criteria = int((METADATA['compliance_percentage'] / 100) * len(compliance_cols))
                logger.info(f"Need {total_columns_criteria} from {len(compliance_cols)} columns to satisfy the "
                            f"compliance")
                rq_cols = compliance_cols.copy()
                rq_cols.append('timestamp')
                df = df[rq_cols]
                final_data_dict = df.to_dict(orient='index')

                df_final = self.create_compliance_sheet(final_data_dict, compliance_cols, total_columns_criteria)
                all_dfs.append(df_final)

        if all_dfs:
            logger.info("Combining the Data")
            final_df = pd.concat(all_dfs)
            final_df.to_csv('r5-overall-compliance.csv', index=False)
        if parameter_wise_dfs:
            params_df = pd.concat(parameter_wise_dfs)

            params_df.to_csv('r5-parameter-wise-compliance.csv', index=False)
