Commit 1fa8df91 authored by harshavardhan.c's avatar harshavardhan.c

dev: Initial Project setup.

parent e45d7531
......@@ -2,4 +2,4 @@ GRAPH_HOST=192.168.0.220
GRAPH_PORT=7687
GRAPH_USERNAME=neo4j
GRAPH_PASSWORD=root
REDIS_URI=redis://192.168.0.220:6379
\ No newline at end of file
FROM python:3.9.10-slim
COPY requirements.txt /code/requirements.txt
WORKDIR /code
RUN pip install -r requirements.txt
RUN apt update && apt install curl -y
COPY . /code
CMD [ "python", "app.py" ]
\ No newline at end of file
if __name__ == '__main__':
from dotenv import load_dotenv
load_dotenv()
from scripts.core.engine import GraphTraversal
from scripts.db import get_db
from scripts.schemas import GraphData, GetNodeInfo
def ingest_data_handler(graph_data: GraphData):
try:
node_list = []
relations = []
for node_type, node_obj in graph_data.__root__.items():
db = get_db()
print(node_list)
except Exception as e:
print(e.args)
input_data = {
"Node1": {
"node_id": "event_1",
"action": "add",
"node_name": "Event 1",
"project_id": "project_099",
"node_type": "Events",
"properties": {
"name": "Event 1",
"external_data_source": "mongo",
"external_data_id": "101",
"event_info": "847263",
"received_date": 643308200,
"completed_date": 1643394600
},
"edges": [],
"tz": "Asia/Kolkata"
},
"Node2": {
"node_id": "event_2",
"project_id": "project_099",
"action": "add",
"node_name": "Event 2",
"node_type": "Events,Harsha",
"properties": {
"name": "Event 2",
"external_data_source": "mongo",
"external_data_id": "101",
"event_info": "847263",
"received_date": 643308200,
"completed_date": 1643394600
},
"edges": [{
"action": "add",
"rel_name": "Causes",
"new_rel_name": "",
"bind_to": "Node1",
"bind_id": "BSCH270120022"
}]
},
"Node3": {
"node_id": "event_3",
"action": "add",
"node_name": "Event 3",
"project_id": "project_099",
"node_type": "Events",
"properties": {
"name": "Event 3",
"external_data_source": "mongo",
"external_data_id": "101",
"event_info": "847263",
"received_date": 643308200,
"completed_date": 1643394600
},
"edges": [{
"action": "add",
"rel_name": "Error",
"new_rel_name": "",
"bind_to": "Node1",
"bind_id": "BSCH270120022"
}, {
"action": "add",
"rel_name": "Error2",
"new_rel_name": "janu_s_UI",
"bind_to": "Node2",
"bind_id": "BSCH270120022"
}],
"tz": "Asia/Kolkata"
}, "Node4": {
"node_id": "event_4",
"action": "add",
"node_name": "Event 4",
"project_id": "project_099",
"node_type": "Events",
"properties": {
"name": "Event 4",
"external_data_source": "mongo",
"external_data_id": "101",
"event_info": "847263",
"received_date": 643308200,
"completed_date": 1643394600
},
"edges": [{
"action": "add",
"rel_name": "Causes",
"new_rel_name": "",
"bind_to": "Node1",
"bind_id": "BSCH270120022"
}],
"tz": "Asia/Kolkata"
}
}
gt_obj = GraphTraversal(db=get_db())
GraphData(__root__=input_data)
# gt_obj.ingest_data_handler(graph_data=GraphData(__root__=input_data))
print(gt_obj.fetch_node_data(graph_request=GetNodeInfo(project_id="project_099",
node_id="event_1")))
if __name__ == '__main__':
from dotenv import load_dotenv
load_dotenv()
import argparse
import gc
import uvicorn
from scripts.config import Service
from scripts.logging.logging import logger
gc.collect()
ap = argparse.ArgumentParser()
if __name__ == "__main__":
ap.add_argument(
"--port",
"-p",
required=False,
default=Service.PORT,
help="Port to start the application.",
)
ap.add_argument(
"--bind",
"-b",
required=False,
default=Service.HOST,
help="IP to start the application.",
)
arguments = vars(ap.parse_args())
logger.info(f"App Starting at {arguments['bind']}:{arguments['port']}")
uvicorn.run("main:app", host=arguments["bind"], port=int(arguments["port"]))
[MODULE]
name = graph-management
[SERVICE]
port = 3973
host = 0.0.0.0
port=3973
host=0.0.0.0
[GRAPH_DB]
GRAPH_HOST=$GRAPH_HOST
......@@ -12,3 +16,14 @@ DB_TYPE=$DB_TYPE
[LOGGING]
level=$LOG_LEVEL
traceback=true
[REDIS]
uri=$REDIS_URI
login_db = 9
project_tags_db = 18
user_role_permissions=21
[DIRECTORY]
base_path = $BASE_PATH
mount_dir = $MOUNT_DIR
keys_path = data/keys
\ No newline at end of file
if __name__ == '__main__':
from dotenv import load_dotenv
import os
load_dotenv()
from scripts.core.engine import GraphTraversal
from scripts.db import get_db
from fastapi import FastAPI, Depends
from fastapi.middleware.cors import CORSMiddleware
from jwt_signature_validator.encoded_payload import (
EncodedPayloadSignatureMiddleware as SignatureVerificationMiddleware
)
from scripts.schemas import GraphData
from scripts.config import Service
from scripts.constants import Secrets
from scripts.services import service_router
from scripts.utils.security_utils.decorators import CookieAuthentication
auth = CookieAuthentication()
secure_access = os.environ.get("SECURE_ACCESS", default=False)
def ingest_data_handler(graph_data: GraphData):
try:
node_list = []
relations = []
for node_type, node_obj in graph_data.__root__.items():
db = get_db()
app = FastAPI(
title="GraphDB Management",
version="1.0.0",
description="Graph Management App",
openapi_url=os.environ.get("SW_OPENAPI_URL"),
docs_url=os.environ.get("SW_DOCS_URL"),
redoc_url=None,
root_path="/rel_mnt"
)
print(node_list)
except Exception as e:
print(e.args)
if Service.verify_signature in [True, 'True', 'true']:
app.add_middleware(
SignatureVerificationMiddleware,
jwt_secret=Secrets.signature_key,
jwt_algorithms=Secrets.signature_key_alg,
protect_hosts=Service.protected_hosts,
)
origins_list = os.environ.get("CORS_URLS", default="")
origins_list = origins_list.split(',') if origins_list else ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins_list,
allow_credentials=True,
allow_methods=["GET", "POST", "DELETE", "PUT"],
allow_headers=["*"],
)
input_data = {
"Node1": {
"node_id": "event_001",
"action": "add",
"node_name": "Event 20",
"project_id": "project_099",
"node_type": "Events",
"properties": {
"name": "Event 1",
"external_data_source": "mongo",
"external_data_id": "101",
"event_info": "847263",
"received_date": 643308200,
"completed_date": 1643394600
},
"edges": [],
"tz": "Asia/Kolkata"
},
"Node10": {
"node_id": "event_10",
"project_id": "project_099",
"action": "add",
"node_name": "Event Name",
"node_type": "Events,Harsha",
"properties": {
"name": "Event 2",
"external_data_source": "mongo",
"external_data_id": "101",
"event_info": "847263",
"received_date": 643308200,
"completed_date": 1643394600
},
"edges": [{
"action": "add",
"rel_name": "madhuri_don_boscho",
"new_rel_name": "janu_s_UI",
"bind_to": "Node1",
"bind_id": "BSCH270120022"
}]
}
}
gt_obj = GraphTraversal(db=get_db())
GraphData(__root__=input_data)
gt_obj.ingest_data_handler(graph_data=GraphData(__root__=input_data))
@app.get(f"/api/{Service.MODULE_NAME}/healthcheck")
def ping():
return {"status": 200}
auth_enabled = [Depends(auth)] if secure_access in [True, 'true', 'True'] else []
app.include_router(service_router, dependencies=auth_enabled)
python-dotenv~=0.19.2
SQLAlchemy==1.4.35
GQLAlchemy
psycopg2-binary==2.9.3
fastapi~=0.74.1
pytz~=2021.3
PyYAML~=6.0
......@@ -11,5 +9,11 @@ pymongo==3.7.2
ilens-kafka-publisher==0.4.2
kafka-python==1.4.7
faust==1.10.4
SQLAlchemy-Utils==0.38.2
uvicorn[standard]~=0.18.2
cryptography>=3.3.1
pendulum==2.1.2
jwt_signature_validator==0.0.5
crypto~=1.4.1
pycryptodomex==3.9.9
pycryptodome==3.9.9
shortuuid==1.0.8
\ No newline at end of file
import os
import shutil
import sys
from configparser import BasicInterpolation, ConfigParser
......@@ -28,8 +29,11 @@ except Exception as e:
class Service:
port = config["SERVICE"]["port"]
host = config["SERVICE"]["host"]
MODULE_NAME = config["MODULE"]["name"]
PORT = config["SERVICE"]["port"]
HOST = config["SERVICE"]["host"]
verify_signature = os.environ.get("VERIFY_SIGNATURE", False)
protected_hosts = os.environ.get("PROTECTED_HOSTS", "").split(",")
class DBConf:
......@@ -47,3 +51,22 @@ class Logging:
level = level or "INFO"
tb_flag = config.getboolean("LOGGING", "traceback", fallback=True)
tb_flag = tb_flag if tb_flag is not None else True
class RedisConfig(object):
uri = config.get("REDIS", "uri")
login_db = config["REDIS"]["login_db"]
project_tags_db = config.getint("REDIS", "project_tags_db")
user_role_permissions = config.getint("REDIS", "user_role_permissions")
class KeyPath(object):
keys_path = config['DIRECTORY']['keys_path']
if not os.path.isfile(os.path.join(keys_path, "public")) or not os.path.isfile(
os.path.join(keys_path, "private")):
if not os.path.exists(keys_path):
os.makedirs(keys_path)
shutil.copy(os.path.join("assets", "keys", "public"), os.path.join(keys_path, "public"))
shutil.copy(os.path.join("assets", "keys", "private"), os.path.join(keys_path, "private"))
public = os.path.join(keys_path, "public")
private = os.path.join(keys_path, "private")
......@@ -8,3 +8,14 @@ class APIEndPoints:
api_create = '/create'
graph_traverse = "/traverse"
ingest_graph_data = "/ingest"
class Secrets:
LOCK_OUT_TIME_MINS = 30
leeway_in_mins = 10
unique_key = '45c37939-0f75'
token = '8674cd1d-2578-4a62-8ab7-d3ee5f9a'
issuer = "ilens"
alg = "RS256"
signature_key = 'kliLensKLiLensKL'
signature_key_alg = ["HS256"]
......@@ -62,9 +62,9 @@ class GraphTraversal:
raise
def fetch_node_data(self, graph_request: GetNodeInfo):
return_data = ResponseModelSchema(nodes=[], links=[])
return_data = ResponseModelSchema(series_data=dict(nodes=[], links=[]))
try:
existing_data = self.graph_util.get_connecting_nodes_info(input_data=graph_request.dict())
existing_data = self.graph_util.get_connecting_nodes_info(input_data=graph_request)
existing_node_info = []
for k, v in existing_data.items():
for _item in v:
......@@ -76,9 +76,9 @@ class GraphTraversal:
ui_dict["linkName"] = ui_dict.pop("_type")
return_data.series_data["links"].append(ui_dict)
continue
node_id = node_info.get("id")
node_id = node_info.get("node_id")
unique_id = node_info.get("_id")
if not node_id or node_id in unique_id:
if not node_id or unique_id in existing_node_info:
continue
existing_node_info.append(unique_id)
ui_dict.update({"x": '', "y": ''})
......
import redis
from scripts.config import RedisConfig
login_db = redis.from_url(RedisConfig.uri, db=int(RedisConfig.login_db), decode_responses=True)
project_details_db = redis.from_url(RedisConfig.uri, db=int(RedisConfig.project_tags_db), decode_responses=True)
user_role_permissions_redis = redis.from_url(
RedisConfig.uri, db=int(RedisConfig.user_role_permissions), decode_responses=True
)
class InternalError(Exception):
pass
class UnauthorizedError(Exception):
pass
class ProjectIdError(Exception):
pass
class ILensPermissionError(Exception):
pass
class DuplicateTemplateNameError(Exception):
pass
class DuplicateWorkflowNameError(Exception):
pass
class ImplementationError(Exception):
pass
class DuplicatestepNameError(Exception):
pass
class StepNotFound(Exception):
pass
class CustomError(Exception):
pass
class DatetimeMismatchError(Exception):
pass
class DuplicateLogbookNameError(Exception):
pass
class FileExceptions(Exception):
pass
class FileFormatNotSupported(FileExceptions):
pass
class RequiredColumnsMissingException(Exception):
pass
class FileNotFoundException(Exception):
pass
class DisplayFieldsException(Exception):
pass
class TaskCreationLimitExceeded(Exception):
pass
class ILensErrors(Exception):
def __init__(self, msg):
Exception.__init__(self, msg)
"""
Base Error Class
"""
class ErrorCodes:
ERR001 = "ERR001 - Operating Time is greater than Planned Time"
ERR002 = "ERR002 - Zero Values are not allowed"
ERR003 = "ERR003 - Operating Time is less than Productive Time"
ERR004 = "ERR004 - Rejected Units is greater than Total Units"
class DowntimeResponseError(ILensErrors):
"""
Error Occurred during fetch of downtime
"""
class AuthenticationError(ILensErrors):
"""
JWT Authentication Error
"""
class BulkUploadError(ILensErrors):
"""
Bulk Upload Custom errors
"""
class CustomILensError(ILensErrors):
"""
Bulk Upload Custom errors
"""
class NoPropertiesFound(ILensErrors):
"""
No properties found for field even disable key is false
"""
class ErrorMessages:
ERROR001 = "Authentication Failed. Please verify token"
ERROR002 = "Signature Expired"
ERROR003 = "Signature Not Valid"
......@@ -87,8 +87,9 @@ class GraphData(BaseModel):
class GetNodeInfo(BaseModel):
label: Optional[str] = "Events"
project_id: str
id: str
node_id: str
class ResponseModelSchema(BaseModel):
......
from fastapi import APIRouter
from scripts.services.graph_service import graph_router
service_router = APIRouter()
service_router.include_router(graph_router)
\ No newline at end of file
......@@ -9,10 +9,10 @@ from scripts.logging import logger
from scripts.schemas import GraphData, GetNodeInfo
from scripts.schemas.responses import DefaultFailureResponse, DefaultResponse
router = APIRouter(prefix=APIEndPoints.graph_base, tags=["Graph Traversal"])
graph_router = APIRouter(prefix=APIEndPoints.graph_base, tags=["Graph Traversal"])
@router.post(
@graph_router.post(
APIEndPoints.ingest_graph_data
)
def ingest_data_service(request_data: GraphData, db=Depends(get_db)):
......@@ -27,7 +27,7 @@ def ingest_data_service(request_data: GraphData, db=Depends(get_db)):
return DefaultFailureResponse(error=e.args).dict()
@router.post(
@graph_router.post(
APIEndPoints.api_graph_link
)
def ingest_data_service(request_data: GetNodeInfo, db=Depends(get_db)):
......
......@@ -8,9 +8,11 @@ class CommonUtils:
...
@staticmethod
def convert_dict_str_format(input_dict):
def convert_dict_str_format(input_dict: dict, exclude_keys: str = "label"):
for k in exclude_keys.split(","):
input_dict.pop(k, None)
return_str = json.dumps(input_dict).replace('"', "'")
for each_key in input_dict.keys():
for each_key in input_dict:
return_str = return_str.replace(f"'{each_key}'", each_key)
return return_str.replace('{', '').replace('}', '')
......
import json
from functools import lru_cache
@lru_cache()
def get_db_name(redis_client, project_id: str, database: str, delimiter="__"):
if not project_id:
return database
val = redis_client.get(project_id)
if val is None:
raise ValueError(
f"Unknown Project, Project ID: {project_id} Not Found!!!")
val = json.loads(val)
if not val:
return database
# Get the prefix flag to apply project_id prefix to any db
prefix_condition = bool(
val.get("source_meta", {}).get("add_prefix_to_database"))
if prefix_condition:
# Get the prefix name from mongo or default to project_id
prefix_name = val.get("source_meta", {}).get("prefix") or project_id
return f"{prefix_name}{delimiter}{database}"
return database
from typing import List
from gqlalchemy import GQLAlchemyError
from gqlalchemy.query_builders import neo4j_query_builder
from scripts.db import get_db
from scripts.db.graphdb.graph_query import QueryFormation
from scripts.db.graphdb.neo4j import Neo4jHandler
from scripts.db.models import RelationShipMapper, NodePropertiesSchema
from scripts.logging import logger
from scripts.schemas import GetNodeInfo
from scripts.utils.common_utils import CommonUtils
......@@ -78,9 +77,10 @@ class GraphUtility:
logger.exception(f'Exception Occurred while fetching the relation details -> {e.args}')
raise
def get_connecting_nodes_info(self, input_data, label='NodeCreationSchema'):
def get_connecting_nodes_info(self, input_data: GetNodeInfo):
try:
query = f"MATCH (a: {label}{{{self.common_util.convert_dict_str_format(input_dict=input_data)}}})-[r]-(b) RETURN r,a,b;"
query_dict = input_data.dict()
query = f"MATCH (a: {input_data.label}{{{self.common_util.convert_dict_str_format(input_dict=query_dict)}}})-[r]-(b) RETURN r,a,b;"
return self.common_util.process_generator_result(self.db.execute_and_fetch(query=query))
except GQLAlchemyError as e:
logger.debug(f'{e.args}')
......
import base64
from Cryptodome import Random
from Cryptodome.Cipher import AES
class AESCipher:
"""
A classical AES Cipher. Can use any size of data and any size of password thanks to padding.
Also ensure the coherence and the type of the data with a unicode to byte converter.
"""
def __init__(self, key):
self.bs = 16
self.key = AESCipher.str_to_bytes(key)
@staticmethod
def str_to_bytes(data):
u_type = type(b''.decode('utf8'))
if isinstance(data, u_type):
return data.encode('utf8')
return data
def _pad(self, s):
return s + (self.bs - len(s) % self.bs) * AESCipher.str_to_bytes(chr(self.bs - len(s) % self.bs))
@staticmethod
def _unpad(s):
return s[:-ord(s[len(s) - 1:])]
def encrypt(self, raw):
raw = self._pad(AESCipher.str_to_bytes(raw))
iv = Random.new().read(AES.block_size)
cipher = AES.new(self.key, AES.MODE_CBC, iv)
return base64.b64encode(iv + cipher.encrypt(raw)).decode('utf-8')
def decrypt(self, enc):
enc = base64.b64decode(enc)
iv = enc[:AES.block_size]
cipher = AES.new(self.key, AES.MODE_CBC, iv)
data = self._unpad(cipher.decrypt(enc[AES.block_size:]))
return data.decode('utf-8')
import uuid
from datetime import timedelta, datetime, timezone
from scripts.constants import Secrets
from scripts.db.redis_connections import login_db
from scripts.utils.security_utils.jwt_util import JWT
jwt = JWT()
def create_token(user_id, ip, token, age=Secrets.LOCK_OUT_TIME_MINS, login_token=None, project_id=None):
"""
This method is to create a cookie
"""
try:
uid = login_token
if not uid:
uid = str(uuid.uuid4()).replace("-", "")
payload = {
"ip": ip,
"user_id": user_id,
"token": token,
"uid": uid,
"age": age
}
if project_id:
payload["project_id"] = project_id
exp = datetime.now(timezone.utc) + timedelta(minutes=age)
_extras = {"iss": Secrets.issuer, "exp": exp}
_payload = payload | _extras
new_token = jwt.encode(_payload)
# Add session to redis
login_db.set(uid, new_token)
login_db.expire(uid, timedelta(minutes=age))
return uid
except Exception:
raise
from secrets import compare_digest
from typing import Optional
from fastapi import Response, Request, HTTPException, status
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security import APIKeyCookie
from fastapi.security.api_key import APIKeyBase
from pydantic import BaseModel, Field
from scripts.config import Service
from scripts.constants import Secrets
from scripts.db.redis_connections import login_db
from scripts.logging.logging import logger
from scripts.utils.security_utils.apply_encrytion_util import create_token
from scripts.utils.security_utils.jwt_util import JWT
class CookieAuthentication(APIKeyBase):
"""
Authentication backend using a cookie.
Internally, uses a JWT token to store the data.
"""
scheme: APIKeyCookie
cookie_name: str
cookie_secure: bool
def __init__(
self,
cookie_name: str = "login-token",
):
super().__init__()
self.model: APIKey = APIKey(**{"in": APIKeyIn.cookie}, name=cookie_name)
self.scheme_name = self.__class__.__name__
self.cookie_name = cookie_name
self.scheme = APIKeyCookie(name=self.cookie_name, auto_error=False)
self.login_redis = login_db
self.jwt = JWT()
async def __call__(self, request: Request, response: Response) -> str:
cookies = request.cookies
login_token = cookies.get("login-token")
if not login_token:
login_token = request.headers.get("login-token")
if not login_token:
raise HTTPException(status_code=401)
jwt_token = self.login_redis.get(login_token)
if not jwt_token:
raise HTTPException(status_code=401)
try:
decoded_token = self.jwt.validate(token=jwt_token)
if not decoded_token:
raise HTTPException(status_code=401)
except Exception as e:
raise HTTPException(status_code=401, detail=e.args)
user_id = decoded_token.get("user_id")
project_id = decoded_token.get("project_id")
_token = decoded_token.get("token")
_age = int(decoded_token.get("age", Secrets.LOCK_OUT_TIME_MINS))
if any(
[
not compare_digest(Secrets.token, _token),
login_token != decoded_token.get("uid"),
]
):
raise HTTPException(status_code=401)
request.cookies.update(
{"user_id": user_id, "project_id": project_id, "projectId": project_id, "userId": user_id}
)
try:
new_token = create_token(
user_id=user_id,
ip=request.client.host,
token=Secrets.token,
age=_age,
login_token=login_token,
project_id=project_id,
)
except Exception as e:
raise HTTPException(status_code=401, detail=e.args)
response.set_cookie(
"login-token",
new_token,
samesite="strict",
httponly=True,
secure=Service.secure_cookie,
max_age=Secrets.LOCK_OUT_TIME_MINS * 60,
)
# If project ID is null, this is susceptible to 500 Status Code. Ensure token formation has project ID in
# login token
if not project_id:
logger.info("Project ID not found in Old token. Soon to be deprecated. Proceeding for now")
response.headers.update({"login-token": new_token, "userId": user_id, "user_id": user_id})
return user_id
response.headers.update(
{
"login-token": new_token,
"projectId": project_id,
"project_id": project_id,
"userId": user_id,
"user_id": user_id,
}
)
return user_id
class MetaInfoSchema(BaseModel):
projectId: Optional[str] = ""
project_id: Optional[str] = ""
user_id: Optional[str] = ""
language: Optional[str] = ""
ip_address: Optional[str] = ""
login_token: Optional[str] = Field(alias="login-token")
class Config:
allow_population_by_field_name = True
class MetaInfoCookie(APIKeyBase):
"""
Project ID backend using a cookie.
"""
scheme: APIKeyCookie
def __init__(self):
super().__init__()
self.model: APIKey = APIKey(**{"in": APIKeyIn.cookie}, name="meta")
self.scheme_name = self.__class__.__name__
def __call__(self, request: Request, response: Response):
cookies = request.cookies
cookie_json = {
"projectId": cookies.get("projectId", request.headers.get("projectId")),
"userId": cookies.get("user_id", cookies.get("userId", request.headers.get("userId"))),
"language": cookies.get("language", request.headers.get("language")),
}
return MetaInfoSchema(
project_id=cookie_json["projectId"],
user_id=cookie_json["userId"],
projectId=cookie_json["projectId"],
language=cookie_json["language"],
ip_address=request.client.host,
login_token=cookies.get("login-token"),
)
class GetUserID(APIKeyBase):
"""
Project ID backend using a cookie.
"""
scheme: APIKeyCookie
def __init__(self):
super().__init__()
self.model: APIKey = APIKey(**{"in": APIKeyIn.cookie}, name="user_id")
self.scheme_name = self.__class__.__name__
def __call__(self, request: Request, response: Response):
if user_id := request.cookies.get("user_id", request.cookies.get("userId", request.headers.get("userId"))):
return user_id
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
import jwt
from jwt.exceptions import (
InvalidSignatureError,
ExpiredSignatureError,
MissingRequiredClaimError,
)
from scripts.config import KeyPath
from scripts.constants import Secrets
from scripts.errors import AuthenticationError, ErrorMessages
from scripts.logging.logging import logger
class JWT:
def __init__(self):
self.max_login_age = Secrets.LOCK_OUT_TIME_MINS
self.issuer = Secrets.issuer
self.alg = Secrets.alg
self.public = KeyPath.public
self.private = KeyPath.private
def encode(self, payload):
try:
with open(self.private, "r") as f:
key = f.read()
return jwt.encode(payload, key, algorithm=self.alg)
except Exception as e:
logger.exception(f'Exception while encoding JWT: {str(e)}')
raise
finally:
f.close()
def validate(self, token):
try:
with open(self.public, "r") as f:
key = f.read()
payload = jwt.decode(
token,
key,
algorithms=self.alg,
leeway=Secrets.leeway_in_mins,
options={"require": ["exp", "iss"]},
)
return payload
except InvalidSignatureError:
raise AuthenticationError(ErrorMessages.ERROR003)
except ExpiredSignatureError:
raise AuthenticationError(ErrorMessages.ERROR002)
except MissingRequiredClaimError:
raise AuthenticationError(ErrorMessages.ERROR002)
except Exception as e:
logger.exception(f'Exception while validating JWT: {str(e)}')
raise
finally:
f.close()
from typing import Optional
from fastapi import Response, Request
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.api_key import APIKeyBase, APIKeyCookie
from pydantic import BaseModel
class MetaInfoSchema(BaseModel):
project_id: Optional[str] = ""
user_id: Optional[str] = ""
language: Optional[str] = ""
class MetaInfoCookie(APIKeyBase):
"""
Project ID backend using a cookie.
"""
scheme: APIKeyCookie
cookie_name: str
def __init__(self, cookie_name: str = "projectId"):
super().__init__()
self.model: APIKey = APIKey(**{"in": APIKeyIn.cookie}, name=cookie_name)
self.cookie_name = cookie_name
self.scheme_name = self.__class__.__name__
self.scheme = APIKeyCookie(name=self.cookie_name, auto_error=False)
async def __call__(self, request: Request, response: Response):
cookies = request.cookies
cookie_json = {
"projectId": cookies.get("projectId", request.headers.get("projectId")),
"userId": cookies.get("user_id", cookies.get("userId", request.headers.get("userId"))),
"language": cookies.get("language", request.headers.get("language"))
}
return MetaInfoSchema(project_id=cookie_json["projectId"], user_id=cookie_json["userId"],
language=cookie_json["language"])
@staticmethod
def set_response_info(cookie_name, cookie_value, response: Response):
response.set_cookie(
cookie_name,
cookie_value,
samesite="strict",
httponly=True
)
response.headers[cookie_name] = cookie_value
import logging
import orjson as json
from fastapi import HTTPException, Request, status
from scripts.db.mongo import mongo_client
from scripts.db.mongo.ilens_configuration.collections.user import User
from scripts.db.mongo.ilens_configuration.collections.user_project import UserProject
from scripts.db.redis_connections import user_role_permissions_redis
from scripts.utils.common_utils import timed_lru_cache
@timed_lru_cache(seconds=60, maxsize=1000)
def get_user_role_id(user_id, project_id):
logging.debug("Fetching user role from DB")
user_conn = User(mongo_client=mongo_client)
if user_role := user_conn.find_user_role_for_user_id(user_id=user_id, project_id=project_id):
return user_role["userrole"][0]
# if user not found in primary collection, check if user is in project collection
user_proj_conn = UserProject(mongo_client=mongo_client)
if user_role := user_proj_conn.find_user_role_for_user_id(user_id=user_id, project_id=project_id):
return user_role["userrole"][0]
class RBAC:
def __init__(self, entity_name: str, operation: list[str]):
self.entity_name = entity_name
self.operation = operation
def check_permissions(self, user_id: str, project_id: str) -> dict[str, bool]:
user_role_id = get_user_role_id(user_id, project_id)
if not user_role_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User role not found!")
r_key = f"{project_id}__{user_role_id}" # eg: project_100__user_role_100
user_role_rec = user_role_permissions_redis.hget(r_key, self.entity_name)
if not user_role_rec:
return {} # TODO: raise exception here
user_role_rec = json.loads(user_role_rec)
if permission_dict := {i: True for i in self.operation if user_role_rec.get(i)}:
return permission_dict
else:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient Permission!")
def __call__(self, request: Request) -> dict[str, bool]:
user_id = request.cookies.get("userId", request.headers.get("userId"))
project_id = request.cookies.get("projectId", request.headers.get("projectId"))
return self.check_permissions(user_id=user_id, project_id=project_id)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment