Claude Code Plugins

Community-maintained marketplace

Feedback

pytorch-common-pitfalls

@smith6jt-cop/Skills_Registry
0
0

Fix common PyTorch bugs: percentile calculations, LayerNorm for Conv1d, buffer edge cases. Trigger when writing PyTorch code for RL or neural networks.

Install Skill

1Download skill
2Enable skills in Claude

Open claude.ai/settings/capabilities and find the "Skills" section

3Upload to Claude

Click "Upload skill" and select the downloaded ZIP file

Note: Please verify skill by going through its instructions before using it.

SKILL.md

name pytorch-common-pitfalls
description Fix common PyTorch bugs: percentile calculations, LayerNorm for Conv1d, buffer edge cases. Trigger when writing PyTorch code for RL or neural networks.
author Claude Code
date Wed Jan 01 2025 00:00:00 GMT+0000 (Coordinated Universal Time)

PyTorch Common Pitfalls

Experiment Overview

Item Details
Date 2025-01-01
Goal Document and fix common PyTorch bugs found during code review
Environment PyTorch 2.x, CUDA-enabled training
Status Success

Context

During review of advanced RL components (ensemble PPO, VAE OOD detector, curiosity module, dilated CNN), GitHub Copilot identified several common PyTorch pitfalls. These patterns appear frequently in ML code and can cause subtle bugs.

Verified Patterns

1. Percentile Calculation: Use torch.quantile, NOT kthvalue

Problem: kthvalue requires integer index, leading to off-by-one errors and incorrect percentiles.

Wrong:

# BUG: kthvalue needs exact index, rounds incorrectly
k = int(self.threshold_percentile / 100.0 * errors.numel())
threshold = errors.view(-1).kthvalue(k).values

Correct:

# Use quantile for proper interpolation
q = self.threshold_percentile / 100.0
q = max(0.0, min(1.0, q))  # Clamp to valid range
threshold = torch.quantile(errors.view(-1), q)

Why: torch.quantile handles interpolation automatically and accepts float percentiles directly.

2. LayerNorm with Conv1d: Create a Wrapper

Problem: Conv1d outputs are (batch, channels, time) but LayerNorm expects normalization over the last dimension.

Wrong:

# BUG: LayerNorm(channels) applied to wrong dimension
self.norm = nn.LayerNorm(num_channels)
x = self.conv(x)  # (batch, channels, time)
x = self.norm(x)  # Normalizes over the last dim (time) for each channel, not across channels as intended

Correct:

class ChannelLayerNorm1d(nn.Module):
    """LayerNorm wrapper for Conv1d outputs shaped (batch, channels, time)."""

    def __init__(self, num_channels: int):
        super().__init__()
        self.ln = nn.LayerNorm(num_channels)

    def forward(self, x: Tensor) -> Tensor:
        # x: (batch, channels, time) -> transpose -> (batch, time, channels)
        x = x.transpose(1, 2)
        x = self.ln(x)  # Now normalizes over channels
        return x.transpose(1, 2)  # Back to (batch, channels, time)

# Usage
self.norm = ChannelLayerNorm1d(num_channels)

3. Buffer Edge Cases: Handle Last Transitions

Problem: When sampling from replay buffers, next_obs = obs[indices + 1] fails for the last transition.

Wrong:

# BUG: Index out of bounds when indices contains last index
indices = torch.randint(0, total_steps, (batch_size,))
next_obs = obs_flat[indices + 1]  # Crashes if indices contains (total_steps - 1)

Correct:

# Handle edge case for last transitions
total_steps = obs_flat.shape[0]
next_indices = indices + 1
has_next = next_indices < total_steps

# Use safe indices (clamp to valid range)
safe_next_indices = next_indices.clone()
safe_next_indices[~has_next] = total_steps - 1

next_obs = obs_flat[safe_next_indices]

# Mark transitions without valid next_obs as terminal
dones = dones_flat[indices].clone()
dones[~has_next] = True  # These are terminal states

Alternative - Filter invalid transitions:

# Only use transitions with valid next_obs
valid_mask = next_indices < total_steps
valid_indices = indices[valid_mask]
valid_next_indices = next_indices[valid_mask]

obs = obs_flat[valid_indices]
next_obs = obs_flat[valid_next_indices]
actions = actions_flat[valid_indices]

4. Safe Action Gating with torch.where

Problem: Boolean masking with multiplication can be unclear and error-prone.

Wrong:

# Unclear intent, potential floating-point issues
gated = proposed * (1 - mask.float()) + hold * mask.float()

Correct:

# Clear, explicit conditional
gated_action = torch.where(
    intervention_mask,  # Boolean tensor
    torch.full_like(proposed_action, self.hold_action),  # If True
    proposed_action,  # If False
)

5. NaN/Inf Checks in Calibration

Problem: Fitting distributions or thresholds on data with NaN/Inf values corrupts the model.

# Always check for invalid values before calibration
errors = self._compute_reconstruction_errors(data)

if torch.isnan(errors).any() or torch.isinf(errors).any():
    raise ValueError(
        f"Invalid values in calibration data: "
        f"NaN={torch.isnan(errors).sum().item()}, "
        f"Inf={torch.isinf(errors).sum().item()}"
    )

# Safe to proceed with calibration
self.threshold = torch.quantile(errors.view(-1), q)

6. Minimum Sample Size Checks

Problem: Computing statistics on too few samples produces unreliable results.

# Ensure sufficient samples before computing diversity metrics
min_samples = 32

if n_samples < min_samples or total_transitions == 0:
    return torch.zeros(self.config.n_learners, device=self.device)

# Safe to compute statistics
diversity_bonus = self._compute_diversity(samples)

Failed Attempts

Attempt Why it Failed Lesson Learned
Using kthvalue for percentiles Off-by-one errors, integer rounding issues Use torch.quantile for floating-point percentiles
LayerNorm(channels) on Conv1d output Normalizes wrong dimension (time instead of channels) Create wrapper that transposes before/after LayerNorm
obs[indices + 1] without bounds check Index out of bounds for last transition Always handle edge case or filter invalid indices
Silent failures returning zeros Bugs hidden, hard to debug Raise ValueError with informative message
Using x: any = ... without import Built-in any is not the same as typing.Any; not a valid type hint Use from typing import Any and annotate as x: Any = ...

Key Insights

  • torch.quantile > kthvalue: Proper interpolation, accepts float directly
  • Transpose for LayerNorm: Conv1d is (B,C,T), LayerNorm normalizes last dim
  • Bounds checking is mandatory: Buffer sampling must handle edge cases
  • Fail loudly: raise ValueError > return zeros silently
  • Type hints matter: Any from typing, not Python builtin any

References