Saving and Loading Model
Method 1: Using torch.save() and torch.load()
The following code shows method to save and load the model using the built-in function provided by the torch module. The torch.save() method directly saves model object into the file and the torch.load() loads the model back into the memory.
Python
# Save the model torch.save(cnn_model.state_dict(), 'cnn_model.pth' ) # Load the model loaded_model = SimpleCNN() loaded_model.load_state_dict(torch.load( 'cnn_model.pth' )) # Set the model to evaluation mode loaded_model. eval () |
Output:
SimpleCNN(
(conv1_layer): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2_layer): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc1_layer): Linear(in_features=1568, out_features=128, bias=True)
(fc2_layer): Linear(in_features=128, out_features=10, bias=True)
)
Method 2: Using model.state_dict()
Now, let us see another way to save and load the model using the state_dict() method. This method stores the parameters of the created model. When the model is loaded, a new model with the same architecture is created. Then, the parameters of the new model are replaced with the stored parameters. Since only parameters are stored, this method is memory efficient. The following code snippet illustrates this method.
Python
# Saving the model torch.save(cnn_model.state_dict(), 'cnn_model.pth' ) # Loading the model loaded_model = SimpleCNN() loaded_model.load_state_dict(torch.load( 'cnn_model.pth' )) print (loaded_model) |
Output:
SimpleCNN(
(conv1_layer): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2_layer): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc1_layer): Linear(in_features=1568, out_features=128, bias=True)
(fc2_layer): Linear(in_features=128, out_features=10, bias=True)
)
Method 3: Saving and Loading using the Checkpoints
The checkpoints method saves the model by creating a dictionary that contains all the necessary information like model state_dict, optimizer state_dict, current epoch, loss, etc. And, to load the model, the checkpoint file is loaded to retrieve the information. This method is demonstrated as shown below:
Python
# Saving the model checkpoint = { 'epoch' : epoch, 'model_state_dict' : cnn_model.state_dict(), 'optimizer_state_dict' : optimizer.state_dict(), 'loss' : loss, # you may add other information to add } torch.save(checkpoint, 'checkpoint.pth' ) # Loading the model checkpoint = torch.load( 'checkpoint.pth' ) cnn_model = SimpleCNN() cnn_model.load_state_dict(checkpoint[ 'model_state_dict' ]) optimizer.load_state_dict(checkpoint[ 'optimizer_state_dict' ]) epoch = checkpoint[ 'epoch' ] loss = checkpoint[ 'loss' ] print (cnn_model) |
Output:
SimpleCNN(
(conv1_layer): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2_layer): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc1_layer): Linear(in_features=1568, out_features=128, bias=True)
(fc2_layer): Linear(in_features=128, out_features=10, bias=True)
)
Save and Load Models in PyTorch
It often happens that we need to use the already-trained models to perform some operations in our development environment. In this case, would you create the model again and again? Or, you will save the model somewhere else and load it as per the requirement. You would definitely choose the second option. So in this article, we will see how to implement the concept of saving and loading the models using PyTorch.
Table of Content
- What is PyTorch?
- Stepwise Guide to Save and Load Models in PyTorch
- Saving and Loading Model
- Frequently Asked Questions