Claude Code Plugins

Community-maintained marketplace

Feedback

Implement native Rust ML inference with Candle framework. Use when building GPU-accelerated ML pipelines without Python dependencies.

Install Skill

1Download skill
2Enable skills in Claude

Open claude.ai/settings/capabilities and find the "Skills" section

3Upload to Claude

Click "Upload skill" and select the downloaded ZIP file

Note: Please verify skill by going through its instructions before using it.

SKILL.md

name rust-candle-whisper
description Implement native Rust ML inference with Candle framework. Use when building GPU-accelerated ML pipelines without Python dependencies.

Native ML with Candle

Pure Rust ML inference using the Candle framework for GPU-accelerated models.

Setup

# Cargo.toml
[dependencies]
candle-core = "0.4"
candle-nn = "0.4"
candle-transformers = "0.4"
hf-hub = "0.3"
tokenizers = "0.15"
symphonia = { version = "0.5", features = ["all"] }

[features]
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]

Model Structure

use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::whisper::{self as m, Config};
use hf_hub::{Repo, RepoType};
use std::path::Path;

pub struct WhisperModel {
    model: m::model::Whisper,
    tokenizer: WhisperTokenizer,
    mel_filters: Vec<f32>,
    device: Device,
}

Device Initialization

impl WhisperModel {
    fn init_device() -> Result<Device> {
        // Try CUDA first
        #[cfg(feature = "cuda")]
        {
            if let Ok(device) = Device::new_cuda(0) {
                tracing::info!("Using CUDA device");
                return Ok(device);
            }
        }

        // Fall back to CPU
        tracing::info!("Using CPU device");
        Ok(Device::Cpu)
    }
}

Loading from HuggingFace Hub

impl WhisperModel {
    pub fn load(model_id: &str, cache_dir: Option<&Path>) -> Result<Self> {
        tracing::info!("Loading model: {}", model_id);

        // Setup HF Hub with custom cache
        let cache_path = cache_dir
            .map(|p| p.to_path_buf())
            .unwrap_or_else(|| PathBuf::from("models/hf"));

        std::env::set_var("HF_HOME", &cache_path);

        let api = hf_hub::api::sync::ApiBuilder::new()
            .with_cache_dir(cache_path)
            .build()?;

        let repo = api.repo(Repo::new(model_id.to_string(), RepoType::Model));

        // Download model files
        let config_path = repo.get("config.json")?;
        let tokenizer_path = repo.get("tokenizer.json")?;
        let weights_path = repo.get("model.safetensors")?;

        // Load configuration
        let config: Config = {
            let content = std::fs::read_to_string(&config_path)?;
            serde_json::from_str(&content)?
        };

        tracing::info!(
            "Config: {} encoder layers, {} decoder layers",
            config.encoder_layers,
            config.decoder_layers
        );

        // Initialize device
        let device = Self::init_device()?;

        // Load weights with memory mapping
        let vb = unsafe {
            VarBuilder::from_mmaped_safetensors(&[weights_path], DType::F32, &device)?
        };

        // Build model
        let model = m::model::Whisper::load(&vb, config)?;

        // Load tokenizer and mel filters
        let tokenizer = WhisperTokenizer::load(&tokenizer_path)?;
        let mel_filters = load_mel_filters()?;

        Ok(Self {
            model,
            tokenizer,
            mel_filters,
            device,
        })
    }
}

Audio Loading with Symphonia

use symphonia::core::audio::SampleBuffer;
use symphonia::core::codecs::DecoderOptions;
use symphonia::core::formats::FormatOptions;
use symphonia::core::io::MediaSourceStream;
use symphonia::core::probe::Hint;

pub fn load_audio(path: &Path) -> Result<Vec<f32>> {
    let file = std::fs::File::open(path)?;
    let mss = MediaSourceStream::new(Box::new(file), Default::default());

    let mut hint = Hint::new();
    if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
        hint.with_extension(ext);
    }

    let probed = symphonia::default::get_probe()
        .format(&hint, mss, &FormatOptions::default(), &Default::default())?;

    let mut format = probed.format;
    let track = format.default_track()
        .ok_or_else(|| Error::Audio("No audio track".into()))?;

    let mut decoder = symphonia::default::get_codecs()
        .make(&track.codec_params, &DecoderOptions::default())?;

    let track_id = track.id;
    let mut samples = Vec::new();

    loop {
        let packet = match format.next_packet() {
            Ok(p) => p,
            Err(symphonia::core::errors::Error::IoError(ref e))
                if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
            Err(e) => return Err(e.into()),
        };

        if packet.track_id() != track_id {
            continue;
        }

        let decoded = decoder.decode(&packet)?;
        let spec = *decoded.spec();

        let mut sample_buf = SampleBuffer::<f32>::new(decoded.capacity() as u64, spec);
        sample_buf.copy_interleaved_ref(decoded);

        // Convert to mono if stereo
        let channel_samples = sample_buf.samples();
        if spec.channels.count() > 1 {
            let channels = spec.channels.count();
            for chunk in channel_samples.chunks(channels) {
                let avg: f32 = chunk.iter().sum::<f32>() / channels as f32;
                samples.push(avg);
            }
        } else {
            samples.extend_from_slice(channel_samples);
        }
    }

    Ok(samples)
}

Mel Spectrogram Computation

const N_FFT: usize = 400;
const HOP_LENGTH: usize = 160;
const N_MELS: usize = 128;

pub fn pcm_to_mel(samples: &[f32], filters: &[f32], device: &Device) -> Result<Tensor> {
    let n_frames = (samples.len() - N_FFT) / HOP_LENGTH + 1;

    // Pre-compute Hann window
    let hann_window: Vec<f32> = (0..N_FFT)
        .map(|i| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / N_FFT as f32).cos()))
        .collect();

    // Compute STFT magnitudes
    let fft_size = N_FFT / 2 + 1;
    let mut magnitudes = vec![0.0f32; n_frames * fft_size];

    for frame_idx in 0..n_frames {
        let start = frame_idx * HOP_LENGTH;

        // Apply window
        let windowed: Vec<f32> = samples[start..start + N_FFT]
            .iter()
            .zip(&hann_window)
            .map(|(s, w)| s * w)
            .collect();

        // DFT for power spectrum
        for k in 0..fft_size {
            let mut real = 0.0f32;
            let mut imag = 0.0f32;

            for (n, &sample) in windowed.iter().enumerate() {
                let angle = -2.0 * std::f32::consts::PI * k as f32 * n as f32 / N_FFT as f32;
                real += sample * angle.cos();
                imag += sample * angle.sin();
            }

            magnitudes[frame_idx * fft_size + k] = real * real + imag * imag;
        }
    }

    // Apply mel filterbank
    let mut mel_spec = vec![0.0f32; n_frames * N_MELS];

    for frame in 0..n_frames {
        for mel in 0..N_MELS {
            let mut sum = 0.0f32;
            for k in 0..fft_size {
                sum += filters[mel * fft_size + k] * magnitudes[frame * fft_size + k];
            }
            mel_spec[frame * N_MELS + mel] = sum.max(1e-10);
        }
    }

    // Log scale and normalize
    let log_spec: Vec<f32> = mel_spec.iter().map(|&x| x.ln().max(-10.0)).collect();
    let max_val = log_spec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
    let normalized: Vec<f32> = log_spec
        .iter()
        .map(|&x| ((x - max_val) / 4.0).clamp(-1.0, 1.0))
        .collect();

    // Create tensor: (1, n_mels, n_frames)
    Tensor::from_vec(normalized, (1, N_MELS, n_frames), device)
        .map_err(Into::into)
}

Autoregressive Decoding

use candle_nn::ops::softmax;

pub struct Decoder<'a> {
    model: &'a mut m::model::Whisper,
    tokenizer: &'a WhisperTokenizer,
    device: &'a Device,
    suppress_tokens: Vec<u32>,
}

impl<'a> Decoder<'a> {
    pub fn decode(&mut self, audio_features: &Tensor) -> Result<String> {
        // Initial tokens: <|startoftranscript|><|en|><|transcribe|><|notimestamps|>
        let mut tokens: Vec<u32> = vec![50258, 50259, 50359, 50363];
        let mut all_tokens = tokens.clone();

        // Autoregressive loop
        for step in 0..448 {
            let token_tensor = Tensor::new(tokens.as_slice(), self.device)?
                .unsqueeze(0)?;

            // Run decoder
            let logits = self.model.decoder
                .forward(&token_tensor, audio_features, step == 0)?;

            // Get last token logits
            let seq_len = logits.dim(1)?;
            let last_logits = logits.i((.., seq_len - 1, ..))?;

            // Apply suppression and sample
            let last_logits = self.apply_suppression(&last_logits)?;
            let probs = softmax(&last_logits, candle_core::D::Minus1)?;

            let next_token = probs
                .argmax(candle_core::D::Minus1)?
                .to_dtype(DType::U32)?
                .to_vec1::<u32>()?[0];

            // Check for end of transcript
            if next_token == 50257 {  // <|endoftext|>
                break;
            }

            all_tokens.push(next_token);
            tokens = vec![next_token];
        }

        // Decode tokens to text
        let text_tokens: Vec<u32> = all_tokens
            .iter()
            .filter(|&&t| t < 50257)  // Filter special tokens
            .copied()
            .collect();

        self.tokenizer.decode(&text_tokens)
    }

    fn apply_suppression(&self, logits: &Tensor) -> Result<Tensor> {
        let mut logits_vec = logits.to_vec2::<f32>()?;

        // Suppress specified tokens
        for &token in &self.suppress_tokens {
            logits_vec[0][token as usize] = f32::NEG_INFINITY;
        }

        Tensor::new(logits_vec, self.device).map_err(Into::into)
    }
}

Global Model Caching

use std::sync::OnceLock;
use parking_lot::Mutex;

static WHISPER_MODEL: OnceLock<Mutex<WhisperModel>> = OnceLock::new();

pub fn transcribe(audio_path: &Path) -> Result<String> {
    let model = WHISPER_MODEL.get_or_init(|| {
        tracing::info!("Loading Whisper model (first use)...");
        Mutex::new(WhisperModel::load_default().expect("Failed to load model"))
    });

    let mut model_guard = model.lock();
    model_guard.transcribe(audio_path)
}

pub fn preload_model() -> Result<()> {
    if WHISPER_MODEL.get().is_some() {
        return Ok(());
    }

    let model = WhisperModel::load_default()?;
    let _ = WHISPER_MODEL.get_or_init(|| Mutex::new(model));
    Ok(())
}

VRAM Estimation

fn estimate_vram_gb(config: &Config) -> f32 {
    let encoder_params = config.encoder_layers * config.d_model * config.d_model * 4;
    let decoder_params = config.decoder_layers * config.d_model * config.d_model * 4;
    let vocab_params = config.vocab_size * config.d_model;
    let total_params = encoder_params + decoder_params + vocab_params;

    // float32 = 4 bytes, plus 20% overhead
    (total_params as f32 * 4.0 * 1.2) / (1024.0 * 1024.0 * 1024.0)
}

Guidelines

  • Use cuda feature for GPU acceleration
  • Memory-map weights with from_mmaped_safetensors
  • Cache models globally with OnceLock
  • Use Symphonia for pure-Rust audio decoding
  • Pre-compute mel filterbank coefficients
  • Implement token suppression for stable decoding
  • Estimate VRAM before loading models

Examples

See hercules-local-algo/src/whisper/ for complete Whisper implementation.