diff --git a/Cargo.lock b/Cargo.lock index cd28da5..acb650b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1409,6 +1409,7 @@ dependencies = [ "config", "dirs 5.0.1", "serde", + "serde_json", "shellexpand", "tempfile", "thiserror 1.0.69", diff --git a/config.coach-player.example.toml b/config.coach-player.example.toml index d92e9df..dbb0f7f 100644 --- a/config.coach-player.example.toml +++ b/config.coach-player.example.toml @@ -33,4 +33,5 @@ temperature = 0.3 # Slightly higher temperature for more creative implementatio [agent] fallback_default_max_tokens = 8192 enable_streaming = true -timeout_seconds = 60 \ No newline at end of file +timeout_seconds = 60 +allow_multiple_tool_calls = true # Enable multiple tool calls, will usually only work with Anthropic \ No newline at end of file diff --git a/config.example.toml b/config.example.toml index 5420cfb..68e5aeb 100644 --- a/config.example.toml +++ b/config.example.toml @@ -57,6 +57,7 @@ timeout_seconds = 60 # Retry configuration for recoverable errors (timeouts, rate limits, etc.) max_retry_attempts = 3 # Default mode retry attempts autonomous_max_retry_attempts = 6 # Autonomous mode retry attempts (higher for long-running tasks) +allow_multiple_tool_calls = true # Enable multiple tool calls [computer_control] enabled = false # Set to true to enable computer control (requires OS permissions) diff --git a/crates/g3-config/Cargo.toml b/crates/g3-config/Cargo.toml index 92e0b89..67c228b 100644 --- a/crates/g3-config/Cargo.toml +++ b/crates/g3-config/Cargo.toml @@ -15,3 +15,4 @@ dirs = "5.0" [dev-dependencies] tempfile = "3.8" +serde_json = { workspace = true } diff --git a/crates/g3-config/src/lib.rs b/crates/g3-config/src/lib.rs index 827a83b..e8f567f 100644 --- a/crates/g3-config/src/lib.rs +++ b/crates/g3-config/src/lib.rs @@ -70,6 +70,7 @@ pub struct AgentConfig { pub max_context_length: Option, pub fallback_default_max_tokens: usize, pub enable_streaming: bool, + pub allow_multiple_tool_calls: bool, pub timeout_seconds: u64, pub auto_compact: bool, pub max_retry_attempts: u32, @@ -145,6 +146,7 @@ impl Default for Config { max_context_length: None, fallback_default_max_tokens: 8192, enable_streaming: true, + allow_multiple_tool_calls: false, timeout_seconds: 60, auto_compact: true, max_retry_attempts: 3, @@ -265,6 +267,7 @@ impl Config { max_context_length: None, fallback_default_max_tokens: 8192, enable_streaming: true, + allow_multiple_tool_calls: false, timeout_seconds: 60, auto_compact: true, max_retry_attempts: 3, diff --git a/crates/g3-config/tests/test_multiple_tool_calls.rs b/crates/g3-config/tests/test_multiple_tool_calls.rs new file mode 100644 index 0000000..5500a7e --- /dev/null +++ b/crates/g3-config/tests/test_multiple_tool_calls.rs @@ -0,0 +1,39 @@ +#[cfg(test)] +mod test_multiple_tool_calls { + use g3_config::{Config, AgentConfig}; + + #[test] + fn test_config_has_multiple_tool_calls_field() { + let config = Config::default(); + + // Test that the field exists and defaults to false + assert_eq!(config.agent.allow_multiple_tool_calls, false); + + // Test that we can create a config with the field set to true + let mut custom_config = Config::default(); + custom_config.agent.allow_multiple_tool_calls = true; + assert_eq!(custom_config.agent.allow_multiple_tool_calls, true); + } + + #[test] + fn test_agent_config_serialization() { + let agent_config = AgentConfig { + max_context_length: Some(100000), + fallback_default_max_tokens: 8192, + enable_streaming: true, + allow_multiple_tool_calls: true, + timeout_seconds: 60, + auto_compact: true, + max_retry_attempts: 3, + autonomous_max_retry_attempts: 6, + }; + + // Test serialization + let json = serde_json::to_string(&agent_config).unwrap(); + assert!(json.contains("\"allow_multiple_tool_calls\":true")); + + // Test deserialization + let deserialized: AgentConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.allow_multiple_tool_calls, true); + } +} diff --git a/crates/g3-core/src/lib.rs b/crates/g3-core/src/lib.rs index d1c429a..5af0fd9 100644 --- a/crates/g3-core/src/lib.rs +++ b/crates/g3-core/src/lib.rs @@ -27,14 +27,18 @@ use g3_computer_control::WebDriverController; use g3_config::Config; use g3_execution::CodeExecutor; use g3_providers::{CacheControl, CompletionRequest, Message, MessageRole, ProviderRegistry, Tool}; +use chrono::Local; #[allow(unused_imports)] use regex::Regex; use serde::{Deserialize, Serialize}; use serde_json::json; +use std::fs::OpenOptions; +use std::io::Write; +use std::sync::{Mutex, OnceLock}; use std::time::{Duration, Instant}; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, warn}; -use prompts::{SYSTEM_PROMPT_FOR_NON_NATIVE_TOOL_USE, SYSTEM_PROMPT_FOR_NATIVE_TOOL_USE}; +use prompts::{SYSTEM_PROMPT_FOR_NON_NATIVE_TOOL_USE, get_system_prompt_for_native}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolCall { @@ -955,10 +959,10 @@ impl Agent { let provider = providers.get(None)?; let provider_has_native_tool_calling = provider.has_native_tool_calling(); let _ = provider; // Drop provider reference to avoid borrowing issues - + let system_prompt = if provider_has_native_tool_calling { // For native tool calling providers, use a more explicit system prompt - SYSTEM_PROMPT_FOR_NATIVE_TOOL_USE.to_string() + get_system_prompt_for_native(config.agent.allow_multiple_tool_calls) } else { // For non-native providers (embedded models), use JSON format instructions SYSTEM_PROMPT_FOR_NON_NATIVE_TOOL_USE.to_string() @@ -1214,6 +1218,63 @@ impl Agent { Ok(context_length) } + fn tool_log_handle() -> Option<&'static Mutex> { + static TOOL_LOG: OnceLock>> = OnceLock::new(); + + TOOL_LOG + .get_or_init(|| { + if let Err(e) = std::fs::create_dir_all("logs") { + error!("Failed to create logs directory for tool log: {}", e); + return None; + } + + let ts = Local::now().format("%Y%m%d_%H%M%S").to_string(); + let path = format!("logs/tool_calls_{}.log", ts); + + match OpenOptions::new() + .create(true) + .append(true) + .open(&path) + { + Ok(file) => Some(Mutex::new(file)), + Err(e) => { + error!("Failed to open tool log file {}: {}", path, e); + None + } + } + }) + .as_ref() + } + + fn log_tool_call(&self, tool_call: &ToolCall, response: &str) { + if let Some(handle) = Self::tool_log_handle() { + let timestamp = Local::now().format("%Y-%m-%d %H:%M:%S").to_string(); + let args_str = serde_json::to_string(&tool_call.args) + .unwrap_or_else(|_| "".to_string()); + + fn sanitize(s: &str) -> String { + s.replace('\n', "\\n") + } + fn truncate(s: &str, limit: usize) -> String { + s.chars().take(limit).collect() + } + + let args_snippet = truncate(&sanitize(&args_str), 80); + let response_snippet = truncate(&sanitize(response), 80); + + let tool_field = format!("{:<15}", tool_call.tool); + let line = format!( + "{} {} {} 🟩 {}\n", + timestamp, tool_field, args_snippet, response_snippet + ); + + if let Ok(mut file) = handle.lock() { + let _ = file.write_all(line.as_bytes()); + let _ = file.flush(); + } + } + } + pub fn get_provider_info(&self) -> Result<(String, String)> { let provider = self.providers.get(None)?; Ok((provider.name().to_string(), provider.model().to_string())) @@ -2729,8 +2790,12 @@ impl Agent { tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; } - let provider = self.providers.get(None)?; - debug!("Got provider: {}", provider.name()); + // Get provider info for logging, then drop it to avoid borrow issues + let (provider_name, provider_model) = { + let provider = self.providers.get(None)?; + (provider.name().to_string(), provider.model().to_string()) + }; + debug!("Got provider: {}", provider_name); // Create error context for detailed logging let last_prompt = request @@ -2743,8 +2808,8 @@ impl Agent { let error_context = ErrorContext::new( "stream_completion".to_string(), - provider.name().to_string(), - provider.model().to_string(), + provider_name.clone(), + provider_model.clone(), last_prompt, self.session_id.clone(), self.context_window.used_tokens, @@ -2757,8 +2822,8 @@ impl Agent { // Log initial request details debug!("Starting stream with provider={}, model={}, messages={}, tools={}, max_tokens={:?}", - provider.name(), - provider.model(), + provider_name, + provider_model, request.messages.len(), request.tools.is_some(), request.max_tokens @@ -2848,10 +2913,125 @@ impl Agent { // Process chunk with the new parser let completed_tools = parser.process_chunk(&chunk); - // Handle completed tool calls - if let Some(tool_call) = completed_tools.into_iter().next() { + // Handle completed tool calls - process all if multiple calls enabled + let tools_to_process: Vec = if self.config.agent.allow_multiple_tool_calls { + completed_tools + } else { + // Original behavior - only take the first tool + completed_tools.into_iter().take(1).collect() + }; + + // Helper function to check if two tool calls are duplicates + let are_duplicates = |tc1: &ToolCall, tc2: &ToolCall| -> bool { + tc1.tool == tc2.tool && tc1.args == tc2.args + }; + + // De-duplicate tool calls and track duplicates + let mut seen_in_chunk: Vec = Vec::new(); + let mut deduplicated_tools: Vec<(ToolCall, Option)> = Vec::new(); + + for tool_call in tools_to_process { + let mut duplicate_type = None; + + // Check for duplicates in current chunk + if seen_in_chunk.iter().any(|tc| are_duplicates(tc, &tool_call)) { + duplicate_type = Some("DUP IN CHUNK".to_string()); + } else { + // Check for duplicate against previous message in history + // Look at the last assistant message that contains tool calls + let mut found_in_prev = false; + for msg in self.context_window.conversation_history.iter().rev() { + if matches!(msg.role, MessageRole::Assistant) { + // Try to parse tool calls from the message content + if msg.content.contains(r#"\"tool\""#) { + // Simple JSON extraction for tool calls + let content = &msg.content; + let mut start_idx = 0; + while let Some(tool_start) = content[start_idx..].find(r#"{\"tool\""#) { + let tool_start = start_idx + tool_start; + // Find the end of this JSON object + let mut brace_count = 0; + let mut in_string = false; + let mut escape_next = false; + let mut end_idx = tool_start; + + for (i, ch) in content[tool_start..].char_indices() { + if escape_next { + escape_next = false; + continue; + } + if ch == '\\' && in_string { + escape_next = true; + continue; + } + if ch == '"' && !escape_next { + in_string = !in_string; + } + if !in_string { + if ch == '{' { + brace_count += 1; + } else if ch == '}' { + brace_count -= 1; + if brace_count == 0 { + end_idx = tool_start + i + 1; + break; + } + } + } + } + + if end_idx > tool_start { + let tool_json = &content[tool_start..end_idx]; + if let Ok(prev_tool) = serde_json::from_str::(tool_json) { + if are_duplicates(&prev_tool, &tool_call) { + found_in_prev = true; + break; + } + } + } + start_idx = end_idx; + } + } + // Only check the most recent assistant message + break; + } + } + + if found_in_prev { + duplicate_type = Some("DUP IN MSG".to_string()); + } + } + + // Add to seen list if not a duplicate in chunk + if duplicate_type.as_ref().map_or(true, |s| s != "DUP IN CHUNK") { + seen_in_chunk.push(tool_call.clone()); + } + + deduplicated_tools.push((tool_call, duplicate_type)); + } + + // Process each tool call + for (tool_call, duplicate_type) in deduplicated_tools { debug!("Processing completed tool call: {:?}", tool_call); + // If it's a duplicate, log it and return a warning + if let Some(dup_type) = &duplicate_type { + // Log the duplicate with red prefix + let prefixed_tool_name = format!("🟥 {} {}", tool_call.tool, dup_type); + let warning_msg = format!( + "⚠️ Duplicate tool call detected ({}): Skipping execution of {} with args {}", + dup_type, + tool_call.tool, + serde_json::to_string(&tool_call.args).unwrap_or_else(|_| "".to_string()) + ); + + // Log to tool log with red prefix + let mut modified_tool_call = tool_call.clone(); + modified_tool_call.tool = prefixed_tool_name; + self.log_tool_call(&modified_tool_call, &warning_msg); + continue; // Skip execution of duplicate + } + // Check if we should auto-compact at 90% BEFORE executing the tool // We need to do this before any borrows of self if self.auto_compact && self.context_window.percentage_used() >= 90.0 { @@ -3140,7 +3320,16 @@ impl Agent { current_response.clear(); // Reset response_started flag for next iteration response_started = false; - break; // Break out of current stream to start a new one + + // For single tool mode, break immediately + if !self.config.agent.allow_multiple_tool_calls { + break; // Break out of current stream to start a new one + } + } // End of for loop processing each tool call + + // If we processed any tools in multiple mode, break out to start new stream + if tool_executed && self.config.agent.allow_multiple_tool_calls { + break; } // If no tool calls were completed, continue streaming normally @@ -3223,8 +3412,8 @@ impl Agent { error!("Iteration: {}/{}", iteration_count, MAX_ITERATIONS); error!( "Provider: {} (model: {})", - provider.name(), - provider.model() + provider_name, + provider_model ); error!("Chunks received: {}", chunks_received); error!("Parser state:"); @@ -3503,7 +3692,17 @@ impl Agent { pub async fn execute_tool(&mut self, tool_call: &ToolCall) -> Result { // Increment tool call count self.tool_call_count += 1; - + + let result = self.execute_tool_inner(tool_call).await; + let log_str = match &result { + Ok(s) => s.clone(), + Err(e) => format!("ERROR: {}", e), + }; + self.log_tool_call(tool_call, &log_str); + result + } + + async fn execute_tool_inner(&mut self, tool_call: &ToolCall) -> Result { debug!("=== EXECUTING TOOL ==="); debug!("Tool name: {}", tool_call.tool); debug!("Tool args (raw): {:?}", tool_call.args); diff --git a/crates/g3-core/src/prompts.rs b/crates/g3-core/src/prompts.rs index c1a1a34..3acabf0 100644 --- a/crates/g3-core/src/prompts.rs +++ b/crates/g3-core/src/prompts.rs @@ -187,6 +187,24 @@ Do not explain what you're going to do - just do it by calling the tools. pub const SYSTEM_PROMPT_FOR_NATIVE_TOOL_USE: &'static str = concatcp!(CODING_STYLE, SYSTEM_NATIVE_TOOL_CALLS); +/// Generate system prompt based on whether multiple tool calls are allowed +pub fn get_system_prompt_for_native(allow_multiple: bool) -> String { + if allow_multiple { + // Replace the "ONE tool" instruction with multiple tools instruction + let base = SYSTEM_PROMPT_FOR_NATIVE_TOOL_USE.to_string(); + base.replace( + "2. Call the appropriate tool with the required parameters", + "2. Call the appropriate tool(s) with the required parameters - you may call multiple tools in parallel when appropriate. + + For maximum efficiency, whenever you perform multiple independent operations, invoke all relevant tools simultaneously rather than sequentially. Prioritize calling tools in parallel whenever possible. For example, when reading 3 files, run 3 tool calls in parallel to read all 3 files into context at the same time. When running multiple read-only commands like `ls` or `list_dir`, always run all of the commands in parallel. Err on the side of maximizing parallel tool calls rather than running too many tools sequentially. + +" + ) + } else { + SYSTEM_PROMPT_FOR_NATIVE_TOOL_USE.to_string() + } +} + const SYSTEM_NON_NATIVE_TOOL_USE: &'static str = "You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems by writing code.