Claude Code Plugins

Community-maintained marketplace

Feedback

pytorch-model-recovery

@majiayu000/claude-skill-registry
5
0

Guidance for recovering PyTorch model architectures from state dictionaries, retraining specific layers, and saving models in TorchScript format. This skill should be used when tasks involve reconstructing model architectures from saved weights, fine-tuning specific layers while freezing others, or converting models to TorchScript format.

Install Skill

1Download skill
2Enable skills in Claude

Open claude.ai/settings/capabilities and find the "Skills" section

3Upload to Claude

Click "Upload skill" and select the downloaded ZIP file

Note: Please verify skill by going through its instructions before using it.

SKILL.md

name pytorch-model-recovery
description Guidance for recovering PyTorch model architectures from state dictionaries, retraining specific layers, and saving models in TorchScript format. This skill should be used when tasks involve reconstructing model architectures from saved weights, fine-tuning specific layers while freezing others, or converting models to TorchScript format.

PyTorch Model Recovery

This skill provides guidance for tasks involving PyTorch model architecture recovery from state dictionaries, selective layer training, and TorchScript export.

When to Use This Skill

This skill applies when:

  • Reconstructing a model architecture from a state dictionary (.pt or .pth file containing weights)
  • Training or fine-tuning specific layers while keeping others frozen
  • Converting a recovered model to TorchScript format
  • Debugging model loading issues or architecture mismatches

Approach Overview

Model recovery tasks require a systematic, incremental approach with verification at each step. The key phases are:

  1. Architecture Analysis - Infer model structure from state dictionary keys
  2. Architecture Implementation - Build the model class to match the state dict
  3. Verification - Confirm weights load correctly before any training
  4. Training - Fine-tune specific layers with appropriate hyperparameters
  5. Export - Save to required format (often TorchScript)

Phase 1: Architecture Analysis

Examining the State Dictionary

To understand the model architecture, first load and inspect the state dictionary:

import torch

weights = torch.load('model_weights.pt', map_location='cpu')

# Print all keys with shapes
for key, value in weights.items():
    print(f"{key}: {value.shape}")

Key Patterns to Identify

Common patterns in state dictionary keys:

Key Pattern Indicates
encoder.layers.N.* Transformer encoder with N+1 layers
decoder.layers.N.* Transformer decoder with N+1 layers
embedding.weight Embedding layer
pos_encoder.pe Positional encoding (often a buffer)
output_layer.weight/bias Final linear projection
*.in_proj_weight Combined QKV projection in attention
*.self_attn.* Self-attention component
*.linear1/linear2.* Feed-forward network layers
*.norm1/norm2.* Layer normalization

Inferring Dimensions

Extract model dimensions from weight shapes:

# Example: Inferring transformer dimensions
d_model = weights['encoder.layers.0.self_attn.in_proj_weight'].shape[1]
nhead = weights['encoder.layers.0.self_attn.in_proj_weight'].shape[0] // (3 * d_model) * nhead_factor
# Note: in_proj_weight has shape [3*d_model, d_model] for combined QKV

vocab_size = weights['embedding.weight'].shape[0]
num_layers = max(int(k.split('.')[2]) for k in weights if 'encoder.layers' in k) + 1

Phase 2: Architecture Implementation

Building the Model Class

When implementing the model class:

  1. Match the exact layer names used in the state dictionary
  2. Use the same PyTorch module types (e.g., nn.TransformerEncoder vs custom)
  3. Register buffers for non-learnable tensors (e.g., positional encodings)
class RecoveredModel(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers, dim_feedforward):
        super().__init__()
        # Ensure attribute names match state dict keys exactly
        self.embedding = nn.Embedding(vocab_size, d_model)

        # For positional encoding stored as buffer
        self.pos_encoder = PositionalEncoding(d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            batch_first=True  # Check if original used batch_first
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.output_layer = nn.Linear(d_model, vocab_size)

Common Architecture Mistakes

  • Incorrect layer naming: self.fc vs self.output_layer - must match exactly
  • Missing buffers: Positional encodings often registered as buffers, not parameters
  • Wrong module types: Custom attention vs nn.MultiheadAttention
  • Batch dimension mismatch: batch_first=True vs batch_first=False

Phase 3: Verification (Critical)

Verify Architecture Before Training

Always verify the model loads weights correctly before any training:

model = RecoveredModel(...)

# This will raise an error if keys don't match
model.load_state_dict(weights, strict=True)
print("Weights loaded successfully!")

# Verify a forward pass works
with torch.no_grad():
    dummy_input = torch.randint(0, vocab_size, (1, 10))
    output = model(dummy_input)
    print(f"Output shape: {output.shape}")

Handling Key Mismatches

If load_state_dict fails, compare keys:

model_keys = set(model.state_dict().keys())
weight_keys = set(weights.keys())

missing = weight_keys - model_keys
unexpected = model_keys - weight_keys

print(f"Missing in model: {missing}")
print(f"Unexpected in model: {unexpected}")

Verify TorchScript Compatibility Early

If TorchScript export is required, test it early:

# Test scripting works before investing time in training
try:
    scripted = torch.jit.script(model)
    print("TorchScript scripting successful")
except Exception as e:
    print(f"Scripting failed: {e}")
    # Try tracing instead
    traced = torch.jit.trace(model, dummy_input)
    print("TorchScript tracing successful")

Phase 4: Training Specific Layers

Freezing Layers

To train only specific layers, freeze all others:

# Freeze all parameters first
for param in model.parameters():
    param.requires_grad = False

# Unfreeze only target layers
for param in model.output_layer.parameters():
    param.requires_grad = True

# Verify freeze status
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} parameters")

Computing Baseline Loss

Before training, establish a baseline:

model.eval()
with torch.no_grad():
    outputs = model(inputs)
    original_loss = criterion(outputs, targets)
    print(f"Original MSE loss: {original_loss.item()}")

Training Loop Considerations

# Create optimizer only for trainable parameters
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=0.001
)

# Training with progress tracking
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    outputs = model(inputs)
    loss = criterion(outputs, targets)

    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}: Loss = {loss.item():.6f}")

Alternative: Closed-Form Solution for Linear Layers

When retraining only a linear output layer, consider a closed-form solution for efficiency:

# Pre-compute frozen layer outputs
model.eval()
with torch.no_grad():
    # Get features before output layer
    features = model.get_features(inputs)  # Shape: [N, d_model]

# Solve linear regression: W*features = targets
# Using pseudo-inverse: W = targets @ features.T @ (features @ features.T)^-1
solution = torch.linalg.lstsq(features, targets).solution
model.output_layer.weight.data = solution.T

Phase 5: TorchScript Export

Saving the Model

# Ensure model is in eval mode
model.eval()

# Script the model (preferred for control flow)
scripted_model = torch.jit.script(model)
scripted_model.save('/app/model.pt')

# Or trace the model (for simpler models)
traced_model = torch.jit.trace(model, example_input)
traced_model.save('/app/model.pt')

Verify Saved Model

# Reload and verify
loaded = torch.jit.load('/app/model.pt')
loaded.eval()

with torch.no_grad():
    original_out = model(test_input)
    loaded_out = loaded(test_input)

    diff = (original_out - loaded_out).abs().max()
    print(f"Max difference: {diff.item()}")
    assert diff < 1e-5, "Model outputs don't match!"

Environment Considerations

Handling Slow Environments

When operating in resource-constrained environments:

  1. Benchmark first: Test basic operations before committing to full solution

    import time
    start = time.time()
    _ = model(torch.randint(0, vocab_size, (1, 10)))
    print(f"Single forward pass: {time.time() - start:.2f}s")
    
  2. Reduce batch size: Process samples individually if needed

  3. Set realistic timeouts: Base on benchmarks, not arbitrary values

  4. Use incremental checkpoints: Save progress periodically

Memory Management

# Clear GPU cache between operations
torch.cuda.empty_cache()

# Use gradient checkpointing for large models
from torch.utils.checkpoint import checkpoint

# Process in smaller batches
for batch in torch.split(data, batch_size):
    process(batch)

Common Pitfalls

  1. Not verifying architecture match before training - Always test load_state_dict first
  2. Arbitrary hyperparameters - Justify choices based on task characteristics
  3. Ignoring TorchScript compatibility - Test export early, not after training
  4. Syntax errors in edits - Review code changes carefully, especially string formatting
  5. Incomplete state dict mapping - Verify all keys are accounted for
  6. Not establishing baseline metrics - Compute original loss before training
  7. Missing torch.no_grad() for inference - Use context manager for evaluation
  8. Forgetting to set model.eval() - Required for consistent behavior in eval/export

Verification Checklist

Before considering the task complete:

  • State dictionary keys fully analyzed and documented
  • Model architecture matches state dict exactly (verified with load_state_dict)
  • Forward pass produces valid output
  • Baseline loss/metric computed
  • Target layers correctly unfrozen, others frozen
  • Training improves loss over baseline
  • TorchScript export succeeds
  • Exported model produces same outputs as original
  • Model saved to required path