It is noted that GAN training is hard and unstable, which results in blury images. In this post, a several techniques are introduced to improve the training stability of GAN.
Multi-scale discriminator
Method
It is first introduced in the paper High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs, where several techniques are proposed to enhance the reality of generated images.
According to the paper, discriminator with a large receptive field can help differentiate high-resolution real and synthesized images. But This would require either a deeper network or larger convolutional kernels, both of which would increase the network capacity and potentially cause overfitting. Also, both choices demand a larger memory footprint for training, which is already a scarce resource for highresolution image generation.
Therefore, they try to use multi-scale discriminators. They use 3 discriminators that have an identical network structure but operate at different image scales.
We will refer to the discriminators as D1, D2 and D3. Specifically, we downsample the real and synthesized highresolution images by a factor of 2 and 4 to create an image pyramid of 3 scales. The discriminators D1, D2 and D3 are then trained to differentiate real and synthesized images at the 3 different scales, respectively. Although the discriminators have an identical architecture, the one that operates at the coarsest scale has the largest receptive field. It has a more global view of the image and can guide the generator to generate globally consistent images. On the other hand, the discriminator at the finest scale encourages the generator to produce finer details. This also makes training the coarse-to-fine generator easier, since extending a low re solution model to a higher resolution only requires adding
a discriminator at the finest level, rather than retraining from scratch. Without the multi-scale discriminators, we observe that many repeated patterns often appear in the generated images.
Pytorch
Here, BaseDiscriminator
implements a single identical discriminator network. Then MultipleDiscriminator
receive the parameter num_of_discriminator
and create MultipleDiscriminator network strcuture. Then we can directly use function define_D
to use the method.
|
|
We also need to cusomize our own MSE
loss because each discriminator have its own output and corresponding mse loss. We want the sum mse loss of all discriminators.
|
|
Once we finish the above codes, we can use them.
|
|