from scripts.db.graphdb.neo4j import Neo4jHandler
from scripts.db.models import NodePropertiesSchema, RelationShipMapper
from scripts.logging import logger
from scripts.schemas import GraphData, InputRequestSchema, NodeActionOptions, GetNodeInfo, ResponseModelSchema
from scripts.utils.graph_utility import GraphUtility


class GraphTraversal:
    def __init__(self, db):
        """
        Graph Tracking Handler class interacts with the Graph Traversal class
        that fetches all the nodes and relations
        based on the input received from the user
        """
        self.db = db
        self.graph_util = GraphUtility(db=db)

    def ingest_data_handler(self, graph_data: GraphData):
        try:
            nodes_map = {}
            edges_info = {}
            for node_type, node_obj in graph_data.__root__.items():
                if node_obj.action == NodeActionOptions.delete:
                    self.perform_delete_node_action(graph_data=node_obj)
                else:
                    nodes_map, edges_info = self.perform_save_node_action(node_obj=node_obj, nodes_map=nodes_map,
                                                                          edges_info=edges_info, node_type=node_type)
            if edges_info:
                for des_edge, edges in edges_info.items():
                    for edge in edges:
                        if nodes_map.get(des_edge) and nodes_map.get(edge.bind_to):
                            self.graph_util.save_relationship(
                                rel_data=RelationShipMapper(_start_node_id=nodes_map[edge.bind_to]._id,
                                                            _end_node_id=nodes_map[des_edge]._id,
                                                            _type=edge.rel_name,
                                                            project_id=nodes_map[des_edge].project_id,
                                                            ), new_rel_name=edge.new_rel_name)
        except Exception as e:
            logger.exception(f'{e.args}')

    def perform_save_node_action(self, node_obj: InputRequestSchema, nodes_map: dict, edges_info: dict, node_type):
        try:
            res = self.graph_util.save_single_node(
                node_data=NodePropertiesSchema(_labels=node_obj.node_type.split(","), **node_obj.dict()))
            nodes_map[node_type] = res
            if node_obj.edges and isinstance(node_obj.edges, list):
                if not edges_info.get(node_type):
                    edges_info[node_type] = []
                edges_info[node_type].extend(node_obj.edges)
            return nodes_map, edges_info
        except Exception as e:
            logger.exception(f'Exception occurred while updating node info {e.args}')
            raise

    def perform_delete_node_action(self, graph_data: InputRequestSchema):
        try:
            neo4j_handler = Neo4jHandler(self.db)
            neo4j_handler.delete_node_by_id(node=NodePropertiesSchema(**graph_data.dict()))
            return "Node Deleted Successfully"
        except Exception as e:
            logger.exception(f'Exception Occurred while deleting node info {e.args}')
            raise

    def fetch_node_data(self, graph_request: GetNodeInfo):
        return_data = ResponseModelSchema(series_data=dict(nodes=[], links=[]))
        try:
            existing_data = self.graph_util.get_connecting_nodes_info(input_data=graph_request)
            existing_node_info = []
            for k, v in existing_data.items():
                for _item in v:
                    node_info = _item.dict()
                    ui_dict = _item.dict(exclude_none=True)
                    if k.lower() == "r":
                        ui_dict["source"] = ui_dict.pop("_start_node_id")
                        ui_dict["target"] = ui_dict.pop("_end_node_id")
                        ui_dict["linkName"] = ui_dict.pop("_type")
                        return_data.series_data["links"].append(ui_dict)
                        continue
                    node_id = node_info.get("node_id")
                    unique_id = node_info.get("_id")
                    if not node_id or unique_id in existing_node_info:
                        continue
                    existing_node_info.append(unique_id)
                    ui_dict.update({"x": '', "y": ''})
                    return_data.series_data["nodes"].append(ui_dict)
            return return_data
        except Exception as e:
            logger.exception(f"Exception Occurred while fetching data from node - {e.args}")
            raise
