Day 15 - Batch Normalization

What does batch normalization do?

Using batch normalization for input standardization

Trainable and non-trainable parameters

Performance impact during training and inference

Using the right activation functions and weight initialization strategies significantly reduce the problem of unstable gradients at the beginning of training. However, they don't guarantee that the problem won't come back during training.

Batch Normalization adds an operation in the model before or after the activation function in each hidden layer.

This operation, zero-centers and normlizes each input, and then shifts and scales the result.

The mean and variance used to zero-center and normalize the inputs is calculated per mini-batch of training data.

Batch Normalization also acts as a regularizer and reduced the need for other regularization techniques (e.g. dropout).

If a Batch Normalization layer is added as the first layer of a neural network then it will approximately standardize the inputs and StandardScaler doesn't need to be used.

The standardization will be approximate because the layer will find the mean and variance using only one batch at a time, and the layer can also scale and shift input features in addition to standardization.

The Batch Normalization Algorithm

  1. Calculate the mean of each input across the entire mini-batch.
  1. Calculate the standard deviation of each input across the entire mini-batch.
  1. Standardize each input instance in the mini-batch using the mean and standard deviation calculated above.
  1. Offset and scale each instance after standardization using a different offset parameter and scaling parameter for each instance.

Each input will have its own mean, standard deviation, offset and scaling parameters.

The offset and scaling parameter vectors are learned through backpropagation. They are trainable parameters in Keras.

The mean and standard deviation parameter vectors are calculated per mini-batch during training time and not learned during backpropagation. They are called non-trainable parameters in Keras.

Keras keeps an exponential moving average of the mean and standard deviation parameter vectors, and uses these after training is complete.

Batch Normalization increases the number of parameters in the model and therefore makes each epoch slower. However, this is compensated for by faster convergence (fewer epochs to reach same performance).

Batch Normalization also slows down the network when making predictions since there are more computations to perform. However, it is possible to fuse a Batch Normalization layer with previous layer after training to avoid this. This is done by updating the previous layer's weights and biases so that its output matches the scale and offset of the Batch Normalization layer.

Using Batch Normalization in Keras

Batch Normalization can be added before or after each hidden layer’s activation function. There is a lot of debate about which is better, and it usually depends on the task and dataset.

To add it after a hidden layer’s activation, add a BatchNormalization layer after a hidden layer.

To add it before a layer’s activation, remove the activation from the hidden layer and add it separately after the BatchNormalization layer. Also, since the BatchNormalization layer includes an offset per input, you can also remove the bias from the previous layer.

Hyperparameters

The momentum hyperparameter is used when updating the exponential moving averages of the mean and standard deviation parameter vectors.

vv×momentum+vnew×(1momentum)\vec{v} \larr \vec{v} \times \text{momentum} + \vec{v}_{new} \times (1 - \text{momentum})

A good value of momentum is close to 1 (e.g. 0.9, 0.99, 0.999, etc.) use more 9s for larger datasets and smaller mini-batches.

The axis hyperparameter determines which axis should be normalized. The default value is -1, which means the last axis will be normalized. The mean and standard deviation will be calculated across other axes.

For 2-dimensional data with dimensions [batch size, features], this is what we need - each input feature normalized using mean and standard deviation calculated across the mini-batch.

For 3-dimensional data such as MNIST handwritten digits with dimensions [batch size, height, width], this will cause mean and standard deviation to be calculated per column of pixels computed across all instances in the mini-batch and across all rows in the column (28 mean, standard deviation, scale and shift parameters).

To treat the pixels in the image independently, set axis=[1, 2].