Claude Code Plugins

Community-maintained marketplace

Feedback

module-design-patterns

@tachyon-beep/skillpacks
1
0

nn.Module design - hooks, initialization, serialization, device movement

Install Skill

1Download skill
2Enable skills in Claude

Open claude.ai/settings/capabilities and find the "Skills" section

3Upload to Claude

Click "Upload skill" and select the downloaded ZIP file

Note: Please verify skill by going through its instructions before using it.

SKILL.md

name module-design-patterns
description nn.Module design - hooks, initialization, serialization, device movement

PyTorch nn.Module Design Patterns

Overview

Core Principle: nn.Module is not just a container for forward passes. It's PyTorch's contract for model serialization, device management, parameter enumeration, and inspection. Follow conventions or face subtle bugs during scaling, deployment, and debugging.

Poor module design manifests as: state dict corruption, DDP failures, hook memory leaks, initialization fragility, and un-inspectable architectures. These bugs are silent until production. Design modules correctly from the start using PyTorch's established patterns.

When to Use

Use this skill when:

  • Implementing custom nn.Module subclasses
  • Adding forward/backward hooks for feature extraction or debugging
  • Designing modular architectures with swappable components
  • Implementing custom weight initialization strategies
  • Building reusable model components (blocks, layers, heads)
  • Encountering state dict issues, DDP failures, or hook problems

Don't use when:

  • Simple model composition (stack existing modules)
  • Training loop issues (use training-optimization)
  • Memory debugging unrelated to modules (use tensor-operations-and-memory)

Symptoms triggering this skill:

  • "State dict keys don't match after loading"
  • "DDP not syncing gradients properly"
  • "Hooks causing memory leaks"
  • "Can't move model to device"
  • "Model serialization breaks after changes"
  • "Need to extract intermediate features"
  • "Want to make architecture more modular"

Expert Module Design Patterns

Pattern 1: Always Use nn.Module, Never None

Problem: Conditional module assignment using None breaks PyTorch's module contract.

# ❌ WRONG: Conditional None assignment
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1)
        self.bn1 = nn.BatchNorm2d(out_channels)

        # PROBLEM: Using None for conditional skip connection
        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Conv2d(in_channels, out_channels, 1, stride, 0)
        else:
            self.skip = None  # ❌ Breaks module enumeration!

    def forward(self, x):
        out = self.bn1(self.conv1(x))
        # Conditional check needed
        if self.skip is not None:
            x = self.skip(x)
        return F.relu(out + x)

Why this breaks:

  • model.parameters() and model.named_modules() skip None attributes
  • .to(device) doesn't move None, causes device mismatch bugs
  • state_dict() saving/loading becomes inconsistent
  • DDP/model parallel don't handle None modules correctly
  • Can't inspect architecture: for name, module in model.named_modules()

✅ CORRECT: Use nn.Identity() for no-op

class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1)
        self.bn1 = nn.BatchNorm2d(out_channels)

        # ✅ ALWAYS assign an nn.Module subclass
        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Conv2d(in_channels, out_channels, 1, stride, 0)
        else:
            self.skip = nn.Identity()  # ✅ No-op module

    def forward(self, x):
        out = self.bn1(self.conv1(x))
        # No conditional needed!
        x = self.skip(x)  # Identity passes through unchanged
        return F.relu(out + x)

Why this works:

  • nn.Identity() passes input unchanged (no-op)
  • Consistent module hierarchy across all code paths
  • Device movement works: .to(device) works on Identity too
  • State dict consistent: Identity has no parameters but is tracked
  • DDP handles Identity correctly
  • Architecture inspection works: model.skip always exists

Rule: Never assign None to self.* for modules. Use nn.Identity() for no-ops.


Pattern 2: Functional vs Module Operations - When to Use Each

Core Question: When should you use F.relu(x) vs self.relu = nn.ReLU()?

Decision Framework:

Use Functional (F.*) When Use Module (nn.*) When
Simple, stateless operations Need to hook the operation
Performance critical paths Need to inspect/modify later
Operations in complex control flow Want clear module hierarchy
One-off computations Operation has learnable parameters
Loss functions Activation functions you might swap

Example: When functional is fine

class SimpleBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, 3, 1, 1)
        self.bn = nn.BatchNorm2d(channels)
        # No need to store ReLU as module for simple blocks

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x)  # ✅ Fine for simple cases

Example: When module storage matters

class FeatureExtractorBlock(nn.Module):
    def __init__(self, channels, activation='relu'):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, 3, 1, 1)
        self.bn = nn.BatchNorm2d(channels)

        # ✅ Store as module for flexibility and inspection
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'gelu':
            self.activation = nn.GELU()
        else:
            self.activation = nn.Identity()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return self.activation(x)  # ✅ Can hook, swap, inspect

Why storing as module matters:

  1. Hooks: Can only hook module operations, not functional

    # ✅ Can register hook
    model.layer1.activation.register_forward_hook(hook_fn)
    
    # ❌ Can't hook F.relu() calls
    
  2. Inspection: Module hierarchy shows architecture

    for name, module in model.named_modules():
        print(f"{name}: {module}")
    # With nn.ReLU: "layer1.activation: ReLU()"
    # With F.relu: activation not shown
    
  3. Modification: Can swap modules after creation

    # ✅ Can replace activation
    model.layer1.activation = nn.GELU()
    
    # ❌ Can't modify F.relu() usage without code changes
    
  4. Quantization: Quantization tools trace module operations

    # ✅ Quantization sees nn.ReLU
    quantized = torch.quantization.quantize_dynamic(model)
    
    # ❌ F.relu() not traced by quantization
    

Pattern to follow:

  • Simple internal blocks: Functional is fine
  • Top-level operations you might modify: Use modules
  • When building reusable components: Use modules
  • When unsure: Use modules (negligible overhead)

Pattern 3: Modular Design with Substitutable Components

Problem: Hardcoding architecture choices makes variants difficult.

# ❌ WRONG: Hardcoded architecture
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # Hardcoded: ReLU, BatchNorm, specific conv config
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        return x

Problem: To use LayerNorm or GELU, you must copy-paste and create new class.

✅ CORRECT: Modular design with substitutable components

class EncoderBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        norm_layer=nn.BatchNorm2d,  # ✅ Substitutable
        activation=nn.ReLU,          # ✅ Substitutable
        bias=True
    ):
        super().__init__()

        # Use provided norm and activation
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=bias)
        self.norm1 = norm_layer(out_channels)
        self.act1 = activation()

        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=bias)
        self.norm2 = norm_layer(out_channels)
        self.act2 = activation()

    def forward(self, x):
        x = self.act1(self.norm1(self.conv1(x)))
        x = self.act2(self.norm2(self.conv2(x)))
        return x

# Usage examples:
# Standard: BatchNorm + ReLU
block1 = EncoderBlock(64, 128)

# LayerNorm + GELU (for vision transformers)
block2 = EncoderBlock(64, 128, norm_layer=nn.LayerNorm, activation=nn.GELU)

# No normalization
block3 = EncoderBlock(64, 128, norm_layer=nn.Identity, activation=nn.ReLU)

Advanced: Flexible normalization for different dimensions

class EncoderBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        norm_layer=None,  # If None, use default BatchNorm2d
        activation=None    # If None, use default ReLU
    ):
        super().__init__()

        # Set defaults
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if activation is None:
            activation = nn.ReLU

        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
        # Handle both class and partial/lambda
        self.norm1 = norm_layer(out_channels) if callable(norm_layer) else norm_layer
        self.act1 = activation() if callable(activation) else activation

        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)
        self.norm2 = norm_layer(out_channels) if callable(norm_layer) else norm_layer
        self.act2 = activation() if callable(activation) else activation

    def forward(self, x):
        x = self.act1(self.norm1(self.conv1(x)))
        x = self.act2(self.norm2(self.conv2(x)))
        return x

Benefits:

  • One class supports many architectural variants
  • Easy to experiment: swap LayerNorm, GELU, etc.
  • Code reuse without duplication
  • Matches PyTorch's own design (e.g., ResNet's norm_layer parameter)

Pattern: Accept layer constructors as arguments, not hardcoded classes.


Pattern 4: Proper State Management and __init__ Structure

Core principle: __init__ defines the module's structure, forward defines computation.

class WellStructuredModule(nn.Module):
    """
    Template for well-structured PyTorch modules.
    """

    def __init__(self, config):
        # 1. ALWAYS call super().__init__() first
        super().__init__()

        # 2. Store configuration (for reproducibility/serialization)
        self.config = config

        # 3. Initialize all submodules (parameters registered automatically)
        self._build_layers()

        # 4. Initialize weights AFTER building layers
        self.reset_parameters()

    def _build_layers(self):
        """
        Separate method for building layers (cleaner __init__).
        """
        self.encoder = nn.Sequential(
            nn.Linear(self.config.input_dim, self.config.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.config.hidden_dim, self.config.hidden_dim)
        )

        self.decoder = nn.Linear(self.config.hidden_dim, self.config.output_dim)

        # ✅ Use nn.Identity() for conditional modules
        if self.config.use_skip:
            self.skip = nn.Linear(self.config.input_dim, self.config.output_dim)
        else:
            self.skip = nn.Identity()

    def reset_parameters(self):
        """
        Custom initialization following PyTorch convention.

        This method can be called to re-initialize the module:
        - After creation
        - When loading partial checkpoints
        - For training experiments
        """
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:  # ✅ Check before accessing
                    nn.init.zeros_(module.bias)

    def forward(self, x):
        """
        Forward pass - pure computation, no module construction.
        """
        # ❌ NEVER create modules here!
        # ❌ NEVER assign self.* here!

        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        skip = self.skip(x)

        return decoded + skip

Critical rules:

  1. Never create modules in forward()

    # ❌ WRONG
    def forward(self, x):
        self.temp_layer = nn.Linear(10, 10)  # ❌ Created during forward!
        return self.temp_layer(x)
    

    Why: Parameters not registered, DDP breaks, state dict inconsistent.

  2. Never use self.* for intermediate results

    # ❌ WRONG
    def forward(self, x):
        self.intermediate = self.encoder(x)  # ❌ Storing as attribute!
        return self.decoder(self.intermediate)
    

    Why: Retains computation graph, memory leak, not thread-safe.

  3. All modules defined in __init__

    # ✅ CORRECT
    def __init__(self):
        super().__init__()
        self.encoder = nn.Linear(10, 10)  # ✅ Defined in __init__
    
    def forward(self, x):
        intermediate = self.encoder(x)  # ✅ Local variable
        return self.decoder(intermediate)
    

Hook Management Best Practices

Pattern 5: Forward Hooks for Feature Extraction

Problem: Naive hook usage causes memory leaks and handle management issues.

# ❌ WRONG: Multiple problems
import torch
import torch.nn as nn

features = {}  # ❌ Global state

def get_features(name):
    def hook(module, input, output):
        features[name] = output  # ❌ Retains computation graph!
    return hook

model = nn.Sequential(...)
# ❌ No handle stored, can't remove
model[2].register_forward_hook(get_features('layer2'))

with torch.no_grad():
    output = model(x)

# features now contains tensors with gradients (even in no_grad context!)

Why this breaks:

  1. Hooks run outside torch.no_grad() context: Hook is called by autograd machinery, not your code
  2. Global state: Not thread-safe, can't have multiple concurrent extractions
  3. No cleanup: Hooks persist forever, can't remove
  4. Memory leak: Retained outputs keep computation graph alive

✅ CORRECT: Encapsulated hook handler with proper cleanup

class FeatureExtractor:
    """
    Proper feature extraction using forward hooks.

    Example:
        extractor = FeatureExtractor(model, layers=['layer2', 'layer3'])
        with extractor:
            output = model(input)
        features = extractor.features  # Dict of detached tensors
    """

    def __init__(self, model, layers):
        self.model = model
        self.layers = layers
        self.features = {}
        self.handles = []  # ✅ Store handles for cleanup

    def _make_hook(self, name):
        def hook(module, input, output):
            # ✅ CRITICAL: Detach and optionally clone
            self.features[name] = output.detach()
            # For inputs that might be modified in-place, use .clone():
            # self.features[name] = output.detach().clone()
        return hook

    def __enter__(self):
        """Register hooks when entering context."""
        self.features.clear()

        for name, module in self.model.named_modules():
            if name in self.layers:
                handle = module.register_forward_hook(self._make_hook(name))
                self.handles.append(handle)  # ✅ Store handle

        return self

    def __exit__(self, *args):
        """Clean up hooks when exiting context."""
        # ✅ CRITICAL: Remove all hooks
        for handle in self.handles:
            handle.remove()
        self.handles.clear()

# Usage
model = resnet50()
extractor = FeatureExtractor(model, layers=['layer2', 'layer3', 'layer4'])

with extractor:
    output = model(input_tensor)

# Features extracted and hooks cleaned up
pyramid_features = [
    extractor.features['layer2'],
    extractor.features['layer3'],
    extractor.features['layer4']
]

Key points:

  • ✅ Encapsulated in class (no global state)
  • output.detach() breaks gradient tracking (prevents memory leak)
  • ✅ Hook handles stored and removed (no persistent hooks)
  • ✅ Context manager ensures cleanup even if error occurs
  • ✅ Thread-safe (each extractor has own state)

Pattern 6: When to Detach vs Clone in Hooks

Question: Should hooks detach, clone, or both?

Decision framework:

def hook(module, input, output):
    # Decision tree:

    # 1. Just reading output, no modifications?
    self.features[name] = output.detach()  # ✅ Sufficient

    # 2. Output might be modified in-place later?
    self.features[name] = output.detach().clone()  # ✅ Safer

    # 3. Need gradients for analysis (rare)?
    self.features[name] = output  # ⚠️ Dangerous, ensure short lifetime

Example: When clone matters

# Scenario: In-place operations after hook
class Model(nn.Module):
    def forward(self, x):
        x = self.layer1(x)  # Hook here
        x = self.layer2(x)
        x += 10  # ❌ In-place modification!
        return x

# ❌ WRONG: Detach without clone
def hook(module, input, output):
    features['layer1'] = output.detach()  # Still shares memory!

# After forward pass:
# features['layer1'] has been modified by x += 10!

# ✅ CORRECT: Clone to get independent copy
def hook(module, input, output):
    features['layer1'] = output.detach().clone()  # Independent copy

Rule of thumb:

  • Detach only: Reading features for analysis, no in-place ops
  • Detach + clone: Features might be modified, or unsure
  • Neither: Only if you need gradients (rare, risky)

Pattern 7: Backward Hooks for Gradient Inspection

Use case: Debugging gradient flow, detecting vanishing/exploding gradients.

class GradientInspector:
    """
    Inspect gradients during backward pass.

    Example:
        inspector = GradientInspector(model, layers=['layer1', 'layer2'])
        with inspector:
            output = model(input)
            loss.backward()

        # Check gradient statistics
        for name, stats in inspector.grad_stats.items():
            print(f"{name}: mean={stats['mean']:.4f}, std={stats['std']:.4f}")
    """

    def __init__(self, model, layers):
        self.model = model
        self.layers = layers
        self.grad_stats = {}
        self.handles = []

    def _make_hook(self, name):
        def hook(module, grad_input, grad_output):
            # grad_output: gradients w.r.t. outputs (from upstream)
            # grad_input: gradients w.r.t. inputs (to downstream)

            # Check grad_output (most common)
            if grad_output[0] is not None:
                grad = grad_output[0].detach()
                self.grad_stats[name] = {
                    'mean': grad.abs().mean().item(),
                    'std': grad.std().item(),
                    'max': grad.abs().max().item(),
                    'min': grad.abs().min().item(),
                }
        return hook

    def __enter__(self):
        self.grad_stats.clear()

        for name, module in self.model.named_modules():
            if name in self.layers:
                handle = module.register_full_backward_hook(self._make_hook(name))
                self.handles.append(handle)

        return self

    def __exit__(self, *args):
        for handle in self.handles:
            handle.remove()
        self.handles.clear()

# Usage for gradient debugging
model = MyModel()
inspector = GradientInspector(model, layers=['encoder.layer1', 'decoder.layer1'])

with inspector:
    output = model(input)
    loss = criterion(output, target)
    loss.backward()

# Check for vanishing/exploding gradients
for name, stats in inspector.grad_stats.items():
    if stats['mean'] < 1e-7:
        print(f"⚠️ Vanishing gradient in {name}")
    if stats['mean'] > 100:
        print(f"⚠️ Exploding gradient in {name}")

Critical differences from forward hooks:

  • Backward hooks run during .backward(): Not during forward pass
  • Receive gradient tensors: Not activations
  • Used for gradient analysis: Not feature extraction

Pattern 8: Hook Handle Management Patterns

Never do this:

# ❌ WRONG: No handle stored
model.layer.register_forward_hook(hook_fn)
# Hook persists forever, can't remove!

Three patterns for handle management:

Pattern A: Context manager (recommended for temporary hooks)

class HookManager:
    def __init__(self, module, hook_fn):
        self.module = module
        self.hook_fn = hook_fn
        self.handle = None

    def __enter__(self):
        self.handle = self.module.register_forward_hook(self.hook_fn)
        return self

    def __exit__(self, *args):
        if self.handle:
            self.handle.remove()

# Usage
with HookManager(model.layer1, my_hook):
    output = model(input)
# Hook automatically removed

Pattern B: Explicit cleanup (for long-lived hooks)

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(10, 10)
        self.hook_handle = self.layer.register_forward_hook(self._debug_hook)

    def _debug_hook(self, module, input, output):
        print(f"Output shape: {output.shape}")

    def remove_hooks(self):
        """Explicit cleanup method."""
        if self.hook_handle:
            self.hook_handle.remove()
            self.hook_handle = None

# Usage
model = Model()
# ... use model ...
model.remove_hooks()  # Clean up before saving or finishing

Pattern C: List of handles (multiple hooks)

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(10, 10) for _ in range(5)])
        self.hook_handles = []

    def register_debug_hooks(self):
        """Register hooks on all layers."""
        for i, layer in enumerate(self.layers):
            handle = layer.register_forward_hook(
                lambda m, inp, out, idx=i: print(f"Layer {idx}: {out.shape}")
            )
            self.hook_handles.append(handle)

    def remove_all_hooks(self):
        """Remove all registered hooks."""
        for handle in self.hook_handles:
            handle.remove()
        self.hook_handles.clear()

Critical rule: Every register_*_hook() call MUST have corresponding handle.remove().


Weight Initialization Patterns

Pattern 9: The reset_parameters() Convention

PyTorch convention: Custom initialization goes in reset_parameters(), called from __init__.

# ❌ WRONG: Initialization in __init__ after submodule creation
class CustomModule(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()

        self.linear1 = nn.Linear(in_dim, out_dim)
        self.linear2 = nn.Linear(out_dim, out_dim)

        # ❌ Initializing here is fragile
        nn.init.xavier_uniform_(self.linear1.weight)
        nn.init.xavier_uniform_(self.linear2.weight)
        # What if linear has bias=False? This crashes:
        nn.init.zeros_(self.linear1.bias)  # ❌ AttributeError if bias=False

Problems:

  1. Happens AFTER nn.Linear's own reset_parameters() (already initialized)
  2. Can't re-initialize later: model.reset_parameters() won't work
  3. Fragile: assumes bias exists
  4. Violates PyTorch convention

✅ CORRECT: Define reset_parameters() method

class CustomModule(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()

        self.linear1 = nn.Linear(in_dim, out_dim)
        self.linear2 = nn.Linear(out_dim, out_dim)

        # ✅ Call reset_parameters at end of __init__
        self.reset_parameters()

    def reset_parameters(self):
        """
        Initialize module parameters.

        Following PyTorch convention, this method:
        - Can be called to re-initialize the module
        - Is called automatically at end of __init__
        - Allows for custom initialization strategies
        """
        # ✅ Defensive: check if bias exists
        nn.init.xavier_uniform_(self.linear1.weight)
        if self.linear1.bias is not None:
            nn.init.zeros_(self.linear1.bias)

        nn.init.xavier_uniform_(self.linear2.weight)
        if self.linear2.bias is not None:
            nn.init.zeros_(self.linear2.bias)

    def forward(self, x):
        return self.linear2(F.relu(self.linear1(x)))

# Benefits:
# 1. Can re-initialize: model.reset_parameters()
# 2. Defensive checks for optional bias
# 3. Follows PyTorch convention
# 4. Clear separation: __init__ defines structure, reset_parameters initializes

Pattern 10: Hierarchical Initialization

Pattern: When modules contain submodules, iterate through hierarchy.

class ComplexModel(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Linear(config.input_dim, config.hidden_dim),
            nn.ReLU(),
            nn.Linear(config.hidden_dim, config.hidden_dim)
        )

        self.attention = nn.MultiheadAttention(config.hidden_dim, config.num_heads)

        self.decoder = nn.Linear(config.hidden_dim, config.output_dim)

        self.reset_parameters()

    def reset_parameters(self):
        """
        Initialize all submodules hierarchically.
        """
        # Method 1: Iterate through all modules
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.MultiheadAttention):
                # MultiheadAttention has its own reset_parameters()
                # Option: Call it or customize
                module._reset_parameters()  # Call internal reset

        # Method 2: Specific initialization for specific layers
        # Override general initialization for decoder
        nn.init.xavier_uniform_(self.decoder.weight, gain=0.5)

Two strategies:

  1. Uniform initialization: Iterate all modules, apply same rules

    for module in self.modules():
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            nn.init.kaiming_normal_(module.weight)
    
  2. Layered initialization: Different rules for different components

    def reset_parameters(self):
        # Encoder: Xavier
        for module in self.encoder.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
    
        # Decoder: Xavier with small gain
        nn.init.xavier_uniform_(self.decoder.weight, gain=0.5)
    

Defensive checks:

def reset_parameters(self):
    for module in self.modules():
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)

            # ✅ Always check for bias
            if module.bias is not None:
                nn.init.zeros_(module.bias)

        elif isinstance(module, nn.BatchNorm2d):
            # BatchNorm has weight and bias, but different semantics
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

Pattern 11: Initialization with Learnable Parameters

Use case: Custom parameters that need special initialization.

class AttentionWithTemperature(nn.Module):
    def __init__(self, d_model, d_k):
        super().__init__()

        self.d_k = d_k

        self.query = nn.Linear(d_model, d_k)
        self.key = nn.Linear(d_model, d_k)
        self.value = nn.Linear(d_model, d_k)
        self.output = nn.Linear(d_k, d_model)

        # ✅ Learnable temperature parameter
        # Initialize to 1/sqrt(d_k), but make it learnable
        self.temperature = nn.Parameter(torch.ones(1))

        self.reset_parameters()

    def reset_parameters(self):
        """Initialize all parameters."""
        # Standard initialization for linear layers
        for linear in [self.query, self.key, self.value]:
            nn.init.xavier_uniform_(linear.weight)
            if linear.bias is not None:
                nn.init.zeros_(linear.bias)

        # Output projection with smaller gain
        nn.init.xavier_uniform_(self.output.weight, gain=0.5)
        if self.output.bias is not None:
            nn.init.zeros_(self.output.bias)

        # ✅ Custom parameter initialization
        nn.init.constant_(self.temperature, 1.0 / math.sqrt(self.d_k))

    def forward(self, x):
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        scores = torch.matmul(q, k.transpose(-2, -1)) * self.temperature
        attn = torch.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)

        return self.output(out)

Key points:

  • Custom parameters defined with nn.Parameter()
  • Initialized in reset_parameters() like other parameters
  • Can use nn.init.* functions on parameters

Common Pitfalls

Consolidated Pitfalls Table

# Pitfall Symptom Root Cause Fix
1 Using self.x = None for conditional modules State dict inconsistent, DDP fails, can't move to device None not an nn.Module Use nn.Identity()
2 Using functional ops when hooks/inspection needed Can't hook activations, architecture invisible Functional bypasses module hierarchy Store as self.activation = nn.ReLU()
3 Hooks retaining computation graphs Memory leak during feature extraction Hook doesn't detach outputs Use output.detach() in hook
4 No hook handle cleanup Hooks persist, memory leak, unexpected behavior Handles not stored/removed Store handles, call handle.remove()
5 Global state in hook closures Not thread-safe, coupling issues Mutable global variables Encapsulate in class
6 Initialization in __init__ instead of reset_parameters() Can't re-initialize, fragile timing Violates PyTorch convention Define reset_parameters()
7 Accessing bias without checking existence Crashes with AttributeError Assumes bias always exists Check if module.bias is not None:
8 Creating modules in forward() Parameters not registered, DDP breaks Modules must be in __init__ Move to __init__, use local vars
9 Storing intermediate results as self.* Memory leak, not thread-safe Retains computation graph Use local variables only
10 Not using context managers for hooks Hooks not cleaned up on error Missing try/finally Use __enter__/__exit__ pattern

Pitfall 1: Conditional None Assignment

# ❌ WRONG
class Block(nn.Module):
    def __init__(self, use_skip):
        super().__init__()
        self.layer = nn.Linear(10, 10)
        self.skip = nn.Linear(10, 10) if use_skip else None  # ❌

    def forward(self, x):
        out = self.layer(x)
        if self.skip is not None:
            out = out + self.skip(x)
        return out

# ✅ CORRECT
class Block(nn.Module):
    def __init__(self, use_skip):
        super().__init__()
        self.layer = nn.Linear(10, 10)
        self.skip = nn.Linear(10, 10) if use_skip else nn.Identity()  # ✅

    def forward(self, x):
        out = self.layer(x)
        out = out + self.skip(x)  # No conditional needed
        return out

Symptom: State dict keys mismatch, DDP synchronization failures Fix: Always use nn.Identity() for no-op modules


Pitfall 2: Functional Ops Preventing Hooks

# ❌ WRONG: Can't hook ReLU
class Encoder(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(dim, dim)

    def forward(self, x):
        x = self.linear(x)
        return F.relu(x)  # ❌ Can't hook this!

# Can't do this:
# encoder.relu.register_forward_hook(hook)  # AttributeError!

# ✅ CORRECT: Hookable activation
class Encoder(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(dim, dim)
        self.relu = nn.ReLU()  # ✅ Stored as module

    def forward(self, x):
        x = self.linear(x)
        return self.relu(x)

# ✅ Now can hook:
encoder.relu.register_forward_hook(hook)

Symptom: Can't register hooks on operations Fix: Store operations as modules when you need inspection/hooks


Pitfall 3: Hook Memory Leak

# ❌ WRONG: Hook retains graph
features = {}

def hook(module, input, output):
    features['layer'] = output  # ❌ Retains computation graph!

model.layer.register_forward_hook(hook)

with torch.no_grad():
    output = model(input)
# features['layer'] STILL has gradients!

# ✅ CORRECT: Detach in hook
def hook(module, input, output):
    features['layer'] = output.detach()  # ✅ Breaks graph

# Even better: Clone if might be modified
def hook(module, input, output):
    features['layer'] = output.detach().clone()  # ✅ Independent copy

Symptom: Memory grows during feature extraction even with torch.no_grad() Fix: Always .detach() in hooks (and .clone() if needed)


Pitfall 4: Missing Hook Cleanup

# ❌ WRONG: No handle management
model.layer.register_forward_hook(my_hook)
# Hook persists forever, can't remove!

# ✅ CORRECT: Store and clean up handle
class HookManager:
    def __init__(self):
        self.handle = None

    def register(self, module, hook):
        self.handle = module.register_forward_hook(hook)

    def cleanup(self):
        if self.handle:
            self.handle.remove()

manager = HookManager()
manager.register(model.layer, my_hook)
# ... use model ...
manager.cleanup()  # ✅ Remove hook

Symptom: Hooks persist, unexpected behavior, memory leaks Fix: Always store handles and call .remove()


Pitfall 5: Initialization Timing

# ❌ WRONG: Init in __init__ (fragile)
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 10)  # Already initialized!

        # This works but is fragile:
        nn.init.xavier_uniform_(self.linear.weight)  # Overwrites default init

# ✅ CORRECT: Init in reset_parameters()
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 10)
        self.reset_parameters()  # ✅ Clear separation

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.linear.weight)
        if self.linear.bias is not None:  # ✅ Defensive
            nn.init.zeros_(self.linear.bias)

Symptom: Can't re-initialize, crashes on bias=False Fix: Define reset_parameters(), call from __init__


Red Flags - Stop and Reconsider

If you catch yourself doing ANY of these, STOP and follow patterns:

Red Flag Action Reality What to Do Instead
"I'll assign None to this module attribute" Breaks PyTorch's module contract Use nn.Identity()
"F.relu() is simpler than nn.ReLU()" True, but prevents inspection/hooks Use module if you might need hooks
"I'll store hook output directly" Retains computation graph Always .detach() first
"I don't need to store the hook handle" Can't remove hook later Always store handles
"I'll just initialize in init" Can't re-initialize later Use reset_parameters()
"Bias always exists, right?" No! bias=False is common Check if bias is not None:
"I'll save intermediate results as self.*" Memory leak, not thread-safe Use local variables only
"I'll create this module in forward()" Parameters not registered All modules in __init__

Critical rule: Follow PyTorch conventions or face subtle bugs in production.


Complete Example: Well-Designed ResNet Block

import torch
import torch.nn as nn
import math

class ResNetBlock(nn.Module):
    """
    Well-designed ResNet block following all best practices.

    Features:
    - Substitutable norm and activation layers
    - Proper use of nn.Identity() for skip connections
    - Hook-friendly (all operations are modules)
    - Correct initialization via reset_parameters()
    - Defensive bias checking
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        stride=1,
        norm_layer=nn.BatchNorm2d,
        activation=nn.ReLU,
        bias=False  # Usually False with BatchNorm
    ):
        super().__init__()

        # Store config for potential serialization
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride

        # Main path: conv -> norm -> activation -> conv -> norm
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=bias)
        self.norm1 = norm_layer(out_channels)
        self.act1 = activation()

        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=bias)
        self.norm2 = norm_layer(out_channels)

        # Skip connection (dimension matching)
        # ✅ CRITICAL: Use nn.Identity(), never None
        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, 0, bias=bias),
                norm_layer(out_channels)
            )
        else:
            self.skip = nn.Identity()

        # Final activation (applied after residual addition)
        self.act2 = activation()

        # ✅ Initialize weights following convention
        self.reset_parameters()

    def reset_parameters(self):
        """
        Initialize weights using He initialization (good for ReLU).
        """
        # Iterate through all conv layers
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                # He initialization for ReLU
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
                # ✅ Defensive: check bias exists
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

            elif isinstance(module, nn.BatchNorm2d):
                # BatchNorm standard initialization
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)

    def forward(self, x):
        """
        Forward pass: residual connection with skip path.

        Note: All operations are modules, so can be hooked or modified.
        """
        # Main path
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.act1(out)  # ✅ Module, not F.relu()

        out = self.conv2(out)
        out = self.norm2(out)

        # Skip connection (always works, no conditional)
        skip = self.skip(x)  # ✅ Identity passes through if no projection needed

        # Residual addition and final activation
        out = out + skip
        out = self.act2(out)  # ✅ Module, not F.relu()

        return out

# Usage examples:
# Standard ResNet block
block1 = ResNetBlock(64, 128, stride=2)

# With LayerNorm and GELU (Vision Transformer style)
block2 = ResNetBlock(64, 128, norm_layer=nn.GroupNorm, activation=nn.GELU)

# Can hook any operation:
handle = block1.act1.register_forward_hook(lambda m, i, o: print(f"ReLU output shape: {o.shape}"))

# Can re-initialize:
block1.reset_parameters()

# Can inspect architecture:
for name, module in block1.named_modules():
    print(f"{name}: {module}")

Why this design is robust:

  1. ✅ No None assignments (uses nn.Identity())
  2. ✅ All operations are modules (hookable)
  3. ✅ Substitutable components (norm, activation)
  4. ✅ Proper initialization (reset_parameters())
  5. ✅ Defensive bias checking
  6. ✅ Clear module hierarchy
  7. ✅ Configuration stored (reproducibility)
  8. ✅ No magic numbers or hardcoded choices

Edge Cases and Advanced Scenarios

Edge Case 1: Dynamic Module Lists (nn.ModuleList)

Scenario: Need variable number of layers based on config.

# ❌ WRONG: Using Python list for modules
class DynamicModel(nn.Module):
    def __init__(self, num_layers):
        super().__init__()
        self.layers = []  # ❌ Python list, parameters not registered!
        for i in range(num_layers):
            self.layers.append(nn.Linear(10, 10))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# model.parameters() is empty! DDP breaks!

# ✅ CORRECT: Use nn.ModuleList
class DynamicModel(nn.Module):
    def __init__(self, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([  # ✅ Registers all parameters
            nn.Linear(10, 10) for _ in range(num_layers)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

Rule: Use nn.ModuleList for lists of modules, nn.ModuleDict for dicts.


Edge Case 2: Hooks on nn.Sequential

Problem: Hooking specific layers inside nn.Sequential.

model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 20),
    nn.ReLU(),
    nn.Linear(20, 10)
)

# ❌ WRONG: Can't access by name easily
# model.layer2.register_forward_hook(hook)  # AttributeError

# ✅ CORRECT: Access by index
handle = model[2].register_forward_hook(hook)  # Third layer (Linear 20->20)

# ✅ BETTER: Use named modules
for name, module in model.named_modules():
    if isinstance(module, nn.Linear):
        print(f"Hooking {name}")
        module.register_forward_hook(hook)

Best practice: For hookable models, use explicit named attributes instead of Sequential:

class HookableModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(10, 20)
        self.act1 = nn.ReLU()
        self.layer2 = nn.Linear(20, 20)  # ✅ Named, easy to hook
        self.act2 = nn.ReLU()
        self.layer3 = nn.Linear(20, 10)

    def forward(self, x):
        x = self.act1(self.layer1(x))
        x = self.act2(self.layer2(x))
        return self.layer3(x)

# Easy to hook specific layers:
model.layer2.register_forward_hook(hook)

Edge Case 3: Hooks with In-Place Operations

Problem: In-place operations modify hooked tensors.

class ModelWithInPlace(nn.Module):
    def forward(self, x):
        x = self.layer1(x)  # Hook here
        x += 10  # ❌ In-place modification!
        x = self.layer2(x)
        return x

# Hook only using detach():
def hook(module, input, output):
    features['layer1'] = output.detach()  # ❌ Still shares memory!

# After forward pass, features['layer1'] has been modified!

# ✅ CORRECT: Detach AND clone
def hook(module, input, output):
    features['layer1'] = output.detach().clone()  # ✅ Independent copy

Decision tree for hooks:

Is output modified in-place later?
├─ Yes → Use .detach().clone()
└─ No → Use .detach() (sufficient)

Need gradients for analysis?
├─ Yes → Don't detach (but ensure short lifetime!)
└─ No → Detach (prevents memory leak)

Edge Case 4: Partial State Dict Loading

Scenario: Loading checkpoint with different architecture.

# Original model
class ModelV1(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Linear(10, 20)
        self.decoder = nn.Linear(20, 10)

# New model with additional layer
class ModelV2(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Linear(10, 20)
        self.middle = nn.Linear(20, 20)  # New layer!
        self.decoder = nn.Linear(20, 10)

        self.reset_parameters()

    def reset_parameters(self):
        # ✅ Initialize all layers
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

# Load V1 checkpoint into V2 model
model_v2 = ModelV2()
checkpoint = torch.load('model_v1.pth')

# ✅ Use strict=False for partial loading
model_v2.load_state_dict(checkpoint, strict=False)

# ✅ Re-initialize new layers only
model_v2.middle.reset_parameters()  # New layer needs init

Pattern: When loading partial checkpoints:

  1. Load with strict=False
  2. Check which keys are missing/unexpected
  3. Re-initialize only new layers (not loaded ones)

Edge Case 5: Hook Removal During Forward Pass

Problem: Removing hooks while iterating causes issues.

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(10, 10)
        self.hook_handles = []

    def add_temporary_hook(self):
        def hook(module, input, output):
            print("Hook called!")
            # ❌ WRONG: Removing handle inside hook
            for h in self.hook_handles:
                h.remove()  # Dangerous during iteration!

        handle = self.layer.register_forward_hook(hook)
        self.hook_handles.append(handle)

# ✅ CORRECT: Flag for removal, remove after forward pass
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(10, 10)
        self.hook_handles = []
        self.hooks_to_remove = []

    def add_temporary_hook(self):
        def hook(module, input, output):
            print("Hook called!")
            # ✅ Flag for removal
            self.hooks_to_remove.append(handle)

        handle = self.layer.register_forward_hook(hook)
        self.hook_handles.append(handle)

    def cleanup_hooks(self):
        """Call after forward pass"""
        for handle in self.hooks_to_remove:
            handle.remove()
            self.hook_handles.remove(handle)
        self.hooks_to_remove.clear()

Rule: Never modify hook handles during forward pass. Flag for removal and clean up after.


Edge Case 6: Custom Modules with Buffers

Pattern: Buffers are non-parameter tensors that should be saved/moved with model.

class RunningStatsModule(nn.Module):
    def __init__(self, num_features):
        super().__init__()

        # ❌ WRONG: Just store as attribute
        self.running_mean = torch.zeros(num_features)  # Not registered!

        # ✅ CORRECT: Register as buffer
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

        # Parameters (learnable)
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))

    def forward(self, x):
        # Update running stats (in training mode)
        if self.training:
            mean = x.mean(dim=0)
            var = x.var(dim=0)
            # ✅ In-place update of buffers
            self.running_mean.mul_(0.9).add_(mean, alpha=0.1)
            self.running_var.mul_(0.9).add_(var, alpha=0.1)

        # Normalize using running stats
        normalized = (x - self.running_mean) / torch.sqrt(self.running_var + 1e-5)
        return normalized * self.weight + self.bias

# Buffers are moved with model:
model = RunningStatsModule(10)
model.cuda()  # ✅ running_mean and running_var moved to GPU

# Buffers are saved in state_dict:
torch.save(model.state_dict(), 'model.pth')  # ✅ Includes buffers

When to use buffers:

  • Running statistics (BatchNorm-style)
  • Fixed embeddings (not updated by optimizer)
  • Positional encodings (not learned)
  • Masks or indices

Rule: Use register_buffer() for tensors that aren't parameters but should be saved/moved.


Common Rationalizations (Don't Do These)

Excuse Reality Correct Approach
"User wants quick solution, I'll use None" Quick becomes slow when DDP breaks Always use nn.Identity(), same speed
"It's just a prototype, proper patterns later" Prototype becomes production, tech debt compounds Build correctly from start, no extra time
"F.relu() is more Pythonic/simpler" True, but prevents hooks and modification Use nn.ReLU() if any chance of needing hooks
"I'll fix initialization in training loop" Defeats purpose of reset_parameters() Put in reset_parameters(), 5 extra lines
"Bias is almost always there" False! Many models use bias=False Check if bias is not None, always
"Hooks are advanced, user won't use them" Until they need debugging or feature extraction Design hookable from start, no cost
"I'll clean up hooks manually later" Later never comes, memory leaks persist Context manager takes 10 lines, bulletproof
"This module is simple, no need for modularity" Simple modules get extended and reused Substitutable components from start
"State dict loading always matches architecture" False! Checkpoints get reused across versions Implement reset_parameters() for partial loads
"In-place ops are fine, I'll remember detach+clone" Won't remember under pressure Document decision in code, add comment

Critical insight: "Shortcuts for simplicity" become "bugs in production." Proper patterns take seconds more, prevent hours of debugging.


Decision Frameworks

Framework 1: Module vs Functional Operations

Question: Should I use nn.ReLU() or F.relu()?

Will you ever need to:
├─ Register hooks on this operation? → Use nn.ReLU()
├─ Inspect architecture (model.named_modules())? → Use nn.ReLU()
├─ Swap activation (ReLU→GELU)? → Use nn.ReLU()
├─ Use quantization? → Use nn.ReLU()
└─ None of above AND performance critical? → F.relu() acceptable

Default: When in doubt, use module version. Performance difference negligible.


Framework 2: Hook Detachment Strategy

Question: In my hook, should I use detach(), detach().clone(), or neither?

Do you need gradients for analysis?
├─ Yes → Don't detach (but ensure short lifetime!)
└─ No → Continue...

Will the output be modified in-place later?
├─ Yes → Use .detach().clone()
├─ Unsure → Use .detach().clone() (safer)
└─ No → Use .detach() (sufficient)

Example decision:

# Scenario: Extract features for visualization (no gradients needed, no in-place)
def hook(module, input, output):
    return output.detach()  # ✅ Sufficient

# Scenario: Extract features, model has in-place ops (x += y)
def hook(module, input, output):
    return output.detach().clone()  # ✅ Necessary

# Scenario: Gradient analysis (rare!)
def hook(module, input, output):
    return output  # ⚠️ Keep gradients, but ensure short lifetime

Framework 3: Initialization Strategy Selection

Question: Which initialization should I use?

Activation function?
├─ ReLU family → Kaiming (He) initialization
├─ Tanh/Sigmoid → Xavier (Glorot) initialization
├─ GELU/Swish → Xavier or Kaiming (experiment)
└─ None/Linear → Xavier

Layer type?
├─ Conv → Usually Kaiming with mode='fan_out'
├─ Linear → Kaiming or Xavier depending on activation
├─ Embedding → Normal(0, 1) or Xavier
└─ LSTM/GRU → Xavier for gates

Special considerations?
├─ ResNet-style → Last layer of block: small gain (e.g., 0.5)
├─ Transformer → Xavier uniform, specific scale for embeddings
├─ GAN → Careful initialization critical (see paper)
└─ Pre-trained → Don't re-initialize! Load checkpoint

Code example:

def reset_parameters(self):
    for module in self.modules():
        if isinstance(module, nn.Conv2d):
            # ReLU activation → Kaiming
            nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
            if module.bias is not None:
                nn.init.zeros_(module.bias)

        elif isinstance(module, nn.Linear):
            # Check what activation follows (from self.config or hardcoded)
            if self.activation == 'relu':
                nn.init.kaiming_uniform_(module.weight, nonlinearity='relu')
            else:
                nn.init.xavier_uniform_(module.weight)

            if module.bias is not None:
                nn.init.zeros_(module.bias)

Framework 4: When to Use Buffers vs Parameters vs Attributes

Decision tree:

Is it a tensor that needs to be saved with the model?
└─ No → Regular attribute (self.x = value)
└─ Yes → Continue...

Should it be updated by optimizer?
└─ Yes → nn.Parameter()
└─ No → Continue...

Should it move with model (.to(device))?
└─ Yes → register_buffer()
└─ No → Regular attribute

Examples:
- Model weights → nn.Parameter()
- Running statistics (BatchNorm) → register_buffer()
- Configuration dict → Regular attribute
- Fixed positional encoding → register_buffer()
- Dropout probability → Regular attribute
- Learnable temperature → nn.Parameter()

Pressure Testing Scenarios

Scenario 1: Time Pressure

User: "I need this module quickly, just make it work."

Agent thought: "I'll use None and functional ops, faster to write."

Reality: Taking 30 seconds more to use nn.Identity() and nn.ReLU() prevents hours of debugging DDP issues.

Correct response: Apply patterns anyway. They're not slower to write once familiar.


Scenario 2: "Simple" Module

User: "This is a simple block, don't overcomplicate it."

Agent thought: "I'll hardcode ReLU and BatchNorm, it's just a prototype."

Reality: Prototypes become production. Making activation/norm substitutable takes one extra line.

Correct response: Design modularly from the start. "Simple" doesn't mean "brittle."


Scenario 3: Existing Codebase

User: "The existing code uses None for optional modules."

Agent thought: "I should match existing style for consistency."

Reality: Existing code may have bugs. Improving patterns is better than perpetuating anti-patterns.

Correct response: Use correct patterns. Offer to refactor existing code if user wants.


Scenario 4: "Just Getting Started"

User: "I'm just experimenting, I'll clean it up later."

Agent thought: "Proper patterns can wait until it works."

Reality: Later never comes. Or worse, you can't iterate quickly because of accumulated tech debt.

Correct response: Proper patterns don't slow down experimentation. They enable faster iteration.


Red Flags Checklist

Before writing __init__ or forward, check yourself:

Module Definition Red Flags

  • Am I assigning None to a module attribute?
    • FIX: Use nn.Identity()
  • Am I using functional ops (F.relu) without considering hooks?
    • ASK: Will this ever need inspection/modification?
  • Am I hardcoding architecture choices (ReLU, BatchNorm)?
    • FIX: Make them substitutable parameters
  • Am I creating modules in forward()?
    • FIX: All modules in __init__

Hook Usage Red Flags

  • Am I storing hook output without detaching?
    • FIX: Use .detach() or .detach().clone()
  • Am I registering hooks without storing handles?
    • FIX: Store handles, clean up in __exit__
  • Am I using global variables in hook closures?
    • FIX: Encapsulate in a class
  • Am I modifying hook handles during forward pass?
    • FIX: Flag for removal, clean up after

Initialization Red Flags

  • Am I initializing weights in __init__?
    • FIX: Define reset_parameters(), call from __init__
  • Am I accessing .bias without checking if it exists?
    • FIX: Check if module.bias is not None:
  • Am I using one initialization for all layers?
    • ASK: Should different layers have different strategies?

State Management Red Flags

  • Am I storing intermediate results as self.*?
    • FIX: Use local variables only
  • Am I using Python list for modules?
    • FIX: Use nn.ModuleList
  • Do I have tensors that should be buffers but aren't?
    • FIX: Use register_buffer()

If ANY red flag is true, STOP and apply the pattern before proceeding.


Quick Reference Cards

Card 1: Module Design Checklist

✓ super().__init__() called first
✓ All modules defined in __init__ (not forward)
✓ No None assignments (use nn.Identity())
✓ Substitutable components (norm_layer, activation args)
✓ reset_parameters() defined and called
✓ Defensive checks (if bias is not None)
✓ Buffers registered (register_buffer())
✓ No self.* assignments in forward()

Card 2: Hook Checklist

✓ Hook detaches output (.detach() or .detach().clone())
✓ Hook handles stored in list
✓ Context manager for cleanup (__enter__/__exit__)
✓ No global state mutation
✓ Error handling (try/except in hook)
✓ Documented whether hook modifies output

Card 3: Initialization Checklist

✓ reset_parameters() method defined
✓ Called from __init__
✓ Iterates through modules or layers
✓ Checks if bias is not None
✓ Uses appropriate init strategy (Kaiming/Xavier)
✓ Documents why this initialization
✓ Can be called to re-initialize

References

PyTorch Documentation:

Related Skills:

  • tensor-operations-and-memory (memory management)
  • debugging-techniques (using hooks for debugging)
  • distributed-training-strategies (DDP-compatible module design)
  • checkpointing-and-reproducibility (state dict best practices)