Conclusion#
This chapter has guided you through the essentials of building and training neural networks in PyTorch. Here’s a quick recap of the key concepts to remember.
Data Management. PyTorch provides the classes
Dataset
andDataLoader
to load and iterate over batches of data. They can be combined with the APItorchvision.transforms.v2
to apply transformations to image data before it is fed into the network.Building Neural Networks. PyTorch makes it easy to define neural networks by subclassing
torch.nn.Module
and implementing theforward
method using the layers available intorch.nn
.Loss Function. PyTorch provides a wide range of loss functions that can be used to train a model. The choice of loss function depends on the task at hand.
For binary classification, the network should be trained with
nn.BCEWithLogitsLoss
, which combines the sigmoid activation and the binary cross-entropy loss.For multi-class classification, the network should be trained with
nn.CrossEntropyLoss
, which combines the softmax activation and the negative log-likelihood loss.For regression, the network should be trained with
nn.MSELoss
.
Optimizer. PyTorch provides a wide range of optimizers that can be used to train a model, such as
SGD
,Adam
,RMSprop
from thetorch.optim
module. The choice of optimizer depends on the task at hand, butAdam
is generally a good choice.Training Neural Networks. The training loop in PyTorch typically involves iterating over batches of data using a DataLoader, computing predictions with the model, calculating the prediction error using a loss function, and updating the model parameters with an optimizer. It is a good practice to isolate the training logic in a separate function, so that it can be reused easity to experiment with different architectures and hyperparameters.
Overfitting. Neural networks tend to overfit the training data, i.e., they perform well on the training data but poorly on unseen data. It is a good practice to monitor the performance of the model on a separate validation set during training. This can be done quite easily by modifying the training loop to include an evaluation phase.