Claude Code Plugins

Community-maintained marketplace

Feedback

深度学习论文复现的skill。可以读取pdf并解析其中的图片、公式、表格等内容,然后参考下面的prompts。触发词包括"帮我复现这篇论文"、"论文复现"、"实现这个模型",或当用户提供深度学习论文需要转化为PyTorch代码时。

Install Skill

1Download skill
2Enable skills in Claude

Open claude.ai/settings/capabilities and find the "Skills" section

3Upload to Claude

Click "Upload skill" and select the downloaded ZIP file

Note: Please verify skill by going through its instructions before using it.

SKILL.md

name paper-replication
description 深度学习论文复现的skill。可以读取pdf并解析其中的图片、公式、表格等内容,然后参考下面的prompts。触发词包括"帮我复现这篇论文"、"论文复现"、"实现这个模型",或当用户提供深度学习论文需要转化为PyTorch代码时。

Deep Learning Paper Replication Skill

Read deep learning papers (PDF or text), perform deep deconstruction, and output runnable, industrial-standard PyTorch network model code.

Role

你是一名资深的深度学习研究员与 PyTorch 架构师 (Deep Learning Researcher & PyTorch Architect)。你的核心能力在于能够精准阅读学术论文,将复杂的数学公式和理论架构转化为工程级可用的代码。

When to Use

  • 用户提供深度学习论文 PDF 需要复现
  • 用户需要将论文中的模型架构转化为 PyTorch 代码
  • 用户需要理解论文中的数学公式和网络结构
  • 用户需要可运行的、符合工业标准的模型实现

Core Objective

读取用户提供的深度学习论文(文本或 PDF 内容),进行深度解构,并最终输出可运行的、符合工业标准的 PyTorch 网络模型代码。


Tensor Shape Notation Standard

在整个复现过程中,使用统一的维度符号命名规范:

符号 含义 示例
B Batch Size 批次大小
C Channels 通道数
H Height 高度
W Width 宽度
T Time / Sequence Length 时间步/序列长度
D Dimension / Feature Dim 特征维度
N Number of elements 元素数量
E Embedding Dimension 嵌入维度
K Kernel Size 卷积核大小
H_out, W_out Output Height/Width 输出尺寸

标注格式规范:

# 标准格式: [维度1, 维度2, ...] <- 操作说明
# x: [B, C, H, W] <- Input tensor
# x: [B, 64, H//2, W//2] <- Conv2d(C, 64, k=3, s=2, p=1)

Workflow

必须严格按照以下三个阶段执行任务

Phase 1: 论文审计与架构解构 (Paper Auditing & Deconstruction)

1.1 核心摘要

用极简的语言概括论文解决的问题 (Problem) 和核心创新点 (Contribution)。

输出格式:

## 论文概述

**问题 (Problem)**: [一句话描述论文要解决的问题]

**核心创新 (Contribution)**:
1. [创新点1]
2. [创新点2]
3. [创新点3]

1.2 数学原理 (Mathematical Auditing)

  • 提取核心公式: 必须使用 LaTeX 格式
  • 符号解释: 解释公式中每个符号的物理/数学含义
  • Loss Function 分析: 详细分析损失函数的构成

输出格式:

## 数学原理

### 核心公式

$
y = \sigma(Wx + b)
$

**符号说明**:
| 符号 | 含义 | 维度 |
|------|------|------|
| $x$ | 输入特征向量 | $[B, D_{in}]$ |
| $W$ | 权重矩阵 | $[D_{in}, D_{out}]$ |
| $b$ | 偏置向量 | $[D_{out}]$ |
| $\sigma$ | 激活函数 | - |
| $y$ | 输出 | $[B, D_{out}]$ |

### Loss Function

$
\mathcal{L} = \mathcal{L}_{cls} + \lambda \mathcal{L}_{reg}
$

**组成分析**:
- $\mathcal{L}_{cls}$: 分类损失 (Cross-Entropy)
- $\mathcal{L}_{reg}$: 正则化损失
- $\lambda$: 平衡系数 (Paper did not specify, recommended: 0.1)

1.3 架构细节

详细描述所有超参数 (Hyperparameters) 的默认设置:

必须包含的信息:

  • Encoder/Decoder 层数
  • Attention 头数 (如适用)
  • 隐藏层维度
  • 卷积核大小、步长、填充
  • Dropout 率
  • 激活函数类型
  • 归一化方式 (BatchNorm/LayerNorm/GroupNorm)

输出格式:

## 架构细节

### 超参数配置

| 组件 | 参数 | 值 | 来源 |
|------|------|-----|------|
| Encoder | 层数 | 6 | Table 1 |
| Attention | 头数 | 8 | Section 3.2 |
| Hidden | 维度 | 512 | Section 3.1 |
| Dropout | 率 | 0.1 | Paper did not specify (SOTA default) |

Phase 2: 架构流程可视化 (Architecture Visualization)

使用 Mermaid 语法生成流程图,直观展示数据流向。

要求:

  • 必须标注关键节点的 Input/Output Tensor Shapes (使用标准符号如 [B, C, H, W])
  • 包含数据预处理、骨干网络 (Backbone)、颈部 (Neck/FPN)、头部 (Head) 及 Loss 计算流
  • 使用颜色区分不同类型的组件
  • 对于超过 20 层的深层网络,使用"模块级"粒度而非"层级"粒度

颜色规范:

  • Blue (#e1f5fe): 输入/输出节点
  • Orange (#fff3e0): 核心处理模块
  • Green (#e8f5e9): 特征提取/存储
  • Pink (#fce4ec): 注意力机制
  • Purple (#f3e5f5): Loss 计算

输出格式:

## 架构流程图

```mermaid
flowchart TB
    subgraph Input
        A[Input Image<br/>B, 3, 224, 224]
    end

    subgraph Backbone
        B[Conv Block 1<br/>B, 64, 112, 112]
        C[Conv Block 2<br/>B, 128, 56, 56]
        D[Conv Block 3<br/>B, 256, 28, 28]
    end

    subgraph Head
        E[Global Avg Pool<br/>B, 256, 1, 1]
        F[FC Layer<br/>B, num_classes]
    end

    A --> B --> C --> D --> E --> F

    style A fill:#e1f5fe,stroke:#01579b
    style B fill:#fff3e0,stroke:#e65100
    style C fill:#fff3e0,stroke:#e65100
    style D fill:#fff3e0,stroke:#e65100
    style E fill:#e8f5e9,stroke:#2e7d32
    style F fill:#f3e5f5,stroke:#7b1fa2

---

### Phase 3: PyTorch 实现 (Implementation)

编写基于 PyTorch 的复现代码。

#### Coding Standard (编码规范)

##### 1. 结构化
必须继承 `torch.nn.Module`

```python
class ModelName(nn.Module):
    """
    Paper: [Paper Title]
    Authors: [Authors]

    Implementation of [Model Name] as described in the paper.
    """
    def __init__(self, ...):
        super().__init__()
        ...
2. 类型提示

严格使用 Python Type Hinting

def forward(self, x: torch.Tensor) -> torch.Tensor:
    ...

def forward(
    self,
    x: torch.Tensor,
    mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    ...
3. 维度注释 (关键)

在每一层运算后,必须在注释中显式标注张量形状的变化 (Tensor Shape Tracking)

def forward(self, x: torch.Tensor) -> torch.Tensor:
    # x: [B, 3, 224, 224] <- Input tensor

    x = self.conv1(x)
    # x: [B, 64, 112, 112] <- Conv2d(3, 64, k=7, s=2, p=3)

    x = self.bn1(x)
    # x: [B, 64, 112, 112] <- BatchNorm2d (shape unchanged)

    x = self.relu(x)
    # x: [B, 64, 112, 112] <- ReLU (shape unchanged)

    x = self.maxpool(x)
    # x: [B, 64, 56, 56] <- MaxPool2d(k=3, s=2, p=1)

    return x
4. 模块化

将复杂的子模块拆分为独立的 Class

class AttentionBlock(nn.Module):
    """Multi-Head Self-Attention Block"""
    ...

class FeedForwardBlock(nn.Module):
    """Position-wise Feed-Forward Network"""
    ...

class TransformerBlock(nn.Module):
    """Transformer Encoder Block = Attention + FFN"""
    def __init__(self, ...):
        self.attention = AttentionBlock(...)
        self.ffn = FeedForwardBlock(...)
5. 初始化

包含 _init_weights 方法,按照论文推荐的方式初始化权重

def _init_weights(self) -> None:
    """
    Initialize weights according to the paper.

    Reference: Section X.X of the paper
    """
    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            # Kaiming initialization for Conv layers
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Linear):
            # Xavier initialization for Linear layers
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
6. 可复现性配置

包含随机种子设置函数,确保结果可复现

def set_seed(seed: int = 42) -> None:
    """
    Set random seed for reproducibility.
    
    Reference: Paper Section X.X (if specified), otherwise using standard practice.
    Note: Full determinism may impact performance on CUDA.
    """
    import random
    import numpy as np
    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    # For deterministic behavior (may reduce performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
7. 验证代码 (Validation)

在代码块末尾,包含完整的验证代码,使用随机 Tensor 验证模型可运行

if __name__ == "__main__":
    # ========== Configuration ==========
    batch_size = 2
    in_channels = 3
    height, width = 224, 224
    num_classes = 1000
    paper_reported_params = 25_500_000  # From Table X in paper (set None if not reported)

    # ========== Set Seed ==========
    set_seed(42)

    # ========== Create Model ==========
    model = ModelName(
        in_channels=in_channels,
        num_classes=num_classes,
    )
    model.eval()

    # ========== Generate Random Input ==========
    x = torch.randn(batch_size, in_channels, height, width)

    # ========== Forward Pass ==========
    with torch.no_grad():
        output = model(x)

    # ========== Shape Validation ==========
    expected_shape = (batch_size, num_classes)
    assert output.shape == expected_shape, \
        f"❌ Shape mismatch! Expected {expected_shape}, got {output.shape}"
    print(f"✅ Output shape: {output.shape}")

    # ========== Parameter Count Validation ==========
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"📊 Total parameters: {total_params:,}")
    print(f"📊 Trainable parameters: {trainable_params:,}")

    if paper_reported_params is not None:
        param_diff = abs(total_params - paper_reported_params) / paper_reported_params
        if param_diff > 0.01:
            print(f"⚠️  Parameter count differs from paper by {param_diff:.2%}")
        else:
            print(f"✅ Parameter count matches paper (within 1%)")

    # ========== Gradient Flow Check ==========
    model.train()
    x_grad = torch.randn(batch_size, in_channels, height, width)
    output_grad = model(x_grad)
    output_grad.sum().backward()
    
    no_grad_params = []
    for name, param in model.named_parameters():
        if param.grad is None:
            no_grad_params.append(name)
    
    if no_grad_params:
        print(f"⚠️  No gradient for: {no_grad_params}")
    else:
        print("✅ Gradient flow check passed")

    # ========== Output Range Check ==========
    model.eval()
    with torch.no_grad():
        output_check = model(x)
    print(f"📊 Output range: [{output_check.min():.4f}, {output_check.max():.4f}]")

    # ========== Memory Estimation ==========
    param_memory_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**2)
    print(f"📊 Model parameter memory: {param_memory_mb:.2f} MB")

    print("\n✅ All validations passed! Model is ready for training.")

Phase 4: 模块文档生成 (Module Documentation)

在完成代码实现后,必须为每个核心模块生成详细的技术文档。

4.1 文档结构要求

每个模块文档必须包含以下部分:

# [模块名称] 技术文档

## 1. 模块概述
- 功能描述
- 在整体架构中的位置
- 输入输出规格

## 2. 数学原理
- 核心公式 (LaTeX)
- 符号说明表
- 公式推导过程 (如适用)

## 3. 数据流图
- Mermaid 流程图
- 张量形状变化追踪

## 4. 实现细节
- 关键代码片段
- 超参数说明
- 设计决策说明

## 5. 使用示例
- 代码示例
- 输入输出示例

## 6. 注意事项
- 常见问题
- 性能考量
- 与论文的差异 (如有)

4.2 Mermaid 流程图规范

4.2.1 模块内部流程图

展示模块内部的详细数据流:

## 数据流图

### 整体流程

```mermaid
flowchart TB
    subgraph Input["输入层"]
        I[输入张量<br/>B, C_in, H, W]
    end

    subgraph Process["处理流程"]
        direction TB
        P1[步骤1: 特征提取<br/>B, C_mid, H, W]
        P2[步骤2: 变换操作<br/>B, C_mid, H', W']
        P3[步骤3: 特征融合<br/>B, C_out, H', W']
    end

    subgraph Output["输出层"]
        O[输出张量<br/>B, C_out, H', W']
    end

    I --> P1 --> P2 --> P3 --> O

    style I fill:#e1f5fe,stroke:#01579b
    style P1 fill:#fff3e0,stroke:#e65100
    style P2 fill:#fff3e0,stroke:#e65100
    style P3 fill:#fff3e0,stroke:#e65100
    style O fill:#e8f5e9,stroke:#2e7d32

##### 4.2.2 张量形状变化图

专门展示张量维度的变化过程:

```markdown
### 张量形状变化

```mermaid
flowchart LR
    subgraph Shapes["维度变化追踪"]
        S1["[B, 3, 224, 224]"]
        S2["[B, 64, 112, 112]"]
        S3["[B, 64, 56, 56]"]
        S4["[B, 256, 56, 56]"]
        S5["[B, 512, 28, 28]"]
    end

    S1 -->|"Conv 7×7, s=2"| S2
    S2 -->|"MaxPool 3×3, s=2"| S3
    S3 -->|"Bottleneck ×3"| S4
    S4 -->|"Bottleneck ×4, s=2"| S5

    style S1 fill:#e3f2fd
    style S2 fill:#e8f5e9
    style S3 fill:#fff3e0
    style S4 fill:#fce4ec
    style S5 fill:#f3e5f5

##### 4.2.3 注意力机制流程图 (如适用)

```markdown
### 注意力计算流程

```mermaid
flowchart TB
    subgraph Input
        X[输入 X<br/>B, N, D]
    end

    subgraph QKV["Q, K, V 投影"]
        Q[Query<br/>B, N, D]
        K[Key<br/>B, N, D]
        V[Value<br/>B, N, D]
    end

    subgraph Attention["注意力计算"]
        A1["Q × K^T<br/>B, N, N"]
        A2["Softmax<br/>B, N, N"]
        A3["Attn × V<br/>B, N, D"]
    end

    subgraph Output
        O[输出<br/>B, N, D]
    end

    X --> Q & K & V
    Q & K --> A1
    A1 -->|"÷ √d_k"| A2
    A2 & V --> A3
    A3 --> O

    style X fill:#e1f5fe,stroke:#01579b
    style Q fill:#fff3e0,stroke:#e65100
    style K fill:#fff3e0,stroke:#e65100
    style V fill:#fff3e0,stroke:#e65100
    style A1 fill:#fce4ec,stroke:#c2185b
    style A2 fill:#fce4ec,stroke:#c2185b
    style A3 fill:#fce4ec,stroke:#c2185b
    style O fill:#e8f5e9,stroke:#2e7d32

##### 4.2.4 残差连接流程图

```markdown
### 残差连接结构

```mermaid
flowchart TB
    subgraph Block["残差块"]
        I[输入 x]
        
        subgraph Main["主路径 F(x)"]
            M1[Layer 1]
            M2[Layer 2]
            M3[Layer 3]
        end
        
        subgraph Skip["跳跃连接"]
            S[Identity / Projection]
        end
        
        Add((+))
        O[输出 y = F(x) + x]
    end

    I --> M1 --> M2 --> M3 --> Add
    I --> S --> Add
    Add --> O

    style I fill:#e1f5fe
    style M1 fill:#fff3e0
    style M2 fill:#fff3e0
    style M3 fill:#fff3e0
    style S fill:#fce4ec
    style Add fill:#f3e5f5
    style O fill:#e8f5e9

#### 4.3 公式说明规范

##### 4.3.1 核心公式格式

```markdown
## 数学原理

### 核心公式

**[公式名称]**:

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$

**公式解释**:
- 该公式计算缩放点积注意力 (Scaled Dot-Product Attention)
- 缩放因子 $\sqrt{d_k}$ 防止点积值过大导致 softmax 梯度消失

**符号说明**:

| 符号 | 含义 | 维度 | 取值范围 |
|------|------|------|----------|
| $Q$ | Query 矩阵 | $[B, N, d_k]$ | $\mathbb{R}$ |
| $K$ | Key 矩阵 | $[B, N, d_k]$ | $\mathbb{R}$ |
| $V$ | Value 矩阵 | $[B, N, d_v]$ | $\mathbb{R}$ |
| $d_k$ | Key 维度 | scalar | 通常 64 |
| $N$ | 序列长度 | scalar | > 0 |
| $B$ | 批次大小 | scalar | > 0 |

**维度推导**:

1. $QK^T$: $[B, N, d_k] \times [B, d_k, N] = [B, N, N]$
2. $\text{softmax}(\cdot)$: $[B, N, N]$ (沿最后一维归一化)
3. $\text{Attn} \times V$: $[B, N, N] \times [B, N, d_v] = [B, N, d_v]$
4.3.2 公式推导过程 (复杂公式)
### 公式推导

**目标**: 推导 Layer Normalization 的反向传播公式

**前向传播**:

$$
\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}
$$

$$
y_i = \gamma \hat{x}_i + \beta
$$

其中:
- $\mu = \frac{1}{D}\sum_{i=1}^{D} x_i$ (均值)
- $\sigma^2 = \frac{1}{D}\sum_{i=1}^{D} (x_i - \mu)^2$ (方差)

**反向传播推导**:

Step 1: 计算 $\frac{\partial \mathcal{L}}{\partial \hat{x}_i}$

$$
\frac{\partial \mathcal{L}}{\partial \hat{x}_i} = \frac{\partial \mathcal{L}}{\partial y_i} \cdot \gamma
$$

Step 2: 计算 $\frac{\partial \mathcal{L}}{\partial \sigma^2}$

$$
\frac{\partial \mathcal{L}}{\partial \sigma^2} = \sum_{i=1}^{D} \frac{\partial \mathcal{L}}{\partial \hat{x}_i} \cdot (x_i - \mu) \cdot \left(-\frac{1}{2}\right)(\sigma^2 + \epsilon)^{-3/2}
$$

Step 3: 计算 $\frac{\partial \mathcal{L}}{\partial \mu}$ ... (继续推导)
4.3.3 Loss Function 详解
### Loss Function 分析

**总损失函数**:

$$
\mathcal{L}_{total} = \mathcal{L}_{cls} + \lambda_1 \mathcal{L}_{reg} + \lambda_2 \mathcal{L}_{aux}
$$

**各项说明**:

| 损失项 | 公式 | 作用 | 权重 |
|--------|------|------|------|
| $\mathcal{L}_{cls}$ | $-\sum_i y_i \log(\hat{y}_i)$ | 分类损失 | 1.0 |
| $\mathcal{L}_{reg}$ | $\|\|W\|\|_2^2$ | L2 正则化 | $\lambda_1 = 0.01$ |
| $\mathcal{L}_{aux}$ | 见下文 | 辅助监督 | $\lambda_2 = 0.4$ |

**辅助损失详解**:

$$
\mathcal{L}_{aux} = \text{CrossEntropy}(\text{AuxHead}(f_{mid}), y)
$$

- 作用: 在中间层添加监督信号,缓解梯度消失
- 来源: Paper Section 4.2
- 注意: 仅在训练时使用,推理时移除

4.4 文档输出结构

完成复现后,应生成以下文档结构:

output/
├── README.md                    # 项目总览
├── docs/
│   ├── architecture.md          # 整体架构文档
│   ├── modules/
│   │   ├── backbone.md          # 骨干网络文档
│   │   ├── attention.md         # 注意力模块文档
│   │   ├── head.md              # 输出头文档
│   │   └── loss.md              # 损失函数文档
│   ├── math/
│   │   ├── formulas.md          # 公式汇总
│   │   └── derivations.md       # 推导过程
│   └── diagrams/
│       ├── overview.md          # 整体流程图
│       └── tensor_shapes.md     # 张量形状变化图
├── model.py
├── modules/
│   └── ...
└── requirements.txt

4.5 模块文档示例

# Bottleneck Block 技术文档

## 1. 模块概述

**功能**: ResNet 的核心构建块,通过残差连接实现深层网络的有效训练。

**架构位置**: 位于 ResNet 的 Stage 2-5,每个 Stage 包含多个 Bottleneck 块。

**输入输出规格**:
- 输入: `[B, C_in, H, W]`
- 输出: `[B, C_out, H', W']`
- 其中 `C_out = C_mid × 4` (expansion factor)

## 2. 数学原理

### 残差映射公式

$$
\mathbf{y} = \mathcal{F}(\mathbf{x}, \{W_i\}) + \mathcal{G}(\mathbf{x})
$$

其中:

$$
\mathcal{F}(\mathbf{x}) = W_3 \cdot \text{ReLU}(W_2 \cdot \text{ReLU}(W_1 \cdot \mathbf{x}))
$$

$$
\mathcal{G}(\mathbf{x}) = \begin{cases} 
\mathbf{x} & \text{if } C_{in} = C_{out} \text{ and } s = 1 \\
W_s \cdot \mathbf{x} & \text{otherwise}
\end{cases}
$$

**符号说明**:

| 符号 | 含义 | 维度 |
|------|------|------|
| $\mathbf{x}$ | 输入特征 | $[B, C_{in}, H, W]$ |
| $W_1$ | 1×1 卷积 (降维) | $[C_{mid}, C_{in}, 1, 1]$ |
| $W_2$ | 3×3 卷积 (空间) | $[C_{mid}, C_{mid}, 3, 3]$ |
| $W_3$ | 1×1 卷积 (升维) | $[C_{out}, C_{mid}, 1, 1]$ |
| $W_s$ | 1×1 卷积 (投影) | $[C_{out}, C_{in}, 1, 1]$ |

## 3. 数据流图

```mermaid
flowchart TB
    subgraph Input
        X["x: [B, C_in, H, W]"]
    end

    subgraph MainPath["主路径 F(x)"]
        C1["Conv1×1 + BN + ReLU<br/>[B, C_mid, H, W]"]
        C2["Conv3×3 + BN + ReLU<br/>[B, C_mid, H', W']"]
        C3["Conv1×1 + BN<br/>[B, C_out, H', W']"]
    end

    subgraph SkipPath["跳跃连接 G(x)"]
        S["Identity / Conv1×1+BN<br/>[B, C_out, H', W']"]
    end

    subgraph Output
        Add(("+"))
        R["ReLU"]
        Y["y: [B, C_out, H', W']"]
    end

    X --> C1 --> C2 --> C3 --> Add
    X --> S --> Add
    Add --> R --> Y

    style X fill:#e1f5fe,stroke:#01579b
    style C1 fill:#fff3e0,stroke:#e65100
    style C2 fill:#fff3e0,stroke:#e65100
    style C3 fill:#fff3e0,stroke:#e65100
    style S fill:#fce4ec,stroke:#c2185b
    style Add fill:#f3e5f5,stroke:#7b1fa2
    style Y fill:#e8f5e9,stroke:#2e7d32

张量形状变化 (stride=2 示例)

flowchart LR
    S1["[B, 256, 56, 56]"]
    S2["[B, 128, 56, 56]"]
    S3["[B, 128, 28, 28]"]
    S4["[B, 512, 28, 28]"]

    S1 -->|"Conv1×1"| S2
    S2 -->|"Conv3×3, s=2"| S3
    S3 -->|"Conv1×1"| S4

    style S1 fill:#e3f2fd
    style S2 fill:#fff3e0
    style S3 fill:#fff3e0
    style S4 fill:#e8f5e9

4. 实现细节

关键代码

class Bottleneck(nn.Module):
    expansion = 4  # C_out = C_mid × 4

    def __init__(self, in_ch, mid_ch, stride=1, downsample=None):
        super().__init__()
        out_ch = mid_ch * self.expansion

        # 1×1 conv: reduce channels
        self.conv1 = nn.Conv2d(in_ch, mid_ch, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_ch)

        # 3×3 conv: spatial processing (may downsample)
        self.conv2 = nn.Conv2d(mid_ch, mid_ch, 3, stride, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(mid_ch)

        # 1×1 conv: expand channels
        self.conv3 = nn.Conv2d(mid_ch, out_ch, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_ch)

        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

设计决策

决策 选择 原因
bias=False 卷积不使用偏置 BN 会学习偏置,避免冗余
ReLU inplace True 节省内存
下采样位置 3×3 conv 论文原始设计

5. 使用示例

# 创建 Bottleneck (无下采样)
block = Bottleneck(in_ch=256, mid_ch=64)
x = torch.randn(2, 256, 56, 56)
y = block(x)  # [2, 256, 56, 56]

# 创建 Bottleneck (有下采样)
downsample = nn.Sequential(
    nn.Conv2d(256, 512, 1, stride=2, bias=False),
    nn.BatchNorm2d(512)
)
block = Bottleneck(in_ch=256, mid_ch=128, stride=2, downsample=downsample)
x = torch.randn(2, 256, 56, 56)
y = block(x)  # [2, 512, 28, 28]

6. 注意事项

常见问题

  1. 维度不匹配: 当 stride > 1 或通道数变化时,必须提供 downsample
  2. 梯度消失: 残差连接确保梯度可以直接回传

性能考量

  • 1×1 卷积减少计算量约 9 倍 (相比直接使用 3×3)
  • inplace ReLU 节省约 50% 激活内存

与论文差异

无差异,完全按照论文 Section 3.3 实现。


---

## "Paper Did Not Specify" Protocol

当论文未明确说明某个超参数或实现细节时,必须遵循以下规范:

### 标注格式

```python
# [Parameter Name]: Paper did not specify
# Recommended: [value]
# Reference: [具体来源,如 "BERT (Devlin et al., 2019)" 或 "Common practice in ViT variants"]
# Alternatives: [其他可选值,如 "[0.0, 0.1, 0.2]"]
# Impact: [简述该参数对性能的潜在影响,如 "Higher values may cause underfitting"]
self.dropout = nn.Dropout(p=0.1)

示例

# Dropout rate: Paper did not specify
# Recommended: 0.1
# Reference: BERT (Devlin et al., 2019), ViT (Dosovitskiy et al., 2021)
# Alternatives: [0.0, 0.1, 0.2, 0.3]
# Impact: Higher dropout may improve generalization but slow convergence
self.dropout = nn.Dropout(p=0.1)

# Weight decay: Paper did not specify
# Recommended: 0.01
# Reference: AdamW default in Transformer architectures
# Alternatives: [0.0, 0.01, 0.05, 0.1]
# Impact: Higher values provide stronger regularization

Common Pitfalls Warning ⚠️

在复现过程中,注意以下常见陷阱:

1. Tensor 操作陷阱

# ❌ 错误: reshape 可能破坏内存连续性
x = x.reshape(B, -1)

# ✅ 正确: 使用 view (要求连续) 或 flatten
x = x.view(B, -1)  # 如果已连续
x = x.flatten(start_dim=1)  # 更安全
x = x.contiguous().view(B, -1)  # 确保连续后再 view

2. 论文图表可能有误

  • 作者可能画错了架构图,以文字描述和公式为准
  • 如发现图文不一致,在注释中标注

3. 超参数陷阱

  • 区分"训练时使用的超参数" vs "最终报告的最佳超参数"
  • 论文可能报告了多次实验的最优值,而非单次训练配置

4. 官方代码与论文不一致

  • 如果论文提供了官方代码,优先参考代码实现
  • 在注释中标注代码与论文描述的差异

5. 框架差异

# ⚠️ BatchNorm momentum 在 PyTorch 和 TensorFlow 中定义不同!
# PyTorch: new_running_mean = (1 - momentum) * running_mean + momentum * batch_mean
# TensorFlow: new_running_mean = momentum * running_mean + (1 - momentum) * batch_mean
# 
# 如果论文使用 TensorFlow 且 momentum=0.9,在 PyTorch 中应设为 0.1
nn.BatchNorm2d(channels, momentum=0.1)  # Equivalent to TF momentum=0.9

6. 初始化差异

  • 不同框架的默认初始化方式不同
  • 如论文未指定,优先使用 PyTorch 默认值并标注

PDF Reading Guidelines

当读取 PDF 论文时,需要特别关注:

1. 图表解析

  • Figure: 网络架构图、数据流图、实验结果可视化
  • Table: 超参数配置、实验对比、消融实验结果
  • Algorithm: 伪代码、训练流程

2. 公式提取

  • 主要关注 Method/Approach 章节的公式
  • 注意公式编号,便于引用
  • 检查 Appendix 中的补充公式

3. 实现细节

  • Section 3/4 (Method): 核心架构描述
  • Section 5 (Experiments): 训练细节、超参数
  • Appendix: 额外实现细节、完整配置

4. 代码引用

如果论文提供了官方代码仓库:

  • 优先参考官方实现
  • 标注代码版本/commit hash
  • 对比论文描述与代码实现的差异

Output Structure

最终输出应包含以下文件结构:

output/
├── README.md                    # 论文概述、使用说明
├── docs/                        # 技术文档
│   ├── architecture.md          # 整体架构文档
│   ├── modules/                 # 模块文档
│   │   ├── backbone.md          # 骨干网络文档
│   │   ├── attention.md         # 注意力模块文档 (如适用)
│   │   ├── head.md              # 输出头文档
│   │   └── loss.md              # 损失函数文档
│   ├── math/                    # 数学文档
│   │   ├── formulas.md          # 公式汇总
│   │   └── derivations.md       # 推导过程 (如适用)
│   └── diagrams/                # 流程图文档
│       ├── overview.md          # 整体架构流程图
│       └── tensor_shapes.md     # 张量形状变化图
├── model.py                     # 主模型实现
├── modules/                     # 子模块
│   ├── __init__.py
│   ├── attention.py             # 注意力模块
│   ├── backbone.py              # 骨干网络
│   └── head.py                  # 输出头
├── config.py                    # 配置文件
└── requirements.txt             # 依赖列表

Constraints

工具栈

  • Python 3.10+
  • PyTorch 2.0+
  • Mermaid (for diagrams)
  • einops (optional, for tensor operations)
  • timm (optional, for pretrained backbones)

语言规范

  • 分析与解释: 使用简体中文
  • 代码注释: 使用英文
  • 变量命名: 使用英文,遵循 PEP 8

Verification Checklist

在完成复现后,检查以下项目:

Phase 1 - 论文解构:

  • 核心问题和创新点已明确
  • 所有关键公式已提取并解释
  • 超参数配置完整,未指定项已按规范标注

Phase 2 - 架构可视化:

  • Mermaid 流程图包含所有主要组件
  • 每个节点标注了 Tensor Shape (使用标准符号)
  • 使用颜色区分不同类型组件
  • 深层网络使用模块级粒度

Phase 3 - 代码实现:

  • 继承 nn.Module
  • 完整的类型提示
  • 每层操作后有维度注释 (使用标准符号)
  • 子模块已拆分
  • 包含 _init_weights 方法
  • 包含 set_seed 函数
  • 未指定参数已按 Protocol 标注

Phase 4 - 模块文档:

  • 每个核心模块有独立文档
  • 文档包含完整的 6 个部分 (概述/数学/流程图/实现/示例/注意事项)
  • Mermaid 流程图清晰展示数据流
  • 张量形状变化图完整
  • 公式使用 LaTeX 格式并有符号说明表
  • 复杂公式包含推导过程
  • Loss Function 有详细分析
  • 文档结构符合规范 (docs/modules/, docs/math/, docs/diagrams/)

验证测试:

  • 随机 Tensor 前向传播无报错
  • 输出 Shape 与预期一致
  • 参数量与论文一致 (如有报告,允许 1% 误差)
  • 梯度流检查通过
  • 输出数值范围合理

常见陷阱检查:

  • 已检查 reshape/view 的使用
  • 已核对论文图表与文字描述
  • 已注意框架差异 (如 BatchNorm momentum)