from fastapi.encoders import jsonable_encoder
from sqlalchemy import create_engine, inspect
from sqlalchemy.orm import Session

from scripts.config.app_configurations import DBConf
from scripts.db.psql.models.ope_formula_calculation import *
from scripts.logging.logging import logger, logging_config


class QueryLayer:
    def __init__(self, db: Session):
        self.session: Session = db
        self.table_summary = DBModelSummary
        self.shift_wise_summary = DBModelShiftDetails
        self.echo = logging_config["level"].upper() == "DEBUG"
        self.create_table(self.table_summary)
        self.create_table(self.shift_wise_summary)

    def create_table(self, table):
        try:
            engine = create_engine(DBConf.ASSISTANT_DB_URI, echo=self.echo)
            if not inspect(engine).has_table(table.__tablename__):
                orm_table = table
                orm_table.__table__.create(bind=engine, checkfirst=True)
        except Exception as e:
            logger.error(f"Error occurred during start-up: {e}", exc_info=True)

    @property
    def column_id(self):
        return "id"

    @property
    def column_date(self):
        return "date"

    @property
    def column_step_id(self):
        return "step_id"

    @property
    def column_unadjusted_loss_in_time(self):
        return "unadjusted_loss_in_time"

    @property
    def column_booked_loss_in_time(self):
        return "booked_loss_in_time"

    def add_to_table(self, records, step_id, date):
        list_of_recs = list()
        try:
            for each in records:
                for shift in ["summary", "shift_a", "shift_b", "shift_c"]:
                    shift_obj = {x.split(f"_{shift}")[0]: y for x, y in each.items() if shift in x}
                    if not shift_obj:
                        continue
                    if shift == "summary":
                        summary_obj = SummarySchema(**shift_obj).dict(exclude_none=True)
                        existing_data = self.session.query(self.table_summary) \
                            .filter(DBModelSummary.step_id == step_id,
                                    DBModelSummary.date == date).first()
                        existing_data = jsonable_encoder(existing_data)
                        if existing_data:
                            self.session.query(self.table_summary).filter(DBModelSummary.step_id == step_id,
                                                                          DBModelSummary.date == date).update(
                                summary_obj)
                            self.session.commit()
                            self.session.flush()
                        else:
                            table_obj = DBModelSummary(**summary_obj, step_id=step_id,
                                                       date=date)
                            list_of_recs.append(table_obj)
                    else:
                        shift_in_db = shift.split("shift_")[1].upper()
                        shift_model = SummarySchema(**shift_obj).dict(exclude_none=True)
                        table_obj = DBModelShiftDetails(**shift_model, step_id=step_id, date=date, shift=shift_in_db)
                        list_of_recs.append(table_obj)
            return list_of_recs
        except TypeError:
            raise
        except Exception as e:
            logger.exception(f"Exception occurred while adding to postgres table {e}")

    def insert_data(self, object_models_list, step_id, date):
        try:
            self.delete_shift_data(step_id, date)
            mappings = self.add_to_table(object_models_list, step_id, date)
            self.session.bulk_save_objects(mappings)
            self.session.commit()
            return True
        except TypeError:
            raise
        except Exception as e:
            logger.exception(e)
            raise

    def delete_shift_data(self, step_id, date):
        try:
            self.session.query(self.shift_wise_summary) \
                .filter(self.shift_wise_summary.step_id == step_id, self.shift_wise_summary.date == date) \
                .delete()
            self.session.commit()
            return True
        except Exception as e:
            logger.error(e)
            raise

    def update_losses(self, unaccounted_losses, booked_loss, date, step_id):
        try:
            existing_data = self.session.query(self.table_summary) \
                .filter(self.table_summary.date == date,
                        self.table_summary.step_id == step_id).first()
            if not existing_data:
                table_obj = DBModelSummary(step_id=step_id, date=date, unadjusted_loss_in_time=unaccounted_losses,
                                           booked_loss_in_time=booked_loss)
                self.session.bulk_save_objects([table_obj])
                return False
            existing_data_dict = jsonable_encoder(existing_data)
            existing_data_dict.update({self.column_unadjusted_loss_in_time: unaccounted_losses,
                                       self.column_booked_loss_in_time: booked_loss})
            for field in jsonable_encoder(existing_data):
                if field in existing_data_dict:
                    setattr(existing_data, field, existing_data_dict[field])
            self.session.commit()
            self.session.flush()
            return True
        except Exception as e:
            logger.error(e)
            raise
