%pip install click requests tqdm pyspng ninja imageio-ffmpeg==0.4.3
Requirement already satisfied: click in /opt/conda/lib/python3.7/site-packages (7.1.2) Requirement already satisfied: requests in /opt/conda/lib/python3.7/site-packages (2.25.1) Requirement already satisfied: tqdm in /opt/conda/lib/python3.7/site-packages (4.58.0) Requirement already satisfied: pyspng in /opt/conda/lib/python3.7/site-packages (0.1.0) Requirement already satisfied: ninja in /opt/conda/lib/python3.7/site-packages (1.10.0.post2) Requirement already satisfied: imageio-ffmpeg==0.4.3 in /opt/conda/lib/python3.7/site-packages (0.4.3) Requirement already satisfied: numpy in /opt/conda/lib/python3.7/site-packages (from pyspng) (1.19.5) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests) (1.26.3) Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests) (2020.12.5) Requirement already satisfied: chardet<5,>=3.0.2 in /opt/conda/lib/python3.7/site-packages (from requests) (4.0.0) Requirement already satisfied: idna<3,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests) (2.10) Note: you may need to restart the kernel to use updated packages.
# --------------------------------------------------------
# Written by Yufei Ye (https://github.com/JudyYe)
# --------------------------------------------------------
from __future__ import print_function
import argparse
import os
import os.path as osp
import numpy as np
import types
from LBFGS import FullBatchLBFGS
import torch
import torch.nn as nn
import torch.nn.functional as F
import imageio
import torchvision.utils as vutils
from torchvision.models import vgg19
from dataloader import get_data_loader
from matplotlib import pyplot as plt
from PIL import Image
device = 'cuda'
def build_model(name):
if name.startswith('vanilla'):
z_dim = 100
model_path = 'data_weight/pretrained/%s.ckpt' % name
pretrain = torch.load(model_path)
from vanilla.models import DCGenerator
model = DCGenerator(z_dim, 32, 'instance')
model.load_state_dict(pretrain)
elif name == 'stylegan':
model_path = 'data_weight/pretrained/%s.ckpt' % name
import sys
sys.path.insert(0, 'stylegan')
from stylegan import dnnlib, legacy
with dnnlib.util.open_url(model_path) as f:
model = legacy.load_network_pkl(f)['G_ema']
z_dim = model.z_dim
else:
return NotImplementedError('model [%s] is not implemented', name)
if torch.cuda.is_available():
model = model.cuda()
model.eval()
return model, z_dim
class Wrapper(nn.Module):
"""The wrapper helps to abstract stylegan / vanilla GAN, z / w latent"""
def __init__(self, args, model, z_dim):
super().__init__()
self.model, self.z_dim = model, z_dim
self.latent = args.latent
self.is_style = args.model == 'stylegan'
def forward(self, param):
if self.latent == 'z':
if self.is_style:
image = self.model(param, None)
else:
image = self.model(param)
# w / wp
else:
assert self.is_style
if self.latent == 'w':
param = param.repeat(1, self.model.mapping.num_ws, 1)
image = self.model.synthesis(param)
return image
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
# create a module to normalize input image so we can easily put it in a
# nn.Sequential
class Normalization(nn.Module):
def __init__(self, mean, std):
super(Normalization, self).__init__()
# .view the mean and std to make them [C x 1 x 1] so that they can
# directly work with image Tensor of shape [B x C x H x W].
# B is batch size. C is number of channels. H is height and W is width.
self.mean = torch.tensor(mean).view(-1, 1, 1)
self.std = torch.tensor(std).view(-1, 1, 1)
def forward(self, img):
# normalize img
return (img - self.mean) / self.std
class PerceptualLoss(nn.Module):
def __init__(self, args, layer=['conv_1']):
super(PerceptualLoss, self).__init__()
cnn = vgg19(pretrained=True).features.to(device).eval()
#cnn = copy.deepcopy(cnn)
# normalization module
normalization = Normalization(cnn_normalization_mean, cnn_normalization_std).to(device)
# just in order to have an iterable access to or list of content
content_losses = []
# assuming that cnn is a nn.Sequential, so we make a new nn.Sequential
# to put in modules that are supposed to be activated sequentially
model = nn.Sequential(normalization)
i = 0 # increment every time we see a conv
for index, iterlayer in enumerate(cnn.children()):
if isinstance(iterlayer, nn.Conv2d):
i += 1
name = 'conv_{}'.format(i)
elif isinstance(iterlayer, nn.ReLU):
name = 'relu_{}'.format(i)
# The in-place version doesn't play very nicely with the ContentLoss
# and StyleLoss we insert below. So we replace with out-of-place
# ones here.
layer = nn.ReLU(inplace=False)
elif isinstance(iterlayer, nn.MaxPool2d):
name = 'pool_{}'.format(i)
elif isinstance(iterlayer, nn.BatchNorm2d):
name = 'bn_{}'.format(i)
else:
raise RuntimeError('Unrecognized layer: {}'.format(iterlayer.__class__.__name__))
model.add_module(name, iterlayer)
if name in layer:
# add content loss:
break
self.model = model
def forward(self, pred, target, mask=None):
if mask is None:
return F.mse_loss(self.model(pred),self.model(target).detach())
else:
return F.mse_loss(self.model(pred)*mask,(self.model(target)*mask).detach())
class Criterion(nn.Module):
def __init__(self, args, mask=False, layer=['conv_2']):
super().__init__()
self.perc_wgt = args.perc_wgt
self.mask = mask
self.perc = PerceptualLoss(layer)
def forward(self, pred, target):
"""Calculate loss of prediction and target. in p-norm / perceptual space"""
if self.mask:
target, mask = target
# todo: loss with mask
MSELoss = F.mse_loss(pred*(mask.detach()>0.5),(target*(mask>0.5)).detach())
MSELossPerc = self.perc(pred,target,(mask.detach()>0.5));
loss = MSELoss + MSELossPerc *self.perc_wgt
else:
MSELossPerc = self.perc(pred,target);
MSELoss = F.mse_loss(pred,target)
loss = MSELossPerc *self.perc_wgt + MSELoss
return loss, MSELoss, MSELossPerc
def save_images(image, fname, col=8):
image = image.cpu().detach()
image = image / 2 + 0.5
image = vutils.make_grid(image, nrow=col) # (C, H, W)
image = image.numpy().transpose([1, 2, 0])
image = np.clip(255 * image, 0, 255).astype(np.uint8)
if fname is not None:
os.makedirs(os.path.dirname(fname), exist_ok=True)
imageio.imwrite(fname + '.png', image)
return image
def save_gifs(image_list, fname, col=1):
"""
:param image_list: [(N, C, H, W), ] in scale [-1, 1]
"""
image_list = [save_images(each, None, col) for each in image_list]
os.makedirs(os.path.dirname(fname), exist_ok=True)
imageio.mimsave(fname + '.gif', image_list)
def sample_noise(dim, device, latent, model, N=1, from_mean=False):
samplesInMean = 10
z = torch.randn(N, dim, device=device)
if latent == 'z':
return z
# w / w+
if from_mean:
# todo: map a bunch of z, take their mean of w / w+.
# To see how to pass stylegan2, refer to stylegan/generate_gif.py L70:81
# dummy:
waccum = model.mapping(torch.randn(N, dim, device=device).to(device), None)
if latent == 'w':
waccum = torch.mean(waccum,1,keepdim=True)
for _ in range(samplesInMean-1):
wsample = model.mapping(torch.randn(N, dim, device=device).to(device), None)
if latent == 'w':
wsample = torch.mean(wsample,1,keepdim=True)
w = waccum / samplesInMean
else:
w = model.mapping(z.to(device), None)
# todo: take a random z, map it to w / w+
# To see how to pass stylegan2, refer to stylegan/generate_gif.py L70:81
# w = z
if latent == 'w':
w = torch.mean(w,1,keepdim=True)
return w
def optimize_para(wrapper, param, target, criterion, num_step, save_prefix=None, res=False):
"""
wrapper: image = wrapper(z / w/ w+): an interface for a generator forward pass.
param: z / w / w+
target: (1, C, H, W)
criterion: loss(pred, target)
"""
param = param.requires_grad_().to(device)
optimizer = FullBatchLBFGS([param], lr=.1, line_search='Wolfe')
iter_count = [0]
def closure():
image = wrapper(param)
if args.mode=='draw':
loss, mseLoss, percepMSELoss = criterion(image,target,param)
else:
loss, mseLoss, percepMSELoss = criterion(image,target)
iter_count[0] += 1
# todo: your optimiztion
if iter_count[0] % 500 == 0 and save_prefix is not None:
# visualization code
print('iter count {} loss {:4f} {:4f} {:4f}'.format(iter_count, loss.item(), mseLoss.item(), percepMSELoss.item()))
iter_result = image.data.clamp_(-1, 1)
save_images(iter_result, save_prefix + '_%d' % iter_count[0])
return loss
loss = closure()
loss.backward()
while iter_count[0] <= num_step:
options = {'closure': closure, 'max_ls': 10}
loss, _, lr, _, F_eval, G_eval, _, _ = optimizer.step(options)
image = wrapper(param)
return param, image
def sample(args):
model, z_dim = build_model(args.model)
wrapper = Wrapper(args, model, z_dim)
batch_size = 16
# todo: complete sample_noise and wrapper
noise = sample_noise(z_dim, device, args.latent, model, batch_size)
image = wrapper(noise)
fname = os.path.join('output/forward/%s_%s' % (args.model, args.mode))
os.makedirs(os.path.dirname(fname), exist_ok=True)
save_images(image, fname)
def project(args):
# load images
loader = get_data_loader(args.input, is_train=False)
# define and load the pre-trained model
model, z_dim = build_model(args.model)
wrapper = Wrapper(args, model, z_dim)
print('model {} loaded'.format(args.model))
# todo: implement your criterion here.
criterion = Criterion(args)
# project each image
for idx, (data, _) in enumerate(loader):
target = data.to(device)
save_images(data, 'output/project/%d_data' % idx, 1)
param = sample_noise(z_dim, device, args.latent, model)
optimize_para(wrapper, param, target, criterion, args.n_iters,
'output/project/%d_%s_%s_%g' % (idx, args.model, args.latent, args.perc_wgt))
if idx >= 0:
break
args = types.SimpleNamespace()
args.model='vanilla'# choices=['vanilla'# 'stylegan'])
args.latent='z'# choices=['z'# 'w'# 'w+'])
args.n_iters=1000# help="number of optimization steps in the image projection")
args.perc_wgt=0.05# help="perc loss lambda")
args.input='data_weight/data/cat/*.png'# help="path to the input image")
args.mode='sample'
sample(args)
idx = 0
iter = 500
path = 'output/project/%d_%s_%s_%g_%g.png' % (idx, args.model, args.latent, args.perc_wgt, iter)
print(path)
image = Image.open(path)
plt.imshow(image)
output/project/0_vanilla_z_0.05_500.png
<matplotlib.image.AxesImage at 0x7fe71754e890>
args = types.SimpleNamespace()
args.model='stylegan'# choices=['vanilla'# 'stylegan'])
args.latent='w+'# choices=['z'# 'w'# 'w+'])
args.n_iters=1000# help="number of optimization steps in the image projection")
args.perc_wgt=0.05# help="perc loss lambda")
args.input='data_weight/data/cat/*.png'# help="path to the input image")
args.mode='sample'
sample(args)
Setting up PyTorch plugin "bias_act_plugin"... Done. Setting up PyTorch plugin "upfirdn2d_plugin"... Done.
args = types.SimpleNamespace()
args.model='vanilla'# choices=['vanilla'# 'stylegan'])
args.latent='z'# choices=['z'# 'w'# 'w+'])
args.n_iters=1000# help="number of optimization steps in the image projection")
args.perc_wgt=0.05# help="perc loss lambda")
args.input='data_weight/data/cat/*.png'# help="path to the input image")
args.mode='project'
idx = 0
iter = 500
path = 'output/project/%d_data.png' % (idx)
print(path)
image = Image.open(path)
plt.imshow(image)
plt.title('target')
project(args)
output/project/0_data.png 204 model vanilla loaded
/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:12: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). if sys.path[0] == '': /opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:13: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). del sys.path[0] /home/jupyter/proj5/LBFGS.py:257: UserWarning: This overload of add_ is deprecated: add_(Number alpha, Tensor other) Consider using one of the following signatures instead: add_(Tensor other, *, Number alpha) (Triggered internally at /opt/conda/conda-bld/pytorch_1607370156314/work/torch/csrc/utils/python_arg_parser.cpp:882.) p.data.add_(step_size, update[offset:offset + numel].view_as(p.data))
iter count [500] loss 0.342176 0.147557 3.892370 iter count [1000] loss 0.358353 0.158697 3.993124
idx = 0
iter = 500
path = 'output/project/%d_%s_%s_%g_%g.png' % (idx, args.model, args.latent, args.perc_wgt, iter)
print(path)
image = Image.open(path)
plt.title('after 500 iterations')
plt.imshow(image)
output/project/0_vanilla_z_0.05_500.png
<matplotlib.image.AxesImage at 0x7fe6fc922750>
idx = 0
iter = 10000
path = 'output/project/%d_%s_%s_%g_%g.png' % (idx, args.model, args.latent, args.perc_wgt, iter)
print(path)
image = Image.open(path)
plt.title('after 1000 iterations')
plt.imshow(image)
output/project/0_stylegan_z_1_10000.png
--------------------------------------------------------------------------- FileNotFoundError Traceback (most recent call last) <ipython-input-31-f46d2f48a120> in <module> 3 path = 'output/project/%d_%s_%s_%g_%g.png' % (idx, args.model, args.latent, args.perc_wgt, iter) 4 print(path) ----> 5 image = Image.open(path) 6 plt.title('after 1000 iterations') 7 plt.imshow(image) /opt/conda/lib/python3.7/site-packages/PIL/Image.py in open(fp, mode, formats) 2902 2903 if filename: -> 2904 fp = builtins.open(filename, "rb") 2905 exclusive_fp = True 2906 FileNotFoundError: [Errno 2] No such file or directory: 'output/project/0_stylegan_z_1_10000.png'
Poor reconstruction
args = types.SimpleNamespace()
args.model='stylegan'# choices=['vanilla'# 'stylegan'])
args.latent='w'# choices=['z'# 'w'# 'w+'])
args.n_iters=10000# help="number of optimization steps in the image projection")
args.perc_wgt=5.0#5# help="perc loss lambda")
args.input='data_weight/data/cat/*.png'# help="path to the input image")
args.mode='project'
idx = 0
iter = 500
path = 'output/project/%d_data.png' % (idx)
print(path)
image = Image.open(path)
plt.imshow(image)
plt.title('target')
project(args)
output/project/0_data.png 204 model stylegan loaded
/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:12: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). if sys.path[0] == '': /opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:13: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). del sys.path[0]
iter count [500] loss 16.948500 0.134638 3.362772 iter count [1000] loss 19.693346 0.144765 3.909716 iter count [1500] loss 629.884277 7.422346 124.492378 iter count [2000] loss 19.788034 0.152908 3.927025 iter count [2500] loss 2814.095703 19.278767 558.963379 iter count [3000] loss 22.322737 0.233933 4.417761 iter count [3500] loss 2769.279297 20.673231 549.721191 iter count [4000] loss 19.603922 0.134589 3.893867 iter count [4500] loss 48.259853 0.614441 9.529082 iter count [5000] loss 19.851524 0.174300 3.935445 iter count [5500] loss 50.523563 0.616701 9.981373 iter count [6000] loss 19.351143 0.140402 3.842148 iter count [6500] loss 19.185463 0.142273 3.808638 iter count [7000] loss 19.544216 0.168309 3.875182 iter count [7500] loss 20.631704 0.202930 4.085755 iter count [8000] loss 24.303215 0.259793 4.808684 iter count [8500] loss 19.617533 0.148083 3.893890 iter count [9000] loss 6791.041504 68.574554 1344.493408 iter count [9500] loss 17.887630 0.159433 3.545639 iter count [10000] loss 1943144.125000 35434.714844 381541.875000
idx = 0
iter = 500
path = 'output/project/%d_%s_%s_%g_%g.png' % (idx, args.model, args.latent, args.perc_wgt, iter)
print(path)
image = Image.open(path)
plt.title('after 500 iterations')
plt.imshow(image)
output/project/0_stylegan_w_5_500.png
<matplotlib.image.AxesImage at 0x7fe6fc7ad810>
idx = 0
iter = 9000
path = 'output/project/%d_%s_%s_%g_%g.png' % (idx, args.model, args.latent, args.perc_wgt, iter)
print(path)
image = Image.open(path)
plt.title('after 3000 iterations')
plt.imshow(image)
I saw some instability with this one, possibly too many parameters for lbfgs to be effective
args = types.SimpleNamespace()
args.model='stylegan'# choices=['vanilla'# 'stylegan'])
args.latent='w+'# choices=['z'# 'w'# 'w+'])
args.n_iters=10000# help="number of optimization steps in the image projection")
args.perc_wgt=0.05# help="perc loss lambda")
args.input='data_weight/data/cat/*.png'# help="path to the input image")
args.mode='project'
idx = 0
iter = 500
path = 'output/project/%d_data.png' % (idx)
print(path)
image = Image.open(path)
plt.imshow(image)
plt.title('target')
project(args)
output/project/0_data.png 204 model stylegan loaded
/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:12: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). if sys.path[0] == '': /opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:13: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). del sys.path[0]
iter count [500] loss 0.237195 0.086875 3.006400 iter count [1000] loss 0.299407 0.119029 3.607553 iter count [1500] loss 0.236567 0.085567 3.019995 iter count [2000] loss 0.330041 0.135870 3.883427 iter count [2500] loss 0.278460 0.102508 3.519037 iter count [3000] loss 0.236473 0.086935 2.990759 iter count [3500] loss 0.357191 0.146226 4.219302 iter count [4000] loss 0.270474 0.103678 3.335905 iter count [4500] loss 0.592601 0.284957 6.152874 iter count [5000] loss 0.210453 0.074453 2.720001 iter count [5500] loss 0.593633 0.285160 6.169464 iter count [6000] loss 0.684081 0.363711 6.407407 iter count [6500] loss 0.240056 0.088397 3.033170 iter count [7000] loss 0.323874 0.132949 3.818483 iter count [7500] loss 0.190908 0.065962 2.498916 iter count [8000] loss 0.233947 0.085764 2.963649 iter count [8500] loss 0.185506 0.063369 2.442754 iter count [9000] loss 0.187505 0.063285 2.484388 iter count [9500] loss 0.183167 0.062690 2.409536 iter count [10000] loss 0.265832 0.095240 3.411850
idx = 0
iter = 500
path = 'output/project/%d_%s_%s_%g_%g.png' % (idx, args.model, args.latent, args.perc_wgt, iter)
print(path)
image = Image.open(path)
plt.title('after 500 iterations')
plt.imshow(image)
output/project/0_stylegan_w+_0.05_500.png
<matplotlib.image.AxesImage at 0x7fe6fc2a1210>
idx = 0
iter = 10000
path = 'output/project/%d_%s_%s_%g_%g.png' % (idx, args.model, args.latent, args.perc_wgt, iter)
print(path)
image = Image.open(path)
plt.title('after 10000 iterations')
plt.imshow(image)
output/project/0_stylegan_w+_0.05_10000.png
<matplotlib.image.AxesImage at 0x7fe6fc26fbd0>
Looks very good, even after a small number of iterations.
def interpolate(args):
model, z_dim = build_model(args.model)
wrapper = Wrapper(args, model, z_dim)
countStepsInclusive = 9
# load the target and mask
loader = get_data_loader(args.input)
criterion = Criterion(args)
for idx, (image, _) in enumerate(loader):
save_images(image, 'output/interpolate/%d' % (idx))
target = image.to(device)
param = sample_noise(z_dim, device, args.latent, model, from_mean=True)
initial = wrapper(param)
save_images(initial, 'output/interpolate/%d_%s_%s_initial' % (idx, args.model, args.latent))
param, recon = optimize_para(wrapper, param, target, criterion, args.n_iters, 'output/interpolate/%d_%s_%s_%g' % (idx, args.model, args.latent, args.perc_wgt))
save_images(recon, 'output/interpolate/%d_%s_%s' % (idx, args.model, args.latent))
if idx % 2 == 0:
src = param
continue
dst = param
image_list = []
with torch.no_grad():
for ite in range(countStepsInclusive):
theta = ite/(countStepsInclusive-1)
print(theta)
param = theta * src + (1-theta) * dst
image_list.append(wrapper(param))
# todo: interpolation code
save_gifs(image_list, 'output/interpolate/%d_%s_%s' % (idx, args.model, args.latent))
if idx >= 3:
break
return
args = types.SimpleNamespace()
args.model='stylegan'# choices=['vanilla'# 'stylegan'])
args.latent='w'# choices=['z'# 'w'# 'w+'])
args.n_iters=1000# help="number of optimization steps in the image projection")
args.perc_wgt=1# help="perc loss lambda")
args.input='data_weight/data/cat/*.png'# help="path to the input image")
args.mode='interpolate'
interpolate(args)
204
/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:12: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). if sys.path[0] == '': /opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:13: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). del sys.path[0]
iter count [500] loss 2.852360 0.090001 2.762359 iter count [1000] loss 3.007579 0.098214 2.909365 iter count [500] loss 5.317166 0.179487 5.137679 iter count [1000] loss 5.033410 0.170084 4.863326 0.0 0.125 0.25 0.375 0.5 0.625 0.75 0.875 1.0 iter count [500] loss 4.335003 0.138111 4.196892 iter count [1000] loss 3.600891 0.090524 3.510366 iter count [500] loss 3.713975 0.131058 3.582917 iter count [1000] loss 3.151961 0.101543 3.050418 0.0 0.125 0.25 0.375 0.5 0.625 0.75 0.875 1.0
Interpolations look smooth. The second gif has a blurry cat as part of it, but the interpolated frames aren't blurry
class CriterionDraw(nn.Module):
def __init__(self, args, initParam, mask=False, layer=['conv_2']):
super().__init__()
self.crit = Criterion(args, mask, layer)
self.initParam = initParam.clone().detach()
model, z_dim = build_model(args.model)
wrapper = Wrapper(args, model, z_dim)
self.initIm = wrapper(self.initParam).detach()
def forward(self, pred, target, param):
"""Calculate loss of prediction and target. in p-norm / perceptual space"""
loss, MSELoss, MSELossPerc = self.crit(pred, target)
#MSELossPerc = self.crit.perc(pred,self.initIm);
l2loss = F.mse_loss(param, self.initParam)
loss = loss + 0.01*l2loss #+ args.perc_wgt * MSELossPerc
return loss, MSELoss, MSELossPerc
def draw(args):
# define and load the pre-trained model
model, z_dim = build_model(args.model)
wrapper = Wrapper(args, model, z_dim)
# load the target and mask
loader = get_data_loader(args.input, alpha=True)
for idx, (rgb, mask) in enumerate(loader):
rgb, mask = rgb.to(device), mask.to(device)
save_images(rgb, 'output/draw/%d_data' % idx, 1)
save_images(mask, 'output/draw/%d_mask' % idx, 1)
param = sample_noise(z_dim, device, args.latent, model, from_mean=True)
criterion = CriterionDraw(args, param, True)
param, recon = optimize_para(wrapper, param, [rgb,mask], criterion, args.n_iters, 'output/draw/%d_%s_%s_%g' % (idx, args.model, args.latent, args.perc_wgt))
save_images(recon, 'output/draw/%d_%s_%s' % (idx, args.model, args.latent))
# todo: optimize sketch 2 image
args = types.SimpleNamespace()
args.model='stylegan'# choices=['vanilla'# 'stylegan'])
args.latent='z'# choices=['z'# 'w'# 'w+'])
args.n_iters=10000# help="number of optimization steps in the image projection")
args.perc_wgt=1# help="perc loss lambda")
args.input='Drawing.png'# help="path to the input image")
args.mode='draw'
idx = 0
iter = 500
path = 'Drawing.png'
print(path)
image = Image.open(path)
plt.imshow(image)
plt.title('target')
draw(args)
'output/draw/%d_%s_%s' % (idx, args.model, args.latent)
Drawing.png 1
/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:12: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). if sys.path[0] == '': /opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:13: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). del sys.path[0]
iter count [500] loss 2.702252 0.162954 2.539296 iter count [1000] loss 2.673340 0.159021 2.514314 iter count [1500] loss 2.671141 0.156637 2.514497 iter count [2000] loss 2.481525 0.100375 2.380785 iter count [2500] loss 2.634187 0.157386 2.476795 iter count [3000] loss 5.713142 0.202379 5.499958 iter count [3500] loss 2.625915 0.159062 2.466849 iter count [4000] loss 2.705186 0.078826 2.621944 iter count [4500] loss 2.743344 0.162872 2.580469 iter count [5000] loss 2.120368 0.054528 2.064186 iter count [5500] loss 2.758537 0.161872 2.596662 iter count [6000] loss 2.338967 0.095161 2.243235 iter count [6500] loss 2.209269 0.062371 2.145559 iter count [7000] loss 2.330429 0.118698 2.211527 iter count [7500] loss 2.438490 0.131428 2.306953 iter count [8000] loss 2.424048 0.135036 2.288936 iter count [8500] loss 2.332158 0.124621 2.207386 iter count [9000] loss 2.513986 0.145015 2.368944 iter count [9500] loss 2.355243 0.124627 2.230463 iter count [10000] loss 2.643920 0.151577 2.492331
'output/draw/0_stylegan_z'
idx = 0
iter = 500
path = 'output/draw/%d_%s_%s_%g_%g.png' % (idx, args.model, args.latent, args.perc_wgt, iter)
print(path)
image = Image.open(path)
plt.title('after 500 iterations')
plt.imshow(image)
output/draw/0_stylegan_z_1_500.png
<matplotlib.image.AxesImage at 0x7fe6fc14f750>
idx = 0
iter = 5000
path = 'output/draw/%d_%s_%s_%g_%g.png' % (idx, args.model, args.latent, args.perc_wgt, iter)
print(path)
image = Image.open(path)
plt.title('after 5000 iterations')
plt.imshow(image)
output/draw/0_stylegan_z_1_5000.png
<matplotlib.image.AxesImage at 0x7fe6fc18ead0>
With only z to optimize, it has a lot of trouble reproducing the color of the input, but it seems to reproduce the zoom level
args = types.SimpleNamespace()
args.model='stylegan'# choices=['vanilla'# 'stylegan'])
args.latent='w'# choices=['z'# 'w'# 'w+'])
args.n_iters=10000# help="number of optimization steps in the image projection")
args.perc_wgt=1# help="perc loss lambda")
args.input='Drawing.png'# help="path to the input image")
args.mode='draw'
idx = 0
iter = 500
path = 'Drawing.png'
print(path)
image = Image.open(path)
plt.imshow(image)
plt.title('target')
draw(args)
'output/draw/%d_%s_%s' % (idx, args.model, args.latent)
Drawing.png 1
/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:12: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). if sys.path[0] == '': /opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:13: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). del sys.path[0]
iter count [500] loss 1.549890 0.034893 1.509155 iter count [1000] loss 1.301463 0.008336 1.287699 iter count [1500] loss 1.267575 0.012505 1.250288 iter count [2000] loss 1.390269 0.030032 1.347930 iter count [2500] loss 1.394323 0.011730 1.380482 iter count [3000] loss 1.255619 0.005215 1.233957 iter count [3500] loss 1.413735 0.006407 1.406524 iter count [4000] loss 1.287025 0.017149 1.235233 iter count [4500] loss 2.102963 0.122126 1.980837 iter count [5000] loss 1.165328 0.004910 1.147495 iter count [5500] loss 2.102387 0.121903 1.980484 iter count [6000] loss 1.132562 0.007984 1.116429 iter count [6500] loss 1.039030 0.008503 1.014045 iter count [7000] loss 7.268490 1.068840 6.030033 iter count [7500] loss 1.020605 0.014255 0.985070 iter count [8000] loss 1.364294 0.013917 1.298249 iter count [8500] loss 1.141460 0.005933 1.128488 iter count [9000] loss 1.180856 0.014252 1.148262 iter count [9500] loss 1.016343 0.007897 0.999450 iter count [10000] loss 1.166458 0.007156 1.153436
'output/draw/0_stylegan_w'
idx = 0
iter = 2000
path = 'output/draw/%d_%s_%s_%g_%g.png' % (idx, args.model, args.latent, args.perc_wgt, iter)
print(path)
image = Image.open(path)
plt.title('after 2000 iterations')
plt.imshow(image)
output/draw/0_stylegan_w_1_2000.png
<matplotlib.image.AxesImage at 0x7fe6f56625d0>
idx = 0
iter = 5500
path = 'output/draw/%d_%s_%s_%g_%g.png' % (idx, args.model, args.latent, args.perc_wgt, iter)
print(path)
image = Image.open(path)
plt.title('after 5500 iterations')
plt.imshow(image)
output/draw/0_stylegan_w_1_5500.png
<matplotlib.image.AxesImage at 0x7fe6f55ce890>
This one is pretty unstable, many of the iterations are generate non-cat images. Sometimes it gets back on the cat manifold, but the zoom is way off