Commit 15cdc800 authored by Sikhin VC's avatar Sikhin VC

initial commit

parent ef41ac0c
# These are supported funding model platforms
github: [lucidrains]
# This workflows will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Upload Python Package
on:
release:
types: [created]
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
python setup.py sdist bdist_wheel
twine upload dist/*
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.aim
models/
results/
MIT License
Copyright (c) 2021 Phil Wang
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# lightweight_gan_synthetic_dataset_generation <img src="./images/sample-512.jpg" width="600px"></img>
*512x512 flowers after 12 hours of training, 1 gpu*
<img src="./images/sample-256.jpg" width="400px"></img>
*256x256 flowers after 12 hours of training, 1 gpu*
<img src="./images/pizza-512.jpg" width="600px"></img>
*Pizza*
## 'Lightweight' GAN
[![PyPI version](https://badge.fury.io/py/lightweight-gan.svg)](https://badge.fury.io/py/lightweight-gan)
Implementation of <a href="https://openreview.net/forum?id=1Fqg133qRaI">'lightweight' GAN</a> proposed in ICLR 2021, in Pytorch. The main contributions of the paper is a skip-layer excitation in the generator, paired with autoencoding self-supervised learning in the discriminator. Quoting the one-line summary "converge on single gpu with few hours' training, on 1024 resolution sub-hundred images".
## Install
```bash
$ pip install lightweight-gan
```
## Use
One command
```bash
$ lightweight_gan --data ./path/to/images --image-size 512
```
Model will be saved to `./models/{name}` every 1000 iterations, and samples from the model saved to `./results/{name}`. `name` will be `default`, by default.
## Training settings
Pretty self explanatory for deep learning practitioners
```bash
$ lightweight_gan \
--data ./path/to/images \
--name {name of run} \
--batch-size 16 \
--gradient-accumulate-every 4 \
--num-train-steps 200000
```
## Augmentation
Augmentation is essential for Lightweight GAN to work effectively in a low data setting
By default, the augmentation types is set to translation and cutout, with color omitted. You can include color as well with the following.
```bash
$ lightweight_gan --data ./path/to/images --aug-prob 0.25 --aug-types [translation,cutout,color]
```
### Test augmentation
You can test and see how your images will be augmented before it pass into a neural network (if you use augmentation). Let's see how it works on this image:
![](./docs/aug_test/lena.jpg)
#### Basic usage
Base code to augment your image, define `--aug-test` and put path to your image into `--data`:
```bash
lightweight_gan \
--aug-test \
--data ./path/to/lena.jpg
```
After this will be created the file lena_augs.jpg that will be look something like this:
![](./docs/aug_test/lena_augs_default.jpg)
#### Options
You can use some options to change result:
- `--image-size 256` to change size of image tiles in the result. Default: `256`.
- `--aug-type [color,cutout,translation]` to combine several augmentations. Default: `[cutout,translation]`.
- `--batch-size 10` to change count of images in the result image. Default: `10`.
- `--num-image-tiles 5` to change count of tiles in the result image. Default: `5`.
Try this command:
```bash
lightweight_gan \
--aug-test \
--data ./path/to/lena.jpg \
--batch-size 16 \
--num-image-tiles 4 \
--aug-types [color,translation]
```
result wil be something like that:
![](./docs/aug_test/lena_augs.jpg)
### Types of augmentations
This library contains several types of embedded augmentations.
Some of these works by default, some of these can be controlled from a command as options in the `--aug-types`:
- Horizontal flip (work by default, not under control, runs in the AugWrapper class);
- `color` randomly change brightness, saturation and contrast;
- `cutout` creates random black boxes on the image;
- `offset` randomly moves image by x and y-axis with repeating image;
- `offset_h` only by an x-axis;
- `offset_v` only by a y-axis;
- `translation` randomly moves image on the canvas with black background;
Full setup of augmentations is `--aug-types [color,cutout,offset,translation]`.
General recommendation is using suitable augs for your data and as many as possible, then after sometime of training disable most destructive (for image) augs.
#### Color
![](./docs/aug_types/lena_augs_color.jpg)
#### Cutout
![](./docs/aug_types/lena_augs_cutout.jpg)
#### Offset
![](./docs/aug_types/lena_augs_offset.jpg)
Only x-axis:
![](./docs/aug_types/lena_augs_offset_h.jpg)
Only y-axis:
![](./docs/aug_types/lena_augs_offset_v.jpg)
#### Translation
![](./docs/aug_types/lena_augs_translation.jpg)
## Mixed precision
You can turn on automatic mixed precision with one flag `--amp`
You should expect it to be 33% faster and save up to 40% memory
## Multiple GPUs
Also one flag to use `--multi-gpus`
## Visualizing training insights with Aim
[Aim](https://github.com/aimhubio/aim) is an open-source experiment tracker that logs your training runs, enables a beautiful UI to compare them and an API to query them programmatically.
First you need to install `aim` with `pip`
```bash
$ pip install aim
```
Next, you can specify Aim logs directory with `--aim_repo` flag, otherwise logs will be stored in the current directory
```bash
$ lightweight_gan --data ./path/to/images --image-size 512 --use-aim --aim_repo ./path/to/logs/
```
Execute `aim up --repo ./path/to/logs/` to run Aim UI on your server.
**View all tracked runs, each metric last tracked values and tracked hyperparameters in Runs Dashboard:**
<img width="1431" alt="Screen Shot 2022-04-19 at 00 48 55" src="https://user-images.githubusercontent.com/11066664/163875698-dc497334-1f77-4e18-a37e-ac0f874b9814.png">
**Compare loss curves with Metrics Explorer - group and aggregate by any hyperparameter to easily compare the runs:**
<img width="1440" alt="Screen Shot 2022-04-12 at 16 56 35" src="https://user-images.githubusercontent.com/11066664/163875452-1da3bf36-f3bc-449f-906e-cebaf9a4fd6c.png">
**Compare and debug generated images across training steps and runs via Images Explorer:**
<img width="1439" alt="Screen Shot 2022-04-12 at 16 57 24" src="https://user-images.githubusercontent.com/11066664/163875815-9cd8ce85-2815-4f0a-80dd-0f3258193c19.png">
## Generating
Once you have finished training, you can generate samples with one command. You can select which checkpoint number to load from. If `--load-from` is not specified, will default to the latest.
```bash
$ lightweight_gan \
--name {name of run} \
--load-from {checkpoint num} \
--generate \
--generate-types {types of result, default: [default,ema]} \
--num-image-tiles {count of image result}
```
After run this command you will get folder near results image folder with postfix "-generated-{checkpoint num}".
You can also generate interpolations
```bash
$ lightweight_gan --name {name of run} --generate-interpolation
```
## Show progress
After creating several checkpoints of model you can generate progress as sequence images by command:
```bash
$ lightweight_gan \
--name {name of run} \
--show-progress \
--generate-types {types of result, default: [default,ema]} \
--num-image-tiles {count of image result}
```
After running this command you will get a new folder in the results folder, with postfix "-progress". You can convert the images to a video with ffmpeg using the command "ffmpeg -framerate 10 -pattern_type glob -i '*-ema.jpg' out.mp4".
![Show progress gif demonstration](./docs/show_progress/show-progress.gif)
![Show progress video demonstration](./docs/show_progress/show-progress.mp4)
## Discriminator output size
The author has kindly let me know that the discriminator output size (5x5 vs 1x1) leads to different results on different datasets. (5x5 works better for art than for faces, as an example). You can toggle this with a single flag
```bash
# disc output size is by default 1x1
$ lightweight_gan --data ./path/to/art --image-size 512 --disc-output-size 5
```
## Attention
You can add linear + axial attention to specific resolution layers with the following
```bash
# make sure there are no spaces between the values within the brackets []
$ lightweight_gan --data ./path/to/images --image-size 512 --attn-res-layers [32,64] --aug-prob 0.25
```
## Dual Contrastive Loss
A recent paper has proposed that a novel contrastive loss between the real and fake logits can improve quality slightly over the default hinge loss.
You can use this with one extra flag as follows
```bash
$ lightweight_gan --data ./path/to/images --dual-contrast-loss
```
## Bonus
You can also train with transparent images
```bash
$ lightweight_gan --data ./path/to/images --transparent
```
Or greyscale
```bash
$ lightweight_gan --data ./path/to/images --greyscale
```
## Alternatives
If you want the current state of the art GAN, you can find it at https://github.com/lucidrains/stylegan2-pytorch
## Citations
```bibtex
@inproceedings{
anonymous2021towards,
title = {Towards Faster and Stabilized {\{}GAN{\}} Training for High-fidelity Few-shot Image Synthesis},
author = {Anonymous},
booktitle = {Submitted to International Conference on Learning Representations},
year = {2021},
url = {https://openreview.net/forum?id=1Fqg133qRaI},
note = {under review}
}
```
```bibtex
@misc{cao2020global,
title = {Global Context Networks},
author = {Yue Cao and Jiarui Xu and Stephen Lin and Fangyun Wei and Han Hu},
year = {2020},
eprint = {2012.13375},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{qin2020fcanet,
title = {FcaNet: Frequency Channel Attention Networks},
author = {Zequn Qin and Pengyi Zhang and Fei Wu and Xi Li},
year = {2020},
eprint = {2012.11879},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{yu2021dual,
title = {Dual Contrastive Loss and Attention for GANs},
author = {Ning Yu and Guilin Liu and Aysegul Dundar and Andrew Tao and Bryan Catanzaro and Larry Davis and Mario Fritz},
year = {2021},
eprint = {2103.16748},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@article{Sunkara2022NoMS,
title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
author = {Raja Sunkara and Tie Luo},
journal = {ArXiv},
year = {2022},
volume = {abs/2208.03641}
}
```
*What I cannot create, I do not understand* - Richard Feynman
from lightweight_gan.lightweight_gan import LightweightGAN, Generator, Discriminator, Trainer, NanException
from kornia.filters import filter2d
import os
import fire
import random
from retry.api import retry_call
from tqdm import tqdm
from datetime import datetime
from functools import wraps
from lightweight_gan import Trainer, NanException
from lightweight_gan.diff_augment_test import DiffAugmentTest
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
import numpy as np
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_list(el):
return el if isinstance(el, list) else [el]
def timestamped_filename(prefix = 'generated-'):
now = datetime.now()
timestamp = now.strftime("%m-%d-%Y_%H-%M-%S")
return f'{prefix}{timestamp}'
def set_seed(seed):
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)
def run_training(rank, world_size, model_args, data, load_from, new, num_train_steps, name, seed, use_aim, aim_repo, aim_run_hash):
is_main = rank == 0
is_ddp = world_size > 1
if is_ddp:
set_seed(seed)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group('nccl', rank=rank, world_size=world_size)
print(f"{rank + 1}/{world_size} process initialized.")
model_args.update(
is_ddp = is_ddp,
rank = rank,
world_size = world_size
)
model = Trainer(**model_args, hparams=model_args, use_aim=use_aim, aim_repo=aim_repo, aim_run_hash=aim_run_hash)
if not new:
model.load(load_from)
else:
model.clear()
model.set_data_src(data)
progress_bar = tqdm(initial = model.steps, total = num_train_steps, mininterval=10., desc=f'{name}<{data}>')
while model.steps < num_train_steps:
retry_call(model.train, tries=3, exceptions=NanException)
progress_bar.n = model.steps
progress_bar.refresh()
if is_main and model.steps % 50 == 0:
model.print_log()
model.save(model.checkpoint_num)
if is_ddp:
dist.destroy_process_group()
def train_from_folder(
data = './data',
results_dir = './results',
models_dir = './models',
name = 'default',
new = False,
load_from = -1,
image_size = 256,
optimizer = 'adam',
fmap_max = 512,
transparent = False,
greyscale = False,
batch_size = 10,
gradient_accumulate_every = 4,
num_train_steps = 150000,
learning_rate = 2e-4,
save_every = 1000,
evaluate_every = 1000,
generate = False,
generate_types = ['default', 'ema'],
generate_interpolation = False,
aug_test = False,
aug_prob=None,
aug_types=['cutout', 'translation'],
dataset_aug_prob=0.,
attn_res_layers = [32],
freq_chan_attn = False,
disc_output_size = 1,
dual_contrast_loss = False,
antialias = False,
interpolation_num_steps = 100,
save_frames = False,
num_image_tiles = None,
num_workers = None,
multi_gpus = False,
calculate_fid_every = None,
calculate_fid_num_images = 12800,
clear_fid_cache = False,
seed = 42,
amp = False,
show_progress = False,
use_aim = False,
aim_repo = None,
aim_run_hash = None,
load_strict = True
):
num_image_tiles = default(num_image_tiles, 4 if image_size > 512 else 8)
model_args = dict(
name = name,
results_dir = results_dir,
models_dir = models_dir,
batch_size = batch_size,
gradient_accumulate_every = gradient_accumulate_every,
attn_res_layers = cast_list(attn_res_layers),
freq_chan_attn = freq_chan_attn,
disc_output_size = disc_output_size,
dual_contrast_loss = dual_contrast_loss,
antialias = antialias,
image_size = image_size,
num_image_tiles = num_image_tiles,
optimizer = optimizer,
num_workers = num_workers,
fmap_max = fmap_max,
transparent = transparent,
greyscale = greyscale,
lr = learning_rate,
save_every = save_every,
evaluate_every = evaluate_every,
aug_prob = aug_prob,
aug_types = cast_list(aug_types),
dataset_aug_prob = dataset_aug_prob,
calculate_fid_every = calculate_fid_every,
calculate_fid_num_images = calculate_fid_num_images,
clear_fid_cache = clear_fid_cache,
amp = amp,
load_strict = load_strict
)
if generate:
model = Trainer(**model_args, use_aim = use_aim)
model.load(load_from)
samples_name = timestamped_filename()
checkpoint = model.checkpoint_num
dir_result = model.generate(samples_name, num_image_tiles, checkpoint, generate_types)
print(f'sample images generated at {dir_result}')
return
if generate_interpolation:
model = Trainer(**model_args, use_aim = use_aim)
model.load(load_from)
samples_name = timestamped_filename()
model.generate_interpolation(samples_name, num_image_tiles, num_steps = interpolation_num_steps, save_frames = save_frames)
print(f'interpolation generated at {results_dir}/{name}/{samples_name}')
return
if show_progress:
model = Trainer(**model_args, use_aim = use_aim)
model.show_progress(num_images=num_image_tiles, types=generate_types)
return
if aug_test:
DiffAugmentTest(data=data, image_size=image_size, batch_size=batch_size, types=aug_types, nrow=num_image_tiles)
return
world_size = torch.cuda.device_count()
if world_size == 1 or not multi_gpus:
run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed, use_aim, aim_repo, aim_run_hash)
return
mp.spawn(run_training,
args=(world_size, model_args, data, load_from, new, num_train_steps, name, seed, use_aim, aim_repo, aim_run_hash,),
nprocs=world_size,
join=True)
def main():
fire.Fire(train_from_folder)
import random
import torch
import torch.nn.functional as F
def DiffAugment(x, types=[]):
for p in types:
for f in AUGMENT_FNS[p]:
x = f(x)
return x.contiguous()
# """
# Augmentation functions got images as `x`
# where `x` is tensor with this dimensions:
# 0 - count of images
# 1 - channels
# 2 - width
# 3 - height of image
# """
def rand_brightness(x):
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
return x
def rand_saturation(x):
x_mean = x.mean(dim=1, keepdim=True)
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
return x
def rand_contrast(x):
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
return x
def rand_translation(x, ratio=0.125):
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(x.size(2), dtype=torch.long, device=x.device),
torch.arange(x.size(3), dtype=torch.long, device=x.device),
indexing = 'ij')
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
return x
def rand_offset(x, ratio=1, ratio_h=1, ratio_v=1):
w, h = x.size(2), x.size(3)
imgs = []
for img in x.unbind(dim = 0):
max_h = int(w * ratio * ratio_h)
max_v = int(h * ratio * ratio_v)
value_h = random.randint(0, max_h) * 2 - max_h
value_v = random.randint(0, max_v) * 2 - max_v
if abs(value_h) > 0:
img = torch.roll(img, value_h, 2)
if abs(value_v) > 0:
img = torch.roll(img, value_v, 1)
imgs.append(img)
return torch.stack(imgs)
def rand_offset_h(x, ratio=1):
return rand_offset(x, ratio=1, ratio_h=ratio, ratio_v=0)
def rand_offset_v(x, ratio=1):
return rand_offset(x, ratio=1, ratio_h=0, ratio_v=ratio)
def rand_cutout(x, ratio=0.5):
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
indexing = 'ij')
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
mask[grid_batch, grid_x, grid_y] = 0
x = x * mask.unsqueeze(1)
return x
AUGMENT_FNS = {
'color': [rand_brightness, rand_saturation, rand_contrast],
'offset': [rand_offset],
'offset_h': [rand_offset_h],
'offset_v': [rand_offset_v],
'translation': [rand_translation],
'cutout': [rand_cutout],
}
import os
import tempfile
from pathlib import Path
from shutil import copyfile
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from lightweight_gan.lightweight_gan import AugWrapper, ImageDataset
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
@torch.no_grad()
def DiffAugmentTest(image_size = 256, data = './data/0.jpg', types = [], batch_size = 10, rank = 0, nrow = 5):
model = DummyModel()
aug_wrapper = AugWrapper(model, image_size)
with tempfile.TemporaryDirectory() as directory:
file = Path(data)
if os.path.exists(file):
file_name, ext = os.path.splitext(data)
for i in range(batch_size):
tmp_file_name = str(i) + ext
copyfile(file, os.path.join(directory, tmp_file_name))
dataset = ImageDataset(directory, image_size, aug_prob=0)
dataloader = DataLoader(dataset, batch_size=batch_size)
image_batch = next(iter(dataloader)).cuda(rank)
images_augment = aug_wrapper(images=image_batch, prob=1, types=types, detach=True)
save_result = file_name + f'_augs{ext}'
torchvision.utils.save_image(images_augment, save_result, nrow=nrow)
print('Save result to:', save_result)
else:
print('File not found. File', file)
import os
import json
import multiprocessing
from random import random
import math
from math import log2, floor
from functools import lru_cache, partial
from contextlib import contextmanager, ExitStack
from pathlib import Path
from shutil import rmtree
import torch
from torch.cuda.amp import autocast, GradScaler
from torch.optim import Adam
from torch import nn, einsum
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.autograd import grad as torch_grad
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from PIL import Image
import torchvision
from torchvision import transforms
from kornia.filters import filter2d
from lightweight_gan.diff_augment import DiffAugment
from lightweight_gan.version import __version__
from tqdm import tqdm
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange
from adabelief_pytorch import AdaBelief
# asserts
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
# constants
NUM_CORES = multiprocessing.cpu_count()
EXTS = ['jpg', 'jpeg', 'png', 'tiff']
# helpers
def exists(val):
return val is not None
@contextmanager
def null_context():
yield
def combine_contexts(contexts):
@contextmanager
def multi_contexts():
with ExitStack() as stack:
yield [stack.enter_context(ctx()) for ctx in contexts]
return multi_contexts
def is_power_of_two(val):
return log2(val).is_integer()
def default(val, d):
return val if exists(val) else d
def set_requires_grad(model, bool):
for p in model.parameters():
p.requires_grad = bool
def cycle(iterable):
while True:
for i in iterable:
yield i
def raise_if_nan(t):
if torch.isnan(t):
raise NanException
def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps):
if is_ddp:
num_no_syncs = gradient_accumulate_every - 1
head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs
tail = [null_context]
contexts = head + tail
else:
contexts = [null_context] * gradient_accumulate_every
for context in contexts:
with context():
yield
def evaluate_in_chunks(max_batch_size, model, *args):
split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
chunked_outputs = [model(*i) for i in split_args]
if len(chunked_outputs) == 1:
return chunked_outputs[0]
return torch.cat(chunked_outputs, dim=0)
def slerp(val, low, high):
low_norm = low / torch.norm(low, dim=1, keepdim=True)
high_norm = high / torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm * high_norm).sum(1))
so = torch.sin(omega)
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res
def safe_div(n, d):
try:
res = n / d
except ZeroDivisionError:
prefix = '' if int(n >= 0) else '-'
res = float(f'{prefix}inf')
return res
# loss functions
def gen_hinge_loss(fake, real):
return fake.mean()
def hinge_loss(real, fake):
return (F.relu(1 + real) + F.relu(1 - fake)).mean()
def dual_contrastive_loss(real_logits, fake_logits):
device = real_logits.device
real_logits, fake_logits = map(lambda t: rearrange(t, '... -> (...)'), (real_logits, fake_logits))
def loss_half(t1, t2):
t1 = rearrange(t1, 'i -> i ()')
t2 = repeat(t2, 'j -> i j', i = t1.shape[0])
t = torch.cat((t1, t2), dim = -1)
return F.cross_entropy(t, torch.zeros(t1.shape[0], device = device, dtype = torch.long))
return loss_half(real_logits, fake_logits) + loss_half(-fake_logits, -real_logits)
@lru_cache(maxsize=10)
def det_randn(*args):
"""
deterministic random to track the same latent vars (and images) across training steps
helps to visualize same image over training steps
"""
return torch.randn(*args)
def interpolate_between(a, b, *, num_samples, dim):
assert num_samples > 2
samples = []
step_size = 0
for _ in range(num_samples):
sample = torch.lerp(a, b, step_size)
samples.append(sample)
step_size += 1 / (num_samples - 1)
return torch.stack(samples, dim=dim)
# helper classes
class NanException(Exception):
pass
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_average(self, old, new):
if not exists(old):
return new
return old * self.beta + (1 - self.beta) * new
class RandomApply(nn.Module):
def __init__(self, prob, fn, fn_else = lambda x: x):
super().__init__()
self.fn = fn
self.fn_else = fn_else
self.prob = prob
def forward(self, x):
fn = self.fn if random() < self.prob else self.fn_else
return fn(x)
class ChanNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = ChanNorm(dim)
def forward(self, x):
return self.fn(self.norm(x))
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
class SumBranches(nn.Module):
def __init__(self, branches):
super().__init__()
self.branches = nn.ModuleList(branches)
def forward(self, x):
return sum(map(lambda fn: fn(x), self.branches))
class Blur(nn.Module):
def __init__(self):
super().__init__()
f = torch.Tensor([1, 2, 1])
self.register_buffer('f', f)
def forward(self, x):
f = self.f
f = f[None, None, :] * f [None, :, None]
return filter2d(x, f, normalized=True)
class Noise(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1))
def forward(self, x, noise = None):
b, _, h, w, device = *x.shape, x.device
if not exists(noise):
noise = torch.randn(b, 1, h, w, device = device)
return x + self.weight * noise
def Conv2dSame(dim_in, dim_out, kernel_size, bias = True):
pad_left = kernel_size // 2
pad_right = (pad_left - 1) if (kernel_size % 2) == 0 else pad_left
return nn.Sequential(
nn.ZeroPad2d((pad_left, pad_right, pad_left, pad_right)),
nn.Conv2d(dim_in, dim_out, kernel_size, bias = bias)
)
# attention
class DepthWiseConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding = 0, stride = 1, bias = True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
)
def forward(self, x):
return self.net(x)
class LinearAttention(nn.Module):
def __init__(self, dim, dim_head = 64, heads = 8, kernel_size = 3):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
self.dim_head = dim_head
inner_dim = dim_head * heads
self.kernel_size = kernel_size
self.nonlin = nn.GELU()
self.to_lin_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_lin_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding = 1, bias = False)
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim * 2, dim, 1)
def forward(self, fmap):
h, x, y = self.heads, *fmap.shape[-2:]
# linear attention
lin_q, lin_k, lin_v = (self.to_lin_q(fmap), *self.to_lin_kv(fmap).chunk(2, dim = 1))
lin_q, lin_k, lin_v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (lin_q, lin_k, lin_v))
lin_q = lin_q.softmax(dim = -1)
lin_k = lin_k.softmax(dim = -2)
lin_q = lin_q * self.scale
context = einsum('b n d, b n e -> b d e', lin_k, lin_v)
lin_out = einsum('b n d, b d e -> b n e', lin_q, context)
lin_out = rearrange(lin_out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
# conv-like full attention
q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1))
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) c x y', h = h), (q, k, v))
k = F.unfold(k, kernel_size = self.kernel_size, padding = self.kernel_size // 2)
v = F.unfold(v, kernel_size = self.kernel_size, padding = self.kernel_size // 2)
k, v = map(lambda t: rearrange(t, 'b (d j) n -> b n j d', d = self.dim_head), (k, v))
q = rearrange(q, 'b c ... -> b (...) c') * self.scale
sim = einsum('b i d, b i j d -> b i j', q, k)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
full_out = einsum('b i j, b i j d -> b i d', attn, v)
full_out = rearrange(full_out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
# add outputs of linear attention + conv like full attention
lin_out = self.nonlin(lin_out)
out = torch.cat((lin_out, full_out), dim = 1)
return self.to_out(out)
# dataset
def convert_image_to(img_type, image):
if image.mode != img_type:
return image.convert(img_type)
return image
class identity(object):
def __call__(self, tensor):
return tensor
class expand_greyscale(object):
def __init__(self, transparent):
self.transparent = transparent
def __call__(self, tensor):
channels = tensor.shape[0]
num_target_channels = 4 if self.transparent else 3
if channels == num_target_channels:
return tensor
alpha = None
if channels == 1:
color = tensor.expand(3, -1, -1)
elif channels == 2:
color = tensor[:1].expand(3, -1, -1)
alpha = tensor[1:]
else:
raise Exception(f'image with invalid number of channels given {channels}')
if not exists(alpha) and self.transparent:
alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device)
return color if not self.transparent else torch.cat((color, alpha))
def resize_to_minimum_size(min_size, image):
if max(*image.size) < min_size:
return torchvision.transforms.functional.resize(image, min_size)
return image
class ImageDataset(Dataset):
def __init__(
self,
folder,
image_size,
transparent = False,
greyscale = False,
aug_prob = 0.
):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in EXTS for p in Path(f'{folder}').glob(f'**/*.{ext}')]
assert len(self.paths) > 0, f'No images were found in {folder} for training'
if transparent:
num_channels = 4
pillow_mode = 'RGBA'
expand_fn = expand_greyscale(transparent)
elif greyscale:
num_channels = 1
pillow_mode = 'L'
expand_fn = identity()
else:
num_channels = 3
pillow_mode = 'RGB'
expand_fn = expand_greyscale(transparent)
convert_image_fn = partial(convert_image_to, pillow_mode)
self.transform = transforms.Compose([
transforms.Lambda(convert_image_fn),
transforms.Lambda(partial(resize_to_minimum_size, image_size)),
transforms.Resize(image_size),
RandomApply(aug_prob, transforms.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)), transforms.CenterCrop(image_size)),
transforms.ToTensor(),
transforms.Lambda(expand_fn)
])
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
# augmentations
def random_hflip(tensor, prob):
if prob > random():
return tensor
return torch.flip(tensor, dims=(3,))
class AugWrapper(nn.Module):
def __init__(self, D, image_size):
super().__init__()
self.D = D
def forward(self, images, prob = 0., types = [], detach = False, **kwargs):
context = torch.no_grad if detach else null_context
with context():
if random() < prob:
images = random_hflip(images, prob=0.5)
images = DiffAugment(images, types=types)
return self.D(images, **kwargs)
# modifiable global variables
norm_class = nn.BatchNorm2d
class PixelShuffleUpsample(nn.Module):
def __init__(self, dim, dim_out = None):
super().__init__()
dim_out = default(dim_out, dim)
conv = nn.Conv2d(dim, dim_out * 4, 1)
self.net = nn.Sequential(
conv,
nn.SiLU(),
nn.PixelShuffle(2)
)
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, h, w = conv.weight.shape
conv_weight = torch.empty(o // 4, i, h, w)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
return self.net(x)
def SPConvDownsample(dim, dim_out = None):
# https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample
# named SP-conv in the paper, but basically a pixel unshuffle
dim_out = default(dim_out, dim)
return nn.Sequential(
Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),
nn.Conv2d(dim * 4, dim_out, 1)
)
# squeeze excitation classes
# global context network
# https://arxiv.org/abs/2012.13375
# similar to squeeze-excite, but with a simplified attention pooling and a subsequent layer norm
class GlobalContext(nn.Module):
def __init__(
self,
*,
chan_in,
chan_out
):
super().__init__()
self.to_k = nn.Conv2d(chan_in, 1, 1)
chan_intermediate = max(3, chan_out // 2)
self.net = nn.Sequential(
nn.Conv2d(chan_in, chan_intermediate, 1),
nn.LeakyReLU(0.1),
nn.Conv2d(chan_intermediate, chan_out, 1),
nn.Sigmoid()
)
def forward(self, x):
context = self.to_k(x)
context = context.flatten(2).softmax(dim = -1)
out = einsum('b i n, b c n -> b c i', context, x.flatten(2))
out = out.unsqueeze(-1)
return self.net(out)
# frequency channel attention
# https://arxiv.org/abs/2012.11879
def get_1d_dct(i, freq, L):
result = math.cos(math.pi * freq * (i + 0.5) / L) / math.sqrt(L)
return result * (1 if freq == 0 else math.sqrt(2))
def get_dct_weights(width, channel, fidx_u, fidx_v):
dct_weights = torch.zeros(1, channel, width, width)
c_part = channel // len(fidx_u)
for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)):
for x in range(width):
for y in range(width):
coor_value = get_1d_dct(x, u_x, width) * get_1d_dct(y, v_y, width)
dct_weights[:, i * c_part: (i + 1) * c_part, x, y] = coor_value
return dct_weights
class FCANet(nn.Module):
def __init__(
self,
*,
chan_in,
chan_out,
reduction = 4,
width
):
super().__init__()
freq_w, freq_h = ([0] * 8), list(range(8)) # in paper, it seems 16 frequencies was ideal
dct_weights = get_dct_weights(width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w])
self.register_buffer('dct_weights', dct_weights)
chan_intermediate = max(3, chan_out // reduction)
self.net = nn.Sequential(
nn.Conv2d(chan_in, chan_intermediate, 1),
nn.LeakyReLU(0.1),
nn.Conv2d(chan_intermediate, chan_out, 1),
nn.Sigmoid()
)
def forward(self, x):
x = reduce(x * self.dct_weights, 'b c (h h1) (w w1) -> b c h1 w1', 'sum', h1 = 1, w1 = 1)
return self.net(x)
# generative adversarial network
class Generator(nn.Module):
def __init__(
self,
*,
image_size,
latent_dim = 256,
fmap_max = 512,
fmap_inverse_coef = 12,
transparent = False,
greyscale = False,
attn_res_layers = [],
freq_chan_attn = False
):
super().__init__()
resolution = log2(image_size)
assert is_power_of_two(image_size), 'image size must be a power of 2'
if transparent:
init_channel = 4
elif greyscale:
init_channel = 1
else:
init_channel = 3
fmap_max = default(fmap_max, latent_dim)
self.initial_conv = nn.Sequential(
nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4),
norm_class(latent_dim * 2),
nn.GLU(dim = 1)
)
num_layers = int(resolution) - 2
features = list(map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2)))
features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))
features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features))
features = [latent_dim, *features]
in_out_features = list(zip(features[:-1], features[1:]))
self.res_layers = range(2, num_layers + 2)
self.layers = nn.ModuleList([])
self.res_to_feature_map = dict(zip(self.res_layers, in_out_features))
self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10))
self.sle_map = list(filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map))
self.sle_map = dict(self.sle_map)
self.num_layers_spatial_res = 1
for (res, (chan_in, chan_out)) in zip(self.res_layers, in_out_features):
image_width = 2 ** res
attn = None
if image_width in attn_res_layers:
attn = PreNorm(chan_in, LinearAttention(chan_in))
sle = None
if res in self.sle_map:
residual_layer = self.sle_map[res]
sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1]
if freq_chan_attn:
sle = FCANet(
chan_in = chan_out,
chan_out = sle_chan_out,
width = 2 ** (res + 1)
)
else:
sle = GlobalContext(
chan_in = chan_out,
chan_out = sle_chan_out
)
layer = nn.ModuleList([
nn.Sequential(
PixelShuffleUpsample(chan_in),
Blur(),
Conv2dSame(chan_in, chan_out * 2, 4),
Noise(),
norm_class(chan_out * 2),
nn.GLU(dim = 1)
),
sle,
attn
])
self.layers.append(layer)
self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding = 1)
def forward(self, x):
x = rearrange(x, 'b c -> b c () ()')
x = self.initial_conv(x)
x = F.normalize(x, dim = 1)
residuals = dict()
for (res, (up, sle, attn)) in zip(self.res_layers, self.layers):
if exists(attn):
x = attn(x) + x
x = up(x)
if exists(sle):
out_res = self.sle_map[res]
residual = sle(x)
residuals[out_res] = residual
next_res = res + 1
if next_res in residuals:
x = x * residuals[next_res]
return self.out_conv(x)
class SimpleDecoder(nn.Module):
def __init__(
self,
*,
chan_in,
chan_out = 3,
num_upsamples = 4,
):
super().__init__()
self.layers = nn.ModuleList([])
final_chan = chan_out
chans = chan_in
for ind in range(num_upsamples):
last_layer = ind == (num_upsamples - 1)
chan_out = chans if not last_layer else final_chan * 2
layer = nn.Sequential(
PixelShuffleUpsample(chans),
nn.Conv2d(chans, chan_out, 3, padding = 1),
nn.GLU(dim = 1)
)
self.layers.append(layer)
chans //= 2
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class Discriminator(nn.Module):
def __init__(
self,
*,
image_size,
fmap_max = 512,
fmap_inverse_coef = 12,
transparent = False,
greyscale = False,
disc_output_size = 5,
attn_res_layers = []
):
super().__init__()
resolution = log2(image_size)
assert is_power_of_two(image_size), 'image size must be a power of 2'
assert disc_output_size in {1, 5}, 'discriminator output dimensions can only be 5x5 or 1x1'
resolution = int(resolution)
if transparent:
init_channel = 4
elif greyscale:
init_channel = 1
else:
init_channel = 3
num_non_residual_layers = max(0, int(resolution) - 8)
num_residual_layers = 8 - 3
non_residual_resolutions = range(min(8, resolution), 2, -1)
features = list(map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), non_residual_resolutions))
features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))
if num_non_residual_layers == 0:
res, _ = features[0]
features[0] = (res, init_channel)
chan_in_out = list(zip(features[:-1], features[1:]))
self.non_residual_layers = nn.ModuleList([])
for ind in range(num_non_residual_layers):
first_layer = ind == 0
last_layer = ind == (num_non_residual_layers - 1)
chan_out = features[0][-1] if last_layer else init_channel
self.non_residual_layers.append(nn.Sequential(
Blur(),
nn.Conv2d(init_channel, chan_out, 4, stride = 2, padding = 1),
nn.LeakyReLU(0.1)
))
self.residual_layers = nn.ModuleList([])
for (res, ((_, chan_in), (_, chan_out))) in zip(non_residual_resolutions, chan_in_out):
image_width = 2 ** res
attn = None
if image_width in attn_res_layers:
attn = PreNorm(chan_in, LinearAttention(chan_in))
self.residual_layers.append(nn.ModuleList([
SumBranches([
nn.Sequential(
Blur(),
SPConvDownsample(chan_in, chan_out),
nn.LeakyReLU(0.1),
nn.Conv2d(chan_out, chan_out, 3, padding = 1),
nn.LeakyReLU(0.1)
),
nn.Sequential(
Blur(),
nn.AvgPool2d(2),
nn.Conv2d(chan_in, chan_out, 1),
nn.LeakyReLU(0.1),
)
]),
attn
]))
last_chan = features[-1][-1]
if disc_output_size == 5:
self.to_logits = nn.Sequential(
nn.Conv2d(last_chan, last_chan, 1),
nn.LeakyReLU(0.1),
nn.Conv2d(last_chan, 1, 4)
)
elif disc_output_size == 1:
self.to_logits = nn.Sequential(
Blur(),
nn.Conv2d(last_chan, last_chan, 3, stride = 2, padding = 1),
nn.LeakyReLU(0.1),
nn.Conv2d(last_chan, 1, 4)
)
self.to_shape_disc_out = nn.Sequential(
nn.Conv2d(init_channel, 64, 3, padding = 1),
Residual(PreNorm(64, LinearAttention(64))),
SumBranches([
nn.Sequential(
Blur(),
SPConvDownsample(64, 32),
nn.LeakyReLU(0.1),
nn.Conv2d(32, 32, 3, padding = 1),
nn.LeakyReLU(0.1)
),
nn.Sequential(
Blur(),
nn.AvgPool2d(2),
nn.Conv2d(64, 32, 1),
nn.LeakyReLU(0.1),
)
]),
Residual(PreNorm(32, LinearAttention(32))),
nn.AdaptiveAvgPool2d((4, 4)),
nn.Conv2d(32, 1, 4)
)
self.decoder1 = SimpleDecoder(chan_in = last_chan, chan_out = init_channel)
self.decoder2 = SimpleDecoder(chan_in = features[-2][-1], chan_out = init_channel) if resolution >= 9 else None
def forward(self, x, calc_aux_loss = False):
orig_img = x
for layer in self.non_residual_layers:
x = layer(x)
layer_outputs = []
for (net, attn) in self.residual_layers:
if exists(attn):
x = attn(x) + x
x = net(x)
layer_outputs.append(x)
out = self.to_logits(x).flatten(1)
img_32x32 = F.interpolate(orig_img, size = (32, 32))
out_32x32 = self.to_shape_disc_out(img_32x32)
if not calc_aux_loss:
return out, out_32x32, None
# self-supervised auto-encoding loss
layer_8x8 = layer_outputs[-1]
layer_16x16 = layer_outputs[-2]
recon_img_8x8 = self.decoder1(layer_8x8)
aux_loss = F.mse_loss(
recon_img_8x8,
F.interpolate(orig_img, size = recon_img_8x8.shape[2:])
)
if exists(self.decoder2):
select_random_quadrant = lambda rand_quadrant, img: rearrange(img, 'b c (m h) (n w) -> (m n) b c h w', m = 2, n = 2)[rand_quadrant]
crop_image_fn = partial(select_random_quadrant, floor(random() * 4))
img_part, layer_16x16_part = map(crop_image_fn, (orig_img, layer_16x16))
recon_img_16x16 = self.decoder2(layer_16x16_part)
aux_loss_16x16 = F.mse_loss(
recon_img_16x16,
F.interpolate(img_part, size = recon_img_16x16.shape[2:])
)
aux_loss = aux_loss + aux_loss_16x16
return out, out_32x32, aux_loss
class LightweightGAN(nn.Module):
def __init__(
self,
*,
latent_dim,
image_size,
optimizer = "adam",
fmap_max = 512,
fmap_inverse_coef = 12,
transparent = False,
greyscale = False,
disc_output_size = 5,
attn_res_layers = [],
freq_chan_attn = False,
ttur_mult = 1.,
lr = 2e-4,
rank = 0,
ddp = False
):
super().__init__()
self.latent_dim = latent_dim
self.image_size = image_size
G_kwargs = dict(
image_size = image_size,
latent_dim = latent_dim,
fmap_max = fmap_max,
fmap_inverse_coef = fmap_inverse_coef,
transparent = transparent,
greyscale = greyscale,
attn_res_layers = attn_res_layers,
freq_chan_attn = freq_chan_attn
)
self.G = Generator(**G_kwargs)
self.D = Discriminator(
image_size = image_size,
fmap_max = fmap_max,
fmap_inverse_coef = fmap_inverse_coef,
transparent = transparent,
greyscale = greyscale,
attn_res_layers = attn_res_layers,
disc_output_size = disc_output_size
)
self.ema_updater = EMA(0.995)
self.GE = Generator(**G_kwargs)
set_requires_grad(self.GE, False)
if optimizer == "adam":
self.G_opt = Adam(self.G.parameters(), lr = lr, betas=(0.5, 0.9))
self.D_opt = Adam(self.D.parameters(), lr = lr * ttur_mult, betas=(0.5, 0.9))
elif optimizer == "adabelief":
self.G_opt = AdaBelief(self.G.parameters(), lr = lr, betas=(0.5, 0.9))
self.D_opt = AdaBelief(self.D.parameters(), lr = lr * ttur_mult, betas=(0.5, 0.9))
else:
assert False, "No valid optimizer is given"
self.apply(self._init_weights)
self.reset_parameter_averaging()
self.cuda(rank)
self.D_aug = AugWrapper(self.D, image_size)
def _init_weights(self, m):
if type(m) in {nn.Conv2d, nn.Linear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
def EMA(self):
def update_moving_average(ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.ema_updater.update_average(old_weight, up_weight)
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
new_buffer_value = self.ema_updater.update_average(ma_buffer, current_buffer)
ma_buffer.copy_(new_buffer_value)
update_moving_average(self.GE, self.G)
def reset_parameter_averaging(self):
self.GE.load_state_dict(self.G.state_dict())
def forward(self, x):
raise NotImplemented
# trainer
class Trainer():
def __init__(
self,
name = 'default',
results_dir = 'results',
models_dir = 'models',
base_dir = './',
optimizer = 'adam',
num_workers = None,
latent_dim = 256,
image_size = 128,
num_image_tiles = 8,
fmap_max = 512,
transparent = False,
greyscale = False,
batch_size = 4,
gp_weight = 10,
gradient_accumulate_every = 1,
attn_res_layers = [],
freq_chan_attn = False,
disc_output_size = 5,
dual_contrast_loss = False,
antialias = False,
lr = 2e-4,
lr_mlp = 1.,
ttur_mult = 1.,
save_every = 1000,
evaluate_every = 1000,
aug_prob = None,
aug_types = ['translation', 'cutout'],
dataset_aug_prob = 0.,
calculate_fid_every = None,
calculate_fid_num_images = 12800,
clear_fid_cache = False,
is_ddp = False,
rank = 0,
world_size = 1,
log = False,
amp = False,
hparams = None,
use_aim = True,
aim_repo = None,
aim_run_hash = None,
load_strict = True,
*args,
**kwargs
):
self.GAN_params = [args, kwargs]
self.GAN = None
self.name = name
base_dir = Path(base_dir)
self.base_dir = base_dir
self.results_dir = base_dir / results_dir
self.models_dir = base_dir / models_dir
self.fid_dir = base_dir / 'fid' / name
self.config_path = self.models_dir / name / '.config.json'
assert is_power_of_two(image_size), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
assert all(map(is_power_of_two, attn_res_layers)), 'resolution layers of attention must all be powers of 2 (16, 32, 64, 128, 256, 512)'
assert not (dual_contrast_loss and disc_output_size > 1), 'discriminator output size cannot be greater than 1 if using dual contrastive loss'
self.image_size = image_size
self.num_image_tiles = num_image_tiles
self.latent_dim = latent_dim
self.fmap_max = fmap_max
self.transparent = transparent
self.greyscale = greyscale
assert (int(self.transparent) + int(self.greyscale)) < 2, 'you can only set either transparency or greyscale'
self.aug_prob = aug_prob
self.aug_types = aug_types
self.lr = lr
self.optimizer = optimizer
self.num_workers = num_workers
self.ttur_mult = ttur_mult
self.batch_size = batch_size
self.gradient_accumulate_every = gradient_accumulate_every
self.gp_weight = gp_weight
self.evaluate_every = evaluate_every
self.save_every = save_every
self.steps = 0
self.attn_res_layers = attn_res_layers
self.freq_chan_attn = freq_chan_attn
self.disc_output_size = disc_output_size
self.antialias = antialias
self.dual_contrast_loss = dual_contrast_loss
self.d_loss = 0
self.g_loss = 0
self.last_gp_loss = None
self.last_recon_loss = None
self.last_fid = None
self.init_folders()
self.loader = None
self.dataset_aug_prob = dataset_aug_prob
self.calculate_fid_every = calculate_fid_every
self.calculate_fid_num_images = calculate_fid_num_images
self.clear_fid_cache = clear_fid_cache
self.is_ddp = is_ddp
self.is_main = rank == 0
self.rank = rank
self.world_size = world_size
self.syncbatchnorm = is_ddp
self.load_strict = load_strict
self.amp = amp
self.G_scaler = GradScaler(enabled = self.amp)
self.D_scaler = GradScaler(enabled = self.amp)
self.run = None
self.hparams = hparams
if self.is_main and use_aim:
try:
import aim
self.aim = aim
except ImportError:
print('unable to import aim experiment tracker - please run `pip install aim` first')
self.run = self.aim.Run(run_hash=aim_run_hash, repo=aim_repo)
self.run['hparams'] = hparams
@property
def image_extension(self):
return 'jpg' if not self.transparent else 'png'
@property
def checkpoint_num(self):
return floor(self.steps // self.save_every)
def init_GAN(self):
args, kwargs = self.GAN_params
# set some global variables before instantiating GAN
global norm_class
global Blur
norm_class = nn.SyncBatchNorm if self.syncbatchnorm else nn.BatchNorm2d
Blur = nn.Identity if not self.antialias else Blur
# handle bugs when
# switching from multi-gpu back to single gpu
if self.syncbatchnorm and not self.is_ddp:
import torch.distributed as dist
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group('nccl', rank=0, world_size=1)
# instantiate GAN
self.GAN = LightweightGAN(
optimizer=self.optimizer,
lr = self.lr,
latent_dim = self.latent_dim,
attn_res_layers = self.attn_res_layers,
freq_chan_attn = self.freq_chan_attn,
image_size = self.image_size,
ttur_mult = self.ttur_mult,
fmap_max = self.fmap_max,
disc_output_size = self.disc_output_size,
transparent = self.transparent,
greyscale = self.greyscale,
rank = self.rank,
*args,
**kwargs
)
if self.is_ddp:
ddp_kwargs = {'device_ids': [self.rank], 'output_device': self.rank, 'find_unused_parameters': True}
self.G_ddp = DDP(self.GAN.G, **ddp_kwargs)
self.D_ddp = DDP(self.GAN.D, **ddp_kwargs)
self.D_aug_ddp = DDP(self.GAN.D_aug, **ddp_kwargs)
def write_config(self):
self.config_path.write_text(json.dumps(self.config()))
def load_config(self):
config = self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text())
self.image_size = config['image_size']
self.transparent = config['transparent']
self.syncbatchnorm = config['syncbatchnorm']
self.disc_output_size = config['disc_output_size']
self.greyscale = config.pop('greyscale', False)
self.attn_res_layers = config.pop('attn_res_layers', [])
self.freq_chan_attn = config.pop('freq_chan_attn', False)
self.optimizer = config.pop('optimizer', 'adam')
self.fmap_max = config.pop('fmap_max', 512)
del self.GAN
self.init_GAN()
def config(self):
return {
'image_size': self.image_size,
'transparent': self.transparent,
'greyscale': self.greyscale,
'syncbatchnorm': self.syncbatchnorm,
'disc_output_size': self.disc_output_size,
'optimizer': self.optimizer,
'attn_res_layers': self.attn_res_layers,
'freq_chan_attn': self.freq_chan_attn
}
def set_data_src(self, folder):
num_workers = default(self.num_workers, math.ceil(NUM_CORES / self.world_size))
self.dataset = ImageDataset(folder, self.image_size, transparent = self.transparent, greyscale = self.greyscale, aug_prob = self.dataset_aug_prob)
sampler = DistributedSampler(self.dataset, rank=self.rank, num_replicas=self.world_size, shuffle=True) if self.is_ddp else None
dataloader = DataLoader(self.dataset, num_workers = num_workers, batch_size = math.ceil(self.batch_size / self.world_size), sampler = sampler, shuffle = not self.is_ddp, drop_last = True, pin_memory = True)
self.loader = cycle(dataloader)
# auto set augmentation prob for user if dataset is detected to be low
num_samples = len(self.dataset)
if not exists(self.aug_prob) and num_samples < 1e5:
self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6)
print(f'autosetting augmentation probability to {round(self.aug_prob * 100)}%')
def train(self):
assert exists(self.loader), 'You must first initialize the data source with `.set_data_src(<folder of images>)`'
device = torch.device(f'cuda:{self.rank}')
if not exists(self.GAN):
self.init_GAN()
self.GAN.train()
total_disc_loss = torch.zeros([], device=device)
total_gen_loss = torch.zeros([], device=device)
batch_size = math.ceil(self.batch_size / self.world_size)
image_size = self.GAN.image_size
latent_dim = self.GAN.latent_dim
aug_prob = default(self.aug_prob, 0)
aug_types = self.aug_types
aug_kwargs = {'prob': aug_prob, 'types': aug_types}
G = self.GAN.G if not self.is_ddp else self.G_ddp
D = self.GAN.D if not self.is_ddp else self.D_ddp
D_aug = self.GAN.D_aug if not self.is_ddp else self.D_aug_ddp
apply_gradient_penalty = self.steps % 4 == 0
# amp related contexts and functions
amp_context = autocast if self.amp else null_context
# discriminator loss fn
if self.dual_contrast_loss:
D_loss_fn = dual_contrastive_loss
else:
D_loss_fn = hinge_loss
# train discriminator
self.GAN.D_opt.zero_grad()
for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[D_aug, G]):
latents = torch.randn(batch_size, latent_dim).cuda(self.rank)
image_batch = next(self.loader).cuda(self.rank)
image_batch.requires_grad_()
with amp_context():
with torch.no_grad():
generated_images = G(latents)
fake_output, fake_output_32x32, _ = D_aug(generated_images, detach = True, **aug_kwargs)
real_output, real_output_32x32, real_aux_loss = D_aug(image_batch, calc_aux_loss = True, **aug_kwargs)
real_output_loss = real_output
fake_output_loss = fake_output
divergence = D_loss_fn(real_output_loss, fake_output_loss)
divergence_32x32 = D_loss_fn(real_output_32x32, fake_output_32x32)
disc_loss = divergence + divergence_32x32
aux_loss = real_aux_loss
disc_loss = disc_loss + aux_loss
if apply_gradient_penalty:
outputs = [real_output, real_output_32x32]
outputs = list(map(self.D_scaler.scale, outputs)) if self.amp else outputs
scaled_gradients = torch_grad(outputs=outputs, inputs=image_batch,
grad_outputs=list(map(lambda t: torch.ones(t.size(), device = image_batch.device), outputs)),
create_graph=True, retain_graph=True, only_inputs=True)[0]
inv_scale = safe_div(1., self.D_scaler.get_scale()) if self.amp else 1.
if inv_scale != float('inf'):
gradients = scaled_gradients * inv_scale
with amp_context():
gradients = gradients.reshape(batch_size, -1)
gp = self.gp_weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
if not torch.isnan(gp):
disc_loss = disc_loss + gp
self.last_gp_loss = gp.clone().detach().item()
with amp_context():
disc_loss = disc_loss / self.gradient_accumulate_every
disc_loss.register_hook(raise_if_nan)
self.D_scaler.scale(disc_loss).backward()
total_disc_loss += divergence
self.last_recon_loss = aux_loss.item()
self.d_loss = float(total_disc_loss.item() / self.gradient_accumulate_every)
self.D_scaler.step(self.GAN.D_opt)
self.D_scaler.update()
# generator loss fn
if self.dual_contrast_loss:
G_loss_fn = dual_contrastive_loss
G_requires_calc_real = True
else:
G_loss_fn = gen_hinge_loss
G_requires_calc_real = False
# train generator
self.GAN.G_opt.zero_grad()
for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[G, D_aug]):
latents = torch.randn(batch_size, latent_dim).cuda(self.rank)
if G_requires_calc_real:
image_batch = next(self.loader).cuda(self.rank)
image_batch.requires_grad_()
with amp_context():
generated_images = G(latents)
fake_output, fake_output_32x32, _ = D_aug(generated_images, **aug_kwargs)
real_output, real_output_32x32, _ = D_aug(image_batch, **aug_kwargs) if G_requires_calc_real else (None, None, None)
loss = G_loss_fn(fake_output, real_output)
loss_32x32 = G_loss_fn(fake_output_32x32, real_output_32x32)
gen_loss = loss + loss_32x32
gen_loss = gen_loss / self.gradient_accumulate_every
gen_loss.register_hook(raise_if_nan)
self.G_scaler.scale(gen_loss).backward()
total_gen_loss += loss
self.g_loss = float(total_gen_loss.item() / self.gradient_accumulate_every)
self.G_scaler.step(self.GAN.G_opt)
self.G_scaler.update()
# calculate moving averages
if self.is_main and self.steps % 10 == 0 and self.steps > 20000:
self.GAN.EMA()
if self.is_main and self.steps <= 25000 and self.steps % 1000 == 2:
self.GAN.reset_parameter_averaging()
# save from NaN errors
if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)):
print(f'NaN detected for generator or discriminator. Loading from checkpoint #{self.checkpoint_num}')
self.load(self.checkpoint_num)
raise NanException
del total_disc_loss
del total_gen_loss
# periodically save results
if self.is_main:
if self.steps % self.save_every == 0:
self.save(self.checkpoint_num)
if self.steps % self.evaluate_every == 0 or (self.steps % 100 == 0 and self.steps < 20000):
self.evaluate(floor(self.steps / self.evaluate_every), num_image_tiles = self.num_image_tiles)
if exists(self.calculate_fid_every) and self.steps % self.calculate_fid_every == 0 and self.steps != 0:
num_batches = math.ceil(self.calculate_fid_num_images / self.batch_size)
fid = self.calculate_fid(num_batches)
self.last_fid = fid
with open(str(self.results_dir / self.name / f'fid_scores.txt'), 'a') as f:
f.write(f'{self.steps},{fid}\n')
self.steps += 1
@torch.no_grad()
def evaluate(self, num = 0, num_image_tiles = 4):
self.GAN.eval()
ext = self.image_extension
num_rows = num_image_tiles
latent_dim = self.GAN.latent_dim
image_size = self.GAN.image_size
# latents and noise
def image_to_pil(image):
ndarr = image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
return im
latents = det_randn((num_rows ** 2, latent_dim)).cuda(self.rank)
interpolate_latents = interpolate_between(latents[:num_rows], latents[-num_rows:],
num_samples=num_rows,
dim=0).flatten(end_dim=1)
generate_interpolations = self.generate_(self.GAN.G, interpolate_latents)
if self.run is not None:
grouped = generate_interpolations.view(num_rows, num_rows, *generate_interpolations.shape[1:])
for idx, images in enumerate(grouped):
alpha = idx / (len(grouped) - 1)
aim_images = []
for image in images:
im = image_to_pil(image)
aim_images.append(self.aim.Image(im, caption=f'#{idx}'))
self.run.track(value=aim_images, name='generated',
step=self.steps,
context={'interpolated': True,
'alpha': alpha})
torchvision.utils.save_image(generate_interpolations, str(self.results_dir / self.name / f'{str(num)}-interp.{ext}'), nrow=num_rows)
# regular
generated_images = self.generate_(self.GAN.G, latents)
if self.run is not None:
aim_images = []
for idx, image in enumerate(generated_images):
im = image_to_pil(image)
aim_images.append(self.aim.Image(im, caption=f'#{idx}'))
self.run.track(value=aim_images, name='generated',
step=self.steps,
context={'ema': False})
torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows)
# moving averages
generated_images = self.generate_(self.GAN.GE, latents)
if self.run is not None:
aim_images = []
for idx, image in enumerate(generated_images):
im = image_to_pil(image)
aim_images.append(self.aim.Image(im, caption=f'EMA #{idx}'))
self.run.track(value=aim_images, name='generated',
step=self.steps,
context={'ema': True})
torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-ema.{ext}'), nrow=num_rows)
@torch.no_grad()
def generate(self, num=0, num_image_tiles=4, checkpoint=None, types=['default', 'ema']):
self.GAN.eval()
latent_dim = self.GAN.latent_dim
dir_name = self.name + str('-generated-') + str(checkpoint)
dir_full = Path().absolute() / self.results_dir / dir_name
ext = self.image_extension
if not dir_full.exists():
os.mkdir(dir_full)
# regular
if 'default' in types:
for i in tqdm(range(num_image_tiles), desc='Saving generated default images'):
latents = torch.randn((1, latent_dim)).cuda(self.rank)
generated_image = self.generate_(self.GAN.G, latents)
path = str(self.results_dir / dir_name / f'{str(num)}-{str(i)}.{ext}')
torchvision.utils.save_image(generated_image[0], path, nrow=1)
# moving averages
if 'ema' in types:
for i in tqdm(range(num_image_tiles), desc='Saving generated EMA images'):
latents = torch.randn((1, latent_dim)).cuda(self.rank)
generated_image = self.generate_(self.GAN.GE, latents)
path = str(self.results_dir / dir_name / f'{str(num)}-{str(i)}-ema.{ext}')
torchvision.utils.save_image(generated_image[0], path, nrow=1)
return dir_full
@torch.no_grad()
def show_progress(self, num_images=4, types=['default', 'ema']):
checkpoints = self.get_checkpoints()
assert exists(checkpoints), 'cannot find any checkpoints to create a training progress video for'
dir_name = self.name + str('-progress')
dir_full = Path().absolute() / self.results_dir / dir_name
ext = self.image_extension
latents = None
zfill_length = math.ceil(math.log10(len(checkpoints)))
if not dir_full.exists():
os.mkdir(dir_full)
for checkpoint in tqdm(checkpoints, desc='Generating progress images'):
self.load(checkpoint, print_version=False)
self.GAN.eval()
if checkpoint == 0:
latents = torch.randn((num_images, self.GAN.latent_dim)).cuda(self.rank)
# regular
if 'default' in types:
generated_image = self.generate_(self.GAN.G, latents)
path = str(self.results_dir / dir_name / f'{str(checkpoint).zfill(zfill_length)}.{ext}')
torchvision.utils.save_image(generated_image, path, nrow=num_images)
# moving averages
if 'ema' in types:
generated_image = self.generate_(self.GAN.GE, latents)
path = str(self.results_dir / dir_name / f'{str(checkpoint).zfill(zfill_length)}-ema.{ext}')
torchvision.utils.save_image(generated_image, path, nrow=num_images)
@torch.no_grad()
def calculate_fid(self, num_batches):
from pytorch_fid import fid_score
torch.cuda.empty_cache()
real_path = self.fid_dir / 'real'
fake_path = self.fid_dir / 'fake'
# remove any existing files used for fid calculation and recreate directories
if not real_path.exists() or self.clear_fid_cache:
rmtree(real_path, ignore_errors=True)
os.makedirs(real_path)
for batch_num in tqdm(range(num_batches), desc='calculating FID - saving reals'):
real_batch = next(self.loader)
for k, image in enumerate(real_batch.unbind(0)):
ind = k + batch_num * self.batch_size
torchvision.utils.save_image(image, real_path / f'{ind}.png')
# generate a bunch of fake images in results / name / fid_fake
rmtree(fake_path, ignore_errors=True)
os.makedirs(fake_path)
self.GAN.eval()
ext = self.image_extension
latent_dim = self.GAN.latent_dim
image_size = self.GAN.image_size
for batch_num in tqdm(range(num_batches), desc='calculating FID - saving generated'):
# latents and noise
latents = torch.randn(self.batch_size, latent_dim).cuda(self.rank)
# moving averages
generated_images = self.generate_(self.GAN.GE, latents)
for j, image in enumerate(generated_images.unbind(0)):
ind = j + batch_num * self.batch_size
torchvision.utils.save_image(image, str(fake_path / f'{str(ind)}-ema.{ext}'))
return fid_score.calculate_fid_given_paths([str(real_path), str(fake_path)], 256, latents.device, 2048)
@torch.no_grad()
def generate_(self, G, style, num_image_tiles = 8):
generated_images = evaluate_in_chunks(self.batch_size, G, style)
return generated_images.clamp_(0., 1.)
@torch.no_grad()
def generate_interpolation(self, num = 0, num_image_tiles = 8, num_steps = 100, save_frames = False):
self.GAN.eval()
ext = self.image_extension
num_rows = num_image_tiles
latent_dim = self.GAN.latent_dim
image_size = self.GAN.image_size
# latents and noise
latents_low = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank)
latents_high = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank)
ratios = torch.linspace(0., 8., num_steps)
frames = []
for ratio in tqdm(ratios):
interp_latents = slerp(ratio, latents_low, latents_high)
generated_images = self.generate_(self.GAN.GE, interp_latents)
images_grid = torchvision.utils.make_grid(generated_images, nrow = num_rows)
pil_image = transforms.ToPILImage()(images_grid.cpu())
if self.transparent:
background = Image.new('RGBA', pil_image.size, (255, 255, 255))
pil_image = Image.alpha_composite(background, pil_image)
frames.append(pil_image)
frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'), save_all=True, append_images=frames[1:], duration=80, loop=0, optimize=True)
if save_frames:
folder_path = (self.results_dir / self.name / f'{str(num)}')
folder_path.mkdir(parents=True, exist_ok=True)
for ind, frame in enumerate(frames):
frame.save(str(folder_path / f'{str(ind)}.{ext}'))
def print_log(self):
data = [
('G', self.g_loss),
('D', self.d_loss),
('GP', self.last_gp_loss),
('SS', self.last_recon_loss),
('FID', self.last_fid)
]
data = [d for d in data if exists(d[1])]
log = ' | '.join(map(lambda n: f'{n[0]}: {n[1]:.2f}', data))
print(log)
if self.run is not None:
for key, value in data:
self.run.track(value, key, step=self.steps)
return data
def model_name(self, num):
return str(self.models_dir / self.name / f'model_{num}.pt')
def init_folders(self):
(self.results_dir / self.name).mkdir(parents=True, exist_ok=True)
(self.models_dir / self.name).mkdir(parents=True, exist_ok=True)
def clear(self):
rmtree(str(self.models_dir / self.name), True)
rmtree(str(self.results_dir / self.name), True)
rmtree(str(self.fid_dir), True)
rmtree(str(self.config_path), True)
self.init_folders()
def save(self, num):
save_data = {
'GAN': self.GAN.state_dict(),
'version': __version__,
'G_scaler': self.G_scaler.state_dict(),
'D_scaler': self.D_scaler.state_dict()
}
torch.save(save_data, self.model_name(num))
self.write_config()
def load(self, num=-1, print_version=True):
self.load_config()
name = num
if num == -1:
checkpoints = self.get_checkpoints()
if not exists(checkpoints):
return
name = checkpoints[-1]
print(f'continuing from previous epoch - {name}')
self.steps = name * self.save_every
load_data = torch.load(self.model_name(name))
if print_version and 'version' in load_data and self.is_main:
print(f"loading from version {load_data['version']}")
try:
self.GAN.load_state_dict(load_data['GAN'], strict = self.load_strict)
except Exception as e:
saved_version = load_data['version']
print('unable to load save model. please try downgrading the package to the version specified by the saved model (to do so, just run `pip install lightweight-gan=={saved_version}`')
raise e
if 'G_scaler' in load_data:
self.G_scaler.load_state_dict(load_data['G_scaler'])
if 'D_scaler' in load_data:
self.D_scaler.load_state_dict(load_data['D_scaler'])
def get_checkpoints(self):
file_paths = [p for p in Path(self.models_dir / self.name).glob('model_*.pt')]
saved_nums = sorted(map(lambda x: int(x.stem.split('_')[1]), file_paths))
if len(saved_nums) == 0:
return None
return saved_nums
__version__ = '1.1.1'
import sys
from setuptools import setup, find_packages
sys.path[0:0] = ['lightweight_gan']
from version import __version__
setup(
name = 'lightweight-gan',
packages = find_packages(),
entry_points={
'console_scripts': [
'lightweight_gan = lightweight_gan.cli:main',
],
},
version = __version__,
license='MIT',
description = 'Lightweight GAN',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
url = 'https://github.com/lucidrains/lightweight-gan',
keywords = [
'artificial intelligence',
'deep learning',
'generative adversarial networks'
],
install_requires=[
'adabelief-pytorch',
'einops>=0.3',
'fire',
'kornia>=0.5.4',
'numpy',
'pillow',
'retry',
'torch>=1.10',
'torchvision',
'tqdm'
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment