Files
g3/crates/g3-providers/src/embedded.rs
Jochen ad198a8501 add code exploration fast start
This tries to short-circuit multiple round-trips to llm for reading code.
It's a precursor to trying to context engineer tailored to specific tasks.
In initial experiments, it's only marginally faster than regular mode, and burns more tokens.
2025-11-25 22:51:32 +11:00

783 lines
32 KiB
Rust

use anyhow::Result;
use crate::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
MessageRole, Usage,
};
use llama_cpp::{
standard_sampler::{SamplerStage, StandardSampler},
LlamaModel, LlamaParams, LlamaSession, SessionParams,
};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error, info};
pub struct EmbeddedProvider {
session: Arc<Mutex<LlamaSession>>,
model_name: String,
max_tokens: u32,
temperature: f32,
context_length: u32,
}
impl EmbeddedProvider {
pub fn new(
model_path: String,
model_type: String,
context_length: Option<u32>,
max_tokens: Option<u32>,
temperature: Option<f32>,
gpu_layers: Option<u32>,
threads: Option<u32>,
) -> Result<Self> {
info!("Loading embedded model from: {}", model_path);
// Expand tilde in path
let expanded_path = shellexpand::tilde(&model_path);
let model_path_buf = PathBuf::from(expanded_path.as_ref());
// If model doesn't exist and it's the default Qwen model, offer to download it
if !model_path_buf.exists() {
if model_path.contains("qwen2.5-7b-instruct-q3_k_m.gguf") {
info!("Model file not found. Attempting to download Qwen 2.5 7B model...");
Self::download_qwen_model(&model_path_buf)?;
} else {
anyhow::bail!("Model file not found: {}", model_path_buf.display());
}
}
let model_path = model_path_buf.as_path();
// Set up model parameters
let mut params = LlamaParams::default();
if let Some(gpu_layers) = gpu_layers {
params.n_gpu_layers = gpu_layers;
info!("Using {} GPU layers", gpu_layers);
}
let context_size = context_length.unwrap_or(4096);
info!("Using context length: {}", context_size);
// Load the model
info!("Loading model...");
let model = LlamaModel::load_from_file(model_path, params)
.map_err(|e| anyhow::anyhow!("Failed to load model: {}", e))?;
// Create session with parameters
let mut session_params = SessionParams {
n_ctx: context_size,
..Default::default()
};
if let Some(threads) = threads {
session_params.n_threads = threads;
}
let session = model
.create_session(session_params)
.map_err(|e| anyhow::anyhow!("Failed to create session: {}", e))?;
info!("Successfully loaded {} model", model_type);
Ok(Self {
session: Arc::new(Mutex::new(session)),
model_name: format!("embedded-{}", model_type),
max_tokens: max_tokens.unwrap_or(2048),
temperature: temperature.unwrap_or(0.1),
context_length: context_size,
})
}
fn format_messages(&self, messages: &[Message]) -> String {
// Determine the appropriate format based on model type
let model_name_lower = self.model_name.to_lowercase();
if model_name_lower.contains("qwen") {
// Qwen format: <|im_start|>role\ncontent<|im_end|>
let mut formatted = String::new();
for message in messages {
let role = match message.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
};
formatted.push_str(&format!(
"<|im_start|>{}\n{}<|im_end|>\n",
role, message.content
));
}
// Add the start of assistant response
formatted.push_str("<|im_start|>assistant\n");
formatted
} else if model_name_lower.contains("mistral") {
// Mistral Instruct format: <s>[INST] ... [/INST] assistant_response</s>
let mut formatted = String::new();
let mut in_conversation = false;
for (i, message) in messages.iter().enumerate() {
match message.role {
MessageRole::System => {
// Mistral doesn't have a special system token, include it at the start
if i == 0 {
formatted.push_str("<s>[INST] ");
formatted.push_str(&message.content);
formatted.push_str("\n\n");
in_conversation = true;
}
}
MessageRole::User => {
if !in_conversation {
formatted.push_str("<s>[INST] ");
}
formatted.push_str(&message.content);
formatted.push_str(" [/INST]");
in_conversation = false;
}
MessageRole::Assistant => {
formatted.push(' ');
formatted.push_str(&message.content);
formatted.push_str("</s> ");
in_conversation = false;
}
}
}
// If the last message was from user, add a space for the assistant's response
if messages.last().is_some_and(|m| matches!(m.role, MessageRole::User)) {
formatted.push(' ');
}
formatted
} else {
// Use Llama/CodeLlama format for other models
let mut formatted = String::new();
for message in messages {
match message.role {
MessageRole::System => {
formatted.push_str(&format!(
"[INST] <<SYS>>\n{}\n<</SYS>>\n\n",
message.content
));
}
MessageRole::User => {
formatted.push_str(&format!("{} [/INST] ", message.content));
}
MessageRole::Assistant => {
formatted.push_str(&format!("{} </s><s>[INST] ", message.content));
}
}
}
formatted
}
}
async fn generate_completion(
&self,
prompt: &str,
max_tokens: u32,
temperature: f32,
) -> Result<String> {
let session = self.session.clone();
let prompt = prompt.to_string();
// Calculate dynamic max tokens based on available context headroom
let prompt_tokens = self.estimate_tokens(&prompt);
let available_tokens = self
.context_length
.saturating_sub(prompt_tokens)
.saturating_sub(50); // Reserve 50 tokens for safety
let dynamic_max_tokens = std::cmp::min(max_tokens as usize, available_tokens as usize);
debug!("Context calculation: prompt_tokens={}, context_length={}, available_tokens={}, dynamic_max_tokens={}",
prompt_tokens, self.context_length, available_tokens, dynamic_max_tokens);
// Get stop sequences before entering the closure
let stop_sequences = self.get_stop_sequences();
// Add timeout to the entire operation
let timeout_duration = std::time::Duration::from_secs(30); // Increased timeout for larger contexts
let result = tokio::time::timeout(
timeout_duration,
tokio::task::spawn_blocking(move || {
// Retry logic for acquiring the session lock
let mut session_guard = None;
for attempt in 0..5 {
match session.try_lock() {
Ok(ctx) => {
session_guard = Some(ctx);
break;
}
Err(_) => {
if attempt < 4 {
debug!("Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1);
std::thread::sleep(std::time::Duration::from_millis(100 * (attempt + 1) as u64));
} else {
return Err(anyhow::anyhow!("Model is busy after 5 attempts, please try again"));
}
}
}
}
let mut session = session_guard.ok_or_else(|| anyhow::anyhow!("Failed to acquire session lock"))?;
debug!(
"Starting inference with prompt length: {} chars, estimated {} tokens",
prompt.len(),
prompt_tokens
);
// Set context to the prompt
debug!("About to call set_context...");
session
.set_context(&prompt)
.map_err(|e| anyhow::anyhow!("Failed to set context: {}", e))?;
debug!("set_context completed successfully");
// Create sampler with temperature
debug!("Creating sampler...");
let stages = vec![
SamplerStage::Temperature(temperature),
SamplerStage::TopK(40),
SamplerStage::TopP(0.9),
];
let sampler = StandardSampler::new_softmax(stages, 1);
debug!("Sampler created successfully");
// Start completion with dynamic max tokens
debug!(
"About to call start_completing_with with {} max tokens...",
dynamic_max_tokens
);
let mut completion_handle = session
.start_completing_with(sampler, dynamic_max_tokens)
.map_err(|e| anyhow::anyhow!("Failed to start completion: {}", e))?;
debug!("start_completing_with completed successfully");
let mut generated_text = String::new();
let mut token_count = 0;
let start_time = std::time::Instant::now();
debug!("Starting token generation loop...");
// Generate tokens with dynamic limits
while let Some(token) = completion_handle.next_token() {
// Check for timeout on each token
if start_time.elapsed() > std::time::Duration::from_secs(25) {
debug!("Token generation timeout after {} tokens", token_count);
break;
}
let token_string = session.model().token_to_piece(token);
generated_text.push_str(&token_string);
token_count += 1;
if token_count <= 10 || token_count % 50 == 0 {
debug!("Generated token {}: '{}'", token_count, token_string);
}
// Use dynamic token limit
if token_count >= dynamic_max_tokens {
debug!("Reached dynamic token limit: {}", dynamic_max_tokens);
break;
}
// Stop on completion markers
let mut hit_stop = false;
for stop_seq in &stop_sequences {
if generated_text.contains(stop_seq) {
debug!("Hit stop sequence '{}' at {} tokens", stop_seq, token_count);
hit_stop = true;
break;
}
}
if hit_stop {
break;
}
}
debug!(
"Token generation loop completed. Generated {} tokens in {:?}",
token_count,
start_time.elapsed()
);
Ok((generated_text, token_count))
}),
)
.await;
match result {
Ok(inner_result) => match inner_result {
Ok(task_result) => match task_result {
Ok((text, token_count)) => {
info!(
"Completed generation: {} tokens (dynamic limit was {})",
token_count, dynamic_max_tokens
);
// Clean stop sequences from the generated text after the closure
Ok(self.clean_stop_sequences(&text))
}
Err(e) => Err(e),
},
Err(e) => Err(e.into()),
},
Err(_) => {
error!("Generation timed out after 30 seconds");
Err(anyhow::anyhow!("Generation timed out"))
}
}
}
// Helper function to estimate token count from text
fn estimate_tokens(&self, text: &str) -> u32 {
// Rough estimation: average 4 characters per token
// This is conservative - actual tokenization might be different
(text.len() as f32 / 4.0).ceil() as u32
}
// Helper function to get stop sequences based on model type
fn get_stop_sequences(&self) -> Vec<&'static str> {
// Determine model type from model_name
let model_name_lower = self.model_name.to_lowercase();
if model_name_lower.contains("qwen") {
vec![
"<|im_end|>", // Qwen ChatML format end token
"<|endoftext|>", // Alternative end token
"</s>", // Generic end of sequence
"<|im_start|>", // Start of new message (shouldn't appear in response)
]
} else if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama") {
vec![
"</s>", // End of sequence
"[/INST]", // End of instruction
"<</SYS>>", // End of system message
"[INST]", // Start of new instruction (shouldn't appear in response)
"<<SYS>>", // Start of system (shouldn't appear in response)
]
} else if model_name_lower.contains("llama") {
vec![
"</s>", // End of sequence
"[/INST]", // End of instruction
"<</SYS>>", // End of system message
"### Human:", // Conversation format
"### Assistant:", // Conversation format
"[INST]", // Start of new instruction
]
} else if model_name_lower.contains("mistral") {
vec![
"</s>", // End of sequence
"[/INST]", // End of instruction
"<|im_end|>", // ChatML format
]
} else if model_name_lower.contains("vicuna") || model_name_lower.contains("wizard") {
vec![
"### Human:", // Conversation format
"### Assistant:", // Conversation format
"USER:", // Alternative format
"ASSISTANT:", // Alternative format
"</s>", // End of sequence
]
} else if model_name_lower.contains("alpaca") {
vec![
"### Instruction:", // Alpaca format
"### Response:", // Alpaca format
"### Input:", // Alpaca format
"</s>", // End of sequence
]
} else {
// Generic/unknown model - use common stop sequences
vec![
"</s>", // Most common end sequence
"<|endoftext|>", // GPT-style
"<|im_end|>", // ChatML
"### Human:", // Common conversation format
"### Assistant:", // Common conversation format
"[/INST]", // Instruction format
"<</SYS>>", // System format
]
}
}
// Helper function to clean up stop sequences from generated text
fn clean_stop_sequences(&self, text: &str) -> String {
let mut cleaned = text.to_string();
let stop_sequences = self.get_stop_sequences();
for stop_seq in &stop_sequences {
if let Some(pos) = cleaned.find(stop_seq) {
cleaned.truncate(pos);
break; // Only remove the first occurrence to avoid over-truncation
}
}
cleaned.trim().to_string()
}
// Download the Qwen 2.5 7B model if it doesn't exist
fn download_qwen_model(model_path: &Path) -> Result<()> {
use std::fs;
use std::process::Command;
const MODEL_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-7B-Instruct-GGUF/resolve/main/qwen2.5-7b-instruct-q3_k_m.gguf";
const MODEL_SIZE_MB: u64 = 3631; // Approximate size in MB
// Create the parent directory if it doesn't exist
if let Some(parent) = model_path.parent() {
fs::create_dir_all(parent)?;
}
info!("Downloading Qwen 2.5 7B model (Q3_K_M quantization, ~3.5GB)...");
info!("This is a one-time download that may take several minutes depending on your connection.");
info!("Downloading to: {}", model_path.display());
// Use curl with progress bar for download
let output = Command::new("curl")
.args([
"-L", // Follow redirects
"-#", // Show progress bar
"-f", // Fail on HTTP errors
"-o", model_path.to_str().unwrap(),
MODEL_URL,
])
.output()?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
// If curl is not available, provide alternative instructions
if stderr.contains("command not found") || stderr.contains("not found") {
error!("curl is not installed. Please install curl or manually download the model.");
error!("Manual download instructions:");
error!("1. Download from: {}", MODEL_URL);
error!("2. Save to: {}", model_path.display());
anyhow::bail!("curl not found - please install curl or download the model manually");
}
anyhow::bail!("Failed to download model: {}", stderr);
}
// Verify the file was created and has reasonable size
let metadata = fs::metadata(model_path)?;
let size_mb = metadata.len() / (1024 * 1024);
if size_mb < MODEL_SIZE_MB - 100 { // Allow some variance
fs::remove_file(model_path).ok(); // Clean up partial download
anyhow::bail!(
"Downloaded file appears incomplete ({}MB vs expected ~{}MB). Please try again.",
size_mb, MODEL_SIZE_MB
);
}
info!("Successfully downloaded Qwen 2.5 7B model ({}MB)", size_mb);
Ok(())
}
}
#[async_trait::async_trait]
impl LLMProvider for EmbeddedProvider {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
debug!(
"Processing completion request with {} messages",
request.messages.len()
);
let prompt = self.format_messages(&request.messages);
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
debug!("Formatted prompt length: {} chars", prompt.len());
let content = self
.generate_completion(&prompt, max_tokens, temperature)
.await?;
// Estimate token usage (rough approximation)
let prompt_tokens = (prompt.len() / 4) as u32; // Rough estimate: 4 chars per token
let completion_tokens = (content.len() / 4) as u32;
Ok(CompletionResponse {
content,
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
},
model: self.model_name.clone(),
})
}
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
debug!(
"Processing streaming request with {} messages",
request.messages.len()
);
let prompt = self.format_messages(&request.messages);
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
let (tx, rx) = mpsc::channel(100);
let session = self.session.clone();
let prompt = prompt.to_string();
// Spawn streaming task
tokio::task::spawn_blocking(move || {
// Retry logic for acquiring the session lock
let mut session_guard = None;
for attempt in 0..5 {
match session.try_lock() {
Ok(ctx) => {
session_guard = Some(ctx);
break;
}
Err(_) => {
if attempt < 4 {
debug!("Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1);
std::thread::sleep(std::time::Duration::from_millis(100 * (attempt + 1) as u64));
} else {
let _ = tx.blocking_send(Err(anyhow::anyhow!("Model is busy after 5 attempts, please try again")));
return;
}
}
}
}
let mut session = match session_guard {
Some(ctx) => ctx,
None => {
let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to acquire session lock")));
return;
}
};
// Set context to the prompt
if let Err(e) = session.set_context(&prompt) {
let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to set context: {}", e)));
return;
}
// Create sampler with temperature
let stages = vec![
SamplerStage::Temperature(temperature),
SamplerStage::TopK(40),
SamplerStage::TopP(0.9),
];
let sampler = StandardSampler::new_softmax(stages, 1);
// Start completion
let mut completion_handle = match session
.start_completing_with(sampler, max_tokens as usize)
{
Ok(handle) => handle,
Err(e) => {
let _ =
tx.blocking_send(Err(anyhow::anyhow!("Failed to start completion: {}", e)));
return;
}
};
let mut accumulated_text = String::new();
let mut token_count = 0;
let mut unsent_tokens = String::new(); // Buffer for tokens we're holding back
// Get stop sequences dynamically based on model type
let stop_sequences = if prompt.contains("<|im_start|>") {
// Qwen ChatML format detected
vec!["<|im_end|>", "<|endoftext|>", "</s>", "<|im_start|>"]
} else if prompt.contains("[INST]") || prompt.contains("<<SYS>>") {
// Llama/CodeLlama format detected
vec!["</s>", "[/INST]", "<</SYS>>", "[INST]", "<<SYS>>", "### Human:", "### Assistant:"]
} else {
// Generic format
vec!["</s>", "<|endoftext|>", "<|im_end|>", "### Human:", "### Assistant:", "[/INST]", "<</SYS>>"]
};
// Stream tokens with proper limits
while let Some(token) = completion_handle.next_token() {
let token_string = session.model().token_to_piece(token);
accumulated_text.push_str(&token_string);
unsent_tokens.push_str(&token_string);
token_count += 1;
// Check if we've hit a complete stop sequence
let mut hit_stop = false;
for stop_seq in &stop_sequences {
if accumulated_text.contains(stop_seq) {
debug!("Hit complete stop sequence in streaming: {}", stop_seq);
hit_stop = true;
break;
}
}
if hit_stop {
// Before stopping, check if there might be an incomplete tool call
// Look for JSON tool call patterns that might be cut off by the stop sequence
let has_potential_tool_call = accumulated_text.contains(r#"{"tool":"#) ||
accumulated_text.contains(r#"{"{""tool"":"#) ||
accumulated_text.contains(r#"{{""tool"":"#);
if has_potential_tool_call {
// Check if the tool call appears to be complete (has closing brace after the stop sequence)
let mut complete_tool_call = false;
for stop_seq in &stop_sequences {
if let Some(stop_pos) = accumulated_text.find(stop_seq) {
// Look for tool call pattern before the stop sequence
let before_stop = &accumulated_text[..stop_pos];
if let Some(tool_start) = before_stop.rfind(r#"{"tool":"#) {
let tool_part = &before_stop[tool_start..];
// Count braces to see if JSON is complete
let open_braces = tool_part.matches('{').count();
let close_braces = tool_part.matches('}').count();
if open_braces > 0 && open_braces == close_braces {
complete_tool_call = true;
break;
}
}
}
}
// If tool call is incomplete, send the raw content including stop sequences
// so the main parser can handle it properly
if !complete_tool_call {
debug!("Found incomplete tool call, sending raw content with stop sequences");
let already_sent_len = accumulated_text.len() - unsent_tokens.len();
if accumulated_text.len() > already_sent_len {
let remaining_to_send = &accumulated_text[already_sent_len..];
if !remaining_to_send.is_empty() {
let chunk = CompletionChunk {
content: remaining_to_send.to_string(),
finished: false,
usage: None,
tool_calls: None,
};
let _ = tx.blocking_send(Ok(chunk));
}
}
break;
}
}
// Send any remaining clean content before stopping (original behavior)
let mut clean_accumulated = accumulated_text.clone();
for stop_seq in &stop_sequences {
if let Some(pos) = clean_accumulated.find(stop_seq) {
clean_accumulated.truncate(pos);
break;
}
}
// Calculate what part we haven't sent yet
let already_sent_len = accumulated_text.len() - unsent_tokens.len();
if clean_accumulated.len() > already_sent_len {
let remaining_to_send = &clean_accumulated[already_sent_len..];
if !remaining_to_send.is_empty() {
let chunk = CompletionChunk {
content: remaining_to_send.to_string(),
finished: false,
usage: None,
tool_calls: None,
};
let _ = tx.blocking_send(Ok(chunk));
}
}
break;
}
// Check if we're building towards a stop sequence
let mut might_be_stop = false;
for stop_seq in &stop_sequences {
for i in 1..stop_seq.len() {
let partial = &stop_seq[..i];
if accumulated_text.ends_with(partial) {
debug!("Detected potential partial stop sequence: '{}'", partial);
might_be_stop = true;
break;
}
}
if might_be_stop {
break;
}
}
if might_be_stop {
// Hold back tokens, but only for a limited buffer size
if unsent_tokens.len() > 20 { // Don't hold back more than 20 characters
// Send the oldest part and keep only the recent part that might be a stop sequence
let to_send = &unsent_tokens[..unsent_tokens.len() - 10];
if !to_send.is_empty() {
let chunk = CompletionChunk {
content: to_send.to_string(),
finished: false,
usage: None,
tool_calls: None,
};
if tx.blocking_send(Ok(chunk)).is_err() {
break;
}
}
unsent_tokens = unsent_tokens[unsent_tokens.len() - 10..].to_string();
}
// Continue to next token without sending
} else {
// No potential stop sequence, send all unsent tokens
if !unsent_tokens.is_empty() {
let chunk = CompletionChunk {
content: unsent_tokens.clone(),
finished: false,
usage: None,
tool_calls: None,
};
if tx.blocking_send(Ok(chunk)).is_err() {
break;
}
unsent_tokens.clear();
}
}
// Enforce token limit
if token_count >= max_tokens as usize {
debug!("Reached max token limit in streaming: {}", max_tokens);
break;
}
}
// Send final chunk
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: None, // Embedded models calculate usage differently
tool_calls: None,
};
let _ = tx.blocking_send(Ok(final_chunk));
});
Ok(ReceiverStream::new(rx))
}
fn name(&self) -> &str {
"embedded"
}
fn model(&self) -> &str {
&self.model_name
}
fn max_tokens(&self) -> u32 {
self.max_tokens
}
fn temperature(&self) -> f32 {
self.temperature
}
}