Pretraining#

Training a deep network for object localization from scratch can be slow and may require a large amount of data. To accelerate training and benefit from prior knowledge, you can use a pretrained model as a starting point. PyTorch provides many pretrained models in torchvision.models. These models are trained on large datasets (ImageNet, COCO, …) and can provide rich visual features that generalize well to new tasks.

Feature Extraction#

For single-object localization, you will use MobileNetV3 as a feature extractor. The procedure is as follows:

  • Load the pretrained MobileNetV3 model.

  • Pass each image from your dataset through the backbone to obtain a feature tensor.

  • Store these feature tensors on disk, along with the corresponding labels and bounding box annotations.

After this step, the localization model is trained only on the extracted features, not on the original images. This greatly reduces training time and memory requirements, since the expensive convolutional computations of MobileNetV3 are done only once.

The following code snippet shows how to implement a feature extractor using MobileNetV3. Note that the extracted features are not flattened, so you will need to apply this operation in your model before feeding them into a fully connected layer.

Hide code cell source
import torch
from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights


class MobileNetExtractor(torch.nn.Module):
    """
    Extracts features from a batch of images with MobileNetV3.
    If `pool` is True, applies global average pooling to the features, 
    which significantly reduces their size.

    Note that the features are not flattened. 
    Add this operation in your model as needed:
    ```
    x = torch.flatten(x, 1)
    ```
    """

    def __init__(self, pool: bool = False):
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu")
        self.model = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT)
        self.model.eval()
        self.model.to(self.device)
        self.pool = pool

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        images = images.to(self.device)
        feat = self.model.features(images)

        if self.pool:
            feat = self.model.avgpool(feat)

        return feat.cpu()

Preprocessing#

When using a pretrained model, you must apply the same preprocessing as was used during its original training. This usually includes resizing, cropping, and normalization. Most importantly, any preprocessing applied to the image must also be applied consistently to the corresponding bounding boxes. For example:

  • If the image is resized, bounding box coordinates must be scaled accordingly.

  • If the image is cropped, bounding boxes must be shifted and clipped to match the new image region.

Ensuring that preprocessing is consistently applied to both images and bounding boxes is essential for correct training, since the model learns to map visual features to precise spatial coordinates. Fortunately, TorchVision provides utilities to help with this process.

Images#

Pretrained models in TorchVision come with predefined preprocessing transforms that match the training conditions. For MobileNetV3, you can obtain the appropriate image preprocessing as follows.

weights = MobileNet_V3_Large_Weights.DEFAULT
preprocess = weights.transforms()

Bounding boxes#

Most pretrained classification models in TorchVision (such as MobileNetV3) were trained using the legacy transforms from torchvision.transforms. These transforms work for images but are not compatible with bounding boxes. To solve this, TorchVision introduced a new transform API in torchvision.transforms.v2.

The v2 API provides the same functionality as v1, but extends it to work on all the data types defined in torchvision.tv_tensors, including images and bounding boxes. Once the bounding boxes of an image are wrapped in a BoundingBoxes object, the v2 transforms can be applied to them in the same way as to the image. Geometric transforms (like Resize, Crop, Flip) modify bounding box coordinates consistently, while pixel-level transforms (like Normalize) do not affect the bounding boxes.

Implementation#

Both VOCDataset and FilteredVOCDataset datasets accept a transforms argument (with the s at the end) that allows you to preprocess samples on the fly. You can define a custom transform class like the one provided below to apply ImageNet-style preprocessing to both images and bounding boxes. Your task is to implement the operations used to preprocess the bounding boxes.

import torch
from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F
from torchvision.transforms._presets import ImageClassification
from torchvision.tv_tensors._dataset_wrapper import VOC_DETECTION_CATEGORY_TO_IDX


class SingleClassPreprocess(torch.nn.Module):
    """
    Joint preprocessing for images and bounding boxes in VOC dataset.
    """

    def __init__(self, category: str, preprocess: ImageClassification):
        super().__init__()
        self.transform = preprocess
        self.target_transform = v2.Compose([
            FilterBoundingBoxes(category),   # TODO: Implement this class
            ResizeBoundingBoxes(preprocess), # TODO: Implement this class
            NormalizeBoundingBoxes()         # TODO: Implement this class
        ])

    def forward(self, image, target):
        image = self.transform(image)
        target = self.target_transform(target)
        return image, target

More specifically, you need to implement the following classes.

  • FilterBoundingBoxes

    • It removes all bounding boxes except the one corresponding to a specific category. If no such object is present, it creates a (0, 0, 0, 0) bounding box.

    • It also sets the label to 1 if an object of the specified category is present, otherwise 0.

    • Note 1: The category index is obtained from the dictionary VOC_DETECTION_CATEGORY_TO_IDX.

    • Note 2: The bounding box must be a tensor of shape (1, 4) and of type BoundingBoxes.

  • ResizeBoundingBoxes

    • It resizes and center crops the bounding box to match the image preprocessing.

    • Note 1: Resizing uses the resize_size value of the ImageClassification object.

    • Note 2: Center cropping uses the crop_size value of the ImageClassification object.

  • NormalizeBoundingBoxes

    • It normalizes the bounding box coordinates to be in the range [0, 1] relative to the image size.

    • Note: The image size is obtained from the canvas_size field of the BoundingBoxes object.

The general structure of each transform class is as follows.

Hide code cell source
class TransformBoundingBoxes(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, target):
        boxes = target['boxes']
        labels = target['labels']

        ... # Modify boxes and labels as needed

        target['boxes'] = boxes
        target['labels'] = labels
        return target

Usage#

To use the SingleClassPreprocess transform, simply instantiate it with the desired category and the image preprocessing, then pass it to a dataset (either VOCDataset or FilteredVOCDataset).

Hide code cell source
from torchvision.datasets import VOCDetection, wrap_dataset_for_transforms_v2

CATEGORY = "person"

weights = MobileNet_V3_Large_Weights.DEFAULT
preprocess = weights.transforms()
transforms = SingleClassPreprocess(CATEGORY, preprocess)

voc = VOCDetection(".data/", transforms=transforms)
voc = wrap_dataset_for_transforms_v2(voc)

image, target = voc[0]

To check for correctness, you can visualize an image and its bounding box before and after applying the preprocessing. (For visualization, you need to convert the bounding box back to absolute coordinates.)

Original

With Preprocessing

Before Transform

After Transform

Summary#

In a nutshell, this is what you should do in preparation for training.

  • Instanciate the FilteredVOCDetection dataset with the SingleClassPreprocess transform.

  • Create a DataLoader to iterate over the dataset in batches. (No shuffle.)

  • Loop through the dataset to extract features from the images. Store the resulting features, labels, and bounding boxes in three separate lists.

  • Once the loop is complete, convert the lists into PyTorch tensors and save them using torch.save.

Afterwards, you can load the tensors back with torch.load and wrap them in a TensorDataset. This dataset can then be used for training without any further preprocessing.