Hey @abdou_brk,
No worries, let me try it again with some mathematics, so that we can use the formulations to our advantage. Let’s consider that we have the inputs z_1, z_2, z_3, z_4, z_5 (to a single neuron) for 5 batches (with varying mean and variance), which we will be considering for batch normalization. Here, note that BatchNorm is applied for each of the neurons individually, i.e., different \gamma and \beta for each of the neurons.
Consider case-1, in which there is no batch normalization. Let’s say that
-
z^{(1)} follows distribution D_1 having mean \mu_1 and variance \sigma_1^2
-
z^{(2)} follows distribution D_2 having mean \mu_2 and variance \sigma_2^2
and so on. Now, in this case, the weights will have to adjust themselves according to 5 different distributions. This is only going to be more difficult (for the optimization algorithm) when the number of batches with varying distributions will increase. In simple words, different batches of inputs will have different means and variances, and the weights would adjust (to a considerable extent) to the most recent batch of inputs, i.e., the weights will keep on oscillating instead of converging. This is only going to become worse if we encounter the case that the dev/test set samples follow a different (even slightly) distribution than the train set samples. In this case, the weights won’t be adjusted to these samples, and would give poor performance.
Now, let’s see how Batch Normalization can help us to avoid all this. I guess the normalization part of BatchNorm is pretty trivial, so I will be using the following short-hand notation moving forward from here:
z_{norm}^{(i)} = \dfrac{z^{(i)} - \mu_i}{\sqrt{\sigma^2_i + \epsilon}}
Now, if we don’t involve any learnable parameters here, then for each of the neurons, the inputs will have the mean as 0, and the variance as 1. Intuitively, we can understand that this would not be good, since all the features would lose their identity (in the terms of their mean and variance across the batches). Empirically, you can understand this by simply training 2 networks, one with normalization and one with batch normalization, and comparing their performance.
So, I am assuming we agree now that simply normalization isn’t going to cut it. Now, we will introduce the learnable parameters, \gamma and \beta, and the resulting normalized values would be,
\tilde{z}^{(i)} = \gamma * z_{norm}^{(i)} + \beta
Now, if we consider the role of weights and bias, their role is to model the relative importance of the different features and their relations between one form of representation of features and another form of representation of features, so that when we combine these different features using weights and bias, we get the most appropriate form. However, if we don’t use BatchNorm, they would also have to account for the different distributions of different batches. But if we use BatchNorm, we can give this duty of accounting for the different distributions of different batches to \gamma and \beta, i.e., separation of tasks between weights/bias and gamma/beta.
Here, you might wonder that \gamma and \beta will also change, and indeed they will. But once the first epoch is completed, the parameters, i.e., \gamma and \beta would have seen all the batches, and would have approximated the variance and mean to a great extent, thereby, changing pretty less in the next epochs. On the other hand, weights and bias must continuously change, since they got other duty as well, which is in fact, the major task of these weights and bias, and this continuous change could make it a difficult job for the weights and bias to account for the changing distributions, since the distributions won’t change across the epochs (at least in the generic case, when we have a stored dataset).
And perhaps the greatest advantage of BatchNorm is seen during inference. In the generic case, the distributions of samples differ slightly between the train and dev/test sets. When we use batch norm, it uses the running mean and running variance (computed over the train samples), to normalize the dev/test samples, and then uses the learnt \gamma and \beta to shift the mean and variance to the ones that the weights of the network are adjusted to.
Let me know if this helps.
Cheers,
Elemento