| name | pytorch-model-recovery |
| description | Guidance for recovering PyTorch model architectures from state dictionaries, retraining specific layers, and saving models in TorchScript format. This skill should be used when tasks involve reconstructing model architectures from saved weights, fine-tuning specific layers while freezing others, or converting models to TorchScript format. |
PyTorch Model Recovery
This skill provides guidance for tasks involving PyTorch model architecture recovery from state dictionaries, selective layer training, and TorchScript export.
When to Use This Skill
This skill applies when:
- Reconstructing a model architecture from a state dictionary (
.ptor.pthfile containing weights) - Training or fine-tuning specific layers while keeping others frozen
- Converting a recovered model to TorchScript format
- Debugging model loading issues or architecture mismatches
Approach Overview
Model recovery tasks require a systematic, incremental approach with verification at each step. The key phases are:
- Architecture Analysis - Infer model structure from state dictionary keys
- Architecture Implementation - Build the model class to match the state dict
- Verification - Confirm weights load correctly before any training
- Training - Fine-tune specific layers with appropriate hyperparameters
- Export - Save to required format (often TorchScript)
Phase 1: Architecture Analysis
Examining the State Dictionary
To understand the model architecture, first load and inspect the state dictionary:
import torch
weights = torch.load('model_weights.pt', map_location='cpu')
# Print all keys with shapes
for key, value in weights.items():
print(f"{key}: {value.shape}")
Key Patterns to Identify
Common patterns in state dictionary keys:
| Key Pattern | Indicates |
|---|---|
encoder.layers.N.* |
Transformer encoder with N+1 layers |
decoder.layers.N.* |
Transformer decoder with N+1 layers |
embedding.weight |
Embedding layer |
pos_encoder.pe |
Positional encoding (often a buffer) |
output_layer.weight/bias |
Final linear projection |
*.in_proj_weight |
Combined QKV projection in attention |
*.self_attn.* |
Self-attention component |
*.linear1/linear2.* |
Feed-forward network layers |
*.norm1/norm2.* |
Layer normalization |
Inferring Dimensions
Extract model dimensions from weight shapes:
# Example: Inferring transformer dimensions
d_model = weights['encoder.layers.0.self_attn.in_proj_weight'].shape[1]
nhead = weights['encoder.layers.0.self_attn.in_proj_weight'].shape[0] // (3 * d_model) * nhead_factor
# Note: in_proj_weight has shape [3*d_model, d_model] for combined QKV
vocab_size = weights['embedding.weight'].shape[0]
num_layers = max(int(k.split('.')[2]) for k in weights if 'encoder.layers' in k) + 1
Phase 2: Architecture Implementation
Building the Model Class
When implementing the model class:
- Match the exact layer names used in the state dictionary
- Use the same PyTorch module types (e.g.,
nn.TransformerEncodervs custom) - Register buffers for non-learnable tensors (e.g., positional encodings)
class RecoveredModel(nn.Module):
def __init__(self, vocab_size, d_model, nhead, num_layers, dim_feedforward):
super().__init__()
# Ensure attribute names match state dict keys exactly
self.embedding = nn.Embedding(vocab_size, d_model)
# For positional encoding stored as buffer
self.pos_encoder = PositionalEncoding(d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
batch_first=True # Check if original used batch_first
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.output_layer = nn.Linear(d_model, vocab_size)
Common Architecture Mistakes
- Incorrect layer naming:
self.fcvsself.output_layer- must match exactly - Missing buffers: Positional encodings often registered as buffers, not parameters
- Wrong module types: Custom attention vs
nn.MultiheadAttention - Batch dimension mismatch:
batch_first=Truevsbatch_first=False
Phase 3: Verification (Critical)
Verify Architecture Before Training
Always verify the model loads weights correctly before any training:
model = RecoveredModel(...)
# This will raise an error if keys don't match
model.load_state_dict(weights, strict=True)
print("Weights loaded successfully!")
# Verify a forward pass works
with torch.no_grad():
dummy_input = torch.randint(0, vocab_size, (1, 10))
output = model(dummy_input)
print(f"Output shape: {output.shape}")
Handling Key Mismatches
If load_state_dict fails, compare keys:
model_keys = set(model.state_dict().keys())
weight_keys = set(weights.keys())
missing = weight_keys - model_keys
unexpected = model_keys - weight_keys
print(f"Missing in model: {missing}")
print(f"Unexpected in model: {unexpected}")
Verify TorchScript Compatibility Early
If TorchScript export is required, test it early:
# Test scripting works before investing time in training
try:
scripted = torch.jit.script(model)
print("TorchScript scripting successful")
except Exception as e:
print(f"Scripting failed: {e}")
# Try tracing instead
traced = torch.jit.trace(model, dummy_input)
print("TorchScript tracing successful")
Phase 4: Training Specific Layers
Freezing Layers
To train only specific layers, freeze all others:
# Freeze all parameters first
for param in model.parameters():
param.requires_grad = False
# Unfreeze only target layers
for param in model.output_layer.parameters():
param.requires_grad = True
# Verify freeze status
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} parameters")
Computing Baseline Loss
Before training, establish a baseline:
model.eval()
with torch.no_grad():
outputs = model(inputs)
original_loss = criterion(outputs, targets)
print(f"Original MSE loss: {original_loss.item()}")
Training Loop Considerations
# Create optimizer only for trainable parameters
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=0.001
)
# Training with progress tracking
for epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f"Epoch {epoch}: Loss = {loss.item():.6f}")
Alternative: Closed-Form Solution for Linear Layers
When retraining only a linear output layer, consider a closed-form solution for efficiency:
# Pre-compute frozen layer outputs
model.eval()
with torch.no_grad():
# Get features before output layer
features = model.get_features(inputs) # Shape: [N, d_model]
# Solve linear regression: W*features = targets
# Using pseudo-inverse: W = targets @ features.T @ (features @ features.T)^-1
solution = torch.linalg.lstsq(features, targets).solution
model.output_layer.weight.data = solution.T
Phase 5: TorchScript Export
Saving the Model
# Ensure model is in eval mode
model.eval()
# Script the model (preferred for control flow)
scripted_model = torch.jit.script(model)
scripted_model.save('/app/model.pt')
# Or trace the model (for simpler models)
traced_model = torch.jit.trace(model, example_input)
traced_model.save('/app/model.pt')
Verify Saved Model
# Reload and verify
loaded = torch.jit.load('/app/model.pt')
loaded.eval()
with torch.no_grad():
original_out = model(test_input)
loaded_out = loaded(test_input)
diff = (original_out - loaded_out).abs().max()
print(f"Max difference: {diff.item()}")
assert diff < 1e-5, "Model outputs don't match!"
Environment Considerations
Handling Slow Environments
When operating in resource-constrained environments:
Benchmark first: Test basic operations before committing to full solution
import time start = time.time() _ = model(torch.randint(0, vocab_size, (1, 10))) print(f"Single forward pass: {time.time() - start:.2f}s")Reduce batch size: Process samples individually if needed
Set realistic timeouts: Base on benchmarks, not arbitrary values
Use incremental checkpoints: Save progress periodically
Memory Management
# Clear GPU cache between operations
torch.cuda.empty_cache()
# Use gradient checkpointing for large models
from torch.utils.checkpoint import checkpoint
# Process in smaller batches
for batch in torch.split(data, batch_size):
process(batch)
Common Pitfalls
- Not verifying architecture match before training - Always test
load_state_dictfirst - Arbitrary hyperparameters - Justify choices based on task characteristics
- Ignoring TorchScript compatibility - Test export early, not after training
- Syntax errors in edits - Review code changes carefully, especially string formatting
- Incomplete state dict mapping - Verify all keys are accounted for
- Not establishing baseline metrics - Compute original loss before training
- Missing
torch.no_grad()for inference - Use context manager for evaluation - Forgetting to set
model.eval()- Required for consistent behavior in eval/export
Verification Checklist
Before considering the task complete:
- State dictionary keys fully analyzed and documented
- Model architecture matches state dict exactly (verified with
load_state_dict) - Forward pass produces valid output
- Baseline loss/metric computed
- Target layers correctly unfrozen, others frozen
- Training improves loss over baseline
- TorchScript export succeeds
- Exported model produces same outputs as original
- Model saved to required path