DP-GANImproving

It is noted that GAN training is hard and unstable, which results in blury images. In this post, a several techniques are introduced to improve the training stability of GAN.

Multi-scale discriminator

Method

It is first introduced in the paper High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs, where several techniques are proposed to enhance the reality of generated images.

According to the paper, discriminator with a large receptive field can help differentiate high-resolution real and synthesized images. But This would require either a deeper network or larger convolutional kernels, both of which would increase the network capacity and potentially cause overfitting. Also, both choices demand a larger memory footprint for training, which is already a scarce resource for highresolution image generation.

Therefore, they try to use multi-scale discriminators. They use 3 discriminators that have an identical network structure but operate at different image scales.
We will refer to the discriminators as D1, D2 and D3. Specifically, we downsample the real and synthesized highresolution images by a factor of 2 and 4 to create an image pyramid of 3 scales. The discriminators D1, D2 and D3 are then trained to differentiate real and synthesized images at the 3 different scales, respectively. Although the discriminators have an identical architecture, the one that operates at the coarsest scale has the largest receptive field. It has a more global view of the image and can guide the generator to generate globally consistent images. On the other hand, the discriminator at the finest scale encourages the generator to produce finer details. This also makes training the coarse-to-fine generator easier, since extending a low re solution model to a higher resolution only requires adding
a discriminator at the finest level, rather than retraining from scratch. Without the multi-scale discriminators, we observe that many repeated patterns often appear in the generated images.

Pytorch

Here, BaseDiscriminator implements a single identical discriminator network. Then MultipleDiscriminator receive the parameter num_of_discriminator and create MultipleDiscriminator network strcuture. Then we can directly use function define_D to use the method.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def define_D(in_channels=3,num_D=3):
net_D = MultipleDiscriminator(in_channels,num_D)
return net_D
class MultipleDiscriminator(nn.Module):
def __init__(self,in_channels=3,num_D=3):
super(MultipleDiscriminator,self).__init__()
self.num_D = num_D
for i in range(num_D):
netD = BaseDiscriminator(in_channels)
setattr(self, 'layer' + str(i), netD.model)
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
def singleD_forward(self, model, input):
return [model(input)]
def forward(self, input):
num_D = self.num_D
result = []
input_downsampled = input
for i in range(num_D):
model = getattr(self, 'layer' + str(num_D - 1 - i))
result.append(self.singleD_forward(model, input_downsampled))
if i != (num_D - 1):
input_downsampled = self.downsample(input_downsampled)
return result
class BaseDiscriminator(nn.Module):
def __init__(self, in_channels=3):
super(BaseDiscriminator, self).__init__()
def discriminator_block(in_filters, out_filters, normalization=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalization:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*discriminator_block(in_channels*2, 64, normalization=False),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(512, 1, 4, padding=1, bias=False)
)
def forward(self, img_A, img_B):
# Concatenate image and condition image by channels to produce input
img_input = torch.cat((img_A, img_B), 1)
return self.model(img_input)

We also need to cusomize our own MSE loss because each discriminator have its own output and corresponding mse loss. We want the sum mse loss of all discriminators.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()
def get_target_tensor(self, input, target_is_real):
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor
def __call__(self, input, target_is_real):
if isinstance(input[0], list):
loss = 0
for input_i in input:
pred = input_i[-1]
target_tensor = self.get_target_tensor(pred, target_is_real)
loss += self.loss(pred, target_tensor)
return loss
else:
target_tensor = self.get_target_tensor(input[-1], target_is_real)
return self.loss(input[-1], target_tensor)

Once we finish the above codes, we can use them.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def discriminate(input_label, test_image):
input_concat = torch.cat((input_label, test_image.detach()), dim=1)
return discriminator(input_concat)
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
criterion_GAN = GANLoss(use_lsgan=not opt.no_lsgan,tensor=Tensor)
discriminator = define_D(num_D = opt.num_discriminator)
if cuda:
discriminator.cuda()
discriminator.apply(weights_init_normal)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Real loss
pred_real = discriminate(real_B, real_A)
loss_real = criterion_GAN(pred_real, True)
# Fake loss
pred_fake = discriminate(fake_B.detach(), real_A)
loss_fake = criterion_GAN(pred_fake, False)
# Total loss
loss_D = 0.5 * (loss_real + loss_fake)
loss_D.backward()
optimizer_D.step()

Progressive growing of GANs

link1 link2

pytorch-progan1 2 3 4 5