In the code segment that updates the critic weights we use:
crit_loss.backward(retain_graph=True)
Why do we need to retain the graph at this point?
For more context, here is the code segment:
for _ in range(crit_repeats):
### Update critic ###
crit_opt.zero_grad()
fake_noise = get_noise(cur_batch_size, z_dim, device=device)
fake = gen(fake_noise)
crit_fake_pred = crit(fake.detach())
crit_real_pred = crit(real)
epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
gradient = get_gradient(crit, real, fake.detach(), epsilon)
gp = gradient_penalty(gradient)
crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)
# Keep track of the average critic loss in this batch
mean_iteration_critic_loss += crit_loss.item() / crit_repeats
# Update gradients
crit_loss.backward(retain_graph=True)
# Update optimizer
crit_opt.step()