| name | training-loop-architecture |
| description | Training loops - checkpointing, validation, memory management, proper train/val/test splits |
Training Loop Architecture
Overview
Core Principle: A properly structured training loop is the foundation of all successful deep learning projects. Success requires: (1) correct train/val/test data separation, (2) validation after EVERY epoch (not just once), (3) complete checkpoint state (model + optimizer + scheduler), (4) comprehensive logging/monitoring, and (5) graceful error handling. Poor loop structure causes: silent overfitting, broken resume functionality, undetectable training issues, and memory leaks.
Training loop failures manifest as: overfitting with good metrics, crashes on resume, unexplained loss spikes, or out-of-memory errors. These stem from misunderstanding when validation runs, what state must be saved, or how to manage GPU memory. Systematic architecture beats trial-and-error fixes.
When to Use
Use this skill when:
- Implementing a new training loop from scratch
- Training loop is crashing unexpectedly
- Can't resume training from checkpoint correctly
- Model overfits but validation metrics look good
- Out-of-memory errors during training
- Unsure about train/val/test data split
- Need to monitor training progress properly
- Implementing early stopping or checkpoint selection
- Training loops show loss spikes or divergence on resume
- Adding logging/monitoring to training
Don't use when:
- Debugging single backward pass (use gradient-management skill)
- Tuning learning rate (use learning-rate-scheduling skill)
- Fixing specific loss function (use loss-functions-and-objectives skill)
- Data loading issues (use data-augmentation-strategies skill)
Symptoms triggering this skill:
- "Training loss decreases but validation loss increases (overfitting)"
- "Training crashes when resuming from checkpoint"
- "Out of memory errors after epoch 20"
- "I validated on training data and didn't realize"
- "Can't detect overfitting because I don't validate"
- "Training loss spikes when resuming"
- "My checkpoint doesn't load correctly"
Complete Training Loop Structure
1. The Standard Training Loop (The Reference)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import logging
from pathlib import Path
# Setup logging (ALWAYS do this first)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class TrainingLoop:
"""Complete training loop with validation, checkpointing, and monitoring."""
def __init__(self, model, optimizer, scheduler, criterion, device='cuda'):
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.criterion = criterion
self.device = device
# Tracking metrics
self.train_losses = []
self.val_losses = []
self.best_val_loss = float('inf')
self.epochs_without_improvement = 0
def train_epoch(self, train_loader):
"""Train for one epoch."""
self.model.train()
total_loss = 0.0
num_batches = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(self.device), target.to(self.device)
# Forward pass
self.optimizer.zero_grad()
output = self.model(data)
loss = self.criterion(output, target)
# Backward pass with gradient clipping (if needed)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
# Accumulate loss
total_loss += loss.item()
num_batches += 1
# Log progress every 10 batches
if batch_idx % 10 == 0:
logger.debug(f"Batch {batch_idx}: loss={loss.item():.4f}")
avg_loss = total_loss / num_batches
return avg_loss
def validate_epoch(self, val_loader):
"""Validate on validation set (AFTER each epoch, not during)."""
self.model.eval()
total_loss = 0.0
num_batches = 0
with torch.no_grad(): # ✅ CRITICAL: No gradients during validation
for data, target in val_loader:
data, target = data.to(self.device), target.to(self.device)
output = self.model(data)
loss = self.criterion(output, target)
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / num_batches
return avg_loss
def save_checkpoint(self, epoch, val_loss, checkpoint_dir='checkpoints'):
"""Save complete checkpoint (model + optimizer + scheduler)."""
Path(checkpoint_dir).mkdir(exist_ok=True)
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'val_loss': val_loss,
'train_losses': self.train_losses,
'val_losses': self.val_losses,
}
# Save last checkpoint
torch.save(checkpoint, f'{checkpoint_dir}/checkpoint_latest.pt')
# Save best checkpoint
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
torch.save(checkpoint, f'{checkpoint_dir}/checkpoint_best.pt')
logger.info(f"New best validation loss: {val_loss:.4f}")
def load_checkpoint(self, checkpoint_path):
"""Load checkpoint and resume training correctly."""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
# ✅ CRITICAL ORDER: Load model, optimizer, scheduler (in that order)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
# Restore metrics history
self.train_losses = checkpoint['train_losses']
self.val_losses = checkpoint['val_losses']
self.best_val_loss = min(self.val_losses) if self.val_losses else float('inf')
epoch = checkpoint['epoch']
logger.info(f"Loaded checkpoint from epoch {epoch}")
return epoch
def train(self, train_loader, val_loader, num_epochs, checkpoint_dir='checkpoints'):
"""Full training loop with validation and checkpointing."""
start_epoch = 0
# Try to resume from checkpoint if it exists
checkpoint_path = f'{checkpoint_dir}/checkpoint_latest.pt'
if Path(checkpoint_path).exists():
start_epoch = self.load_checkpoint(checkpoint_path)
logger.info(f"Resuming training from epoch {start_epoch}")
for epoch in range(start_epoch, num_epochs):
try:
# Train for one epoch
train_loss = self.train_epoch(train_loader)
self.train_losses.append(train_loss)
# ✅ CRITICAL: Validate after every epoch
val_loss = self.validate_epoch(val_loader)
self.val_losses.append(val_loss)
# Step scheduler (after epoch)
self.scheduler.step()
# Log metrics
logger.info(
f"Epoch {epoch}: train_loss={train_loss:.4f}, "
f"val_loss={val_loss:.4f}, lr={self.optimizer.param_groups[0]['lr']:.2e}"
)
# Checkpoint every epoch
self.save_checkpoint(epoch, val_loss, checkpoint_dir)
# Early stopping (optional)
if val_loss < self.best_val_loss:
self.epochs_without_improvement = 0
else:
self.epochs_without_improvement += 1
if self.epochs_without_improvement >= 10:
logger.info(f"Early stopping at epoch {epoch}")
break
except KeyboardInterrupt:
logger.info("Training interrupted by user")
self.save_checkpoint(epoch, val_loss, checkpoint_dir)
break
except RuntimeError as e:
logger.error(f"Error in epoch {epoch}: {e}")
raise
logger.info("Training complete")
return self.model
2. Data Split: Train/Val/Test Separation (CRITICAL)
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset, Dataset
# ✅ CORRECT: Proper three-way split with NO data leakage
class DataSplitter:
"""Ensures clean train/val/test splits without data leakage."""
@staticmethod
def split_dataset(dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, random_state=42):
"""
Split dataset into train/val/test.
CRITICAL: Split indices first, then create loaders.
This prevents any data leakage.
"""
assert train_ratio + val_ratio + test_ratio == 1.0
n = len(dataset)
indices = list(range(n))
# First split: train vs (val + test)
train_size = int(train_ratio * n)
train_indices = indices[:train_size]
remaining_indices = indices[train_size:]
# Second split: val vs test
remaining_size = len(remaining_indices)
val_size = int(val_ratio / (val_ratio + test_ratio) * remaining_size)
val_indices = remaining_indices[:val_size]
test_indices = remaining_indices[val_size:]
# Create subset datasets (same transforms, different data)
train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
test_dataset = Subset(dataset, test_indices)
logger.info(
f"Dataset split: train={len(train_dataset)}, "
f"val={len(val_dataset)}, test={len(test_dataset)}"
)
return train_dataset, val_dataset, test_dataset
# Usage
train_dataset, val_dataset, test_dataset = DataSplitter.split_dataset(
full_dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15
)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
# ✅ CRITICAL: Validate that splits are actually different
print(f"Train samples: {len(train_loader.dataset)}")
print(f"Val samples: {len(val_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")
# ✅ CRITICAL: Never mix splits (don't re-shuffle or combine)
3. Monitoring and Logging (Reproducibility)
import json
from datetime import datetime
class TrainingMonitor:
"""Track all metrics for reproducibility and debugging."""
def __init__(self, log_dir='logs'):
self.log_dir = Path(log_dir)
self.log_dir.mkdir(exist_ok=True)
# Metrics to track
self.metrics = {
'timestamp': datetime.now().isoformat(),
'epochs': [],
'train_losses': [],
'val_losses': [],
'learning_rates': [],
'gradient_norms': [],
'batch_times': [],
}
def log_epoch(self, epoch, train_loss, val_loss, lr, gradient_norm=None, batch_time=None):
"""Log metrics for one epoch."""
self.metrics['epochs'].append(epoch)
self.metrics['train_losses'].append(train_loss)
self.metrics['val_losses'].append(val_loss)
self.metrics['learning_rates'].append(lr)
if gradient_norm is not None:
self.metrics['gradient_norms'].append(gradient_norm)
if batch_time is not None:
self.metrics['batch_times'].append(batch_time)
def save_metrics(self):
"""Save metrics to JSON for post-training analysis."""
metrics_path = self.log_dir / f'metrics_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json'
with open(metrics_path, 'w') as f:
json.dump(self.metrics, f, indent=2)
logger.info(f"Metrics saved to {metrics_path}")
def plot_metrics(self):
"""Plot training curves."""
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
# Loss curves
axes[0, 0].plot(self.metrics['epochs'], self.metrics['train_losses'], label='Train')
axes[0, 0].plot(self.metrics['epochs'], self.metrics['val_losses'], label='Val')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].set_title('Training and Validation Loss')
# Learning rate schedule
axes[0, 1].plot(self.metrics['epochs'], self.metrics['learning_rates'])
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Learning Rate')
axes[0, 1].set_title('Learning Rate Schedule')
axes[0, 1].set_yscale('log')
# Gradient norms (if available)
if self.metrics['gradient_norms']:
axes[1, 0].plot(self.metrics['epochs'], self.metrics['gradient_norms'])
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Gradient Norm')
axes[1, 0].set_title('Gradient Norms')
# Batch times (if available)
if self.metrics['batch_times']:
axes[1, 1].plot(self.metrics['epochs'], self.metrics['batch_times'])
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Time (seconds)')
axes[1, 1].set_title('Batch Processing Time')
plt.tight_layout()
plot_path = self.log_dir / f'training_curves_{datetime.now().strftime("%Y%m%d_%H%M%S")}.png'
plt.savefig(plot_path)
logger.info(f"Plot saved to {plot_path}")
4. Checkpointing and Resuming (Complete State)
class CheckpointManager:
"""Properly save and load ALL training state."""
def __init__(self, checkpoint_dir='checkpoints'):
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(exist_ok=True)
def save_full_checkpoint(self, epoch, model, optimizer, scheduler, metrics, path_suffix=''):
"""Save COMPLETE state for resuming training."""
checkpoint = {
# Model and optimizer state
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
# Training metrics (for monitoring)
'train_losses': metrics['train_losses'],
'val_losses': metrics['val_losses'],
'learning_rates': metrics['learning_rates'],
# Timestamp for recovery
'timestamp': datetime.now().isoformat(),
}
# Save as latest
latest_path = self.checkpoint_dir / f'checkpoint_latest{path_suffix}.pt'
torch.save(checkpoint, latest_path)
# Save periodically (every 10 epochs)
if epoch % 10 == 0:
periodic_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch:04d}.pt'
torch.save(checkpoint, periodic_path)
logger.info(f"Checkpoint saved: {latest_path}")
return latest_path
def load_full_checkpoint(self, model, optimizer, scheduler, checkpoint_path):
"""Load COMPLETE state correctly."""
if not Path(checkpoint_path).exists():
logger.warning(f"Checkpoint not found: {checkpoint_path}")
return 0, None
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# ✅ CRITICAL ORDER: Model first, then optimizer, then scheduler
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
epoch = checkpoint['epoch']
metrics = {
'train_losses': checkpoint.get('train_losses', []),
'val_losses': checkpoint.get('val_losses', []),
'learning_rates': checkpoint.get('learning_rates', []),
}
logger.info(
f"Loaded checkpoint from epoch {epoch}, "
f"saved at {checkpoint.get('timestamp', 'unknown')}"
)
return epoch, metrics
def get_best_checkpoint(self):
"""Find checkpoint with best validation loss."""
checkpoints = list(self.checkpoint_dir.glob('checkpoint_epoch_*.pt'))
if not checkpoints:
return None
best_loss = float('inf')
best_path = None
for ckpt_path in checkpoints:
checkpoint = torch.load(ckpt_path, map_location='cpu')
val_losses = checkpoint.get('val_losses', [])
if val_losses and min(val_losses) < best_loss:
best_loss = min(val_losses)
best_path = ckpt_path
return best_path
5. Memory Management (Prevent Leaks)
class MemoryManager:
"""Prevent out-of-memory errors during long training."""
def __init__(self, device='cuda'):
self.device = device
def clear_cache(self):
"""Clear GPU cache between epochs."""
if self.device.startswith('cuda'):
torch.cuda.empty_cache()
# Optional: clear CUDA graphs
torch.cuda.synchronize()
def check_memory(self):
"""Log GPU memory usage."""
if self.device.startswith('cuda'):
allocated = torch.cuda.memory_allocated() / 1e9
reserved = torch.cuda.memory_reserved() / 1e9
logger.info(f"GPU memory - allocated: {allocated:.2f}GB, reserved: {reserved:.2f}GB")
def training_loop_with_memory_management(self, model, train_loader, optimizer, criterion):
"""Training loop with proper memory management."""
model.train()
total_loss = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(self.device), target.to(self.device)
# Forward and backward
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
# ✅ Clear temporary tensors (data and target go out of scope)
# ✅ Don't hold onto loss or output after using them
# Periodically check memory
if batch_idx % 100 == 0:
self.check_memory()
# Clear cache between epochs
self.clear_cache()
return total_loss / len(train_loader)
Error Handling and Recovery
class RobustTrainingLoop:
"""Training loop with proper error handling."""
def train_with_error_handling(self, model, train_loader, val_loader, optimizer,
scheduler, criterion, num_epochs, checkpoint_dir):
"""Training with error recovery."""
checkpoint_manager = CheckpointManager(checkpoint_dir)
memory_manager = MemoryManager()
# Resume from last checkpoint if available
start_epoch, metrics = checkpoint_manager.load_full_checkpoint(
model, optimizer, scheduler, f'{checkpoint_dir}/checkpoint_latest.pt'
)
for epoch in range(start_epoch, num_epochs):
try:
# Train
train_loss = self.train_epoch(model, train_loader, optimizer, criterion)
# Validate
val_loss = self.validate_epoch(model, val_loader, criterion)
# Update scheduler
scheduler.step()
# Log
logger.info(
f"Epoch {epoch}: train={train_loss:.4f}, val={val_loss:.4f}, "
f"lr={optimizer.param_groups[0]['lr']:.2e}"
)
# Checkpoint
checkpoint_manager.save_full_checkpoint(
epoch, model, optimizer, scheduler,
{'train_losses': [train_loss], 'val_losses': [val_loss]}
)
# Memory management
memory_manager.clear_cache()
except KeyboardInterrupt:
logger.warning("Training interrupted - checkpoint saved")
checkpoint_manager.save_full_checkpoint(
epoch, model, optimizer, scheduler,
{'train_losses': [train_loss], 'val_losses': [val_loss]}
)
break
except RuntimeError as e:
if 'out of memory' in str(e).lower():
logger.error("Out of memory error")
memory_manager.clear_cache()
# Try to continue (reduce batch size in real scenario)
raise
else:
logger.error(f"Runtime error: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error in epoch {epoch}: {e}")
checkpoint_manager.save_full_checkpoint(
epoch, model, optimizer, scheduler,
{'train_losses': [train_loss], 'val_losses': [val_loss]}
)
raise
return model
Common Pitfalls and How to Avoid Them
Pitfall 1: Validating on Training Data
# ❌ WRONG
val_loader = train_loader # Same loader!
# ✅ CORRECT
train_dataset, val_dataset = split_dataset(full_dataset)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
Pitfall 2: Missing Optimizer State in Checkpoint
# ❌ WRONG
torch.save({'model': model.state_dict()}, 'ckpt.pt')
# ✅ CORRECT
torch.save({
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
}, 'ckpt.pt')
Pitfall 3: Not Validating During Training
# ❌ WRONG
for epoch in range(100):
train_epoch()
final_val = evaluate() # Only at the end!
# ✅ CORRECT
for epoch in range(100):
train_epoch()
validate_epoch() # After every epoch
Pitfall 4: Holding Onto Tensor References
# ❌ WRONG
all_losses = []
for data, target in loader:
loss = criterion(model(data), target)
all_losses.append(loss) # Memory leak!
# ✅ CORRECT
total_loss = 0.0
for data, target in loader:
loss = criterion(model(data), target)
total_loss += loss.item() # Scalar value
Pitfall 5: Forgetting torch.no_grad() in Validation
# ❌ WRONG
model.eval()
for data, target in val_loader:
output = model(data) # Gradients still computed!
loss = criterion(output, target)
# ✅ CORRECT
model.eval()
with torch.no_grad():
for data, target in val_loader:
output = model(data) # No gradients
loss = criterion(output, target)
Pitfall 6: Resetting Scheduler on Resume
# ❌ WRONG
checkpoint = torch.load('ckpt.pt')
model.load_state_dict(checkpoint['model'])
scheduler = CosineAnnealingLR(optimizer, T_max=100) # Fresh scheduler!
# Now at epoch 50, scheduler thinks it's epoch 0
# ✅ CORRECT
checkpoint = torch.load('ckpt.pt')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler']) # Resume scheduler state
Pitfall 7: Not Handling Early Stopping Correctly
# ❌ WRONG
best_loss = float('inf')
for epoch in range(100):
val_loss = validate()
if val_loss < best_loss:
best_loss = val_loss
# No checkpoint! Can't recover best model
# ✅ CORRECT
best_loss = float('inf')
patience = 10
patience_counter = 0
for epoch in range(100):
val_loss = validate()
if val_loss < best_loss:
best_loss = val_loss
patience_counter = 0
save_checkpoint(model, optimizer, scheduler, epoch) # Save best
else:
patience_counter += 1
if patience_counter >= patience:
break # Stop early
Pitfall 8: Mixing Train and Validation Mode
# ❌ WRONG
for epoch in range(100):
for data, target in train_loader:
output = model(data) # Is model in train or eval mode?
loss = criterion(output, target)
# ✅ CORRECT
model.train()
for epoch in range(100):
for data, target in train_loader:
output = model(data) # Definitely in train mode
loss = criterion(output, target)
model.eval()
with torch.no_grad():
for data, target in val_loader:
output = model(data) # Definitely in eval mode
Pitfall 9: Loading Checkpoint on Wrong Device
# ❌ WRONG
checkpoint = torch.load('ckpt.pt') # Loads on GPU if saved on GPU
model.load_state_dict(checkpoint['model']) # Might be on wrong device
# ✅ CORRECT
checkpoint = torch.load('ckpt.pt', map_location='cuda:0') # Specify device
model.load_state_dict(checkpoint['model'])
model.to('cuda:0') # Move to device
Pitfall 10: Not Clearing GPU Cache
# ❌ WRONG
for epoch in range(100):
train_epoch()
validate_epoch()
# GPU cache growing every epoch
# ✅ CORRECT
for epoch in range(100):
train_epoch()
validate_epoch()
torch.cuda.empty_cache() # Clear cache
Integration with Optimization Techniques
Complete Training Loop with All Techniques
class FullyOptimizedTrainingLoop:
"""Integrates: gradient clipping, mixed precision, learning rate scheduling."""
def train_with_all_techniques(self, model, train_loader, val_loader,
num_epochs, checkpoint_dir='checkpoints'):
"""Training with all optimization techniques integrated."""
# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
criterion = nn.CrossEntropyLoss()
# Mixed precision (if using AMP)
scaler = torch.cuda.amp.GradScaler()
# Training loop
for epoch in range(num_epochs):
model.train()
total_loss = 0.0
for data, target in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# Mixed precision forward pass
with torch.autocast('cuda'):
output = model(data)
loss = criterion(output, target)
# Gradient scaling for mixed precision
scaler.scale(loss).backward()
# Gradient clipping (unscale first!)
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Optimizer step
scaler.step(optimizer)
scaler.update()
total_loss += loss.item()
train_loss = total_loss / len(train_loader)
# Validation
model.eval()
val_loss = 0.0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
output = model(data)
val_loss += criterion(output, target).item()
val_loss /= len(val_loader)
# Scheduler step
scheduler.step()
logger.info(
f"Epoch {epoch}: train={train_loss:.4f}, val={val_loss:.4f}, "
f"lr={optimizer.param_groups[0]['lr']:.2e}"
)
return model
Rationalization Table: When to Deviate from Standard
| Situation | Standard Practice | Deviation | Rationale |
|---|---|---|---|
| Validate only at end | Validate every epoch | ✗ Never | Can't detect overfitting |
| Save only model | Save model + optimizer + scheduler | ✗ Never | Resume training breaks |
| Mixed train/val | Separate datasets completely | ✗ Never | Data leakage and false metrics |
| Constant batch size | Fix batch size for reproducibility | ✓ Sometimes | May need dynamic batching for memory |
| Single LR | Use scheduler | ✓ Sometimes | <10 epoch training or hyperparameter search |
| No early stopping | Implement early stopping | ✓ Sometimes | If training time unlimited |
| Log every batch | Log every 10-100 batches | ✓ Often | Reduces I/O overhead |
| GPU cache every epoch | Clear GPU cache periodically | ✓ Sometimes | Only if OOM issues |
Red Flags: Immediate Warning Signs
- Training loss much lower than validation loss (>2x) → Overfitting
- Loss spikes on resume → Optimizer state not loaded
- GPU memory grows over time → Memory leak, likely tensor accumulation
- Validation never runs → Check if validation is in loop
- Best model not saved → Check checkpoint logic
- Different results on resume → Scheduler not loaded
- Early stopping not working → Checkpoint not at best model
- OOM during training → Clear GPU cache, check for accumulated tensors
Testing Your Training Loop
def test_training_loop():
"""Quick test to verify training loop is correct."""
# Create dummy data
X_train = torch.randn(100, 10)
y_train = torch.randint(0, 2, (100,))
X_val = torch.randn(20, 10)
y_val = torch.randint(0, 2, (20,))
train_loader = DataLoader(
list(zip(X_train, y_train)), batch_size=16
)
val_loader = DataLoader(
list(zip(X_val, y_val)), batch_size=16
)
# Simple model
model = nn.Sequential(
nn.Linear(10, 64),
nn.ReLU(),
nn.Linear(64, 2)
)
# Training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
criterion = nn.CrossEntropyLoss()
loop = TrainingLoop(model, optimizer, scheduler, criterion)
# Should complete without errors
loop.train(train_loader, val_loader, num_epochs=5, checkpoint_dir='test_ckpts')
# Check outputs
assert len(loop.train_losses) == 5
assert len(loop.val_losses) == 5
assert all(isinstance(l, float) for l in loop.train_losses)
print("✓ Training loop test passed")
if __name__ == '__main__':
test_training_loop()