use crate::{ CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message, MessageRole, Usage, streaming::{make_text_chunk, make_final_chunk}, }; use anyhow::Result; use llama_cpp::{ standard_sampler::{SamplerStage, StandardSampler}, LlamaModel, LlamaParams, LlamaSession, SessionParams, }; use std::path::{Path, PathBuf}; use std::sync::Arc; use tokio::sync::mpsc; use tokio::sync::Mutex; use tokio_stream::wrappers::ReceiverStream; use tracing::{debug, error}; pub struct EmbeddedProvider { session: Arc>, model_name: String, max_tokens: u32, temperature: f32, context_length: u32, } impl EmbeddedProvider { pub fn new( model_path: String, model_type: String, context_length: Option, max_tokens: Option, temperature: Option, gpu_layers: Option, threads: Option, ) -> Result { debug!("Loading embedded model from: {}", model_path); // Expand tilde in path let expanded_path = shellexpand::tilde(&model_path); let model_path_buf = PathBuf::from(expanded_path.as_ref()); // If model doesn't exist and it's the default Qwen model, offer to download it if !model_path_buf.exists() { if model_path.contains("qwen2.5-7b-instruct-q3_k_m.gguf") { debug!("Model file not found. Attempting to download Qwen 2.5 7B model..."); Self::download_qwen_model(&model_path_buf)?; } else { anyhow::bail!("Model file not found: {}", model_path_buf.display()); } } let model_path = model_path_buf.as_path(); // Set up model parameters let mut params = LlamaParams::default(); if let Some(gpu_layers) = gpu_layers { params.n_gpu_layers = gpu_layers; debug!("Using {} GPU layers", gpu_layers); } let context_size = context_length.unwrap_or(4096); debug!("Using context length: {}", context_size); // Load the model debug!("Loading model..."); let model = LlamaModel::load_from_file(model_path, params) .map_err(|e| anyhow::anyhow!("Failed to load model: {}", e))?; // Create session with parameters let mut session_params = SessionParams { n_ctx: context_size, ..Default::default() }; if let Some(threads) = threads { session_params.n_threads = threads; } let session = model .create_session(session_params) .map_err(|e| anyhow::anyhow!("Failed to create session: {}", e))?; debug!("Successfully loaded {} model", model_type); Ok(Self { session: Arc::new(Mutex::new(session)), model_name: format!("embedded-{}", model_type), max_tokens: max_tokens.unwrap_or(2048), temperature: temperature.unwrap_or(0.1), context_length: context_size, }) } fn format_messages(&self, messages: &[Message]) -> String { // Determine the appropriate format based on model type let model_name_lower = self.model_name.to_lowercase(); if model_name_lower.contains("qwen") { // Qwen format: <|im_start|>role\ncontent<|im_end|> let mut formatted = String::new(); for message in messages { let role = match message.role { MessageRole::System => "system", MessageRole::User => "user", MessageRole::Assistant => "assistant", }; formatted.push_str(&format!( "<|im_start|>{}\n{}<|im_end|>\n", role, message.content )); } // Add the start of assistant response formatted.push_str("<|im_start|>assistant\n"); formatted } else if model_name_lower.contains("mistral") { // Mistral Instruct format: [INST] ... [/INST] assistant_response let mut formatted = String::new(); let mut in_conversation = false; for (i, message) in messages.iter().enumerate() { match message.role { MessageRole::System => { // Mistral doesn't have a special system token, include it at the start if i == 0 { formatted.push_str("[INST] "); formatted.push_str(&message.content); formatted.push_str("\n\n"); in_conversation = true; } } MessageRole::User => { if !in_conversation { formatted.push_str("[INST] "); } formatted.push_str(&message.content); formatted.push_str(" [/INST]"); in_conversation = false; } MessageRole::Assistant => { formatted.push(' '); formatted.push_str(&message.content); formatted.push_str(" "); in_conversation = false; } } } // If the last message was from user, add a space for the assistant's response if messages .last() .is_some_and(|m| matches!(m.role, MessageRole::User)) { formatted.push(' '); } formatted } else { // Use Llama/CodeLlama format for other models let mut formatted = String::new(); for message in messages { match message.role { MessageRole::System => { formatted.push_str(&format!( "[INST] <>\n{}\n<>\n\n", message.content )); } MessageRole::User => { formatted.push_str(&format!("{} [/INST] ", message.content)); } MessageRole::Assistant => { formatted.push_str(&format!("{} [INST] ", message.content)); } } } formatted } } async fn generate_completion( &self, prompt: &str, max_tokens: u32, temperature: f32, ) -> Result { let session = self.session.clone(); let prompt = prompt.to_string(); // Calculate dynamic max tokens based on available context headroom let prompt_tokens = self.estimate_tokens(&prompt); let available_tokens = self .context_length .saturating_sub(prompt_tokens) .saturating_sub(50); // Reserve 50 tokens for safety let dynamic_max_tokens = std::cmp::min(max_tokens as usize, available_tokens as usize); debug!("Context calculation: prompt_tokens={}, context_length={}, available_tokens={}, dynamic_max_tokens={}", prompt_tokens, self.context_length, available_tokens, dynamic_max_tokens); // Get stop sequences before entering the closure let stop_sequences = self.get_stop_sequences(); // Add timeout to the entire operation let timeout_duration = std::time::Duration::from_secs(30); // Increased timeout for larger contexts let result = tokio::time::timeout( timeout_duration, tokio::task::spawn_blocking(move || { // Retry logic for acquiring the session lock let mut session_guard = None; for attempt in 0..5 { match session.try_lock() { Ok(ctx) => { session_guard = Some(ctx); break; } Err(_) => { if attempt < 4 { debug!( "Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1 ); std::thread::sleep(std::time::Duration::from_millis( 100 * (attempt + 1) as u64, )); } else { return Err(anyhow::anyhow!( "Model is busy after 5 attempts, please try again" )); } } } } let mut session = session_guard .ok_or_else(|| anyhow::anyhow!("Failed to acquire session lock"))?; debug!( "Starting inference with prompt length: {} chars, estimated {} tokens", prompt.len(), prompt_tokens ); // Set context to the prompt debug!("About to call set_context..."); session .set_context(&prompt) .map_err(|e| anyhow::anyhow!("Failed to set context: {}", e))?; debug!("set_context completed successfully"); // Create sampler with temperature debug!("Creating sampler..."); let stages = vec![ SamplerStage::Temperature(temperature), SamplerStage::TopK(40), SamplerStage::TopP(0.9), ]; let sampler = StandardSampler::new_softmax(stages, 1); debug!("Sampler created successfully"); // Start completion with dynamic max tokens debug!( "About to call start_completing_with with {} max tokens...", dynamic_max_tokens ); let mut completion_handle = session .start_completing_with(sampler, dynamic_max_tokens) .map_err(|e| anyhow::anyhow!("Failed to start completion: {}", e))?; debug!("start_completing_with completed successfully"); let mut generated_text = String::new(); let mut token_count = 0; let start_time = std::time::Instant::now(); debug!("Starting token generation loop..."); // Generate tokens with dynamic limits while let Some(token) = completion_handle.next_token() { // Check for timeout on each token if start_time.elapsed() > std::time::Duration::from_secs(25) { debug!("Token generation timeout after {} tokens", token_count); break; } let token_string = session.model().token_to_piece(token); generated_text.push_str(&token_string); token_count += 1; if token_count <= 10 || token_count % 50 == 0 { debug!("Generated token {}: '{}'", token_count, token_string); } // Use dynamic token limit if token_count >= dynamic_max_tokens { debug!("Reached dynamic token limit: {}", dynamic_max_tokens); break; } // Stop on completion markers let mut hit_stop = false; for stop_seq in &stop_sequences { if generated_text.contains(stop_seq) { debug!("Hit stop sequence '{}' at {} tokens", stop_seq, token_count); hit_stop = true; break; } } if hit_stop { break; } } debug!( "Token generation loop completed. Generated {} tokens in {:?}", token_count, start_time.elapsed() ); Ok((generated_text, token_count)) }), ) .await; match result { Ok(inner_result) => match inner_result { Ok(task_result) => match task_result { Ok((text, token_count)) => { debug!( "Completed generation: {} tokens (dynamic limit was {})", token_count, dynamic_max_tokens ); // Clean stop sequences from the generated text after the closure Ok(self.clean_stop_sequences(&text)) } Err(e) => Err(e), }, Err(e) => Err(e.into()), }, Err(_) => { error!("Generation timed out after 30 seconds"); Err(anyhow::anyhow!("Generation timed out")) } } } // Helper function to estimate token count from text fn estimate_tokens(&self, text: &str) -> u32 { // Rough estimation: average 4 characters per token // This is conservative - actual tokenization might be different (text.len() as f32 / 4.0).ceil() as u32 } // Helper function to get stop sequences based on model type fn get_stop_sequences(&self) -> Vec<&'static str> { // Determine model type from model_name let model_name_lower = self.model_name.to_lowercase(); if model_name_lower.contains("qwen") { vec![ "<|im_end|>", // Qwen ChatML format end token "<|endoftext|>", // Alternative end token "", // Generic end of sequence "<|im_start|>", // Start of new message (shouldn't appear in response) ] } else if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama") { vec![ "", // End of sequence "[/INST]", // End of instruction "<>", // End of system message "[INST]", // Start of new instruction (shouldn't appear in response) "<>", // Start of system (shouldn't appear in response) ] } else if model_name_lower.contains("llama") { vec![ "", // End of sequence "[/INST]", // End of instruction "<>", // End of system message "### Human:", // Conversation format "### Assistant:", // Conversation format "[INST]", // Start of new instruction ] } else if model_name_lower.contains("mistral") { vec![ "", // End of sequence "[/INST]", // End of instruction "<|im_end|>", // ChatML format ] } else if model_name_lower.contains("vicuna") || model_name_lower.contains("wizard") { vec![ "### Human:", // Conversation format "### Assistant:", // Conversation format "USER:", // Alternative format "ASSISTANT:", // Alternative format "", // End of sequence ] } else if model_name_lower.contains("alpaca") { vec![ "### Instruction:", // Alpaca format "### Response:", // Alpaca format "### Input:", // Alpaca format "", // End of sequence ] } else { // Generic/unknown model - use common stop sequences vec![ "", // Most common end sequence "<|endoftext|>", // GPT-style "<|im_end|>", // ChatML "### Human:", // Common conversation format "### Assistant:", // Common conversation format "[/INST]", // Instruction format "<>", // System format ] } } // Helper function to clean up stop sequences from generated text fn clean_stop_sequences(&self, text: &str) -> String { let mut cleaned = text.to_string(); let stop_sequences = self.get_stop_sequences(); for stop_seq in &stop_sequences { if let Some(pos) = cleaned.find(stop_seq) { cleaned.truncate(pos); break; // Only remove the first occurrence to avoid over-truncation } } cleaned.trim().to_string() } // Download the Qwen 2.5 7B model if it doesn't exist fn download_qwen_model(model_path: &Path) -> Result<()> { use std::fs; use std::process::Command; const MODEL_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-7B-Instruct-GGUF/resolve/main/qwen2.5-7b-instruct-q3_k_m.gguf"; const MODEL_SIZE_MB: u64 = 3631; // Approximate size in MB // Create the parent directory if it doesn't exist if let Some(parent) = model_path.parent() { fs::create_dir_all(parent)?; } debug!("Downloading Qwen 2.5 7B model (Q3_K_M quantization, ~3.5GB)..."); debug!("This is a one-time download that may take several minutes depending on your connection."); debug!("Downloading to: {}", model_path.display()); // Use curl with progress bar for download let output = Command::new("curl") .args([ "-L", // Follow redirects "-#", // Show progress bar "-f", // Fail on HTTP errors "-o", model_path.to_str().unwrap(), MODEL_URL, ]) .output()?; if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); // If curl is not available, provide alternative instructions if stderr.contains("command not found") || stderr.contains("not found") { error!( "curl is not installed. Please install curl or manually download the model." ); error!("Manual download instructions:"); error!("1. Download from: {}", MODEL_URL); error!("2. Save to: {}", model_path.display()); anyhow::bail!( "curl not found - please install curl or download the model manually" ); } anyhow::bail!("Failed to download model: {}", stderr); } // Verify the file was created and has reasonable size let metadata = fs::metadata(model_path)?; let size_mb = metadata.len() / (1024 * 1024); if size_mb < MODEL_SIZE_MB - 100 { // Allow some variance fs::remove_file(model_path).ok(); // Clean up partial download anyhow::bail!( "Downloaded file appears incomplete ({}MB vs expected ~{}MB). Please try again.", size_mb, MODEL_SIZE_MB ); } debug!("Successfully downloaded Qwen 2.5 7B model ({}MB)", size_mb); Ok(()) } } #[async_trait::async_trait] impl LLMProvider for EmbeddedProvider { async fn complete(&self, request: CompletionRequest) -> Result { debug!( "Processing completion request with {} messages", request.messages.len() ); let prompt = self.format_messages(&request.messages); let max_tokens = request.max_tokens.unwrap_or(self.max_tokens); let temperature = request.temperature.unwrap_or(self.temperature); debug!("Formatted prompt length: {} chars", prompt.len()); let content = self .generate_completion(&prompt, max_tokens, temperature) .await?; // Estimate token usage (rough approximation) let prompt_tokens = (prompt.len() / 4) as u32; // Rough estimate: 4 chars per token let completion_tokens = (content.len() / 4) as u32; Ok(CompletionResponse { content, usage: Usage { prompt_tokens, completion_tokens, total_tokens: prompt_tokens + completion_tokens, }, model: self.model_name.clone(), }) } async fn stream(&self, request: CompletionRequest) -> Result { debug!( "Processing streaming request with {} messages", request.messages.len() ); let prompt = self.format_messages(&request.messages); let max_tokens = request.max_tokens.unwrap_or(self.max_tokens); let temperature = request.temperature.unwrap_or(self.temperature); let (tx, rx) = mpsc::channel(100); let session = self.session.clone(); let prompt = prompt.to_string(); // Spawn streaming task tokio::task::spawn_blocking(move || { // Retry logic for acquiring the session lock let mut session_guard = None; for attempt in 0..5 { match session.try_lock() { Ok(ctx) => { session_guard = Some(ctx); break; } Err(_) => { if attempt < 4 { debug!( "Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1 ); std::thread::sleep(std::time::Duration::from_millis( 100 * (attempt + 1) as u64, )); } else { let _ = tx.blocking_send(Err(anyhow::anyhow!( "Model is busy after 5 attempts, please try again" ))); return; } } } } let mut session = match session_guard { Some(ctx) => ctx, None => { let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to acquire session lock"))); return; } }; // Set context to the prompt if let Err(e) = session.set_context(&prompt) { let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to set context: {}", e))); return; } // Create sampler with temperature let stages = vec![ SamplerStage::Temperature(temperature), SamplerStage::TopK(40), SamplerStage::TopP(0.9), ]; let sampler = StandardSampler::new_softmax(stages, 1); // Start completion let mut completion_handle = match session .start_completing_with(sampler, max_tokens as usize) { Ok(handle) => handle, Err(e) => { let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to start completion: {}", e))); return; } }; let mut accumulated_text = String::new(); let mut token_count = 0; let mut unsent_tokens = String::new(); // Buffer for tokens we're holding back // Get stop sequences dynamically based on model type let stop_sequences = if prompt.contains("<|im_start|>") { // Qwen ChatML format detected vec!["<|im_end|>", "<|endoftext|>", "", "<|im_start|>"] } else if prompt.contains("[INST]") || prompt.contains("<>") { // Llama/CodeLlama format detected vec![ "", "[/INST]", "<>", "[INST]", "<>", "### Human:", "### Assistant:", ] } else { // Generic format vec![ "", "<|endoftext|>", "<|im_end|>", "### Human:", "### Assistant:", "[/INST]", "<>", ] }; // Stream tokens with proper limits while let Some(token) = completion_handle.next_token() { let token_string = session.model().token_to_piece(token); accumulated_text.push_str(&token_string); unsent_tokens.push_str(&token_string); token_count += 1; // Check if we've hit a complete stop sequence let mut hit_stop = false; for stop_seq in &stop_sequences { if accumulated_text.contains(stop_seq) { debug!("Hit complete stop sequence in streaming: {}", stop_seq); hit_stop = true; break; } } if hit_stop { // Before stopping, check if there might be an incomplete tool call // Look for JSON tool call patterns that might be cut off by the stop sequence let has_potential_tool_call = accumulated_text.contains(r#"{"tool":"#) || accumulated_text.contains(r#"{"{""tool"":"#) || accumulated_text.contains(r#"{{""tool"":"#); if has_potential_tool_call { // Check if the tool call appears to be complete (has closing brace after the stop sequence) let mut complete_tool_call = false; for stop_seq in &stop_sequences { if let Some(stop_pos) = accumulated_text.find(stop_seq) { // Look for tool call pattern before the stop sequence let before_stop = &accumulated_text[..stop_pos]; if let Some(tool_start) = before_stop.rfind(r#"{"tool":"#) { let tool_part = &before_stop[tool_start..]; // Count braces to see if JSON is complete let open_braces = tool_part.matches('{').count(); let close_braces = tool_part.matches('}').count(); if open_braces > 0 && open_braces == close_braces { complete_tool_call = true; break; } } } } // If tool call is incomplete, send the raw content including stop sequences // so the main parser can handle it properly if !complete_tool_call { debug!("Found incomplete tool call, sending raw content with stop sequences"); let already_sent_len = accumulated_text.len() - unsent_tokens.len(); if accumulated_text.len() > already_sent_len { let remaining_to_send = &accumulated_text[already_sent_len..]; if !remaining_to_send.is_empty() { let chunk = make_text_chunk(remaining_to_send.to_string()); let _ = tx.blocking_send(Ok(chunk)); } } break; } } // Send any remaining clean content before stopping (original behavior) let mut clean_accumulated = accumulated_text.clone(); for stop_seq in &stop_sequences { if let Some(pos) = clean_accumulated.find(stop_seq) { clean_accumulated.truncate(pos); break; } } // Calculate what part we haven't sent yet let already_sent_len = accumulated_text.len() - unsent_tokens.len(); if clean_accumulated.len() > already_sent_len { let remaining_to_send = &clean_accumulated[already_sent_len..]; if !remaining_to_send.is_empty() { let chunk = make_text_chunk(remaining_to_send.to_string()); let _ = tx.blocking_send(Ok(chunk)); } } break; } // Check if we're building towards a stop sequence let mut might_be_stop = false; for stop_seq in &stop_sequences { for i in 1..stop_seq.len() { let partial = &stop_seq[..i]; if accumulated_text.ends_with(partial) { debug!("Detected potential partial stop sequence: '{}'", partial); might_be_stop = true; break; } } if might_be_stop { break; } } if might_be_stop { // Hold back tokens, but only for a limited buffer size if unsent_tokens.len() > 20 { // Don't hold back more than 20 characters // Send the oldest part and keep only the recent part that might be a stop sequence let to_send = &unsent_tokens[..unsent_tokens.len() - 10]; if !to_send.is_empty() { let chunk = make_text_chunk(to_send.to_string()); if tx.blocking_send(Ok(chunk)).is_err() { break; } } unsent_tokens = unsent_tokens[unsent_tokens.len() - 10..].to_string(); } // Continue to next token without sending } else { // No potential stop sequence, send all unsent tokens if !unsent_tokens.is_empty() { let chunk = make_text_chunk(unsent_tokens.clone()); if tx.blocking_send(Ok(chunk)).is_err() { break; } unsent_tokens.clear(); } } // Enforce token limit if token_count >= max_tokens as usize { debug!("Reached max token limit in streaming: {}", max_tokens); break; } } // Send final chunk let final_chunk = make_final_chunk(vec![], None); let _ = tx.blocking_send(Ok(final_chunk)); }); Ok(ReceiverStream::new(rx)) } fn name(&self) -> &str { "embedded" } fn model(&self) -> &str { &self.model_name } fn max_tokens(&self) -> u32 { self.max_tokens } fn temperature(&self) -> f32 { self.temperature } }