from datetime import date
from typing import Optional, List, Dict

from scripts.constants import StepCategories
from scripts.constants.app_constants import DatabaseNames, CollectionNames
from scripts.constants.db_keys import ReferenceDataKeys
from scripts.db.mongo.schema import MongoBaseSchema
from scripts.utils.mongo_util import MongoCollectionBaseClass


class ReferenceStepSchema(MongoBaseSchema):
    """
    This is the Schema for the Mongo DB Collection.
    All datastore and general responses will be following the schema.
    """
    step_id: Optional[str]
    properties: Optional[dict] = dict()
    data: Optional[Dict] = dict()


class ReferenceStep(MongoCollectionBaseClass):
    def __init__(self, mongo_client, project_id=None):
        super().__init__(mongo_client, database=DatabaseNames.ilens_assistant,
                         collection=CollectionNames.reference_steps)
        self.project_id = project_id

    @property
    def key_step_id(self):
        return ReferenceDataKeys.KEY_STEP_ID

    @property
    def key_date(self):
        return ReferenceDataKeys.KEY_DATE

    @property
    def key_data(self):
        return ReferenceDataKeys.KEY_DATA

    @property
    def key_property(self):
        return ReferenceDataKeys.KEY_PROPERTIES

    @property
    def key_step_category(self):
        return ReferenceDataKeys.KEY_STEP_CATEGORY

    @property
    def key_entity_name(self):
        return ReferenceDataKeys.KEY_ENTITY_NAME

    @property
    def key_event_id(self):
        return ReferenceDataKeys.KEY_EVENT_ID

    @property
    def key_task_id(self):
        return ReferenceDataKeys.KEY_TASK_ID

    def find_by_id(self, step_id: str):
        query = {self.key_step_id: step_id}
        record = self.find_one(query)
        if not record:
            return None
        return ReferenceStepSchema(**record)

    def find_by_date_and_step(self, _date: str, step_id: str):
        query = {self.key_date: _date, self.key_step_id: step_id}
        record = self.find_one(query)
        if not record:
            return ReferenceStepSchema()
        return ReferenceStepSchema(**record)

    def find_by_date_and_multi_step(self, _date: str, step_id_list: List, task_id: str = None):
        query = {f"{self.key_property}.{self.key_date}": _date, self.key_step_id: {"$in": step_id_list}}
        if task_id:
            query.update({self.key_task_id: task_id})
        records = self.find(query)
        if not records:
            return dict()
        return {step.get("step_id"): step for step in records}

    def find_by_multi_step_without_date(self, step_id_list: List, task_id: str = None):
        query = {self.key_step_id: {"$in": step_id_list}}
        if task_id:
            query.update({self.key_task_id: task_id})
        records = self.find(query)
        if not records:
            return dict()
        return {step.get("step_id"): step for step in records}

    def insert_date(self, _date: date, step_id: str, data):
        json_data = {self.key_date: _date, self.key_step_id: step_id, self.key_data: data}
        self.insert_one(json_data)
        return True

    def update_data_with_date(self, _date: str, step_id: str, data, step_category: str, entity_name: str, task_id: str):
        if step_category in [StepCategories.PERIODIC, StepCategories.TRIGGER_BASED]:
            query = {f"{self.key_property}.{self.key_date}": _date, self.key_step_id: step_id,
                     self.key_task_id: task_id,
                     f"{self.key_property}.{self.key_step_category}": step_category, self.key_entity_name: entity_name}
        else:
            query = {self.key_step_id: step_id, self.key_task_id: task_id,
                     f"{self.key_property}.{self.key_step_category}": step_category, self.key_entity_name: entity_name}

        self.update_one(query, {self.key_data: data}, upsert=True)
        return True

    def fetch_data_from_query(self, query):
        records = self.find(query)
        if not records:
            return dict()
        return {step.get("step_id"): step for step in records}

    def find_data_from_query(self, query, sort_json=None, find_one=True):
        if sort_json:
            records = list(self.find(query, sort=list(sort_json.items())))
        else:
            records = list(self.find(query))
        if not records:
            return dict()
        if find_one:
            return records[0]
        return records

    def update_data_for_trigger_steps(self, _date: str, step_id: str, data, step_category: str, entity_name: str,
                                      event_id: str):
        query = {f"{self.key_property}.{self.key_date}": _date, self.key_step_id: step_id,
                 f"{self.key_property}.{self.key_event_id}": event_id,
                 f"{self.key_property}.{self.key_step_category}": step_category, self.key_entity_name: entity_name}
        self.update_one(query, {self.key_data: data}, upsert=True)
        return True
