Applying Batch Normalization in CNN model using PyTorch
In PyTorch, we can easily apply batch normalization in a CNN model.
For applying BN in 1D Convolutional Neural Network model, we use ‘nn.BatchNorm1d()’.
import torch import torch.nn as nn class CNN1D(nn.Module): def __init__(self): super(CNN1D, self).__init__() self.conv1 = nn.Conv1d(3, 16, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm1d(16) self.conv2 = nn.Conv1d(16, 32, kernel_size=3, stride=1, padding=1) self.bn2 = nn.BatchNorm1d(32) self.fc = nn.Linear(32 * 28, 10) # Example fully connected layer def forward(self, x): x = torch.relu(self.bn1(self.conv1(x))) x = torch.relu(self.bn2(self.conv2(x))) x = x.view(-1, 32 * 28) # Reshape for fully connected layer x = self.fc(x) return x # Instantiate the model model = CNN1D()
For applying Batch Normalization in 2D Convolutional Neural Network model, we use ‘nn.BatchNorm2d()’.
import torch import torch.nn as nn class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(16) self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) self.bn2 = nn.BatchNorm2d(32) self.fc = nn.Linear(32 * 28 * 28, 10) # Example fully connected layer def forward(self, x): x = torch.relu(self.bn1(self.conv1(x))) x = torch.relu(self.bn2(self.conv2(x))) x = x.view(-1, 32 * 28 * 28) # Reshape for fully connected layer x = self.fc(x) return x # Instantiate the model model = CNN()
For more detailed explanation regarding the implementation, refer to
What is Batch Normalization in CNN?
Batch Normalization is a technique used to improve the training and performance of neural networks, particularly CNNs. The article aims to provide an overview of batch normalization in CNNs along with the implementation in PyTorch and TensorFlow.
Table of Content
- Overview of Batch Normalization
- Need for Batch Normalization in CNN model
- How Does Batch Normalization Work in CNN?
- 1. Normalization within Mini-Batch
- 2. Scaling and Shifting
- 3. Learnable Parameters
- 4. Applying Batch Normalization
- 5. Training and Inference
- Applying Batch Normalization in CNN model using TensorFlow
- Applying Batch Normalization in CNN model using PyTorch
- Advantages of Batch Normalization in CNN