import re

from scripts.constants.db_connections import mongo_client
from scripts.db.mongo.ilens_configuration.aggregations.tag_hierarchy import TagHierarchyAggregate
from scripts.db.mongo.ilens_configuration.collections.tag_hierarchy import TagHierarchy
from scripts.db.mongo.schema.tag_hierarchy import GetTagsLists, OutputTagsList
from scripts.logging import logger


class TagHierarchyHandler:
    def __init__(self, project_id=None):
        self.tag_hierarchy_conn = TagHierarchy(mongo_client=mongo_client, project_id=project_id)

    def get_tags_list_by_hierarchy(self, input_data: GetTagsLists):
        try:
            aggregate_query = TagHierarchyAggregate.tag_aggregate(project_id=input_data.project_id,
                                                                  hierarchy_id=re.escape(f'{input_data.hierarchy}$tag'))
            tags_list = self.tag_hierarchy_conn.get_tag_hierarchy_by_aggregate(query=aggregate_query)
            tags_list = tags_list[0] if tags_list else {}
            return tags_list
        except Exception as e:
            logger.exception(f"failed to fetch tags_list by hierarchy {e.args}")
            raise

    def get_output_tags_for_oee(self, input_data: OutputTagsList):
        try:
            if input_data.hierarchy_list:
                hierarchy_str = re.escape("|".join([f'{_each}$tag' for _each in input_data.hierarchy_list]))
            elif input_data.hierarchy_level:
                hierarchy_str = input_data.hierarchy_level
            elif input_data.hierarchy:
                hierarchy_str = re.escape(f'{input_data.hierarchy}$tag')
            else:
                return {}
            aggregate_query = TagHierarchyAggregate.tag_aggregate_by_hierarchy_list(project_id=input_data.project_id,
                                                                                    hierarchy_str=hierarchy_str)
            tags_list = self.tag_hierarchy_conn.get_tag_hierarchy_by_aggregate(query=aggregate_query)
            tags_list = tags_list[0] if tags_list else {}
            return tags_list
        except Exception as e:
            logger.exception(f"failed to fetch output tags list by hierarchy {e.args}")
            raise
