C1W1_Your_First_GAN - Test functions

I found very interesting the test functions used for checking our implementation. They act sort of as unit tests and I think you don’t see this very often in DL projects. However, I am not sure how they were written, how the exact values were chosen for the assertions. Like in this example:

assert torch.all(torch.abs(get_disc_loss(gen, disc, criterion, real, num_images, z_dim, 'cpu').mean() - 5) < 1e-5)

Why substract 5 from the mean and why should it be less then 1e-5? Is there a reason there that I miss or are they chosen arbitrarily and I am reading too much into this?

Thank you!

Hi @Andreea_Elena_Sandu,
Great to see you curious about the test functions! I agree. The test functions are interesting and very handy for checking implementation.

For tests like your example, where you have something like:
assert <some_val> - 5 < 1e-5
the idea is that the test expects <some_val> to be 5, but to account for possible rounding error, instead of asserting <some_val> == 5, it checks if <some_val> - 5 is less than some very small number, in this case 1e-5. (As you do more assignments, you’ll notice that the tests frequently use 1e-5 for this purpose).

1 Like