Overview#

The goal of this project is to develop a retrieval system that can search for natural images using hand-drawn sketches as queries. The system works by:

  • computing a similarity score between a query sketch and each photo in the retrieval dataset,

  • ranking the photos based on their similarity score,

  • returning the top-ranked photos as the most relevant results.

You will implement the retrieval system using deep metric learning, where embeddings are learned from two distinct types of image: sketches and photos. A convolutional network (CNN) will generate embeddings from photos, while a separate CNN will generate embeddings from sketches. Despite processing different input modalities, both networks are trained to produce outputs in the same embedding space, allowing for direct comparison between sketches and images.

Pretraining the backbones#

The photo and sketch networks consist of two parts: a convolutional backbone and a fully-connected head. To simplify training and improve results, the backbone can be taken from a network that has already been trained on ImageNet, such as ResNet-50 or MobileNet. The head, however, is initialized randomly, since it needs to learn to produce embeddings that work specifically for sketch-photo retrieval.

Recommendation

At the start, it’s best to freeze the backbone and train only the head. This is called feature extraction. Once the head is working well, you can try fine-tuning the entire network.

Feature extraction#

You will use the same pretrained model for both the sketch and photo networks. The pretrained model will be modified to output the features of the global average pooling layer, flattened to a 1D vector. Since the backbone is frozen, you can extract the features once for all photos and sketches, and save them to disk. Then, you will create a Dataset that provides tuples of (photo feature, sketch feature, label). Each item in this dataset corresponds to a sketch-photo pair that belongs to the same category. This dataset will be used to train the heads of the embedding networks.

Head architecture#

The head of each network consists of several fully-connected layers with ReLU activations and a final output layer with no activation.

  • The output layer determines the dimensionality of the embedding space. Common choices are 128, 256, or 512 dimensions.

  • The embeddings must be normalized to unit length. Inclde this normalization in the forward pass of the head, so you don’t have to remember to do it later. (Use torch.nn.functional.normalize).

  • Batch normalization can help with convergence. You can insert torch.nn.BatchNorm1d between the intermediate layers and their activation functions.

Training the heads#

After the photo-sketch dataset is ready and the network architecture is defined, you will begin training the heads that will produce the sketch and photo embeddings. This is done with deep metric learning, using a loss function that encourages the embeddings of a sketch and its corresponding photo to be close together, while pushing the embeddings of other photos away. There are several loss functions that can be used for this purpose, such as

  • Triplet loss,

  • NT-Xent loss,

  • CosFace loss,

  • ArcFace loss.

Recommendation

It is perhaps best to start with triplet loss, as it is the simplest to implement and understand. In this case, the anchor is the sketch, the positive is the corresponding photo, and the negative is a photo from a different category. Triplets are sampled within a batch using the batch-all mining strategy.

Custom train_step function#

You can reuse the training loop from the tutorials by providing the appropriate train_step() function. The general approach is similar to the Weak Supervision tutorial, but with a few differences.

  • The photo and sketch networks should be combined into a single PyTorch model, which will be passed to the train_step() function.

    class EmbeddingNet(nn.Module):
    
        def __init__(self, ...):
            super().__init__()
            self.photo_net = ...
            self.sketch_net = ...
    
  • The loss function should take three arguments, corresponding to a batch of sketch embeddings, photo embeddings, and labels.

    class LossFunction(torch.nn.Module):
    
        def __init__(self, ...):
            super().__init__()
            self.margin = ...
    
        def forward(self, anchor: torch.Tensor, positive: torch.Tensor, label: torch.Tensor):
            """
            For each i, anchor[i] and positive[i] belong to the same category label[i].
            For each j ≠ i, anchor[i] and positive[j] may belong to different categories, 
            depending on their labels.
            """
            ...
            return loss
    
  • The train_step() function should compute the embeddings for both photos and sketches, pass them to the loss function, and then continue as usual (backward pass, optimizer step, etc.).

    def train_step(model: EmbeddingNet, 
                   batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
                   loss_fn: torch.nn.Module, 
                   optimizer: torch.optim.Optimizer, 
                   device: torch.device) -> float:
        
        # Unpack the batch
        photos, sketches, labels = batch
    
        # Send data to device
        photos   = photos.to(device)
        sketches = sketches.to(device)
        labels   = labels.to(device)
    
        # Forward pass
        anchors   = model.sketch_net(sketches)
        positives = model.photo_net(photos)
    
        # Loss function
        loss = loss_fn(anchors, positives, labels)
    
        ... # To be continued