| name | model-serving-patterns |
| description | Design production serving with FastAPI, TorchServe, gRPC, batching, and containerization. |
Model Serving Patterns Skill
When to Use This Skill
Use this skill when:
- Deploying ML models to production environments
- Building model serving APIs for real-time inference
- Optimizing model serving for throughput and latency
- Containerizing models for consistent deployment
- Implementing request batching for efficiency
- Choosing between serving frameworks and protocols
When NOT to use: Notebook prototyping, training jobs, or single-prediction scripts where serving infrastructure is premature.
Core Principle
Serving infrastructure is not one-size-fits-all. Pattern selection is context-dependent.
Without proper serving infrastructure:
- model.pkl in repo (manual dependency hell)
- Wrong protocol choice (gRPC for simple REST use cases)
- No batching (1 req/sec instead of 100 req/sec)
- Not containerized (works on my machine syndrome)
- Static batching when dynamic needed (underutilized GPU)
Formula: Right framework (FastAPI vs TorchServe vs gRPC vs ONNX) + Request batching (dynamic > static) + Containerization (Docker + model) + Clear selection criteria = Production-ready serving.
Serving Framework Decision Tree
┌────────────────────────────────────────┐
│ What's your primary requirement? │
└──────────────┬─────────────────────────┘
│
┌───────┴───────┐
▼ ▼
Flexibility Batteries Included
│ │
▼ ▼
FastAPI TorchServe
(Custom) (PyTorch)
│ │
│ ┌───────┴───────┐
│ ▼ ▼
│ Low Latency Cross-Framework
│ │ │
│ ▼ ▼
│ gRPC ONNX Runtime
│ │ │
└───────┴───────────────┘
│
▼
┌───────────────────────┐
│ Add Request Batching │
│ Dynamic > Static │
└───────────┬────────────┘
│
▼
┌───────────────────────┐
│ Containerize with │
│ Docker + Dependencies│
└────────────────────────┘
Part 1: FastAPI for Custom Serving
When to use: Need flexibility, custom preprocessing, or non-standard workflows.
Advantages: Full control, easy debugging, Python ecosystem integration. Disadvantages: Manual optimization, no built-in model management.
Basic FastAPI Serving
# serve_fastapi.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import torch
import numpy as np
from typing import List, Optional
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Model Serving API", version="1.0.0")
class PredictionRequest(BaseModel):
"""Request schema with validation."""
inputs: List[List[float]] = Field(..., description="Input features as 2D array")
return_probabilities: bool = Field(False, description="Return class probabilities")
class Config:
schema_extra = {
"example": {
"inputs": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
"return_probabilities": True
}
}
class PredictionResponse(BaseModel):
"""Response schema."""
predictions: List[int]
probabilities: Optional[List[List[float]]] = None
latency_ms: float
class ModelServer:
"""
Model server with lazy loading and caching.
WHY: Load model once at startup, reuse across requests.
WHY: Avoids 5-10 second model loading per request.
"""
def __init__(self, model_path: str):
self.model_path = model_path
self.model = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model(self):
"""Load model on first request (lazy loading)."""
if self.model is None:
logger.info(f"Loading model from {self.model_path}...")
self.model = torch.load(self.model_path, map_location=self.device)
self.model.eval() # WHY: Disable dropout, batchnorm for inference
logger.info("Model loaded successfully")
def predict(self, inputs: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
Run inference.
Args:
inputs: Input array (batch_size, features)
Returns:
(predictions, probabilities)
"""
self.load_model()
# Convert to tensor
x = torch.tensor(inputs, dtype=torch.float32).to(self.device)
# WHY: torch.no_grad() disables gradient computation for inference
# WHY: Reduces memory usage by 50% and speeds up by 2×
with torch.no_grad():
logits = self.model(x)
probabilities = torch.softmax(logits, dim=1)
predictions = torch.argmax(probabilities, dim=1)
return predictions.cpu().numpy(), probabilities.cpu().numpy()
# Global model server instance
model_server = ModelServer(model_path="model.pth")
@app.on_event("startup")
async def startup_event():
"""Load model at startup for faster first request."""
model_server.load_model()
logger.info("Server startup complete")
@app.get("/health")
async def health_check():
"""Health check endpoint for load balancers."""
return {
"status": "healthy",
"model_loaded": model_server.model is not None,
"device": str(model_server.device)
}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
"""
Prediction endpoint with validation and error handling.
WHY: Pydantic validates inputs automatically.
WHY: Returns 422 for invalid inputs, not 500.
"""
import time
start_time = time.time()
try:
inputs = np.array(request.inputs)
# Validate shape
if inputs.ndim != 2:
raise HTTPException(
status_code=400,
detail=f"Expected 2D array, got {inputs.ndim}D"
)
predictions, probabilities = model_server.predict(inputs)
latency_ms = (time.time() - start_time) * 1000
response = PredictionResponse(
predictions=predictions.tolist(),
probabilities=probabilities.tolist() if request.return_probabilities else None,
latency_ms=latency_ms
)
logger.info(f"Predicted {len(predictions)} samples in {latency_ms:.2f}ms")
return response
except Exception as e:
logger.error(f"Prediction error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Run with: uvicorn serve_fastapi:app --host 0.0.0.0 --port 8000 --workers 4
Performance characteristics:
| Metric | Value | Notes |
|---|---|---|
| Cold start | 5-10s | Model loading time |
| Warm latency | 10-50ms | Per request |
| Throughput | 100-500 req/sec | Single worker |
| Memory | 2-8GB | Model + runtime |
Advanced: Async FastAPI with Background Tasks
# serve_fastapi_async.py
from fastapi import FastAPI, BackgroundTasks
from asyncio import Queue, create_task, sleep
import asyncio
from typing import Dict
import uuid
app = FastAPI()
class AsyncBatchPredictor:
"""
Async batch predictor with request queuing.
WHY: Collect multiple requests, predict as batch.
WHY: GPU utilization: 20% (1 req) → 80% (batch of 32).
"""
def __init__(self, model_server: ModelServer, batch_size: int = 32, wait_ms: int = 10):
self.model_server = model_server
self.batch_size = batch_size
self.wait_ms = wait_ms
self.queue: Queue = Queue()
self.pending_requests: Dict[str, asyncio.Future] = {}
async def start(self):
"""Start background batch processing loop."""
create_task(self._batch_processing_loop())
async def _batch_processing_loop(self):
"""
Continuously collect and process batches.
WHY: Wait for batch_size OR timeout, then process.
WHY: Balances throughput (large batch) and latency (timeout).
"""
while True:
batch_requests = []
batch_ids = []
# Collect batch
deadline = asyncio.get_event_loop().time() + (self.wait_ms / 1000)
while len(batch_requests) < self.batch_size:
timeout = max(0, deadline - asyncio.get_event_loop().time())
try:
request_id, inputs = await asyncio.wait_for(
self.queue.get(),
timeout=timeout
)
batch_requests.append(inputs)
batch_ids.append(request_id)
except asyncio.TimeoutError:
break # Timeout reached, process what we have
if not batch_requests:
await sleep(0.001) # Brief sleep before next iteration
continue
# Process batch
batch_array = np.array(batch_requests)
predictions, probabilities = self.model_server.predict(batch_array)
# Return results to waiting requests
for i, request_id in enumerate(batch_ids):
future = self.pending_requests.pop(request_id)
future.set_result((predictions[i], probabilities[i]))
async def predict_async(self, inputs: List[float]) -> tuple[int, np.ndarray]:
"""
Add request to queue and await result.
WHY: Returns immediately if batch ready, waits if not.
WHY: Client doesn't know about batching (transparent).
"""
request_id = str(uuid.uuid4())
future = asyncio.Future()
self.pending_requests[request_id] = future
await self.queue.put((request_id, inputs))
# Wait for batch processing to complete
prediction, probability = await future
return prediction, probability
# Global async predictor
async_predictor = None
@app.on_event("startup")
async def startup():
global async_predictor
model_server.load_model()
async_predictor = AsyncBatchPredictor(model_server, batch_size=32, wait_ms=10)
await async_predictor.start()
@app.post("/predict_async")
async def predict_async(request: PredictionRequest):
"""
Async prediction with automatic batching.
WHY: 10× better GPU utilization than synchronous.
WHY: Same latency, much higher throughput.
"""
# Single input for simplicity (extend for batch)
inputs = request.inputs[0]
prediction, probability = await async_predictor.predict_async(inputs)
return {
"prediction": int(prediction),
"probability": probability.tolist()
}
Performance improvement:
| Approach | Throughput | GPU Utilization | Latency P95 |
|---|---|---|---|
| Synchronous | 100 req/sec | 20% | 15ms |
| Async batching | 1000 req/sec | 80% | 25ms |
| Improvement | 10× | 4× | +10ms |
Part 2: TorchServe for PyTorch Models
When to use: PyTorch models, want batteries-included solution with monitoring, metrics, and model management.
Advantages: Built-in batching, model versioning, A/B testing, metrics. Disadvantages: PyTorch-only, less flexibility, steeper learning curve.
Creating a TorchServe Handler
# handler.py
import torch
import torch.nn.functional as F
from ts.torch_handler.base_handler import BaseHandler
import logging
logger = logging.getLogger(__name__)
class CustomClassifierHandler(BaseHandler):
"""
Custom TorchServe handler with preprocessing and batching.
WHY: TorchServe provides: model versioning, A/B testing, metrics, monitoring.
WHY: Built-in dynamic batching (no custom code needed).
"""
def initialize(self, context):
"""
Initialize handler (called once at startup).
Args:
context: TorchServe context with model artifacts
"""
self.manifest = context.manifest
properties = context.system_properties
# Set device
self.device = torch.device(
"cuda:" + str(properties.get("gpu_id"))
if torch.cuda.is_available()
else "cpu"
)
# Load model
model_dir = properties.get("model_dir")
serialized_file = self.manifest["model"]["serializedFile"]
model_path = f"{model_dir}/{serialized_file}"
self.model = torch.jit.load(model_path, map_location=self.device)
self.model.eval()
logger.info(f"Model loaded successfully on {self.device}")
# WHY: Initialize preprocessing parameters
self.mean = torch.tensor([0.485, 0.456, 0.406]).to(self.device)
self.std = torch.tensor([0.229, 0.224, 0.225]).to(self.device)
self.initialized = True
def preprocess(self, data):
"""
Preprocess input data.
Args:
data: List of input requests
Returns:
Preprocessed tensor batch
WHY: TorchServe batches requests automatically.
WHY: This method receives multiple requests at once.
"""
inputs = []
for row in data:
# Get input from request (JSON or binary)
input_data = row.get("data") or row.get("body")
# Parse and convert
if isinstance(input_data, (bytes, bytearray)):
input_data = input_data.decode("utf-8")
# Convert to tensor
tensor = torch.tensor(eval(input_data), dtype=torch.float32)
# Normalize
tensor = (tensor - self.mean) / self.std
inputs.append(tensor)
# Stack into batch
batch = torch.stack(inputs).to(self.device)
return batch
def inference(self, batch):
"""
Run inference on batch.
Args:
batch: Preprocessed batch tensor
Returns:
Model output
WHY: torch.no_grad() for inference (faster, less memory).
"""
with torch.no_grad():
output = self.model(batch)
return output
def postprocess(self, inference_output):
"""
Postprocess inference output.
Args:
inference_output: Raw model output
Returns:
List of predictions (one per request in batch)
WHY: Convert tensors to JSON-serializable format.
WHY: Return predictions in same order as inputs.
"""
# Apply softmax
probabilities = F.softmax(inference_output, dim=1)
# Get predictions
predictions = torch.argmax(probabilities, dim=1)
# Convert to list (one entry per request)
results = []
for i in range(len(predictions)):
results.append({
"prediction": predictions[i].item(),
"probabilities": probabilities[i].tolist()
})
return results
TorchServe Configuration
# model_config.yaml
# WHY: Configuration controls batching, workers, timeouts
# WHY: Tune these for your latency/throughput requirements
minWorkers: 2 # WHY: Minimum workers (always ready)
maxWorkers: 4 # WHY: Maximum workers (scale up under load)
batchSize: 32 # WHY: Maximum batch size (GPU utilization)
maxBatchDelay: 10 # WHY: Max wait time for batch (ms)
# WHY: Trade-off: larger batch (better GPU util) vs latency
responseTimeout: 120 # WHY: Request timeout (seconds)
# WHY: Prevent hung requests
# Device assignment
deviceType: "gpu" # WHY: Use GPU if available
deviceIds: [0] # WHY: Specific GPU ID
# Metrics
metrics:
enable: true
prometheus: true # WHY: Export to Prometheus for monitoring
Packaging and Serving
# Package model for TorchServe
# WHY: .mar file contains model + handler + config (portable)
torch-model-archiver \
--model-name classifier \
--version 1.0 \
--serialized-file model.pt \
--handler handler.py \
--extra-files "model_config.yaml" \
--export-path model_store/
# Start TorchServe
# WHY: Serves on 8080 (inference), 8081 (management), 8082 (metrics)
torchserve \
--start \
--ncs \
--model-store model_store \
--models classifier.mar \
--ts-config config.properties
# Register model (if not auto-loaded)
curl -X POST "http://localhost:8081/models?url=classifier.mar&batch_size=32&max_batch_delay=10"
# Make prediction
curl -X POST "http://localhost:8080/predictions/classifier" \
-H "Content-Type: application/json" \
-d '{"data": [[1.0, 2.0, 3.0]]}'
# Get metrics (for monitoring)
curl http://localhost:8082/metrics
# Unregister model (for updates)
curl -X DELETE "http://localhost:8081/models/classifier"
TorchServe advantages:
| Feature | Built-in? | Notes |
|---|---|---|
| Dynamic batching | ✓ | Automatic, configurable |
| Model versioning | ✓ | A/B testing support |
| Metrics/monitoring | ✓ | Prometheus integration |
| Multi-model serving | ✓ | Multiple models per server |
| GPU management | ✓ | Automatic device assignment |
| Custom preprocessing | ✓ | Via handler |
Part 3: gRPC for Low-Latency Serving
When to use: Low latency critical (< 10ms), internal services, microservices architecture.
Advantages: 3-5× faster than REST, binary protocol, streaming support. Disadvantages: More complex, requires proto definitions, harder debugging.
Protocol Definition
// model_service.proto
syntax = "proto3";
package modelserving;
// WHY: Define service contract in .proto file
// WHY: Code generation for multiple languages (Python, Go, Java, etc.)
service ModelService {
// Unary RPC (one request, one response)
rpc Predict (PredictRequest) returns (PredictResponse);
// Server streaming (one request, stream responses)
rpc PredictStream (PredictRequest) returns (stream PredictResponse);
// Bidirectional streaming (stream requests and responses)
rpc PredictBidi (stream PredictRequest) returns (stream PredictResponse);
}
message PredictRequest {
// WHY: Repeated = array/list
repeated float features = 1; // WHY: Input features
bool return_probabilities = 2;
}
message PredictResponse {
int32 prediction = 1;
repeated float probabilities = 2;
float latency_ms = 3;
}
// Health check service (for load balancers)
service Health {
rpc Check (HealthCheckRequest) returns (HealthCheckResponse);
}
message HealthCheckRequest {
string service = 1;
}
message HealthCheckResponse {
enum ServingStatus {
UNKNOWN = 0;
SERVING = 1;
NOT_SERVING = 2;
}
ServingStatus status = 1;
}
gRPC Server Implementation
# serve_grpc.py
import grpc
from concurrent import futures
import time
import logging
import torch
import numpy as np
# Generated from proto file (run: python -m grpc_tools.protoc ...)
import model_service_pb2
import model_service_pb2_grpc
logger = logging.getLogger(__name__)
class ModelServicer(model_service_pb2_grpc.ModelServiceServicer):
"""
gRPC service implementation.
WHY: gRPC is 3-5× faster than REST (binary protocol, HTTP/2).
WHY: Use for low-latency internal services (< 10ms target).
"""
def __init__(self, model_path: str):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = torch.load(model_path, map_location=self.device)
self.model.eval()
logger.info(f"Model loaded on {self.device}")
def Predict(self, request, context):
"""
Unary RPC prediction.
WHY: Fastest for single predictions.
WHY: 3-5ms latency vs 10-15ms for REST.
"""
start_time = time.time()
try:
# Convert proto repeated field to numpy
features = np.array(request.features, dtype=np.float32)
# Reshape for model
x = torch.tensor(features).unsqueeze(0).to(self.device)
# Inference
with torch.no_grad():
logits = self.model(x)
probs = torch.softmax(logits, dim=1)
pred = torch.argmax(probs, dim=1)
latency_ms = (time.time() - start_time) * 1000
# Build response
response = model_service_pb2.PredictResponse(
prediction=int(pred.item()),
latency_ms=latency_ms
)
# WHY: Only include probabilities if requested (reduce bandwidth)
if request.return_probabilities:
response.probabilities.extend(probs[0].cpu().tolist())
return response
except Exception as e:
logger.error(f"Prediction error: {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(e))
return model_service_pb2.PredictResponse()
def PredictStream(self, request, context):
"""
Server streaming RPC.
WHY: Send multiple predictions over one connection.
WHY: Lower overhead for batch processing.
"""
# Stream multiple predictions (example: time series)
for i in range(10): # Simulate 10 predictions
response = self.Predict(request, context)
yield response
time.sleep(0.01) # Simulate processing delay
def PredictBidi(self, request_iterator, context):
"""
Bidirectional streaming RPC.
WHY: Real-time inference (send request, get response immediately).
WHY: Lowest latency for streaming use cases.
"""
for request in request_iterator:
response = self.Predict(request, context)
yield response
class HealthServicer(model_service_pb2_grpc.HealthServicer):
"""Health check service for load balancers."""
def Check(self, request, context):
# WHY: Load balancers need health checks to route traffic
return model_service_pb2.HealthCheckResponse(
status=model_service_pb2.HealthCheckResponse.SERVING
)
def serve():
"""
Start gRPC server.
WHY: ThreadPoolExecutor for concurrent request handling.
WHY: max_workers controls concurrency (tune based on CPU cores).
"""
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=10),
options=[
# WHY: These options optimize for low latency
('grpc.max_send_message_length', 10 * 1024 * 1024), # 10MB
('grpc.max_receive_message_length', 10 * 1024 * 1024),
('grpc.so_reuseport', 1), # WHY: Allows multiple servers on same port
('grpc.use_local_subchannel_pool', 1) # WHY: Better connection reuse
]
)
# Add services
model_service_pb2_grpc.add_ModelServiceServicer_to_server(
ModelServicer("model.pth"), server
)
model_service_pb2_grpc.add_HealthServicer_to_server(
HealthServicer(), server
)
# Bind to port
server.add_insecure_port('[::]:50051')
server.start()
logger.info("gRPC server started on port 50051")
server.wait_for_termination()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
serve()
gRPC Client
# client_grpc.py
import grpc
import model_service_pb2
import model_service_pb2_grpc
import time
def benchmark_grpc_vs_rest():
"""
Benchmark gRPC vs REST latency.
WHY: gRPC is faster, but how much faster?
"""
# gRPC client
channel = grpc.insecure_channel('localhost:50051')
stub = model_service_pb2_grpc.ModelServiceStub(channel)
# Warm up
request = model_service_pb2.PredictRequest(
features=[1.0, 2.0, 3.0],
return_probabilities=True
)
for _ in range(10):
stub.Predict(request)
# Benchmark
iterations = 1000
start = time.time()
for _ in range(iterations):
response = stub.Predict(request)
grpc_latency = ((time.time() - start) / iterations) * 1000
print(f"gRPC average latency: {grpc_latency:.2f}ms")
# Compare with REST (FastAPI)
import requests
rest_url = "http://localhost:8000/predict"
# Warm up
for _ in range(10):
requests.post(rest_url, json={"inputs": [[1.0, 2.0, 3.0]]})
# Benchmark
start = time.time()
for _ in range(iterations):
requests.post(rest_url, json={"inputs": [[1.0, 2.0, 3.0]]})
rest_latency = ((time.time() - start) / iterations) * 1000
print(f"REST average latency: {rest_latency:.2f}ms")
print(f"gRPC is {rest_latency/grpc_latency:.1f}× faster")
# Typical results:
# gRPC: 3-5ms
# REST: 10-15ms
# gRPC is 3-5× faster
if __name__ == "__main__":
benchmark_grpc_vs_rest()
gRPC vs REST comparison:
| Metric | gRPC | REST | Advantage |
|---|---|---|---|
| Latency | 3-5ms | 10-15ms | gRPC 3× faster |
| Throughput | 10k req/sec | 3k req/sec | gRPC 3× higher |
| Payload size | Binary (smaller) | JSON (larger) | gRPC 30-50% smaller |
| Debugging | Harder | Easier | REST |
| Browser support | No (requires proxy) | Yes | REST |
| Streaming | Native | Complex (SSE/WebSocket) | gRPC |
Part 4: ONNX Runtime for Cross-Framework Serving
When to use: Need cross-framework support (PyTorch, TensorFlow, etc.), want maximum performance, or deploying to edge devices.
Advantages: Framework-agnostic, highly optimized, smaller deployment size. Disadvantages: Not all models convert easily, limited debugging.
Converting PyTorch to ONNX
# convert_to_onnx.py
import torch
import torch.onnx
def convert_pytorch_to_onnx(model_path: str, output_path: str):
"""
Convert PyTorch model to ONNX format.
WHY: ONNX is framework-agnostic (portable).
WHY: ONNX Runtime is 2-3× faster than native PyTorch inference.
WHY: Smaller deployment size (no PyTorch dependency).
"""
# Load PyTorch model
model = torch.load(model_path)
model.eval()
# Create dummy input (for tracing)
dummy_input = torch.randn(1, 3, 224, 224) # Example: image
# Export to ONNX
torch.onnx.export(
model,
dummy_input,
output_path,
export_params=True, # WHY: Include model weights
opset_version=17, # WHY: Latest stable ONNX opset
do_constant_folding=True, # WHY: Optimize constants at export time
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'}, # WHY: Support variable batch size
'output': {0: 'batch_size'}
}
)
print(f"Model exported to {output_path}")
# Verify ONNX model
import onnx
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
print("ONNX model validation successful")
# Example usage
convert_pytorch_to_onnx("model.pth", "model.onnx")
ONNX Runtime Serving
# serve_onnx.py
import onnxruntime as ort
import numpy as np
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
import logging
logger = logging.getLogger(__name__)
app = FastAPI()
class ONNXModelServer:
"""
ONNX Runtime server with optimizations.
WHY: ONNX Runtime is 2-3× faster than PyTorch inference.
WHY: Smaller memory footprint (no PyTorch/TensorFlow).
WHY: Cross-platform (Windows, Linux, Mac, mobile, edge).
"""
def __init__(self, model_path: str):
self.model_path = model_path
self.session = None
def load_model(self):
"""Load ONNX model with optimizations."""
if self.session is None:
# Set execution providers (GPU > CPU)
# WHY: Tries GPU first, falls back to CPU
providers = [
'CUDAExecutionProvider', # NVIDIA GPU
'CPUExecutionProvider' # CPU fallback
]
# Session options for optimization
sess_options = ort.SessionOptions()
# WHY: Enable graph optimizations (fuse ops, constant folding)
sess_options.graph_optimization_level = (
ort.GraphOptimizationLevel.ORT_ENABLE_ALL
)
# WHY: Intra-op parallelism (parallel ops within graph)
sess_options.intra_op_num_threads = 4
# WHY: Inter-op parallelism (parallel independent subgraphs)
sess_options.inter_op_num_threads = 2
# WHY: Enable memory pattern optimization
sess_options.enable_mem_pattern = True
# WHY: Enable CPU memory arena (reduces allocation overhead)
sess_options.enable_cpu_mem_arena = True
self.session = ort.InferenceSession(
self.model_path,
sess_options=sess_options,
providers=providers
)
# Get input/output metadata
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
logger.info(f"ONNX model loaded: {self.model_path}")
logger.info(f"Execution provider: {self.session.get_providers()[0]}")
def predict(self, inputs: np.ndarray) -> np.ndarray:
"""
Run ONNX inference.
WHY: ONNX Runtime automatically optimizes:
- Operator fusion (combine multiple ops)
- Constant folding (compute constants at load time)
- Memory reuse (reduce allocations)
"""
self.load_model()
# Run inference
outputs = self.session.run(
[self.output_name],
{self.input_name: inputs.astype(np.float32)}
)
return outputs[0]
def benchmark_vs_pytorch(self, num_iterations: int = 1000):
"""Compare ONNX vs PyTorch inference speed."""
import time
import torch
dummy_input = np.random.randn(1, 3, 224, 224).astype(np.float32)
# Warm up
for _ in range(10):
self.predict(dummy_input)
# Benchmark ONNX
start = time.time()
for _ in range(num_iterations):
self.predict(dummy_input)
onnx_time = (time.time() - start) / num_iterations * 1000
# Benchmark PyTorch
pytorch_model = torch.load(self.model_path.replace('.onnx', '.pth'))
pytorch_model.eval()
dummy_tensor = torch.from_numpy(dummy_input)
# Warm up
with torch.no_grad():
for _ in range(10):
pytorch_model(dummy_tensor)
# Benchmark
start = time.time()
with torch.no_grad():
for _ in range(num_iterations):
pytorch_model(dummy_tensor)
pytorch_time = (time.time() - start) / num_iterations * 1000
print(f"ONNX Runtime: {onnx_time:.2f}ms")
print(f"PyTorch: {pytorch_time:.2f}ms")
print(f"ONNX is {pytorch_time/onnx_time:.1f}× faster")
# Typical results:
# ONNX: 5-8ms
# PyTorch: 12-20ms
# ONNX is 2-3× faster
# Global server
onnx_server = ONNXModelServer("model.onnx")
@app.on_event("startup")
async def startup():
onnx_server.load_model()
@app.post("/predict")
async def predict(request: PredictionRequest):
"""ONNX prediction endpoint."""
inputs = np.array(request.inputs, dtype=np.float32)
outputs = onnx_server.predict(inputs)
return {
"predictions": outputs.tolist()
}
ONNX Runtime advantages:
| Feature | Benefit | Measurement |
|---|---|---|
| Speed | Optimized operators | 2-3× faster than native |
| Size | No framework dependency | 10-50MB vs 500MB+ (PyTorch) |
| Portability | Framework-agnostic | PyTorch/TF/etc → ONNX |
| Edge deployment | Lightweight runtime | Mobile, IoT, embedded |
Part 5: Request Batching Patterns
Core principle: Batch requests for GPU efficiency.
Why batching matters:
- GPU utilization: 20% (single request) → 80% (batch of 32)
- Throughput: 100 req/sec (unbatched) → 1000 req/sec (batched)
- Cost: 10× reduction in GPU cost per request
Dynamic Batching (Adaptive)
# dynamic_batching.py
import asyncio
from asyncio import Queue, Lock
from typing import List, Tuple
import numpy as np
import time
import logging
logger = logging.getLogger(__name__)
class DynamicBatcher:
"""
Dynamic batching with adaptive timeout.
WHY: Static batching waits for full batch (high latency at low load).
WHY: Dynamic batching adapts: full batch OR timeout (balanced).
Key parameters:
- max_batch_size: Maximum batch size (GPU memory limit)
- max_wait_ms: Maximum wait time (latency target)
Trade-off:
- Larger batch → better GPU utilization, higher throughput
- Shorter timeout → lower latency, worse GPU utilization
"""
def __init__(
self,
model_server,
max_batch_size: int = 32,
max_wait_ms: int = 10
):
self.model_server = model_server
self.max_batch_size = max_batch_size
self.max_wait_ms = max_wait_ms
self.request_queue: Queue = Queue()
self.batch_lock = Lock()
self.stats = {
"total_requests": 0,
"total_batches": 0,
"avg_batch_size": 0,
"gpu_utilization": 0
}
async def start(self):
"""Start batch processing loop."""
asyncio.create_task(self._batch_loop())
async def _batch_loop(self):
"""
Main batching loop.
Algorithm:
1. Wait for first request
2. Start timeout timer
3. Collect requests until:
- Batch full (max_batch_size reached)
- OR timeout expired (max_wait_ms)
4. Process batch
5. Return results to waiting requests
"""
while True:
batch = []
futures = []
# Wait for first request (no timeout)
request_data, future = await self.request_queue.get()
batch.append(request_data)
futures.append(future)
# Start deadline timer
deadline = asyncio.get_event_loop().time() + (self.max_wait_ms / 1000)
# Collect additional requests until batch full or timeout
while len(batch) < self.max_batch_size:
remaining_time = max(0, deadline - asyncio.get_event_loop().time())
try:
request_data, future = await asyncio.wait_for(
self.request_queue.get(),
timeout=remaining_time
)
batch.append(request_data)
futures.append(future)
except asyncio.TimeoutError:
# Timeout: process what we have
break
# Process batch
await self._process_batch(batch, futures)
async def _process_batch(
self,
batch: List[np.ndarray],
futures: List[asyncio.Future]
):
"""Process batch and return results."""
batch_size = len(batch)
# Convert to batch array
batch_array = np.array(batch)
# Run inference
start_time = time.time()
predictions, probabilities = self.model_server.predict(batch_array)
inference_time = (time.time() - start_time) * 1000
# Update stats
self.stats["total_requests"] += batch_size
self.stats["total_batches"] += 1
self.stats["avg_batch_size"] = (
self.stats["total_requests"] / self.stats["total_batches"]
)
self.stats["gpu_utilization"] = (
self.stats["avg_batch_size"] / self.max_batch_size * 100
)
logger.info(
f"Processed batch: size={batch_size}, "
f"inference_time={inference_time:.2f}ms, "
f"avg_batch_size={self.stats['avg_batch_size']:.1f}, "
f"gpu_util={self.stats['gpu_utilization']:.1f}%"
)
# Return results to waiting requests
for i, future in enumerate(futures):
if not future.done():
future.set_result((predictions[i], probabilities[i]))
async def predict(self, inputs: np.ndarray) -> Tuple[int, np.ndarray]:
"""
Add request to batch queue.
WHY: Transparent batching (caller doesn't see batching).
WHY: Returns when batch processed (might wait for other requests).
"""
future = asyncio.Future()
await self.request_queue.put((inputs, future))
# Wait for batch to be processed
prediction, probability = await future
return prediction, probability
def get_stats(self):
"""Get batching statistics."""
return self.stats
# Example usage with load simulation
async def simulate_load():
"""
Simulate varying load to demonstrate dynamic batching.
WHY: Shows how batcher adapts to load:
- High load: Fills batches quickly (high GPU util)
- Low load: Processes smaller batches (low latency)
"""
from serve_fastapi import ModelServer
model_server = ModelServer("model.pth")
model_server.load_model()
batcher = DynamicBatcher(
model_server,
max_batch_size=32,
max_wait_ms=10
)
await batcher.start()
# High load (32 concurrent requests)
print("Simulating HIGH LOAD (32 concurrent)...")
tasks = []
for i in range(32):
inputs = np.random.randn(10)
task = asyncio.create_task(batcher.predict(inputs))
tasks.append(task)
results = await asyncio.gather(*tasks)
print(f"High load results: {len(results)} predictions")
print(f"Stats: {batcher.get_stats()}")
# Expected: avg_batch_size ≈ 32, gpu_util ≈ 100%
await asyncio.sleep(0.1) # Reset
# Low load (1 request at a time)
print("\nSimulating LOW LOAD (1 at a time)...")
for i in range(10):
inputs = np.random.randn(10)
result = await batcher.predict(inputs)
await asyncio.sleep(0.02) # 20ms between requests
print(f"Stats: {batcher.get_stats()}")
# Expected: avg_batch_size ≈ 1-2, gpu_util ≈ 5-10%
# WHY: Timeout expires before batch fills (low latency maintained)
if __name__ == "__main__":
asyncio.run(simulate_load())
Batching performance:
| Load | Batch Size | GPU Util | Latency | Throughput |
|---|---|---|---|---|
| High (100 req/sec) | 28-32 | 90% | 12ms | 1000 req/sec |
| Medium (20 req/sec) | 8-12 | 35% | 11ms | 200 req/sec |
| Low (5 req/sec) | 1-2 | 10% | 10ms | 50 req/sec |
Key insight: Dynamic batching adapts to load while maintaining latency target.
Part 6: Containerization
Why containerize: "Works on my machine" → "Works everywhere"
Benefits:
- Reproducible builds (same dependencies, versions)
- Isolated environment (no conflicts)
- Portable deployment (dev, staging, prod identical)
- Easy scaling (K8s, Docker Swarm)
Multi-Stage Docker Build
# Dockerfile
# WHY: Multi-stage build reduces image size by 50-80%
# WHY: Build stage has compilers, runtime stage only has runtime deps
# ==================== Stage 1: Build ====================
FROM python:3.11-slim as builder
# WHY: Install build dependencies (needed for compilation)
RUN apt-get update && apt-get install -y \
gcc \
g++ \
&& rm -rf /var/lib/apt/lists/*
# WHY: Create virtual environment in builder stage
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# WHY: Copy only requirements first (layer caching)
# WHY: If requirements.txt unchanged, this layer is cached
COPY requirements.txt .
# WHY: Install Python dependencies
RUN pip install --no-cache-dir -r requirements.txt
# ==================== Stage 2: Runtime ====================
FROM python:3.11-slim
# WHY: Copy only virtual environment from builder (not build tools)
COPY --from=builder /opt/venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# WHY: Set working directory
WORKDIR /app
# WHY: Copy application code
COPY serve_fastapi.py .
COPY model.pth .
# WHY: Non-root user for security
RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
USER appuser
# WHY: Expose port (documentation, not enforcement)
EXPOSE 8000
# WHY: Health check for container orchestration
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# WHY: Run with uvicorn (production ASGI server)
CMD ["uvicorn", "serve_fastapi:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
Docker Compose for Multi-Service
# docker-compose.yml
# WHY: Docker Compose for local development and testing
# WHY: Defines multiple services (API, model, monitoring)
version: '3.8'
services:
# Model serving API
model-api:
build:
context: .
dockerfile: Dockerfile
ports:
- "8000:8000"
environment:
# WHY: Environment variables for configuration
- MODEL_PATH=/app/model.pth
- LOG_LEVEL=INFO
volumes:
# WHY: Mount model directory (for updates without rebuild)
- ./models:/app/models:ro
deploy:
resources:
# WHY: Limit resources to prevent resource exhaustion
limits:
cpus: '2'
memory: 4G
reservations:
# WHY: Reserve GPU (requires nvidia-docker)
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
# Redis for caching
redis:
image: redis:7-alpine
ports:
- "6379:6379"
volumes:
- redis-data:/data
command: redis-server --appendonly yes
# Prometheus for metrics
prometheus:
image: prom/prometheus:latest
ports:
- "9090:9090"
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml
- prometheus-data:/prometheus
command:
- '--config.file=/etc/prometheus/prometheus.yml'
# Grafana for visualization
grafana:
image: grafana/grafana:latest
ports:
- "3000:3000"
environment:
- GF_SECURITY_ADMIN_PASSWORD=admin
volumes:
- grafana-data:/var/lib/grafana
volumes:
redis-data:
prometheus-data:
grafana-data:
Build and Deploy
# Build image
# WHY: Tag with version for rollback capability
docker build -t model-api:1.0.0 .
# Run container
docker run -d \
--name model-api \
-p 8000:8000 \
--gpus all \
model-api:1.0.0
# Check logs
docker logs -f model-api
# Test API
curl http://localhost:8000/health
# Start all services with docker-compose
docker-compose up -d
# Scale API service (multiple instances)
# WHY: Load balancer distributes traffic across instances
docker-compose up -d --scale model-api=3
# View logs
docker-compose logs -f model-api
# Stop all services
docker-compose down
Container image sizes:
| Stage | Size | Contents |
|---|---|---|
| Full build | 2.5 GB | Python + build tools + deps + model |
| Multi-stage | 800 MB | Python + runtime deps + model |
| Optimized | 400 MB | Minimal Python + deps + model |
| Savings | 84% | From 2.5 GB → 400 MB |
Part 7: Framework Selection Guide
Decision Matrix
# framework_selector.py
from enum import Enum
from typing import List
class Requirement(Enum):
FLEXIBILITY = "flexibility" # Custom preprocessing, business logic
BATTERIES_INCLUDED = "batteries" # Minimal setup, built-in features
LOW_LATENCY = "low_latency" # < 10ms target
CROSS_FRAMEWORK = "cross_framework" # PyTorch + TensorFlow support
EDGE_DEPLOYMENT = "edge" # Mobile, IoT, embedded
EASE_OF_DEBUG = "debug" # Development experience
HIGH_THROUGHPUT = "throughput" # > 1000 req/sec
class Framework(Enum):
FASTAPI = "fastapi"
TORCHSERVE = "torchserve"
GRPC = "grpc"
ONNX = "onnx"
# Framework capabilities (0-5 scale)
FRAMEWORK_SCORES = {
Framework.FASTAPI: {
Requirement.FLEXIBILITY: 5, # Full control
Requirement.BATTERIES_INCLUDED: 2, # Manual implementation
Requirement.LOW_LATENCY: 3, # 10-20ms
Requirement.CROSS_FRAMEWORK: 4, # Any Python model
Requirement.EDGE_DEPLOYMENT: 2, # Heavyweight
Requirement.EASE_OF_DEBUG: 5, # Excellent debugging
Requirement.HIGH_THROUGHPUT: 3 # 100-500 req/sec
},
Framework.TORCHSERVE: {
Requirement.FLEXIBILITY: 3, # Customizable via handlers
Requirement.BATTERIES_INCLUDED: 5, # Everything built-in
Requirement.LOW_LATENCY: 4, # 5-15ms
Requirement.CROSS_FRAMEWORK: 1, # PyTorch only
Requirement.EDGE_DEPLOYMENT: 2, # Heavyweight
Requirement.EASE_OF_DEBUG: 3, # Learning curve
Requirement.HIGH_THROUGHPUT: 5 # 1000+ req/sec with batching
},
Framework.GRPC: {
Requirement.FLEXIBILITY: 4, # Binary protocol, custom logic
Requirement.BATTERIES_INCLUDED: 2, # Manual implementation
Requirement.LOW_LATENCY: 5, # 3-8ms
Requirement.CROSS_FRAMEWORK: 4, # Any model
Requirement.EDGE_DEPLOYMENT: 3, # Moderate size
Requirement.EASE_OF_DEBUG: 2, # Binary protocol harder
Requirement.HIGH_THROUGHPUT: 5 # 1000+ req/sec
},
Framework.ONNX: {
Requirement.FLEXIBILITY: 3, # Limited to ONNX ops
Requirement.BATTERIES_INCLUDED: 3, # Runtime provided
Requirement.LOW_LATENCY: 5, # 2-6ms (optimized)
Requirement.CROSS_FRAMEWORK: 5, # Any framework → ONNX
Requirement.EDGE_DEPLOYMENT: 5, # Lightweight runtime
Requirement.EASE_OF_DEBUG: 2, # Conversion can be tricky
Requirement.HIGH_THROUGHPUT: 4 # 500-1000 req/sec
}
}
def select_framework(
requirements: List[Requirement],
weights: List[float] = None
) -> Framework:
"""
Select best framework based on requirements.
Args:
requirements: List of requirements
weights: Importance weight for each requirement (0-1)
Returns:
Best framework
"""
if weights is None:
weights = [1.0] * len(requirements)
scores = {}
for framework in Framework:
score = 0
for req, weight in zip(requirements, weights):
score += FRAMEWORK_SCORES[framework][req] * weight
scores[framework] = score
best_framework = max(scores, key=scores.get)
print(f"\nFramework Selection:")
print(f"Requirements: {[r.value for r in requirements]}")
print(f"\nScores:")
for framework, score in sorted(scores.items(), key=lambda x: x[1], reverse=True):
print(f" {framework.value}: {score:.1f}")
return best_framework
# Example use cases
print("=" * 60)
print("Use Case 1: Prototyping with flexibility")
print("=" * 60)
selected = select_framework([
Requirement.FLEXIBILITY,
Requirement.EASE_OF_DEBUG
])
print(f"\nRecommendation: {selected.value}")
# Expected: FASTAPI
print("\n" + "=" * 60)
print("Use Case 2: Production PyTorch with minimal setup")
print("=" * 60)
selected = select_framework([
Requirement.BATTERIES_INCLUDED,
Requirement.HIGH_THROUGHPUT
])
print(f"\nRecommendation: {selected.value}")
# Expected: TORCHSERVE
print("\n" + "=" * 60)
print("Use Case 3: Low-latency microservice")
print("=" * 60)
selected = select_framework([
Requirement.LOW_LATENCY,
Requirement.HIGH_THROUGHPUT
])
print(f"\nRecommendation: {selected.value}")
# Expected: GRPC or ONNX
print("\n" + "=" * 60)
print("Use Case 4: Edge deployment (mobile/IoT)")
print("=" * 60)
selected = select_framework([
Requirement.EDGE_DEPLOYMENT,
Requirement.CROSS_FRAMEWORK,
Requirement.LOW_LATENCY
])
print(f"\nRecommendation: {selected.value}")
# Expected: ONNX
print("\n" + "=" * 60)
print("Use Case 5: Multi-framework ML platform")
print("=" * 60)
selected = select_framework([
Requirement.CROSS_FRAMEWORK,
Requirement.HIGH_THROUGHPUT,
Requirement.BATTERIES_INCLUDED
])
print(f"\nRecommendation: {selected.value}")
# Expected: ONNX or TORCHSERVE (depending on weights)
Quick Reference Guide
| Scenario | Framework | Why |
|---|---|---|
| Prototyping | FastAPI | Fast iteration, easy debugging |
| PyTorch production | TorchServe | Built-in batching, metrics, management |
| Internal microservices | gRPC | Lowest latency, high throughput |
| Multi-framework | ONNX Runtime | Framework-agnostic, optimized |
| Edge/mobile | ONNX Runtime | Lightweight, cross-platform |
| Custom preprocessing | FastAPI | Full flexibility |
| High throughput batch | TorchServe + batching | Dynamic batching built-in |
| Real-time streaming | gRPC | Bidirectional streaming |
Summary
Model serving is pattern matching, not one-size-fits-all.
Core patterns:
- FastAPI: Flexibility, custom logic, easy debugging
- TorchServe: PyTorch batteries-included, built-in batching
- gRPC: Low latency (3-5ms), high throughput, microservices
- ONNX Runtime: Cross-framework, optimized, edge deployment
- Dynamic batching: Adaptive batch size, balances latency and throughput
- Containerization: Reproducible, portable, scalable
Selection checklist:
- ✓ Identify primary requirement (flexibility, latency, throughput, etc.)
- ✓ Match requirement to framework strengths
- ✓ Consider deployment environment (cloud, edge, on-prem)
- ✓ Evaluate trade-offs (development speed vs performance)
- ✓ Implement batching if GPU-based (10× better utilization)
- ✓ Containerize for reproducibility
- ✓ Monitor metrics (latency, throughput, GPU util)
- ✓ Iterate based on production data
Anti-patterns to avoid:
- ✗ model.pkl in repo (dependency hell)
- ✗ gRPC for simple REST use cases (over-engineering)
- ✗ No batching with GPU (wasted 80% capacity)
- ✗ Not containerized (deployment inconsistency)
- ✗ Static batching (poor latency at low load)
Production-ready model serving requires matching infrastructure pattern to requirements.