In this assignment, we dived into the topic of GAN photo editing, a novel image modification task that utilizes generative adversarial network to produce the desired image with proper constraints. Specifically, in this assignment we explored subtopics such as generator inversion, GAN image interpolation and generating potorealistic images from scribbles.
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt
Unless stated otherwise, by default, all experiments in this assignment are following:
with all other hyperparameters set to starter file and PyTorch default.
In this section, we implement a generator inversion optimization process to retrieve latent variable $z$ from an input image $x$. Specifically, our optimization target is
$$
z^{\ast} = \arg \min_z \left\{(1 - \lambda) \cdot \mathcal{L_2}(G(z), x) + \lambda \cdot \left(\sum_i\mathcal{L}^{(i)}_{content}((G(z), x) + \sum_j\mathcal{L}^{(j)}_{style}((G(z), x)\right)\right\},
$$
where $\mathcal{L_2}$ is the L2 loss, $\mathcal{L}_{content}$ and $\mathcal{L}_{style}$ are content loss components styles loss components. There are multiple content/style loss units in the network so we sum them up to formulate the perceptual loss. Finally, we combine L2 loss and perceptual loss via coefficient $\lambda$. We optimize our generator inversion problem over this target function.
We can also switch from directly using the latent variable $z$ to higher dimensional intermediate embeddings such as $w$ and $w+$.
We show the effect of perceptual loss by comparing reconstructed images generated using different $\lambda$ values while keeping other settings identical. Specifically, our experiments are done using vanilla generator model with $\lambda$ set to be $0, 0.25, 0.5, 0.75, 1$ to reconstruct $z$ to test on situations of using L2 loss only, using a mixture of L2 and perceptual losses, and using perceptual loss only, correspondingly. Visualizations of reconstructed images are as follows.
images = {
'original input': media.read_image('output/project/0_data.png'),
'lambda = 0': media.read_image('output/project/0_vanilla_z_0_10000.png'),
'lambda = 0.25': media.read_image('output/project/0_vanilla_z_0.25_10000.png'),
'lambda = 0.5': media.read_image('output/project/0_vanilla_z_0.5_10000.png'),
'lambda = 0.75': media.read_image('output/project/0_vanilla_z_0.75_10000.png'),
'lambda = 1': media.read_image('output/project/0_vanilla_z_1_10000.png'),
}
media.show_images(images, columns=6, border=True)
original input | lambda = 0 | lambda = 0.25 | lambda = 0.5 | lambda = 0.75 | lambda = 1 |
Comment on previous results: It seems like that $\lambda = 0.5$ gives the best result among all settings, which in my opinion has a proper balance of resonstruction of details and overall style similarity.
We further test the effectiveness of swaping the encoding from latent $z$ to $w$ and $w+$. To show this, we employ the provided StyleGAN generator which has a proper $w$/$w+$ mapping network. We reconstruct images with a fixed $\lambda = 0,5$, using $z$, $w$ and $w+$ respectively. Here are the results.
images = {
'original input': media.read_image('output/project/0_data.png'),
'z': media.read_image('output/project/0_stylegan_z_0.5_10000.png'),
'w': media.read_image('output/project/0_stylegan_w_0.5_10000.png'),
'w+': media.read_image('output/project/0_stylegan_w+_0.5_10000.png'),
}
media.show_images(images, columns=4, border=True)
original input | z | w | w+ |
Comment on previous results: First of all, even only using latent $z$, the result is better than the the its vanilla counterpart, which shows that the provided StyleGAN generator is a better model compared with the vanilla model. When it comes to the reconstruction quality using different embeddings, I find at this time $w$ embedding seems to give the best result in terms of its color, style, and overall similarity with the original input. Due to this, I will choose to use StyleGAN generator with $w$ space and set $\lambda = 0.5$ in the following section. When it comes to the optimization speed, the StyleGAN generator takes significant longer time to train compared with the vanilla model while different embeddings do not affact the training speed at all.
In this section, we blend 2 cat images together by reweighting their corresponding latent $z_1$ and $z_2$ variables. Specifically, for gif generation purpose, our $\theta$ ranges from $0$ to $1.02$ with a step size $0.02$ (i.e.torch.arange(0, 1.2, 0.02)
. By doing this we can have a smooth transition from using $0$% of $z_2$ to $100$% using $z_2$. We also tried this process on various settings (i.e. different models & embeddings). Recounstruction results as well as gif interpolation results are as follows.
images = {
'original input': media.read_image('output/interpolate/0.png'),
'vanilla z': media.read_image('output/interpolate/0_vanilla_z.png'),
'StyleGAN z': media.read_image('output/interpolate/0_stylegan_z.png'),
'StyleGAN w': media.read_image('output/interpolate/0_stylegan_w.png'),
'StyleGAN w+': media.read_image('output/interpolate/0_stylegan_w+.png'),
}
media.show_images(images, columns=5, border=True)
images = {
'original input': media.read_image('output/interpolate/1.png'),
'vanilla z': media.read_image('output/interpolate/1_vanilla_z.png'),
'StyleGAN z': media.read_image('output/interpolate/1_stylegan_z.png'),
'StyleGAN w': media.read_image('output/interpolate/1_stylegan_w.png'),
'StyleGAN w+': media.read_image('output/interpolate/1_stylegan_w+.png'),
}
media.show_images(images, columns=5, border=True)
images = {
'original input': media.read_image('output/interpolate/0.png'),
'vanilla z': media.read_image('output/interpolate/2_vanilla_z.png'),
'StyleGAN z': media.read_image('output/interpolate/2_stylegan_z.png'),
'StyleGAN w': media.read_image('output/interpolate/2_stylegan_w.png'),
'StyleGAN w+': media.read_image('output/interpolate/2_stylegan_w+.png'),
}
media.show_images(images, columns=5, border=True)
images = {
'original input': media.read_image('output/interpolate/0.png'),
'vanilla z': media.read_image('output/interpolate/3_vanilla_z.png'),
'StyleGAN z': media.read_image('output/interpolate/3_stylegan_z.png'),
'StyleGAN w': media.read_image('output/interpolate/3_stylegan_w.png'),
'StyleGAN w+': media.read_image('output/interpolate/3_stylegan_w+.png'),
}
media.show_images(images, columns=5, border=True)
original input | vanilla z | StyleGAN z | StyleGAN w | StyleGAN w+ |
original input | vanilla z | StyleGAN z | StyleGAN w | StyleGAN w+ |
original input | vanilla z | StyleGAN z | StyleGAN w | StyleGAN w+ |
original input | vanilla z | StyleGAN z | StyleGAN w | StyleGAN w+ |
videos = {
'vanilla z 0->1': media.read_video('output/interpolate/1_vanilla_z.gif'),
'StyleGAN z 0->1': media.read_video('output/interpolate/1_stylegan_z.gif'),
'StyleGAN w 0->1': media.read_video('output/interpolate/1_stylegan_w.gif'),
'StyleGAN w+ 0->1': media.read_video('output/interpolate/1_stylegan_w+.gif'),
}
media.show_videos(videos, columns=5, border=True, codec='gif')
videos = {
'vanilla z 2->3': media.read_video('output/interpolate/3_vanilla_z.gif'),
'StyleGAN z 2->3': media.read_video('output/interpolate/3_stylegan_z.gif'),
'StyleGAN w 2->3': media.read_video('output/interpolate/3_stylegan_w.gif'),
'StyleGAN w+ 2->3': media.read_video('output/interpolate/3_stylegan_w+.gif'),
}
media.show_videos(videos, columns=5, border=True, codec='gif')
vanilla z 0->1 | StyleGAN z 0->1 | StyleGAN w 0->1 | StyleGAN w+ 0->1 |
vanilla z 2->3 | StyleGAN z 2->3 | StyleGAN w 2->3 | StyleGAN w+ 2->3 |
Comment on previous results: The interpolation quality is basically in accordance with the image reconstruction quality. Thus, there is no surprise that StyleGAN with $w$ or $w+$ space yields the best result out of all experiments. When it comes to the interpolation consistancy, overall they are all fair good except for the StyleGan with $z$ settings on 2->3. This may be due to the poor reconstruction of the input image.
In this section we try to convert a hand drawing cat to a photorealistic cat iamge generated by GAN generator. Sepcifically we use color scribble constraints and try different settings to generate from the same drawing. Results are show below.
images = {
'original': media.read_image('output/draw/0_data.png'),
'mask': media.read_image('output/draw/0_mask.png'),
'vanilla z': media.read_image('output/draw/0_vanilla_z_0_10000.png'),
'StyleGAN w': media.read_image('output/draw/0_stylegan_w_0_1500.png'),
'StyleGAN w+': media.read_image('output/draw/0_stylegan_w+_0_5000.png'),
}
media.show_images(images, columns=5, border=True)
images = {
'original': media.read_image('output/draw/1_data.png'),
'mask': media.read_image('output/draw/1_mask.png'),
'vanilla z': media.read_image('output/draw/1_vanilla_z_0_10000.png'),
'StyleGAN w': media.read_image('output/draw/1_stylegan_w_0_3000.png'),
'StyleGAN w+': media.read_image('output/draw/1_stylegan_w+_0_5000.png'),
}
media.show_images(images, columns=5, border=True)
images = {
'original': media.read_image('output/draw/2_data.png'),
'mask': media.read_image('output/draw/2_mask.png'),
'vanilla z': media.read_image('output/draw/2_vanilla_z_0_10000.png'),
'StyleGAN w': media.read_image('output/draw/2_stylegan_w_0_2750.png'),
'StyleGAN w+': media.read_image('output/draw/2_stylegan_w+_0_5000.png'),
}
media.show_images(images, columns=5, border=True)
images = {
'original': media.read_image('output/draw/3_data.png'),
'mask': media.read_image('output/draw/3_mask.png'),
'vanilla z': media.read_image('output/draw/3_vanilla_z_0_10000.png'),
'StyleGAN w': media.read_image('output/draw/3_stylegan_w_0_5000.png'),
'StyleGAN w+': media.read_image('output/draw/3_stylegan_w+_0_4750.png'),
}
media.show_images(images, columns=5, border=True)
images = {
'original': media.read_image('output/draw/4_data.png'),
'mask': media.read_image('output/draw/4_mask.png'),
'vanilla z': media.read_image('output/draw/4_vanilla_z_0_10000.png'),
'StyleGAN w': media.read_image('output/draw/4_stylegan_w_0_3500.png'),
'StyleGAN w+': media.read_image('output/draw/4_stylegan_w+_0_2500.png'),
}
media.show_images(images, columns=5, border=True)
images = {
'original': media.read_image('output/draw/5_data.png'),
'mask': media.read_image('output/draw/5_mask.png'),
'vanilla z': media.read_image('output/draw/5_vanilla_z_0_10000.png'),
'StyleGAN w': media.read_image('output/draw/5_stylegan_w_0_5000.png'),
'StyleGAN w+': media.read_image('output/draw/5_stylegan_w+_0_5000.png'),
}
media.show_images(images, columns=5, border=True)
original | mask | vanilla z | StyleGAN w | StyleGAN w+ |
original | mask | vanilla z | StyleGAN w | StyleGAN w+ |
original | mask | vanilla z | StyleGAN w | StyleGAN w+ |
original | mask | vanilla z | StyleGAN w | StyleGAN w+ |
original | mask | vanilla z | StyleGAN w | StyleGAN w+ |
original | mask | vanilla z | StyleGAN w | StyleGAN w+ |
Comment on previous results: The overall quality is not great but discernible. I do notice that the generator sometimes is bounding between choosing an image that matches the color better vs. choosing another image that better resembles the original scribble structure. This seems to be a trade-off. A possible explanation is that the only constraint we impose is not capable of discriminating such difference so the model struggles to decide which way to go. We can solve this problem by adding constraints.
We inject style loss, same as $\mathcal{L}_{style}$ that is described in Part 1, to do a better style transfer between hand drawing images and photorealistic reconstructions. Here are the results when using a mixture of color scribble constraint and the texture constraint.
images = {
'origina': media.read_image('output/draw/0_data.png'),
'mask': media.read_image('output/draw/0_mask.png'),
'z=0.25': media.read_image('output/draw/0_vanilla_z_0.25_10000.png'),
'z=0.5': media.read_image('output/draw/0_vanilla_z_0.5_10000.png'),
'z=0.75': media.read_image('output/draw/0_vanilla_z_0.75_10000.png'),
}
media.show_images(images, columns=5, border=True)
images = {
'origina': media.read_image('output/draw/1_data.png'),
'mask': media.read_image('output/draw/1_mask.png'),
'z=0.25': media.read_image('output/draw/1_vanilla_z_0.25_10000.png'),
'z=0.5': media.read_image('output/draw/1_vanilla_z_0.5_10000.png'),
'z=0.75': media.read_image('output/draw/1_vanilla_z_0.75_10000.png'),
}
media.show_images(images, columns=5, border=True)
images = {
'origina': media.read_image('output/draw/2_data.png'),
'mask': media.read_image('output/draw/2_mask.png'),
'z=0.25': media.read_image('output/draw/2_vanilla_z_0.25_10000.png'),
'z=0.5': media.read_image('output/draw/2_vanilla_z_0.5_10000.png'),
'z=0.75': media.read_image('output/draw/2_vanilla_z_0.75_10000.png'),
}
media.show_images(images, columns=5, border=True)
images = {
'origina': media.read_image('output/draw/3_data.png'),
'mask': media.read_image('output/draw/3_mask.png'),
'z=0.25': media.read_image('output/draw/3_vanilla_z_0.25_10000.png'),
'z=0.5': media.read_image('output/draw/3_vanilla_z_0.5_10000.png'),
'z=0.75': media.read_image('output/draw/3_vanilla_z_0.75_10000.png'),
}
media.show_images(images, columns=5, border=True)
images = {
'origina': media.read_image('output/draw/4_data.png'),
'mask': media.read_image('output/draw/4_mask.png'),
'z=0.25': media.read_image('output/draw/4_vanilla_z_0.25_10000.png'),
'z=0.5': media.read_image('output/draw/4_vanilla_z_0.5_10000.png'),
'z=0.75': media.read_image('output/draw/4_vanilla_z_0.75_10000.png'),
}
media.show_images(images, columns=5, border=True)
images = {
'origina': media.read_image('output/draw/5_data.png'),
'mask': media.read_image('output/draw/5_mask.png'),
'z=0.25': media.read_image('output/draw/5_vanilla_z_0.25_10000.png'),
'z=0.5': media.read_image('output/draw/5_vanilla_z_0.5_10000.png'),
'z=0.75': media.read_image('output/draw/5_vanilla_z_0.75_10000.png'),
}
media.show_images(images, columns=5, border=True)
origina | mask | z=0.25 | z=0.5 | z=0.75 |
origina | mask | z=0.25 | z=0.5 | z=0.75 |
origina | mask | z=0.25 | z=0.5 | z=0.75 |
origina | mask | z=0.25 | z=0.5 | z=0.75 |
origina | mask | z=0.25 | z=0.5 | z=0.75 |
origina | mask | z=0.25 | z=0.5 | z=0.75 |
We upgrade the whole model to optimize on high resolution cat images and redo some tasks from Part 1 to Part 3. Some "cherry-picked" results are demonstrated below.
images = {
'src input': media.read_image('output/interpolate_high_res/12.png'),
'src reconst': media.read_image('output/interpolate_high_res/12_stylegan_w.png'),
'dst input': media.read_image('output/interpolate_high_res/13.png'),
'dst reconst': media.read_image('output/interpolate_high_res/13_stylegan_w.png'),
}
media.show_images(images, columns=5, border=True)
images = {
'src input': media.read_image('output/interpolate_high_res/16.png'),
'src reconst': media.read_image('output/interpolate_high_res/16_stylegan_w.png'),
'dst input': media.read_image('output/interpolate_high_res/17.png'),
'dst reconst': media.read_image('output/interpolate_high_res/17_stylegan_w.png'),
}
media.show_images(images, columns=5, border=True)
images = {
'src input': media.read_image('output/interpolate_high_res/18.png'),
'src reconst': media.read_image('output/interpolate_high_res/18_stylegan_w.png'),
'dst input': media.read_image('output/interpolate_high_res/19.png'),
'dst reconst': media.read_image('output/interpolate_high_res/19_stylegan_w.png'),
}
media.show_images(images, columns=5, border=True)
images = {
'src input': media.read_image('output/interpolate_high_res/20.png'),
'src reconst': media.read_image('output/interpolate_high_res/20_stylegan_w.png'),
'dst input': media.read_image('output/interpolate_high_res/21.png'),
'dst reconst': media.read_image('output/interpolate_high_res/21_stylegan_w.png'),
}
media.show_images(images, columns=5, border=True)
src input | src reconst | dst input | dst reconst |
src input | src reconst | dst input | dst reconst |
src input | src reconst | dst input | dst reconst |
src input | src reconst | dst input | dst reconst |
videos = {
'12->13': media.read_video('output/interpolate/13_stylegan_w.gif'),
'16->17': media.read_video('output/interpolate/17_stylegan_w.gif'),
'18->19': media.read_video('output/interpolate/19_stylegan_w.gif'),
'20->21': media.read_video('output/interpolate/21_stylegan_w.gif'),
}
media.show_videos(videos, columns=5, border=True, codec='gif')
12->13 | 16->17 | 18->19 | 20->21 |
images = {
'original': media.read_image('output/draw_high_res/0_data.png'),
'mask': media.read_image('output/draw_high_res/0_mask.png'),
'vanilla z': media.read_image('output/draw_high_res/0_stylegan_z_0_4500.png'),
}
media.show_images(images, columns=3, border=True)
images = {
'original': media.read_image('output/draw_high_res/1_data.png'),
'mask': media.read_image('output/draw_high_res/1_mask.png'),
'vanilla z': media.read_image('output/draw_high_res/1_stylegan_z_0_4500.png'),
}
media.show_images(images, columns=3, border=True)
images = {
'original': media.read_image('output/draw_high_res/2_data.png'),
'mask': media.read_image('output/draw_high_res/2_mask.png'),
'vanilla z': media.read_image('output/draw_high_res/2_stylegan_z_0_4500.png'),
}
media.show_images(images, columns=3, border=True)
images = {
'original': media.read_image('output/draw_high_res/3_data.png'),
'mask': media.read_image('output/draw_high_res/3_mask.png'),
'vanilla z': media.read_image('output/draw_high_res/3_stylegan_z_0_5000.png'),
}
media.show_images(images, columns=3, border=True)
images = {
'original': media.read_image('output/draw_high_res/4_data.png'),
'mask': media.read_image('output/draw_high_res/4_mask.png'),
'vanilla z': media.read_image('output/draw_high_res/4_stylegan_z_0_5000.png'),
}
media.show_images(images, columns=3, border=True)
images = {
'original': media.read_image('output/draw_high_res/5_data.png'),
'mask': media.read_image('output/draw_high_res/5_mask.png'),
'vanilla z': media.read_image('output/draw_high_res/5_stylegan_z_0_5000.png'),
}
media.show_images(images, columns=3, border=True)
original | mask | vanilla z |
original | mask | vanilla z |
original | mask | vanilla z |
original | mask | vanilla z |
original | mask | vanilla z |
original | mask | vanilla z |