decent basic toolshim
This commit is contained in:
@@ -7,7 +7,7 @@ use std::fs;
|
||||
use std::path::Path;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, warn, debug};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
@@ -55,11 +55,11 @@ impl StreamingToolParser {
|
||||
// Look for JSON tool call pattern - check both raw JSON and inside code blocks
|
||||
// Also handle malformed patterns like {"{""tool"":
|
||||
let patterns = [
|
||||
r#"{"tool":"#, // Normal pattern
|
||||
r#"{"{""tool"":"#, // Malformed pattern with extra brace and doubled quotes
|
||||
r#"{{""tool"":"#, // Alternative malformed pattern
|
||||
r#"{"tool":"#, // Normal pattern
|
||||
r#"{"{""tool"":"#, // Malformed pattern with extra brace and doubled quotes
|
||||
r#"{{""tool"":"#, // Alternative malformed pattern
|
||||
];
|
||||
|
||||
|
||||
for pattern in &patterns {
|
||||
if let Some(pos) = self.buffer.rfind(pattern) {
|
||||
info!("Found tool call pattern '{}' at position: {}", pattern, pos);
|
||||
@@ -113,11 +113,12 @@ impl StreamingToolParser {
|
||||
|
||||
// Clean up malformed JSON patterns
|
||||
json_str = json_str
|
||||
.replace(r#"{"{""#, r#"{"#) // Fix {"{" -> {"
|
||||
.replace(r#"""}"#, r#""}"#) // Fix ""} -> "}
|
||||
.replace(r#"{{""#, r#"{"#) // Fix {{" -> {"
|
||||
.replace(r#"""}"#, r#""}"#); // Fix ""} -> "}
|
||||
.replace(r#"{"{""#, r#"{"#) // Fix {"{" -> {"
|
||||
.replace(r#"""}"#, r#""}"#) // Fix ""} -> "}
|
||||
.replace(r#"{{""#, r#"{"#) // Fix {{" -> {"
|
||||
.replace(r#"""}"#, r#""}"#); // Fix ""} -> "}
|
||||
|
||||
// First, try to parse the JSON as-is
|
||||
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(&json_str) {
|
||||
info!("Successfully parsed tool call: {:?}", tool_call);
|
||||
// Reset parser state
|
||||
@@ -126,13 +127,49 @@ impl StreamingToolParser {
|
||||
self.brace_count = 0;
|
||||
|
||||
return Some((tool_call, end_pos));
|
||||
} else {
|
||||
info!("Failed to parse JSON after cleanup: {}", json_str);
|
||||
// Invalid JSON, reset and continue looking
|
||||
self.in_tool_call = false;
|
||||
self.tool_start_pos = None;
|
||||
self.brace_count = 0;
|
||||
}
|
||||
|
||||
// If parsing failed and this is a shell command, try to fix nested quotes
|
||||
if json_str.contains(r#""tool": "shell""#) {
|
||||
let fixed_json = fix_nested_quotes_in_shell_command(&json_str);
|
||||
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(&fixed_json) {
|
||||
info!("Successfully parsed tool call after fixing nested quotes: {:?}", tool_call);
|
||||
// Reset parser state
|
||||
self.in_tool_call = false;
|
||||
self.tool_start_pos = None;
|
||||
self.brace_count = 0;
|
||||
|
||||
return Some((tool_call, end_pos));
|
||||
} else {
|
||||
info!(
|
||||
"Failed to parse JSON even after fixing nested quotes: {}",
|
||||
fixed_json
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Try to fix mixed quote issues (single quotes in JSON)
|
||||
let fixed_mixed_quotes = fix_mixed_quotes_in_json(&json_str);
|
||||
if fixed_mixed_quotes != json_str {
|
||||
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(&fixed_mixed_quotes) {
|
||||
info!("Successfully parsed tool call after fixing mixed quotes: {:?}", tool_call);
|
||||
// Reset parser state
|
||||
self.in_tool_call = false;
|
||||
self.tool_start_pos = None;
|
||||
self.brace_count = 0;
|
||||
|
||||
return Some((tool_call, end_pos));
|
||||
} else {
|
||||
info!("Failed to parse JSON even after fixing mixed quotes: {}", fixed_mixed_quotes);
|
||||
}
|
||||
} else {
|
||||
info!("Failed to parse JSON (no fixes applied): {}", json_str);
|
||||
}
|
||||
|
||||
// Invalid JSON, reset and continue looking
|
||||
self.in_tool_call = false;
|
||||
self.tool_start_pos = None;
|
||||
self.brace_count = 0;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
@@ -245,7 +282,10 @@ impl Agent {
|
||||
}
|
||||
|
||||
// Set default provider
|
||||
debug!("Setting default provider to: {}", config.providers.default_provider);
|
||||
debug!(
|
||||
"Setting default provider to: {}",
|
||||
config.providers.default_provider
|
||||
);
|
||||
providers.set_default(&config.providers.default_provider)?;
|
||||
debug!("Default provider set successfully");
|
||||
|
||||
@@ -394,7 +434,7 @@ When you need to execute a tool, write ONLY the JSON tool call on a new line:
|
||||
|
||||
{{\"tool\": \"tool_name\", \"args\": {{\"param\": \"value\"}}}}
|
||||
|
||||
The tool will execute immediately and you'll receive the result to continue with.
|
||||
The tool will execute immediately and you'll receive the result (success or error) to continue with.
|
||||
|
||||
# Available Tools
|
||||
|
||||
@@ -410,12 +450,12 @@ The tool will execute immediately and you'll receive the result to continue with
|
||||
1. Analyze the request and break down into smaller tasks if appropriate
|
||||
2. Execute ONE tool at a time
|
||||
3. STOP when the original request was satisfied
|
||||
4. End with final_output when done
|
||||
4. Call the final_output tool when done
|
||||
|
||||
# Response Guidelines
|
||||
|
||||
- Use Markdown formatting for all responses except tool calls.
|
||||
- Whenever calling tools, use the pronoun 'I'
|
||||
- Whenever taking actions, use the pronoun 'I'
|
||||
|
||||
");
|
||||
|
||||
@@ -616,10 +656,10 @@ The tool will execute immediately and you'll receive the result to continue with
|
||||
first_token_time = Some(stream_start.elapsed());
|
||||
}
|
||||
|
||||
// Check for tool calls - either from JSON parsing (embedded models)
|
||||
// Check for tool calls - either from JSON parsing (embedded models)
|
||||
// or from native tool calls (Anthropic, OpenAI, etc.)
|
||||
let mut detected_tool_call = None;
|
||||
|
||||
|
||||
// First check for native tool calls in the chunk
|
||||
if let Some(ref tool_calls) = chunk.tool_calls {
|
||||
debug!("Found native tool calls in chunk: {:?}", tool_calls);
|
||||
@@ -637,12 +677,13 @@ The tool will execute immediately and you'll receive the result to continue with
|
||||
} else {
|
||||
debug!("No native tool calls in chunk, chunk.tool_calls is None");
|
||||
}
|
||||
|
||||
|
||||
// If no native tool calls, check for JSON tool calls in text (embedded models)
|
||||
// IMPORTANT: Pass raw content to parser BEFORE cleaning stop sequences
|
||||
if detected_tool_call.is_none() {
|
||||
detected_tool_call = parser.add_chunk(&chunk.content);
|
||||
}
|
||||
|
||||
|
||||
if let Some((tool_call, tool_end_pos)) = detected_tool_call {
|
||||
// Found a complete tool call! Stop streaming and execute it
|
||||
let content_before_tool = parser.get_content_before_tool(tool_end_pos);
|
||||
@@ -667,17 +708,18 @@ The tool will execute immediately and you'll receive the result to continue with
|
||||
let final_display_content = clean_display_content.trim();
|
||||
|
||||
// Safely get the new content to display
|
||||
let new_content = if current_response.len() <= final_display_content.len() {
|
||||
// Use char indices to avoid UTF-8 boundary issues
|
||||
let chars_already_shown = current_response.chars().count();
|
||||
final_display_content
|
||||
.chars()
|
||||
.skip(chars_already_shown)
|
||||
.collect::<String>()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let new_content =
|
||||
if current_response.len() <= final_display_content.len() {
|
||||
// Use char indices to avoid UTF-8 boundary issues
|
||||
let chars_already_shown = current_response.chars().count();
|
||||
final_display_content
|
||||
.chars()
|
||||
.skip(chars_already_shown)
|
||||
.collect::<String>()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
// Only print if there's actually new content to show
|
||||
if !new_content.trim().is_empty() {
|
||||
print!("{}", new_content);
|
||||
@@ -764,7 +806,8 @@ The tool will execute immediately and you'll receive the result to continue with
|
||||
// Check if this was a final_output tool call - if so, stop the conversation
|
||||
if tool_call.tool == "final_output" {
|
||||
println!(); // New line after final output
|
||||
let ttft = first_token_time.unwrap_or_else(|| stream_start.elapsed());
|
||||
let ttft =
|
||||
first_token_time.unwrap_or_else(|| stream_start.elapsed());
|
||||
return Ok((full_response, ttft));
|
||||
}
|
||||
|
||||
@@ -774,12 +817,13 @@ The tool will execute immediately and you'll receive the result to continue with
|
||||
} else {
|
||||
// No tool call detected, continue streaming normally
|
||||
// Filter out stop tokens from the streaming output
|
||||
let clean_content = chunk.content
|
||||
let clean_content = chunk
|
||||
.content
|
||||
.replace("<|im_end|>", "")
|
||||
.replace("</s>", "")
|
||||
.replace("[/INST]", "")
|
||||
.replace("<</SYS>>", "");
|
||||
|
||||
|
||||
if !clean_content.is_empty() {
|
||||
print!("{}", clean_content);
|
||||
io::stdout().flush()?;
|
||||
@@ -978,6 +1022,118 @@ fn shell_escape_command(command: &str) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to fix nested quotes in shell command JSON
|
||||
fn fix_nested_quotes_in_shell_command(json_str: &str) -> String {
|
||||
// This handles cases where shell commands contain nested quotes that break JSON parsing
|
||||
// Example: {"tool": "shell", "args": {"command": "python -c 'import os; print("hello")'"}}
|
||||
|
||||
// Look for the pattern: "command": "
|
||||
if let Some(command_start) = json_str.find(r#""command": ""#) {
|
||||
let command_value_start = command_start + r#""command": ""#.len();
|
||||
|
||||
// Find the end of the command string by looking for the pattern "}
|
||||
// We need to be careful about nested quotes
|
||||
if let Some(end_marker) = json_str[command_value_start..].find(r#"" }"#) {
|
||||
let command_end = command_value_start + end_marker;
|
||||
|
||||
let before = &json_str[..command_value_start];
|
||||
let command_content = &json_str[command_value_start..command_end];
|
||||
let after = &json_str[command_end..];
|
||||
|
||||
// Fix the command content by properly escaping quotes
|
||||
let mut fixed_command = String::new();
|
||||
let mut chars = command_content.chars().peekable();
|
||||
|
||||
while let Some(ch) = chars.next() {
|
||||
match ch {
|
||||
'"' => {
|
||||
// Check if this quote is already escaped
|
||||
if fixed_command.ends_with('\\') {
|
||||
fixed_command.push(ch); // Already escaped, keep as-is
|
||||
} else {
|
||||
fixed_command.push_str(r#"\""#); // Escape the quote
|
||||
}
|
||||
}
|
||||
'\\' => {
|
||||
// Check what follows the backslash
|
||||
if let Some(&next_ch) = chars.peek() {
|
||||
if next_ch == '"' {
|
||||
// This is an escaped quote, keep the backslash
|
||||
fixed_command.push(ch);
|
||||
} else {
|
||||
// Regular backslash, escape it
|
||||
fixed_command.push_str(r#"\\"#);
|
||||
}
|
||||
} else {
|
||||
// Backslash at end, escape it
|
||||
fixed_command.push_str(r#"\\"#);
|
||||
}
|
||||
}
|
||||
_ => fixed_command.push(ch),
|
||||
}
|
||||
}
|
||||
|
||||
return format!("{}{}{}", before, fixed_command, after);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: if we can't parse the structure, try some basic replacements
|
||||
json_str.to_string()
|
||||
}
|
||||
|
||||
// Helper function to fix mixed quotes in JSON (single quotes where double quotes should be)
|
||||
fn fix_mixed_quotes_in_json(json_str: &str) -> String {
|
||||
// This handles cases where the LLM uses single quotes in JSON values
|
||||
// Example: {"tool": "shell", "args": {"command": 'echo "hello"'}}
|
||||
|
||||
let mut result = String::new();
|
||||
let mut chars = json_str.chars().peekable();
|
||||
let mut in_string = false;
|
||||
let mut string_delimiter = '"';
|
||||
|
||||
while let Some(ch) = chars.next() {
|
||||
match ch {
|
||||
'"' if !in_string => {
|
||||
// Start of a double-quoted string
|
||||
in_string = true;
|
||||
string_delimiter = '"';
|
||||
result.push(ch);
|
||||
}
|
||||
'\'' if !in_string => {
|
||||
// Start of a single-quoted string - convert to double quotes
|
||||
in_string = true;
|
||||
string_delimiter = '\'';
|
||||
result.push('"'); // Convert single quote to double quote
|
||||
}
|
||||
c if in_string && c == string_delimiter => {
|
||||
// End of current string
|
||||
if string_delimiter == '\'' {
|
||||
result.push('"'); // Convert single quote to double quote
|
||||
} else {
|
||||
result.push(c);
|
||||
}
|
||||
in_string = false;
|
||||
}
|
||||
'"' if in_string && string_delimiter == '\'' => {
|
||||
// Double quote inside single-quoted string - escape it
|
||||
result.push_str(r#"\""#);
|
||||
}
|
||||
'\\' if in_string => {
|
||||
// Escape sequence - preserve it
|
||||
result.push(ch);
|
||||
if let Some(&next_ch) = chars.peek() {
|
||||
result.push(chars.next().unwrap());
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
result.push(ch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub mod providers {
|
||||
pub mod anthropic;
|
||||
pub mod embedded;
|
||||
|
||||
Reference in New Issue
Block a user