decent basic toolshim
This commit is contained in:
@@ -7,7 +7,7 @@ use std::fs;
|
||||
use std::path::Path;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, warn, debug};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
@@ -118,6 +118,7 @@ impl StreamingToolParser {
|
||||
.replace(r#"{{""#, r#"{"#) // Fix {{" -> {"
|
||||
.replace(r#"""}"#, r#""}"#); // Fix ""} -> "}
|
||||
|
||||
// First, try to parse the JSON as-is
|
||||
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(&json_str) {
|
||||
info!("Successfully parsed tool call: {:?}", tool_call);
|
||||
// Reset parser state
|
||||
@@ -125,16 +126,52 @@ impl StreamingToolParser {
|
||||
self.tool_start_pos = None;
|
||||
self.brace_count = 0;
|
||||
|
||||
return Some((tool_call, end_pos));
|
||||
}
|
||||
|
||||
// If parsing failed and this is a shell command, try to fix nested quotes
|
||||
if json_str.contains(r#""tool": "shell""#) {
|
||||
let fixed_json = fix_nested_quotes_in_shell_command(&json_str);
|
||||
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(&fixed_json) {
|
||||
info!("Successfully parsed tool call after fixing nested quotes: {:?}", tool_call);
|
||||
// Reset parser state
|
||||
self.in_tool_call = false;
|
||||
self.tool_start_pos = None;
|
||||
self.brace_count = 0;
|
||||
|
||||
return Some((tool_call, end_pos));
|
||||
} else {
|
||||
info!("Failed to parse JSON after cleanup: {}", json_str);
|
||||
info!(
|
||||
"Failed to parse JSON even after fixing nested quotes: {}",
|
||||
fixed_json
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Try to fix mixed quote issues (single quotes in JSON)
|
||||
let fixed_mixed_quotes = fix_mixed_quotes_in_json(&json_str);
|
||||
if fixed_mixed_quotes != json_str {
|
||||
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(&fixed_mixed_quotes) {
|
||||
info!("Successfully parsed tool call after fixing mixed quotes: {:?}", tool_call);
|
||||
// Reset parser state
|
||||
self.in_tool_call = false;
|
||||
self.tool_start_pos = None;
|
||||
self.brace_count = 0;
|
||||
|
||||
return Some((tool_call, end_pos));
|
||||
} else {
|
||||
info!("Failed to parse JSON even after fixing mixed quotes: {}", fixed_mixed_quotes);
|
||||
}
|
||||
} else {
|
||||
info!("Failed to parse JSON (no fixes applied): {}", json_str);
|
||||
}
|
||||
|
||||
// Invalid JSON, reset and continue looking
|
||||
self.in_tool_call = false;
|
||||
self.tool_start_pos = None;
|
||||
self.brace_count = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@@ -245,7 +282,10 @@ impl Agent {
|
||||
}
|
||||
|
||||
// Set default provider
|
||||
debug!("Setting default provider to: {}", config.providers.default_provider);
|
||||
debug!(
|
||||
"Setting default provider to: {}",
|
||||
config.providers.default_provider
|
||||
);
|
||||
providers.set_default(&config.providers.default_provider)?;
|
||||
debug!("Default provider set successfully");
|
||||
|
||||
@@ -394,7 +434,7 @@ When you need to execute a tool, write ONLY the JSON tool call on a new line:
|
||||
|
||||
{{\"tool\": \"tool_name\", \"args\": {{\"param\": \"value\"}}}}
|
||||
|
||||
The tool will execute immediately and you'll receive the result to continue with.
|
||||
The tool will execute immediately and you'll receive the result (success or error) to continue with.
|
||||
|
||||
# Available Tools
|
||||
|
||||
@@ -410,12 +450,12 @@ The tool will execute immediately and you'll receive the result to continue with
|
||||
1. Analyze the request and break down into smaller tasks if appropriate
|
||||
2. Execute ONE tool at a time
|
||||
3. STOP when the original request was satisfied
|
||||
4. End with final_output when done
|
||||
4. Call the final_output tool when done
|
||||
|
||||
# Response Guidelines
|
||||
|
||||
- Use Markdown formatting for all responses except tool calls.
|
||||
- Whenever calling tools, use the pronoun 'I'
|
||||
- Whenever taking actions, use the pronoun 'I'
|
||||
|
||||
");
|
||||
|
||||
@@ -639,6 +679,7 @@ The tool will execute immediately and you'll receive the result to continue with
|
||||
}
|
||||
|
||||
// If no native tool calls, check for JSON tool calls in text (embedded models)
|
||||
// IMPORTANT: Pass raw content to parser BEFORE cleaning stop sequences
|
||||
if detected_tool_call.is_none() {
|
||||
detected_tool_call = parser.add_chunk(&chunk.content);
|
||||
}
|
||||
@@ -667,7 +708,8 @@ The tool will execute immediately and you'll receive the result to continue with
|
||||
let final_display_content = clean_display_content.trim();
|
||||
|
||||
// Safely get the new content to display
|
||||
let new_content = if current_response.len() <= final_display_content.len() {
|
||||
let new_content =
|
||||
if current_response.len() <= final_display_content.len() {
|
||||
// Use char indices to avoid UTF-8 boundary issues
|
||||
let chars_already_shown = current_response.chars().count();
|
||||
final_display_content
|
||||
@@ -764,7 +806,8 @@ The tool will execute immediately and you'll receive the result to continue with
|
||||
// Check if this was a final_output tool call - if so, stop the conversation
|
||||
if tool_call.tool == "final_output" {
|
||||
println!(); // New line after final output
|
||||
let ttft = first_token_time.unwrap_or_else(|| stream_start.elapsed());
|
||||
let ttft =
|
||||
first_token_time.unwrap_or_else(|| stream_start.elapsed());
|
||||
return Ok((full_response, ttft));
|
||||
}
|
||||
|
||||
@@ -774,7 +817,8 @@ The tool will execute immediately and you'll receive the result to continue with
|
||||
} else {
|
||||
// No tool call detected, continue streaming normally
|
||||
// Filter out stop tokens from the streaming output
|
||||
let clean_content = chunk.content
|
||||
let clean_content = chunk
|
||||
.content
|
||||
.replace("<|im_end|>", "")
|
||||
.replace("</s>", "")
|
||||
.replace("[/INST]", "")
|
||||
@@ -978,6 +1022,118 @@ fn shell_escape_command(command: &str) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to fix nested quotes in shell command JSON
|
||||
fn fix_nested_quotes_in_shell_command(json_str: &str) -> String {
|
||||
// This handles cases where shell commands contain nested quotes that break JSON parsing
|
||||
// Example: {"tool": "shell", "args": {"command": "python -c 'import os; print("hello")'"}}
|
||||
|
||||
// Look for the pattern: "command": "
|
||||
if let Some(command_start) = json_str.find(r#""command": ""#) {
|
||||
let command_value_start = command_start + r#""command": ""#.len();
|
||||
|
||||
// Find the end of the command string by looking for the pattern "}
|
||||
// We need to be careful about nested quotes
|
||||
if let Some(end_marker) = json_str[command_value_start..].find(r#"" }"#) {
|
||||
let command_end = command_value_start + end_marker;
|
||||
|
||||
let before = &json_str[..command_value_start];
|
||||
let command_content = &json_str[command_value_start..command_end];
|
||||
let after = &json_str[command_end..];
|
||||
|
||||
// Fix the command content by properly escaping quotes
|
||||
let mut fixed_command = String::new();
|
||||
let mut chars = command_content.chars().peekable();
|
||||
|
||||
while let Some(ch) = chars.next() {
|
||||
match ch {
|
||||
'"' => {
|
||||
// Check if this quote is already escaped
|
||||
if fixed_command.ends_with('\\') {
|
||||
fixed_command.push(ch); // Already escaped, keep as-is
|
||||
} else {
|
||||
fixed_command.push_str(r#"\""#); // Escape the quote
|
||||
}
|
||||
}
|
||||
'\\' => {
|
||||
// Check what follows the backslash
|
||||
if let Some(&next_ch) = chars.peek() {
|
||||
if next_ch == '"' {
|
||||
// This is an escaped quote, keep the backslash
|
||||
fixed_command.push(ch);
|
||||
} else {
|
||||
// Regular backslash, escape it
|
||||
fixed_command.push_str(r#"\\"#);
|
||||
}
|
||||
} else {
|
||||
// Backslash at end, escape it
|
||||
fixed_command.push_str(r#"\\"#);
|
||||
}
|
||||
}
|
||||
_ => fixed_command.push(ch),
|
||||
}
|
||||
}
|
||||
|
||||
return format!("{}{}{}", before, fixed_command, after);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: if we can't parse the structure, try some basic replacements
|
||||
json_str.to_string()
|
||||
}
|
||||
|
||||
// Helper function to fix mixed quotes in JSON (single quotes where double quotes should be)
|
||||
fn fix_mixed_quotes_in_json(json_str: &str) -> String {
|
||||
// This handles cases where the LLM uses single quotes in JSON values
|
||||
// Example: {"tool": "shell", "args": {"command": 'echo "hello"'}}
|
||||
|
||||
let mut result = String::new();
|
||||
let mut chars = json_str.chars().peekable();
|
||||
let mut in_string = false;
|
||||
let mut string_delimiter = '"';
|
||||
|
||||
while let Some(ch) = chars.next() {
|
||||
match ch {
|
||||
'"' if !in_string => {
|
||||
// Start of a double-quoted string
|
||||
in_string = true;
|
||||
string_delimiter = '"';
|
||||
result.push(ch);
|
||||
}
|
||||
'\'' if !in_string => {
|
||||
// Start of a single-quoted string - convert to double quotes
|
||||
in_string = true;
|
||||
string_delimiter = '\'';
|
||||
result.push('"'); // Convert single quote to double quote
|
||||
}
|
||||
c if in_string && c == string_delimiter => {
|
||||
// End of current string
|
||||
if string_delimiter == '\'' {
|
||||
result.push('"'); // Convert single quote to double quote
|
||||
} else {
|
||||
result.push(c);
|
||||
}
|
||||
in_string = false;
|
||||
}
|
||||
'"' if in_string && string_delimiter == '\'' => {
|
||||
// Double quote inside single-quoted string - escape it
|
||||
result.push_str(r#"\""#);
|
||||
}
|
||||
'\\' if in_string => {
|
||||
// Escape sequence - preserve it
|
||||
result.push(ch);
|
||||
if let Some(&next_ch) = chars.peek() {
|
||||
result.push(chars.next().unwrap());
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
result.push(ch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub mod providers {
|
||||
pub mod anthropic;
|
||||
pub mod embedded;
|
||||
|
||||
@@ -517,7 +517,53 @@ impl LLMProvider for EmbeddedProvider {
|
||||
}
|
||||
|
||||
if hit_stop {
|
||||
// Send any remaining clean content before stopping
|
||||
// Before stopping, check if there might be an incomplete tool call
|
||||
// Look for JSON tool call patterns that might be cut off by the stop sequence
|
||||
let has_potential_tool_call = accumulated_text.contains(r#"{"tool":"#) ||
|
||||
accumulated_text.contains(r#"{"{""tool"":"#) ||
|
||||
accumulated_text.contains(r#"{{""tool"":"#);
|
||||
|
||||
if has_potential_tool_call {
|
||||
// Check if the tool call appears to be complete (has closing brace after the stop sequence)
|
||||
let mut complete_tool_call = false;
|
||||
for stop_seq in &stop_sequences {
|
||||
if let Some(stop_pos) = accumulated_text.find(stop_seq) {
|
||||
// Look for tool call pattern before the stop sequence
|
||||
let before_stop = &accumulated_text[..stop_pos];
|
||||
if let Some(tool_start) = before_stop.rfind(r#"{"tool":"#) {
|
||||
let tool_part = &before_stop[tool_start..];
|
||||
// Count braces to see if JSON is complete
|
||||
let open_braces = tool_part.matches('{').count();
|
||||
let close_braces = tool_part.matches('}').count();
|
||||
if open_braces > 0 && open_braces == close_braces {
|
||||
complete_tool_call = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If tool call is incomplete, send the raw content including stop sequences
|
||||
// so the main parser can handle it properly
|
||||
if !complete_tool_call {
|
||||
debug!("Found incomplete tool call, sending raw content with stop sequences");
|
||||
let already_sent_len = accumulated_text.len() - unsent_tokens.len();
|
||||
if accumulated_text.len() > already_sent_len {
|
||||
let remaining_to_send = &accumulated_text[already_sent_len..];
|
||||
if !remaining_to_send.is_empty() {
|
||||
let chunk = CompletionChunk {
|
||||
content: remaining_to_send.to_string(),
|
||||
finished: false,
|
||||
tool_calls: None,
|
||||
};
|
||||
let _ = tx.blocking_send(Ok(chunk));
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Send any remaining clean content before stopping (original behavior)
|
||||
let mut clean_accumulated = accumulated_text.clone();
|
||||
for stop_seq in &stop_sequences {
|
||||
if let Some(pos) = clean_accumulated.find(stop_seq) {
|
||||
|
||||
Reference in New Issue
Block a user