From d4941dc95a6db62cdbd8a7d312c93dcffb7c5c47 Mon Sep 17 00:00:00 2001 From: "Dhanji R. Prasanna" Date: Thu, 29 Jan 2026 11:39:46 +1100 Subject: [PATCH] refactor(providers): improve readability of embedded.rs and gemini.rs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit embedded.rs (937→789 lines, -16%): - Extract duplicated inference setup into prepare_context() helper - Extract stop sequence handling into find_stop_sequence() and truncate_at_stop_sequence() - Add InferenceParams struct to consolidate request parameter extraction - Add clear section markers for code organization - Tests now use module-level format functions directly (no duplication) gemini.rs: - Extract common request building into build_request() method - Reduces duplication between complete() and stream() methods All 399 unit tests pass. Behavior unchanged. Agent: carmack --- crates/g3-providers/src/embedded.rs | 986 ++++++++++++---------------- crates/g3-providers/src/gemini.rs | 30 +- 2 files changed, 432 insertions(+), 584 deletions(-) diff --git a/crates/g3-providers/src/embedded.rs b/crates/g3-providers/src/embedded.rs index ce62d42..17211e7 100644 --- a/crates/g3-providers/src/embedded.rs +++ b/crates/g3-providers/src/embedded.rs @@ -1,14 +1,23 @@ +//! Embedded LLM provider using llama.cpp with Metal acceleration on macOS. +//! +//! Supports multiple model families with their native chat templates: +//! - Qwen (ChatML format) +//! - GLM-4 (ChatGLM4 format) +//! - Mistral (Instruct format) +//! - Llama/CodeLlama (Llama2 format) + use crate::{ - CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message, - MessageRole, Usage, - streaming::{make_text_chunk, make_final_chunk_with_reason}, + CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message, MessageRole, + Usage, + streaming::{make_final_chunk_with_reason, make_text_chunk}, }; use anyhow::Result; use llama_cpp_2::{ + context::LlamaContext, context::params::LlamaContextParams, llama_backend::LlamaBackend, llama_batch::LlamaBatch, - model::{params::LlamaModelParams, AddBos, LlamaModel, Special}, + model::{AddBos, LlamaModel, Special, params::LlamaModelParams}, sampling::LlamaSampler, }; use std::num::NonZeroU32; @@ -19,74 +28,74 @@ use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tracing::{debug, error}; +// ============================================================================ +// Global Backend +// ============================================================================ + /// Global llama.cpp backend - can only be initialized once per process static LLAMA_BACKEND: OnceLock> = OnceLock::new(); /// Get or initialize the global llama.cpp backend fn get_or_init_backend() -> Result> { - // Check if already initialized if let Some(backend) = LLAMA_BACKEND.get() { return Ok(Arc::clone(backend)); } - - // Suppress llama.cpp's verbose logging to stderr before initialization + + // Suppress llama.cpp's verbose logging to stderr + suppress_llama_logging(); + + debug!("Initializing llama.cpp backend..."); + let backend = LlamaBackend::init() + .map_err(|e| anyhow::anyhow!("Failed to initialize llama.cpp backend: {:?}", e))?; + + // Store it (ignore if another thread beat us to it) + let _ = LLAMA_BACKEND.set(Arc::new(backend)); + Ok(Arc::clone(LLAMA_BACKEND.get().expect("backend was just set"))) +} + +fn suppress_llama_logging() { unsafe { unsafe extern "C" fn void_log( _level: std::ffi::c_int, _text: *const std::os::raw::c_char, _user_data: *mut std::os::raw::c_void, ) { - // Intentionally empty - suppress all llama.cpp logging + // Intentionally empty + } + extern "C" { + fn llama_log_set( + log_callback: Option< + unsafe extern "C" fn( + std::ffi::c_int, + *const std::os::raw::c_char, + *mut std::os::raw::c_void, + ), + >, + user_data: *mut std::os::raw::c_void, + ); } - // Call the underlying C function directly - extern "C" { fn llama_log_set(log_callback: Option, user_data: *mut std::os::raw::c_void); } llama_log_set(Some(void_log), std::ptr::null_mut()); } - - // Try to initialize - debug!("Initializing llama.cpp backend..."); - let backend = LlamaBackend::init() - .map_err(|e| anyhow::anyhow!("Failed to initialize llama.cpp backend: {:?}", e))?; - - // Store it (ignore if another thread beat us to it) - let _ = LLAMA_BACKEND.set(Arc::new(backend)); - let backend = LLAMA_BACKEND.get().expect("backend was just set"); - Ok(Arc::clone(backend)) } -/// Embedded LLM provider using llama.cpp with Metal acceleration on macOS. -/// -/// Supports multiple model families with their native chat templates: -/// - Qwen (ChatML format) -/// - GLM-4 (ChatGLM4 format) -/// - Mistral (Instruct format) -/// - Llama/CodeLlama (Llama2 format) +// ============================================================================ +// Provider Struct +// ============================================================================ + pub struct EmbeddedProvider { - /// Provider name in format "embedded.{config_name}" name: String, - /// The loaded model model: Arc, - /// The llama.cpp backend (must be kept alive) backend: Arc, - /// Model type identifier (e.g., "qwen", "glm4", "mistral") model_type: String, - /// Full model name for display model_name: String, - /// Maximum tokens to generate (None = auto-calculate) max_tokens: Option, - /// Sampling temperature temperature: f32, - /// Context window size context_length: u32, - /// Number of threads threads: Option, } impl EmbeddedProvider { - /// Create a new embedded provider with default naming. - /// - /// The provider will be registered as "embedded" (legacy behavior). - /// For proper multi-provider support, use `new_with_name()` instead. + /// Create a new embedded provider with default naming ("embedded"). pub fn new( model_path: String, model_type: String, @@ -109,16 +118,6 @@ impl EmbeddedProvider { } /// Create a new embedded provider with a custom name. - /// - /// # Arguments - /// * `name` - Provider name (e.g., "embedded.glm4", "embedded.qwen") - /// * `model_path` - Path to the GGUF model file (supports ~ expansion) - /// * `model_type` - Model family identifier ("qwen", "glm4", "glm", "mistral", "llama", etc.) - /// * `context_length` - Context window size (default: auto-detected from GGUF) - /// * `max_tokens` - Maximum tokens to generate (default: min(4096, context/4)) - /// * `temperature` - Sampling temperature (default: 0.1) - /// * `gpu_layers` - Number of layers to offload to GPU (default: 99 for Apple Silicon) - /// * `threads` - Number of CPU threads for inference pub fn new_with_name( name: String, model_path: String, @@ -131,7 +130,6 @@ impl EmbeddedProvider { ) -> Result { debug!("Loading embedded model from: {}", model_path); - // Expand tilde in path let expanded_path = shellexpand::tilde(&model_path); let model_path_buf = PathBuf::from(expanded_path.as_ref()); @@ -139,20 +137,16 @@ impl EmbeddedProvider { anyhow::bail!("Model file not found: {}", model_path_buf.display()); } - // Get or initialize the global llama.cpp backend let backend = get_or_init_backend()?; - // Set up model parameters let n_gpu_layers = gpu_layers.unwrap_or(99); let model_params = LlamaModelParams::default().with_n_gpu_layers(n_gpu_layers); debug!("Using {} GPU layers", n_gpu_layers); - // Load the model debug!("Loading model..."); let model = LlamaModel::load_from_file(&backend, &model_path_buf, &model_params) .map_err(|e| anyhow::anyhow!("Failed to load model: {:?}", e))?; - // Auto-detect context length from GGUF metadata, or use provided value let model_ctx_train = model.n_ctx_train(); let context_size = context_length.unwrap_or(model_ctx_train); debug!( @@ -175,211 +169,275 @@ impl EmbeddedProvider { }) } - /// Format messages according to the model's native chat template. - fn format_messages(&self, messages: &[Message]) -> String { - let model_type = &self.model_type; - - if model_type.contains("glm") { - self.format_glm4_messages(messages) - } else if model_type.contains("qwen") { - self.format_qwen_messages(messages) - } else if model_type.contains("mistral") { - self.format_mistral_messages(messages) - } else { - // Default to Llama format - self.format_llama_messages(messages) - } - } - - /// GLM-4 ChatGLM4 format: [gMASK]<|role|>\ncontent - fn format_glm4_messages(&self, messages: &[Message]) -> String { - let mut formatted = String::from("[gMASK]"); - - for message in messages { - let role = match message.role { - MessageRole::System => "<|system|>", - MessageRole::User => "<|user|>", - MessageRole::Assistant => "<|assistant|>", - }; - formatted.push_str(&format!("{}\n{}", role, message.content)); - } - - // Add the start of assistant response - formatted.push_str("<|assistant|>\n"); - formatted - } - - /// Qwen ChatML format: <|im_start|>role\ncontent<|im_end|> - fn format_qwen_messages(&self, messages: &[Message]) -> String { - let mut formatted = String::new(); - - for message in messages { - let role = match message.role { - MessageRole::System => "system", - MessageRole::User => "user", - MessageRole::Assistant => "assistant", - }; - - formatted.push_str(&format!( - "<|im_start|>{}\n{}<|im_end|>\n", - role, message.content - )); - } - - // Add the start of assistant response - formatted.push_str("<|im_start|>assistant\n"); - formatted - } - - /// Mistral Instruct format: [INST] ... [/INST] response - fn format_mistral_messages(&self, messages: &[Message]) -> String { - let mut formatted = String::new(); - let mut in_conversation = false; - - for (i, message) in messages.iter().enumerate() { - match message.role { - MessageRole::System => { - // Mistral doesn't have a special system token, include it at the start - if i == 0 { - formatted.push_str("[INST] "); - formatted.push_str(&message.content); - formatted.push_str("\n\n"); - in_conversation = true; - } - } - MessageRole::User => { - if !in_conversation { - formatted.push_str("[INST] "); - } - formatted.push_str(&message.content); - formatted.push_str(" [/INST]"); - in_conversation = false; - } - MessageRole::Assistant => { - formatted.push(' '); - formatted.push_str(&message.content); - formatted.push_str(" "); - in_conversation = false; - } - } - } - - // If the last message was from user, add a space for the assistant's response - if messages - .last() - .is_some_and(|m| matches!(m.role, MessageRole::User)) - { - formatted.push(' '); - } - - formatted - } - - /// Llama/CodeLlama format: [INST] <>\nsystem<>\n\nuser [/INST] - fn format_llama_messages(&self, messages: &[Message]) -> String { - let mut formatted = String::new(); - - for message in messages { - match message.role { - MessageRole::System => { - formatted.push_str(&format!( - "[INST] <>\n{}\n<>\n\n", - message.content - )); - } - MessageRole::User => { - formatted.push_str(&format!("{} [/INST] ", message.content)); - } - MessageRole::Assistant => { - formatted.push_str(&format!("{} [INST] ", message.content)); - } - } - } - formatted - } - - /// Estimate token count from text (rough approximation: ~4 chars per token) - fn estimate_tokens(&self, text: &str) -> u32 { - (text.len() as f32 / 4.0).ceil() as u32 - } - - /// Get stop sequences based on model type. - fn get_stop_sequences(&self) -> Vec<&'static str> { - let model_type = &self.model_type; - - if model_type.contains("glm") { - vec![ - "<|endoftext|>", // GLM end of text - "<|user|>", // Start of new user turn - "<|observation|>", // Tool observation (shouldn't appear in response) - "<|system|>", // System message (shouldn't appear in response) - ] - } else if model_type.contains("qwen") { - vec![ - "<|im_end|>", // Qwen ChatML format end token - "<|endoftext|>", // Alternative end token - "", // Generic end of sequence - "<|im_start|>", // Start of new message (shouldn't appear in response) - ] - } else if model_type.contains("codellama") || model_type.contains("code-llama") { - vec![ - "", // End of sequence - "[/INST]", // End of instruction - "<>", // End of system message - "[INST]", // Start of new instruction - "<>", // Start of system - ] - } else if model_type.contains("llama") { - vec![ - "", // End of sequence - "[/INST]", // End of instruction - "<>", // End of system message - "### Human:", // Conversation format - "### Assistant:", // Conversation format - "[INST]", // Start of new instruction - ] - } else if model_type.contains("mistral") { - vec![ - "", // End of sequence - "[/INST]", // End of instruction - "<|im_end|>", // ChatML format (some Mistral fine-tunes) - ] - } else if model_type.contains("vicuna") || model_type.contains("wizard") { - vec![ - "### Human:", // Conversation format - "### Assistant:", // Conversation format - "USER:", // Alternative format - "ASSISTANT:", // Alternative format - "", // End of sequence - ] - } else if model_type.contains("alpaca") { - vec![ - "### Instruction:", // Alpaca format - "### Response:", // Alpaca format - "### Input:", // Alpaca format - "", // End of sequence - ] - } else { - // Generic/unknown model - use common stop sequences - vec![ - "", // Most common end sequence - "<|endoftext|>", // GPT-style - "<|im_end|>", // ChatML - "### Human:", // Common conversation format - "### Assistant:", // Common conversation format - "[/INST]", // Instruction format - "<>", // System format - ] - } - } - - /// Get the effective max tokens for generation fn effective_max_tokens(&self) -> u32 { self.max_tokens .unwrap_or_else(|| std::cmp::min(4096, self.context_length / 4)) } + + /// Estimate token count from text (~4 chars per token) + fn estimate_tokens(&self, text: &str) -> u32 { + (text.len() as f32 / 4.0).ceil() as u32 + } } +// ============================================================================ +// Chat Template Formatting +// ============================================================================ + +impl EmbeddedProvider { + /// Format messages according to the model's native chat template. + fn format_messages(&self, messages: &[Message]) -> String { + match self.model_type.as_str() { + t if t.contains("glm") => format_glm4(messages), + t if t.contains("qwen") => format_qwen(messages), + t if t.contains("mistral") => format_mistral(messages), + _ => format_llama(messages), + } + } + + /// Get stop sequences based on model type. + fn get_stop_sequences(&self) -> &'static [&'static str] { + get_stop_sequences_for_model(&self.model_type) + } +} + +/// GLM-4 ChatGLM4 format: [gMASK]<|role|>\ncontent +fn format_glm4(messages: &[Message]) -> String { + let mut out = String::from("[gMASK]"); + for msg in messages { + let role = match msg.role { + MessageRole::System => "<|system|>", + MessageRole::User => "<|user|>", + MessageRole::Assistant => "<|assistant|>", + }; + out.push_str(&format!("{}\n{}", role, msg.content)); + } + out.push_str("<|assistant|>\n"); + out +} + +/// Qwen ChatML format: <|im_start|>role\ncontent<|im_end|> +fn format_qwen(messages: &[Message]) -> String { + let mut out = String::new(); + for msg in messages { + let role = match msg.role { + MessageRole::System => "system", + MessageRole::User => "user", + MessageRole::Assistant => "assistant", + }; + out.push_str(&format!("<|im_start|>{}\n{}<|im_end|>\n", role, msg.content)); + } + out.push_str("<|im_start|>assistant\n"); + out +} + +/// Mistral Instruct format: [INST] ... [/INST] response +fn format_mistral(messages: &[Message]) -> String { + let mut out = String::new(); + let mut in_inst = false; + + for (i, msg) in messages.iter().enumerate() { + match msg.role { + MessageRole::System if i == 0 => { + out.push_str("[INST] "); + out.push_str(&msg.content); + out.push_str("\n\n"); + in_inst = true; + } + MessageRole::System => {} // Ignore non-first system messages + MessageRole::User => { + if !in_inst { + out.push_str("[INST] "); + } + out.push_str(&msg.content); + out.push_str(" [/INST]"); + in_inst = false; + } + MessageRole::Assistant => { + out.push(' '); + out.push_str(&msg.content); + out.push_str(" "); + in_inst = false; + } + } + } + + if messages.last().is_some_and(|m| matches!(m.role, MessageRole::User)) { + out.push(' '); + } + out +} + +/// Llama/CodeLlama format: [INST] <>\nsystem<>\n\nuser [/INST] +fn format_llama(messages: &[Message]) -> String { + let mut out = String::new(); + for msg in messages { + match msg.role { + MessageRole::System => { + out.push_str(&format!("[INST] <>\n{}\n<>\n\n", msg.content)); + } + MessageRole::User => { + out.push_str(&format!("{} [/INST] ", msg.content)); + } + MessageRole::Assistant => { + out.push_str(&format!("{} [INST] ", msg.content)); + } + } + } + out +} + +/// Get stop sequences for a model type. +fn get_stop_sequences_for_model(model_type: &str) -> &'static [&'static str] { + if model_type.contains("glm") { + &["<|endoftext|>", "<|user|>", "<|observation|>", "<|system|>"] + } else if model_type.contains("qwen") { + &["<|im_end|>", "<|endoftext|>", "", "<|im_start|>"] + } else if model_type.contains("code-llama") || model_type.contains("codellama") { + &["", "[/INST]", "<>", "[INST]", "<>"] + } else if model_type.contains("llama") { + &[ + "", + "[/INST]", + "<>", + "### Human:", + "### Assistant:", + "[INST]", + ] + } else if model_type.contains("mistral") { + &["", "[/INST]", "<|im_end|>"] + } else if model_type.contains("vicuna") || model_type.contains("wizard") { + &[ + "### Human:", + "### Assistant:", + "USER:", + "ASSISTANT:", + "", + ] + } else if model_type.contains("alpaca") { + &["### Instruction:", "### Response:", "### Input:", ""] + } else { + // Generic fallback + &[ + "", + "<|endoftext|>", + "<|im_end|>", + "### Human:", + "### Assistant:", + "[/INST]", + "<>", + ] + } +} + +// ============================================================================ +// Inference Helpers +// ============================================================================ + +/// Parameters for inference, extracted from request and provider defaults. +struct InferenceParams { + prompt: String, + max_tokens: u32, + temperature: f32, + stop_sequences: Vec, +} + +/// Prepared inference context with tokenized prompt ready for generation. +struct PreparedContext<'a> { + ctx: LlamaContext<'a>, + batch: LlamaBatch, + sampler: LlamaSampler, + token_count: i32, +} + +impl EmbeddedProvider { + /// Extract inference parameters from a completion request. + fn extract_params(&self, request: &CompletionRequest) -> InferenceParams { + InferenceParams { + prompt: self.format_messages(&request.messages), + max_tokens: request.max_tokens.unwrap_or_else(|| self.effective_max_tokens()), + temperature: request.temperature.unwrap_or(self.temperature), + stop_sequences: self + .get_stop_sequences() + .iter() + .map(|s| s.to_string()) + .collect(), + } + } +} + +/// Prepare the inference context: create context, tokenize prompt, decode initial batch. +fn prepare_context<'a>( + model: &'a LlamaModel, + backend: &'a LlamaBackend, + prompt: &str, + temperature: f32, + context_length: u32, + threads: Option, +) -> Result> { + let n_ctx = NonZeroU32::new(context_length).unwrap_or(NonZeroU32::new(4096).unwrap()); + let mut ctx_params = LlamaContextParams::default() + .with_n_ctx(Some(n_ctx)) + .with_n_batch(context_length); + if let Some(n_threads) = threads { + ctx_params = ctx_params.with_n_threads(n_threads as i32); + } + + let mut ctx = model + .new_context(backend, ctx_params) + .map_err(|e| anyhow::anyhow!("Failed to create context: {:?}", e))?; + + let tokens = model + .str_to_token(prompt, AddBos::Always) + .map_err(|e| anyhow::anyhow!("Failed to tokenize: {:?}", e))?; + + debug!("Tokenized prompt: {} tokens", tokens.len()); + + let batch_size = std::cmp::max(512, tokens.len()); + let mut batch = LlamaBatch::new(batch_size, 1); + for (i, token) in tokens.iter().enumerate() { + batch + .add(*token, i as i32, &[0], i == tokens.len() - 1) + .map_err(|e| anyhow::anyhow!("Failed to add token to batch: {:?}", e))?; + } + + ctx.decode(&mut batch) + .map_err(|e| anyhow::anyhow!("Failed to decode prompt: {:?}", e))?; + + let sampler = LlamaSampler::chain_simple([ + LlamaSampler::temp(temperature), + LlamaSampler::dist(1234), + ]); + + Ok(PreparedContext { + ctx, + batch, + sampler, + token_count: tokens.len() as i32, + }) +} + +/// Check if text contains any stop sequence. Returns the truncation position if found. +fn find_stop_sequence(text: &str, stop_sequences: &[String]) -> Option { + for stop_seq in stop_sequences { + if let Some(pos) = text.find(stop_seq) { + return Some(pos); + } + } + None +} + +/// Truncate text at the first stop sequence, if any. +fn truncate_at_stop_sequence(text: &mut String, stop_sequences: &[String]) { + if let Some(pos) = find_stop_sequence(text, stop_sequences) { + text.truncate(pos); + } +} + +// ============================================================================ +// LLMProvider Implementation +// ============================================================================ + #[async_trait::async_trait] impl LLMProvider for EmbeddedProvider { async fn complete(&self, request: CompletionRequest) -> Result { @@ -388,115 +446,66 @@ impl LLMProvider for EmbeddedProvider { request.messages.len() ); - let prompt = self.format_messages(&request.messages); - let max_tokens = request.max_tokens.unwrap_or_else(|| self.effective_max_tokens()); - let temperature = request.temperature.unwrap_or(self.temperature); + let params = self.extract_params(&request); + let prompt_tokens = self.estimate_tokens(¶ms.prompt); - debug!("Formatted prompt length: {} chars", prompt.len()); - - // Estimate prompt tokens before moving prompt into closure - let prompt_tokens = self.estimate_tokens(&prompt); + debug!("Formatted prompt length: {} chars", params.prompt.len()); // Clone what we need for the blocking task let model = self.model.clone(); let backend = self.backend.clone(); let context_length = self.context_length; let threads = self.threads; - let stop_sequences: Vec = self.get_stop_sequences().iter().map(|s| s.to_string()).collect(); + let model_name = self.model_name.clone(); let (content, completion_tokens) = tokio::task::spawn_blocking(move || { - // Create context for this completion - let n_ctx = NonZeroU32::new(context_length).unwrap_or(NonZeroU32::new(4096).unwrap()); - let mut ctx_params = LlamaContextParams::default() - .with_n_ctx(Some(n_ctx)) - .with_n_batch(context_length); // Batch size must accommodate full prompt - if let Some(n_threads) = threads { - ctx_params = ctx_params.with_n_threads(n_threads as i32); - } + let mut prepared = prepare_context( + &model, + &backend, + ¶ms.prompt, + params.temperature, + context_length, + threads, + )?; - let mut ctx = model - .new_context(&backend, ctx_params) - .map_err(|e| anyhow::anyhow!("Failed to create context: {:?}", e))?; - - // Tokenize the prompt - let tokens = model - .str_to_token(&prompt, AddBos::Always) - .map_err(|e| anyhow::anyhow!("Failed to tokenize: {:?}", e))?; - - debug!("Tokenized prompt: {} tokens", tokens.len()); - - // Create batch large enough for the prompt tokens - // The batch size must be at least as large as the number of tokens we're adding - let batch_size = std::cmp::max(512, tokens.len()); - let mut batch = LlamaBatch::new(batch_size, 1); - for (i, token) in tokens.iter().enumerate() { - batch - .add(*token, i as i32, &[0], i == tokens.len() - 1) - .map_err(|e| anyhow::anyhow!("Failed to add token to batch: {:?}", e))?; - } - - // Decode the prompt - ctx.decode(&mut batch) - .map_err(|e| anyhow::anyhow!("Failed to decode prompt: {:?}", e))?; - - // Set up sampler - let mut sampler = LlamaSampler::chain_simple([ - LlamaSampler::temp(temperature), - LlamaSampler::dist(1234), - ]); - - // Generate tokens let mut generated_text = String::new(); - let mut n_cur = tokens.len() as i32; let mut token_count = 0u32; - for _ in 0..max_tokens { - let new_token = sampler.sample(&ctx, batch.n_tokens() - 1); - sampler.accept(new_token); + for _ in 0..params.max_tokens { + let new_token = prepared.sampler.sample(&prepared.ctx, prepared.batch.n_tokens() - 1); + prepared.sampler.accept(new_token); - // Check for end of generation if model.is_eog_token(new_token) { debug!("Hit end-of-generation token at {} tokens", token_count); break; } - // Decode token to string - let token_str = model.token_to_str(new_token, Special::Tokenize) + let token_str = model + .token_to_str(new_token, Special::Tokenize) .unwrap_or_default(); generated_text.push_str(&token_str); token_count += 1; - // Check for stop sequences - let mut hit_stop = false; - for stop_seq in &stop_sequences { - if generated_text.contains(stop_seq) { - debug!("Hit stop sequence '{}' at {} tokens", stop_seq, token_count); - hit_stop = true; - break; - } - } - if hit_stop { + if find_stop_sequence(&generated_text, ¶ms.stop_sequences).is_some() { + debug!("Hit stop sequence at {} tokens", token_count); break; } // Prepare next batch - batch.clear(); - batch - .add(new_token, n_cur, &[0], true) + prepared.batch.clear(); + prepared + .batch + .add(new_token, prepared.token_count, &[0], true) .map_err(|e| anyhow::anyhow!("Failed to add token to batch: {:?}", e))?; - n_cur += 1; + prepared.token_count += 1; - ctx.decode(&mut batch) + prepared + .ctx + .decode(&mut prepared.batch) .map_err(|e| anyhow::anyhow!("Failed to decode: {:?}", e))?; } - // Clean stop sequences from output - for stop_seq in &stop_sequences { - if let Some(pos) = generated_text.find(stop_seq) { - generated_text.truncate(pos); - break; - } - } + truncate_at_stop_sequence(&mut generated_text, ¶ms.stop_sequences); Ok::<_, anyhow::Error>((generated_text.trim().to_string(), token_count)) }) @@ -512,7 +521,7 @@ impl LLMProvider for EmbeddedProvider { cache_creation_tokens: 0, cache_read_tokens: 0, }, - model: self.model_name.clone(), + model: model_name, }) } @@ -522,155 +531,84 @@ impl LLMProvider for EmbeddedProvider { request.messages.len() ); - let prompt = self.format_messages(&request.messages); - let max_tokens = request.max_tokens.unwrap_or_else(|| self.effective_max_tokens()); - let temperature = request.temperature.unwrap_or(self.temperature); - - // Estimate prompt tokens for usage tracking - let prompt_tokens = self.estimate_tokens(&prompt); + let params = self.extract_params(&request); + let prompt_tokens = self.estimate_tokens(¶ms.prompt); let (tx, rx) = mpsc::channel(100); - // Clone what we need for the blocking task let model = self.model.clone(); let backend = self.backend.clone(); let context_length = self.context_length; let threads = self.threads; - let stop_sequences: Vec = self.get_stop_sequences().iter().map(|s| s.to_string()).collect(); tokio::task::spawn_blocking(move || { - // Create context for this completion - let n_ctx = NonZeroU32::new(context_length).unwrap_or(NonZeroU32::new(4096).unwrap()); - let mut ctx_params = LlamaContextParams::default() - .with_n_ctx(Some(n_ctx)) - .with_n_batch(context_length); // Batch size must accommodate full prompt - if let Some(n_threads) = threads { - ctx_params = ctx_params.with_n_threads(n_threads as i32); - } - - let mut ctx = match model.new_context(&backend, ctx_params) { - Ok(ctx) => ctx, + let mut prepared = match prepare_context( + &model, + &backend, + ¶ms.prompt, + params.temperature, + context_length, + threads, + ) { + Ok(p) => p, Err(e) => { - let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to create context: {:?}", e))); + let _ = tx.blocking_send(Err(e)); return; } }; - // Tokenize the prompt - let tokens = match model.str_to_token(&prompt, AddBos::Always) { - Ok(t) => t, - Err(e) => { - let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to tokenize: {:?}", e))); - return; - } - }; - - debug!("Tokenized prompt: {} tokens", tokens.len()); - - // Create batch large enough for the prompt tokens - // The batch size must be at least as large as the number of tokens we're adding - let batch_size = std::cmp::max(512, tokens.len()); - let mut batch = LlamaBatch::new(batch_size, 1); - for (i, token) in tokens.iter().enumerate() { - if let Err(e) = batch.add(*token, i as i32, &[0], i == tokens.len() - 1) { - let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to add token to batch: {:?}", e))); - return; - } - } - - // Decode the prompt - if let Err(e) = ctx.decode(&mut batch) { - let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to decode prompt: {:?}", e))); - return; - } - - // Set up sampler - let mut sampler = LlamaSampler::chain_simple([ - LlamaSampler::temp(temperature), - LlamaSampler::dist(1234), - ]); - - // Generate tokens let mut accumulated_text = String::new(); - let mut n_cur = tokens.len() as i32; let mut token_count = 0u32; let mut stop_reason: Option = None; - for _ in 0..max_tokens { - let new_token = sampler.sample(&ctx, batch.n_tokens() - 1); - sampler.accept(new_token); + for _ in 0..params.max_tokens { + let new_token = prepared.sampler.sample(&prepared.ctx, prepared.batch.n_tokens() - 1); + prepared.sampler.accept(new_token); - // Check for end of generation if model.is_eog_token(new_token) { debug!("Hit end-of-generation token at {} tokens", token_count); stop_reason = Some("end_turn".to_string()); break; } - // Decode token to string - let token_str = model.token_to_str(new_token, Special::Tokenize) + let token_str = model + .token_to_str(new_token, Special::Tokenize) .unwrap_or_default(); - + accumulated_text.push_str(&token_str); token_count += 1; - // Check for stop sequences - let mut hit_stop = false; - for stop_seq in &stop_sequences { - if accumulated_text.contains(stop_seq) { - debug!("Hit stop sequence '{}' at {} tokens", stop_seq, token_count); - hit_stop = true; - stop_reason = Some("stop_sequence".to_string()); - break; - } - } - - if hit_stop { - // Send any remaining clean content - let mut clean_text = accumulated_text.clone(); - for stop_seq in &stop_sequences { - if let Some(pos) = clean_text.find(stop_seq) { - clean_text.truncate(pos); - break; - } - } - // We've been sending incrementally, so just break + if find_stop_sequence(&accumulated_text, ¶ms.stop_sequences).is_some() { + debug!("Hit stop sequence at {} tokens", token_count); + stop_reason = Some("stop_sequence".to_string()); break; } - // Send the token - let chunk = make_text_chunk(token_str); - if tx.blocking_send(Ok(chunk)).is_err() { - break; + // Stream the token + if tx.blocking_send(Ok(make_text_chunk(token_str))).is_err() { + return; // Receiver dropped } - // Check token limit - if token_count >= max_tokens { - debug!("Reached max token limit: {}", max_tokens); + if token_count >= params.max_tokens { + debug!("Reached max token limit: {}", params.max_tokens); stop_reason = Some("max_tokens".to_string()); break; } // Prepare next batch - batch.clear(); - if let Err(e) = batch.add(new_token, n_cur, &[0], true) { + prepared.batch.clear(); + if let Err(e) = prepared.batch.add(new_token, prepared.token_count, &[0], true) { error!("Failed to add token to batch: {:?}", e); break; } - n_cur += 1; + prepared.token_count += 1; - if let Err(e) = ctx.decode(&mut batch) { + if let Err(e) = prepared.ctx.decode(&mut prepared.batch) { error!("Failed to decode: {:?}", e); break; } } - // If no stop reason set, it was end_turn (natural completion) - if stop_reason.is_none() { - stop_reason = Some("end_turn".to_string()); - } - - // Send final chunk with usage information let usage = Usage { prompt_tokens, completion_tokens: token_count, @@ -678,7 +616,8 @@ impl LLMProvider for EmbeddedProvider { cache_creation_tokens: 0, cache_read_tokens: 0, }; - let final_chunk = make_final_chunk_with_reason(vec![], Some(usage), stop_reason); + let final_chunk = + make_final_chunk_with_reason(vec![], Some(usage), stop_reason.or(Some("end_turn".to_string()))); let _ = tx.blocking_send(Ok(final_chunk)); }); @@ -706,6 +645,10 @@ impl LLMProvider for EmbeddedProvider { } } +// ============================================================================ +// Tests +// ============================================================================ + #[cfg(test)] mod tests { use super::*; @@ -717,8 +660,8 @@ mod tests { Message::new(MessageRole::User, "Hello!".to_string()), ]; - let formatted = format_glm4_messages_standalone(&messages); - + let formatted = format_glm4(&messages); + assert!(formatted.starts_with("[gMASK]")); assert!(formatted.contains("<|system|>\nYou are a helpful assistant.")); assert!(formatted.contains("<|user|>\nHello!")); @@ -732,8 +675,8 @@ mod tests { Message::new(MessageRole::User, "Hello!".to_string()), ]; - let formatted = format_qwen_messages_standalone(&messages); - + let formatted = format_qwen(&messages); + assert!(formatted.contains("<|im_start|>system\nYou are a helpful assistant.<|im_end|>")); assert!(formatted.contains("<|im_start|>user\nHello!<|im_end|>")); assert!(formatted.ends_with("<|im_start|>assistant\n")); @@ -746,8 +689,8 @@ mod tests { Message::new(MessageRole::User, "Hello!".to_string()), ]; - let formatted = format_mistral_messages_standalone(&messages); - + let formatted = format_mistral(&messages); + assert!(formatted.starts_with("[INST] ")); assert!(formatted.contains("You are a helpful assistant.")); assert!(formatted.contains("Hello!")); @@ -761,8 +704,8 @@ mod tests { Message::new(MessageRole::User, "Hello!".to_string()), ]; - let formatted = format_llama_messages_standalone(&messages); - + let formatted = format_llama(&messages); + assert!(formatted.contains("<>")); assert!(formatted.contains("You are a helpful assistant.")); assert!(formatted.contains("<>")); @@ -772,8 +715,8 @@ mod tests { #[test] fn test_glm4_stop_sequences() { - let stop_seqs = get_stop_sequences_for_model_type("glm4"); - + let stop_seqs = get_stop_sequences_for_model("glm4"); + assert!(stop_seqs.contains(&"<|endoftext|>")); assert!(stop_seqs.contains(&"<|user|>")); assert!(stop_seqs.contains(&"<|observation|>")); @@ -782,8 +725,8 @@ mod tests { #[test] fn test_qwen_stop_sequences() { - let stop_seqs = get_stop_sequences_for_model_type("qwen"); - + let stop_seqs = get_stop_sequences_for_model("qwen"); + assert!(stop_seqs.contains(&"<|im_end|>")); assert!(stop_seqs.contains(&"<|endoftext|>")); assert!(stop_seqs.contains(&"<|im_start|>")); @@ -793,145 +736,54 @@ mod tests { fn test_glm4_multi_turn_conversation() { let messages = vec![ Message::new(MessageRole::System, "You are a coding assistant.".to_string()), - Message::new(MessageRole::User, "Write a hello world in Python.".to_string()), - Message::new(MessageRole::Assistant, "print('Hello, World!')".to_string()), + Message::new( + MessageRole::User, + "Write a hello world in Python.".to_string(), + ), + Message::new( + MessageRole::Assistant, + "print('Hello, World!')".to_string(), + ), Message::new(MessageRole::User, "Now in Rust.".to_string()), ]; - let formatted = format_glm4_messages_standalone(&messages); - + let formatted = format_glm4(&messages); + // Verify all parts are present in order let system_pos = formatted.find("<|system|>").unwrap(); let user1_pos = formatted.find("<|user|>\nWrite a hello world").unwrap(); let assistant_pos = formatted.find("<|assistant|>\nprint").unwrap(); let user2_pos = formatted.find("<|user|>\nNow in Rust").unwrap(); let final_assistant_pos = formatted.rfind("<|assistant|>\n").unwrap(); - + assert!(system_pos < user1_pos); assert!(user1_pos < assistant_pos); assert!(assistant_pos < user2_pos); assert!(user2_pos < final_assistant_pos); } - // Standalone formatting functions for testing without needing a full provider - fn format_glm4_messages_standalone(messages: &[Message]) -> String { - let mut formatted = String::from("[gMASK]"); - for message in messages { - let role = match message.role { - MessageRole::System => "<|system|>", - MessageRole::User => "<|user|>", - MessageRole::Assistant => "<|assistant|>", - }; - formatted.push_str(&format!("{}\n{}", role, message.content)); - } - formatted.push_str("<|assistant|>\n"); - formatted + #[test] + fn test_find_stop_sequence() { + let stop_seqs = vec!["".to_string(), "<|im_end|>".to_string()]; + + assert_eq!(find_stop_sequence("hello world", &stop_seqs), None); + assert_eq!(find_stop_sequence("helloworld", &stop_seqs), Some(5)); + assert_eq!( + find_stop_sequence("hello<|im_end|>world", &stop_seqs), + Some(5) + ); } - fn format_qwen_messages_standalone(messages: &[Message]) -> String { - let mut formatted = String::new(); - for message in messages { - let role = match message.role { - MessageRole::System => "system", - MessageRole::User => "user", - MessageRole::Assistant => "assistant", - }; - formatted.push_str(&format!( - "<|im_start|>{}\n{}<|im_end|>\n", - role, message.content - )); - } - formatted.push_str("<|im_start|>assistant\n"); - formatted - } + #[test] + fn test_truncate_at_stop_sequence() { + let stop_seqs = vec!["".to_string()]; - fn format_mistral_messages_standalone(messages: &[Message]) -> String { - let mut formatted = String::new(); - let mut in_conversation = false; - for (i, message) in messages.iter().enumerate() { - match message.role { - MessageRole::System => { - if i == 0 { - formatted.push_str("[INST] "); - formatted.push_str(&message.content); - formatted.push_str("\n\n"); - in_conversation = true; - } - } - MessageRole::User => { - if !in_conversation { - formatted.push_str("[INST] "); - } - formatted.push_str(&message.content); - formatted.push_str(" [/INST]"); - in_conversation = false; - } - MessageRole::Assistant => { - formatted.push(' '); - formatted.push_str(&message.content); - formatted.push_str(" "); - in_conversation = false; - } - } - } - if messages.last().is_some_and(|m| matches!(m.role, MessageRole::User)) { - formatted.push(' '); - } - formatted - } + let mut text = "helloworld".to_string(); + truncate_at_stop_sequence(&mut text, &stop_seqs); + assert_eq!(text, "hello"); - fn format_llama_messages_standalone(messages: &[Message]) -> String { - let mut formatted = String::new(); - for message in messages { - match message.role { - MessageRole::System => { - formatted.push_str(&format!( - "[INST] <>\n{}\n<>\n\n", - message.content - )); - } - MessageRole::User => { - formatted.push_str(&format!("{} [/INST] ", message.content)); - } - MessageRole::Assistant => { - formatted.push_str(&format!("{} [INST] ", message.content)); - } - } - } - formatted - } - - fn get_stop_sequences_for_model_type(model_type: &str) -> Vec<&'static str> { - if model_type.contains("glm") { - vec![ - "<|endoftext|>", - "<|user|>", - "<|observation|>", - "<|system|>", - ] - } else if model_type.contains("qwen") { - vec![ - "<|im_end|>", - "<|endoftext|>", - "", - "<|im_start|>", - ] - } else if model_type.contains("mistral") { - vec![ - "", - "[/INST]", - "<|im_end|>", - ] - } else { - vec![ - "", - "<|endoftext|>", - "<|im_end|>", - "### Human:", - "### Assistant:", - "[/INST]", - "<>", - ] - } + let mut text2 = "no stop here".to_string(); + truncate_at_stop_sequence(&mut text2, &stop_seqs); + assert_eq!(text2, "no stop here"); } } diff --git a/crates/g3-providers/src/gemini.rs b/crates/g3-providers/src/gemini.rs index bc6ac4b..db0148e 100644 --- a/crates/g3-providers/src/gemini.rs +++ b/crates/g3-providers/src/gemini.rs @@ -523,12 +523,11 @@ fn try_parse_json_from_buffer(buffer: &mut String) -> Option { // LLMProvider Implementation // ============================================================================ -#[async_trait] -impl LLMProvider for GeminiProvider { - async fn complete(&self, request: CompletionRequest) -> Result { +impl GeminiProvider { + /// Build a GeminiRequest from a CompletionRequest. + fn build_request(&self, request: &CompletionRequest) -> GeminiRequest { let (contents, system_instruction) = convert_messages(&request.messages); - - let gemini_request = GeminiRequest { + GeminiRequest { contents, system_instruction, tools: request.tools.as_ref().map(|t| convert_tools(t)), @@ -536,7 +535,14 @@ impl LLMProvider for GeminiProvider { max_output_tokens: request.max_tokens.or(Some(self.max_tokens)), temperature: request.temperature.or(Some(self.temperature)), }, - }; + } + } +} + +#[async_trait] +impl LLMProvider for GeminiProvider { + async fn complete(&self, request: CompletionRequest) -> Result { + let gemini_request = self.build_request(&request); let url = self.get_api_url(false); debug!("Gemini request URL: {}", url); @@ -579,17 +585,7 @@ impl LLMProvider for GeminiProvider { } async fn stream(&self, request: CompletionRequest) -> Result { - let (contents, system_instruction) = convert_messages(&request.messages); - - let gemini_request = GeminiRequest { - contents, - system_instruction, - tools: request.tools.as_ref().map(|t| convert_tools(t)), - generation_config: GeminiGenerationConfig { - max_output_tokens: request.max_tokens.or(Some(self.max_tokens)), - temperature: request.temperature.or(Some(self.temperature)), - }, - }; + let gemini_request = self.build_request(&request); // For streaming, add alt=sse parameter let url = format!("{}&alt=sse", self.get_api_url(true));