DLS course-4 general question on tfl.BatchNormalization

What does it mean when they say tfl.BatchNormalization( ) is performed along an axis ?

I understand normalization but what is meant by normalizing along an axis?

Can someone please explain?

In the programming assignments we mostly did normalization along the features/channels axis.

Along its axis it refers for example if you have a tensor of shape [a, b, c] the normalization is performed for example along dimension c, it is performed only for the values of stored in this dimension, for eg. along the height dimension of a cube.

could you please explain with a concrete example taking some sample values?
I know the axis/dimension part but cannot understand what is exactly done along the specified axis (and what is not done along the other axes)

Hello @ydsk1234,

I will share one example with you and share the key points, but you will have to experiment with it to find out anything else you are interested in :wink:

This is a toy dataset of 2 samples and 3 features.

import numpy as np
import tensorflow as tf

X = np.array([
    [1., 4.],
    [2., 5.],
    [3., 6.]], dtype=np.float32)

Let’s initialize a BN layer. Note that we use X.T so the zeroth axis becomes the sample axis and the first axis becomes the feature axis

bn = tf.keras.layers.BatchNormalization(axis=1, center=False, scale=False, name='bn')
_ = bn(X.T, training=False) # to initialize the Batch Normalization Layer

And print the weights out.

print(bn.weights)
#Output:
#[
#    <tf.Variable 'bn/moving_mean:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>, 
#    <tf.Variable 'bn/moving_variance:0' shape=(3,) dtype=float32, numpy=array([1., 1., 1.], dtype=float32)>
#]

Now, there are two variables, and you know what they do by their names. The point here is that each of them has three values because we chose axis=1 and that is the “feature axis” and we have three features in total.

The initial means and variances are boring, so let’s change them to something else:

bn.set_weights([
    np.array([5, 4, 3], dtype=np.float32),
    np.array([10**2, 10**2, 10**2], dtype=np.float32),
])

And I will let you check whether the updated means and variances deliver the expected results

print(bn(X.T, training=False))

Finally, don’t forget the documentation of the BatchNormalization for more details of the implementation.

Cheers,
Raymond

1 Like