What’s the purpose of multiplication in this line of code :
# Accumulate the training loss for the batch
running_loss += loss.item() * images.size(0)
loss.item() is the loss for the entire batch isn’t it ? Imo it should be running_loss += loss.item()
Thank you !
hi @Liviu_Marian_Mircea
In PyTorch, the standard behavior for loss functions is to average the loss over the samples in a batch.
loss.item() represents the mean loss per sample for that batch, not the total sum of losses, multiplying by images.size(0) (the batch size) converts this average back into a total batch loss
Basically they are using the approach to get the true average loss for an epoch, the total cumulative loss divided by the total number of samples.
By multiplying by images.size(0) , we are ensuring summing the absolute loss contribution of every single image in the dataset.
formula is
average epoch loss = Σ(batch mean x batch size)/Total number of samples.
This works when we initialise loss function with reduction = sum
CrossEntropyLoss — PyTorch 2.11 documentation CrossEntropyLoss — PyTorch 2.11 documentation
criterion = nn.CrossEntropyLoss(reduction='sum')
# Now loss.item() is already the sum for the entire batch
running_loss += loss.item()
Regards
Dr. Deepti