import jwt
from jwt.exceptions import (
    InvalidSignatureError,
    ExpiredSignatureError,
    MissingRequiredClaimError,
)

from scripts.constants.app_configuration import KeyPath
from scripts.constants.app_constants import Secrets
from scripts.exceptions import AuthenticationError, ErrorMessages
from scripts.logging.logging import logger


class JWT:
    def __init__(self):
        self.max_login_age = Secrets.LOCK_OUT_TIME_MINS
        self.issuer = Secrets.issuer
        self.alg = Secrets.alg
        self.public = KeyPath.public
        self.private = KeyPath.private

    def encode(self, payload):
        try:
            logger.debug('Inside encode')
            with open(self.private, "r") as f:
                key = f.read()
            return jwt.encode(payload, key, algorithm=self.alg)
        except Exception as e:
            logger.debug(f'Exception in encode: {str(e)}')
            raise
        finally:
            f.close()

    def validate(self, token):
        try:
            logger.debug(f'Inside validate')
            with open(self.public, "r") as f:
                key = f.read()
            payload = jwt.decode(
                token,
                key,
                algorithms=self.alg,
                leeway=Secrets.leeway_in_mins,
                options={"require": ["exp", "iss"]},
            )
            return payload
        except InvalidSignatureError:
            raise AuthenticationError(ErrorMessages.ERROR003)
        except ExpiredSignatureError:
            raise AuthenticationError(ErrorMessages.ERROR002)
        except MissingRequiredClaimError:
            raise AuthenticationError(ErrorMessages.ERROR002)
        except Exception as e:
            logger.debug(f'Exception in validate: {str(e)}')
            raise
        finally:
            f.close()
