Here is the complete code for that test cell:
# UNIT TEST
def test_gen_reasonable(num_images=10):
gen = torch.zeros_like
disc = lambda x, y: torch.ones(len(x), 1)
real = None
condition = torch.ones(num_images, 3, 10, 10)
adv_criterion = torch.mul
recon_criterion = lambda x, y: torch.tensor(0)
lambda_recon = 0
assert get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon).sum() == num_images
disc = lambda x, y: torch.zeros(len(x), 1)
assert torch.abs(get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon)).sum() == 0
adv_criterion = lambda x, y: torch.tensor(0)
recon_criterion = lambda x, y: torch.abs(x - y).max()
real = torch.randn(num_images, 3, 10, 10)
lambda_recon = 2
gen = lambda x: real + 1
assert torch.abs(get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon) - 2) < 1e-4
adv_criterion = lambda x, y: (x + y).max() + x.max()
assert torch.abs(get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon) - 3) < 1e-4
test_gen_reasonable()
print("Success!")
You can see that the test cell is playing some pretty sophisticated games using the fact that the Function Under Test takes functions as some of its arguments. It looks like your code is failing the very first test. Notice in that case that real
is in fact set to None
and the adv_criterion
function is torch.mul
.
It’s fine to use pseudo-code to describe your solution. Your first two steps agree with mine. I think your mistake is that you should not be passing real
“as is” to the adv_criterion
function. My reading of the instructions is that gets fed to the reconstruction criterion function first. My code from that line on looks quite a bit different than yours.