From this picture about the distribution strategy, the num_batches is seen to increase by 1. What about in the case when we have more than one device? For instance, we have three devices, which accumulate into three different batches, one for each device.
Does num_batches increase by one for every batch in train_dist_dataset still hold true?
If it does, does the total_loss holds true as the total summed-up loss across the entire replicas?
Edited: I think I figure it out now. The total_loss does not receive the total sum of the loss across each replica, It instead receives the reduced(sum) of the losses across each replica.