Please familiarize yourself with gradient accumulation.
Gradient accumulation is employed when one wants to train a large network with limited memory. To provide a concrete example, if your hardware can support only a batch size of 8 but you want to train the NN with batch size of 32, you’ll accumulate the gradients for 4 mini batches and then perform backward pass. In fact, this is exactly what the top figure is doing.
Pay atttention to the right most column Update
. This means that each device performs update only after all gradients are accumulated. This is why the Update
action is stacked at the end of timeline. It’ll help to pay attention to the notation in the provided image. F_{i,j} refers to the forward pass of j^{th} micro batch passing through i^{th} layer of the network i.e. each of Device x
is responsible for training 1 layer of the network. B_{i,j} is for the backward pass.
I hope this provides clarity on how the image at the bottom makes more efficient use of the hardware.