实验要求与基本流程
实验要求
- 完成上一节实验课内容,理解GAN(Generative Adversarial Networks,生成对抗网络)的原理与训练方法.
- 结合理论课内容, 了解CGAN, pix2pix等模型基本结构和主要作用.
- 阅读实验指导书的实验内容,按照提示运行以及补充实验代码,或者简要回答问题.提交作业时,保留实验结果.
实验流程
- CGAN
- pix2pix
CGAN(Conditional GAN)
由上节课的内容可以看到,GAN可以用来生成接近真实的图片,但普通的GAN太过自由而不可控了,而CGAN(Conditional GAN)是一种带条件约束的GAN,在生成模型(D)和判别模型(G)的建模中均引入条件变量.这些条件变量可以基于多种信息,例如类别标签,用于图像修复的部分数据等等.在这个接下来这个CGAN中我们引入类别标签作为G和D的条件变量.
在下面的CGAN网络结构(与上节课展示的DCGAN模型相似)中,与之前的模型最大的不同是在G和D的输入中加入了类别标签labels,在G中,labels(用one-hot向量表示,如有3个类(0/1/2),第2类的one-hot向量为[0, 0, 1])和原来的噪声z一起输入到第一层全连接层中,在D中,labels和输入图片一起输入到卷积层中,labels中每个label用大小为(class_num,image_size,image_size)的张量表示,其正确类别的channel全为1,其余channel全为0.
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
%matplotlib inline
from utils import initialize_weights
class DCGenerator(nn.Module):
def __init__(self, image_size=32, latent_dim=64, output_channel=1, class_num=3):
super(DCGenerator, self).__init__()
self.image_size = image_size
self.latent_dim = latent_dim
self.output_channel = output_channel
self.class_num = class_num
self.init_size = image_size // 8
# fc: Linear -> BN -> ReLU
self.fc = nn.Sequential(
nn.Linear(latent_dim + class_num, 512 * self.init_size ** 2),
nn.BatchNorm1d(512 * self.init_size ** 2),
nn.ReLU(inplace=True)
)
# deconv: ConvTranspose2d(4, 2, 1) -> BN -> ReLU ->
# ConvTranspose2d(4, 2, 1) -> BN -> ReLU ->
# ConvTranspose2d(4, 2, 1) -> Tanh
self.deconv = nn.Sequential(
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, output_channel, 4, stride=2, padding=1),
nn.Tanh(),
)
initialize_weights(self)
def forward(self, z, labels):
"""
z : noise vector
labels : one-hot vector
"""
input_ = torch.cat((z, labels), dim=1)
out = self.fc(input_)
out = out.view(out.shape[0], 512, self.init_size, self.init_size)
img = self.deconv(out)
return img
class DCDiscriminator(nn.Module):
def __init__(self, image_size=32, input_channel=1, class_num=3, sigmoid=True):
super(DCDiscriminator, self).__init__()
self.image_size = image_size
self.input_channel = input_channel
self.class_num = class_num
self.fc_size = image_size // 8
# conv: Conv2d(3,2,1) -> LeakyReLU
# Conv2d(3,2,1) -> BN -> LeakyReLU
# Conv2d(3,2,1) -> BN -> LeakyReLU
self.conv = nn.Sequential(
nn.Conv2d(input_channel + class_num, 128, 3, 2, 1),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, 3, 2, 1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 512, 3, 2, 1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),
)
# fc: Linear -> Sigmoid
self.fc = nn.Sequential(
nn.Linear(512 * self.fc_size * self.fc_size, 1),
)
if sigmoid:
self.fc.add_module('sigmoid', nn.Sigmoid())
initialize_weights(self)
def forward(self, img, labels):
"""
img : input image
labels : (batch_size, class_num, image_size, image_size)
the i-th channel is filled with 1, and others is filled with 0.
"""
input_ = torch.cat((img, labels), dim=1)
out = self.conv(input_)
out = out.view(out.shape[0], -1)
out = self.fc(out)
return out
数据集
我们使用我们熟悉的MNIST手写体数据集来训练我们的CGAN,我们同样提供了一个简化版本的数据集来加快我们的训练速度,与上次的数据集不一样的是,这次的数据集包含0到9共10类的手写数字,每类各200张,共2000张.图片同样为28*28的单通道灰度图(我们将其resize到32*32).下面是加载mnist数据集的代码.
def load_mnist_data():
"""
load mnist(0,1,2) dataset
"""
transform = torchvision.transforms.Compose([
# transform to 1-channel gray image since we reading image in RGB mode
transforms.Grayscale(1),
# resize image from 28 * 28 to 32 * 32
transforms.Resize(32),
transforms.ToTensor(),
# normalize with mean=0.5 std=0.5
transforms.Normalize(mean=(0.5, ),
std=(0.5, ))
])
train_dataset = torchvision.datasets.ImageFolder(root='./data/mnist', transform=transform)
return train_dataset
接下来让我们查看一下各个类上真实的手写体数据集的数据吧.(运行一下2个cell的代码,无需理解)
def denorm(x):
# denormalize
out = (x + 1) / 2
return out.clamp(0, 1)
from utils import show
"""
you can pass code in this cell
"""
# show mnist real data
train_dataset = load_mnist_data()
images = []
for j in range(5):
for i in range(10):
images.append(train_dataset[i * 200 + j][0])
show(torchvision.utils.make_grid(denorm(torch.stack(images)), nrow=10))
训练部分的代码代码与之前相似, 不同的地方在于要根据类别生成y_vec(one-hot向量如类别2对应[0,1,0,0,0,0,0,0,0,0])和y_fill(将y_vec扩展到大小为(class_num, image_size, image_size),正确的类别的channel全为1,其他channel全为0),分别输入G和D作为条件变量.其他训练过程与普通的GAN相似.我们可以先为每个类别标签生成vecs和fills.
# class number
class_num = 10
# image size and channel
image_size=32
image_channel=1
# vecs: one-hot vectors of size(class_num, class_num)
# fills: vecs expand to size(class_num, class_num, image_size, image_size)
vecs = torch.eye(class_num)
fills = vecs.unsqueeze(2).unsqueeze(3).expand(class_num, class_num, image_size, image_size)
print(vecs)
print(fills)
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
tensor([[[[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
...,
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
…,
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.]],
[[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
...,
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
…,
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]],
...,
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
…,
[[[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
…,
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
...,
[[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
…,
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
...,
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
…,
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
...,
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]]]])
def train(trainloader, G, D, G_optimizer, D_optimizer, loss_func, device, z_dim, class_num):
"""
train a GAN with model G and D in one epoch
Args:
trainloader: data loader to train
G: model Generator
D: model Discriminator
G_optimizer: optimizer of G(etc. Adam, SGD)
D_optimizer: optimizer of D(etc. Adam, SGD)
loss_func: Binary Cross Entropy(BCE) or MSE loss function
device: cpu or cuda device
z_dim: the dimension of random noise z
"""
# set train mode
D.train()
G.train()
D_total_loss = 0
G_total_loss = 0
for i, (x, y) in enumerate(trainloader):
x = x.to(device)
batch_size_ = x.size(0)
image_size = x.size(2)
# real label and fake label
real_label = torch.ones(batch_size_, 1).to(device)
fake_label = torch.zeros(batch_size_, 1).to(device)
# y_vec: (batch_size, class_num) one-hot vector, for example, [0,0,0,0,1,0,0,0,0,0] (label: 4)
y_vec = vecs[y.long()].to(device)
# y_fill: (batch_size, class_num, image_size, image_size)
# y_fill: the i-th channel is filled with 1, and others is filled with 0.
y_fill = fills[y.long()].to(device)
z = torch.rand(batch_size_, z_dim).to(device)
# update D network
# D optimizer zero grads
D_optimizer.zero_grad()
# D real loss from real images
d_real = D(x, y_fill)
d_real_loss = loss_func(d_real, real_label)
# D fake loss from fake images generated by G
g_z = G(z, y_vec)
d_fake = D(g_z, y_fill)
d_fake_loss = loss_func(d_fake, fake_label)
# D backward and step
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
D_optimizer.step()
# update G network
# G optimizer zero gradsinput_dim=100, output_dim=1, input_size=32, class_num=10
G_optimizer.zero_grad()
# G loss
g_z = G(z, y_vec)
d_fake = D(g_z, y_fill)
g_loss = loss_func(d_fake, real_label)
# G backward and step
g_loss.backward()
G_optimizer.step()
D_total_loss += d_loss.item()
G_total_loss += g_loss.item()
return D_total_loss / len(trainloader), G_total_loss / len(trainloader)
visualize_results和run_gan的代码不再详细说明.
def visualize_results(G, device, z_dim, class_num, class_result_size=5):
G.eval()
z = torch.rand(class_num * class_result_size, z_dim).to(device)
y = torch.LongTensor([i for i in range(class_num)] * class_result_size)
y_vec = vecs[y.long()].to(device)
g_z = G(z, y_vec)
show(torchvision.utils.make_grid(denorm(g_z.detach().cpu()), nrow=class_num))
def run_gan(trainloader, G, D, G_optimizer, D_optimizer, loss_func, n_epochs, device, latent_dim, class_num):
d_loss_hist = []
g_loss_hist = []
for epoch in range(n_epochs):
d_loss, g_loss = train(trainloader, G, D, G_optimizer, D_optimizer, loss_func, device,
latent_dim, class_num)
print('Epoch {}: Train D loss: {:.4f}, G loss: {:.4f}'.format(epoch, d_loss, g_loss))
d_loss_hist.append(d_loss)
g_loss_hist.append(g_loss)
if epoch == 0 or (epoch + 1) % 10 == 0:
visualize_results(G, device, latent_dim, class_num)
return d_loss_hist, g_loss_hist
下面尝试训练一下我们的CGAN吧.
# hyper params
# z dim
latent_dim = 100
# Adam lr and betas
learning_rate = 0.0002
betas = (0.5, 0.999)
# epochs and batch size
n_epochs = 120
batch_size = 32
# device : cpu or cuda:0/1/2/3
device = torch.device('cuda:0')
# mnist dataset and dataloader
train_dataset = load_mnist_data()
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# use BCELoss as loss function
bceloss = nn.BCELoss().to(device)
# G and D model
G = DCGenerator(image_size=image_size, latent_dim=latent_dim, output_channel=image_channel, class_num=class_num)
D = DCDiscriminator(image_size=image_size, input_channel=image_channel, class_num=class_num)
G.to(device)
D.to(device)
print(D)
print(G)
# G and D optimizer, use Adam or SGD
G_optimizer = optim.Adam(G.parameters(), lr=learning_rate, betas=betas)
D_optimizer = optim.Adam(D.parameters(), lr=learning_rate, betas=betas)
DCDiscriminator(
(conv): Sequential(
(0): Conv2d(11, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): LeakyReLU(negative_slope=0.2)
(2): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2)
(5): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2)
)
(fc): Sequential(
(0): Linear(in_features=8192, out_features=1, bias=True)
(sigmoid): Sigmoid()
)
)
DCGenerator(
(fc): Sequential(
(0): Linear(in_features=110, out_features=8192, bias=True)
(1): BatchNorm1d(8192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace)
)
(deconv): Sequential(
(0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace)
(3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace)
(6): ConvTranspose2d(128, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(7): Tanh()
)
)
d_loss_hist, g_loss_hist = run_gan(trainloader, G, D, G_optimizer, D_optimizer, bceloss,
n_epochs, device, latent_dim, class_num)
Epoch 0: Train D loss: 0.2962, G loss: 3.8550
Epoch 1: Train D loss: 0.6089, G loss: 3.6378
Epoch 2: Train D loss: 0.8812, G loss: 2.2457
Epoch 3: Train D loss: 0.8877, G loss: 1.9269
Epoch 4: Train D loss: 0.9665, G loss: 1.8893
Epoch 5: Train D loss: 0.9414, G loss: 1.7735
Epoch 6: Train D loss: 0.8708, G loss: 1.8289
Epoch 7: Train D loss: 0.8942, G loss: 1.7005
Epoch 8: Train D loss: 0.9111, G loss: 1.7255
Epoch 9: Train D loss: 0.8998, G loss: 1.7084
Epoch 10: Train D loss: 0.9060, G loss: 1.6594
Epoch 11: Train D loss: 0.9331, G loss: 1.6657
Epoch 12: Train D loss: 0.9313, G loss: 1.6259
Epoch 13: Train D loss: 0.9475, G loss: 1.6301
Epoch 14: Train D loss: 0.9856, G loss: 1.6319
Epoch 15: Train D loss: 0.9712, G loss: 1.5905
Epoch 16: Train D loss: 0.9892, G loss: 1.5713
Epoch 17: Train D loss: 1.0118, G loss: 1.5743
Epoch 18: Train D loss: 1.0041, G loss: 1.5457
Epoch 19: Train D loss: 1.0028, G loss: 1.6262
Epoch 20: Train D loss: 1.0085, G loss: 1.5393
Epoch 21: Train D loss: 1.0020, G loss: 1.6078
Epoch 22: Train D loss: 0.9486, G loss: 1.6651
Epoch 23: Train D loss: 0.9706, G loss: 1.6328
Epoch 24: Train D loss: 0.9127, G loss: 1.6835
Epoch 25: Train D loss: 0.9416, G loss: 1.6948
Epoch 26: Train D loss: 0.8698, G loss: 1.7693
Epoch 27: Train D loss: 0.8571, G loss: 1.8435
Epoch 28: Train D loss: 0.8520, G loss: 1.8850
Epoch 29: Train D loss: 0.7613, G loss: 2.0046
Epoch 30: Train D loss: 0.8708, G loss: 1.9706
Epoch 31: Train D loss: 0.6392, G loss: 2.0542
Epoch 32: Train D loss: 0.7748, G loss: 2.0904
Epoch 33: Train D loss: 0.7603, G loss: 2.1889
Epoch 34: Train D loss: 0.6701, G loss: 2.2419
Epoch 35: Train D loss: 0.4888, G loss: 2.4315
Epoch 36: Train D loss: 0.6143, G loss: 2.4058
Epoch 37: Train D loss: 0.5030, G loss: 2.5943
Epoch 38: Train D loss: 0.6665, G loss: 2.5604
Epoch 39: Train D loss: 0.2921, G loss: 2.8537
Epoch 40: Train D loss: 0.7130, G loss: 2.7242
Epoch 41: Train D loss: 0.3132, G loss: 2.9228
Epoch 42: Train D loss: 0.4735, G loss: 2.9304
Epoch 43: Train D loss: 0.1570, G loss: 3.3429
Epoch 44: Train D loss: 0.6236, G loss: 3.0557
Epoch 45: Train D loss: 0.2389, G loss: 3.2241
Epoch 46: Train D loss: 0.1189, G loss: 3.6270
Epoch 47: Train D loss: 0.1112, G loss: 3.8986
Epoch 48: Train D loss: 0.5740, G loss: 3.6167
Epoch 49: Train D loss: 0.2161, G loss: 3.4319
Epoch 50: Train D loss: 0.1162, G loss: 3.9703
Epoch 51: Train D loss: 0.0875, G loss: 4.1047
Epoch 52: Train D loss: 1.1022, G loss: 2.5413
Epoch 53: Train D loss: 0.1822, G loss: 3.4868
Epoch 54: Train D loss: 0.0919, G loss: 3.9516
Epoch 55: Train D loss: 0.0657, G loss: 4.2033
Epoch 56: Train D loss: 0.0595, G loss: 4.3836
Epoch 57: Train D loss: 0.0533, G loss: 4.5497
Epoch 58: Train D loss: 0.7047, G loss: 3.6997
Epoch 59: Train D loss: 0.2122, G loss: 3.7186
Epoch 60: Train D loss: 0.0671, G loss: 4.2783
Epoch 61: Train D loss: 0.0534, G loss: 4.5652
Epoch 62: Train D loss: 0.0490, G loss: 4.6673
Epoch 63: Train D loss: 0.0387, G loss: 4.8734
Epoch 64: Train D loss: 0.0347, G loss: 4.9742
Epoch 65: Train D loss: 0.2409, G loss: 5.1782
Epoch 66: Train D loss: 1.0484, G loss: 2.4625
Epoch 67: Train D loss: 0.4583, G loss: 3.2699
Epoch 68: Train D loss: 0.4521, G loss: 3.4144
Epoch 69: Train D loss: 0.1248, G loss: 4.0661
Epoch 70: Train D loss: 0.0579, G loss: 4.4066
Epoch 71: Train D loss: 0.0474, G loss: 4.7067
Epoch 72: Train D loss: 0.0375, G loss: 4.8429
Epoch 73: Train D loss: 0.0304, G loss: 5.0606
Epoch 74: Train D loss: 0.0243, G loss: 5.2481
Epoch 75: Train D loss: 0.0260, G loss: 5.3255
Epoch 76: Train D loss: 0.0225, G loss: 5.4283
Epoch 77: Train D loss: 1.2070, G loss: 2.4013
Epoch 78: Train D loss: 0.6930, G loss: 2.4867
Epoch 79: Train D loss: 0.5972, G loss: 3.1937
Epoch 80: Train D loss: 0.2452, G loss: 3.6573
Epoch 81: Train D loss: 0.0592, G loss: 4.4053
Epoch 82: Train D loss: 0.0456, G loss: 4.7146
Epoch 83: Train D loss: 0.0366, G loss: 4.8923
Epoch 84: Train D loss: 0.0303, G loss: 5.0758
Epoch 85: Train D loss: 0.0233, G loss: 5.2704
Epoch 86: Train D loss: 0.0254, G loss: 5.4018
Epoch 87: Train D loss: 0.8972, G loss: 3.2275
Epoch 88: Train D loss: 0.6262, G loss: 2.5182
Epoch 89: Train D loss: 0.5316, G loss: 3.4189
Epoch 90: Train D loss: 0.5059, G loss: 3.2730
Epoch 91: Train D loss: 0.0708, G loss: 4.4447
Epoch 92: Train D loss: 0.0399, G loss: 4.8059
Epoch 93: Train D loss: 0.0292, G loss: 5.0672
Epoch 94: Train D loss: 0.0242, G loss: 5.1704
Epoch 95: Train D loss: 0.0206, G loss: 5.3694
Epoch 96: Train D loss: 0.0209, G loss: 5.4811
Epoch 97: Train D loss: 0.0174, G loss: 5.5394
Epoch 98: Train D loss: 0.0174, G loss: 5.5801
Epoch 99: Train D loss: 0.0167, G loss: 5.7518
Epoch 100: Train D loss: 0.0147, G loss: 5.8225
Epoch 101: Train D loss: 0.0153, G loss: 5.9176
Epoch 102: Train D loss: 0.0133, G loss: 6.0194
Epoch 103: Train D loss: 0.0114, G loss: 6.0404
Epoch 104: Train D loss: 0.0125, G loss: 6.0783
Epoch 105: Train D loss: 0.0102, G loss: 6.2466
Epoch 106: Train D loss: 0.0109, G loss: 6.2441
Epoch 107: Train D loss: 0.6059, G loss: 5.0261
Epoch 108: Train D loss: 0.5775, G loss: 2.7050
Epoch 109: Train D loss: 0.5215, G loss: 2.7918
Epoch 110: Train D loss: 0.5460, G loss: 2.7928
Epoch 111: Train D loss: 0.5656, G loss: 3.0143
Epoch 112: Train D loss: 0.5745, G loss: 3.1358
Epoch 113: Train D loss: 0.3454, G loss: 3.3785
Epoch 114: Train D loss: 0.6632, G loss: 3.3066
Epoch 115: Train D loss: 0.1403, G loss: 3.9030
Epoch 116: Train D loss: 0.0821, G loss: 4.3970
Epoch 117: Train D loss: 0.4486, G loss: 3.9750
Epoch 118: Train D loss: 0.4547, G loss: 3.1868
Epoch 119: Train D loss: 0.7379, G loss: 3.3208
from utils import loss_plot
loss_plot(d_loss_hist, g_loss_hist)
作业 :
- 在D中,可以将输入图片和labels分别通过两个不同的卷积层然后在维度1合并(通道上合并),再一起送去接下来的网络结构.网络部分结构已经在DCDiscriminator中写好,请在补充forward函数完成上述功能并再次使用同样的数据集训练CGAN.与之前的结果对比,说说有什么不同?
答:与之前的结果相比,loss变化整体趋势差别不大,不过这里的loss值波动比较频繁,而且生成器的loss最低值比之前大。判别器和生成器在训练前期相互对抗,使得生成器的loss下降,判别器的loss在上升,但是在后期,判别器占主导地位,其loss能稳定在较低值附近震荡,使得生成器的loss不断上升。另外这里的生成图片效果没有这么好,比如3、4和6都不容易识别出来。
class DCDiscriminator1(nn.Module):
def __init__(self, image_size=32, input_channel=1, class_num=3, sigmoid=True):
super().__init__()
self.image_size = image_size
self.input_channel = input_channel
self.class_num = class_num
self.fc_size = image_size // 8
# model : img -> conv1_1
# labels -> conv1_2
# (img U labels) -> Conv2d(3,2,1) -> BN -> LeakyReLU
# Conv2d(3,2,1) -> BN -> LeakyReLU
self.conv1_1 = nn.Sequential(nn.Conv2d(input_channel, 64, 3, 2, 1),
nn.BatchNorm2d(64))
self.conv1_2 = nn.Sequential(nn.Conv2d(class_num, 64, 3, 2, 1),
nn.BatchNorm2d(64))
self.conv = nn.Sequential(
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, 3, 2, 1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 512, 3, 2, 1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),
)
# fc: Linear -> Sigmoid
self.fc = nn.Sequential(
nn.Linear(512 * self.fc_size * self.fc_size, 1),
)
if sigmoid:
self.fc.add_module('sigmoid', nn.Sigmoid())
initialize_weights(self)
def forward(self, img, labels):
"""
img : input image
labels : (batch_size, class_num, image_size, image_size)
the i-th channel is filled with 1, and others is filled with 0.
"""
img_out = self.conv1_1(img)
labels_out = self.conv1_2(labels)
out = torch.cat((img_out, labels_out), dim=1)
out = self.conv(out)
out = out.view(out.shape[0], -1)
out = self.fc(out)
return out
return out
from utils import loss_plot
# hyper params
# device : cpu or cuda:0/1/2/3
device = torch.device('cuda:0')
# G and D model
G = DCGenerator(image_size=image_size, latent_dim=latent_dim, output_channel=image_channel, class_num=class_num)
D = DCDiscriminator1(image_size=image_size, input_channel=image_channel, class_num=class_num)
G.to(device)
D.to(device)
# G and D optimizer, use Adam or SGD
G_optimizer = optim.Adam(G.parameters(), lr=learning_rate, betas=betas)
D_optimizer = optim.Adam(D.parameters(), lr=learning_rate, betas=betas)
d_loss_hist, g_loss_hist = run_gan(trainloader, G, D, G_optimizer, D_optimizer, bceloss,
n_epochs, device, latent_dim, class_num)
loss_plot(d_loss_hist, g_loss_hist)
Epoch 0: Train D loss: 0.5783, G loss: 4.0314
Epoch 1: Train D loss: 0.3084, G loss: 4.6929
Epoch 2: Train D loss: 0.3652, G loss: 4.4327
Epoch 3: Train D loss: 0.6010, G loss: 3.2752
Epoch 4: Train D loss: 0.5115, G loss: 2.9438
Epoch 5: Train D loss: 0.5711, G loss: 2.8227
Epoch 6: Train D loss: 0.5215, G loss: 2.5894
Epoch 7: Train D loss: 0.5886, G loss: 2.6360
Epoch 8: Train D loss: 0.5324, G loss: 2.7322
Epoch 9: Train D loss: 0.4396, G loss: 2.6589
Epoch 10: Train D loss: 0.5049, G loss: 2.6881
Epoch 11: Train D loss: 0.6259, G loss: 2.5495
Epoch 12: Train D loss: 0.5501, G loss: 2.6990
Epoch 13: Train D loss: 0.6089, G loss: 2.5458
Epoch 14: Train D loss: 0.6249, G loss: 2.4948
Epoch 15: Train D loss: 0.6507, G loss: 2.2958
Epoch 16: Train D loss: 0.5265, G loss: 2.5526
Epoch 17: Train D loss: 0.6500, G loss: 2.4998
Epoch 18: Train D loss: 0.5119, G loss: 2.6624
Epoch 19: Train D loss: 0.7852, G loss: 2.2363
Epoch 20: Train D loss: 0.5294, G loss: 2.3413
Epoch 21: Train D loss: 0.6635, G loss: 2.7594
Epoch 22: Train D loss: 0.5128, G loss: 2.6446
Epoch 23: Train D loss: 0.6374, G loss: 2.4458
Epoch 24: Train D loss: 0.5262, G loss: 2.8333
Epoch 25: Train D loss: 0.4865, G loss: 2.6566
Epoch 26: Train D loss: 0.6546, G loss: 2.6343
Epoch 27: Train D loss: 0.6002, G loss: 2.8760
Epoch 28: Train D loss: 0.2794, G loss: 3.1967
Epoch 29: Train D loss: 0.3933, G loss: 3.2833
Epoch 30: Train D loss: 0.3230, G loss: 3.3384
Epoch 31: Train D loss: 0.4659, G loss: 3.4798
Epoch 32: Train D loss: 0.4419, G loss: 3.2220
Epoch 33: Train D loss: 0.7314, G loss: 2.6443
Epoch 34: Train D loss: 0.2897, G loss: 3.0850
Epoch 35: Train D loss: 0.2233, G loss: 3.6760
Epoch 36: Train D loss: 0.2126, G loss: 4.0898
Epoch 37: Train D loss: 0.8669, G loss: 2.8141
Epoch 38: Train D loss: 0.3106, G loss: 3.0525
Epoch 39: Train D loss: 0.1445, G loss: 3.7176
Epoch 40: Train D loss: 0.0959, G loss: 4.1259
Epoch 41: Train D loss: 0.9976, G loss: 3.0617
Epoch 42: Train D loss: 0.4574, G loss: 2.6349
Epoch 43: Train D loss: 0.6087, G loss: 2.9328
Epoch 44: Train D loss: 0.1489, G loss: 3.6072
Epoch 45: Train D loss: 0.0740, G loss: 4.2103
Epoch 46: Train D loss: 0.0629, G loss: 4.3956
Epoch 47: Train D loss: 0.5998, G loss: 3.2734
Epoch 48: Train D loss: 0.1357, G loss: 4.0024
Epoch 49: Train D loss: 0.0552, G loss: 4.5454
Epoch 50: Train D loss: 0.0484, G loss: 4.8041
Epoch 51: Train D loss: 0.6982, G loss: 3.7349
Epoch 52: Train D loss: 0.5501, G loss: 2.6109
Epoch 53: Train D loss: 0.2982, G loss: 3.6175
Epoch 54: Train D loss: 0.1213, G loss: 4.3322
Epoch 55: Train D loss: 0.5588, G loss: 3.3895
Epoch 56: Train D loss: 0.0901, G loss: 4.3445
Epoch 57: Train D loss: 0.0476, G loss: 4.8930
Epoch 58: Train D loss: 0.0491, G loss: 5.0212
Epoch 59: Train D loss: 0.0369, G loss: 5.0965
Epoch 60: Train D loss: 0.0337, G loss: 5.2226
Epoch 61: Train D loss: 1.2331, G loss: 2.7346
Epoch 62: Train D loss: 0.3062, G loss: 3.3701
Epoch 63: Train D loss: 0.4700, G loss: 3.3737
Epoch 64: Train D loss: 0.1190, G loss: 4.2390
Epoch 65: Train D loss: 0.0480, G loss: 4.6856
Epoch 66: Train D loss: 0.1715, G loss: 4.9363
Epoch 67: Train D loss: 0.9050, G loss: 2.3453
Epoch 68: Train D loss: 0.3282, G loss: 3.5270
Epoch 69: Train D loss: 0.3222, G loss: 3.8528
Epoch 70: Train D loss: 0.0641, G loss: 4.4396
Epoch 71: Train D loss: 0.0386, G loss: 4.9179
Epoch 72: Train D loss: 0.3381, G loss: 4.9477
Epoch 73: Train D loss: 0.7052, G loss: 2.2826
Epoch 74: Train D loss: 0.3881, G loss: 3.3018
Epoch 75: Train D loss: 0.3167, G loss: 3.6229
Epoch 76: Train D loss: 0.0612, G loss: 4.5307
Epoch 77: Train D loss: 0.0349, G loss: 5.0571
Epoch 78: Train D loss: 0.0259, G loss: 5.2854
Epoch 79: Train D loss: 0.6566, G loss: 4.2716
Epoch 80: Train D loss: 0.5006, G loss: 2.6297
Epoch 81: Train D loss: 0.4957, G loss: 3.3940
Epoch 82: Train D loss: 0.1931, G loss: 3.7004
Epoch 83: Train D loss: 0.0516, G loss: 4.6862
Epoch 84: Train D loss: 0.0361, G loss: 5.0259
Epoch 85: Train D loss: 0.0347, G loss: 5.2946
Epoch 86: Train D loss: 0.5135, G loss: 3.3887
Epoch 87: Train D loss: 0.0544, G loss: 4.6875
Epoch 88: Train D loss: 0.0276, G loss: 5.2771
Epoch 89: Train D loss: 0.0223, G loss: 5.4915
Epoch 90: Train D loss: 0.0286, G loss: 5.5067
Epoch 91: Train D loss: 0.0169, G loss: 5.7693
Epoch 92: Train D loss: 0.1816, G loss: 5.8505
Epoch 93: Train D loss: 0.6672, G loss: 2.9792
Epoch 94: Train D loss: 0.3550, G loss: 3.5908
Epoch 95: Train D loss: 0.1152, G loss: 4.2985
Epoch 96: Train D loss: 0.3241, G loss: 4.6347
Epoch 97: Train D loss: 0.1568, G loss: 4.1708
Epoch 98: Train D loss: 0.0337, G loss: 5.1337
Epoch 99: Train D loss: 0.2742, G loss: 5.3655
Epoch 100: Train D loss: 0.4777, G loss: 2.9118
Epoch 101: Train D loss: 0.3332, G loss: 3.7523
Epoch 102: Train D loss: 0.0812, G loss: 4.8293
Epoch 103: Train D loss: 0.0243, G loss: 5.4633
Epoch 104: Train D loss: 0.0205, G loss: 5.6605
Epoch 105: Train D loss: 0.0175, G loss: 5.7829
Epoch 106: Train D loss: 0.0174, G loss: 5.8838
Epoch 107: Train D loss: 0.0143, G loss: 6.0738
Epoch 108: Train D loss: 0.3694, G loss: 5.1934
Epoch 109: Train D loss: 0.2177, G loss: 4.0592
Epoch 110: Train D loss: 0.0368, G loss: 5.0899
Epoch 111: Train D loss: 0.3410, G loss: 5.0052
Epoch 112: Train D loss: 0.7964, G loss: 2.3244
Epoch 113: Train D loss: 0.6525, G loss: 2.7550
Epoch 114: Train D loss: 0.5676, G loss: 2.8861
Epoch 115: Train D loss: 0.4401, G loss: 3.0373
Epoch 116: Train D loss: 0.3829, G loss: 3.2658
Epoch 117: Train D loss: 0.2251, G loss: 3.6322
Epoch 118: Train D loss: 0.2588, G loss: 3.9562
Epoch 119: Train D loss: 0.2042, G loss: 3.9001
- 在D中,可以将输入图片通过1个卷积层然后和(尺寸与输入图片一致的)labels在维度1合并(通道上合并),再一起送去接下来的网络结构.网络部分结构已经在DCDiscriminator中写好,请在补充forward函数完成上述功能,并再次使用同样的数据集训练CGAN.与之前的结果对比,说说有什么不同?
答:与之前的结果相比,loss变化整体趋势差别不大,这里的loss值波动也挺频繁,但是波动的范围没有这么大。判别器和生成器在训练前期相互对抗,使得生成器的loss下降,判别器的loss在上升,但是在后期,判别器占主导地位,其loss能稳定在较低值附近震荡,使得生成器的loss不断上升。另外这里的生成图片效果比前面的两次都差,比如2、4、6、7、8和9的生成效果都不太好。
class DCDiscriminator2(nn.Module):
def __init__(self, image_size=32, input_channel=1, class_num=3, sigmoid=True):
super().__init__()
self.image_size = image_size
self.input_channel = input_channel
self.class_num = class_num
self.fc_size = image_size // 8
# model : img -> conv1
# labels -> maxpool
# (img U labels) -> Conv2d(3,2,1) -> BN -> LeakyReLU
# Conv2d(3,2,1) -> BN -> LeakyReLU
self.conv1 = nn.Sequential(nn.Conv2d(input_channel, 128, 3, 2, 1),
nn.BatchNorm2d(128))
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv = nn.Sequential(
nn.LeakyReLU(0.2),
nn.Conv2d(128 + class_num, 256, 3, 2, 1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 512, 3, 2, 1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),
)
# fc: Linear -> Sigmoid
self.fc = nn.Sequential(
nn.Linear(512 * self.fc_size * self.fc_size, 1),
)
if sigmoid:
self.fc.add_module('sigmoid', nn.Sigmoid())
initialize_weights(self)
def forward(self, img, labels):
"""
img : input image
labels : (batch_size, class_num, image_size, image_size)
the i-th channel is filled with 1, and others is filled with 0.
"""
img_out = self.conv1(img)
labels_out = self.maxpool(labels)
out = torch.cat((img_out, labels_out), dim=1)
out = self.conv(out)
out = out.view(out.shape[0], -1)
out = self.fc(out)
return out
# hyper params
# device : cpu or cuda:0/1/2/3
device = torch.device('cuda:0')
# G and D model
G = DCGenerator(image_size=image_size, latent_dim=latent_dim, output_channel=image_channel, class_num=class_num)
D = DCDiscriminator2(image_size=image_size, input_channel=image_channel, class_num=class_num)
G.to(device)
D.to(device)
# G and D optimizer, use Adam or SGD
G_optimizer = optim.Adam(G.parameters(), lr=learning_rate, betas=betas)
D_optimizer = optim.Adam(D.parameters(), lr=learning_rate, betas=betas)
d_loss_hist, g_loss_hist = run_gan(trainloader, G, D, G_optimizer, D_optimizer, bceloss,
n_epochs, device, latent_dim, class_num)
loss_plot(d_loss_hist, g_loss_hist)
Epoch 0: Train D loss: 0.3652, G loss: 4.9241
Epoch 1: Train D loss: 0.4429, G loss: 5.9180
Epoch 2: Train D loss: 0.1754, G loss: 4.7812
Epoch 3: Train D loss: 0.1420, G loss: 4.3085
Epoch 4: Train D loss: 0.3841, G loss: 4.3876
Epoch 5: Train D loss: 0.2930, G loss: 4.0079
Epoch 6: Train D loss: 0.4288, G loss: 3.7424
Epoch 7: Train D loss: 0.2865, G loss: 3.3463
Epoch 8: Train D loss: 0.4004, G loss: 3.2854
Epoch 9: Train D loss: 0.5848, G loss: 3.1671
Epoch 10: Train D loss: 0.4002, G loss: 2.8178
Epoch 11: Train D loss: 0.5253, G loss: 2.8021
Epoch 12: Train D loss: 0.4601, G loss: 2.9233
Epoch 13: Train D loss: 0.4344, G loss: 2.8415
Epoch 14: Train D loss: 0.5028, G loss: 2.8215
Epoch 15: Train D loss: 0.5168, G loss: 2.7453
Epoch 16: Train D loss: 0.4130, G loss: 2.7670
Epoch 17: Train D loss: 0.4163, G loss: 3.0764
Epoch 18: Train D loss: 0.5104, G loss: 2.8070
Epoch 19: Train D loss: 0.3978, G loss: 2.8812
Epoch 20: Train D loss: 0.4791, G loss: 2.8178
Epoch 21: Train D loss: 0.4476, G loss: 2.8565
Epoch 22: Train D loss: 0.4595, G loss: 3.1004
Epoch 23: Train D loss: 0.4785, G loss: 2.8015
Epoch 24: Train D loss: 0.3153, G loss: 3.0215
Epoch 25: Train D loss: 0.5425, G loss: 2.8971
Epoch 26: Train D loss: 0.4823, G loss: 2.9873
Epoch 27: Train D loss: 0.3167, G loss: 3.1143
Epoch 28: Train D loss: 0.3648, G loss: 3.3072
Epoch 29: Train D loss: 0.2634, G loss: 3.4745
Epoch 30: Train D loss: 0.2306, G loss: 3.4688
Epoch 31: Train D loss: 0.4227, G loss: 3.5273
Epoch 32: Train D loss: 0.1514, G loss: 3.5782
Epoch 33: Train D loss: 0.7272, G loss: 3.0295
Epoch 34: Train D loss: 0.2905, G loss: 3.1706
Epoch 35: Train D loss: 0.3160, G loss: 3.8710
Epoch 36: Train D loss: 0.2567, G loss: 3.7115
Epoch 37: Train D loss: 0.2725, G loss: 3.4961
Epoch 38: Train D loss: 0.1762, G loss: 3.9646
Epoch 39: Train D loss: 0.1075, G loss: 4.2067
Epoch 40: Train D loss: 0.0735, G loss: 4.4320
Epoch 41: Train D loss: 0.9511, G loss: 2.4049
Epoch 42: Train D loss: 0.6208, G loss: 2.4724
Epoch 43: Train D loss: 0.4307, G loss: 2.8940
Epoch 44: Train D loss: 0.4893, G loss: 3.1014
Epoch 45: Train D loss: 0.3474, G loss: 2.8716
Epoch 46: Train D loss: 0.1160, G loss: 3.8310
Epoch 47: Train D loss: 0.0750, G loss: 4.2792
Epoch 48: Train D loss: 0.0646, G loss: 4.5145
Epoch 49: Train D loss: 0.6285, G loss: 3.3446
Epoch 50: Train D loss: 0.3821, G loss: 3.2104
Epoch 51: Train D loss: 0.1024, G loss: 4.2762
Epoch 52: Train D loss: 0.0519, G loss: 4.5700
Epoch 53: Train D loss: 0.0400, G loss: 4.8230
Epoch 54: Train D loss: 0.9730, G loss: 3.0698
Epoch 55: Train D loss: 0.4508, G loss: 2.9594
Epoch 56: Train D loss: 0.2250, G loss: 3.8201
Epoch 57: Train D loss: 0.3813, G loss: 3.9528
Epoch 58: Train D loss: 0.1910, G loss: 3.5154
Epoch 59: Train D loss: 0.0522, G loss: 4.5624
Epoch 60: Train D loss: 0.0466, G loss: 4.8247
Epoch 61: Train D loss: 0.4749, G loss: 3.5991
Epoch 62: Train D loss: 0.3555, G loss: 3.8786
Epoch 63: Train D loss: 0.1598, G loss: 3.7794
Epoch 64: Train D loss: 0.0524, G loss: 4.6509
Epoch 65: Train D loss: 1.0464, G loss: 2.4283
Epoch 66: Train D loss: 0.4957, G loss: 2.8442
Epoch 67: Train D loss: 0.3295, G loss: 3.3887
Epoch 68: Train D loss: 0.1129, G loss: 3.9749
Epoch 69: Train D loss: 0.1507, G loss: 4.4947
Epoch 70: Train D loss: 0.7100, G loss: 2.4590
Epoch 71: Train D loss: 0.2155, G loss: 3.7421
Epoch 72: Train D loss: 0.0485, G loss: 4.4712
Epoch 73: Train D loss: 0.0384, G loss: 4.7572
Epoch 74: Train D loss: 0.0426, G loss: 5.0863
Epoch 75: Train D loss: 0.0276, G loss: 5.1197
Epoch 76: Train D loss: 0.3221, G loss: 4.6973
Epoch 77: Train D loss: 0.1505, G loss: 3.9563
Epoch 78: Train D loss: 0.0345, G loss: 4.8970
Epoch 79: Train D loss: 0.0291, G loss: 5.1258
Epoch 80: Train D loss: 0.0274, G loss: 5.2251
Epoch 81: Train D loss: 0.5761, G loss: 3.7912
Epoch 82: Train D loss: 0.1272, G loss: 4.0800
Epoch 83: Train D loss: 0.0365, G loss: 5.0618
Epoch 84: Train D loss: 0.0256, G loss: 5.2438
Epoch 85: Train D loss: 0.0247, G loss: 5.5058
Epoch 86: Train D loss: 0.0233, G loss: 5.5718
Epoch 87: Train D loss: 0.5834, G loss: 5.2774
Epoch 88: Train D loss: 0.3599, G loss: 3.4697
Epoch 89: Train D loss: 0.5033, G loss: 3.5868
Epoch 90: Train D loss: 0.4733, G loss: 3.7724
Epoch 91: Train D loss: 0.4360, G loss: 3.1645
Epoch 92: Train D loss: 0.1544, G loss: 3.8841
Epoch 93: Train D loss: 0.0724, G loss: 4.4879
Epoch 94: Train D loss: 0.0429, G loss: 4.7762
Epoch 95: Train D loss: 1.1455, G loss: 2.3196
Epoch 96: Train D loss: 0.5529, G loss: 2.6795
Epoch 97: Train D loss: 0.5533, G loss: 3.1325
Epoch 98: Train D loss: 0.1586, G loss: 3.7296
Epoch 99: Train D loss: 0.2869, G loss: 4.0718
Epoch 100: Train D loss: 0.0795, G loss: 4.3183
Epoch 101: Train D loss: 0.0399, G loss: 4.7554
Epoch 102: Train D loss: 0.0357, G loss: 4.9615
Epoch 103: Train D loss: 0.0278, G loss: 5.0769
Epoch 104: Train D loss: 0.4849, G loss: 3.4960
Epoch 105: Train D loss: 0.0654, G loss: 4.4074
Epoch 106: Train D loss: 0.0316, G loss: 5.0200
Epoch 107: Train D loss: 0.6949, G loss: 3.0835
Epoch 108: Train D loss: 0.0958, G loss: 4.1572
Epoch 109: Train D loss: 0.0344, G loss: 4.9555
Epoch 110: Train D loss: 0.8575, G loss: 2.8936
Epoch 111: Train D loss: 0.4397, G loss: 3.1297
Epoch 112: Train D loss: 0.3955, G loss: 3.4052
Epoch 113: Train D loss: 0.0928, G loss: 4.1257
Epoch 114: Train D loss: 0.0494, G loss: 4.5448
Epoch 115: Train D loss: 0.0411, G loss: 4.7861
Epoch 116: Train D loss: 0.0321, G loss: 5.1476
Epoch 117: Train D loss: 0.0227, G loss: 5.2891
Epoch 118: Train D loss: 0.0228, G loss: 5.3384
Epoch 119: Train D loss: 0.5538, G loss: 3.7886
:
- 若输入的类别标签不用one-hot的向量表示,我们一开始先为每个类随机生成一个随机向量,然后使用这个向量作为类别标签,这样对结果会有改变吗?试尝试运行下面代码,与之前的结果对比,说说有什么不同?
答:与之前的结果相比,loss值变化趋势大体相同,不过识别器的loss最低值相比起来最小。根据生成图片的变化,可知其训练速度更快,训练效果一开始就挺好,我们根据第一张图就能大概看出数字形状,而之前的实验第一张图都没有这么容易看出数字形状,另外最后生成的图片分辨度很高。
vecs = torch.randn(class_num, class_num)
fills = vecs.unsqueeze(2).unsqueeze(3).expand(class_num, class_num, image_size, image_size)
print(vecs)
print(fills)
tensor([[ 1.1567, -0.9291, 0.8017, 0.1191, -0.6171, 0.2650, 1.4880, -1.4616,
-0.1435, -1.6340],
[-0.7911, -0.1545, 0.5527, -0.1140, -0.4090, -1.2365, 1.1262, -0.9392,
0.2391, -0.8417],
[ 0.7116, 1.2442, -0.5190, -0.9552, -0.7486, 0.3997, -0.4397, -0.0039,
-0.0925, 0.3558],
[ 0.3213, -0.1969, 0.1278, 1.2716, -0.4009, -0.5936, -0.4486, -0.4744,
-0.1520, 0.1896],
[ 0.6892, -0.4903, 1.0817, 1.0543, -1.6935, -0.2287, 0.1058, 1.6348,
-1.6293, -0.7025],
[-0.0451, -1.3523, 1.1683, -0.5997, -1.0793, 1.0965, -2.0173, 1.8741,
-0.8195, -0.9508],
[-0.1799, 1.3447, -1.4748, -0.5927, 0.1918, -0.0547, -0.3212, -1.8754,
-0.5544, 0.5947],
[ 0.2432, 0.9638, 0.7930, 1.0026, 0.0817, 0.4393, -0.2386, -0.4549,
-0.7598, 1.0468],
[ 0.4202, -0.8265, -0.4051, 1.3794, -0.4501, 0.7389, -0.4055, -0.5978,
-0.0337, -0.7161],
[ 1.1576, -0.0746, -0.2538, 0.3206, -1.9917, 0.0155, 1.6934, -0.9318,
-0.3510, -0.7363]])
tensor([[[[ 1.1567, 1.1567, 1.1567, ..., 1.1567, 1.1567, 1.1567],
[ 1.1567, 1.1567, 1.1567, ..., 1.1567, 1.1567, 1.1567],
[ 1.1567, 1.1567, 1.1567, ..., 1.1567, 1.1567, 1.1567],
...,
[ 1.1567, 1.1567, 1.1567, ..., 1.1567, 1.1567, 1.1567],
[ 1.1567, 1.1567, 1.1567, ..., 1.1567, 1.1567, 1.1567],
[ 1.1567, 1.1567, 1.1567, ..., 1.1567, 1.1567, 1.1567]],
[[-0.9291, -0.9291, -0.9291, ..., -0.9291, -0.9291, -0.9291],
[-0.9291, -0.9291, -0.9291, ..., -0.9291, -0.9291, -0.9291],
[-0.9291, -0.9291, -0.9291, ..., -0.9291, -0.9291, -0.9291],
...,
[-0.9291, -0.9291, -0.9291, ..., -0.9291, -0.9291, -0.9291],
[-0.9291, -0.9291, -0.9291, ..., -0.9291, -0.9291, -0.9291],
[-0.9291, -0.9291, -0.9291, ..., -0.9291, -0.9291, -0.9291]],
[[ 0.8017, 0.8017, 0.8017, ..., 0.8017, 0.8017, 0.8017],
[ 0.8017, 0.8017, 0.8017, ..., 0.8017, 0.8017, 0.8017],
[ 0.8017, 0.8017, 0.8017, ..., 0.8017, 0.8017, 0.8017],
...,
[ 0.8017, 0.8017, 0.8017, ..., 0.8017, 0.8017, 0.8017],
[ 0.8017, 0.8017, 0.8017, ..., 0.8017, 0.8017, 0.8017],
[ 0.8017, 0.8017, 0.8017, ..., 0.8017, 0.8017, 0.8017]],
...,
[[-1.4616, -1.4616, -1.4616, ..., -1.4616, -1.4616, -1.4616],
[-1.4616, -1.4616, -1.4616, ..., -1.4616, -1.4616, -1.4616],
[-1.4616, -1.4616, -1.4616, ..., -1.4616, -1.4616, -1.4616],
...,
[-1.4616, -1.4616, -1.4616, ..., -1.4616, -1.4616, -1.4616],
[-1.4616, -1.4616, -1.4616, ..., -1.4616, -1.4616, -1.4616],
[-1.4616, -1.4616, -1.4616, ..., -1.4616, -1.4616, -1.4616]],
[[-0.1435, -0.1435, -0.1435, ..., -0.1435, -0.1435, -0.1435],
[-0.1435, -0.1435, -0.1435, ..., -0.1435, -0.1435, -0.1435],
[-0.1435, -0.1435, -0.1435, ..., -0.1435, -0.1435, -0.1435],
...,
[-0.1435, -0.1435, -0.1435, ..., -0.1435, -0.1435, -0.1435],
[-0.1435, -0.1435, -0.1435, ..., -0.1435, -0.1435, -0.1435],
[-0.1435, -0.1435, -0.1435, ..., -0.1435, -0.1435, -0.1435]],
[[-1.6340, -1.6340, -1.6340, ..., -1.6340, -1.6340, -1.6340],
[-1.6340, -1.6340, -1.6340, ..., -1.6340, -1.6340, -1.6340],
[-1.6340, -1.6340, -1.6340, ..., -1.6340, -1.6340, -1.6340],
...,
[-1.6340, -1.6340, -1.6340, ..., -1.6340, -1.6340, -1.6340],
[-1.6340, -1.6340, -1.6340, ..., -1.6340, -1.6340, -1.6340],
[-1.6340, -1.6340, -1.6340, ..., -1.6340, -1.6340, -1.6340]]],
[[[-0.7911, -0.7911, -0.7911, …, -0.7911, -0.7911, -0.7911],
[-0.7911, -0.7911, -0.7911, …, -0.7911, -0.7911, -0.7911],
[-0.7911, -0.7911, -0.7911, …, -0.7911, -0.7911, -0.7911],
…,
[-0.7911, -0.7911, -0.7911, …, -0.7911, -0.7911, -0.7911],
[-0.7911, -0.7911, -0.7911, …, -0.7911, -0.7911, -0.7911],
[-0.7911, -0.7911, -0.7911, …, -0.7911, -0.7911, -0.7911]],
[[-0.1545, -0.1545, -0.1545, ..., -0.1545, -0.1545, -0.1545],
[-0.1545, -0.1545, -0.1545, ..., -0.1545, -0.1545, -0.1545],
[-0.1545, -0.1545, -0.1545, ..., -0.1545, -0.1545, -0.1545],
...,
[-0.1545, -0.1545, -0.1545, ..., -0.1545, -0.1545, -0.1545],
[-0.1545, -0.1545, -0.1545, ..., -0.1545, -0.1545, -0.1545],
[-0.1545, -0.1545, -0.1545, ..., -0.1545, -0.1545, -0.1545]],
[[ 0.5527, 0.5527, 0.5527, ..., 0.5527, 0.5527, 0.5527],
[ 0.5527, 0.5527, 0.5527, ..., 0.5527, 0.5527, 0.5527],
[ 0.5527, 0.5527, 0.5527, ..., 0.5527, 0.5527, 0.5527],
...,
[ 0.5527, 0.5527, 0.5527, ..., 0.5527, 0.5527, 0.5527],
[ 0.5527, 0.5527, 0.5527, ..., 0.5527, 0.5527, 0.5527],
[ 0.5527, 0.5527, 0.5527, ..., 0.5527, 0.5527, 0.5527]],
...,
[[-0.9392, -0.9392, -0.9392, ..., -0.9392, -0.9392, -0.9392],
[-0.9392, -0.9392, -0.9392, ..., -0.9392, -0.9392, -0.9392],
[-0.9392, -0.9392, -0.9392, ..., -0.9392, -0.9392, -0.9392],
...,
[-0.9392, -0.9392, -0.9392, ..., -0.9392, -0.9392, -0.9392],
[-0.9392, -0.9392, -0.9392, ..., -0.9392, -0.9392, -0.9392],
[-0.9392, -0.9392, -0.9392, ..., -0.9392, -0.9392, -0.9392]],
[[ 0.2391, 0.2391, 0.2391, ..., 0.2391, 0.2391, 0.2391],
[ 0.2391, 0.2391, 0.2391, ..., 0.2391, 0.2391, 0.2391],
[ 0.2391, 0.2391, 0.2391, ..., 0.2391, 0.2391, 0.2391],
...,
[ 0.2391, 0.2391, 0.2391, ..., 0.2391, 0.2391, 0.2391],
[ 0.2391, 0.2391, 0.2391, ..., 0.2391, 0.2391, 0.2391],
[ 0.2391, 0.2391, 0.2391, ..., 0.2391, 0.2391, 0.2391]],
[[-0.8417, -0.8417, -0.8417, ..., -0.8417, -0.8417, -0.8417],
[-0.8417, -0.8417, -0.8417, ..., -0.8417, -0.8417, -0.8417],
[-0.8417, -0.8417, -0.8417, ..., -0.8417, -0.8417, -0.8417],
...,
[-0.8417, -0.8417, -0.8417, ..., -0.8417, -0.8417, -0.8417],
[-0.8417, -0.8417, -0.8417, ..., -0.8417, -0.8417, -0.8417],
[-0.8417, -0.8417, -0.8417, ..., -0.8417, -0.8417, -0.8417]]],
[[[ 0.7116, 0.7116, 0.7116, …, 0.7116, 0.7116, 0.7116],
[ 0.7116, 0.7116, 0.7116, …, 0.7116, 0.7116, 0.7116],
[ 0.7116, 0.7116, 0.7116, …, 0.7116, 0.7116, 0.7116],
…,
[ 0.7116, 0.7116, 0.7116, …, 0.7116, 0.7116, 0.7116],
[ 0.7116, 0.7116, 0.7116, …, 0.7116, 0.7116, 0.7116],
[ 0.7116, 0.7116, 0.7116, …, 0.7116, 0.7116, 0.7116]],
[[ 1.2442, 1.2442, 1.2442, ..., 1.2442, 1.2442, 1.2442],
[ 1.2442, 1.2442, 1.2442, ..., 1.2442, 1.2442, 1.2442],
[ 1.2442, 1.2442, 1.2442, ..., 1.2442, 1.2442, 1.2442],
...,
[ 1.2442, 1.2442, 1.2442, ..., 1.2442, 1.2442, 1.2442],
[ 1.2442, 1.2442, 1.2442, ..., 1.2442, 1.2442, 1.2442],
[ 1.2442, 1.2442, 1.2442, ..., 1.2442, 1.2442, 1.2442]],
[[-0.5190, -0.5190, -0.5190, ..., -0.5190, -0.5190, -0.5190],
[-0.5190, -0.5190, -0.5190, ..., -0.5190, -0.5190, -0.5190],
[-0.5190, -0.5190, -0.5190, ..., -0.5190, -0.5190, -0.5190],
...,
[-0.5190, -0.5190, -0.5190, ..., -0.5190, -0.5190, -0.5190],
[-0.5190, -0.5190, -0.5190, ..., -0.5190, -0.5190, -0.5190],
[-0.5190, -0.5190, -0.5190, ..., -0.5190, -0.5190, -0.5190]],
...,
[[-0.0039, -0.0039, -0.0039, ..., -0.0039, -0.0039, -0.0039],
[-0.0039, -0.0039, -0.0039, ..., -0.0039, -0.0039, -0.0039],
[-0.0039, -0.0039, -0.0039, ..., -0.0039, -0.0039, -0.0039],
...,
[-0.0039, -0.0039, -0.0039, ..., -0.0039, -0.0039, -0.0039],
[-0.0039, -0.0039, -0.0039, ..., -0.0039, -0.0039, -0.0039],
[-0.0039, -0.0039, -0.0039, ..., -0.0039, -0.0039, -0.0039]],
[[-0.0925, -0.0925, -0.0925, ..., -0.0925, -0.0925, -0.0925],
[-0.0925, -0.0925, -0.0925, ..., -0.0925, -0.0925, -0.0925],
[-0.0925, -0.0925, -0.0925, ..., -0.0925, -0.0925, -0.0925],
...,
[-0.0925, -0.0925, -0.0925, ..., -0.0925, -0.0925, -0.0925],
[-0.0925, -0.0925, -0.0925, ..., -0.0925, -0.0925, -0.0925],
[-0.0925, -0.0925, -0.0925, ..., -0.0925, -0.0925, -0.0925]],
[[ 0.3558, 0.3558, 0.3558, ..., 0.3558, 0.3558, 0.3558],
[ 0.3558, 0.3558, 0.3558, ..., 0.3558, 0.3558, 0.3558],
[ 0.3558, 0.3558, 0.3558, ..., 0.3558, 0.3558, 0.3558],
...,
[ 0.3558, 0.3558, 0.3558, ..., 0.3558, 0.3558, 0.3558],
[ 0.3558, 0.3558, 0.3558, ..., 0.3558, 0.3558, 0.3558],
[ 0.3558, 0.3558, 0.3558, ..., 0.3558, 0.3558, 0.3558]]],
…,
[[[ 0.2432, 0.2432, 0.2432, …, 0.2432, 0.2432, 0.2432],
[ 0.2432, 0.2432, 0.2432, …, 0.2432, 0.2432, 0.2432],
[ 0.2432, 0.2432, 0.2432, …, 0.2432, 0.2432, 0.2432],
…,
[ 0.2432, 0.2432, 0.2432, …, 0.2432, 0.2432, 0.2432],
[ 0.2432, 0.2432, 0.2432, …, 0.2432, 0.2432, 0.2432],
[ 0.2432, 0.2432, 0.2432, …, 0.2432, 0.2432, 0.2432]],
[[ 0.9638, 0.9638, 0.9638, ..., 0.9638, 0.9638, 0.9638],
[ 0.9638, 0.9638, 0.9638, ..., 0.9638, 0.9638, 0.9638],
[ 0.9638, 0.9638, 0.9638, ..., 0.9638, 0.9638, 0.9638],
...,
[ 0.9638, 0.9638, 0.9638, ..., 0.9638, 0.9638, 0.9638],
[ 0.9638, 0.9638, 0.9638, ..., 0.9638, 0.9638, 0.9638],
[ 0.9638, 0.9638, 0.9638, ..., 0.9638, 0.9638, 0.9638]],
[[ 0.7930, 0.7930, 0.7930, ..., 0.7930, 0.7930, 0.7930],
[ 0.7930, 0.7930, 0.7930, ..., 0.7930, 0.7930, 0.7930],
[ 0.7930, 0.7930, 0.7930, ..., 0.7930, 0.7930, 0.7930],
...,
[ 0.7930, 0.7930, 0.7930, ..., 0.7930, 0.7930, 0.7930],
[ 0.7930, 0.7930, 0.7930, ..., 0.7930, 0.7930, 0.7930],
[ 0.7930, 0.7930, 0.7930, ..., 0.7930, 0.7930, 0.7930]],
...,
[[-0.4549, -0.4549, -0.4549, ..., -0.4549, -0.4549, -0.4549],
[-0.4549, -0.4549, -0.4549, ..., -0.4549, -0.4549, -0.4549],
[-0.4549, -0.4549, -0.4549, ..., -0.4549, -0.4549, -0.4549],
...,
[-0.4549, -0.4549, -0.4549, ..., -0.4549, -0.4549, -0.4549],
[-0.4549, -0.4549, -0.4549, ..., -0.4549, -0.4549, -0.4549],
[-0.4549, -0.4549, -0.4549, ..., -0.4549, -0.4549, -0.4549]],
[[-0.7598, -0.7598, -0.7598, ..., -0.7598, -0.7598, -0.7598],
[-0.7598, -0.7598, -0.7598, ..., -0.7598, -0.7598, -0.7598],
[-0.7598, -0.7598, -0.7598, ..., -0.7598, -0.7598, -0.7598],
...,
[-0.7598, -0.7598, -0.7598, ..., -0.7598, -0.7598, -0.7598],
[-0.7598, -0.7598, -0.7598, ..., -0.7598, -0.7598, -0.7598],
[-0.7598, -0.7598, -0.7598, ..., -0.7598, -0.7598, -0.7598]],
[[ 1.0468, 1.0468, 1.0468, ..., 1.0468, 1.0468, 1.0468],
[ 1.0468, 1.0468, 1.0468, ..., 1.0468, 1.0468, 1.0468],
[ 1.0468, 1.0468, 1.0468, ..., 1.0468, 1.0468, 1.0468],
...,
[ 1.0468, 1.0468, 1.0468, ..., 1.0468, 1.0468, 1.0468],
[ 1.0468, 1.0468, 1.0468, ..., 1.0468, 1.0468, 1.0468],
[ 1.0468, 1.0468, 1.0468, ..., 1.0468, 1.0468, 1.0468]]],
[[[ 0.4202, 0.4202, 0.4202, …, 0.4202, 0.4202, 0.4202],
[ 0.4202, 0.4202, 0.4202, …, 0.4202, 0.4202, 0.4202],
[ 0.4202, 0.4202, 0.4202, …, 0.4202, 0.4202, 0.4202],
…,
[ 0.4202, 0.4202, 0.4202, …, 0.4202, 0.4202, 0.4202],
[ 0.4202, 0.4202, 0.4202, …, 0.4202, 0.4202, 0.4202],
[ 0.4202, 0.4202, 0.4202, …, 0.4202, 0.4202, 0.4202]],
[[-0.8265, -0.8265, -0.8265, ..., -0.8265, -0.8265, -0.8265],
[-0.8265, -0.8265, -0.8265, ..., -0.8265, -0.8265, -0.8265],
[-0.8265, -0.8265, -0.8265, ..., -0.8265, -0.8265, -0.8265],
...,
[-0.8265, -0.8265, -0.8265, ..., -0.8265, -0.8265, -0.8265],
[-0.8265, -0.8265, -0.8265, ..., -0.8265, -0.8265, -0.8265],
[-0.8265, -0.8265, -0.8265, ..., -0.8265, -0.8265, -0.8265]],
[[-0.4051, -0.4051, -0.4051, ..., -0.4051, -0.4051, -0.4051],
[-0.4051, -0.4051, -0.4051, ..., -0.4051, -0.4051, -0.4051],
[-0.4051, -0.4051, -0.4051, ..., -0.4051, -0.4051, -0.4051],
...,
[-0.4051, -0.4051, -0.4051, ..., -0.4051, -0.4051, -0.4051],
[-0.4051, -0.4051, -0.4051, ..., -0.4051, -0.4051, -0.4051],
[-0.4051, -0.4051, -0.4051, ..., -0.4051, -0.4051, -0.4051]],
...,
[[-0.5978, -0.5978, -0.5978, ..., -0.5978, -0.5978, -0.5978],
[-0.5978, -0.5978, -0.5978, ..., -0.5978, -0.5978, -0.5978],
[-0.5978, -0.5978, -0.5978, ..., -0.5978, -0.5978, -0.5978],
...,
[-0.5978, -0.5978, -0.5978, ..., -0.5978, -0.5978, -0.5978],
[-0.5978, -0.5978, -0.5978, ..., -0.5978, -0.5978, -0.5978],
[-0.5978, -0.5978, -0.5978, ..., -0.5978, -0.5978, -0.5978]],
[[-0.0337, -0.0337, -0.0337, ..., -0.0337, -0.0337, -0.0337],
[-0.0337, -0.0337, -0.0337, ..., -0.0337, -0.0337, -0.0337],
[-0.0337, -0.0337, -0.0337, ..., -0.0337, -0.0337, -0.0337],
...,
[-0.0337, -0.0337, -0.0337, ..., -0.0337, -0.0337, -0.0337],
[-0.0337, -0.0337, -0.0337, ..., -0.0337, -0.0337, -0.0337],
[-0.0337, -0.0337, -0.0337, ..., -0.0337, -0.0337, -0.0337]],
[[-0.7161, -0.7161, -0.7161, ..., -0.7161, -0.7161, -0.7161],
[-0.7161, -0.7161, -0.7161, ..., -0.7161, -0.7161, -0.7161],
[-0.7161, -0.7161, -0.7161, ..., -0.7161, -0.7161, -0.7161],
...,
[-0.7161, -0.7161, -0.7161, ..., -0.7161, -0.7161, -0.7161],
[-0.7161, -0.7161, -0.7161, ..., -0.7161, -0.7161, -0.7161],
[-0.7161, -0.7161, -0.7161, ..., -0.7161, -0.7161, -0.7161]]],
[[[ 1.1576, 1.1576, 1.1576, …, 1.1576, 1.1576, 1.1576],
[ 1.1576, 1.1576, 1.1576, …, 1.1576, 1.1576, 1.1576],
[ 1.1576, 1.1576, 1.1576, …, 1.1576, 1.1576, 1.1576],
…,
[ 1.1576, 1.1576, 1.1576, …, 1.1576, 1.1576, 1.1576],
[ 1.1576, 1.1576, 1.1576, …, 1.1576, 1.1576, 1.1576],
[ 1.1576, 1.1576, 1.1576, …, 1.1576, 1.1576, 1.1576]],
[[-0.0746, -0.0746, -0.0746, ..., -0.0746, -0.0746, -0.0746],
[-0.0746, -0.0746, -0.0746, ..., -0.0746, -0.0746, -0.0746],
[-0.0746, -0.0746, -0.0746, ..., -0.0746, -0.0746, -0.0746],
...,
[-0.0746, -0.0746, -0.0746, ..., -0.0746, -0.0746, -0.0746],
[-0.0746, -0.0746, -0.0746, ..., -0.0746, -0.0746, -0.0746],
[-0.0746, -0.0746, -0.0746, ..., -0.0746, -0.0746, -0.0746]],
[[-0.2538, -0.2538, -0.2538, ..., -0.2538, -0.2538, -0.2538],
[-0.2538, -0.2538, -0.2538, ..., -0.2538, -0.2538, -0.2538],
[-0.2538, -0.2538, -0.2538, ..., -0.2538, -0.2538, -0.2538],
...,
[-0.2538, -0.2538, -0.2538, ..., -0.2538, -0.2538, -0.2538],
[-0.2538, -0.2538, -0.2538, ..., -0.2538, -0.2538, -0.2538],
[-0.2538, -0.2538, -0.2538, ..., -0.2538, -0.2538, -0.2538]],
...,
[[-0.9318, -0.9318, -0.9318, ..., -0.9318, -0.9318, -0.9318],
[-0.9318, -0.9318, -0.9318, ..., -0.9318, -0.9318, -0.9318],
[-0.9318, -0.9318, -0.9318, ..., -0.9318, -0.9318, -0.9318],
...,
[-0.9318, -0.9318, -0.9318, ..., -0.9318, -0.9318, -0.9318],
[-0.9318, -0.9318, -0.9318, ..., -0.9318, -0.9318, -0.9318],
[-0.9318, -0.9318, -0.9318, ..., -0.9318, -0.9318, -0.9318]],
[[-0.3510, -0.3510, -0.3510, ..., -0.3510, -0.3510, -0.3510],
[-0.3510, -0.3510, -0.3510, ..., -0.3510, -0.3510, -0.3510],
[-0.3510, -0.3510, -0.3510, ..., -0.3510, -0.3510, -0.3510],
...,
[-0.3510, -0.3510, -0.3510, ..., -0.3510, -0.3510, -0.3510],
[-0.3510, -0.3510, -0.3510, ..., -0.3510, -0.3510, -0.3510],
[-0.3510, -0.3510, -0.3510, ..., -0.3510, -0.3510, -0.3510]],
[[-0.7363, -0.7363, -0.7363, ..., -0.7363, -0.7363, -0.7363],
[-0.7363, -0.7363, -0.7363, ..., -0.7363, -0.7363, -0.7363],
[-0.7363, -0.7363, -0.7363, ..., -0.7363, -0.7363, -0.7363],
...,
[-0.7363, -0.7363, -0.7363, ..., -0.7363, -0.7363, -0.7363],
[-0.7363, -0.7363, -0.7363, ..., -0.7363, -0.7363, -0.7363],
[-0.7363, -0.7363, -0.7363, ..., -0.7363, -0.7363, -0.7363]]]])
# hyper params
# device : cpu or cuda:0/1/2/3
device = torch.device('cuda:0')
# G and D model
G = DCGenerator(image_size=image_size, latent_dim=latent_dim, output_channel=image_channel, class_num=class_num)
D = DCDiscriminator(image_size=image_size, input_channel=image_channel, class_num=class_num)
G.to(device)
D.to(device)
# G and D optimizer, use Adam or SGD
G_optimizer = optim.Adam(G.parameters(), lr=learning_rate, betas=betas)
D_optimizer = optim.Adam(D.parameters(), lr=learning_rate, betas=betas)
d_loss_hist, g_loss_hist = run_gan(trainloader, G, D, G_optimizer, D_optimizer, bceloss,
n_epochs, device, latent_dim, class_num)
loss_plot(d_loss_hist, g_loss_hist)
Epoch 0: Train D loss: 0.7260, G loss: 2.1572
Epoch 1: Train D loss: 0.9390, G loss: 1.7824
Epoch 2: Train D loss: 1.0395, G loss: 1.7254
Epoch 3: Train D loss: 1.1483, G loss: 1.4995
Epoch 4: Train D loss: 1.1209, G loss: 1.3907
Epoch 5: Train D loss: 1.1032, G loss: 1.4722
Epoch 6: Train D loss: 1.0763, G loss: 1.4546
Epoch 7: Train D loss: 1.0108, G loss: 1.5468
Epoch 8: Train D loss: 1.0263, G loss: 1.5239
Epoch 9: Train D loss: 1.0970, G loss: 1.4444
Epoch 10: Train D loss: 1.1275, G loss: 1.3871
Epoch 11: Train D loss: 1.0939, G loss: 1.3838
Epoch 12: Train D loss: 1.0969, G loss: 1.3490
Epoch 13: Train D loss: 1.0826, G loss: 1.3967
Epoch 14: Train D loss: 1.1435, G loss: 1.3215
Epoch 15: Train D loss: 1.1192, G loss: 1.3385
Epoch 16: Train D loss: 1.1218, G loss: 1.3073
Epoch 17: Train D loss: 1.1673, G loss: 1.2399
Epoch 18: Train D loss: 1.1879, G loss: 1.2262
Epoch 19: Train D loss: 1.1955, G loss: 1.2235
Epoch 20: Train D loss: 1.2040, G loss: 1.1830
Epoch 21: Train D loss: 1.2068, G loss: 1.1786
Epoch 22: Train D loss: 1.2297, G loss: 1.1382
Epoch 23: Train D loss: 1.2207, G loss: 1.1666
Epoch 24: Train D loss: 1.2436, G loss: 1.1467
Epoch 25: Train D loss: 1.2206, G loss: 1.1358
Epoch 26: Train D loss: 1.2438, G loss: 1.1182
Epoch 27: Train D loss: 1.2187, G loss: 1.1273
Epoch 28: Train D loss: 1.2356, G loss: 1.1087
Epoch 29: Train D loss: 1.2386, G loss: 1.1199
Epoch 30: Train D loss: 1.2112, G loss: 1.1312
Epoch 31: Train D loss: 1.2411, G loss: 1.1433
Epoch 32: Train D loss: 1.2233, G loss: 1.1320
Epoch 33: Train D loss: 1.2163, G loss: 1.1253
Epoch 34: Train D loss: 1.2149, G loss: 1.1394
Epoch 35: Train D loss: 1.2262, G loss: 1.1641
Epoch 36: Train D loss: 1.2044, G loss: 1.1631
Epoch 37: Train D loss: 1.2155, G loss: 1.1475
Epoch 38: Train D loss: 1.1877, G loss: 1.1697
Epoch 39: Train D loss: 1.1980, G loss: 1.1835
Epoch 40: Train D loss: 1.2010, G loss: 1.1685
Epoch 41: Train D loss: 1.1966, G loss: 1.1575
Epoch 42: Train D loss: 1.1946, G loss: 1.2079
Epoch 43: Train D loss: 1.1545, G loss: 1.2180
Epoch 44: Train D loss: 1.1543, G loss: 1.2154
Epoch 45: Train D loss: 1.1398, G loss: 1.2518
Epoch 46: Train D loss: 1.1400, G loss: 1.2798
Epoch 47: Train D loss: 1.1451, G loss: 1.2790
Epoch 48: Train D loss: 1.1083, G loss: 1.2797
Epoch 49: Train D loss: 1.0828, G loss: 1.3435
Epoch 50: Train D loss: 1.0941, G loss: 1.3757
Epoch 51: Train D loss: 1.0729, G loss: 1.3664
Epoch 52: Train D loss: 1.0801, G loss: 1.4018
Epoch 53: Train D loss: 1.0361, G loss: 1.4298
Epoch 54: Train D loss: 0.9954, G loss: 1.4514
Epoch 55: Train D loss: 1.0083, G loss: 1.4741
Epoch 56: Train D loss: 0.9435, G loss: 1.5283
Epoch 57: Train D loss: 0.9614, G loss: 1.6080
Epoch 58: Train D loss: 1.0008, G loss: 1.6005
Epoch 59: Train D loss: 0.8756, G loss: 1.6289
Epoch 60: Train D loss: 0.9253, G loss: 1.6770
Epoch 61: Train D loss: 0.8097, G loss: 1.7874
Epoch 62: Train D loss: 0.8168, G loss: 1.8690
Epoch 63: Train D loss: 0.8487, G loss: 1.8578
Epoch 64: Train D loss: 0.8003, G loss: 1.9216
Epoch 65: Train D loss: 0.8222, G loss: 1.8820
Epoch 66: Train D loss: 0.7529, G loss: 1.9991
Epoch 67: Train D loss: 0.7870, G loss: 1.9942
Epoch 68: Train D loss: 0.6070, G loss: 2.1499
Epoch 69: Train D loss: 0.6849, G loss: 2.1850
Epoch 70: Train D loss: 0.5645, G loss: 2.3048
Epoch 71: Train D loss: 0.6525, G loss: 2.3541
Epoch 72: Train D loss: 0.5725, G loss: 2.3339
Epoch 73: Train D loss: 0.4690, G loss: 2.5266
Epoch 74: Train D loss: 0.4421, G loss: 2.6831
Epoch 75: Train D loss: 0.4909, G loss: 2.7405
Epoch 76: Train D loss: 0.4755, G loss: 2.6397
Epoch 77: Train D loss: 0.2924, G loss: 2.9268
Epoch 78: Train D loss: 0.3036, G loss: 3.0787
Epoch 79: Train D loss: 0.1643, G loss: 3.3467
Epoch 80: Train D loss: 0.4594, G loss: 3.1674
Epoch 81: Train D loss: 0.2945, G loss: 3.1968
Epoch 82: Train D loss: 0.1721, G loss: 3.4859
Epoch 83: Train D loss: 0.6296, G loss: 3.0034
Epoch 84: Train D loss: 0.1400, G loss: 3.5678
Epoch 85: Train D loss: 0.2866, G loss: 3.6853
Epoch 86: Train D loss: 0.3567, G loss: 3.3547
Epoch 87: Train D loss: 0.1112, G loss: 3.7273
Epoch 88: Train D loss: 0.0861, G loss: 4.0452
Epoch 89: Train D loss: 0.0946, G loss: 4.2372
Epoch 90: Train D loss: 0.8998, G loss: 2.6816
Epoch 91: Train D loss: 0.1990, G loss: 3.7121
Epoch 92: Train D loss: 0.4284, G loss: 3.2418
Epoch 93: Train D loss: 0.0850, G loss: 3.9878
Epoch 94: Train D loss: 0.0668, G loss: 4.1930
Epoch 95: Train D loss: 0.0567, G loss: 4.3359
Epoch 96: Train D loss: 0.0570, G loss: 4.4259
Epoch 97: Train D loss: 0.0474, G loss: 4.6437
Epoch 98: Train D loss: 0.9481, G loss: 2.7850
Epoch 99: Train D loss: 0.4302, G loss: 3.4014
Epoch 100: Train D loss: 0.0989, G loss: 3.9864
Epoch 101: Train D loss: 0.0609, G loss: 4.3214
Epoch 102: Train D loss: 0.0498, G loss: 4.5487
Epoch 103: Train D loss: 0.0451, G loss: 4.6218
Epoch 104: Train D loss: 0.0464, G loss: 4.7520
Epoch 105: Train D loss: 0.0371, G loss: 4.8745
Epoch 106: Train D loss: 0.1983, G loss: 4.9354
Epoch 107: Train D loss: 1.0577, G loss: 1.8109
Epoch 108: Train D loss: 0.7774, G loss: 3.1034
Epoch 109: Train D loss: 0.5359, G loss: 3.3039
Epoch 110: Train D loss: 0.2023, G loss: 3.8176
Epoch 111: Train D loss: 0.0711, G loss: 4.3033
Epoch 112: Train D loss: 0.0551, G loss: 4.4357
Epoch 113: Train D loss: 0.0455, G loss: 4.6138
Epoch 114: Train D loss: 0.7499, G loss: 3.3320
Epoch 115: Train D loss: 0.1707, G loss: 3.8742
Epoch 116: Train D loss: 0.0536, G loss: 4.4872
Epoch 117: Train D loss: 0.0465, G loss: 4.7260
Epoch 118: Train D loss: 0.0329, G loss: 4.9563
Epoch 119: Train D loss: 0.0338, G loss: 4.9434
:
Image-image translation
下面介绍一个使用CGAN来做Image-to-Image Translation的模型–pix2pix。
import os
import numpy as np
import math
import itertools
import time
import datetime
import sys
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch.nn as nn
import torch.nn.functional as F
import torch
本次实验使用的是Facade数据集,由于数据集的特殊性,一张图片包括两部分,如下图,左半边为groundtruth,右半边为轮廓,我们需要重写数据集的读取类,下面这个cell是就是用来读取数据集。最终使得我们的模型可以从右边部分的轮廓生成左边的建筑.
(可以跳过阅读)下面是dataset部分代码.
import glob
import random
import os
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
class ImageDataset(Dataset):
def __init__(self, root, transforms_=None, mode="train"):
self.transform = transforms_
# read image
self.files = sorted(glob.glob(os.path.join(root, mode) + "/*.*"))
def __getitem__(self, index):
# crop image,the left half if groundtruth image, and the right half is outline of groundtruth.
img = Image.open(self.files[index % len(self.files)])
w, h = img.size
img_B = img.crop((0, 0, w / 2, h))
img_A = img.crop((w / 2, 0, w, h))
if np.random.random() < 0.5:
# revese the image by 50%
img_A = Image.fromarray(np.array(img_A)[:, ::-1, :], "RGB")
img_B = Image.fromarray(np.array(img_B)[:, ::-1, :], "RGB")
img_A = self.transform(img_A)
img_B = self.transform(img_B)
return {"A": img_A, "B": img_B}
def __len__(self):
return len(self.files)
生成网络G,一个Encoder-Decoder模型,借鉴了U-Net结构,所谓的U-Net是将第i层拼接到第n-i层,这样做是因为第i层和第n-i层的图像大小是一致的。 判别网络D,Pix2Pix中的D被实现为Patch-D,所谓Patch,是指无论生成的图像有多大,将其切分为多个固定大小的Patch输入进D去判断。
import torch.nn as nn
import torch.nn.functional as F
import torch
##############################
# U-NET
##############################
class UNetDown(nn.Module):
def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
super(UNetDown, self).__init__()
layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
if normalize:
# when baych-size is 1, BN is replaced by instance normalization
layers.append(nn.InstanceNorm2d(out_size))
layers.append(nn.LeakyReLU(0.2))
if dropout:
layers.append(nn.Dropout(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class UNetUp(nn.Module):
def __init__(self, in_size, out_size, dropout=0.0):
super(UNetUp, self).__init__()
layers = [
nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
# when baych-size is 1, BN is replaced by instance normalization
nn.InstanceNorm2d(out_size),
nn.ReLU(inplace=True),
]
if dropout:
layers.append(nn.Dropout(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x, skip_input):
x = self.model(x)
x = torch.cat((x, skip_input), 1)
return x
class GeneratorUNet(nn.Module):
def __init__(self, in_channels=3, out_channels=3):
super(GeneratorUNet, self).__init__()
self.down1 = UNetDown(in_channels, 64, normalize=False)
self.down2 = UNetDown(64, 128)
self.down3 = UNetDown(128, 256)
self.down4 = UNetDown(256, 256, dropout=0.5)
self.down5 = UNetDown(256, 256, dropout=0.5)
self.down6 = UNetDown(256, 256, normalize=False, dropout=0.5)
self.up1 = UNetUp(256, 256, dropout=0.5)
self.up2 = UNetUp(512, 256)
self.up3 = UNetUp(512, 256)
self.up4 = UNetUp(512, 128)
self.up5 = UNetUp(256, 64)
self.final = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(128, out_channels, 4, padding=1),
nn.Tanh(),
)
def forward(self, x):
# U-Net generator with skip connections from encoder to decoder
d1 = self.down1(x)# 32x32
d2 = self.down2(d1)#16x16
d3 = self.down3(d2)#8x8
d4 = self.down4(d3)#4x4
d5 = self.down5(d4)#2x2
d6 = self.down6(d5)#1x1
u1 = self.up1(d6, d5)#2x2
u2 = self.up2(u1, d4)#4x4
u3 = self.up3(u2, d3)#8x8
u4 = self.up4(u3, d2)#16x16
u5 = self.up5(u4, d1)#32x32
return self.final(u5)#64x64
##############################
# Discriminator
##############################
class Discriminator(nn.Module):
def __init__(self, in_channels=3):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, normalization=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalization:
# when baych-size is 1, BN is replaced by instance normalization
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*discriminator_block(in_channels * 2, 64, normalization=False),#32x32
*discriminator_block(64, 128),#16x16
*discriminator_block(128, 256),#8x8
*discriminator_block(256, 256),#4x4
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(256, 1, 4, padding=1, bias=False)#4x4
)
def forward(self, img_A, img_B):
# Concatenate image and condition image by channels to produce input
img_input = torch.cat((img_A, img_B), 1)
return self.model(img_input)
(可以跳过阅读)下面这个函数用来保存轮廓图,生成图片,groundtruth,以作对比。
from utils import show
def sample_images(dataloader, G, device):
"""Saves a generated sample from the validation set"""
imgs = next(iter(dataloader))
real_A = imgs["A"].to(device)
real_B = imgs["B"].to(device)
fake_B = G(real_A)
img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)
show(torchvision.utils.make_grid(img_sample.cpu().data, nrow=5, normalize=True))
接着定义一些超参数lambda_pixel
# hyper param
n_epochs = 200
batch_size = 2
lr = 0.0002
img_size = 64
channels = 3
device = torch.device('cuda:0')
betas = (0.5, 0.999)
# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 1
对于pix2pix的loss function,包括CGAN的loss,加上L1Loss,其中L1Loss之前有一个系数lambda,用于调节两者之间的权重。
这里定义损失函数和优化器,这里损失函数使用了MSEloss作为GAN的loss(LSGAN).
from utils import weights_init_normal
# Loss functions
criterion_GAN = torch.nn.MSELoss().to(device)
criterion_pixelwise = torch.nn.L1Loss().to(device)
# Calculate output of image discriminator (PatchGAN)
patch = (1, img_size // 16, img_size // 16)
# Initialize generator and discriminator
G = GeneratorUNet().to(device)
D = Discriminator().to(device)
G.apply(weights_init_normal)
D.apply(weights_init_normal)
optimizer_G = torch.optim.Adam(G.parameters(), lr=lr, betas=betas)
optimizer_D = torch.optim.Adam(D.parameters(), lr=lr, betas=betas)
# Configure dataloaders
transforms_ = transforms.Compose([
transforms.Resize((img_size, img_size), Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
dataloader = DataLoader(
ImageDataset("./data/facades", transforms_=transforms_),
batch_size=batch_size,
shuffle=True,
num_workers=8,
)
val_dataloader = DataLoader(
ImageDataset("./data/facades", transforms_=transforms_, mode="val"),
batch_size=10,
shuffle=True,
num_workers=1,
)
下面开始训练pix2pix,训练的过程:
- 首先训练G,对于每张图片A(轮廓),用G生成fakeB(建筑),然后fakeB与realB(ground truth)计算L1loss,同时使用D判别(fakeB,A),计算MSEloss(label为1),用这2个loss一起更新G;
- 再训练D,使用(fakeB,A)与(realB,A)计算MSEloss(label前者为0,后者为1),更新D.
for epoch in range(n_epochs):
for i, batch in enumerate(dataloader):
# G:B -> A
real_A = batch["A"].to(device)
real_B = batch["B"].to(device)
# Adversarial ground truths
real_label = torch.ones((real_A.size(0), *patch)).to(device)
fake_label = torch.zeros((real_A.size(0), *patch)).to(device)
# ------------------
# Train Generators
# ------------------
optimizer_G.zero_grad()
# GAN loss
fake_B = G(real_A)
pred_fake = D(fake_B, real_A)
loss_GAN = criterion_GAN(pred_fake, real_label)
# Pixel-wise loss
loss_pixel = criterion_pixelwise(fake_B, real_B)
# Total loss
loss_G = loss_GAN + lambda_pixel * loss_pixel
loss_G.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Real loss
pred_real = D(real_B, real_A)
loss_real = criterion_GAN(pred_real, real_label)
# Fake loss
pred_fake = D(fake_B.detach(), real_A)
loss_fake = criterion_GAN(pred_fake, fake_label)
# Total loss
loss_D = 0.5 * (loss_real + loss_fake)
loss_D.backward()
optimizer_D.step()
# Print log
print(
"\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f]"
% (
epoch,
n_epochs,
i,
len(dataloader),
loss_D.item(),
loss_G.item(),
loss_pixel.item(),
loss_GAN.item(),
)
)
# If at sample interval save image
if epoch == 0 or (epoch + 1) % 5 == 0:
sample_images(val_dataloader, G, device)
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/upsampling.py:129: UserWarning: nn.Upsample is deprecated. Use nn.functional.interpolate instead.
warnings.warn("nn.{} is deprecated. Use nn.functional.interpolate instead.".format(self.name))
[Epoch 0/200] [Batch 199/200] [D loss: 0.329559] [G loss: 0.837497, pixel: 0.370509, adv: 0.466988]
[Epoch 1/200] [Batch 199/200] [D loss: 0.187533] [G loss: 0.690237, pixel: 0.384734, adv: 0.305503]
[Epoch 2/200] [Batch 199/200] [D loss: 0.192769] [G loss: 0.710474, pixel: 0.357925, adv: 0.352549]
[Epoch 3/200] [Batch 199/200] [D loss: 0.257360] [G loss: 0.608871, pixel: 0.327612, adv: 0.281260]
[Epoch 4/200] [Batch 199/200] [D loss: 0.147929] [G loss: 0.887955, pixel: 0.474433, adv: 0.413522]
[Epoch 5/200] [Batch 199/200] [D loss: 0.377922] [G loss: 0.743606, pixel: 0.492447, adv: 0.251159]
[Epoch 6/200] [Batch 199/200] [D loss: 0.209727] [G loss: 0.689151, pixel: 0.384093, adv: 0.305057]
[Epoch 7/200] [Batch 199/200] [D loss: 0.224705] [G loss: 1.000042, pixel: 0.639260, adv: 0.360782]
[Epoch 8/200] [Batch 199/200] [D loss: 0.144029] [G loss: 1.020684, pixel: 0.503782, adv: 0.516902]
[Epoch 9/200] [Batch 199/200] [D loss: 0.254280] [G loss: 0.809810, pixel: 0.416601, adv: 0.393209]
[Epoch 10/200] [Batch 199/200] [D loss: 0.243891] [G loss: 0.895190, pixel: 0.446443, adv: 0.448747]
[Epoch 11/200] [Batch 199/200] [D loss: 0.210248] [G loss: 0.712496, pixel: 0.409450, adv: 0.303046]
[Epoch 12/200] [Batch 199/200] [D loss: 0.178942] [G loss: 0.673143, pixel: 0.382591, adv: 0.290552]
[Epoch 13/200] [Batch 199/200] [D loss: 0.116803] [G loss: 1.028422, pixel: 0.466901, adv: 0.561522]
[Epoch 14/200] [Batch 199/200] [D loss: 0.236468] [G loss: 0.860611, pixel: 0.383249, adv: 0.477362]
[Epoch 15/200] [Batch 199/200] [D loss: 0.220974] [G loss: 0.686148, pixel: 0.446806, adv: 0.239342]
[Epoch 16/200] [Batch 199/200] [D loss: 0.296042] [G loss: 1.235985, pixel: 0.507029, adv: 0.728956]
[Epoch 17/200] [Batch 199/200] [D loss: 0.223143] [G loss: 0.806767, pixel: 0.373452, adv: 0.433314]
[Epoch 18/200] [Batch 199/200] [D loss: 0.164129] [G loss: 1.060684, pixel: 0.519046, adv: 0.541638]
[Epoch 19/200] [Batch 199/200] [D loss: 0.132792] [G loss: 1.019057, pixel: 0.431385, adv: 0.587671]
[Epoch 20/200] [Batch 199/200] [D loss: 0.210773] [G loss: 1.006550, pixel: 0.355910, adv: 0.650640]
[Epoch 21/200] [Batch 199/200] [D loss: 0.197349] [G loss: 0.917636, pixel: 0.464914, adv: 0.452722]
[Epoch 22/200] [Batch 199/200] [D loss: 0.315029] [G loss: 0.775995, pixel: 0.473480, adv: 0.302515]
[Epoch 23/200] [Batch 199/200] [D loss: 0.215998] [G loss: 0.834821, pixel: 0.451176, adv: 0.383645]
[Epoch 24/200] [Batch 199/200] [D loss: 0.139875] [G loss: 0.799897, pixel: 0.412393, adv: 0.387504]
[Epoch 25/200] [Batch 199/200] [D loss: 0.190896] [G loss: 0.672478, pixel: 0.454510, adv: 0.217968]
[Epoch 26/200] [Batch 199/200] [D loss: 0.284857] [G loss: 0.748555, pixel: 0.531898, adv: 0.216657]
[Epoch 27/200] [Batch 199/200] [D loss: 0.119153] [G loss: 0.961913, pixel: 0.458838, adv: 0.503075]
[Epoch 28/200] [Batch 199/200] [D loss: 0.303614] [G loss: 0.743979, pixel: 0.453780, adv: 0.290199]
[Epoch 29/200] [Batch 199/200] [D loss: 0.152767] [G loss: 0.888828, pixel: 0.408863, adv: 0.479965]
[Epoch 30/200] [Batch 199/200] [D loss: 0.086363] [G loss: 1.132304, pixel: 0.526258, adv: 0.606046]
[Epoch 31/200] [Batch 199/200] [D loss: 0.170432] [G loss: 0.892955, pixel: 0.439683, adv: 0.453272]
[Epoch 32/200] [Batch 199/200] [D loss: 0.082255] [G loss: 1.501711, pixel: 0.480995, adv: 1.020716]
[Epoch 33/200] [Batch 199/200] [D loss: 0.100667] [G loss: 1.003713, pixel: 0.436757, adv: 0.566956]
[Epoch 34/200] [Batch 199/200] [D loss: 0.174403] [G loss: 0.696919, pixel: 0.390337, adv: 0.306583]
[Epoch 35/200] [Batch 199/200] [D loss: 0.327912] [G loss: 1.071013, pixel: 0.365537, adv: 0.705477]
[Epoch 36/200] [Batch 199/200] [D loss: 0.212015] [G loss: 0.746720, pixel: 0.390056, adv: 0.356664]
[Epoch 37/200] [Batch 199/200] [D loss: 0.084461] [G loss: 1.030632, pixel: 0.451042, adv: 0.579591]
[Epoch 38/200] [Batch 199/200] [D loss: 0.046271] [G loss: 1.881051, pixel: 0.425613, adv: 1.455438]
[Epoch 39/200] [Batch 199/200] [D loss: 0.100079] [G loss: 1.065145, pixel: 0.462124, adv: 0.603022]
[Epoch 40/200] [Batch 199/200] [D loss: 0.057483] [G loss: 1.097130, pixel: 0.521220, adv: 0.575910]
[Epoch 41/200] [Batch 199/200] [D loss: 0.084353] [G loss: 1.751491, pixel: 0.453270, adv: 1.298221]
[Epoch 42/200] [Batch 199/200] [D loss: 0.116546] [G loss: 1.467547, pixel: 0.437802, adv: 1.029745]
[Epoch 43/200] [Batch 199/200] [D loss: 0.081984] [G loss: 0.924543, pixel: 0.489197, adv: 0.435347]
[Epoch 44/200] [Batch 199/200] [D loss: 0.182606] [G loss: 0.823386, pixel: 0.483347, adv: 0.340038]
[Epoch 45/200] [Batch 199/200] [D loss: 0.138830] [G loss: 0.853829, pixel: 0.362748, adv: 0.491082]
[Epoch 46/200] [Batch 199/200] [D loss: 0.085006] [G loss: 1.059178, pixel: 0.424786, adv: 0.634392]
[Epoch 47/200] [Batch 199/200] [D loss: 0.138033] [G loss: 1.319859, pixel: 0.399257, adv: 0.920602]
[Epoch 48/200] [Batch 199/200] [D loss: 0.115646] [G loss: 1.412395, pixel: 0.324788, adv: 1.087608]
[Epoch 49/200] [Batch 199/200] [D loss: 0.044216] [G loss: 1.235876, pixel: 0.493417, adv: 0.742459]
[Epoch 50/200] [Batch 199/200] [D loss: 0.131177] [G loss: 0.773633, pixel: 0.376225, adv: 0.397407]
[Epoch 51/200] [Batch 199/200] [D loss: 0.067668] [G loss: 0.976136, pixel: 0.384608, adv: 0.591528]
[Epoch 52/200] [Batch 199/200] [D loss: 0.101439] [G loss: 1.171204, pixel: 0.405691, adv: 0.765513]
[Epoch 53/200] [Batch 199/200] [D loss: 0.055503] [G loss: 1.202201, pixel: 0.493737, adv: 0.708464]
[Epoch 54/200] [Batch 199/200] [D loss: 0.071844] [G loss: 1.415663, pixel: 0.513541, adv: 0.902122]
[Epoch 55/200] [Batch 199/200] [D loss: 0.024924] [G loss: 1.331125, pixel: 0.430270, adv: 0.900854]
[Epoch 56/200] [Batch 199/200] [D loss: 0.069257] [G loss: 1.498419, pixel: 0.474226, adv: 1.024193]
[Epoch 57/200] [Batch 199/200] [D loss: 0.070797] [G loss: 1.130181, pixel: 0.428668, adv: 0.701513]
[Epoch 58/200] [Batch 199/200] [D loss: 0.085243] [G loss: 1.172139, pixel: 0.406482, adv: 0.765656]
[Epoch 59/200] [Batch 199/200] [D loss: 0.091993] [G loss: 1.530317, pixel: 0.399569, adv: 1.130749]
[Epoch 60/200] [Batch 199/200] [D loss: 0.145450] [G loss: 1.073283, pixel: 0.454153, adv: 0.619130]
[Epoch 61/200] [Batch 199/200] [D loss: 0.078339] [G loss: 1.825652, pixel: 0.356573, adv: 1.469079]
[Epoch 62/200] [Batch 199/200] [D loss: 0.054064] [G loss: 1.838756, pixel: 0.324753, adv: 1.514003]
[Epoch 63/200] [Batch 199/200] [D loss: 0.103259] [G loss: 1.146301, pixel: 0.400082, adv: 0.746219]
[Epoch 64/200] [Batch 199/200] [D loss: 0.054458] [G loss: 1.055564, pixel: 0.504933, adv: 0.550630]
[Epoch 65/200] [Batch 199/200] [D loss: 0.061942] [G loss: 1.100672, pixel: 0.395495, adv: 0.705177]
[Epoch 66/200] [Batch 199/200] [D loss: 0.085689] [G loss: 0.860537, pixel: 0.386446, adv: 0.474090]
[Epoch 67/200] [Batch 199/200] [D loss: 0.050138] [G loss: 1.295279, pixel: 0.357417, adv: 0.937862]
[Epoch 68/200] [Batch 199/200] [D loss: 0.028183] [G loss: 1.624760, pixel: 0.593959, adv: 1.030801]
[Epoch 69/200] [Batch 199/200] [D loss: 0.044521] [G loss: 1.419075, pixel: 0.469773, adv: 0.949302]
[Epoch 70/200] [Batch 199/200] [D loss: 0.036712] [G loss: 1.157503, pixel: 0.481283, adv: 0.676221]
[Epoch 71/200] [Batch 199/200] [D loss: 0.075681] [G loss: 1.276237, pixel: 0.385465, adv: 0.890772]
[Epoch 72/200] [Batch 199/200] [D loss: 0.077527] [G loss: 1.177667, pixel: 0.379942, adv: 0.797725]
[Epoch 73/200] [Batch 199/200] [D loss: 0.042144] [G loss: 1.063881, pixel: 0.456144, adv: 0.607737]
[Epoch 74/200] [Batch 199/200] [D loss: 0.085770] [G loss: 1.122791, pixel: 0.492341, adv: 0.630450]
[Epoch 75/200] [Batch 199/200] [D loss: 0.103890] [G loss: 1.007681, pixel: 0.469902, adv: 0.537779]
[Epoch 76/200] [Batch 199/200] [D loss: 0.053791] [G loss: 1.174764, pixel: 0.411327, adv: 0.763437]
[Epoch 77/200] [Batch 199/200] [D loss: 0.020036] [G loss: 1.313630, pixel: 0.395490, adv: 0.918141]
[Epoch 78/200] [Batch 199/200] [D loss: 0.126697] [G loss: 1.074769, pixel: 0.373476, adv: 0.701293]
[Epoch 79/200] [Batch 199/200] [D loss: 0.024991] [G loss: 1.667502, pixel: 0.502399, adv: 1.165103]
[Epoch 80/200] [Batch 199/200] [D loss: 0.039185] [G loss: 1.543022, pixel: 0.401599, adv: 1.141423]
[Epoch 81/200] [Batch 199/200] [D loss: 0.026950] [G loss: 1.390851, pixel: 0.535749, adv: 0.855101]
[Epoch 82/200] [Batch 199/200] [D loss: 0.015496] [G loss: 1.451605, pixel: 0.495355, adv: 0.956250]
[Epoch 83/200] [Batch 199/200] [D loss: 0.040288] [G loss: 1.348929, pixel: 0.445604, adv: 0.903325]
[Epoch 84/200] [Batch 199/200] [D loss: 0.025638] [G loss: 1.267531, pixel: 0.403702, adv: 0.863829]
[Epoch 85/200] [Batch 199/200] [D loss: 0.088575] [G loss: 1.346314, pixel: 0.416517, adv: 0.929797]
[Epoch 86/200] [Batch 199/200] [D loss: 0.064832] [G loss: 1.030367, pixel: 0.419878, adv: 0.610489]
[Epoch 87/200] [Batch 199/200] [D loss: 0.016706] [G loss: 1.378341, pixel: 0.525949, adv: 0.852392]
[Epoch 88/200] [Batch 199/200] [D loss: 0.078510] [G loss: 0.865967, pixel: 0.411993, adv: 0.453975]
[Epoch 89/200] [Batch 199/200] [D loss: 0.049418] [G loss: 1.415064, pixel: 0.318340, adv: 1.096724]
[Epoch 90/200] [Batch 199/200] [D loss: 0.036470] [G loss: 1.326276, pixel: 0.561596, adv: 0.764681]
[Epoch 91/200] [Batch 199/200] [D loss: 0.026371] [G loss: 1.198210, pixel: 0.423642, adv: 0.774568]
[Epoch 92/200] [Batch 199/200] [D loss: 0.110314] [G loss: 1.079814, pixel: 0.404454, adv: 0.675360]
[Epoch 93/200] [Batch 199/200] [D loss: 0.022842] [G loss: 1.389670, pixel: 0.360465, adv: 1.029206]
[Epoch 94/200] [Batch 199/200] [D loss: 0.037473] [G loss: 1.001947, pixel: 0.345857, adv: 0.656090]
[Epoch 95/200] [Batch 199/200] [D loss: 0.037387] [G loss: 1.497730, pixel: 0.449730, adv: 1.048000]
[Epoch 96/200] [Batch 199/200] [D loss: 0.042948] [G loss: 1.324835, pixel: 0.525198, adv: 0.799637]
[Epoch 97/200] [Batch 199/200] [D loss: 0.070627] [G loss: 1.534212, pixel: 0.338898, adv: 1.195314]
[Epoch 98/200] [Batch 199/200] [D loss: 0.039812] [G loss: 1.462890, pixel: 0.330236, adv: 1.132654]
[Epoch 99/200] [Batch 199/200] [D loss: 0.033521] [G loss: 1.198216, pixel: 0.414350, adv: 0.783866]
[Epoch 100/200] [Batch 199/200] [D loss: 0.032366] [G loss: 1.328524, pixel: 0.449658, adv: 0.878866]
[Epoch 101/200] [Batch 199/200] [D loss: 0.021940] [G loss: 1.399054, pixel: 0.348522, adv: 1.050532]
[Epoch 102/200] [Batch 199/200] [D loss: 0.017531] [G loss: 1.258798, pixel: 0.380441, adv: 0.878357]
[Epoch 103/200] [Batch 199/200] [D loss: 0.019968] [G loss: 1.470231, pixel: 0.594463, adv: 0.875768]
[Epoch 104/200] [Batch 199/200] [D loss: 0.043118] [G loss: 1.165057, pixel: 0.454049, adv: 0.711009]
[Epoch 105/200] [Batch 199/200] [D loss: 0.020502] [G loss: 1.532076, pixel: 0.514340, adv: 1.017736]
[Epoch 106/200] [Batch 199/200] [D loss: 0.041665] [G loss: 1.530959, pixel: 0.358589, adv: 1.172370]
[Epoch 107/200] [Batch 199/200] [D loss: 0.026232] [G loss: 1.083097, pixel: 0.353141, adv: 0.729956]
[Epoch 108/200] [Batch 199/200] [D loss: 0.015473] [G loss: 1.475121, pixel: 0.432230, adv: 1.042891]
[Epoch 109/200] [Batch 199/200] [D loss: 0.016544] [G loss: 1.460269, pixel: 0.395489, adv: 1.064780]
[Epoch 110/200] [Batch 199/200] [D loss: 0.030761] [G loss: 1.332439, pixel: 0.365682, adv: 0.966757]
[Epoch 111/200] [Batch 199/200] [D loss: 0.050728] [G loss: 1.191074, pixel: 0.427588, adv: 0.763487]
[Epoch 112/200] [Batch 199/200] [D loss: 0.018158] [G loss: 1.542107, pixel: 0.520371, adv: 1.021735]
[Epoch 113/200] [Batch 199/200] [D loss: 0.050946] [G loss: 1.014630, pixel: 0.441678, adv: 0.572952]
[Epoch 114/200] [Batch 199/200] [D loss: 0.029447] [G loss: 1.650831, pixel: 0.527636, adv: 1.123196]
[Epoch 115/200] [Batch 199/200] [D loss: 0.053841] [G loss: 0.906547, pixel: 0.358696, adv: 0.547852]
[Epoch 116/200] [Batch 199/200] [D loss: 0.039013] [G loss: 1.084382, pixel: 0.304030, adv: 0.780352]
[Epoch 117/200] [Batch 199/200] [D loss: 0.022909] [G loss: 1.241905, pixel: 0.474598, adv: 0.767307]
[Epoch 118/200] [Batch 199/200] [D loss: 0.014374] [G loss: 1.314399, pixel: 0.463417, adv: 0.850982]
[Epoch 119/200] [Batch 199/200] [D loss: 0.066138] [G loss: 1.027380, pixel: 0.384822, adv: 0.642558]
[Epoch 120/200] [Batch 199/200] [D loss: 0.035893] [G loss: 1.009986, pixel: 0.343447, adv: 0.666538]
[Epoch 121/200] [Batch 199/200] [D loss: 0.023549] [G loss: 1.244753, pixel: 0.356718, adv: 0.888035]
[Epoch 122/200] [Batch 199/200] [D loss: 0.019047] [G loss: 1.731042, pixel: 0.560255, adv: 1.170787]
[Epoch 123/200] [Batch 199/200] [D loss: 0.015558] [G loss: 1.535109, pixel: 0.508780, adv: 1.026329]
[Epoch 124/200] [Batch 199/200] [D loss: 0.135557] [G loss: 1.524657, pixel: 0.447322, adv: 1.077335]
[Epoch 125/200] [Batch 199/200] [D loss: 0.017648] [G loss: 1.424236, pixel: 0.377841, adv: 1.046396]
[Epoch 126/200] [Batch 199/200] [D loss: 0.033450] [G loss: 1.626775, pixel: 0.434452, adv: 1.192323]
[Epoch 127/200] [Batch 199/200] [D loss: 0.072135] [G loss: 0.924539, pixel: 0.374359, adv: 0.550180]
[Epoch 128/200] [Batch 199/200] [D loss: 0.045950] [G loss: 1.254386, pixel: 0.363163, adv: 0.891223]
[Epoch 129/200] [Batch 199/200] [D loss: 0.010190] [G loss: 1.406185, pixel: 0.438934, adv: 0.967251]
[Epoch 130/200] [Batch 199/200] [D loss: 0.010876] [G loss: 1.308254, pixel: 0.472779, adv: 0.835475]
[Epoch 131/200] [Batch 199/200] [D loss: 0.024451] [G loss: 1.241876, pixel: 0.459188, adv: 0.782688]
[Epoch 132/200] [Batch 199/200] [D loss: 0.024267] [G loss: 1.150942, pixel: 0.381786, adv: 0.769156]
[Epoch 133/200] [Batch 199/200] [D loss: 0.028550] [G loss: 1.226290, pixel: 0.417267, adv: 0.809023]
[Epoch 134/200] [Batch 199/200] [D loss: 0.035529] [G loss: 1.150472, pixel: 0.461133, adv: 0.689339]
[Epoch 135/200] [Batch 199/200] [D loss: 0.075388] [G loss: 1.057127, pixel: 0.388512, adv: 0.668615]
[Epoch 136/200] [Batch 199/200] [D loss: 0.016374] [G loss: 1.203050, pixel: 0.430617, adv: 0.772433]
[Epoch 137/200] [Batch 199/200] [D loss: 0.015597] [G loss: 1.139218, pixel: 0.362835, adv: 0.776384]
[Epoch 138/200] [Batch 199/200] [D loss: 0.013657] [G loss: 1.274905, pixel: 0.380114, adv: 0.894791]
[Epoch 139/200] [Batch 199/200] [D loss: 0.011944] [G loss: 1.411466, pixel: 0.392570, adv: 1.018895]
[Epoch 140/200] [Batch 199/200] [D loss: 0.022011] [G loss: 1.341472, pixel: 0.383966, adv: 0.957506]
[Epoch 141/200] [Batch 199/200] [D loss: 0.034214] [G loss: 1.590953, pixel: 0.420876, adv: 1.170077]
[Epoch 142/200] [Batch 199/200] [D loss: 0.069331] [G loss: 0.913922, pixel: 0.432432, adv: 0.481490]
[Epoch 143/200] [Batch 199/200] [D loss: 0.028701] [G loss: 1.152182, pixel: 0.375248, adv: 0.776934]
[Epoch 144/200] [Batch 199/200] [D loss: 0.037139] [G loss: 1.084857, pixel: 0.308332, adv: 0.776525]
[Epoch 145/200] [Batch 199/200] [D loss: 0.072288] [G loss: 1.587902, pixel: 0.359959, adv: 1.227943]
[Epoch 146/200] [Batch 199/200] [D loss: 0.041215] [G loss: 1.109346, pixel: 0.392038, adv: 0.717308]
[Epoch 147/200] [Batch 199/200] [D loss: 0.034205] [G loss: 1.752288, pixel: 0.441332, adv: 1.310956]
[Epoch 148/200] [Batch 199/200] [D loss: 0.012378] [G loss: 1.168075, pixel: 0.303929, adv: 0.864146]
[Epoch 149/200] [Batch 199/200] [D loss: 0.051096] [G loss: 1.067259, pixel: 0.498688, adv: 0.568571]
[Epoch 150/200] [Batch 199/200] [D loss: 0.071198] [G loss: 0.817898, pixel: 0.358357, adv: 0.459542]
[Epoch 151/200] [Batch 199/200] [D loss: 0.024366] [G loss: 1.234428, pixel: 0.405045, adv: 0.829383]
[Epoch 152/200] [Batch 199/200] [D loss: 0.013819] [G loss: 1.482826, pixel: 0.491976, adv: 0.990851]
[Epoch 153/200] [Batch 199/200] [D loss: 0.043319] [G loss: 1.351990, pixel: 0.502132, adv: 0.849858]
[Epoch 154/200] [Batch 199/200] [D loss: 0.023567] [G loss: 1.616637, pixel: 0.458313, adv: 1.158324]
[Epoch 155/200] [Batch 199/200] [D loss: 0.042370] [G loss: 1.022207, pixel: 0.376785, adv: 0.645422]
[Epoch 156/200] [Batch 199/200] [D loss: 0.042269] [G loss: 0.945393, pixel: 0.363320, adv: 0.582073]
[Epoch 157/200] [Batch 199/200] [D loss: 0.032874] [G loss: 1.185990, pixel: 0.470290, adv: 0.715700]
[Epoch 158/200] [Batch 199/200] [D loss: 0.022171] [G loss: 1.377957, pixel: 0.428952, adv: 0.949005]
[Epoch 159/200] [Batch 199/200] [D loss: 0.059965] [G loss: 0.842113, pixel: 0.348511, adv: 0.493603]
[Epoch 160/200] [Batch 199/200] [D loss: 0.023744] [G loss: 1.060010, pixel: 0.316527, adv: 0.743483]
[Epoch 161/200] [Batch 199/200] [D loss: 0.021823] [G loss: 1.504025, pixel: 0.317761, adv: 1.186264]
[Epoch 162/200] [Batch 199/200] [D loss: 0.032255] [G loss: 1.077471, pixel: 0.350018, adv: 0.727453]
[Epoch 163/200] [Batch 199/200] [D loss: 0.018691] [G loss: 1.524051, pixel: 0.502410, adv: 1.021640]
[Epoch 164/200] [Batch 199/200] [D loss: 0.020606] [G loss: 1.301775, pixel: 0.348564, adv: 0.953212]
[Epoch 165/200] [Batch 199/200] [D loss: 0.045401] [G loss: 0.964343, pixel: 0.355987, adv: 0.608357]
[Epoch 166/200] [Batch 199/200] [D loss: 0.112676] [G loss: 0.938677, pixel: 0.465589, adv: 0.473088]
[Epoch 167/200] [Batch 199/200] [D loss: 0.025489] [G loss: 1.460020, pixel: 0.422306, adv: 1.037714]
[Epoch 168/200] [Batch 199/200] [D loss: 0.019899] [G loss: 1.384184, pixel: 0.531230, adv: 0.852953]
[Epoch 169/200] [Batch 199/200] [D loss: 0.033787] [G loss: 1.480100, pixel: 0.355705, adv: 1.124394]
[Epoch 170/200] [Batch 199/200] [D loss: 0.036215] [G loss: 0.898988, pixel: 0.282254, adv: 0.616734]
[Epoch 171/200] [Batch 199/200] [D loss: 0.091275] [G loss: 0.871186, pixel: 0.327691, adv: 0.543495]
[Epoch 172/200] [Batch 199/200] [D loss: 0.087182] [G loss: 0.780764, pixel: 0.335386, adv: 0.445378]
[Epoch 173/200] [Batch 199/200] [D loss: 0.015739] [G loss: 1.251271, pixel: 0.469624, adv: 0.781647]
[Epoch 174/200] [Batch 199/200] [D loss: 0.017278] [G loss: 1.346440, pixel: 0.402886, adv: 0.943554]
[Epoch 175/200] [Batch 199/200] [D loss: 0.019024] [G loss: 1.158731, pixel: 0.340505, adv: 0.818225]
[Epoch 176/200] [Batch 199/200] [D loss: 0.008746] [G loss: 1.529629, pixel: 0.400643, adv: 1.128986]
[Epoch 177/200] [Batch 199/200] [D loss: 0.021719] [G loss: 1.176342, pixel: 0.422043, adv: 0.754300]
[Epoch 178/200] [Batch 199/200] [D loss: 0.013136] [G loss: 1.187904, pixel: 0.373253, adv: 0.814650]
[Epoch 179/200] [Batch 199/200] [D loss: 0.026584] [G loss: 1.150412, pixel: 0.426622, adv: 0.723790]
[Epoch 180/200] [Batch 199/200] [D loss: 0.014295] [G loss: 1.537314, pixel: 0.404008, adv: 1.133307]
[Epoch 181/200] [Batch 199/200] [D loss: 0.013056] [G loss: 1.406345, pixel: 0.446078, adv: 0.960267]
[Epoch 182/200] [Batch 199/200] [D loss: 0.030828] [G loss: 1.034685, pixel: 0.394839, adv: 0.639846]
[Epoch 183/200] [Batch 199/200] [D loss: 0.020732] [G loss: 1.441298, pixel: 0.392655, adv: 1.048643]
[Epoch 184/200] [Batch 199/200] [D loss: 0.114661] [G loss: 0.802386, pixel: 0.425005, adv: 0.377381]
[Epoch 185/200] [Batch 199/200] [D loss: 0.059408] [G loss: 0.796534, pixel: 0.274744, adv: 0.521790]
[Epoch 186/200] [Batch 199/200] [D loss: 0.034220] [G loss: 0.985536, pixel: 0.342627, adv: 0.642909]
[Epoch 187/200] [Batch 199/200] [D loss: 0.035371] [G loss: 1.045623, pixel: 0.296336, adv: 0.749287]
[Epoch 188/200] [Batch 199/200] [D loss: 0.042729] [G loss: 1.420175, pixel: 0.410863, adv: 1.009312]
[Epoch 189/200] [Batch 199/200] [D loss: 0.021064] [G loss: 1.135419, pixel: 0.292835, adv: 0.842585]
[Epoch 190/200] [Batch 199/200] [D loss: 0.023207] [G loss: 1.274331, pixel: 0.346932, adv: 0.927399]
[Epoch 191/200] [Batch 199/200] [D loss: 0.017008] [G loss: 1.253284, pixel: 0.341651, adv: 0.911634]
[Epoch 192/200] [Batch 199/200] [D loss: 0.053124] [G loss: 0.952296, pixel: 0.424036, adv: 0.528260]
[Epoch 193/200] [Batch 199/200] [D loss: 0.028502] [G loss: 1.004215, pixel: 0.253622, adv: 0.750593]
[Epoch 194/200] [Batch 199/200] [D loss: 0.021864] [G loss: 1.111022, pixel: 0.212485, adv: 0.898536]
[Epoch 195/200] [Batch 199/200] [D loss: 0.029861] [G loss: 1.303421, pixel: 0.308407, adv: 0.995015]
[Epoch 196/200] [Batch 199/200] [D loss: 0.014942] [G loss: 1.487336, pixel: 0.484178, adv: 1.003158]
[Epoch 197/200] [Batch 199/200] [D loss: 0.037478] [G loss: 1.205319, pixel: 0.364015, adv: 0.841304]
[Epoch 198/200] [Batch 199/200] [D loss: 0.046917] [G loss: 0.954051, pixel: 0.309143, adv: 0.644907]
[Epoch 199/200] [Batch 199/200] [D loss: 0.026501] [G loss: 1.166066, pixel: 0.400264, adv: 0.765802]
作业:
- 只用L1 Loss的情况下训练pix2pix.说说有结果什么不同.
答:生成图片比较模糊,轮廓没有之前的清晰,建筑物的细节也少了很多,颜色也比较单一。
for epoch in range(n_epochs):
for i, batch in enumerate(dataloader):
# G:B -> A
real_A = batch["A"].to(device)
real_B = batch["B"].to(device)
# ------------------
# Train Generators
# ------------------
optimizer_G.zero_grad()
# GAN loss
fake_B = G(real_A)
# Pixel-wise loss
loss_pixel = criterion_pixelwise(fake_B, real_B)
# Total loss
loss_G = loss_pixel
loss_G.backward()
optimizer_G.step()
# Print log
print(
"\r[Epoch %d/%d] [Batch %d/%d] [G loss: %f]"
% (
epoch,
n_epochs,
i,
len(dataloader),
loss_G.item()
)
)
# If at sample interval save image
if epoch == 0 or (epoch + 1) % 5 == 0:
sample_images(val_dataloader, G, device)
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/upsampling.py:129: UserWarning: nn.Upsample is deprecated. Use nn.functional.interpolate instead.
warnings.warn("nn.{} is deprecated. Use nn.functional.interpolate instead.".format(self.name))
[Epoch 0/200] [Batch 199/200] [G loss: 0.277459]
[Epoch 1/200] [Batch 199/200] [G loss: 0.280747]
[Epoch 2/200] [Batch 199/200] [G loss: 0.302678]
[Epoch 3/200] [Batch 199/200] [G loss: 0.277463]
[Epoch 4/200] [Batch 199/200] [G loss: 0.305907]
[Epoch 5/200] [Batch 199/200] [G loss: 0.361403]
[Epoch 6/200] [Batch 199/200] [G loss: 0.262338]
[Epoch 7/200] [Batch 199/200] [G loss: 0.242269]
[Epoch 8/200] [Batch 199/200] [G loss: 0.269101]
[Epoch 9/200] [Batch 199/200] [G loss: 0.282228]
[Epoch 10/200] [Batch 199/200] [G loss: 0.314959]
[Epoch 11/200] [Batch 199/200] [G loss: 0.264300]
[Epoch 12/200] [Batch 199/200] [G loss: 0.317328]
[Epoch 13/200] [Batch 199/200] [G loss: 0.288205]
[Epoch 14/200] [Batch 199/200] [G loss: 0.268344]
[Epoch 15/200] [Batch 199/200] [G loss: 0.270621]
[Epoch 16/200] [Batch 199/200] [G loss: 0.260496]
[Epoch 17/200] [Batch 199/200] [G loss: 0.295739]
[Epoch 18/200] [Batch 199/200] [G loss: 0.172208]
[Epoch 19/200] [Batch 199/200] [G loss: 0.208443]
[Epoch 20/200] [Batch 199/200] [G loss: 0.199149]
[Epoch 21/200] [Batch 199/200] [G loss: 0.252810]
[Epoch 22/200] [Batch 199/200] [G loss: 0.249091]
[Epoch 23/200] [Batch 199/200] [G loss: 0.215632]
[Epoch 24/200] [Batch 199/200] [G loss: 0.243048]
[Epoch 25/200] [Batch 199/200] [G loss: 0.203973]
[Epoch 26/200] [Batch 199/200] [G loss: 0.167193]
[Epoch 27/200] [Batch 199/200] [G loss: 0.198062]
[Epoch 28/200] [Batch 199/200] [G loss: 0.188340]
[Epoch 29/200] [Batch 199/200] [G loss: 0.203220]
[Epoch 30/200] [Batch 199/200] [G loss: 0.226458]
[Epoch 31/200] [Batch 199/200] [G loss: 0.179194]
[Epoch 32/200] [Batch 199/200] [G loss: 0.253033]
[Epoch 33/200] [Batch 199/200] [G loss: 0.248549]
[Epoch 34/200] [Batch 199/200] [G loss: 0.229301]
[Epoch 35/200] [Batch 199/200] [G loss: 0.198106]
[Epoch 36/200] [Batch 199/200] [G loss: 0.245744]
[Epoch 37/200] [Batch 199/200] [G loss: 0.197577]
[Epoch 38/200] [Batch 199/200] [G loss: 0.177224]
[Epoch 39/200] [Batch 199/200] [G loss: 0.184711]
[Epoch 40/200] [Batch 199/200] [G loss: 0.212995]
[Epoch 41/200] [Batch 199/200] [G loss: 0.226832]
[Epoch 42/200] [Batch 199/200] [G loss: 0.162086]
[Epoch 43/200] [Batch 199/200] [G loss: 0.155080]
[Epoch 44/200] [Batch 199/200] [G loss: 0.219378]
[Epoch 45/200] [Batch 199/200] [G loss: 0.202605]
[Epoch 46/200] [Batch 199/200] [G loss: 0.185659]
[Epoch 47/200] [Batch 199/200] [G loss: 0.157083]
[Epoch 48/200] [Batch 199/200] [G loss: 0.154050]
[Epoch 49/200] [Batch 199/200] [G loss: 0.146045]
[Epoch 50/200] [Batch 199/200] [G loss: 0.164124]
[Epoch 51/200] [Batch 199/200] [G loss: 0.160308]
[Epoch 52/200] [Batch 199/200] [G loss: 0.175741]
[Epoch 53/200] [Batch 199/200] [G loss: 0.175433]
[Epoch 54/200] [Batch 199/200] [G loss: 0.113329]
[Epoch 55/200] [Batch 199/200] [G loss: 0.182398]
[Epoch 56/200] [Batch 199/200] [G loss: 0.149941]
[Epoch 57/200] [Batch 199/200] [G loss: 0.139894]
[Epoch 58/200] [Batch 199/200] [G loss: 0.162389]
[Epoch 59/200] [Batch 199/200] [G loss: 0.162091]
[Epoch 60/200] [Batch 199/200] [G loss: 0.163489]
[Epoch 61/200] [Batch 199/200] [G loss: 0.140439]
[Epoch 62/200] [Batch 199/200] [G loss: 0.168068]
[Epoch 63/200] [Batch 199/200] [G loss: 0.180321]
[Epoch 64/200] [Batch 199/200] [G loss: 0.187986]
[Epoch 65/200] [Batch 199/200] [G loss: 0.149694]
[Epoch 66/200] [Batch 199/200] [G loss: 0.133518]
[Epoch 67/200] [Batch 199/200] [G loss: 0.117963]
[Epoch 68/200] [Batch 199/200] [G loss: 0.132153]
[Epoch 69/200] [Batch 199/200] [G loss: 0.133291]
[Epoch 70/200] [Batch 199/200] [G loss: 0.145365]
[Epoch 71/200] [Batch 199/200] [G loss: 0.119306]
[Epoch 72/200] [Batch 199/200] [G loss: 0.142480]
[Epoch 73/200] [Batch 199/200] [G loss: 0.171040]
[Epoch 74/200] [Batch 199/200] [G loss: 0.169992]
[Epoch 75/200] [Batch 199/200] [G loss: 0.125971]
[Epoch 76/200] [Batch 199/200] [G loss: 0.123406]
[Epoch 77/200] [Batch 199/200] [G loss: 0.147985]
[Epoch 78/200] [Batch 199/200] [G loss: 0.129160]
[Epoch 79/200] [Batch 199/200] [G loss: 0.139947]
[Epoch 80/200] [Batch 199/200] [G loss: 0.125355]
[Epoch 81/200] [Batch 199/200] [G loss: 0.118749]
[Epoch 82/200] [Batch 199/200] [G loss: 0.145414]
[Epoch 83/200] [Batch 199/200] [G loss: 0.162139]
[Epoch 84/200] [Batch 199/200] [G loss: 0.173825]
[Epoch 85/200] [Batch 199/200] [G loss: 0.122010]
[Epoch 86/200] [Batch 199/200] [G loss: 0.149573]
[Epoch 87/200] [Batch 199/200] [G loss: 0.131326]
[Epoch 88/200] [Batch 199/200] [G loss: 0.117505]
[Epoch 89/200] [Batch 199/200] [G loss: 0.124128]
[Epoch 90/200] [Batch 199/200] [G loss: 0.159866]
[Epoch 91/200] [Batch 199/200] [G loss: 0.135815]
[Epoch 92/200] [Batch 199/200] [G loss: 0.142748]
[Epoch 93/200] [Batch 199/200] [G loss: 0.160993]
[Epoch 94/200] [Batch 199/200] [G loss: 0.135749]
[Epoch 95/200] [Batch 199/200] [G loss: 0.122210]
[Epoch 96/200] [Batch 199/200] [G loss: 0.137793]
[Epoch 97/200] [Batch 199/200] [G loss: 0.129412]
[Epoch 98/200] [Batch 199/200] [G loss: 0.139037]
[Epoch 99/200] [Batch 199/200] [G loss: 0.125869]
[Epoch 100/200] [Batch 199/200] [G loss: 0.116563]
[Epoch 101/200] [Batch 199/200] [G loss: 0.139237]
[Epoch 102/200] [Batch 199/200] [G loss: 0.126479]
[Epoch 103/200] [Batch 199/200] [G loss: 0.145636]
[Epoch 104/200] [Batch 199/200] [G loss: 0.102723]
[Epoch 105/200] [Batch 199/200] [G loss: 0.118776]
[Epoch 106/200] [Batch 199/200] [G loss: 0.132644]
[Epoch 107/200] [Batch 199/200] [G loss: 0.102977]
[Epoch 108/200] [Batch 199/200] [G loss: 0.135690]
[Epoch 109/200] [Batch 199/200] [G loss: 0.137939]
[Epoch 110/200] [Batch 199/200] [G loss: 0.147280]
[Epoch 111/200] [Batch 199/200] [G loss: 0.135464]
[Epoch 112/200] [Batch 199/200] [G loss: 0.149474]
[Epoch 113/200] [Batch 199/200] [G loss: 0.123235]
[Epoch 114/200] [Batch 199/200] [G loss: 0.145617]
[Epoch 115/200] [Batch 199/200] [G loss: 0.151036]
[Epoch 116/200] [Batch 199/200] [G loss: 0.117502]
[Epoch 117/200] [Batch 199/200] [G loss: 0.128346]
[Epoch 118/200] [Batch 199/200] [G loss: 0.116811]
[Epoch 119/200] [Batch 199/200] [G loss: 0.123286]
[Epoch 120/200] [Batch 199/200] [G loss: 0.120572]
[Epoch 121/200] [Batch 199/200] [G loss: 0.125450]
[Epoch 122/200] [Batch 199/200] [G loss: 0.119538]
[Epoch 123/200] [Batch 199/200] [G loss: 0.121445]
[Epoch 124/200] [Batch 199/200] [G loss: 0.117159]
[Epoch 125/200] [Batch 199/200] [G loss: 0.110577]
[Epoch 126/200] [Batch 199/200] [G loss: 0.106227]
[Epoch 127/200] [Batch 199/200] [G loss: 0.121005]
[Epoch 128/200] [Batch 199/200] [G loss: 0.164577]
[Epoch 129/200] [Batch 199/200] [G loss: 0.133884]
[Epoch 130/200] [Batch 199/200] [G loss: 0.117050]
[Epoch 131/200] [Batch 199/200] [G loss: 0.138256]
[Epoch 132/200] [Batch 199/200] [G loss: 0.109242]
[Epoch 133/200] [Batch 199/200] [G loss: 0.118484]
[Epoch 134/200] [Batch 199/200] [G loss: 0.122282]
[Epoch 135/200] [Batch 199/200] [G loss: 0.126166]
[Epoch 136/200] [Batch 199/200] [G loss: 0.125155]
[Epoch 137/200] [Batch 199/200] [G loss: 0.125193]
[Epoch 138/200] [Batch 199/200] [G loss: 0.099766]
[Epoch 139/200] [Batch 199/200] [G loss: 0.117015]
[Epoch 140/200] [Batch 199/200] [G loss: 0.116830]
[Epoch 141/200] [Batch 199/200] [G loss: 0.119369]
[Epoch 142/200] [Batch 199/200] [G loss: 0.113811]
[Epoch 143/200] [Batch 199/200] [G loss: 0.102296]
[Epoch 144/200] [Batch 199/200] [G loss: 0.128800]
[Epoch 145/200] [Batch 199/200] [G loss: 0.112457]
[Epoch 146/200] [Batch 199/200] [G loss: 0.119446]
[Epoch 147/200] [Batch 199/200] [G loss: 0.109735]
[Epoch 148/200] [Batch 199/200] [G loss: 0.104489]
[Epoch 149/200] [Batch 199/200] [G loss: 0.102157]
[Epoch 150/200] [Batch 199/200] [G loss: 0.113293]
[Epoch 151/200] [Batch 199/200] [G loss: 0.084870]
[Epoch 152/200] [Batch 199/200] [G loss: 0.099532]
[Epoch 153/200] [Batch 199/200] [G loss: 0.111813]
[Epoch 154/200] [Batch 199/200] [G loss: 0.138587]
[Epoch 155/200] [Batch 199/200] [G loss: 0.128371]
[Epoch 156/200] [Batch 199/200] [G loss: 0.106724]
[Epoch 157/200] [Batch 199/200] [G loss: 0.103742]
[Epoch 158/200] [Batch 199/200] [G loss: 0.090922]
[Epoch 159/200] [Batch 199/200] [G loss: 0.102734]
[Epoch 160/200] [Batch 199/200] [G loss: 0.098833]
[Epoch 161/200] [Batch 199/200] [G loss: 0.109438]
[Epoch 162/200] [Batch 199/200] [G loss: 0.104891]
[Epoch 163/200] [Batch 199/200] [G loss: 0.096490]
[Epoch 164/200] [Batch 199/200] [G loss: 0.116669]
[Epoch 165/200] [Batch 199/200] [G loss: 0.113991]
[Epoch 166/200] [Batch 199/200] [G loss: 0.109866]
[Epoch 167/200] [Batch 199/200] [G loss: 0.112979]
[Epoch 168/200] [Batch 199/200] [G loss: 0.116685]
[Epoch 169/200] [Batch 199/200] [G loss: 0.123616]
[Epoch 170/200] [Batch 199/200] [G loss: 0.119336]
[Epoch 171/200] [Batch 199/200] [G loss: 0.126123]
[Epoch 172/200] [Batch 199/200] [G loss: 0.118350]
[Epoch 173/200] [Batch 199/200] [G loss: 0.120627]
[Epoch 174/200] [Batch 199/200] [G loss: 0.109667]
[Epoch 175/200] [Batch 199/200] [G loss: 0.118170]
[Epoch 176/200] [Batch 199/200] [G loss: 0.119886]
[Epoch 177/200] [Batch 199/200] [G loss: 0.119359]
[Epoch 178/200] [Batch 199/200] [G loss: 0.106798]
[Epoch 179/200] [Batch 199/200] [G loss: 0.119646]
[Epoch 180/200] [Batch 199/200] [G loss: 0.103119]
[Epoch 181/200] [Batch 199/200] [G loss: 0.101382]
[Epoch 182/200] [Batch 199/200] [G loss: 0.133359]
[Epoch 183/200] [Batch 199/200] [G loss: 0.118980]
[Epoch 184/200] [Batch 199/200] [G loss: 0.083562]
[Epoch 185/200] [Batch 199/200] [G loss: 0.105433]
[Epoch 186/200] [Batch 199/200] [G loss: 0.131912]
[Epoch 187/200] [Batch 199/200] [G loss: 0.096478]
[Epoch 188/200] [Batch 199/200] [G loss: 0.080363]
[Epoch 189/200] [Batch 199/200] [G loss: 0.098164]
[Epoch 190/200] [Batch 199/200] [G loss: 0.116934]
[Epoch 191/200] [Batch 199/200] [G loss: 0.102533]
[Epoch 192/200] [Batch 199/200] [G loss: 0.099125]
[Epoch 193/200] [Batch 199/200] [G loss: 0.092987]
[Epoch 194/200] [Batch 199/200] [G loss: 0.102197]
[Epoch 195/200] [Batch 199/200] [G loss: 0.084951]
[Epoch 196/200] [Batch 199/200] [G loss: 0.109301]
[Epoch 197/200] [Batch 199/200] [G loss: 0.096025]
[Epoch 198/200] [Batch 199/200] [G loss: 0.106507]
[Epoch 199/200] [Batch 199/200] [G loss: 0.096771]
- 只用CGAN Loss训练pix2pix(在下面的cell填入对应代码并运行).说说有结果什么不同.
答:生成图片比只用L1 Loss的清晰不少,很接近两者都用时的清晰度,但是颜色还是没有两者都用时的丰富和贴近现实。
for epoch in range(n_epochs):
for i, batch in enumerate(dataloader):
# G:B -> A
real_A = batch["A"].to(device)
real_B = batch["B"].to(device)
# Adversarial ground truths
real_label = torch.ones((real_A.size(0), *patch)).to(device)
fake_label = torch.zeros((real_A.size(0), *patch)).to(device)
# ------------------
# Train Generators
# ------------------
optimizer_G.zero_grad()
# GAN loss
fake_B = G(real_A)
pred_fake = D(fake_B, real_A)
loss_GAN = criterion_GAN(pred_fake, real_label)
# Total loss
loss_G = loss_GAN
loss_G.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Real loss
pred_real = D(real_B, real_A)
loss_real = criterion_GAN(pred_real, real_label)
# Fake loss
pred_fake = D(fake_B.detach(), real_A)
loss_fake = criterion_GAN(pred_fake, fake_label)
# Total loss
loss_D = 0.5 * (loss_real + loss_fake)
loss_D.backward()
optimizer_D.step()
# Print log
print(
"\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (
epoch,
n_epochs,
i,
len(dataloader),
loss_D.item(),
loss_G.item()
)
)
# If at sample interval save image
if epoch == 0 or (epoch + 1) % 5 == 0:
sample_images(val_dataloader, G, device)
[Epoch 0/200] [Batch 199/200] [D loss: 0.254110] [G loss: 0.400683]
[Epoch 1/200] [Batch 199/200] [D loss: 0.342961] [G loss: 0.359540]
[Epoch 2/200] [Batch 199/200] [D loss: 0.385767] [G loss: 0.340707]
[Epoch 3/200] [Batch 199/200] [D loss: 0.326501] [G loss: 0.188578]
[Epoch 4/200] [Batch 199/200] [D loss: 0.102095] [G loss: 0.752533]
[Epoch 5/200] [Batch 199/200] [D loss: 0.122336] [G loss: 0.466870]
[Epoch 6/200] [Batch 199/200] [D loss: 0.242705] [G loss: 0.309910]
[Epoch 7/200] [Batch 199/200] [D loss: 0.305823] [G loss: 0.485629]
[Epoch 8/200] [Batch 199/200] [D loss: 0.213961] [G loss: 0.397673]
[Epoch 9/200] [Batch 199/200] [D loss: 0.344844] [G loss: 0.432747]
[Epoch 10/200] [Batch 199/200] [D loss: 0.184985] [G loss: 0.340543]
[Epoch 11/200] [Batch 199/200] [D loss: 0.132156] [G loss: 0.509994]
[Epoch 12/200] [Batch 199/200] [D loss: 0.197557] [G loss: 0.312057]
[Epoch 13/200] [Batch 199/200] [D loss: 0.224186] [G loss: 0.181903]
[Epoch 14/200] [Batch 199/200] [D loss: 0.099012] [G loss: 0.685086]
[Epoch 15/200] [Batch 199/200] [D loss: 0.225293] [G loss: 0.548334]
[Epoch 16/200] [Batch 199/200] [D loss: 0.286089] [G loss: 0.723756]
[Epoch 17/200] [Batch 199/200] [D loss: 0.291427] [G loss: 0.749057]
[Epoch 18/200] [Batch 199/200] [D loss: 0.120387] [G loss: 0.559266]
[Epoch 19/200] [Batch 199/200] [D loss: 0.078057] [G loss: 0.525624]
[Epoch 20/200] [Batch 199/200] [D loss: 0.241737] [G loss: 0.502774]
[Epoch 21/200] [Batch 199/200] [D loss: 0.096059] [G loss: 0.497557]
[Epoch 22/200] [Batch 199/200] [D loss: 0.309472] [G loss: 0.092431]
[Epoch 23/200] [Batch 199/200] [D loss: 0.186846] [G loss: 0.214923]
[Epoch 24/200] [Batch 199/200] [D loss: 0.173517] [G loss: 0.269690]
[Epoch 25/200] [Batch 199/200] [D loss: 0.092891] [G loss: 0.587741]
[Epoch 26/200] [Batch 199/200] [D loss: 0.160067] [G loss: 0.451931]
[Epoch 27/200] [Batch 199/200] [D loss: 0.075956] [G loss: 0.541692]
[Epoch 28/200] [Batch 199/200] [D loss: 0.129337] [G loss: 0.469304]
[Epoch 29/200] [Batch 199/200] [D loss: 0.049432] [G loss: 0.772484]
[Epoch 30/200] [Batch 199/200] [D loss: 0.369290] [G loss: 0.426039]
[Epoch 31/200] [Batch 199/200] [D loss: 0.153306] [G loss: 0.507966]
[Epoch 32/200] [Batch 199/200] [D loss: 0.292500] [G loss: 0.671189]
[Epoch 33/200] [Batch 199/200] [D loss: 0.158917] [G loss: 0.467205]
[Epoch 34/200] [Batch 199/200] [D loss: 0.071784] [G loss: 0.620297]
[Epoch 35/200] [Batch 199/200] [D loss: 0.116610] [G loss: 0.813179]
[Epoch 36/200] [Batch 199/200] [D loss: 0.166591] [G loss: 0.762671]
[Epoch 37/200] [Batch 199/200] [D loss: 0.129473] [G loss: 0.673858]
[Epoch 38/200] [Batch 199/200] [D loss: 0.114900] [G loss: 0.363758]
[Epoch 39/200] [Batch 199/200] [D loss: 0.145779] [G loss: 0.318783]
[Epoch 40/200] [Batch 199/200] [D loss: 0.162806] [G loss: 0.379244]
[Epoch 41/200] [Batch 199/200] [D loss: 0.141195] [G loss: 0.496114]
[Epoch 42/200] [Batch 199/200] [D loss: 0.082435] [G loss: 0.576555]
[Epoch 43/200] [Batch 199/200] [D loss: 0.099793] [G loss: 0.556394]
[Epoch 44/200] [Batch 199/200] [D loss: 0.318942] [G loss: 0.077569]
[Epoch 45/200] [Batch 199/200] [D loss: 0.229349] [G loss: 0.271100]
[Epoch 46/200] [Batch 199/200] [D loss: 0.056999] [G loss: 0.610993]
[Epoch 47/200] [Batch 199/200] [D loss: 0.117804] [G loss: 0.388312]
[Epoch 48/200] [Batch 199/200] [D loss: 0.042054] [G loss: 1.032030]
[Epoch 49/200] [Batch 199/200] [D loss: 0.124821] [G loss: 0.436014]
[Epoch 50/200] [Batch 199/200] [D loss: 0.052324] [G loss: 0.650074]
[Epoch 51/200] [Batch 199/200] [D loss: 0.113365] [G loss: 0.607839]
[Epoch 52/200] [Batch 199/200] [D loss: 0.092997] [G loss: 0.953787]
[Epoch 53/200] [Batch 199/200] [D loss: 0.094141] [G loss: 0.536500]
[Epoch 54/200] [Batch 199/200] [D loss: 0.030567] [G loss: 1.092693]
[Epoch 55/200] [Batch 199/200] [D loss: 0.030092] [G loss: 0.798391]
[Epoch 56/200] [Batch 199/200] [D loss: 0.095452] [G loss: 0.916641]
[Epoch 57/200] [Batch 199/200] [D loss: 0.039354] [G loss: 0.780800]
[Epoch 58/200] [Batch 199/200] [D loss: 0.098653] [G loss: 0.679285]
[Epoch 59/200] [Batch 199/200] [D loss: 0.022100] [G loss: 0.843827]
[Epoch 60/200] [Batch 199/200] [D loss: 0.045137] [G loss: 0.824282]
[Epoch 61/200] [Batch 199/200] [D loss: 0.023420] [G loss: 0.946648]
[Epoch 62/200] [Batch 199/200] [D loss: 0.098902] [G loss: 1.133402]
[Epoch 63/200] [Batch 199/200] [D loss: 0.044117] [G loss: 1.303398]
[Epoch 64/200] [Batch 199/200] [D loss: 0.080212] [G loss: 0.455219]
[Epoch 65/200] [Batch 199/200] [D loss: 0.074195] [G loss: 0.522145]
[Epoch 66/200] [Batch 199/200] [D loss: 0.059246] [G loss: 0.886048]
[Epoch 67/200] [Batch 199/200] [D loss: 0.022533] [G loss: 0.761089]
[Epoch 68/200] [Batch 199/200] [D loss: 0.075360] [G loss: 0.715215]
[Epoch 69/200] [Batch 199/200] [D loss: 0.098906] [G loss: 1.055892]
[Epoch 70/200] [Batch 199/200] [D loss: 0.115444] [G loss: 0.392026]
[Epoch 71/200] [Batch 199/200] [D loss: 0.056990] [G loss: 0.651682]
[Epoch 72/200] [Batch 199/200] [D loss: 0.045823] [G loss: 0.773473]
[Epoch 73/200] [Batch 199/200] [D loss: 0.081872] [G loss: 0.871012]
[Epoch 74/200] [Batch 199/200] [D loss: 0.040572] [G loss: 0.923304]
[Epoch 75/200] [Batch 199/200] [D loss: 0.093709] [G loss: 0.780252]
[Epoch 76/200] [Batch 199/200] [D loss: 0.085103] [G loss: 1.123193]
[Epoch 77/200] [Batch 199/200] [D loss: 0.058822] [G loss: 0.696394]
[Epoch 78/200] [Batch 199/200] [D loss: 0.047456] [G loss: 0.927635]
[Epoch 79/200] [Batch 199/200] [D loss: 0.031657] [G loss: 0.914478]
[Epoch 80/200] [Batch 199/200] [D loss: 0.027254] [G loss: 0.701341]
[Epoch 81/200] [Batch 199/200] [D loss: 0.023676] [G loss: 1.058345]
[Epoch 82/200] [Batch 199/200] [D loss: 0.043505] [G loss: 0.873614]
[Epoch 83/200] [Batch 199/200] [D loss: 0.032101] [G loss: 1.026267]
[Epoch 84/200] [Batch 199/200] [D loss: 0.018859] [G loss: 0.928631]
[Epoch 85/200] [Batch 199/200] [D loss: 0.048639] [G loss: 0.658756]
[Epoch 86/200] [Batch 199/200] [D loss: 0.043439] [G loss: 0.900952]
[Epoch 87/200] [Batch 199/200] [D loss: 0.063972] [G loss: 0.885661]
[Epoch 88/200] [Batch 199/200] [D loss: 0.016437] [G loss: 1.086205]
[Epoch 89/200] [Batch 199/200] [D loss: 0.113462] [G loss: 0.358929]
[Epoch 90/200] [Batch 199/200] [D loss: 0.040559] [G loss: 0.599311]
[Epoch 91/200] [Batch 199/200] [D loss: 0.038592] [G loss: 1.244941]
[Epoch 92/200] [Batch 199/200] [D loss: 0.125126] [G loss: 0.893965]
[Epoch 93/200] [Batch 199/200] [D loss: 0.043307] [G loss: 0.658253]
[Epoch 94/200] [Batch 199/200] [D loss: 0.015796] [G loss: 1.042205]
[Epoch 95/200] [Batch 199/200] [D loss: 0.025060] [G loss: 1.062810]
[Epoch 96/200] [Batch 199/200] [D loss: 0.057385] [G loss: 0.865152]
[Epoch 97/200] [Batch 199/200] [D loss: 0.035650] [G loss: 0.651488]
[Epoch 98/200] [Batch 199/200] [D loss: 0.118526] [G loss: 0.486323]
[Epoch 99/200] [Batch 199/200] [D loss: 0.074473] [G loss: 0.479583]
[Epoch 100/200] [Batch 199/200] [D loss: 0.095015] [G loss: 1.005855]
[Epoch 101/200] [Batch 199/200] [D loss: 0.023201] [G loss: 0.887087]
[Epoch 102/200] [Batch 199/200] [D loss: 0.045691] [G loss: 0.600945]
[Epoch 103/200] [Batch 199/200] [D loss: 0.008456] [G loss: 0.984253]
[Epoch 104/200] [Batch 199/200] [D loss: 0.034703] [G loss: 0.908872]
[Epoch 105/200] [Batch 199/200] [D loss: 0.014694] [G loss: 0.774785]
[Epoch 106/200] [Batch 199/200] [D loss: 0.024673] [G loss: 0.776556]
[Epoch 107/200] [Batch 199/200] [D loss: 0.017629] [G loss: 0.921517]
[Epoch 108/200] [Batch 199/200] [D loss: 0.067858] [G loss: 0.496951]
[Epoch 109/200] [Batch 199/200] [D loss: 0.018423] [G loss: 0.788722]
[Epoch 110/200] [Batch 199/200] [D loss: 0.050647] [G loss: 1.166952]
[Epoch 111/200] [Batch 199/200] [D loss: 0.014269] [G loss: 1.062882]
[Epoch 112/200] [Batch 199/200] [D loss: 0.027838] [G loss: 0.830265]
[Epoch 113/200] [Batch 199/200] [D loss: 0.106117] [G loss: 0.462354]
[Epoch 114/200] [Batch 199/200] [D loss: 0.037023] [G loss: 1.001458]
[Epoch 115/200] [Batch 199/200] [D loss: 0.053321] [G loss: 0.964119]
[Epoch 116/200] [Batch 199/200] [D loss: 0.045950] [G loss: 0.804897]
[Epoch 117/200] [Batch 199/200] [D loss: 0.047809] [G loss: 1.042855]
[Epoch 118/200] [Batch 199/200] [D loss: 0.016845] [G loss: 1.110005]
[Epoch 119/200] [Batch 199/200] [D loss: 0.021090] [G loss: 0.985897]
[Epoch 120/200] [Batch 199/200] [D loss: 0.013367] [G loss: 0.876777]
[Epoch 121/200] [Batch 199/200] [D loss: 0.024044] [G loss: 0.886552]
[Epoch 122/200] [Batch 199/200] [D loss: 0.016847] [G loss: 0.801245]
[Epoch 123/200] [Batch 199/200] [D loss: 0.080545] [G loss: 1.302130]
[Epoch 124/200] [Batch 199/200] [D loss: 0.044526] [G loss: 0.799763]
[Epoch 125/200] [Batch 199/200] [D loss: 0.039700] [G loss: 1.287937]
[Epoch 126/200] [Batch 199/200] [D loss: 0.032246] [G loss: 0.756250]
[Epoch 127/200] [Batch 199/200] [D loss: 0.013014] [G loss: 0.999001]
[Epoch 128/200] [Batch 199/200] [D loss: 0.038288] [G loss: 0.970052]
[Epoch 129/200] [Batch 199/200] [D loss: 0.026947] [G loss: 0.966035]
[Epoch 130/200] [Batch 199/200] [D loss: 0.017097] [G loss: 1.018042]
[Epoch 131/200] [Batch 199/200] [D loss: 0.022436] [G loss: 0.793433]
[Epoch 132/200] [Batch 199/200] [D loss: 0.012662] [G loss: 0.935876]
[Epoch 133/200] [Batch 199/200] [D loss: 0.014513] [G loss: 0.743143]
[Epoch 134/200] [Batch 199/200] [D loss: 0.016531] [G loss: 1.129284]
[Epoch 135/200] [Batch 199/200] [D loss: 0.016159] [G loss: 1.021123]
[Epoch 136/200] [Batch 199/200] [D loss: 0.029625] [G loss: 0.759095]
[Epoch 137/200] [Batch 199/200] [D loss: 0.026612] [G loss: 0.797332]
[Epoch 138/200] [Batch 199/200] [D loss: 0.013228] [G loss: 0.776929]
[Epoch 139/200] [Batch 199/200] [D loss: 0.012583] [G loss: 1.063368]
[Epoch 140/200] [Batch 199/200] [D loss: 0.037750] [G loss: 0.834311]
[Epoch 141/200] [Batch 199/200] [D loss: 0.020066] [G loss: 1.099112]
[Epoch 142/200] [Batch 199/200] [D loss: 0.010482] [G loss: 1.197470]
[Epoch 143/200] [Batch 199/200] [D loss: 0.094754] [G loss: 1.204931]
[Epoch 144/200] [Batch 199/200] [D loss: 0.026892] [G loss: 0.673463]
[Epoch 145/200] [Batch 199/200] [D loss: 0.061226] [G loss: 1.317254]
[Epoch 146/200] [Batch 199/200] [D loss: 0.025262] [G loss: 0.769292]
[Epoch 147/200] [Batch 199/200] [D loss: 0.040495] [G loss: 1.239729]
[Epoch 148/200] [Batch 199/200] [D loss: 0.013362] [G loss: 1.051308]
[Epoch 149/200] [Batch 199/200] [D loss: 0.021302] [G loss: 0.730047]
[Epoch 150/200] [Batch 199/200] [D loss: 0.025892] [G loss: 0.814001]
[Epoch 151/200] [Batch 199/200] [D loss: 0.020776] [G loss: 0.982714]
[Epoch 152/200] [Batch 199/200] [D loss: 0.026047] [G loss: 1.213889]
[Epoch 153/200] [Batch 199/200] [D loss: 0.013361] [G loss: 1.006207]
[Epoch 154/200] [Batch 199/200] [D loss: 0.051965] [G loss: 0.551684]
[Epoch 155/200] [Batch 199/200] [D loss: 0.025992] [G loss: 1.195173]
[Epoch 156/200] [Batch 199/200] [D loss: 0.039328] [G loss: 0.642962]
[Epoch 157/200] [Batch 199/200] [D loss: 0.022489] [G loss: 0.885466]
[Epoch 158/200] [Batch 199/200] [D loss: 0.013182] [G loss: 1.080920]
[Epoch 159/200] [Batch 199/200] [D loss: 0.012617] [G loss: 1.099098]
[Epoch 160/200] [Batch 199/200] [D loss: 0.032842] [G loss: 1.059318]
[Epoch 161/200] [Batch 199/200] [D loss: 0.022808] [G loss: 0.761112]
[Epoch 162/200] [Batch 199/200] [D loss: 0.007848] [G loss: 1.067001]
[Epoch 163/200] [Batch 199/200] [D loss: 0.005567] [G loss: 1.051462]
[Epoch 164/200] [Batch 199/200] [D loss: 0.025706] [G loss: 1.148592]
[Epoch 165/200] [Batch 199/200] [D loss: 0.009668] [G loss: 1.005182]
[Epoch 166/200] [Batch 199/200] [D loss: 0.037075] [G loss: 0.700570]
[Epoch 167/200] [Batch 199/200] [D loss: 0.042427] [G loss: 0.589349]
[Epoch 168/200] [Batch 199/200] [D loss: 0.045177] [G loss: 0.554835]
[Epoch 169/200] [Batch 199/200] [D loss: 0.024359] [G loss: 0.690428]
[Epoch 170/200] [Batch 199/200] [D loss: 0.018959] [G loss: 0.780097]
[Epoch 171/200] [Batch 199/200] [D loss: 0.019923] [G loss: 0.929609]
[Epoch 172/200] [Batch 199/200] [D loss: 0.018804] [G loss: 0.794430]
[Epoch 173/200] [Batch 199/200] [D loss: 0.015542] [G loss: 1.019145]
[Epoch 174/200] [Batch 199/200] [D loss: 0.011219] [G loss: 0.911498]
[Epoch 175/200] [Batch 199/200] [D loss: 0.022086] [G loss: 1.036923]
[Epoch 176/200] [Batch 199/200] [D loss: 0.045086] [G loss: 0.868485]
[Epoch 177/200] [Batch 199/200] [D loss: 0.052415] [G loss: 1.125537]
[Epoch 178/200] [Batch 199/200] [D loss: 0.017944] [G loss: 0.801145]
[Epoch 179/200] [Batch 199/200] [D loss: 0.008215] [G loss: 0.998200]
[Epoch 180/200] [Batch 199/200] [D loss: 0.027437] [G loss: 0.762241]
[Epoch 181/200] [Batch 199/200] [D loss: 0.021706] [G loss: 0.888318]
[Epoch 182/200] [Batch 199/200] [D loss: 0.017923] [G loss: 1.255804]
[Epoch 183/200] [Batch 199/200] [D loss: 0.013489] [G loss: 0.859919]
[Epoch 184/200] [Batch 199/200] [D loss: 0.016861] [G loss: 1.131081]
[Epoch 185/200] [Batch 199/200] [D loss: 0.038280] [G loss: 1.020154]
[Epoch 186/200] [Batch 199/200] [D loss: 0.015845] [G loss: 1.005845]
[Epoch 187/200] [Batch 199/200] [D loss: 0.014177] [G loss: 0.855151]
[Epoch 188/200] [Batch 199/200] [D loss: 0.005620] [G loss: 0.985205]
[Epoch 189/200] [Batch 199/200] [D loss: 0.010096] [G loss: 0.989456]
[Epoch 190/200] [Batch 199/200] [D loss: 0.009696] [G loss: 1.126480]
[Epoch 191/200] [Batch 199/200] [D loss: 0.007405] [G loss: 1.073396]
[Epoch 192/200] [Batch 199/200] [D loss: 0.039247] [G loss: 0.804929]
[Epoch 193/200] [Batch 199/200] [D loss: 0.027823] [G loss: 1.100355]
[Epoch 194/200] [Batch 199/200] [D loss: 0.020142] [G loss: 0.842804]
[Epoch 195/200] [Batch 199/200] [D loss: 0.008569] [G loss: 0.983230]
[Epoch 196/200] [Batch 199/200] [D loss: 0.013945] [G loss: 0.900784]
[Epoch 197/200] [Batch 199/200] [D loss: 0.021424] [G loss: 0.746807]
[Epoch 198/200] [Batch 199/200] [D loss: 0.030270] [G loss: 1.077675]
[Epoch 199/200] [Batch 199/200] [D loss: 0.013651] [G loss: 0.832313]