From 0dc3b42f3868e61b3f65fad40cf2ec8ecde022ca Mon Sep 17 00:00:00 2001 From: Dhanji Prasanna Date: Mon, 15 Sep 2025 08:18:31 +1000 Subject: [PATCH] decent basic toolshim --- crates/g3-core/src/lib.rs | 230 +++++++++++++++++++---- crates/g3-core/src/providers/embedded.rs | 48 ++++- 2 files changed, 240 insertions(+), 38 deletions(-) diff --git a/crates/g3-core/src/lib.rs b/crates/g3-core/src/lib.rs index 8a31322..d1a4b93 100644 --- a/crates/g3-core/src/lib.rs +++ b/crates/g3-core/src/lib.rs @@ -7,7 +7,7 @@ use std::fs; use std::path::Path; use std::time::{Duration, Instant}; use tokio_util::sync::CancellationToken; -use tracing::{error, info, warn, debug}; +use tracing::{debug, error, info, warn}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolCall { @@ -55,11 +55,11 @@ impl StreamingToolParser { // Look for JSON tool call pattern - check both raw JSON and inside code blocks // Also handle malformed patterns like {"{""tool"": let patterns = [ - r#"{"tool":"#, // Normal pattern - r#"{"{""tool"":"#, // Malformed pattern with extra brace and doubled quotes - r#"{{""tool"":"#, // Alternative malformed pattern + r#"{"tool":"#, // Normal pattern + r#"{"{""tool"":"#, // Malformed pattern with extra brace and doubled quotes + r#"{{""tool"":"#, // Alternative malformed pattern ]; - + for pattern in &patterns { if let Some(pos) = self.buffer.rfind(pattern) { info!("Found tool call pattern '{}' at position: {}", pattern, pos); @@ -113,11 +113,12 @@ impl StreamingToolParser { // Clean up malformed JSON patterns json_str = json_str - .replace(r#"{"{""#, r#"{"#) // Fix {"{" -> {" - .replace(r#"""}"#, r#""}"#) // Fix ""} -> "} - .replace(r#"{{""#, r#"{"#) // Fix {{" -> {" - .replace(r#"""}"#, r#""}"#); // Fix ""} -> "} + .replace(r#"{"{""#, r#"{"#) // Fix {"{" -> {" + .replace(r#"""}"#, r#""}"#) // Fix ""} -> "} + .replace(r#"{{""#, r#"{"#) // Fix {{" -> {" + .replace(r#"""}"#, r#""}"#); // Fix ""} -> "} + // First, try to parse the JSON as-is if let Ok(tool_call) = serde_json::from_str::(&json_str) { info!("Successfully parsed tool call: {:?}", tool_call); // Reset parser state @@ -126,13 +127,49 @@ impl StreamingToolParser { self.brace_count = 0; return Some((tool_call, end_pos)); - } else { - info!("Failed to parse JSON after cleanup: {}", json_str); - // Invalid JSON, reset and continue looking - self.in_tool_call = false; - self.tool_start_pos = None; - self.brace_count = 0; } + + // If parsing failed and this is a shell command, try to fix nested quotes + if json_str.contains(r#""tool": "shell""#) { + let fixed_json = fix_nested_quotes_in_shell_command(&json_str); + if let Ok(tool_call) = serde_json::from_str::(&fixed_json) { + info!("Successfully parsed tool call after fixing nested quotes: {:?}", tool_call); + // Reset parser state + self.in_tool_call = false; + self.tool_start_pos = None; + self.brace_count = 0; + + return Some((tool_call, end_pos)); + } else { + info!( + "Failed to parse JSON even after fixing nested quotes: {}", + fixed_json + ); + } + } + + // Try to fix mixed quote issues (single quotes in JSON) + let fixed_mixed_quotes = fix_mixed_quotes_in_json(&json_str); + if fixed_mixed_quotes != json_str { + if let Ok(tool_call) = serde_json::from_str::(&fixed_mixed_quotes) { + info!("Successfully parsed tool call after fixing mixed quotes: {:?}", tool_call); + // Reset parser state + self.in_tool_call = false; + self.tool_start_pos = None; + self.brace_count = 0; + + return Some((tool_call, end_pos)); + } else { + info!("Failed to parse JSON even after fixing mixed quotes: {}", fixed_mixed_quotes); + } + } else { + info!("Failed to parse JSON (no fixes applied): {}", json_str); + } + + // Invalid JSON, reset and continue looking + self.in_tool_call = false; + self.tool_start_pos = None; + self.brace_count = 0; } } _ => {} @@ -245,7 +282,10 @@ impl Agent { } // Set default provider - debug!("Setting default provider to: {}", config.providers.default_provider); + debug!( + "Setting default provider to: {}", + config.providers.default_provider + ); providers.set_default(&config.providers.default_provider)?; debug!("Default provider set successfully"); @@ -394,7 +434,7 @@ When you need to execute a tool, write ONLY the JSON tool call on a new line: {{\"tool\": \"tool_name\", \"args\": {{\"param\": \"value\"}}}} -The tool will execute immediately and you'll receive the result to continue with. +The tool will execute immediately and you'll receive the result (success or error) to continue with. # Available Tools @@ -410,12 +450,12 @@ The tool will execute immediately and you'll receive the result to continue with 1. Analyze the request and break down into smaller tasks if appropriate 2. Execute ONE tool at a time 3. STOP when the original request was satisfied -4. End with final_output when done +4. Call the final_output tool when done # Response Guidelines - Use Markdown formatting for all responses except tool calls. -- Whenever calling tools, use the pronoun 'I' +- Whenever taking actions, use the pronoun 'I' "); @@ -616,10 +656,10 @@ The tool will execute immediately and you'll receive the result to continue with first_token_time = Some(stream_start.elapsed()); } - // Check for tool calls - either from JSON parsing (embedded models) + // Check for tool calls - either from JSON parsing (embedded models) // or from native tool calls (Anthropic, OpenAI, etc.) let mut detected_tool_call = None; - + // First check for native tool calls in the chunk if let Some(ref tool_calls) = chunk.tool_calls { debug!("Found native tool calls in chunk: {:?}", tool_calls); @@ -637,12 +677,13 @@ The tool will execute immediately and you'll receive the result to continue with } else { debug!("No native tool calls in chunk, chunk.tool_calls is None"); } - + // If no native tool calls, check for JSON tool calls in text (embedded models) + // IMPORTANT: Pass raw content to parser BEFORE cleaning stop sequences if detected_tool_call.is_none() { detected_tool_call = parser.add_chunk(&chunk.content); } - + if let Some((tool_call, tool_end_pos)) = detected_tool_call { // Found a complete tool call! Stop streaming and execute it let content_before_tool = parser.get_content_before_tool(tool_end_pos); @@ -667,17 +708,18 @@ The tool will execute immediately and you'll receive the result to continue with let final_display_content = clean_display_content.trim(); // Safely get the new content to display - let new_content = if current_response.len() <= final_display_content.len() { - // Use char indices to avoid UTF-8 boundary issues - let chars_already_shown = current_response.chars().count(); - final_display_content - .chars() - .skip(chars_already_shown) - .collect::() - } else { - String::new() - }; - + let new_content = + if current_response.len() <= final_display_content.len() { + // Use char indices to avoid UTF-8 boundary issues + let chars_already_shown = current_response.chars().count(); + final_display_content + .chars() + .skip(chars_already_shown) + .collect::() + } else { + String::new() + }; + // Only print if there's actually new content to show if !new_content.trim().is_empty() { print!("{}", new_content); @@ -764,7 +806,8 @@ The tool will execute immediately and you'll receive the result to continue with // Check if this was a final_output tool call - if so, stop the conversation if tool_call.tool == "final_output" { println!(); // New line after final output - let ttft = first_token_time.unwrap_or_else(|| stream_start.elapsed()); + let ttft = + first_token_time.unwrap_or_else(|| stream_start.elapsed()); return Ok((full_response, ttft)); } @@ -774,12 +817,13 @@ The tool will execute immediately and you'll receive the result to continue with } else { // No tool call detected, continue streaming normally // Filter out stop tokens from the streaming output - let clean_content = chunk.content + let clean_content = chunk + .content .replace("<|im_end|>", "") .replace("", "") .replace("[/INST]", "") .replace("<>", ""); - + if !clean_content.is_empty() { print!("{}", clean_content); io::stdout().flush()?; @@ -978,6 +1022,118 @@ fn shell_escape_command(command: &str) -> String { } } +// Helper function to fix nested quotes in shell command JSON +fn fix_nested_quotes_in_shell_command(json_str: &str) -> String { + // This handles cases where shell commands contain nested quotes that break JSON parsing + // Example: {"tool": "shell", "args": {"command": "python -c 'import os; print("hello")'"}} + + // Look for the pattern: "command": " + if let Some(command_start) = json_str.find(r#""command": ""#) { + let command_value_start = command_start + r#""command": ""#.len(); + + // Find the end of the command string by looking for the pattern "} + // We need to be careful about nested quotes + if let Some(end_marker) = json_str[command_value_start..].find(r#"" }"#) { + let command_end = command_value_start + end_marker; + + let before = &json_str[..command_value_start]; + let command_content = &json_str[command_value_start..command_end]; + let after = &json_str[command_end..]; + + // Fix the command content by properly escaping quotes + let mut fixed_command = String::new(); + let mut chars = command_content.chars().peekable(); + + while let Some(ch) = chars.next() { + match ch { + '"' => { + // Check if this quote is already escaped + if fixed_command.ends_with('\\') { + fixed_command.push(ch); // Already escaped, keep as-is + } else { + fixed_command.push_str(r#"\""#); // Escape the quote + } + } + '\\' => { + // Check what follows the backslash + if let Some(&next_ch) = chars.peek() { + if next_ch == '"' { + // This is an escaped quote, keep the backslash + fixed_command.push(ch); + } else { + // Regular backslash, escape it + fixed_command.push_str(r#"\\"#); + } + } else { + // Backslash at end, escape it + fixed_command.push_str(r#"\\"#); + } + } + _ => fixed_command.push(ch), + } + } + + return format!("{}{}{}", before, fixed_command, after); + } + } + + // Fallback: if we can't parse the structure, try some basic replacements + json_str.to_string() +} + +// Helper function to fix mixed quotes in JSON (single quotes where double quotes should be) +fn fix_mixed_quotes_in_json(json_str: &str) -> String { + // This handles cases where the LLM uses single quotes in JSON values + // Example: {"tool": "shell", "args": {"command": 'echo "hello"'}} + + let mut result = String::new(); + let mut chars = json_str.chars().peekable(); + let mut in_string = false; + let mut string_delimiter = '"'; + + while let Some(ch) = chars.next() { + match ch { + '"' if !in_string => { + // Start of a double-quoted string + in_string = true; + string_delimiter = '"'; + result.push(ch); + } + '\'' if !in_string => { + // Start of a single-quoted string - convert to double quotes + in_string = true; + string_delimiter = '\''; + result.push('"'); // Convert single quote to double quote + } + c if in_string && c == string_delimiter => { + // End of current string + if string_delimiter == '\'' { + result.push('"'); // Convert single quote to double quote + } else { + result.push(c); + } + in_string = false; + } + '"' if in_string && string_delimiter == '\'' => { + // Double quote inside single-quoted string - escape it + result.push_str(r#"\""#); + } + '\\' if in_string => { + // Escape sequence - preserve it + result.push(ch); + if let Some(&next_ch) = chars.peek() { + result.push(chars.next().unwrap()); + } + } + _ => { + result.push(ch); + } + } + } + + result +} + pub mod providers { pub mod anthropic; pub mod embedded; diff --git a/crates/g3-core/src/providers/embedded.rs b/crates/g3-core/src/providers/embedded.rs index dab87be..22f9972 100644 --- a/crates/g3-core/src/providers/embedded.rs +++ b/crates/g3-core/src/providers/embedded.rs @@ -517,7 +517,53 @@ impl LLMProvider for EmbeddedProvider { } if hit_stop { - // Send any remaining clean content before stopping + // Before stopping, check if there might be an incomplete tool call + // Look for JSON tool call patterns that might be cut off by the stop sequence + let has_potential_tool_call = accumulated_text.contains(r#"{"tool":"#) || + accumulated_text.contains(r#"{"{""tool"":"#) || + accumulated_text.contains(r#"{{""tool"":"#); + + if has_potential_tool_call { + // Check if the tool call appears to be complete (has closing brace after the stop sequence) + let mut complete_tool_call = false; + for stop_seq in &stop_sequences { + if let Some(stop_pos) = accumulated_text.find(stop_seq) { + // Look for tool call pattern before the stop sequence + let before_stop = &accumulated_text[..stop_pos]; + if let Some(tool_start) = before_stop.rfind(r#"{"tool":"#) { + let tool_part = &before_stop[tool_start..]; + // Count braces to see if JSON is complete + let open_braces = tool_part.matches('{').count(); + let close_braces = tool_part.matches('}').count(); + if open_braces > 0 && open_braces == close_braces { + complete_tool_call = true; + break; + } + } + } + } + + // If tool call is incomplete, send the raw content including stop sequences + // so the main parser can handle it properly + if !complete_tool_call { + debug!("Found incomplete tool call, sending raw content with stop sequences"); + let already_sent_len = accumulated_text.len() - unsent_tokens.len(); + if accumulated_text.len() > already_sent_len { + let remaining_to_send = &accumulated_text[already_sent_len..]; + if !remaining_to_send.is_empty() { + let chunk = CompletionChunk { + content: remaining_to_send.to_string(), + finished: false, + tool_calls: None, + }; + let _ = tx.blocking_send(Ok(chunk)); + } + } + break; + } + } + + // Send any remaining clean content before stopping (original behavior) let mut clean_accumulated = accumulated_text.clone(); for stop_seq in &stop_sequences { if let Some(pos) = clean_accumulated.find(stop_seq) {