| name | Type Hints for ML Code |
| description | Apply appropriate type hints for ML/PyTorch code. Use when adding type annotations to ML code or addressing mypy errors. |
Type Hints for ML Code - Constitutional AI
When to Apply
Automatically activate when:
- Adding type hints to ML code
- Addressing mypy type errors
- Working with PyTorch tensors, models, or optimizers
- Dealing with HuggingFace transformers types
Project Type Checking Status
Current mypy status:
- 41 errors remaining (documented in MYPY_ANALYSIS_REPORT.md)
- Status: Accepted as reasonable for ML research code
- CI behavior:
continue-on-error: true(doesn't fail builds)
Key insight: Perfect type coverage is not the goal for ML code. Prioritize correctness and readability over type perfection.
When to Use Type Hints
✅ Do Use Type Hints For
Public API functions
def load_model(model_name: str) -> tuple[AutoModelForCausalLM, AutoTokenizer]: """Load model and tokenizer.""" passConfiguration dataclasses
@dataclass class TrainingConfig: learning_rate: float batch_size: int num_epochs: intClear input/output types
def evaluate_text(text: str, framework: ConstitutionalFramework) -> dict[str, Any]: """Evaluate text against principles.""" passHelper functions with simple types
def format_prompt(prompt: str, examples: list[str]) -> str: """Format prompt with examples.""" pass
❌ Don't Force Type Hints For
- Complex tensor operations (mypy struggles with tensor shapes)
- Dynamic PyTorch internals (intentionally uses
Any) - NumPy operations (overload resolution issues)
- Training loop internals (too complex, low value)
PyTorch Type Patterns
Tensor Types
import torch
from torch import Tensor
# ✅ Basic tensor type
def forward(inputs: Tensor) -> Tensor:
return inputs * 2
# ✅ Optional tensor (common in ML)
def process_batch(
inputs: Tensor,
labels: Tensor | None = None
) -> Tensor:
pass
# ⚠️ Avoid overly specific tensor types (mypy can't verify shapes)
# This is too specific and won't type-check well:
def bad_example(inputs: Tensor[int, 32, 768]) -> Tensor[int, 32, 10]:
pass
Model and Optimizer Types
from torch.nn import Module
from torch.optim import Optimizer
from transformers import PreTrainedModel, PreTrainedTokenizer
# ✅ Use base classes for flexibility
def train_model(
model: Module, # or PreTrainedModel for HuggingFace
optimizer: Optimizer,
dataloader: DataLoader
) -> None:
pass
# ✅ HuggingFace specific types
def generate_text(
prompt: str,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
max_length: int = 100
) -> str:
pass
Device Types
from torch import device as Device
# ✅ Device type
def move_to_device(tensor: Tensor, device: Device | str) -> Tensor:
return tensor.to(device)
# Common pattern:
device: str | Device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Handling "Any" Type
When Any is Acceptable
from typing import Any
# ✅ For complex nested structures
def process_model_output(output: Any) -> dict[str, Any]:
"""Process model output (structure varies by model)."""
pass
# ✅ For highly dynamic operations
def advanced_tensor_op(tensors: list[Tensor]) -> Any:
"""Complex operation with unpredictable output type."""
pass
# ✅ For configuration dictionaries
config: dict[str, Any] = {
'learning_rate': 0.001,
'model_name': 'gpt2',
'device': 'cuda',
}
Prefer Specific Types When Possible
# ❌ Too vague
def process_data(data: Any) -> Any:
pass
# ✅ More specific
def process_data(data: list[dict[str, float]]) -> dict[str, Tensor]:
pass
Optional and Union Types
from typing import Optional # or use | None (Python 3.10+)
# Modern syntax (Python 3.10+)
def load_checkpoint(path: str | None = None) -> dict[str, Tensor] | None:
if path is None:
return None
return torch.load(path)
# Equivalent older syntax
from typing import Optional, Union
def load_checkpoint(path: Optional[str] = None) -> Optional[dict[str, Tensor]]:
pass
Common ML Type Patterns
Return Multiple Values
# ✅ Tuple with type hints
def load_model(name: str) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
model = AutoModelForCausalLM.from_pretrained(name)
tokenizer = AutoTokenizer.from_pretrained(name)
return model, tokenizer
# Usage
model, tokenizer = load_model("gpt2")
Dataclass for Complex Returns
from dataclasses import dataclass
@dataclass
class EvaluationResult:
any_flagged: bool
flagged_principles: list[str]
weighted_score: float
details: dict[str, Any]
def evaluate_text(text: str) -> EvaluationResult:
"""Evaluate text and return structured result."""
pass
Generator Types
from collections.abc import Iterator
def batch_generator(data: list[Tensor], batch_size: int) -> Iterator[Tensor]:
"""Generate batches from data."""
for i in range(0, len(data), batch_size):
yield data[i:i + batch_size]
Known Mypy Challenges in Project
Category 1: NumPy Overload Resolution (14 errors)
# Mypy struggles with NumPy's 6+ overload variants
import numpy as np
# This may show mypy error, but it's correct at runtime
scores = np.array([1.0, 2.0, 3.0]) # mypy: "cannot infer type"
Resolution: Use # type: ignore[misc] or accept the error (CI allows it)
Category 2: Optional Attribute Access (7 errors)
# Mypy warns about potential None access
def process_model(model: PreTrainedModel | None):
if model is not None:
output = model.generate(...) # mypy may still warn
# Resolution: Assert non-None or use type: ignore
Category 3: Tensor Type Inference (8 errors)
# PyTorch intentionally uses dynamic typing
loss = criterion(outputs, targets) # mypy: "Cannot determine type"
Resolution: Accept as limitation - PyTorch is dynamically typed by design
Type Hints Best Practices for This Project
1. Prioritize Public APIs
# ✅ Type hints for exported functions
def setup_default_framework() -> ConstitutionalFramework:
"""Public API - should have types."""
pass
# ⚠️ Optional for internal helpers
def _internal_helper(data):
"""Internal - types optional."""
pass
2. Use Any Strategically
# ✅ Good use of Any
def process_principle_config(config: dict[str, Any]) -> ConstitutionalPrinciple:
"""Config structure varies - Any is appropriate."""
pass
# ❌ Overuse of Any
def add(a: Any, b: Any) -> Any:
"""Too vague - use specific types."""
return a + b
3. Document Complex Types
from typing import TypeAlias
# Define alias for complex type
TrainingBatch: TypeAlias = dict[str, Tensor]
EvalResults: TypeAlias = dict[str, float | bool | list[str]]
def train_step(batch: TrainingBatch) -> EvalResults:
"""Type alias makes signature readable."""
pass
4. Handle Protocol/Abstract Types
from typing import Protocol
class Evaluator(Protocol):
"""Protocol for evaluation functions."""
def evaluate(self, text: str) -> dict[str, Any]: ...
def run_evaluation(text: str, evaluator: Evaluator) -> dict[str, Any]:
"""Accept any object matching Evaluator protocol."""
return evaluator.evaluate(text)
Mypy Configuration (from pyproject.toml)
[tool.mypy]
python_version = "3.10"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = false # Not required for ML code
ignore_missing_imports = true # Many ML libraries lack stubs
Key settings:
disallow_untyped_defs = false- Types helpful but not requiredignore_missing_imports = true- PyTorch, transformers lack complete type stubs
When to Use # type: ignore
Acceptable Use Cases
# ✅ Known mypy limitation with NumPy
scores = np.array(data) # type: ignore[misc]
# ✅ PyTorch dynamic typing
loss = criterion(outputs, targets) # type: ignore[arg-type]
# ✅ Third-party library without stubs
from some_ml_lib import advanced_feature # type: ignore
Avoid Overuse
# ❌ Don't ignore fixable errors
def add(a: int, b: int) -> str:
return a + b # type: ignore # Fix the return type instead!
# ✅ Fix the actual issue
def add(a: int, b: int) -> int:
return a + b
Type Checking Commands
# Run mypy on all source code
mypy constitutional_ai/ --ignore-missing-imports
# Run on specific module
mypy constitutional_ai/framework.py
# Show error codes (useful for targeted ignores)
mypy constitutional_ai/ --show-error-codes
# Strict mode (educational, will fail)
mypy constitutional_ai/ --strict
Summary: Type Hint Philosophy for ML Code
- Use type hints where they add clarity (public APIs, config)
- Skip type hints where they fight the framework (PyTorch internals)
- Accept mypy errors for known ML ecosystem limitations (documented in MYPY_ANALYSIS_REPORT.md)
- Use
Anystrategically, not lazily (complex structures = OK, simple functions = not OK) - CI should pass with 41 known errors (continue-on-error: true)
- Prioritize correctness over type perfection (tests are the real validation)
Quick Reference
| Type | When to Use | Example |
|---|---|---|
str, int, float |
Primitives | def format(text: str) -> str |
list[T], dict[K, V] |
Collections | def process(data: list[str]) -> dict[str, int] |
Tensor |
PyTorch tensors | def forward(x: Tensor) -> Tensor |
Module |
PyTorch models | def train(model: Module) -> None |
PreTrainedModel |
HF models | def generate(model: PreTrainedModel) -> str |
Any |
Unknown/dynamic | def process(config: dict[str, Any]) -> Any |
T | None |
Optional | def load(path: str | None) -> Tensor | None |
tuple[A, B] |
Multiple returns | def load_model() -> tuple[Module, Tokenizer] |
Remember: The goal is helpful type hints, not perfect type coverage. This is research code.