I am stuck on the assertion tests, specifically the assertion test to check if the weights are changing.
AssertionError Traceback (most recent call last)
Input In [18], in <cell line: 45>()
41 assert not torch.all(torch.eq(old_weight, new_weight))
44 test_gen_reasonable(10)
---> 45 test_gen_loss(18)
46 print("Success!")
Input In [18], in test_gen_loss(num_images)
39 gen_opt.step()
40 new_weight = gen.gen[0][0].weight
---> 41 assert not torch.all(torch.eq(old_weight, new_weight))
AssertionError:
I am using version ‘1.13.1+cu116’
That assertion is checking that the generator’s weights are changing when you use the generator loss function. So how could that not happen? Did you use any “detach” operations in your generator loss function? Note that in the discriminator loss function, we detach the generator, so that we don’t bother computing generator gradients. But it’s key to understand that the two situations are asymmetric. In the case of the generator loss, it goes through the discriminator, so we can’t do any detaches. But the discriminators loss does not actually require the gradients for the generators.
Here’s a thread from a while back that discusses this in a bit more detail.