| name | batch-size-and-memory-tradeoffs |
| description | Batch size selection - memory, convergence, generalization, gradient accumulation |
Batch Size and Memory Tradeoffs
Overview
Batch size is one of the most misunderstood hyperparameters. Most engineers think: "larger batch = faster training = better". Wrong. Batch size affects convergence speed, generalization, memory usage, and actual wall-clock training time in complex ways. Larger batch size is NOT always better.
Core principle: Batch size selection is a system optimization problem, not a memory constraint problem. Choose batch size based on computational speed, convergence requirements, and generalization targets - not just what fits in memory.
When to Use This Skill
Use this skill when:
- Choosing batch size for new training
- Training is slow and considering larger batches
- Out-of-memory errors during training
- Learning rate needs adjustment after batch size change
- Distributed training needs batch size scaling
- Gradient accumulation considerations
- User asks "what batch size should I use?"
- Training accuracy varies widely between batch sizes
- Convergence takes too long or is unstable
- Memory per sample calculation needed
- Comparing training speed: iterations vs epochs vs wall-clock time
- Fine-tuning with different batch sizes than pre-training
Symptoms you need this skill:
- "I have memory, what's the maximum batch size?" (wrong question)
- "Larger batches train faster, so use 512?" (incomplete)
- "Batch size doesn't affect accuracy, only speed?" (false)
- "Gradient accumulation is a workaround for small memory?" (misconception)
- "Just scale learning rate by 2x when doubling batch size?" (incomplete)
- "We get OOM at batch 256, so use 128 forever" (not optimized)
Don't use when:
- User has pure memory/infrastructure questions (use pytorch-engineering)
- User asks about optimizer selection (use optimizer-selection-framework)
- User asks about learning rate scheduling (use learning-rate-scheduling)
- User has general training failure (not batch-size specific)
Core Patterns
Pattern 1: The Batch Size Tradeoff Space
The critical insight: Batch size affects FOUR independent dimensions simultaneously. Optimize one = impact others.
The four dimensions:
1. TRAINING SPEED (iterations to converge)
├─ Larger batch → fewer iterations to convergence ✓
├─ BUT: Gradient variance decreases (noisier gradients are better)
└─ Result: Mixed - can't just maximize batch
2. COMPUTATIONAL EFFICIENCY (wall-clock time)
├─ Larger batch → amortize overhead per sample ✓
├─ BUT: Larger batch → need larger LR (unstable)
├─ AND: Gradient accumulation = repeated backward (slow)
└─ Result: Optimal ≠ Maximum
3. GENERALIZATION (test accuracy)
├─ Smaller batch → noisier gradients → better regularization ✓
├─ Larger batch → cleaner gradient → overfit risk ✗
├─ BUT: Can compensate with stronger regularization
└─ Result: Batch size ↔ regularization coupling
4. MEMORY USAGE (GPU memory required)
├─ Larger batch → linear increase in activation memory
├─ Parameters constant regardless of batch
├─ Optimizer state constant regardless of batch
└─ Result: Memory ∝ batch size (linear only for activations)
The mental model:
LARGER BATCH:
✓ Fewer iterations to convergence
✓ Better computational efficiency (up to point)
✗ Worse generalization (harder to regularize)
✗ Requires larger learning rate (instability risk)
✗ Higher memory usage
SMALLER BATCH:
✗ More iterations to convergence
✗ Worse computational efficiency
✓ Better generalization (noise helps)
✓ Smaller learning rates are stable
✓ Lower memory usage
Finding the sweet spot:
- Start with batch size that uses ~80% GPU memory
- Adjust learning rate using linear scaling rule
- Monitor validation accuracy
- If validation accuracy drops → batch too large, reduce or regularize
- If training is slow → may need gradient accumulation, not larger batch
Pattern 2: Linear Learning Rate Scaling Rule
The rule that changes everything:
If you increase batch size by factor K, increase learning rate by factor K.
New LR = Old LR × (New Batch Size / Old Batch Size)
Why this works (the math):
Gradient Descent Update: param = param - lr * gradient
With Batch Size B, gradient is average of B samples:
gradient_B = (1/B) * sum(gradients from B samples)
update_B = lr * gradient_B
With Batch Size 2B, gradient is average of 2B samples:
gradient_2B = (1/(2B)) * sum(gradients from 2B samples)
Variance drops by 2x when averaging 2x more samples.
If variance drops 2x, gradient magnitude is √2x smaller.
To keep update magnitude constant: lr should increase by 2x.
Empirically validated: Goyal et al. (2017) "Accurate, Large Batch Training"
Implementation:
# Pattern 1: Direct scaling
original_lr = 0.001
original_batch_size = 32
new_batch_size = 128
scaling_factor = new_batch_size / original_batch_size # 4x
new_lr = original_lr * scaling_factor # 0.004
# Pattern 2: When changing both batch AND learning rate
def compute_scaled_lr(base_lr, base_batch_size, current_batch_size):
"""
Compute learning rate for new batch size using linear scaling rule.
Args:
base_lr: Learning rate at reference batch size
base_batch_size: Batch size where base_lr was tuned (usually 32 or 256)
current_batch_size: New batch size
Returns:
Scaled learning rate
WHY: Linear scaling rule keeps update magnitude constant
"""
scale_factor = current_batch_size / base_batch_size
return base_lr * scale_factor
# Example: ResNet-50 training (ImageNet baseline)
# Reference: batch=256, lr=0.1
# Now training at: batch=1024
scaled_lr = compute_scaled_lr(0.1, 256, 1024) # 0.4
print(f"Batch 256 with lr=0.1 → Batch 1024 with lr={scaled_lr}")
When linear scaling works:
# CASE 1: Scaling works well
# Batch: 32 → 256 (8x increase)
# Learning rate: 0.001 → 0.008 (8x)
# Training: ✓ Converges normally, same final accuracy
# Wall-clock: ✓ Faster (fewer iterations, better hardware utilization)
# CASE 2: Scaling doesn't work
# Batch: 32 → 1024 (32x increase!)
# Learning rate: 0.001 → 0.032 (32x)
# Problem: Learning rate too large, training diverges
# Solution: Need warmup phase
The Critical Caveat: WARMUP IS REQUIRED
# WRONG: Apply full scaled LR immediately
optimizer = torch.optim.SGD(model.parameters(), lr=0.032) # Too large!
for epoch in range(100):
for batch in train_loader:
loss = criterion(model(batch), targets)
loss.backward()
optimizer.step() # Loss diverges on first iteration!
# CORRECT: Warmup phase before scaled LR
def warmup_lr_schedule(base_lr, current_batch_size, reference_batch_size,
current_step, warmup_steps):
"""
Linear warmup from 0 to scaled LR.
WHY: Large LR jumps can cause divergence.
Gradual warmup lets model adapt to larger updates.
"""
scaled_lr = base_lr * (current_batch_size / reference_batch_size)
if current_step < warmup_steps:
# Linear warmup: ramp from 0 to scaled_lr
return scaled_lr * (current_step / warmup_steps)
else:
# Full scaled LR after warmup
return scaled_lr
# Implementation with PyTorch scheduler
from torch.optim.lr_scheduler import LambdaLR
def get_warmup_scheduler(optimizer, warmup_steps):
base_lrs = [param_group['lr'] for param_group in optimizer.param_groups]
def lr_lambda(current_step):
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
return 1.0
return LambdaLR(optimizer, lr_lambda)
# Training loop
optimizer = torch.optim.SGD(model.parameters(), lr=0.032)
scheduler = get_warmup_scheduler(optimizer, warmup_steps=1000)
for epoch in range(100):
for step, batch in enumerate(train_loader):
loss = criterion(model(batch), targets)
loss.backward()
optimizer.step()
scheduler.step() # Gradually increase LR
Practical guidelines:
BATCH SIZE INCREASE LEARNING RATE SCALE WARMUP NEEDED? WHY
2x (64→128) 2x (0.001→0.002) No Safe, gradual
4x (64→256) 4x (0.001→0.004) Maybe Starting to matter
8x (64→512) 8x (0.001→0.008) YES Risky without warmup
16x+ (64→1024) 16x+ (0.001→0.016) CRITICAL Risk of divergence
Pattern 3: Gradient Accumulation - The Alternative to Large Batches
What gradient accumulation does:
Gradient accumulation simulates large batch size without large GPU memory. Instead of 1 forward+backward of batch 256, do 8 forward+backwardsof batch 32. Same effective batch, 1/8th memory.
How it works:
# SIMPLE APPROACH (without accumulation)
batch_size = 256
effective_batch_size = 256 # Process full batch at once
memory_required = HIGH # Can't fit in GPU
for batch in train_loader: # batch.size() = 256
output = model(batch)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# GRADIENT ACCUMULATION APPROACH
batch_size = 32
accumulation_steps = 8
effective_batch_size = 32 * 8 = 256 # Same as above!
memory_required = LOW # Only batch 32 in memory at once
optimizer.zero_grad()
for accumulation_step in range(accumulation_steps):
batch = next(iter(train_loader)) # batch.size() = 32
output = model(batch)
loss = criterion(output, target)
loss.backward() # Accumulate gradients (don't zero!)
# Don't call optimizer.step() yet!
optimizer.step() # Update weights after accumulation complete
# Effect: Updated weights as if we processed batch_size=256
When to use gradient accumulation:
# CASE 1: Model too large to fit large batch
# Model: GPT-2 (124M parameters)
# Available GPU: 24GB
# Desired batch: 512 per GPU
# Fits in memory: No, only 32 fits
# Solution: Accumulate 16 steps of batch 32 = effective 512
model_params = 124_000_000 # 124M
param_memory = model_params * 4 # bytes (FP32)
optimizer_memory = model_params * 8 # Adam state (8x parameters)
batch_size = 32
sequence_length = 512
activation_memory_per_sample = param_memory / 10 # Rough estimate
total_memory = param_memory + optimizer_memory + (batch_size * activation_memory_per_sample)
# ~2GB memory per step
# 16 accumulation steps still << 24GB
# CASE 2: Distributed training across 8 GPUs
# Per-GPU batch: 32
# Number of GPUs: 8
# Local accumulation: 4 steps
# Total effective: 32 * 8 * 4 = 1024 (synchronized across 8 GPUs)
# Accumulation enables large total batch without massive per-GPU batch
The memory math:
Memory with Gradient Accumulation:
Without accumulation (batch_size = 256):
- Parameters: Fixed
- Optimizer state: Fixed (8x params for Adam)
- Activations: O(batch_size) = O(256)
- Gradients: O(batch_size) = O(256)
- Total ≈ 1.0x baseline memory
With accumulation (batch_size = 32, steps = 8):
- Parameters: Fixed (same)
- Optimizer state: Fixed (same)
- Activations: O(batch_size) = O(32) = 8x SMALLER
- Gradients: O(batch_size) = O(32) = 8x SMALLER
- Total ≈ 0.15x baseline memory (for activations+gradients)
Savings: ~85% memory reduction!
Cost: 8x longer (8 backward passes instead of 1)
Net wall-clock: ~1.5-2x slower (overhead, synchronization)
Implementation patterns:
# Pattern 1: Manual gradient accumulation
num_accumulation_steps = 8
optimizer.zero_grad()
for step, (batch, target) in enumerate(train_loader):
output = model(batch)
loss = criterion(output, target)
# Scale loss by accumulation steps
# WHY: Otherwise gradient magnitudes stack up across steps
loss = loss / num_accumulation_steps
loss.backward() # Accumulate gradients
if (step + 1) % num_accumulation_steps == 0:
optimizer.step() # Update after accumulation complete
optimizer.zero_grad()
# Pattern 2: With learning rate adjustment
# IMPORTANT: Don't adjust learning rate just because of accumulation!
# Accumulation is transparent to optimizer.
# Scale is: effective_batch = batch_size * num_accumulation_steps
# So LR should match effective_batch, NOT per-GPU batch
original_lr = 0.1 # Tuned for batch_size = 32
num_accumulation_steps = 8
effective_batch = 32 * 8 # 256
# Linear scaling rule based on effective batch:
# Batch 32 → 256 is 8x increase
# So LR: 0.1 → 0.8 (8x)
new_lr = original_lr * 8 # 0.8
optimizer = torch.optim.SGD(model.parameters(), lr=new_lr)
# Pattern 3: Distributed training with gradient accumulation
# Per-GPU batch: 32
# Number of GPUs: 8
# Accumulation steps: 4
# Effective batch: 32 * 8 * 4 = 1024
from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(model)
num_accumulation_steps = 4
optimizer.zero_grad()
for step, (batch, target) in enumerate(train_loader):
output = model(batch)
loss = criterion(output, target)
loss = loss / num_accumulation_steps
loss.backward()
if (step + 1) % num_accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
# Pattern 4: With synchronization (distributed)
class GradientAccumulator:
def __init__(self, model, num_accumulation_steps, sync_gradients_every=1):
self.model = model
self.num_steps = num_accumulation_steps
self.sync_every = sync_gradients_every
self.step_count = 0
def should_sync_gradients(self):
"""
In DDP, only sync gradients when we're about to do optimizer.step().
This reduces communication overhead.
"""
return (self.step_count + 1) % self.sync_every == 0
def backward(self, loss):
loss = loss / self.num_steps
# Only sync if we're about to step
if self.should_sync_gradients():
loss.backward()
else:
with self.model.no_sync(): # Skip gradient sync in DDP
loss.backward()
self.step_count += 1
Gradient accumulation vs large batch - when to choose:
# When gradient accumulation is GOOD choice:
# 1. Memory-constrained (can't fit large batch)
# 2. Need large effective batch for convergence
# 3. Can tolerate ~1.5-2x slowdown
# 4. Training wall-clock time not critical
# When gradient accumulation is BAD choice:
# 1. Can fit desired batch size in memory
# 2. Training speed is critical (wall-clock matters)
# 3. Already have good convergence with smaller batches
# 4. Reduced gradient noise is important for task
# Comparison table:
# LARGE BATCH GRADIENT ACCUMULATION
# Memory High Low (1/accumulation)
# Wall-clock time Fast ~1.5-2x slower
# Convergence speed Good Same (effective batch same)
# Implementation Simple Requires manual loop
# Memory savings None ~85% (with 8x accumulation)
# When to use When memory OK When memory constrained
Pattern 4: Memory Estimation and Optimization
Understanding memory components:
Total GPU Memory = Parameters + Optimizer State + Activations + Gradients
Example: Training BERT-base (110M params) with batch_size=32, seq_len=512
1. PARAMETERS (Fixed)
- BERT: 110M × 4 bytes (FP32) = 440 MB
- Or 110M × 2 bytes (FP16) = 220 MB
2. OPTIMIZER STATE (Fixed)
- SGD: No extra state = 0 MB
- Adam: m + v = 2 × params = 880 MB (FP32) or 440 MB (FP16)
- AdamW: Same as Adam
3. ACTIVATIONS (Linear in batch_size, seq_len)
- Stored during forward pass (for backward)
- BERT layer: ~batch × seq_len × hidden_dim × 4
- = 32 × 512 × 768 × 4 bytes
- = ~320 MB per layer
- × 12 layers = ~3.8 GB
4. GRADIENTS (Linear in batch_size)
- Stored after backward, until optimizer.step()
- Same size as parameters = 440 MB
TOTAL MEMORY = 440 + 880 + 3800 + 440 = ~5.6 GB
Typical budget: Use ~80% GPU = 19 GB with 24GB GPU
Room for more: Can increase batch from 32 → 128 safely
Memory calculation framework:
def estimate_memory_usage(
num_params: int,
batch_size: int,
seq_length: int,
hidden_dim: int,
num_layers: int,
dtype_bytes: int = 4, # 4 for FP32, 2 for FP16
optimizer: str = "adam", # or "sgd"
use_gradient_checkpointing: bool = False,
):
"""
Estimate memory for training a transformer model.
Args:
num_params: Total parameters
batch_size: Batch size
seq_length: Sequence length
hidden_dim: Hidden dimension (for activation estimation)
num_layers: Number of transformer layers
dtype_bytes: 4 for FP32, 2 for FP16, 1 for INT8
optimizer: "sgd" (no state), "adam" (8x params)
use_gradient_checkpointing: If True, reduce activation memory
Returns:
Memory in GB
WHY: Helps choose batch size without trial-and-error OOM
"""
# 1. Parameter memory
param_memory = num_params * dtype_bytes
# 2. Optimizer state
if optimizer.lower() == "adam":
opt_memory = 2 * num_params * dtype_bytes # m + v
elif optimizer.lower() == "adamw":
opt_memory = 2 * num_params * dtype_bytes # m + v
else: # SGD
opt_memory = 0
# 3. Activation memory (transformer-specific)
# Activations = hidden states + attention weights stored during forward
# Per layer: batch × seq_len × hidden_dim × 4 bytes
# × num_layers
activation_memory_per_layer = batch_size * seq_length * hidden_dim * dtype_bytes
total_activation_memory = activation_memory_per_layer * num_layers
if use_gradient_checkpointing:
# With checkpointing: only save activations for last layer
# (recompute others during backward)
total_activation_memory = activation_memory_per_layer # Only 1 layer
# 4. Gradient memory (same as parameter memory)
gradient_memory = num_params * dtype_bytes
# Total
total_bytes = param_memory + opt_memory + total_activation_memory + gradient_memory
total_gb = total_bytes / (1024**3)
return total_gb
# Example: BERT training
memory_gb = estimate_memory_usage(
num_params=110_000_000, # BERT-base
batch_size=32,
seq_length=512,
hidden_dim=768,
num_layers=12,
dtype_bytes=4, # FP32
optimizer="adam",
use_gradient_checkpointing=False,
)
print(f"Memory: {memory_gb:.1f} GB") # ~5.6 GB
# Optimize by reducing batch
memory_gb_batch16 = estimate_memory_usage(
num_params=110_000_000,
batch_size=16, # 2x smaller
seq_length=512,
hidden_dim=768,
num_layers=12,
dtype_bytes=4,
optimizer="adam",
use_gradient_checkpointing=False,
)
print(f"Memory with batch 16: {memory_gb_batch16:.1f} GB") # ~3.8 GB
# Optimize by mixed precision
memory_gb_fp16 = estimate_memory_usage(
num_params=110_000_000,
batch_size=32,
seq_length=512,
hidden_dim=768,
num_layers=12,
dtype_bytes=2, # FP16 instead of FP32
optimizer="adam",
use_gradient_checkpointing=False,
)
print(f"Memory with FP16: {memory_gb_fp16:.1f} GB") # ~2.8 GB
# Optimize with checkpointing
memory_gb_ckpt = estimate_memory_usage(
num_params=110_000_000,
batch_size=32,
seq_length=512,
hidden_dim=768,
num_layers=12,
dtype_bytes=4,
optimizer="adam",
use_gradient_checkpointing=True, # Save only last layer activations
)
print(f"Memory with checkpointing: {memory_gb_ckpt:.1f} GB") # ~1.0 GB
Memory optimization techniques:
# Technique 1: Gradient Checkpointing
# Recompute activations instead of storing them
# Memory: O(sqrt(num_layers)) instead of O(num_layers)
# Cost: ~30% slower training (recompute activations during backward)
from torch.utils.checkpoint import checkpoint
class TransformerBlock(nn.Module):
def forward(self, x):
# Forward: compute and store activations
# Backward: recompute activations during backward
return checkpoint(self._forward, x, use_reentrant=False)
def _forward(self, x):
x = self.attention(x)
x = self.feedforward(x)
return x
# Technique 2: Mixed Precision (FP16)
# Use FP16 for forward+backward (2x memory)
# Use FP32 for weights (don't accumulate errors)
# Memory: ~50% reduction
# Speed: 1.3-2x faster on modern GPUs
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch, target in train_loader:
optimizer.zero_grad()
with autocast(): # Automatic FP16 casting
output = model(batch)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# Technique 3: Quantization-Aware Training
# Store weights in INT8 or FP8
# Requires special hardware support
# Memory: 75-90% reduction
# Speed: 2-4x faster
# Technique 4: Batch Size Scheduling
# Start with small batch, increase during training
# Reason: Large batch early = poor generalization
# Large batch late = good generalization
# Memory: Gradually increases as needed
def get_adaptive_batch_size(epoch, total_epochs):
"""Increase batch size as training progresses"""
base_batch = 32
max_batch = 256
# Linear increase: start small, end large
scale_factor = base_batch + (max_batch - base_batch) * (epoch / total_epochs)
return int(scale_factor)
Pattern 5: Batch Size Effects on Convergence and Generalization
The generalization gap - why bigger batch = worse accuracy:
Generalization Gap = Test Accuracy (Large Batch) - Test Accuracy (Small Batch)
Why small batch generalizes better:
1. Gradient Noise: Small batch = noisy gradients
- Noise acts as regularization
- Forces model to find robust minima
- Larger noise → larger generalization margin
2. Loss Landscape: SGD with noise explores landscape differently
- Large batch: Gradient descent (exact gradient)
- Small batch: Stochastic gradient (noisy)
- Noise helps escape sharp minima (bad generalization)
- Leads to flat minima (good generalization)
3. Batch Normalization Interaction:
- BN computes statistics per batch
- Larger batch → more stable statistics
- More stable → less regularization effect
- Less regularization → worse generalization
Real numbers (ResNet-50 on ImageNet):
- Batch 256: 76.0% accuracy
- Batch 1024: 74.8% accuracy (1.2% gap!)
- Batch 4096: 72.0% accuracy (4% gap!)
The sharp minima problem:
SMALL BATCH (32):
Loss landscape: Finds FLAT minima
- Small change in weights → loss increases slowly
- Generalizes well (robust to input variations)
- Test accuracy ≈ Train accuracy
- Variance: Higher (gradient noise)
LARGE BATCH (1024):
Loss landscape: Finds SHARP minima
- Small change in weights → loss increases quickly
- Generalizes poorly (sensitive to input variations)
- Test accuracy << Train accuracy (overfitting)
- Variance: Lower (stable gradients)
SOLUTION: Add regularization to large batch training
- L2 regularization (weight decay)
- Dropout
- Data augmentation
- Label smoothing
Batch size effects on different architectures:
# Architecture 1: ResNets (well-studied)
# Batch 256: 76.0% top-1 accuracy (ImageNet)
# Batch 1024: 74.8% (-1.2%)
# Batch 4096: 72.0% (-4%)
# Conclusion: Batch size matters, gap grows exponentially
# Architecture 2: Vision Transformers
# Batch 512: 82.0% accuracy
# Batch 1024: 81.8% (-0.2%)
# Batch 4096: 81.0% (-1%)
# Conclusion: Less sensitive to batch size (more robust)
# Architecture 3: BERT (Language)
# Batch 128: 89.0% GLUE score
# Batch 256: 88.8% (-0.2%)
# Batch 512: 88.2% (-0.8%)
# Conclusion: Moderate sensitivity
# WHY THE DIFFERENCES?
# - ResNets: Simple architecture, sharp minima
# - Vision Transformers: Attention provides regularization
# - BERT: Pre-training + fine-tuning, already regularized
Empirical guidelines for batch size vs generalization:
# Rule 1: Start with batch 128-256
# Most tasks achieve good accuracy at this range
# Memory reasonable on modern GPUs
# Generalization gap minimal
# Rule 2: If increasing batch size - add regularization
def add_regularization_for_large_batch(batch_size, base_batch=256):
"""Adjust regularization strength for larger batch size"""
# Start from base: batch 256, weight_decay 0.0001
# Double batch → increase regularization
scale_factor = batch_size / base_batch
weight_decay = 0.0001 * (scale_factor ** 0.5) # sqrt scale
dropout = 0.1 # Add dropout
label_smoothing = 0.1 # Label smoothing helps
return {
'weight_decay': weight_decay,
'dropout': dropout,
'label_smoothing': label_smoothing,
}
# Rule 3: Validate on validation set
# Don't assume scaling rule works for accuracy
# Larger batch might need different epochs/learning rate schedule
# Rule 4: Gradient accumulation doesn't help generalization
# Accumulation ≠ large batch for gradient statistics
# Gradient accumulation has same gradient per parameter
# Just takes longer (multiple backward passes)
# Generalization benefit same as if you had memory for full batch
Pattern 6: Finding Optimal Batch Size (Not Just Maximum)
The batch size selection framework:
Step 1: Calculate memory budget
→ Max memory available (e.g., 24GB GPU)
→ Estimate parameters + optimizer state
→ Available for batch = Total - (params + opt state)
Step 2: Estimate per-sample memory
→ Run small batch (8), measure memory
→ Divide by 8 to get per-sample
→ Max batch = Available Memory / per-sample
Step 3: Find memory-safe batch
→ Use 80% of max (leaves margin)
→ This is maximum batch that's safe
Step 4: Check convergence at maximum batch
→ Train model with maximum safe batch
→ Compare accuracy to smaller batches
→ If >2% accuracy drop: reduce batch or add regularization
Step 5: Optimize for wall-clock time
→ Profile training time at different batch sizes
→ Wall-clock = (iterations) × (time per iteration)
→ Iterations = (samples / batch) × epochs
→ Find batch that minimizes wall-clock time
→ Often NOT the maximum batch!
Step 6: Select based on task requirements
→ If convergence matters more: smaller batch
→ If speed matters more: larger batch
→ If memory constrained: gradient accumulation
→ If fine-tuning: smaller batch (preserve pre-training)
Implementation:
def find_optimal_batch_size(
model,
train_loader,
criterion,
device,
target_accuracy=None,
time_budget_seconds=None,
):
"""
Find optimal batch size by profiling at different sizes.
Args:
model: PyTorch model to profile
train_loader: DataLoader
criterion: Loss function
device: torch.device
target_accuracy: If set, find batch that achieves this
time_budget_seconds: If set, find fastest batch within budget
Returns:
Optimal batch size, profiling results
WHY: Maximum batch ≠ optimal batch
"""
batch_sizes = [32, 64, 128, 256, 512]
results = {}
for batch_size in batch_sizes:
# Measure memory for this batch size
try:
batch, target = next(iter(train_loader))
batch = batch[:batch_size].to(device)
target = target[:batch_size].to(device)
torch.cuda.reset_peak_memory_stats(device)
with torch.cuda.device(device):
output = model(batch)
loss = criterion(output, target)
loss.backward()
memory_mb = torch.cuda.max_memory_allocated(device) / (1024 ** 2)
# Measure iteration time
import time
start = time.time()
for _ in range(10):
output = model(batch)
loss = criterion(output, target)
loss.backward()
iteration_time = (time.time() - start) / 10
# Calculate total training time
# Assume 100 epochs, 50k samples
iterations_per_epoch = 50000 // batch_size
total_iterations = iterations_per_epoch * 100
total_time = total_iterations * iteration_time
results[batch_size] = {
'memory_mb': memory_mb,
'iteration_time_ms': iteration_time * 1000,
'total_time_hours': total_time / 3600,
}
except RuntimeError as e:
results[batch_size] = {'error': str(e)}
# Find optimal based on criteria
if target_accuracy is not None:
# Choose smallest batch that achieves target accuracy
return min(results.keys())
elif time_budget_seconds is not None:
# Choose largest batch within time budget
valid = {bs: r for bs, r in results.items()
if 'error' not in r and r['total_time_hours'] * 3600 < time_budget_seconds}
return max(valid.keys()) if valid else None
else:
# Default: choose largest batch within 80% memory limit
memory_limit = 0.8 * torch.cuda.get_device_properties(device).total_memory / (1024**2)
valid = {bs: r for bs, r in results.items()
if 'error' not in r and r['memory_mb'] < memory_limit}
return max(valid.keys()) if valid else None
# Batch size discovery loop
def discover_optimal_batch_size(model, train_loader, criterion, device):
"""
Progressive batch size search starting from small.
Pattern: Double batch size until OOM, then back off.
"""
batch_size = 8
while True:
try:
# Try current batch size
batch, target = next(iter(train_loader))
batch = batch[:batch_size].to(device)
target = target[:batch_size].to(device)
output = model(batch)
loss = criterion(output, target)
loss.backward()
print(f"✓ Batch {batch_size} works")
# Try 2x
prev_batch = batch_size
batch_size *= 2
except RuntimeError as e:
if "out of memory" in str(e).lower():
# OOM: go back to last working batch
optimal_batch = prev_batch
print(f"✗ Batch {batch_size} OOM")
print(f"→ Use batch size {optimal_batch} (safe margin)")
# But check if we can use 1.5x
test_batch = int(optimal_batch * 1.5)
try:
batch = batch[:test_batch].to(device)
output = model(batch)
loss = criterion(output, target)
loss.backward()
print(f"✓ Batch {test_batch} also works, use this")
return test_batch
except:
return optimal_batch
else:
raise
Batch size selection by use case:
# Use Case 1: Maximum accuracy matters (research, publication)
# → Choose smaller batch (128-256)
# → More gradient noise = better generalization
# → Willing to train longer if accuracy is better
optimal_batch_size = 128
# Use Case 2: Training speed matters (prototyping, iteration)
# → Choose larger batch (512-1024)
# → Trade some accuracy for wall-clock speed
# → Need to add regularization to reduce generalization gap
optimal_batch_size = 512
regularization_strength = 'strong' # weight_decay, dropout
# Use Case 3: Memory severely constrained (mobile, edge)
# → Choose small batch (16-32)
# → Use gradient accumulation to simulate larger batch
# → Accept lower accuracy if necessary
optimal_batch_size = 16
accumulation_steps = 8 # Simulate batch 128
# Use Case 4: Fine-tuning small dataset
# → Choose small batch (16-32)
# → Preserve pre-training (smaller updates)
# → Larger batch risks forgetting pre-trained knowledge
optimal_batch_size = 16
# Use Case 5: Large model, large dataset
# → Choose medium-large batch (256-512)
# → Gradient accumulation for effective larger batch
# → Mixed precision for memory savings
optimal_batch_size = 256
use_mixed_precision = True
use_gradient_accumulation = False # Fits with mixed precision
# Use Case 6: Distributed training (multiple GPUs/TPUs)
# → Per-GPU batch: 32-64
# → Accumulation: 4-8 steps
# → Total effective: per_gpu * num_gpus * accumulation
# → Large total effective batch, small per-GPU batch
per_gpu_batch = 64
num_gpus = 8
accumulation_steps = 4
effective_batch = 64 * 8 * 4 # 2048
Common Pitfalls
❌ Pitfall 1: Confusing Maximum Batch with Optimal Batch
→ Symptom: "I have 24GB memory, so I should use the largest batch that fits" → Why it breaks: Larger batch = worse generalization. Maximum batch might achieve 2-3% lower accuracy. → Fix: Use 80% of maximum batch size, validate accuracy, adjust if needed.
# WRONG
max_batch = find_max_batch_that_fits(model, memory=24_000_000_000)
train(model, batch_size=max_batch) # Likely overfit
# CORRECT
safe_batch = int(max_batch * 0.8) # 80% of maximum
train(model, batch_size=safe_batch)
validate_accuracy(model) # Check if acceptable
if accuracy_drop > 2%:
reduce_batch_size(safe_batch * 0.8)
add_regularization()
❌ Pitfall 2: Ignoring Learning Rate Scaling
→ Symptom: "I doubled my batch size, training diverges now" → Why it breaks: Gradient magnitudes decrease with larger batch, so learning rate must increase proportionally. → Fix: Use linear scaling rule: new_lr = old_lr × (new_batch / old_batch)
# WRONG
batch_size = 64
learning_rate = 0.001
# Increase batch without adjusting LR
batch_size = 256
# Learning rate still 0.001 - too small!
# Gradient updates too conservative, very slow convergence
# CORRECT
batch_size = 64
learning_rate = 0.001
batch_size = 256
learning_rate = 0.001 * (256 / 64) # Scale by 4x
# = 0.004
❌ Pitfall 3: Using Huge Learning Rate Without Warmup
→ Symptom: "I scaled my learning rate by 10x and now training diverges immediately" → Why it breaks: Very large learning rate jumps cause instability. Model can't adapt. → Fix: Add linear warmup phase: gradually increase LR from 0 to scaled value.
# WRONG
scaled_lr = 0.001 * 10 # 0.01
optimizer = SGD(model, lr=0.01)
for epoch in range(100):
for batch in train_loader:
loss = criterion(model(batch), target)
loss.backward()
optimizer.step() # Diverges on first iteration!
# CORRECT
base_lr = 0.001
scaled_lr = 0.001 * 10 # 0.01
warmup_steps = 1000
def lr_lambda(step):
if step < warmup_steps:
return float(step) / float(max(1, warmup_steps)) * 10 # 0 to 10x over warmup
return 1.0 # 10x after warmup
optimizer = SGD(model, lr=base_lr)
scheduler = LambdaLR(optimizer, lr_lambda)
for epoch in range(100):
for batch in train_loader:
loss = criterion(model(batch), target)
loss.backward()
optimizer.step()
scheduler.step()
❌ Pitfall 4: Gradient Accumulation Without LR Adjustment
→ Symptom: "I added gradient accumulation but training is much slower to converge" → Why it breaks: Accumulation itself doesn't require LR change, but if effective batch increased, LR should too. → Fix: Adjust LR based on effective batch size, not per-GPU batch size.
# WRONG
batch_size = 32 # Per-GPU
num_accumulation = 8
# Learning rate still tuned for batch 32
# Effective batch = 32 × 8 = 256
# But LR not scaled for batch 256
# Convergence slower because LR too conservative
# CORRECT
batch_size = 32
num_accumulation = 8
effective_batch = batch_size * num_accumulation # 256
# Get LR for batch 32
base_lr_batch32 = 0.001
# Scale for batch 256
lr_batch256 = base_lr_batch32 * (256 / 32) # 0.008
optimizer = SGD(model, lr=lr_batch256)
❌ Pitfall 5: Assuming Batch Size Doesn't Affect Accuracy
→ Symptom: "Batch size only affects speed, not accuracy" → Why it breaks: Batch size strongly affects generalization (1-4% gap is common). → Fix: Always validate final accuracy at different batch sizes. Larger batch might need different hyperparameters.
# WRONG - assume accuracy independent of batch
batch_sizes = [64, 256, 1024]
for batch_size in batch_sizes:
model = train(learning_rate=0.001) # Same LR for all!
accuracy = evaluate(model)
# Accuracy will differ significantly!
# CORRECT - adjust hyperparameters per batch
for batch_size in batch_sizes:
lr = 0.001 * (batch_size / 64) # Scale LR
weight_decay = 0.0001 * (batch_size / 64) ** 0.5 # Increase regularization
model = train(learning_rate=lr, weight_decay=weight_decay)
accuracy = evaluate(model)
# More consistent accuracy across batch sizes
❌ Pitfall 6: Not Considering Synchronous vs Asynchronous Batch Norm
→ Symptom: "My distributed training accuracy is much worse than single-GPU" → Why it breaks: Batch norm computes statistics per batch. Distributed training with small per-GPU batch = incorrect statistics. → Fix: Use SyncBatchNorm for correct statistics across all GPUs.
# WRONG - Synchronous data parallel, asynchronous BN
from torch.nn.parallel import DataParallel
model = DataParallel(model, device_ids=[0, 1, 2, 3])
# Each GPU has batch_size=32
# BN computes stats from only its 32 samples
# Stats unstable, training broken
# CORRECT - Synchronous batch norm
from torch.nn.modules.batchnorm import SyncBatchNorm
model = SyncBatchNorm.convert_sync_batchnorm(model)
model = DistributedDataParallel(model, find_unused_parameters=False)
# Each GPU: batch 32, but BN aggregates across all 4 GPUs = 128
# Stats computed from all 128 samples, stable
❌ Pitfall 7: Gradient Accumulation Too Large (>16x)
→ Symptom: "I'm accumulating gradients over 32 steps but training diverges" → Why it breaks: Large accumulation means many iterations of gradient computation before update. Gradients become stale, divergence risk. → Fix: Keep accumulation ≤ 16x. Use distributed training for larger effective batches.
# WRONG - excessive accumulation
batch_size = 4
accumulation_steps = 32 # 128x effective batch!
# Gradients from step 1 are way out of date by step 32
# Large variance in gradient estimates, divergence
# CORRECT - reasonable accumulation
batch_size = 32
accumulation_steps = 8 # 256x effective batch, acceptable
# Gradients only 8 iterations old by update time
# Variance manageable
# OR use distributed training instead
per_gpu_batch = 32
num_gpus = 8
effective_batch = 32 * 8 = 256 # Same as above, but no accumulation
# Better convergence properties
❌ Pitfall 8: Mixing Gradient Accumulation with Exponential Moving Average (EMA)
→ Symptom: "I'm using gradient accumulation with learning rate scheduler and EMA, but training is unstable" → Why it breaks: EMA expects one update per step. With accumulation, multiple backward passes → stale momentum terms. → Fix: Update EMA only when you call optimizer.step(), not every backward pass.
# WRONG - updating EMA every backward pass
ema_model = ExponentialMovingAverage(model.parameters(), decay=0.999)
for step, batch in enumerate(train_loader):
loss = criterion(model(batch), target)
loss.backward()
ema_model.update(model.parameters()) # Called every iteration!
if (step + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
# CORRECT - update EMA only on optimizer.step()
for step, batch in enumerate(train_loader):
loss = criterion(model(batch), target)
loss.backward()
if (step + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
ema_model.update(model.parameters()) # Only here!
❌ Pitfall 9: Batch Size Doubling Without Validation
→ Symptom: "I increased batch from 64 to 128 based on linear scaling rule, but accuracy dropped 2%" → Why it breaks: Linear scaling rule gives convergence rate, not accuracy guarantee. Generalization gap widens. → Fix: Always validate on holdout set when changing batch size. Accept accuracy drop or add regularization.
# WRONG - assume linear scaling guarantees accuracy
original_batch = 64
original_lr = 0.001
original_accuracy = 0.85
new_batch = 128
new_lr = 0.001 * (128 / 64) # 0.002
new_accuracy = 0.83 # Dropped 2%! Should have validated first
# CORRECT - validate and adjust regularization if needed
new_batch = 128
new_lr = 0.001 * (128 / 64)
model = train(lr=new_lr, batch=new_batch)
val_accuracy = validate(model)
if val_accuracy < 0.84: # Acceptable drop?
# Add regularization for larger batch
model = train(
lr=new_lr,
batch=new_batch,
weight_decay=0.0002, # Increase
dropout=0.2, # Add/increase
)
val_accuracy = validate(model)
❌ Pitfall 10: Using Maximum Batch in Fine-tuning
→ Symptom: "I fine-tuned with large batch size and catastrophically forgot pre-training" → Why it breaks: Large batch = large updates. Pre-trained weights overwritten too quickly. → Fix: Fine-tuning requires SMALLER batch size (32-64) and smaller learning rate than pre-training.
# WRONG - fine-tuning with large batch
pretrained_model = load_pretrained_bert()
batch_size = 512 # Large!
learning_rate = 0.001 # Too large!
model = fine_tune(pretrained_model, batch_size=512, lr=0.001)
# Overfit to task, forget pre-trained knowledge
# Pre-training lost!
# CORRECT - conservative fine-tuning
pretrained_model = load_pretrained_bert()
batch_size = 32 # Small, conservative
learning_rate = 0.00001 # Tiny, preserve pre-training
model = fine_tune(
pretrained_model,
batch_size=batch_size,
lr=learning_rate,
weight_decay=0.001, # Strong L2 regularization
)
# Preserves pre-training knowledge, adapts carefully
Practical Decision Framework
Quick Batch Size Decision Tree
1. How much GPU memory do you have?
├─ < 8 GB: Start with batch 16-32
├─ 8-16 GB: Start with batch 32-64
├─ 16-24 GB: Start with batch 64-128
└─ 24+ GB: Start with batch 128-256
2. Can you fit your target batch in memory?
├─ Yes: Use it (with LR scaling)
├─ No, by <2x: Use gradient accumulation
└─ No, by >2x: Use smaller batch + stronger regularization
3. Is accuracy your priority or speed?
├─ Accuracy: Use smaller batch (32-128)
├─ Speed: Use larger batch (256-1024)
└─ Both: Gradient accumulation + mixed precision
4. Are you fine-tuning or training from scratch?
├─ Fine-tuning: Use small batch (16-32), small LR
└─ From scratch: Use medium batch (64-256), scale LR
5. Are you using distributed training?
├─ Yes: Per-GPU batch 32-64, accumulate for effective 256-512
└─ No: Single GPU batch 64-256
Red Flags - Stop and Clarify
| Excuse | Reality | What To Do |
|---|---|---|
| "Just use the maximum batch that fits" | Worse generalization likely. Need to validate accuracy. | Measure accuracy at 80% of max, validate trade-offs. |
| "Linear scaling rule means I don't need to validate" | Rule gives convergence rate, not accuracy guarantee. Generalization gap exists. | Always validate final accuracy with new batch size. |
| "Gradient accumulation is just for memory-constrained settings" | It's a legitimate technique with trade-offs (slowness) worth understanding. | Use when memory constrained; understand slowdown cost. |
| "Batch size only affects speed, not accuracy" | Incorrect. Batch size strongly affects final accuracy (1-4% typical gap). | Always measure accuracy, expect gap, add regularization. |
| "I'll use the batch size from a paper, it should work" | Different model, data, hardware - need to validate. | Use paper as starting point, but validate and adjust. |
| "Larger batch = faster training" | Depends on what you measure (iterations vs epochs vs wall-clock). | Measure actual wall-clock time at different batch sizes. |
| "Just double the learning rate when doubling batch" | Linear scaling rule requires warmup for large increases. | Add warmup phase, measure convergence. |
| "Fine-tuning works same as pre-training, just different data" | Fine-tuning needs much smaller batch and LR (preserve pre-training). | Use batch 16-32, LR 10-100x smaller than pre-training. |
Advanced Patterns: Batch Size Optimization in Production
Pattern 7: Batch Size Scheduling During Training
Increasing batch size as training progresses - when and why:
# Intuition: Start with small batch (good generalization),
# increase later (finish training faster)
def get_scheduled_batch_size(epoch, total_epochs, base_batch=32, max_batch=256):
"""
Increase batch size linearly with epochs.
WHY: Start small for generalization, increase for speed later.
Research shows this works well for long training.
"""
# Linear increase: 0 → 100% over training
scale = epoch / total_epochs
return int(base_batch + (max_batch - base_batch) * scale)
# Usage in training loop
for epoch in range(total_epochs):
batch_size = get_scheduled_batch_size(epoch, total_epochs)
for batch, target in get_data_loader(batch_size=batch_size):
# Adjust learning rate dynamically
lr = 0.001 * (batch_size / 32) # Scale with batch
update_learning_rate(optimizer, lr)
output = model(batch)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# Alternative: exponential schedule
def get_exponential_batch_schedule(epoch, base_batch=32, max_batch=256):
"""Exponential increase instead of linear"""
scale = (epoch / total_epochs)
return int(base_batch * (max_batch / base_batch) ** scale)
When batch size scheduling is valuable:
GOOD FIT:
- Long training (100+ epochs)
- Starting generalization is important
- Speed only matters at end
- Example: ResNet on ImageNet
NOT NEEDED:
- Short training (10-20 epochs)
- Already regularized enough (BERT fine-tuning)
- Batch size well-chosen from start
Pattern 8: Batch Size vs Other Hyperparameters
Understanding interactions with other hyperparameters:
# Interaction 1: Batch size ↔ Learning rate
# Already covered: linear scaling rule
# Interaction 2: Batch size ↔ Weight decay
# Larger batch → worse generalization
# Solution: Increase weight decay when increasing batch
# Typical: weight_decay ~ sqrt(batch_size)
def adjust_weight_decay(base_wd=0.0001, base_batch=256, new_batch=512):
"""Scale weight decay with batch size"""
return base_wd * (new_batch / base_batch) ** 0.5
# Example
wd_batch_256 = 0.0001
wd_batch_512 = adjust_weight_decay(wd_batch_256, 256, 512) # 0.000141
# Interaction 3: Batch size ↔ Dropout
# Larger batch → add/increase dropout
# Dropout magnitude depends on layer, typically 0.1-0.5
def adjust_dropout(base_dropout=0.1, base_batch=256, new_batch=512):
"""Increase dropout for larger batches"""
# Dropout strength ~ sqrt(batch_size)
scale = (new_batch / base_batch) ** 0.5
return min(base_dropout * scale, 0.5) # Cap at 0.5
# Interaction 4: Batch size ↔ Number of epochs
# Larger batch → more epochs needed to converge
# Typical: iterations constant ≈ samples/batch × epochs
# If batch 4x → epochs 1.5-2x to match convergence
base_batch = 64
base_epochs = 100
base_iterations = (50000 / base_batch) * base_epochs # Total iterations
new_batch = 256
# To maintain same iterations:
new_epochs = base_iterations / (50000 / new_batch) # ~25 epochs
# Wall-clock faster (fewer iterations) but need fewer epochs
# Interaction 5: Batch size ↔ Optimizer choice
# SGD: works well at all batch sizes
# Momentum: accumulates larger steps, works best with smaller batch
# Adam: adaptive, less sensitive to batch size
# RMSprop: similar to Adam
# Recommendation:
# - Small batch (32-128): SGD with momentum or Adam
# - Large batch (512+): Adam (more stable) or SGD with warmup + large LR
# Interaction 6: Batch size ↔ Normalization technique
# Batch Norm: statistics from batch, larger batch = better stats
# Layer Norm: independent of batch size
# Group Norm: middle ground, works well with any batch size
# If using BatchNorm with small batch (< 16):
# → Use SyncBatchNorm across devices
# → Or use GroupNorm instead
# If using BatchNorm with large batch (> 1024):
# → Standard BatchNorm fine
# → May want to reduce BN momentum (accumulate stats slower)
Rationalization Table: Common Excuses About Batch Size
| Rationalization | Why It's Wrong | Correct Approach |
|---|---|---|
| "Larger batch is always better for speed" | Wall-clock time depends on iterations AND time-per-iteration. Larger batch may have lower throughput. | Profile wall-clock time at different batch sizes, choose fastest. |
| "I'll tune batch size last, it's not important" | Batch size affects convergence rate, generalization, and stability early. Tuning last wastes time. | Choose good batch size early (based on memory), validate accuracy. |
| "Maximum batch that fits = optimal batch" | Generalization gap widens with batch size (1-4% typical). Maximum might hit accuracy target. | Use 80% of max, validate on validation set, adjust if needed. |
| "Linear scaling rule means I don't validate" | Scaling rule gives convergence rate. Accuracy still varies with batch size due to generalization gap. | Always validate test/validation accuracy with new batch. |
| "Gradient accumulation is slow, don't use it" | True, it's slower (1.5-2x). But if memory is bottleneck, only alternative. Choose based on constraints. | Use when memory constrained. Accept slowdown. Don't use if memory OK. |
| "I don't need warmup, I'll just use scaled LR" | Large LR jumps cause divergence. Warmup prevents this. | Add linear warmup phase for scaled LR. |
| "My paper used batch X, I'll use that" | Different model, data, hardware converge differently. Paper batch might not be optimal for you. | Use paper as starting point. Validate and adjust for your setup. |
| "Fine-tuning uses same batch as pre-training" | Fine-tuning needs much smaller batch (preserve knowledge). Using pre-training batch erases pre-training. | Use batch 10-20x smaller than pre-training. Use tiny LR. |
| "Batch size only affects speed, not accuracy" | Batch size strongly affects generalization (1-4% gap common). Different final accuracy with different batch. | Expect accuracy variation with batch. Validate at each batch size. |
| "I increased batch, why is training slower?" | Fewer iterations (good) but longer per-iteration (bad). Total wall-clock = iterations × time-per-iteration. | Profile actual wall-clock time. May need gradient accumulation. |
| "I'll start with large batch to save memory" | Large batch → bad generalization early → harder to recover later. Start small, increase if needed. | Start with batch 32-64, increase during training if memory allows. |
Comprehensive Example: Training a Vision Transformer
Let's put it all together with a real example:
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torch.optim import AdamW
from torchvision import models, datasets, transforms
def train_vision_transformer_optimized():
"""
Complete example: training Vision Transformer with batch size optimization.
"""
# Step 1: Model and data
device = torch.device("cuda:0")
model = models.vit_b_16(pretrained=False).to(device)
criterion = torch.nn.CrossEntropyLoss()
# Dataset (ImageNet-scale)
transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
])
# Step 2: Determine batch size
# ViT-Base: 86M parameters
# GPU: 40GB A100
# Memory estimate: params (344MB) + optimizer (688MB) + activations
# Can fit batch 256-512
base_batch = 256
num_accumulation_steps = 1 # Can fit directly
effective_batch = base_batch
# Step 3: Initialize optimizer with scaled LR
# Base LR tuned for batch 256
base_lr = 1e-4
scaled_lr = base_lr * (effective_batch / 256) # 1e-4 (no scaling needed)
optimizer = AdamW(model.parameters(), lr=scaled_lr, weight_decay=0.05)
# Step 4: Warmup scheduler
warmup_steps = 1000
total_steps = 100 * len(dataset) // effective_batch
def warmup_cosine_schedule(step):
if step < warmup_steps:
return float(step) / float(max(1, warmup_steps))
return 0.5 * (1.0 + torch.cos(
torch.tensor(3.14159) *
(step - warmup_steps) / (total_steps - warmup_steps)
)).item()
scheduler = LambdaLR(optimizer, warmup_cosine_schedule)
# Step 5: Training loop with gradient accumulation (even though not needed)
# Good practice for larger models
model.train()
optimizer.zero_grad()
for epoch in range(100):
for step, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
# Forward + backward
logits = model(images)
loss = criterion(logits, labels)
loss = loss / num_accumulation_steps
loss.backward()
# Update on accumulation step
if (step + 1) % num_accumulation_steps == 0:
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step % 100 == 0:
print(f"Epoch {epoch}, step {step}, loss {loss.item():.4f}, "
f"lr {optimizer.param_groups[0]['lr']:.2e}")
# Validate every epoch
val_accuracy = validate(model, device)
print(f"Epoch {epoch} validation accuracy: {val_accuracy:.2%}")
return model
# Key patterns demonstrated:
# 1. Batch size chosen based on memory (80% of max)
# 2. Learning rate scaled for batch size
# 3. Warmup phase for gradual LR increase
# 4. Cosine annealing for LR decay
# 5. Gradient accumulation structure (even if not needed)
# 6. Gradient clipping for stability
# 7. Regular validation to monitor accuracy
Summary: Batch Size and Memory Decision Making
The core principle: Batch size is a system design choice affecting convergence, generalization, speed, and memory simultaneously. There is no universal "right" batch size - it depends on your constraints and priorities.
The decision process:
- Memory constraint: Start with 80% of maximum batch
- Convergence: Scale learning rate 1:1 with batch increase (with warmup)
- Generalization: Validate accuracy, reduce if gap >2% (or add regularization)
- Performance: Profile wall-clock time at different batch sizes
- Architecture: Different models have different optimal batches
The key insights:
- Larger batch = faster iterations but worse generalization
- Linear scaling rule requires warmup for large increases
- Gradient accumulation is a legitimate technique (understand slowdown cost)
- Fine-tuning requires smaller batch than pre-training
- Distributed training needs care with batch norm and gradient updates
- Always measure, validate, and adjust - don't assume rules apply to your case
The testing approach:
When pressure-tested, this skill should:
- Explain why maximum batch ≠ optimal batch (generalization gap)
- Provide concrete examples of linear scaling rule with warmup
- Address gradient accumulation systematically (when, why, cost)
- Discuss memory estimation and optimization techniques
- Help select batch size based on constraints AND priorities
- Resist rationalizations and always recommend validation
References and Further Reading
Key papers:
- Goyal et al. (2017) "Accurate, Large Batch Training" - Linear scaling rule
- You et al. (2019) "Large Batch Optimization for Deep Learning" - Theory
- Smith et al. (2017) "Don't Decay the Learning Rate" - Learning rate schedules
Techniques mentioned:
- Batch Normalization: Ioffe & Szegedy (2015)
- Layer Normalization: Ba et al. (2016)
- Mixed Precision Training: Micikevicius et al. (2017)
- Gradient Checkpointing: Chen et al. (2016)
Related Yzmir Skills:
learning-rate-scheduling- LR schedule choices beyond linear scalinggradient-management- Gradient clipping and accumulation for stabilityoptimization-algorithms- Optimizer selection and hyperparameter tuning