Test loss in W4 L2/L3

Hi there.

I am having a bit of confusion with the test losses in these two labs.

In both labs, the test_loss is created as a tf.keras.metrics.Mean object within the strategy scope and its internal state is updated within the test_step function.

The confusion arises when printing out the resulting loss in the training loop; in lab 2 (and in the accompanying video), it is simply the result of test_loss, whereas in lab 3, it is the result of test_loss scaled by the number of replicas in sync.

I am struggling to see why this extra division is done in lab 3. Surely each call to update_state increments the count by 1, so it shouldn’t be needed? On the other hand, without this extra division, the test loss is significantly higher than one would expect.

Is it the case that within the Mean object, for each update to count, there are num_replicas_in_sync additions to the total? This would explain the extra division factor required in lab 3, but doesn’t explain the lack of division in lab 2.

Any help with understanding the reasoning of the extra division will help.

Thanks in advance!

In Lab 2 the Global batch size is set before the compute_loss and is the batch_size times the num of replicas (see Prepare data section) but in lab 3 is done at the compute_loss function at return.