A question about batch norm at test time

I have already read some of the questions related to batch norm, but I still do not understand how batch norm is applied at test time. Here is my understanding:

  1. Batch Norm at Train Time

For each mini-batch at t:
forward-propagation using X^{t},Y^{t} (mini-batch) with batch norm up to the Lth layer (the whole layer)
backward-propagation using X^{t},Y^{t}

My understanding is that batch norm uses batch mean and std, i.e. row-wise mean and std of X^{t} for normalizing during this procedure. So for each t-th batch and l-th layer, it will use different mean and std for normalizing the inputs.

  1. Batch Norm at Test Time: The problem

With mini-batch alone, we have obtained all the parameters so we can let a single example pass through the model and get the prediction as usual (just passing through each layer with usual forward-propagation). But with batch norm, the normalizing is done for each mini-batch and at l-th layer, so the question arises what kind of mean and std should be used for each layer instead.

  1. Solution

As we have seen, batch norm mean and std depend on both t and l. Therefore, gather all the mean and std that appeared at the l-th layer across all mini-batches during the forward-propagation and use exponentially moving average, and use those values to normalize inputs at the l-th layer during testing.

Questions:

  • Is this correct understanding of batch norm?
  • Why not use the usual average instead of exponentially moving average? The latter puts more weight to the lastly-used mini-batch, and I am not sure why one would do that. One benefit I see is that you can simply update each moving average for each layer so it is much easier to handle (just L many variables) compared to storing all the averages across all mini-batches and layers. I am not entirely convinced if this is the main reason of using exponentially moving average here.
1 Like

Hey @Taxxi,

I will try to answer your questions and you tell me if you need more clarifications.

First your understanding of batch normalization is generally correct, but let me clarify and address your questions:

1. Is this the correct understanding of batch norm?

Your understanding of how batch normalization works during training and testing is accurate. Batch normalization normalizes the inputs using the mean and standard deviation calculated within each mini-batch during training. However, during testing, it needs a way to calculate the mean and standard deviation for normalization since there are no mini-batches. The commonly used method is to use the moving averages of mean and standard deviation computed during training, which is your proposed solution.

2. Why use exponentially moving averages instead of the usual average?

The choice of using exponentially weighted moving averages for batch normalization statistics during testing serves several important purposes:

  • Stability: Using exponentially weighted moving averages adds a degree of stability to the normalization. It ensures that the statistics used during testing are not too sensitive to the specific characteristics of the last few mini-batches seen during training. This stability can help prevent sudden shifts in model behavior during testing.

  • Compatibility: Exponentially weighted moving averages allow the model to maintain a sense of memory about the entire training process. It’s like a long-term memory of the network’s experience with various mini-batches. Using a simple average would treat all mini-batches equally, potentially discarding useful information from earlier in training.

  • Adaptability: As you mentioned, it’s also computationally efficient to maintain exponentially weighted moving averages since you only need to update a few variables (mean and standard deviation per layer). It also naturally decays older statistics, which aligns well with the idea that older data may not be as relevant as newer data in a non-stationary environment.

So the use of exponentially weighted moving averages in batch normalization during testing is a technique to strike a balance between stability, adaptability, and computational efficiency.

Regards,
Jamal

Thank you for your reply. I just have one question.

I understood this part: ‘(Adaptability) It also naturally decays older statistics, which aligns well with the idea that older data may not be as relevant as newer data in a non-stationary environment.’

With this in mind, I think “(Compatibility) Using a simple average would treat all mini-batches equally, potentially putting too much weight on older data” might have slightly better nuance(since EWMA decays older statistics compared to a simple average)? I might be being too meticulous here, but it really helps me when I try to understand all the subtleties while I am studying.

1 Like

Yes you are right!!