Hi
In the test function of Q6,
in test_disc_reasonable. I saw “assert torch.all(torch.abs(disc_loss.mean() - 0.5) < 1e-5)”, this indicates disc_loss should be a tensor.
in test_disc_loss, I saw “assert (disc_loss - 0.68).abs() < 0.05”, this indicates disc_loss should be a scalar.
In the test function of Q7
in test_gen_reasonable, I saw “torch.all(torch.abs(gen_loss_tensor) < 1e-5)”, this indicates the gen_loss should be a tensor.
in test_gen_loss, I saw “assert (gen_loss - 0.7).abs() < 0.1”, this indicates the gen_loss should be a scalar.
I feel it hard to prepare the output dimension to satisfy both test functions.
Can I have some hints on this?
In addition, in test_disc_loss of Q6, criterion is not defined.
I always have trouble with “assert (disc_loss - 0.68).abs() < 0.05” in Q6, and “assert (gen_loss - 0.7).abs() < 0.1” in Q7.
Any hints will be appreciated!