from typing import Optional, Dict, List

from scripts.constants.app_constants import DatabaseNames, CollectionNames, TaskInstanceDataKeys
from scripts.db.mongo.schema import MongoBaseSchema
from scripts.utils.mongo_util import MongoCollectionBaseClass


class TaskInstanceDataSchema(MongoBaseSchema):
    """
    This is the Schema for the Mongo DB Collection.
    All datastore and general responses will be following the schema.
    """
    task_id: Optional[str]
    stage_id: Optional[str]
    step_id: Optional[str]
    step_data: Optional[Dict] = dict()
    project_id: Optional[str]
    remarks: Optional[List] = list()
    status: Optional[bool] = False


class TaskInstanceData(MongoCollectionBaseClass):
    def __init__(self, mongo_client, project_id=None):
        super().__init__(mongo_client, database=DatabaseNames.ilens_assistant,
                         collection=CollectionNames.task_instance_data)
        self.project_id = project_id

    @property
    def key_stage_id(self):
        return TaskInstanceDataKeys.KEY_STAGE_ID

    @property
    def key_step_id(self):
        return TaskInstanceDataKeys.KEY_STEP_ID

    @property
    def key_task_id(self):
        return TaskInstanceDataKeys.KEY_TASK_ID

    @property
    def key_status(self):
        return TaskInstanceDataKeys.KEY_STATUS

    def find_by_id(self, stage_id: str):
        query = {"stage_id": stage_id}
        record = self.find_one(query)
        if not record:
            return None
        return TaskInstanceDataSchema(**record)

    def find_by_task_id_step_id(self, task_id: str, step_id: str):
        query = {"task_id": task_id, "step_id": step_id}
        record = self.find_one(query)
        if not record:
            return None
        return TaskInstanceDataSchema(**record)

    def update_by_task_step_id(self, task_id: str, step_id: str, data: dict):
        query = {"task_id": task_id, "step_id": step_id}
        json_data = {"step_data": data}
        self.update_one(query, json_data, True)

    def update_stage(self, stage_id, data):
        query = {"stage_id": stage_id}
        json_data = {"step_data": data}
        self.update_one(query, json_data, True)

    def update_many_stages(self, stage_list, data):
        query = {"stage_id": {'$in': stage_list}}
        self.update_many(query, data, True)

    def update_stage_data(self, stage_id, data):
        query = {"stage_id": stage_id}
        self.update_one(query, data, True)

    def find_data_by_task_id(self, task_id):
        query = {"task_id": task_id}
        records = self.find(query)
        if not records:
            return list()
        return [dict(stage_id=record["stage_id"], data=record["step_data"]) for record in records]

    def get_stage_map_steps(self, stages):
        query = {"stage_id": {"$in": stages}}
        stage_data = self.find(query)
        if not stage_data:
            return dict(), list()
        stages_map = dict()
        steps = list()
        for stage in stage_data:
            stages_map.update({stage.get("stage_id"): stage.get("step_id")})
            steps.append(stage.get("step_id"))
        return stages_map, steps

    def find_many(self, stages_list):
        query = {"stage_id": {"$in": stages_list}}
        stages = self.find(query)
        if not stages:
            return dict(), dict()
        stage_data = dict()
        step_data = dict()
        for stage in stages:
            stage_data.update({stage.get("stage_id"): stage.get("step_data")})
            step_data.update({stage.get("stage_id"): stage.get("step_id")})
        return stage_data, step_data

    def find_data_for_multiple_stages(self, stages_list: list):
        query = {"stage_id": {"$in": stages_list}}
        records = self.find(query, sort=[('_id', 1)])
        if not records:
            return list()
        return list(records)

    def find_data_with_task_id_step_list(self, task_id, steps_list: list):
        query = {self.key_task_id: task_id, self.key_step_id: {'$in': steps_list}}
        records = list(self.find(query))
        if not records:
            return list()
        return records

    def find_all_data_by_task_id(self, task_id):
        query = {"task_id": task_id}
        records = self.find(query)
        if not records:
            return list()
        return records
