How Does Cross Entropy Loss Work in PyTorch?
ML @ 🏃♀️Runhouse🏠
Introduction to Cross-Entropy Loss
In machine learning, the choice of loss function is important for training effective models. One of the most widely used loss functions is cross-entropy loss, which is particularly well-suited for classification problems. Cross-entropy measures the difference between the predicted probability distribution and the true probability distribution.
The Math Behind Cross-Entropy Loss
In a multi-class setting, the total cross-entropy loss is the sum of the individual cross-entropy losses for each class. For a single prediction: L = -sum(y_true * log(y_pred))
Here, `y_true` is a one-hot encoded vector of the true class, and `y_pred` is the vector of predicted probabilities for each class.
You can average the per-prediction loss over all of a batch in order to get the average loss.
As you expect, the cross-entropy loss function is minimized when predicted probabilities match the true probabilities. As with all loss functions, as the model learns, the cross-entropy loss will decrease if the model becomes more accurate.
Why is Cross-Entropy Loss Popular?
Alignment with Probability Distributions
Cross-entropy loss calculates the "distance" between the predicted probability distribution of a model and the true distribution (often represented as one-hot encoded labels). By focusing on probabilities, it ensures that the model doesn’t just assign the correct class the highest probability but aims to match the true distribution as closely as possible. This probabilistic approach is especially useful in multi-class settings, as it emphasizes correct classification in terms of confidence levels.
Sensitivity to Class Confidence
Unlike other loss functions (e.g., Mean Squared Error, MSE), cross-entropy is sensitive to how confident the model is in its predictions. It penalizes incorrect, high-confidence predictions more than those with lower confidence. This drives the model to make highly confident predictions only when accurate, improving its ability to discriminate between classes.
Suitability for Multi-Class Classification
Cross-entropy loss handles multiple classes naturally by summing the log-loss over each class. In a multi-class problem, it calculates loss across all possible classes, which is hard to achieve with simpler losses like binary or mean squared error loss. It also scales efficiently as the number of classes increases, making it efficient for large multi-class problems.
Gradient Properties That Facilitate Convergence
Cross-entropy loss typically has well-behaved gradients that lead to faster convergence during training compared to alternatives like MSE, which can suffer from "vanishing gradients" in classification tasks. This makes cross-entropy easier to optimize, especially for deep neural networks, where gradients are crucial for learning.
Better Performance with Softmax
Cross-entropy loss pairs effectively with the softmax activation function in the output layer of a neural network, as softmax converts raw output scores into probability distributions. Combined with softmax, cross-entropy directly reflects the likelihood of the true class, making it a more interpretable and naturally suited loss function for probabilistic outputs.
Using Cross-Entropy Loss in PyTorch
PyTorch provides a implements cross-entropy loss through the `torch.nn.CrossEntropyLoss` module. This function takes two inputs: the model's logits (unnormalized output scores) and the true class labels (as integer indices).
Here's an extremely minimal snippet of using `CrossEntropyLoss` in a PyTorch model:
import torch.nn as nn # Define the loss function criterion = nn.CrossEntropyLoss() # Forward pass through the model logits = model(inputs) loss = criterion(logits, targets) # Backpropagate and update model parameters loss.backward() optimizer.step()
The `CrossEntopyLoss` function handles the conversion from logits to probabilities, as well as the summation over classes. It also supports additional options, such as class weighting, handling of ignored classes, and different reduction modes (sum, mean, etc.).
Example: Training a Simple Model with Cross-Entropy Loss
Let's walk through a complete example of training a simple classification model using cross-entropy loss on the MNIST dataset:
import torch import torch.nn as nn import torch.optim as optim from torchvision.datasets import MNIST from torchvision.transforms import Compose, ToTensor from torch.utils.data import DataLoader # Define the model class MnistClassifier(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(784, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = x.view(-1, 784) x = torch.relu(self.fc1(x)) x = self.fc2(x) return x # Load the MNIST dataset transform = Compose([ToTensor()]) train_dataset = MNIST(root='./data', train=True, download=True, transform=transform) train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) # Define the model, loss, and optimizer model = MnistClassifier() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # Training loop for epoch in range(10): for inputs, targets in train_loader: # Forward pass outputs = model(inputs) loss = criterion(outputs, targets) # Backward pass and optimization optimizer.zero_grad() loss.backward() optimizer.step() print(f'Epoch [{epoch+1}/10], Loss: {loss.item():.4f}')
This example demonstrates how to use `CrossEntropyLoss` to train a simple classifier on the MNIST dataset. By minimizing the cross-entropy loss, the model learns to map the input images to the correct digit classes.
For more information, refer to the PyTorch documentation.
To see how to use Runhouse to simplify model training and ensure your code runs identically across research and production, check out our Torch training example.
Stay up to speed 🏃♀️📩
Subscribe to our newsletter to receive updates about upcoming Runhouse features and announcements.