Image Dataset#

Images are a type of data commonly used in deep learning. PyTorch provides tools to load and preprocess this type of data efficiently. The general workflow for creating a dataset from image files is as follows.

  • Step 1. Organize your images into folders, and remove any corrupt images.

  • Step 2. Define a sequence of transformations to apply to each image. This can include resizing, cropping, conversion to tensors, and normalization.

  • Step 3. Use the ImageFolder class to create a dataset that reads images from the folders, applies the transforms, and assigns labels based on the folder structure.

  • Step 4. Split the dataset into training and validation sets using the Subset class.

For a more detailed explanation of this topic, you can refer to this tutorial and this tutorial.

import torch
import torchvision
from torchvision.transforms import v2

from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import pathlib
import os

Organize images into folders#

The first step is to organize your images into a specific folder structure. Within the root directory, each subfolder represents a class. The images contained in a subfolder are treated as belonging to the corresponding class. For example, if you have a dataset of 3 classes (cats, dogs, and birds), your folder structure might look like this.

root/
│
├── cat/
│   ├── cat001.jpg
│   ├── cat002.jpg
│   └── ...
│
├── dog/
│   ├── dog001.jpg
│   ├── dog002.jpg
│   └── ...
│
└── bird/
    ├── bird001.jpg
    ├── bird002.jpg
    └── ...

Removing corrupted images#

It is common to have corrupted images in a dataset. A simple way to check for corrupted images is to read the first few bytes of each file and check if it contains the string “JFIF”, which is a standard part of the header of valid JPEG files. If this string is not present, then the image is either corrupted or in a different format.

Additionally, as an extra precaution, we can read and write all the images to ensure that they are in a consistent format. This is a more time-consuming process, however, and should only be done if necessary.

data_folder = pathlib.Path(".data/cats_vs_dogs/PetImages")

num_skipped = 0

for path in data_folder.rglob("*.jpg"):

    # Check header
    with open(path, "rb") as file:
        delete = b"JFIF" not in file.peek(10)

    # Read/Write image
    #if not delete:
    #    img = plt.imread(path)
    #    plt.imsave(path, img)

    # Delete image with bad header
    if delete:
        num_skipped += 1
        os.remove(path)

print(f"Deleted {num_skipped} images.")
Deleted 0 images.

Preprocessing#

The PyTorch API allows you to transform images on-the-fly while loading them. Normally, when a dataset is created through a custom Dataset provided by TorchVision, you can pass a transform argument that will be applied to each image, and a target_transform argument that will be applied to each label. There are many transformations available in TorchVision. This example provides a visual illustration for some of them.

Convert & Normalize#

PyTorch loads images in the PIL format by default. For training, we need the images as PyTorch tensors with values normalized to the range [0, 1]. To make these transformations, we use the following functions provided in the module torchvision.transforms.v2.

  • ToImage - Convert a PIL Image to the Image type, which is a subclass of torch.Tensor.

  • ToDtype - Convert the data values to floats, and optionally normalize them to the range [0, 1].

minimal_preprocess = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True)
])

Resize & Crop#

Images in a dataset may have different sizes. To ensure that all images have the same size, and possibly the same aspect ratio, we can use the following transformations.

  • Resize - Resize the input image to the given size. By default, the shorter edge of the image is matched to the given size, and the larger edge is scaled proportionally to maintain the aspect ratio.

  • CenterCrop - Crops the center of the image to the given size. If the image is smaller than the given size, it is padded with zeros.

custom_preprocess = v2.Compose([
    v2.ToImage(),
    v2.Resize(128),
    v2.CenterCrop(128),
    v2.ToDtype(torch.float32, scale=True),
])

Pre-trained model#

Pre-trained models expect images to be preprocessed in the exact same way they were during training. TorchVision bundles the necessary preprocessing transforms into each model weight. These are accessible via the weight.transforms attribute.

weights = torchvision.models.MobileNet_V3_Large_Weights.DEFAULT

specific_preprocess = weights.transforms()

Class ImageFolder#

The ImageFolder class is a dataset provided by TorchVision that reads images from a folder structure and assigns labels based on the folder structure. The constructor of this class takes the path to the root directory of the dataset, and an optional transform argument that applies the specified transforms to each image.

dataset = torchvision.datasets.ImageFolder(data_folder, transform=custom_preprocess)

print("Images: ", len(dataset))
print("Classes:", dataset.classes)
Images:  23410
Classes: ['Cat', 'Dog']

Data split#

When working with a dataset, it is common to split the data into training, validation, and test sets. The ImageFolder class does not provide a built-in way to split the data, but you can use the Subset class to create subsets of the dataset. The Subset class takes a dataset and a list of indices as arguments and returns a subset of the dataset containing only the specified indices. The indices can be generated using the train_test_split function provided by the scikit-learn library.

train_idx, test_idx = train_test_split(range(len(dataset)), stratify=dataset.targets, test_size=0.2, shuffle=True, random_state=42)

train_ds = torch.utils.data.Subset(dataset, train_idx)
test_ds  = torch.utils.data.Subset(dataset, test_idx)

print("Train set:", len(train_ds))
print("Test set: ", len(test_ds))
Train set: 18728
Test set:  4682