Fix for tool use

This commit is contained in:
Dhanji Prasanna
2025-09-20 20:17:50 +10:00
parent 444245d7dd
commit 9a5486f2a8
4 changed files with 385 additions and 43 deletions

View File

@@ -1,8 +1,9 @@
use anyhow::Result;
use g3_config::Config;
use g3_execution::CodeExecutor;
use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry};
use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry, Tool};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::fs;
use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
@@ -412,7 +413,7 @@ impl Agent {
/// Split a complex request into simpler sub-tasks
async fn split_complex_request(&mut self, description: &str) -> Result<Vec<String>> {
let provider = self.providers.get(None)?;
// Create a specific prompt to split the task
let split_prompt = format!(
"Analyze this request and split it into simpler, independent sub-tasks. \
@@ -423,7 +424,7 @@ impl Agent {
Sub-tasks:",
description
);
let messages = vec![
Message {
role: MessageRole::System,
@@ -434,24 +435,26 @@ impl Agent {
content: split_prompt,
},
];
let request = CompletionRequest {
messages,
max_tokens: Some(512),
temperature: Some(0.1),
stream: false,
tools: None, // No tools needed for task splitting
};
// Use the non-streaming complete method
let response = provider.complete(request).await?;
// Split the response by newlines and filter out empty lines
let tasks: Vec<String> = response.content
let tasks: Vec<String> = response
.content
.lines()
.filter(|line| !line.trim().is_empty())
.map(|line| line.trim().to_string())
.collect();
// If we got back multiple tasks, return them; otherwise return the original
if tasks.len() > 1 {
info!("Split complex request into {} sub-tasks", tasks.len());
@@ -479,39 +482,44 @@ impl Agent {
// First, attempt to split the request into simpler sub-tasks
let sub_tasks = self.split_complex_request(description).await?;
// If we have multiple sub-tasks, execute them sequentially
if sub_tasks.len() > 1 {
println!("📋 Breaking down request into {} sub-tasks:", sub_tasks.len());
println!(
"📋 Breaking down request into {} sub-tasks:",
sub_tasks.len()
);
for (i, task) in sub_tasks.iter().enumerate() {
println!(" {}. {}", i + 1, task);
}
println!();
let mut all_responses = Vec::new();
for (i, sub_task) in sub_tasks.iter().enumerate() {
println!("━━━ Sub-task {}/{} ━━━", i + 1, sub_tasks.len());
println!("📌 {}", sub_task);
println!();
// Execute each sub-task
let result = self.execute_single_task(
sub_task,
show_prompt,
show_code,
show_timing,
cancellation_token.clone()
).await?;
let result = self
.execute_single_task(
sub_task,
show_prompt,
show_code,
show_timing,
cancellation_token.clone(),
)
.await?;
all_responses.push(result);
// Add some spacing between tasks
if i < sub_tasks.len() - 1 {
println!();
}
}
// Combine all responses
println!("\n━━━ All sub-tasks completed ━━━");
Ok(all_responses.join("\n\n---\n\n"))
@@ -522,8 +530,9 @@ impl Agent {
show_prompt,
show_code,
show_timing,
cancellation_token
).await
cancellation_token,
)
.await
}
}
@@ -542,7 +551,27 @@ impl Agent {
// Only add system message if this is the first interaction (empty conversation history)
if self.context_window.conversation_history.is_empty() {
let system_prompt = "You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems by writing code.
let provider = self.providers.get(None)?;
let system_prompt = if provider.has_native_tool_calling() {
// For native tool calling providers, use a more explicit system prompt
"You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems by writing code.
You have access to tools. When you need to accomplish a task, you MUST use the appropriate tool immediately. Do not just describe what you would do - actually use the tools.
IMPORTANT: You must call tools to complete tasks. When you receive a request:
1. Identify what needs to be done
2. Immediately call the appropriate tool with the required parameters
3. Wait for the tool result
4. Continue or complete the task based on the result
For shell commands: Use the shell tool with the exact command needed. Avoid commands that produce a large amount of output, and consider piping those outputs to files. Example: If asked to list files, immediately call the shell tool with command parameter \"ls\".
For task completion: Use the final_output tool with a summary.
Do not explain what you're going to do - just do it by calling the tools.".to_string()
} else {
// For non-native providers (embedded models), use JSON format instructions
"You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems by writing code.
# Tool Call Format
@@ -573,7 +602,8 @@ The tool will execute immediately and you'll receive the result (success or erro
- Use Markdown formatting for all responses except tool calls.
- Whenever taking actions, use the pronoun 'I'
".to_string();
".to_string()
};
if show_prompt {
println!("🔍 System Prompt:");
@@ -601,11 +631,20 @@ The tool will execute immediately and you'll receive the result (success or erro
// Use the complete conversation history for the request
let messages = self.context_window.conversation_history.clone();
// Check if provider supports native tool calling and add tools if so
let provider = self.providers.get(None)?;
let tools = if provider.has_native_tool_calling() {
Some(Self::create_tool_definitions())
} else {
None
};
let request = CompletionRequest {
messages,
max_tokens: Some(2048),
temperature: Some(0.1),
stream: true, // Enable streaming
tools,
};
// Time the LLM call with cancellation support and streaming
@@ -738,6 +777,40 @@ The tool will execute immediately and you'll receive the result (success or erro
self.stream_completion_with_tools(request).await
}
/// Create tool definitions for native tool calling providers
fn create_tool_definitions() -> Vec<Tool> {
vec![
Tool {
name: "shell".to_string(),
description: "Execute shell commands".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The shell command to execute"
}
},
"required": ["command"]
}),
},
Tool {
name: "final_output".to_string(),
description: "Signal task completion with a detailed summary".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"summary": {
"type": "string",
"description": "A detailed summary of what was accomplished"
}
},
"required": ["summary"]
}),
},
]
}
async fn stream_completion_with_tools(
&mut self,
mut request: CompletionRequest,
@@ -801,8 +874,7 @@ The tool will execute immediately and you'll receive the result (success or erro
first_token_time = Some(stream_start.elapsed());
}
// Check for tool calls - either from JSON parsing (embedded models)
// or from native tool calls (Anthropic, OpenAI, etc.)
// Check for tool calls - prioritize native tool calls over JSON parsing
let mut detected_tool_call = None;
// First check for native tool calls in the chunk
@@ -823,9 +895,9 @@ The tool will execute immediately and you'll receive the result (success or erro
debug!("No native tool calls in chunk, chunk.tool_calls is None");
}
// If no native tool calls, check for JSON tool calls in text (embedded models)
// IMPORTANT: Pass raw content to parser BEFORE cleaning stop sequences
if detected_tool_call.is_none() {
// Only fall back to JSON parsing if no native tool calls and provider doesn't support native calling
if detected_tool_call.is_none() && !provider.has_native_tool_calling() {
// For embedded models and other non-native providers, parse JSON from text
detected_tool_call = parser.add_chunk(&chunk.content);
}
@@ -958,6 +1030,11 @@ The tool will execute immediately and you'll receive the result (success or erro
// Update the request with the new context for next iteration
request.messages = self.context_window.conversation_history.clone();
// Ensure tools are included for native providers in subsequent iterations
if provider.has_native_tool_calling() {
request.tools = Some(Self::create_tool_definitions());
}
full_response.push_str(final_display_content);
full_response.push_str(&format!(
"\n\nTool executed: {} -> {}\n\n",
@@ -1023,10 +1100,21 @@ The tool will execute immediately and you'll receive the result (success or erro
}
async fn execute_tool(&self, tool_call: &ToolCall) -> Result<String> {
debug!("Executing tool: {}", tool_call.tool);
debug!("Tool call args: {:?}", tool_call.args);
debug!(
"Tool call args JSON: {}",
serde_json::to_string(&tool_call.args)
.unwrap_or_else(|_| "failed to serialize".to_string())
);
match tool_call.tool.as_str() {
"shell" => {
debug!("Processing shell tool call");
if let Some(command) = tool_call.args.get("command") {
debug!("Found command parameter: {:?}", command);
if let Some(command_str) = command.as_str() {
debug!("Command string: {}", command_str);
// Use shell escaping to handle filenames with spaces and special characters
let escaped_command = shell_escape_command(command_str);
@@ -1046,9 +1134,18 @@ The tool will execute immediately and you'll receive the result (success or erro
Err(e) => Ok(format!("❌ Execution error: {}", e)),
}
} else {
debug!("Command parameter is not a string: {:?}", command);
Ok("❌ Invalid command argument".to_string())
}
} else {
debug!("No command parameter found in args: {:?}", tool_call.args);
debug!(
"Available keys: {:?}",
tool_call
.args
.as_object()
.map(|obj| obj.keys().collect::<Vec<_>>())
);
Ok("❌ Missing command argument".to_string())
}
}