Fix for tool use
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
use anyhow::Result;
|
||||
use g3_config::Config;
|
||||
use g3_execution::CodeExecutor;
|
||||
use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry};
|
||||
use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry, Tool};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::fs;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -412,7 +413,7 @@ impl Agent {
|
||||
/// Split a complex request into simpler sub-tasks
|
||||
async fn split_complex_request(&mut self, description: &str) -> Result<Vec<String>> {
|
||||
let provider = self.providers.get(None)?;
|
||||
|
||||
|
||||
// Create a specific prompt to split the task
|
||||
let split_prompt = format!(
|
||||
"Analyze this request and split it into simpler, independent sub-tasks. \
|
||||
@@ -423,7 +424,7 @@ impl Agent {
|
||||
Sub-tasks:",
|
||||
description
|
||||
);
|
||||
|
||||
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: MessageRole::System,
|
||||
@@ -434,24 +435,26 @@ impl Agent {
|
||||
content: split_prompt,
|
||||
},
|
||||
];
|
||||
|
||||
|
||||
let request = CompletionRequest {
|
||||
messages,
|
||||
max_tokens: Some(512),
|
||||
temperature: Some(0.1),
|
||||
stream: false,
|
||||
tools: None, // No tools needed for task splitting
|
||||
};
|
||||
|
||||
|
||||
// Use the non-streaming complete method
|
||||
let response = provider.complete(request).await?;
|
||||
|
||||
|
||||
// Split the response by newlines and filter out empty lines
|
||||
let tasks: Vec<String> = response.content
|
||||
let tasks: Vec<String> = response
|
||||
.content
|
||||
.lines()
|
||||
.filter(|line| !line.trim().is_empty())
|
||||
.map(|line| line.trim().to_string())
|
||||
.collect();
|
||||
|
||||
|
||||
// If we got back multiple tasks, return them; otherwise return the original
|
||||
if tasks.len() > 1 {
|
||||
info!("Split complex request into {} sub-tasks", tasks.len());
|
||||
@@ -479,39 +482,44 @@ impl Agent {
|
||||
|
||||
// First, attempt to split the request into simpler sub-tasks
|
||||
let sub_tasks = self.split_complex_request(description).await?;
|
||||
|
||||
|
||||
// If we have multiple sub-tasks, execute them sequentially
|
||||
if sub_tasks.len() > 1 {
|
||||
println!("📋 Breaking down request into {} sub-tasks:", sub_tasks.len());
|
||||
println!(
|
||||
"📋 Breaking down request into {} sub-tasks:",
|
||||
sub_tasks.len()
|
||||
);
|
||||
for (i, task) in sub_tasks.iter().enumerate() {
|
||||
println!(" {}. {}", i + 1, task);
|
||||
}
|
||||
println!();
|
||||
|
||||
|
||||
let mut all_responses = Vec::new();
|
||||
|
||||
|
||||
for (i, sub_task) in sub_tasks.iter().enumerate() {
|
||||
println!("━━━ Sub-task {}/{} ━━━", i + 1, sub_tasks.len());
|
||||
println!("📌 {}", sub_task);
|
||||
println!();
|
||||
|
||||
|
||||
// Execute each sub-task
|
||||
let result = self.execute_single_task(
|
||||
sub_task,
|
||||
show_prompt,
|
||||
show_code,
|
||||
show_timing,
|
||||
cancellation_token.clone()
|
||||
).await?;
|
||||
|
||||
let result = self
|
||||
.execute_single_task(
|
||||
sub_task,
|
||||
show_prompt,
|
||||
show_code,
|
||||
show_timing,
|
||||
cancellation_token.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
all_responses.push(result);
|
||||
|
||||
|
||||
// Add some spacing between tasks
|
||||
if i < sub_tasks.len() - 1 {
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Combine all responses
|
||||
println!("\n━━━ All sub-tasks completed ━━━");
|
||||
Ok(all_responses.join("\n\n---\n\n"))
|
||||
@@ -522,8 +530,9 @@ impl Agent {
|
||||
show_prompt,
|
||||
show_code,
|
||||
show_timing,
|
||||
cancellation_token
|
||||
).await
|
||||
cancellation_token,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -542,7 +551,27 @@ impl Agent {
|
||||
|
||||
// Only add system message if this is the first interaction (empty conversation history)
|
||||
if self.context_window.conversation_history.is_empty() {
|
||||
let system_prompt = "You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems by writing code.
|
||||
let provider = self.providers.get(None)?;
|
||||
let system_prompt = if provider.has_native_tool_calling() {
|
||||
// For native tool calling providers, use a more explicit system prompt
|
||||
"You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems by writing code.
|
||||
|
||||
You have access to tools. When you need to accomplish a task, you MUST use the appropriate tool immediately. Do not just describe what you would do - actually use the tools.
|
||||
|
||||
IMPORTANT: You must call tools to complete tasks. When you receive a request:
|
||||
1. Identify what needs to be done
|
||||
2. Immediately call the appropriate tool with the required parameters
|
||||
3. Wait for the tool result
|
||||
4. Continue or complete the task based on the result
|
||||
|
||||
For shell commands: Use the shell tool with the exact command needed. Avoid commands that produce a large amount of output, and consider piping those outputs to files. Example: If asked to list files, immediately call the shell tool with command parameter \"ls\".
|
||||
For task completion: Use the final_output tool with a summary.
|
||||
|
||||
|
||||
Do not explain what you're going to do - just do it by calling the tools.".to_string()
|
||||
} else {
|
||||
// For non-native providers (embedded models), use JSON format instructions
|
||||
"You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems by writing code.
|
||||
|
||||
# Tool Call Format
|
||||
|
||||
@@ -573,7 +602,8 @@ The tool will execute immediately and you'll receive the result (success or erro
|
||||
- Use Markdown formatting for all responses except tool calls.
|
||||
- Whenever taking actions, use the pronoun 'I'
|
||||
|
||||
".to_string();
|
||||
".to_string()
|
||||
};
|
||||
|
||||
if show_prompt {
|
||||
println!("🔍 System Prompt:");
|
||||
@@ -601,11 +631,20 @@ The tool will execute immediately and you'll receive the result (success or erro
|
||||
// Use the complete conversation history for the request
|
||||
let messages = self.context_window.conversation_history.clone();
|
||||
|
||||
// Check if provider supports native tool calling and add tools if so
|
||||
let provider = self.providers.get(None)?;
|
||||
let tools = if provider.has_native_tool_calling() {
|
||||
Some(Self::create_tool_definitions())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let request = CompletionRequest {
|
||||
messages,
|
||||
max_tokens: Some(2048),
|
||||
temperature: Some(0.1),
|
||||
stream: true, // Enable streaming
|
||||
tools,
|
||||
};
|
||||
|
||||
// Time the LLM call with cancellation support and streaming
|
||||
@@ -738,6 +777,40 @@ The tool will execute immediately and you'll receive the result (success or erro
|
||||
self.stream_completion_with_tools(request).await
|
||||
}
|
||||
|
||||
/// Create tool definitions for native tool calling providers
|
||||
fn create_tool_definitions() -> Vec<Tool> {
|
||||
vec![
|
||||
Tool {
|
||||
name: "shell".to_string(),
|
||||
description: "Execute shell commands".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The shell command to execute"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}),
|
||||
},
|
||||
Tool {
|
||||
name: "final_output".to_string(),
|
||||
description: "Signal task completion with a detailed summary".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"summary": {
|
||||
"type": "string",
|
||||
"description": "A detailed summary of what was accomplished"
|
||||
}
|
||||
},
|
||||
"required": ["summary"]
|
||||
}),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
async fn stream_completion_with_tools(
|
||||
&mut self,
|
||||
mut request: CompletionRequest,
|
||||
@@ -801,8 +874,7 @@ The tool will execute immediately and you'll receive the result (success or erro
|
||||
first_token_time = Some(stream_start.elapsed());
|
||||
}
|
||||
|
||||
// Check for tool calls - either from JSON parsing (embedded models)
|
||||
// or from native tool calls (Anthropic, OpenAI, etc.)
|
||||
// Check for tool calls - prioritize native tool calls over JSON parsing
|
||||
let mut detected_tool_call = None;
|
||||
|
||||
// First check for native tool calls in the chunk
|
||||
@@ -823,9 +895,9 @@ The tool will execute immediately and you'll receive the result (success or erro
|
||||
debug!("No native tool calls in chunk, chunk.tool_calls is None");
|
||||
}
|
||||
|
||||
// If no native tool calls, check for JSON tool calls in text (embedded models)
|
||||
// IMPORTANT: Pass raw content to parser BEFORE cleaning stop sequences
|
||||
if detected_tool_call.is_none() {
|
||||
// Only fall back to JSON parsing if no native tool calls and provider doesn't support native calling
|
||||
if detected_tool_call.is_none() && !provider.has_native_tool_calling() {
|
||||
// For embedded models and other non-native providers, parse JSON from text
|
||||
detected_tool_call = parser.add_chunk(&chunk.content);
|
||||
}
|
||||
|
||||
@@ -958,6 +1030,11 @@ The tool will execute immediately and you'll receive the result (success or erro
|
||||
// Update the request with the new context for next iteration
|
||||
request.messages = self.context_window.conversation_history.clone();
|
||||
|
||||
// Ensure tools are included for native providers in subsequent iterations
|
||||
if provider.has_native_tool_calling() {
|
||||
request.tools = Some(Self::create_tool_definitions());
|
||||
}
|
||||
|
||||
full_response.push_str(final_display_content);
|
||||
full_response.push_str(&format!(
|
||||
"\n\nTool executed: {} -> {}\n\n",
|
||||
@@ -1023,10 +1100,21 @@ The tool will execute immediately and you'll receive the result (success or erro
|
||||
}
|
||||
|
||||
async fn execute_tool(&self, tool_call: &ToolCall) -> Result<String> {
|
||||
debug!("Executing tool: {}", tool_call.tool);
|
||||
debug!("Tool call args: {:?}", tool_call.args);
|
||||
debug!(
|
||||
"Tool call args JSON: {}",
|
||||
serde_json::to_string(&tool_call.args)
|
||||
.unwrap_or_else(|_| "failed to serialize".to_string())
|
||||
);
|
||||
|
||||
match tool_call.tool.as_str() {
|
||||
"shell" => {
|
||||
debug!("Processing shell tool call");
|
||||
if let Some(command) = tool_call.args.get("command") {
|
||||
debug!("Found command parameter: {:?}", command);
|
||||
if let Some(command_str) = command.as_str() {
|
||||
debug!("Command string: {}", command_str);
|
||||
// Use shell escaping to handle filenames with spaces and special characters
|
||||
let escaped_command = shell_escape_command(command_str);
|
||||
|
||||
@@ -1046,9 +1134,18 @@ The tool will execute immediately and you'll receive the result (success or erro
|
||||
Err(e) => Ok(format!("❌ Execution error: {}", e)),
|
||||
}
|
||||
} else {
|
||||
debug!("Command parameter is not a string: {:?}", command);
|
||||
Ok("❌ Invalid command argument".to_string())
|
||||
}
|
||||
} else {
|
||||
debug!("No command parameter found in args: {:?}", tool_call.args);
|
||||
debug!(
|
||||
"Available keys: {:?}",
|
||||
tool_call
|
||||
.args
|
||||
.as_object()
|
||||
.map(|obj| obj.keys().collect::<Vec<_>>())
|
||||
);
|
||||
Ok("❌ Missing command argument".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user