Deep Metric Learning

Deep Metric Learning#

Deep metric learning is a subfield of deep learning that focuses on learning a distance function over a set of data points. Unlike traditional classification tasks that map inputs to discrete labels, metric learning aims to learn a feature representation where similar samples are closer together, while dissimilar samples are farther apart in a learned embedding space. This approach is widely used in applications that rely on similarity measures, such as face recognition, image retrieval, signature verification, and person re-identification.

Metric learning methods can be broadly categorized into two families, based on the type of loss function used to train the model.

  • Contrastive approaches. These methods explicitly define relationships between pairs or triplets of samples. They focus on minimizing distances between similar samples while maximizing distances between dissimilar ones. Examples include pairwise loss, triplet loss, and quadruplet loss.

  • Non-contrastive approaches. These methods do not explicitly compare pairs or triplets of samples. Instead, they optimize embeddings using class-based or angular constraints. Examples include center loss, cosine loss, and arc loss.

Deep metric learning is a powerful tool for tasks requiring similarity-based decision-making. Depending on the application, choosing the right loss function is crucial. While contrastive losses are useful for direct comparison-based learning, non-contrastive approaches like ArcFace and CosFace offer better generalization and class separability. In this chapter, we will explore both contrastive and non-contrastive methods for deep metric learning in PyTorch.

Deep Metric Learning