Claude Code Plugins

Community-maintained marketplace

Feedback

experiment-tracking

@tachyon-beep/skillpacks
1
0

Experiment tracking - metrics, artifacts; TensorBoard, Weights & Biases, MLflow

Install Skill

1Download skill
2Enable skills in Claude

Open claude.ai/settings/capabilities and find the "Skills" section

3Upload to Claude

Click "Upload skill" and select the downloaded ZIP file

Note: Please verify skill by going through its instructions before using it.

SKILL.md

name experiment-tracking
description Experiment tracking - metrics, artifacts; TensorBoard, Weights & Biases, MLflow

Experiment Tracking Skill

When to Use This Skill

Use this skill when:

  • User starts training a model and asks "should I track this experiment?"
  • User wants to reproduce a previous result but doesn't remember settings
  • Training runs overnight and user needs persistent logs
  • User asks "which tool should I use: TensorBoard, W&B, or MLflow?"
  • Multiple experiments running and user can't compare results
  • User wants to share results with teammates or collaborators
  • Model checkpoints accumulating with no organization or versioning
  • User asks "what should I track?" or "how do I make experiments reproducible?"
  • Debugging training issues and needs historical data (metrics, gradients)
  • User wants to visualize training curves or compare hyperparameters
  • Working on a research project that requires tracking many experiments
  • User lost their best result and can't reproduce it

Do NOT use when:

  • User is doing quick prototyping with throwaway code (<5 minutes)
  • Only running inference on pre-trained models (no training)
  • Single experiment that's already tracked and working
  • User is asking about hyperparameter tuning strategy (not tracking)
  • Discussing model architecture design (not experiment management)

Core Principles

1. Track Before You Need It (Can't Add Retroactively)

The BIGGEST mistake: waiting to track until results are worth saving.

The Reality:

  • The best result is ALWAYS the one you didn't track
  • Can't add tracking after the experiment completes
  • Human memory fails within hours (let alone days/weeks)
  • Print statements disappear when terminal closes
  • Code changes between experiments (git state matters)

When Tracking Matters:

Experiment value curve:
    ^
    |                    ╱─  Peak result (untracked = lost forever)
    |                  ╱
    |                ╱
    |              ╱
    |            ╱
    |          ╱
    |        ╱
    |      ╱
    |____╱________________________________>
         Start                        Time

If you wait to track "important" experiments, you've already lost them.

Track From Day 1:

  • First experiment (even if "just testing")
  • Every hyperparameter change
  • Every model architecture variation
  • Every data preprocessing change

Decision Rule: If you're running python train.py, you should be tracking. No exceptions.


2. Complete Tracking = Hyperparameters + Metrics + Artifacts + Environment

Reproducibility requires tracking EVERYTHING that affects the result.

The Five Categories:

┌─────────────────────────────────────────────────────────┐
│ 1. HYPERPARAMETERS (what you're tuning)                │
├─────────────────────────────────────────────────────────┤
│ • Learning rate, batch size, optimizer type             │
│ • Model architecture (width, depth, activation)         │
│ • Regularization (weight decay, dropout)                │
│ • Training length (epochs, steps)                       │
│ • Data augmentation settings                            │
└─────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────┐
│ 2. METRICS (how you're doing)                           │
├─────────────────────────────────────────────────────────┤
│ • Training loss (every step or epoch)                   │
│ • Validation loss (every epoch)                         │
│ • Evaluation metrics (accuracy, F1, mAP, etc.)          │
│ • Learning rate schedule (actual LR each step)          │
│ • Gradient norms (for debugging)                        │
└─────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────┐
│ 3. ARTIFACTS (what you're saving)                       │
├─────────────────────────────────────────────────────────┤
│ • Model checkpoints (with epoch/step metadata)          │
│ • Training plots (loss curves, confusion matrices)      │
│ • Predictions on validation set                         │
│ • Logs (stdout, stderr)                                 │
│ • Config files (for reproducibility)                    │
└─────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────┐
│ 4. CODE VERSION (what you're running)                   │
├─────────────────────────────────────────────────────────┤
│ • Git commit hash                                        │
│ • Git branch name                                        │
│ • Dirty status (uncommitted changes)                    │
│ • Code diff (if uncommitted)                            │
└─────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────┐
│ 5. ENVIRONMENT (where you're running)                   │
├─────────────────────────────────────────────────────────┤
│ • Python version, PyTorch version                       │
│ • CUDA version, GPU type                                │
│ • Random seeds (Python, NumPy, PyTorch, CUDA)           │
│ • Data version (if dataset changes)                     │
│ • Hardware (CPU, RAM, GPU count)                        │
└─────────────────────────────────────────────────────────┘

Reproducibility Test:

Can someone else (or future you) reproduce the result with ONLY the tracked information?

If NO, you're not tracking enough.


3. Tool Selection: Local vs Team vs Production

Different tools for different use cases. Choose based on your needs.

Tool Comparison:

Feature TensorBoard Weights & Biases MLflow Custom
Setup Complexity Low Low Medium High
Local Only Yes No (cloud) Yes Yes
Team Collaboration Limited Excellent Good Custom
Cost Free Free tier + paid Free Free
Scalability Medium High High Low
Visualization Good Excellent Good Custom
Integration PyTorch, TF Everything Everything Manual
Best For Solo projects Team research Production Specific needs

Decision Tree:

Do you need team collaboration?
├─ YES → Need to share results with teammates?
│   ├─ YES → Weights & Biases (best team features)
│   └─ NO → MLflow (self-hosted, more control)
│
└─ NO → Solo project?
    ├─ YES → TensorBoard (simplest, local)
    └─ NO → MLflow (scales to production)

Budget constraints?
├─ FREE only → TensorBoard or MLflow
└─ Can pay → W&B (worth it for teams)

Production deployment?
├─ YES → MLflow (production-ready)
└─ NO → TensorBoard or W&B (research)

Recommendation:

  • Starting out / learning: TensorBoard (easiest, free, local)
  • Research team / collaboration: Weights & Biases (best UX, sharing)
  • Production ML / enterprise: MLflow (self-hosted, model registry)
  • Specific needs / customization: Custom logging (CSV + Git)

4. Minimal Overhead, Maximum Value

Tracking should cost 1-5% overhead, not 50%.

What to Track at Different Frequencies:

# Every step (high frequency, small data):
log_every_step = {
    "train_loss": loss.item(),
    "learning_rate": optimizer.param_groups[0]['lr'],
    "step": global_step,
}

# Every epoch (medium frequency, medium data):
log_every_epoch = {
    "train_loss_avg": train_losses.mean(),
    "val_loss": val_loss,
    "val_accuracy": val_acc,
    "epoch": epoch,
}

# Once per experiment (low frequency, large data):
log_once = {
    "hyperparameters": config,
    "git_commit": get_git_hash(),
    "environment": {
        "python_version": sys.version,
        "torch_version": torch.__version__,
        "cuda_version": torch.version.cuda,
    },
}

# Only on improvement (conditional):
if val_loss < best_val_loss:
    save_checkpoint(model, optimizer, epoch, val_loss)
    log_artifact("best_model.pt")

Overhead Guidelines:

  • Logging scalars (loss, accuracy): <0.1% overhead (always do)
  • Logging images/plots: 1-2% overhead (do every epoch)
  • Logging checkpoints: 5-10% overhead (do only on improvement)
  • Logging gradients: 10-20% overhead (do only for debugging)

Don't Track:

  • Raw training data (too large, use data versioning instead)
  • Every intermediate activation (use profiling tools instead)
  • Full model weights every step (only on improvement)

5. Experiment Organization: Naming, Tagging, Grouping

With 100+ experiments, organization is survival.

Naming Convention:

# GOOD: Descriptive, sortable, parseable
experiment_name = f"{model}_{dataset}_{timestamp}_{hyperparams}"
# Examples:
# "resnet18_cifar10_20241030_lr0.01_bs128"
# "bert_squad_20241030_lr3e-5_warmup1000"
# "gpt2_wikitext_20241030_ctx512_layers12"

# BAD: Uninformative
experiment_name = "test"
experiment_name = "final"
experiment_name = "model_v2"
experiment_name = "test_again_actually_final"

Tagging Strategy:

# Tags for filtering and grouping
tags = {
    "model": "resnet18",
    "dataset": "cifar10",
    "experiment_type": "hyperparameter_search",
    "status": "completed",
    "goal": "beat_baseline",
    "author": "john",
}

# Can filter later:
# - Show me all "hyperparameter_search" experiments
# - Show me all "resnet18" on "cifar10"
# - Show me experiments by "john"

Grouping Related Experiments:

# Group by goal/project
project = "cifar10_sota"
group = "learning_rate_search"
experiment_name = f"{project}/{group}/lr_{lr}"

# Hierarchy:
# cifar10_sota/
#   ├─ learning_rate_search/
#   │   ├─ lr_0.001
#   │   ├─ lr_0.01
#   │   └─ lr_0.1
#   ├─ architecture_search/
#   │   ├─ resnet18
#   │   ├─ resnet34
#   │   └─ resnet50
#   └─ regularization_search/
#       ├─ dropout_0.1
#       ├─ dropout_0.3
#       └─ dropout_0.5

Tool-Specific Integration

TensorBoard (Local, Simple)

Setup:

from torch.utils.tensorboard import SummaryWriter

# Create writer
writer = SummaryWriter(f"runs/{experiment_name}")

# Log hyperparameters
hparams = {
    "learning_rate": 0.01,
    "batch_size": 128,
    "optimizer": "adam",
}
metrics = {
    "best_val_acc": 0.0,
}
writer.add_hparams(hparams, metrics)

During Training:

for epoch in range(num_epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        # Training step
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        # Log every N steps
        global_step = epoch * len(train_loader) + batch_idx
        if global_step % log_interval == 0:
            writer.add_scalar("train/loss", loss.item(), global_step)
            writer.add_scalar("train/lr", optimizer.param_groups[0]['lr'], global_step)

    # Validation
    val_loss, val_acc = evaluate(model, val_loader)
    writer.add_scalar("val/loss", val_loss, epoch)
    writer.add_scalar("val/accuracy", val_acc, epoch)

    # Log images (confusion matrix, etc.)
    if epoch % 10 == 0:
        fig = plot_confusion_matrix(model, val_loader)
        writer.add_figure("val/confusion_matrix", fig, epoch)

writer.close()

View Results:

tensorboard --logdir=runs
# Opens web UI at http://localhost:6006

Pros:

  • Simple setup (2 lines of code)
  • Local (no cloud dependency)
  • Good visualizations (scalars, images, graphs)
  • Integrated with PyTorch

Cons:

  • No hyperparameter comparison table
  • Limited team collaboration
  • No artifact storage (checkpoints)
  • Manual experiment management

Weights & Biases (Team, Cloud)

Setup:

import wandb

# Initialize experiment
wandb.init(
    project="cifar10-sota",
    name=experiment_name,
    config={
        "learning_rate": 0.01,
        "batch_size": 128,
        "optimizer": "adam",
        "model": "resnet18",
        "dataset": "cifar10",
    },
    tags=["hyperparameter_search", "resnet"],
)

# Config is automatically tracked
config = wandb.config

During Training:

for epoch in range(num_epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        # Log metrics
        wandb.log({
            "train/loss": loss.item(),
            "train/lr": optimizer.param_groups[0]['lr'],
            "epoch": epoch,
        })

    # Validation
    val_loss, val_acc = evaluate(model, val_loader)
    wandb.log({
        "val/loss": val_loss,
        "val/accuracy": val_acc,
        "epoch": epoch,
    })

    # Save checkpoint
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pt")
        wandb.save("best_model.pt")  # Upload to cloud

# Log final results
wandb.log({"best_val_accuracy": best_val_acc})
wandb.finish()

Advanced Features:

# Log images
wandb.log({"examples": [wandb.Image(img, caption=f"Pred: {pred}") for img, pred in samples]})

# Log plots
fig = plot_confusion_matrix(model, val_loader)
wandb.log({"confusion_matrix": wandb.Image(fig)})

# Log tables (for result analysis)
table = wandb.Table(columns=["epoch", "train_loss", "val_loss", "val_acc"])
for epoch, tl, vl, va in zip(epochs, train_losses, val_losses, val_accs):
    table.add_data(epoch, tl, vl, va)
wandb.log({"results": table})

# Log model architecture
wandb.watch(model, log="all", log_freq=100)  # Logs gradients + weights

View Results:

Pros:

  • Excellent team collaboration (share links)
  • Beautiful visualizations
  • Hyperparameter comparison (parallel coordinates)
  • Artifact versioning (models, data)
  • Integration with everything (PyTorch, TF, JAX)

Cons:

  • Cloud-based (requires internet)
  • Free tier limits (100GB storage)
  • Data leaves your machine (privacy concern)

MLflow (Production, Self-Hosted)

Setup:

import mlflow
import mlflow.pytorch

# Start experiment
mlflow.set_experiment("cifar10-sota")

# Start run
with mlflow.start_run(run_name=experiment_name):
    # Log hyperparameters
    mlflow.log_param("learning_rate", 0.01)
    mlflow.log_param("batch_size", 128)
    mlflow.log_param("optimizer", "adam")
    mlflow.log_param("model", "resnet18")

    # Training loop
    for epoch in range(num_epochs):
        train_loss = train_epoch(model, train_loader, optimizer)
        val_loss, val_acc = evaluate(model, val_loader)

        # Log metrics
        mlflow.log_metric("train_loss", train_loss, step=epoch)
        mlflow.log_metric("val_loss", val_loss, step=epoch)
        mlflow.log_metric("val_accuracy", val_acc, step=epoch)

    # Log final metrics
    mlflow.log_metric("best_val_accuracy", best_val_acc)

    # Log model
    mlflow.pytorch.log_model(model, "model")

    # Log artifacts
    mlflow.log_artifact("config.yaml")
    mlflow.log_artifact("best_model.pt")

View Results:

mlflow ui
# Opens web UI at http://localhost:5000

Model Registry (for production):

# Register model
model_uri = f"runs:/{run_id}/model"
mlflow.register_model(model_uri, "cifar10-resnet18")

# Load registered model
model = mlflow.pytorch.load_model("models:/cifar10-resnet18/production")

Pros:

  • Self-hosted (full control, privacy)
  • Model registry (production deployment)
  • Scales to large teams
  • Integration with deployment tools

Cons:

  • More complex setup (need server)
  • Visualization not as good as W&B
  • Less intuitive UI

Reproducibility Patterns

1. Seed Everything

import random
import numpy as np
import torch

def set_seed(seed):
    """Set all random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # Deterministic operations (slower but reproducible)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# At start of training
set_seed(42)

# Log seed
config = {"seed": 42}

Warning: Deterministic mode can be 10-20% slower. Trade-off between speed and reproducibility.


2. Capture Git State

import subprocess

def get_git_info():
    """Capture current git state."""
    try:
        commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
        branch = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).decode('ascii').strip()

        # Check for uncommitted changes
        status = subprocess.check_output(['git', 'status', '--porcelain']).decode('ascii').strip()
        is_dirty = len(status) > 0

        # Get diff if dirty
        diff = None
        if is_dirty:
            diff = subprocess.check_output(['git', 'diff']).decode('ascii')

        return {
            "commit": commit,
            "branch": branch,
            "is_dirty": is_dirty,
            "diff": diff,
        }
    except Exception as e:
        return {"error": str(e)}

# Log git info
git_info = get_git_info()
if git_info.get("is_dirty"):
    print("WARNING: Uncommitted changes detected!")
    print("Experiment may not be reproducible without the diff.")

3. Environment Capture

import sys
import torch

def get_environment_info():
    """Capture environment details."""
    return {
        "python_version": sys.version,
        "torch_version": torch.__version__,
        "cuda_version": torch.version.cuda,
        "cudnn_version": torch.backends.cudnn.version(),
        "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
        "gpu_count": torch.cuda.device_count(),
    }

# Save requirements.txt
# pip freeze > requirements.txt

# Or use pip-tools
# pip-compile requirements.in

4. Config Files for Reproducibility

# config.yaml
model:
  name: resnet18
  num_classes: 10

training:
  learning_rate: 0.01
  batch_size: 128
  num_epochs: 100
  optimizer: adam
  weight_decay: 0.0001

data:
  dataset: cifar10
  augmentation: true
  normalize: true

# Load config
import yaml
with open("config.yaml") as f:
    config = yaml.safe_load(f)

# Save config alongside results
import shutil
shutil.copy("config.yaml", f"results/{experiment_name}/config.yaml")

Experiment Comparison

1. Comparing Metrics

# TensorBoard: Compare multiple runs
# tensorboard --logdir=runs --port=6006
# Select multiple runs in UI

# W&B: Filter and compare
# Go to project page, select runs, click "Compare"

# MLflow: Query experiments
import mlflow

# Get all runs from an experiment
experiment = mlflow.get_experiment_by_name("cifar10-sota")
runs = mlflow.search_runs(experiment_ids=[experiment.experiment_id])

# Filter by metric
best_runs = runs[runs["metrics.val_accuracy"] > 0.85]

# Sort by metric
best_runs = runs.sort_values("metrics.val_accuracy", ascending=False)

# Analyze hyperparameter impact
import pandas as pd
import seaborn as sns

# Plot learning rate vs accuracy
sns.scatterplot(data=runs, x="params.learning_rate", y="metrics.val_accuracy")

2. Hyperparameter Analysis

# W&B: Parallel coordinates plot
# Shows which hyperparameter combinations lead to best results
# UI: Click "Parallel Coordinates" in project view

# MLflow: Custom analysis
import matplotlib.pyplot as plt

# Group by hyperparameter
for lr in [0.001, 0.01, 0.1]:
    lr_runs = runs[runs["params.learning_rate"] == str(lr)]
    accuracies = lr_runs["metrics.val_accuracy"]
    plt.scatter([lr] * len(accuracies), accuracies, alpha=0.5, label=f"LR={lr}")

plt.xlabel("Learning Rate")
plt.ylabel("Validation Accuracy")
plt.xscale("log")
plt.legend()
plt.title("Learning Rate vs Accuracy")
plt.show()

3. Comparing Artifacts

# Compare model checkpoints
from torchvision.models import resnet18

# Load two models
model_a = resnet18()
model_a.load_state_dict(torch.load("experiments/exp_a/best_model.pt"))

model_b = resnet18()
model_b.load_state_dict(torch.load("experiments/exp_b/best_model.pt"))

# Compare on validation set
acc_a = evaluate(model_a, val_loader)
acc_b = evaluate(model_b, val_loader)

print(f"Model A: {acc_a:.2%}")
print(f"Model B: {acc_b:.2%}")

# Compare predictions
preds_a = model_a(val_data)
preds_b = model_b(val_data)
agreement = (preds_a.argmax(1) == preds_b.argmax(1)).float().mean()
print(f"Prediction agreement: {agreement:.2%}")

Collaboration Workflows

1. Sharing Results (W&B)

# Share experiment link
# https://wandb.ai/your-username/cifar10-sota/runs/run-id

# Create report
# W&B UI: Click "Create Report" → Add charts, text, code

# Export results
# W&B UI: Click "Export" → CSV, JSON, or API

# API access for programmatic sharing
import wandb
api = wandb.Api()
runs = api.runs("your-username/cifar10-sota")

for run in runs:
    print(f"{run.name}: {run.summary['val_accuracy']}")

2. Team Experiment Dashboard

# MLflow: Shared tracking server
# Server machine:
mlflow server --host 0.0.0.0 --port 5000

# Team members:
import mlflow
mlflow.set_tracking_uri("http://shared-server:5000")

# Everyone logs to same server
with mlflow.start_run():
    mlflow.log_metric("val_accuracy", 0.87)

3. Experiment Handoff

# Package experiment for reproducibility
experiment_package = {
    "code": "git_commit_hash",
    "config": "config.yaml",
    "model": "best_model.pt",
    "results": "results.csv",
    "logs": "training.log",
    "environment": "requirements.txt",
}

# Create reproducibility script
# reproduce.sh
"""
#!/bin/bash
git checkout <commit-hash>
pip install -r requirements.txt
python train.py --config config.yaml
"""

Complete Tracking Example

Here's a production-ready tracking setup:

import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import wandb
import yaml
import subprocess
from pathlib import Path
from datetime import datetime

class ExperimentTracker:
    """Complete experiment tracking wrapper."""

    def __init__(self, config, experiment_name=None, use_wandb=True, use_tensorboard=True):
        self.config = config
        self.use_wandb = use_wandb
        self.use_tensorboard = use_tensorboard

        # Generate experiment name
        if experiment_name is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            experiment_name = f"{config['model']}_{config['dataset']}_{timestamp}"
        self.experiment_name = experiment_name

        # Create experiment directory
        self.exp_dir = Path(f"experiments/{experiment_name}")
        self.exp_dir.mkdir(parents=True, exist_ok=True)

        # Initialize tracking tools
        if self.use_tensorboard:
            self.tb_writer = SummaryWriter(self.exp_dir / "tensorboard")

        if self.use_wandb:
            wandb.init(
                project=config.get("project", "default"),
                name=experiment_name,
                config=config,
                dir=self.exp_dir,
            )

        # Save config
        with open(self.exp_dir / "config.yaml", "w") as f:
            yaml.dump(config, f)

        # Capture environment
        self._log_environment()

        # Capture git state
        self._log_git_state()

        # Setup logging
        self._setup_logging()

        self.global_step = 0
        self.best_metric = float('-inf')

    def _log_environment(self):
        """Log environment information."""
        import sys
        env_info = {
            "python_version": sys.version,
            "torch_version": torch.__version__,
            "cuda_version": torch.version.cuda if torch.cuda.is_available() else None,
            "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
            "gpu_count": torch.cuda.device_count(),
        }

        # Save to file
        with open(self.exp_dir / "environment.yaml", "w") as f:
            yaml.dump(env_info, f)

        # Log to W&B
        if self.use_wandb:
            wandb.config.update({"environment": env_info})

    def _log_git_state(self):
        """Log git commit and status."""
        try:
            commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
            branch = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).decode('ascii').strip()
            status = subprocess.check_output(['git', 'status', '--porcelain']).decode('ascii').strip()
            is_dirty = len(status) > 0

            git_info = {
                "commit": commit,
                "branch": branch,
                "is_dirty": is_dirty,
            }

            # Save to file
            with open(self.exp_dir / "git_info.yaml", "w") as f:
                yaml.dump(git_info, f)

            # Save diff if dirty
            if is_dirty:
                diff = subprocess.check_output(['git', 'diff']).decode('ascii')
                with open(self.exp_dir / "git_diff.patch", "w") as f:
                    f.write(diff)
                print("WARNING: Uncommitted changes detected! Saved to git_diff.patch")

            # Log to W&B
            if self.use_wandb:
                wandb.config.update({"git": git_info})

        except Exception as e:
            print(f"Failed to capture git state: {e}")

    def _setup_logging(self):
        """Setup file logging."""
        import logging
        self.logger = logging.getLogger(self.experiment_name)
        self.logger.setLevel(logging.INFO)

        # File handler
        fh = logging.FileHandler(self.exp_dir / "training.log")
        fh.setLevel(logging.INFO)

        # Console handler
        ch = logging.StreamHandler()
        ch.setLevel(logging.INFO)

        # Formatter
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        fh.setFormatter(formatter)
        ch.setFormatter(formatter)

        self.logger.addHandler(fh)
        self.logger.addHandler(ch)

    def log_metrics(self, metrics, step=None):
        """Log metrics to all tracking backends."""
        if step is None:
            step = self.global_step

        # TensorBoard
        if self.use_tensorboard:
            for key, value in metrics.items():
                if isinstance(value, (int, float)):
                    self.tb_writer.add_scalar(key, value, step)

        # W&B
        if self.use_wandb:
            wandb.log(metrics, step=step)

        # File
        self.logger.info(f"Step {step}: {metrics}")

        self.global_step = step + 1

    def save_checkpoint(self, model, optimizer, epoch, metric_value, metric_name="val_accuracy"):
        """Save model checkpoint with metadata."""
        checkpoint = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            metric_name: metric_value,
            "config": self.config,
        }

        # Save latest checkpoint
        checkpoint_path = self.exp_dir / "checkpoints" / f"checkpoint_epoch_{epoch}.pt"
        checkpoint_path.parent.mkdir(exist_ok=True)
        torch.save(checkpoint, checkpoint_path)

        # Save best checkpoint
        if metric_value > self.best_metric:
            self.best_metric = metric_value
            best_path = self.exp_dir / "checkpoints" / "best_model.pt"
            torch.save(checkpoint, best_path)

            self.logger.info(f"New best model saved: {metric_name}={metric_value:.4f}")

            # Log to W&B
            if self.use_wandb:
                wandb.log({f"best_{metric_name}": metric_value})
                wandb.save(str(best_path))

        return checkpoint_path

    def log_figure(self, name, figure, step=None):
        """Log matplotlib figure."""
        if step is None:
            step = self.global_step

        # TensorBoard
        if self.use_tensorboard:
            self.tb_writer.add_figure(name, figure, step)

        # W&B
        if self.use_wandb:
            wandb.log({name: wandb.Image(figure)}, step=step)

        # Save to disk
        fig_path = self.exp_dir / "figures" / f"{name}_step_{step}.png"
        fig_path.parent.mkdir(exist_ok=True)
        figure.savefig(fig_path)

    def finish(self):
        """Clean up and close tracking backends."""
        if self.use_tensorboard:
            self.tb_writer.close()

        if self.use_wandb:
            wandb.finish()

        self.logger.info("Experiment tracking finished.")


# Usage example
if __name__ == "__main__":
    config = {
        "project": "cifar10-sota",
        "model": "resnet18",
        "dataset": "cifar10",
        "learning_rate": 0.01,
        "batch_size": 128,
        "num_epochs": 100,
        "optimizer": "adam",
        "seed": 42,
    }

    # Initialize tracker
    tracker = ExperimentTracker(config)

    # Training loop
    model = create_model(config)
    optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])

    for epoch in range(config["num_epochs"]):
        train_loss = train_epoch(model, train_loader, optimizer)
        val_loss, val_acc = evaluate(model, val_loader)

        # Log metrics
        tracker.log_metrics({
            "train/loss": train_loss,
            "val/loss": val_loss,
            "val/accuracy": val_acc,
            "epoch": epoch,
        })

        # Save checkpoint
        tracker.save_checkpoint(model, optimizer, epoch, val_acc)

        # Log figure (every 10 epochs)
        if epoch % 10 == 0:
            fig = plot_confusion_matrix(model, val_loader)
            tracker.log_figure("confusion_matrix", fig)

    # Finish
    tracker.finish()

Pitfalls and Anti-Patterns

Pitfall 1: Tracking Metrics But Not Config

Symptom: Have CSV with 50 experiments' metrics, but no idea what hyperparameters produced them.

Why It Happens:

  • User focuses on "what matters" (the metric)
  • Assumes they'll remember settings
  • Doesn't realize metrics without context are useless

Fix:

# WRONG: Only metrics
with open("results.csv", "a") as f:
    f.write(f"{epoch},{train_loss},{val_loss}\n")

# RIGHT: Metrics + config
experiment_id = f"exp_{timestamp}"
with open(f"{experiment_id}_config.yaml", "w") as f:
    yaml.dump(config, f)
with open(f"{experiment_id}_results.csv", "w") as f:
    f.write(f"{epoch},{train_loss},{val_loss}\n")

Pitfall 2: Overwriting Checkpoints Without Versioning

Symptom: Always saving to "best_model.pt", can't recover earlier checkpoints.

Why It Happens:

  • Disk space concerns (misguided)
  • Only care about "best" model
  • Don't anticipate evaluation bugs

Fix:

# WRONG: Overwriting
torch.save(model.state_dict(), "best_model.pt")

# RIGHT: Versioned checkpoints
torch.save(model.state_dict(), f"checkpoints/model_epoch_{epoch}.pt")
torch.save(model.state_dict(), f"checkpoints/best_model_val_acc_{val_acc:.4f}.pt")

Pitfall 3: Using Print Instead of Logging

Symptom: Training crashes, all print output lost, can't debug.

Why It Happens:

  • Print is simpler than logging
  • Works for short scripts
  • Doesn't anticipate crashes

Fix:

# WRONG: Print statements
print(f"Epoch {epoch}: loss={loss}")

# RIGHT: Proper logging
import logging
logging.basicConfig(filename="training.log", level=logging.INFO)
logging.info(f"Epoch {epoch}: loss={loss}")

Pitfall 4: No Git Tracking for Code Changes

Symptom: Can't reproduce result because code changed between experiments.

Why It Happens:

  • Rapid iteration (uncommitted changes)
  • "I'll commit later"
  • Don't realize code version matters

Fix:

# Log git commit at start of training
git_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
config["git_commit"] = git_hash

# Better: Require clean git state
status = subprocess.check_output(['git', 'status', '--porcelain']).decode('ascii').strip()
if status:
    print("ERROR: Uncommitted changes detected!")
    print("Commit your changes before running experiments.")
    sys.exit(1)

Pitfall 5: Not Tracking Random Seeds

Symptom: Same code, same hyperparameters, different results every time.

Why It Happens:

  • Forget to set seed
  • Set seed in one place but not others (PyTorch, NumPy, CUDA)
  • Don't log seed value

Fix:

# Set all seeds
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# Use seed from config
set_seed(config["seed"])

# Log seed
tracker.log_metrics({"seed": config["seed"]})

Pitfall 6: Tracking Too Much Data (Storage Bloat)

Symptom: 100GB of logs for 50 experiments, can't store more.

Why It Happens:

  • Logging every step (not just epoch)
  • Saving all checkpoints (not just best)
  • Logging high-resolution images

Fix:

# Log at appropriate frequency
if global_step % 100 == 0:  # Every 100 steps, not every step
    tracker.log_metrics({"train/loss": loss})

# Save only best checkpoints
if val_acc > best_val_acc:  # Only when improving
    tracker.save_checkpoint(model, optimizer, epoch, val_acc)

# Downsample images
img_low_res = F.interpolate(img, size=(64, 64))  # Don't log 224x224

Pitfall 7: No Experiment Naming Convention

Symptom: experiments/test, experiments/test2, experiments/final, experiments/final_final

Why It Happens:

  • No planning for multiple experiments
  • Naming feels unimportant
  • "I'll organize later"

Fix:

# Good naming convention
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
experiment_name = f"{config['model']}_{config['dataset']}_{timestamp}_lr{config['lr']}"
# Example: "resnet18_cifar10_20241030_120000_lr0.01"

Pitfall 8: Not Tracking Evaluation Metrics

Symptom: Saved best model by training loss, but validation loss was actually increasing (overfitting).

Why It Happens:

  • Only tracking training metrics
  • Assuming training loss = model quality
  • Not validating frequently enough

Fix:

# Track both training and validation
tracker.log_metrics({
    "train/loss": train_loss,
    "val/loss": val_loss,  # Don't forget validation!
    "val/accuracy": val_acc,
})

# Save best model by validation metric, not training
if val_acc > best_val_acc:
    tracker.save_checkpoint(model, optimizer, epoch, val_acc)

Pitfall 9: Local-Only Tracking for Team Projects

Symptom: Team members can't see each other's experiments, duplicate work.

Why It Happens:

  • TensorBoard is local by default
  • Don't realize collaboration tools exist
  • Privacy concerns (unfounded)

Fix:

# Use team-friendly tool
wandb.init(project="team-project")  # Everyone can see

# Or: Share TensorBoard logs
# scp -r runs/ shared-server:/path/
# tensorboard --logdir=/path/runs --host=0.0.0.0

Pitfall 10: No Tracking Until "Important" Experiment

Symptom: First 20 experiments untracked, realize they had valuable insights.

Why It Happens:

  • "Just testing" mentality
  • Tracking feels like overhead
  • Don't realize importance until later

Fix:

# Track from experiment 1
# Even if "just testing", it takes 30 seconds to set up tracking
tracker = ExperimentTracker(config)

# Future you will thank past you

Rationalization vs Reality Table

User Rationalization Reality Recommendation
"I'll remember what I tried" You won't (memory fails in hours) Track from day 1, always
"Print statements are enough" Lost on crash or terminal close Use proper logging to file
"Only track final metrics" Can't debug without intermediate data Track every epoch minimum
"Just save best model" Need checkpoints for analysis Version all important checkpoints
"Tracking adds too much overhead" <1% overhead for scalars Log metrics, not raw data
"I only need the model file" Need hyperparameters to understand it Save config + model + metrics
"TensorBoard is too complex" 2 lines of code to set up Start simple, expand later
"I'll organize experiments later" Never happens, chaos ensues Use naming convention from start
"Git commits slow me down" Uncommitted code = irreproducible Commit before experiments
"Cloud tracking costs money" Free tiers are generous W&B free: 100GB, unlimited experiments
"I don't need reproducibility" Your future self will Track environment + seed + git
"Tracking is for production, not research" Research needs it more (exploration) Research = more experiments = more tracking

Red Flags (Likely to Fail)

  1. "I'll track it later"

    • Reality: Later = never; best results are always untracked
    • Action: Track from experiment 1
  2. "Just using print statements"

    • Reality: Lost on crash/close; can't analyze later
    • Action: Use logging framework or tracking tool
  3. "Only tracking the final metric"

    • Reality: Can't debug convergence issues; no training curves
    • Action: Track every epoch at minimum
  4. "Saving to best_model.pt (overwriting)"

    • Reality: Can't recover earlier checkpoints; evaluation bugs = disaster
    • Action: Version checkpoints with epoch/metric
  5. "Don't need to track hyperparameters"

    • Reality: Metrics without config are meaningless
    • Action: Log config alongside metrics
  6. "Not tracking git commit"

    • Reality: Code changes = irreproducible
    • Action: Log git hash, check for uncommitted changes
  7. "Random seed doesn't matter"

    • Reality: Can cause 5%+ variance in results
    • Action: Set and log all seeds
  8. "TensorBoard/W&B is overkill for me"

    • Reality: Setup takes 2 minutes, saves hours later
    • Action: Use simplest tool (TensorBoard), expand if needed
  9. "I'm just testing, don't need tracking"

    • Reality: Best results come from "tests"
    • Action: Track everything, including tests
  10. "Team doesn't need to see my experiments"

    • Reality: Collaboration requires transparency
    • Action: Use shared tracking (W&B, MLflow server)

When This Skill Applies

Strong Signals (definitely use):

  • Starting a new ML project (even "quick prototype")
  • User asks "should I track this?"
  • User lost their best result and can't reproduce
  • Multiple experiments running (need comparison)
  • Team collaboration (need to share results)
  • User asks about TensorBoard, W&B, or MLflow
  • Training crashes and user needs debugging data

Weak Signals (maybe use):

  • User has tracking but it's incomplete
  • Asking about reproducibility
  • Discussing hyperparameter tuning (needs tracking)
  • Long-running training (overnight, multi-day)

Not Applicable:

  • Pure inference (no training)
  • Single experiment already tracked
  • Discussing model architecture only
  • Data preprocessing questions (pre-training)

Success Criteria

You've successfully applied this skill when:

  1. Complete Tracking: Hyperparameters + metrics + artifacts + git + environment all logged
  2. Reproducibility: Someone else (or future you) can reproduce the result from tracked info
  3. Tool Choice: Selected appropriate tool (TensorBoard, W&B, MLflow) for use case
  4. Organization: Experiments have clear naming, tagging, grouping
  5. Comparison: Can compare experiments side-by-side, analyze hyperparameter impact
  6. Collaboration: Team can see and discuss results (if team project)
  7. Minimal Overhead: Tracking adds <5% runtime overhead
  8. Persistence: Logs survive crashes, terminal closes, reboots
  9. Historical Analysis: Can go back to any experiment and understand what was done
  10. Best Practices: Git commits before experiments, seeds set, evaluation bugs impossible

Final Test: Can you reproduce the best result from 6 months ago using only the tracked information?

If YES: Excellent tracking! If NO: Gaps remain.


Advanced Tracking Patterns

1. Multi-Run Experiments (Hyperparameter Sweeps)

When running many experiments systematically:

# W&B Sweeps
sweep_config = {
    "method": "random",
    "metric": {"name": "val_accuracy", "goal": "maximize"},
    "parameters": {
        "learning_rate": {"values": [0.001, 0.01, 0.1]},
        "batch_size": {"values": [32, 64, 128]},
        "optimizer": {"values": ["adam", "sgd"]},
    },
}

sweep_id = wandb.sweep(sweep_config, project="cifar10-sweep")

def train():
    run = wandb.init()
    config = wandb.config

    model = create_model(config)
    # ... training code ...
    wandb.log({"val_accuracy": val_acc})

wandb.agent(sweep_id, train, count=10)

# MLflow with Optuna
import optuna
import mlflow

def objective(trial):
    with mlflow.start_run(nested=True):
        lr = trial.suggest_loguniform("learning_rate", 1e-5, 1e-1)
        batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])

        mlflow.log_params({"learning_rate": lr, "batch_size": batch_size})

        val_acc = train_and_evaluate(lr, batch_size)
        mlflow.log_metric("val_accuracy", val_acc)

        return val_acc

with mlflow.start_run():
    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=20)

    mlflow.log_param("best_params", study.best_params)
    mlflow.log_metric("best_accuracy", study.best_value)

2. Distributed Training Tracking

When training on multiple GPUs or machines:

import torch.distributed as dist

def setup_distributed_tracking(rank, world_size):
    """Setup tracking for distributed training."""

    # Only rank 0 logs to avoid duplicates
    if rank == 0:
        tracker = ExperimentTracker(config)
    else:
        tracker = None

    return tracker

def train_distributed(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    tracker = setup_distributed_tracking(rank, world_size)

    model = DistributedDataParallel(model, device_ids=[rank])

    for epoch in range(num_epochs):
        train_loss = train_epoch(model, train_loader, optimizer)

        # Gather metrics from all ranks
        train_loss_tensor = torch.tensor(train_loss).cuda()
        dist.all_reduce(train_loss_tensor, op=dist.ReduceOp.SUM)
        avg_train_loss = train_loss_tensor.item() / world_size

        # Only rank 0 logs
        if rank == 0 and tracker:
            tracker.log_metrics({
                "train/loss": avg_train_loss,
                "epoch": epoch,
            })

    if rank == 0 and tracker:
        tracker.finish()

    dist.destroy_process_group()

3. Experiment Resumption

Tracking setup for resumable experiments:

class ResumableExperimentTracker(ExperimentTracker):
    """Experiment tracker with resume support."""

    def __init__(self, config, checkpoint_path=None):
        super().__init__(config)

        self.checkpoint_path = checkpoint_path

        if checkpoint_path and os.path.exists(checkpoint_path):
            self.resume_from_checkpoint()

    def resume_from_checkpoint(self):
        """Resume tracking from saved checkpoint."""
        checkpoint = torch.load(self.checkpoint_path)

        self.global_step = checkpoint.get("global_step", 0)
        self.best_metric = checkpoint.get("best_metric", float('-inf'))

        self.logger.info(f"Resumed from checkpoint: step={self.global_step}")

    def save_checkpoint(self, model, optimizer, epoch, metric_value, metric_name="val_accuracy"):
        """Save checkpoint with tracker state."""
        checkpoint = {
            "epoch": epoch,
            "global_step": self.global_step,
            "best_metric": self.best_metric,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            metric_name: metric_value,
            "config": self.config,
        }

        checkpoint_path = self.exp_dir / "checkpoints" / "latest.pt"
        checkpoint_path.parent.mkdir(exist_ok=True)
        torch.save(checkpoint, checkpoint_path)

        # Also save best
        if metric_value > self.best_metric:
            self.best_metric = metric_value
            best_path = self.exp_dir / "checkpoints" / "best.pt"
            torch.save(checkpoint, best_path)

        return checkpoint_path

# Usage
tracker = ResumableExperimentTracker(config, checkpoint_path="checkpoints/latest.pt")

# Training continues from where it left off
for epoch in range(start_epoch, num_epochs):
    # ... training ...
    tracker.save_checkpoint(model, optimizer, epoch, val_acc)

4. Experiment Comparison and Analysis

Programmatic experiment analysis:

def analyze_experiments(project_name):
    """Analyze all experiments in a project."""

    # W&B
    import wandb
    api = wandb.Api()
    runs = api.runs(project_name)

    # Extract data
    data = []
    for run in runs:
        data.append({
            "name": run.name,
            "learning_rate": run.config.get("learning_rate"),
            "batch_size": run.config.get("batch_size"),
            "val_accuracy": run.summary.get("val_accuracy"),
            "train_time": run.summary.get("_runtime"),
        })

    df = pd.DataFrame(data)

    # Analysis
    print("Top 5 experiments by accuracy:")
    print(df.nlargest(5, "val_accuracy"))

    # Hyperparameter impact
    print("\nAverage accuracy by learning rate:")
    print(df.groupby("learning_rate")["val_accuracy"].mean())

    # Visualization
    import matplotlib.pyplot as plt

    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    # Learning rate vs accuracy
    axes[0].scatter(df["learning_rate"], df["val_accuracy"])
    axes[0].set_xlabel("Learning Rate")
    axes[0].set_ylabel("Validation Accuracy")
    axes[0].set_xscale("log")

    # Batch size vs accuracy
    axes[1].scatter(df["batch_size"], df["val_accuracy"])
    axes[1].set_xlabel("Batch Size")
    axes[1].set_ylabel("Validation Accuracy")

    plt.tight_layout()
    plt.savefig("experiment_analysis.png")

    return df

# Run analysis
df = analyze_experiments("team/cifar10-sota")

5. Data Versioning Integration

Tracking data versions alongside experiments:

import hashlib

def hash_dataset(dataset_path):
    """Compute hash of dataset for versioning."""
    hasher = hashlib.sha256()

    # Hash dataset files
    for file in sorted(Path(dataset_path).rglob("*")):
        if file.is_file():
            with open(file, "rb") as f:
                hasher.update(f.read())

    return hasher.hexdigest()

# Track data version
data_version = hash_dataset("data/cifar10")
config["data_version"] = data_version

tracker = ExperimentTracker(config)

# Or use DVC
"""
# Initialize DVC
dvc init

# Track data
dvc add data/cifar10
git add data/cifar10.dvc

# Log DVC hash in experiment
with open("data/cifar10.dvc") as f:
    dvc_config = yaml.safe_load(f)
    data_hash = dvc_config["outs"][0]["md5"]
    config["data_hash"] = data_hash
"""

6. Artifact Management Best Practices

Organizing and managing experiment artifacts:

class ArtifactManager:
    """Manages experiment artifacts (models, plots, logs)."""

    def __init__(self, experiment_dir):
        self.exp_dir = Path(experiment_dir)

        # Create subdirectories
        self.checkpoints_dir = self.exp_dir / "checkpoints"
        self.figures_dir = self.exp_dir / "figures"
        self.logs_dir = self.exp_dir / "logs"

        for d in [self.checkpoints_dir, self.figures_dir, self.logs_dir]:
            d.mkdir(parents=True, exist_ok=True)

    def save_checkpoint(self, checkpoint, name):
        """Save checkpoint with automatic cleanup."""
        path = self.checkpoints_dir / f"{name}.pt"
        torch.save(checkpoint, path)

        # Keep only last N checkpoints (except best)
        self._cleanup_checkpoints(keep_n=5)

        return path

    def _cleanup_checkpoints(self, keep_n=5):
        """Keep only recent checkpoints to save space."""
        checkpoints = sorted(
            self.checkpoints_dir.glob("checkpoint_epoch_*.pt"),
            key=lambda p: p.stat().st_mtime,
            reverse=True,
        )

        # Delete old checkpoints (keep best + last N)
        for ckpt in checkpoints[keep_n:]:
            if "best" not in ckpt.name:
                ckpt.unlink()

    def save_figure(self, fig, name, step=None):
        """Save matplotlib figure with metadata."""
        if step is not None:
            filename = f"{name}_step_{step}.png"
        else:
            filename = f"{name}.png"

        path = self.figures_dir / filename
        fig.savefig(path, dpi=150, bbox_inches="tight")

        return path

    def get_artifact_summary(self):
        """Get summary of stored artifacts."""
        summary = {
            "num_checkpoints": len(list(self.checkpoints_dir.glob("*.pt"))),
            "num_figures": len(list(self.figures_dir.glob("*.png"))),
            "total_size_mb": sum(
                f.stat().st_size for f in self.exp_dir.rglob("*") if f.is_file()
            ) / (1024 * 1024),
        }
        return summary

# Usage
artifacts = ArtifactManager(experiment_dir)
artifacts.save_checkpoint(checkpoint, f"checkpoint_epoch_{epoch}")
artifacts.save_figure(fig, "training_curve")
print(artifacts.get_artifact_summary())

7. Real-Time Monitoring and Alerts

Setup alerts for experiment issues:

# W&B Alerts
import wandb

wandb.init(project="cifar10")

for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, optimizer)

    wandb.log({"train/loss": train_loss, "epoch": epoch})

    # Alert on divergence
    if math.isnan(train_loss) or train_loss > 10:
        wandb.alert(
            title="Training Diverged",
            text=f"Loss is {train_loss} at epoch {epoch}",
            level=wandb.AlertLevel.ERROR,
        )
        break

    # Alert on milestone
    if val_acc > 0.90:
        wandb.alert(
            title="90% Accuracy Reached!",
            text=f"Validation accuracy: {val_acc:.2%}",
            level=wandb.AlertLevel.INFO,
        )

# Slack integration
def send_slack_alert(message, webhook_url):
    """Send alert to Slack."""
    import requests
    requests.post(webhook_url, json={"text": message})

# Email alerts
def send_email_alert(subject, body, to_email):
    """Send email alert."""
    import smtplib
    from email.message import EmailMessage

    msg = EmailMessage()
    msg["Subject"] = subject
    msg["To"] = to_email
    msg.set_content(body)

    # Send via SMTP
    with smtplib.SMTP("localhost") as s:
        s.send_message(msg)

Common Integration Patterns

Pattern 1: Training Script with Complete Tracking

#!/usr/bin/env python3
"""
Complete training script with experiment tracking.
"""

import argparse
import yaml
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from experiment_tracker import ExperimentTracker
from models import create_model
from data import load_dataset

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True, help="Config file")
    parser.add_argument("--resume", type=str, help="Resume from checkpoint")
    parser.add_argument("--name", type=str, help="Experiment name")
    return parser.parse_args()

def main():
    args = parse_args()

    # Load config
    with open(args.config) as f:
        config = yaml.safe_load(f)

    # Initialize tracking
    tracker = ExperimentTracker(
        config=config,
        experiment_name=args.name,
        use_wandb=True,
        use_tensorboard=True,
    )

    # Setup training
    model = create_model(config)
    optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])
    criterion = nn.CrossEntropyLoss()

    train_loader, val_loader = load_dataset(config)

    # Resume if checkpoint provided
    start_epoch = 0
    if args.resume:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_epoch = checkpoint["epoch"] + 1
        tracker.logger.info(f"Resumed from epoch {start_epoch}")

    # Training loop
    best_val_acc = 0.0
    for epoch in range(start_epoch, config["num_epochs"]):
        # Train
        model.train()
        train_losses = []
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())

            # Log every N batches
            if batch_idx % config.get("log_interval", 100) == 0:
                tracker.log_metrics({
                    "train/loss": loss.item(),
                    "train/lr": optimizer.param_groups[0]['lr'],
                }, step=epoch * len(train_loader) + batch_idx)

        # Validate
        model.eval()
        val_losses = []
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in val_loader:
                output = model(data)
                loss = criterion(output, target)
                val_losses.append(loss.item())

                pred = output.argmax(dim=1)
                correct += (pred == target).sum().item()
                total += target.size(0)

        train_loss = sum(train_losses) / len(train_losses)
        val_loss = sum(val_losses) / len(val_losses)
        val_acc = correct / total

        # Log epoch metrics
        tracker.log_metrics({
            "train/loss_epoch": train_loss,
            "val/loss": val_loss,
            "val/accuracy": val_acc,
            "epoch": epoch,
        })

        # Save checkpoint
        tracker.save_checkpoint(model, optimizer, epoch, val_acc)

        # Update best
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            tracker.logger.info(f"New best accuracy: {val_acc:.4f}")

        # Early stopping
        if epoch > 50 and val_acc < 0.5:
            tracker.logger.warning("Model not improving, stopping early")
            break

    # Log final results
    tracker.log_metrics({"best_val_accuracy": best_val_acc})
    tracker.logger.info(f"Training completed. Best accuracy: {best_val_acc:.4f}")

    # Cleanup
    tracker.finish()

if __name__ == "__main__":
    main()

Usage:

# Train new model
python train.py --config configs/resnet18.yaml --name resnet18_baseline

# Resume training
python train.py --config configs/resnet18.yaml --resume experiments/resnet18_baseline/checkpoints/latest.pt

Further Reading


Remember: Experiment tracking is insurance. It costs 1% overhead but saves 100% when disaster strikes. Track from day 1, track everything, and your future self will thank you.