import base64

from Cryptodome import Random
from Cryptodome.Cipher import AES
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric import rsa

from scripts.logging.logging import logger as LOG


class AESCipher(object):
    """
    A classical AES Cipher. Can use any size of data and any size of password thanks to padding.
    Also ensure the coherence and the type of the data with a unicode to byte converter.
    """

    def __init__(self, key):
        self.bs = AES.block_size
        self.key = AESCipher.str_to_bytes(key)

    @staticmethod
    def str_to_bytes(data):
        u_type = type(b''.decode('utf8'))
        if isinstance(data, u_type):
            return data.encode('utf8')
        return data

    def _pad(self, s):
        return s + (self.bs - len(s) % self.bs) * AESCipher.str_to_bytes(chr(self.bs - len(s) % self.bs))

    @staticmethod
    def _unpad(s):
        return s[:-ord(s[len(s) - 1:])]

    def encrypt(self, raw):
        raw = self._pad(AESCipher.str_to_bytes(raw))
        iv = Random.new().read(AES.block_size)
        cipher = AES.new(self.key, AES.MODE_CBC, iv)
        return base64.b64encode(iv + cipher.encrypt(raw)).decode('utf-8')

    def decrypt(self, enc):
        # self.key = self.key.decode()
        enc = base64.b64decode(enc)
        iv = enc[:AES.block_size]
        cipher = AES.new(self.key, AES.MODE_CBC, iv)
        data = self._unpad(cipher.decrypt(enc[AES.block_size:]))
        return data.decode('utf-8')


class AsymmetricEncryption(object):
    """
    This utility is used for :-
        1.generating public and private key pairs
        2.serialization of keys to strings
        3.deserialization of strings to keys
        4.Encryption and decryption of data
    """

    def __init__(self):
        pass

    @staticmethod
    def generate_private_key():
        try:
            private_key = rsa.generate_private_key(
                public_exponent=65537,
                key_size=2048,
                backend=default_backend()
            )
            return private_key
        except Exception as e:
            LOG.error("Exception in generating private key" + str(e))
            return None

    @staticmethod
    def generate_public_key(private_key):
        return private_key.public_key()

    @staticmethod
    def encrypt_data(public_key, message):
        """
        param: message - -- string
        param: public_key - --object
        :returns string
        """
        try:
            encrypted_msg = public_key.encrypt(
                message.encode('utf-8'),
                padding.OAEP(
                    mgf=padding.MGF1(algorithm=hashes.SHA256()),
                    algorithm=hashes.SHA256(),
                    label=None
                )
            )
            return base64.b64encode(encrypted_msg).decode('utf-8')
        except Exception as e:
            LOG.error("Exception in encryption" + str(e))
            raise e

    @staticmethod
    def decrypt_data(private_key, encrypted_data):
        """
        :param encrypted_data - -- string
        :param private_key - --object
        :returns string
        """
        try:
            decrypted_data = private_key.decrypt(
                base64.b64decode(encrypted_data.encode('utf-8')),
                padding.OAEP(
                    mgf=padding.MGF1(algorithm=hashes.SHA256()),
                    algorithm=hashes.SHA256(),
                    label=None
                )
            )
            return decrypted_data.decode('utf-8')
        except Exception as e:
            LOG.error("Exception in decryption" + str(e))
            return None

    @staticmethod
    def gen_signature(message, private_key):
        """
        :param message--- string
        :param private_key---object
        :returns string
        """
        try:
            signature = private_key.sign(
                data=message.encode('utf-8'),
                padding=padding.PSS(
                    mgf=padding.MGF1(hashes.SHA256()),
                    salt_length=padding.PSS.MAX_LENGTH
                ),
                algorithm=hashes.SHA256()

            )
            return base64.b64encode(signature).decode("utf-8")
        except Exception as e:
            LOG.error("Exception in signing" + str(e))
            return None

    @staticmethod
    def verify_signature(signature, public_key, message):
        """
        :param signature: str
        :param public_key: str
        :param message: str
        """
        try:
            public_key.verify(
                signature=base64.b64decode(signature.encode("utf-8")),
                data=message.encode("utf-8"),
                padding=padding.PSS(
                    mgf=padding.MGF1(hashes.SHA256()),
                    salt_length=padding.PSS.MAX_LENGTH

                ),
                algorithm=hashes.SHA256()
            )
            signature_valid = True
        except InvalidSignature:
            signature_valid = False
        return signature_valid

    @staticmethod
    def serialize_public_key(public_key):
        public_key_pem = public_key.public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo
        ).decode("utf-8")
        return public_key_pem

    @staticmethod
    def serialize_private_key(private_key):
        private_key_pem = private_key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.PKCS8,
            encryption_algorithm=serialization.NoEncryption()
        ).decode("utf-8")
        return private_key_pem

    @staticmethod
    def deserialize_private_key(private_key_pem):
        private_key = serialization.load_pem_private_key(
            private_key_pem.encode("utf-8"),
            password=None,
            backend=default_backend()
        )
        return private_key

    @staticmethod
    def deserialize_public_key(public_key_pem):
        public_key = serialization.load_pem_public_key(
            public_key_pem.encode("utf-8"),
            backend=default_backend()
        )
        return public_key
