import base64
import copy
import hashlib
import json
import os
from datetime import datetime
from operator import itemgetter
from uuid import UUID

from Crypto import Random
from Crypto.Cipher import AES

from scripts.config import Security
from scripts.errors import CustomError
from scripts.errors.mongo_exceptions import (
    MongoException,
    MongoUnknownDatatypeException,
)
from scripts.logging import logger
from scripts.utils.jwt_util import JWT

exclude_encryption_datatypes = (
    datetime,
    UUID,
)

class MongoEncryptionConstants:
    # mongo encryption keys
    key_encrypt_keys = "encrypt_keys"
    key_exclude_encryption = "exclude_encryption"
    product_encrypted = "product_encrypted"
    max_docs_per_batch = 5
    cipher_key = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJrZXkiOiItLS0tLUJFR0lOIFJTQSBQUklWQVRFIEtFWS0tLS0tXG5NSUlFb3dJQkFBS0NBUUVBclZFRDVjcit0TXRGdFZtWGwyTzBjdlFiRWdvWVNJRmQ4eXZrbW02ejdYQWRYNkVnXG5Za0tlejB5ZFRsMjZLT2RKMThBN0tuOGV0V0dlOG5Ua1NHaGVKbDlybi9KK2xFMXpwbzRaZy9UM3dEbk04Rk0zXG55dU0yNnZwSWIrMG9KbU5jOURrRlhvNFd0eFJGWkR5dGRFVGcvWXlJK2VKWURSRHJaU3JscUF6SURwQWRMcHY5XG5VaHNNaFlRKzJuM1BjYXVMZUpiMGRLUFZUYzZrU3ZHQ3MzTFowV3lUYlJuUXlKTUNXbmF4enBTSVVjSDdxYXFPXG5LQy9mQkNLc1ptUmpSTlNtUTNnZXB6NFZuUUt5SkNtN0NKaytjUWlRTVF6cnNwUlB2aG1Hb3VIWlVNMzZLanNHXG42eWx4MkJ1Nk9ZeS9IYnJkUmtKS05sdjN1NkJCTDZQbi9aSlpHUUlEQVFBQkFvSUJBQkk4ZU1oRVNuWWJtMVJJXG5XOFM4WXplSU8xUHoxM2hEa3U3Y0FyY0VLRzcya2NTbTU4a25BTjVIamJLNTluVkkxdEo2Z2M4NEpuTkgxUWxtXG5ac0crcDQ5cWtXQzRTM3pQeEhnMU1mYWFQenBNNnFVcjRHNDY1Nk9rVjV4ZFRCRHorZ3NoZDlEcDZ2WnpEZFVjXG45RlJNVGc4bnF4Nzk0NjFtUnhwelA4eGxvYVEwTmNLQnpGSzllM2cvNGk3Mkx3Z05QM0U2eG1FU2l1N2dvcUoxXG5HT0FJMm1KaWUzVFRZMXo4c2Y0dWlTRkxNYUZyRXhrcTR6NEtrd1M3cUYybk9KeGh2OEgvZzlUR1BOV3JuekF3XG55Qkh3SU5Cb1VhSndpT1Q1MXh4SURMZ05RaU5vSUZ1YU1LVnUybCtyV3RvUVdLR2lPbncxWmhZeGVKQ1hCeVhDXG5RcXBBZmdFQ2dZRUF3cHpTZnlvdDNQQWx4bTlpVks1WmM2bFJkQnE3SmF6dDd0OTFVNnplWTdDNHh6TkcxVHVmXG5jU1lLM3FSd2xNdzJ1WGw5YXV4eVY0MXJ6aVg5c1podEZVbm00amNHdjlNSGVhQWFTU1BTc3ZydFpERkJTN2t5XG5sMkl4azEwNzhMVFpDTE1ZbUFLQ0FyMlhMbVNoQlBTVmN1YUxrRFJYNHJ2dzdzY1dtTWI4NndFQ2dZRUE0L3lDXG5FQWpYbEwwV2xPWURKM0ovL1BnNGlCdEllZEhYbW4zMGdvTnVDQkJhb1l5Z1hhcGV5dEVtVTJxNWh5YlFUTVRYXG5WbC92SUFGaXUwVFg4MVZRN0xETEphYmVyLzdHRXNJVDN4K3htMGpGdk94RllWaFQ1YjBzMHoxQ1FvbG5SRnNBXG5kSXdRNXU1R2tQNjVoeUpVYTNaTWgrTDZWaXNTQ1RLcEFjbzlaaGtDZ1lBS0ZaNUN3S2pIdmhuM0FtYVNCTWJWXG4yM3hCQy9HT3JqdFdHWFkyODhwQ1dESDdBSWszRzNQVHBTa0RDSHBjKzRnS2JHVTNXVEZEb0N4cDdrWUxJZDdsXG5MNE1yVGJhbjBnT2RKZEsyMzRoWGhmRXZNKzR5UWxLQXpiSEw5UlRhRUVUKzBtai8xNEZ0S3UzZWxaQlNkV29aXG5IaUUxUThFYUdxc05kSHVUUnh4c0FRS0JnUUNxdzdlbnl2ZXVzUEw1RkUvSWZEcmhnQXJYNTVlaHAwdVdyRUU0XG5nTGtwMFJZUmF3T3pKS2xid015MExueElmd29HZG1uVWlJYlRzallCanM4eHMvV3BVOExWc09lYmEzbHhFMjFPXG44cTVWWVd5NjFUNGlhOVpyamdiRk1sMHUrVHdnTndsZ1FvbG1iNUxyaDkvdkdBZWpkamhjaitaeUpGQ2VFeFFFXG5BemQ2QVFLQmdCaGUrRndNaFR1czk2MWpxRUtYQlhtMC9PYU9nek9kZ2wvYXN1QzhvTFU3Y0FWRDdzUzJMRmNVXG51N29mSVZJRzZjUldScnVhakl1Q2RsSWNMT2VkVEU0WUw1akF1UkwxVHlWdnhNbTBGc3JrV1BBQkZySFdoc1pzXG5UU3pwaU9GSmtMSlRWblQ3aGxXLyttMHFyS2lXMHpyRnphMEphRndQL2xqK2hScllHa09sXG4tLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLSIsImlzcyI6ImlsZW5zIiwiZXhwIjoxNzg4NTg2MzYyfQ.K6PrPcum1ACp9jQtqL3oNncnmXtTnPEOLYWCaaHmFMpLnAPAnKlYblsQkx4nv4pskJ3DBzSk6H-7Tnns4oejfaZI56wHhGz99JZN9mQ9JrQazZ01uccAwhcaOOMnMEny5J4Q6FB0OyyNIxSsScx2s21Vx-eJvOV1FOrCBjvZG78"


enc = AES.MODE_CBC
jwt = JWT()

try:
    file_name = Security.ENCRYPTION_CONSTANTS_FILE_PATH
    if not os.path.exists(file_name):
        encrypt_collection_dict = {}
    else:
        with open(file_name) as f:
            mongo_encryption_constants_data = json.load(f)
        if "encrypt_collection_dict" in mongo_encryption_constants_data and Security.USER_ENCRYPTION:
            encrypt_collection_dict = mongo_encryption_constants_data["encrypt_collection_dict"]
        else:
            encrypt_collection_dict = {}
except Exception as es:
    encrypt_collection_dict = {}
    logger.exception(f" Unable to fetch mongo encryption constants:{str(es)}")


class AESCipher:
    def __init__(self, key):
        self.bs = AES.block_size
        self.key = hashlib.sha256(key.encode()).digest()

    def get_cipher(self, iv):
        return AES.new(self.key, enc, iv)

    def encrypt(self, raw):
        raw = self._pad(raw)
        iv = Random.new().read(AES.block_size)
        cipher = self.get_cipher(iv)
        return base64.b64encode(iv + cipher.encrypt(raw.encode()))

    def decrypt(self, enc):
        enc = base64.b64decode(enc)
        iv = enc[: AES.block_size]
        cipher = self.get_cipher(iv)
        return self._unpad(cipher.decrypt(enc[AES.block_size :])).decode("utf-8")

    def _pad(self, s):
        return s + (self.bs - len(s) % self.bs) * chr(self.bs - len(s) % self.bs)

    @staticmethod
    def _unpad(s):
        return s[: -ord(s[len(s) - 1 :])]


class MongoDataEncryption:
    def __init__(self):
        self.decoded_aes = jwt.decode(MongoEncryptionConstants.cipher_key).get("key")
        self.aes_cipher = AESCipher(key=self.decoded_aes)
        # pass

    def create_encrypted_string(self, payload):
        return self.aes_cipher.encrypt(raw=json.dumps(payload))

    def create_decrypted_string(self, payload):
        decrypted_result = self.aes_cipher.decrypt(enc=payload)
        try:
            result = json.loads(decrypted_result)
        except json.JSONDecodeError:
            if isinstance(decrypted_result, bytes):
                result = decrypted_result.decode("utf-8")
            else:
                result = decrypted_result

        return result

    def encrypt_data(self, json_data, collection_name):
        """
        Encrypt the data in mongo based on the collection and key to be encrypted.
        :param json_data: The data to be encrypted
        :param collection_name: The collection where the document is stored
        :return: Encrypted document based on product defined configuration.
        """
        try:
            if collection_name in encrypt_collection_dict.keys():
                if isinstance(json_data, list):
                    encrypted_data = []
                    for data in encrypted_data:
                        dict_data = self.encrypt_dict_data(doc=data, collection_name=collection_name)
                        encrypted_data.append(dict_data)
                elif isinstance(json_data, dict):
                    encrypted_data = self.encrypt_dict_data(doc=json_data, collection_name=collection_name)
                else:
                    raise MongoUnknownDatatypeException(
                        "Unsupported datatype '{}' is being inserted to mongodb.".format(type(json_data))
                    )
            else:
                logger.info("Given data is not a part of the Mongo encryption setup. Skipping encryption")
                if isinstance(json_data, dict):
                    encrypted_data = json_data
                    encrypted_data[MongoEncryptionConstants.product_encrypted] = False
                else:
                    encrypted_data = json_data
            return encrypted_data
        except MongoException as e:
            raise MongoException(str(e))
        except Exception as e:
            raise MongoException("Server faced a problem when encrypting the data --> {}".format(str(e)))

    def encrypt_dict_data(self, doc, collection_name):  # NOSONAR
        """
        This method crawls the document and encrypts the keys that are marked for encryption.
        Skips encrypting the keys with values of the datatypes defines in the tuple 'exclude_encryption_datatypes'
        Adds two new keys to the document 'product_encrypted' and 'encryption_salt'
        Key product_encrypted - Is a boolean value which flags a document as encrypted by this utility.
        Key encryption_salt - List of all the values that were excluded from encryption due to datatype constraints.
        :param doc: The document considered for encryption
        :param collection_name: The collection where the document resided.
                                This is needed for the utility to read the encryption configuration
        :return: The input document with the relevant keys encrypted.
        """
        try:
            is_mlens_encrypted = False
            encrypted_data = {}
            encrypted_data["encryption_salt"] = {}
            if "*" in encrypt_collection_dict[collection_name][MongoEncryptionConstants.key_encrypt_keys]:
                # Forming encryption salt
                for index, exclude_encryption_datatype in enumerate(exclude_encryption_datatypes):
                    if exclude_encryption_datatype not in [None, ""]:
                        encrypted_data["encryption_salt"][f"dt_{index}"] = self.search_datatype(
                            doc, exclude_encryption_datatype
                        )
                        sorted_path = sorted(
                            encrypted_data["encryption_salt"][f"dt_{index}"], key=itemgetter("p"), reverse=True
                        )
                        for _path in sorted_path:
                            to_pop = self.remove_value_of_datatype_command(_path, "dict_data")
                            exec(to_pop)
                for dt in encrypted_data["encryption_salt"]:
                    for path_index, _path in enumerate(encrypted_data["encryption_salt"][dt]):
                        encrypted_data["encryption_salt"][dt][path_index]["p"] = base64.b64encode(_path["p"].encode())

                # Encrypting the data
                for key in doc.keys():
                    if (
                        key
                        not in encrypt_collection_dict[collection_name][MongoEncryptionConstants.key_exclude_encryption]
                    ):
                        encrypted_data[key] = {
                            "d": self.create_encrypted_string(payload=self.convert(doc[key])),
                            "t": base64.b64encode(type(doc[key]).__name__.encode()),
                        }
                        is_mlens_encrypted = True
                    else:
                        encrypted_data[key] = doc[key]
            else:
                for key in doc.keys():
                    if key in encrypt_collection_dict[collection_name][MongoEncryptionConstants.key_encrypt_keys]:
                        # Forming encryption salt
                        for index, exclude_encryption_datatype in enumerate(exclude_encryption_datatypes):
                            if exclude_encryption_datatype not in [None, ""]:
                                temp_dict_data = {}
                                temp_dict_data[key] = copy.deepcopy(doc[key])
                                encrypted_data["encryption_salt"][f"dt_{index}"] = self.search_datatype(
                                    temp_dict_data, exclude_encryption_datatype
                                )
                                sorted_path = sorted(
                                    encrypted_data["encryption_salt"][f"dt_{index}"],
                                    key=itemgetter("p"),
                                    reverse=True,
                                )
                                for _path in sorted_path:
                                    to_pop = self.remove_value_of_datatype_command(_path, "dict_data")
                                    exec(to_pop)
                        for dt in encrypted_data["encryption_salt"]:
                            for path_index, _path in enumerate(encrypted_data["encryption_salt"][dt]):
                                encrypted_data["encryption_salt"][dt][path_index]["p"] = base64.b64encode(
                                    _path["p"].encode()
                                )
                        # Encrypting the data
                        encrypted_data[key] = {
                            "d": self.create_encrypted_string(payload=self.convert(doc[key])),
                            "t": base64.b64encode(type(doc[key]).__name__.encode()),
                        }
                        is_mlens_encrypted = True
                    else:
                        encrypted_data[key] = doc[key]
            encrypted_data[MongoEncryptionConstants.product_encrypted] = is_mlens_encrypted
            if not encrypted_data[MongoEncryptionConstants.product_encrypted]:
                del encrypted_data["encryption_salt"]
            return encrypted_data
        except MongoException as e:
            raise MongoException(str(e))
        except Exception as e:
            raise MongoException("Server faced a problem when encrypting the data --> {}".format(str(e)))

    def decrypt_data(self, dict_data, _collection_name):  # NOSONAR
        """
        This method decrypts all the data that is encrypted.
        Keys that were excluded during encryption and have been added to the encryption_salt
        will be added back to their original positions.
        :param dict_data: The document that needs to be decrypted
        :param _collection_name: The collection to which the document belongs to
        :return: The decrypted data with the original data types intact
        """
        try:
            if _collection_name in encrypt_collection_dict.keys():
                decrypted_data = {}
                if "*" in encrypt_collection_dict[_collection_name][MongoEncryptionConstants.key_encrypt_keys]:
                    for key in dict_data.keys():
                        if key not in encrypt_collection_dict[_collection_name][
                            MongoEncryptionConstants.key_exclude_encryption
                        ] and not isinstance(dict_data[key], exclude_encryption_datatypes):
                            if isinstance(dict_data[key], dict):
                                if "d" in dict_data[key].keys() and "t" in dict_data[key].keys():
                                    decrypted_data[key] = self.decrypt_convert_proper_data_type(
                                        data=self.create_decrypted_string(payload=dict_data[key]["d"]),
                                        data_type=base64.b64decode(dict_data[key]["t"].decode()).decode(),
                                    )
                                else:
                                    decrypted_data[key] = dict_data[key]
                            else:
                                decrypted_data[key] = dict_data[key]
                        else:
                            decrypted_data[key] = dict_data[key]
                else:
                    for key in dict_data.keys():
                        if key in encrypt_collection_dict[_collection_name][
                            MongoEncryptionConstants.key_encrypt_keys
                        ] and not isinstance(dict_data[key], exclude_encryption_datatypes):
                            if isinstance(dict_data[key], dict):
                                if "d" in dict_data[key].keys() and "t" in dict_data[key].keys():
                                    decrypted_data[key] = self.decrypt_convert_proper_data_type(
                                        data=self.create_decrypted_string(payload=dict_data[key]["d"]),
                                        data_type=base64.b64decode(dict_data[key]["t"].decode()).decode(),
                                    )
                                else:
                                    decrypted_data[key] = dict_data[key]
                            else:
                                decrypted_data[key] = dict_data[key]
                        else:
                            decrypted_data[key] = dict_data[key]
            else:
                decrypted_data = dict_data
            return decrypted_data
        except MongoException as e:
            raise MongoException(str(e))
        except Exception as e:
            raise MongoException(f"Server faced a problem when decrypting the data: {str(e)}")

    @staticmethod
    def decrypt_convert_proper_data_type(data, data_type):
        """
        Convert the de-serialized JSON object to the original data-type
        :param data: The de-serialized data
        :param data_type: The original data type to which the de-serialized data should be converted to
        :return: The de-serialized data with it's original data type.
        """
        if data_type == "int":
            return int(data)
        elif data_type == "list":
            return data
        elif data_type == "dict":
            return data
        elif data_type == "bool":
            return data
        else:
            return data.lstrip('"').rstrip('"')

    def convert(self, data):
        """
        Convert all byte-like objects into the proper data types.
        This supports conversion of nested dict, list and tuples.
        :param data:
        :return:
        """
        if isinstance(data, bytes):
            return data.decode("ascii")
        if isinstance(data, dict):
            return dict(map(self.convert, data.items()))
        if isinstance(data, tuple):
            return map(self.convert, data)
        if isinstance(data, list):
            return list(map(self.convert, data))
        return data

    def search_datatype(self, _input, search_type, prev_datapoint_path=""):  # NOSONAR
        """
        Search for an excluded data type in a nested dictionary or list and record it's path in the document.
        This does not support the exclusion of data of types dict and list.
        :param _input: The input data
        :param search_type: The data type to be searched for to exclude.
        :param prev_datapoint_path: The path of a value in a nested dict or nested list.
        :return: List of dictionaries, with each dictionary containing the true value and it's path.
        """
        try:
            output = []
            current_datapoint = _input
            current_datapoint_path = prev_datapoint_path
            if search_type is dict:
                raise CustomError("Searching for datatype dict is not supported!")
            elif search_type is list:
                raise CustomError("Searching for datatype list is not supported!")
            else:
                if isinstance(current_datapoint, dict):
                    for dkey in current_datapoint:
                        temp_datapoint_path = current_datapoint_path
                        temp_datapoint_path += "dict-{}.".format(dkey)
                        for index in self.search_datatype(current_datapoint[dkey], search_type, temp_datapoint_path):
                            output.append(index)
                elif isinstance(current_datapoint, list):
                    for index in range(0, len(current_datapoint)):
                        temp_datapoint_path = current_datapoint_path
                        temp_datapoint_path += "list-{}.".format(index)
                        for index_1 in self.search_datatype(current_datapoint[index], search_type, temp_datapoint_path):
                            output.append(index_1)
                elif isinstance(current_datapoint, search_type):
                    output.append({"p": current_datapoint_path, "v": current_datapoint})
                output = filter(None, output)
                return list(output)
        except Exception:
            raise CustomError(f"Server faced a problem when searching for instances of datatype  --> '{search_type}' ")

    @staticmethod
    def remove_value_of_datatype_command(remove_value, var_name):
        """
        This method produces the command for the value to be removed from a nested dict or list,
        when given the path of that value in the source variable.
        :param remove_value: The value (it's path) to be removed.
        :param var_name: The variable on which the exec function should run on to remove the non-serializable value.
        :return: The final command that will run in the exec function to remove the value from a nested dict or list.
        """
        temp_path = ""
        individual_path_list = remove_value["p"].split(".")
        individual_path_list.remove("")
        if individual_path_list[len(individual_path_list) - 1].split("-")[0] == "dict":
            orig_path = "del {var_name}{path}"
        elif individual_path_list[len(individual_path_list) - 1].split("-")[0] == "list":
            pop_index = ".pop({})".format(individual_path_list[len(individual_path_list) - 1].split("-")[1])
            orig_path = "{var_name}{path}" + pop_index
            individual_path_list.pop(len(individual_path_list) - 1)
        else:
            return
        for _path_index, path in enumerate(individual_path_list):
            if path.split("-")[0] == "dict":
                temp_path += '["{}"]'.format(path.split("-")[1])
            elif path.split("-")[0] == "list":
                temp_path += "[{}]".format(path.split("-")[1])
        orig_path = orig_path.format(path=temp_path, var_name=var_name)
        return orig_path

    @staticmethod
    def add_value_datatype_command(add_value, var_name, value):
        """
        This method produces the command for the value to be added back to a nested dict or list,
        when given the path of that value in the source variable.
        :param add_value: The value (it's path) to be added
        :param var_name: The source variable name on which the exec function should run on.
        :param value: The original non-serialized value.
        :return: The command to be executed on the source variable.
        """
        path_string = ""
        temp_path_string = ""
        individual_path_list = add_value["p"].split(".")
        individual_path_list.remove("")
        for _path_index, path in enumerate(individual_path_list):
            if path.split("-")[0] == "dict":
                temp_path_string = '["{}"]'.format(path.split("-")[1])
            elif path.split("-")[0] == "list":
                temp_path_string = "[{}]".format(path.split("-")[1])
            else:
                raise CustomError("Unsupported datatype given for add value")
            path_string += temp_path_string
        if individual_path_list[len(individual_path_list) - 1].split("-")[0] == "dict":
            command = "{var_name}{path} = {value}".format(var_name=var_name, path=path_string, value=value)
        elif individual_path_list[len(individual_path_list) - 1].split("-")[0] == "list":
            command = "{var_name}{path}].append({value})".format(
                var_name=var_name, path=path_string.rstrip(temp_path_string), value=value
            )
        else:
            raise CustomError("Unsupported datatype given for add value")
        return command
