Implementation of Transfer Learning in Computer Vision using PyTorch
Here’s a simple example of how to implement transfer learning using a pre-trained model in PyTorch, Here we have performed object detection using a pre-trained Faster R-CNN model from the torchvision library. Here’s a brief explanation of its steps:
- Import Libraries:
- The necessary libraries (
torch
,torchvision
,PIL
,matplotlib
) are imported. functional as F
fromtorchvision.transforms
is imported for image transformation functions.
- The necessary libraries (
- Load Pre-trained Model:
- A pre-trained Faster R-CNN model (
fasterrcnn_resnet50_fpn
) is loaded and set to evaluation mode usingmodel.eval()
.
- A pre-trained Faster R-CNN model (
- Load Image Function:
- The
load_image
function takes an image path, loads the image usingPIL
, converts it to an RGB format, and then converts it to a tensor usingF.to_tensor
.
- The
- Object Detection Function:
- The
detect_objects
function takes the model, an image tensor, and a threshold value. - It determines the device (GPU if available, otherwise CPU) and moves the model and image to this device.
- Inference is performed with
torch.no_grad()
to disable gradient calculation. - The output contains detected objects, and detections with scores below the threshold are filtered out.
- The
- Plot Detections Function:
- The
plot_detections
function takes the image tensor and filtered detections. - The image tensor is converted to a numpy array and transposed to the [H, W, C] format for plotting.
- A
matplotlib
figure and axis are created, and the image is displayed. - Bounding boxes for detected objects are drawn on the image using
patches.Rectangle
with red borders. - The plot is shown with a title ‘Object Detections’.
- The
- Main Execution:
- An image is loaded from the specified path using
load_image
. - Object detection is performed on the image using
detect_objects
. - The image with detected bounding boxes is plotted using
plot_detections
.
- An image is loaded from the specified path using
Here is the entire script with these steps annotated:
import torch
import torchvision
from PIL import Image
from torchvision.transforms import functional as F
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# Load a pre-trained Faster R-CNN model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval() # Set the model to evaluation mode
# Function to load an image and convert it to a tensor
def load_image(image_path):
image = Image.open(image_path).convert("RGB")
image = F.to_tensor(image)
return image
# Function to perform object detection
def detect_objects(model, image, threshold=0.5):
# Move the image to the same device as the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
image = image.to(device)
# Perform inference
with torch.no_grad():
outputs = model([image])
# Filter out detections with a score below the threshold
detections = outputs[0]
scores = detections['scores']
keep = scores >= threshold
filtered_detections = {k: v[keep].cpu() for k, v in detections.items()}
return filtered_detections
# Function to plot the image with detected bounding boxes
def plot_detections(image, detections):
# Convert the tensor image to a numpy array and transpose it to [H, W, C] format
image = image.permute(1, 2, 0).numpy()
# Create a figure and axis
fig, ax = plt.subplots(1, figsize=(12, 9))
# Display the image
ax.imshow(image)
# Plot each bounding box
for box in detections['boxes']:
x1, y1, x2, y2 = box
width = x2 - x1
height = y2 - y1
rect = patches.Rectangle(
(x1, y1), width, height, linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(rect)
# Set plot title and show plot
ax.set_title('Object Detections')
plt.show()
# Load an image
image_path = "pawangunjan.jpg"
image = load_image(image_path)
# Perform object detection
detections = detect_objects(model, image, threshold=0.5)
# Plot the image with detections
plot_detections(image, detections)
Output:
Conclusion
Transfer learning is a versatile and effective technique for enhancing computer vision models, enabling them to achieve high performance with limited data and reduced training time. By leveraging pre-trained models, practitioners can build robust solutions for a wide range of applications, from image classification to object detection and beyond.
Transfer Learning for Computer Vision
Transfer learning is a powerful technique in the field of computer vision, where a pre-trained model on a large dataset is fine-tuned for a different but related task. This approach leverages the knowledge gained from the initial training to improve performance and reduce training time for the new task. Here’s an overview of transfer learning for computer vision: