Batch Norm reducing internal covariate shift

Without batch norm, the mean and variance of Z[l] is determined by a complex interaction between parameters and activations in the layers before l. This distribution fluctuates a lot. However, with batch norm the mean and variance can be controlled with beta and gamma. The learning algo learns good values of beta and gamma and we get a relatively consistent distribution for each mini batch thereby reducing internal covariate shift.

My question is how can we get a consistent distribution if beta and gamma also keep changing in every iteration of Gradient Descent? With fluctuating beta and gamma, the distribution also fluctuates.

I understand that forcing mean=0 and variance=1 for all units in all layers reduces the expressive power of the neural network and it cannot learn a good mapping from input to output. But I don’t see how a constantly fluctuating beta and gamma can give a consistent distribution. What’s to stop beta and gamma to keep changing throughout training and keep changing distribution just like it was before batch norm?

Appreciate any clarification :slight_smile:

As with other learning algorithms, the optimizer & learning rate will play a vital role in determining the magnitude of change of the parameters (\beta and \gamma). Changes should be minor if the layer encounters many batches of data are similarly distributed.

Thanks for your response! I have another question.

It goes from a distribution of mean=0 and variance=1 at the beginning of training to whatever the ideal mean and variance are eventually. Does this mean beta and gamma reach roughly optimal values early in the optimization process? So that during the rest of the training they don’t fluctuate much and thereby we have a steady distribution and can learn W,b much faster.

What if the mean (0) and variance(1) we start with are very very far from what is ideal for a particular Z. It might take very long for beta, gamma to reach good values. Is the hope that for the majority of the Zs we reach good values of beta, gamma pretty fast so that for majority of the optimization process we are working with reasonably steady distribution?

With mini batches having similar distributions and W,b updating only slightly at each iteration (controlled by learning rate, etc), why do we even need batch norm to begin with? It seems like there is very little change in distribution from iteration to iteration anyway.

\beta and \gamma are learnt like any other NN problem.

Wild fluctiation of the parameters IMO means that the learning rate is too high or the less possible, dataset isn’t well shuffled (I haven’t dug that deep to observe this problem).
While learning will almost flat out if there are many similarly distributed batches, we want to account for the reality that not all mini batches are similarly distributed.

Consider a simple linear regression problem with 1 input feature. We don’t pick 2 points and draw a line that call it y. The line of best fit is one that has the least error like MSE. It’s the same idea here. We learn the best rescaling parameters from mini batches of data.

Thanks for your reply @balaji.ambresh .

I think I get it now. There are 2 cases to think about this.

  1. If we choose to have a small learning rate, then the beta and gamma change only slightly from iteration to iteration thereby giving us a relatively steady distribution to deal with even as beta and gamma march slowly towards their optimum value.

  2. If learning rate is on the larger side, beta and gamma update by bigger values. So until they reach near optimum values we pay a price of having more internal covariate shift. However we don’t pay this price for too long as beta, gamma reach their optimal values much sooner and from that point on we have low ICS for the rest of training.

Would this be a fair way to summarize it?

This is not necessarily true since there’s nothing stopping you from providing a really high learning rate for rapid convergence.

Tuning learning rate (which I’ve never done from what tensorflow provides for this layer) has similar meaning to that of a regression problem. If learning rate is too low, it’s takes more iterations for model to reduce loss and arrive at the
local minimum loss. If learning rate is too high, overshooting and bouncing around the local minima is something to be aware of.

Correct. I guess I meant as high learning rare as possible with which we can still converge ideally. Not extremely high where we could have oscillations and potentially have divergence.

Sure. Learning rates have a lot of schedules. So, as long as you’ve got the hang of this intuition, you’re good to go.