decent basic toolshim

This commit is contained in:
Dhanji Prasanna
2025-09-15 08:18:31 +10:00
parent 358faa15e0
commit 0dc3b42f38
2 changed files with 240 additions and 38 deletions

View File

@@ -7,7 +7,7 @@ use std::fs;
use std::path::Path;
use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn, debug};
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
@@ -118,6 +118,7 @@ impl StreamingToolParser {
.replace(r#"{{""#, r#"{"#) // Fix {{" -> {"
.replace(r#"""}"#, r#""}"#); // Fix ""} -> "}
// First, try to parse the JSON as-is
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(&json_str) {
info!("Successfully parsed tool call: {:?}", tool_call);
// Reset parser state
@@ -125,16 +126,52 @@ impl StreamingToolParser {
self.tool_start_pos = None;
self.brace_count = 0;
return Some((tool_call, end_pos));
}
// If parsing failed and this is a shell command, try to fix nested quotes
if json_str.contains(r#""tool": "shell""#) {
let fixed_json = fix_nested_quotes_in_shell_command(&json_str);
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(&fixed_json) {
info!("Successfully parsed tool call after fixing nested quotes: {:?}", tool_call);
// Reset parser state
self.in_tool_call = false;
self.tool_start_pos = None;
self.brace_count = 0;
return Some((tool_call, end_pos));
} else {
info!("Failed to parse JSON after cleanup: {}", json_str);
info!(
"Failed to parse JSON even after fixing nested quotes: {}",
fixed_json
);
}
}
// Try to fix mixed quote issues (single quotes in JSON)
let fixed_mixed_quotes = fix_mixed_quotes_in_json(&json_str);
if fixed_mixed_quotes != json_str {
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(&fixed_mixed_quotes) {
info!("Successfully parsed tool call after fixing mixed quotes: {:?}", tool_call);
// Reset parser state
self.in_tool_call = false;
self.tool_start_pos = None;
self.brace_count = 0;
return Some((tool_call, end_pos));
} else {
info!("Failed to parse JSON even after fixing mixed quotes: {}", fixed_mixed_quotes);
}
} else {
info!("Failed to parse JSON (no fixes applied): {}", json_str);
}
// Invalid JSON, reset and continue looking
self.in_tool_call = false;
self.tool_start_pos = None;
self.brace_count = 0;
}
}
}
_ => {}
}
}
@@ -245,7 +282,10 @@ impl Agent {
}
// Set default provider
debug!("Setting default provider to: {}", config.providers.default_provider);
debug!(
"Setting default provider to: {}",
config.providers.default_provider
);
providers.set_default(&config.providers.default_provider)?;
debug!("Default provider set successfully");
@@ -394,7 +434,7 @@ When you need to execute a tool, write ONLY the JSON tool call on a new line:
{{\"tool\": \"tool_name\", \"args\": {{\"param\": \"value\"}}}}
The tool will execute immediately and you'll receive the result to continue with.
The tool will execute immediately and you'll receive the result (success or error) to continue with.
# Available Tools
@@ -410,12 +450,12 @@ The tool will execute immediately and you'll receive the result to continue with
1. Analyze the request and break down into smaller tasks if appropriate
2. Execute ONE tool at a time
3. STOP when the original request was satisfied
4. End with final_output when done
4. Call the final_output tool when done
# Response Guidelines
- Use Markdown formatting for all responses except tool calls.
- Whenever calling tools, use the pronoun 'I'
- Whenever taking actions, use the pronoun 'I'
");
@@ -639,6 +679,7 @@ The tool will execute immediately and you'll receive the result to continue with
}
// If no native tool calls, check for JSON tool calls in text (embedded models)
// IMPORTANT: Pass raw content to parser BEFORE cleaning stop sequences
if detected_tool_call.is_none() {
detected_tool_call = parser.add_chunk(&chunk.content);
}
@@ -667,7 +708,8 @@ The tool will execute immediately and you'll receive the result to continue with
let final_display_content = clean_display_content.trim();
// Safely get the new content to display
let new_content = if current_response.len() <= final_display_content.len() {
let new_content =
if current_response.len() <= final_display_content.len() {
// Use char indices to avoid UTF-8 boundary issues
let chars_already_shown = current_response.chars().count();
final_display_content
@@ -764,7 +806,8 @@ The tool will execute immediately and you'll receive the result to continue with
// Check if this was a final_output tool call - if so, stop the conversation
if tool_call.tool == "final_output" {
println!(); // New line after final output
let ttft = first_token_time.unwrap_or_else(|| stream_start.elapsed());
let ttft =
first_token_time.unwrap_or_else(|| stream_start.elapsed());
return Ok((full_response, ttft));
}
@@ -774,7 +817,8 @@ The tool will execute immediately and you'll receive the result to continue with
} else {
// No tool call detected, continue streaming normally
// Filter out stop tokens from the streaming output
let clean_content = chunk.content
let clean_content = chunk
.content
.replace("<|im_end|>", "")
.replace("</s>", "")
.replace("[/INST]", "")
@@ -978,6 +1022,118 @@ fn shell_escape_command(command: &str) -> String {
}
}
// Helper function to fix nested quotes in shell command JSON
fn fix_nested_quotes_in_shell_command(json_str: &str) -> String {
// This handles cases where shell commands contain nested quotes that break JSON parsing
// Example: {"tool": "shell", "args": {"command": "python -c 'import os; print("hello")'"}}
// Look for the pattern: "command": "
if let Some(command_start) = json_str.find(r#""command": ""#) {
let command_value_start = command_start + r#""command": ""#.len();
// Find the end of the command string by looking for the pattern "}
// We need to be careful about nested quotes
if let Some(end_marker) = json_str[command_value_start..].find(r#"" }"#) {
let command_end = command_value_start + end_marker;
let before = &json_str[..command_value_start];
let command_content = &json_str[command_value_start..command_end];
let after = &json_str[command_end..];
// Fix the command content by properly escaping quotes
let mut fixed_command = String::new();
let mut chars = command_content.chars().peekable();
while let Some(ch) = chars.next() {
match ch {
'"' => {
// Check if this quote is already escaped
if fixed_command.ends_with('\\') {
fixed_command.push(ch); // Already escaped, keep as-is
} else {
fixed_command.push_str(r#"\""#); // Escape the quote
}
}
'\\' => {
// Check what follows the backslash
if let Some(&next_ch) = chars.peek() {
if next_ch == '"' {
// This is an escaped quote, keep the backslash
fixed_command.push(ch);
} else {
// Regular backslash, escape it
fixed_command.push_str(r#"\\"#);
}
} else {
// Backslash at end, escape it
fixed_command.push_str(r#"\\"#);
}
}
_ => fixed_command.push(ch),
}
}
return format!("{}{}{}", before, fixed_command, after);
}
}
// Fallback: if we can't parse the structure, try some basic replacements
json_str.to_string()
}
// Helper function to fix mixed quotes in JSON (single quotes where double quotes should be)
fn fix_mixed_quotes_in_json(json_str: &str) -> String {
// This handles cases where the LLM uses single quotes in JSON values
// Example: {"tool": "shell", "args": {"command": 'echo "hello"'}}
let mut result = String::new();
let mut chars = json_str.chars().peekable();
let mut in_string = false;
let mut string_delimiter = '"';
while let Some(ch) = chars.next() {
match ch {
'"' if !in_string => {
// Start of a double-quoted string
in_string = true;
string_delimiter = '"';
result.push(ch);
}
'\'' if !in_string => {
// Start of a single-quoted string - convert to double quotes
in_string = true;
string_delimiter = '\'';
result.push('"'); // Convert single quote to double quote
}
c if in_string && c == string_delimiter => {
// End of current string
if string_delimiter == '\'' {
result.push('"'); // Convert single quote to double quote
} else {
result.push(c);
}
in_string = false;
}
'"' if in_string && string_delimiter == '\'' => {
// Double quote inside single-quoted string - escape it
result.push_str(r#"\""#);
}
'\\' if in_string => {
// Escape sequence - preserve it
result.push(ch);
if let Some(&next_ch) = chars.peek() {
result.push(chars.next().unwrap());
}
}
_ => {
result.push(ch);
}
}
}
result
}
pub mod providers {
pub mod anthropic;
pub mod embedded;

View File

@@ -517,7 +517,53 @@ impl LLMProvider for EmbeddedProvider {
}
if hit_stop {
// Send any remaining clean content before stopping
// Before stopping, check if there might be an incomplete tool call
// Look for JSON tool call patterns that might be cut off by the stop sequence
let has_potential_tool_call = accumulated_text.contains(r#"{"tool":"#) ||
accumulated_text.contains(r#"{"{""tool"":"#) ||
accumulated_text.contains(r#"{{""tool"":"#);
if has_potential_tool_call {
// Check if the tool call appears to be complete (has closing brace after the stop sequence)
let mut complete_tool_call = false;
for stop_seq in &stop_sequences {
if let Some(stop_pos) = accumulated_text.find(stop_seq) {
// Look for tool call pattern before the stop sequence
let before_stop = &accumulated_text[..stop_pos];
if let Some(tool_start) = before_stop.rfind(r#"{"tool":"#) {
let tool_part = &before_stop[tool_start..];
// Count braces to see if JSON is complete
let open_braces = tool_part.matches('{').count();
let close_braces = tool_part.matches('}').count();
if open_braces > 0 && open_braces == close_braces {
complete_tool_call = true;
break;
}
}
}
}
// If tool call is incomplete, send the raw content including stop sequences
// so the main parser can handle it properly
if !complete_tool_call {
debug!("Found incomplete tool call, sending raw content with stop sequences");
let already_sent_len = accumulated_text.len() - unsent_tokens.len();
if accumulated_text.len() > already_sent_len {
let remaining_to_send = &accumulated_text[already_sent_len..];
if !remaining_to_send.is_empty() {
let chunk = CompletionChunk {
content: remaining_to_send.to_string(),
finished: false,
tool_calls: None,
};
let _ = tx.blocking_send(Ok(chunk));
}
}
break;
}
}
// Send any remaining clean content before stopping (original behavior)
let mut clean_accumulated = accumulated_text.clone();
for stop_seq in &stop_sequences {
if let Some(pos) = clean_accumulated.find(stop_seq) {