Implementation of Transfer Learning in Computer Vision using PyTorch

Here’s a simple example of how to implement transfer learning using a pre-trained model in PyTorch, Here we have performed object detection using a pre-trained Faster R-CNN model from the torchvision library. Here’s a brief explanation of its steps:

  1. Import Libraries:
    • The necessary libraries (torch, torchvision, PIL, matplotlib) are imported.
    • functional as F from torchvision.transforms is imported for image transformation functions.
  2. Load Pre-trained Model:
    • A pre-trained Faster R-CNN model (fasterrcnn_resnet50_fpn) is loaded and set to evaluation mode using model.eval().
  3. Load Image Function:
    • The load_image function takes an image path, loads the image using PIL, converts it to an RGB format, and then converts it to a tensor using F.to_tensor.
  4. Object Detection Function:
    • The detect_objects function takes the model, an image tensor, and a threshold value.
    • It determines the device (GPU if available, otherwise CPU) and moves the model and image to this device.
    • Inference is performed with torch.no_grad() to disable gradient calculation.
    • The output contains detected objects, and detections with scores below the threshold are filtered out.
  5. Plot Detections Function:
    • The plot_detections function takes the image tensor and filtered detections.
    • The image tensor is converted to a numpy array and transposed to the [H, W, C] format for plotting.
    • A matplotlib figure and axis are created, and the image is displayed.
    • Bounding boxes for detected objects are drawn on the image using patches.Rectangle with red borders.
    • The plot is shown with a title ‘Object Detections’.
  6. Main Execution:
    • An image is loaded from the specified path using load_image.
    • Object detection is performed on the image using detect_objects.
    • The image with detected bounding boxes is plotted using plot_detections.

Here is the entire script with these steps annotated:

Python
import torch
import torchvision
from PIL import Image
from torchvision.transforms import functional as F
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# Load a pre-trained Faster R-CNN model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()  # Set the model to evaluation mode

# Function to load an image and convert it to a tensor
def load_image(image_path):
    image = Image.open(image_path).convert("RGB")
    image = F.to_tensor(image)
    return image

# Function to perform object detection


def detect_objects(model, image, threshold=0.5):
    # Move the image to the same device as the model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    image = image.to(device)

    # Perform inference
    with torch.no_grad():
        outputs = model([image])

    # Filter out detections with a score below the threshold
    detections = outputs[0]
    scores = detections['scores']
    keep = scores >= threshold
    filtered_detections = {k: v[keep].cpu() for k, v in detections.items()}

    return filtered_detections

# Function to plot the image with detected bounding boxes


def plot_detections(image, detections):
    # Convert the tensor image to a numpy array and transpose it to [H, W, C] format
    image = image.permute(1, 2, 0).numpy()

    # Create a figure and axis
    fig, ax = plt.subplots(1, figsize=(12, 9))

    # Display the image
    ax.imshow(image)

    # Plot each bounding box
    for box in detections['boxes']:
        x1, y1, x2, y2 = box
        width = x2 - x1
        height = y2 - y1
        rect = patches.Rectangle(
            (x1, y1), width, height, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)

    # Set plot title and show plot
    ax.set_title('Object Detections')
    plt.show()


# Load an image
image_path = "pawangunjan.jpg"
image = load_image(image_path)

# Perform object detection
detections = detect_objects(model, image, threshold=0.5)

# Plot the image with detections
plot_detections(image, detections)

Output:


Conclusion

Transfer learning is a versatile and effective technique for enhancing computer vision models, enabling them to achieve high performance with limited data and reduced training time. By leveraging pre-trained models, practitioners can build robust solutions for a wide range of applications, from image classification to object detection and beyond.




Transfer Learning for Computer Vision

Transfer learning is a powerful technique in the field of computer vision, where a pre-trained model on a large dataset is fine-tuned for a different but related task. This approach leverages the knowledge gained from the initial training to improve performance and reduce training time for the new task. Here’s an overview of transfer learning for computer vision:

Similar Reads

What is Transfer Learning?

Transfer learning involves taking a pre-trained model, typically trained on a large and diverse dataset like ImageNet, and adapting it for a specific task. This method is particularly useful when the target dataset is smaller or lacks the diversity needed to train a high-performance model from scratch....

Key Concepts in Transfer Learning

Pre-trained Models: Models that have been previously trained on large datasets, such as VGG, ResNet, Inception, and DenseNet, have learned rich feature representations.Feature Extraction: Using the pre-trained model as a fixed feature extractor. The model’s earlier layers, which capture general features, are retained, while the final layers are replaced with new ones suitable for the target task.Fine-Tuning: Adjusting the weights of the pre-trained model’s layers along with the new layers. Fine-tuning can be done selectively, where only certain layers are updated to adapt the model to the new task....

Steps in Transfer Learning for Computer Vision

Select a Pre-trained Model: Choose a model pre-trained on a large dataset. Common choices include ResNet, VGG, and Inception due to their proven performance and availability in popular deep-learning libraries.Modify the Model: Replace the final classification layer of the pre-trained model with one that matches the number of classes in the target task. This often involves adding new fully connected layers followed by a softmax or sigmoid activation function.Freeze Layers: Optionally freeze the weights of the earlier layers to retain their learned features. This helps in leveraging the general patterns and structures learned from the large dataset.Train the Model: Train the modified model on the target dataset. This involves fine-tuning the new layers and possibly the later layers of the pre-trained model. Fine-tuning is typically done with a lower learning rate to avoid drastic changes to the pre-trained weights....

Advantages of Transfer Learning in Computer Vision

Reduced Training Time: By leveraging pre-trained models, transfer learning significantly reduces the time required to train a model for a new task.Improved Performance: Pre-trained models provide a strong starting point, often leading to better performance on the target task compared to training from scratch.Lower Data Requirements: Transfer learning is particularly beneficial when the target dataset is small, as the pre-trained model’s general features mitigate the need for large amounts of labeled data....

Limitations of Transfer Learning in Computer Vision

Domain Mismatch: Transfer learning assumes that the features learned from the source domain (e.g., ImageNet) are applicable to the target domain. However, if there is a significant difference between the source and target domains, the pre-trained model may not perform well, and the transferred features might not be as useful.Overfitting on Small Datasets: While transfer learning can help when the target dataset is small, there is still a risk of overfitting if the target dataset is too small to fine-tune the model properly. The model may memorize the training data instead of learning generalizable features.Model Complexity and Size: Pre-trained models, especially those based on deep neural networks, are often large and complex. This can lead to increased computational and memory requirements, making it challenging to deploy these models on devices with limited resources.Limited Adaptability: Pre-trained models are typically fine-tuned for specific tasks. Adapting them to tasks that are significantly different from the original training objective might require extensive modifications and fine-tuning, which can be computationally expensive and time-consuming....

Applications of Transfer Learning in Computer Vision

Image Classification: Transfer learning can be used to adapt pre-trained models for classifying images into different categories specific to a new dataset.Object Detection: Models like Faster R-CNN and YOLO, pre-trained on datasets like COCO, can be fine-tuned for detecting objects in specific domains.Semantic Segmentation: Pre-trained models can be adapted for segmenting images into meaningful regions, useful in medical imaging and autonomous driving.Style Transfer: Transfer learning techniques can be employed to apply artistic styles from one image to another, leveraging features learned from diverse datasets....

Implementation of Transfer Learning in Computer Vision using PyTorch

Here’s a simple example of how to implement transfer learning using a pre-trained model in PyTorch, Here we have performed object detection using a pre-trained Faster R-CNN model from the torchvision library. Here’s a brief explanation of its steps:...