| name | rust-candle-whisper |
| description | Implement native Rust ML inference with Candle framework. Use when building GPU-accelerated ML pipelines without Python dependencies. |
Native ML with Candle
Pure Rust ML inference using the Candle framework for GPU-accelerated models.
Setup
# Cargo.toml
[dependencies]
candle-core = "0.4"
candle-nn = "0.4"
candle-transformers = "0.4"
hf-hub = "0.3"
tokenizers = "0.15"
symphonia = { version = "0.5", features = ["all"] }
[features]
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
Model Structure
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::whisper::{self as m, Config};
use hf_hub::{Repo, RepoType};
use std::path::Path;
pub struct WhisperModel {
model: m::model::Whisper,
tokenizer: WhisperTokenizer,
mel_filters: Vec<f32>,
device: Device,
}
Device Initialization
impl WhisperModel {
fn init_device() -> Result<Device> {
// Try CUDA first
#[cfg(feature = "cuda")]
{
if let Ok(device) = Device::new_cuda(0) {
tracing::info!("Using CUDA device");
return Ok(device);
}
}
// Fall back to CPU
tracing::info!("Using CPU device");
Ok(Device::Cpu)
}
}
Loading from HuggingFace Hub
impl WhisperModel {
pub fn load(model_id: &str, cache_dir: Option<&Path>) -> Result<Self> {
tracing::info!("Loading model: {}", model_id);
// Setup HF Hub with custom cache
let cache_path = cache_dir
.map(|p| p.to_path_buf())
.unwrap_or_else(|| PathBuf::from("models/hf"));
std::env::set_var("HF_HOME", &cache_path);
let api = hf_hub::api::sync::ApiBuilder::new()
.with_cache_dir(cache_path)
.build()?;
let repo = api.repo(Repo::new(model_id.to_string(), RepoType::Model));
// Download model files
let config_path = repo.get("config.json")?;
let tokenizer_path = repo.get("tokenizer.json")?;
let weights_path = repo.get("model.safetensors")?;
// Load configuration
let config: Config = {
let content = std::fs::read_to_string(&config_path)?;
serde_json::from_str(&content)?
};
tracing::info!(
"Config: {} encoder layers, {} decoder layers",
config.encoder_layers,
config.decoder_layers
);
// Initialize device
let device = Self::init_device()?;
// Load weights with memory mapping
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_path], DType::F32, &device)?
};
// Build model
let model = m::model::Whisper::load(&vb, config)?;
// Load tokenizer and mel filters
let tokenizer = WhisperTokenizer::load(&tokenizer_path)?;
let mel_filters = load_mel_filters()?;
Ok(Self {
model,
tokenizer,
mel_filters,
device,
})
}
}
Audio Loading with Symphonia
use symphonia::core::audio::SampleBuffer;
use symphonia::core::codecs::DecoderOptions;
use symphonia::core::formats::FormatOptions;
use symphonia::core::io::MediaSourceStream;
use symphonia::core::probe::Hint;
pub fn load_audio(path: &Path) -> Result<Vec<f32>> {
let file = std::fs::File::open(path)?;
let mss = MediaSourceStream::new(Box::new(file), Default::default());
let mut hint = Hint::new();
if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
hint.with_extension(ext);
}
let probed = symphonia::default::get_probe()
.format(&hint, mss, &FormatOptions::default(), &Default::default())?;
let mut format = probed.format;
let track = format.default_track()
.ok_or_else(|| Error::Audio("No audio track".into()))?;
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &DecoderOptions::default())?;
let track_id = track.id;
let mut samples = Vec::new();
loop {
let packet = match format.next_packet() {
Ok(p) => p,
Err(symphonia::core::errors::Error::IoError(ref e))
if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e.into()),
};
if packet.track_id() != track_id {
continue;
}
let decoded = decoder.decode(&packet)?;
let spec = *decoded.spec();
let mut sample_buf = SampleBuffer::<f32>::new(decoded.capacity() as u64, spec);
sample_buf.copy_interleaved_ref(decoded);
// Convert to mono if stereo
let channel_samples = sample_buf.samples();
if spec.channels.count() > 1 {
let channels = spec.channels.count();
for chunk in channel_samples.chunks(channels) {
let avg: f32 = chunk.iter().sum::<f32>() / channels as f32;
samples.push(avg);
}
} else {
samples.extend_from_slice(channel_samples);
}
}
Ok(samples)
}
Mel Spectrogram Computation
const N_FFT: usize = 400;
const HOP_LENGTH: usize = 160;
const N_MELS: usize = 128;
pub fn pcm_to_mel(samples: &[f32], filters: &[f32], device: &Device) -> Result<Tensor> {
let n_frames = (samples.len() - N_FFT) / HOP_LENGTH + 1;
// Pre-compute Hann window
let hann_window: Vec<f32> = (0..N_FFT)
.map(|i| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / N_FFT as f32).cos()))
.collect();
// Compute STFT magnitudes
let fft_size = N_FFT / 2 + 1;
let mut magnitudes = vec![0.0f32; n_frames * fft_size];
for frame_idx in 0..n_frames {
let start = frame_idx * HOP_LENGTH;
// Apply window
let windowed: Vec<f32> = samples[start..start + N_FFT]
.iter()
.zip(&hann_window)
.map(|(s, w)| s * w)
.collect();
// DFT for power spectrum
for k in 0..fft_size {
let mut real = 0.0f32;
let mut imag = 0.0f32;
for (n, &sample) in windowed.iter().enumerate() {
let angle = -2.0 * std::f32::consts::PI * k as f32 * n as f32 / N_FFT as f32;
real += sample * angle.cos();
imag += sample * angle.sin();
}
magnitudes[frame_idx * fft_size + k] = real * real + imag * imag;
}
}
// Apply mel filterbank
let mut mel_spec = vec![0.0f32; n_frames * N_MELS];
for frame in 0..n_frames {
for mel in 0..N_MELS {
let mut sum = 0.0f32;
for k in 0..fft_size {
sum += filters[mel * fft_size + k] * magnitudes[frame * fft_size + k];
}
mel_spec[frame * N_MELS + mel] = sum.max(1e-10);
}
}
// Log scale and normalize
let log_spec: Vec<f32> = mel_spec.iter().map(|&x| x.ln().max(-10.0)).collect();
let max_val = log_spec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let normalized: Vec<f32> = log_spec
.iter()
.map(|&x| ((x - max_val) / 4.0).clamp(-1.0, 1.0))
.collect();
// Create tensor: (1, n_mels, n_frames)
Tensor::from_vec(normalized, (1, N_MELS, n_frames), device)
.map_err(Into::into)
}
Autoregressive Decoding
use candle_nn::ops::softmax;
pub struct Decoder<'a> {
model: &'a mut m::model::Whisper,
tokenizer: &'a WhisperTokenizer,
device: &'a Device,
suppress_tokens: Vec<u32>,
}
impl<'a> Decoder<'a> {
pub fn decode(&mut self, audio_features: &Tensor) -> Result<String> {
// Initial tokens: <|startoftranscript|><|en|><|transcribe|><|notimestamps|>
let mut tokens: Vec<u32> = vec![50258, 50259, 50359, 50363];
let mut all_tokens = tokens.clone();
// Autoregressive loop
for step in 0..448 {
let token_tensor = Tensor::new(tokens.as_slice(), self.device)?
.unsqueeze(0)?;
// Run decoder
let logits = self.model.decoder
.forward(&token_tensor, audio_features, step == 0)?;
// Get last token logits
let seq_len = logits.dim(1)?;
let last_logits = logits.i((.., seq_len - 1, ..))?;
// Apply suppression and sample
let last_logits = self.apply_suppression(&last_logits)?;
let probs = softmax(&last_logits, candle_core::D::Minus1)?;
let next_token = probs
.argmax(candle_core::D::Minus1)?
.to_dtype(DType::U32)?
.to_vec1::<u32>()?[0];
// Check for end of transcript
if next_token == 50257 { // <|endoftext|>
break;
}
all_tokens.push(next_token);
tokens = vec![next_token];
}
// Decode tokens to text
let text_tokens: Vec<u32> = all_tokens
.iter()
.filter(|&&t| t < 50257) // Filter special tokens
.copied()
.collect();
self.tokenizer.decode(&text_tokens)
}
fn apply_suppression(&self, logits: &Tensor) -> Result<Tensor> {
let mut logits_vec = logits.to_vec2::<f32>()?;
// Suppress specified tokens
for &token in &self.suppress_tokens {
logits_vec[0][token as usize] = f32::NEG_INFINITY;
}
Tensor::new(logits_vec, self.device).map_err(Into::into)
}
}
Global Model Caching
use std::sync::OnceLock;
use parking_lot::Mutex;
static WHISPER_MODEL: OnceLock<Mutex<WhisperModel>> = OnceLock::new();
pub fn transcribe(audio_path: &Path) -> Result<String> {
let model = WHISPER_MODEL.get_or_init(|| {
tracing::info!("Loading Whisper model (first use)...");
Mutex::new(WhisperModel::load_default().expect("Failed to load model"))
});
let mut model_guard = model.lock();
model_guard.transcribe(audio_path)
}
pub fn preload_model() -> Result<()> {
if WHISPER_MODEL.get().is_some() {
return Ok(());
}
let model = WhisperModel::load_default()?;
let _ = WHISPER_MODEL.get_or_init(|| Mutex::new(model));
Ok(())
}
VRAM Estimation
fn estimate_vram_gb(config: &Config) -> f32 {
let encoder_params = config.encoder_layers * config.d_model * config.d_model * 4;
let decoder_params = config.decoder_layers * config.d_model * config.d_model * 4;
let vocab_params = config.vocab_size * config.d_model;
let total_params = encoder_params + decoder_params + vocab_params;
// float32 = 4 bytes, plus 20% overhead
(total_params as f32 * 4.0 * 1.2) / (1024.0 * 1024.0 * 1024.0)
}
Guidelines
- Use
cudafeature for GPU acceleration - Memory-map weights with
from_mmaped_safetensors - Cache models globally with
OnceLock - Use Symphonia for pure-Rust audio decoding
- Pre-compute mel filterbank coefficients
- Implement token suppression for stable decoding
- Estimate VRAM before loading models
Examples
See hercules-local-algo/src/whisper/ for complete Whisper implementation.