import json

import pandas as pd
import requests
import yaml
from loguru import logger


class ComplianceDataPuller(object):
    def __init__(self, conf=None,
                 query_name='query',
                 request_url=None,
                 absolute_time=None,
                 preset_timestamp=False,
                 preset_timestamp_value=None, validation=False):
        self.request_url = request_url
        self.conf = conf[query_name]
        self.payload = self.conf["query"]
        self.column_rename = self.conf["column_renamer"]
        self.validation = validation
        if absolute_time is not None:
            if "start_relative" in self.payload:
                del self.payload["start_relative"]
            if "end_relative" in self.payload:
                del self.payload["end_relative"]
            self.payload["start_absolute"] = absolute_time["start_absolute"]
            self.payload["end_absolute"] = absolute_time["end_absolute"]
        self.payload = json.dumps(self.payload)
        self.preset_timestamp = preset_timestamp
        self.preset_timestamp_value = preset_timestamp_value

    def get_data(self):
        logger.info("Data for the parameters being pulled from Kairos Database")
        response_data = requests.post(url=self.request_url, data=self.payload).json()
        output_data = response_data["queries"]
        logger.debug("Data pull complete")
        df_final = pd.DataFrame()
        for i in range(len(output_data)):
            grouped_output_data = output_data[i]["results"]
            for each_grouped_data in grouped_output_data:
                value = (each_grouped_data["values"])
                if len(value) > 0:
                    tag_id = None
                    try:
                        tag_id = each_grouped_data["group_by"][0]["group"]["c3"]
                        logger.debug("Renamed {} to {} in Data".format(tag_id,
                                                                       self.column_rename[tag_id]))
                        column_name = self.column_rename[tag_id]
                    except KeyError as key_error:
                        logger.debug("Column Renaming Logic not found for {}".format(tag_id))
                        logger.warning(f"{key_error}")
                        column_name = tag_id
                    df_column_data = pd.DataFrame(data=value, columns=["timestamp", column_name])
                    if self.preset_timestamp:
                        df_column_data["timestamp"] = self.preset_timestamp_value
                    if df_final.empty:
                        df_final = df_column_data
                    else:
                        df_final = df_final.merge(df_column_data, how="outer", left_on="timestamp",
                                                  right_on="timestamp")
                else:
                    logger.warning("No data for a tag")
        df_final["datetime"] = pd.to_datetime(df_final['timestamp'],
                                              unit="ms").dt.tz_localize('UTC').dt.tz_convert('Asia/Kolkata')
        if self.validation:
            df_final = self.dataframe_validator(df_final, list(self.column_rename.values()))
        logger.debug("Final number of columns : {}".format(str(len(list(df_final.columns)))))
        return df_final

    @staticmethod
    def shift_identifier(row):
        if (row["timestamp"].hour <= 6) or (row["timestamp"].hour > 14):
            return "B"
        elif (row["timestamp"].hour <= 14) or (row["timestamp"].hour > 22):
            return "C"
        else:
            return "A"

    @staticmethod
    def dataframe_validator(df_data, columns):
        for column in columns:
            if column not in df_data.columns:
                df_data[column] = 0
        return df_data

    @staticmethod
    def shift_date_identifier(row):
        if row["shift"] in ["B", "C"]:
            return row["timestamp"]
        else:
            if row["timestamp"].hour <= 22:
                return row["timestamp"] + pd.Timedelta(days=0)
            else:
                return row["timestamp"] + pd.Timedelta(days=-1)


class DataPuller(object):
    def __init__(self, db_host, data_config, payload, absolute_time=None, optional_payload=None):
        self.optional_payload = optional_payload
        _config_file = data_config
        with open(_config_file, 'r') as _cf:
            self.conf = yaml.full_load(_cf)
        self.db_host_url = db_host
        self.request_url = "{kairos_host}/api/v1/datapoints/query".format(kairos_host=self.db_host_url)
        self.payload = self.conf[payload]
        self.column_rename = self.conf["column_renamer"]
        if absolute_time is not None:
            if "start_relative" in self.payload:
                del self.payload["start_relative"]
            if "end_relative" in self.payload:
                del self.payload["end_relative"]
            self.payload["start_absolute"] = absolute_time["start_absolute"]
            self.payload["end_absolute"] = absolute_time["end_absolute"]
        self.payload = json.dumps(self.payload)

    def get_data(self):
        logger.info("Data for the parameters being pulled from Kairos Database")
        response_data = requests.post(url=self.request_url, data=self.payload).json()
        output_data = response_data["queries"]
        logger.debug("Data pull complete")
        df_final = pd.DataFrame()
        for i in range(len(output_data)):
            grouped_output_data = output_data[i]["results"]
            for each_grouped_data in grouped_output_data:
                value = (each_grouped_data["values"])
                if len(value) > 0:
                    tag_id = None
                    try:
                        tag_id = each_grouped_data["group_by"][0]["group"]["c3"]
                        logger.debug("Renamed {} to {} in Data".format(tag_id, self.column_rename[tag_id]))
                        column_name = self.column_rename[tag_id]
                    except KeyError as ke:
                        logger.debug("Column Renaming Logic not found for {}".format(tag_id))
                        logger.warning(f"{ke}")
                        column_name = tag_id
                    df_column_data = pd.DataFrame(data=value, columns=["timestamp", column_name])
                    if df_final.empty:
                        df_final = df_column_data
                    else:
                        df_final = df_final.merge(df_column_data, how="outer", left_on="timestamp",
                                                  right_on="timestamp")
                else:
                    logger.warning("No data for a tag")
        df_final["epochtime"] = df_final["timestamp"]
        df_final["timestamp"] = pd.to_datetime(df_final['timestamp'], unit="ms").dt.tz_localize('UTC').dt.tz_convert(
            'Asia/Kolkata')
        logger.debug("Final number of columns : {}".format(str(len(list(df_final.columns)))))
        return df_final
