Commit 1eb75421 authored by Sikhin VC's avatar Sikhin VC

initial commit

parent 174a8e28
FROM python:3.8
RUN apt-get update && apt-get install tzdata openssl -y
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
ADD . /app
WORKDIR /app
RUN pip install -r requirements.txt
CMD [ "python3","app.py" ]
# vgg16_training_pipeline
# poc-vision-active-learning-off-prem
Project repo is to maintain code for training pipeline
\ No newline at end of file
if __name__ == "__main__":
from dotenv import load_dotenv
load_dotenv(dotenv_path='config.env')
import os.path
import shutil
import sys
import traceback
import warnings
from pathlib import Path
from pymongo import MongoClient
from loguru import logger
from scripts.constants.app_configuration import Mongo, job, StatusMessage
from utils.other.blob_util import BlobUtil
from utils.other.mlflow_vam_util import ModelReTrainer
from utils.preprocessing.dataset_augmentation import DataAugmentation
from utils.preprocessing.dataset_extraction import ExtractDataset
from utils.preprocessing.train_test_split_classifier import TrainTestSplit
warnings.filterwarnings("ignore")
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
def main():
try:
mmongo = MongoClient(Mongo.mongo_uri)
mongo_client = mmongo[Mongo.mongo_db]
job_id = job.job_id
job_collection = job.job_collection
master_conf = mongo_client[job_collection].find_one({"job_id": job_id})
os.mkdir(job_id)
root_data_path = os.path.join(job_id, 'temp')
raw_dataset_path = os.path.join(root_data_path, "raw_dataset")
extracted_dataset_path = os.path.join(root_data_path, "exatracted_dataset")
unaugmented_dataset_path = os.path.join(root_data_path, "unaugmented_dataset")
augmented_dataset_path = os.path.join(root_data_path, "augmented_dataset")
split_dataset_path = os.path.join(root_data_path, "split_dataset")
if not os.path.exists(root_data_path):
os.mkdir(root_data_path)
if not os.path.exists(raw_dataset_path):
os.mkdir(raw_dataset_path)
if not os.path.exists(extracted_dataset_path):
os.mkdir(extracted_dataset_path)
if not os.path.exists(unaugmented_dataset_path):
os.mkdir(unaugmented_dataset_path)
if not os.path.exists(augmented_dataset_path):
os.mkdir(augmented_dataset_path)
if not os.path.exists(split_dataset_path):
os.mkdir(split_dataset_path)
logger.info("Starting Vision Data Accusation Pipeline")
blob_list = master_conf['blob_path']
logger.info(f'Starting accusation of Stream {blob_list}')
blob_util = BlobUtil(master_conf['project_id'], master_conf['Site_id'],
master_conf['line_id'], master_conf['camera_id'])
mongo_client[job.job_collection].update_one({'job_id': job_id}, {
'$set': {"job_status": StatusMessage.tr_data_download_started, 'progress': 5.0}}, upsert=False)
if blob_download_status := blob_util.download(master_conf['blob_path'], raw_dataset_path):
logger.info("Data Downloaded Successfully!")
mongo_client[job.job_collection].update_one({'job_id': job_id}, {
'$set':
{"job_status": StatusMessage.tr_data_downloaded,
'progress': 10.0}},
upsert=False)
extract = ExtractDataset()
#raw_dataset_path = Path("data_serin")
#extract.extract_ds(os.path.join(raw_dataset_path, "serin_dataset.zip"), extracted_dataset_path)
extract.extract_ds(os.path.join(raw_dataset_path, "dataset.zip"), extracted_dataset_path)
extract.move_files(extracted_dataset_path, unaugmented_dataset_path)
class_names = os.listdir(unaugmented_dataset_path)
logger.info('Initiating Augmentations of Dataset!!')
general_augmentation_list = master_conf['general_augmentations'].keys()
augmentation_functions = []
for augmentation in general_augmentation_list:
augmentation_value =master_conf['general_augmentations'][augmentation]
augmentation_functions.extend(
{"property": augmentation, "value": val} for val in augmentation_value)
logger.info(
f'Starting Data augmentation with type {augmentation} and Value {augmentation_value}')
classification_data_aug_manager = DataAugmentation(augmentation_functions)
for cls in class_names:
classwise_unaugment_dataset_path = os.path.join(unaugmented_dataset_path, cls)
os.mkdir(os.path.join(augmented_dataset_path, cls))
classification_data_aug_manager.process(annotation_directory=classwise_unaugment_dataset_path,
post_process_directory=os.path.join(augmented_dataset_path,
cls))
mongo_client[job.job_collection].update_one({'job_id': job_id}, {
'$set':
{"job_status": StatusMessage.tr_augmentation_completed,
'progress': 20.0}},
upsert=False)
classification_data_aug_manager.combine_dataset(augmented_dataset_path=augmented_dataset_path,
unaugmented_dataset_path=unaugmented_dataset_path, cls=cls)
train_test_split_obj = TrainTestSplit(augmented_dataset_path, job_id,
validation_size=0.20,
test_size=0.10)
training_path, validation_path, test_path, dataset_size = train_test_split_obj.train_test_split()
shutil.rmtree(f'./{root_data_path}')
master_conf['Dataset_size'] = dataset_size
mongo_client[job.job_collection].update_one({'job_id': job_id}, {
'$set':
{"job_status": StatusMessage.tr_init_training,
'progress': 35.0}},
upsert=False)
model_trainer = ModelReTrainer(master_conf['project_id'],
master_conf['Site_id'],
master_conf['line_id'],
master_conf['camera_id'], training_path,
validation_path, job_id, master_conf)
result_directory = model_trainer.start_training()
mongo_client[job.job_collection].update_one({'job_id': job_id}, {
'$set': {"job_status": StatusMessage.tr_data_upload_started, 'progress': 90.0}}, upsert=False)
blob_util = BlobUtil(master_conf['project_id'], master_conf['Site_id'],
master_conf['line_id'], master_conf['camera_id'])
logger.info("Uploading Files!!!")
for each_file in os.listdir(result_directory):
if not blob_util.upload(result_directory, f'{each_file}'):
logger.info("Data Upload Unsuccessfull!")
logger.info("Data Uploaded Successfully!")
mongo_client[job.job_collection].update_one({'job_id': job_id}, {
'$set':
{"job_status": StatusMessage.tr_data_uploaded,
'progress': 100.0}},
upsert=False)
shutil.rmtree(job_id)
shutil.rmtree(result_directory)
except Exception as e:
traceback.print_exc()
shutil.rmtree(job_id)
if os.path.exists(result_directory):
shutil.rmtree(result_directory)
mongo_client[job.job_collection].update_one({'job_id': job_id}, {
'$set':
{"job_status": StatusMessage.tr_failed,
'progress': 35.0}},
upsert=False)
mmongo.close()
if __name__ == '__main__':
logger.info("Attempting to Start Training Pipeline!")
try:
main()
except Exception as e:
traceback.print_exc()
logger.error(f"Failed to Start Training Pipeline! : {str(e)}")
[BLOB_STORAGE]
account_name=$ACCOUNT_NAME
account_key=$ACCOUNT_KEY
container_name=$CONTAINER_NAME
[MLFLOW]
mlflow_tracking_uri=$MLFLOW_TRACKING_URI
mlflow_tracking_username=$MLFLOW_TRACKING_USERNAME
mlflow_tracking_password=$MLFLOW_TRACKING_PASSWORD
azure_storage_connection_string=$AZURE_STORAGE_CONNECTION_STRING
azure_storage_access_key=$AZURE_STORAGE_ACCESS_KEY
user=$USER
experiment_name=$EXPERIMENT_NAME
run_name=$RUN_NAME
model_name=$MODEL_NAME
check_param=$CHECK_PARAM
model_check_param=$MODEL_CHECK_PARAM
total_models_needed=$TOTAL_MODELS_NEEDED
[MONGO]
mongo_uri=$MONGO_URI
mongo_db=$MONGO_DATABASE
[JOB]
job_id=$JOB_ID
collection=$COLLECTION
master_config:
project: 'JK_Cements'
site: 'Chittorgarh'
line: 1
camera: 1
blob_path:
"JK_Cements/Chittorgarh/1/1/Annotated_data/camera_41_annotated_data.zip"
data_path:
"jk_data"
augmentation_types:
general_augmentations:
blur:
value: 0.3
sepia:
value: 0.3
noise:
value: 0.1
cutout:
value: 0.1
horizontal_flip:
value: 0
grayscale:
value: 0.4
hue:
value: 1
saturation:
value: 1
brightness:
value: 0.8
exposure:
value: 0.9
vertical_flip:
value: 0
bounding_box_level_augmentations:
horizontal_flip:
value: 0
random_translate:
value: 0.3
rotate:
value: [ 90, 180, 270 ]
random_shear:
value: 0.2
resize:
value: 608
random_hsv:
value: [ 100, 100, 100 ]
training_params:
weights: 'models/jk_v5_cam_44.pt' #'initial weights path'
cfg: 'cfg/yolov5s.yaml' #help='model.yaml path'
data: 'hyp/jk_data.yaml' # help='hyp.yaml path'
hyp: 'hyp/hyp.scratch.yaml'# help='hyperparameters path'
epochs: 40
batch_size: 16 # help='total batch size for all GPUs'
imgsz: 416 # help='[train, test] image sizes'
rect: False # help='rectangular training')
resume: False # help='resume most recent training')
nosave: False # help='only save final checkpoint')
notest: False # help='only test final epoch')
noautoanchor: False # help='disable autoanchor check')
evolve: False # help='evolve hyperparameters')
bucket: '' # help='gsutil bucket')
cache: False # help='cache images for faster training')
image_weights: False # help='use weighted image selection for training')
device: 'cpu' # help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
multi_scale: False # help='vary img-size +/- 50%%')
single_cls: False # help='train multi-class hyp as single-class')
adam: False # help='use torch.optim.Adam() optimizer')
sync-bn: False # help='use SyncBatchNorm, only available in DDP mode')
local_rank: -1 # help='DDP parameter, do not modify')
workers: 8 # help='maximum number of dataloader workers')
project: 'runs/train' # help='save to project/name')
entity: None # help='W&B entity')
name: '0' # help='save to project/name')
exist_ok: False # help='existing project/name ok, do not increment')
quad: False # help='quad dataloader')
linear_lr: False # help='linear LR')
label_smoothing: 0.0 # help='Label smoothing epsilon')
upload_dataset: False # help='Upload dataset as W&B artifact table')
bbox_interval: -1 # help='Set bounding-box image logging interval for W&B')
save_period: -1 # help='Log model after every "save_period" epoch')
artifact_alias: "latest" # help='version of dataset artifact to be used')
freeze: 0 # help='Freeze layers: backbone of yolov7=50, first3=0 1 2')
v5_metric: False # help='assume maximum recall as 1.0 in AP calculation')
patience: 100
noval: False
wts_conversion_params:
output : 'models/best.wts'
ACCOUNT_NAME = azrmlilensqa006382180551
ACCOUNT_KEY = tDGOKfiZ2svfoMvVmS0Fbpf0FTHfTq4wKYuDX7cAxlhve/3991QuzdvJHm9vWc+lo6mtC+x9yPSghWNR4+gacg==
CONTAINER_NAME = vision-app-videos
MLFLOW_TRACKING_URI=http://192.168.2.147:50000/
AZURE_STORAGE_CONNECTION_STRING=DefaultEndpointsProtocol=https;AccountName=azrmlilensqa006382180551;AccountKey=tDGOKfiZ2svfoMvVmS0Fbpf0FTHfTq4wKYuDX7cAxlhve/3991QuzdvJHm9vWc+lo6mtC+x9yPSghWNR4+gacg==;EndpointSuffix=core.windows.net
AZURE_STORAGE_ACCESS_KEY=tDGOKfiZ2svfoMvVmS0Fbpf0FTHfTq4wKYuDX7cAxlhve/3991QuzdvJHm9vWc+lo6mtC+x9yPSghWNR4+gacg==
USER=Dalmia_degradation
EXPERIMENT_NAME=Solar-String-Level-Degradation-Model-Factory
RUN_NAME=ObjectDetection
MODEL_NAME=versioning
CHECK_PARAM=hours
MODEL_CHECK_PARAM=480
MONGO_URI=mongodb://admin:iLensVisMongo%23723@192.168.2.147:2717
MONGO_DATABASE=admin
JOB_ID=05_10_23_132322
COLLECTION=job
absl-py==1.4.0
albumentations==1.1.0
alembic==1.11.1
anyio==3.7.1
azure-core==1.26.3
azure-identity==1.12.0
azure-storage-blob==12.15.0
blinker==1.6.2
cachetools==5.2.0
certifi==2021.10.8
cffi==1.15.1
charset-normalizer==3.2.0
click==8.1.6
cloudpickle==2.2.1
colorama==0.4.6
contourpy==1.1.0
cryptography==41.0.2
cycler==0.11.0
databricks-cli==0.17.7
dnspython==2.4.1
docker==6.1.3
entrypoints==0.4
exceptiongroup==1.1.2
fastapi==0.100.1
Flask==2.3.2
fonttools==4.41.1
gitdb==4.0.10
GitPython==3.1.32
google-auth==2.22.0
google-auth-oauthlib==0.4.6
greenlet==2.0.2
grpcio==1.56.2
h11==0.14.0
idna==3.4
imageio==2.31.1
importlib-metadata==6.8.0
importlib-resources==6.0.0
isodate==0.6.1
itsdangerous==2.1.2
Jinja2==3.1.2
joblib==1.3.1
kiwisolver==1.4.4
llvmlite==0.40.1
loguru==0.6.0
Mako==1.2.4
Markdown==3.4.4
MarkupSafe==2.1.3
matplotlib==3.7.2
mlflow==2.2.2
mongoengine==0.27.0
msal==1.23.0
msal-extensions==1.0.0
networkx==3.1
numba==0.57.1
numpy==1.24.4
oauthlib==3.2.2
opencv-python==4.6.0.66
opencv-python-headless==4.8.0.74
packaging==23.1
pandas==1.3.4
Pillow==8.4.0
portalocker==2.7.0
protobuf==3.20.3
pyarrow==11.0.0
pyasn1==0.5.0
pyasn1-modules==0.3.0
pycparser==2.21
pydantic==1.8.2
PyJWT==2.8.0
pymongo==4.4.1
pyparsing==3.0.9
python-dateutil==2.8.2
python-dotenv==0.19.0
pytz==2022.7.1
PyWavelets==1.4.1
#pywin32==306
PyYAML==6.0
qudida==0.0.4
querystring-parser==1.2.4
requests==2.31.0
requests-oauthlib==1.3.1
rsa==4.9
scikit-image==0.19.3
scikit-learn==1.3.0
scipy==1.10.1
seaborn==0.12.2
shap==0.42.1
six==1.16.0
slicer==0.0.7
smmap==5.0.0
sniffio==1.3.0
SQLAlchemy==2.0.19
sqlparse==0.4.4
starlette==0.27.0
tabulate==0.9.0
tensorboard==2.11.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
threadpoolctl==3.2.0
tifffile==2023.7.10
torch==1.13.0
torchaudio==0.13.0
torchvision==0.14.0
tqdm==4.64.1
typing_extensions==4.7.1
urllib3==1.26.16
uvicorn==0.23.2
waitress==2.1.2
websocket-client==1.6.1
Werkzeug==2.3.6
win32-setctime==1.1.0
zipp==3.16.2
import os
import os.path
import sys
from configparser import ConfigParser, BasicInterpolation
import yaml
master_configuration_file = r"./conf/master_config.yml"
class EnvInterpolation(BasicInterpolation):
"""
Interpolation which expands environment variables in values.
"""
def before_get(self, parser, section, option, value, defaults):
value = super().before_get(parser, section, option, value, defaults)
if not os.path.expandvars(value).startswith("$"):
return os.path.expandvars(value)
else:
return
try:
config = ConfigParser(interpolation=EnvInterpolation())
config.read("conf/application.conf")
except Exception as e:
print(f"Error while loading the config: {e}")
print("Failed to Load Configuration. Exiting!!!")
sys.exit()
class Logging:
level = config.get("LOGGING", "level", fallback="INFO")
level = level or "INFO"
tb_flag = config.getboolean("LOGGING", "traceback", fallback=True)
tb_flag = tb_flag if tb_flag is not None else True
BLOB_ACCOUNT_NAME = config["BLOB_STORAGE"]["account_name"]
BLOB_ACCOUNT_KEY = config["BLOB_STORAGE"]["account_key"]
BLOB_CONTAINER_NAME = config["BLOB_STORAGE"]["container_name"]
class MlFlow:
mlflow_tracking_uri = config['MLFLOW']['mlflow_tracking_uri']
# mlflow_tracking_username = config['MLFLOW']['mlflow_tracking_username']
# mlflow_tracking_password = config['MLFLOW']['mlflow_tracking_password']
azure_storage_connection_string = config['MLFLOW']['azure_storage_connection_string']
azure_storage_access_key = config['MLFLOW']['azure_storage_access_key']
user = config['MLFLOW']['user']
experiment_name = config['MLFLOW']['experiment_name']
run_name = config['MLFLOW']['run_name']
model_name = config['MLFLOW']['model_name']
check_param = config['MLFLOW']['check_param']
model_check_param = config['MLFLOW']['model_check_param']
class StatusMessage:
tr_data_uploaded = 'Data Downloaded Successfully'
tr_data_upload_started = 'Data Upload Started'
tr_started = 'Started'
tr_data_download_started = 'Data Download Started'
tr_data_downloaded = 'Data Downloaded Successfully'
tr_augmentation_completed = 'Augmentation Completed'
tr_init_training = 'Initiating Training'
tr_in_progress = 'Training In-Progress'
tr_completed = 'Training Completed'
tr_failed = 'Training Failed'
class Mongo:
mongo_uri = config['MONGO']['mongo_uri']
mongo_db = config['MONGO']['mongo_db']
class job:
job_id = config['JOB']['job_id']
job_collection = config['JOB']['collection']
class ReqTimeZone:
required_tz = "Asia/Kolkata"
from __future__ import print_function, division
import json
import shutil
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from loguru import logger
from pymongo import MongoClient
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torchvision import datasets, models, transforms
from torch.utils.data.dataloader import DataLoader
import time
import os
import copy
from scripts.constants.app_configuration import Mongo, job, StatusMessage
class VGG16Training:
"""
Entire code for vgg16
"""
def __init__(self, master_config, training_path, validation_path, job_id):
self.map_50_90 = []
self.map_50 = []
self.recall = []
self.precision = []
self.mongo_client = MongoClient(Mongo.mongo_uri)[Mongo.mongo_db]
self.clsloss = []
self.objloss = []
self.bbloss = []
self.master_config = master_config
self.pbar = 0
self.mloss = 0
self.current_epoch = 0
self.status = 0
self.classes = {'nc': len(master_config['classes'].split(',')), 'names': master_config['classes'].split(','),
'train': 'train/', 'val': 'valid/'}
self.epoch = 0
self.avg_loss = []
self.avg_acc = []
self.avg_loss_val = []
self.avg_acc_val = []
self.training_path = training_path
self.validation_path = validation_path
self.data_dir = job_id
self.image_datasets = self.storing_image_datasets()
self.dataset_sizes = self.get_dataset_size(self.image_datasets)
self.inputs, self.classes = next(iter(self.dataloaders("train")))
self.class_names = self.image_datasets['train'].classes
self.vgg16 = models.vgg16(pretrained=True)
for param in self.vgg16.features.parameters():
param.require_grad = False
self.use_gpu = torch.cuda.is_available()
if self.use_gpu:
logger.info("Using CUDA")
self.vgg16.cuda() # .cuda() will move everything to the GPU side
else:
logger.info("Not Using CUDA")
self.criterion = nn.CrossEntropyLoss()
self.optimizer_ft = optim.SGD(self.vgg16.parameters(), lr=0.001, momentum=0.9)
self.exp_lr_scheduler = lr_scheduler.StepLR(self.optimizer_ft, step_size=7, gamma=0.1)
self.num_epochs = self.master_config['training_params'].get("epochs")
def update_epoch(self, epoch_result, epochs_total):
"""
Update each epoch after its finished
"""
logger.info("Updating Epoch!!!")
self.epoch = int(epoch_result.get('Epoch'))
self.avg_loss.append(epoch_result.get('avg_loss_train'))
self.avg_acc.append(epoch_result.get('avg_acc_train'))
self.avg_loss_val.append(epoch_result.get('avg_loss_val'))
self.avg_acc_val.append(epoch_result.get('avg_acc_val'))
self.mongo_client[job.job_collection].update_one({'job_id': job.job_id}, {
'$set':
{"job_status": StatusMessage.tr_in_progress,
'progress': float(
35.0 + ((55 / epochs_total) * int(self.epoch)))
, 'avg_loss_train': self.avg_loss, 'avg_acc_train': self.avg_acc,
'avg_loss_val': self.avg_loss_val,
'avg_acc_val': self.avg_acc_val}},
upsert=False)
def data_augmentation_transforms(self, name):
"""
Returns a composed data augmentation transformation pipeline.
Returns:
A torchvision.transforms.Compose object representing the data augmentation pipeline.
"""
if name == "train":
return transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
if name == "valid":
return transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
])
# if name == self.test:
# return transforms.Compose([
# transforms.Resize(256),
# transforms.CenterCrop(224),
# transforms.ToTensor(),
# ])
def storing_image_datasets(self):
image_datasets = {} # Initialize an empty dictionary to store the datasets
for each in os.listdir(self.data_dir):
dataset = datasets.ImageFolder(os.path.join(self.data_dir, each),
transform=self.data_augmentation_transforms(each))
image_datasets[each] = dataset # Store the dataset in the dictionary using 'train' or 'valid' as keys
return image_datasets
def dataloaders(self, x, batch_size=8, shuffle=True, num_workers=4):
"""
Create data loaders for training and validation datasets.
"""
dataloaders = {}
dataloaders[x] = DataLoader(
self.image_datasets[x], batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
)
return dataloaders[x]
def get_dataset_size(self, image_datasets):
dataset_sizes = {x: len(image_datasets[x]) for x in os.listdir(self.data_dir)}
return dataset_sizes
# def imshow(self, inp, title=None):
# inp = inp.numpy().transpose((1, 2, 0))
# plt.axis('off')
# plt.imshow(inp)
# if title is not None:
# plt.title(title)
# plt.pause(0)
#
# def show_databatch(self, inputs, classes):
# out = torchvision.utils.make_grid(inputs)
# self.imshow(out, title=[self.class_names[x] for x in classes])
def train_model(self):
"""
The code for training a model
"""
logger.info('Initiating Training!!')
since = time.time()
best_model_wts = copy.deepcopy(self.vgg16.state_dict())
best_acc = 0.0
avg_loss = 0
avg_acc = 0
avg_loss_val = 0
avg_acc_val = 0
train_batches = len(self.dataloaders("train"))
val_batches = len(self.dataloaders("valid"))
results_list = []
for epoch in range(self.num_epochs):
logger.info("Epoch {}/{}".format(epoch, self.num_epochs))
logger.info('-' * 10)
loss_train = 0
loss_val = 0
acc_train = 0
acc_val = 0
self.vgg16.train(True)
for i, data in enumerate(self.dataloaders("train")):
if i % 100 == 0:
logger.info("Training batch {}/{}".format(i, train_batches / 2), end='', flush=True)
# Use half training dataset
if i >= train_batches / 2:
break
inputs, labels = data
if self.use_gpu:
inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
else:
inputs, labels = Variable(inputs), Variable(labels)
self.optimizer_ft.zero_grad()
outputs = self.vgg16(inputs)
_, preds = torch.max(outputs.data, 1)
loss = self.criterion(outputs, labels)
loss.backward()
self.optimizer_ft.step()
loss_train += loss.data
acc_train += torch.sum(preds == labels.data)
del inputs, labels, outputs, preds
torch.cuda.empty_cache()
# * 2 as we only used half of the dataset
avg_loss = loss_train * 2 / self.dataset_sizes['train']
avg_acc = acc_train * 2 / self.dataset_sizes['train']
self.vgg16.train(False)
self.vgg16.eval()
for i, data in enumerate(self.dataloaders("valid")):
if i % 100 == 0:
logger.info("\rValidation batch {}/{}".format(i, val_batches), end='', flush=True)
inputs, labels = data
if self.use_gpu:
inputs, labels = Variable(inputs.cuda(), volatile=True), Variable(labels.cuda(), volatile=True)
else:
inputs, labels = Variable(inputs, volatile=True), Variable(labels, volatile=True)
self.optimizer_ft.zero_grad()
outputs = self.vgg16(inputs)
_, preds = torch.max(outputs.data, 1)
loss = self.criterion(outputs, labels)
loss_val += loss.data
acc_val += torch.sum(preds == labels.data)
del inputs, labels, outputs, preds
torch.cuda.empty_cache()
avg_loss_val = loss_val / self.dataset_sizes['valid']
avg_acc_val = acc_val / self.dataset_sizes['valid']
logger.info("Epoch {} result: ".format(epoch))
logger.info("avg_loss_train: {:.4f}".format(avg_loss))
logger.info("avg_acc_train: {:.4f}".format(avg_acc))
logger.info("avg_loss_val: {:.4f}".format(avg_loss_val))
logger.info("avg_acc_val: {:.4f}".format(avg_acc_val))
logger.info('-' * 10)
# Initialize an empty dictionary to store the results for each epoch
epoch_result = {
"Epoch": format(epoch),
"avg_loss_train": format(avg_loss),
"avg_acc_train": format(avg_acc),
"avg_loss_val": format(avg_loss_val),
"avg_acc_val": format(avg_acc_val)
}
self.update_epoch(epoch_result, self.num_epochs)
results_list.append(epoch_result)
# Print the final dictionary
if avg_acc_val > best_acc:
best_acc = avg_acc_val
best_model_wts = copy.deepcopy(self.vgg16.state_dict())
elapsed_time = time.time() - since
logger.info("Training completed in {:.0f}m {:.0f}s".format(elapsed_time // 60, elapsed_time % 60))
logger.info("Best acc: {:.4f}".format(best_acc))
training_result = {"Training completed in": "{:.0f}m {:.0f}s".format(elapsed_time // 60, elapsed_time % 60),
"Best acc": "{:.4f}".format(best_acc)}
self.mongo_client[job.job_collection].update_one({'job_id': job.job_id}, {
'$set': {"job_status": StatusMessage.tr_completed, 'progress': 85.0}}, upsert=False)
results_list.append(training_result)
self.vgg16.load_state_dict(best_model_wts)
if os.path.exists("vgg_results"):
shutil.rmtree("vgg_results")
logger.info('Deleting existing directory!!')
result_directory = Path("vgg_results")
if not os.path.exists(result_directory):
os.mkdir(result_directory)
torch.save(self.vgg16, 'vgg_results/VGG16_model_complete.pt')
torch.save(self.vgg16.state_dict(), 'vgg_results/VGG16_model_State_dict.pt')
results_json_file = os.path.join(result_directory, 'output_data.json')
with open(results_json_file, 'w') as file:
json.dump(results_list, file)
return results_list, result_directory
import os
from loguru import logger
import datetime
import traceback
from azure.storage.blob import BlobServiceClient
from scripts.constants.app_configuration import BLOB_CONTAINER_NAME, BLOB_ACCOUNT_NAME, BLOB_ACCOUNT_KEY
class Blob_Uploader:
def __init__(self, project, site, line, camera, local_path):
self.project = project
self.site = site
self.line = line
self.camera = camera
self.blob_service_client = BlobServiceClient(f"https://{BLOB_ACCOUNT_NAME}.blob.core.windows.net",
credential=BLOB_ACCOUNT_KEY)
# self.local_path = "./hyp"
self.local_path = local_path
self.container = BLOB_CONTAINER_NAME
def upload(self):
try:
for i in os.listdir(self.local_path):
upload_file_path = os.path.join(self.local_path, i)
blob_client = self.blob_service_client.get_blob_client(container=self.container,
blob=f'{self.project}/{self.site}/{self.line}/'
f'{self.camera}/'
f'{datetime.datetime.now().day}_'
f'{datetime.datetime.now().month}_'
f'{datetime.datetime.now().year}'
f'/images/{i}')
with open(file=upload_file_path, mode="rb") as data:
blob_client.upload_blob(data)
logger.info(
f'Uploaded to Azure Storage with blob path:{self.project}/{self.site}/{self.line}/{self.camera}/'
f'{datetime.datetime.now().day}_{datetime.datetime.now().month}_{datetime.datetime.now().year}'
f'/images')
return True
except Exception as e:
traceback.print_exc()
logger.error(f"Failed to Push Files to blob Storage! : {str(e)}")
return False
import datetime
import os
import traceback
from pathlib import Path
from azure.storage.blob import BlobServiceClient
from loguru import logger
from scripts.constants.app_configuration import BLOB_CONTAINER_NAME, BLOB_ACCOUNT_NAME, BLOB_ACCOUNT_KEY
class BlobUtil:
"""
For downloading and uploading datas via blob storage
"""
def __init__(self, project, site, line, camera):
self.project = project
self.site = site
self.line = line
self.camera = camera
self.blob_service_client = BlobServiceClient(f"https://{BLOB_ACCOUNT_NAME}.blob.core.windows.net",
credential=BLOB_ACCOUNT_KEY)
self.container = BLOB_CONTAINER_NAME
def download(self, download_path, dest_path):
try:
download_file_path = os.path.join(dest_path, "dataset.zip")
blob_client = self.blob_service_client.get_blob_client(container=self.container,
blob=download_path)
with open(download_file_path, "wb") as download_file:
blob_client.download_blob().readinto(download_file)
logger.info(f'download from Azure Storage with blob path:{download_path}')
return True
except Exception as e:
traceback.print_exc()
logger.error(f"Failed to Pull Files from blob Storage! : {str(e)}")
return False
def upload(self, local_path, file_name):
try:
upload_file_path = os.path.join(local_path, file_name)
blob_client = self.blob_service_client.get_blob_client(container=self.container,
blob=f'{self.project}/{self.site}/{self.line}/'
f'{self.camera}/'
f'{datetime.datetime.now().hour}_'
f'{datetime.datetime.now().day}_'
f'{datetime.datetime.now().month}_'
f'{datetime.datetime.now().year}'
f'/{file_name}')
with open(file=upload_file_path, mode="rb") as data:
blob_client.upload_blob(data)
logger.info(
f'Uploaded to Azure Storage with blob path:{self.project}/{self.site}/{self.line}/{self.camera}/'
f'{datetime.datetime.now().hour}_{datetime.datetime.now().day}_{datetime.datetime.now().month}_{datetime.datetime.now().year}'
f'/{file_name}')
return True
except Exception as e:
traceback.print_exc()
logger.error(f"Failed to Push Files to blob Storage! : {str(e)}")
return False
import os
import re
import tracemalloc
import mlflow
from loguru import logger
from scripts.constants.app_configuration import MlFlow, job
from scripts.core.vgg16_training import VGG16Training
mlflow_tracking_uri = MlFlow.mlflow_tracking_uri
# os.environ["MLFLOW_TRACKING_USERNAME"] = MlFlow.mlflow_tracking_username
# os.environ["MLFLOW_TRACKING_PASSWORD"] = MlFlow.mlflow_tracking_password
os.environ["AZURE_STORAGE_CONNECTION_STRING"] = MlFlow.azure_storage_connection_string
os.environ["AZURE_STORAGE_ACCESS_KEY"] = MlFlow.azure_storage_access_key
mlflow.set_tracking_uri(mlflow_tracking_uri)
mlflow.set_registry_uri(mlflow_tracking_uri)
client = mlflow.tracking.MlflowClient()
class MlFlowUtil:
@staticmethod
def log_model(model, model_name):
try:
mlflow.sklearn.log_model(model, model_name)
logger.info("logged the model")
return True
except Exception as e:
logger.exception(str(e))
@staticmethod
def log_metrics(metrics):
try:
updated_metric = {}
for key, value in metrics.items():
key = re.sub(r"[([{})\]]", "", key)
updated_metric[key] = value
mlflow.log_metrics(updated_metric)
return True
except Exception as e:
logger.exception(str(e))
@staticmethod
def log_hyper_param(hyper_params):
try:
mlflow.log_params(hyper_params)
return True
except Exception as e:
logger.exception(str(e))
@staticmethod
def set_tag(child_run_id, key, value):
try:
client.set_tag(run_id=child_run_id, key=key, value=value)
except Exception as e:
logger.exception(f"Exception while setting the tag - {e}")
class ModelReTrainer:
def __init__(self, experiment_name, parent_run_name, line, camera, training_path, validation_path, job_id,
master_config):
self.experiment_name = experiment_name
self.parent_run_name = parent_run_name
self.line = f'Line_{line}'
self.camera = f'Camera_{camera}'
self.training_path = training_path
self.validation_path = validation_path
self.job_id = job_id
self.master_config = master_config
self._mfu_ = MlFlowUtil()
self.current_run_name = job.job_id
def check_create_experiment(self):
"""
check if experiment exists, if not creates a new experiment
:return: experiment_id of the experiment
"""
experiment_info = mlflow.get_experiment_by_name(self.experiment_name)
if experiment_info is None:
logger.info(f"No experiment found with name {self.experiment_name}, So creating one")
mlflow.create_experiment(self.experiment_name)
else:
logger.info(f"Proceeding with existing Experiment {self.experiment_name}")
mlflow.set_experiment(experiment_name=self.experiment_name)
experiment_info = mlflow.get_experiment_by_name(self.experiment_name)
experiment_id = experiment_info.experiment_id
return experiment_id
def check_create_parent_run(self, experiment_id):
"""
check if a parent run exists in the experiment, if not create it with the mentioned parent run name
:param experiment_id: Experiment id
:return: returns the parent run id
"""
parent_runs_df = mlflow.search_runs(experiment_id)
run_key = 'tags.mlflow.runName'
if run_key in parent_runs_df.columns:
parent_runs_df = parent_runs_df[parent_runs_df[run_key] == self.parent_run_name]
else:
parent_runs_df = parent_runs_df.iloc[:0]
if not parent_runs_df.empty:
logger.info(f"Proceeding with existing Parent Run {self.parent_run_name}")
return list(parent_runs_df['run_id'])[0]
# no parent run found
logger.info(f"No Parent Run present {self.parent_run_name}")
with mlflow.start_run(experiment_id=experiment_id, run_name=self.parent_run_name) as run:
logger.info(f"Creating the parent Run {self.parent_run_name} with Parent Run Id {run.info.run_id}")
return run.info.run_id
def check_create_child_run(self, experiment_id, parent_run_id):
"""
check if a child run exists in the experiment id under the parent run id
if exists take the child run id which has the model saved and validate when was it lastly trained.
Based on the lastly trained see if you have to retrain or not. if retrain create a new child run
else if no child run exists under the parent run id of experiment id, create a new child run
:param experiment_id: experiment id
:param parent_run_id: parent run id
:return: child run id, retrain flag
"""
child_runs_df = mlflow.search_runs(experiment_id, filter_string=f"tags.mlflow.parentRunId='{parent_run_id}'")
if not child_runs_df.empty:
logger.info(f"Already Child runs are present for Parent Run Id {parent_run_id}")
child_runs_df = child_runs_df[child_runs_df['tags.mlflow.runName'] == str(self.line)]
# child_run_id, retrain = self.get_latest_child_run(experiment_id, parent_run_id, child_runs_df)
if child_runs_df.empty:
with mlflow.start_run(experiment_id=experiment_id, run_id=parent_run_id, nested=True):
with mlflow.start_run(experiment_id=experiment_id, nested=True, run_name=self.line) as child_run:
return child_run.info.run_id
return list(child_runs_df['run_id'])[0]
else:
logger.info(f"Child runs are not present for Parent Run Id {parent_run_id}")
with mlflow.start_run(experiment_id=experiment_id, run_id=parent_run_id, nested=True):
with mlflow.start_run(experiment_id=experiment_id, nested=True, run_name=self.line) as child_run:
return child_run.info.run_id
def create_camera_run(self, experiment_id, city_run_id, line_run_id):
camera_child_runs_df = mlflow.search_runs(experiment_id,
filter_string=f"tags.mlflow.parentRunId='{line_run_id}'")
if not camera_child_runs_df.empty:
child_runs_df = camera_child_runs_df[camera_child_runs_df['tags.mlflow.runName'] == str(self.camera)]
if child_runs_df.empty:
with mlflow.start_run(experiment_id=experiment_id, run_id=city_run_id, nested=True):
with mlflow.start_run(experiment_id=experiment_id, nested=True, run_id=line_run_id):
with mlflow.start_run(experiment_id=experiment_id, nested=True,
run_name=self.camera) as child_run:
return child_run.info.run_id
return list(child_runs_df['run_id'])[0]
with mlflow.start_run(experiment_id=experiment_id, run_id=city_run_id, nested=True):
with mlflow.start_run(experiment_id=experiment_id, nested=True, run_id=line_run_id):
with mlflow.start_run(experiment_id=experiment_id, nested=True, run_name=self.camera) as child_run:
return child_run.info.run_id
def get_current_run(self, experiment_id, city_run_id, line_run_id, camera_run_id):
current_child_runs_df = mlflow.search_runs(experiment_id,
filter_string=f"tags.mlflow.parentRunId='{camera_run_id}'")
if not current_child_runs_df.empty:
child_runs_df = current_child_runs_df[current_child_runs_df['tags.mlflow.runName'] == self.current_run_name]
if child_runs_df.empty:
with mlflow.start_run(experiment_id=experiment_id, run_id=city_run_id, nested=True):
with mlflow.start_run(experiment_id=experiment_id, nested=True, run_id=line_run_id):
with mlflow.start_run(experiment_id=experiment_id, nested=True, run_id=camera_run_id):
with mlflow.start_run(experiment_id=experiment_id, nested=True,
run_name=self.current_run_name) as child_run:
return child_run.info.run_id
return list(child_runs_df['run_id'])[0]
else:
with mlflow.start_run(experiment_id=experiment_id, run_id=city_run_id, nested=True):
with mlflow.start_run(experiment_id=experiment_id, nested=True, run_id=line_run_id):
with mlflow.start_run(experiment_id=experiment_id, nested=True, run_id=camera_run_id):
with mlflow.start_run(experiment_id=experiment_id, nested=True,
run_name=self.current_run_name) as child_run:
return child_run.info.run_id
@staticmethod
def flatten_dict(dd, separator='_', prefix=''):
stack = [(dd, prefix)]
flat_dict = {}
while stack:
cur_dict, cur_prefix = stack.pop()
for key, val in cur_dict.items():
new_key = cur_prefix + separator + key if cur_prefix else key
if isinstance(val, dict):
stack.append((val, new_key))
else:
flat_dict[new_key] = val
return flat_dict
def start_training(self):
"""
This is the Main function which will return the latest model
:return:
"""
experiment_id = self.check_create_experiment()
parent_run_id = self.check_create_parent_run(experiment_id)
child_run_id = self.check_create_child_run(experiment_id, parent_run_id)
camera_run_id = self.create_camera_run(experiment_id=experiment_id, city_run_id=parent_run_id,
line_run_id=child_run_id)
current_run_id = self.get_current_run(experiment_id=experiment_id, city_run_id=parent_run_id,
line_run_id=child_run_id, camera_run_id=camera_run_id)
with mlflow.start_run(run_id=current_run_id):
logger.info('Creating the new model !')
vgg16 = VGG16Training(self.master_config, self.training_path, self.validation_path, self.job_id)
metrics, results_directory = vgg16.train_model()
metrics = metrics[-2]
metrics = {'avg_loss_train': float(metrics.get('avg_loss_train')),
'avg_acc_train': float(metrics.get('avg_acc_train')),
'avg_loss_val': float(metrics.get('avg_loss_val')),
'avg_acc_val': float(metrics.get('avg_acc_val'))}
logger.info(f'metrics - {metrics}')
self.log_metrics(metrics=metrics)
for each in os.listdir(results_directory):
self.log_model(model_path=os.path.join(results_directory, each))
tracemalloc.clear_traces()
tracemalloc.get_traced_memory()
# with mlflow.start_run(run_id=camera_run_id):
# self.log_metrics(metrics=metrics)
logger.info(f"Loading the model from the child run id {camera_run_id}")
return results_directory
@staticmethod
def log_model(model_path):
"""
Function is to log the model
:param model_path : model Path
:return: Boolean Value
"""
try:
mlflow.log_artifact(model_path)
logger.info("logged the model")
return True
except Exception as e:
logger.exception(str(e))
@staticmethod
def log_metrics(metrics):
"""
Function is to log the metrics
:param metrics: dict of metrics
:return: Boolen value
"""
try:
updated_metric = {}
for key, value in metrics.items():
key = re.sub("[\([{})\]]", "", key)
updated_metric[key] = value
mlflow.log_metrics(updated_metric)
logger.info('logged the model metric')
return True
except Exception as e:
logger.exception(str(e))
@staticmethod
def log_hyper_param(hyperparameters):
"""
Function is to log the hyper params
:param hyperparameters: dict of hyperparameters
:return: Boolen value
"""
try:
mlflow.log_params(hyperparameters)
logger.debug('logged model hyper parameters')
return True
except Exception as e:
logger.exception(str(e))
@staticmethod
def set_tag(run_id, params):
"""
Function is to set the tag for a particular run
:param run_id: Run id in which the tags need to be added
:param key: Name of the key
:param value: what needs to tagged in the value
"""
try:
for i, k in zip(params.keys(), params.values()):
client.set_tag(run_id=run_id, key=i, value=k)
logger.debug('set the tag for the model')
except Exception as e:
logger.exception(f"Exception while setting the tag - {e}")
import os
import shutil
import cv2
import numpy as np
from loguru import logger
from pathlib import Path
import random
class DataAugmentation:
"""
Handles with various augmentations for dataset.
"""
def __init__(self, functions):
self.functions = functions
# [AugmentImage({each_prop['property']: each_prop}) for each_prop in functions]
# for item in functions:
# self.values = tuple([item['value']])
# print(self.values)
def fill(self, img, h, w):
img = cv2.resize(img, (h, w), cv2.INTER_CUBIC)
return img
def horizontal_shift(self, img):
for augmentation in self.functions:
if "horizontal_shift" == augmentation['property']:
ratio = augmentation['value']
if ratio > 1 or ratio < 0:
print('Value should be less than 1 and greater than 0')
return img
ratio = random.uniform(-ratio, ratio)
h, w = img.shape[:2]
to_shift = w * ratio
if ratio > 0:
img = img[:, :int(w - to_shift), :]
if ratio < 0:
img = img[:, int(-1 * to_shift):, :]
img = self.fill(img, h, w)
return img
def vertical_shift(self, img):
for augmentation in self.functions:
if "vertical_shift" == augmentation['property']:
ratio = augmentation['value']
if ratio > 1 or ratio < 0:
print('Value should be less than 1 and greater than 0')
return img
ratio = random.uniform(-ratio, ratio)
h, w = img.shape[:2]
to_shift = h * ratio
if ratio > 0:
img = img[:int(h - to_shift), :, :]
if ratio < 0:
img = img[int(-1 * to_shift):, :, :]
img = self.fill(img, h, w)
return img
def brightness(self, img):
for augmentation in self.functions:
if "brightness" == augmentation['property']:
low = augmentation['value']
value = random.uniform(low, low + 2.5)
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
hsv = np.array(hsv, dtype=np.float64)
hsv[:, :, 1] = hsv[:, :, 1] * value
hsv[:, :, 1][hsv[:, :, 1] > 255] = 255
hsv[:, :, 2] = hsv[:, :, 2] * value
hsv[:, :, 2][hsv[:, :, 2] > 255] = 255
hsv = np.array(hsv, dtype=np.uint8)
img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
return img
def channel_shift(self, img):
for augmentation in self.functions:
if "channel_shift" == augmentation['property']:
value = augmentation['value']
value = int(random.uniform(-value, value))
img = img + value
img[:, :, :][img[:, :, :] > 255] = 255
img[:, :, :][img[:, :, :] < 0] = 0
img = img.astype(np.uint8)
return img
def zoom(self, img):
for augmentation in self.functions:
if "zoom" == augmentation['property']:
value = augmentation['value']
if value > 1 or value < 0:
print('Value for zoom should be less than 1 and greater than 0')
return img
value = random.uniform(value, 1)
h, w = img.shape[:2]
h_taken = int(value * h)
w_taken = int(value * w)
h_start = random.randint(0, h - h_taken)
w_start = random.randint(0, w - w_taken)
img = img[h_start:h_start + h_taken, w_start:w_start + w_taken, :]
img = self.fill(img, h, w)
return img
def horizontal_flip(self, img):
for augmentation in self.functions:
if "horizontal_flip" == augmentation['property']:
flag = augmentation['value']
if flag:
return cv2.flip(img, 1)
else:
return img
def vertical_flip(self, img):
for augmentation in self.functions:
if "vertical_flip" == augmentation['property']:
flag = augmentation['value']
if flag:
return cv2.flip(img, 0)
else:
return img
def rotation(self, img):
for augmentation in self.functions:
if "rotation" == augmentation['property']:
angle = augmentation['value']
angle = int(random.uniform(-angle, angle))
h, w = img.shape[:2]
M = cv2.getRotationMatrix2D((int(w / 2), int(h / 2)), angle, 1)
img = cv2.warpAffine(img, M, (w, h))
return img
def process(self, annotation_directory, post_process_directory):
assert os.path.exists(annotation_directory)
if not os.path.exists(post_process_directory):
os.mkdir(post_process_directory)
logger.info(f"Path: {post_process_directory} does not exist, creating one now!")
for each_file in os.listdir(annotation_directory):
filename, file_extension = os.path.splitext(os.path.join(annotation_directory, each_file))
if file_extension in ['.jpg', '.jpeg', '.png']:
image = cv2.imread(os.path.join(annotation_directory, each_file))
multi_images = (
self.horizontal_shift(image), self.vertical_shift(image), self.brightness(image),
self.zoom(image), self.channel_shift(image), self.horizontal_flip(image),
self.vertical_flip(image), self.rotation(image))
_file_name = 0
for each_element in multi_images:
image = each_element
cv2.imwrite(
os.path.join(post_process_directory, f"{each_file[:-4]}" + "_" + f"{_file_name}" + ".jpg"),
image)
_file_name = _file_name + 1
def combine_dataset(self, augmented_dataset_path, unaugmented_dataset_path, cls):
for each_file in os.listdir(os.path.join(unaugmented_dataset_path, cls)):
shutil.copy(os.path.join(unaugmented_dataset_path, cls, each_file),
os.path.join(augmented_dataset_path, cls))
from zipfile import ZipFile
import os
import shutil
class ExtractDataset:
"""
Extracts the dataset.
"""
def __init__(self):
self.dest_path = None
self.dataset_path = None
def extract_ds(self, dataset_path, dest_path):
self.dataset_path = dataset_path
self.dest_path = dest_path
with ZipFile(self.dataset_path, 'r') as zObject:
zObject.extractall(path=dest_path)
def move_files(self, dataset_path, dest_path):
self.dataset_path = dataset_path
self.dest_path = dest_path
dataset_outer_path = os.listdir(self.dataset_path)
classes = os.listdir(os.path.join(self.dataset_path, dataset_outer_path[0]))
for cls in classes:
shutil.move(os.path.join(self.dataset_path, dataset_outer_path[0], cls), os.path.join(self.dest_path, cls))
# for name in files:
# shutil.move(os.path.join(root, name), os.path.join(self.dest_path, name))
import os
import shutil
import cv2
from loguru import logger
class TrainTestSplit:
"""
Split the dataset into three.
train
test
valid
"""
def __init__(self, dataset_path, job_id, validation_size=0.20, test_size=0.10):
self.path = dataset_path
self.folder_train_name = os.path.join(job_id, "train")
self.folder_valid_name = os.path.join(job_id, "valid")
self.folder_test_name = os.path.join(job_id, "test")
self.validation_size = validation_size
self.test_size = test_size
if not os.path.exists(self.folder_train_name):
os.mkdir(self.folder_train_name)
if not os.path.exists(self.folder_valid_name):
os.mkdir(self.folder_valid_name)
if not os.path.exists(self.folder_test_name):
os.mkdir(self.folder_test_name)
def train_test_split(self):
main_dir = os.listdir(self.path)
list_image = []
for each in os.listdir(self.path):
for img in os.listdir(os.path.join(self.path, each)):
list_image.append(cv2.imread(os.path.join(self.path, each, img)))
count = len(list_image)
for each_dir in main_dir:
os.mkdir(os.path.join(self.folder_train_name, each_dir))
os.mkdir(os.path.join(self.folder_test_name, each_dir))
os.mkdir(os.path.join(self.folder_valid_name, each_dir))
files = os.listdir(os.path.join(self.path, each_dir))
train_per = round(len(files) * (1 - (self.validation_size + self.test_size)))
valid_per = round(len(files) * self.validation_size)
test_per = round(len(files) * self.test_size)
for every_file in files[:train_per]:
shutil.copyfile(os.path.join(self.path, each_dir, every_file),
os.path.join(self.folder_train_name, each_dir, every_file))
for every_file in files[train_per:train_per + valid_per]:
shutil.copyfile(os.path.join(self.path, each_dir, every_file),
os.path.join(self.folder_valid_name, each_dir, every_file))
for every_file in files[train_per + valid_per:]:
shutil.copyfile(os.path.join(self.path, each_dir, every_file),
os.path.join(self.folder_test_name, each_dir, every_file))
return self.folder_train_name, self.folder_valid_name, self.folder_test_name, count
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment