| name | gradient-management |
| description | Gradient management - clipping, accumulation, scaling; prevent exploding/vanishing; mixed precision |
Gradient Management Skill
When to Use This Skill
Use this skill when:
- Loss becomes NaN or Inf during training
- Training is unstable with loss spikes
- User asks about gradient clipping
- User wants larger batch size but has OOM issues
- User mentions "exploding gradients" or "vanishing gradients"
- Gradients are very large (>100) or very small (<1e-8)
- Implementing gradient accumulation
- Using mixed precision (AMP) with gradient clipping
- User asks "why is my training unstable?"
- Training Transformers, RNNs, or very deep networks
- User implements gradient accumulation without loss scaling (RED FLAG)
- User clips gradients after optimizer.step() (RED FLAG)
- User doesn't unscale before clipping with AMP (RED FLAG)
- Reinforcement learning (policy gradients often explode)
- Distributed training with gradient synchronization questions
- User says "just lower learning rate" for NaN loss (may need clipping)
Do NOT use when:
- Training is stable with no gradient issues
- User has architecture questions unrelated to gradients
- User only asks about learning rate (use learning-rate-scheduling skill)
- User asks about data issues (different problem space)
Core Principles
1. The Critical Importance of Gradient Management
Gradients are the foundation of neural network training:
- Backpropagation computes gradients of loss w.r.t. parameters
- Optimizer uses gradients to update parameters
- Gradient magnitude determines update size
- Gradient stability determines training stability
- Wrong gradient handling → training failure (NaN, no convergence)
Common Impact:
- Gradient clipping: Difference between training and NaN loss
- Gradient accumulation: Train with 8x larger effective batch size on same hardware
- Proper diagnosis: 1-2 hours to fix vs days of confusion
- Mixed precision integration: 2x speedup without breaking training
This is NOT optional:
- Every Transformer paper uses gradient clipping
- Gradient accumulation is standard for large models
- Production training code always monitors gradients
- Ignoring gradients → fragile, unreliable training
2. Gradient Flow in Training
Understanding the training loop gradient flow:
# Step 1: Zero gradients from previous iteration
optimizer.zero_grad()
# Step 2: Forward pass (compute loss)
output = model(input)
loss = criterion(output, target)
# Step 3: Backward pass (compute gradients)
# This computes: param.grad = ∂loss/∂param for all parameters
loss.backward()
# Step 4: [OPTIONAL] Modify gradients (clipping, scaling, etc.)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Step 5: Optimizer step (update parameters using gradients)
# This does: param = param - lr * param.grad (simplified)
optimizer.step()
Critical ordering:
- Gradients are computed by
backward() - Gradients can be modified between
backward()andstep() - Gradients are consumed by
step()to update parameters - Gradient modifications MUST happen after
backward(), beforestep()
Mental model:
backward()produces gradients- Your code can inspect/modify gradients
step()consumes gradients to update parameters- Modifications after
step()are useless (gradients already consumed) - Modifications before
backward()are useless (gradients don't exist yet)
Gradient Clipping
Why Gradient Clipping Matters
The exploding gradients problem:
- Deep networks multiply gradients through chain rule
- Each layer multiplies gradient by weights and activation derivatives
- If these multiplications are >1, gradients grow exponentially
- Large gradients → large parameter updates → training instability
- Extremely large gradients → NaN or Inf loss
Real-world symptoms:
- Loss suddenly jumps to NaN after normal training
- Loss oscillates wildly between iterations
- Training is stable initially, then diverges
- Parameters become NaN or Inf
- Gradient norms >100 or >1000
Why it happens:
- Transformers: Attention mechanism can amplify gradients
- RNNs: Backpropagation through time multiplies gradients across timesteps
- Very deep networks: Many layers multiply gradients
- Poor initialization: Large initial weights amplify gradients
- High learning rates: Amplify already-large gradients
Norm-Based Gradient Clipping (Primary Method)
The standard solution:
# Clip gradients by global norm
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Complete training loop:
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
What it does:
- Computes total gradient norm:
total_norm = sqrt(sum(g^2 for g in all gradients)) - If
total_norm > max_norm:- Scaling factor =
max_norm / total_norm - All gradients multiplied by this factor
- Scaling factor =
- Result: Gradient direction preserved, magnitude limited
Why this is good:
- Preserves gradient direction (doesn't distort signal)
- Only scales when needed (if total_norm ≤ max_norm, no change)
- Global view (considers all parameters together)
- Mathematically elegant (scales gradient vector to unit ball)
Typical values for max_norm:
# Transformers (BERT, GPT, T5)
max_norm = 1.0 # Most common
max_norm = 5.0 # Sometimes used for very large models
# RNNs/LSTMs
max_norm = 0.5 # More aggressive clipping
max_norm = 1.0 # Also common
# Reinforcement Learning (policy gradients)
max_norm = 0.5 # RL gradients are particularly unstable
# CNNs (ResNets, etc.)
# Usually DON'T clip - residual connections provide stability
# Only clip if you observe instability
# Very deep networks (>100 layers)
max_norm = 1.0 # Helps with stability
When to use norm-based clipping: ✅ Training Transformers (almost always needed) ✅ Training RNNs/LSTMs (essential for long sequences) ✅ Reinforcement learning (policy gradients) ✅ Any time you see loss → NaN ✅ Loss spikes or wild oscillations ✅ Very deep networks (>50 layers)
When NOT to use: ❌ Stable CNN training (ResNet on ImageNet) ❌ Training is already stable with no issues ❌ As a preemptive measure without evidence of need
Value-Based Gradient Clipping (Rare)
Clips each gradient element individually:
# Clip each gradient value to [-clip_value, +clip_value]
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
# What it does:
for param in model.parameters():
if param.grad is not None:
param.grad.clamp_(-clip_value, clip_value)
Difference from norm-based:
- Norm-based: Scales entire gradient vector to limit total magnitude
- Value-based: Clamps each gradient element independently
- Value-based is MORE aggressive (can change gradient direction)
- Value-based treats all parameters equally (ignores scale differences)
When to use value-based clipping:
- Debugging: Identify which specific parameters have large gradients
- Extreme outliers: Some parameters have huge gradients while others are normal
- Legacy code: Some old papers use this
Usually prefer norm-based:
- Norm-based is standard in modern deep learning
- Preserves gradient direction
- Better theoretical properties
- Used in all major Transformer implementations
Complete Clipping Implementation
import torch
import torch.nn as nn
# Model and optimizer
model = TransformerModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# Training loop with gradient clipping
for epoch in range(num_epochs):
for batch in train_loader:
# 1. Zero gradients
optimizer.zero_grad()
# 2. Forward pass
output = model(batch['input'])
loss = criterion(output, batch['target'])
# 3. Backward pass (compute gradients)
loss.backward()
# 4. Clip gradients (CRITICAL: after backward, before step)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# 5. Optimizer step (update parameters)
optimizer.step()
Common mistakes - WRONG ORDER:
# WRONG: Clipping after optimizer.step()
loss.backward()
optimizer.step()
clip_grad_norm_(model.parameters(), 1.0) # ❌ Too late! Already updated.
# WRONG: Clipping before backward()
optimizer.zero_grad()
clip_grad_norm_(model.parameters(), 1.0) # ❌ No gradients exist yet!
loss.backward()
optimizer.step()
# RIGHT: Clipping after backward(), before step()
loss.backward() # Compute gradients
clip_grad_norm_(model.parameters(), 1.0) # Modify gradients
optimizer.step() # Use modified gradients
How to Choose max_norm Value
Start with standard values:
# Default starting point for Transformers
max_norm = 1.0
# If still unstable (loss spikes)
max_norm = 0.5 # More aggressive clipping
# If training seems too constrained (slow convergence)
max_norm = 2.0 # Less aggressive clipping
Systematic tuning:
Monitor gradient norms WITHOUT clipping:
# Check typical gradient magnitudes total_norm = 0.0 for p in model.parameters(): if p.grad is not None: param_norm = p.grad.data.norm(2) total_norm += param_norm.item() ** 2 total_norm = total_norm ** 0.5 print(f"Gradient norm: {total_norm:.4f}")Set max_norm based on typical norms:
- If typical norms are 0.5-2.0, set max_norm=2.0 or 3.0
- If typical norms are 5-10, set max_norm=5.0 or 10.0
- Goal: Clip outliers without affecting normal gradients
Verify clipping is helping:
# Log how often clipping activates grad_norm_before = compute_grad_norm(model) clip_grad_norm_(model.parameters(), max_norm=1.0) grad_norm_after = compute_grad_norm(model) if grad_norm_before > max_norm: print(f"Clipped: {grad_norm_before:.4f} -> {grad_norm_after:.4f}")
Signs you need clipping:
- Gradient norms occasionally >10 or >100
- Loss occasionally spikes or becomes NaN
- Training is initially stable then diverges
- Gradient norms grow over time
Signs your max_norm is too low:
- Clipping activates on EVERY iteration
- Training converges very slowly
- Gradient norm is always exactly max_norm (always clipping)
Signs your max_norm is too high:
- Still getting NaN or loss spikes
- Clipping never activates
- Not solving the stability problem
Gradient Accumulation
Why Gradient Accumulation Matters
The memory vs batch size problem:
- Larger batch sizes often improve training (more stable gradients)
- Larger batches require more GPU memory
- Memory is limited (GPU VRAM)
- Can't always fit desired batch size in memory
Example scenario:
- Want batch size 256 for stable training
- Only fit batch size 32 in GPU memory
- Can't afford bigger GPU
- Solution: Gradient accumulation
What gradient accumulation does:
- Accumulate gradients over multiple small batches
- Update parameters once with accumulated gradients
- Equivalent to training with one large batch
- Same results, but fits in memory
Real-world impact:
- Train models 4-8x larger batch size on same hardware
- Standard technique in production training
- Used in all large model training (GPT, BERT, etc.)
- Essential for competitive performance on limited hardware
Correct Gradient Accumulation Implementation
The critical implementation:
# Want effective batch size 256, but can only fit 64 in memory
# Solution: Accumulate over 4 steps (256 = 64 * 4)
accumulation_steps = 4
optimizer.zero_grad()
for i, (data, target) in enumerate(train_loader):
# Forward pass
output = model(data)
loss = criterion(output, target)
# Backward pass with CRITICAL loss scaling
# MUST divide loss by accumulation_steps!
(loss / accumulation_steps).backward()
# Update weights every accumulation_steps
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
Why scale loss by accumulation_steps?
# Without scaling:
loss.backward() # Adds gradients: param.grad += ∂loss/∂param
# After 4 accumulation steps:
# param.grad = ∂loss1/∂param + ∂loss2/∂param + ∂loss3/∂param + ∂loss4/∂param
# This is 4x larger than single batch!
# With scaling:
(loss / 4).backward() # Adds: param.grad += (∂loss/∂param) / 4
# After 4 accumulation steps:
# param.grad = (∂loss1/∂param + ∂loss2/∂param + ∂loss3/∂param + ∂loss4/∂param) / 4
# This is the AVERAGE gradient - equivalent to single large batch!
Mathematical equivalence:
- Large batch loss:
L = (l1 + l2 + l3 + l4) / 4(mean over samples) - Large batch gradient:
∂L/∂param = (∂l1/∂param + ∂l2/∂param + ∂l3/∂param + ∂l4/∂param) / 4 - Accumulated gradient: Same result!
Common Gradient Accumulation Mistakes
WRONG: Not scaling loss
# ❌ WRONG - Gradients are accumulation_steps times too large!
accumulation_steps = 4
for i, batch in enumerate(train_loader):
loss = criterion(model(batch), target)
loss.backward() # ❌ Not scaled!
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
# Result: Equivalent to learning_rate * accumulation_steps
# Acts like LR is 4x too high → unstable training
WRONG: Scaling gradients instead of loss
# ❌ WRONG - Inefficient and error-prone!
accumulation_steps = 4
for i, batch in enumerate(train_loader):
loss = criterion(model(batch), target)
loss.backward()
# Manually scale gradients
for param in model.parameters():
if param.grad is not None:
param.grad /= accumulation_steps # ❌ Inefficient!
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
# Why wrong:
# - More code, more error-prone
# - Less efficient (iterates all parameters)
# - Easy to forget or do incorrectly
# - Scaling loss is cleaner and standard
WRONG: Forgetting to zero_grad() after update
# ❌ WRONG - Gradients keep accumulating forever!
accumulation_steps = 4
for i, batch in enumerate(train_loader):
loss = criterion(model(batch), target)
(loss / accumulation_steps).backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
# ❌ Missing optimizer.zero_grad()!
# Next accumulation will add to these gradients!
# Result: Gradients never reset, accumulate across updates
# Acts like accumulation_steps grows over time
WRONG: Zeroing gradients inside accumulation loop
# ❌ WRONG - Clears gradients before accumulating!
accumulation_steps = 4
for i, batch in enumerate(train_loader):
optimizer.zero_grad() # ❌ Clears previous accumulation!
loss = criterion(model(batch), target)
(loss / accumulation_steps).backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
# Result: Only last batch's gradients are used (no accumulation!)
Complete Gradient Accumulation Implementation
import torch
import torch.nn as nn
# Configuration
batch_size_per_step = 64 # What fits in memory
accumulation_steps = 4 # Accumulate over 4 steps
effective_batch_size = batch_size_per_step * accumulation_steps # = 256
# DataLoader with smaller batch size
train_loader = DataLoader(dataset, batch_size=batch_size_per_step)
# Model and optimizer
model = TransformerModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# Training loop
optimizer.zero_grad() # Zero once before accumulation loop
for epoch in range(num_epochs):
for i, (data, target) in enumerate(train_loader):
# Forward pass
output = model(data)
loss = criterion(output, target)
# Backward pass with scaled loss
(loss / accumulation_steps).backward()
# Update every accumulation_steps
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
# Handle remaining batches at end of epoch
# (if total batches not divisible by accumulation_steps)
if len(train_loader) % accumulation_steps != 0:
optimizer.step()
optimizer.zero_grad()
Gradient Accumulation with Gradient Clipping
Correct order:
accumulation_steps = 4
optimizer.zero_grad()
for i, (data, target) in enumerate(train_loader):
output = model(data)
loss = criterion(output, target)
# Scale loss and backward
(loss / accumulation_steps).backward()
# Update every accumulation_steps
if (i + 1) % accumulation_steps == 0:
# Clip BEFORE optimizer step (on accumulated gradients)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
Why this order?
- Gradients accumulate over
accumulation_stepsiterations - After accumulation, gradients are ready for clipping
- Clip once on the full accumulated gradients
- Then update parameters with clipped gradients
WRONG: Clipping on each accumulation step:
# ❌ WRONG - Clips partial gradients!
for i, (data, target) in enumerate(train_loader):
(loss / accumulation_steps).backward()
# ❌ Clipping partial gradients!
clip_grad_norm_(model.parameters(), max_norm=1.0)
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
# Why wrong:
# - Clipping partial gradients distorts accumulation
# - Each partial gradient is ~1/4 of final gradient
# - Clipping these small gradients has wrong threshold
# - Clip ONCE on final accumulated gradient
Gradient Accumulation with Learning Rate Scheduling
Correct implementation:
accumulation_steps = 4
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)
optimizer.zero_grad()
for i, (data, target) in enumerate(train_loader):
output = model(data)
loss = criterion(output, target)
(loss / accumulation_steps).backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
# Step scheduler AFTER optimizer step (once per update)
scheduler.step()
Key points:
- Scheduler steps once per parameter update (not per batch)
- Matches the effective batch size timing
- Scheduler sees
num_batches / accumulation_stepstotal steps
Gradient Diagnosis
Why Diagnosis Matters
Don't guess - measure:
- "Training isn't working" could be many issues
- Gradient issues have specific symptoms
- Measuring gradients identifies the problem
- Diagnosis guides the solution
What to diagnose:
- Gradient magnitudes (too large? too small?)
- Gradient distribution across layers (vanishing in early layers?)
- NaN or Inf gradients (numerical issues?)
- Gradient patterns over time (growing? shrinking?)
Checking Gradient Magnitudes
Basic gradient checking:
def check_gradients(model):
"""Check gradient magnitudes for all parameters"""
total_norm = 0.0
param_norms = {}
for name, param in model.named_parameters():
if param.grad is not None:
# Compute gradient norm for this parameter
param_norm = param.grad.data.norm(2).item()
param_norms[name] = param_norm
total_norm += param_norm ** 2
total_norm = total_norm ** 0.5
print(f"Total gradient norm: {total_norm:.4f}")
# Show top 5 largest gradients
print("\nLargest gradients:")
for name, norm in sorted(param_norms.items(), key=lambda x: x[1], reverse=True)[:5]:
print(f" {name}: {norm:.4f}")
# Show top 5 smallest gradients
print("\nSmallest gradients:")
for name, norm in sorted(param_norms.items(), key=lambda x: x[1])[:5]:
print(f" {name}: {norm:.4e}")
return total_norm
# Usage in training loop:
loss.backward()
grad_norm = check_gradients(model)
optimizer.step()
What to look for:
# Healthy gradients:
# Total norm: 0.1 to 10
# Layer norms: Similar order of magnitude across layers
# No NaN or Inf values
# Exploding gradients:
# Total norm: >100 or >1000
# Some layers have huge gradients (>10)
# → Solution: Gradient clipping
# Vanishing gradients:
# Total norm: <1e-6
# Early layers have much smaller gradients than late layers
# → Solution: Better activation/initialization/architecture
# NaN gradients:
# Any gradient is NaN or Inf
# → Solution: Check for numerical instability in loss or model
Comprehensive Gradient Diagnostics
def diagnose_gradients(model, threshold_low=1e-8, threshold_high=100):
"""
Comprehensive gradient diagnostics with automatic issue detection
Args:
model: PyTorch model
threshold_low: Threshold for vanishing gradients
threshold_high: Threshold for exploding gradients
Returns:
dict with diagnostic information
"""
diagnostics = {
'total_norm': 0.0,
'param_norms': {},
'has_nan': False,
'has_inf': False,
'vanishing': [],
'exploding': [],
}
total_norm = 0.0
for name, param in model.named_parameters():
if param.grad is not None:
grad = param.grad.data
# Check for NaN or Inf
if torch.isnan(grad).any():
diagnostics['has_nan'] = True
print(f"⚠️ NaN gradient detected in {name}")
if torch.isinf(grad).any():
diagnostics['has_inf'] = True
print(f"⚠️ Inf gradient detected in {name}")
# Compute norm
param_norm = grad.norm(2).item()
diagnostics['param_norms'][name] = param_norm
total_norm += param_norm ** 2
# Check for vanishing
if param_norm < threshold_low:
diagnostics['vanishing'].append((name, param_norm))
# Check for exploding
if param_norm > threshold_high:
diagnostics['exploding'].append((name, param_norm))
total_norm = total_norm ** 0.5
diagnostics['total_norm'] = total_norm
# Print diagnosis
print(f"\n{'='*60}")
print(f"GRADIENT DIAGNOSTICS")
print(f"{'='*60}")
print(f"Total gradient norm: {total_norm:.4f}")
if diagnostics['has_nan']:
print("\n🚨 CRITICAL: NaN gradients detected!")
print(" Possible causes:")
print(" - Division by zero in loss or model")
print(" - Log of zero or negative number")
print(" - Numerical overflow")
print(" - Already-NaN parameters or inputs")
if diagnostics['has_inf']:
print("\n🚨 CRITICAL: Inf gradients detected!")
print(" Possible causes:")
print(" - Numerical overflow (very large values)")
print(" - Division by very small number")
print(" - Exponential of very large number")
if total_norm > threshold_high:
print(f"\n⚠️ EXPLODING GRADIENTS: Total norm {total_norm:.2f} > {threshold_high}")
print(" Solution: Add gradient clipping")
print(f" torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm={threshold_high/10:.1f})")
if diagnostics['exploding']:
print(f"\n Top exploding layers:")
for name, norm in sorted(diagnostics['exploding'], key=lambda x: x[1], reverse=True)[:5]:
print(f" - {name}: {norm:.2f}")
if total_norm < threshold_low:
print(f"\n⚠️ VANISHING GRADIENTS: Total norm {total_norm:.2e} < {threshold_low}")
print(" Possible solutions:")
print(" - Use ReLU/GELU instead of sigmoid/tanh")
print(" - Check weight initialization (use He/Xavier)")
print(" - Add batch normalization")
print(" - Add residual connections")
print(" - Increase learning rate (after other fixes)")
if diagnostics['vanishing']:
print(f"\n Layers with vanishing gradients:")
for name, norm in sorted(diagnostics['vanishing'], key=lambda x: x[1])[:5]:
print(f" - {name}: {norm:.2e}")
print(f"{'='*60}\n")
return diagnostics
# Usage:
loss.backward()
diagnostics = diagnose_gradients(model)
if diagnostics['has_nan'] or diagnostics['has_inf']:
# Stop training, fix the issue
raise RuntimeError("NaN or Inf gradients detected!")
Gradient Monitoring and Logging
Log gradient statistics during training:
import wandb # or tensorboard
def log_gradient_stats(model, logger, step):
"""Log gradient statistics for monitoring"""
total_norm = 0.0
layer_norms = {}
for name, param in model.named_parameters():
if param.grad is not None:
# Gradient norm
grad_norm = param.grad.data.norm(2).item()
layer_norms[name] = grad_norm
total_norm += grad_norm ** 2
# Parameter norm (for ratio calculation)
param_norm = param.data.norm(2).item()
# Log individual layer stats
logger.log({
f"gradients/{name}/norm": grad_norm,
f"gradients/{name}/mean": param.grad.data.mean().item(),
f"gradients/{name}/std": param.grad.data.std().item(),
f"gradients/{name}/max": param.grad.data.abs().max().item(),
}, step=step)
# Log ratio of gradient norm to parameter norm
# Healthy ratio is typically 0.001 to 0.01
if param_norm > 0:
ratio = grad_norm / param_norm
logger.log({f"gradients/{name}/ratio": ratio}, step=step)
total_norm = total_norm ** 0.5
# Log total gradient norm
logger.log({"gradients/total_norm": total_norm}, step=step)
return total_norm
# Usage in training loop:
for step, batch in enumerate(train_loader):
optimizer.zero_grad()
loss = model(batch)
loss.backward()
# Log gradients (before clipping to see true magnitudes)
grad_norm = log_gradient_stats(model, wandb, step)
# Clip and update
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
What to watch in gradient logs:
# Healthy training:
# - Total gradient norm: Relatively stable (0.1 to 10)
# - Layer norms: Similar across layers (no huge disparities)
# - Ratios: ~0.001 (gradients much smaller than parameters)
# - No sudden spikes or drops to zero
# Warning signs:
# - Total norm suddenly spikes (>100) → exploding gradients
# - Total norm gradually decreases to near-zero → vanishing gradients
# - Early layers have much smaller norms than late layers → vanishing
# - Ratios > 0.1 → updates are too large relative to parameters
# - Sudden drop to zero → dead neurons or broken gradient flow
Vanishing Gradients
Recognizing Vanishing Gradients
Symptoms:
- Training loss decreases very slowly or not at all
- Validation metrics don't improve
- Gradient norms are extremely small (<1e-6)
- Early layers have much smaller gradients than later layers
- Training seems "stuck" after initialization
How to confirm:
# Check gradient magnitudes by layer depth
loss.backward()
print("Layer-wise gradient norms:")
for name, param in model.named_parameters():
if param.grad is not None:
norm = param.grad.norm(2).item()
print(f"{name}: {norm:.2e}")
# Example output showing vanishing gradients:
# layer1.weight: 1.23e-02 ← Early layer
# layer5.weight: 3.45e-04
# layer10.weight: 8.91e-06
# layer15.weight: 2.34e-07
# layer20.weight: 5.67e-09 ← Late layer
# Pattern: Gradients shrink exponentially with depth
# This is vanishing gradients!
Causes of Vanishing Gradients
1. Too many layers (very deep networks):
- Each layer multiplies gradient by weights during backprop
- If multiplication factor <1, gradients shrink exponentially
- More layers = more multiplication = smaller gradients
2. Saturating activation functions:
- Sigmoid:
σ'(x) ≈ 0when|x|is large (saturates) - Tanh:
tanh'(x) ≈ 0when|x|is large - Gradient flows through:
grad = grad * activation'(x) - If
activation'(x) ≈ 0, gradient vanishes
3. Poor weight initialization:
- Weights too small → activations too small → gradients too small
- Weights initialized uniformly → improper scaling across layers
4. Learning rate too low:
- Not a root cause, but can make problem worse
- Tiny gradients * tiny LR = no learning
Solutions for Vanishing Gradients
Solution 1: Use Better Activation Functions
# AVOID: Sigmoid and Tanh (saturate easily)
class BadModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(100, 100),
nn.Sigmoid(), # ❌ Saturates, kills gradients
nn.Linear(100, 100),
nn.Sigmoid(), # ❌ Even worse with depth
nn.Linear(100, 10)
)
# PREFER: ReLU, GELU, or other non-saturating activations
class GoodModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(100, 100),
nn.ReLU(), # ✅ Doesn't saturate (for x>0)
nn.Linear(100, 100),
nn.GELU(), # ✅ Smooth, non-saturating
nn.Linear(100, 10)
)
# Why it helps:
# ReLU: grad = 1 for x>0, doesn't shrink gradient
# GELU: Smooth version of ReLU, widely used in Transformers
# Both avoid saturation that kills gradients
Solution 2: Proper Weight Initialization
# Use He initialization for ReLU networks
def init_weights(m):
if isinstance(m, nn.Linear):
# He initialization: optimal for ReLU
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
model = GoodModel()
model.apply(init_weights)
# Use Xavier initialization for Tanh/Sigmoid (if you must use them)
def init_weights_xavier(m):
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# Why it helps:
# Proper initialization ensures gradients have appropriate scale
# He init accounts for ReLU's effect on variance
# Xavier init maintains variance across layers for symmetric activations
Solution 3: Batch Normalization
# Add BatchNorm between layers
class ModelWithBatchNorm(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(100, 100),
nn.BatchNorm1d(100), # ✅ Normalizes activations
nn.ReLU(),
nn.Linear(100, 100),
nn.BatchNorm1d(100), # ✅ Helps gradient flow
nn.ReLU(),
nn.Linear(100, 10)
)
# For CNNs:
class CNNWithBatchNorm(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.BatchNorm2d(64), # ✅ After conv, before activation
nn.ReLU(),
nn.Conv2d(64, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
)
# Why it helps:
# BatchNorm normalizes activations to have mean=0, std=1
# Prevents activations from getting too small or too large
# Helps maintain gradient scale through network
# Widely used in modern architectures
Solution 4: Residual Connections (Skip Connections)
# Add skip connections (ResNet-style)
class ResidualBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(dim, dim),
nn.ReLU(),
nn.Linear(dim, dim)
)
def forward(self, x):
# Skip connection: add input to output
return x + self.layers(x) # ✅ Gradient flows through skip connection
class ResidualNetwork(nn.Module):
def __init__(self):
super().__init__()
self.blocks = nn.Sequential(
ResidualBlock(100),
ResidualBlock(100),
ResidualBlock(100),
# Can stack many blocks without vanishing gradients!
)
self.output = nn.Linear(100, 10)
def forward(self, x):
x = self.blocks(x)
return self.output(x)
# Why it helps:
# Gradients can flow directly through skip connections
# Backprop path: grad flows through addition (no multiplication)
# Allows training very deep networks (ResNet-152, ResNet-200)
# Essential for modern deep architectures
Solution 5: Layer Normalization (for Transformers)
# Transformers use Layer Normalization
class TransformerBlock(nn.Module):
def __init__(self, d_model):
super().__init__()
self.attention = MultiHeadAttention(d_model)
self.norm1 = nn.LayerNorm(d_model) # ✅ Layer norm
self.ffn = FeedForward(d_model)
self.norm2 = nn.LayerNorm(d_model) # ✅ Layer norm
def forward(self, x):
# Pre-norm architecture (modern standard)
x = x + self.attention(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
# Why Layer Norm:
# BatchNorm doesn't work well for sequences (different lengths)
# LayerNorm normalizes across features (not batch)
# Standard in Transformers (BERT, GPT, etc.)
Solution 6: Gradient Checkpointing (if memory-constrained)
# Trade computation for memory (from pytorch-engineering pack)
from torch.utils.checkpoint import checkpoint
class DeepModel(nn.Module):
def __init__(self):
super().__init__()
self.blocks = nn.ModuleList([
ResidualBlock(100) for _ in range(50) # Very deep!
])
def forward(self, x):
for block in self.blocks:
# Use checkpointing to save memory
x = checkpoint(block, x, use_reentrant=False)
return x
# Why it helps:
# Allows training deeper networks in same memory
# Doesn't directly solve vanishing gradients
# But removes memory constraint that prevents using deeper models
# Compatible with all other solutions (BN, residuals, etc.)
Systematic Approach to Vanishing Gradients
Step 1: Confirm diagnosis
# Check gradient magnitudes
loss.backward()
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name}: {param.grad.norm(2).item():.2e}")
# Look for: Early layers << Late layers
Step 2: Apply architectural fixes (priority order)
- Switch to ReLU/GELU activations (highest impact)
- Add proper weight initialization (He/Xavier)
- Add BatchNorm or LayerNorm
- Add residual connections if very deep (>20 layers)
Step 3: Verify improvement
# After fixes, check gradients again
# Should see more uniform gradient magnitudes across layers
Step 4: Adjust learning rate if needed
# Only AFTER architectural fixes
# May need slightly higher LR with better gradient flow
IMPORTANT NOTE: When Small Gradients Are Actually OK
Don't blindly "fix" small gradients if training is working well:
# Scenario: Gradients are small (1e-7) but training is progressing
# Epoch 1: Loss 2.34, Grad norm: 3.45e-07
# Epoch 2: Loss 1.89, Grad norm: 2.91e-07 ← Loss decreasing!
# Epoch 3: Loss 1.52, Grad norm: 2.34e-07 ← Still improving!
# This is OK! Don't fix what isn't broken.
Healthy small gradients:
- Training progressing (loss decreasing, metrics improving) ✓
- Gradients relatively uniform across layers
- Gradients stable over time
Unhealthy vanishing gradients:
- Training stuck (loss not decreasing)
- Early layers << late layers (1000x difference)
- Gradients decreasing over time
Key insight: Absolute gradient magnitude depends on parameter scale, loss scale, and learning rate. What matters is: Is the model learning?
# Better diagnostic: Check relative gradients across layers
grad_norms = {}
for name, param in model.named_parameters():
if param.grad is not None:
grad_norms[name] = param.grad.norm(2).item()
# Check ratio: Are early layers much smaller than late layers?
early_layers = [v for k, v in grad_norms.items() if 'layer0' in k or 'layer1' in k]
late_layers = [v for k, v in grad_norms.items() if 'layer19' in k or 'layer20' in k]
if early_layers and late_layers:
ratio = np.mean(late_layers) / np.mean(early_layers)
if ratio > 1000:
print(f"⚠️ Vanishing gradients: late/early ratio = {ratio:.0f}")
else:
print(f"✅ Gradient flow OK: late/early ratio = {ratio:.0f}")
Decision rule:
- Training working well + gradients stable → No action needed
- Training stuck + early << late → Apply architectural fixes
- Training working + improving over time → Monitor but don't change
Exploding Gradients
Recognizing Exploding Gradients
Symptoms:
- Loss suddenly becomes NaN or Inf during training
- Loss oscillates wildly (jumps up and down)
- Parameters become very large or NaN
- Gradient norms >100 or >1000
- Training is stable initially then suddenly diverges
How to confirm:
# Check gradient magnitudes
loss.backward()
total_norm = 0.0
for param in model.parameters():
if param.grad is not None:
param_norm = param.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
print(f"Total gradient norm: {total_norm:.4f}")
# If total_norm > 100: Exploding gradients!
# If any parameter grad norm > 100: Exploding gradients!
Causes of Exploding Gradients
1. Learning rate too high:
- Large gradients * large LR = huge parameter updates
- Updates overshoot optimal values
- Can cause oscillation or divergence
2. Poor weight initialization:
- Weights too large → activations too large → gradients too large
- Random initialization without proper scaling
3. Lack of gradient clipping:
- Occasional gradient spikes are normal in some architectures
- Without clipping, one spike can break training
4. Numerical instability in model:
- Division by very small numbers
- Exponential of large numbers
- Log of numbers close to zero
5. Architecture-specific issues:
- Transformers: Attention mechanism can amplify gradients
- RNNs: Backprop through time multiplies gradients across timesteps
- Very deep networks: Many layers multiply gradients
Solutions for Exploding Gradients
Solution 1: Gradient Clipping (Primary Solution)
# Add gradient clipping - THE solution for exploding gradients
optimizer.zero_grad()
loss.backward()
# Clip gradients to maximum norm
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
# Why this works:
# Limits gradient magnitude while preserving direction
# Prevents huge parameter updates
# Standard practice for Transformers, RNNs, RL
Solution 2: Lower Learning Rate
# If gradients are consistently large, try lower LR
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # Was 1e-3
# But NOTE:
# Gradient clipping is usually BETTER than just lowering LR
# Clipping handles occasional spikes without limiting normal gradients
# Lowering LR slows down ALL learning, even when gradients are normal
# Best approach: Use both
# - Gradient clipping for stability (handles spikes)
# - Reasonable learning rate for speed (not too high or too low)
Solution 3: Better Weight Initialization
# Use proper initialization
def init_weights(m):
if isinstance(m, nn.Linear):
# He initialization for ReLU
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
nn.init.zeros_(m.bias)
model.apply(init_weights)
# Why it helps:
# Proper initialization ensures weights are appropriate scale
# Prevents initial gradients from being too large
# Particularly important for very deep networks
Solution 4: Batch Normalization
# Add BatchNorm to stabilize training
class StableModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(100, 100),
nn.BatchNorm1d(100), # ✅ Stabilizes gradients
nn.ReLU(),
nn.Linear(100, 100),
nn.BatchNorm1d(100),
nn.ReLU(),
nn.Linear(100, 10)
)
# Why it helps:
# Normalizes activations, which stabilizes gradients
# Reduces internal covariate shift
# Makes training more robust to hyperparameter choices
Solution 5: Check for Numerical Issues
# AVOID: Operations that can cause numerical instability
# ❌ Division by small numbers
loss = 1.0 / (predictions + eps) # If predictions ≈ 0, loss explodes
# ✅ Add epsilon for stability
eps = 1e-8
loss = 1.0 / (predictions + eps)
# ❌ Log of values close to zero
loss = -torch.log(predictions) # If predictions ≈ 0, loss → -∞
# ✅ Add epsilon
loss = -torch.log(predictions + eps)
# ❌ Exp of large values
loss = torch.exp(logits) # If logits are large, exp explodes
# ✅ Use log-sum-exp trick or built-in stable functions
loss = F.cross_entropy(logits, targets) # Handles numerics internally
# ❌ Custom loss without stability
def unstable_loss(pred, target):
return ((pred - target) / pred).pow(2).mean() # Division can explode
# ✅ Add stability
def stable_loss(pred, target):
return ((pred - target) / (pred.abs() + eps)).pow(2).mean()
Solution 6: Use Residual Connections
# Residual connections help stability
class ResidualBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(dim, dim),
nn.ReLU(),
nn.Linear(dim, dim)
)
def forward(self, x):
return x + self.layers(x) # ✅ Skip connection provides stable path
# Why it helps:
# Gradients can flow through skip connections
# Prevents gradients from exploding through many layers
# Used in all modern deep architectures (ResNet, Transformer, etc.)
Systematic Approach to Exploding Gradients
Step 1: Confirm diagnosis
# Monitor gradient norms
loss.backward()
total_norm = sum(p.grad.data.norm(2).item() ** 2
for p in model.parameters() if p.grad is not None) ** 0.5
print(f"Gradient norm: {total_norm:.4f}")
# If norm > 100 or training diverges: Exploding gradients
Step 2: Apply fixes (priority order)
Add gradient clipping (highest priority, most effective)
clip_grad_norm_(model.parameters(), max_norm=1.0)Check learning rate (if still unstable after clipping)
optimizer = Adam(model.parameters(), lr=1e-4) # Try lowerVerify initialization (if problems from start of training)
model.apply(init_weights) # Use He/Xavier initCheck for numerical issues (if NaN appears)
# Add epsilon to divisions, logs, etc.
Step 3: Verify improvement
# Monitor gradient norms during training
# Should stay in reasonable range (0.1 to 10)
# No sudden spikes to >100
# No NaN or Inf
When Clipping Doesn't Fix NaN
If you've added gradient clipping but still get NaN loss:
The problem may be in your loss function, not gradients. Diagnose systematically:
# Step 1: Check if loss is NaN BEFORE backward()
optimizer.zero_grad()
output = model(batch)
loss = custom_loss(output, target)
# Check loss BEFORE backward
if torch.isnan(loss):
print("❌ Loss is NaN BEFORE backward - problem is in loss function!")
print(f" Output range: {output.min():.4f} to {output.max():.4f}")
print(f" Target range: {target.min():.4f} to {target.max():.4f}")
# Don't proceed with backward - fix loss function first
else:
print("✅ Loss is valid before backward")
loss.backward()
# Check gradients after backward
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"❌ NaN gradient in {name} - gradient issue")
Common loss function numerical issues:
# ❌ UNSTABLE: Log of zero or negative
def bad_loss(pred, target):
return -torch.log(pred).mean() # NaN if pred <= 0!
# ✅ STABLE: Add epsilon
def good_loss(pred, target):
eps = 1e-8
return -torch.log(pred + eps).mean()
---
# ❌ UNSTABLE: Division by zero or very small number
def bad_loss2(pred, target):
return (target / pred).mean() # Explodes if pred ≈ 0
# ✅ STABLE: Add epsilon
def good_loss2(pred, target):
eps = 1e-8
return (target / (pred + eps)).mean()
---
# ❌ UNSTABLE: Sqrt of negative (can happen with numerical errors)
def bad_loss3(pred, target):
diff = pred - target
return torch.sqrt(diff ** 2).mean() # Can get negative from rounding
# ✅ STABLE: Use abs or clamp
def good_loss3(pred, target):
diff = pred - target
return torch.sqrt(torch.clamp(diff ** 2, min=0)).mean()
---
# ❌ UNSTABLE: Exp of large values
def bad_loss4(logits):
return torch.exp(logits).sum() # Explodes if logits > 100
# ✅ STABLE: Use built-in stable functions
def good_loss4(logits, targets):
return F.cross_entropy(logits, targets) # Handles log-sum-exp internally
Diagnostic order when NaN appears:
Check loss before backward():
if torch.isnan(loss): ...- If NaN here → fix loss function (add epsilon, clamp, use stable functions)
- If not NaN → gradient issue
Check gradients after backward():
- If gradients are NaN → clipping placement correct? Unscaling (AMP)?
- If gradients OK → parameters NaN from previous update?
Check parameters:
for name, param in model.named_parameters(): if torch.isnan(param).any(): print(f"❌ NaN in parameter {name} - previous update caused NaN")
Summary decision tree:
Loss becomes NaN
│
├─ Check: Is loss NaN before backward()?
│ │
│ ├─ YES → Problem in loss function
│ │ • Add epsilon to divisions
│ │ • Add epsilon to logs
│ │ • Clamp inputs to sqrt
│ │ • Use stable built-in functions
│ │
│ └─ NO → Problem in backward/gradients
│ • Check gradient clipping is correctly placed
│ • Check unscaling if using AMP
│ • Check for numerical instability in model
│ • Verify proper initialization
Mixed Precision Training Integration
Gradient Clipping with AMP
The critical interaction:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
model = TransformerModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for batch in train_loader:
optimizer.zero_grad()
# Forward pass with autocast (mixed precision)
with autocast():
output = model(batch['input'])
loss = criterion(output, batch['target'])
# Backward pass (gradients are SCALED)
scaler.scale(loss).backward()
# CRITICAL: Unscale before clipping!
scaler.unscale_(optimizer)
# Now clip (on unscaled gradients)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Optimizer step (scaler handles it)
scaler.step(optimizer)
scaler.update()
Why unscale before clipping?
# Understanding the problem:
# GradScaler multiplies gradients by large factor (e.g., 2^16 = 65536)
# This prevents underflow in fp16 gradients
# But clipping should happen on TRUE gradient values, not scaled values
# WITHOUT unscaling:
scaler.scale(loss).backward() # Gradients are now 65536x larger
clip_grad_norm_(model.parameters(), max_norm=1.0) # ❌ Clips at 1.0
# But gradients are scaled! Effective clip threshold is 65536, not 1.0
# Clipping does nothing - gradients are rarely >65536
# WITH unscaling:
scaler.scale(loss).backward() # Gradients are 65536x larger
scaler.unscale_(optimizer) # Gradients back to true values
clip_grad_norm_(model.parameters(), max_norm=1.0) # ✅ Clips at true 1.0
# Clipping works correctly on true gradient magnitudes
The flow:
1. Forward pass with autocast() → activations in fp16
2. Compute loss (in fp16 or fp32 depending on operation)
3. scaler.scale(loss).backward() → multiply gradients by scale factor
4. scaler.unscale_(optimizer) → divide gradients by scale factor (back to true values)
5. clip_grad_norm_() → clip true gradient values
6. scaler.step(optimizer) → check for inf/NaN, update parameters if safe
7. scaler.update() → adjust scale factor for next iteration
Complete AMP + Clipping + Accumulation:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
accumulation_steps = 4
optimizer.zero_grad()
for i, batch in enumerate(train_loader):
# Forward pass with autocast
with autocast():
output = model(batch['input'])
loss = criterion(output, batch['target'])
# Scale loss for accumulation
scaled_loss = loss / accumulation_steps
# Backward pass (scaled)
scaler.scale(scaled_loss).backward()
# Update every accumulation_steps
if (i + 1) % accumulation_steps == 0:
# Unscale before clipping
scaler.unscale_(optimizer)
# Clip gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Optimizer step with scaler
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
Common AMP + Gradient Mistakes
WRONG: Not unscaling before clipping
# ❌ WRONG - Clipping scaled gradients
scaler.scale(loss).backward()
clip_grad_norm_(model.parameters(), max_norm=1.0) # ❌ On scaled gradients!
scaler.step(optimizer)
scaler.update()
# Result: Clipping doesn't work, training may diverge
WRONG: Unscaling multiple times
# ❌ WRONG - Unscaling twice
scaler.scale(loss).backward()
scaler.unscale_(optimizer) # Unscale once
clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.unscale_(optimizer) # ❌ Unscale again! Gradients now too small
scaler.step(optimizer)
# Result: Gradients become too small, slow training
WRONG: Calling step() directly instead of scaler.step()
# ❌ WRONG - Bypassing scaler
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step() # ❌ Should use scaler.step()!
scaler.update()
# Result: Scaler can't skip updates when inf/NaN detected
# Training may diverge from inf/NaN gradients
Advanced Topics
Per-Layer Gradient Clipping
When global clipping isn't enough:
def clip_grad_norm_per_layer(model, max_norm):
"""
Clip each layer's gradients independently
Use when some layers have much larger gradients than others
and global clipping is too aggressive or not aggressive enough
"""
for name, module in model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d, nn.MultiheadAttention)):
# Get parameters for this layer
params = [p for p in module.parameters() if p.grad is not None]
if params:
# Clip this layer's gradients
layer_norm = torch.nn.utils.clip_grad_norm_(params, max_norm)
# Log if clipping was applied
if layer_norm > max_norm:
print(f"Clipped {name}: {layer_norm:.4f} -> {max_norm}")
# Usage:
loss.backward()
clip_grad_norm_per_layer(model, max_norm=1.0)
optimizer.step()
# When to use:
# - Attention layers have much larger gradients than FFN layers
# - Some task heads have huge gradients while backbone is normal
# - Global clipping clips too much for some layers, too little for others
# Trade-off:
# ✅ More fine-grained control
# ❌ More complex, harder to tune
# ❌ Less common in literature (harder to compare)
Gradient Noise and Stability
Adding noise to gradients (advanced technique):
def add_gradient_noise(model, noise_scale=1e-3):
"""
Add Gaussian noise to gradients
Can help with:
- Escaping sharp minima (better generalization)
- Privacy (differential privacy)
- Exploration in RL
"""
for param in model.parameters():
if param.grad is not None:
noise = torch.randn_like(param.grad) * noise_scale
param.grad.add_(noise)
# Usage:
loss.backward()
add_gradient_noise(model, noise_scale=1e-3)
clip_grad_norm_(model.parameters(), max_norm=1.0) # Clip after adding noise
optimizer.step()
# When to use:
# - Research setting (exploring new techniques)
# - Differential privacy requirements
# - NOT recommended for standard training (adds complexity)
Gradient Checkpointing Interaction
Gradient checkpointing compatibility:
from torch.utils.checkpoint import checkpoint
# Gradient checkpointing (from pytorch-engineering pack)
# Trades computation for memory by recomputing activations during backward
class CheckpointedModel(nn.Module):
def __init__(self):
super().__init__()
self.blocks = nn.ModuleList([
TransformerBlock(dim=512) for _ in range(24)
])
def forward(self, x):
for block in self.blocks:
# Checkpoint each block
x = checkpoint(block, x, use_reentrant=False)
return x
# Training with checkpointing + clipping + accumulation:
model = CheckpointedModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
accumulation_steps = 4
optimizer.zero_grad()
for i, batch in enumerate(train_loader):
output = model(batch) # Uses checkpointing internally
loss = criterion(output, target)
(loss / accumulation_steps).backward() # Recomputes activations
if (i + 1) % accumulation_steps == 0:
# Clipping works normally (no special handling needed)
clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
# Compatibility:
# ✅ Gradient clipping: Works normally after backward()
# ✅ Gradient accumulation: No special handling needed
# ✅ Mixed precision: Combine with AMP as usual
# ✅ All gradient management techniques: Fully compatible
# Performance note:
# Checkpointing increases backward pass time by ~30-50%
# But enables training much larger models or batch sizes
# Trade computation for memory
Distributed Training Considerations
Gradient clipping in DDP (DistributedDataParallel):
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# Setup DDP
model = TransformerModel().cuda()
model = DDP(model, device_ids=[local_rank])
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for batch in train_loader:
optimizer.zero_grad()
output = model(batch)
loss = criterion(output, target)
loss.backward()
# Gradient clipping in DDP
# IMPORTANT: Clip AFTER backward() (gradients are already synchronized)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
# How DDP works:
# 1. Forward pass: Each GPU computes independently
# 2. Backward pass: Gradients computed on each GPU
# 3. Gradient synchronization: DDP averages gradients across GPUs (automatic)
# 4. Clipping: Happens AFTER synchronization (on averaged gradients)
# 5. Optimizer step: Each GPU updates identically (same gradients)
# Key points:
# ✅ Clip after backward() as usual - DDP handles synchronization automatically
# ✅ All GPUs see same averaged gradients, so clipping is consistent
# ❌ DON'T manually synchronize gradients (DDP does this)
# ❌ DON'T clip before backward() (gradients don't exist yet)
Gradient accumulation with DDP (Optimized):
IMPORTANT: DDP synchronizes gradients on every backward() by default. With accumulation, this is wasteful - we only need to sync ONCE per update. Use no_sync() to optimize.
from contextlib import nullcontext
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# Setup DDP
model = TransformerModel().cuda()
model = DDP(model, device_ids=[local_rank])
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
accumulation_steps = 4
optimizer.zero_grad()
for i, batch in enumerate(train_loader):
# Disable gradient synchronization for accumulation steps
# Only sync on the last accumulation step
is_accumulation_step = (i + 1) % accumulation_steps != 0
# Context manager: no_sync() when accumulating, normal when updating
with model.no_sync() if is_accumulation_step else nullcontext():
output = model(batch)
loss = criterion(output, target)
(loss / accumulation_steps).backward()
# Update on last accumulation step (gradients are now synchronized)
if (i + 1) % accumulation_steps == 0:
# Gradients are synchronized across all GPUs
clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
How this works:
WITHOUT no_sync() (inefficient):
Step 1: backward() → sync gradients across GPUs (communication!)
Step 2: backward() → sync gradients across GPUs (communication!)
Step 3: backward() → sync gradients across GPUs (communication!)
Step 4: backward() → sync gradients across GPUs (communication!)
optimizer.step() → update parameters
Total: 4 synchronizations per update
WITH no_sync() (optimized):
Step 1: backward() with no_sync() → no communication
Step 2: backward() with no_sync() → no communication
Step 3: backward() with no_sync() → no communication
Step 4: backward() without no_sync() → sync accumulated gradients (communication!)
optimizer.step() → update parameters
Total: 1 synchronization per update
Performance improvement: 3x less communication overhead
Why no_sync() is necessary:
- DDP normally synchronizes gradients on every backward() (default behavior)
- With accumulation, we only want to sync ONCE (on last step)
- no_sync() temporarily disables DDP's all-reduce operation
- On last step (without no_sync()), DDP performs normal synchronization
- Result: Accumulated gradients are synchronized once and correctly averaged
Complete DDP + Accumulation + Clipping + AMP:
from torch.cuda.amp import autocast, GradScaler
from contextlib import nullcontext
model = DDP(model, device_ids=[local_rank])
scaler = GradScaler()
accumulation_steps = 4
optimizer.zero_grad()
for i, batch in enumerate(train_loader):
is_accumulation_step = (i + 1) % accumulation_steps != 0
# Disable sync on accumulation steps
with model.no_sync() if is_accumulation_step else nullcontext():
# Mixed precision forward
with autocast():
output = model(batch)
loss = criterion(output, target)
# Scale and backward
scaled_loss = loss / accumulation_steps
scaler.scale(scaled_loss).backward()
# Update after accumulation
if (i + 1) % accumulation_steps == 0:
# Gradients now synchronized across GPUs
scaler.unscale_(optimizer) # Unscale for clipping
clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
# This combines ALL techniques correctly:
# ✅ DDP distributed training
# ✅ Gradient accumulation (with loss scaling)
# ✅ Mixed precision (with proper unscaling)
# ✅ Gradient clipping (on correct values)
# ✅ Optimized communication (no_sync())
Performance comparison:
# Measure with and without no_sync()
# WITHOUT no_sync(): ~40 seconds per epoch (excessive communication)
# WITH no_sync(): ~12 seconds per epoch (optimized communication)
# Speedup: 3.3x faster with accumulation_steps=4
# The more GPUs you have, the more important no_sync() becomes
# 2 GPUs: ~2x speedup
# 4 GPUs: ~3x speedup
# 8 GPUs: ~4x speedup
Common mistake:
# ❌ WRONG - Synchronizing on every step (slow!)
model = DDP(model)
accumulation_steps = 4
for i, batch in enumerate(train_loader):
(loss / accumulation_steps).backward() # Syncs every time!
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
# Result: Correct results but 3-4x slower than necessary
Common Gradient Pitfalls
Pitfall 1: Not Clipping When Needed
Symptom: Training becomes NaN after few epochs, loss spikes
WRONG:
# User sees NaN loss and thinks: "Must be learning rate"
optimizer = Adam(model.parameters(), lr=1e-5) # ❌ Lower LR to "fix" it
# Result: Training is slow and may still diverge
# Root cause (exploding gradients) not addressed
RIGHT:
# Recognize exploding gradients, add clipping
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
# Result: Training is stable, no NaN
# This is THE solution for exploding gradients
Pitfall 2: Wrong Gradient Accumulation Scaling
Symptom: Gradient accumulation gives worse results than small batch
WRONG:
# ❌ Not scaling loss
accumulation_steps = 4
for i, batch in enumerate(train_loader):
loss = criterion(model(batch), target)
loss.backward() # ❌ Gradients are 4x too large!
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
RIGHT:
# ✅ Scale loss by accumulation_steps
accumulation_steps = 4
for i, batch in enumerate(train_loader):
loss = criterion(model(batch), target)
(loss / accumulation_steps).backward() # ✅ Correct scaling
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
Pitfall 3: Clipping After optimizer.step()
Symptom: Clipping doesn't help, training still unstable
WRONG:
# ❌ Clipping after step (useless!)
loss.backward()
optimizer.step()
clip_grad_norm_(model.parameters(), max_norm=1.0) # ❌ Too late!
RIGHT:
# ✅ Clipping after backward, before step
loss.backward()
clip_grad_norm_(model.parameters(), max_norm=1.0) # ✅ Correct timing
optimizer.step()
Pitfall 4: Not Unscaling Before Clipping (AMP)
Symptom: Mixed precision training diverges, regular training works
WRONG:
# ❌ Clipping scaled gradients
scaler.scale(loss).backward()
clip_grad_norm_(model.parameters(), max_norm=1.0) # ❌ Wrong scale!
scaler.step(optimizer)
scaler.update()
RIGHT:
# ✅ Unscale before clipping
scaler.scale(loss).backward()
scaler.unscale_(optimizer) # ✅ Unscale first!
clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
Pitfall 5: Forgetting to zero_grad() After Accumulation
Symptom: Loss decreases then increases, training unstable
WRONG:
# ❌ Missing zero_grad() after update
accumulation_steps = 4
for i, batch in enumerate(train_loader):
(loss / accumulation_steps).backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
# ❌ Missing optimizer.zero_grad()!
RIGHT:
# ✅ Zero gradients after update
accumulation_steps = 4
for i, batch in enumerate(train_loader):
(loss / accumulation_steps).backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad() # ✅ Clear gradients for next accumulation
Pitfall 6: Using Value Clipping Instead of Norm Clipping
Symptom: Training works but slower convergence than expected
SUBOPTIMAL:
# Value clipping changes gradient direction
clip_grad_value_(model.parameters(), clip_value=0.5) # Can distort gradients
BETTER:
# Norm clipping preserves direction
clip_grad_norm_(model.parameters(), max_norm=1.0) # Preferred method
Pitfall 7: Applying Clipping to All Models
Symptom: Unnecessarily slow training, limiting gradient flow
WRONG:
# ❌ Clipping when not needed (ResNet on ImageNet)
model = ResNet50()
optimizer = SGD(model.parameters(), lr=0.1)
for batch in train_loader:
loss.backward()
clip_grad_norm_(model.parameters(), max_norm=1.0) # ❌ Not needed!
optimizer.step()
# Result: Limits gradient flow, may slow convergence
RIGHT:
# ✅ Only clip when needed (training is unstable)
model = ResNet50()
optimizer = SGD(model.parameters(), lr=0.1)
for batch in train_loader:
loss.backward()
# No clipping - ResNets are naturally stable
optimizer.step()
# Only add clipping if you observe:
# - Loss becomes NaN
# - Loss spikes
# - Training instability
Pitfall 8: Not Monitoring Gradients
Symptom: Training fails, no visibility into why
WRONG:
# ❌ No gradient monitoring
for batch in train_loader:
loss = train_step(batch)
# Training fails, no idea why
RIGHT:
# ✅ Monitor gradient norms
for step, batch in enumerate(train_loader):
optimizer.zero_grad()
loss = criterion(model(batch), target)
loss.backward()
# Monitor gradients
if step % 100 == 0:
total_norm = 0.0
for p in model.parameters():
if p.grad is not None:
total_norm += p.grad.data.norm(2).item() ** 2
total_norm = total_norm ** 0.5
print(f"Step {step}, Loss: {loss.item():.4f}, Grad norm: {total_norm:.4f}")
clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
# Now you can see:
# - When gradients explode (norm suddenly large)
# - When gradients vanish (norm goes to zero)
# - How clipping affects training
Pitfall 9: Wrong DDP Gradient Synchronization
Symptom: DDP with accumulation slower than expected or wrong results
WRONG:
# ❌ DDP synchronizes on every backward (wasteful with accumulation)
model = DDP(model)
accumulation_steps = 4
for i, batch in enumerate(train_loader):
(loss / accumulation_steps).backward() # ❌ Syncs every time!
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
RIGHT:
# ✅ Disable sync except on last accumulation step
model = DDP(model)
accumulation_steps = 4
for i, batch in enumerate(train_loader):
is_accumulation_step = (i + 1) % accumulation_steps != 0
with model.no_sync() if is_accumulation_step else nullcontext():
(loss / accumulation_steps).backward()
if (i + 1) % accumulation_steps == 0:
clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
Pitfall 10: Clipping Too Aggressively
Symptom: Training converges very slowly, gradient norm always at max_norm
WRONG:
# ❌ max_norm too low, clipping every iteration
clip_grad_norm_(model.parameters(), max_norm=0.01) # Way too aggressive!
# Result: All gradients clipped, learning very slow
RIGHT:
# ✅ Monitor and tune max_norm appropriately
# Check typical gradient norms without clipping
total_norm = compute_grad_norm(model)
print(f"Gradient norm: {total_norm:.4f}")
# Set max_norm to clip outliers, not normal gradients
# If typical norms are 0.5-2.0, set max_norm=5.0
clip_grad_norm_(model.parameters(), max_norm=5.0) # Clips outliers only
Rationalization Prevention Table
| When Agent Wants To Say | STOP - Say This Instead |
|---|---|
| "Just lower the learning rate" | "This is likely exploding gradients. Add gradient clipping: clip_grad_norm_(model.parameters(), max_norm=1.0) BEFORE optimizer.step(). Then adjust LR if still needed." |
| "Try a smaller model to save memory" | "Use gradient accumulation to train with larger effective batch size: Scale loss by accumulation_steps and update every N batches. This is standard practice." |
| "Gradient accumulation is complicated" | "It's actually simple: (loss / accumulation_steps).backward() to accumulate, optimizer.step() every N batches. MUST scale loss - this is critical." |
| "Mixed precision doesn't work with clipping" | "AMP + clipping work together perfectly. You MUST unscale before clipping: scaler.unscale_(optimizer) then clip_grad_norm_(). This is documented and standard." |
| "Your gradients are too small, just increase LR" | "This is vanishing gradients. Architectural fixes are needed: Use ReLU/GELU activations, proper initialization (He/Xavier), BatchNorm, and residual connections. Increasing LR alone won't fix it." |
| "Clipping is a hack, don't use it" | "Clipping is standard practice in Transformers, RNNs, and RL. Every major paper (BERT, GPT, etc.) uses gradient clipping. It's essential for training stability, not a hack." |
| "The paper didn't use clipping, so you shouldn't" | "Papers don't always document all techniques. Clipping may have been used but not mentioned. If you observe instability (NaN, spikes), add clipping regardless of what paper says." |
| "Try different optimizer, maybe SGD works better" | "Switching optimizer doesn't fix exploding gradients. Add gradient clipping first, then compare optimizers. Clipping works with any optimizer." |
| "Gradient issues are mysterious and hard to debug" | "Gradient issues are systematic: Check gradient norms. >100 = exploding (clip). <1e-6 = vanishing (fix activations/init). NaN = numerical instability (check loss/model). Clear diagnosis → clear solution." |
| "You can clip anytime in the training loop" | "Clipping MUST happen after backward(), before step(). Timing is critical: backward() creates gradients, clip() modifies them, step() consumes them. Wrong order = useless clipping." |
| "Scale gradients instead of loss for accumulation" | "Scale LOSS, not gradients: (loss / accumulation_steps).backward(). Scaling gradients manually is error-prone and inefficient. Loss scaling is the standard, correct way." |
| "Batch norm is optional for deep networks" | "BatchNorm is critical for very deep networks (>20 layers). It normalizes activations and stabilizes gradients. Essential for training stability. Use unless you have specific reason not to." |
| "Residual connections are just a fancy trick" | "Residual connections are fundamental to training deep networks (>50 layers). They provide direct gradient path, preventing vanishing gradients. ResNet, Transformer - all use residuals." |
| "Just clip more aggressively (max_norm=0.01)" | "Too-aggressive clipping limits all gradients, slowing learning. Monitor typical gradient norms. Set max_norm to clip outliers (>100) without affecting normal gradients (1-10)." |
| "DDP handles everything automatically" | "DDP synchronizes gradients on backward(). For accumulation, use model.no_sync() on intermediate steps to avoid unnecessary synchronization. Only sync on final accumulation step." |
| "Your model is too complex, that's why training fails" | "Model complexity alone doesn't cause training failure. Gradient issues do. Diagnose gradients first. Most complex models (GPT-3, etc.) train successfully with proper gradient management." |
| "Checkpointing and clipping don't work together" | "They're fully compatible. Checkpoint affects forward/backward computation. Clipping affects gradients after backward(). No interaction - use both together freely." |
| "You need expensive GPUs for large batches" | "Use gradient accumulation for larger effective batches on any GPU. Accumulate over N steps = N× batch size, same memory. Standard technique for training large models on consumer hardware." |
| "Loss → NaN means your data has NaN" | "Usually exploding gradients, not data. Check gradient norms. If >100, add clipping. Data NaN would cause issues immediately, not after several epochs." |
| "Transformers just don't train stably" | "Transformers train extremely well with proper gradient management. BERT, GPT, T5 - all trained successfully. Use gradient clipping (max_norm=1.0), proper LR, and you'll have stable training." |
| "Clipping is expensive, will slow training significantly" | "Clipping overhead is <1%, not 5-10%. It's computing gradient norms (one pass) then scaling. Much cheaper than backward pass. 1% cost to prevent catastrophic training failure is excellent trade-off." |
| "I added clipping but still get NaN, it doesn't work" | "Check if loss is NaN BEFORE backward(). If yes, problem is in loss function (add epsilon to divisions/logs), not gradients. If no, check clipping placement and AMP unscaling." |
| "Accumulation scaling depends on batch content" | "NO. Accumulation scaling is ALWAYS accumulation_steps (constant). Sample weighting is separate concern (handled in loss function). Don't confuse these two independent concepts." |
| "Paper doesn't mention clipping, so I shouldn't use it" | "Papers don't document all implementation details. Clipping may have been used but not mentioned. Check official code if available. If your training is unstable, add clipping - stable training is prerequisite for valid comparison." |
| "Different sources give conflicting advice on clipping" | "Context matters. Transformers/RNNs usually need clipping. CNNs usually don't. Decide based on YOUR architecture and stability. Monitor gradient norms. If you see spikes >100 or NaN, add clipping. Empiricism over dogma." |
| "Use PyTorch Lightning so I need to manually add clipping" | "Lightning has built-in clipping: Trainer(gradient_clip_val=1.0, gradient_clip_algorithm='norm'). No manual code needed. Check your framework docs - most have built-in gradient management features." |
| "My model is complex so gradients will always be problematic" | "Model complexity doesn't determine gradient behavior. GPT-3 (175B parameters) trains successfully. Proper gradient management (clipping, architecture, initialization) enables training ANY size model. Complexity is not the issue." |
| "Small gradients mean training is broken" | "Only if training is stuck. If loss is decreasing and metrics improving, small absolute gradient values are OK. What matters: relative gradients across layers and whether learning is happening. Don't fix what isn't broken." |
Red Flags Checklist
When you see these patterns, IMMEDIATELY suggest gradient management solutions:
Critical Red Flags (Fix Immediately)
🚨 Loss becomes NaN during training
- Cause: Exploding gradients or numerical instability
- Solution: Add gradient clipping + check for numerical issues in loss/model
🚨 User implements gradient accumulation without scaling loss
# ❌ RED FLAG
loss.backward() # Should be: (loss / accumulation_steps).backward()
- Impact: Gradients are accumulation_steps times too large
- Solution: Scale loss by accumulation_steps
🚨 User clips gradients after optimizer.step()
# ❌ RED FLAG
optimizer.step()
clip_grad_norm_(...) # Too late!
- Impact: Clipping does nothing (gradients already consumed)
- Solution: Move clipping between backward() and step()
🚨 User uses AMP + clipping without unscaling
# ❌ RED FLAG
scaler.scale(loss).backward()
clip_grad_norm_(...) # Should unscale first!
- Impact: Clipping wrong magnitude (on scaled gradients)
- Solution: Add scaler.unscale_(optimizer) before clipping
Warning Signs (Suggest Improvements)
⚠️ Training transformers/RNNs without gradient clipping
- Likely to hit exploding gradients eventually
- Suggest preemptive clipping (max_norm=1.0)
⚠️ Very deep network (>20 layers) with sigmoid/tanh activations
- Vanishing gradients likely
- Suggest ReLU/GELU + BatchNorm + residual connections
⚠️ User says "want larger batch but OOM"
- Perfect use case for gradient accumulation
- Explain technique and correct implementation
⚠️ Gradient norms consistently >10 or <1e-6
- Exploding or vanishing gradients
- Diagnose and suggest appropriate solution
⚠️ User lowers learning rate to fix NaN loss
- Treating symptom, not cause
- Likely exploding gradients - suggest clipping
⚠️ DDP training with gradient accumulation, no no_sync()
- Inefficient (synchronizing unnecessarily)
- Suggest no_sync() on accumulation steps
⚠️ User asks "is gradient clipping necessary?"
- Depends on architecture and stability
- Provide decision criteria (Transformers: yes, CNNs: maybe not)
⚠️ Custom loss function with divisions or logs
- Potential numerical instability
- Check for epsilon additions and proper handling
Optimization Opportunities (Mention If Relevant)
💡 User monitors loss but not gradients
- Suggest logging gradient norms for better visibility
💡 User training large model on single GPU with small batch
- Suggest gradient accumulation for better results
💡 Gradient clipping activates every iteration
- max_norm might be too low
- Suggest monitoring and tuning threshold
💡 Using value clipping instead of norm clipping
- Suggest norm clipping (preserves direction)
Summary
Gradient management is essential for reliable training:
Gradient Clipping
- PRIMARY solution for exploding gradients (NaN, spikes)
- Use norm-based clipping:
clip_grad_norm_(model.parameters(), max_norm=1.0) - Place after backward(), before step()
- Standard for Transformers, RNNs, RL
Gradient Accumulation
- Train with larger effective batch size on same hardware
- MUST scale loss:
(loss / accumulation_steps).backward() - Update every N steps, zero_grad() after update
- Standard technique in production training
Gradient Diagnosis
- Don't guess - measure gradient norms
100: Exploding (clip)
- <1e-6: Vanishing (fix architecture)
- NaN: Numerical issues (check loss/model)
Vanishing Gradients
- Use ReLU/GELU activations (not sigmoid/tanh)
- Proper initialization (He for ReLU, Xavier for tanh)
- Add BatchNorm/LayerNorm
- Add residual connections for deep networks
Exploding Gradients
- Add gradient clipping (primary solution)
- Check learning rate (secondary)
- Verify initialization
- Check for numerical issues
Mixed Precision Integration
- MUST unscale before clipping:
scaler.unscale_(optimizer) - Then clip on true gradient values
- Standard pattern in AMP training
- MUST unscale before clipping:
Common Pitfalls
- Not scaling loss in accumulation (gradients too large)
- Clipping after step() (useless)
- Not unscaling before clipping in AMP
- Forgetting zero_grad() after accumulation
- Not monitoring gradients (no visibility)
This is NOT optional:
- Gradient management determines training success or failure
- Every production training system handles gradients properly
- The difference between reliable training and mysterious failures
Master these techniques and you'll have stable, efficient training.