Quick Recap

Quick Recap#

This chapter has guided you through the essentials of building, training, and evaluating neural networks in PyTorch. Here’s a quick recap of the key concepts to remember.

  • Data Management. PyTorch provides the classes Dataset and DataLoader to load and iterate over batches of data. They can be combined with the API torchvision.transforms.v2 to apply transformations to image data before it is fed into the network.

  • Building Neural Networks. PyTorch makes it easy to define neural networks by subclassing torch.nn.Module and implementing the forward method using the layers available in torch.nn.

  • Loss Function. PyTorch provides a wide range of loss functions that can be used to train a model. The choice of loss function depends on the task at hand.

    • For binary classification, the network should be trained with nn.BCEWithLogitsLoss, which combines the sigmoid activation and the binary cross-entropy loss.

    • For multi-class classification, the network should be trained with nn.CrossEntropyLoss, which combines the softmax activation and the negative log-likelihood loss.

    • For regression, the network should be trained with nn.MSELoss.

  • Optimizer. PyTorch provides a wide range of optimizers that can be used to train a model, such as SGD, Adam, RMSprop from the torch.optim module. The choice of optimizer depends on the task at hand, but Adam or AdamW is generally a good choice.

  • Training Neural Networks. The training loop in PyTorch typically involves iterating over batches of data using a DataLoader, computing predictions with the model, calculating the prediction error using a loss function, and updating the model parameters with an optimizer. It is a good practice to isolate the training logic in a separate function, so that it can be reused easily to experiment with different architectures and hyperparameters.

  • Overfitting. Neural networks tend to overfit the training data, i.e., they perform well on the training data but poorly on unseen data. It is a good practice to monitor the performance of the model on a separate validation set during training. This can be done quite easily by modifying the training loop to include an evaluation phase.

  • Evaluation. Implementing evaluation metrics from scratch can be tedious and error-prone. Third-party libraries like TorchEval, TorchMetrics, and scikit-learn provide a wide range of pre-implemented metrics that can be easily integrated into your deep learning pipeline.