from sqlalchemy import Float, Integer, Text, update, create_engine
from sqlalchemy import MetaData, Column, Table
from sqlalchemy.dialects.postgresql import JSON
from sqlalchemy.orm import Session
from sqlalchemy_utils import database_exists, create_database

from scripts.config.app_configurations import DBConf
from scripts.db.psql.databases import Base
from scripts.logging.logging import logger as LOG, logging_config
from scripts.utils.common_utils import CommonUtils


class TicketEntry(Base):
    __tablename__ = "ticket_entry"
    workflow_id = Column(Text)
    template_id = Column(Text)
    ticket_title = Column(Text)
    site_hierarchy = Column(Text)
    data = Column(JSON)
    user_id = Column(Text)
    created_on = Column(Float(precision=20, decimal_return_scale=True))
    last_updated = Column(Float(precision=20, decimal_return_scale=True))
    expiry_date = Column(Float(precision=20, decimal_return_scale=True))
    assign_to = Column(Text)
    id = Column(Integer, primary_key=True, autoincrement=True)
    event_type = Column(Text)
    event_status = Column(Text)
    project_id = Column(Text)

    @staticmethod
    def column_template_id():
        return 'template_id'

    @staticmethod
    def column_workflow_id():
        return 'workflow_id'

    @staticmethod
    def column_ticket_title():
        return 'ticket_title'

    @staticmethod
    def column_site_hierarchy():
        return 'site_hierarchy'

    @staticmethod
    def column_event_status():
        return 'event_status'

    @staticmethod
    def column_audit_type():
        return 'audit_type'

    @staticmethod
    def column_event_type():
        return 'event_type'

    @staticmethod
    def column_event_id():
        return 'event_id'

    @staticmethod
    def column_created_on():
        return "created_on"

    @staticmethod
    def column_last_updated():
        return "last_updated"

    @staticmethod
    def column_expiry_date():
        return "expiry_date"

    @staticmethod
    def column_assign_to():
        return "assign_to"

    @staticmethod
    def column_user_id():
        return "user_id"

    @staticmethod
    def column_data():
        return "data"

    @staticmethod
    def column_id():
        return "id"

    @staticmethod
    def column_project_id():
        return "project_id"

    def pagination_search(self, search_query, table):
        row = self.session.query(table)

    def table_def_user_entry(self, meta=MetaData()):
        return Table(self.__tablename__, meta,
                     Column(self.column_template_id(), Text),
                     Column(self.column_workflow_id(), Text),
                     Column(self.column_ticket_title(), Text),
                     Column(self.column_site_hierarchy(), Text),
                     Column(self.column_data(), JSON),
                     Column(self.column_user_id(), Text),
                     Column(self.column_created_on(), Float(precision=20, decimal_return_scale=True)),
                     Column(self.column_last_updated(), Float(precision=20, decimal_return_scale=True)),
                     Column(self.column_expiry_date(), Float(precision=20, decimal_return_scale=True)),
                     Column(self.column_assign_to(), Text),
                     Column(self.column_event_status(), Text),
                     Column(self.column_event_type(), Text),
                     Column(self.column_id(), Integer, primary_key=True, autoincrement=True),
                     Column(self.column_project_id(), Text))


class SQLDBUtils(CommonUtils):
    def __init__(self, db: Session):
        self.session: Session = db
        self.filter = None
        self.echo = logging_config["level"].upper() == "DEBUG"
        super().__init__()

    def add_data(self, table):
        self.session.add(table)
        self.session.commit()
        self.session.flush(table)

    @staticmethod
    def enable_traceback():
        return True

    def create_db(self):
        try:
            engine = create_engine(DBConf.MAINTENANCE_DB_URI, echo=self.echo)
            if not database_exists(engine.url):
                create_database(engine.url)
        except Exception as e:
            LOG.error(f"Error occurred during start-up: {e}", exc_info=True)

    @staticmethod
    def create_all_tables(engine, meta):
        if not engine.dialect.has_table(engine, TicketEntry().__tablename__):
            TicketEntry().table_def_user_entry(meta=meta)
        meta.create_all(engine)
        try:
            column = Column(TicketEntry().column_workflow_id(), Text(), primary_key=False)
            SQLDBUtils().add_column(engine, TicketEntry().__tablename__, column, meta)
        except:
            pass

    @staticmethod
    def add_column(engine, table_name, column, meta):
        t = Table(table_name, meta, autoload_with=engine)
        columns = [m.key for m in t.columns]
        column_name = column.compile(dialect=engine.dialect)
        if column_name in columns:
            return
        column_type = column.type.compile(engine.dialect)
        engine.execute('ALTER TABLE %s ADD COLUMN %s %s' % (table_name, column_name, column_type))

    @staticmethod
    def key_filter_expression():
        return "expression"

    @staticmethod
    def key_filter_column():
        return "column"

    @staticmethod
    def key_filter_value():
        return "value"

    def filter_expression(self):
        filter_expression = self.filter.get(self.key_filter_expression(), 'eq')
        LOG.debug(f"Filter expression: {filter_expression}")
        return filter_expression

    def filter_column(self):
        column = self.filter.get(self.key_filter_column(), None)
        LOG.debug(f"Filter column: {column}")
        return column

    def filter_value(self):
        filter_value = self.filter.get(self.key_filter_value(), None)
        LOG.debug(f"Filter value: {filter_value}")
        return filter_value

    def _filter(self, session_query, filters=None):
        if filters is not None:
            for _filter in filters:
                self.filter = _filter
                if self.filter_column() is None:
                    continue
                session_query = self.get_session_query(session_query=session_query)
        return session_query

    def get_session_query(self, session_query):
        try:
            if self.filter_expression() == 'eq':
                session_query = session_query.filter(self.filter_column() == self.filter_value())
            if self.filter_expression() == 'le':
                session_query = session_query.filter(self.filter_column() < self.filter_value())
            if self.filter_expression() == 'ge':
                session_query = session_query.filter(self.filter_column() > self.filter_value())
            if self.filter_expression() == 'lte':
                session_query = session_query.filter(self.filter_column() <= self.filter_value())
            if self.filter_expression() == 'gte':
                session_query = session_query.filter(self.filter_column() >= self.filter_value())
            if self.filter_expression() == 'neq':
                session_query = session_query.filter(self.filter_column() != self.filter_value())
        except Exception as e:
            LOG.error(f"Error occurred while filtering the session query {e}", exc_info=self.enable_traceback())
        return session_query

    def insert_one(self, session, table, insert_json):
        try:
            row = table()
            for k in insert_json:
                setattr(row, k, insert_json[k])
            session.merge(row)
            session.commit()
            session.close()
            return True
        except Exception as e:
            LOG.error(f"Error while inserting the record {e}", exc_info=self.enable_traceback())
            raise

    def update(self, table, update_json, filters=None, insert=False, insert_id=None):
        try:
            LOG.debug(filters)
            session = self.session
            row = session.query(table)
            filtered_row = self._filter(session_query=row, filters=filters)
            filtered_row = filtered_row.first()
            if filtered_row is None:
                LOG.debug("There are no rows meeting the given update criteria.")
                if insert:
                    LOG.debug("Trying to insert a new record")
                    if insert_id is None:
                        LOG.warning("ID not provided to insert record. Skipping insert.")
                        return False
                    else:
                        update_json.update(insert_id)
                    if self.insert_one(session=session, table=table, insert_json=update_json):
                        return True
                    else:
                        return False
                else:
                    return False
            else:
                LOG.debug("Record available to update")
            for k in update_json:
                setattr(filtered_row, k, update_json[k])
            # filtered_row.update()
            session.commit()
            session.close()
        except Exception as e:
            LOG.error(f"Error while updating the record {e}", exc_info=self.enable_traceback())
            raise

    def update_many(self, table, update_json, filters, conn):
        try:
            stmt = (update(table).where(filters).values(update_json))
            conn.execute(stmt)
            conn.close()
        except Exception as e:
            LOG.error(f"Error while updating the record {e}", exc_info=self.enable_traceback())
            raise

    def delete(self, table, filters=None):
        try:
            # LOG.trace(filters)
            session = self.session
            row = session.query(table)
            filtered_row = self._filter(session_query=row, filters=filters)
            if filtered_row is None:
                LOG.debug("There were no records to be deleted")
                session.close()
            else:
                filtered_row.delete()
                session.commit()
                session.close()
            return True
        except Exception as e:
            LOG.error(f"Failed to delete a record {e}", exc_info=self.enable_traceback())
            raise

    def distinct_values_by_column(self, table, session, column, filters=None):
        query = session.query(getattr(table, column).distinct().label(column))
        query = self._filter(session_query=query, filters=filters)
        distinct_values = [getattr(row, column) for row in query.all()]
        session.close()
        return distinct_values

    def execute_query(self, table=None, query=None):
        session = self.session
        if query is None:
            query = f"select * from {table}"
        result = session.execute(query)
        output = [dict(zip(row.keys(), row.values())) for row in result]
        session.close()
        return output

    @staticmethod
    def fetch_from_table(table, session, filter_text, limit_value, skip_value, project_id):
        LOG.debug(filter_text)
        row = session.query().filter(Text(filter_text)).limit(limit_value).offset(
            skip_value)
        result = session.execute(row)
        output = [dict(zip(row.keys(), row.values())) for row in result]
        session.close()
        return output


class TicketEntryTable(SQLDBUtils):
    def __init__(self, db: Session):
        super().__init__(db)
        self.table = TicketEntry
