In this assignment, we explored training of GAN (Generative Adversarial Network). Specifically, we trained two different types of GANs that run on two different tasks. One is for generating photorealistic images and another is for the style transfer. By doing this assignment, we systematically learn about how to use GAN to tackle real-world problems.
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt
import numpy as np
For calculating the padding size, we follow the equation: $$ D_{out} = \lfloor \frac{D_{in} + 2 * padding\_size - kernel\_size}{stride\_size} + 1 \rfloor, $$ where $D$ denotes the dimension (e.g. height or witdh) of the input and output tensors.
Thus, we want the output dimension to be the $1/2$ of the size of the input (for conv1 to conv4), with $kernel\_size = 4$ and $stride\_size = 2$. Hence, we have $padding\_size = 1$.
For conv5, we want the size of padding to be zero.
In this section, all our DCGANs are trained with conv_dim=64. We then show the screenshots of discriminatorand generator training loss with both --data_aug=basic/deluxe. We show that the smoothed training losses of both are decreasing over the training process, which indicates the network is learning as expected.
image = media.read_image('figs/DCGAN_discriminator_training_loss.png')
media.show_image(image, border=True, title='discriminator loss')
discriminator loss |
image = media.read_image('figs/DCGAN_generator_training_loss.png')
media.show_image(image, border=True, title='generator loss')
generator loss |
For output samples from deluxe data augmentation DCGAN, we choose one sample every 200 iterations from 200 to 1200. The generated result is obviously improving over time. The noise it carries in its pixels is decreasing dramatically and the features related to a cat is becoming clearer and more dominant.
images = {
'iter 200': media.read_image('figs/DCGAN_200.png'),
'iter 400': media.read_image('figs/DCGAN_400.png'),
'iter 600': media.read_image('figs/DCGAN_600.png'),
'iter 800': media.read_image('figs/DCGAN_800.png'),
'iter 1K': media.read_image('figs/DCGAN_1000.png'),
'iter 1.2K': media.read_image('figs/DCGAN_1200.png'),
}
media.show_images(images, border=True, height=300)
iter 200 | iter 400 | iter 600 | iter 800 | iter 1K | iter 1.2K |
images = {
'XY iter 400': media.read_image('figs/CG_NO_XY_400.png'),
'YX iter 400': media.read_image('figs/CG_NO_YX_400.png'),
'XY iter 600': media.read_image('figs/CG_NO_XY_600.png'),
'YX iter 800': media.read_image('figs/CG_NO_YX_600.png'),
}
media.show_images(images, border=True, height=600)
XY iter 400 | YX iter 400 | XY iter 600 | YX iter 800 |
Here are our results for having cycle consistency loss.
images = {
'XY iter 400': media.read_image('figs/CG_YES_XY_400.png'),
'YX iter 400': media.read_image('figs/CG_YES_YX_400.png'),
'XY iter 600': media.read_image('figs/CG_YES_XY_600.png'),
'YX iter 800': media.read_image('figs/CG_YES_YX_600.png'),
}
media.show_images(images, border=True, height=600)
XY iter 400 | YX iter 400 | XY iter 600 | YX iter 800 |
After checking that our shorter runs work, we now move on to train longer for 10K iterations. We then show the final training result at 10K iteration both with/without cycle consistency loss. First, here we show results that do not utilize cycle consistency loss.
images = {
'XY': media.read_image('figs/CG_NO_XY_10000.png'),
'YX': media.read_image('figs/CG_NO_YX_10000.png'),
}
media.show_images(images, border=True, height=600)
XY | YX |
Here are our final results with cycle consistency loss.
images = {
'XY': media.read_image('figs/CG_YES_XY_10000.png'),
'YX': media.read_image('figs/CG_YES_YX_10000.png'),
}
media.show_images(images, border=True, height=600)
XY | YX |
As can be seen easily from above, cycle consistency loss contributes a lot to the final image quality. Especially, Y->X style transfer improves drastically thanks to the cycle consistency loss. The result behind this is presumably because that cycle consistency loss enforces the generated images can be only within those two domains involved in the training, while not having cycle consistency loss may result in that the generator accidentally learns to transfer from one style to some random unseen style that is irrelevant to our task.