Claude Code Plugins

Community-maintained marketplace

Feedback

model-serving-patterns

@tachyon-beep/skillpacks
1
0

Design production serving with FastAPI, TorchServe, gRPC, batching, and containerization.

Install Skill

1Download skill
2Enable skills in Claude

Open claude.ai/settings/capabilities and find the "Skills" section

3Upload to Claude

Click "Upload skill" and select the downloaded ZIP file

Note: Please verify skill by going through its instructions before using it.

SKILL.md

name model-serving-patterns
description Design production serving with FastAPI, TorchServe, gRPC, batching, and containerization.

Model Serving Patterns Skill

When to Use This Skill

Use this skill when:

  • Deploying ML models to production environments
  • Building model serving APIs for real-time inference
  • Optimizing model serving for throughput and latency
  • Containerizing models for consistent deployment
  • Implementing request batching for efficiency
  • Choosing between serving frameworks and protocols

When NOT to use: Notebook prototyping, training jobs, or single-prediction scripts where serving infrastructure is premature.

Core Principle

Serving infrastructure is not one-size-fits-all. Pattern selection is context-dependent.

Without proper serving infrastructure:

  • model.pkl in repo (manual dependency hell)
  • Wrong protocol choice (gRPC for simple REST use cases)
  • No batching (1 req/sec instead of 100 req/sec)
  • Not containerized (works on my machine syndrome)
  • Static batching when dynamic needed (underutilized GPU)

Formula: Right framework (FastAPI vs TorchServe vs gRPC vs ONNX) + Request batching (dynamic > static) + Containerization (Docker + model) + Clear selection criteria = Production-ready serving.

Serving Framework Decision Tree

┌────────────────────────────────────────┐
│     What's your primary requirement?   │
└──────────────┬─────────────────────────┘
               │
       ┌───────┴───────┐
       ▼               ▼
   Flexibility    Batteries Included
       │               │
       ▼               ▼
   FastAPI         TorchServe
   (Custom)        (PyTorch)
       │               │
       │       ┌───────┴───────┐
       │       ▼               ▼
       │   Low Latency   Cross-Framework
       │       │               │
       │       ▼               ▼
       │    gRPC         ONNX Runtime
       │       │               │
       └───────┴───────────────┘
                   │
                   ▼
       ┌───────────────────────┐
       │  Add Request Batching  │
       │  Dynamic > Static      │
       └───────────┬────────────┘
                   │
                   ▼
       ┌───────────────────────┐
       │   Containerize with    │
       │   Docker + Dependencies│
       └────────────────────────┘

Part 1: FastAPI for Custom Serving

When to use: Need flexibility, custom preprocessing, or non-standard workflows.

Advantages: Full control, easy debugging, Python ecosystem integration. Disadvantages: Manual optimization, no built-in model management.

Basic FastAPI Serving

# serve_fastapi.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import torch
import numpy as np
from typing import List, Optional
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="Model Serving API", version="1.0.0")

class PredictionRequest(BaseModel):
    """Request schema with validation."""
    inputs: List[List[float]] = Field(..., description="Input features as 2D array")
    return_probabilities: bool = Field(False, description="Return class probabilities")

    class Config:
        schema_extra = {
            "example": {
                "inputs": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
                "return_probabilities": True
            }
        }

class PredictionResponse(BaseModel):
    """Response schema."""
    predictions: List[int]
    probabilities: Optional[List[List[float]]] = None
    latency_ms: float

class ModelServer:
    """
    Model server with lazy loading and caching.

    WHY: Load model once at startup, reuse across requests.
    WHY: Avoids 5-10 second model loading per request.
    """

    def __init__(self, model_path: str):
        self.model_path = model_path
        self.model = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def load_model(self):
        """Load model on first request (lazy loading)."""
        if self.model is None:
            logger.info(f"Loading model from {self.model_path}...")
            self.model = torch.load(self.model_path, map_location=self.device)
            self.model.eval()  # WHY: Disable dropout, batchnorm for inference
            logger.info("Model loaded successfully")

    def predict(self, inputs: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
        """
        Run inference.

        Args:
            inputs: Input array (batch_size, features)

        Returns:
            (predictions, probabilities)
        """
        self.load_model()

        # Convert to tensor
        x = torch.tensor(inputs, dtype=torch.float32).to(self.device)

        # WHY: torch.no_grad() disables gradient computation for inference
        # WHY: Reduces memory usage by 50% and speeds up by 2×
        with torch.no_grad():
            logits = self.model(x)
            probabilities = torch.softmax(logits, dim=1)
            predictions = torch.argmax(probabilities, dim=1)

        return predictions.cpu().numpy(), probabilities.cpu().numpy()

# Global model server instance
model_server = ModelServer(model_path="model.pth")

@app.on_event("startup")
async def startup_event():
    """Load model at startup for faster first request."""
    model_server.load_model()
    logger.info("Server startup complete")

@app.get("/health")
async def health_check():
    """Health check endpoint for load balancers."""
    return {
        "status": "healthy",
        "model_loaded": model_server.model is not None,
        "device": str(model_server.device)
    }

@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
    """
    Prediction endpoint with validation and error handling.

    WHY: Pydantic validates inputs automatically.
    WHY: Returns 422 for invalid inputs, not 500.
    """
    import time
    start_time = time.time()

    try:
        inputs = np.array(request.inputs)

        # Validate shape
        if inputs.ndim != 2:
            raise HTTPException(
                status_code=400,
                detail=f"Expected 2D array, got {inputs.ndim}D"
            )

        predictions, probabilities = model_server.predict(inputs)

        latency_ms = (time.time() - start_time) * 1000

        response = PredictionResponse(
            predictions=predictions.tolist(),
            probabilities=probabilities.tolist() if request.return_probabilities else None,
            latency_ms=latency_ms
        )

        logger.info(f"Predicted {len(predictions)} samples in {latency_ms:.2f}ms")
        return response

    except Exception as e:
        logger.error(f"Prediction error: {e}")
        raise HTTPException(status_code=500, detail=str(e))

# Run with: uvicorn serve_fastapi:app --host 0.0.0.0 --port 8000 --workers 4

Performance characteristics:

Metric Value Notes
Cold start 5-10s Model loading time
Warm latency 10-50ms Per request
Throughput 100-500 req/sec Single worker
Memory 2-8GB Model + runtime

Advanced: Async FastAPI with Background Tasks

# serve_fastapi_async.py
from fastapi import FastAPI, BackgroundTasks
from asyncio import Queue, create_task, sleep
import asyncio
from typing import Dict
import uuid

app = FastAPI()

class AsyncBatchPredictor:
    """
    Async batch predictor with request queuing.

    WHY: Collect multiple requests, predict as batch.
    WHY: GPU utilization: 20% (1 req) → 80% (batch of 32).
    """

    def __init__(self, model_server: ModelServer, batch_size: int = 32, wait_ms: int = 10):
        self.model_server = model_server
        self.batch_size = batch_size
        self.wait_ms = wait_ms
        self.queue: Queue = Queue()
        self.pending_requests: Dict[str, asyncio.Future] = {}

    async def start(self):
        """Start background batch processing loop."""
        create_task(self._batch_processing_loop())

    async def _batch_processing_loop(self):
        """
        Continuously collect and process batches.

        WHY: Wait for batch_size OR timeout, then process.
        WHY: Balances throughput (large batch) and latency (timeout).
        """
        while True:
            batch_requests = []
            batch_ids = []

            # Collect batch
            deadline = asyncio.get_event_loop().time() + (self.wait_ms / 1000)

            while len(batch_requests) < self.batch_size:
                timeout = max(0, deadline - asyncio.get_event_loop().time())

                try:
                    request_id, inputs = await asyncio.wait_for(
                        self.queue.get(),
                        timeout=timeout
                    )
                    batch_requests.append(inputs)
                    batch_ids.append(request_id)
                except asyncio.TimeoutError:
                    break  # Timeout reached, process what we have

            if not batch_requests:
                await sleep(0.001)  # Brief sleep before next iteration
                continue

            # Process batch
            batch_array = np.array(batch_requests)
            predictions, probabilities = self.model_server.predict(batch_array)

            # Return results to waiting requests
            for i, request_id in enumerate(batch_ids):
                future = self.pending_requests.pop(request_id)
                future.set_result((predictions[i], probabilities[i]))

    async def predict_async(self, inputs: List[float]) -> tuple[int, np.ndarray]:
        """
        Add request to queue and await result.

        WHY: Returns immediately if batch ready, waits if not.
        WHY: Client doesn't know about batching (transparent).
        """
        request_id = str(uuid.uuid4())
        future = asyncio.Future()
        self.pending_requests[request_id] = future

        await self.queue.put((request_id, inputs))

        # Wait for batch processing to complete
        prediction, probability = await future
        return prediction, probability

# Global async predictor
async_predictor = None

@app.on_event("startup")
async def startup():
    global async_predictor
    model_server.load_model()
    async_predictor = AsyncBatchPredictor(model_server, batch_size=32, wait_ms=10)
    await async_predictor.start()

@app.post("/predict_async")
async def predict_async(request: PredictionRequest):
    """
    Async prediction with automatic batching.

    WHY: 10× better GPU utilization than synchronous.
    WHY: Same latency, much higher throughput.
    """
    # Single input for simplicity (extend for batch)
    inputs = request.inputs[0]
    prediction, probability = await async_predictor.predict_async(inputs)

    return {
        "prediction": int(prediction),
        "probability": probability.tolist()
    }

Performance improvement:

Approach Throughput GPU Utilization Latency P95
Synchronous 100 req/sec 20% 15ms
Async batching 1000 req/sec 80% 25ms
Improvement 10× +10ms

Part 2: TorchServe for PyTorch Models

When to use: PyTorch models, want batteries-included solution with monitoring, metrics, and model management.

Advantages: Built-in batching, model versioning, A/B testing, metrics. Disadvantages: PyTorch-only, less flexibility, steeper learning curve.

Creating a TorchServe Handler

# handler.py
import torch
import torch.nn.functional as F
from ts.torch_handler.base_handler import BaseHandler
import logging

logger = logging.getLogger(__name__)

class CustomClassifierHandler(BaseHandler):
    """
    Custom TorchServe handler with preprocessing and batching.

    WHY: TorchServe provides: model versioning, A/B testing, metrics, monitoring.
    WHY: Built-in dynamic batching (no custom code needed).
    """

    def initialize(self, context):
        """
        Initialize handler (called once at startup).

        Args:
            context: TorchServe context with model artifacts
        """
        self.manifest = context.manifest
        properties = context.system_properties

        # Set device
        self.device = torch.device(
            "cuda:" + str(properties.get("gpu_id"))
            if torch.cuda.is_available()
            else "cpu"
        )

        # Load model
        model_dir = properties.get("model_dir")
        serialized_file = self.manifest["model"]["serializedFile"]
        model_path = f"{model_dir}/{serialized_file}"

        self.model = torch.jit.load(model_path, map_location=self.device)
        self.model.eval()

        logger.info(f"Model loaded successfully on {self.device}")

        # WHY: Initialize preprocessing parameters
        self.mean = torch.tensor([0.485, 0.456, 0.406]).to(self.device)
        self.std = torch.tensor([0.229, 0.224, 0.225]).to(self.device)

        self.initialized = True

    def preprocess(self, data):
        """
        Preprocess input data.

        Args:
            data: List of input requests

        Returns:
            Preprocessed tensor batch

        WHY: TorchServe batches requests automatically.
        WHY: This method receives multiple requests at once.
        """
        inputs = []

        for row in data:
            # Get input from request (JSON or binary)
            input_data = row.get("data") or row.get("body")

            # Parse and convert
            if isinstance(input_data, (bytes, bytearray)):
                input_data = input_data.decode("utf-8")

            # Convert to tensor
            tensor = torch.tensor(eval(input_data), dtype=torch.float32)

            # Normalize
            tensor = (tensor - self.mean) / self.std

            inputs.append(tensor)

        # Stack into batch
        batch = torch.stack(inputs).to(self.device)
        return batch

    def inference(self, batch):
        """
        Run inference on batch.

        Args:
            batch: Preprocessed batch tensor

        Returns:
            Model output

        WHY: torch.no_grad() for inference (faster, less memory).
        """
        with torch.no_grad():
            output = self.model(batch)

        return output

    def postprocess(self, inference_output):
        """
        Postprocess inference output.

        Args:
            inference_output: Raw model output

        Returns:
            List of predictions (one per request in batch)

        WHY: Convert tensors to JSON-serializable format.
        WHY: Return predictions in same order as inputs.
        """
        # Apply softmax
        probabilities = F.softmax(inference_output, dim=1)

        # Get predictions
        predictions = torch.argmax(probabilities, dim=1)

        # Convert to list (one entry per request)
        results = []
        for i in range(len(predictions)):
            results.append({
                "prediction": predictions[i].item(),
                "probabilities": probabilities[i].tolist()
            })

        return results

TorchServe Configuration

# model_config.yaml
# WHY: Configuration controls batching, workers, timeouts
# WHY: Tune these for your latency/throughput requirements

minWorkers: 2          # WHY: Minimum workers (always ready)
maxWorkers: 4          # WHY: Maximum workers (scale up under load)
batchSize: 32          # WHY: Maximum batch size (GPU utilization)
maxBatchDelay: 10      # WHY: Max wait time for batch (ms)
                       # WHY: Trade-off: larger batch (better GPU util) vs latency

responseTimeout: 120   # WHY: Request timeout (seconds)
                       # WHY: Prevent hung requests

# Device assignment
deviceType: "gpu"      # WHY: Use GPU if available
deviceIds: [0]         # WHY: Specific GPU ID

# Metrics
metrics:
  enable: true
  prometheus: true     # WHY: Export to Prometheus for monitoring

Packaging and Serving

# Package model for TorchServe
# WHY: .mar file contains model + handler + config (portable)
torch-model-archiver \
  --model-name classifier \
  --version 1.0 \
  --serialized-file model.pt \
  --handler handler.py \
  --extra-files "model_config.yaml" \
  --export-path model_store/

# Start TorchServe
# WHY: Serves on 8080 (inference), 8081 (management), 8082 (metrics)
torchserve \
  --start \
  --ncs \
  --model-store model_store \
  --models classifier.mar \
  --ts-config config.properties

# Register model (if not auto-loaded)
curl -X POST "http://localhost:8081/models?url=classifier.mar&batch_size=32&max_batch_delay=10"

# Make prediction
curl -X POST "http://localhost:8080/predictions/classifier" \
  -H "Content-Type: application/json" \
  -d '{"data": [[1.0, 2.0, 3.0]]}'

# Get metrics (for monitoring)
curl http://localhost:8082/metrics

# Unregister model (for updates)
curl -X DELETE "http://localhost:8081/models/classifier"

TorchServe advantages:

Feature Built-in? Notes
Dynamic batching Automatic, configurable
Model versioning A/B testing support
Metrics/monitoring Prometheus integration
Multi-model serving Multiple models per server
GPU management Automatic device assignment
Custom preprocessing Via handler

Part 3: gRPC for Low-Latency Serving

When to use: Low latency critical (< 10ms), internal services, microservices architecture.

Advantages: 3-5× faster than REST, binary protocol, streaming support. Disadvantages: More complex, requires proto definitions, harder debugging.

Protocol Definition

// model_service.proto
syntax = "proto3";

package modelserving;

// WHY: Define service contract in .proto file
// WHY: Code generation for multiple languages (Python, Go, Java, etc.)
service ModelService {
  // Unary RPC (one request, one response)
  rpc Predict (PredictRequest) returns (PredictResponse);

  // Server streaming (one request, stream responses)
  rpc PredictStream (PredictRequest) returns (stream PredictResponse);

  // Bidirectional streaming (stream requests and responses)
  rpc PredictBidi (stream PredictRequest) returns (stream PredictResponse);
}

message PredictRequest {
  // WHY: Repeated = array/list
  repeated float features = 1;  // WHY: Input features
  bool return_probabilities = 2;
}

message PredictResponse {
  int32 prediction = 1;
  repeated float probabilities = 2;
  float latency_ms = 3;
}

// Health check service (for load balancers)
service Health {
  rpc Check (HealthCheckRequest) returns (HealthCheckResponse);
}

message HealthCheckRequest {
  string service = 1;
}

message HealthCheckResponse {
  enum ServingStatus {
    UNKNOWN = 0;
    SERVING = 1;
    NOT_SERVING = 2;
  }
  ServingStatus status = 1;
}

gRPC Server Implementation

# serve_grpc.py
import grpc
from concurrent import futures
import time
import logging
import torch
import numpy as np

# Generated from proto file (run: python -m grpc_tools.protoc ...)
import model_service_pb2
import model_service_pb2_grpc

logger = logging.getLogger(__name__)

class ModelServicer(model_service_pb2_grpc.ModelServiceServicer):
    """
    gRPC service implementation.

    WHY: gRPC is 3-5× faster than REST (binary protocol, HTTP/2).
    WHY: Use for low-latency internal services (< 10ms target).
    """

    def __init__(self, model_path: str):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = torch.load(model_path, map_location=self.device)
        self.model.eval()
        logger.info(f"Model loaded on {self.device}")

    def Predict(self, request, context):
        """
        Unary RPC prediction.

        WHY: Fastest for single predictions.
        WHY: 3-5ms latency vs 10-15ms for REST.
        """
        start_time = time.time()

        try:
            # Convert proto repeated field to numpy
            features = np.array(request.features, dtype=np.float32)

            # Reshape for model
            x = torch.tensor(features).unsqueeze(0).to(self.device)

            # Inference
            with torch.no_grad():
                logits = self.model(x)
                probs = torch.softmax(logits, dim=1)
                pred = torch.argmax(probs, dim=1)

            latency_ms = (time.time() - start_time) * 1000

            # Build response
            response = model_service_pb2.PredictResponse(
                prediction=int(pred.item()),
                latency_ms=latency_ms
            )

            # WHY: Only include probabilities if requested (reduce bandwidth)
            if request.return_probabilities:
                response.probabilities.extend(probs[0].cpu().tolist())

            return response

        except Exception as e:
            logger.error(f"Prediction error: {e}")
            context.set_code(grpc.StatusCode.INTERNAL)
            context.set_details(str(e))
            return model_service_pb2.PredictResponse()

    def PredictStream(self, request, context):
        """
        Server streaming RPC.

        WHY: Send multiple predictions over one connection.
        WHY: Lower overhead for batch processing.
        """
        # Stream multiple predictions (example: time series)
        for i in range(10):  # Simulate 10 predictions
            response = self.Predict(request, context)
            yield response
            time.sleep(0.01)  # Simulate processing delay

    def PredictBidi(self, request_iterator, context):
        """
        Bidirectional streaming RPC.

        WHY: Real-time inference (send request, get response immediately).
        WHY: Lowest latency for streaming use cases.
        """
        for request in request_iterator:
            response = self.Predict(request, context)
            yield response

class HealthServicer(model_service_pb2_grpc.HealthServicer):
    """Health check service for load balancers."""

    def Check(self, request, context):
        # WHY: Load balancers need health checks to route traffic
        return model_service_pb2.HealthCheckResponse(
            status=model_service_pb2.HealthCheckResponse.SERVING
        )

def serve():
    """
    Start gRPC server.

    WHY: ThreadPoolExecutor for concurrent request handling.
    WHY: max_workers controls concurrency (tune based on CPU cores).
    """
    server = grpc.server(
        futures.ThreadPoolExecutor(max_workers=10),
        options=[
            # WHY: These options optimize for low latency
            ('grpc.max_send_message_length', 10 * 1024 * 1024),  # 10MB
            ('grpc.max_receive_message_length', 10 * 1024 * 1024),
            ('grpc.so_reuseport', 1),  # WHY: Allows multiple servers on same port
            ('grpc.use_local_subchannel_pool', 1)  # WHY: Better connection reuse
        ]
    )

    # Add services
    model_service_pb2_grpc.add_ModelServiceServicer_to_server(
        ModelServicer("model.pth"), server
    )
    model_service_pb2_grpc.add_HealthServicer_to_server(
        HealthServicer(), server
    )

    # Bind to port
    server.add_insecure_port('[::]:50051')

    server.start()
    logger.info("gRPC server started on port 50051")

    server.wait_for_termination()

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    serve()

gRPC Client

# client_grpc.py
import grpc
import model_service_pb2
import model_service_pb2_grpc
import time

def benchmark_grpc_vs_rest():
    """
    Benchmark gRPC vs REST latency.

    WHY: gRPC is faster, but how much faster?
    """
    # gRPC client
    channel = grpc.insecure_channel('localhost:50051')
    stub = model_service_pb2_grpc.ModelServiceStub(channel)

    # Warm up
    request = model_service_pb2.PredictRequest(
        features=[1.0, 2.0, 3.0],
        return_probabilities=True
    )
    for _ in range(10):
        stub.Predict(request)

    # Benchmark
    iterations = 1000
    start = time.time()
    for _ in range(iterations):
        response = stub.Predict(request)
    grpc_latency = ((time.time() - start) / iterations) * 1000

    print(f"gRPC average latency: {grpc_latency:.2f}ms")

    # Compare with REST (FastAPI)
    import requests
    rest_url = "http://localhost:8000/predict"

    # Warm up
    for _ in range(10):
        requests.post(rest_url, json={"inputs": [[1.0, 2.0, 3.0]]})

    # Benchmark
    start = time.time()
    for _ in range(iterations):
        requests.post(rest_url, json={"inputs": [[1.0, 2.0, 3.0]]})
    rest_latency = ((time.time() - start) / iterations) * 1000

    print(f"REST average latency: {rest_latency:.2f}ms")
    print(f"gRPC is {rest_latency/grpc_latency:.1f}× faster")

    # Typical results:
    # gRPC: 3-5ms
    # REST: 10-15ms
    # gRPC is 3-5× faster

if __name__ == "__main__":
    benchmark_grpc_vs_rest()

gRPC vs REST comparison:

Metric gRPC REST Advantage
Latency 3-5ms 10-15ms gRPC 3× faster
Throughput 10k req/sec 3k req/sec gRPC 3× higher
Payload size Binary (smaller) JSON (larger) gRPC 30-50% smaller
Debugging Harder Easier REST
Browser support No (requires proxy) Yes REST
Streaming Native Complex (SSE/WebSocket) gRPC

Part 4: ONNX Runtime for Cross-Framework Serving

When to use: Need cross-framework support (PyTorch, TensorFlow, etc.), want maximum performance, or deploying to edge devices.

Advantages: Framework-agnostic, highly optimized, smaller deployment size. Disadvantages: Not all models convert easily, limited debugging.

Converting PyTorch to ONNX

# convert_to_onnx.py
import torch
import torch.onnx

def convert_pytorch_to_onnx(model_path: str, output_path: str):
    """
    Convert PyTorch model to ONNX format.

    WHY: ONNX is framework-agnostic (portable).
    WHY: ONNX Runtime is 2-3× faster than native PyTorch inference.
    WHY: Smaller deployment size (no PyTorch dependency).
    """
    # Load PyTorch model
    model = torch.load(model_path)
    model.eval()

    # Create dummy input (for tracing)
    dummy_input = torch.randn(1, 3, 224, 224)  # Example: image

    # Export to ONNX
    torch.onnx.export(
        model,
        dummy_input,
        output_path,
        export_params=True,  # WHY: Include model weights
        opset_version=17,    # WHY: Latest stable ONNX opset
        do_constant_folding=True,  # WHY: Optimize constants at export time
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={
            'input': {0: 'batch_size'},   # WHY: Support variable batch size
            'output': {0: 'batch_size'}
        }
    )

    print(f"Model exported to {output_path}")

    # Verify ONNX model
    import onnx
    onnx_model = onnx.load(output_path)
    onnx.checker.check_model(onnx_model)
    print("ONNX model validation successful")

# Example usage
convert_pytorch_to_onnx("model.pth", "model.onnx")

ONNX Runtime Serving

# serve_onnx.py
import onnxruntime as ort
import numpy as np
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
import logging

logger = logging.getLogger(__name__)

app = FastAPI()

class ONNXModelServer:
    """
    ONNX Runtime server with optimizations.

    WHY: ONNX Runtime is 2-3× faster than PyTorch inference.
    WHY: Smaller memory footprint (no PyTorch/TensorFlow).
    WHY: Cross-platform (Windows, Linux, Mac, mobile, edge).
    """

    def __init__(self, model_path: str):
        self.model_path = model_path
        self.session = None

    def load_model(self):
        """Load ONNX model with optimizations."""
        if self.session is None:
            # Set execution providers (GPU > CPU)
            # WHY: Tries GPU first, falls back to CPU
            providers = [
                'CUDAExecutionProvider',  # NVIDIA GPU
                'CPUExecutionProvider'     # CPU fallback
            ]

            # Session options for optimization
            sess_options = ort.SessionOptions()

            # WHY: Enable graph optimizations (fuse ops, constant folding)
            sess_options.graph_optimization_level = (
                ort.GraphOptimizationLevel.ORT_ENABLE_ALL
            )

            # WHY: Intra-op parallelism (parallel ops within graph)
            sess_options.intra_op_num_threads = 4

            # WHY: Inter-op parallelism (parallel independent subgraphs)
            sess_options.inter_op_num_threads = 2

            # WHY: Enable memory pattern optimization
            sess_options.enable_mem_pattern = True

            # WHY: Enable CPU memory arena (reduces allocation overhead)
            sess_options.enable_cpu_mem_arena = True

            self.session = ort.InferenceSession(
                self.model_path,
                sess_options=sess_options,
                providers=providers
            )

            # Get input/output metadata
            self.input_name = self.session.get_inputs()[0].name
            self.output_name = self.session.get_outputs()[0].name

            logger.info(f"ONNX model loaded: {self.model_path}")
            logger.info(f"Execution provider: {self.session.get_providers()[0]}")

    def predict(self, inputs: np.ndarray) -> np.ndarray:
        """
        Run ONNX inference.

        WHY: ONNX Runtime automatically optimizes:
        - Operator fusion (combine multiple ops)
        - Constant folding (compute constants at load time)
        - Memory reuse (reduce allocations)
        """
        self.load_model()

        # Run inference
        outputs = self.session.run(
            [self.output_name],
            {self.input_name: inputs.astype(np.float32)}
        )

        return outputs[0]

    def benchmark_vs_pytorch(self, num_iterations: int = 1000):
        """Compare ONNX vs PyTorch inference speed."""
        import time
        import torch

        dummy_input = np.random.randn(1, 3, 224, 224).astype(np.float32)

        # Warm up
        for _ in range(10):
            self.predict(dummy_input)

        # Benchmark ONNX
        start = time.time()
        for _ in range(num_iterations):
            self.predict(dummy_input)
        onnx_time = (time.time() - start) / num_iterations * 1000

        # Benchmark PyTorch
        pytorch_model = torch.load(self.model_path.replace('.onnx', '.pth'))
        pytorch_model.eval()

        dummy_tensor = torch.from_numpy(dummy_input)

        # Warm up
        with torch.no_grad():
            for _ in range(10):
                pytorch_model(dummy_tensor)

        # Benchmark
        start = time.time()
        with torch.no_grad():
            for _ in range(num_iterations):
                pytorch_model(dummy_tensor)
        pytorch_time = (time.time() - start) / num_iterations * 1000

        print(f"ONNX Runtime: {onnx_time:.2f}ms")
        print(f"PyTorch: {pytorch_time:.2f}ms")
        print(f"ONNX is {pytorch_time/onnx_time:.1f}× faster")

        # Typical results:
        # ONNX: 5-8ms
        # PyTorch: 12-20ms
        # ONNX is 2-3× faster

# Global server
onnx_server = ONNXModelServer("model.onnx")

@app.on_event("startup")
async def startup():
    onnx_server.load_model()

@app.post("/predict")
async def predict(request: PredictionRequest):
    """ONNX prediction endpoint."""
    inputs = np.array(request.inputs, dtype=np.float32)
    outputs = onnx_server.predict(inputs)

    return {
        "predictions": outputs.tolist()
    }

ONNX Runtime advantages:

Feature Benefit Measurement
Speed Optimized operators 2-3× faster than native
Size No framework dependency 10-50MB vs 500MB+ (PyTorch)
Portability Framework-agnostic PyTorch/TF/etc → ONNX
Edge deployment Lightweight runtime Mobile, IoT, embedded

Part 5: Request Batching Patterns

Core principle: Batch requests for GPU efficiency.

Why batching matters:

  • GPU utilization: 20% (single request) → 80% (batch of 32)
  • Throughput: 100 req/sec (unbatched) → 1000 req/sec (batched)
  • Cost: 10× reduction in GPU cost per request

Dynamic Batching (Adaptive)

# dynamic_batching.py
import asyncio
from asyncio import Queue, Lock
from typing import List, Tuple
import numpy as np
import time
import logging

logger = logging.getLogger(__name__)

class DynamicBatcher:
    """
    Dynamic batching with adaptive timeout.

    WHY: Static batching waits for full batch (high latency at low load).
    WHY: Dynamic batching adapts: full batch OR timeout (balanced).

    Key parameters:
    - max_batch_size: Maximum batch size (GPU memory limit)
    - max_wait_ms: Maximum wait time (latency target)

    Trade-off:
    - Larger batch → better GPU utilization, higher throughput
    - Shorter timeout → lower latency, worse GPU utilization
    """

    def __init__(
        self,
        model_server,
        max_batch_size: int = 32,
        max_wait_ms: int = 10
    ):
        self.model_server = model_server
        self.max_batch_size = max_batch_size
        self.max_wait_ms = max_wait_ms

        self.request_queue: Queue = Queue()
        self.batch_lock = Lock()

        self.stats = {
            "total_requests": 0,
            "total_batches": 0,
            "avg_batch_size": 0,
            "gpu_utilization": 0
        }

    async def start(self):
        """Start batch processing loop."""
        asyncio.create_task(self._batch_loop())

    async def _batch_loop(self):
        """
        Main batching loop.

        Algorithm:
        1. Wait for first request
        2. Start timeout timer
        3. Collect requests until:
           - Batch full (max_batch_size reached)
           - OR timeout expired (max_wait_ms)
        4. Process batch
        5. Return results to waiting requests
        """
        while True:
            batch = []
            futures = []

            # Wait for first request (no timeout)
            request_data, future = await self.request_queue.get()
            batch.append(request_data)
            futures.append(future)

            # Start deadline timer
            deadline = asyncio.get_event_loop().time() + (self.max_wait_ms / 1000)

            # Collect additional requests until batch full or timeout
            while len(batch) < self.max_batch_size:
                remaining_time = max(0, deadline - asyncio.get_event_loop().time())

                try:
                    request_data, future = await asyncio.wait_for(
                        self.request_queue.get(),
                        timeout=remaining_time
                    )
                    batch.append(request_data)
                    futures.append(future)
                except asyncio.TimeoutError:
                    # Timeout: process what we have
                    break

            # Process batch
            await self._process_batch(batch, futures)

    async def _process_batch(
        self,
        batch: List[np.ndarray],
        futures: List[asyncio.Future]
    ):
        """Process batch and return results."""
        batch_size = len(batch)

        # Convert to batch array
        batch_array = np.array(batch)

        # Run inference
        start_time = time.time()
        predictions, probabilities = self.model_server.predict(batch_array)
        inference_time = (time.time() - start_time) * 1000

        # Update stats
        self.stats["total_requests"] += batch_size
        self.stats["total_batches"] += 1
        self.stats["avg_batch_size"] = (
            self.stats["total_requests"] / self.stats["total_batches"]
        )
        self.stats["gpu_utilization"] = (
            self.stats["avg_batch_size"] / self.max_batch_size * 100
        )

        logger.info(
            f"Processed batch: size={batch_size}, "
            f"inference_time={inference_time:.2f}ms, "
            f"avg_batch_size={self.stats['avg_batch_size']:.1f}, "
            f"gpu_util={self.stats['gpu_utilization']:.1f}%"
        )

        # Return results to waiting requests
        for i, future in enumerate(futures):
            if not future.done():
                future.set_result((predictions[i], probabilities[i]))

    async def predict(self, inputs: np.ndarray) -> Tuple[int, np.ndarray]:
        """
        Add request to batch queue.

        WHY: Transparent batching (caller doesn't see batching).
        WHY: Returns when batch processed (might wait for other requests).
        """
        future = asyncio.Future()
        await self.request_queue.put((inputs, future))

        # Wait for batch to be processed
        prediction, probability = await future
        return prediction, probability

    def get_stats(self):
        """Get batching statistics."""
        return self.stats

# Example usage with load simulation
async def simulate_load():
    """
    Simulate varying load to demonstrate dynamic batching.

    WHY: Shows how batcher adapts to load:
    - High load: Fills batches quickly (high GPU util)
    - Low load: Processes smaller batches (low latency)
    """
    from serve_fastapi import ModelServer

    model_server = ModelServer("model.pth")
    model_server.load_model()

    batcher = DynamicBatcher(
        model_server,
        max_batch_size=32,
        max_wait_ms=10
    )
    await batcher.start()

    # High load (32 concurrent requests)
    print("Simulating HIGH LOAD (32 concurrent)...")
    tasks = []
    for i in range(32):
        inputs = np.random.randn(10)
        task = asyncio.create_task(batcher.predict(inputs))
        tasks.append(task)

    results = await asyncio.gather(*tasks)
    print(f"High load results: {len(results)} predictions")
    print(f"Stats: {batcher.get_stats()}")
    # Expected: avg_batch_size ≈ 32, gpu_util ≈ 100%

    await asyncio.sleep(0.1)  # Reset

    # Low load (1 request at a time)
    print("\nSimulating LOW LOAD (1 at a time)...")
    for i in range(10):
        inputs = np.random.randn(10)
        result = await batcher.predict(inputs)
        await asyncio.sleep(0.02)  # 20ms between requests

    print(f"Stats: {batcher.get_stats()}")
    # Expected: avg_batch_size ≈ 1-2, gpu_util ≈ 5-10%
    # WHY: Timeout expires before batch fills (low latency maintained)

if __name__ == "__main__":
    asyncio.run(simulate_load())

Batching performance:

Load Batch Size GPU Util Latency Throughput
High (100 req/sec) 28-32 90% 12ms 1000 req/sec
Medium (20 req/sec) 8-12 35% 11ms 200 req/sec
Low (5 req/sec) 1-2 10% 10ms 50 req/sec

Key insight: Dynamic batching adapts to load while maintaining latency target.


Part 6: Containerization

Why containerize: "Works on my machine" → "Works everywhere"

Benefits:

  • Reproducible builds (same dependencies, versions)
  • Isolated environment (no conflicts)
  • Portable deployment (dev, staging, prod identical)
  • Easy scaling (K8s, Docker Swarm)

Multi-Stage Docker Build

# Dockerfile
# WHY: Multi-stage build reduces image size by 50-80%
# WHY: Build stage has compilers, runtime stage only has runtime deps

# ==================== Stage 1: Build ====================
FROM python:3.11-slim as builder

# WHY: Install build dependencies (needed for compilation)
RUN apt-get update && apt-get install -y \
    gcc \
    g++ \
    && rm -rf /var/lib/apt/lists/*

# WHY: Create virtual environment in builder stage
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"

# WHY: Copy only requirements first (layer caching)
# WHY: If requirements.txt unchanged, this layer is cached
COPY requirements.txt .

# WHY: Install Python dependencies
RUN pip install --no-cache-dir -r requirements.txt

# ==================== Stage 2: Runtime ====================
FROM python:3.11-slim

# WHY: Copy only virtual environment from builder (not build tools)
COPY --from=builder /opt/venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"

# WHY: Set working directory
WORKDIR /app

# WHY: Copy application code
COPY serve_fastapi.py .
COPY model.pth .

# WHY: Non-root user for security
RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
USER appuser

# WHY: Expose port (documentation, not enforcement)
EXPOSE 8000

# WHY: Health check for container orchestration
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
  CMD curl -f http://localhost:8000/health || exit 1

# WHY: Run with uvicorn (production ASGI server)
CMD ["uvicorn", "serve_fastapi:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]

Docker Compose for Multi-Service

# docker-compose.yml
# WHY: Docker Compose for local development and testing
# WHY: Defines multiple services (API, model, monitoring)

version: '3.8'

services:
  # Model serving API
  model-api:
    build:
      context: .
      dockerfile: Dockerfile
    ports:
      - "8000:8000"
    environment:
      # WHY: Environment variables for configuration
      - MODEL_PATH=/app/model.pth
      - LOG_LEVEL=INFO
    volumes:
      # WHY: Mount model directory (for updates without rebuild)
      - ./models:/app/models:ro
    deploy:
      resources:
        # WHY: Limit resources to prevent resource exhaustion
        limits:
          cpus: '2'
          memory: 4G
        reservations:
          # WHY: Reserve GPU (requires nvidia-docker)
          devices:
            - driver: nvidia
              count: 1
              capabilities: [gpu]
    healthcheck:
      test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
      interval: 30s
      timeout: 10s
      retries: 3

  # Redis for caching
  redis:
    image: redis:7-alpine
    ports:
      - "6379:6379"
    volumes:
      - redis-data:/data
    command: redis-server --appendonly yes

  # Prometheus for metrics
  prometheus:
    image: prom/prometheus:latest
    ports:
      - "9090:9090"
    volumes:
      - ./prometheus.yml:/etc/prometheus/prometheus.yml
      - prometheus-data:/prometheus
    command:
      - '--config.file=/etc/prometheus/prometheus.yml'

  # Grafana for visualization
  grafana:
    image: grafana/grafana:latest
    ports:
      - "3000:3000"
    environment:
      - GF_SECURITY_ADMIN_PASSWORD=admin
    volumes:
      - grafana-data:/var/lib/grafana

volumes:
  redis-data:
  prometheus-data:
  grafana-data:

Build and Deploy

# Build image
# WHY: Tag with version for rollback capability
docker build -t model-api:1.0.0 .

# Run container
docker run -d \
  --name model-api \
  -p 8000:8000 \
  --gpus all \
  model-api:1.0.0

# Check logs
docker logs -f model-api

# Test API
curl http://localhost:8000/health

# Start all services with docker-compose
docker-compose up -d

# Scale API service (multiple instances)
# WHY: Load balancer distributes traffic across instances
docker-compose up -d --scale model-api=3

# View logs
docker-compose logs -f model-api

# Stop all services
docker-compose down

Container image sizes:

Stage Size Contents
Full build 2.5 GB Python + build tools + deps + model
Multi-stage 800 MB Python + runtime deps + model
Optimized 400 MB Minimal Python + deps + model
Savings 84% From 2.5 GB → 400 MB

Part 7: Framework Selection Guide

Decision Matrix

# framework_selector.py
from enum import Enum
from typing import List

class Requirement(Enum):
    FLEXIBILITY = "flexibility"          # Custom preprocessing, business logic
    BATTERIES_INCLUDED = "batteries"     # Minimal setup, built-in features
    LOW_LATENCY = "low_latency"         # < 10ms target
    CROSS_FRAMEWORK = "cross_framework"  # PyTorch + TensorFlow support
    EDGE_DEPLOYMENT = "edge"            # Mobile, IoT, embedded
    EASE_OF_DEBUG = "debug"             # Development experience
    HIGH_THROUGHPUT = "throughput"      # > 1000 req/sec

class Framework(Enum):
    FASTAPI = "fastapi"
    TORCHSERVE = "torchserve"
    GRPC = "grpc"
    ONNX = "onnx"

# Framework capabilities (0-5 scale)
FRAMEWORK_SCORES = {
    Framework.FASTAPI: {
        Requirement.FLEXIBILITY: 5,        # Full control
        Requirement.BATTERIES_INCLUDED: 2, # Manual implementation
        Requirement.LOW_LATENCY: 3,        # 10-20ms
        Requirement.CROSS_FRAMEWORK: 4,    # Any Python model
        Requirement.EDGE_DEPLOYMENT: 2,    # Heavyweight
        Requirement.EASE_OF_DEBUG: 5,      # Excellent debugging
        Requirement.HIGH_THROUGHPUT: 3     # 100-500 req/sec
    },
    Framework.TORCHSERVE: {
        Requirement.FLEXIBILITY: 3,        # Customizable via handlers
        Requirement.BATTERIES_INCLUDED: 5, # Everything built-in
        Requirement.LOW_LATENCY: 4,        # 5-15ms
        Requirement.CROSS_FRAMEWORK: 1,    # PyTorch only
        Requirement.EDGE_DEPLOYMENT: 2,    # Heavyweight
        Requirement.EASE_OF_DEBUG: 3,      # Learning curve
        Requirement.HIGH_THROUGHPUT: 5     # 1000+ req/sec with batching
    },
    Framework.GRPC: {
        Requirement.FLEXIBILITY: 4,        # Binary protocol, custom logic
        Requirement.BATTERIES_INCLUDED: 2, # Manual implementation
        Requirement.LOW_LATENCY: 5,        # 3-8ms
        Requirement.CROSS_FRAMEWORK: 4,    # Any model
        Requirement.EDGE_DEPLOYMENT: 3,    # Moderate size
        Requirement.EASE_OF_DEBUG: 2,      # Binary protocol harder
        Requirement.HIGH_THROUGHPUT: 5     # 1000+ req/sec
    },
    Framework.ONNX: {
        Requirement.FLEXIBILITY: 3,        # Limited to ONNX ops
        Requirement.BATTERIES_INCLUDED: 3, # Runtime provided
        Requirement.LOW_LATENCY: 5,        # 2-6ms (optimized)
        Requirement.CROSS_FRAMEWORK: 5,    # Any framework → ONNX
        Requirement.EDGE_DEPLOYMENT: 5,    # Lightweight runtime
        Requirement.EASE_OF_DEBUG: 2,      # Conversion can be tricky
        Requirement.HIGH_THROUGHPUT: 4     # 500-1000 req/sec
    }
}

def select_framework(
    requirements: List[Requirement],
    weights: List[float] = None
) -> Framework:
    """
    Select best framework based on requirements.

    Args:
        requirements: List of requirements
        weights: Importance weight for each requirement (0-1)

    Returns:
        Best framework
    """
    if weights is None:
        weights = [1.0] * len(requirements)

    scores = {}

    for framework in Framework:
        score = 0
        for req, weight in zip(requirements, weights):
            score += FRAMEWORK_SCORES[framework][req] * weight
        scores[framework] = score

    best_framework = max(scores, key=scores.get)

    print(f"\nFramework Selection:")
    print(f"Requirements: {[r.value for r in requirements]}")
    print(f"\nScores:")
    for framework, score in sorted(scores.items(), key=lambda x: x[1], reverse=True):
        print(f"  {framework.value}: {score:.1f}")

    return best_framework

# Example use cases
print("=" * 60)
print("Use Case 1: Prototyping with flexibility")
print("=" * 60)
selected = select_framework([
    Requirement.FLEXIBILITY,
    Requirement.EASE_OF_DEBUG
])
print(f"\nRecommendation: {selected.value}")
# Expected: FASTAPI

print("\n" + "=" * 60)
print("Use Case 2: Production PyTorch with minimal setup")
print("=" * 60)
selected = select_framework([
    Requirement.BATTERIES_INCLUDED,
    Requirement.HIGH_THROUGHPUT
])
print(f"\nRecommendation: {selected.value}")
# Expected: TORCHSERVE

print("\n" + "=" * 60)
print("Use Case 3: Low-latency microservice")
print("=" * 60)
selected = select_framework([
    Requirement.LOW_LATENCY,
    Requirement.HIGH_THROUGHPUT
])
print(f"\nRecommendation: {selected.value}")
# Expected: GRPC or ONNX

print("\n" + "=" * 60)
print("Use Case 4: Edge deployment (mobile/IoT)")
print("=" * 60)
selected = select_framework([
    Requirement.EDGE_DEPLOYMENT,
    Requirement.CROSS_FRAMEWORK,
    Requirement.LOW_LATENCY
])
print(f"\nRecommendation: {selected.value}")
# Expected: ONNX

print("\n" + "=" * 60)
print("Use Case 5: Multi-framework ML platform")
print("=" * 60)
selected = select_framework([
    Requirement.CROSS_FRAMEWORK,
    Requirement.HIGH_THROUGHPUT,
    Requirement.BATTERIES_INCLUDED
])
print(f"\nRecommendation: {selected.value}")
# Expected: ONNX or TORCHSERVE (depending on weights)

Quick Reference Guide

Scenario Framework Why
Prototyping FastAPI Fast iteration, easy debugging
PyTorch production TorchServe Built-in batching, metrics, management
Internal microservices gRPC Lowest latency, high throughput
Multi-framework ONNX Runtime Framework-agnostic, optimized
Edge/mobile ONNX Runtime Lightweight, cross-platform
Custom preprocessing FastAPI Full flexibility
High throughput batch TorchServe + batching Dynamic batching built-in
Real-time streaming gRPC Bidirectional streaming

Summary

Model serving is pattern matching, not one-size-fits-all.

Core patterns:

  1. FastAPI: Flexibility, custom logic, easy debugging
  2. TorchServe: PyTorch batteries-included, built-in batching
  3. gRPC: Low latency (3-5ms), high throughput, microservices
  4. ONNX Runtime: Cross-framework, optimized, edge deployment
  5. Dynamic batching: Adaptive batch size, balances latency and throughput
  6. Containerization: Reproducible, portable, scalable

Selection checklist:

  • ✓ Identify primary requirement (flexibility, latency, throughput, etc.)
  • ✓ Match requirement to framework strengths
  • ✓ Consider deployment environment (cloud, edge, on-prem)
  • ✓ Evaluate trade-offs (development speed vs performance)
  • ✓ Implement batching if GPU-based (10× better utilization)
  • ✓ Containerize for reproducibility
  • ✓ Monitor metrics (latency, throughput, GPU util)
  • ✓ Iterate based on production data

Anti-patterns to avoid:

  • ✗ model.pkl in repo (dependency hell)
  • ✗ gRPC for simple REST use cases (over-engineering)
  • ✗ No batching with GPU (wasted 80% capacity)
  • ✗ Not containerized (deployment inconsistency)
  • ✗ Static batching (poor latency at low load)

Production-ready model serving requires matching infrastructure pattern to requirements.