from datetime import datetime

import pytz

from scripts.constants import TagCategoryConstants
from scripts.errors import ErrorCodes
from scripts.logging import logger
from scripts.schemas.batch_oee import OEEDataSaveRequest, BatchOEEData
from scripts.utils.common_utils import CommonUtils


class OEECalculator:
    def __init__(self):
        self.common_utils = CommonUtils()

    @staticmethod
    def calculate_availability(operating_time, planned_prod_time):
        if operating_time > planned_prod_time:
            logger.error(ErrorCodes.ERR001)
            raise ValueError(ErrorCodes.ERR001)
        try:
            return operating_time / planned_prod_time
        except Exception as e:
            logger.exception(e)
            raise

    @staticmethod
    def calculate_performance(units_produced, cycle_time, operating_time):
        try:
            if cycle_time == 0 or operating_time == 0:
                logger.error(ErrorCodes.ERR002)
                raise ValueError(ErrorCodes.ERR002)
            productive_time = units_produced * (1 / cycle_time)
            if productive_time > operating_time:
                logger.error(ErrorCodes.ERR003)
                raise ValueError(ErrorCodes.ERR003)
            return productive_time / operating_time
        except Exception as e:
            logger.exception(e)
            raise

    @staticmethod
    def calculate_productive_time(units_produced, cycle_time):
        try:
            if cycle_time == 0:
                logger.error(ErrorCodes.ERR002)
                raise ValueError(ErrorCodes.ERR002)
            return units_produced * (1 / cycle_time)

        except Exception as e:
            logger.exception(e)
            raise

    @staticmethod
    def calculate_quality(rejected_units, total_units):
        if rejected_units > total_units:
            logger.error(ErrorCodes.ERR004)
            raise ValueError(ErrorCodes.ERR004)
        try:
            return (total_units - rejected_units) / total_units
        except ZeroDivisionError:
            return 0
        except Exception as e:
            logger.exception(e)
            raise

    @staticmethod
    def calculate_oee(availability, performance, quality):
        try:
            return availability * performance * quality
        except Exception as e:
            logger.exception(e)
            raise


class OEELossesCalculator:
    @staticmethod
    def calculate_availability_loss(downtime, available_time):
        return (downtime / available_time) * 100

    @staticmethod
    def calculate_quality_loss(reject_units, cycle_time, available_time):
        return ((reject_units * (1 / cycle_time)) / available_time) * 100

    @staticmethod
    def calculate_performance_loss(
            oee_percentage, availability_loss, quality_loss
    ):
        return 100 - availability_loss - quality_loss - oee_percentage


class OEETagFinder:
    @staticmethod
    def get_total_units_tag_id(input_data: dict, category_name=TagCategoryConstants.TOTAL_UNITS_CATEGORY):
        if not input_data.get(category_name):
            logger.error(ErrorCodes.ERR006)
            raise ValueError(ErrorCodes.ERR006)
        return input_data.get(category_name)

    @staticmethod
    def get_reject_units_tag_id(input_data: dict, category_name=TagCategoryConstants.REJECT_UNITS_CATEGORY):
        if not input_data.get(category_name):
            logger.error(ErrorCodes.ERR007)
            raise ValueError(ErrorCodes.ERR007)
        return input_data.get(category_name)

    @staticmethod
    def get_cycle_time_tag_id(input_data: dict, category_name=TagCategoryConstants.OEE_CYCLE_DESIGN_CATEGORY):
        if not input_data.get(category_name):
            logger.error(ErrorCodes.ERR008)
            raise ValueError(ErrorCodes.ERR008)
        return input_data.get(category_name)


class OEEEngine:
    def __init__(self):
        self.oee_calc = OEECalculator()
        self.oee_loss_calc = OEELossesCalculator()
        self.common_util = CommonUtils()

    def start_batch_oee_calc(
            self,
            request_data: OEEDataSaveRequest
    ) -> BatchOEEData:
        try:
            logger.debug(f"Calculating OEE for {request_data.reference_id}")

            # Start and End time should be in milliseconds since epoch.

            cal_type = self.common_util.get_uom_type(uom_type=request_data.uom)
            duration = self.common_util.get_duration(tz=request_data.tz, meta=request_data.dict(),
                                                     difference=True)
            planned_production_time = self.common_util.get_diff_duration_in_int(
                input_time=duration, return_type=cal_type)
            # operating time is production time
            production_time = planned_production_time - request_data.downtime

            availability = self.oee_calc.calculate_availability(
                operating_time=production_time,
                planned_prod_time=planned_production_time,
            )

            performance = self.oee_calc.calculate_performance(
                units_produced=request_data.total_units,
                operating_time=production_time,
                cycle_time=request_data.cycle_time,
            )

            quality = self.oee_calc.calculate_quality(
                total_units=request_data.total_units,
                rejected_units=request_data.reject_units,
            )

            oee = self.oee_calc.calculate_oee(
                availability=availability,
                performance=performance,
                quality=quality,
            )

            productive_time = self.oee_calc.calculate_productive_time(
                cycle_time=request_data.cycle_time,
                units_produced=request_data.total_units,
            )

            availability_loss = self.oee_loss_calc.calculate_availability_loss(
                downtime=request_data.downtime,
                available_time=planned_production_time,
            )

            quality_loss = self.oee_loss_calc.calculate_quality_loss(
                reject_units=request_data.reject_units,
                available_time=planned_production_time,
                cycle_time=request_data.cycle_time,
            )

            performance_loss = self.oee_loss_calc.calculate_performance_loss(
                oee_percentage=oee * 100,
                availability_loss=availability_loss,
                quality_loss=quality_loss,
            )

            oee_dict = {
                "availability": availability * 100,
                "performance": performance * 100,
                "quality": quality * 100,
            }

            oee_loss = {
                "availability_loss": availability_loss,
                "quality_loss": quality_loss,
                "performance_loss": performance_loss,
            }

            logger.debug(f"OEE: {request_data.reference_id}: {oee_dict}")
            logger.debug(f"OEE Loss: {request_data.reference_id}: {oee_loss}")

            batch_oee = BatchOEEData(
                **request_data.dict(),
                calculated_on=datetime.now().astimezone(tz=pytz.timezone(request_data.tz)).isoformat(),
                productive_time=productive_time,
                availability=availability * 100,
                performance=performance * 100,
                quality=quality * 100,
                availability_loss=availability_loss,
                quality_loss=quality_loss,
                performance_loss=performance_loss,
                oee=oee * 100,
            )

            return batch_oee
        except Exception as e:
            logger.exception(f"Exception occurred while calculating batch oee {e.args}")
            raise
