import json
from datetime import datetime

import pandas as pd
import pytz
from sqlalchemy.orm import Session

from scripts.config import DBConf
from scripts.constants import ResponseCodes, CommonConstants, TagCategoryConstants
from scripts.core.engine.oee_calculator import OEEEngine, OEETagFinder
from scripts.core.handlers.tag_handler import TagHierarchyHandler
from scripts.db.mongo.schema.tag_hierarchy import GetTagsLists
from scripts.db.psql.oee_discrete import DiscreteOEE
from scripts.db.redis_connections import oee_production_db
from scripts.errors import DataNotFound
from scripts.logging import logger
from scripts.schemas.batch_oee import OEEDataInsertRequest, BatchOEEData, OEEDataSaveRequest
from scripts.schemas.response_models import DefaultResponse
from scripts.utils.common_utils import CommonUtils
from scripts.utils.kairos_db_util import BaseQuery
from scripts.utils.kairos_db_util.df_formation_util import create_kairos_df
from scripts.utils.kairos_db_util.query_kairos import KairosQuery

oee_engine = OEEEngine()


class CalculateBatchOEEHandler:

    def __init__(self, project_id=None):
        self.common_util = CommonUtils()
        self.base_query = BaseQuery()
        self.oee_tag_finder = OEETagFinder()
        self.tag_hierarchy_handler = TagHierarchyHandler(project_id=project_id)

    def calculate_oee(self, db, request_data: OEEDataInsertRequest):
        table_obj = DiscreteOEE(db=db)
        try:
            record_presence = table_obj.get_oee_data_by_reference_id(reference_id=request_data.reference_id,
                                                                     hierarchy=request_data.hierarchy,
                                                                     project_id=request_data.project_id)
            request_data.total_units, request_data.reject_units = self.get_data_for_tags(input_data=request_data)
            redis_key = f"{request_data.project_id}${request_data.reference_id}"
            if not record_presence:
                if not request_data.prod_start_time:
                    request_data.prod_start_time = datetime.now().astimezone(
                        tz=pytz.timezone(request_data.tz)).isoformat()
                else:
                    request_data.prod_start_time = datetime.strptime(request_data.prod_start_time,
                                                                     CommonConstants.USER_META_TIME_FORMAT).astimezone(
                        tz=pytz.timezone(request_data.tz)).isoformat()
                if request_data.prod_end_time:
                    request_data.prod_end_time = datetime.strptime(request_data.prod_end_time,
                                                                   CommonConstants.USER_META_TIME_FORMAT).astimezone(
                        tz=pytz.timezone(request_data.tz)).isoformat()
                request_data = OEEDataSaveRequest(**request_data.dict(exclude_none=True))
                request_data.downtime = self.common_util.get_downtime_details_by_hierarchy(
                    hierarchy=request_data.hierarchy, project_id=request_data.project_id)
                oee_calculation = oee_engine.start_batch_oee_calc(request_data=request_data)
                self.save_oee_data(oee_calculation, db)
                if not request_data.prod_end_time:
                    oee_production_db.set(name=redis_key,
                                          value=json.dumps(BatchOEEData(**request_data.dict(exclude_none=True))))
                response = DefaultResponse(
                    status=ResponseCodes.SUCCESS,
                    data=oee_calculation,
                    message="OEE saved Successfully",
                )
                return response
            status = self.update_oee_data(oee_data=OEEDataSaveRequest(**request_data.dict(exclude_none=True)),
                                          old_record=record_presence, db=db)
            if status:
                if request_data.prod_end_time:
                    oee_production_db.delete(redis_key)
                else:
                    oee_production_db.set(name=redis_key,
                                          value=json.dumps(record_presence))
            response = DefaultResponse(
                status=ResponseCodes.SUCCESS,
                data=status,
                message="OEE updated Successfully",
            )
            return response
        except Exception as e:
            logger.exception(f"Exception while saving oee record: {e}")
            raise e

    @staticmethod
    def save_oee_data(oee_data: BatchOEEData, db: Session):
        table_obj = DiscreteOEE(db=db)
        try:
            table_obj.insert_one(table=table_obj.table, insert_json=oee_data.dict())
            return True
        except Exception as e:
            raise e

    @staticmethod
    def update_oee_data(oee_data: OEEDataSaveRequest, old_record: dict, db: Session):
        table_obj = DiscreteOEE(db=db)
        try:
            old_record.update(**oee_data.dict(exclude_none=True))
            oee_calculation = oee_engine.start_batch_oee_calc(
                request_data=OEEDataSaveRequest(**old_record))
            table_obj.update(update_json=oee_calculation.dict(), table=table_obj.table, filters=[{'expression': 'eq',
                                                                                                  'column': table_obj.table.project_id,
                                                                                                  'value': oee_data.project_id
                                                                                                  },
                                                                                                 {'expression': 'eq',
                                                                                                  'column': table_obj.table.reference_id,
                                                                                                  'value': oee_data.reference_id
                                                                                                  },
                                                                                                 {'expression': 'eq',
                                                                                                  'column': table_obj.table.hierarchy,
                                                                                                  'value': oee_data.hierarchy
                                                                                                  }], update_one=True)
            return True
        except Exception as e:
            raise e

    def get_data_for_tags(self, input_data: OEEDataInsertRequest):
        total_units_value = 0
        reject_units_value = 0
        try:
            start_epoch = int(
                datetime.strptime(input_data.prod_start_time, CommonConstants.USER_META_TIME_FORMAT).astimezone(
                    tz=pytz.timezone(input_data.tz)).timestamp()) * 1000
            end_epoch = int(
                datetime.strptime(input_data.prod_end_time, CommonConstants.USER_META_TIME_FORMAT).astimezone(
                    tz=pytz.timezone(input_data.tz)).timestamp()) * 1000
            hierarchy_tags = self.tag_hierarchy_handler.get_tags_list_by_hierarchy(GetTagsLists(**input_data.dict()))
            total_units_tag_id = self.oee_tag_finder.get_total_units_tag_id(input_data=hierarchy_tags)
            reject_units_tag_id = self.oee_tag_finder.get_reject_units_tag_id(input_data=hierarchy_tags)
            kairos_util = KairosQuery(url=DBConf.KAIROS_URL)
            data = kairos_util.query(
                self.base_query.form_generic_query(tags_list=[total_units_tag_id, reject_units_tag_id],
                                                   project_id=input_data.project_id,
                                                   start_epoch=start_epoch, end_epoch=end_epoch))
            master_df = pd.DataFrame()
            data = [data] if not isinstance(data, list) else data
            for each_data in data:
                master_df = create_kairos_df(
                    master_df=master_df,
                    response_data=each_data,
                    tags_list=[total_units_tag_id, reject_units_tag_id],
                    group_by_tags=[total_units_tag_id, reject_units_tag_id, DBConf.KAIROS_DEFAULT_FULL_TAG],
                    tz=input_data.tz
                )
            if master_df.empty:
                raise DataNotFound
            master_df_columns = list(master_df.columns)
            if f'{total_units_tag_id}_diff' not in master_df_columns:
                return total_units_value, reject_units_value
            total_units_value = master_df[f'{total_units_tag_id}_diff'].sum()
            if f'{reject_units_tag_id}_diff' in master_df_columns:
                reject_units_value = master_df[f'{reject_units_tag_id}_diff'].sum()
            return total_units_value, reject_units_value
        except Exception as e:
            logger.exception(f'Exception occurred while fetching tag details{e.args}')
            raise
