From 358faa15e027b29158753b98594248db04db6dce Mon Sep 17 00:00:00 2001 From: Dhanji Prasanna Date: Tue, 9 Sep 2025 15:18:08 +1000 Subject: [PATCH] End tool call --- crates/g3-core/src/lib.rs | 103 ++++++++---- crates/g3-core/src/providers/embedded.rs | 194 ++++++++++++++++++----- 2 files changed, 225 insertions(+), 72 deletions(-) diff --git a/crates/g3-core/src/lib.rs b/crates/g3-core/src/lib.rs index 164582e..8a31322 100644 --- a/crates/g3-core/src/lib.rs +++ b/crates/g3-core/src/lib.rs @@ -53,22 +53,31 @@ impl StreamingToolParser { // Look for the start of a tool call pattern: {"tool": if !self.in_tool_call { // Look for JSON tool call pattern - check both raw JSON and inside code blocks - if let Some(pos) = self.buffer.rfind(r#"{"tool":"#) { - //info!("Found tool call pattern at position: {}", pos); + // Also handle malformed patterns like {"{""tool"": + let patterns = [ + r#"{"tool":"#, // Normal pattern + r#"{"{""tool"":"#, // Malformed pattern with extra brace and doubled quotes + r#"{{""tool"":"#, // Alternative malformed pattern + ]; + + for pattern in &patterns { + if let Some(pos) = self.buffer.rfind(pattern) { + info!("Found tool call pattern '{}' at position: {}", pattern, pos); - // Check if this is inside a code block - let before_pos = &self.buffer[..pos]; - let code_block_count = before_pos.matches("```").count(); + // Check if this is inside a code block + let before_pos = &self.buffer[..pos]; + let code_block_count = before_pos.matches("```").count(); - // Accept tool calls both inside and outside code blocks - // The LLM might use either format despite our instructions - //info!("Starting tool call parsing (code block status: {})", code_block_count % 2 == 1); - self.in_tool_call = true; - self.tool_start_pos = Some(pos); - self.brace_count = 0; // Start counting from 0, we'll count the opening brace in parsing + // Accept tool calls both inside and outside code blocks + // The LLM might use either format despite our instructions + //info!("Starting tool call parsing (code block status: {})", code_block_count % 2 == 1); + self.in_tool_call = true; + self.tool_start_pos = Some(pos); + self.brace_count = 0; // Start counting from 0, we'll count the opening brace in parsing - // Continue parsing from after the opening brace - return self.parse_from_start_pos(pos); + // Continue parsing from after the opening brace + return self.parse_from_start_pos(pos); + } } } else { //info!("Already in tool call, continuing parsing"); @@ -100,10 +109,17 @@ impl StreamingToolParser { if current_brace_count == 0 { // Found complete JSON object let end_pos = start_pos + i + 1; - let json_str = &self.buffer[start_pos..end_pos]; + let mut json_str = self.buffer[start_pos..end_pos].to_string(); - if let Ok(tool_call) = serde_json::from_str::(json_str) { - //info!("Successfully parsed tool call: {:?}", tool_call); + // Clean up malformed JSON patterns + json_str = json_str + .replace(r#"{"{""#, r#"{"#) // Fix {"{" -> {" + .replace(r#"""}"#, r#""}"#) // Fix ""} -> "} + .replace(r#"{{""#, r#"{"#) // Fix {{" -> {" + .replace(r#"""}"#, r#""}"#); // Fix ""} -> "} + + if let Ok(tool_call) = serde_json::from_str::(&json_str) { + info!("Successfully parsed tool call: {:?}", tool_call); // Reset parser state self.in_tool_call = false; self.tool_start_pos = None; @@ -111,7 +127,7 @@ impl StreamingToolParser { return Some((tool_call, end_pos)); } else { - info!("Failed to parse JSON: {}", json_str); + info!("Failed to parse JSON after cleanup: {}", json_str); // Invalid JSON, reset and continue looking self.in_tool_call = false; self.tool_start_pos = None; @@ -261,6 +277,7 @@ impl Agent { "codellama" => 16384, // CodeLlama supports 16k context "llama" => 4096, // Base Llama models "mistral" => 8192, // Mistral models + "qwen" => 32768, // Qwen2.5 supports 32k context _ => 4096, // Conservative default } }) @@ -630,28 +647,42 @@ The tool will execute immediately and you'll receive the result to continue with // Found a complete tool call! Stop streaming and execute it let content_before_tool = parser.get_content_before_tool(tool_end_pos); - // Display content up to the tool call (excluding the JSON) + // Display content up to the tool call (excluding the JSON and any stop tokens) let display_content = if let Some(json_start) = content_before_tool.rfind(r#"{"tool":"#) { - &content_before_tool[..json_start] + // Only show content before the JSON tool call + content_before_tool[..json_start].trim() } else { - &content_before_tool + // Fallback: clean any stop tokens from the content + content_before_tool.trim() }; + // Clean stop tokens from display content + let clean_display_content = display_content + .replace("<|im_end|>", "") + .replace("", "") + .replace("[/INST]", "") + .replace("<>", ""); + let final_display_content = clean_display_content.trim(); + // Safely get the new content to display - let new_content = if current_response.len() <= display_content.len() { + let new_content = if current_response.len() <= final_display_content.len() { // Use char indices to avoid UTF-8 boundary issues let chars_already_shown = current_response.chars().count(); - display_content + final_display_content .chars() .skip(chars_already_shown) .collect::() } else { String::new() }; - print!("{}", new_content); - io::stdout().flush()?; + + // Only print if there's actually new content to show + if !new_content.trim().is_empty() { + print!("{}", new_content); + io::stdout().flush()?; + } // Execute the tool with formatted output println!(); // New line before tool execution @@ -724,20 +755,36 @@ The tool will execute immediately and you'll receive the result to continue with // Update the request with the new context for next iteration request.messages = self.context_window.conversation_history.clone(); - full_response.push_str(display_content); + full_response.push_str(final_display_content); full_response.push_str(&format!( "\n\nTool executed: {} -> {}\n\n", tool_call.tool, tool_result )); + // Check if this was a final_output tool call - if so, stop the conversation + if tool_call.tool == "final_output" { + println!(); // New line after final output + let ttft = first_token_time.unwrap_or_else(|| stream_start.elapsed()); + return Ok((full_response, ttft)); + } + tool_executed = true; // Break out of current stream to start a new one with updated context break; } else { // No tool call detected, continue streaming normally - print!("{}", chunk.content); - io::stdout().flush()?; - current_response.push_str(&chunk.content); + // Filter out stop tokens from the streaming output + let clean_content = chunk.content + .replace("<|im_end|>", "") + .replace("", "") + .replace("[/INST]", "") + .replace("<>", ""); + + if !clean_content.is_empty() { + print!("{}", clean_content); + io::stdout().flush()?; + current_response.push_str(&clean_content); + } } if chunk.finished { diff --git a/crates/g3-core/src/providers/embedded.rs b/crates/g3-core/src/providers/embedded.rs index c10cdec..dab87be 100644 --- a/crates/g3-core/src/providers/embedded.rs +++ b/crates/g3-core/src/providers/embedded.rs @@ -86,27 +86,51 @@ impl EmbeddedProvider { } fn format_messages(&self, messages: &[Message]) -> String { - // Use proper prompt format for CodeLlama - let mut formatted = String::new(); + // Determine the appropriate format based on model type + let model_name_lower = self.model_name.to_lowercase(); + + if model_name_lower.contains("qwen") { + // Qwen format: <|im_start|>role\ncontent<|im_end|> + let mut formatted = String::new(); + + for message in messages { + let role = match message.role { + MessageRole::System => "system", + MessageRole::User => "user", + MessageRole::Assistant => "assistant", + }; + + formatted.push_str(&format!( + "<|im_start|>{}\n{}<|im_end|>\n", + role, message.content + )); + } + + // Add the start of assistant response + formatted.push_str("<|im_start|>assistant\n"); + formatted + } else { + // Use Llama/CodeLlama format for other models + let mut formatted = String::new(); - for message in messages { - match message.role { - MessageRole::System => { - formatted.push_str(&format!( - "[INST] <>\n{}\n<>\n\n", - message.content - )); - } - MessageRole::User => { - formatted.push_str(&format!("{} [/INST] ", message.content)); - } - MessageRole::Assistant => { - formatted.push_str(&format!("{} [INST] ", message.content)); + for message in messages { + match message.role { + MessageRole::System => { + formatted.push_str(&format!( + "[INST] <>\n{}\n<>\n\n", + message.content + )); + } + MessageRole::User => { + formatted.push_str(&format!("{} [/INST] ", message.content)); + } + MessageRole::Assistant => { + formatted.push_str(&format!("{} [INST] ", message.content)); + } } } + formatted } - - formatted } async fn generate_completion( @@ -138,10 +162,26 @@ impl EmbeddedProvider { let result = tokio::time::timeout( timeout_duration, tokio::task::spawn_blocking(move || { - let mut session = match session.try_lock() { - Ok(ctx) => ctx, - Err(_) => return Err(anyhow::anyhow!("Model is busy, please try again")), - }; + // Retry logic for acquiring the session lock + let mut session_guard = None; + for attempt in 0..5 { + match session.try_lock() { + Ok(ctx) => { + session_guard = Some(ctx); + break; + } + Err(_) => { + if attempt < 4 { + debug!("Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1); + std::thread::sleep(std::time::Duration::from_millis(100 * (attempt + 1) as u64)); + } else { + return Err(anyhow::anyhow!("Model is busy after 5 attempts, please try again")); + } + } + } + } + + let mut session = session_guard.ok_or_else(|| anyhow::anyhow!("Failed to acquire session lock"))?; debug!( "Starting inference with prompt length: {} chars, estimated {} tokens", @@ -264,7 +304,14 @@ impl EmbeddedProvider { // Determine model type from model_name let model_name_lower = self.model_name.to_lowercase(); - if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama") { + if model_name_lower.contains("qwen") { + vec![ + "<|im_end|>", // Qwen ChatML format end token + "<|endoftext|>", // Alternative end token + "", // Generic end of sequence + "<|im_start|>", // Start of new message (shouldn't appear in response) + ] + } else if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama") { vec![ "", // End of sequence "[/INST]", // End of instruction @@ -381,11 +428,30 @@ impl LLMProvider for EmbeddedProvider { // Spawn streaming task tokio::task::spawn_blocking(move || { - let mut session = match session.try_lock() { - Ok(ctx) => ctx, - Err(_) => { - let _ = - tx.blocking_send(Err(anyhow::anyhow!("Model is busy, please try again"))); + // Retry logic for acquiring the session lock + let mut session_guard = None; + for attempt in 0..5 { + match session.try_lock() { + Ok(ctx) => { + session_guard = Some(ctx); + break; + } + Err(_) => { + if attempt < 4 { + debug!("Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1); + std::thread::sleep(std::time::Duration::from_millis(100 * (attempt + 1) as u64)); + } else { + let _ = tx.blocking_send(Err(anyhow::anyhow!("Model is busy after 5 attempts, please try again"))); + return; + } + } + } + } + + let mut session = match session_guard { + Some(ctx) => ctx, + None => { + let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to acquire session lock"))); return; } }; @@ -418,11 +484,13 @@ impl LLMProvider for EmbeddedProvider { let mut accumulated_text = String::new(); let mut token_count = 0; + let mut unsent_tokens = String::new(); // Buffer for tokens we're holding back // Get stop sequences dynamically based on model type - // We need to create a temporary EmbeddedProvider instance to access the method - // Since we can't access self in the spawned task, we'll use a static approach - let stop_sequences = if prompt.contains("[INST]") || prompt.contains("<>") { + let stop_sequences = if prompt.contains("<|im_start|>") { + // Qwen ChatML format detected + vec!["<|im_end|>", "<|endoftext|>", "", "<|im_start|>"] + } else if prompt.contains("[INST]") || prompt.contains("<>") { // Llama/CodeLlama format detected vec!["", "[/INST]", "<>", "[INST]", "<>", "### Human:", "### Assistant:"] } else { @@ -435,21 +503,21 @@ impl LLMProvider for EmbeddedProvider { let token_string = session.model().token_to_piece(token); accumulated_text.push_str(&token_string); + unsent_tokens.push_str(&token_string); token_count += 1; - // Check if we've hit a stop sequence + // Check if we've hit a complete stop sequence let mut hit_stop = false; for stop_seq in &stop_sequences { if accumulated_text.contains(stop_seq) { - debug!("Hit stop sequence in streaming: {}", stop_seq); + debug!("Hit complete stop sequence in streaming: {}", stop_seq); hit_stop = true; break; } } if hit_stop { - // Don't send the token that contains the stop sequence - // Instead, send only the part before the stop sequence + // Send any remaining clean content before stopping let mut clean_accumulated = accumulated_text.clone(); for stop_seq in &stop_sequences { if let Some(pos) = clean_accumulated.find(stop_seq) { @@ -459,7 +527,7 @@ impl LLMProvider for EmbeddedProvider { } // Calculate what part we haven't sent yet - let already_sent_len = accumulated_text.len() - token_string.len(); + let already_sent_len = accumulated_text.len() - unsent_tokens.len(); if clean_accumulated.len() > already_sent_len { let remaining_to_send = &clean_accumulated[already_sent_len..]; if !remaining_to_send.is_empty() { @@ -472,16 +540,54 @@ impl LLMProvider for EmbeddedProvider { } } break; - } else { - // Normal token, send it - let chunk = CompletionChunk { - content: token_string.clone(), - finished: false, - tool_calls: None, - }; + } - if tx.blocking_send(Ok(chunk)).is_err() { - break; // Receiver dropped + // Check if we're building towards a stop sequence + let mut might_be_stop = false; + for stop_seq in &stop_sequences { + for i in 1..stop_seq.len() { + let partial = &stop_seq[..i]; + if accumulated_text.ends_with(partial) { + debug!("Detected potential partial stop sequence: '{}'", partial); + might_be_stop = true; + break; + } + } + if might_be_stop { + break; + } + } + + if might_be_stop { + // Hold back tokens, but only for a limited buffer size + if unsent_tokens.len() > 20 { // Don't hold back more than 20 characters + // Send the oldest part and keep only the recent part that might be a stop sequence + let to_send = &unsent_tokens[..unsent_tokens.len() - 10]; + if !to_send.is_empty() { + let chunk = CompletionChunk { + content: to_send.to_string(), + finished: false, + tool_calls: None, + }; + if tx.blocking_send(Ok(chunk)).is_err() { + break; + } + } + unsent_tokens = unsent_tokens[unsent_tokens.len() - 10..].to_string(); + } + // Continue to next token without sending + } else { + // No potential stop sequence, send all unsent tokens + if !unsent_tokens.is_empty() { + let chunk = CompletionChunk { + content: unsent_tokens.clone(), + finished: false, + tool_calls: None, + }; + if tx.blocking_send(Ok(chunk)).is_err() { + break; + } + unsent_tokens.clear(); } }