from scripts.db.mongo.ilens_configuration.collections.a_batch_info import \
    ABatchInfoCollection
from scripts.db.mongo.ilens_configuration.collections.b_batch_parameter_data import \
    BBatchParameterDataCollection
from scripts.db.mongo.ilens_configuration.collections.c_raw_material_data import \
    CRawMaterialDataCollection
from scripts.db.mongo.ilens_configuration.collections.d_non_compliance_data import \
    DNonComplianceDataCollection
from scripts.logging import logger
from scripts.schemas.ut_schema import GetSample, GetTableData
from scripts.constants import BatchAppHeaderContentConstants
from py2neo import Graph, Node, Relationship

graph = Graph("bolt://192.168.0.220:7687", auth=("neo4j", "root"))


def make_response_for_batch_info(data, parent_batch_no):
    try:
        batch_no_id_dict = {}
        response = {"edges": [], "nodes": []}
        parent_ids = []
        for nodes in data:
            batch_no_id_dict.update({nodes.get("Batch_No", ""): nodes.get("id", "")})
            if nodes.get("Level") == -1:
                parent_ids.append(nodes.get("id", ""))
        edges_list = []
        nodes_list = []
        for each_nodes in data:
            if each_nodes.get("Batch_No") == parent_batch_no:
                nodes_list.append(
                    {"id": each_nodes.get("id", ""),
                     "label": each_nodes.get("Batch_No", ""),
                     "title": each_nodes
                     }
                )
            for each_parent in each_nodes.get("Parent", []):
                parent_id = batch_no_id_dict.get(each_parent, "")
                if any([parent_id not in parent_ids and each_nodes.get("Level") == -2, not parent_id]):
                    continue
                if each_nodes.get("Level") == 0:
                    continue
                each_nodes["Parent"] = ",".join(each_nodes.get("Parent", []))
                nodes_list.append(
                    {"id": each_nodes.get("id", ""),
                     "label": each_nodes.get("Batch_No", ""),
                     "title": each_nodes
                     }
                )
                if each_nodes.get("id", "") == parent_id:
                    continue
                edges_list.append({"to": each_nodes.get("id", ""), "from": parent_id})
        print(edges_list)
        response["edges"] = edges_list
        response["nodes"] = nodes_list
        return response
    except Exception as e:
        logger.error(f"Error while getting sample data - {e}")
        raise


def create_query_for_batch_details(direction, level, batch_no):
    try:
        query_direction = "<" if direction == "upstream" else ">"
        sub_query = f"(source)<-[relationship:OCCURRED*0..{level}]-(target)" \
            if query_direction == "<" else f"(source)-[relationship:OCCURRED*0..{level}]->(target)"
        cypher_query = f"MATCH {sub_query} WHERE source.Batch_No = {batch_no} return source, relationship,target;"
        return cypher_query
    except Exception as e:
        logger.error(f"Error while getting sample data - {e}")
        raise e


class UTHandler:
    def __init__(self):
        self.a_batch_info = ABatchInfoCollection()
        self.b_batch_parameter_data = BBatchParameterDataCollection()
        self.c_raw_material_data = CRawMaterialDataCollection()
        self.d_non_compliance_data = DNonComplianceDataCollection()

    # async def get_sample_data(self, request_data: GetSample):
    #     try:
    #         data = self.sample.find_one(
    #             {
    #                 "project_id": request_data.project_id,
    #                 "sample_id": request_data.sample_id,
    #             }
    #         )
    #         return data
    #     except Exception as e:
    #         logger.error(f"Error while getting sample data - {e}")
    #         raise

    async def get_batch_info_neo_4j(self, request_data: GetSample):
        try:
            cypher_query = create_query_for_batch_details(batch_no=request_data.Batch_No, level=request_data.Batch_Level
                                                          , direction=request_data.Traverse.lower().replace(" ", ""))
            result = graph.run(cypher=cypher_query)
            response = {}
            nodes_list = []
            node_ids = []
            edges_list = []
            relation_mapping_ids = []
            result_data = result.data()
            for record in result_data:
                source_node = record['source']
                target_node = record['target']
                if source_node.identity not in node_ids:
                    node_ids.append(source_node.identity)
                    # source_node["id"] = source_node.identity
                    source_node["Mfg_Date"] = source_node.pop("Mfg._Date")
                    source_node["Mfg_Stage"] = source_node.pop("Mfg._Stage")
                    nodes_list.append(
                        {"id": source_node.identity,
                         "label": source_node.get("Batch_No", ""),
                         "title": dict(source_node)}
                    )
                if target_node.identity not in node_ids:
                    node_ids.append(target_node.identity)
                    target_node["Mfg_Date"] = target_node.pop("Mfg._Stage")
                    target_node["Mfg_Stage"] = target_node.pop("Mfg._Date")
                    nodes_list.append(
                        {"id": target_node.identity,
                         "label": target_node.get("Batch_No", ""),
                         "title": dict(target_node)}
                    )
                relationship = record['relationship']
                for each in relationship:
                    if f'{each.start_node.identity}_{each.end_node.identity}' in relation_mapping_ids:
                        continue
                    relation_mapping_ids.append(f'{each.start_node.identity}_{each.end_node.identity}')
                    edges_list.append({
                        "from": each.start_node.identity,
                        "to": each.end_node.identity,
                    })
            # for each_node in nodes_list:
            #     each_node["Mfg_Date"] = each_node.get("Mfg._Date", "")
            #     each_node["Mfg_Stage"] = each_node.get("Mfg._Stage", "")
            #     each_id = each_node.get("id", "")
            #     each_node.pop("Mfg._Date")
            #     each_node.pop("Mfg._Stage")
            #     each_node.pop("id")
            #     response_node_list.append(
            #         {"id": each_id,
            #          "label": each_node.get("Batch_No", ""),
            #          "title": each_node
            #          }
            #     )
            response["edges"] = edges_list
            response["nodes"] = nodes_list
            return response

        except Exception as e:
            logger.error(f"Error while getting sample data - {e}")
            raise

    async def get_batch_info(self, request_data: GetSample):
        try:
            data = list(self.a_batch_info.find(
                {
                    "$or": [{"Parent": request_data.Batch_No, "Level": -1 if request_data.Batch_Level > 0 else 0},
                            {"Batch_No": request_data.Batch_No}]
                }
            ))
            if request_data.Batch_Level >= 2:
                if request_data.Traverse.replace(" ", "").lower() == "downstream":
                    request_data.Batch_Level *= -1
                    levels = [x for x in range(request_data.Batch_Level, -1)]
                    multi_level_data = list(self.a_batch_info.find({"Level": {"$in": levels}}))
                    data.extend(multi_level_data)

            node_edge_data = make_response_for_batch_info(data, request_data.Batch_No)
            return node_edge_data
        except Exception as e:
            logger.error(f"Error while getting sample data - {e}")
            raise

    async def get_batch_parameter_table_data(self, request_data: GetTableData):
        try:
            # data = list(self.a_batch_info.find(
            #     query={
            #         "$or": [{"Parent": request_data.Batch_No},
            #                 {"Batch_No": request_data.Batch_No}]
            #     },
            #     filter_dict={"_id": 0, "Batch_No": 1, "Parent": 1}
            # ))
            # batch_no_list = []
            # for batch_data in data:
            #     if batch_data.get("Batch_No", "") == request_data.Batch_No:
            #         batch_no_list.extend(batch_data.get("Parent", []))
            #     batch_no_list.append(batch_data.get("Batch_No", ""))
            # set(batch_no_list)
            table_data = list(self.b_batch_parameter_data.find(
                {
                    "Batch_No": request_data.Batch_No
                }
            ))
            tab_list = []
            for tab_data in table_data:
                dict_tab = {
                    "Batch No": tab_data.get("Batch_No"),
                    "Product Name": tab_data.get("Product_Name"),
                    "Mfg. Date": tab_data.get("Mfg_Date"),
                    "Site": tab_data.get("Site"),
                    "Plant": tab_data.get("Plant"),
                    "Mfg. Stage": tab_data.get("Mfg_Stage"),
                    "Unit Operation": tab_data.get("Unit_Operation"),
                    "Parameter": tab_data.get("Parameter"),
                    "Value": tab_data.get("Value"),
                    "UOM": tab_data.get("UOM"),
                    "Source": tab_data.get("Source")
                }
                tab_list.append(dict_tab)

            table_data_response = {"message":
                                       {"tableData":
                                            {"bodyContent": tab_list,
                                             "headerContent": BatchAppHeaderContentConstants.B_BATCH_PARAMETER_HEADER}},
                                   "tableData": {"bodyContent": [],
                                                 "headerContent": BatchAppHeaderContentConstants.B_BATCH_PARAMETER_HEADER}}
            return table_data_response
        except Exception as e:
            logger.error(f"Error while getting sample data - {e}")
            raise

    async def get_raw_material_table_data(self, request_data: GetTableData):
        try:
            # data = list(self.a_batch_info.find(
            #     query={
            #         "$or": [{"Parent": request_data.Batch_No},
            #                 {"Batch_No": request_data.Batch_No}]
            #     },
            #     filter_dict={"_id": 0, "Batch_No": 1, "Parent": 1}
            # ))
            # batch_no_list = []
            # for batch_data in data:
            #     if batch_data.get("Batch_No", "") == request_data.Batch_No:
            #         batch_no_list.extend(batch_data.get("Parent", []))
            #     batch_no_list.append(batch_data.get("Batch_No", ""))
            # set(batch_no_list)
            table_data = list(self.c_raw_material_data.find(
                {
                    "Batch_No": request_data.Batch_No
                }
            ))
            tab_list = []
            for tab_data in table_data:
                dict_tab = {
                    "Batch No": tab_data.get("Batch_No"),
                    "Material No": tab_data.get("Material_No"),
                    "Material Description": tab_data.get("Material_Description"),
                    "Mfg. Date": tab_data.get("Mfg_Date"),
                    "Expiration Date": tab_data.get("Expiration_Date"),
                    "Vendor": tab_data.get("Vendor"),
                    "Parameter": tab_data.get("Parameter"),
                    "Value": tab_data.get("Value"),
                    "UOM": tab_data.get("UOM"),
                }
                tab_list.append(dict_tab)

            table_data_response = {"message":
                                       {"tableData":
                                            {"bodyContent": tab_list,
                                             "headerContent": BatchAppHeaderContentConstants.C_RAW_MATERIAL_DATA}}}
            return table_data_response
        except Exception as e:
            logger.error(f"Error while getting sample data - {e}")
            raise

    async def get_non_compliance_table_data(self, request_data: GetTableData):
        try:
            # data = list(self.a_batch_info.find(
            #     query={
            #         "$or": [{"Parent": request_data.Batch_No},
            #                 {"Batch_No": request_data.Batch_No}]
            #     },
            #     filter_dict={"_id": 0, "Batch_No": 1, "Parent": 1}
            # ))
            # batch_no_list = []
            # for batch_data in data:
            #     if batch_data.get("Batch_No", "") == request_data.Batch_No:
            #         batch_no_list.extend(batch_data.get("Parent", []))
            #     batch_no_list.append(batch_data.get("Batch_No", ""))
            # set(batch_no_list)
            table_data = list(self.d_non_compliance_data.find(
                {
                    "Batch_No": request_data.Batch_No
                }
            ))
            tab_list = []
            for tab_data in table_data:
                dict_tab = {
                    "Batch No": tab_data.get("Batch_No"),
                    "NC ID": tab_data.get("NC_ID"),
                    "Classification": tab_data.get("Classification"),
                    "Description": tab_data.get("Description"),
                    "Status": tab_data.get("Status"),
                    "Observed Date": tab_data.get("Observed_Date"),
                    "Closed Date": tab_data.get("Closed_Date"),
                    "Due Date": tab_data.get("Due_Date")
                }
                tab_list.append(dict_tab)

            table_data_response = {"message":
                                       {"tableData":
                                            {"bodyContent": tab_list,
                                             "headerContent": BatchAppHeaderContentConstants.D_NON_COMPLIANCE_DATA}}}
            return table_data_response
        except Exception as e:
            logger.error(f"Error while getting sample data - {e}")
            raise
