Commit 15cdc800 authored by Sikhin VC's avatar Sikhin VC

initial commit

parent ef41ac0c
# These are supported funding model platforms
github: [lucidrains]
# This workflows will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Upload Python Package
on:
release:
types: [created]
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
python setup.py sdist bdist_wheel
twine upload dist/*
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.aim
models/
results/
MIT License
Copyright (c) 2021 Phil Wang
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
This diff is collapsed.
from lightweight_gan.lightweight_gan import LightweightGAN, Generator, Discriminator, Trainer, NanException
from kornia.filters import filter2d
import os
import fire
import random
from retry.api import retry_call
from tqdm import tqdm
from datetime import datetime
from functools import wraps
from lightweight_gan import Trainer, NanException
from lightweight_gan.diff_augment_test import DiffAugmentTest
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
import numpy as np
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_list(el):
return el if isinstance(el, list) else [el]
def timestamped_filename(prefix = 'generated-'):
now = datetime.now()
timestamp = now.strftime("%m-%d-%Y_%H-%M-%S")
return f'{prefix}{timestamp}'
def set_seed(seed):
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)
def run_training(rank, world_size, model_args, data, load_from, new, num_train_steps, name, seed, use_aim, aim_repo, aim_run_hash):
is_main = rank == 0
is_ddp = world_size > 1
if is_ddp:
set_seed(seed)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group('nccl', rank=rank, world_size=world_size)
print(f"{rank + 1}/{world_size} process initialized.")
model_args.update(
is_ddp = is_ddp,
rank = rank,
world_size = world_size
)
model = Trainer(**model_args, hparams=model_args, use_aim=use_aim, aim_repo=aim_repo, aim_run_hash=aim_run_hash)
if not new:
model.load(load_from)
else:
model.clear()
model.set_data_src(data)
progress_bar = tqdm(initial = model.steps, total = num_train_steps, mininterval=10., desc=f'{name}<{data}>')
while model.steps < num_train_steps:
retry_call(model.train, tries=3, exceptions=NanException)
progress_bar.n = model.steps
progress_bar.refresh()
if is_main and model.steps % 50 == 0:
model.print_log()
model.save(model.checkpoint_num)
if is_ddp:
dist.destroy_process_group()
def train_from_folder(
data = './data',
results_dir = './results',
models_dir = './models',
name = 'default',
new = False,
load_from = -1,
image_size = 256,
optimizer = 'adam',
fmap_max = 512,
transparent = False,
greyscale = False,
batch_size = 10,
gradient_accumulate_every = 4,
num_train_steps = 150000,
learning_rate = 2e-4,
save_every = 1000,
evaluate_every = 1000,
generate = False,
generate_types = ['default', 'ema'],
generate_interpolation = False,
aug_test = False,
aug_prob=None,
aug_types=['cutout', 'translation'],
dataset_aug_prob=0.,
attn_res_layers = [32],
freq_chan_attn = False,
disc_output_size = 1,
dual_contrast_loss = False,
antialias = False,
interpolation_num_steps = 100,
save_frames = False,
num_image_tiles = None,
num_workers = None,
multi_gpus = False,
calculate_fid_every = None,
calculate_fid_num_images = 12800,
clear_fid_cache = False,
seed = 42,
amp = False,
show_progress = False,
use_aim = False,
aim_repo = None,
aim_run_hash = None,
load_strict = True
):
num_image_tiles = default(num_image_tiles, 4 if image_size > 512 else 8)
model_args = dict(
name = name,
results_dir = results_dir,
models_dir = models_dir,
batch_size = batch_size,
gradient_accumulate_every = gradient_accumulate_every,
attn_res_layers = cast_list(attn_res_layers),
freq_chan_attn = freq_chan_attn,
disc_output_size = disc_output_size,
dual_contrast_loss = dual_contrast_loss,
antialias = antialias,
image_size = image_size,
num_image_tiles = num_image_tiles,
optimizer = optimizer,
num_workers = num_workers,
fmap_max = fmap_max,
transparent = transparent,
greyscale = greyscale,
lr = learning_rate,
save_every = save_every,
evaluate_every = evaluate_every,
aug_prob = aug_prob,
aug_types = cast_list(aug_types),
dataset_aug_prob = dataset_aug_prob,
calculate_fid_every = calculate_fid_every,
calculate_fid_num_images = calculate_fid_num_images,
clear_fid_cache = clear_fid_cache,
amp = amp,
load_strict = load_strict
)
if generate:
model = Trainer(**model_args, use_aim = use_aim)
model.load(load_from)
samples_name = timestamped_filename()
checkpoint = model.checkpoint_num
dir_result = model.generate(samples_name, num_image_tiles, checkpoint, generate_types)
print(f'sample images generated at {dir_result}')
return
if generate_interpolation:
model = Trainer(**model_args, use_aim = use_aim)
model.load(load_from)
samples_name = timestamped_filename()
model.generate_interpolation(samples_name, num_image_tiles, num_steps = interpolation_num_steps, save_frames = save_frames)
print(f'interpolation generated at {results_dir}/{name}/{samples_name}')
return
if show_progress:
model = Trainer(**model_args, use_aim = use_aim)
model.show_progress(num_images=num_image_tiles, types=generate_types)
return
if aug_test:
DiffAugmentTest(data=data, image_size=image_size, batch_size=batch_size, types=aug_types, nrow=num_image_tiles)
return
world_size = torch.cuda.device_count()
if world_size == 1 or not multi_gpus:
run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed, use_aim, aim_repo, aim_run_hash)
return
mp.spawn(run_training,
args=(world_size, model_args, data, load_from, new, num_train_steps, name, seed, use_aim, aim_repo, aim_run_hash,),
nprocs=world_size,
join=True)
def main():
fire.Fire(train_from_folder)
import random
import torch
import torch.nn.functional as F
def DiffAugment(x, types=[]):
for p in types:
for f in AUGMENT_FNS[p]:
x = f(x)
return x.contiguous()
# """
# Augmentation functions got images as `x`
# where `x` is tensor with this dimensions:
# 0 - count of images
# 1 - channels
# 2 - width
# 3 - height of image
# """
def rand_brightness(x):
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
return x
def rand_saturation(x):
x_mean = x.mean(dim=1, keepdim=True)
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
return x
def rand_contrast(x):
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
return x
def rand_translation(x, ratio=0.125):
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(x.size(2), dtype=torch.long, device=x.device),
torch.arange(x.size(3), dtype=torch.long, device=x.device),
indexing = 'ij')
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
return x
def rand_offset(x, ratio=1, ratio_h=1, ratio_v=1):
w, h = x.size(2), x.size(3)
imgs = []
for img in x.unbind(dim = 0):
max_h = int(w * ratio * ratio_h)
max_v = int(h * ratio * ratio_v)
value_h = random.randint(0, max_h) * 2 - max_h
value_v = random.randint(0, max_v) * 2 - max_v
if abs(value_h) > 0:
img = torch.roll(img, value_h, 2)
if abs(value_v) > 0:
img = torch.roll(img, value_v, 1)
imgs.append(img)
return torch.stack(imgs)
def rand_offset_h(x, ratio=1):
return rand_offset(x, ratio=1, ratio_h=ratio, ratio_v=0)
def rand_offset_v(x, ratio=1):
return rand_offset(x, ratio=1, ratio_h=0, ratio_v=ratio)
def rand_cutout(x, ratio=0.5):
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
indexing = 'ij')
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
mask[grid_batch, grid_x, grid_y] = 0
x = x * mask.unsqueeze(1)
return x
AUGMENT_FNS = {
'color': [rand_brightness, rand_saturation, rand_contrast],
'offset': [rand_offset],
'offset_h': [rand_offset_h],
'offset_v': [rand_offset_v],
'translation': [rand_translation],
'cutout': [rand_cutout],
}
import os
import tempfile
from pathlib import Path
from shutil import copyfile
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from lightweight_gan.lightweight_gan import AugWrapper, ImageDataset
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
@torch.no_grad()
def DiffAugmentTest(image_size = 256, data = './data/0.jpg', types = [], batch_size = 10, rank = 0, nrow = 5):
model = DummyModel()
aug_wrapper = AugWrapper(model, image_size)
with tempfile.TemporaryDirectory() as directory:
file = Path(data)
if os.path.exists(file):
file_name, ext = os.path.splitext(data)
for i in range(batch_size):
tmp_file_name = str(i) + ext
copyfile(file, os.path.join(directory, tmp_file_name))
dataset = ImageDataset(directory, image_size, aug_prob=0)
dataloader = DataLoader(dataset, batch_size=batch_size)
image_batch = next(iter(dataloader)).cuda(rank)
images_augment = aug_wrapper(images=image_batch, prob=1, types=types, detach=True)
save_result = file_name + f'_augs{ext}'
torchvision.utils.save_image(images_augment, save_result, nrow=nrow)
print('Save result to:', save_result)
else:
print('File not found. File', file)
This diff is collapsed.
__version__ = '1.1.1'
import sys
from setuptools import setup, find_packages
sys.path[0:0] = ['lightweight_gan']
from version import __version__
setup(
name = 'lightweight-gan',
packages = find_packages(),
entry_points={
'console_scripts': [
'lightweight_gan = lightweight_gan.cli:main',
],
},
version = __version__,
license='MIT',
description = 'Lightweight GAN',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
url = 'https://github.com/lucidrains/lightweight-gan',
keywords = [
'artificial intelligence',
'deep learning',
'generative adversarial networks'
],
install_requires=[
'adabelief-pytorch',
'einops>=0.3',
'fire',
'kornia>=0.5.4',
'numpy',
'pillow',
'retry',
'torch>=1.10',
'torchvision',
'tqdm'
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment