Claude Code Plugins

Community-maintained marketplace

Feedback

rust-gpu-scheduling

@gar-ai/mallorn
1
0

Implement GPU scheduling with VRAM budgets, work queues, and model lifecycle management. Use when building ML pipeline orchestrators.

Install Skill

1Download skill
2Enable skills in Claude

Open claude.ai/settings/capabilities and find the "Skills" section

3Upload to Claude

Click "Upload skill" and select the downloaded ZIP file

Note: Please verify skill by going through its instructions before using it.

SKILL.md

name rust-gpu-scheduling
description Implement GPU scheduling with VRAM budgets, work queues, and model lifecycle management. Use when building ML pipeline orchestrators.

GPU Scheduling

VRAM-aware GPU scheduling for ML model orchestration.

Model Types and VRAM Configuration

use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ModelType {
    Whisper = 0,    // 4.5 GB
    VideoMAE = 1,   // 5.0 GB
    CLAP = 2,       // 2.0 GB
    Qwen3 = 3,      // 1.5 GB
    DINOv3 = 4,     // 1.0 GB
}

#[derive(Debug, Clone)]
pub struct VramConfig {
    pub total_vram_gb: f64,      // e.g., 16.0
    pub safety_factor: f64,       // e.g., 0.9 (use 90%)
    pub whisper_vram_gb: f64,
    pub videomae_vram_gb: f64,
    pub clap_vram_gb: f64,
    pub qwen3_vram_gb: f64,
    pub dinov3_vram_gb: f64,
}

impl VramConfig {
    pub fn usable_vram_gb(&self) -> f64 {
        self.total_vram_gb * self.safety_factor
    }

    pub fn model_vram(&self, model: ModelType) -> f64 {
        match model {
            ModelType::Whisper => self.whisper_vram_gb,
            ModelType::VideoMAE => self.videomae_vram_gb,
            ModelType::CLAP => self.clap_vram_gb,
            ModelType::Qwen3 => self.qwen3_vram_gb,
            ModelType::DINOv3 => self.dinov3_vram_gb,
        }
    }

    pub fn can_fit(&self, model: ModelType) -> bool {
        self.model_vram(model) <= self.usable_vram_gb()
    }
}

impl Default for VramConfig {
    fn default() -> Self {
        Self {
            total_vram_gb: 16.0,
            safety_factor: 0.9,
            whisper_vram_gb: 4.5,
            videomae_vram_gb: 5.0,
            clap_vram_gb: 2.0,
            qwen3_vram_gb: 1.5,
            dinov3_vram_gb: 1.0,
        }
    }
}

Work Item with Dependencies

use std::collections::HashSet;
use std::time::{SystemTime, UNIX_EPOCH};

#[derive(Debug, Clone)]
pub struct WorkItem {
    pub id: String,
    pub video_id: String,
    pub model_type: ModelType,
    pub input_data: String,
    pub priority: u32,
    pub dependencies: Vec<String>,
    pub submitted_at: f64,
    pub resolved_deps: HashSet<String>,
}

impl WorkItem {
    pub fn new(
        video_id: String,
        model_type: ModelType,
        input_data: String,
        dependencies: Option<Vec<String>>,
        priority: Option<u32>,
    ) -> Self {
        let id = format!("{}_{:?}_{}", video_id, model_type, uuid::Uuid::new_v4());

        Self {
            id,
            video_id,
            model_type,
            input_data,
            priority: priority.unwrap_or(0),
            dependencies: dependencies.unwrap_or_default(),
            submitted_at: SystemTime::now()
                .duration_since(UNIX_EPOCH)
                .map(|d| d.as_secs_f64())
                .unwrap_or(0.0),
            resolved_deps: HashSet::new(),
        }
    }

    pub fn is_ready(&self) -> bool {
        self.dependencies.iter()
            .all(|dep| self.resolved_deps.contains(dep))
    }

    pub fn resolve_dependency(&mut self, dep_id: &str) {
        self.resolved_deps.insert(dep_id.to_string());
    }

    pub fn waiting_time(&self) -> f64 {
        let now = SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .map(|d| d.as_secs_f64())
            .unwrap_or(0.0);
        now - self.submitted_at
    }
}

Work Queue with Priority

use std::collections::{HashMap, VecDeque};

pub struct WorkQueue {
    model_type: ModelType,
    ready_items: VecDeque<WorkItem>,
    waiting_items: Vec<WorkItem>,
}

impl WorkQueue {
    pub fn new(model_type: ModelType) -> Self {
        Self {
            model_type,
            ready_items: VecDeque::new(),
            waiting_items: Vec::new(),
        }
    }

    pub fn push(&mut self, item: WorkItem) {
        if item.is_ready() {
            // Insert by priority (higher priority first)
            let pos = self.ready_items
                .iter()
                .position(|i| i.priority < item.priority)
                .unwrap_or(self.ready_items.len());
            self.ready_items.insert(pos, item);
        } else {
            self.waiting_items.push(item);
        }
    }

    pub fn pop(&mut self) -> Option<WorkItem> {
        self.ready_items.pop_front()
    }

    pub fn resolve_dependency(&mut self, dep_id: &str) {
        let mut newly_ready = Vec::new();
        let mut i = 0;

        while i < self.waiting_items.len() {
            self.waiting_items[i].resolve_dependency(dep_id);

            if self.waiting_items[i].is_ready() {
                let item = self.waiting_items.swap_remove(i);
                newly_ready.push(item);
            } else {
                i += 1;
            }
        }

        for item in newly_ready {
            self.push(item);
        }
    }

    pub fn ready_count(&self) -> usize {
        self.ready_items.len()
    }

    pub fn waiting_count(&self) -> usize {
        self.waiting_items.len()
    }
}

Model State Manager

use std::time::{Duration, Instant};

pub struct ModelStateManager {
    current_model: Option<ModelType>,
    loaded_at: Option<Instant>,
    load_counts: [u32; 5],
    last_load_times: [Option<Duration>; 5],
}

impl ModelStateManager {
    pub fn new() -> Self {
        Self {
            current_model: None,
            loaded_at: None,
            load_counts: [0; 5],
            last_load_times: [None; 5],
        }
    }

    pub fn mark_loaded(&mut self, model: ModelType, load_time: Duration) {
        self.current_model = Some(model);
        self.loaded_at = Some(Instant::now());
        self.load_counts[model as usize] += 1;
        self.last_load_times[model as usize] = Some(load_time);
    }

    pub fn mark_unloaded(&mut self) {
        self.current_model = None;
        self.loaded_at = None;
    }

    pub fn current_model(&self) -> Option<ModelType> {
        self.current_model
    }

    pub fn estimated_load_time(&self, model: ModelType) -> Duration {
        self.last_load_times[model as usize].unwrap_or_else(|| {
            // Default estimates
            match model {
                ModelType::Whisper => Duration::from_secs(5),
                ModelType::VideoMAE => Duration::from_secs(8),
                ModelType::CLAP => Duration::from_secs(4),
                ModelType::Qwen3 => Duration::from_secs(10),
                ModelType::DINOv3 => Duration::from_secs(6),
            }
        })
    }

    pub fn needs_switch(&self, target: ModelType) -> bool {
        self.current_model != Some(target)
    }
}

GPU Scheduler

use parking_lot::RwLock;
use std::sync::Arc;

pub struct GPUScheduler {
    queues: Arc<RwLock<HashMap<ModelType, WorkQueue>>>,
    state: Arc<RwLock<ModelStateManager>>,
    vram_config: VramConfig,
}

impl GPUScheduler {
    pub fn new(vram_config: VramConfig) -> Self {
        let mut queues = HashMap::new();
        for model in [
            ModelType::Whisper,
            ModelType::VideoMAE,
            ModelType::CLAP,
            ModelType::Qwen3,
            ModelType::DINOv3,
        ] {
            queues.insert(model, WorkQueue::new(model));
        }

        Self {
            queues: Arc::new(RwLock::new(queues)),
            state: Arc::new(RwLock::new(ModelStateManager::new())),
            vram_config,
        }
    }

    pub fn submit(&self, item: WorkItem) -> String {
        let id = item.id.clone();
        let model = item.model_type;

        let mut queues = self.queues.write();
        if let Some(queue) = queues.get_mut(&model) {
            queue.push(item);
        }

        id
    }

    pub fn resolve_dependency(&self, dep_id: &str) {
        let mut queues = self.queues.write();
        for queue in queues.values_mut() {
            queue.resolve_dependency(dep_id);
        }
    }

    pub fn get_next_batch(&self, max_batch_size: usize) -> Option<(ModelType, Vec<WorkItem>)> {
        let state = self.state.read();
        let mut queues = self.queues.write();

        // Prefer current model to minimize switches
        if let Some(current) = state.current_model() {
            if let Some(queue) = queues.get_mut(&current) {
                if queue.ready_count() > 0 {
                    let mut batch = Vec::new();
                    while batch.len() < max_batch_size {
                        if let Some(item) = queue.pop() {
                            batch.push(item);
                        } else {
                            break;
                        }
                    }
                    if !batch.is_empty() {
                        return Some((current, batch));
                    }
                }
            }
        }

        // Find queue with most ready items
        let best_model = queues
            .iter()
            .filter(|(_, q)| q.ready_count() > 0)
            .max_by_key(|(_, q)| q.ready_count())
            .map(|(m, _)| *m);

        if let Some(model) = best_model {
            let queue = queues.get_mut(&model).unwrap();
            let mut batch = Vec::new();
            while batch.len() < max_batch_size {
                if let Some(item) = queue.pop() {
                    batch.push(item);
                } else {
                    break;
                }
            }
            if !batch.is_empty() {
                return Some((model, batch));
            }
        }

        None
    }

    pub fn total_pending(&self) -> usize {
        self.queues.read()
            .values()
            .map(|q| q.ready_count() + q.waiting_count())
            .sum()
    }
}

Semaphore-Based VRAM Limiting

use tokio::sync::Semaphore;

// Alternative: Use semaphore for simpler VRAM tracking
pub struct VramSemaphore {
    semaphore: Semaphore,
    units_per_gb: u32,
}

impl VramSemaphore {
    pub fn new(total_gb: f64, units_per_gb: u32) -> Self {
        let total_units = (total_gb * units_per_gb as f64) as usize;
        Self {
            semaphore: Semaphore::new(total_units),
            units_per_gb,
        }
    }

    pub async fn acquire(&self, vram_gb: f64) -> Result<SemaphorePermit<'_>> {
        let units = (vram_gb * self.units_per_gb as f64) as u32;
        self.semaphore.acquire_many(units).await
            .map_err(|_| Error::VramAcquisition)
    }
}

// Usage
let vram = VramSemaphore::new(16.0, 10);  // 160 units total

async fn run_whisper(vram: &VramSemaphore) -> Result<()> {
    let _permit = vram.acquire(4.5).await?;  // 45 units
    // Run model - permit released on drop
    Ok(())
}

Guidelines

  • Track VRAM usage per model type
  • Use safety factor (90%) to avoid OOM
  • Minimize model switches (prefer current model)
  • Support work item dependencies
  • Use priority queues for important work
  • Track model load times for scheduling decisions
  • Use semaphores for simpler VRAM limiting

Examples

See hercules-local-algo/src/scheduler/ for complete implementation.