import jwt
from jwt.exceptions import ExpiredSignatureError, InvalidSignatureError, MissingRequiredClaimError

from tb_sdk.config import KeyPath
from tb_sdk.connectors.constants.secrets import Secrets
from tb_sdk.connectors.errors import AuthenticationError, ErrorMessages


class JWT:
    def __init__(self):
        self.max_login_age = Secrets.LOCK_OUT_TIME_MINS
        self.issuer = Secrets.issuer
        self.alg = Secrets.alg
        self.public = KeyPath.public
        self.private = KeyPath.private

    def encode(self, payload):
        try:
            with open(self.private) as f:
                key = f.read()
            return jwt.encode(payload, key, algorithm=self.alg)
        except Exception as e:
            print(f"Exception while encoding JWT: {str(e)}")
            raise
        finally:
            f.close()

    def validate(self, token):
        try:
            with open(self.public) as f:
                key = f.read()
            payload = jwt.decode(
                token,
                key,
                algorithms=self.alg,
                leeway=Secrets.leeway_in_mins,
                options={"require": ["exp", "iss"]},
            )
            return payload
        except InvalidSignatureError:
            raise AuthenticationError(ErrorMessages.ERROR003)
        except ExpiredSignatureError:
            raise AuthenticationError(ErrorMessages.ERROR002)
        except MissingRequiredClaimError:
            raise AuthenticationError(ErrorMessages.ERROR002)
        except Exception as e:
            print(f"Exception while validating JWT: {str(e)}")
            raise
        finally:
            f.close()

    def decode(self, token):
        try:
            with open(self.public) as f:
                key = f.read()
            return jwt.decode(token, key, algorithms=self.alg)
        except Exception as e:
            print(f"Exception while encoding JWT: {str(e)}")
            raise
        finally:
            f.close()
