This long post is my distillation of various discussions on two posts to try and get to the bottom of this. The take-home is: training = true is wrong when hardwired into batch-norms. It produces peculiar behaviour. There’s an MWE at the bottom that illustrates this. The following discusses what the issue is with some details on how batch-normalisation works.
To follow up on this: I just wanted to explain batch-norm (as I understand it) to see concretely what may be happening here.
In regular gradient descent, the activations of the next layer l are calculated from the previous one l-1 as: a^l = g(z^l) = g(W^la^{l-1} + b^l), where g is the activation function (can be relu or whatever).
In batch-norm, you normalise the z’s. You drop the initial bias term and calculate z^l = W^la^{l-1}. You then calculate a normalised z - \hat z - as:
\hat z^l = \frac{z^l-μ^l}{\sqrt{(σ^2 )^l+ε}}.
\mu^l is a vector - one for each neutron - of means of z^l taken across the particular mini-batch. \sigma^l is the variances of the same.
This normalised value is then scaled by two more learnable parameters: \tilde z = \gamma^l \hat z^l + \beta^l. Finally the activations of the layer are a^l = g(\tilde z^l).
I think the key term here is the mean and variances. These are calculated for each mini-batch. During training, running values of these are stored. During inference, these stored values are used (as these correspond to the learned parameters). If you set training = true in batch-norm layers when handing just one example, you will effectively ‘normalise’ that one layer. In other words your batch norms, in theory should return just the ‘bias’ term \beta^l, because \hat z^l = 0 since z^l = μ^l for one example.
I thought this might mean all predictions for any input are equal but this is not the case. My guess is that batch-norm applied to convolution layers is not so simple as I’ve presented above and (maybe) the means apply across channels? Potentially preventing the normalisation terms disappearing as I’ve suggested in the previous paragraph. But I’ve never seen the nuts and bolts of a conv-net batch norm implementation so am really not sure on this point.
The main point is: I’m fairly sure training = true is the issue here and that it shouldn’t be part of the definitions of identity or convolutional blocks.
Edit:
I did a bit more sleuthing and yes, convolution networks do take means across channels (that’s the axis = 3 bit). There’s a nice explainer here. This is why different inputs give different predictions with this model even when training = true and you only hand the predict() function one example.
To test all of the above I built a basic MWE, see below. This shows how, if you run a neural network with batch norm that only has densely connected layers, set training = true and then run predict - all inputs give the same predictions while how you predict (on one or all inputs in X_test) changes the result. Comment out the relevant lines to convert the fully connected net to a conv-net and you will see that training = true no longer results in all predictions being equal (because z^l \neq \mu^l any more) but you still see the behaviour that how you predict matters in this regime: pred(X_{test}[i]) \neq pred(X_{test})[I].
I think I fully understand this now - thanks for everyone’s input. Happy to answer any issues!
import tensorflow as tf
import numpy as np
import scipy.misc
from tensorflow.keras.applications.resnet_v2 import ResNet50V2
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet_v2 import preprocess_input, decode_predictions
from tensorflow.keras import layers
from tensorflow.keras.layers import Input, Add, Dense, Activation, ZeroPadding2D, BatchNormalization, Flatten, Conv2D, AveragePooling2D, MaxPooling2D, GlobalMaxPooling2D
from tensorflow.keras.models import Model, load_model
from resnets_utils import *
from tensorflow.keras.initializers import random_uniform, glorot_uniform, constant, identity
from tensorflow.python.framework.ops import EagerTensor
from matplotlib.pyplot import imshow
from test_utils import summary, comparator
import public_tests
def ResNet50(input_shape = (64, 64, 3), classes = 6):
"""
Basic NN
Arguments:
input_shape -- shape of the images of the dataset
classes -- integer, number of classes
Returns:
model -- a Model() instance in Keras
"""
# Define the input as a tensor with shape input_shape
X_input = Input(input_shape)
# to test batch-norm effects with convolutions, use next line and comment out following two; to see batch norm on regular NN do the opposite
# X = Conv2D(64, (7, 7), strides = (2, 2), kernel_initializer = glorot_uniform(seed=0))(X_input)
X = Flatten()(X_input)
X = Dense(64*7*7)(X)
## if you include the training=True line, you will see that the predictions are the same for all inputs (e.g. all of i = 0,1,2,...)
## and you will see that the predictions differ depending on whether you hand predict() all of X_test or just X_test[i]
# X = BatchNormalization()(X)
X = BatchNormalization()(X, training = True)
X = Activation('relu')(X)
# to test batch-norm effects with convolutions, use next line; to see batch norm on regular NN comment out
# X = Flatten()(X)
X = Dense(classes, activation='softmax', kernel_initializer = glorot_uniform(seed=0))(X)
# Create model
model = Model(inputs = X_input, outputs = X)
return model
# load data
X_train_orig, Y_train_orig, X_test_orig, Y_test_orig, classes = load_dataset()
X_train = X_train_orig / 255.
X_test = X_test_orig / 255.
Y_train = convert_to_one_hot(Y_train_orig, 6).T
Y_test = convert_to_one_hot(Y_test_orig, 6).T
# build and fit basic NN model
model = ResNet50(input_shape = (64, 64, 3), classes = 6)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(X_train, Y_train, epochs = 3, batch_size = 32)
# compare predictions from running on whole test set to specific example only
# key points: with training = true prediction_i_direct will be same for all examples (the model loses all input info because the means in the batch norm = the inputs so you are basically just fitting constants gamma)
# with training = true, how you predict changes the answer since the BN means and variances change dependening on whether you feed model all X_test or just X_test[i]
for i in range(0,3):
prediction_i_direct = model.predict(X_test[[i]])
prediction_i_from_all_preds = model.predict(X_test)[i]
print("Pred vector - inference from one example = ", i, " = ", prediction_i_direct)
print("Pred vector - inference from all examples = ", i, "= " , prediction_i_from_all_preds)