decent basic toolshim

This commit is contained in:
Dhanji Prasanna
2025-09-15 08:18:31 +10:00
parent 358faa15e0
commit 0dc3b42f38
2 changed files with 240 additions and 38 deletions

View File

@@ -7,7 +7,7 @@ use std::fs;
use std::path::Path; use std::path::Path;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn, debug}; use tracing::{debug, error, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall { pub struct ToolCall {
@@ -118,6 +118,7 @@ impl StreamingToolParser {
.replace(r#"{{""#, r#"{"#) // Fix {{" -> {" .replace(r#"{{""#, r#"{"#) // Fix {{" -> {"
.replace(r#"""}"#, r#""}"#); // Fix ""} -> "} .replace(r#"""}"#, r#""}"#); // Fix ""} -> "}
// First, try to parse the JSON as-is
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(&json_str) { if let Ok(tool_call) = serde_json::from_str::<ToolCall>(&json_str) {
info!("Successfully parsed tool call: {:?}", tool_call); info!("Successfully parsed tool call: {:?}", tool_call);
// Reset parser state // Reset parser state
@@ -125,16 +126,52 @@ impl StreamingToolParser {
self.tool_start_pos = None; self.tool_start_pos = None;
self.brace_count = 0; self.brace_count = 0;
return Some((tool_call, end_pos));
}
// If parsing failed and this is a shell command, try to fix nested quotes
if json_str.contains(r#""tool": "shell""#) {
let fixed_json = fix_nested_quotes_in_shell_command(&json_str);
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(&fixed_json) {
info!("Successfully parsed tool call after fixing nested quotes: {:?}", tool_call);
// Reset parser state
self.in_tool_call = false;
self.tool_start_pos = None;
self.brace_count = 0;
return Some((tool_call, end_pos)); return Some((tool_call, end_pos));
} else { } else {
info!("Failed to parse JSON after cleanup: {}", json_str); info!(
"Failed to parse JSON even after fixing nested quotes: {}",
fixed_json
);
}
}
// Try to fix mixed quote issues (single quotes in JSON)
let fixed_mixed_quotes = fix_mixed_quotes_in_json(&json_str);
if fixed_mixed_quotes != json_str {
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(&fixed_mixed_quotes) {
info!("Successfully parsed tool call after fixing mixed quotes: {:?}", tool_call);
// Reset parser state
self.in_tool_call = false;
self.tool_start_pos = None;
self.brace_count = 0;
return Some((tool_call, end_pos));
} else {
info!("Failed to parse JSON even after fixing mixed quotes: {}", fixed_mixed_quotes);
}
} else {
info!("Failed to parse JSON (no fixes applied): {}", json_str);
}
// Invalid JSON, reset and continue looking // Invalid JSON, reset and continue looking
self.in_tool_call = false; self.in_tool_call = false;
self.tool_start_pos = None; self.tool_start_pos = None;
self.brace_count = 0; self.brace_count = 0;
} }
} }
}
_ => {} _ => {}
} }
} }
@@ -245,7 +282,10 @@ impl Agent {
} }
// Set default provider // Set default provider
debug!("Setting default provider to: {}", config.providers.default_provider); debug!(
"Setting default provider to: {}",
config.providers.default_provider
);
providers.set_default(&config.providers.default_provider)?; providers.set_default(&config.providers.default_provider)?;
debug!("Default provider set successfully"); debug!("Default provider set successfully");
@@ -394,7 +434,7 @@ When you need to execute a tool, write ONLY the JSON tool call on a new line:
{{\"tool\": \"tool_name\", \"args\": {{\"param\": \"value\"}}}} {{\"tool\": \"tool_name\", \"args\": {{\"param\": \"value\"}}}}
The tool will execute immediately and you'll receive the result to continue with. The tool will execute immediately and you'll receive the result (success or error) to continue with.
# Available Tools # Available Tools
@@ -410,12 +450,12 @@ The tool will execute immediately and you'll receive the result to continue with
1. Analyze the request and break down into smaller tasks if appropriate 1. Analyze the request and break down into smaller tasks if appropriate
2. Execute ONE tool at a time 2. Execute ONE tool at a time
3. STOP when the original request was satisfied 3. STOP when the original request was satisfied
4. End with final_output when done 4. Call the final_output tool when done
# Response Guidelines # Response Guidelines
- Use Markdown formatting for all responses except tool calls. - Use Markdown formatting for all responses except tool calls.
- Whenever calling tools, use the pronoun 'I' - Whenever taking actions, use the pronoun 'I'
"); ");
@@ -639,6 +679,7 @@ The tool will execute immediately and you'll receive the result to continue with
} }
// If no native tool calls, check for JSON tool calls in text (embedded models) // If no native tool calls, check for JSON tool calls in text (embedded models)
// IMPORTANT: Pass raw content to parser BEFORE cleaning stop sequences
if detected_tool_call.is_none() { if detected_tool_call.is_none() {
detected_tool_call = parser.add_chunk(&chunk.content); detected_tool_call = parser.add_chunk(&chunk.content);
} }
@@ -667,7 +708,8 @@ The tool will execute immediately and you'll receive the result to continue with
let final_display_content = clean_display_content.trim(); let final_display_content = clean_display_content.trim();
// Safely get the new content to display // Safely get the new content to display
let new_content = if current_response.len() <= final_display_content.len() { let new_content =
if current_response.len() <= final_display_content.len() {
// Use char indices to avoid UTF-8 boundary issues // Use char indices to avoid UTF-8 boundary issues
let chars_already_shown = current_response.chars().count(); let chars_already_shown = current_response.chars().count();
final_display_content final_display_content
@@ -764,7 +806,8 @@ The tool will execute immediately and you'll receive the result to continue with
// Check if this was a final_output tool call - if so, stop the conversation // Check if this was a final_output tool call - if so, stop the conversation
if tool_call.tool == "final_output" { if tool_call.tool == "final_output" {
println!(); // New line after final output println!(); // New line after final output
let ttft = first_token_time.unwrap_or_else(|| stream_start.elapsed()); let ttft =
first_token_time.unwrap_or_else(|| stream_start.elapsed());
return Ok((full_response, ttft)); return Ok((full_response, ttft));
} }
@@ -774,7 +817,8 @@ The tool will execute immediately and you'll receive the result to continue with
} else { } else {
// No tool call detected, continue streaming normally // No tool call detected, continue streaming normally
// Filter out stop tokens from the streaming output // Filter out stop tokens from the streaming output
let clean_content = chunk.content let clean_content = chunk
.content
.replace("<|im_end|>", "") .replace("<|im_end|>", "")
.replace("</s>", "") .replace("</s>", "")
.replace("[/INST]", "") .replace("[/INST]", "")
@@ -978,6 +1022,118 @@ fn shell_escape_command(command: &str) -> String {
} }
} }
// Helper function to fix nested quotes in shell command JSON
fn fix_nested_quotes_in_shell_command(json_str: &str) -> String {
// This handles cases where shell commands contain nested quotes that break JSON parsing
// Example: {"tool": "shell", "args": {"command": "python -c 'import os; print("hello")'"}}
// Look for the pattern: "command": "
if let Some(command_start) = json_str.find(r#""command": ""#) {
let command_value_start = command_start + r#""command": ""#.len();
// Find the end of the command string by looking for the pattern "}
// We need to be careful about nested quotes
if let Some(end_marker) = json_str[command_value_start..].find(r#"" }"#) {
let command_end = command_value_start + end_marker;
let before = &json_str[..command_value_start];
let command_content = &json_str[command_value_start..command_end];
let after = &json_str[command_end..];
// Fix the command content by properly escaping quotes
let mut fixed_command = String::new();
let mut chars = command_content.chars().peekable();
while let Some(ch) = chars.next() {
match ch {
'"' => {
// Check if this quote is already escaped
if fixed_command.ends_with('\\') {
fixed_command.push(ch); // Already escaped, keep as-is
} else {
fixed_command.push_str(r#"\""#); // Escape the quote
}
}
'\\' => {
// Check what follows the backslash
if let Some(&next_ch) = chars.peek() {
if next_ch == '"' {
// This is an escaped quote, keep the backslash
fixed_command.push(ch);
} else {
// Regular backslash, escape it
fixed_command.push_str(r#"\\"#);
}
} else {
// Backslash at end, escape it
fixed_command.push_str(r#"\\"#);
}
}
_ => fixed_command.push(ch),
}
}
return format!("{}{}{}", before, fixed_command, after);
}
}
// Fallback: if we can't parse the structure, try some basic replacements
json_str.to_string()
}
// Helper function to fix mixed quotes in JSON (single quotes where double quotes should be)
fn fix_mixed_quotes_in_json(json_str: &str) -> String {
// This handles cases where the LLM uses single quotes in JSON values
// Example: {"tool": "shell", "args": {"command": 'echo "hello"'}}
let mut result = String::new();
let mut chars = json_str.chars().peekable();
let mut in_string = false;
let mut string_delimiter = '"';
while let Some(ch) = chars.next() {
match ch {
'"' if !in_string => {
// Start of a double-quoted string
in_string = true;
string_delimiter = '"';
result.push(ch);
}
'\'' if !in_string => {
// Start of a single-quoted string - convert to double quotes
in_string = true;
string_delimiter = '\'';
result.push('"'); // Convert single quote to double quote
}
c if in_string && c == string_delimiter => {
// End of current string
if string_delimiter == '\'' {
result.push('"'); // Convert single quote to double quote
} else {
result.push(c);
}
in_string = false;
}
'"' if in_string && string_delimiter == '\'' => {
// Double quote inside single-quoted string - escape it
result.push_str(r#"\""#);
}
'\\' if in_string => {
// Escape sequence - preserve it
result.push(ch);
if let Some(&next_ch) = chars.peek() {
result.push(chars.next().unwrap());
}
}
_ => {
result.push(ch);
}
}
}
result
}
pub mod providers { pub mod providers {
pub mod anthropic; pub mod anthropic;
pub mod embedded; pub mod embedded;

View File

@@ -517,7 +517,53 @@ impl LLMProvider for EmbeddedProvider {
} }
if hit_stop { if hit_stop {
// Send any remaining clean content before stopping // Before stopping, check if there might be an incomplete tool call
// Look for JSON tool call patterns that might be cut off by the stop sequence
let has_potential_tool_call = accumulated_text.contains(r#"{"tool":"#) ||
accumulated_text.contains(r#"{"{""tool"":"#) ||
accumulated_text.contains(r#"{{""tool"":"#);
if has_potential_tool_call {
// Check if the tool call appears to be complete (has closing brace after the stop sequence)
let mut complete_tool_call = false;
for stop_seq in &stop_sequences {
if let Some(stop_pos) = accumulated_text.find(stop_seq) {
// Look for tool call pattern before the stop sequence
let before_stop = &accumulated_text[..stop_pos];
if let Some(tool_start) = before_stop.rfind(r#"{"tool":"#) {
let tool_part = &before_stop[tool_start..];
// Count braces to see if JSON is complete
let open_braces = tool_part.matches('{').count();
let close_braces = tool_part.matches('}').count();
if open_braces > 0 && open_braces == close_braces {
complete_tool_call = true;
break;
}
}
}
}
// If tool call is incomplete, send the raw content including stop sequences
// so the main parser can handle it properly
if !complete_tool_call {
debug!("Found incomplete tool call, sending raw content with stop sequences");
let already_sent_len = accumulated_text.len() - unsent_tokens.len();
if accumulated_text.len() > already_sent_len {
let remaining_to_send = &accumulated_text[already_sent_len..];
if !remaining_to_send.is_empty() {
let chunk = CompletionChunk {
content: remaining_to_send.to_string(),
finished: false,
tool_calls: None,
};
let _ = tx.blocking_send(Ok(chunk));
}
}
break;
}
}
// Send any remaining clean content before stopping (original behavior)
let mut clean_accumulated = accumulated_text.clone(); let mut clean_accumulated = accumulated_text.clone();
for stop_seq in &stop_sequences { for stop_seq in &stop_sequences {
if let Some(pos) = clean_accumulated.find(stop_seq) { if let Some(pos) = clean_accumulated.find(stop_seq) {