diff --git a/Cargo.lock b/Cargo.lock index af6e193..57b1543 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1884,18 +1884,28 @@ checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" [[package]] name = "serde" -version = "1.0.219" +version = "1.0.225" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "fd6c24dee235d0da097043389623fb913daddf92c76e9f5a1db88607a0bcbd1d" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.225" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "659356f9a0cb1e529b24c01e43ad2bdf520ec4ceaf83047b83ddcc2251f96383" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.225" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "0ea936adf78b1f766949a4977b91d2f5595825bd6ec079aa9543ad2685fc4516" dependencies = [ "proc-macro2", "quote", diff --git a/crates/g3-core/src/lib.rs b/crates/g3-core/src/lib.rs index 5727c51..b1b20f4 100644 --- a/crates/g3-core/src/lib.rs +++ b/crates/g3-core/src/lib.rs @@ -1,8 +1,9 @@ use anyhow::Result; use g3_config::Config; use g3_execution::CodeExecutor; -use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry}; +use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry, Tool}; use serde::{Deserialize, Serialize}; +use serde_json::json; use std::fs; use std::time::{Duration, Instant}; use tokio_util::sync::CancellationToken; @@ -412,7 +413,7 @@ impl Agent { /// Split a complex request into simpler sub-tasks async fn split_complex_request(&mut self, description: &str) -> Result> { let provider = self.providers.get(None)?; - + // Create a specific prompt to split the task let split_prompt = format!( "Analyze this request and split it into simpler, independent sub-tasks. \ @@ -423,7 +424,7 @@ impl Agent { Sub-tasks:", description ); - + let messages = vec![ Message { role: MessageRole::System, @@ -434,24 +435,26 @@ impl Agent { content: split_prompt, }, ]; - + let request = CompletionRequest { messages, max_tokens: Some(512), temperature: Some(0.1), stream: false, + tools: None, // No tools needed for task splitting }; - + // Use the non-streaming complete method let response = provider.complete(request).await?; - + // Split the response by newlines and filter out empty lines - let tasks: Vec = response.content + let tasks: Vec = response + .content .lines() .filter(|line| !line.trim().is_empty()) .map(|line| line.trim().to_string()) .collect(); - + // If we got back multiple tasks, return them; otherwise return the original if tasks.len() > 1 { info!("Split complex request into {} sub-tasks", tasks.len()); @@ -479,39 +482,44 @@ impl Agent { // First, attempt to split the request into simpler sub-tasks let sub_tasks = self.split_complex_request(description).await?; - + // If we have multiple sub-tasks, execute them sequentially if sub_tasks.len() > 1 { - println!("šŸ“‹ Breaking down request into {} sub-tasks:", sub_tasks.len()); + println!( + "šŸ“‹ Breaking down request into {} sub-tasks:", + sub_tasks.len() + ); for (i, task) in sub_tasks.iter().enumerate() { println!(" {}. {}", i + 1, task); } println!(); - + let mut all_responses = Vec::new(); - + for (i, sub_task) in sub_tasks.iter().enumerate() { println!("━━━ Sub-task {}/{} ━━━", i + 1, sub_tasks.len()); println!("šŸ“Œ {}", sub_task); println!(); - + // Execute each sub-task - let result = self.execute_single_task( - sub_task, - show_prompt, - show_code, - show_timing, - cancellation_token.clone() - ).await?; - + let result = self + .execute_single_task( + sub_task, + show_prompt, + show_code, + show_timing, + cancellation_token.clone(), + ) + .await?; + all_responses.push(result); - + // Add some spacing between tasks if i < sub_tasks.len() - 1 { println!(); } } - + // Combine all responses println!("\n━━━ All sub-tasks completed ━━━"); Ok(all_responses.join("\n\n---\n\n")) @@ -522,8 +530,9 @@ impl Agent { show_prompt, show_code, show_timing, - cancellation_token - ).await + cancellation_token, + ) + .await } } @@ -542,7 +551,27 @@ impl Agent { // Only add system message if this is the first interaction (empty conversation history) if self.context_window.conversation_history.is_empty() { - let system_prompt = "You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems by writing code. + let provider = self.providers.get(None)?; + let system_prompt = if provider.has_native_tool_calling() { + // For native tool calling providers, use a more explicit system prompt + "You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems by writing code. + +You have access to tools. When you need to accomplish a task, you MUST use the appropriate tool immediately. Do not just describe what you would do - actually use the tools. + +IMPORTANT: You must call tools to complete tasks. When you receive a request: +1. Identify what needs to be done +2. Immediately call the appropriate tool with the required parameters +3. Wait for the tool result +4. Continue or complete the task based on the result + +For shell commands: Use the shell tool with the exact command needed. Avoid commands that produce a large amount of output, and consider piping those outputs to files. Example: If asked to list files, immediately call the shell tool with command parameter \"ls\". +For task completion: Use the final_output tool with a summary. + + +Do not explain what you're going to do - just do it by calling the tools.".to_string() + } else { + // For non-native providers (embedded models), use JSON format instructions + "You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems by writing code. # Tool Call Format @@ -573,7 +602,8 @@ The tool will execute immediately and you'll receive the result (success or erro - Use Markdown formatting for all responses except tool calls. - Whenever taking actions, use the pronoun 'I' -".to_string(); +".to_string() + }; if show_prompt { println!("šŸ” System Prompt:"); @@ -601,11 +631,20 @@ The tool will execute immediately and you'll receive the result (success or erro // Use the complete conversation history for the request let messages = self.context_window.conversation_history.clone(); + // Check if provider supports native tool calling and add tools if so + let provider = self.providers.get(None)?; + let tools = if provider.has_native_tool_calling() { + Some(Self::create_tool_definitions()) + } else { + None + }; + let request = CompletionRequest { messages, max_tokens: Some(2048), temperature: Some(0.1), stream: true, // Enable streaming + tools, }; // Time the LLM call with cancellation support and streaming @@ -738,6 +777,40 @@ The tool will execute immediately and you'll receive the result (success or erro self.stream_completion_with_tools(request).await } + /// Create tool definitions for native tool calling providers + fn create_tool_definitions() -> Vec { + vec![ + Tool { + name: "shell".to_string(), + description: "Execute shell commands".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The shell command to execute" + } + }, + "required": ["command"] + }), + }, + Tool { + name: "final_output".to_string(), + description: "Signal task completion with a detailed summary".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "summary": { + "type": "string", + "description": "A detailed summary of what was accomplished" + } + }, + "required": ["summary"] + }), + }, + ] + } + async fn stream_completion_with_tools( &mut self, mut request: CompletionRequest, @@ -801,8 +874,7 @@ The tool will execute immediately and you'll receive the result (success or erro first_token_time = Some(stream_start.elapsed()); } - // Check for tool calls - either from JSON parsing (embedded models) - // or from native tool calls (Anthropic, OpenAI, etc.) + // Check for tool calls - prioritize native tool calls over JSON parsing let mut detected_tool_call = None; // First check for native tool calls in the chunk @@ -823,9 +895,9 @@ The tool will execute immediately and you'll receive the result (success or erro debug!("No native tool calls in chunk, chunk.tool_calls is None"); } - // If no native tool calls, check for JSON tool calls in text (embedded models) - // IMPORTANT: Pass raw content to parser BEFORE cleaning stop sequences - if detected_tool_call.is_none() { + // Only fall back to JSON parsing if no native tool calls and provider doesn't support native calling + if detected_tool_call.is_none() && !provider.has_native_tool_calling() { + // For embedded models and other non-native providers, parse JSON from text detected_tool_call = parser.add_chunk(&chunk.content); } @@ -958,6 +1030,11 @@ The tool will execute immediately and you'll receive the result (success or erro // Update the request with the new context for next iteration request.messages = self.context_window.conversation_history.clone(); + // Ensure tools are included for native providers in subsequent iterations + if provider.has_native_tool_calling() { + request.tools = Some(Self::create_tool_definitions()); + } + full_response.push_str(final_display_content); full_response.push_str(&format!( "\n\nTool executed: {} -> {}\n\n", @@ -1023,10 +1100,21 @@ The tool will execute immediately and you'll receive the result (success or erro } async fn execute_tool(&self, tool_call: &ToolCall) -> Result { + debug!("Executing tool: {}", tool_call.tool); + debug!("Tool call args: {:?}", tool_call.args); + debug!( + "Tool call args JSON: {}", + serde_json::to_string(&tool_call.args) + .unwrap_or_else(|_| "failed to serialize".to_string()) + ); + match tool_call.tool.as_str() { "shell" => { + debug!("Processing shell tool call"); if let Some(command) = tool_call.args.get("command") { + debug!("Found command parameter: {:?}", command); if let Some(command_str) = command.as_str() { + debug!("Command string: {}", command_str); // Use shell escaping to handle filenames with spaces and special characters let escaped_command = shell_escape_command(command_str); @@ -1046,9 +1134,18 @@ The tool will execute immediately and you'll receive the result (success or erro Err(e) => Ok(format!("āŒ Execution error: {}", e)), } } else { + debug!("Command parameter is not a string: {:?}", command); Ok("āŒ Invalid command argument".to_string()) } } else { + debug!("No command parameter found in args: {:?}", tool_call.args); + debug!( + "Available keys: {:?}", + tool_call + .args + .as_object() + .map(|obj| obj.keys().collect::>()) + ); Ok("āŒ Missing command argument".to_string()) } } diff --git a/crates/g3-providers/src/anthropic.rs b/crates/g3-providers/src/anthropic.rs index 0eb1ad3..5fc9be7 100644 --- a/crates/g3-providers/src/anthropic.rs +++ b/crates/g3-providers/src/anthropic.rs @@ -108,7 +108,7 @@ use tracing::{debug, error, info, warn}; use crate::{ CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message, - MessageRole, Usage, + MessageRole, Tool, ToolCall, Usage, }; const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages"; @@ -163,6 +163,51 @@ impl AnthropicProvider { builder } + fn convert_tools(&self, tools: &[Tool]) -> Vec { + tools + .iter() + .map(|tool| { + let mut schema = AnthropicToolInputSchema { + schema_type: "object".to_string(), + properties: serde_json::Value::Object(serde_json::Map::new()), + required: None, + }; + + // Extract properties and required fields from the input schema + if let Ok(schema_obj) = serde_json::from_value::>(tool.input_schema.clone()) { + if let Some(properties) = schema_obj.get("properties") { + schema.properties = properties.clone(); + } + if let Some(required) = schema_obj.get("required") { + if let Ok(required_vec) = serde_json::from_value::>(required.clone()) { + schema.required = Some(required_vec); + } + } + } + + AnthropicTool { + name: tool.name.clone(), + description: tool.description.clone(), + input_schema: schema, + } + }) + .collect() + } + + fn convert_anthropic_tool_calls(&self, content: &[AnthropicContent]) -> Vec { + content + .iter() + .filter_map(|c| match c { + AnthropicContent::ToolUse { id, name, input } => Some(ToolCall { + id: id.clone(), + tool: name.clone(), + args: input.clone(), + }), + _ => None, + }) + .collect() + } + fn convert_messages(&self, messages: &[Message]) -> Result<(Option, Vec)> { let mut system_message = None; let mut anthropic_messages = Vec::new(); @@ -200,6 +245,7 @@ impl AnthropicProvider { fn create_request_body( &self, messages: &[Message], + tools: Option<&[Tool]>, streaming: bool, max_tokens: u32, temperature: f32, @@ -210,12 +256,16 @@ impl AnthropicProvider { return Err(anyhow!("At least one user or assistant message is required")); } + // Convert tools if provided + let anthropic_tools = tools.map(|t| self.convert_tools(t)); + let request = AnthropicRequest { model: self.model.clone(), max_tokens, temperature, messages: anthropic_messages, system, + tools: anthropic_tools, stream: streaming, }; @@ -233,6 +283,8 @@ impl AnthropicProvider { tx: mpsc::Sender>, ) { let mut buffer = String::new(); + let mut current_tool_calls: Vec = Vec::new(); + let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls while let Some(chunk_result) = stream.next().await { match chunk_result { @@ -266,7 +318,7 @@ impl AnthropicProvider { let final_chunk = CompletionChunk { content: String::new(), finished: true, - tool_calls: None, + tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) }, }; if tx.send(Ok(final_chunk)).await.is_err() { debug!("Receiver dropped, stopping stream"); @@ -274,9 +326,38 @@ impl AnthropicProvider { return; } + debug!("Raw Claude API JSON: {}", data); + match serde_json::from_str::(data) { Ok(event) => { + debug!("Parsed event: {:?}", event); match event.event_type.as_str() { + "content_block_start" => { + debug!("Received content_block_start event: {:?}", event); + if let Some(content_block) = event.content_block { + match content_block { + AnthropicContent::ToolUse { id, name, input } => { + debug!("Found tool use in content_block_start: id={}, name={}, input={:?}", id, name, input); + debug!("Input JSON string: {}", serde_json::to_string(&input).unwrap_or_else(|_| "failed to serialize".to_string())); + + // Create initial tool call - we'll update the args later from streaming JSON + let tool_call = ToolCall { + id: id.clone(), + tool: name.clone(), + args: input, // This might be empty initially + }; + debug!("Created initial tool call: {:?}", tool_call); + current_tool_calls.push(tool_call); + + // Reset partial JSON accumulator for this tool + partial_tool_json.clear(); + } + _ => { + debug!("Non-tool content block: {:?}", content_block); + } + } + } + } "content_block_delta" => { if let Some(delta) = event.delta { if let Some(text) = delta.text { @@ -290,6 +371,44 @@ impl AnthropicProvider { return; } } + // Handle partial JSON for tool calls + if let Some(partial_json) = delta.partial_json { + debug!("Received partial JSON: {}", partial_json); + partial_tool_json.push_str(&partial_json); + debug!("Accumulated tool JSON: {}", partial_tool_json); + } + } + } + "content_block_stop" => { + // Tool call block is complete - now parse the accumulated JSON + if !current_tool_calls.is_empty() && !partial_tool_json.is_empty() { + debug!("Parsing complete tool JSON: {}", partial_tool_json); + + // Parse the accumulated JSON and update the last tool call + if let Ok(parsed_args) = serde_json::from_str::(&partial_tool_json) { + if let Some(last_tool) = current_tool_calls.last_mut() { + last_tool.args = parsed_args; + debug!("Updated tool call with complete args: {:?}", last_tool); + } + } else { + debug!("Failed to parse accumulated JSON: {}", partial_tool_json); + } + + // Clear the accumulator + partial_tool_json.clear(); + } + + // Send the complete tool call + if !current_tool_calls.is_empty() { + let chunk = CompletionChunk { + content: String::new(), + finished: false, + tool_calls: Some(current_tool_calls.clone()), + }; + if tx.send(Ok(chunk)).await.is_err() { + debug!("Receiver dropped, stopping stream"); + return; + } } } "message_stop" => { @@ -297,7 +416,7 @@ impl AnthropicProvider { let final_chunk = CompletionChunk { content: String::new(), finished: true, - tool_calls: None, + tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) }, }; if tx.send(Ok(final_chunk)).await.is_err() { debug!("Receiver dropped, stopping stream"); @@ -338,7 +457,7 @@ impl AnthropicProvider { let final_chunk = CompletionChunk { content: String::new(), finished: true, - tool_calls: None, + tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) }, }; let _ = tx.send(Ok(final_chunk)).await; } @@ -355,7 +474,13 @@ impl LLMProvider for AnthropicProvider { let max_tokens = request.max_tokens.unwrap_or(self.max_tokens); let temperature = request.temperature.unwrap_or(self.temperature); - let request_body = self.create_request_body(&request.messages, false, max_tokens, temperature)?; + let request_body = self.create_request_body( + &request.messages, + request.tools.as_deref(), + false, + max_tokens, + temperature + )?; debug!("Sending request to Anthropic API: model={}, max_tokens={}, temperature={}", request_body.model, request_body.max_tokens, request_body.temperature); @@ -387,6 +512,7 @@ impl LLMProvider for AnthropicProvider { .iter() .filter_map(|c| match c { AnthropicContent::Text { text } => Some(text.as_str()), + _ => None, }) .collect::>() .join(""); @@ -418,10 +544,19 @@ impl LLMProvider for AnthropicProvider { let max_tokens = request.max_tokens.unwrap_or(self.max_tokens); let temperature = request.temperature.unwrap_or(self.temperature); - let request_body = self.create_request_body(&request.messages, true, max_tokens, temperature)?; + let request_body = self.create_request_body( + &request.messages, + request.tools.as_deref(), + true, + max_tokens, + temperature + )?; debug!("Sending streaming request to Anthropic API: model={}, max_tokens={}, temperature={}", request_body.model, request_body.max_tokens, request_body.temperature); + + // Debug: Log the full request body + debug!("Full request body: {}", serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "Failed to serialize".to_string())); let response = self .create_request_builder(true) @@ -475,9 +610,27 @@ struct AnthropicRequest { messages: Vec, #[serde(skip_serializing_if = "Option::is_none")] system: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, stream: bool, } +#[derive(Debug, Serialize)] +struct AnthropicTool { + name: String, + description: String, + input_schema: AnthropicToolInputSchema, +} + +#[derive(Debug, Serialize)] +struct AnthropicToolInputSchema { + #[serde(rename = "type")] + schema_type: String, + properties: serde_json::Value, + #[serde(skip_serializing_if = "Option::is_none")] + required: Option>, +} + #[derive(Debug, Serialize, Deserialize)] struct AnthropicMessage { role: String, @@ -489,6 +642,12 @@ struct AnthropicMessage { enum AnthropicContent { #[serde(rename = "text")] Text { text: String }, + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, } #[derive(Debug, Deserialize)] @@ -520,6 +679,8 @@ struct AnthropicStreamEvent { delta: Option, #[serde(default)] error: Option, + #[serde(default)] + content_block: Option, } #[derive(Debug, Deserialize)] @@ -527,6 +688,7 @@ struct AnthropicDelta { #[serde(rename = "type")] delta_type: Option, text: Option, + partial_json: Option, } #[derive(Debug, Deserialize)] @@ -589,7 +751,7 @@ mod tests { ]; let request_body = provider - .create_request_body(&messages, false, 1000, 0.5) + .create_request_body(&messages, None, false, 1000, 0.5) .unwrap(); assert_eq!(request_body.model, "claude-3-haiku-20240307"); @@ -597,5 +759,70 @@ mod tests { assert_eq!(request_body.temperature, 0.5); assert!(!request_body.stream); assert_eq!(request_body.messages.len(), 1); + assert!(request_body.tools.is_none()); + } + + #[test] + fn test_tool_conversion() { + let provider = AnthropicProvider::new( + "test-key".to_string(), + None, + None, + None, + ).unwrap(); + + let tools = vec![ + Tool { + name: "get_weather".to_string(), + description: "Get the current weather".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state" + } + }, + "required": ["location"] + }), + }, + ]; + + let anthropic_tools = provider.convert_tools(&tools); + + assert_eq!(anthropic_tools.len(), 1); + assert_eq!(anthropic_tools[0].name, "get_weather"); + assert_eq!(anthropic_tools[0].description, "Get the current weather"); + assert_eq!(anthropic_tools[0].input_schema.schema_type, "object"); + assert!(anthropic_tools[0].input_schema.required.is_some()); + assert_eq!(anthropic_tools[0].input_schema.required.as_ref().unwrap()[0], "location"); + } + + #[test] + fn test_tool_call_conversion() { + let provider = AnthropicProvider::new( + "test-key".to_string(), + None, + None, + None, + ).unwrap(); + + let content = vec![ + AnthropicContent::Text { + text: "I'll help you get the weather.".to_string(), + }, + AnthropicContent::ToolUse { + id: "toolu_123".to_string(), + name: "get_weather".to_string(), + input: serde_json::json!({"location": "San Francisco, CA"}), + }, + ]; + + let tool_calls = provider.convert_anthropic_tool_calls(&content); + + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].id, "toolu_123"); + assert_eq!(tool_calls[0].tool, "get_weather"); + assert_eq!(tool_calls[0].args["location"], "San Francisco, CA"); } } diff --git a/crates/g3-providers/src/lib.rs b/crates/g3-providers/src/lib.rs index d49b375..18dc8ce 100644 --- a/crates/g3-providers/src/lib.rs +++ b/crates/g3-providers/src/lib.rs @@ -29,6 +29,7 @@ pub struct CompletionRequest { pub max_tokens: Option, pub temperature: Option, pub stream: bool, + pub tools: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -75,6 +76,13 @@ pub struct ToolCall { pub args: serde_json::Value, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Tool { + pub name: String, + pub description: String, + pub input_schema: serde_json::Value, +} + pub mod anthropic; pub use anthropic::AnthropicProvider;