Claude Code Plugins

Community-maintained marketplace

Feedback

Write correct, idiomatic Apple MLX code for Apple Silicon ML. Use when working with MLX arrays, neural networks, training loops, lazy evaluation, unified memory, mx.eval, mx.compile, Metal GPU, memory optimization, quantization, or Apple Silicon performance. Covers critical API differences from PyTorch/NumPy, array indexing gotchas (lists must be mx.array, slices create copies), NHWC format for Conv2d, __call__ not forward(), float64 CPU-only, mlx-lm integration, and debugging patterns.

Install Skill

1Download skill
2Enable skills in Claude

Open claude.ai/settings/capabilities and find the "Skills" section

3Upload to Claude

Click "Upload skill" and select the downloaded ZIP file

Note: Please verify skill by going through its instructions before using it.

SKILL.md

name mlx-dev
description Write correct, idiomatic Apple MLX code for Apple Silicon ML. Use when working with MLX arrays, neural networks, training loops, lazy evaluation, unified memory, mx.eval, mx.compile, Metal GPU, memory optimization, quantization, or Apple Silicon performance. Covers critical API differences from PyTorch/NumPy, array indexing gotchas (lists must be mx.array, slices create copies), NHWC format for Conv2d, __call__ not forward(), float64 CPU-only, mlx-lm integration, and debugging patterns.

MLX Development Guide

Environment Setup

Use uv for Python environment and package management:

# Install MLX
uv add mlx

# Run MLX scripts
uv run python train.py

# Run with specific dependencies
uv run --with mlx python script.py

Critical Rules

1. Lazy Evaluation - Always Evaluate at Loop Boundaries

Operations build a graph; nothing computes until mx.eval():

# CORRECT: Evaluate at iteration boundaries
for batch in dataset:
    loss, grads = value_and_grad_fn(model, batch)
    optimizer.update(model, grads)
    mx.eval(loss, model.parameters())  # ALL computation here

# WRONG: Evaluating too frequently
for _ in range(100):
    a = a + b
    mx.eval(a)  # Massive overhead!

Implicit eval triggers: print(a), a.item(), np.array(a), if a > 0:.

2. Array Indexing Differs from NumPy

# Lists must be mx.array
a[[0, 1]]              # ValueError!
a[mx.array([0, 1])]    # Works

# Slice indices must be Python ints
i = mx.array(2)
x[i:i+2]               # ValueError!
x[i.item():i.item()+2] # Works (forces eval)

# Slices create COPIES, not views (opposite of NumPy)
b = a[:]
b[2] = 0  # a is unchanged!

# Boolean mask READS not supported
a[mask]  # Not supported - use mx.where()

# No bounds checking - out-of-bounds returns garbage

For accumulating updates, use at[] syntax:

a = a.at[idx].add(1)  # Properly accumulates at duplicate indices

See references/array-indexing.md for complete patterns.

3. Neural Networks: NHWC Format and call

# Conv2d uses NHWC (not NCHW like PyTorch)
x_mlx = mx.array(x_torch.numpy().transpose(0, 2, 3, 1))

# Override __call__, not forward()
class MyModel(nn.Module):
    def __call__(self, x):  # NOT forward()
        return self.layer(x)

# No dtype in constructors - use set_dtype()
layer = nn.Linear(10, 10)
layer.set_dtype(mx.bfloat16)

See references/neural-networks.md for layer equivalents.

4. Data Types: float64 is CPU-Only

a = mx.array([1.0], dtype=mx.float64)
mx.exp(a, stream=mx.gpu)  # RuntimeError!

# Solutions:
mx.exp(a, stream=mx.cpu)
mx.exp(a.astype(mx.float32))

# bfloat16 from external sources gets misinterpreted
from ml_dtypes import bfloat16
x = np.array(1., dtype=bfloat16)
mx.array(x)  # Returns complex64!
mx.array(x.astype(np.float32), dtype=mx.bfloat16)  # Correct

See references/dtypes.md for full type support table.

5. Compilation: Capture All Mutable State

from functools import partial

state = [model.state, optimizer.state, mx.random.state]  # Include random!

@partial(mx.compile, inputs=state, outputs=state)
def train_step(x, y):
    loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
    optimizer.update(model, grads)
    return loss

# No print() in compiled functions - crashes during tracing
# String decoding triggers recompilation - decode outside loop

See references/compilation.md for recompilation triggers.

Quick Reference Tables

Dtype Support

Type GPU Notes
float32 Yes Default float
float16 Yes
bfloat16 Yes M3+ recommended
float64 CPU only GPU throws!
int8-64, uint8-64 Yes
complex64 Partial No matmul

PyTorch → MLX Equivalents

PyTorch MLX
tensor.to('cuda') Not needed (unified memory)
nn.forward() nn.__call__()
NCHW format NHWC format
torch.gather() mx.take_along_axis()
torch.scatter_add_() arr.at[idx].add()

Not Available in MLX

  • np.nonzero() - restructure algorithm
  • np.unique() - pre-sort or use dicts
  • arr[bool_mask] read - use mx.where()
  • np.linalg.det(), np.linalg.lstsq()

Performance Notes

  • Transformers: MLX typically 2-3x faster than PyTorch MPS
  • Convolutions: 10-150x SLOWER than PyTorch MPS (known limitation)
  • LLM inference: Excellent, especially quantized
  • Use float16/bfloat16 for 2x memory bandwidth
  • Use 4-bit quantization for LLMs (4x bandwidth)

See Also

Idiomatic Training Example

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = [nn.Linear(784, 256), nn.Linear(256, 10)]

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = mx.maximum(layer(x), 0)
        return self.layers[-1](x)

def loss_fn(model, x, y):
    return nn.losses.cross_entropy(model(x), y, reduction="mean")

model = Model()
optimizer = optim.AdamW(learning_rate=1e-3)

state = [model.state, optimizer.state, mx.random.state]

@partial(mx.compile, inputs=state, outputs=state)
def train_step(x, y):
    loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
    optimizer.update(model, grads)
    return loss

for epoch in range(num_epochs):
    for x_batch, y_batch in dataloader:
        loss = train_step(x_batch, y_batch)
        mx.eval(state)
    print(f"Epoch {epoch}: {loss.item():.4f}")