import os
import socket
import time
from copy import deepcopy
from datetime import datetime, timedelta

import httpx
import pytz
from dateutil import parser
from fastapi import Request

from scripts.config.app_configurations import KafkaConf, PathToServices
from scripts.constants import CommonKeys, CommonConstants
from scripts.constants.api import EventsEndPoints
from scripts.constants.app_constants import CommonStatusCode
from scripts.constants.date_constants import ui_time_format_data
from scripts.core.engine.task_engine import TaskEngine
from scripts.core.schemas.forms import CustomActionsModel
from scripts.core.schemas.other_schemas import ExternRequest
from scripts.db import mongo_client, TaskInstanceData, Trigger, TaskInstance
from scripts.db.mongo.ilens_assistant.collections.logbook import LogbookInfo
from scripts.db.mongo.ilens_configuration.aggregations.config_aggregate import ConfigAggregate
from scripts.db.mongo.ilens_configuration.collections.customer_projects import CustomerProjects
from scripts.db.mongo.ilens_configuration.collections.lookup_table import LookupTable
from scripts.db.mongo.ilens_configuration.collections.shifts import Shifts
from scripts.db.mongo.ilens_configuration.collections.unique_id import UniqueIdSchema, UniqueId
from scripts.db.mongo.ilens_configuration.collections.user import User
from scripts.db.mongo.ilens_configuration.collections.user_project import UserProject
from scripts.logging.logging import logger
from scripts.utils.ilens_publish_data import KairosWriter


class CommonUtils(CommonKeys):
    def __init__(self, project_id=None):
        self.user_conn = User(mongo_client)
        self.user_proj = UserProject(mongo_client)
        self.unique_con = UniqueId(mongo_client, project_id=project_id)
        self.events_api = f"{PathToServices.ILENS_EVENTS}{EventsEndPoints.api_create_event}"
        self.task_engine = TaskEngine(project_id=project_id)
        self.logbook_conn = LogbookInfo(mongo_client=mongo_client, project_id=project_id)
        self.task_inst_data = TaskInstanceData(mongo_client, project_id=project_id)
        self.trigger_conn = Trigger(mongo_client, project_id=project_id)
        self.task_instance_conn = TaskInstance(mongo_client, project_id=project_id)
        self.customer_projects_con = CustomerProjects(mongo_client=mongo_client)
        self.config_aggregate = ConfigAggregate()
        self.lookup_data_conn = LookupTable(mongo_client, project_id=project_id)
        self.default_code = os.environ.get("DEFAULT_EVENT_CODE", "")

    @staticmethod
    def get_time_now():
        return time.time() * 1000

    @staticmethod
    def get_ip_of_user():
        hostname = socket.gethostname()
        return socket.gethostbyname(hostname)

    @staticmethod
    def meta_composer(user_id, is_update: bool = False):
        if is_update:
            meta_dict = dict(updated_by=user_id,
                             updated_at=int(time.time()))
            return meta_dict
        meta_dict = dict(created_by=user_id,
                         created_at=int(time.time()))
        return meta_dict

    @staticmethod
    def get_time_in_ms():
        return int(time.time() * 1000)

    def get_user_roles(self, user_id):
        user_rec = self.user_conn.find_user(user_id)
        user_rec = user_rec if bool(user_rec) else {}
        return user_rec.get("userrole", [])

    @staticmethod
    def get_time_by_ts(timestamp, timezone, time_format=None):
        if time_format:
            return str(datetime.fromtimestamp(timestamp, pytz.timezone(timezone)).strftime(time_format))
        return datetime.fromtimestamp(timestamp, pytz.timezone(timezone))

    @staticmethod
    def convert_str_to_ts(_date, _time, _format, tz):
        localized_tz = pytz.timezone(tz)
        datetime_with_tz = datetime.strptime(f"{_date} {_time}", _format)
        return int(localized_tz.localize(datetime_with_tz).timestamp()) * 1000

    def get_user_meta(self, user_id=None, check_flag=False):
        data_for_meta = {}
        if check_flag:
            data_for_meta[self.KEY_CREATED_BY] = user_id
            data_for_meta[self.KEY_CREATED_TIME] = int(time.time() * 1000)
        data_for_meta[self.KEY_UPDATED_AT] = user_id
        data_for_meta[self.KEY_LAST_UPDATED_TIME] = int(time.time() * 1000)
        return data_for_meta

    def get_user_name_from_id(self, user_id):
        user = self.user_conn.find_user(user_id)
        return user.get("name", "") if bool(user) else ""

    @staticmethod
    def time_zone_converter(epoch_ts, tz, to_format=None):
        date = datetime.fromtimestamp(epoch_ts // 1000, tz=pytz.timezone(tz))
        return str(date.strftime(to_format)) if to_format else date

    @staticmethod
    def add_days_to_epoch(days, ts, tz):
        current_datetime = datetime.fromtimestamp(ts // 1000, pytz.timezone(tz))
        new = current_datetime + timedelta(days=days)
        return int(new.timestamp() * 1000)

    @staticmethod
    def get_next_date(_date, _format, num):
        next_date = datetime.strptime(_date, ui_time_format_data[_format]) + timedelta(days=num)
        return next_date.strftime(ui_time_format_data[_format])

    @staticmethod
    def convert_trigger_date_to_epoch(triggers, request_data=None):
        if utc_date := triggers.get("date"):
            epoch_trigger = parser.parse(utc_date).timestamp() * 1000
            if request_data:
                request_data.date = epoch_trigger
            return epoch_trigger
        for each in triggers.keys():
            if "date" in each:
                utc_date = triggers[each]
                epoch_trigger = parser.parse(utc_date).timestamp() * 1000
                if request_data:
                    request_data.date = epoch_trigger
                return epoch_trigger
        return False

    def get_trigger_in_epoch(self, triggers, submitted_data, field_props):
        if epoch_value := self.convert_trigger_date_to_epoch(triggers):
            return epoch_value
        if not all([submitted_data, "data" in submitted_data, submitted_data.get("data")]):
            return False
        trigger_prop_dict = {x: y for x, y in field_props.items() if
                             "triggerOnChange" in y.keys() and y["triggerOnChange"] == "true"}

        for each in submitted_data["data"].keys():
            if "date" in each and each in trigger_prop_dict:
                utc_date = submitted_data["data"][each]
                return parser.parse(utc_date).timestamp() * 1000
        return False

    @staticmethod
    def get_hierarchy_name(input_data: str, site_data: dict):
        final_response = str()
        try:
            hierarchy_type = input_data.split("_")[0]
            hierarchy_id = f"{hierarchy_type}_id"
            hierarchy_name = f"{hierarchy_type}_name"
            if hierarchy_type == "site":
                return_data = site_data.get("site_name", str())
                return return_data
            return_data = deepcopy(site_data.get(hierarchy_type))
            for each_data in return_data:
                if each_data[hierarchy_id] == input_data:
                    final_response = each_data.get(hierarchy_name)
                    break
            return final_response
        except Exception as e:
            logger.error(f"Error while fetching hierarchy details:{str(e)}")
        return final_response

    @staticmethod
    def auditing_with_kafka(audits):
        try:
            kairos_writer = KairosWriter()
            logger.debug(f"Data going to kafka writer in audit logs, len: {len(audits)}")
            kairos_writer.audit_data(audits, KafkaConf.audit_topic)
            logger.debug("Audited data successfully")
        except Exception as e:
            logger.error("Failed in auditing_with_kafka", e)

    @staticmethod
    def publish_data_to_kafka(tag_dict, project_id):
        try:
            kairos_writer = KairosWriter()
            logger.debug(f"Data going to kafka writer, len: {len(tag_dict)}")
            midnight = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0).timestamp() * 1000
            # To avoid rule execution on backdated timestamps
            backdated = {ts: data for ts, data in tag_dict.items() if ts < midnight and data}
            non_backdated = {ts: data for ts, data in tag_dict.items() if ts >= midnight and data}
            if backdated:
                kairos_writer.write_data(backdated, KafkaConf.backdated_topic, project_id)
                logger.debug("Published to backdated topic successfully")
            if non_backdated:
                kairos_writer.write_data(non_backdated, KafkaConf.topic, project_id)
                logger.debug("Published to non-backdated topic successfully")
        except Exception as e:
            logger.error("Failed in publish_data_to_kafka", e)

    @staticmethod
    def get_updated_reference_data(records: list):
        try:
            if bool(records):
                if len(records) == 1:
                    return records[0]
                final_dict = records[0]
                temp_dict = {}
                for _data in records[1:]:
                    temp_dict |= _data["data"]
                temp_dict.update(final_dict["data"])
                final_dict["data"] = temp_dict
                return final_dict
            return {}
        except Exception as e:
            logger.exception(f"Exception occurred while fetching the reference data {e}")
            return {}

    def get_user_roles_by_project_id(self, user_id, project_id):
        user_rec = self.user_conn.find_user_by_project_id(user_id=user_id, project_id=project_id)

        user_rec = user_rec if bool(user_rec) else {}
        if not user_rec:
            user_rec = self.user_proj.fetch_user_project(user_id=user_id, project_id=project_id)
            user_rec = user_rec if bool(user_rec) else {}
        return user_rec.get("userrole", [])

    def get_next_id(self, _param):
        my_dict = UniqueIdSchema(key=_param)
        my_doc = self.unique_con.find_one_record(key=_param)
        if not my_doc.id:
            my_dict.id = "100"
            return self.unique_con.insert_record(my_dict)
        else:
            count_value = str(int(my_doc.id) + 1)
            my_dict.id = count_value
            return self.unique_con.update_record(my_dict)

    @staticmethod
    def get_iso_format(timestamp, timezone='UTC', timeformat=CommonConstants.__iso_format__):
        return datetime.fromtimestamp(timestamp, pytz.timezone(timezone)).strftime(
            timeformat) if timeformat else datetime.fromtimestamp(timestamp, pytz.timezone(timezone))

    def get_shift(self, project_id, from_time: str, end_time: str):
        try:
            shifts_con = Shifts(mongo_client, project_id)
            shift_data = shifts_con.find_shifts_by_project_id(project_id=project_id)
            shift = ""
            for each in shift_data:
                if activities := each.get("activities", []):
                    shift_start_time = activities[0].get("shift_start")
                    shift_end_time = activities[0].get("shift_end")
                    if self.in_shift(from_time, end_time, shift_start_time, shift_end_time):
                        shift = each.get("shift_name", "")
                        return shift
            return shift
        except Exception as e:
            logger.exception(e)
            raise

    @staticmethod
    def in_shift(*args):
        now = datetime.now()
        args_modified = []
        for each in args:
            if each:
                hour_minute = each.split(":")
                args_modified.append(now.replace(hour=int(hour_minute[0]), minute=int(hour_minute[1]), second=0,
                                                 microsecond=0).timestamp())
        if len(args_modified) == 4:
            if args_modified[0] >= args_modified[2] and args_modified[1] <= args_modified[3]:
                return True
        return False

    def trigger_create_event(self, request_data, task_data, user_id, request_obj: Request):
        try:
            site_templates = self.customer_projects_con.get_project_data_by_aggregate(
                self.config_aggregate.get_project_template(request_data.get('project_id')))
            site_templates = site_templates[0].get("data") if bool(site_templates) else []
            hierarchy_id_str = ""
            role_id = self.get_user_roles_by_project_id(user_id=user_id, project_id=request_data.get("project_id"))
            user_role = role_id[0]
            actions = self.get_actions(workflow_id=task_data.get('associated_workflow_id'),
                                       workflow_version=task_data.get('associated_workflow_version'),
                                       user_role=user_role, on_click=request_data.get('type'))
            event_code = self.default_code
            for state in actions:
                if state["action_type"] == 'event':
                    event_code = state.get('event_codes', event_code)
            logbook_data = self.logbook_conn.find_by_id(task_data.get("logbook_id"))
            if hierarchy := self.task_engine.get_hierarchy(logbook_data.dict(), task_data):
                hierarchy_id_str = self.task_engine.get_hierarchy_string(hierarchy, site_templates)
            event_time = datetime.now().astimezone(pytz.timezone(request_data.get("tz"))).strftime("%Y-%m-%d %H:%M:%S")
            event_dict = dict(asset_id=hierarchy_id_str, processes=','.join(logbook_data.logbook_tags),
                              user_action=request_data.get("type"),
                              user_trigger_data=dict(task_id=request_data.get("task_id"),
                                                     logbook_id=logbook_data.logbook_id,
                                                     submitted_data=request_data.get('submitted_data', {})),
                              query_key='task_id')
            event_final_dict = dict(event_code=event_code, event_time=event_time, event_src_type="human_events",
                                    project_id=request_data.get("project_id"), event_table="process_human_events",
                                    event_data=event_dict)
            event_payload = dict(data=event_final_dict, project_id=request_data.get("project_id"))
            with httpx.Client() as client:
                resp = client.post(url=self.events_api, cookies=request_obj.cookies, json=event_payload, timeout=15)
            if resp.status_code in CommonStatusCode.SUCCESS_CODES:
                json_res = resp.json()
                logger.info(
                    f"Resp Message:{resp.status_code} \nCookies: {request_obj.cookies} \nRest API: {self.events_api}")
                return json_res
            elif resp.status_code == 404:
                logger.info(f"Module not found: {self.events_api}")
            elif resp.status_code == 401:
                logger.info(f"Unauthorized to execute request on {self.events_api}")
            logger.info(
                f"Resp Message:{resp.status_code} \nCookies: {request_obj.cookies} \nRest API: {self.events_api}")
        except Exception as e:
            logger.exception(e)
            raise

    def get_actions(self, workflow_id, workflow_version, user_role, on_click):
        trigger_data = self.trigger_conn.fetch_by_id(workflow_id=workflow_id,
                                                     workflow_version=workflow_version,
                                                     role=user_role,
                                                     on_click=on_click)
        actions = trigger_data.actions
        return actions

    @staticmethod
    def hit_external_service(api_url, payload=None, request_cookies=None,
                             timeout=int(os.environ.get("REQUEST_TIMEOUT", default=30)), method="post", params=None,
                             auth=None):
        try:
            logger.info(f"Inside function to hit external services\nURL - {api_url}")
            payload_json = ExternRequest(url=api_url, timeout=timeout, cookies=request_cookies, params=params,
                                         auth=auth)
            payload_json = payload_json.dict(exclude_none=True)
            if payload:
                payload_json.update(json=payload)
            with httpx.Client() as client:
                for _ in range(3):
                    method_type = getattr(client, method)
                    resp = method_type(**payload_json)
                    logger.info(f"Resp Code:{resp.status_code}")
                    if resp.status_code in CommonStatusCode.SUCCESS_CODES:
                        return resp.json()
                    elif resp.status_code == 404:
                        logger.info(f"Module not found: {api_url}")
                        raise ModuleNotFoundError
                    elif resp.status_code == 401:
                        logger.info(f"Unauthorized to execute request on {api_url}")
                    logger.info(f"Resp Message:{resp.status_code} \nCookies: {request_cookies} \nRest API: {api_url}")

                time.sleep(3)
        except Exception as e:
            logger.error(e)
            raise

    @staticmethod
    def get_task_time(task_time, custom_model: CustomActionsModel, task_property_name, task_type="start"):
        required_task_time = None
        try:
            task_time = task_time / 1000 if task_time else time.time()
            if task_property_name in custom_model.submitted_data:
                try:
                    required_task_time = parser.parse(timestr=custom_model.submitted_data.get(task_property_name))
                    if task_type.lower() == "start":
                        if required_task_time <= datetime.fromtimestamp(task_time,
                                                                        tz=pytz.timezone(custom_model.tz)):
                            logger.info(f"OEE Start Time is less than Task Start Time")
                            required_task_time = task_time

                except Exception as e:
                    logger.info(f"Exception occurred while converting datetime {e.args}")
                    required_task_time = None
        except Exception as e:
            logger.exception(f"Exception occurred while fetching the task creation time {e.args}")
        if not required_task_time:
            required_task_time = datetime.fromtimestamp(task_time, tz=pytz.timezone(custom_model.tz))
        return required_task_time
