Like we set discriminator.trainable to true and false on phases 1 and 2 respectively, why don’t we set the generator.trainable to false and true in thses phases? When we train the discriminator, do we want also the weights of the generator to be effected?
I do understand that it is more important to disable disciminator’s learning when we fake data. But it seems to me that also the generator may learn the oposite stuff when we train the discriminator, because its goal is the opposite, aka to succeed in fooling the discriminator.
def train_gan(gan, dataset, random_normal_dimensions, n_epochs=50):
""" Defines the two-phase training loop of the GAN
Args:
gan -- the GAN model which has the generator and discriminator
dataset -- the training set of real images
random_normal_dimensions -- dimensionality of the input to the generator
n_epochs -- number of epochs
"""
# get the two sub networks from the GAN model
generator, discriminator = gan.layers
# start loop
for epoch in range(n_epochs):
print("Epoch {}/{}".format(epoch + 1, n_epochs))
for real_images in dataset:
# infer batch size from the training batch
batch_size = real_images.shape[0]
# Train the discriminator - PHASE 1
# Create the noise
noise = tf.random.normal(shape=[batch_size, random_normal_dimensions])
# Use the noise to generate fake images
fake_images = generator(noise)
# Create a list by concatenating the fake images with the real ones
mixed_images = tf.concat([fake_images, real_images], axis=0)
# Create the labels for the discriminator
# 0 for the fake images
# 1 for the real images
discriminator_labels = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
# Ensure that the discriminator is trainable
discriminator.trainable = True
# Use train_on_batch to train the discriminator with the mixed images and the discriminator labels
discriminator.train_on_batch(mixed_images, discriminator_labels)
# Train the generator - PHASE 2
# create a batch of noise input to feed to the GAN
noise = tf.random.normal(shape=[batch_size, random_normal_dimensions])
# label all generated images to be "real"
generator_labels = tf.constant([[1.]] * batch_size)
# Freeze the discriminator
discriminator.trainable = False
# Train the GAN on the noise with the labels all set to be true
gan.train_on_batch(noise, generator_labels)
# plot the fake images used to train the discriminator
plot_multiple_images(fake_images, 8)
plt.show()