| name | replicate-integration |
| description | Integrate Replicate API for AI model deployment. Use when generating images with Flux, SDXL, or custom LoRA models via Replicate. |
| license | MIT |
Replicate API Integration
Purpose
Deploy and run AI models in production using Replicate's cloud platform. Specialized for image generation with Flux Dev, SDXL, custom LoRA fine-tuning, and prediction polling patterns.
When to Use
- Generating images with Flux Dev, SDXL, or custom models
- Fine-tuning models with custom LoRA weights
- Running predictions with polling/webhook patterns
- Deploying custom models to production
- Managing long-running AI workloads
Architecture Pattern
Project Structure
backend/
├── services/
│ ├── replicate_service.py # Main Replicate client
│ └── model_service.py # Model-specific logic
├── models/
│ └── replicate_models.py # Pydantic models
├── config/
│ └── replicate_config.py # Configuration
└── utils/
└── polling.py # Polling utilities
Installation
pip install replicate httpx python-dotenv pydantic
Environment Setup
# .env
REPLICATE_API_TOKEN=r8_...
FRONTEND_URL=http://localhost:3000
DEFAULT_MODEL=black-forest-labs/flux-dev
LORA_MODEL_ID=your-username/your-model-id
Quick Start
Basic Image Generation
import replicate
import os
client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
# Simple prediction
output = client.run(
"black-forest-labs/flux-dev",
input={
"prompt": "A serene mountain landscape at sunset",
"num_outputs": 1,
"aspect_ratio": "16:9"
}
)
print(f"Generated image: {output[0]}")
Async Pattern with Polling
import replicate
import asyncio
from typing import List
async def generate_images_async(
prompt: str,
num_images: int = 4
) -> List[str]:
client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
# Start prediction
prediction = client.predictions.create(
version="black-forest-labs/flux-dev",
input={
"prompt": prompt,
"num_outputs": num_images,
"guidance_scale": 3.5,
"num_inference_steps": 28
}
)
# Poll for completion
while prediction.status not in ["succeeded", "failed", "canceled"]:
await asyncio.sleep(1)
prediction = client.predictions.get(prediction.id)
if prediction.status == "succeeded":
return prediction.output
else:
raise Exception(f"Prediction failed: {prediction.error}")
Flux Dev Integration
Complete Flux Dev Configuration
from pydantic import BaseModel, Field
from typing import Literal
class FluxDevInput(BaseModel):
prompt: str = Field(description="Text description of image to generate")
aspect_ratio: Literal["1:1", "16:9", "21:9", "3:2", "2:3", "4:5", "5:4", "9:16", "9:21"] = "1:1"
num_outputs: int = Field(ge=1, le=4, default=1)
num_inference_steps: int = Field(ge=1, le=50, default=28)
guidance_scale: float = Field(ge=0, le=10, default=3.5)
output_format: Literal["webp", "jpg", "png"] = "webp"
output_quality: int = Field(ge=0, le=100, default=80)
seed: int | None = None
disable_safety_checker: bool = False
async def generate_flux_images(input: FluxDevInput) -> List[str]:
client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
output = await client.async_run(
"black-forest-labs/flux-dev",
input=input.model_dump()
)
return output
Flux Dev Best Practices
# Optimal settings for portraits
PORTRAIT_CONFIG = {
"aspect_ratio": "2:3",
"num_inference_steps": 30,
"guidance_scale": 4.0,
"output_format": "webp",
"output_quality": 90
}
# Optimal settings for landscapes
LANDSCAPE_CONFIG = {
"aspect_ratio": "16:9",
"num_inference_steps": 28,
"guidance_scale": 3.5,
"output_format": "webp",
"output_quality": 85
}
# Batch generation pattern
async def batch_generate(prompts: List[str]) -> List[List[str]]:
tasks = [generate_flux_images(FluxDevInput(prompt=p)) for p in prompts]
return await asyncio.gather(*tasks)
LoRA Fine-Tuning Integration
Using Custom LoRA Models
class LoRAInput(BaseModel):
prompt: str
lora_scale: float = Field(ge=0, le=1, default=1.0, description="LoRA influence strength")
trigger_word: str | None = Field(default=None, description="Special token for identity")
num_outputs: int = Field(ge=1, le=10, default=1)
async def generate_with_lora(
model_id: str, # e.g., "daniel-carreon/danielcarrong"
input: LoRAInput
) -> List[str]:
client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
# Ensure trigger word is in prompt
prompt = input.prompt
if input.trigger_word and input.trigger_word not in prompt:
prompt = f"{input.trigger_word} {prompt}"
output = await client.async_run(
model_id,
input={
"prompt": prompt,
"lora_scale": input.lora_scale,
"num_outputs": input.num_outputs,
"num_inference_steps": 28,
"guidance_scale": 3.5
}
)
return output
Training Custom LoRA
async def train_lora(
images_zip_url: str,
trigger_word: str,
steps: int = 1000
) -> str:
"""Train custom LoRA model on Replicate"""
client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
training = client.trainings.create(
version="ostris/flux-dev-lora-trainer",
input={
"input_images": images_zip_url,
"trigger_word": trigger_word,
"steps": steps,
"lora_rank": 16,
"optimizer": "adamw8bit",
"batch_size": 1,
"learning_rate": 4e-4,
"caption_prefix": f"a photo of {trigger_word}"
},
destination=f"{os.getenv('REPLICATE_USERNAME')}/my-lora-model"
)
# Wait for training completion
while training.status not in ["succeeded", "failed", "canceled"]:
await asyncio.sleep(30)
training = client.trainings.get(training.id)
if training.status == "succeeded":
return training.output # Model URL
else:
raise Exception(f"Training failed: {training.error}")
Polling Patterns
Robust Polling with Retry
import asyncio
from typing import Callable, Any
class PollConfig(BaseModel):
max_wait: int = 300 # 5 minutes
poll_interval: float = 1.0 # 1 second
timeout_multiplier: float = 1.5 # Backoff factor
async def poll_prediction(
prediction_id: str,
on_progress: Callable[[float], None] | None = None
) -> Any:
"""Poll Replicate prediction with exponential backoff"""
client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
start_time = asyncio.get_event_loop().time()
poll_interval = 1.0
while True:
prediction = client.predictions.get(prediction_id)
# Report progress
if on_progress and hasattr(prediction, 'logs'):
progress = extract_progress(prediction.logs)
on_progress(progress)
# Check status
if prediction.status == "succeeded":
return prediction.output
elif prediction.status in ["failed", "canceled"]:
raise Exception(f"Prediction {prediction.status}: {prediction.error}")
# Timeout check
elapsed = asyncio.get_event_loop().time() - start_time
if elapsed > 300: # 5 minutes
raise TimeoutError(f"Prediction timeout after {elapsed}s")
# Exponential backoff
await asyncio.sleep(poll_interval)
poll_interval = min(poll_interval * 1.5, 5.0)
def extract_progress(logs: str) -> float:
"""Extract progress from logs (0.0 to 1.0)"""
# Example: "Progress: 50%"
import re
match = re.search(r"Progress: (\d+)%", logs or "")
return float(match.group(1)) / 100 if match else 0.0
Webhook Pattern (Production)
from fastapi import FastAPI, Request
app = FastAPI()
@app.post("/webhooks/replicate")
async def handle_webhook(request: Request):
"""Handle Replicate webhook callback"""
payload = await request.json()
prediction_id = payload["id"]
status = payload["status"]
if status == "succeeded":
output = payload["output"]
# Process completed prediction
await save_results(prediction_id, output)
elif status == "failed":
error = payload["error"]
# Handle error
await log_error(prediction_id, error)
return {"status": "received"}
# Start prediction with webhook
def create_prediction_with_webhook(prompt: str) -> str:
client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
prediction = client.predictions.create(
version="black-forest-labs/flux-dev",
input={"prompt": prompt},
webhook=f"{os.getenv('BACKEND_URL')}/webhooks/replicate",
webhook_events_filter=["completed"]
)
return prediction.id
Error Handling
Comprehensive Error Handling
from enum import Enum
class ReplicateError(Exception):
"""Base Replicate error"""
pass
class RateLimitError(ReplicateError):
"""Rate limit exceeded"""
pass
class ModelNotFoundError(ReplicateError):
"""Model not found"""
pass
async def safe_replicate_call(
model: str,
input: dict,
max_retries: int = 3
) -> Any:
"""Call Replicate with retry logic"""
client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
for attempt in range(max_retries):
try:
output = await client.async_run(model, input=input)
return output
except replicate.exceptions.ModelError as e:
if "not found" in str(e).lower():
raise ModelNotFoundError(f"Model {model} not found")
raise
except replicate.exceptions.ReplicateError as e:
if "rate limit" in str(e).lower():
if attempt < max_retries - 1:
await asyncio.sleep(2 ** attempt) # Exponential backoff
continue
raise RateLimitError("Rate limit exceeded")
raise ReplicateError(f"Replicate error: {e}")
except Exception as e:
if attempt < max_retries - 1:
await asyncio.sleep(1)
continue
raise
FastAPI Integration
Complete Replicate Endpoint
from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel
app = FastAPI()
class GenerateRequest(BaseModel):
prompt: str
num_images: int = 4
use_lora: bool = False
lora_model_id: str | None = None
class GenerateResponse(BaseModel):
prediction_id: str
status: str
images: List[str] | None = None
@app.post("/generate", response_model=GenerateResponse)
async def generate_images(request: GenerateRequest):
"""Generate images using Replicate"""
try:
client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
# Select model
model = request.lora_model_id if request.use_lora else "black-forest-labs/flux-dev"
# Create prediction
prediction = client.predictions.create(
version=model,
input={
"prompt": request.prompt,
"num_outputs": request.num_images
}
)
return GenerateResponse(
prediction_id=prediction.id,
status=prediction.status
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/predictions/{prediction_id}")
async def get_prediction(prediction_id: str):
"""Check prediction status"""
try:
client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
prediction = client.predictions.get(prediction_id)
return GenerateResponse(
prediction_id=prediction.id,
status=prediction.status,
images=prediction.output if prediction.status == "succeeded" else None
)
except Exception as e:
raise HTTPException(status_code=404, detail="Prediction not found")
Rate Limiting
Rate Limit Management
from collections import deque
from datetime import datetime, timedelta
class RateLimiter:
def __init__(self, max_requests: int = 50, window: int = 60):
self.max_requests = max_requests
self.window = timedelta(seconds=window)
self.requests: deque[datetime] = deque()
async def acquire(self):
"""Wait if rate limit reached"""
now = datetime.now()
# Remove old requests
while self.requests and self.requests[0] < now - self.window:
self.requests.popleft()
# Check limit
if len(self.requests) >= self.max_requests:
wait_time = (self.requests[0] + self.window - now).total_seconds()
if wait_time > 0:
await asyncio.sleep(wait_time)
self.requests.append(now)
# Usage
limiter = RateLimiter(max_requests=50, window=60)
async def generate_with_limit(prompt: str):
await limiter.acquire()
return await generate_flux_images(FluxDevInput(prompt=prompt))
Best Practices
- Always use async/await for non-blocking I/O
- Implement polling with backoff to avoid rate limits
- Use webhooks in production for long-running tasks
- Cache prediction results to avoid redundant API calls
- Monitor costs - log prediction IDs and metrics
- Handle errors gracefully with retry logic
- Use typed inputs with Pydantic models
- Set timeouts for all predictions
- Validate outputs before returning to users
- Store metadata for debugging and analytics
Common Pitfalls
❌ Don't: Poll too frequently (wastes API calls) ✅ Do: Use exponential backoff (1s → 1.5s → 2.25s → ...)
❌ Don't: Forget to handle prediction failures ✅ Do: Check status and handle errors
❌ Don't: Hardcode model versions ✅ Do: Use environment variables for flexibility
❌ Don't: Block on predictions in API endpoints ✅ Do: Return prediction ID immediately, poll separately
Complete Example: Production Service
from fastapi import FastAPI, BackgroundTasks
import replicate
from pydantic import BaseModel
from typing import List
import asyncio
app = FastAPI()
class ImageGenerationService:
def __init__(self):
self.client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
self.rate_limiter = RateLimiter(max_requests=50, window=60)
async def generate(
self,
prompt: str,
num_images: int = 4,
model_id: str = "black-forest-labs/flux-dev"
) -> List[str]:
"""Generate images with rate limiting and error handling"""
await self.rate_limiter.acquire()
try:
prediction = self.client.predictions.create(
version=model_id,
input={
"prompt": prompt,
"num_outputs": num_images,
"num_inference_steps": 28,
"guidance_scale": 3.5
}
)
# Poll for completion
output = await poll_prediction(
prediction.id,
on_progress=lambda p: print(f"Progress: {p*100:.0f}%")
)
return output
except Exception as e:
print(f"Generation failed: {e}")
raise
# Global service instance
service = ImageGenerationService()
@app.post("/api/generate")
async def generate_endpoint(request: GenerateRequest):
try:
images = await service.generate(
prompt=request.prompt,
num_images=request.num_images
)
return {"status": "success", "images": images}
except Exception as e:
return {"status": "error", "message": str(e)}