import asyncio
import logging
import math
from copy import deepcopy
from datetime import datetime, timedelta
from pprint import pformat
from typing import Dict, Tuple

import httpx

from scripts.constants import StatusCodes
from scripts.errors import KairosDBError, ErrorMessages
from scripts.utils.common_utils import CommonUtils
from . import KairosConstants

kairos_const = KairosConstants()
httpx_timeout = httpx.Timeout(30, read=None)


class KairosQuery:
    def __init__(self, url):
        self.url = url
        self.common_util = CommonUtils()

    def query(self, query_data: dict) -> dict:
        """
        An API Interface to query Kairos REST API with a standard query given as input
        :param query_data: query json/dictionary as given in
        documentation: https://kairosdb.github.io/docs/restapi/QueryMetrics.html#id3
        :return: Dictionary with Data, if available
        """
        try:
            with httpx.Client() as client:
                logging.debug(pformat(query_data))
                r = client.post(
                    self.url + kairos_const.kairos_query_api,
                    json=query_data,
                    timeout=httpx_timeout
                )
                if r.status_code not in StatusCodes.SUCCESS_CODES:
                    logging.error(f"KAIROS RETURNED {r.status_code}")
                    raise KairosDBError
                logging.debug(f"status code: {r.status_code}")
                self.validate_results(r.json())
                return r.json()
        except Exception as e:
            logging.exception(e)
            raise

    async def query_async(self, query_data: dict, client=None) -> dict:
        """
        An API Interface to query Kairos REST API with a standard query given as input
        :param client: httpx.AsyncClient
        :param query_data: query json/dictionary as given in
        documentation: https://kairosdb.github.io/docs/restapi/QueryMetrics.html#id3
        :return: Dictionary with Data, if available
        """
        try:
            if client:
                return await self.perform_query(query_data, client)
            else:
                async with httpx.AsyncClient() as client:
                    return await self.perform_query(query_data, client)

        except Exception as e:
            logging.exception(e)
            return kairos_const.kairos_empty_result

    async def perform_query(self, query_data: dict, client):
        r = client.post(self.url + kairos_const.kairos_query_api, json=query_data,
                        timeout=httpx_timeout)
        logging.debug(f"status code: {r.status_code}, elapsed time: {r.elapsed}")
        self.validate_results(r.json())
        return r.json()

    @staticmethod
    def validate_results(response) -> None:
        """
        Validates the response from Kairos. If errors are found, it raises an Error.
        """
        if not isinstance(response, dict):
            raise KairosDBError(ErrorMessages.K_ERROR2)
        if "errors" in response:
            logging.error(f"Kairos returned with an error: {response.get('errors')}")
            raise KairosDBError(ErrorMessages.K_ERROR2)

    @staticmethod
    def get_timedelta(value):
        """ input has keys value, unit. common inputs noted start_relative, end_relative """
        seconds = int(value['value']) * kairos_const.SECONDS_IN_UNIT[value['unit']]
        return timedelta(seconds=seconds)

    def get_needed_absolute_time_range(self, time_range, now=None):
        """
        Create date-times from Kairos timestamp data.
        :param time_range: dict, containing Kairos timestamp data: keys {start,end}_{relative,absolute}.
        :param now: datetime.datetime, optional. set to remove drift in time during execution.
        :return: 2-tuple (start, end), both datetime.datetime. end may be NoneType.
        """
        if not now:
            now = datetime.now()
        if time_range.get('start_absolute'):
            start = datetime.fromtimestamp(int(time_range['start_absolute']) / 1000)
        else:
            td = self.get_timedelta(time_range.get('start_relative'))
            start = now - td

        if time_range.get('end_absolute'):
            end = datetime.fromtimestamp(int(time_range['end_absolute']) / 1000)
        elif time_range.get('end_relative'):
            td = self.get_timedelta(time_range.get('end_relative'))
            end = now - td
        else:
            end = None
        return start, end

    def get_chunked_time_ranges(self, time_range):
        """
        Given a long kairos range, return N timestamp pairs so we can parallelize COLD calls (new->old).
        This implements up to second precision.
        :param time_range: dict, generated by populate_time_range containing kairos-formatted keys.
        :return: list of 2-tuples of datetime.datetime
        """
        chunk_length = 3600  # 1 hour default
        num_chunks = 10
        now = datetime.now()

        start_time, end_time = self.get_needed_absolute_time_range(time_range, now)
        if not end_time:
            end_time = now
        start_time = start_time.replace(microsecond=0)
        end_time = end_time.replace(microsecond=0)

        elapsed_secs = (end_time - start_time).total_seconds()
        if elapsed_secs <= chunk_length:
            return (self.common_util.convert_to_timestamp(start_time), self.common_util.convert_to_timestamp(end_time)),
        else:
            # need to increase chunk length to fit into max chunks
            if elapsed_secs > chunk_length * num_chunks:
                chunk_length = int(elapsed_secs / num_chunks)
            # need to downsize max chunks because we can use them of size chunk_length
            else:
                num_chunks = int(math.ceil(elapsed_secs / chunk_length))

            chunks = []
            length_td = timedelta(seconds=chunk_length)
            # end_time is mutated below here
            for _ in range(num_chunks):
                start = end_time - length_td

                # Sanity check: make sure we limit the earliest chunk if it's partial
                if start < start_time:
                    start = start_time
                # Add an offset so the chunks don't overlap.
                start += timedelta(seconds=1)
                chunks.append(
                    (self.common_util.convert_to_timestamp(start), self.common_util.convert_to_timestamp(end_time)))
                end_time -= length_td
            return chunks

    @staticmethod
    async def form_query(query: Dict, start_time: int, end_time: int):
        """
        Helper function for forming a query after chunking
        """
        __new_query__ = deepcopy(query)
        __new_query__.update(
            {
                kairos_const.kairos_time_keys[0]: start_time,
                kairos_const.kairos_time_keys[1]: end_time
            }
        )
        return __new_query__

    async def parallel_query(self, query) -> Tuple:
        """
        A helper function that has built-in algorithm to split queries into a given number of chucks.
        The number of chunks are configurable; some aggregations may not work well when using this.

        The primary purpose of this function is to query large amount of data in parallel to reduce
        query time and load.
        """
        logging.debug(f"kairos query : , {query}")
        time_range = {key: query[key] for key in kairos_const.kairos_time_keys}
        time_chunks = self.get_chunked_time_ranges(time_range)

        queries = asyncio.gather(*[
            self.form_query(
                query, start_time=i[0],
                end_time=i[1]) for i in time_chunks
        ])
        async with httpx.AsyncClient() as client:
            results = await asyncio.gather(*[self.query_async(i, client) for i in queries])
        return results
