decent basic toolshim
This commit is contained in:
@@ -7,7 +7,7 @@ use std::fs;
|
|||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
use tracing::{error, info, warn, debug};
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ToolCall {
|
pub struct ToolCall {
|
||||||
@@ -55,9 +55,9 @@ impl StreamingToolParser {
|
|||||||
// Look for JSON tool call pattern - check both raw JSON and inside code blocks
|
// Look for JSON tool call pattern - check both raw JSON and inside code blocks
|
||||||
// Also handle malformed patterns like {"{""tool"":
|
// Also handle malformed patterns like {"{""tool"":
|
||||||
let patterns = [
|
let patterns = [
|
||||||
r#"{"tool":"#, // Normal pattern
|
r#"{"tool":"#, // Normal pattern
|
||||||
r#"{"{""tool"":"#, // Malformed pattern with extra brace and doubled quotes
|
r#"{"{""tool"":"#, // Malformed pattern with extra brace and doubled quotes
|
||||||
r#"{{""tool"":"#, // Alternative malformed pattern
|
r#"{{""tool"":"#, // Alternative malformed pattern
|
||||||
];
|
];
|
||||||
|
|
||||||
for pattern in &patterns {
|
for pattern in &patterns {
|
||||||
@@ -113,11 +113,12 @@ impl StreamingToolParser {
|
|||||||
|
|
||||||
// Clean up malformed JSON patterns
|
// Clean up malformed JSON patterns
|
||||||
json_str = json_str
|
json_str = json_str
|
||||||
.replace(r#"{"{""#, r#"{"#) // Fix {"{" -> {"
|
.replace(r#"{"{""#, r#"{"#) // Fix {"{" -> {"
|
||||||
.replace(r#"""}"#, r#""}"#) // Fix ""} -> "}
|
.replace(r#"""}"#, r#""}"#) // Fix ""} -> "}
|
||||||
.replace(r#"{{""#, r#"{"#) // Fix {{" -> {"
|
.replace(r#"{{""#, r#"{"#) // Fix {{" -> {"
|
||||||
.replace(r#"""}"#, r#""}"#); // Fix ""} -> "}
|
.replace(r#"""}"#, r#""}"#); // Fix ""} -> "}
|
||||||
|
|
||||||
|
// First, try to parse the JSON as-is
|
||||||
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(&json_str) {
|
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(&json_str) {
|
||||||
info!("Successfully parsed tool call: {:?}", tool_call);
|
info!("Successfully parsed tool call: {:?}", tool_call);
|
||||||
// Reset parser state
|
// Reset parser state
|
||||||
@@ -126,13 +127,49 @@ impl StreamingToolParser {
|
|||||||
self.brace_count = 0;
|
self.brace_count = 0;
|
||||||
|
|
||||||
return Some((tool_call, end_pos));
|
return Some((tool_call, end_pos));
|
||||||
} else {
|
|
||||||
info!("Failed to parse JSON after cleanup: {}", json_str);
|
|
||||||
// Invalid JSON, reset and continue looking
|
|
||||||
self.in_tool_call = false;
|
|
||||||
self.tool_start_pos = None;
|
|
||||||
self.brace_count = 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If parsing failed and this is a shell command, try to fix nested quotes
|
||||||
|
if json_str.contains(r#""tool": "shell""#) {
|
||||||
|
let fixed_json = fix_nested_quotes_in_shell_command(&json_str);
|
||||||
|
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(&fixed_json) {
|
||||||
|
info!("Successfully parsed tool call after fixing nested quotes: {:?}", tool_call);
|
||||||
|
// Reset parser state
|
||||||
|
self.in_tool_call = false;
|
||||||
|
self.tool_start_pos = None;
|
||||||
|
self.brace_count = 0;
|
||||||
|
|
||||||
|
return Some((tool_call, end_pos));
|
||||||
|
} else {
|
||||||
|
info!(
|
||||||
|
"Failed to parse JSON even after fixing nested quotes: {}",
|
||||||
|
fixed_json
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to fix mixed quote issues (single quotes in JSON)
|
||||||
|
let fixed_mixed_quotes = fix_mixed_quotes_in_json(&json_str);
|
||||||
|
if fixed_mixed_quotes != json_str {
|
||||||
|
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(&fixed_mixed_quotes) {
|
||||||
|
info!("Successfully parsed tool call after fixing mixed quotes: {:?}", tool_call);
|
||||||
|
// Reset parser state
|
||||||
|
self.in_tool_call = false;
|
||||||
|
self.tool_start_pos = None;
|
||||||
|
self.brace_count = 0;
|
||||||
|
|
||||||
|
return Some((tool_call, end_pos));
|
||||||
|
} else {
|
||||||
|
info!("Failed to parse JSON even after fixing mixed quotes: {}", fixed_mixed_quotes);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
info!("Failed to parse JSON (no fixes applied): {}", json_str);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid JSON, reset and continue looking
|
||||||
|
self.in_tool_call = false;
|
||||||
|
self.tool_start_pos = None;
|
||||||
|
self.brace_count = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
@@ -245,7 +282,10 @@ impl Agent {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set default provider
|
// Set default provider
|
||||||
debug!("Setting default provider to: {}", config.providers.default_provider);
|
debug!(
|
||||||
|
"Setting default provider to: {}",
|
||||||
|
config.providers.default_provider
|
||||||
|
);
|
||||||
providers.set_default(&config.providers.default_provider)?;
|
providers.set_default(&config.providers.default_provider)?;
|
||||||
debug!("Default provider set successfully");
|
debug!("Default provider set successfully");
|
||||||
|
|
||||||
@@ -394,7 +434,7 @@ When you need to execute a tool, write ONLY the JSON tool call on a new line:
|
|||||||
|
|
||||||
{{\"tool\": \"tool_name\", \"args\": {{\"param\": \"value\"}}}}
|
{{\"tool\": \"tool_name\", \"args\": {{\"param\": \"value\"}}}}
|
||||||
|
|
||||||
The tool will execute immediately and you'll receive the result to continue with.
|
The tool will execute immediately and you'll receive the result (success or error) to continue with.
|
||||||
|
|
||||||
# Available Tools
|
# Available Tools
|
||||||
|
|
||||||
@@ -410,12 +450,12 @@ The tool will execute immediately and you'll receive the result to continue with
|
|||||||
1. Analyze the request and break down into smaller tasks if appropriate
|
1. Analyze the request and break down into smaller tasks if appropriate
|
||||||
2. Execute ONE tool at a time
|
2. Execute ONE tool at a time
|
||||||
3. STOP when the original request was satisfied
|
3. STOP when the original request was satisfied
|
||||||
4. End with final_output when done
|
4. Call the final_output tool when done
|
||||||
|
|
||||||
# Response Guidelines
|
# Response Guidelines
|
||||||
|
|
||||||
- Use Markdown formatting for all responses except tool calls.
|
- Use Markdown formatting for all responses except tool calls.
|
||||||
- Whenever calling tools, use the pronoun 'I'
|
- Whenever taking actions, use the pronoun 'I'
|
||||||
|
|
||||||
");
|
");
|
||||||
|
|
||||||
@@ -639,6 +679,7 @@ The tool will execute immediately and you'll receive the result to continue with
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If no native tool calls, check for JSON tool calls in text (embedded models)
|
// If no native tool calls, check for JSON tool calls in text (embedded models)
|
||||||
|
// IMPORTANT: Pass raw content to parser BEFORE cleaning stop sequences
|
||||||
if detected_tool_call.is_none() {
|
if detected_tool_call.is_none() {
|
||||||
detected_tool_call = parser.add_chunk(&chunk.content);
|
detected_tool_call = parser.add_chunk(&chunk.content);
|
||||||
}
|
}
|
||||||
@@ -667,16 +708,17 @@ The tool will execute immediately and you'll receive the result to continue with
|
|||||||
let final_display_content = clean_display_content.trim();
|
let final_display_content = clean_display_content.trim();
|
||||||
|
|
||||||
// Safely get the new content to display
|
// Safely get the new content to display
|
||||||
let new_content = if current_response.len() <= final_display_content.len() {
|
let new_content =
|
||||||
// Use char indices to avoid UTF-8 boundary issues
|
if current_response.len() <= final_display_content.len() {
|
||||||
let chars_already_shown = current_response.chars().count();
|
// Use char indices to avoid UTF-8 boundary issues
|
||||||
final_display_content
|
let chars_already_shown = current_response.chars().count();
|
||||||
.chars()
|
final_display_content
|
||||||
.skip(chars_already_shown)
|
.chars()
|
||||||
.collect::<String>()
|
.skip(chars_already_shown)
|
||||||
} else {
|
.collect::<String>()
|
||||||
String::new()
|
} else {
|
||||||
};
|
String::new()
|
||||||
|
};
|
||||||
|
|
||||||
// Only print if there's actually new content to show
|
// Only print if there's actually new content to show
|
||||||
if !new_content.trim().is_empty() {
|
if !new_content.trim().is_empty() {
|
||||||
@@ -764,7 +806,8 @@ The tool will execute immediately and you'll receive the result to continue with
|
|||||||
// Check if this was a final_output tool call - if so, stop the conversation
|
// Check if this was a final_output tool call - if so, stop the conversation
|
||||||
if tool_call.tool == "final_output" {
|
if tool_call.tool == "final_output" {
|
||||||
println!(); // New line after final output
|
println!(); // New line after final output
|
||||||
let ttft = first_token_time.unwrap_or_else(|| stream_start.elapsed());
|
let ttft =
|
||||||
|
first_token_time.unwrap_or_else(|| stream_start.elapsed());
|
||||||
return Ok((full_response, ttft));
|
return Ok((full_response, ttft));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -774,7 +817,8 @@ The tool will execute immediately and you'll receive the result to continue with
|
|||||||
} else {
|
} else {
|
||||||
// No tool call detected, continue streaming normally
|
// No tool call detected, continue streaming normally
|
||||||
// Filter out stop tokens from the streaming output
|
// Filter out stop tokens from the streaming output
|
||||||
let clean_content = chunk.content
|
let clean_content = chunk
|
||||||
|
.content
|
||||||
.replace("<|im_end|>", "")
|
.replace("<|im_end|>", "")
|
||||||
.replace("</s>", "")
|
.replace("</s>", "")
|
||||||
.replace("[/INST]", "")
|
.replace("[/INST]", "")
|
||||||
@@ -978,6 +1022,118 @@ fn shell_escape_command(command: &str) -> String {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper function to fix nested quotes in shell command JSON
|
||||||
|
fn fix_nested_quotes_in_shell_command(json_str: &str) -> String {
|
||||||
|
// This handles cases where shell commands contain nested quotes that break JSON parsing
|
||||||
|
// Example: {"tool": "shell", "args": {"command": "python -c 'import os; print("hello")'"}}
|
||||||
|
|
||||||
|
// Look for the pattern: "command": "
|
||||||
|
if let Some(command_start) = json_str.find(r#""command": ""#) {
|
||||||
|
let command_value_start = command_start + r#""command": ""#.len();
|
||||||
|
|
||||||
|
// Find the end of the command string by looking for the pattern "}
|
||||||
|
// We need to be careful about nested quotes
|
||||||
|
if let Some(end_marker) = json_str[command_value_start..].find(r#"" }"#) {
|
||||||
|
let command_end = command_value_start + end_marker;
|
||||||
|
|
||||||
|
let before = &json_str[..command_value_start];
|
||||||
|
let command_content = &json_str[command_value_start..command_end];
|
||||||
|
let after = &json_str[command_end..];
|
||||||
|
|
||||||
|
// Fix the command content by properly escaping quotes
|
||||||
|
let mut fixed_command = String::new();
|
||||||
|
let mut chars = command_content.chars().peekable();
|
||||||
|
|
||||||
|
while let Some(ch) = chars.next() {
|
||||||
|
match ch {
|
||||||
|
'"' => {
|
||||||
|
// Check if this quote is already escaped
|
||||||
|
if fixed_command.ends_with('\\') {
|
||||||
|
fixed_command.push(ch); // Already escaped, keep as-is
|
||||||
|
} else {
|
||||||
|
fixed_command.push_str(r#"\""#); // Escape the quote
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'\\' => {
|
||||||
|
// Check what follows the backslash
|
||||||
|
if let Some(&next_ch) = chars.peek() {
|
||||||
|
if next_ch == '"' {
|
||||||
|
// This is an escaped quote, keep the backslash
|
||||||
|
fixed_command.push(ch);
|
||||||
|
} else {
|
||||||
|
// Regular backslash, escape it
|
||||||
|
fixed_command.push_str(r#"\\"#);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Backslash at end, escape it
|
||||||
|
fixed_command.push_str(r#"\\"#);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => fixed_command.push(ch),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return format!("{}{}{}", before, fixed_command, after);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: if we can't parse the structure, try some basic replacements
|
||||||
|
json_str.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to fix mixed quotes in JSON (single quotes where double quotes should be)
|
||||||
|
fn fix_mixed_quotes_in_json(json_str: &str) -> String {
|
||||||
|
// This handles cases where the LLM uses single quotes in JSON values
|
||||||
|
// Example: {"tool": "shell", "args": {"command": 'echo "hello"'}}
|
||||||
|
|
||||||
|
let mut result = String::new();
|
||||||
|
let mut chars = json_str.chars().peekable();
|
||||||
|
let mut in_string = false;
|
||||||
|
let mut string_delimiter = '"';
|
||||||
|
|
||||||
|
while let Some(ch) = chars.next() {
|
||||||
|
match ch {
|
||||||
|
'"' if !in_string => {
|
||||||
|
// Start of a double-quoted string
|
||||||
|
in_string = true;
|
||||||
|
string_delimiter = '"';
|
||||||
|
result.push(ch);
|
||||||
|
}
|
||||||
|
'\'' if !in_string => {
|
||||||
|
// Start of a single-quoted string - convert to double quotes
|
||||||
|
in_string = true;
|
||||||
|
string_delimiter = '\'';
|
||||||
|
result.push('"'); // Convert single quote to double quote
|
||||||
|
}
|
||||||
|
c if in_string && c == string_delimiter => {
|
||||||
|
// End of current string
|
||||||
|
if string_delimiter == '\'' {
|
||||||
|
result.push('"'); // Convert single quote to double quote
|
||||||
|
} else {
|
||||||
|
result.push(c);
|
||||||
|
}
|
||||||
|
in_string = false;
|
||||||
|
}
|
||||||
|
'"' if in_string && string_delimiter == '\'' => {
|
||||||
|
// Double quote inside single-quoted string - escape it
|
||||||
|
result.push_str(r#"\""#);
|
||||||
|
}
|
||||||
|
'\\' if in_string => {
|
||||||
|
// Escape sequence - preserve it
|
||||||
|
result.push(ch);
|
||||||
|
if let Some(&next_ch) = chars.peek() {
|
||||||
|
result.push(chars.next().unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
result.push(ch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
pub mod providers {
|
pub mod providers {
|
||||||
pub mod anthropic;
|
pub mod anthropic;
|
||||||
pub mod embedded;
|
pub mod embedded;
|
||||||
|
|||||||
@@ -517,7 +517,53 @@ impl LLMProvider for EmbeddedProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if hit_stop {
|
if hit_stop {
|
||||||
// Send any remaining clean content before stopping
|
// Before stopping, check if there might be an incomplete tool call
|
||||||
|
// Look for JSON tool call patterns that might be cut off by the stop sequence
|
||||||
|
let has_potential_tool_call = accumulated_text.contains(r#"{"tool":"#) ||
|
||||||
|
accumulated_text.contains(r#"{"{""tool"":"#) ||
|
||||||
|
accumulated_text.contains(r#"{{""tool"":"#);
|
||||||
|
|
||||||
|
if has_potential_tool_call {
|
||||||
|
// Check if the tool call appears to be complete (has closing brace after the stop sequence)
|
||||||
|
let mut complete_tool_call = false;
|
||||||
|
for stop_seq in &stop_sequences {
|
||||||
|
if let Some(stop_pos) = accumulated_text.find(stop_seq) {
|
||||||
|
// Look for tool call pattern before the stop sequence
|
||||||
|
let before_stop = &accumulated_text[..stop_pos];
|
||||||
|
if let Some(tool_start) = before_stop.rfind(r#"{"tool":"#) {
|
||||||
|
let tool_part = &before_stop[tool_start..];
|
||||||
|
// Count braces to see if JSON is complete
|
||||||
|
let open_braces = tool_part.matches('{').count();
|
||||||
|
let close_braces = tool_part.matches('}').count();
|
||||||
|
if open_braces > 0 && open_braces == close_braces {
|
||||||
|
complete_tool_call = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If tool call is incomplete, send the raw content including stop sequences
|
||||||
|
// so the main parser can handle it properly
|
||||||
|
if !complete_tool_call {
|
||||||
|
debug!("Found incomplete tool call, sending raw content with stop sequences");
|
||||||
|
let already_sent_len = accumulated_text.len() - unsent_tokens.len();
|
||||||
|
if accumulated_text.len() > already_sent_len {
|
||||||
|
let remaining_to_send = &accumulated_text[already_sent_len..];
|
||||||
|
if !remaining_to_send.is_empty() {
|
||||||
|
let chunk = CompletionChunk {
|
||||||
|
content: remaining_to_send.to_string(),
|
||||||
|
finished: false,
|
||||||
|
tool_calls: None,
|
||||||
|
};
|
||||||
|
let _ = tx.blocking_send(Ok(chunk));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send any remaining clean content before stopping (original behavior)
|
||||||
let mut clean_accumulated = accumulated_text.clone();
|
let mut clean_accumulated = accumulated_text.clone();
|
||||||
for stop_seq in &stop_sequences {
|
for stop_seq in &stop_sequences {
|
||||||
if let Some(pos) = clean_accumulated.find(stop_seq) {
|
if let Some(pos) = clean_accumulated.find(stop_seq) {
|
||||||
|
|||||||
Reference in New Issue
Block a user