Claude Code Plugins

Community-maintained marketplace

Feedback

batch-size-and-memory-tradeoffs

@tachyon-beep/skillpacks
1
0

Batch size selection - memory, convergence, generalization, gradient accumulation

Install Skill

1Download skill
2Enable skills in Claude

Open claude.ai/settings/capabilities and find the "Skills" section

3Upload to Claude

Click "Upload skill" and select the downloaded ZIP file

Note: Please verify skill by going through its instructions before using it.

SKILL.md

name batch-size-and-memory-tradeoffs
description Batch size selection - memory, convergence, generalization, gradient accumulation

Batch Size and Memory Tradeoffs

Overview

Batch size is one of the most misunderstood hyperparameters. Most engineers think: "larger batch = faster training = better". Wrong. Batch size affects convergence speed, generalization, memory usage, and actual wall-clock training time in complex ways. Larger batch size is NOT always better.

Core principle: Batch size selection is a system optimization problem, not a memory constraint problem. Choose batch size based on computational speed, convergence requirements, and generalization targets - not just what fits in memory.


When to Use This Skill

Use this skill when:

  • Choosing batch size for new training
  • Training is slow and considering larger batches
  • Out-of-memory errors during training
  • Learning rate needs adjustment after batch size change
  • Distributed training needs batch size scaling
  • Gradient accumulation considerations
  • User asks "what batch size should I use?"
  • Training accuracy varies widely between batch sizes
  • Convergence takes too long or is unstable
  • Memory per sample calculation needed
  • Comparing training speed: iterations vs epochs vs wall-clock time
  • Fine-tuning with different batch sizes than pre-training

Symptoms you need this skill:

  • "I have memory, what's the maximum batch size?" (wrong question)
  • "Larger batches train faster, so use 512?" (incomplete)
  • "Batch size doesn't affect accuracy, only speed?" (false)
  • "Gradient accumulation is a workaround for small memory?" (misconception)
  • "Just scale learning rate by 2x when doubling batch size?" (incomplete)
  • "We get OOM at batch 256, so use 128 forever" (not optimized)

Don't use when:

  • User has pure memory/infrastructure questions (use pytorch-engineering)
  • User asks about optimizer selection (use optimizer-selection-framework)
  • User asks about learning rate scheduling (use learning-rate-scheduling)
  • User has general training failure (not batch-size specific)

Core Patterns

Pattern 1: The Batch Size Tradeoff Space

The critical insight: Batch size affects FOUR independent dimensions simultaneously. Optimize one = impact others.

The four dimensions:

1. TRAINING SPEED (iterations to converge)
   ├─ Larger batch → fewer iterations to convergence ✓
   ├─ BUT: Gradient variance decreases (noisier gradients are better)
   └─ Result: Mixed - can't just maximize batch

2. COMPUTATIONAL EFFICIENCY (wall-clock time)
   ├─ Larger batch → amortize overhead per sample ✓
   ├─ BUT: Larger batch → need larger LR (unstable)
   ├─ AND: Gradient accumulation = repeated backward (slow)
   └─ Result: Optimal ≠ Maximum

3. GENERALIZATION (test accuracy)
   ├─ Smaller batch → noisier gradients → better regularization ✓
   ├─ Larger batch → cleaner gradient → overfit risk ✗
   ├─ BUT: Can compensate with stronger regularization
   └─ Result: Batch size ↔ regularization coupling

4. MEMORY USAGE (GPU memory required)
   ├─ Larger batch → linear increase in activation memory
   ├─ Parameters constant regardless of batch
   ├─ Optimizer state constant regardless of batch
   └─ Result: Memory ∝ batch size (linear only for activations)

The mental model:

LARGER BATCH:
  ✓ Fewer iterations to convergence
  ✓ Better computational efficiency (up to point)
  ✗ Worse generalization (harder to regularize)
  ✗ Requires larger learning rate (instability risk)
  ✗ Higher memory usage

SMALLER BATCH:
  ✗ More iterations to convergence
  ✗ Worse computational efficiency
  ✓ Better generalization (noise helps)
  ✓ Smaller learning rates are stable
  ✓ Lower memory usage

Finding the sweet spot:

  • Start with batch size that uses ~80% GPU memory
  • Adjust learning rate using linear scaling rule
  • Monitor validation accuracy
  • If validation accuracy drops → batch too large, reduce or regularize
  • If training is slow → may need gradient accumulation, not larger batch

Pattern 2: Linear Learning Rate Scaling Rule

The rule that changes everything:

If you increase batch size by factor K, increase learning rate by factor K.

New LR = Old LR × (New Batch Size / Old Batch Size)

Why this works (the math):

Gradient Descent Update: param = param - lr * gradient

With Batch Size B, gradient is average of B samples:
  gradient_B = (1/B) * sum(gradients from B samples)
  update_B = lr * gradient_B

With Batch Size 2B, gradient is average of 2B samples:
  gradient_2B = (1/(2B)) * sum(gradients from 2B samples)

Variance drops by 2x when averaging 2x more samples.
If variance drops 2x, gradient magnitude is √2x smaller.
To keep update magnitude constant: lr should increase by 2x.

Empirically validated: Goyal et al. (2017) "Accurate, Large Batch Training"

Implementation:

# Pattern 1: Direct scaling
original_lr = 0.001
original_batch_size = 32
new_batch_size = 128

scaling_factor = new_batch_size / original_batch_size  # 4x
new_lr = original_lr * scaling_factor  # 0.004

# Pattern 2: When changing both batch AND learning rate
def compute_scaled_lr(base_lr, base_batch_size, current_batch_size):
    """
    Compute learning rate for new batch size using linear scaling rule.

    Args:
        base_lr: Learning rate at reference batch size
        base_batch_size: Batch size where base_lr was tuned (usually 32 or 256)
        current_batch_size: New batch size

    Returns:
        Scaled learning rate

    WHY: Linear scaling rule keeps update magnitude constant
    """
    scale_factor = current_batch_size / base_batch_size
    return base_lr * scale_factor

# Example: ResNet-50 training (ImageNet baseline)
# Reference: batch=256, lr=0.1
# Now training at: batch=1024
scaled_lr = compute_scaled_lr(0.1, 256, 1024)  # 0.4
print(f"Batch 256 with lr=0.1 → Batch 1024 with lr={scaled_lr}")

When linear scaling works:

# CASE 1: Scaling works well
# Batch: 32 → 256 (8x increase)
# Learning rate: 0.001 → 0.008 (8x)
# Training: ✓ Converges normally, same final accuracy
# Wall-clock: ✓ Faster (fewer iterations, better hardware utilization)

# CASE 2: Scaling doesn't work
# Batch: 32 → 1024 (32x increase!)
# Learning rate: 0.001 → 0.032 (32x)
# Problem: Learning rate too large, training diverges
# Solution: Need warmup phase

The Critical Caveat: WARMUP IS REQUIRED

# WRONG: Apply full scaled LR immediately
optimizer = torch.optim.SGD(model.parameters(), lr=0.032)  # Too large!
for epoch in range(100):
    for batch in train_loader:
        loss = criterion(model(batch), targets)
        loss.backward()
        optimizer.step()  # Loss diverges on first iteration!

# CORRECT: Warmup phase before scaled LR
def warmup_lr_schedule(base_lr, current_batch_size, reference_batch_size,
                       current_step, warmup_steps):
    """
    Linear warmup from 0 to scaled LR.

    WHY: Large LR jumps can cause divergence.
    Gradual warmup lets model adapt to larger updates.
    """
    scaled_lr = base_lr * (current_batch_size / reference_batch_size)

    if current_step < warmup_steps:
        # Linear warmup: ramp from 0 to scaled_lr
        return scaled_lr * (current_step / warmup_steps)
    else:
        # Full scaled LR after warmup
        return scaled_lr

# Implementation with PyTorch scheduler
from torch.optim.lr_scheduler import LambdaLR

def get_warmup_scheduler(optimizer, warmup_steps):
    base_lrs = [param_group['lr'] for param_group in optimizer.param_groups]

    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        return 1.0

    return LambdaLR(optimizer, lr_lambda)

# Training loop
optimizer = torch.optim.SGD(model.parameters(), lr=0.032)
scheduler = get_warmup_scheduler(optimizer, warmup_steps=1000)

for epoch in range(100):
    for step, batch in enumerate(train_loader):
        loss = criterion(model(batch), targets)
        loss.backward()
        optimizer.step()
        scheduler.step()  # Gradually increase LR

Practical guidelines:

BATCH SIZE INCREASE    LEARNING RATE SCALE    WARMUP NEEDED?    WHY
2x (64→128)           2x (0.001→0.002)       No                Safe, gradual
4x (64→256)           4x (0.001→0.004)       Maybe              Starting to matter
8x (64→512)           8x (0.001→0.008)       YES                Risky without warmup
16x+ (64→1024)        16x+ (0.001→0.016)     CRITICAL           Risk of divergence

Pattern 3: Gradient Accumulation - The Alternative to Large Batches

What gradient accumulation does:

Gradient accumulation simulates large batch size without large GPU memory. Instead of 1 forward+backward of batch 256, do 8 forward+backwardsof batch 32. Same effective batch, 1/8th memory.

How it works:

# SIMPLE APPROACH (without accumulation)
batch_size = 256
effective_batch_size = 256  # Process full batch at once
memory_required = HIGH  # Can't fit in GPU

for batch in train_loader:  # batch.size() = 256
    output = model(batch)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

# GRADIENT ACCUMULATION APPROACH
batch_size = 32
accumulation_steps = 8
effective_batch_size = 32 * 8 = 256  # Same as above!
memory_required = LOW  # Only batch 32 in memory at once

optimizer.zero_grad()
for accumulation_step in range(accumulation_steps):
    batch = next(iter(train_loader))  # batch.size() = 32
    output = model(batch)
    loss = criterion(output, target)
    loss.backward()  # Accumulate gradients (don't zero!)
    # Don't call optimizer.step() yet!

optimizer.step()  # Update weights after accumulation complete
# Effect: Updated weights as if we processed batch_size=256

When to use gradient accumulation:

# CASE 1: Model too large to fit large batch
# Model: GPT-2 (124M parameters)
# Available GPU: 24GB
# Desired batch: 512 per GPU
# Fits in memory: No, only 32 fits
# Solution: Accumulate 16 steps of batch 32 = effective 512

model_params = 124_000_000  # 124M
param_memory = model_params * 4  # bytes (FP32)
optimizer_memory = model_params * 8  # Adam state (8x parameters)
batch_size = 32
sequence_length = 512
activation_memory_per_sample = param_memory / 10  # Rough estimate
total_memory = param_memory + optimizer_memory + (batch_size * activation_memory_per_sample)
# ~2GB memory per step
# 16 accumulation steps still << 24GB

# CASE 2: Distributed training across 8 GPUs
# Per-GPU batch: 32
# Number of GPUs: 8
# Local accumulation: 4 steps
# Total effective: 32 * 8 * 4 = 1024 (synchronized across 8 GPUs)

# Accumulation enables large total batch without massive per-GPU batch

The memory math:

Memory with Gradient Accumulation:

Without accumulation (batch_size = 256):
  - Parameters: Fixed
  - Optimizer state: Fixed (8x params for Adam)
  - Activations: O(batch_size) = O(256)
  - Gradients: O(batch_size) = O(256)
  - Total ≈ 1.0x baseline memory

With accumulation (batch_size = 32, steps = 8):
  - Parameters: Fixed (same)
  - Optimizer state: Fixed (same)
  - Activations: O(batch_size) = O(32) = 8x SMALLER
  - Gradients: O(batch_size) = O(32) = 8x SMALLER
  - Total ≈ 0.15x baseline memory (for activations+gradients)

Savings: ~85% memory reduction!
Cost: 8x longer (8 backward passes instead of 1)
Net wall-clock: ~1.5-2x slower (overhead, synchronization)

Implementation patterns:

# Pattern 1: Manual gradient accumulation
num_accumulation_steps = 8
optimizer.zero_grad()

for step, (batch, target) in enumerate(train_loader):
    output = model(batch)
    loss = criterion(output, target)

    # Scale loss by accumulation steps
    # WHY: Otherwise gradient magnitudes stack up across steps
    loss = loss / num_accumulation_steps

    loss.backward()  # Accumulate gradients

    if (step + 1) % num_accumulation_steps == 0:
        optimizer.step()  # Update after accumulation complete
        optimizer.zero_grad()

# Pattern 2: With learning rate adjustment
# IMPORTANT: Don't adjust learning rate just because of accumulation!
# Accumulation is transparent to optimizer.
# Scale is: effective_batch = batch_size * num_accumulation_steps
# So LR should match effective_batch, NOT per-GPU batch

original_lr = 0.1  # Tuned for batch_size = 32
num_accumulation_steps = 8
effective_batch = 32 * 8  # 256

# Linear scaling rule based on effective batch:
# Batch 32 → 256 is 8x increase
# So LR: 0.1 → 0.8 (8x)
new_lr = original_lr * 8  # 0.8
optimizer = torch.optim.SGD(model.parameters(), lr=new_lr)

# Pattern 3: Distributed training with gradient accumulation
# Per-GPU batch: 32
# Number of GPUs: 8
# Accumulation steps: 4
# Effective batch: 32 * 8 * 4 = 1024

from torch.nn.parallel import DistributedDataParallel as DDP

model = DDP(model)

num_accumulation_steps = 4
optimizer.zero_grad()

for step, (batch, target) in enumerate(train_loader):
    output = model(batch)
    loss = criterion(output, target)
    loss = loss / num_accumulation_steps

    loss.backward()

    if (step + 1) % num_accumulation_steps == 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()

# Pattern 4: With synchronization (distributed)
class GradientAccumulator:
    def __init__(self, model, num_accumulation_steps, sync_gradients_every=1):
        self.model = model
        self.num_steps = num_accumulation_steps
        self.sync_every = sync_gradients_every
        self.step_count = 0

    def should_sync_gradients(self):
        """
        In DDP, only sync gradients when we're about to do optimizer.step().
        This reduces communication overhead.
        """
        return (self.step_count + 1) % self.sync_every == 0

    def backward(self, loss):
        loss = loss / self.num_steps

        # Only sync if we're about to step
        if self.should_sync_gradients():
            loss.backward()
        else:
            with self.model.no_sync():  # Skip gradient sync in DDP
                loss.backward()

        self.step_count += 1

Gradient accumulation vs large batch - when to choose:

# When gradient accumulation is GOOD choice:
# 1. Memory-constrained (can't fit large batch)
# 2. Need large effective batch for convergence
# 3. Can tolerate ~1.5-2x slowdown
# 4. Training wall-clock time not critical

# When gradient accumulation is BAD choice:
# 1. Can fit desired batch size in memory
# 2. Training speed is critical (wall-clock matters)
# 3. Already have good convergence with smaller batches
# 4. Reduced gradient noise is important for task

# Comparison table:
#                    LARGE BATCH      GRADIENT ACCUMULATION
# Memory             High             Low (1/accumulation)
# Wall-clock time    Fast             ~1.5-2x slower
# Convergence speed  Good             Same (effective batch same)
# Implementation     Simple           Requires manual loop
# Memory savings     None             ~85% (with 8x accumulation)
# When to use        When memory OK    When memory constrained

Pattern 4: Memory Estimation and Optimization

Understanding memory components:

Total GPU Memory = Parameters + Optimizer State + Activations + Gradients

Example: Training BERT-base (110M params) with batch_size=32, seq_len=512

1. PARAMETERS (Fixed)
   - BERT: 110M × 4 bytes (FP32) = 440 MB
   - Or 110M × 2 bytes (FP16) = 220 MB

2. OPTIMIZER STATE (Fixed)
   - SGD: No extra state = 0 MB
   - Adam: m + v = 2 × params = 880 MB (FP32) or 440 MB (FP16)
   - AdamW: Same as Adam

3. ACTIVATIONS (Linear in batch_size, seq_len)
   - Stored during forward pass (for backward)
   - BERT layer: ~batch × seq_len × hidden_dim × 4
   - = 32 × 512 × 768 × 4 bytes
   - = ~320 MB per layer
   - × 12 layers = ~3.8 GB

4. GRADIENTS (Linear in batch_size)
   - Stored after backward, until optimizer.step()
   - Same size as parameters = 440 MB

TOTAL MEMORY = 440 + 880 + 3800 + 440 = ~5.6 GB
Typical budget: Use ~80% GPU = 19 GB with 24GB GPU
Room for more: Can increase batch from 32 → 128 safely

Memory calculation framework:

def estimate_memory_usage(
    num_params: int,
    batch_size: int,
    seq_length: int,
    hidden_dim: int,
    num_layers: int,
    dtype_bytes: int = 4,  # 4 for FP32, 2 for FP16
    optimizer: str = "adam",  # or "sgd"
    use_gradient_checkpointing: bool = False,
):
    """
    Estimate memory for training a transformer model.

    Args:
        num_params: Total parameters
        batch_size: Batch size
        seq_length: Sequence length
        hidden_dim: Hidden dimension (for activation estimation)
        num_layers: Number of transformer layers
        dtype_bytes: 4 for FP32, 2 for FP16, 1 for INT8
        optimizer: "sgd" (no state), "adam" (8x params)
        use_gradient_checkpointing: If True, reduce activation memory

    Returns:
        Memory in GB

    WHY: Helps choose batch size without trial-and-error OOM
    """

    # 1. Parameter memory
    param_memory = num_params * dtype_bytes

    # 2. Optimizer state
    if optimizer.lower() == "adam":
        opt_memory = 2 * num_params * dtype_bytes  # m + v
    elif optimizer.lower() == "adamw":
        opt_memory = 2 * num_params * dtype_bytes  # m + v
    else:  # SGD
        opt_memory = 0

    # 3. Activation memory (transformer-specific)
    # Activations = hidden states + attention weights stored during forward
    # Per layer: batch × seq_len × hidden_dim × 4 bytes
    # × num_layers
    activation_memory_per_layer = batch_size * seq_length * hidden_dim * dtype_bytes
    total_activation_memory = activation_memory_per_layer * num_layers

    if use_gradient_checkpointing:
        # With checkpointing: only save activations for last layer
        # (recompute others during backward)
        total_activation_memory = activation_memory_per_layer  # Only 1 layer

    # 4. Gradient memory (same as parameter memory)
    gradient_memory = num_params * dtype_bytes

    # Total
    total_bytes = param_memory + opt_memory + total_activation_memory + gradient_memory
    total_gb = total_bytes / (1024**3)

    return total_gb

# Example: BERT training
memory_gb = estimate_memory_usage(
    num_params=110_000_000,  # BERT-base
    batch_size=32,
    seq_length=512,
    hidden_dim=768,
    num_layers=12,
    dtype_bytes=4,  # FP32
    optimizer="adam",
    use_gradient_checkpointing=False,
)
print(f"Memory: {memory_gb:.1f} GB")  # ~5.6 GB

# Optimize by reducing batch
memory_gb_batch16 = estimate_memory_usage(
    num_params=110_000_000,
    batch_size=16,  # 2x smaller
    seq_length=512,
    hidden_dim=768,
    num_layers=12,
    dtype_bytes=4,
    optimizer="adam",
    use_gradient_checkpointing=False,
)
print(f"Memory with batch 16: {memory_gb_batch16:.1f} GB")  # ~3.8 GB

# Optimize by mixed precision
memory_gb_fp16 = estimate_memory_usage(
    num_params=110_000_000,
    batch_size=32,
    seq_length=512,
    hidden_dim=768,
    num_layers=12,
    dtype_bytes=2,  # FP16 instead of FP32
    optimizer="adam",
    use_gradient_checkpointing=False,
)
print(f"Memory with FP16: {memory_gb_fp16:.1f} GB")  # ~2.8 GB

# Optimize with checkpointing
memory_gb_ckpt = estimate_memory_usage(
    num_params=110_000_000,
    batch_size=32,
    seq_length=512,
    hidden_dim=768,
    num_layers=12,
    dtype_bytes=4,
    optimizer="adam",
    use_gradient_checkpointing=True,  # Save only last layer activations
)
print(f"Memory with checkpointing: {memory_gb_ckpt:.1f} GB")  # ~1.0 GB

Memory optimization techniques:

# Technique 1: Gradient Checkpointing
# Recompute activations instead of storing them
# Memory: O(sqrt(num_layers)) instead of O(num_layers)
# Cost: ~30% slower training (recompute activations during backward)

from torch.utils.checkpoint import checkpoint

class TransformerBlock(nn.Module):
    def forward(self, x):
        # Forward: compute and store activations
        # Backward: recompute activations during backward
        return checkpoint(self._forward, x, use_reentrant=False)

    def _forward(self, x):
        x = self.attention(x)
        x = self.feedforward(x)
        return x

# Technique 2: Mixed Precision (FP16)
# Use FP16 for forward+backward (2x memory)
# Use FP32 for weights (don't accumulate errors)
# Memory: ~50% reduction
# Speed: 1.3-2x faster on modern GPUs

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch, target in train_loader:
    optimizer.zero_grad()

    with autocast():  # Automatic FP16 casting
        output = model(batch)
        loss = criterion(output, target)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

# Technique 3: Quantization-Aware Training
# Store weights in INT8 or FP8
# Requires special hardware support
# Memory: 75-90% reduction
# Speed: 2-4x faster

# Technique 4: Batch Size Scheduling
# Start with small batch, increase during training
# Reason: Large batch early = poor generalization
# Large batch late = good generalization
# Memory: Gradually increases as needed

def get_adaptive_batch_size(epoch, total_epochs):
    """Increase batch size as training progresses"""
    base_batch = 32
    max_batch = 256

    # Linear increase: start small, end large
    scale_factor = base_batch + (max_batch - base_batch) * (epoch / total_epochs)
    return int(scale_factor)

Pattern 5: Batch Size Effects on Convergence and Generalization

The generalization gap - why bigger batch = worse accuracy:

Generalization Gap = Test Accuracy (Large Batch) - Test Accuracy (Small Batch)

Why small batch generalizes better:
1. Gradient Noise: Small batch = noisy gradients
   - Noise acts as regularization
   - Forces model to find robust minima
   - Larger noise → larger generalization margin

2. Loss Landscape: SGD with noise explores landscape differently
   - Large batch: Gradient descent (exact gradient)
   - Small batch: Stochastic gradient (noisy)
   - Noise helps escape sharp minima (bad generalization)
   - Leads to flat minima (good generalization)

3. Batch Normalization Interaction:
   - BN computes statistics per batch
   - Larger batch → more stable statistics
   - More stable → less regularization effect
   - Less regularization → worse generalization

Real numbers (ResNet-50 on ImageNet):
- Batch 256: 76.0% accuracy
- Batch 1024: 74.8% accuracy (1.2% gap!)
- Batch 4096: 72.0% accuracy (4% gap!)

The sharp minima problem:

SMALL BATCH (32):
  Loss landscape: Finds FLAT minima
  - Small change in weights → loss increases slowly
  - Generalizes well (robust to input variations)
  - Test accuracy ≈ Train accuracy
  - Variance: Higher (gradient noise)

LARGE BATCH (1024):
  Loss landscape: Finds SHARP minima
  - Small change in weights → loss increases quickly
  - Generalizes poorly (sensitive to input variations)
  - Test accuracy << Train accuracy (overfitting)
  - Variance: Lower (stable gradients)

SOLUTION: Add regularization to large batch training
- L2 regularization (weight decay)
- Dropout
- Data augmentation
- Label smoothing

Batch size effects on different architectures:

# Architecture 1: ResNets (well-studied)
# Batch 256: 76.0% top-1 accuracy (ImageNet)
# Batch 1024: 74.8% (-1.2%)
# Batch 4096: 72.0% (-4%)
# Conclusion: Batch size matters, gap grows exponentially

# Architecture 2: Vision Transformers
# Batch 512: 82.0% accuracy
# Batch 1024: 81.8% (-0.2%)
# Batch 4096: 81.0% (-1%)
# Conclusion: Less sensitive to batch size (more robust)

# Architecture 3: BERT (Language)
# Batch 128: 89.0% GLUE score
# Batch 256: 88.8% (-0.2%)
# Batch 512: 88.2% (-0.8%)
# Conclusion: Moderate sensitivity

# WHY THE DIFFERENCES?
# - ResNets: Simple architecture, sharp minima
# - Vision Transformers: Attention provides regularization
# - BERT: Pre-training + fine-tuning, already regularized

Empirical guidelines for batch size vs generalization:

# Rule 1: Start with batch 128-256
# Most tasks achieve good accuracy at this range
# Memory reasonable on modern GPUs
# Generalization gap minimal

# Rule 2: If increasing batch size - add regularization
def add_regularization_for_large_batch(batch_size, base_batch=256):
    """Adjust regularization strength for larger batch size"""

    # Start from base: batch 256, weight_decay 0.0001
    # Double batch → increase regularization
    scale_factor = batch_size / base_batch

    weight_decay = 0.0001 * (scale_factor ** 0.5)  # sqrt scale
    dropout = 0.1  # Add dropout
    label_smoothing = 0.1  # Label smoothing helps

    return {
        'weight_decay': weight_decay,
        'dropout': dropout,
        'label_smoothing': label_smoothing,
    }

# Rule 3: Validate on validation set
# Don't assume scaling rule works for accuracy
# Larger batch might need different epochs/learning rate schedule

# Rule 4: Gradient accumulation doesn't help generalization
# Accumulation ≠ large batch for gradient statistics
# Gradient accumulation has same gradient per parameter
# Just takes longer (multiple backward passes)
# Generalization benefit same as if you had memory for full batch

Pattern 6: Finding Optimal Batch Size (Not Just Maximum)

The batch size selection framework:

Step 1: Calculate memory budget
  → Max memory available (e.g., 24GB GPU)
  → Estimate parameters + optimizer state
  → Available for batch = Total - (params + opt state)

Step 2: Estimate per-sample memory
  → Run small batch (8), measure memory
  → Divide by 8 to get per-sample
  → Max batch = Available Memory / per-sample

Step 3: Find memory-safe batch
  → Use 80% of max (leaves margin)
  → This is maximum batch that's safe

Step 4: Check convergence at maximum batch
  → Train model with maximum safe batch
  → Compare accuracy to smaller batches
  → If >2% accuracy drop: reduce batch or add regularization

Step 5: Optimize for wall-clock time
  → Profile training time at different batch sizes
  → Wall-clock = (iterations) × (time per iteration)
  → Iterations = (samples / batch) × epochs
  → Find batch that minimizes wall-clock time
  → Often NOT the maximum batch!

Step 6: Select based on task requirements
  → If convergence matters more: smaller batch
  → If speed matters more: larger batch
  → If memory constrained: gradient accumulation
  → If fine-tuning: smaller batch (preserve pre-training)

Implementation:

def find_optimal_batch_size(
    model,
    train_loader,
    criterion,
    device,
    target_accuracy=None,
    time_budget_seconds=None,
):
    """
    Find optimal batch size by profiling at different sizes.

    Args:
        model: PyTorch model to profile
        train_loader: DataLoader
        criterion: Loss function
        device: torch.device
        target_accuracy: If set, find batch that achieves this
        time_budget_seconds: If set, find fastest batch within budget

    Returns:
        Optimal batch size, profiling results

    WHY: Maximum batch ≠ optimal batch
    """

    batch_sizes = [32, 64, 128, 256, 512]
    results = {}

    for batch_size in batch_sizes:
        # Measure memory for this batch size
        try:
            batch, target = next(iter(train_loader))
            batch = batch[:batch_size].to(device)
            target = target[:batch_size].to(device)

            torch.cuda.reset_peak_memory_stats(device)
            with torch.cuda.device(device):
                output = model(batch)
                loss = criterion(output, target)
                loss.backward()

            memory_mb = torch.cuda.max_memory_allocated(device) / (1024 ** 2)

            # Measure iteration time
            import time
            start = time.time()
            for _ in range(10):
                output = model(batch)
                loss = criterion(output, target)
                loss.backward()
            iteration_time = (time.time() - start) / 10

            # Calculate total training time
            # Assume 100 epochs, 50k samples
            iterations_per_epoch = 50000 // batch_size
            total_iterations = iterations_per_epoch * 100
            total_time = total_iterations * iteration_time

            results[batch_size] = {
                'memory_mb': memory_mb,
                'iteration_time_ms': iteration_time * 1000,
                'total_time_hours': total_time / 3600,
            }

        except RuntimeError as e:
            results[batch_size] = {'error': str(e)}

    # Find optimal based on criteria
    if target_accuracy is not None:
        # Choose smallest batch that achieves target accuracy
        return min(results.keys())
    elif time_budget_seconds is not None:
        # Choose largest batch within time budget
        valid = {bs: r for bs, r in results.items()
                if 'error' not in r and r['total_time_hours'] * 3600 < time_budget_seconds}
        return max(valid.keys()) if valid else None
    else:
        # Default: choose largest batch within 80% memory limit
        memory_limit = 0.8 * torch.cuda.get_device_properties(device).total_memory / (1024**2)
        valid = {bs: r for bs, r in results.items()
                if 'error' not in r and r['memory_mb'] < memory_limit}
        return max(valid.keys()) if valid else None

# Batch size discovery loop
def discover_optimal_batch_size(model, train_loader, criterion, device):
    """
    Progressive batch size search starting from small.

    Pattern: Double batch size until OOM, then back off.
    """
    batch_size = 8

    while True:
        try:
            # Try current batch size
            batch, target = next(iter(train_loader))
            batch = batch[:batch_size].to(device)
            target = target[:batch_size].to(device)

            output = model(batch)
            loss = criterion(output, target)
            loss.backward()

            print(f"✓ Batch {batch_size} works")

            # Try 2x
            prev_batch = batch_size
            batch_size *= 2

        except RuntimeError as e:
            if "out of memory" in str(e).lower():
                # OOM: go back to last working batch
                optimal_batch = prev_batch
                print(f"✗ Batch {batch_size} OOM")
                print(f"→ Use batch size {optimal_batch} (safe margin)")

                # But check if we can use 1.5x
                test_batch = int(optimal_batch * 1.5)
                try:
                    batch = batch[:test_batch].to(device)
                    output = model(batch)
                    loss = criterion(output, target)
                    loss.backward()
                    print(f"✓ Batch {test_batch} also works, use this")
                    return test_batch
                except:
                    return optimal_batch
            else:
                raise

Batch size selection by use case:

# Use Case 1: Maximum accuracy matters (research, publication)
# → Choose smaller batch (128-256)
# → More gradient noise = better generalization
# → Willing to train longer if accuracy is better

optimal_batch_size = 128

# Use Case 2: Training speed matters (prototyping, iteration)
# → Choose larger batch (512-1024)
# → Trade some accuracy for wall-clock speed
# → Need to add regularization to reduce generalization gap

optimal_batch_size = 512
regularization_strength = 'strong'  # weight_decay, dropout

# Use Case 3: Memory severely constrained (mobile, edge)
# → Choose small batch (16-32)
# → Use gradient accumulation to simulate larger batch
# → Accept lower accuracy if necessary

optimal_batch_size = 16
accumulation_steps = 8  # Simulate batch 128

# Use Case 4: Fine-tuning small dataset
# → Choose small batch (16-32)
# → Preserve pre-training (smaller updates)
# → Larger batch risks forgetting pre-trained knowledge

optimal_batch_size = 16

# Use Case 5: Large model, large dataset
# → Choose medium-large batch (256-512)
# → Gradient accumulation for effective larger batch
# → Mixed precision for memory savings

optimal_batch_size = 256
use_mixed_precision = True
use_gradient_accumulation = False  # Fits with mixed precision

# Use Case 6: Distributed training (multiple GPUs/TPUs)
# → Per-GPU batch: 32-64
# → Accumulation: 4-8 steps
# → Total effective: per_gpu * num_gpus * accumulation
# → Large total effective batch, small per-GPU batch

per_gpu_batch = 64
num_gpus = 8
accumulation_steps = 4
effective_batch = 64 * 8 * 4  # 2048

Common Pitfalls

Pitfall 1: Confusing Maximum Batch with Optimal Batch

Symptom: "I have 24GB memory, so I should use the largest batch that fits" → Why it breaks: Larger batch = worse generalization. Maximum batch might achieve 2-3% lower accuracy. → Fix: Use 80% of maximum batch size, validate accuracy, adjust if needed.

# WRONG
max_batch = find_max_batch_that_fits(model, memory=24_000_000_000)
train(model, batch_size=max_batch)  # Likely overfit

# CORRECT
safe_batch = int(max_batch * 0.8)  # 80% of maximum
train(model, batch_size=safe_batch)
validate_accuracy(model)  # Check if acceptable
if accuracy_drop > 2%:
    reduce_batch_size(safe_batch * 0.8)
    add_regularization()

Pitfall 2: Ignoring Learning Rate Scaling

Symptom: "I doubled my batch size, training diverges now" → Why it breaks: Gradient magnitudes decrease with larger batch, so learning rate must increase proportionally. → Fix: Use linear scaling rule: new_lr = old_lr × (new_batch / old_batch)

# WRONG
batch_size = 64
learning_rate = 0.001

# Increase batch without adjusting LR
batch_size = 256
# Learning rate still 0.001 - too small!
# Gradient updates too conservative, very slow convergence

# CORRECT
batch_size = 64
learning_rate = 0.001

batch_size = 256
learning_rate = 0.001 * (256 / 64)  # Scale by 4x
# = 0.004

Pitfall 3: Using Huge Learning Rate Without Warmup

Symptom: "I scaled my learning rate by 10x and now training diverges immediately" → Why it breaks: Very large learning rate jumps cause instability. Model can't adapt. → Fix: Add linear warmup phase: gradually increase LR from 0 to scaled value.

# WRONG
scaled_lr = 0.001 * 10  # 0.01
optimizer = SGD(model, lr=0.01)
for epoch in range(100):
    for batch in train_loader:
        loss = criterion(model(batch), target)
        loss.backward()
        optimizer.step()  # Diverges on first iteration!

# CORRECT
base_lr = 0.001
scaled_lr = 0.001 * 10  # 0.01
warmup_steps = 1000

def lr_lambda(step):
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps)) * 10  # 0 to 10x over warmup
    return 1.0  # 10x after warmup

optimizer = SGD(model, lr=base_lr)
scheduler = LambdaLR(optimizer, lr_lambda)

for epoch in range(100):
    for batch in train_loader:
        loss = criterion(model(batch), target)
        loss.backward()
        optimizer.step()
        scheduler.step()

Pitfall 4: Gradient Accumulation Without LR Adjustment

Symptom: "I added gradient accumulation but training is much slower to converge" → Why it breaks: Accumulation itself doesn't require LR change, but if effective batch increased, LR should too. → Fix: Adjust LR based on effective batch size, not per-GPU batch size.

# WRONG
batch_size = 32  # Per-GPU
num_accumulation = 8
# Learning rate still tuned for batch 32

# Effective batch = 32 × 8 = 256
# But LR not scaled for batch 256
# Convergence slower because LR too conservative

# CORRECT
batch_size = 32
num_accumulation = 8
effective_batch = batch_size * num_accumulation  # 256

# Get LR for batch 32
base_lr_batch32 = 0.001

# Scale for batch 256
lr_batch256 = base_lr_batch32 * (256 / 32)  # 0.008
optimizer = SGD(model, lr=lr_batch256)

Pitfall 5: Assuming Batch Size Doesn't Affect Accuracy

Symptom: "Batch size only affects speed, not accuracy" → Why it breaks: Batch size strongly affects generalization (1-4% gap is common). → Fix: Always validate final accuracy at different batch sizes. Larger batch might need different hyperparameters.

# WRONG - assume accuracy independent of batch
batch_sizes = [64, 256, 1024]
for batch_size in batch_sizes:
    model = train(learning_rate=0.001)  # Same LR for all!
    accuracy = evaluate(model)
    # Accuracy will differ significantly!

# CORRECT - adjust hyperparameters per batch
for batch_size in batch_sizes:
    lr = 0.001 * (batch_size / 64)  # Scale LR
    weight_decay = 0.0001 * (batch_size / 64) ** 0.5  # Increase regularization
    model = train(learning_rate=lr, weight_decay=weight_decay)
    accuracy = evaluate(model)
    # More consistent accuracy across batch sizes

Pitfall 6: Not Considering Synchronous vs Asynchronous Batch Norm

Symptom: "My distributed training accuracy is much worse than single-GPU" → Why it breaks: Batch norm computes statistics per batch. Distributed training with small per-GPU batch = incorrect statistics. → Fix: Use SyncBatchNorm for correct statistics across all GPUs.

# WRONG - Synchronous data parallel, asynchronous BN
from torch.nn.parallel import DataParallel

model = DataParallel(model, device_ids=[0, 1, 2, 3])
# Each GPU has batch_size=32
# BN computes stats from only its 32 samples
# Stats unstable, training broken

# CORRECT - Synchronous batch norm
from torch.nn.modules.batchnorm import SyncBatchNorm

model = SyncBatchNorm.convert_sync_batchnorm(model)
model = DistributedDataParallel(model, find_unused_parameters=False)
# Each GPU: batch 32, but BN aggregates across all 4 GPUs = 128
# Stats computed from all 128 samples, stable

Pitfall 7: Gradient Accumulation Too Large (>16x)

Symptom: "I'm accumulating gradients over 32 steps but training diverges" → Why it breaks: Large accumulation means many iterations of gradient computation before update. Gradients become stale, divergence risk. → Fix: Keep accumulation ≤ 16x. Use distributed training for larger effective batches.

# WRONG - excessive accumulation
batch_size = 4
accumulation_steps = 32  # 128x effective batch!
# Gradients from step 1 are way out of date by step 32
# Large variance in gradient estimates, divergence

# CORRECT - reasonable accumulation
batch_size = 32
accumulation_steps = 8  # 256x effective batch, acceptable
# Gradients only 8 iterations old by update time
# Variance manageable

# OR use distributed training instead
per_gpu_batch = 32
num_gpus = 8
effective_batch = 32 * 8 = 256  # Same as above, but no accumulation
# Better convergence properties

Pitfall 8: Mixing Gradient Accumulation with Exponential Moving Average (EMA)

Symptom: "I'm using gradient accumulation with learning rate scheduler and EMA, but training is unstable" → Why it breaks: EMA expects one update per step. With accumulation, multiple backward passes → stale momentum terms. → Fix: Update EMA only when you call optimizer.step(), not every backward pass.

# WRONG - updating EMA every backward pass
ema_model = ExponentialMovingAverage(model.parameters(), decay=0.999)

for step, batch in enumerate(train_loader):
    loss = criterion(model(batch), target)
    loss.backward()

    ema_model.update(model.parameters())  # Called every iteration!

    if (step + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

# CORRECT - update EMA only on optimizer.step()
for step, batch in enumerate(train_loader):
    loss = criterion(model(batch), target)
    loss.backward()

    if (step + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
        ema_model.update(model.parameters())  # Only here!

Pitfall 9: Batch Size Doubling Without Validation

Symptom: "I increased batch from 64 to 128 based on linear scaling rule, but accuracy dropped 2%" → Why it breaks: Linear scaling rule gives convergence rate, not accuracy guarantee. Generalization gap widens. → Fix: Always validate on holdout set when changing batch size. Accept accuracy drop or add regularization.

# WRONG - assume linear scaling guarantees accuracy
original_batch = 64
original_lr = 0.001
original_accuracy = 0.85

new_batch = 128
new_lr = 0.001 * (128 / 64)  # 0.002
new_accuracy = 0.83  # Dropped 2%! Should have validated first

# CORRECT - validate and adjust regularization if needed
new_batch = 128
new_lr = 0.001 * (128 / 64)
model = train(lr=new_lr, batch=new_batch)
val_accuracy = validate(model)

if val_accuracy < 0.84:  # Acceptable drop?
    # Add regularization for larger batch
    model = train(
        lr=new_lr,
        batch=new_batch,
        weight_decay=0.0002,  # Increase
        dropout=0.2,  # Add/increase
    )
    val_accuracy = validate(model)

Pitfall 10: Using Maximum Batch in Fine-tuning

Symptom: "I fine-tuned with large batch size and catastrophically forgot pre-training" → Why it breaks: Large batch = large updates. Pre-trained weights overwritten too quickly. → Fix: Fine-tuning requires SMALLER batch size (32-64) and smaller learning rate than pre-training.

# WRONG - fine-tuning with large batch
pretrained_model = load_pretrained_bert()
batch_size = 512  # Large!
learning_rate = 0.001  # Too large!

model = fine_tune(pretrained_model, batch_size=512, lr=0.001)
# Overfit to task, forget pre-trained knowledge
# Pre-training lost!

# CORRECT - conservative fine-tuning
pretrained_model = load_pretrained_bert()
batch_size = 32  # Small, conservative
learning_rate = 0.00001  # Tiny, preserve pre-training

model = fine_tune(
    pretrained_model,
    batch_size=batch_size,
    lr=learning_rate,
    weight_decay=0.001,  # Strong L2 regularization
)
# Preserves pre-training knowledge, adapts carefully

Practical Decision Framework

Quick Batch Size Decision Tree

1. How much GPU memory do you have?
   ├─ < 8 GB: Start with batch 16-32
   ├─ 8-16 GB: Start with batch 32-64
   ├─ 16-24 GB: Start with batch 64-128
   └─ 24+ GB: Start with batch 128-256

2. Can you fit your target batch in memory?
   ├─ Yes: Use it (with LR scaling)
   ├─ No, by <2x: Use gradient accumulation
   └─ No, by >2x: Use smaller batch + stronger regularization

3. Is accuracy your priority or speed?
   ├─ Accuracy: Use smaller batch (32-128)
   ├─ Speed: Use larger batch (256-1024)
   └─ Both: Gradient accumulation + mixed precision

4. Are you fine-tuning or training from scratch?
   ├─ Fine-tuning: Use small batch (16-32), small LR
   └─ From scratch: Use medium batch (64-256), scale LR

5. Are you using distributed training?
   ├─ Yes: Per-GPU batch 32-64, accumulate for effective 256-512
   └─ No: Single GPU batch 64-256

Red Flags - Stop and Clarify

Excuse Reality What To Do
"Just use the maximum batch that fits" Worse generalization likely. Need to validate accuracy. Measure accuracy at 80% of max, validate trade-offs.
"Linear scaling rule means I don't need to validate" Rule gives convergence rate, not accuracy guarantee. Generalization gap exists. Always validate final accuracy with new batch size.
"Gradient accumulation is just for memory-constrained settings" It's a legitimate technique with trade-offs (slowness) worth understanding. Use when memory constrained; understand slowdown cost.
"Batch size only affects speed, not accuracy" Incorrect. Batch size strongly affects final accuracy (1-4% typical gap). Always measure accuracy, expect gap, add regularization.
"I'll use the batch size from a paper, it should work" Different model, data, hardware - need to validate. Use paper as starting point, but validate and adjust.
"Larger batch = faster training" Depends on what you measure (iterations vs epochs vs wall-clock). Measure actual wall-clock time at different batch sizes.
"Just double the learning rate when doubling batch" Linear scaling rule requires warmup for large increases. Add warmup phase, measure convergence.
"Fine-tuning works same as pre-training, just different data" Fine-tuning needs much smaller batch and LR (preserve pre-training). Use batch 16-32, LR 10-100x smaller than pre-training.

Advanced Patterns: Batch Size Optimization in Production

Pattern 7: Batch Size Scheduling During Training

Increasing batch size as training progresses - when and why:

# Intuition: Start with small batch (good generalization),
# increase later (finish training faster)

def get_scheduled_batch_size(epoch, total_epochs, base_batch=32, max_batch=256):
    """
    Increase batch size linearly with epochs.

    WHY: Start small for generalization, increase for speed later.
    Research shows this works well for long training.
    """
    # Linear increase: 0 → 100% over training
    scale = epoch / total_epochs
    return int(base_batch + (max_batch - base_batch) * scale)

# Usage in training loop
for epoch in range(total_epochs):
    batch_size = get_scheduled_batch_size(epoch, total_epochs)

    for batch, target in get_data_loader(batch_size=batch_size):
        # Adjust learning rate dynamically
        lr = 0.001 * (batch_size / 32)  # Scale with batch
        update_learning_rate(optimizer, lr)

        output = model(batch)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

# Alternative: exponential schedule
def get_exponential_batch_schedule(epoch, base_batch=32, max_batch=256):
    """Exponential increase instead of linear"""
    scale = (epoch / total_epochs)
    return int(base_batch * (max_batch / base_batch) ** scale)

When batch size scheduling is valuable:

GOOD FIT:
- Long training (100+ epochs)
- Starting generalization is important
- Speed only matters at end
- Example: ResNet on ImageNet

NOT NEEDED:
- Short training (10-20 epochs)
- Already regularized enough (BERT fine-tuning)
- Batch size well-chosen from start

Pattern 8: Batch Size vs Other Hyperparameters

Understanding interactions with other hyperparameters:

# Interaction 1: Batch size ↔ Learning rate
# Already covered: linear scaling rule

# Interaction 2: Batch size ↔ Weight decay
# Larger batch → worse generalization
# Solution: Increase weight decay when increasing batch
# Typical: weight_decay ~ sqrt(batch_size)

def adjust_weight_decay(base_wd=0.0001, base_batch=256, new_batch=512):
    """Scale weight decay with batch size"""
    return base_wd * (new_batch / base_batch) ** 0.5

# Example
wd_batch_256 = 0.0001
wd_batch_512 = adjust_weight_decay(wd_batch_256, 256, 512)  # 0.000141

# Interaction 3: Batch size ↔ Dropout
# Larger batch → add/increase dropout
# Dropout magnitude depends on layer, typically 0.1-0.5

def adjust_dropout(base_dropout=0.1, base_batch=256, new_batch=512):
    """Increase dropout for larger batches"""
    # Dropout strength ~ sqrt(batch_size)
    scale = (new_batch / base_batch) ** 0.5
    return min(base_dropout * scale, 0.5)  # Cap at 0.5

# Interaction 4: Batch size ↔ Number of epochs
# Larger batch → more epochs needed to converge
# Typical: iterations constant ≈ samples/batch × epochs
# If batch 4x → epochs 1.5-2x to match convergence

base_batch = 64
base_epochs = 100
base_iterations = (50000 / base_batch) * base_epochs  # Total iterations

new_batch = 256
# To maintain same iterations:
new_epochs = base_iterations / (50000 / new_batch)  # ~25 epochs
# Wall-clock faster (fewer iterations) but need fewer epochs

# Interaction 5: Batch size ↔ Optimizer choice
# SGD: works well at all batch sizes
# Momentum: accumulates larger steps, works best with smaller batch
# Adam: adaptive, less sensitive to batch size
# RMSprop: similar to Adam

# Recommendation:
# - Small batch (32-128): SGD with momentum or Adam
# - Large batch (512+): Adam (more stable) or SGD with warmup + large LR

# Interaction 6: Batch size ↔ Normalization technique
# Batch Norm: statistics from batch, larger batch = better stats
# Layer Norm: independent of batch size
# Group Norm: middle ground, works well with any batch size

# If using BatchNorm with small batch (< 16):
# → Use SyncBatchNorm across devices
# → Or use GroupNorm instead

# If using BatchNorm with large batch (> 1024):
# → Standard BatchNorm fine
# → May want to reduce BN momentum (accumulate stats slower)

Rationalization Table: Common Excuses About Batch Size

Rationalization Why It's Wrong Correct Approach
"Larger batch is always better for speed" Wall-clock time depends on iterations AND time-per-iteration. Larger batch may have lower throughput. Profile wall-clock time at different batch sizes, choose fastest.
"I'll tune batch size last, it's not important" Batch size affects convergence rate, generalization, and stability early. Tuning last wastes time. Choose good batch size early (based on memory), validate accuracy.
"Maximum batch that fits = optimal batch" Generalization gap widens with batch size (1-4% typical). Maximum might hit accuracy target. Use 80% of max, validate on validation set, adjust if needed.
"Linear scaling rule means I don't validate" Scaling rule gives convergence rate. Accuracy still varies with batch size due to generalization gap. Always validate test/validation accuracy with new batch.
"Gradient accumulation is slow, don't use it" True, it's slower (1.5-2x). But if memory is bottleneck, only alternative. Choose based on constraints. Use when memory constrained. Accept slowdown. Don't use if memory OK.
"I don't need warmup, I'll just use scaled LR" Large LR jumps cause divergence. Warmup prevents this. Add linear warmup phase for scaled LR.
"My paper used batch X, I'll use that" Different model, data, hardware converge differently. Paper batch might not be optimal for you. Use paper as starting point. Validate and adjust for your setup.
"Fine-tuning uses same batch as pre-training" Fine-tuning needs much smaller batch (preserve knowledge). Using pre-training batch erases pre-training. Use batch 10-20x smaller than pre-training. Use tiny LR.
"Batch size only affects speed, not accuracy" Batch size strongly affects generalization (1-4% gap common). Different final accuracy with different batch. Expect accuracy variation with batch. Validate at each batch size.
"I increased batch, why is training slower?" Fewer iterations (good) but longer per-iteration (bad). Total wall-clock = iterations × time-per-iteration. Profile actual wall-clock time. May need gradient accumulation.
"I'll start with large batch to save memory" Large batch → bad generalization early → harder to recover later. Start small, increase if needed. Start with batch 32-64, increase during training if memory allows.

Comprehensive Example: Training a Vision Transformer

Let's put it all together with a real example:

import torch
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torch.optim import AdamW
from torchvision import models, datasets, transforms

def train_vision_transformer_optimized():
    """
    Complete example: training Vision Transformer with batch size optimization.
    """

    # Step 1: Model and data
    device = torch.device("cuda:0")
    model = models.vit_b_16(pretrained=False).to(device)
    criterion = torch.nn.CrossEntropyLoss()

    # Dataset (ImageNet-scale)
    transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                           (0.229, 0.224, 0.225))
    ])

    # Step 2: Determine batch size
    # ViT-Base: 86M parameters
    # GPU: 40GB A100
    # Memory estimate: params (344MB) + optimizer (688MB) + activations
    # Can fit batch 256-512

    base_batch = 256
    num_accumulation_steps = 1  # Can fit directly
    effective_batch = base_batch

    # Step 3: Initialize optimizer with scaled LR
    # Base LR tuned for batch 256
    base_lr = 1e-4
    scaled_lr = base_lr * (effective_batch / 256)  # 1e-4 (no scaling needed)

    optimizer = AdamW(model.parameters(), lr=scaled_lr, weight_decay=0.05)

    # Step 4: Warmup scheduler
    warmup_steps = 1000
    total_steps = 100 * len(dataset) // effective_batch

    def warmup_cosine_schedule(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        return 0.5 * (1.0 + torch.cos(
            torch.tensor(3.14159) *
            (step - warmup_steps) / (total_steps - warmup_steps)
        )).item()

    scheduler = LambdaLR(optimizer, warmup_cosine_schedule)

    # Step 5: Training loop with gradient accumulation (even though not needed)
    # Good practice for larger models
    model.train()
    optimizer.zero_grad()

    for epoch in range(100):
        for step, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)

            # Forward + backward
            logits = model(images)
            loss = criterion(logits, labels)
            loss = loss / num_accumulation_steps
            loss.backward()

            # Update on accumulation step
            if (step + 1) % num_accumulation_steps == 0:
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            if step % 100 == 0:
                print(f"Epoch {epoch}, step {step}, loss {loss.item():.4f}, "
                      f"lr {optimizer.param_groups[0]['lr']:.2e}")

        # Validate every epoch
        val_accuracy = validate(model, device)
        print(f"Epoch {epoch} validation accuracy: {val_accuracy:.2%}")

    return model

# Key patterns demonstrated:
# 1. Batch size chosen based on memory (80% of max)
# 2. Learning rate scaled for batch size
# 3. Warmup phase for gradual LR increase
# 4. Cosine annealing for LR decay
# 5. Gradient accumulation structure (even if not needed)
# 6. Gradient clipping for stability
# 7. Regular validation to monitor accuracy

Summary: Batch Size and Memory Decision Making

The core principle: Batch size is a system design choice affecting convergence, generalization, speed, and memory simultaneously. There is no universal "right" batch size - it depends on your constraints and priorities.

The decision process:

  1. Memory constraint: Start with 80% of maximum batch
  2. Convergence: Scale learning rate 1:1 with batch increase (with warmup)
  3. Generalization: Validate accuracy, reduce if gap >2% (or add regularization)
  4. Performance: Profile wall-clock time at different batch sizes
  5. Architecture: Different models have different optimal batches

The key insights:

  • Larger batch = faster iterations but worse generalization
  • Linear scaling rule requires warmup for large increases
  • Gradient accumulation is a legitimate technique (understand slowdown cost)
  • Fine-tuning requires smaller batch than pre-training
  • Distributed training needs care with batch norm and gradient updates
  • Always measure, validate, and adjust - don't assume rules apply to your case

The testing approach:

When pressure-tested, this skill should:

  • Explain why maximum batch ≠ optimal batch (generalization gap)
  • Provide concrete examples of linear scaling rule with warmup
  • Address gradient accumulation systematically (when, why, cost)
  • Discuss memory estimation and optimization techniques
  • Help select batch size based on constraints AND priorities
  • Resist rationalizations and always recommend validation

References and Further Reading

Key papers:

  • Goyal et al. (2017) "Accurate, Large Batch Training" - Linear scaling rule
  • You et al. (2019) "Large Batch Optimization for Deep Learning" - Theory
  • Smith et al. (2017) "Don't Decay the Learning Rate" - Learning rate schedules

Techniques mentioned:

  • Batch Normalization: Ioffe & Szegedy (2015)
  • Layer Normalization: Ba et al. (2016)
  • Mixed Precision Training: Micikevicius et al. (2017)
  • Gradient Checkpointing: Chen et al. (2016)

Related Yzmir Skills:

  • learning-rate-scheduling - LR schedule choices beyond linear scaling
  • gradient-management - Gradient clipping and accumulation for stability
  • optimization-algorithms - Optimizer selection and hyperparameter tuning