W4L2. About train_loss and test_loss

In the main training loop, I noticed we are “manually” calculating the training loss instead of instantiating a tf.keras.metrics.Mean() like for test_loss. This section of a TF tutorial (very similar to our lab) explains explains why we shouldn’t use tf.keras.metrics.Mean() for train_loss. However, this leaves me confused on two points:

  • Why can we use tf.keras.metrics.Mean() for test_loss?
  • In function test_step(), why do we directly call loss_object(labels, predictions) instead of doing compute_loss(labels, predictions) like in train_step()?

I think this has to do with the fact that test images do not go through training, no forward and backward propagation.

I do apologize, I still am confused on both points. How are they linked to the fact that we’re training or not?

Hi @mchanchee,

I think the idea @gent.spah is referring to is best described in the “Define the loss function” section of the tutorial you found, here

The basic idea is that for training, when we do backprop, the strategy will sum the gradient values from each replica, so we need to make sure the loss value we pass to tape.gradient() in train_step() takes this into account. That means when calculating the loss for training, we need to divide by GLOBAL_BATCH_SIZE, rather than by the size of the replica’s batch. compute_loss() does this by calling tf.nn.compute_average_loss() with the appropriate GLOBAL_BATCH_SIZE. You’ll also notice that later in distributed_train_step() that we call
strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
to account for the fact that we divided by GLOBAL_BATCH_SIZE in our earlier loss calculation and therefore need to sum the losses from each replica to get an accurate loss.

For the test loss, though, we don’t need to calculate gradients since we’re not training, so there is no need to divide the loss by GLOBAL_BATCH_SIZE, which is why we don’t want to call compute_loss() and also why we don’t need to call strategy_reduce() in distributed_test_step()

I hope this helps.

2 Likes

Thanks for clarifying Wendy, I was not sure why that was so.

I much better understand now. Thank you very much @Wendy!

Thank you also @gent.spah :slight_smile:

1 Like