"""
Mongo Utility
Reference: Pymongo Documentation
"""
import base64
import copy
import hashlib
import json
import os
import sys
from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime
from operator import itemgetter
from typing import Dict, List, Optional
from uuid import UUID

# from Crypto import Random
# from Crypto.Cipher import AES
from Cryptodome import Random
from Cryptodome.Cipher import AES
from pymongo import MongoClient

from scripts.constants import Conf, Constants
from scripts.exceptions import Exceptions
from scripts.exceptions.module_exceptions import MongoException, MongoUnknownDatatypeException, \
    MongoConnectionException, MongoFindException
from scripts.logging.logging import logger
from scripts.utils.common_utils import CommonUtils

exclude_encryption_datatypes = (datetime, UUID,)
try:
    file_name = Conf.MONGO_ENCRYPTION_FILE_PATH
    if not os.path.exists(file_name):
        encrypt_collection_dict = {}
    else:
        with open(file_name) as f:
            mongo_encryption_constants_data = json.load(f)
        if "encrypt_collection_dict" in mongo_encryption_constants_data:
            encrypt_collection_dict = mongo_encryption_constants_data["encrypt_collection_dict"]
        else:
            encrypt_collection_dict = {}
except Exception as es:
    encrypt_collection_dict = {}
    logger.exception(" Unable to fetch mongo encryption constants:" + str(es))


class AESCipher(object):

    def __init__(self, key):
        # key = cipher_key
        self.bs = AES.block_size
        self.key = hashlib.sha256(key.encode()).digest()

    def encrypt(self, raw):
        raw = self._pad(raw)
        iv = Random.new().read(AES.block_size)
        cipher = AES.new(self.key, AES.MODE_CBC, iv)
        return base64.b64encode(iv + cipher.encrypt(raw.encode()))

    def decrypt(self, enc):
        enc = base64.b64decode(enc)
        iv = enc[:AES.block_size]
        cipher = AES.new(self.key, AES.MODE_CBC, iv)
        return self._unpad(cipher.decrypt(enc[AES.block_size:])).decode('utf-8')

    def _pad(self, s):
        return s + (self.bs - len(s) % self.bs) * chr(self.bs - len(s) % self.bs)

    @staticmethod
    def _unpad(s):
        return s[:-ord(s[len(s) - 1:])]


class MongoDataEncryption(object):
    def __init__(self):
        self.aes_cipher = AESCipher(key=Constants.cipher_key['k'])
        # pass

    def create_encrypted_string(self, payload):
        return self.aes_cipher.encrypt(raw=json.dumps(payload))

    def create_decrypted_string(self, payload):
        result = json.loads(self.aes_cipher.decrypt(enc=payload))
        return result

    def encrypt_data(self, json_data, collection_name):
        """
        Encrypt the data in mongo based on the collection and key to be encrypted.
        :param json_data: The data to be encrypted
        :param collection_name: The collection where the document is stored
        :return: Encrypted document based on product defined configuration.
        """
        # TODO: Automatically add an unsupported data type to the salt.
        try:
            if collection_name in encrypt_collection_dict.keys():
                if type(json_data) is list:
                    encrypted_data = list()
                    for data in encrypted_data:
                        dict_data = self.encrypt_dict_data(doc=data, collection_name=collection_name)
                        encrypted_data.append(dict_data)
                elif type(json_data) is dict:
                    encrypted_data = self.encrypt_dict_data(doc=json_data, collection_name=collection_name)
                else:
                    raise MongoUnknownDatatypeException("Unsupported datatype '{}' is being inserted to mongodb.".
                                                        format(type(json_data)))
            else:
                logger.debug("Given data is not a part of the Mongo encryption setup. Skipping encryption")
                if type(json_data) is dict:
                    encrypted_data = json_data
                    encrypted_data[Constants.product_encrypted] = False
                else:
                    encrypted_data = json_data
            return encrypted_data
        except MongoException as e:
            raise MongoException(str(e))
        except Exception as e:
            raise MongoException("Server faced a problem when encrypting the data --> {}".format(str(e)))

    def encrypt_dict_data(self, doc, collection_name):
        """
        This method crawls the document and encrypts the keys that are marked for encryption.
        Skips encrypting the keys with values of the datatypes defines in the tuple 'exclude_encryption_datatypes'
        Adds two new keys to the document 'product_encrypted' and 'encryption_salt'
        Key product_encrypted - Is a boolean value which flags a document as encrypted by this utility.
        Key encryption_salt - List of all the values that were excluded from encryption due to datatype constraints.
        :param doc: The document considered for encryption
        :param collection_name: The collection where the document resided.
                                This is needed for the utility to read the encryption configuration
        :return: The input document with the relevant keys encrypted.
        """
        try:
            is_mlens_encrypted = False
            encrypted_data = dict()
            encrypted_data["encryption_salt"] = dict()
            if '*' in encrypt_collection_dict[collection_name][Constants.key_encrypt_keys]:
                # Forming encryption salt
                for index, exclude_encryption_datatype in enumerate(exclude_encryption_datatypes):
                    if exclude_encryption_datatype not in [None, '']:
                        encrypted_data["encryption_salt"]["dt_{}".format(index)] = \
                            self.search_datatype(doc, exclude_encryption_datatype)
                        sorted_path = sorted(encrypted_data["encryption_salt"]["dt_{}".format(index)],
                                             key=itemgetter('p'), reverse=True)
                        for path_index, _path in enumerate(sorted_path):
                            to_pop = self.remove_value_of_datatype_command(_path, "dict_data")
                            exec(to_pop)
                for dt in encrypted_data["encryption_salt"]:
                    for path_index, _path in enumerate(encrypted_data["encryption_salt"][dt]):
                        encrypted_data["encryption_salt"][dt][path_index]['p'] = base64.b64encode(_path['p'].encode())

                # Encrypting the data
                for key in doc.keys():
                    if key not in \
                            encrypt_collection_dict[collection_name][Constants.key_exclude_encryption]:
                        encrypted_data[key] = {'d': self.create_encrypted_string(payload=self.convert(doc[key])),
                                               't': base64.b64encode(type(doc[key]).__name__.encode())}
                        is_mlens_encrypted = True
                    else:
                        encrypted_data[key] = doc[key]
            else:
                for key in doc.keys():
                    if key in encrypt_collection_dict[collection_name][Constants.key_encrypt_keys]:
                        # Forming encryption salt
                        for index, exclude_encryption_datatype in enumerate(exclude_encryption_datatypes):
                            if exclude_encryption_datatype not in [None, '']:
                                temp_dict_data = dict()
                                temp_dict_data[key] = copy.deepcopy(doc[key])
                                encrypted_data["encryption_salt"]["dt_{}".format(index)] = \
                                    self.search_datatype(temp_dict_data, exclude_encryption_datatype)
                                sorted_path = sorted(encrypted_data["encryption_salt"]["dt_{}".format(index)],
                                                     key=itemgetter('p'), reverse=True)
                                for path_index, _path in enumerate(sorted_path):
                                    to_pop = self.remove_value_of_datatype_command(_path, "dict_data")
                                    exec(to_pop)
                        for dt in encrypted_data["encryption_salt"]:
                            for path_index, _path in enumerate(encrypted_data["encryption_salt"][dt]):
                                encrypted_data["encryption_salt"][dt][path_index]['p'] = base64.b64encode(
                                    _path['p'].encode())
                        # Encrypting the data
                        encrypted_data[key] = {'d': self.create_encrypted_string(payload=self.convert(doc[key])),
                                               't': base64.b64encode(type(doc[key]).__name__.encode())}
                        is_mlens_encrypted = True
                    else:
                        encrypted_data[key] = doc[key]
            encrypted_data[Constants.product_encrypted] = is_mlens_encrypted
            if not encrypted_data[Constants.product_encrypted]:
                del encrypted_data["encryption_salt"]
            return encrypted_data
        except MongoException as e:
            raise MongoException(str(e))
        except Exception as e:
            raise MongoException("Server faced a problem when encrypting the data --> {}".format(str(e)))

    def decrypt_data(self, dict_data, _collection_name):
        """
                This method decrypts all the data that is encrypted.
                Keys that were excluded during encryption and have been added to the encryption_salt
                will be added back to their original positions.
                :param dict_data: The document that needs to be decrypted
                :param _collection_name: The collection to which the document belongs to
                :return: The decrypted data with the original data types intact
                """
        try:
            if _collection_name in encrypt_collection_dict.keys():
                decrypted_data = dict()
                if '*' in encrypt_collection_dict[_collection_name][Constants.key_encrypt_keys]:
                    for key in dict_data.keys():
                        if key not in encrypt_collection_dict[_collection_name][
                            Constants.key_exclude_encryption] and \
                                not isinstance(dict_data[key], exclude_encryption_datatypes):
                            if type(dict_data[key]) is dict:
                                if 'd' in dict_data[key].keys() and 't' in dict_data[key].keys():
                                    decrypted_data[key] = self.decrypt_convert_proper_data_type(
                                        data=self.create_decrypted_string(payload=dict_data[key]['d']),
                                        data_type=base64.b64decode(dict_data[key]['t'].decode()).decode()
                                    )
                                else:
                                    decrypted_data[key] = dict_data[key]
                            else:
                                decrypted_data[key] = dict_data[key]
                        else:
                            decrypted_data[key] = dict_data[key]
                else:
                    for key in dict_data.keys():
                        if key in encrypt_collection_dict[_collection_name][
                            Constants.key_encrypt_keys] and \
                                not isinstance(dict_data[key], exclude_encryption_datatypes):
                            if type(dict_data[key]) is dict:
                                if 'd' in dict_data[key].keys() and 't' in dict_data[key].keys():
                                    decrypted_data[key] = self.decrypt_convert_proper_data_type(
                                        data=self.create_decrypted_string(payload=dict_data[key]['d']),
                                        data_type=base64.b64decode(dict_data[key]['t'].decode()).decode()
                                    )
                                else:
                                    decrypted_data[key] = dict_data[key]
                            else:
                                decrypted_data[key] = dict_data[key]
                        else:
                            decrypted_data[key] = dict_data[key]
            else:
                decrypted_data = dict_data
            if Constants.product_encrypted in dict_data and \
                    dict_data[Constants.product_encrypted]:
                if "encryption_salt" in dict_data:
                    for dt in dict_data["encryption_salt"]:
                        for val_index, val in enumerate(dict_data["encryption_salt"][dt]):
                            dict_data["encryption_salt"][dt][val_index]['p'] = \
                                base64.b64decode(dict_data["encryption_salt"][dt][val_index]['p'].decode()).decode()
                    for dt in dict_data["encryption_salt"]:
                        for val_index, val in enumerate(sorted(dict_data["encryption_salt"][dt], key=itemgetter('p'))):
                            to_add = self.add_value_datatype_command(
                                add_value=dict_data["encryption_salt"][dt][val_index],
                                var_name="decrypted_data",
                                value="dict_data[\"encryption_salt\"][dt][val_index]['v']")
                            exec(to_add)

                else:
                    raise MongoException("Encrypted data does not have encryption salt! Unable to decrypt the data!")
            if Constants.product_encrypted in decrypted_data:
                del decrypted_data[Constants.product_encrypted]
            if "encryption_salt" in decrypted_data:
                del decrypted_data["encryption_salt"]
            return decrypted_data
        except MongoException as e:
            raise MongoException(str(e))
        except Exception as e:
            raise MongoException("Server faced a problem when decrypting the data: {}".format(str(e)))

    def decrypt_keys(self, encrypted_doc, collection_name, key_based=False):
        """
        This method loops through the document and decrypts all the keys.
        :param encrypted_doc: The document that needs to be decrypted
        :param collection_name: The collection to which the document belongs to.
        :param key_based: If decryption should be done based on key or on all keys (*)
        :return:
        """
        try:
            decrypted_data = dict()
            if key_based:
                condition_dict = encrypt_collection_dict[collection_name][Constants.key_encrypt_keys]
            else:
                condition_dict = encrypt_collection_dict[collection_name][
                    Constants.key_exclude_encryption]
            for key in encrypted_doc.keys():
                if key in condition_dict and not isinstance(encrypted_doc[key], exclude_encryption_datatypes):
                    if type(encrypted_doc[key]) is dict:
                        if 'd' in encrypted_doc[key].keys() and 't' in encrypted_doc[key].keys():
                            decrypted_data[key] = self.decrypt_convert_proper_data_type(
                                data=self.create_decrypted_string(payload=encrypted_doc[key]['d']),
                                data_type=base64.b64decode(encrypted_doc[key]['t'].decode()).decode()
                            )
                        else:
                            decrypted_data[key] = encrypted_doc[key]
                    else:
                        decrypted_data[key] = encrypted_doc[key]
                else:
                    decrypted_data[key] = encrypted_doc[key]
            return decrypted_data
        except Exception as e:
            raise MongoException("Server faced a problem when decrypting the keys: {}".format(str(e)))

    @staticmethod
    def decrypt_convert_proper_data_type(data, data_type):
        """
        Convert the de-serialized JSON object to the original data-type
        :param data: The de-serialized data
        :param data_type: The original data type to which the de-serialized data should be converted to
        :return: The de-serialized data with it's original data type.
        """
        if data_type == "int":
            return int(data)
        elif data_type == "list":
            return data
        elif data_type == "dict":
            return data
        elif data_type == "bool":
            return data
        else:
            return data.lstrip('"').rstrip('"')

    def convert(self, data):
        """
        Convert all byte-like objects into the proper data types.
        This supports conversion of nested dict, list and tuples.
        :param data:
        :return:
        """
        if isinstance(data, bytes):
            return data.decode('ascii')
        if isinstance(data, dict):
            return dict(map(self.convert, data.items()))
        if isinstance(data, tuple):
            return map(self.convert, data)
        if isinstance(data, list):
            return list(map(self.convert, data))
        return data

    def search_datatype(self, _input, search_type, prev_datapoint_path=''):
        """
        Search for an excluded data type in a nested dictionary or list and record it's path in the document.
        This does not support the exclusion of data of types dict and list.
        :param _input: The input data
        :param search_type: The data type to be searched for to exclude.
        :param prev_datapoint_path: The path of a value in a nested dict or nested list.
        :return: List of dictionaries, with each dictionary containing the true value and it's path.
        """
        try:
            output = []
            current_datapoint = _input
            current_datapoint_path = prev_datapoint_path
            if search_type is dict:
                raise Exception("Searching for datatype dict is not supported!")
            elif search_type is list:
                raise Exception("Searching for datatype list is not supported!")
            else:
                if isinstance(current_datapoint, dict):
                    for dkey in current_datapoint:
                        temp_datapoint_path = current_datapoint_path
                        temp_datapoint_path += "dict-{}.".format(dkey)
                        for index in self.search_datatype(current_datapoint[dkey], search_type, temp_datapoint_path):
                            output.append(index)
                elif isinstance(current_datapoint, list):
                    for index in range(0, len(current_datapoint)):
                        temp_datapoint_path = current_datapoint_path
                        temp_datapoint_path += "list-{}.".format(index)
                        for index_1 in self.search_datatype(current_datapoint[index], search_type, temp_datapoint_path):
                            output.append(index_1)
                elif isinstance(current_datapoint, search_type):
                    output.append(dict(p=current_datapoint_path, v=current_datapoint))
                output = filter(None, output)
                return list(output)
        except Exception as e:
            raise Exception("Server faced a problem when searching for instances of datatype '{}' --> ".
                            format(search_type, str(e)))

    @staticmethod
    def remove_value_of_datatype_command(remove_value, var_name):
        """
        This method produces the command for the value to be removed from a nested dict or list,
        when given the path of that value in the source variable.
        :param remove_value: The value (it's path) to be removed.
        :param var_name: The variable on which the exec function should run on to remove the non-serializable value.
        :return: The final command that will run in the exec function to remove the value from a nested dict or list.
        """
        temp_path = ''
        individual_path_list = remove_value["p"].split('.')
        individual_path_list.remove('')
        if individual_path_list[len(individual_path_list) - 1].split('-')[0] == "dict":
            orig_path = 'del {var_name}{path}'
        elif individual_path_list[len(individual_path_list) - 1].split('-')[0] == "list":
            pop_index = ".pop({})".format(individual_path_list[len(individual_path_list) - 1].split('-')[1])
            orig_path = '{var_name}{path}' + pop_index
            individual_path_list.pop(len(individual_path_list) - 1)
        else:
            return
        for path_index, path in enumerate(individual_path_list):
            if path.split('-')[0] == "dict":
                temp_path += "[\"{}\"]".format(path.split('-')[1])
            elif path.split('-')[0] == "list":
                temp_path += "[{}]".format(path.split('-')[1])
        orig_path = orig_path.format(path=temp_path, var_name=var_name)
        return orig_path

    @staticmethod
    def add_value_datatype_command(add_value, var_name, value):
        """
        This method produces the command for the value to be added back to a nested dict or list,
        when given the path of that value in the source variable.
        :param add_value: The value (it's path) to be added
        :param var_name: The source variable name on which the exec function should run on.
        :param value: The original non-serialized value.
        :return: The command to be executed on the source variable.
        """
        path_string = ''
        temp_path_string = ''
        individual_path_list = add_value["p"].split('.')
        individual_path_list.remove('')
        for path_index, path in enumerate(individual_path_list):
            if path.split('-')[0] == "dict":
                temp_path_string = "[\"{}\"]".format(path.split('-')[1])
            elif path.split('-')[0] == "list":
                temp_path_string = "[{}]".format(path.split('-')[1])
            else:
                raise Exception("Unsupported datatype given for add value")
            path_string += temp_path_string
        if individual_path_list[len(individual_path_list) - 1].split('-')[0] == "dict":
            command = "{var_name}{path} = {value}".format(var_name=var_name, path=path_string, value=value)
        elif individual_path_list[len(individual_path_list) - 1].split('-')[0] == "list":
            command = "{var_name}{path}].append({value})".format(var_name=var_name,
                                                                 path=path_string.rstrip(temp_path_string),
                                                                 value=value)
        else:
            raise Exception("Unsupported datatype given for add value")
        return command


class MongoConnect(MongoDataEncryption):

    def __init__(self):
        super().__init__()
        self._cu_ = CommonUtils()
        try:
            self.client = MongoClient(self._cu_.db_uri(), connect=False)
        except Exception as e:
            print(e)
            sys.exit(1)

    def insert_one(self,
                   database_name: str,
                   collection_name: str,
                   data: Dict):
        """
        The function is used to inserting a document to a collection in a Mongo Database.
        :param database_name: Database Name
        :param collection_name: Collection Name
        :param data: Data to be inserted
        :return: Insert ID
        """
        try:
            db = self.client[database_name]
            collection = db[collection_name]
            response = collection.insert_one(data)
            return response.inserted_id
        except Exception as e:
            raise MongoException(e)

    def insert_many(self,
                    database_name: str,
                    collection_name: str,
                    data: List):
        """
        The function is used to inserting documents to a collection in a Mongo Database.
        :param database_name: Database Name
        :param collection_name: Collection Name
        :param data: List of Data to be inserted
        :return: Insert IDs
        """
        try:
            db = self.client[database_name]
            collection = db[collection_name]
            response = collection.insert_many(data)
            return response.inserted_ids
        except Exception as e:
            raise MongoException(e)

    def find(self,
             database_name: str,
             collection_name: str,
             query: Dict,
             filter_dict: Optional[Dict] = None,
             sort=None,
             skip: Optional[int] = 0,
             limit: Optional[int] = None):
        """
        The function is used to query documents from a given collection in a Mongo Database
        :param database_name: Database Name
        :param collection_name: Collection Name
        :param query: Query Dictionary
        :param filter_dict: Filter Dictionary
        :param sort: List of tuple with key and direction. [(key, -1), ...]
        :param skip: Skip Number
        :param limit: Limit Number
        :return: List of Documents
        """
        if sort is None:
            sort = list()
        if filter_dict is None:
            filter_dict = {"_id": 0}
        try:
            db = self.client[database_name]
            collection = db[collection_name]
            if len(sort) > 0:
                cursor = collection.find(query, filter_dict).sort(sort).skip(skip)
            else:
                cursor = collection.find(query, filter_dict).skip(skip)
            if limit:
                cursor = cursor.limit(limit)
            response = list(cursor)
            cursor.close()
            return response
        except Exception as e:
            raise MongoException(e)

    def find_one(self,
                 database_name: str,
                 collection_name: str,
                 query: Dict,
                 filter_dict: Optional[Dict] = None):
        try:
            if filter_dict is None:
                filter_dict = {"_id": 0}
            db = self.client[database_name]
            collection = db[collection_name]
            response = collection.find_one(query, filter_dict)
            return response
        except Exception as e:
            raise MongoException(e)

    def update_one(self,
                   database_name: str,
                   collection_name: str,
                   query: Dict,
                   data: Dict,
                   upsert: bool = False):
        """

        :param upsert:
        :param database_name:
        :param collection_name:
        :param query:
        :param data:
        :return:
        """
        try:
            db = self.client[database_name]
            collection = db[collection_name]
            if "$set" in data:
                response = collection.update_one(query, data, upsert=upsert)
            else:
                response = collection.update_one(query, {"$set": data}, upsert=upsert)
            return True
        except Exception as e:
            raise MongoException(e)

    def update_many(self,
                    database_name: str,
                    collection_name: str,
                    query: Dict,
                    data: Dict,
                    upsert: bool = False):
        """

        :param upsert:
        :param database_name:
        :param collection_name:
        :param query:
        :param data:
        :return:
        """
        try:
            db = self.client[database_name]
            collection = db[collection_name]
            if "$set" in data:
                response = collection.update_many(query, data, upsert=upsert)
            else:
                response = collection.update_many(query, {"$set": data}, upsert=upsert)
            return True
        except Exception as e:
            raise MongoException(e)

    def delete_many(self,
                    database_name: str,
                    collection_name: str,
                    query: Dict):
        """

        :param database_name:
        :param collection_name:
        :param query:
        :return:
        """
        try:
            db = self.client[database_name]
            collection = db[collection_name]
            response = collection.delete_many(query)
            return response.deleted_count
        except Exception as e:
            raise MongoException(e)

    def delete_one(self,
                   database_name: str,
                   collection_name: str,
                   query: Dict):
        """

        :param database_name:
        :param collection_name:
        :param query:
        :return:
        """
        try:
            db = self.client[database_name]
            collection = db[collection_name]
            response = collection.delete_one(query)
            return response.deleted_count
        except Exception as e:
            raise MongoException(e)

    def distinct(self,
                 database_name: str,
                 collection_name: str,
                 query_key: str,
                 filter_json: Optional[Dict] = None):
        """
        :param database_name:
        :param collection_name:
        :param query_key:
        :param filter_json:
        :return:
        """
        try:
            db = self.client[database_name]
            collection = db[collection_name]
            response = collection.distinct(query_key, filter_json)
            return response
        except Exception as e:
            raise MongoException(e)

    def fetch_records_from_object(self, body, _collection_name):
        """
        Definition for fetching the record from object
        :param body:
        :param _collection_name:
        :return:
        """
        final_list = []
        try:
            # for doc in body:
            #     print(" doc",doc)
            #     final_json = doc
            #     final_json = self.decrypt_data(dict_data=final_json, _collection_name=_collection_name)
            #     print(" finaj json:",final_json)
            #     final_list.append(final_json)
            collection_name = [_collection_name] * len(body)
            # print(" body",body)
            # print(" collection name",collection_name)
            with ThreadPoolExecutor(max_workers=Constants.max_docs_per_batch) as executor:
                final_list = executor.map(self.decrypt_data, body, collection_name)
            final_list = list(final_list)
        except Exception as e:
            raise MongoException(str(e))
        return list(final_list)

    def aggregate(self, db_name: str, collection_name: str, list_for_aggregation: Optional[List]):
        """

        :param db_name:
        :param collection_name:
        :param list_for_aggregation:
        :return:
        """
        mg_response = dict()
        try:
            docid = self.client[db_name][collection_name]
            mg_response = docid.aggregate(list_for_aggregation)
        except Exception as e:
            raise MongoException(str(e))
        return mg_response

    def find_util(self, **kwargs):
        try:
            database_name = kwargs.get('database_name', None)
            collection_name = kwargs.get('collection_name', None)
            find_condition = kwargs.get('find_condition', dict())
            select_condition = kwargs.get('select_condition', None)
            sort_condition = kwargs.get('sort_condition', None)
            skip = kwargs.get('skip', 0)
            limit = kwargs.get('limit', None)

            db = self.client[database_name]
            if select_condition:
                mongo_response = db[collection_name].find(find_condition, select_condition)
            else:
                mongo_response = db[collection_name].find(find_condition)
            if sort_condition is not None:
                mongo_response = mongo_response.sort(sort_condition)
            if skip:
                mongo_response = mongo_response.skip(skip=skip)
            if limit is not None:
                mongo_response = mongo_response.limit(limit=limit)

            # total_records = db[collection_name].find(find_condition).count()
            # mongo_response = self.fetch_records_from_object(body=mongo_response, _collection_name=collection_name)
            return mongo_response
        except Exception as e:
            logger.error(f"{Exceptions.MONGO003}: {str(e)}")
            raise MongoFindException(f"{Exceptions.MONGO005}: {str(e)}")

    def close_connection(self):
        """
        To close the mongo connection
        :return:
        """
        try:
            if self.client is not None:
                self.client.close()
            logger.debug("Mongo connection closed")
        except Exception as e:
            logger.error(f"{Exceptions.MONGO007}: {str(e)}")
            raise MongoConnectionException(f"{Exceptions.MONGO007}: {str(e)}")

    def find_count(self, json_data, database_name, collection_name):
        """

        :param json_data:
        :param database_name: The database to which the collection/ documents belongs to.
        :param collection_name: The collection to which the documents belongs to.
        :return:
        """
        try:
            db = self.client[database_name]
            mongo_response = db[collection_name].find(json_data).count()
            logger.debug("fetched result count from mongo")
            return mongo_response
        except Exception as e:
            logger.error(f"{Exceptions.MONGO003}: {str(e)}")
            raise MongoFindException(f"{Exceptions.MONGO005}: {str(e)}")
