Commit 1eb75421 authored by Sikhin VC's avatar Sikhin VC

initial commit

parent 174a8e28
FROM python:3.8
RUN apt-get update && apt-get install tzdata openssl -y
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
ADD . /app
WORKDIR /app
RUN pip install -r requirements.txt
CMD [ "python3","app.py" ]
# vgg16_training_pipeline # poc-vision-active-learning-off-prem
Project repo is to maintain code for training pipeline
\ No newline at end of file
if __name__ == "__main__":
from dotenv import load_dotenv
load_dotenv(dotenv_path='config.env')
import os.path
import shutil
import sys
import traceback
import warnings
from pathlib import Path
from pymongo import MongoClient
from loguru import logger
from scripts.constants.app_configuration import Mongo, job, StatusMessage
from utils.other.blob_util import BlobUtil
from utils.other.mlflow_vam_util import ModelReTrainer
from utils.preprocessing.dataset_augmentation import DataAugmentation
from utils.preprocessing.dataset_extraction import ExtractDataset
from utils.preprocessing.train_test_split_classifier import TrainTestSplit
warnings.filterwarnings("ignore")
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
def main():
try:
mmongo = MongoClient(Mongo.mongo_uri)
mongo_client = mmongo[Mongo.mongo_db]
job_id = job.job_id
job_collection = job.job_collection
master_conf = mongo_client[job_collection].find_one({"job_id": job_id})
os.mkdir(job_id)
root_data_path = os.path.join(job_id, 'temp')
raw_dataset_path = os.path.join(root_data_path, "raw_dataset")
extracted_dataset_path = os.path.join(root_data_path, "exatracted_dataset")
unaugmented_dataset_path = os.path.join(root_data_path, "unaugmented_dataset")
augmented_dataset_path = os.path.join(root_data_path, "augmented_dataset")
split_dataset_path = os.path.join(root_data_path, "split_dataset")
if not os.path.exists(root_data_path):
os.mkdir(root_data_path)
if not os.path.exists(raw_dataset_path):
os.mkdir(raw_dataset_path)
if not os.path.exists(extracted_dataset_path):
os.mkdir(extracted_dataset_path)
if not os.path.exists(unaugmented_dataset_path):
os.mkdir(unaugmented_dataset_path)
if not os.path.exists(augmented_dataset_path):
os.mkdir(augmented_dataset_path)
if not os.path.exists(split_dataset_path):
os.mkdir(split_dataset_path)
logger.info("Starting Vision Data Accusation Pipeline")
blob_list = master_conf['blob_path']
logger.info(f'Starting accusation of Stream {blob_list}')
blob_util = BlobUtil(master_conf['project_id'], master_conf['Site_id'],
master_conf['line_id'], master_conf['camera_id'])
mongo_client[job.job_collection].update_one({'job_id': job_id}, {
'$set': {"job_status": StatusMessage.tr_data_download_started, 'progress': 5.0}}, upsert=False)
if blob_download_status := blob_util.download(master_conf['blob_path'], raw_dataset_path):
logger.info("Data Downloaded Successfully!")
mongo_client[job.job_collection].update_one({'job_id': job_id}, {
'$set':
{"job_status": StatusMessage.tr_data_downloaded,
'progress': 10.0}},
upsert=False)
extract = ExtractDataset()
#raw_dataset_path = Path("data_serin")
#extract.extract_ds(os.path.join(raw_dataset_path, "serin_dataset.zip"), extracted_dataset_path)
extract.extract_ds(os.path.join(raw_dataset_path, "dataset.zip"), extracted_dataset_path)
extract.move_files(extracted_dataset_path, unaugmented_dataset_path)
class_names = os.listdir(unaugmented_dataset_path)
logger.info('Initiating Augmentations of Dataset!!')
general_augmentation_list = master_conf['general_augmentations'].keys()
augmentation_functions = []
for augmentation in general_augmentation_list:
augmentation_value =master_conf['general_augmentations'][augmentation]
augmentation_functions.extend(
{"property": augmentation, "value": val} for val in augmentation_value)
logger.info(
f'Starting Data augmentation with type {augmentation} and Value {augmentation_value}')
classification_data_aug_manager = DataAugmentation(augmentation_functions)
for cls in class_names:
classwise_unaugment_dataset_path = os.path.join(unaugmented_dataset_path, cls)
os.mkdir(os.path.join(augmented_dataset_path, cls))
classification_data_aug_manager.process(annotation_directory=classwise_unaugment_dataset_path,
post_process_directory=os.path.join(augmented_dataset_path,
cls))
mongo_client[job.job_collection].update_one({'job_id': job_id}, {
'$set':
{"job_status": StatusMessage.tr_augmentation_completed,
'progress': 20.0}},
upsert=False)
classification_data_aug_manager.combine_dataset(augmented_dataset_path=augmented_dataset_path,
unaugmented_dataset_path=unaugmented_dataset_path, cls=cls)
train_test_split_obj = TrainTestSplit(augmented_dataset_path, job_id,
validation_size=0.20,
test_size=0.10)
training_path, validation_path, test_path, dataset_size = train_test_split_obj.train_test_split()
shutil.rmtree(f'./{root_data_path}')
master_conf['Dataset_size'] = dataset_size
mongo_client[job.job_collection].update_one({'job_id': job_id}, {
'$set':
{"job_status": StatusMessage.tr_init_training,
'progress': 35.0}},
upsert=False)
model_trainer = ModelReTrainer(master_conf['project_id'],
master_conf['Site_id'],
master_conf['line_id'],
master_conf['camera_id'], training_path,
validation_path, job_id, master_conf)
result_directory = model_trainer.start_training()
mongo_client[job.job_collection].update_one({'job_id': job_id}, {
'$set': {"job_status": StatusMessage.tr_data_upload_started, 'progress': 90.0}}, upsert=False)
blob_util = BlobUtil(master_conf['project_id'], master_conf['Site_id'],
master_conf['line_id'], master_conf['camera_id'])
logger.info("Uploading Files!!!")
for each_file in os.listdir(result_directory):
if not blob_util.upload(result_directory, f'{each_file}'):
logger.info("Data Upload Unsuccessfull!")
logger.info("Data Uploaded Successfully!")
mongo_client[job.job_collection].update_one({'job_id': job_id}, {
'$set':
{"job_status": StatusMessage.tr_data_uploaded,
'progress': 100.0}},
upsert=False)
shutil.rmtree(job_id)
shutil.rmtree(result_directory)
except Exception as e:
traceback.print_exc()
shutil.rmtree(job_id)
if os.path.exists(result_directory):
shutil.rmtree(result_directory)
mongo_client[job.job_collection].update_one({'job_id': job_id}, {
'$set':
{"job_status": StatusMessage.tr_failed,
'progress': 35.0}},
upsert=False)
mmongo.close()
if __name__ == '__main__':
logger.info("Attempting to Start Training Pipeline!")
try:
main()
except Exception as e:
traceback.print_exc()
logger.error(f"Failed to Start Training Pipeline! : {str(e)}")
[BLOB_STORAGE]
account_name=$ACCOUNT_NAME
account_key=$ACCOUNT_KEY
container_name=$CONTAINER_NAME
[MLFLOW]
mlflow_tracking_uri=$MLFLOW_TRACKING_URI
mlflow_tracking_username=$MLFLOW_TRACKING_USERNAME
mlflow_tracking_password=$MLFLOW_TRACKING_PASSWORD
azure_storage_connection_string=$AZURE_STORAGE_CONNECTION_STRING
azure_storage_access_key=$AZURE_STORAGE_ACCESS_KEY
user=$USER
experiment_name=$EXPERIMENT_NAME
run_name=$RUN_NAME
model_name=$MODEL_NAME
check_param=$CHECK_PARAM
model_check_param=$MODEL_CHECK_PARAM
total_models_needed=$TOTAL_MODELS_NEEDED
[MONGO]
mongo_uri=$MONGO_URI
mongo_db=$MONGO_DATABASE
[JOB]
job_id=$JOB_ID
collection=$COLLECTION
master_config:
project: 'JK_Cements'
site: 'Chittorgarh'
line: 1
camera: 1
blob_path:
"JK_Cements/Chittorgarh/1/1/Annotated_data/camera_41_annotated_data.zip"
data_path:
"jk_data"
augmentation_types:
general_augmentations:
blur:
value: 0.3
sepia:
value: 0.3
noise:
value: 0.1
cutout:
value: 0.1
horizontal_flip:
value: 0
grayscale:
value: 0.4
hue:
value: 1
saturation:
value: 1
brightness:
value: 0.8
exposure:
value: 0.9
vertical_flip:
value: 0
bounding_box_level_augmentations:
horizontal_flip:
value: 0
random_translate:
value: 0.3
rotate:
value: [ 90, 180, 270 ]
random_shear:
value: 0.2
resize:
value: 608
random_hsv:
value: [ 100, 100, 100 ]
training_params:
weights: 'models/jk_v5_cam_44.pt' #'initial weights path'
cfg: 'cfg/yolov5s.yaml' #help='model.yaml path'
data: 'hyp/jk_data.yaml' # help='hyp.yaml path'
hyp: 'hyp/hyp.scratch.yaml'# help='hyperparameters path'
epochs: 40
batch_size: 16 # help='total batch size for all GPUs'
imgsz: 416 # help='[train, test] image sizes'
rect: False # help='rectangular training')
resume: False # help='resume most recent training')
nosave: False # help='only save final checkpoint')
notest: False # help='only test final epoch')
noautoanchor: False # help='disable autoanchor check')
evolve: False # help='evolve hyperparameters')
bucket: '' # help='gsutil bucket')
cache: False # help='cache images for faster training')
image_weights: False # help='use weighted image selection for training')
device: 'cpu' # help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
multi_scale: False # help='vary img-size +/- 50%%')
single_cls: False # help='train multi-class hyp as single-class')
adam: False # help='use torch.optim.Adam() optimizer')
sync-bn: False # help='use SyncBatchNorm, only available in DDP mode')
local_rank: -1 # help='DDP parameter, do not modify')
workers: 8 # help='maximum number of dataloader workers')
project: 'runs/train' # help='save to project/name')
entity: None # help='W&B entity')
name: '0' # help='save to project/name')
exist_ok: False # help='existing project/name ok, do not increment')
quad: False # help='quad dataloader')
linear_lr: False # help='linear LR')
label_smoothing: 0.0 # help='Label smoothing epsilon')
upload_dataset: False # help='Upload dataset as W&B artifact table')
bbox_interval: -1 # help='Set bounding-box image logging interval for W&B')
save_period: -1 # help='Log model after every "save_period" epoch')
artifact_alias: "latest" # help='version of dataset artifact to be used')
freeze: 0 # help='Freeze layers: backbone of yolov7=50, first3=0 1 2')
v5_metric: False # help='assume maximum recall as 1.0 in AP calculation')
patience: 100
noval: False
wts_conversion_params:
output : 'models/best.wts'
ACCOUNT_NAME = azrmlilensqa006382180551
ACCOUNT_KEY = tDGOKfiZ2svfoMvVmS0Fbpf0FTHfTq4wKYuDX7cAxlhve/3991QuzdvJHm9vWc+lo6mtC+x9yPSghWNR4+gacg==
CONTAINER_NAME = vision-app-videos
MLFLOW_TRACKING_URI=http://192.168.2.147:50000/
AZURE_STORAGE_CONNECTION_STRING=DefaultEndpointsProtocol=https;AccountName=azrmlilensqa006382180551;AccountKey=tDGOKfiZ2svfoMvVmS0Fbpf0FTHfTq4wKYuDX7cAxlhve/3991QuzdvJHm9vWc+lo6mtC+x9yPSghWNR4+gacg==;EndpointSuffix=core.windows.net
AZURE_STORAGE_ACCESS_KEY=tDGOKfiZ2svfoMvVmS0Fbpf0FTHfTq4wKYuDX7cAxlhve/3991QuzdvJHm9vWc+lo6mtC+x9yPSghWNR4+gacg==
USER=Dalmia_degradation
EXPERIMENT_NAME=Solar-String-Level-Degradation-Model-Factory
RUN_NAME=ObjectDetection
MODEL_NAME=versioning
CHECK_PARAM=hours
MODEL_CHECK_PARAM=480
MONGO_URI=mongodb://admin:iLensVisMongo%23723@192.168.2.147:2717
MONGO_DATABASE=admin
JOB_ID=05_10_23_132322
COLLECTION=job
absl-py==1.4.0
albumentations==1.1.0
alembic==1.11.1
anyio==3.7.1
azure-core==1.26.3
azure-identity==1.12.0
azure-storage-blob==12.15.0
blinker==1.6.2
cachetools==5.2.0
certifi==2021.10.8
cffi==1.15.1
charset-normalizer==3.2.0
click==8.1.6
cloudpickle==2.2.1
colorama==0.4.6
contourpy==1.1.0
cryptography==41.0.2
cycler==0.11.0
databricks-cli==0.17.7
dnspython==2.4.1
docker==6.1.3
entrypoints==0.4
exceptiongroup==1.1.2
fastapi==0.100.1
Flask==2.3.2
fonttools==4.41.1
gitdb==4.0.10
GitPython==3.1.32
google-auth==2.22.0
google-auth-oauthlib==0.4.6
greenlet==2.0.2
grpcio==1.56.2
h11==0.14.0
idna==3.4
imageio==2.31.1
importlib-metadata==6.8.0
importlib-resources==6.0.0
isodate==0.6.1
itsdangerous==2.1.2
Jinja2==3.1.2
joblib==1.3.1
kiwisolver==1.4.4
llvmlite==0.40.1
loguru==0.6.0
Mako==1.2.4
Markdown==3.4.4
MarkupSafe==2.1.3
matplotlib==3.7.2
mlflow==2.2.2
mongoengine==0.27.0
msal==1.23.0
msal-extensions==1.0.0
networkx==3.1
numba==0.57.1
numpy==1.24.4
oauthlib==3.2.2
opencv-python==4.6.0.66
opencv-python-headless==4.8.0.74
packaging==23.1
pandas==1.3.4
Pillow==8.4.0
portalocker==2.7.0
protobuf==3.20.3
pyarrow==11.0.0
pyasn1==0.5.0
pyasn1-modules==0.3.0
pycparser==2.21
pydantic==1.8.2
PyJWT==2.8.0
pymongo==4.4.1
pyparsing==3.0.9
python-dateutil==2.8.2
python-dotenv==0.19.0
pytz==2022.7.1
PyWavelets==1.4.1
#pywin32==306
PyYAML==6.0
qudida==0.0.4
querystring-parser==1.2.4
requests==2.31.0
requests-oauthlib==1.3.1
rsa==4.9
scikit-image==0.19.3
scikit-learn==1.3.0
scipy==1.10.1
seaborn==0.12.2
shap==0.42.1
six==1.16.0
slicer==0.0.7
smmap==5.0.0
sniffio==1.3.0
SQLAlchemy==2.0.19
sqlparse==0.4.4
starlette==0.27.0
tabulate==0.9.0
tensorboard==2.11.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
threadpoolctl==3.2.0
tifffile==2023.7.10
torch==1.13.0
torchaudio==0.13.0
torchvision==0.14.0
tqdm==4.64.1
typing_extensions==4.7.1
urllib3==1.26.16
uvicorn==0.23.2
waitress==2.1.2
websocket-client==1.6.1
Werkzeug==2.3.6
win32-setctime==1.1.0
zipp==3.16.2
import os
import os.path
import sys
from configparser import ConfigParser, BasicInterpolation
import yaml
master_configuration_file = r"./conf/master_config.yml"
class EnvInterpolation(BasicInterpolation):
"""
Interpolation which expands environment variables in values.
"""
def before_get(self, parser, section, option, value, defaults):
value = super().before_get(parser, section, option, value, defaults)
if not os.path.expandvars(value).startswith("$"):
return os.path.expandvars(value)
else:
return
try:
config = ConfigParser(interpolation=EnvInterpolation())
config.read("conf/application.conf")
except Exception as e:
print(f"Error while loading the config: {e}")
print("Failed to Load Configuration. Exiting!!!")
sys.exit()
class Logging:
level = config.get("LOGGING", "level", fallback="INFO")
level = level or "INFO"
tb_flag = config.getboolean("LOGGING", "traceback", fallback=True)
tb_flag = tb_flag if tb_flag is not None else True
BLOB_ACCOUNT_NAME = config["BLOB_STORAGE"]["account_name"]
BLOB_ACCOUNT_KEY = config["BLOB_STORAGE"]["account_key"]
BLOB_CONTAINER_NAME = config["BLOB_STORAGE"]["container_name"]
class MlFlow:
mlflow_tracking_uri = config['MLFLOW']['mlflow_tracking_uri']
# mlflow_tracking_username = config['MLFLOW']['mlflow_tracking_username']
# mlflow_tracking_password = config['MLFLOW']['mlflow_tracking_password']
azure_storage_connection_string = config['MLFLOW']['azure_storage_connection_string']
azure_storage_access_key = config['MLFLOW']['azure_storage_access_key']
user = config['MLFLOW']['user']
experiment_name = config['MLFLOW']['experiment_name']
run_name = config['MLFLOW']['run_name']
model_name = config['MLFLOW']['model_name']
check_param = config['MLFLOW']['check_param']
model_check_param = config['MLFLOW']['model_check_param']
class StatusMessage:
tr_data_uploaded = 'Data Downloaded Successfully'
tr_data_upload_started = 'Data Upload Started'
tr_started = 'Started'
tr_data_download_started = 'Data Download Started'
tr_data_downloaded = 'Data Downloaded Successfully'
tr_augmentation_completed = 'Augmentation Completed'
tr_init_training = 'Initiating Training'
tr_in_progress = 'Training In-Progress'
tr_completed = 'Training Completed'
tr_failed = 'Training Failed'
class Mongo:
mongo_uri = config['MONGO']['mongo_uri']
mongo_db = config['MONGO']['mongo_db']
class job:
job_id = config['JOB']['job_id']
job_collection = config['JOB']['collection']
class ReqTimeZone:
required_tz = "Asia/Kolkata"
This diff is collapsed.
import os
from loguru import logger
import datetime
import traceback
from azure.storage.blob import BlobServiceClient
from scripts.constants.app_configuration import BLOB_CONTAINER_NAME, BLOB_ACCOUNT_NAME, BLOB_ACCOUNT_KEY
class Blob_Uploader:
def __init__(self, project, site, line, camera, local_path):
self.project = project
self.site = site
self.line = line
self.camera = camera
self.blob_service_client = BlobServiceClient(f"https://{BLOB_ACCOUNT_NAME}.blob.core.windows.net",
credential=BLOB_ACCOUNT_KEY)
# self.local_path = "./hyp"
self.local_path = local_path
self.container = BLOB_CONTAINER_NAME
def upload(self):
try:
for i in os.listdir(self.local_path):
upload_file_path = os.path.join(self.local_path, i)
blob_client = self.blob_service_client.get_blob_client(container=self.container,
blob=f'{self.project}/{self.site}/{self.line}/'
f'{self.camera}/'
f'{datetime.datetime.now().day}_'
f'{datetime.datetime.now().month}_'
f'{datetime.datetime.now().year}'
f'/images/{i}')
with open(file=upload_file_path, mode="rb") as data:
blob_client.upload_blob(data)
logger.info(
f'Uploaded to Azure Storage with blob path:{self.project}/{self.site}/{self.line}/{self.camera}/'
f'{datetime.datetime.now().day}_{datetime.datetime.now().month}_{datetime.datetime.now().year}'
f'/images')
return True
except Exception as e:
traceback.print_exc()
logger.error(f"Failed to Push Files to blob Storage! : {str(e)}")
return False
import datetime
import os
import traceback
from pathlib import Path
from azure.storage.blob import BlobServiceClient
from loguru import logger
from scripts.constants.app_configuration import BLOB_CONTAINER_NAME, BLOB_ACCOUNT_NAME, BLOB_ACCOUNT_KEY
class BlobUtil:
"""
For downloading and uploading datas via blob storage
"""
def __init__(self, project, site, line, camera):
self.project = project
self.site = site
self.line = line
self.camera = camera
self.blob_service_client = BlobServiceClient(f"https://{BLOB_ACCOUNT_NAME}.blob.core.windows.net",
credential=BLOB_ACCOUNT_KEY)
self.container = BLOB_CONTAINER_NAME
def download(self, download_path, dest_path):
try:
download_file_path = os.path.join(dest_path, "dataset.zip")
blob_client = self.blob_service_client.get_blob_client(container=self.container,
blob=download_path)
with open(download_file_path, "wb") as download_file:
blob_client.download_blob().readinto(download_file)
logger.info(f'download from Azure Storage with blob path:{download_path}')
return True
except Exception as e:
traceback.print_exc()
logger.error(f"Failed to Pull Files from blob Storage! : {str(e)}")
return False
def upload(self, local_path, file_name):
try:
upload_file_path = os.path.join(local_path, file_name)
blob_client = self.blob_service_client.get_blob_client(container=self.container,
blob=f'{self.project}/{self.site}/{self.line}/'
f'{self.camera}/'
f'{datetime.datetime.now().hour}_'
f'{datetime.datetime.now().day}_'
f'{datetime.datetime.now().month}_'
f'{datetime.datetime.now().year}'
f'/{file_name}')
with open(file=upload_file_path, mode="rb") as data:
blob_client.upload_blob(data)
logger.info(
f'Uploaded to Azure Storage with blob path:{self.project}/{self.site}/{self.line}/{self.camera}/'
f'{datetime.datetime.now().hour}_{datetime.datetime.now().day}_{datetime.datetime.now().month}_{datetime.datetime.now().year}'
f'/{file_name}')
return True
except Exception as e:
traceback.print_exc()
logger.error(f"Failed to Push Files to blob Storage! : {str(e)}")
return False
This diff is collapsed.
import os
import shutil
import cv2
import numpy as np
from loguru import logger
from pathlib import Path
import random
class DataAugmentation:
"""
Handles with various augmentations for dataset.
"""
def __init__(self, functions):
self.functions = functions
# [AugmentImage({each_prop['property']: each_prop}) for each_prop in functions]
# for item in functions:
# self.values = tuple([item['value']])
# print(self.values)
def fill(self, img, h, w):
img = cv2.resize(img, (h, w), cv2.INTER_CUBIC)
return img
def horizontal_shift(self, img):
for augmentation in self.functions:
if "horizontal_shift" == augmentation['property']:
ratio = augmentation['value']
if ratio > 1 or ratio < 0:
print('Value should be less than 1 and greater than 0')
return img
ratio = random.uniform(-ratio, ratio)
h, w = img.shape[:2]
to_shift = w * ratio
if ratio > 0:
img = img[:, :int(w - to_shift), :]
if ratio < 0:
img = img[:, int(-1 * to_shift):, :]
img = self.fill(img, h, w)
return img
def vertical_shift(self, img):
for augmentation in self.functions:
if "vertical_shift" == augmentation['property']:
ratio = augmentation['value']
if ratio > 1 or ratio < 0:
print('Value should be less than 1 and greater than 0')
return img
ratio = random.uniform(-ratio, ratio)
h, w = img.shape[:2]
to_shift = h * ratio
if ratio > 0:
img = img[:int(h - to_shift), :, :]
if ratio < 0:
img = img[int(-1 * to_shift):, :, :]
img = self.fill(img, h, w)
return img
def brightness(self, img):
for augmentation in self.functions:
if "brightness" == augmentation['property']:
low = augmentation['value']
value = random.uniform(low, low + 2.5)
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
hsv = np.array(hsv, dtype=np.float64)
hsv[:, :, 1] = hsv[:, :, 1] * value
hsv[:, :, 1][hsv[:, :, 1] > 255] = 255
hsv[:, :, 2] = hsv[:, :, 2] * value
hsv[:, :, 2][hsv[:, :, 2] > 255] = 255
hsv = np.array(hsv, dtype=np.uint8)
img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
return img
def channel_shift(self, img):
for augmentation in self.functions:
if "channel_shift" == augmentation['property']:
value = augmentation['value']
value = int(random.uniform(-value, value))
img = img + value
img[:, :, :][img[:, :, :] > 255] = 255
img[:, :, :][img[:, :, :] < 0] = 0
img = img.astype(np.uint8)
return img
def zoom(self, img):
for augmentation in self.functions:
if "zoom" == augmentation['property']:
value = augmentation['value']
if value > 1 or value < 0:
print('Value for zoom should be less than 1 and greater than 0')
return img
value = random.uniform(value, 1)
h, w = img.shape[:2]
h_taken = int(value * h)
w_taken = int(value * w)
h_start = random.randint(0, h - h_taken)
w_start = random.randint(0, w - w_taken)
img = img[h_start:h_start + h_taken, w_start:w_start + w_taken, :]
img = self.fill(img, h, w)
return img
def horizontal_flip(self, img):
for augmentation in self.functions:
if "horizontal_flip" == augmentation['property']:
flag = augmentation['value']
if flag:
return cv2.flip(img, 1)
else:
return img
def vertical_flip(self, img):
for augmentation in self.functions:
if "vertical_flip" == augmentation['property']:
flag = augmentation['value']
if flag:
return cv2.flip(img, 0)
else:
return img
def rotation(self, img):
for augmentation in self.functions:
if "rotation" == augmentation['property']:
angle = augmentation['value']
angle = int(random.uniform(-angle, angle))
h, w = img.shape[:2]
M = cv2.getRotationMatrix2D((int(w / 2), int(h / 2)), angle, 1)
img = cv2.warpAffine(img, M, (w, h))
return img
def process(self, annotation_directory, post_process_directory):
assert os.path.exists(annotation_directory)
if not os.path.exists(post_process_directory):
os.mkdir(post_process_directory)
logger.info(f"Path: {post_process_directory} does not exist, creating one now!")
for each_file in os.listdir(annotation_directory):
filename, file_extension = os.path.splitext(os.path.join(annotation_directory, each_file))
if file_extension in ['.jpg', '.jpeg', '.png']:
image = cv2.imread(os.path.join(annotation_directory, each_file))
multi_images = (
self.horizontal_shift(image), self.vertical_shift(image), self.brightness(image),
self.zoom(image), self.channel_shift(image), self.horizontal_flip(image),
self.vertical_flip(image), self.rotation(image))
_file_name = 0
for each_element in multi_images:
image = each_element
cv2.imwrite(
os.path.join(post_process_directory, f"{each_file[:-4]}" + "_" + f"{_file_name}" + ".jpg"),
image)
_file_name = _file_name + 1
def combine_dataset(self, augmented_dataset_path, unaugmented_dataset_path, cls):
for each_file in os.listdir(os.path.join(unaugmented_dataset_path, cls)):
shutil.copy(os.path.join(unaugmented_dataset_path, cls, each_file),
os.path.join(augmented_dataset_path, cls))
from zipfile import ZipFile
import os
import shutil
class ExtractDataset:
"""
Extracts the dataset.
"""
def __init__(self):
self.dest_path = None
self.dataset_path = None
def extract_ds(self, dataset_path, dest_path):
self.dataset_path = dataset_path
self.dest_path = dest_path
with ZipFile(self.dataset_path, 'r') as zObject:
zObject.extractall(path=dest_path)
def move_files(self, dataset_path, dest_path):
self.dataset_path = dataset_path
self.dest_path = dest_path
dataset_outer_path = os.listdir(self.dataset_path)
classes = os.listdir(os.path.join(self.dataset_path, dataset_outer_path[0]))
for cls in classes:
shutil.move(os.path.join(self.dataset_path, dataset_outer_path[0], cls), os.path.join(self.dest_path, cls))
# for name in files:
# shutil.move(os.path.join(root, name), os.path.join(self.dest_path, name))
import os
import shutil
import cv2
from loguru import logger
class TrainTestSplit:
"""
Split the dataset into three.
train
test
valid
"""
def __init__(self, dataset_path, job_id, validation_size=0.20, test_size=0.10):
self.path = dataset_path
self.folder_train_name = os.path.join(job_id, "train")
self.folder_valid_name = os.path.join(job_id, "valid")
self.folder_test_name = os.path.join(job_id, "test")
self.validation_size = validation_size
self.test_size = test_size
if not os.path.exists(self.folder_train_name):
os.mkdir(self.folder_train_name)
if not os.path.exists(self.folder_valid_name):
os.mkdir(self.folder_valid_name)
if not os.path.exists(self.folder_test_name):
os.mkdir(self.folder_test_name)
def train_test_split(self):
main_dir = os.listdir(self.path)
list_image = []
for each in os.listdir(self.path):
for img in os.listdir(os.path.join(self.path, each)):
list_image.append(cv2.imread(os.path.join(self.path, each, img)))
count = len(list_image)
for each_dir in main_dir:
os.mkdir(os.path.join(self.folder_train_name, each_dir))
os.mkdir(os.path.join(self.folder_test_name, each_dir))
os.mkdir(os.path.join(self.folder_valid_name, each_dir))
files = os.listdir(os.path.join(self.path, each_dir))
train_per = round(len(files) * (1 - (self.validation_size + self.test_size)))
valid_per = round(len(files) * self.validation_size)
test_per = round(len(files) * self.test_size)
for every_file in files[:train_per]:
shutil.copyfile(os.path.join(self.path, each_dir, every_file),
os.path.join(self.folder_train_name, each_dir, every_file))
for every_file in files[train_per:train_per + valid_per]:
shutil.copyfile(os.path.join(self.path, each_dir, every_file),
os.path.join(self.folder_valid_name, each_dir, every_file))
for every_file in files[train_per + valid_per:]:
shutil.copyfile(os.path.join(self.path, each_dir, every_file),
os.path.join(self.folder_test_name, each_dir, every_file))
return self.folder_train_name, self.folder_valid_name, self.folder_test_name, count
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment