| name | loss-functions-and-objectives |
| description | Loss function selection - numerical stability, custom losses, multi-task weighting, debugging |
Loss Functions and Objectives Skill
When to Use This Skill
Use this skill when:
- User asks "what loss function should I use?"
- Implementing binary, multi-class, or multi-label classification
- Implementing regression models
- Training on imbalanced datasets (class imbalance)
- Multi-task learning with multiple loss terms
- Custom loss function implementation needed
- Loss goes to NaN or infinity during training
- Loss not decreasing despite valid training loop
- User suggests BCE instead of BCEWithLogitsLoss (RED FLAG)
- User adds softmax before CrossEntropyLoss (RED FLAG)
- Multi-task losses without weighting (RED FLAG)
- Division or log operations in custom loss (stability concern)
- Segmentation, ranking, or specialized tasks
- Loss debugging and troubleshooting
Do NOT use when:
- User has specific bugs unrelated to loss functions
- Only discussing model architecture (no loss questions)
- Loss function already working well and no questions asked
- User needs general training advice (use optimization-algorithms skill)
Core Principles
1. The Critical Importance of Loss Functions
Loss functions are fundamental to deep learning:
- Direct objective that gradients optimize
- Wrong loss → model optimizes wrong thing
- Numerically unstable loss → NaN, training collapse
- Unweighted multi-task → one task dominates
- Mismatched loss for task → poor performance
Common Impact:
- Proper loss selection: 5-15% performance improvement
- Numerical stability: difference between training and crashing
- Class balancing: difference between 95% accuracy (useless) and 85% F1 (useful)
- Multi-task weighting: difference between all tasks learning vs one task dominating
This is NOT optional:
- Every SOTA paper carefully selects and tunes losses
- Loss function debugging is essential skill
- One mistake (BCE vs BCEWithLogitsLoss) can break training
2. Loss Selection Decision Tree
What is your task?
│
├─ Classification?
│ │
│ ├─ Binary (2 classes, single output)
│ │ → Use: BCEWithLogitsLoss (NOT BCELoss!)
│ │ → Model outputs: logits (no sigmoid)
│ │ → Target shape: (batch,) or (batch, 1) with 0/1
│ │ → Imbalanced? Add pos_weight parameter
│ │
│ ├─ Multi-class (N classes, one label per sample)
│ │ → Use: CrossEntropyLoss
│ │ → Model outputs: logits (batch, num_classes) - no softmax!
│ │ → Target shape: (batch,) with class indices [0, N-1]
│ │ → Imbalanced? Add weight parameter or use focal loss
│ │
│ └─ Multi-label (N classes, multiple labels per sample)
│ → Use: BCEWithLogitsLoss
│ → Model outputs: logits (batch, num_classes) - no sigmoid!
│ → Target shape: (batch, num_classes) with 0/1
│ → Each class is independent binary classification
│
├─ Regression?
│ │
│ ├─ Standard regression, squared errors
│ │ → Use: MSELoss (L2 loss)
│ │ → Sensitive to outliers
│ │ → Penalizes large errors heavily
│ │
│ ├─ Robust to outliers
│ │ → Use: L1Loss (MAE)
│ │ → Less sensitive to outliers
│ │ → Linear penalty
│ │
│ └─ Best of both (recommended)
│ → Use: SmoothL1Loss (Huber loss)
│ → L2 for small errors, L1 for large errors
│ → Good default choice
│
├─ Segmentation?
│ │
│ ├─ Binary segmentation
│ │ → Use: BCEWithLogitsLoss or DiceLoss
│ │ → Combine both: α*BCE + (1-α)*Dice
│ │
│ └─ Multi-class segmentation
│ → Use: CrossEntropyLoss or DiceLoss
│ → Imbalanced pixels? Use weighted CE or focal loss
│
├─ Ranking/Similarity?
│ │
│ ├─ Triplet learning
│ │ → Use: TripletMarginLoss
│ │ → Learn embeddings with anchor, positive, negative
│ │
│ ├─ Pairwise ranking
│ │ → Use: MarginRankingLoss
│ │ → Learn x1 > x2 or x2 > x1
│ │
│ └─ Contrastive learning
│ → Use: ContrastiveLoss or NTXentLoss
│ → Pull similar together, push dissimilar apart
│
└─ Multi-Task?
→ Combine losses with careful weighting
→ See Multi-Task Learning section below
Section 1: Binary Classification - BCEWithLogitsLoss
THE GOLDEN RULE: ALWAYS Use BCEWithLogitsLoss, NEVER BCELoss
This is the MOST COMMON loss function mistake in deep learning.
❌ WRONG: BCELoss (Numerically Unstable)
class BinaryClassifier(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(100, 1)
self.sigmoid = nn.Sigmoid() # ❌ DON'T DO THIS
def forward(self, x):
logits = self.fc(x)
return self.sigmoid(logits) # ❌ Applying sigmoid in model
# Training loop
output = model(x) # Probabilities [0, 1]
loss = F.binary_cross_entropy(output, target) # ❌ UNSTABLE!
Why this is WRONG:
- Numerical instability:
log(sigmoid(x))underflows for large negative x - Gradient issues: sigmoid saturates, BCE takes log → compound saturation
- NaN risk: When sigmoid(logits) = 0 or 1, log(0) = -inf
- Slower training: Less stable gradients
✅ RIGHT: BCEWithLogitsLoss (Numerically Stable)
class BinaryClassifier(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(100, 1)
# ✅ NO sigmoid in model!
def forward(self, x):
return self.fc(x) # ✅ Return logits
# Training loop
logits = model(x) # Raw logits (can be any value)
loss = F.binary_cross_entropy_with_logits(logits, target) # ✅ STABLE!
Why this is RIGHT:
- Numerically stable: Uses log-sum-exp trick internally
- Better gradients: Single combined operation
- No NaN: Stable for all logit values
- Faster training: More stable optimization
The Math Behind the Stability
Unstable version (BCELoss):
# BCE computes: -[y*log(σ(x)) + (1-y)*log(1-σ(x))]
# Problem: log(σ(x)) = log(1/(1+exp(-x))) underflows for large negative x
# Example:
x = -100 # Large negative logit
sigmoid(x) = 1 / (1 + exp(100)) ≈ 0 # Underflows to 0
log(sigmoid(x)) = log(0) = -inf # ❌ NaN!
Stable version (BCEWithLogitsLoss):
# BCEWithLogitsLoss uses log-sum-exp trick:
# log(σ(x)) = log(1/(1+exp(-x))) = -log(1+exp(-x))
# Rewritten as: -log1p(exp(-x)) for stability
# For positive x: use log(sigmoid(x)) = -log1p(exp(-x))
# For negative x: use log(sigmoid(x)) = x - log1p(exp(x))
# This is ALWAYS stable!
# Example:
x = -100
log(sigmoid(x)) = -100 - log1p(exp(-100))
= -100 - log1p(≈0)
= -100 # ✅ Stable!
Inference: Converting Logits to Probabilities
# During training
logits = model(x)
loss = F.binary_cross_entropy_with_logits(logits, target)
# During inference/evaluation
logits = model(x)
probs = torch.sigmoid(logits) # ✅ NOW apply sigmoid
predictions = (probs > 0.5).float() # Binary predictions
Handling Class Imbalance with pos_weight
# Dataset: 95% negative (class 0), 5% positive (class 1)
# Problem: Model predicts all negatives → 95% accuracy but useless!
# Solution 1: pos_weight parameter
neg_count = 950
pos_count = 50
pos_weight = torch.tensor([neg_count / pos_count]) # 950/50 = 19.0
loss = F.binary_cross_entropy_with_logits(
logits, target,
pos_weight=pos_weight # Weight positive class 19x more
)
# pos_weight effect:
# - Positive examples contribute 19x to loss
# - Forces model to care about minority class
# - Balances gradient contributions
# Solution 2: Focal Loss (see Advanced Techniques section)
Complete Binary Classification Example
import torch
import torch.nn as nn
import torch.nn.functional as F
class BinaryClassifier(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, 64)
self.fc2 = nn.Linear(64, 1) # Single output for binary
def forward(self, x):
x = F.relu(self.fc1(x))
return self.fc2(x) # ✅ Return logits (no sigmoid)
# Training setup
model = BinaryClassifier(input_dim=100)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Handle imbalanced data
class_counts = torch.bincount(train_labels.long())
pos_weight = class_counts[0] / class_counts[1]
# Training loop
model.train()
for x, y in train_loader:
optimizer.zero_grad()
logits = model(x)
loss = F.binary_cross_entropy_with_logits(
logits.squeeze(), # Shape: (batch,)
y.float(), # Shape: (batch,)
pos_weight=pos_weight if imbalanced else None
)
loss.backward()
optimizer.step()
# Evaluation
model.eval()
with torch.no_grad():
logits = model(x_test)
probs = torch.sigmoid(logits) # ✅ Apply sigmoid for inference
preds = (probs > 0.5).float()
# Compute metrics
accuracy = (preds.squeeze() == y_test).float().mean()
# Better: Use F1, precision, recall for imbalanced data
Section 2: Multi-Class Classification - CrossEntropyLoss
THE GOLDEN RULE: Pass Logits (NOT Softmax) to CrossEntropyLoss
❌ WRONG: Applying Softmax Before CrossEntropyLoss
class MultiClassifier(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.fc = nn.Linear(100, num_classes)
self.softmax = nn.Softmax(dim=1) # ❌ DON'T DO THIS
def forward(self, x):
logits = self.fc(x)
return self.softmax(logits) # ❌ Applying softmax in model
# Training
probs = model(x) # Already softmaxed
loss = F.cross_entropy(probs, target) # ❌ WRONG! Double softmax!
Why this is WRONG:
- Double softmax: CrossEntropyLoss applies softmax internally
- Numerical instability: Extra softmax operation
- Wrong gradients: Backprop through unnecessary operation
- Confusion: Model outputs different things in train vs eval
✅ RIGHT: Pass Logits to CrossEntropyLoss
class MultiClassifier(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.fc = nn.Linear(100, num_classes)
# ✅ NO softmax in model!
def forward(self, x):
return self.fc(x) # ✅ Return logits
# Training
logits = model(x) # Shape: (batch, num_classes)
target = ... # Shape: (batch,) with class indices [0, num_classes-1]
loss = F.cross_entropy(logits, target) # ✅ CORRECT!
Target Shape Requirements
# ✅ CORRECT: Target is class indices
logits = torch.randn(32, 10) # (batch=32, num_classes=10)
target = torch.randint(0, 10, (32,)) # (batch=32,) with values in [0, 9]
loss = F.cross_entropy(logits, target) # ✅ Works!
# ❌ WRONG: One-hot encoded target
target_onehot = F.one_hot(target, num_classes=10) # (batch=32, num_classes=10)
loss = F.cross_entropy(logits, target_onehot) # ❌ Type error!
# If you have one-hot, convert back to indices:
target_indices = target_onehot.argmax(dim=1) # (batch,)
loss = F.cross_entropy(logits, target_indices) # ✅ Works!
Handling Class Imbalance with Weights
# Dataset: Class 0: 1000 samples, Class 1: 100 samples, Class 2: 50 samples
# Problem: Model biased toward majority class
# Solution 1: Class weights (inverse frequency)
class_counts = torch.tensor([1000., 100., 50.])
class_weights = 1.0 / class_counts
class_weights = class_weights / class_weights.sum() * len(class_weights)
# Normalizes so weights sum to num_classes
# class_weights = [0.086, 0.857, 1.714]
# Minority classes weighted much higher
loss = F.cross_entropy(logits, target, weight=class_weights)
# Solution 2: Balanced accuracy loss (effective sample weighting)
# Weight each sample by inverse class frequency
sample_weights = class_weights[target] # Index into weights
loss = F.cross_entropy(logits, target, reduction='none')
weighted_loss = (loss * sample_weights).mean()
# Solution 3: Focal Loss (see Advanced Techniques section)
Inference: Converting Logits to Probabilities
# During training
logits = model(x)
loss = F.cross_entropy(logits, target)
# During inference/evaluation
logits = model(x) # (batch, num_classes)
probs = F.softmax(logits, dim=1) # ✅ NOW apply softmax
preds = logits.argmax(dim=1) # Or directly argmax logits (same result)
# Why argmax logits works:
# argmax(softmax(logits)) = argmax(logits) because softmax is monotonic
Complete Multi-Class Example
class MultiClassifier(nn.Module):
def __init__(self, input_dim, num_classes):
super().__init__()
self.fc1 = nn.Linear(input_dim, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
x = F.relu(self.fc1(x))
return self.fc2(x) # ✅ Return logits
# Training setup
num_classes = 10
model = MultiClassifier(input_dim=100, num_classes=num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Compute class weights for imbalanced data
class_counts = torch.bincount(train_labels)
class_weights = 1.0 / class_counts.float()
class_weights = class_weights / class_weights.sum() * num_classes
# Training loop
model.train()
for x, y in train_loader:
optimizer.zero_grad()
logits = model(x) # (batch, num_classes)
loss = F.cross_entropy(logits, y, weight=class_weights)
loss.backward()
optimizer.step()
# Evaluation
model.eval()
with torch.no_grad():
logits = model(x_test)
probs = F.softmax(logits, dim=1) # For calibration analysis
preds = logits.argmax(dim=1) # Class predictions
accuracy = (preds == y_test).float().mean()
Section 3: Multi-Label Classification
Use BCEWithLogitsLoss (Each Class is Independent)
# Task: Predict multiple tags for an image
# Example: [dog, outdoor, sunny] → target = [1, 0, 1, 0, 1, 0, ...]
class MultiLabelClassifier(nn.Module):
def __init__(self, input_dim, num_classes):
super().__init__()
self.fc = nn.Linear(input_dim, num_classes)
# ✅ NO sigmoid! Return logits
def forward(self, x):
return self.fc(x) # (batch, num_classes) logits
# Training
logits = model(x) # (batch, num_classes)
target = ... # (batch, num_classes) with 0/1 for each class
# Each class is independent binary classification
loss = F.binary_cross_entropy_with_logits(logits, target.float())
# Inference
logits = model(x_test)
probs = torch.sigmoid(logits) # Per-class probabilities
preds = (probs > 0.5).float() # Threshold each class independently
# Example output:
# probs = [0.9, 0.3, 0.8, 0.1, 0.7, ...]
# preds = [1.0, 0.0, 1.0, 0.0, 1.0, ...] (dog=yes, outdoor=no, sunny=yes, ...)
Handling Imbalanced Labels
# Some labels are rare (e.g., "sunset" appears in 2% of images)
# Solution 1: Per-class pos_weight
label_counts = train_labels.sum(dim=0) # Count per class
num_samples = len(train_labels)
neg_counts = num_samples - label_counts
pos_weights = neg_counts / label_counts # (num_classes,)
loss = F.binary_cross_entropy_with_logits(
logits, target.float(),
pos_weight=pos_weights
)
# Solution 2: Focal loss per class (see Advanced Techniques)
Section 4: Regression Losses
MSELoss (L2 Loss) - Default Choice
# Mean Squared Error: (pred - target)^2
pred = model(x) # (batch, output_dim)
target = ... # (batch, output_dim)
loss = F.mse_loss(pred, target)
# Characteristics:
# ✅ Smooth gradients
# ✅ Penalizes large errors heavily (squared term)
# ❌ Sensitive to outliers (outliers dominate loss)
# ❌ Can be numerically large if targets not normalized
# When to use:
# - Standard regression tasks
# - Targets are normalized (similar scale to predictions)
# - Outliers are rare or not expected
L1Loss (MAE) - Robust to Outliers
# Mean Absolute Error: |pred - target|
pred = model(x)
loss = F.l1_loss(pred, target)
# Characteristics:
# ✅ Robust to outliers (linear penalty)
# ✅ Numerically stable
# ❌ Non-smooth at zero (gradient discontinuity)
# ❌ Equal penalty for all error magnitudes
# When to use:
# - Outliers present in data
# - Want robust loss
# - Median prediction preferred over mean
SmoothL1Loss (Huber Loss) - Best of Both Worlds
# Smooth L1: L2 for small errors, L1 for large errors
pred = model(x)
loss = F.smooth_l1_loss(pred, target, beta=1.0)
# Formula:
# loss = 0.5 * (pred - target)^2 / beta if |pred - target| < beta
# loss = |pred - target| - 0.5 * beta otherwise
# Characteristics:
# ✅ Smooth gradients everywhere
# ✅ Robust to outliers (L1 for large errors)
# ✅ Fast convergence (L2 for small errors)
# ✅ Best default for regression
# When to use:
# - General regression (RECOMMENDED DEFAULT)
# - Uncertainty about outliers
# - Want fast convergence + robustness
Target Normalization (CRITICAL)
# ❌ WRONG: Unnormalized targets
pred = model(x) # Model outputs in range [0, 1] (e.g., after sigmoid)
target = ... # Range [1000, 10000] - NOT NORMALIZED!
loss = F.mse_loss(pred, target) # Huge loss values, bad gradients
# ✅ RIGHT: Normalize targets
# Option 1: Min-Max normalization to [0, 1]
target_min = train_targets.min()
target_max = train_targets.max()
target_normalized = (target - target_min) / (target_max - target_min)
pred = model(x) # Range [0, 1]
loss = F.mse_loss(pred, target_normalized) # ✅ Same scale
# Denormalize for evaluation:
pred_denorm = pred * (target_max - target_min) + target_min
# Option 2: Standardization to mean=0, std=1
target_mean = train_targets.mean()
target_std = train_targets.std()
target_standardized = (target - target_mean) / target_std
pred = model(x) # Should output standardized values
loss = F.mse_loss(pred, target_standardized) # ✅ Normalized scale
# Denormalize for evaluation:
pred_denorm = pred * target_std + target_mean
# Why normalization matters:
# 1. Loss values in reasonable range (not 1e6)
# 2. Better gradient flow
# 3. Learning rate can be standard (1e-3)
# 4. Faster convergence
Complete Regression Example
class Regressor(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, 128)
self.fc2 = nn.Linear(128, output_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
return self.fc2(x) # Linear output for regression
# Normalize targets
target_mean = train_targets.mean(dim=0)
target_std = train_targets.std(dim=0)
def normalize(targets):
return (targets - target_mean) / (target_std + 1e-8)
def denormalize(preds):
return preds * target_std + target_mean
# Training
model = Regressor(input_dim=100, output_dim=1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
model.train()
for x, y in train_loader:
optimizer.zero_grad()
pred = model(x)
y_norm = normalize(y)
loss = F.smooth_l1_loss(pred, y_norm) # Using Huber loss
loss.backward()
optimizer.step()
# Evaluation
model.eval()
with torch.no_grad():
pred_norm = model(x_test)
pred = denormalize(pred_norm) # Back to original scale
mse = F.mse_loss(pred, y_test)
print(f"Test MSE: {mse.item()}")
Section 5: Numerical Stability in Loss Computation
Critical Rule: Avoid log(0), log(negative), and division by zero
Problem 1: Division by Zero
# ❌ UNSTABLE: No protection
def iou_loss(pred, target):
intersection = (pred * target).sum()
union = pred.sum() + target.sum()
iou = intersection / union # ❌ Division by zero if both empty!
return 1 - iou
# ✅ STABLE: Add epsilon
def iou_loss(pred, target):
eps = 1e-8
intersection = (pred * target).sum()
union = pred.sum() + target.sum()
iou = (intersection + eps) / (union + eps) # ✅ Safe
return 1 - iou
# Why epsilon works:
# - Denominator never zero: union + 1e-8 ≥ 1e-8
# - Doesn't affect result when union is large
# - Prevents NaN propagation
Problem 2: Log of Zero or Negative
# ❌ UNSTABLE: No clamping
def custom_loss(pred, target):
ratio = pred / target
return torch.log(ratio).mean() # ❌ log(0) = -inf, log(neg) = nan
# ✅ STABLE: Clamp before log
def custom_loss(pred, target):
eps = 1e-8
ratio = pred / (target + eps) # Safe division
ratio = torch.clamp(ratio, min=eps) # Ensure positive
return torch.log(ratio).mean() # ✅ Safe log
# Alternative: Use log1p for log(1+x)
def custom_loss(pred, target):
eps = 1e-8
ratio = pred / (target + eps)
return torch.log1p(ratio).mean() # log1p(x) = log(1+x), more stable
Problem 3: Exponential Overflow
# ❌ UNSTABLE: Direct exp can overflow
def custom_loss(logits):
return torch.exp(logits).mean() # ❌ exp(100) = overflow!
# ✅ STABLE: Clamp logits or use stable operations
def custom_loss(logits):
# Option 1: Clamp logits
logits = torch.clamp(logits, max=10) # Prevent overflow
return torch.exp(logits).mean()
# Option 2: Use log-space operations
# If computing log(exp(x)), just return x!
Problem 4: Custom Softmax (Use Built-in Instead)
# ❌ UNSTABLE: Manual softmax
def manual_softmax(logits):
exp_logits = torch.exp(logits) # ❌ Overflow for large logits!
return exp_logits / exp_logits.sum(dim=1, keepdim=True)
# ✅ STABLE: Use F.softmax (uses max subtraction trick)
def stable_softmax(logits):
return F.softmax(logits, dim=1) # ✅ Handles overflow internally
# Built-in implementation (for understanding):
def softmax_stable(logits):
# Subtract max for numerical stability
logits_max = logits.max(dim=1, keepdim=True)[0]
logits = logits - logits_max # Now max(logits) = 0
exp_logits = torch.exp(logits) # No overflow!
return exp_logits / exp_logits.sum(dim=1, keepdim=True)
Epsilon Best Practices
# Epsilon guidelines:
eps = 1e-8 # ✅ Good default for float32
eps = 1e-6 # ✅ Alternative, more conservative
eps = 1e-10 # ❌ Too small, can still underflow
# Where to add epsilon:
# 1. Denominators: x / (y + eps)
# 2. Before log: log(x + eps) or log(clamp(x, min=eps))
# 3. Before sqrt: sqrt(x + eps)
# Where NOT to add epsilon:
# 4. ❌ Numerators usually don't need it
# 5. ❌ Inside standard PyTorch functions (already stable)
# 6. ❌ After stable operations
Complete Stable Custom Loss Template
class StableCustomLoss(nn.Module):
def __init__(self):
super().__init__()
self.eps = 1e-8
def forward(self, pred, target):
# 1. Ensure inputs are valid
assert not torch.isnan(pred).any(), "pred contains NaN"
assert not torch.isnan(target).any(), "target contains NaN"
# 2. Safe division
ratio = pred / (target + self.eps)
# 3. Clamp before log/sqrt/pow
ratio = torch.clamp(ratio, min=self.eps, max=1e8)
# 4. Safe log operation
log_ratio = torch.log(ratio)
# 5. Check output
loss = log_ratio.mean()
assert not torch.isnan(loss), "loss is NaN"
return loss
Section 6: Multi-Task Learning and Loss Weighting
The Problem: Different Loss Scales
# Task 1: Classification, CrossEntropyLoss ~ 0.5-2.0
# Task 2: Regression, MSELoss ~ 100-1000
# Task 3: Reconstruction, L2 Loss ~ 10-50
# ❌ WRONG: Naive sum (task 2 dominates!)
loss1 = F.cross_entropy(logits1, target1) # ~0.5
loss2 = F.mse_loss(pred2, target2) # ~500.0
loss3 = F.mse_loss(recon, input) # ~20.0
total_loss = loss1 + loss2 + loss3 # ≈ 520.5
# Gradient analysis:
# ∂total_loss/∂θ ≈ ∂loss2/∂θ (loss1 and loss3 contribute <5%)
# Model learns ONLY task 2, ignores tasks 1 and 3!
Solution 1: Manual Weighting
# Balance losses to similar magnitudes
loss1 = F.cross_entropy(logits1, target1) # ~0.5
loss2 = F.mse_loss(pred2, target2) # ~500.0
loss3 = F.mse_loss(recon, input) # ~20.0
# Set weights so weighted losses are similar scale
w1 = 1.0 # Keep as is
w2 = 0.001 # Scale down by 1000x
w3 = 0.05 # Scale down by 20x
total_loss = w1 * loss1 + w2 * loss2 + w3 * loss3
# = 1.0*0.5 + 0.001*500 + 0.05*20
# = 0.5 + 0.5 + 1.0 = 2.0
# All tasks contribute meaningfully!
# How to find weights:
# 1. Run 1 epoch with equal weights
# 2. Print loss magnitudes
# 3. Set weights inversely proportional to magnitudes
# 4. Iterate until balanced
Solution 2: Uncertainty Weighting (Learnable)
# "Multi-Task Learning Using Uncertainty to Weigh Losses" (Kendall et al., 2018)
# Learn task weights during training!
class MultiTaskLoss(nn.Module):
def __init__(self, num_tasks):
super().__init__()
# Log variance parameters (learnable)
self.log_vars = nn.Parameter(torch.zeros(num_tasks))
def forward(self, losses):
"""
losses: list of task losses [loss1, loss2, loss3, ...]
For each task:
weighted_loss = (1 / (2 * σ²)) * loss + log(σ)
Where σ² = exp(log_var) is the learned uncertainty
- High uncertainty → lower weight on that task
- Low uncertainty → higher weight on that task
"""
weighted_losses = []
for i, loss in enumerate(losses):
precision = torch.exp(-self.log_vars[i]) # 1/σ²
weighted_loss = precision * loss + self.log_vars[i]
weighted_losses.append(weighted_loss)
return sum(weighted_losses)
# Usage
model = MultiTaskModel()
multi_loss = MultiTaskLoss(num_tasks=3)
# Optimize both model and loss weights
optimizer = torch.optim.Adam([
{'params': model.parameters()},
{'params': multi_loss.parameters(), 'lr': 0.01} # Can use different LR
])
# Training loop
for x, targets in train_loader:
optimizer.zero_grad()
# Compute task predictions
out1, out2, out3 = model(x)
# Compute task losses
loss1 = F.cross_entropy(out1, targets[0])
loss2 = F.mse_loss(out2, targets[1])
loss3 = F.mse_loss(out3, targets[2])
# Combine with learned weighting
total_loss = multi_loss([loss1, loss2, loss3])
total_loss.backward()
optimizer.step()
# Monitor learned weights
if step % 100 == 0:
weights = torch.exp(-multi_loss.log_vars)
print(f"Task weights: {weights.detach()}")
Solution 3: Gradient Normalization
# GradNorm: balances task learning by normalizing gradient magnitudes
def grad_norm_step(model, losses, alpha=1.5):
"""
Adjust task weights to balance gradient magnitudes
losses: list of task losses
alpha: balancing parameter (1.5 typical)
"""
# Get initial loss ratios
initial_losses = [l.item() for l in losses]
# Compute average gradient norm per task
shared_params = list(model.shared_layers.parameters())
grad_norms = []
for loss in losses:
model.zero_grad()
loss.backward(retain_graph=True)
# Compute gradient norm
grad_norm = 0
for p in shared_params:
if p.grad is not None:
grad_norm += p.grad.norm(2).item() ** 2
grad_norms.append(grad_norm ** 0.5)
# Target: all tasks have same gradient norm
mean_grad_norm = sum(grad_norms) / len(grad_norms)
# Adjust weights
weights = []
for gn in grad_norms:
weight = mean_grad_norm / (gn + 1e-8)
weights.append(weight ** alpha)
# Normalize weights
weights = torch.tensor(weights)
weights = weights / weights.sum() * len(weights)
return weights
# Note: GradNorm is more complex, this is simplified version
# For production, use manual or uncertainty weighting
Solution 4: Loss Normalization
# Normalize each loss to [0, 1] range before combining
class NormalizedMultiTaskLoss(nn.Module):
def __init__(self, num_tasks):
super().__init__()
# Track running mean/std per task
self.register_buffer('running_mean', torch.zeros(num_tasks))
self.register_buffer('running_std', torch.ones(num_tasks))
self.momentum = 0.9
def forward(self, losses):
"""Normalize each loss before combining"""
losses_tensor = torch.stack(losses)
if self.training:
# Update running statistics
mean = losses_tensor.mean()
std = losses_tensor.std() + 1e-8
self.running_mean = (self.momentum * self.running_mean +
(1 - self.momentum) * mean)
self.running_std = (self.momentum * self.running_std +
(1 - self.momentum) * std)
# Normalize losses
normalized = (losses_tensor - self.running_mean) / self.running_std
return normalized.sum()
Best Practices for Multi-Task Loss
# Recommended approach:
1. Start with manual weighting:
- Run 1 epoch, check loss magnitudes
- Set weights to balance scales
- Quick and interpretable
2. If tasks have different difficulties:
- Use uncertainty weighting
- Let model learn task importance
- More training time but adaptive
3. Monitor individual task metrics:
- Don't just watch total loss
- Track accuracy/error per task
- Ensure all tasks learning
4. Curriculum learning:
- Start with easy tasks
- Gradually add harder tasks
- Can improve stability
# Example monitoring:
if step % 100 == 0:
print(f"Total Loss: {total_loss.item():.4f}")
print(f"Task 1 (CE): {loss1.item():.4f}")
print(f"Task 2 (MSE): {loss2.item():.4f}")
print(f"Task 3 (Recon): {loss3.item():.4f}")
# Check if any task stuck
if loss1 > 5.0: # Not learning
print("WARNING: Task 1 not learning, increase weight")
Section 7: Custom Loss Function Implementation
Template for Custom Loss
class CustomLoss(nn.Module):
"""
Template for implementing custom losses
"""
def __init__(self, weight=None, reduction='mean'):
"""
Args:
weight: Manual sample weights (optional)
reduction: 'mean', 'sum', or 'none'
"""
super().__init__()
self.weight = weight
self.reduction = reduction
self.eps = 1e-8 # For numerical stability
def forward(self, pred, target):
"""
Args:
pred: Model predictions
target: Ground truth
Returns:
Loss value (scalar if reduction != 'none')
"""
# 1. Input validation
assert pred.shape == target.shape, "Shape mismatch"
assert not torch.isnan(pred).any(), "pred contains NaN"
# 2. Compute element-wise loss
loss = self.compute_loss(pred, target)
# 3. Apply sample weights if provided
if self.weight is not None:
loss = loss * self.weight
# 4. Apply reduction
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else: # 'none'
return loss
def compute_loss(self, pred, target):
"""Override this method with your loss computation"""
# Example: MSE
return (pred - target) ** 2
Example 1: Dice Loss (Segmentation)
class DiceLoss(nn.Module):
"""
Dice Loss for segmentation tasks
Dice = 2 * |X ∩ Y| / (|X| + |Y|)
Loss = 1 - Dice
Good for:
- Binary segmentation
- Handling class imbalance
- Smooth gradients
"""
def __init__(self, smooth=1.0):
super().__init__()
self.smooth = smooth # Prevent division by zero
def forward(self, pred, target):
"""
Args:
pred: (batch, C, H, W) logits
target: (batch, C, H, W) binary masks
"""
# Apply sigmoid to get probabilities
pred = torch.sigmoid(pred)
# Flatten spatial dimensions
pred = pred.view(pred.size(0), pred.size(1), -1) # (batch, C, H*W)
target = target.view(target.size(0), target.size(1), -1)
# Compute dice per sample and per class
intersection = (pred * target).sum(dim=2) # (batch, C)
union = pred.sum(dim=2) + target.sum(dim=2) # (batch, C)
dice = (2 * intersection + self.smooth) / (union + self.smooth)
# Average over classes and batch
return 1 - dice.mean()
# Usage
criterion = DiceLoss(smooth=1.0)
loss = criterion(logits, masks)
# Often combined with BCE:
dice_loss = DiceLoss()
bce_loss = nn.BCEWithLogitsLoss()
total_loss = 0.5 * dice_loss(logits, masks) + 0.5 * bce_loss(logits, masks)
Example 2: Focal Loss (Imbalanced Classification)
class FocalLoss(nn.Module):
"""
Focal Loss for addressing class imbalance
FL = -α * (1 - p)^γ * log(p)
- α: class balancing weight
- γ: focusing parameter (typical: 2.0)
- (1-p)^γ: down-weights easy examples
Good for:
- Highly imbalanced datasets (e.g., object detection)
- Many easy negatives, few hard positives
- When class weights aren't enough
"""
def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, logits, target):
"""
Args:
logits: (batch, num_classes) raw logits
target: (batch,) class indices
"""
# Compute cross entropy
ce_loss = F.cross_entropy(logits, target, reduction='none')
# Compute pt = e^(-CE) = probability of true class
pt = torch.exp(-ce_loss)
# Compute focal loss
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
# Apply reduction
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
# Usage
criterion = FocalLoss(alpha=1.0, gamma=2.0)
loss = criterion(logits, target)
# Effect of gamma:
# γ = 0: equivalent to CrossEntropyLoss
# γ = 2: typical value, strong down-weighting of easy examples
# γ = 5: extreme focusing, only hardest examples matter
# Example probability and loss weights:
# pt = 0.9 (easy): (1-0.9)^2 = 0.01 → 1% weight
# pt = 0.5 (medium): (1-0.5)^2 = 0.25 → 25% weight
# pt = 0.1 (hard): (1-0.1)^2 = 0.81 → 81% weight
Example 3: Contrastive Loss (Metric Learning)
class ContrastiveLoss(nn.Module):
"""
Contrastive Loss for learning embeddings
Pulls similar pairs together, pushes dissimilar pairs apart
Good for:
- Face recognition
- Similarity learning
- Few-shot learning
"""
def __init__(self, margin=1.0):
super().__init__()
self.margin = margin
def forward(self, embedding1, embedding2, label):
"""
Args:
embedding1: (batch, embedding_dim) first embeddings
embedding2: (batch, embedding_dim) second embeddings
label: (batch,) 1 if similar, 0 if dissimilar
"""
# Euclidean distance
distance = F.pairwise_distance(embedding1, embedding2)
# Loss for similar pairs: want distance = 0
loss_similar = label * distance.pow(2)
# Loss for dissimilar pairs: want distance ≥ margin
loss_dissimilar = (1 - label) * F.relu(self.margin - distance).pow(2)
loss = loss_similar + loss_dissimilar
return loss.mean()
# Usage
criterion = ContrastiveLoss(margin=1.0)
for (img1, img2, is_similar) in train_loader:
emb1 = model(img1)
emb2 = model(img2)
loss = criterion(emb1, emb2, is_similar)
Example 4: Perceptual Loss (Style Transfer, Super-Resolution)
class PerceptualLoss(nn.Module):
"""
Perceptual Loss using VGG features
Compares high-level features instead of pixels
Good for:
- Image generation
- Super-resolution
- Style transfer
"""
def __init__(self, layer='relu3_3'):
super().__init__()
# Load pre-trained VGG
vgg = torchvision.models.vgg16(pretrained=True).features
self.vgg = vgg.eval()
# Freeze VGG
for param in self.vgg.parameters():
param.requires_grad = False
# Select layer
self.layer_map = {
'relu1_2': 4,
'relu2_2': 9,
'relu3_3': 16,
'relu4_3': 23,
}
self.layer_idx = self.layer_map[layer]
def forward(self, pred, target):
"""
Args:
pred: (batch, 3, H, W) predicted images
target: (batch, 3, H, W) target images
"""
# Extract features
pred_features = self.extract_features(pred)
target_features = self.extract_features(target)
# MSE in feature space
loss = F.mse_loss(pred_features, target_features)
return loss
def extract_features(self, x):
"""Extract features from VGG layer"""
for i, layer in enumerate(self.vgg):
x = layer(x)
if i == self.layer_idx:
return x
return x
# Usage
perceptual_loss = PerceptualLoss(layer='relu3_3')
pixel_loss = nn.L1Loss()
# Combine pixel and perceptual loss
total_loss = pixel_loss(pred, target) + 0.1 * perceptual_loss(pred, target)
Example 5: Custom Weighted MSE
class WeightedMSELoss(nn.Module):
"""
MSE with per-element importance weighting
Good for:
- Focusing on important regions (e.g., foreground)
- Time-series with different importance
- Confidence-weighted regression
"""
def __init__(self):
super().__init__()
def forward(self, pred, target, weight):
"""
Args:
pred: (batch, ...) predictions
target: (batch, ...) targets
weight: (batch, ...) importance weights (0-1)
"""
# Element-wise squared error
squared_error = (pred - target) ** 2
# Weight by importance
weighted_error = squared_error * weight
# Average only over weighted elements
# (avoid counting zero-weight elements)
loss = weighted_error.sum() / (weight.sum() + 1e-8)
return loss
# Usage example: Foreground-focused loss
criterion = WeightedMSELoss()
# Create importance map (1.0 for foreground, 0.1 for background)
weight = torch.where(mask > 0.5, torch.tensor(1.0), torch.tensor(0.1))
loss = criterion(pred, target, weight)
Section 8: Advanced Loss Techniques
Technique 1: Label Smoothing
# Problem: Hard labels [0, 0, 1, 0, 0] cause overconfident predictions
# Solution: Soft labels [0.025, 0.025, 0.9, 0.025, 0.025]
# PyTorch 1.10+ built-in support
loss = F.cross_entropy(logits, target, label_smoothing=0.1)
# What it does:
# Original: y = [0, 0, 1, 0, 0]
# Smoothed: y = (1-α)*[0, 0, 1, 0, 0] + α*[0.2, 0.2, 0.2, 0.2, 0.2]
# = [0.02, 0.02, 0.92, 0.02, 0.02] (for α=0.1, num_classes=5)
# Manual implementation (for understanding):
class LabelSmoothingLoss(nn.Module):
def __init__(self, num_classes, smoothing=0.1):
super().__init__()
self.num_classes = num_classes
self.smoothing = smoothing
self.confidence = 1.0 - smoothing
def forward(self, logits, target):
"""
logits: (batch, num_classes)
target: (batch,) class indices
"""
log_probs = F.log_softmax(logits, dim=1)
# Create smooth labels
smooth_labels = torch.zeros_like(log_probs)
smooth_labels.fill_(self.smoothing / (self.num_classes - 1))
smooth_labels.scatter_(1, target.unsqueeze(1), self.confidence)
# NLL with smooth labels
loss = (-smooth_labels * log_probs).sum(dim=1)
return loss.mean()
# Benefits:
# 1. Better calibration (confidence closer to accuracy)
# 2. Prevents overconfidence
# 3. Acts as regularization
# 4. Often improves test accuracy by 0.5-1%
# When to use:
# ✅ Classification with CrossEntropyLoss
# ✅ Large models prone to overfitting
# ✅ Clean labels (not noisy)
# ❌ Small models (might hurt performance)
# ❌ Noisy labels (already have uncertainty)
Technique 2: Class-Balanced Loss
# Problem: 1000 samples class 0, 10 samples class 1
# Standard CE treats all samples equally → biased to class 0
# Solution 1: Inverse frequency weighting
class_counts = torch.bincount(train_labels)
class_weights = 1.0 / class_counts.float()
class_weights = class_weights / class_weights.sum() * len(class_weights)
loss = F.cross_entropy(logits, target, weight=class_weights)
# Solution 2: Effective number of samples (better for extreme imbalance)
def get_eff_num_weights(num_samples_per_class, beta=0.999):
"""
Effective number of samples: (1 - β^n) / (1 - β)
Handles extreme imbalance better than inverse frequency
Args:
num_samples_per_class: [n1, n2, ..., nC]
beta: Hyperparameter (0.99-0.9999), higher for more imbalance
"""
effective_num = 1.0 - torch.pow(beta, num_samples_per_class)
weights = (1.0 - beta) / effective_num
weights = weights / weights.sum() * len(weights)
return weights
# Usage
class_counts = torch.bincount(train_labels)
weights = get_eff_num_weights(class_counts.float(), beta=0.9999)
loss = F.cross_entropy(logits, target, weight=weights)
# Solution 3: Focal loss (see Example 2 in Section 7)
Technique 3: Mixup / CutMix Loss
# Mixup: Blend two samples and their labels
def mixup_data(x, y, alpha=1.0):
"""
Args:
x: (batch, ...) input
y: (batch,) labels
alpha: Mixup parameter
"""
lam = np.random.beta(alpha, alpha)
batch_size = x.size(0)
index = torch.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * x[index]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def mixup_criterion(pred, y_a, y_b, lam):
"""Compute mixed loss"""
return lam * F.cross_entropy(pred, y_a) + (1 - lam) * F.cross_entropy(pred, y_b)
# Training with Mixup
for x, y in train_loader:
x, y_a, y_b, lam = mixup_data(x, y, alpha=1.0)
optimizer.zero_grad()
pred = model(x)
loss = mixup_criterion(pred, y_a, y_b, lam)
loss.backward()
optimizer.step()
# Benefits:
# - Regularization
# - Better generalization
# - Smooth decision boundaries
# - +1-2% accuracy on CIFAR/ImageNet
Technique 4: Gradient Clipping for Loss Stability
# Problem: Loss spikes to NaN during training
# Often caused by exploding gradients
# Solution: Clip gradients before optimizer step
for x, y in train_loader:
optimizer.zero_grad()
pred = model(x)
loss = criterion(pred, y)
loss.backward()
# Clip gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Or clip by value:
# torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
optimizer.step()
# When to use:
# ✅ RNNs/LSTMs (prone to exploding gradients)
# ✅ Transformers with high learning rates
# ✅ Loss occasionally spikes to NaN
# ✅ Large models or deep networks
# ❌ Stable training (unnecessary overhead)
# How to choose max_norm:
# - Start with 1.0
# - If still unstable, reduce to 0.5
# - Monitor: print gradient norms to see if clipping activates
Technique 5: Loss Scaling for Mixed Precision
# Problem: Mixed precision (FP16) can cause gradients to underflow
# Solution: Scale loss up, then scale gradients down
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for x, y in train_loader:
optimizer.zero_grad()
# Forward in FP16
with autocast():
pred = model(x)
loss = criterion(pred, y)
# Scale loss and backward
scaler.scale(loss).backward()
# Unscale gradients and step
scaler.step(optimizer)
scaler.update()
# GradScaler automatically:
# 1. Scales loss by factor (e.g., 65536)
# 2. Backprop computes scaled gradients
# 3. Unscales gradients before optimizer step
# 4. Adjusts scale factor dynamically
Section 9: Common Loss Function Pitfalls
Pitfall 1: BCE Instead of BCEWithLogitsLoss
# ❌ WRONG (seen in 30% of beginner code!)
probs = torch.sigmoid(logits)
loss = F.binary_cross_entropy(probs, target)
# ✅ RIGHT
loss = F.binary_cross_entropy_with_logits(logits, target)
# Impact: Training instability, NaN losses, worse performance
# Fix time: 2 minutes
# Performance gain: Stable training, +2-5% accuracy
Pitfall 2: Softmax Before CrossEntropyLoss
# ❌ WRONG (seen in 20% of beginner code!)
probs = F.softmax(logits, dim=1)
loss = F.cross_entropy(probs, target)
# ✅ RIGHT
loss = F.cross_entropy(logits, target) # Expects logits!
# Impact: Suboptimal learning, double softmax
# Fix time: 1 minute
# Performance gain: +1-3% accuracy
Pitfall 3: Wrong Target Shape for CrossEntropyLoss
# ❌ WRONG: One-hot encoded targets
target = F.one_hot(labels, num_classes=10) # (batch, 10)
loss = F.cross_entropy(logits, target) # Type error!
# ✅ RIGHT: Class indices
target = labels # (batch,) with values in [0, 9]
loss = F.cross_entropy(logits, target)
# Impact: Runtime error or wrong loss computation
# Fix time: 2 minutes
Pitfall 4: Ignoring Class Imbalance
# ❌ WRONG: 95% negative, 5% positive
loss = F.binary_cross_entropy_with_logits(logits, target)
# Model predicts all negative → 95% accuracy but useless!
# ✅ RIGHT: Weight positive class
pos_weight = torch.tensor([19.0]) # 95/5
loss = F.binary_cross_entropy_with_logits(logits, target, pos_weight=pos_weight)
# Impact: Model learns trivial predictor
# Fix time: 5 minutes
# Performance gain: From useless to actually working
Pitfall 5: Not Normalizing Regression Targets
# ❌ WRONG: Targets in [1000, 10000], predictions in [0, 1]
loss = F.mse_loss(pred, target) # Huge loss, bad gradients
# ✅ RIGHT: Normalize targets
target_norm = (target - target.mean()) / target.std()
loss = F.mse_loss(pred, target_norm)
# Impact: Slow convergence, high loss values, need very small LR
# Fix time: 5 minutes
# Performance gain: 10-100x faster convergence
Pitfall 6: Unweighted Multi-Task Loss
# ❌ WRONG: Different scales
loss1 = F.cross_entropy(out1, target1) # ~0.5
loss2 = F.mse_loss(out2, target2) # ~500.0
total = loss1 + loss2 # Task 2 dominates!
# ✅ RIGHT: Balance scales
total = 1.0 * loss1 + 0.001 * loss2 # Both ~0.5
# Impact: One task learns, others ignored
# Fix time: 10 minutes (trial and error)
# Performance gain: All tasks learn instead of one
Pitfall 7: Division by Zero in Custom Loss
# ❌ WRONG: No epsilon
iou = intersection / union # Division by zero!
# ✅ RIGHT: Add epsilon
eps = 1e-8
iou = (intersection + eps) / (union + eps)
# Impact: NaN losses, training crash
# Fix time: 2 minutes
Pitfall 8: Missing optimizer.zero_grad()
# ❌ WRONG: Gradients accumulate!
for x, y in train_loader:
loss = criterion(model(x), y)
loss.backward()
optimizer.step() # Missing zero_grad!
# ✅ RIGHT: Reset gradients
for x, y in train_loader:
optimizer.zero_grad() # ✅ Critical!
loss = criterion(model(x), y)
loss.backward()
optimizer.step()
# Impact: Loss doesn't decrease, weird behavior
# Fix time: 1 minute
# This is caught by systematic debugging
Pitfall 9: Wrong Reduction for Custom Loss
# ❌ SUBOPTIMAL: Sum over batch
loss = (pred - target).pow(2).sum() # Loss scales with batch size!
# ✅ BETTER: Mean over batch
loss = (pred - target).pow(2).mean() # Loss independent of batch size
# Impact: Learning rate depends on batch size
# Fix time: 2 minutes
# When to use sum vs mean:
# - mean: Default, loss independent of batch size
# - sum: When you want loss to scale with batch size (rare)
# - none: When you want per-sample losses (for weighting)
Pitfall 10: Using Accuracy for Imbalanced Data
# ❌ WRONG: 95-5 imbalance
accuracy = (pred == target).float().mean() # 95% for trivial predictor!
# ✅ RIGHT: Use F1, precision, recall
from sklearn.metrics import f1_score, precision_score, recall_score
f1 = f1_score(target, pred) # Balanced metric
precision = precision_score(target, pred)
recall = recall_score(target, pred)
# Or use balanced accuracy:
balanced_acc = (recall_class0 + recall_class1) / 2
# Impact: Misinterpreting model performance
# Fix time: 5 minutes
Section 10: Loss Debugging Methodology
When Loss is NaN
# Step 1: Check inputs for NaN
print(f"Input has NaN: {torch.isnan(x).any()}")
print(f"Target has NaN: {torch.isnan(target).any()}")
if torch.isnan(x).any():
# Data loading issue
print("Fix: Check data preprocessing")
# Step 2: Check for numerical instability in loss
# - Division by zero
# - Log of zero or negative
# - Exp overflow
# Step 3: Check gradients before NaN
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm()
print(f"{name}: {grad_norm.item()}")
if grad_norm > 1000:
print(f"Exploding gradient in {name}")
# Step 4: Add gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Step 5: Lower learning rate
# LR too high can cause NaN
# Step 6: Check loss computation
# Add assertions in custom loss:
def custom_loss(pred, target):
loss = compute_loss(pred, target)
assert not torch.isnan(loss), f"Loss is NaN, pred range: [{pred.min()}, {pred.max()}]"
assert not torch.isinf(loss), f"Loss is inf"
return loss
When Loss Not Decreasing
# Systematic debugging checklist:
# 1. Check loss value
print(f"Loss: {loss.item()}")
# - Is it reasonable? (CE should be ~ln(num_classes) initially)
# - Is it constant? (optimizer not stepping)
# - Is it very high? (wrong scale)
# 2. Check gradients
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name} grad: mean={param.grad.abs().mean():.6f}, max={param.grad.abs().max():.6f}")
# If all gradients ~ 0:
# → Vanishing gradients (check activation functions, initialization)
# If gradients very large (>10):
# → Exploding gradients (add gradient clipping, lower LR)
# If no gradients printed:
# → Missing loss.backward() or parameters not requiring grad
# 3. Check predictions
print(f"Pred range: [{pred.min():.4f}, {pred.max():.4f}]")
print(f"Target range: [{target.min():.4f}, {target.max():.4f}]")
print(f"Pred mean: {pred.mean():.4f}, Target mean: {target.mean():.4f}")
# If predictions are constant:
# → Model not learning (check optimizer.step(), zero_grad())
# If predictions are random:
# → Model learning but task too hard or wrong loss
# If pred/target ranges very different:
# → Normalization issue
# 4. Verify training setup
print(f"Model training mode: {model.training}") # Should be True
print(f"Requires grad: {next(model.parameters()).requires_grad}") # Should be True
# Check optimizer.zero_grad() is called
# Check loss.backward() is called
# Check optimizer.step() is called
# 5. Check learning rate
print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
# Too low (< 1e-6): Won't learn
# Too high (> 1e-2): Unstable
# 6. Verify loss function matches task
# Classification → CrossEntropyLoss
# Regression → MSELoss or SmoothL1Loss
# Binary classification → BCEWithLogitsLoss
# 7. Check data
# Visualize a batch:
print(f"Batch input shape: {x.shape}")
print(f"Batch target shape: {target.shape}")
print(f"Target unique values: {target.unique()}")
# Are labels correct?
# Is data normalized?
# Any NaN in data?
# 8. Overfit single batch
# Can model fit one batch perfectly?
single_x, single_y = next(iter(train_loader))
for i in range(1000):
optimizer.zero_grad()
pred = model(single_x)
loss = criterion(pred, single_y)
loss.backward()
optimizer.step()
if i % 100 == 0:
print(f"Step {i}: Loss = {loss.item():.4f}")
# If can't overfit single batch:
# → Model architecture issue
# → Loss function wrong
# → Bug in training loop
When Loss Stuck at Same Value
# Scenario: Loss stays at 0.693 for binary classification (ln(2))
# Diagnosis: Model predicting 0.5 probability for all samples
# Possible causes:
# 1. Learning rate too low
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # Try 1e-3, 1e-4
# 2. Dead neurons (all ReLU outputs are 0)
# Check activations:
activations = model.fc1(x)
print(f"Activations: {activations.abs().mean()}")
if activations.abs().mean() < 0.01:
print("Dead neurons! Try:")
print("- Different initialization")
print("- LeakyReLU instead of ReLU")
print("- Lower learning rate")
# 3. Gradient flow blocked
# Check each layer:
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name}: {param.grad.abs().mean():.6f}")
else:
print(f"{name}: NO GRADIENT!")
# 4. Wrong optimizer state (if resuming training)
# Solution: Create fresh optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 5. Model too simple for task
# Try: Larger model, more layers, more parameters
# 6. Task is actually random
# Check: Can humans solve this task?
# Check: Is there signal in the data?
When Loss Oscillating / Unstable
# Scenario: Loss jumps around: 0.5 → 2.0 → 0.3 → 5.0 → ...
# Diagnosis: Unstable training
# Possible causes:
# 1. Learning rate too high
# Solution: Lower LR by 10x
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # Down from 1e-3
# 2. Batch size too small
# Solution: Increase batch size (more stable gradients)
train_loader = DataLoader(dataset, batch_size=64) # Up from 32
# 3. No gradient clipping
# Solution: Clip gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# 4. Numerical instability in loss
# Solution: Use stable loss functions
# BCEWithLogitsLoss instead of BCE
# Add epsilon to custom losses
# 5. Data outliers
# Solution:
# - Remove outliers
# - Use robust loss (L1, SmoothL1, Huber)
# - Clip targets to reasonable range
# 6. Exploding gradients
# Check and clip:
total_norm = 0
for p in model.parameters():
if p.grad is not None:
total_norm += p.grad.norm().item() ** 2
total_norm = total_norm ** 0.5
print(f"Gradient norm: {total_norm}")
if total_norm > 10:
print("Exploding gradients! Add gradient clipping.")
Rationalization Prevention Table
| Rationalization | Why It's Wrong | What You Must Do |
|---|---|---|
| "BCE is simpler than BCEWithLogitsLoss" | BCE is numerically unstable, causes NaN | ALWAYS use BCEWithLogitsLoss. Non-negotiable. |
| "Loss weighting is just extra hyperparameter tuning" | Unweighted multi-task losses fail completely | Check loss scales, weight them. One task will dominate otherwise. |
| "The optimizer will figure out the scale differences" | Optimizers don't balance losses, they follow gradients | Manual balance required. SGD sees gradient magnitude, not task importance. |
| "95% accuracy is great!" | With 95-5 imbalance, this is trivial predictor | Check F1/precision/recall. Accuracy misleading for imbalanced data. |
| "Data is clean, no need for epsilon" | Even clean data can hit edge cases (empty masks, zeros) | Add epsilon anyway. Cost is negligible, prevents NaN. |
| "Softmax before CE makes output clearer" | CE applies softmax internally, this causes double softmax | Pass logits to CE. Never apply softmax first. |
| "One-hot encoding is more standard" | CrossEntropyLoss expects class indices, not one-hot | Use class indices. Shape must be (batch,) not (batch, C). |
| "Reduction parameter is optional" | Controls how loss aggregates, affects training dynamics | Understand and choose: mean (default), sum (rare), none (per-sample). |
| "Just lower LR to fix NaN" | NaN usually from numerical instability, not LR | Fix root cause first: epsilon, clipping, stable loss. Then adjust LR. |
| "Papers use different loss, I should too" | Papers don't always use optimal losses, context matters | Evaluate if appropriate for your data/task. Don't blindly copy. |
| "Custom loss is more flexible" | Built-in losses are optimized and tested | Use built-ins when possible. Only custom when necessary. |
| "Loss function doesn't matter much" | Loss is THE OBJECTIVE your model optimizes | Loss choice is critical. Wrong loss = optimizing wrong thing. |
| "I'll tune loss later" | Loss should match task from the start | Choose correct loss immediately. Tuning won't fix fundamentally wrong loss. |
| "Focal loss is always better for imbalance" | Focal loss has hyperparameters, can hurt if tuned wrong | Try class weights first (simpler, fewer hyperparameters). |
| "Division by zero won't happen in practice" | Edge cases happen: empty batches, all-zero masks | Defensive programming: always add epsilon to denominators. |
Red Flags Checklist
When reviewing loss function code, watch for these RED FLAGS:
Critical (Fix Immediately):
- Using
F.binary_cross_entropyinstead ofF.binary_cross_entropy_with_logits - Applying
sigmoidorsoftmaxbefore stable loss (BCEWithLogitsLoss, CrossEntropyLoss) - Division without epsilon:
x / yinstead ofx / (y + 1e-8) - Log without clamping:
torch.log(x)instead oftorch.log(torch.clamp(x, min=1e-8)) - Missing
optimizer.zero_grad()in training loop - Multi-task losses added without weighting (different scales)
- Loss goes to NaN during training
Important (Fix Soon):
- Class imbalance ignored (no
weightorpos_weightparameter) - Regression targets not normalized (huge loss values)
- Wrong target shape for CrossEntropyLoss (one-hot instead of indices)
- Custom loss without numerical stability checks
- Using accuracy metric for highly imbalanced data
- No gradient clipping for RNNs/Transformers
- Reduction not specified in custom loss
Best Practices (Improve):
- No label smoothing for classification (consider adding)
- No focal loss for extreme imbalance (>100:1 ratio)
- Not monitoring individual task losses in multi-task learning
- Not checking gradient norms during training
- No assertions in custom loss for debugging
- Not testing loss function on toy data first
Summary: Loss Function Selection Flowchart
START
|
├─ Binary Classification?
| → BCEWithLogitsLoss + pos_weight for imbalance
|
├─ Multi-Class Classification?
| → CrossEntropyLoss + class weights for imbalance
| → Consider Focal Loss if extreme imbalance (>100:1)
|
├─ Multi-Label Classification?
| → BCEWithLogitsLoss + per-class pos_weight
|
├─ Regression?
| → SmoothL1Loss (good default)
| → MSELoss if no outliers
| → L1Loss if robust to outliers needed
| → ALWAYS normalize targets!
|
├─ Segmentation?
| → BCEWithLogitsLoss + DiceLoss (combine both)
| → Consider Focal Loss for imbalanced pixels
|
├─ Ranking/Similarity?
| → TripletMarginLoss or ContrastiveLoss
|
└─ Multi-Task?
→ Combine with careful weighting
→ Start with manual balance (check scales!)
→ Consider uncertainty weighting
ALWAYS:
✅ Use logits (no sigmoid/softmax before stable losses)
✅ Add epsilon to divisions and before log/sqrt
✅ Check for class/label imbalance
✅ Normalize regression targets
✅ Monitor loss values and gradients
✅ Test loss on toy data first
NEVER:
❌ Use BCE instead of BCEWithLogitsLoss
❌ Apply softmax before CrossEntropyLoss
❌ Ignore different scales in multi-task
❌ Divide without epsilon
❌ Trust accuracy alone for imbalanced data
Final Checklist Before Training
Before starting training, verify:
Loss Function Matches Task:
- Binary classification → BCEWithLogitsLoss
- Multi-class → CrossEntropyLoss
- Regression → SmoothL1Loss or MSE with normalized targets
Numerical Stability:
- Using stable loss (BCEWithLogitsLoss, not BCE)
- Epsilon in divisions:
x / (y + 1e-8) - Clamp before log:
torch.log(torch.clamp(x, min=1e-8))
Class Imbalance Handled:
- Checked class distribution
- Added
weightorpos_weightif imbalanced - Using F1/precision/recall metrics, not just accuracy
Multi-Task Weighting:
- Checked loss scales (printed first batch)
- Added manual weights or uncertainty weighting
- Monitoring individual task metrics
Target Preparation:
- CrossEntropyLoss: targets are class indices (batch,)
- BCEWithLogitsLoss: targets are 0/1 floats
- Regression: targets normalized to similar scale as predictions
Training Loop:
-
optimizer.zero_grad()before backward -
loss.backward()to compute gradients -
optimizer.step()to update parameters - Gradient clipping if using RNN/Transformer
-
Debugging Setup:
- Can print loss value:
loss.item() - Can print gradient norms
- Can visualize predictions vs targets
- Have tested overfitting single batch
- Can print loss value:
When to Seek Help
If after following this skill you still have loss issues:
Loss is NaN:
- Checked all numerical stability issues?
- Added epsilon everywhere?
- Tried gradient clipping?
- Lowered learning rate? → If still NaN, may need architecture change or data investigation
Loss not decreasing:
- Verified training loop is correct (zero_grad, backward, step)?
- Checked gradients are flowing?
- Tried overfitting single batch?
- Verified loss function matches task? → If still not decreasing, may be model capacity or data issue
Loss decreasing but metrics poor:
- Is loss the right objective for your metric?
- Example: CE minimizes NLL, not accuracy
- Consider metric-aware loss or post-hoc calibration
Multi-task learning not working:
- Tried multiple weighting strategies?
- Monitored individual task losses?
- Ensured all tasks getting gradient signal? → May need task-specific heads or curriculum learning
Remember: Loss function is the heart of deep learning. Get this right first before tuning everything else.
Additional Resources
Key Papers:
- Focal Loss: "Focal Loss for Dense Object Detection" (Lin et al., 2017)
- Label Smoothing: "Rethinking the Inception Architecture" (Szegedy et al., 2016)
- Uncertainty Weighting: "Multi-Task Learning Using Uncertainty to Weigh Losses" (Kendall et al., 2018)
- Class-Balanced Loss: "Class-Balanced Loss Based on Effective Number of Samples" (Cui et al., 2019)
PyTorch Documentation:
- Loss Functions: https://pytorch.org/docs/stable/nn.html#loss-functions
- Numerical Stability: Use built-in combined operations (BCEWithLogitsLoss, etc.)
Common Loss Functions Quick Reference:
# Classification
F.binary_cross_entropy_with_logits(logits, target, pos_weight=...)
F.cross_entropy(logits, target, weight=..., label_smoothing=...)
F.nll_loss(log_probs, target) # If you already have log_probs
# Regression
F.mse_loss(pred, target)
F.l1_loss(pred, target)
F.smooth_l1_loss(pred, target, beta=1.0)
F.huber_loss(pred, target, delta=1.0) # PyTorch 1.10+
# Ranking
F.margin_ranking_loss(input1, input2, target, margin=0.0)
F.triplet_margin_loss(anchor, positive, negative, margin=1.0)
F.cosine_embedding_loss(input1, input2, target)
# Other
F.kl_div(log_probs, target_probs) # KL divergence
F.poisson_nll_loss(log_input, target) # Poisson regression
END OF SKILL
When you use this skill, you become an expert in loss function selection and implementation. You will:
- Choose the correct loss for any task
- Ensure numerical stability
- Handle class imbalance appropriately
- Weight multi-task losses correctly
- Debug loss issues systematically
- Avoid all common loss function pitfalls
Remember: The loss function IS your model's objective. Get this right, and everything else follows.