from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session, defer

from scripts.db.db_models import OEEDiscreteTable
from scripts.errors import ILensError
from scripts.logging import logger
from scripts.utils.postgres_util import SQLDBUtils


class DiscreteOEE(SQLDBUtils):
    def __init__(self, db: Session):
        super().__init__(db)
        self.table = OEEDiscreteTable

    def get_oee_data_all(self, start_time, end_time, hierarchy):

        try:
            data = (
                self.session.query(self.table)
                    .order_by(self.table.calculated_on)
                    .filter(
                    self.table.hierarchy == hierarchy,
                    self.table.calculated_on >= start_time,
                    self.table.calculated_on <= end_time,
                )
            )
            if data:
                return [jsonable_encoder(i) for i in data]
            else:
                return list()
        except Exception as e:
            logger.exception(e)
            raise

    def get_oee_data_batch_id(self, batch_id, hierarchy):
        try:
            data = (
                self.session.query(self.table)
                    .order_by(self.table.calculated_on)
                    .filter(
                    self.table.hierarchy == hierarchy, self.table.batch_id == batch_id
                )
                    .first()
            )
            if data:
                return jsonable_encoder(data)
            else:
                return None
        except Exception as e:
            logger.exception(e)
            raise

    def get_batches(self, hierarchy, start_time, end_time):
        try:
            data = (
                self.session.query(self.table.batch_id)
                    .order_by(self.table.calculated_on)
                    .filter(
                    self.table.hierarchy == hierarchy,
                    self.table.calculated_on >= start_time,
                    self.table.calculated_on <= end_time,
                )
            )
            if data:
                return [getattr(i, self.column_batch_id) for i in data]
            else:
                return list()
        except Exception as e:
            logger.exception(e)
            raise

    def get_products(self, hierarchy, start_time, end_time):
        try:
            data = (
                self.session.query(
                    self.table.batch_id,
                    self.table.batch_start_time,
                    self.table.batch_end_time,
                )
                    .order_by(self.table.calculated_on)
                    .filter(
                    self.table.hierarchy == hierarchy,
                    self.table.batch_start_time >= start_time,
                    self.table.batch_end_time <= end_time,
                )
            )
            if data:
                return [jsonable_encoder(each) for each in data]
            else:
                return list()
        except Exception as e:
            logger.exception(e)
            raise

    def get_chart_data(
            self, start_time, end_time, hierarchy, product_id, aggregation=False
    ):
        try:
            if not aggregation:
                data = (
                    self.session.query(self.table)
                        .filter(
                        self.table.hierarchy == hierarchy,
                        self.table.batch_id == product_id,
                        self.table.batch_start_time >= start_time,
                        self.table.batch_end_time <= end_time,
                    )
                        .first()
                )
                if data:
                    return jsonable_encoder(data)
            else:
                data = (
                    self.session.query(self.table)
                        .filter(
                        self.table.hierarchy == hierarchy,
                        self.table.batch_start_time >= start_time,
                        self.table.batch_end_time <= end_time,
                    )
                        .options(
                        defer(self.table.hierarchy),
                        defer(self.table.batch_id),
                        defer(self.table.uom),
                    )
                )
                if data:
                    return [jsonable_encoder(each) for each in data]
            raise ILensError("Record(s) not found")
        except Exception as e:
            logger.exception(e)
            raise
