In the main training loop, I noticed we are “manually” calculating the training loss instead of instantiating a tf.keras.metrics.Mean() like for test_loss. This section of a TF tutorial (very similar to our lab) explains explains why we shouldn’t use tf.keras.metrics.Mean() for train_loss. However, this leaves me confused on two points:
Why can we use tf.keras.metrics.Mean() for test_loss?
In function test_step(), why do we directly call loss_object(labels, predictions) instead of doing compute_loss(labels, predictions) like in train_step()?
I think the idea @gent.spah is referring to is best described in the “Define the loss function” section of the tutorial you found, here
The basic idea is that for training, when we do backprop, the strategy will sum the gradient values from each replica, so we need to make sure the loss value we pass to tape.gradient() in train_step() takes this into account. That means when calculating the loss for training, we need to divide by GLOBAL_BATCH_SIZE, rather than by the size of the replica’s batch. compute_loss() does this by calling tf.nn.compute_average_loss() with the appropriate GLOBAL_BATCH_SIZE. You’ll also notice that later in distributed_train_step() that we call strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
to account for the fact that we divided by GLOBAL_BATCH_SIZE in our earlier loss calculation and therefore need to sum the losses from each replica to get an accurate loss.
For the test loss, though, we don’t need to calculate gradients since we’re not training, so there is no need to divide the loss by GLOBAL_BATCH_SIZE, which is why we don’t want to call compute_loss() and also why we don’t need to call strategy_reduce() in distributed_test_step()