from typing import Optional, Dict

from scripts.constants import DatabaseNames, CollectionNames
from scripts.constants.db_keys import StepTemplateKeys
from scripts.db.mongo.schema import MongoBaseSchema
from scripts.utils.mongo_util import MongoCollectionBaseClass


class StepTemplatesSchema(MongoBaseSchema):
    template_id: Optional[str]
    template_name: Optional[str]
    logbook_id: Optional[str]
    associated_workflow_version: Optional[int] = 1
    description: Optional[str]
    meta: Optional[Dict] = {}
    file_path: Optional[str]


class StepTemplates(MongoCollectionBaseClass):
    def __init__(self, mongo_client, project_id=None):
        super().__init__(mongo_client, database=DatabaseNames.ilens_assistant,
                         collection=CollectionNames.step_templates)
        self.project_id = project_id

    @property
    def key_template_id(self):
        return StepTemplateKeys.KEY_TEMPLATE_ID

    @property
    def key_template_name(self):
        return StepTemplateKeys.KEY_TEMPLATE_NAME

    @property
    def key_project_id(self):
        return StepTemplateKeys.KEY_PROJECT_ID

    def get_template_data_by_aggregate(self, query: list):
        return list(self.aggregate(pipelines=query))

    def delete_template(self, template_id):
        query = {self.key_template_id: template_id}
        return self.delete_one(query=query)

    def find_template(self, template_name: str):
        query = {self.key_template_name: template_name}
        record = self.find_one(query)
        if not record:
            return None
        return StepTemplatesSchema(**record)

    def find_by_id(self, template_id: str):
        query = {self.key_template_id: template_id}
        record = self.find_one(query)
        if not record:
            return StepTemplatesSchema(**dict())
        return StepTemplatesSchema(**record)

    def update_template_data(self, template_id, project_id, data, upsert=False):
        query = {self.key_template_id: template_id, self.key_project_id: project_id}
        return self.update_one(data=data, query=query, upsert=upsert)

    def get_steps_data_data_by_aggregate(self, query: list):
        return list(self.aggregate(pipelines=query))
