Commit b640a3eb authored by khusraj.jain's avatar khusraj.jain 🎯

Merge branch 'master' into 'master'

# Conflicts:
#   Dockerfile
parents faa8b4c6 1fa8df91
...@@ -2,4 +2,4 @@ GRAPH_HOST=192.168.0.220 ...@@ -2,4 +2,4 @@ GRAPH_HOST=192.168.0.220
GRAPH_PORT=7687 GRAPH_PORT=7687
GRAPH_USERNAME=neo4j GRAPH_USERNAME=neo4j
GRAPH_PASSWORD=root GRAPH_PASSWORD=root
REDIS_URI=redis://192.168.0.220:6379
\ No newline at end of file
if __name__ == '__main__':
from dotenv import load_dotenv
load_dotenv()
from scripts.core.engine import GraphTraversal
from scripts.db import get_db
from scripts.schemas import GraphData, GetNodeInfo
def ingest_data_handler(graph_data: GraphData):
try:
node_list = []
relations = []
for node_type, node_obj in graph_data.__root__.items():
db = get_db()
print(node_list)
except Exception as e:
print(e.args)
input_data = {
"Node1": {
"node_id": "event_1",
"action": "add",
"node_name": "Event 1",
"project_id": "project_099",
"node_type": "Events",
"properties": {
"name": "Event 1",
"external_data_source": "mongo",
"external_data_id": "101",
"event_info": "847263",
"received_date": 643308200,
"completed_date": 1643394600
},
"edges": [],
"tz": "Asia/Kolkata"
},
"Node2": {
"node_id": "event_2",
"project_id": "project_099",
"action": "add",
"node_name": "Event 2",
"node_type": "Events,Harsha",
"properties": {
"name": "Event 2",
"external_data_source": "mongo",
"external_data_id": "101",
"event_info": "847263",
"received_date": 643308200,
"completed_date": 1643394600
},
"edges": [{
"action": "add",
"rel_name": "Causes",
"new_rel_name": "",
"bind_to": "Node1",
"bind_id": "BSCH270120022"
}]
},
"Node3": {
"node_id": "event_3",
"action": "add",
"node_name": "Event 3",
"project_id": "project_099",
"node_type": "Events",
"properties": {
"name": "Event 3",
"external_data_source": "mongo",
"external_data_id": "101",
"event_info": "847263",
"received_date": 643308200,
"completed_date": 1643394600
},
"edges": [{
"action": "add",
"rel_name": "Error",
"new_rel_name": "",
"bind_to": "Node1",
"bind_id": "BSCH270120022"
}, {
"action": "add",
"rel_name": "Error2",
"new_rel_name": "janu_s_UI",
"bind_to": "Node2",
"bind_id": "BSCH270120022"
}],
"tz": "Asia/Kolkata"
}, "Node4": {
"node_id": "event_4",
"action": "add",
"node_name": "Event 4",
"project_id": "project_099",
"node_type": "Events",
"properties": {
"name": "Event 4",
"external_data_source": "mongo",
"external_data_id": "101",
"event_info": "847263",
"received_date": 643308200,
"completed_date": 1643394600
},
"edges": [{
"action": "add",
"rel_name": "Causes",
"new_rel_name": "",
"bind_to": "Node1",
"bind_id": "BSCH270120022"
}],
"tz": "Asia/Kolkata"
}
}
gt_obj = GraphTraversal(db=get_db())
GraphData(__root__=input_data)
# gt_obj.ingest_data_handler(graph_data=GraphData(__root__=input_data))
print(gt_obj.fetch_node_data(graph_request=GetNodeInfo(project_id="project_099",
node_id="event_1")))
if __name__ == '__main__':
from dotenv import load_dotenv
load_dotenv()
import argparse
import gc
import uvicorn
from scripts.config import Service
from scripts.logging.logging import logger
gc.collect()
ap = argparse.ArgumentParser()
if __name__ == "__main__":
ap.add_argument(
"--port",
"-p",
required=False,
default=Service.PORT,
help="Port to start the application.",
)
ap.add_argument(
"--bind",
"-b",
required=False,
default=Service.HOST,
help="IP to start the application.",
)
arguments = vars(ap.parse_args())
logger.info(f"App Starting at {arguments['bind']}:{arguments['port']}")
uvicorn.run("main:app", host=arguments["bind"], port=int(arguments["port"]))
[MODULE]
name = graph-management
[SERVICE] [SERVICE]
port = 3973 port=3973
host = 0.0.0.0 host=0.0.0.0
[GRAPH_DB] [GRAPH_DB]
GRAPH_HOST=$GRAPH_HOST GRAPH_HOST=$GRAPH_HOST
...@@ -11,4 +15,15 @@ DB_TYPE=$DB_TYPE ...@@ -11,4 +15,15 @@ DB_TYPE=$DB_TYPE
[LOGGING] [LOGGING]
level=$LOG_LEVEL level=$LOG_LEVEL
traceback=true traceback=true
\ No newline at end of file
[REDIS]
uri=$REDIS_URI
login_db = 9
project_tags_db = 18
user_role_permissions=21
[DIRECTORY]
base_path = $BASE_PATH
mount_dir = $MOUNT_DIR
keys_path = data/keys
\ No newline at end of file
if __name__ == '__main__': import os
from dotenv import load_dotenv
load_dotenv() from fastapi import FastAPI, Depends
from scripts.core.engine import GraphTraversal from fastapi.middleware.cors import CORSMiddleware
from scripts.db import get_db from jwt_signature_validator.encoded_payload import (
EncodedPayloadSignatureMiddleware as SignatureVerificationMiddleware
)
from scripts.schemas import GraphData from scripts.config import Service
from scripts.constants import Secrets
from scripts.services import service_router
from scripts.utils.security_utils.decorators import CookieAuthentication
auth = CookieAuthentication()
secure_access = os.environ.get("SECURE_ACCESS", default=False)
def ingest_data_handler(graph_data: GraphData): app = FastAPI(
try: title="GraphDB Management",
node_list = [] version="1.0.0",
relations = [] description="Graph Management App",
for node_type, node_obj in graph_data.__root__.items(): openapi_url=os.environ.get("SW_OPENAPI_URL"),
db = get_db() docs_url=os.environ.get("SW_DOCS_URL"),
redoc_url=None,
root_path="/rel_mnt"
)
print(node_list) if Service.verify_signature in [True, 'True', 'true']:
except Exception as e: app.add_middleware(
print(e.args) SignatureVerificationMiddleware,
jwt_secret=Secrets.signature_key,
jwt_algorithms=Secrets.signature_key_alg,
protect_hosts=Service.protected_hosts,
)
origins_list = os.environ.get("CORS_URLS", default="")
origins_list = origins_list.split(',') if origins_list else ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins_list,
allow_credentials=True,
allow_methods=["GET", "POST", "DELETE", "PUT"],
allow_headers=["*"],
)
input_data = {
"Node1": { @app.get(f"/api/{Service.MODULE_NAME}/healthcheck")
"node_id": "event_001", def ping():
"action": "add", return {"status": 200}
"node_name": "Event 20",
"project_id": "project_099",
"node_type": "Events", auth_enabled = [Depends(auth)] if secure_access in [True, 'true', 'True'] else []
"properties": { app.include_router(service_router, dependencies=auth_enabled)
"name": "Event 1",
"external_data_source": "mongo",
"external_data_id": "101",
"event_info": "847263",
"received_date": 643308200,
"completed_date": 1643394600
},
"edges": [],
"tz": "Asia/Kolkata"
},
"Node10": {
"node_id": "event_10",
"project_id": "project_099",
"action": "add",
"node_name": "Event Name",
"node_type": "Events,Harsha",
"properties": {
"name": "Event 2",
"external_data_source": "mongo",
"external_data_id": "101",
"event_info": "847263",
"received_date": 643308200,
"completed_date": 1643394600
},
"edges": [{
"action": "add",
"rel_name": "madhuri_don_boscho",
"new_rel_name": "janu_s_UI",
"bind_to": "Node1",
"bind_id": "BSCH270120022"
}]
}
}
gt_obj = GraphTraversal(db=get_db())
GraphData(__root__=input_data)
gt_obj.ingest_data_handler(graph_data=GraphData(__root__=input_data))
python-dotenv~=0.19.2 python-dotenv~=0.19.2
SQLAlchemy==1.4.35
GQLAlchemy GQLAlchemy
psycopg2-binary==2.9.3
fastapi~=0.74.1 fastapi~=0.74.1
pytz~=2021.3 pytz~=2021.3
PyYAML~=6.0 PyYAML~=6.0
...@@ -11,5 +9,11 @@ pymongo==3.7.2 ...@@ -11,5 +9,11 @@ pymongo==3.7.2
ilens-kafka-publisher==0.4.2 ilens-kafka-publisher==0.4.2
kafka-python==1.4.7 kafka-python==1.4.7
faust==1.10.4 faust==1.10.4
SQLAlchemy-Utils==0.38.2 uvicorn[standard]~=0.18.2
cryptography>=3.3.1
pendulum==2.1.2 pendulum==2.1.2
jwt_signature_validator==0.0.5
crypto~=1.4.1
pycryptodomex==3.9.9
pycryptodome==3.9.9
shortuuid==1.0.8
\ No newline at end of file
import os import os
import shutil
import sys import sys
from configparser import BasicInterpolation, ConfigParser from configparser import BasicInterpolation, ConfigParser
...@@ -28,8 +29,11 @@ except Exception as e: ...@@ -28,8 +29,11 @@ except Exception as e:
class Service: class Service:
port = config["SERVICE"]["port"] MODULE_NAME = config["MODULE"]["name"]
host = config["SERVICE"]["host"] PORT = config["SERVICE"]["port"]
HOST = config["SERVICE"]["host"]
verify_signature = os.environ.get("VERIFY_SIGNATURE", False)
protected_hosts = os.environ.get("PROTECTED_HOSTS", "").split(",")
class DBConf: class DBConf:
...@@ -47,3 +51,22 @@ class Logging: ...@@ -47,3 +51,22 @@ class Logging:
level = level or "INFO" level = level or "INFO"
tb_flag = config.getboolean("LOGGING", "traceback", fallback=True) tb_flag = config.getboolean("LOGGING", "traceback", fallback=True)
tb_flag = tb_flag if tb_flag is not None else True tb_flag = tb_flag if tb_flag is not None else True
class RedisConfig(object):
uri = config.get("REDIS", "uri")
login_db = config["REDIS"]["login_db"]
project_tags_db = config.getint("REDIS", "project_tags_db")
user_role_permissions = config.getint("REDIS", "user_role_permissions")
class KeyPath(object):
keys_path = config['DIRECTORY']['keys_path']
if not os.path.isfile(os.path.join(keys_path, "public")) or not os.path.isfile(
os.path.join(keys_path, "private")):
if not os.path.exists(keys_path):
os.makedirs(keys_path)
shutil.copy(os.path.join("assets", "keys", "public"), os.path.join(keys_path, "public"))
shutil.copy(os.path.join("assets", "keys", "private"), os.path.join(keys_path, "private"))
public = os.path.join(keys_path, "public")
private = os.path.join(keys_path, "private")
...@@ -8,3 +8,14 @@ class APIEndPoints: ...@@ -8,3 +8,14 @@ class APIEndPoints:
api_create = '/create' api_create = '/create'
graph_traverse = "/traverse" graph_traverse = "/traverse"
ingest_graph_data = "/ingest" ingest_graph_data = "/ingest"
class Secrets:
LOCK_OUT_TIME_MINS = 30
leeway_in_mins = 10
unique_key = '45c37939-0f75'
token = '8674cd1d-2578-4a62-8ab7-d3ee5f9a'
issuer = "ilens"
alg = "RS256"
signature_key = 'kliLensKLiLensKL'
signature_key_alg = ["HS256"]
...@@ -62,9 +62,9 @@ class GraphTraversal: ...@@ -62,9 +62,9 @@ class GraphTraversal:
raise raise
def fetch_node_data(self, graph_request: GetNodeInfo): def fetch_node_data(self, graph_request: GetNodeInfo):
return_data = ResponseModelSchema(nodes=[], links=[]) return_data = ResponseModelSchema(series_data=dict(nodes=[], links=[]))
try: try:
existing_data = self.graph_util.get_connecting_nodes_info(input_data=graph_request.dict()) existing_data = self.graph_util.get_connecting_nodes_info(input_data=graph_request)
existing_node_info = [] existing_node_info = []
for k, v in existing_data.items(): for k, v in existing_data.items():
for _item in v: for _item in v:
...@@ -76,9 +76,9 @@ class GraphTraversal: ...@@ -76,9 +76,9 @@ class GraphTraversal:
ui_dict["linkName"] = ui_dict.pop("_type") ui_dict["linkName"] = ui_dict.pop("_type")
return_data.series_data["links"].append(ui_dict) return_data.series_data["links"].append(ui_dict)
continue continue
node_id = node_info.get("id") node_id = node_info.get("node_id")
unique_id = node_info.get("_id") unique_id = node_info.get("_id")
if not node_id or node_id in unique_id: if not node_id or unique_id in existing_node_info:
continue continue
existing_node_info.append(unique_id) existing_node_info.append(unique_id)
ui_dict.update({"x": '', "y": ''}) ui_dict.update({"x": '', "y": ''})
......
import redis
from scripts.config import RedisConfig
login_db = redis.from_url(RedisConfig.uri, db=int(RedisConfig.login_db), decode_responses=True)
project_details_db = redis.from_url(RedisConfig.uri, db=int(RedisConfig.project_tags_db), decode_responses=True)
user_role_permissions_redis = redis.from_url(
RedisConfig.uri, db=int(RedisConfig.user_role_permissions), decode_responses=True
)
class InternalError(Exception):
pass
class UnauthorizedError(Exception):
pass
class ProjectIdError(Exception):
pass
class ILensPermissionError(Exception):
pass
class DuplicateTemplateNameError(Exception):
pass
class DuplicateWorkflowNameError(Exception):
pass
class ImplementationError(Exception):
pass
class DuplicatestepNameError(Exception):
pass
class StepNotFound(Exception):
pass
class CustomError(Exception):
pass
class DatetimeMismatchError(Exception):
pass
class DuplicateLogbookNameError(Exception):
pass
class FileExceptions(Exception):
pass
class FileFormatNotSupported(FileExceptions):
pass
class RequiredColumnsMissingException(Exception):
pass
class FileNotFoundException(Exception):
pass
class DisplayFieldsException(Exception):
pass
class TaskCreationLimitExceeded(Exception):
pass
class ILensErrors(Exception):
def __init__(self, msg):
Exception.__init__(self, msg)
"""
Base Error Class
"""
class ErrorCodes:
ERR001 = "ERR001 - Operating Time is greater than Planned Time"
ERR002 = "ERR002 - Zero Values are not allowed"
ERR003 = "ERR003 - Operating Time is less than Productive Time"
ERR004 = "ERR004 - Rejected Units is greater than Total Units"
class DowntimeResponseError(ILensErrors):
"""
Error Occurred during fetch of downtime
"""
class AuthenticationError(ILensErrors):
"""
JWT Authentication Error
"""
class BulkUploadError(ILensErrors):
"""
Bulk Upload Custom errors
"""
class CustomILensError(ILensErrors):
"""
Bulk Upload Custom errors
"""
class NoPropertiesFound(ILensErrors):
"""
No properties found for field even disable key is false
"""
class ErrorMessages:
ERROR001 = "Authentication Failed. Please verify token"
ERROR002 = "Signature Expired"
ERROR003 = "Signature Not Valid"
...@@ -87,8 +87,9 @@ class GraphData(BaseModel): ...@@ -87,8 +87,9 @@ class GraphData(BaseModel):
class GetNodeInfo(BaseModel): class GetNodeInfo(BaseModel):
label: Optional[str] = "Events"
project_id: str project_id: str
id: str node_id: str
class ResponseModelSchema(BaseModel): class ResponseModelSchema(BaseModel):
......
from fastapi import APIRouter
from scripts.services.graph_service import graph_router
service_router = APIRouter()
service_router.include_router(graph_router)
\ No newline at end of file
...@@ -9,10 +9,10 @@ from scripts.logging import logger ...@@ -9,10 +9,10 @@ from scripts.logging import logger
from scripts.schemas import GraphData, GetNodeInfo from scripts.schemas import GraphData, GetNodeInfo
from scripts.schemas.responses import DefaultFailureResponse, DefaultResponse from scripts.schemas.responses import DefaultFailureResponse, DefaultResponse
router = APIRouter(prefix=APIEndPoints.graph_base, tags=["Graph Traversal"]) graph_router = APIRouter(prefix=APIEndPoints.graph_base, tags=["Graph Traversal"])
@router.post( @graph_router.post(
APIEndPoints.ingest_graph_data APIEndPoints.ingest_graph_data
) )
def ingest_data_service(request_data: GraphData, db=Depends(get_db)): def ingest_data_service(request_data: GraphData, db=Depends(get_db)):
...@@ -27,7 +27,7 @@ def ingest_data_service(request_data: GraphData, db=Depends(get_db)): ...@@ -27,7 +27,7 @@ def ingest_data_service(request_data: GraphData, db=Depends(get_db)):
return DefaultFailureResponse(error=e.args).dict() return DefaultFailureResponse(error=e.args).dict()
@router.post( @graph_router.post(
APIEndPoints.api_graph_link APIEndPoints.api_graph_link
) )
def ingest_data_service(request_data: GetNodeInfo, db=Depends(get_db)): def ingest_data_service(request_data: GetNodeInfo, db=Depends(get_db)):
......
...@@ -8,9 +8,11 @@ class CommonUtils: ...@@ -8,9 +8,11 @@ class CommonUtils:
... ...
@staticmethod @staticmethod
def convert_dict_str_format(input_dict): def convert_dict_str_format(input_dict: dict, exclude_keys: str = "label"):
for k in exclude_keys.split(","):
input_dict.pop(k, None)
return_str = json.dumps(input_dict).replace('"', "'") return_str = json.dumps(input_dict).replace('"', "'")
for each_key in input_dict.keys(): for each_key in input_dict:
return_str = return_str.replace(f"'{each_key}'", each_key) return_str = return_str.replace(f"'{each_key}'", each_key)
return return_str.replace('{', '').replace('}', '') return return_str.replace('{', '').replace('}', '')
......
import json
from functools import lru_cache
@lru_cache()
def get_db_name(redis_client, project_id: str, database: str, delimiter="__"):
if not project_id:
return database
val = redis_client.get(project_id)
if val is None:
raise ValueError(
f"Unknown Project, Project ID: {project_id} Not Found!!!")
val = json.loads(val)
if not val:
return database
# Get the prefix flag to apply project_id prefix to any db
prefix_condition = bool(
val.get("source_meta", {}).get("add_prefix_to_database"))
if prefix_condition:
# Get the prefix name from mongo or default to project_id
prefix_name = val.get("source_meta", {}).get("prefix") or project_id
return f"{prefix_name}{delimiter}{database}"
return database
from typing import List from typing import List
from gqlalchemy import GQLAlchemyError from gqlalchemy import GQLAlchemyError
from gqlalchemy.query_builders import neo4j_query_builder
from scripts.db import get_db
from scripts.db.graphdb.graph_query import QueryFormation from scripts.db.graphdb.graph_query import QueryFormation
from scripts.db.graphdb.neo4j import Neo4jHandler from scripts.db.graphdb.neo4j import Neo4jHandler
from scripts.db.models import RelationShipMapper, NodePropertiesSchema from scripts.db.models import RelationShipMapper, NodePropertiesSchema
from scripts.logging import logger from scripts.logging import logger
from scripts.schemas import GetNodeInfo
from scripts.utils.common_utils import CommonUtils from scripts.utils.common_utils import CommonUtils
...@@ -78,9 +77,10 @@ class GraphUtility: ...@@ -78,9 +77,10 @@ class GraphUtility:
logger.exception(f'Exception Occurred while fetching the relation details -> {e.args}') logger.exception(f'Exception Occurred while fetching the relation details -> {e.args}')
raise raise
def get_connecting_nodes_info(self, input_data, label='NodeCreationSchema'): def get_connecting_nodes_info(self, input_data: GetNodeInfo):
try: try:
query = f"MATCH (a: {label}{{{self.common_util.convert_dict_str_format(input_dict=input_data)}}})-[r]-(b) RETURN r,a,b;" query_dict = input_data.dict()
query = f"MATCH (a: {input_data.label}{{{self.common_util.convert_dict_str_format(input_dict=query_dict)}}})-[r]-(b) RETURN r,a,b;"
return self.common_util.process_generator_result(self.db.execute_and_fetch(query=query)) return self.common_util.process_generator_result(self.db.execute_and_fetch(query=query))
except GQLAlchemyError as e: except GQLAlchemyError as e:
logger.debug(f'{e.args}') logger.debug(f'{e.args}')
......
import base64
from Cryptodome import Random
from Cryptodome.Cipher import AES
class AESCipher:
"""
A classical AES Cipher. Can use any size of data and any size of password thanks to padding.
Also ensure the coherence and the type of the data with a unicode to byte converter.
"""
def __init__(self, key):
self.bs = 16
self.key = AESCipher.str_to_bytes(key)
@staticmethod
def str_to_bytes(data):
u_type = type(b''.decode('utf8'))
if isinstance(data, u_type):
return data.encode('utf8')
return data
def _pad(self, s):
return s + (self.bs - len(s) % self.bs) * AESCipher.str_to_bytes(chr(self.bs - len(s) % self.bs))
@staticmethod
def _unpad(s):
return s[:-ord(s[len(s) - 1:])]
def encrypt(self, raw):
raw = self._pad(AESCipher.str_to_bytes(raw))
iv = Random.new().read(AES.block_size)
cipher = AES.new(self.key, AES.MODE_CBC, iv)
return base64.b64encode(iv + cipher.encrypt(raw)).decode('utf-8')
def decrypt(self, enc):
enc = base64.b64decode(enc)
iv = enc[:AES.block_size]
cipher = AES.new(self.key, AES.MODE_CBC, iv)
data = self._unpad(cipher.decrypt(enc[AES.block_size:]))
return data.decode('utf-8')
import uuid
from datetime import timedelta, datetime, timezone
from scripts.constants import Secrets
from scripts.db.redis_connections import login_db
from scripts.utils.security_utils.jwt_util import JWT
jwt = JWT()
def create_token(user_id, ip, token, age=Secrets.LOCK_OUT_TIME_MINS, login_token=None, project_id=None):
"""
This method is to create a cookie
"""
try:
uid = login_token
if not uid:
uid = str(uuid.uuid4()).replace("-", "")
payload = {
"ip": ip,
"user_id": user_id,
"token": token,
"uid": uid,
"age": age
}
if project_id:
payload["project_id"] = project_id
exp = datetime.now(timezone.utc) + timedelta(minutes=age)
_extras = {"iss": Secrets.issuer, "exp": exp}
_payload = payload | _extras
new_token = jwt.encode(_payload)
# Add session to redis
login_db.set(uid, new_token)
login_db.expire(uid, timedelta(minutes=age))
return uid
except Exception:
raise
from secrets import compare_digest
from typing import Optional
from fastapi import Response, Request, HTTPException, status
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security import APIKeyCookie
from fastapi.security.api_key import APIKeyBase
from pydantic import BaseModel, Field
from scripts.config import Service
from scripts.constants import Secrets
from scripts.db.redis_connections import login_db
from scripts.logging.logging import logger
from scripts.utils.security_utils.apply_encrytion_util import create_token
from scripts.utils.security_utils.jwt_util import JWT
class CookieAuthentication(APIKeyBase):
"""
Authentication backend using a cookie.
Internally, uses a JWT token to store the data.
"""
scheme: APIKeyCookie
cookie_name: str
cookie_secure: bool
def __init__(
self,
cookie_name: str = "login-token",
):
super().__init__()
self.model: APIKey = APIKey(**{"in": APIKeyIn.cookie}, name=cookie_name)
self.scheme_name = self.__class__.__name__
self.cookie_name = cookie_name
self.scheme = APIKeyCookie(name=self.cookie_name, auto_error=False)
self.login_redis = login_db
self.jwt = JWT()
async def __call__(self, request: Request, response: Response) -> str:
cookies = request.cookies
login_token = cookies.get("login-token")
if not login_token:
login_token = request.headers.get("login-token")
if not login_token:
raise HTTPException(status_code=401)
jwt_token = self.login_redis.get(login_token)
if not jwt_token:
raise HTTPException(status_code=401)
try:
decoded_token = self.jwt.validate(token=jwt_token)
if not decoded_token:
raise HTTPException(status_code=401)
except Exception as e:
raise HTTPException(status_code=401, detail=e.args)
user_id = decoded_token.get("user_id")
project_id = decoded_token.get("project_id")
_token = decoded_token.get("token")
_age = int(decoded_token.get("age", Secrets.LOCK_OUT_TIME_MINS))
if any(
[
not compare_digest(Secrets.token, _token),
login_token != decoded_token.get("uid"),
]
):
raise HTTPException(status_code=401)
request.cookies.update(
{"user_id": user_id, "project_id": project_id, "projectId": project_id, "userId": user_id}
)
try:
new_token = create_token(
user_id=user_id,
ip=request.client.host,
token=Secrets.token,
age=_age,
login_token=login_token,
project_id=project_id,
)
except Exception as e:
raise HTTPException(status_code=401, detail=e.args)
response.set_cookie(
"login-token",
new_token,
samesite="strict",
httponly=True,
secure=Service.secure_cookie,
max_age=Secrets.LOCK_OUT_TIME_MINS * 60,
)
# If project ID is null, this is susceptible to 500 Status Code. Ensure token formation has project ID in
# login token
if not project_id:
logger.info("Project ID not found in Old token. Soon to be deprecated. Proceeding for now")
response.headers.update({"login-token": new_token, "userId": user_id, "user_id": user_id})
return user_id
response.headers.update(
{
"login-token": new_token,
"projectId": project_id,
"project_id": project_id,
"userId": user_id,
"user_id": user_id,
}
)
return user_id
class MetaInfoSchema(BaseModel):
projectId: Optional[str] = ""
project_id: Optional[str] = ""
user_id: Optional[str] = ""
language: Optional[str] = ""
ip_address: Optional[str] = ""
login_token: Optional[str] = Field(alias="login-token")
class Config:
allow_population_by_field_name = True
class MetaInfoCookie(APIKeyBase):
"""
Project ID backend using a cookie.
"""
scheme: APIKeyCookie
def __init__(self):
super().__init__()
self.model: APIKey = APIKey(**{"in": APIKeyIn.cookie}, name="meta")
self.scheme_name = self.__class__.__name__
def __call__(self, request: Request, response: Response):
cookies = request.cookies
cookie_json = {
"projectId": cookies.get("projectId", request.headers.get("projectId")),
"userId": cookies.get("user_id", cookies.get("userId", request.headers.get("userId"))),
"language": cookies.get("language", request.headers.get("language")),
}
return MetaInfoSchema(
project_id=cookie_json["projectId"],
user_id=cookie_json["userId"],
projectId=cookie_json["projectId"],
language=cookie_json["language"],
ip_address=request.client.host,
login_token=cookies.get("login-token"),
)
class GetUserID(APIKeyBase):
"""
Project ID backend using a cookie.
"""
scheme: APIKeyCookie
def __init__(self):
super().__init__()
self.model: APIKey = APIKey(**{"in": APIKeyIn.cookie}, name="user_id")
self.scheme_name = self.__class__.__name__
def __call__(self, request: Request, response: Response):
if user_id := request.cookies.get("user_id", request.cookies.get("userId", request.headers.get("userId"))):
return user_id
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
import jwt
from jwt.exceptions import (
InvalidSignatureError,
ExpiredSignatureError,
MissingRequiredClaimError,
)
from scripts.config import KeyPath
from scripts.constants import Secrets
from scripts.errors import AuthenticationError, ErrorMessages
from scripts.logging.logging import logger
class JWT:
def __init__(self):
self.max_login_age = Secrets.LOCK_OUT_TIME_MINS
self.issuer = Secrets.issuer
self.alg = Secrets.alg
self.public = KeyPath.public
self.private = KeyPath.private
def encode(self, payload):
try:
with open(self.private, "r") as f:
key = f.read()
return jwt.encode(payload, key, algorithm=self.alg)
except Exception as e:
logger.exception(f'Exception while encoding JWT: {str(e)}')
raise
finally:
f.close()
def validate(self, token):
try:
with open(self.public, "r") as f:
key = f.read()
payload = jwt.decode(
token,
key,
algorithms=self.alg,
leeway=Secrets.leeway_in_mins,
options={"require": ["exp", "iss"]},
)
return payload
except InvalidSignatureError:
raise AuthenticationError(ErrorMessages.ERROR003)
except ExpiredSignatureError:
raise AuthenticationError(ErrorMessages.ERROR002)
except MissingRequiredClaimError:
raise AuthenticationError(ErrorMessages.ERROR002)
except Exception as e:
logger.exception(f'Exception while validating JWT: {str(e)}')
raise
finally:
f.close()
from typing import Optional
from fastapi import Response, Request
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.api_key import APIKeyBase, APIKeyCookie
from pydantic import BaseModel
class MetaInfoSchema(BaseModel):
project_id: Optional[str] = ""
user_id: Optional[str] = ""
language: Optional[str] = ""
class MetaInfoCookie(APIKeyBase):
"""
Project ID backend using a cookie.
"""
scheme: APIKeyCookie
cookie_name: str
def __init__(self, cookie_name: str = "projectId"):
super().__init__()
self.model: APIKey = APIKey(**{"in": APIKeyIn.cookie}, name=cookie_name)
self.cookie_name = cookie_name
self.scheme_name = self.__class__.__name__
self.scheme = APIKeyCookie(name=self.cookie_name, auto_error=False)
async def __call__(self, request: Request, response: Response):
cookies = request.cookies
cookie_json = {
"projectId": cookies.get("projectId", request.headers.get("projectId")),
"userId": cookies.get("user_id", cookies.get("userId", request.headers.get("userId"))),
"language": cookies.get("language", request.headers.get("language"))
}
return MetaInfoSchema(project_id=cookie_json["projectId"], user_id=cookie_json["userId"],
language=cookie_json["language"])
@staticmethod
def set_response_info(cookie_name, cookie_value, response: Response):
response.set_cookie(
cookie_name,
cookie_value,
samesite="strict",
httponly=True
)
response.headers[cookie_name] = cookie_value
import logging
import orjson as json
from fastapi import HTTPException, Request, status
from scripts.db.mongo import mongo_client
from scripts.db.mongo.ilens_configuration.collections.user import User
from scripts.db.mongo.ilens_configuration.collections.user_project import UserProject
from scripts.db.redis_connections import user_role_permissions_redis
from scripts.utils.common_utils import timed_lru_cache
@timed_lru_cache(seconds=60, maxsize=1000)
def get_user_role_id(user_id, project_id):
logging.debug("Fetching user role from DB")
user_conn = User(mongo_client=mongo_client)
if user_role := user_conn.find_user_role_for_user_id(user_id=user_id, project_id=project_id):
return user_role["userrole"][0]
# if user not found in primary collection, check if user is in project collection
user_proj_conn = UserProject(mongo_client=mongo_client)
if user_role := user_proj_conn.find_user_role_for_user_id(user_id=user_id, project_id=project_id):
return user_role["userrole"][0]
class RBAC:
def __init__(self, entity_name: str, operation: list[str]):
self.entity_name = entity_name
self.operation = operation
def check_permissions(self, user_id: str, project_id: str) -> dict[str, bool]:
user_role_id = get_user_role_id(user_id, project_id)
if not user_role_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User role not found!")
r_key = f"{project_id}__{user_role_id}" # eg: project_100__user_role_100
user_role_rec = user_role_permissions_redis.hget(r_key, self.entity_name)
if not user_role_rec:
return {} # TODO: raise exception here
user_role_rec = json.loads(user_role_rec)
if permission_dict := {i: True for i in self.operation if user_role_rec.get(i)}:
return permission_dict
else:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient Permission!")
def __call__(self, request: Request) -> dict[str, bool]:
user_id = request.cookies.get("userId", request.headers.get("userId"))
project_id = request.cookies.get("projectId", request.headers.get("projectId"))
return self.check_permissions(user_id=user_id, project_id=project_id)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment