from fastapi.encoders import jsonable_encoder
from sqlalchemy import func
from sqlalchemy.orm import Session

from scripts.db.db_models import OEEDiscreteTable
from scripts.errors import ILensError
from scripts.logging import logger
from scripts.utils.postgres_util import SQLDBUtils


class DiscreteOEE(SQLDBUtils):
    def __init__(self, db: Session):
        super().__init__(db)
        self.table = OEEDiscreteTable

    def get_oee_data_all(self, prod_start_time, prod_end_time, hierarchy):

        try:
            data = (
                self.session.query(self.table)
                .order_by(self.table.calculated_on)
                .filter(
                    self.table.hierarchy == hierarchy,
                    self.table.calculated_on >= prod_start_time,
                    self.table.calculated_on <= prod_end_time,
                )
            )
            if data:
                return [jsonable_encoder(i) for i in data]
            else:
                return list()
        except Exception as e:
            logger.exception(e)
            raise

    def get_oee_data_by_reference_id(self, reference_id, hierarchy, project_id):
        try:
            data = (
                self.session.query(self.table)
                .order_by(self.table.calculated_on)
                .filter(
                    self.table.hierarchy == hierarchy, self.table.reference_id == reference_id,
                    self.table.project_id == project_id
                )
                .first()
            )
            if data:
                return jsonable_encoder(data)
            else:
                return None
        except Exception as e:
            logger.exception(e)
            raise

    def get_batches(self, hierarchy, prod_start_time, prod_end_time):
        try:
            data = (
                self.session.query(self.table.batch_id)
                .order_by(self.table.calculated_on)
                .filter(
                    self.table.hierarchy == hierarchy,
                    self.table.calculated_on >= prod_start_time,
                    self.table.calculated_on <= prod_end_time,
                )
            )
            if data:
                return [getattr(i, self.table.reference_id) for i in data]
            else:
                return list()
        except Exception as e:
            logger.exception(e)
            raise

    def get_batches_info(self, hierarchy, prod_start_time, prod_end_time, tz):
        try:
            data = (
                self.session.query(
                    self.table.reference_id,
                    # func.to_char(func.timezone(tz, self.table.prod_start_time), "DD-MM-YYYY HH24:MI").label(
                    #     self.table.prod_start_time.key),
                    func.timezone(tz, self.table.prod_end_time).label(self.table.prod_end_time.key),
                    func.timezone(tz, self.table.prod_start_time).label(self.table.prod_start_time.key),
                )
                .order_by(self.table.calculated_on)
                .filter(
                    self.table.hierarchy == hierarchy,
                    self.table.prod_start_time >= prod_start_time,
                    self.table.prod_end_time <= prod_end_time,
                )
            )
            if data:
                return [jsonable_encoder(each) for each in data]
            else:
                return list()
        except Exception as e:
            logger.exception(e)
            raise

    def get_chart_data(
            self, prod_start_time, prod_end_time, hierarchy, reference_id
    ):
        try:
            data = (
                self.session.query(self.table)
                .filter(
                    self.table.hierarchy == hierarchy,
                    self.table.reference_id == reference_id,
                    self.table.prod_start_time >= prod_start_time,
                    self.table.prod_end_time <= prod_end_time,
                )
                .first()
            )
            if data:
                return jsonable_encoder(data)
            raise ILensError("Record(s) not found")
        except Exception as e:
            logger.exception(e)
            raise
