Claude Code Plugins

Community-maintained marketplace

Feedback
3
1

Expert guidance for JAX (Just After eXecution) - high-performance numerical computing with automatic differentiation, JIT compilation, vectorization, and GPU/TPU acceleration; includes transformations (grad, jit, vmap, pmap), sharp bits, gotchas, and differences from NumPy

Install Skill

1Download skill
2Enable skills in Claude

Open claude.ai/settings/capabilities and find the "Skills" section

3Upload to Claude

Click "Upload skill" and select the downloaded ZIP file

Note: Please verify skill by going through its instructions before using it.

SKILL.md

name python-jax
description Expert guidance for JAX (Just After eXecution) - high-performance numerical computing with automatic differentiation, JIT compilation, vectorization, and GPU/TPU acceleration; includes transformations (grad, jit, vmap, pmap), sharp bits, gotchas, and differences from NumPy
allowed-tools *

JAX - High-Performance Numerical Computing

Overview

JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. It combines a familiar NumPy-style API with powerful function transformations for automatic differentiation, compilation, vectorization, and parallelization.

Core value: Write NumPy-like Python code and automatically get gradients, GPU/TPU acceleration, vectorization, and parallelization through composable function transformations—without changing your mathematical notation.

When to Use JAX

Use JAX when:

  • Need automatic differentiation for optimization or machine learning
  • Want GPU/TPU acceleration with minimal code changes
  • Require high-performance numerical computing
  • Building custom gradient-based algorithms
  • Need to vectorize or parallelize functions automatically
  • Working on research requiring flexible differentiation
  • Want functional programming approach to numerical code

Don't use when:

  • Simple NumPy operations without performance needs (overhead not justified)
  • Heavy reliance on in-place mutations (JAX arrays are immutable)
  • Imperative, stateful code with side effects
  • Need control flow depending on runtime data values (limited support)
  • Working with libraries that require NumPy arrays (compatibility issues)

Core Transformations

JAX provides four fundamental, composable transformations:

1. jax.grad - Automatic Differentiation

Compute gradients automatically using reverse-mode autodiff.

import jax
import jax.numpy as jnp

def loss(x):
    return jnp.sum(x**2)

# Get gradient function
grad_loss = jax.grad(loss)

x = jnp.array([1.0, 2.0, 3.0])
gradient = grad_loss(x)
print(gradient)  # [2. 4. 6.]

Key features:

  • Returns a function that computes gradients
  • Composes for higher-order derivatives
  • Works with complex nested structures (pytrees)

2. jax.jit - Just-In-Time Compilation

Compile functions to XLA for dramatic speedups.

import jax

@jax.jit
def fast_function(x):
    return jnp.sum(x**2 + 3*x + 1)

# First call: compilation + execution (slower)
result = fast_function(jnp.array([1.0, 2.0, 3.0]))

# Subsequent calls: cached compiled version (much faster)
result = fast_function(jnp.array([4.0, 5.0, 6.0]))

Performance:

  • 10-100x speedup typical for numerical functions
  • First call slower (compilation overhead)
  • Cached for subsequent calls with same shapes/dtypes

3. jax.vmap - Automatic Vectorization

Vectorize functions across batch dimensions automatically.

import jax

def process_single(x):
    """Process single example"""
    return jnp.sum(x**2)

# Vectorize to process batch
process_batch = jax.vmap(process_single)

# Works on batch automatically
batch = jnp.array([[1, 2], [3, 4], [5, 6]])
results = process_batch(batch)
print(results)  # [5 25 61]

Benefits:

  • Eliminates manual loop writing
  • Often faster than explicit loops
  • Cleaner, more declarative code

4. jax.pmap - Parallel Map

Parallelize across multiple devices (GPUs/TPUs).

import jax

@jax.pmap
def parallel_fn(x):
    return x**2

# Runs on all available devices
devices = jax.devices()
x = jnp.arange(len(devices))
results = parallel_fn(x)

Use for:

  • Multi-GPU/TPU computation
  • Data parallelism in training
  • Large-scale simulations

Transformation Composition

JAX transformations compose seamlessly:

# Gradient of JIT-compiled function
@jax.jit
def fast_loss(x):
    return jnp.sum(x**2)

grad_fast_loss = jax.grad(fast_loss)

# JIT of gradient (equally valid)
fast_grad_loss = jax.jit(jax.grad(fast_loss))

# Vectorized gradient
batch_grad_loss = jax.vmap(jax.grad(fast_loss))

# JIT + vmap + grad
fast_batch_grad = jax.jit(jax.vmap(jax.grad(fast_loss)))

Order matters for performance but not correctness.

JAX vs NumPy - Critical Differences

1. Immutability

NumPy (mutable):

import numpy as np
x = np.array([1, 2, 3])
x[0] = 10  # Modifies x in-place
print(x)  # [10 2 3]

JAX (immutable):

import jax.numpy as jnp
x = jnp.array([1, 2, 3])
# x[0] = 10  # ERROR: JAX arrays are immutable

# Use functional update instead
x = x.at[0].set(10)  # Returns NEW array
print(x)  # [10 2 3]

Functional updates:

# Set value
x = x.at[0].set(10)

# Add to value
x = x.at[0].add(5)

# Multiply value
x = x.at[0].mul(2)

# Min/max
x = x.at[0].min(5)
x = x.at[0].max(5)

# Multi-index
x = x.at[0, 1].set(10)
x = x.at[[0, 2]].set(10)
x = x.at[0:3].set(10)

2. Random Number Generation

NumPy (global state):

import numpy as np
np.random.seed(42)
x = np.random.normal(size=3)
y = np.random.normal(size=3)  # Different from x

JAX (explicit keys):

import jax
key = jax.random.PRNGKey(42)

# Split key for independent randomness
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, shape=(3,))

key, subkey = jax.random.split(key)
y = jax.random.normal(subkey, shape=(3,))

# Parallel random numbers
keys = jax.random.split(key, num=10)
samples = jax.vmap(lambda k: jax.random.normal(k, shape=(3,)))(keys)

Key management pattern:

# Create initial key
key = jax.random.PRNGKey(0)

# Always split before using
key, subkey1, subkey2 = jax.random.split(key, 3)

# Use subkeys for random operations
x = jax.random.normal(subkey1, shape=(10,))
y = jax.random.uniform(subkey2, shape=(10,))

# Never reuse keys!

3. 64-bit Precision

JAX defaults to 32-bit for performance.

import jax.numpy as jnp

x = jnp.array([1.0])
print(x.dtype)  # float32 (default)

# Enable 64-bit globally
from jax import config
config.update("jax_enable_x64", True)

x = jnp.array([1.0])
print(x.dtype)  # float64

# Or use environment variable
# JAX_ENABLE_X64=1 python script.py

4. Out-of-Bounds Indexing

NumPy (raises error):

import numpy as np
x = np.array([1, 2, 3])
# x[10]  # IndexError

JAX (clamps silently):

import jax.numpy as jnp
x = jnp.array([1, 2, 3])
print(x[10])  # 3 (returns last element, no error!)

# Updates also silently ignored
x = x.at[10].set(99)  # No error, no effect

⚠️ Warning: This is undefined behavior - avoid relying on it!

5. Non-Array Inputs

NumPy (accepts lists):

import numpy as np
result = np.sum([1, 2, 3])  # Works

JAX (requires arrays):

import jax.numpy as jnp
# result = jnp.sum([1, 2, 3])  # ERROR

# Convert explicitly
result = jnp.sum(jnp.array([1, 2, 3]))  # Works

Reason: Prevents performance degradation during tracing.

Sharp Bits & Gotchas

1. Pure Functions Required

❌ Bad: Side effects

counter = 0

@jax.jit
def impure_fn(x):
    global counter
    counter += 1  # Side effect - UNRELIABLE
    return x**2

# First call: counter=1 (traced)
result = impure_fn(2.0)
# Subsequent calls: counter STILL 1 (uses cached trace)
result = impure_fn(3.0)

✅ Good: Pure function

@jax.jit
def pure_fn(x, counter):
    return x**2, counter + 1

result, counter = pure_fn(2.0, 0)
result, counter = pure_fn(3.0, counter)

2. Control Flow Limitations

❌ Bad: Value-dependent control flow

@jax.jit
def conditional(x):
    if x > 0:  # ERROR: x is tracer, not concrete value
        return x**2
    else:
        return x**3

✅ Good: Use jax.lax.cond

@jax.jit
def conditional(x):
    return jax.lax.cond(
        x > 0,
        lambda x: x**2,  # True branch
        lambda x: x**3   # False branch
    )

For loops:

# ❌ Bad: Python for loop in JIT (unrolled at compile time)
@jax.jit
def loop_bad(x, n):
    for i in range(n):  # n must be static
        x = x + 1
    return x

# ✅ Good: Use jax.lax.fori_loop
@jax.jit
def loop_good(x, n):
    def body(i, val):
        return val + 1
    return jax.lax.fori_loop(0, n, body, x)

While loops:

@jax.jit
def while_loop(x):
    def cond_fun(val):
        return val < 10

    def body_fun(val):
        return val + 1

    return jax.lax.while_loop(cond_fun, body_fun, x)

3. Dynamic Shapes

❌ Bad: Shape depends on runtime values

@jax.jit
def dynamic_shape(x, mask):
    return x[mask]  # ERROR: output shape unknown at compile time

✅ Good: Use jnp.where for masking

@jax.jit
def static_shape(x, mask):
    return jnp.where(mask, x, 0)  # Shape stays same

4. In-Place Update Semantics

❌ Bad: Relying on shared references

x = jnp.array([1, 2, 3])
y = x
x = x.at[0].set(10)
print(y[0])  # Still 1 (y unchanged)

Note: .at returns a NEW array; doesn't modify original.

5. Print Debugging in JIT

❌ Bad: print() in JIT context

@jax.jit
def debug(x):
    print(f"x = {x}")  # Only prints ONCE (during trace)
    return x**2

✅ Good: Use jax.debug.print

@jax.jit
def debug(x):
    jax.debug.print("x = {}", x)  # Prints every call
    return x**2

6. Gradient Through Discrete Operations

❌ Bad: Gradient through argmax

def loss(x):
    idx = jnp.argmax(x)  # Discrete - no gradient
    return x[idx]

# grad_loss = jax.grad(loss)  # Gradient is zero everywhere

✅ Good: Use differentiable approximations

def loss(x):
    # Soft argmax (differentiable)
    weights = jax.nn.softmax(x * temperature)
    return jnp.sum(x * weights)

grad_loss = jax.grad(loss)

Automatic Differentiation

Basic Gradient

import jax
import jax.numpy as jnp

def f(x):
    return jnp.sum(x**2)

grad_f = jax.grad(f)

x = jnp.array([1.0, 2.0, 3.0])
print(grad_f(x))  # [2. 4. 6.]

Value and Gradient

# Compute both value and gradient efficiently
value_and_grad_f = jax.value_and_grad(f)

value, gradient = value_and_grad_f(x)
print(f"Value: {value}, Gradient: {gradient}")

Multiple Arguments

def f(x, y):
    return jnp.sum(x**2 + y**3)

# Gradient w.r.t. first argument (default)
grad_f = jax.grad(f)
print(grad_f(x, y))

# Gradient w.r.t. second argument
grad_f_wrt_y = jax.grad(f, argnums=1)
print(grad_f_wrt_y(x, y))

# Gradients w.r.t. both arguments
grad_f_both = jax.grad(f, argnums=(0, 1))
grad_x, grad_y = grad_f_both(x, y)

Auxiliary Data

def loss_with_aux(x):
    loss = jnp.sum(x**2)
    aux_data = {'norm': jnp.linalg.norm(x), 'mean': jnp.mean(x)}
    return loss, aux_data

# Tell JAX about auxiliary output
grad_fn = jax.grad(loss_with_aux, has_aux=True)
gradient, aux = grad_fn(x)

print(f"Gradient: {gradient}")
print(f"Auxiliary: {aux}")

Higher-Order Derivatives

# Second derivative (Hessian diagonal)
def f(x):
    return jnp.sum(x**3)

grad_f = jax.grad(f)
hess_diag_f = jax.grad(lambda x: jnp.sum(grad_f(x) * x))

# Full Hessian
hessian_f = jax.hessian(f)

x = jnp.array([1.0, 2.0, 3.0])
print(hessian_f(x))

Jacobian

def vector_fn(x):
    """Vector to vector function"""
    return jnp.array([x[0]**2, x[1]**3, x[0]*x[1]])

# Forward-mode (efficient for few inputs, many outputs)
jacfwd = jax.jacfwd(vector_fn)

# Reverse-mode (efficient for many inputs, few outputs)
jacrev = jax.jacrev(vector_fn)

x = jnp.array([2.0, 3.0])
print(jacfwd(x))
print(jacrev(x))  # Same result

Custom Gradients

@jax.custom_gradient
def f(x):
    # Forward pass
    result = jnp.exp(x)

    # Define custom gradient
    def grad_fn(g):
        # g is the gradient from downstream
        # Return gradient w.r.t. inputs
        return g * 2 * result  # Custom: 2x the normal gradient

    return result, grad_fn

# Use like normal
grad_f = jax.grad(f)

JIT Compilation

Basic Usage

import jax
import jax.numpy as jnp

def slow_fn(x):
    return jnp.sum(x**2 + 3*x + 1)

fast_fn = jax.jit(slow_fn)

# Or decorator
@jax.jit
def fast_fn2(x):
    return jnp.sum(x**2 + 3*x + 1)

Static Arguments

@jax.jit
def fn(x, n):
    for i in range(n):  # n must be static
        x = x + 1
    return x

# ERROR: n changes each call, triggers recompilation
result = fn(x, 5)
result = fn(x, 10)  # Recompiles!

# Solution: Mark as static
@jax.jit(static_argnums=(1,))
def fn_static(x, n):
    for i in range(n):
        x = x + 1
    return x

# Now works correctly
result = fn_static(x, 5)
result = fn_static(x, 10)  # Separate compilation for n=10

Avoiding Recompilation

# Bad: Different shapes trigger recompilation
@jax.jit
def process(x):
    return jnp.sum(x**2)

x1 = jnp.ones(10)    # Compiles for shape (10,)
x2 = jnp.ones(20)    # Recompiles for shape (20,)
x3 = jnp.ones(10)    # Uses cached version

# Good: Consistent shapes
batch_size = 32
x1 = jnp.ones((batch_size, 10))
x2 = jnp.ones((batch_size, 10))  # Same shape, cached

Vectorization (vmap)

Basic Batching

def process_single(x):
    """Process single example: scalar input"""
    return x**2 + 3*x

# Manually vectorize
def process_batch_manual(xs):
    return jnp.array([process_single(x) for x in xs])

# Automatic vectorization
process_batch = jax.vmap(process_single)

batch = jnp.array([1.0, 2.0, 3.0, 4.0])
print(process_batch(batch))  # Faster and cleaner

Batching Matrix Operations

def matrix_vector_product(matrix, vector):
    """Single matrix-vector product"""
    return matrix @ vector

# Batch over vectors
batch_mvp = jax.vmap(matrix_vector_product, in_axes=(None, 0))

A = jnp.ones((3, 3))
vectors = jnp.ones((10, 3))  # 10 vectors

results = batch_mvp(A, vectors)  # (10, 3)

# Batch over both
batch_both = jax.vmap(matrix_vector_product, in_axes=(0, 0))

matrices = jnp.ones((10, 3, 3))
results = batch_both(matrices, vectors)

Nested vmap

# Batch over two dimensions
def fn(x, y):
    return x * y

# Batch over first arg, then second
fn_batch = jax.vmap(jax.vmap(fn, in_axes=(None, 0)), in_axes=(0, None))

x = jnp.array([1, 2, 3])       # (3,)
y = jnp.array([10, 20])        # (2,)
result = fn_batch(x, y)        # (3, 2)

# Equivalent to: x[:, None] * y[None, :]

PyTrees - Nested Structures

JAX works with nested Python containers (pytrees):

import jax
import jax.numpy as jnp

# Dictionary of arrays
params = {
    'w1': jnp.ones((10, 5)),
    'b1': jnp.zeros(5),
    'w2': jnp.ones((5, 1)),
    'b2': jnp.zeros(1)
}

# Gradient works on entire pytree
def loss(params, x):
    h = x @ params['w1'] + params['b1']
    h = jax.nn.relu(h)
    out = h @ params['w2'] + params['b2']
    return jnp.mean(out**2)

grad_fn = jax.grad(loss)
x = jnp.ones((32, 10))

# Returns gradients in same structure
grads = grad_fn(params, x)
print(grads.keys())  # dict_keys(['w1', 'b1', 'w2', 'b2'])

PyTree Operations

# Tree map - apply function to all leaves
scaled_params = jax.tree_map(lambda x: x * 0.9, params)

# Tree reduce
total_params = jax.tree_reduce(
    lambda total, x: total + x.size,
    params,
    initializer=0
)

# Flatten and unflatten
flat, treedef = jax.tree_flatten(params)
reconstructed = jax.tree_unflatten(treedef, flat)

Common Patterns

Optimization Loop

import jax
import jax.numpy as jnp

# Parameters
params = jnp.array([1.0, 2.0, 3.0])

# Loss function
def loss(params, x, y):
    pred = jnp.dot(params, x)
    return jnp.mean((pred - y)**2)

# Gradient function
grad_fn = jax.jit(jax.grad(loss))

# Training data
x_train = jnp.ones((100, 3))
y_train = jnp.ones(100)

# Optimization loop
learning_rate = 0.01
for step in range(1000):
    grads = grad_fn(params, x_train, y_train)
    params = params - learning_rate * grads

    if step % 100 == 0:
        l = loss(params, x_train, y_train)
        print(f"Step {step}, Loss: {l:.4f}")

Mini-batch Training

def train_step(params, batch):
    """Single training step on one batch"""
    x, y = batch

    def batch_loss(params):
        pred = jnp.dot(x, params)
        return jnp.mean((pred - y)**2)

    loss_value, grads = jax.value_and_grad(batch_loss)(params)
    params = params - 0.01 * grads
    return params, loss_value

# JIT compile training step
train_step = jax.jit(train_step)

# Training loop
for epoch in range(10):
    for batch in data_loader:
        params, loss = train_step(params, batch)

Scan for Efficient Loops

def cumulative_sum(xs):
    """Efficient cumulative sum using scan"""
    def step(carry, x):
        new_carry = carry + x
        output = new_carry
        return new_carry, output

    final_carry, outputs = jax.lax.scan(step, 0, xs)
    return outputs

xs = jnp.array([1, 2, 3, 4, 5])
print(cumulative_sum(xs))  # [1 3 6 10 15]

RNN with Scan

def rnn_step(carry, x):
    """Single RNN step"""
    h = carry
    h_new = jnp.tanh(jnp.dot(W_h, h) + jnp.dot(W_x, x))
    return h_new, h_new

def rnn(xs, h0):
    """Run RNN over sequence"""
    final_h, all_h = jax.lax.scan(rnn_step, h0, xs)
    return all_h

# Process sequence
W_h = jnp.ones((5, 5))
W_x = jnp.ones((5, 3))
xs = jnp.ones((10, 3))  # Sequence length 10
h0 = jnp.zeros(5)

outputs = rnn(xs, h0)  # (10, 5)

Performance Best Practices

1. JIT Your Critical Paths

# Wrap expensive computations in jit
@jax.jit
def expensive_fn(x):
    for _ in range(100):
        x = jnp.dot(x, x.T)
    return x

# Avoid jitting trivial operations
def trivial(x):
    return x + 1  # JIT overhead not worth it

2. Use vmap Instead of Loops

# Slow: Python loop
def slow_batch(xs):
    return jnp.array([process(x) for x in xs])

# Fast: vmap
fast_batch = jax.vmap(process)

3. Minimize Host-Device Transfers

# Bad: Transfer back to host in loop
for i in range(1000):
    x = compute_on_gpu(x)
    print(float(x))  # Transfers to CPU each iteration!

# Good: Transfer once at end
for i in range(1000):
    x = compute_on_gpu(x)
print(float(x))  # Single transfer

4. Use Appropriate Precision

# Use float32 unless you need float64
x = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32)

# Enable float64 only when necessary
from jax import config
config.update("jax_enable_x64", True)

5. Preallocate When Possible

# Bad: Growing arrays
result = jnp.array([])
for i in range(1000):
    result = jnp.append(result, compute(i))

# Good: Preallocate
result = jnp.zeros(1000)
for i in range(1000):
    result = result.at[i].set(compute(i))

# Better: Use scan or vmap
def compute_all(i):
    return compute(i)

result = jax.vmap(compute_all)(jnp.arange(1000))

Debugging

Checking Array Values

# Use jax.debug.print in JIT context
@jax.jit
def debug_fn(x):
    jax.debug.print("x = {}", x)
    jax.debug.print("x shape = {}, dtype = {}", x.shape, x.dtype)
    return x**2

Gradient Checking

from jax.test_util import check_grads

def f(x):
    return jnp.sum(x**3)

x = jnp.array([1.0, 2.0, 3.0])

# Verify gradients numerically
check_grads(f, (x,), order=2)  # Check 1st and 2nd derivatives

Inspecting Compiled Code

# See jaxpr (intermediate representation)
def f(x):
    return x**2 + 3*x

jaxpr = jax.make_jaxpr(f)(1.0)
print(jaxpr)

# See HLO (low-level compiled code)
compiled = jax.jit(f).lower(1.0).compile()
print(compiled.as_text())

Common Gotchas - Quick Reference

Gotcha NumPy JAX Solution
In-place update x[0] = 1 ❌ Error x = x.at[0].set(1)
Random state np.random.seed() ❌ Not reliable key = jax.random.PRNGKey()
List inputs np.sum([1,2,3]) ❌ Error jnp.sum(jnp.array([1,2,3]))
Out-of-bounds IndexError ⚠️ Silent clamp Avoid, validate indices
Value-dependent if Works ❌ In JIT jax.lax.cond()
Dynamic shapes Works ❌ In JIT Keep shapes static
Default precision float64 float32 Set jax_enable_x64
Print in JIT Works Only once jax.debug.print()

Installation

# CPU only
pip install jax

# GPU (CUDA 12)
pip install jax[cuda12]

# TPU
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Ecosystem Libraries

JAX has a rich ecosystem built on its transformation primitives:

Neural Networks:

  • Flax - Official neural network library
  • Haiku - DeepMind's neural network library
  • Equinox - Elegant PyTorch-like library

Optimization:

  • Optax - Gradient processing and optimization
  • JAXopt - Non-linear optimization

Scientific Computing:

  • JAX-MD - Molecular dynamics
  • Diffrax - Differential equation solvers
  • BlackJAX - MCMC sampling

Utilities:

  • jaxtyping - Type annotations for arrays
  • chex - Testing utilities

Additional Resources

Related Skills

  • python-optimization - Numerical optimization with scipy, pyomo
  • python-ase - Atomic simulation environment (can use JAX for forces)
  • pycse - Scientific computing utilities
  • python-best-practices - Code quality for JAX projects