In this post, we go over what is
BatchNormalization and try to develop an intuitive understanding of BatchNorm by visualizing how it works. We will be using PyTorch to build a simple CNN and explore the workings of
What is Batch Normalization
Batch Normalization is a deep-learning technique, which helps the model train faster, allows the usage of a higher learning rate, makes the model less dependant on initialization, and even improves the accuracy of the model. All these benefits have made Batch Normalization one of the most commonly used techniques in training deep neural networks.
In simple words, a Batch Normalization layer normalizes the output of a linear layer based on mean and standard deviation computed on each batch of data, the normalized output is then fed as an input to a non-linear activation layer. It rescales the input to an activation layer (
relu etc.) to have a Normal distribution (μ=0, σ=1). It is usually used between a linear layer and a non-linear activation.
So the next question is how are the values normalized? This is easier to understand by exploring the working of Batch Normalization in a CNN, specifically the working of
For each batch, input to
nn.BatchNorm2d() layer is of the shape [N, Cout, H, W] where N is a number of samples in a batch, Cout is the number of output channels, H and W are the height and width of the image.
nn.BatchNorm2d() normalizes values of each channel (on Cout dimension) in the input, by calculating the mean μc and standard deviation σc of each channel, across all images in the batch. Apart from this,
nn.BatchNorm2d() also has two learnable parameters
bias which is used to scale and shift the normalized output.
These steps are (not so) neatly shown below.
It is also important to note that the working of BatchNorm layer varies is
eval (test) mode. During
train mode μc and σc are calculated on each batch in the train data, and during the training mode BatchNorm also keeps track of a weighted average of μc and σc, lets call them running μc and running σc. During
eval mode, the running μc and running σc are used to compute the normalized output. Because of this if
model.eval() isn't invoked during testing, things can go bad real quick.
Visualizing Batch Normalization in a CNN
Now let's try to better understand BatchNorm by trying some visualizations. The problem we will use is a binary image classification on a subset of the cifar10 dataset. Our model will classify if an image is a
We are using a simple ResNet model for the classifier.
This model gives us a training accuracy of 95% and a validation accuracy of 89%. This is an agreeable performance for a model with ~70k parameters.
We are choosing
model.res_block1.bn1 a module, which is a
BatchNorm2d layer with 16 channels as the candidate for further exploration. I've always found that the best way to understand what's happening within a neural network is to get the intermediate input, output, parameters involved, etc., and visualize them, this is exactly what we will do here. And PyTorch offers
hooks to get these intermediate values.
Visualizing Input and Output
Now let's run the model for 1 training epoch, and plot the overall distribution of input and output values, for each of the 16 channels in
model.res_block1.bn1 layer, across all the batches.
Most of the input distribution seems to be roughly centered in [-0.5, 0.5] range, and the standard deviation is low, so it has narrow distributions. And outputs have a normal distribution.
Visualizing running μ and σ, weight and bias
We can also visualize how the running μ and σ, and the learnable parameters weights and bias change when the model is trained for 50 epochs.
running mean and
running var plots reiterate the input distribution we saw in the previous plot. The
running mean the plot also shows that the mean of channel values can vary significantly from batch to batch, this variation is eliminated by BatchNorm.
bn.bias show slight variation from the initial values of [0, 1] as we train the model.
With these visualizations, we are now equipped to form an intuition of "why batch normalization works"
Why Batch Normalization works?
The original Batch Normalization whitepaper explains that BatchNorm works because of reducing Internal Covariate Shift, which loosely translates to being able to generalize better on long-tail distribution, and there has been some debate on whether this is true.
We can see that BatchNorm normalizes the input given to the non-linear activation layer. Our intuitive understanding is based on the BatchNorm output's interaction with the non-linear activation layer.
Every non-linear activation function has a range in which the activation happens, lets call this activation region, it's [0, 1] for
ReLU, [-1, 1] for
tanh, and outside this region, the rate of change (gradient) of the activation function is less, let's call this the saturated region. And while passing inputs to the activation layer, it helps if the inputs have a stable distribution, that lies in this activation region, because when the data lies in the activation region, it consistently gets a higher gradient, because the rate of change is high in the activation region. If the input data lies in the saturated region of the activation function, the gradient of the output is very low. Having higher gradients from the activation functions helps the gradient descent to occur faster during Backpropagation.
Now we can appreciate why Batch Normalization helps in faster optimization of the model. By rescaling the inputs of the activation layer to avoid the saturated region, it ensures that the gradients are not lost in the forward pass.
So from this, we can say that the intuition behind BatchNorm helping model optimize faster is,
Batch Normalization rescales the input of activation functions to have a stable distribution in the activation region of non-linear activation functions. This ensures that the gradients are not lost during the forward pass, and helps gradient descent to optimize faster during backward pass.
This to me seems like a verifiable and easy-to-understand explanation of why Batch Normalization works. Hope this helps everyone who gives it a read 😄, the code for the model and the visualization are available in the colab notebook. It's possible to try and visualize the other BatchNorm layers in the model too, which could be an interesting thing to do.
Deep learning with PyTorch Book - The ideas in the post and parts of the code is from this amazing book by Thomas Viehmann.
W&B Pytorch study group - The study group organized by Sanyam Bhutani, and the awesome community, was a major inspiration in writing this post.
FastAI - The
Hook class used in the code is from the FastAI library.