import copy
import json

import pandas as pd
import requests
from loguru import logger

from scripts.constants.app_constants import KairosConstants


class KairosQueryBuilder:
    def __init__(self, inp_dict):
        self.inp_dict = inp_dict

    @staticmethod
    def group_tags(tag_details):
        logger.info("Grouping all the tags having same group_by and aggregators")
        tags_dict = []
        for i in tag_details:
            if len(tag_details) == 0:
                tags_dict.append(i)
            else:
                flag = False
                for j in tags_dict:
                    if (
                        j["group_by"] == i["group_by"]
                        and j["aggregators"] == i["aggregators"]
                    ):
                        j["tag"] = j["tag"] + i["tag"]
                        flag = True
                if not flag:
                    tags_dict.append(i)
        return tags_dict

    @staticmethod
    def check_group_by(temp_dict, k):
        groupby_flag = True
        if k["group_by"]:
            temp_dict["group_by"] = copy.deepcopy(KairosConstants.GROUP_BY)["group_by"]
        else:
            groupby_flag = False
        return temp_dict, groupby_flag

    @staticmethod
    def check_aggregation(temp_dict, k):
        aggregation_flag = True
        if len(k["aggregators"]) > 0:
            temp_dict["aggregators"][0][
                "name"
            ] = KairosConstants.AGGREGATOR_KEYS_MAPPING[k["aggregators"]["name"]]
            temp_dict["aggregators"][0]["sampling"]["value"] = k["aggregators"][
                "sampling_value"
            ]
            temp_dict["aggregators"][0]["sampling"][
                "unit"
            ] = KairosConstants.UNITS_MAPPING[k["aggregators"]["sampling_unit"]]
        else:
            aggregation_flag = False
        return temp_dict, aggregation_flag

    def get_query_type(self, train=True):
        if train:
            # getting the start and end absolute time
            logger.info("Building the Training Query")
            output_query = copy.deepcopy(KairosConstants.TRAINING_QUERY)
            output_query["start_absolute"] = self.inp_dict["start_absolute"]
            output_query["end_absolute"] = self.inp_dict["end_absolute"]
        else:
            logger.info("Building the Live Query")
            output_query = copy.deepcopy(KairosConstants.LIVE_QUERY)
        return output_query

    def build_query(self, train=True):
        output_query = self.get_query_type(train)
        # getting the metric name
        metric_name = self.inp_dict["metric_name"]
        # getting all the tag details
        tag_details = self.inp_dict["tag_details"]
        if len(tag_details) > 0:
            # grouping all the tags whose group_by and aggregators match
            tags_dict = self.group_tags(tag_details)
            for k in tags_dict:
                tags_skeleton = copy.deepcopy(KairosConstants.TAGS)
                tags_skeleton["name"] = metric_name
                temp_dict = {
                    "c3": k["tag"],
                    "aggregators": copy.deepcopy(KairosConstants.AGGREGATORS)[
                        "aggregators"
                    ],
                }
                # adding group by data
                temp_dict, groupby_flag = self.check_group_by(temp_dict, k)
                # adding aggregation data
                temp_dict, aggregation_flag = self.check_aggregation(temp_dict, k)
                if k["aggregators"]["align"] is None:
                    logger.info("No align needed")
                elif k["aggregators"]["align"] in ["Sample", "Start Time", "End Time"]:
                    if k["aggregators"]["align"] in ["Start Time", "End Time"]:
                        temp_dict["aggregators"][0][
                            KairosConstants.ALIGNMENT_MAPPING["Sample"]
                        ] = True
                    temp_dict["aggregators"][0][
                        KairosConstants.ALIGNMENT_MAPPING[k["aggregators"]["align"]]
                    ] = True
                tags_skeleton["tags"]["c3"] = temp_dict["c3"]
                if groupby_flag:
                    tags_skeleton["group_by"] = temp_dict["group_by"]
                if aggregation_flag:
                    tags_skeleton["aggregators"] = temp_dict["aggregators"]
                output_query["metrics"].append(tags_skeleton)

        else:
            logger.info("tag details not found")
        return output_query


class DataPuller(object):
    def __init__(self, db_host, payload, absolute_time=None, optional_payload=None):
        self.db_host_url = db_host
        self.request_url = "{kairos_host}/api/v1/datapoints/query".format(
            kairos_host=self.db_host_url
        )
        self.payload = payload
        self.column_rename = {}
        if absolute_time is not None:
            if "start_relative" in self.payload:
                del self.payload["start_relative"]
            if "end_relative" in self.payload:
                del self.payload["end_relative"]
            self.payload["start_absolute"] = absolute_time["start_absolute"]
            self.payload["end_absolute"] = absolute_time["end_absolute"]
        self.payload = json.dumps(self.payload)

    def get_data(self):
        logger.info("Data for the parameters being pulled from Kairos Database")
        response_data = requests.post(url=self.request_url, data=self.payload).json()
        output_data = response_data["queries"]
        logger.debug("Data pull complete")
        df_final = pd.DataFrame()
        for i in range(len(output_data)):
            grouped_output_data = output_data[i]["results"]
            for each_grouped_data in grouped_output_data:
                value = each_grouped_data["values"]
                tag_id = each_grouped_data["group_by"][0]["group"]["c3"]
                try:
                    logger.debug(
                        "Renamed {} to {} in Data".format(
                            tag_id, self.column_rename[tag_id]
                        )
                    )
                    column_name = self.column_rename[tag_id]
                except KeyError as ke:
                    logger.debug(f"Column Renaming Logic not found for {tag_id} - {ke}")
                    column_name = tag_id
                df_column_data = pd.DataFrame(
                    data=value, columns=["timestamp", column_name]
                )
                if df_final.empty:
                    df_final = df_column_data
                else:
                    df_final = df_final.merge(
                        df_column_data,
                        how="outer",
                        left_on="timestamp",
                        right_on="timestamp",
                    )
        df_final["epochtime"] = df_final["timestamp"]
        df_final["timestamp"] = (
            pd.to_datetime(df_final["timestamp"], unit="ms")
            .dt.tz_localize("UTC")
            .dt.tz_convert("Asia/Kolkata")
        )
        df_final.to_csv("data-upload.csv", index=False)
        df_final["shift"] = df_final.apply(self.shift_identifier, axis=1)
        df_final["date"] = df_final.apply(self.shift_date_identifier, axis=1)
        logger.debug(
            "Final number of columns : {}".format(str(len(list(df_final.columns))))
        )
        df_final.to_csv("data-upload", index=False)
        return df_final

    @staticmethod
    def shift_identifier(row):
        # morning 6 am to afternoon 2 pm is shift A, afternoon 2 pm to evening
        # 10 pm is shift B, evening 10 pm to night
        # 6 am is shift C
        if 6 <= row["timestamp"].hour < 14:
            return "A"
        elif 14 <= row["timestamp"].hour < 22:
            return "B"
        else:
            return "C"

    @staticmethod
    def shift_date_identifier(row):
        if 6 <= row["timestamp"].hour < 14:
            return row["timestamp"].date()
        elif 14 <= row["timestamp"].hour < 22:
            return row["timestamp"].date()
        elif 22 <= row["timestamp"].hour <= 23:
            return row["timestamp"].date() + pd.Timedelta(days=0)
        else:
            return row["timestamp"].date() + pd.Timedelta(days=-1)
