Fix for tool use

This commit is contained in:
Dhanji Prasanna
2025-09-20 20:17:50 +10:00
parent 444245d7dd
commit 9a5486f2a8
4 changed files with 385 additions and 43 deletions

18
Cargo.lock generated
View File

@@ -1884,18 +1884,28 @@ checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0"
[[package]]
name = "serde"
version = "1.0.219"
version = "1.0.225"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6"
checksum = "fd6c24dee235d0da097043389623fb913daddf92c76e9f5a1db88607a0bcbd1d"
dependencies = [
"serde_core",
"serde_derive",
]
[[package]]
name = "serde_core"
version = "1.0.225"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "659356f9a0cb1e529b24c01e43ad2bdf520ec4ceaf83047b83ddcc2251f96383"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.219"
version = "1.0.225"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
checksum = "0ea936adf78b1f766949a4977b91d2f5595825bd6ec079aa9543ad2685fc4516"
dependencies = [
"proc-macro2",
"quote",

View File

@@ -1,8 +1,9 @@
use anyhow::Result;
use g3_config::Config;
use g3_execution::CodeExecutor;
use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry};
use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry, Tool};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::fs;
use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
@@ -440,13 +441,15 @@ impl Agent {
max_tokens: Some(512),
temperature: Some(0.1),
stream: false,
tools: None, // No tools needed for task splitting
};
// Use the non-streaming complete method
let response = provider.complete(request).await?;
// Split the response by newlines and filter out empty lines
let tasks: Vec<String> = response.content
let tasks: Vec<String> = response
.content
.lines()
.filter(|line| !line.trim().is_empty())
.map(|line| line.trim().to_string())
@@ -482,7 +485,10 @@ impl Agent {
// If we have multiple sub-tasks, execute them sequentially
if sub_tasks.len() > 1 {
println!("📋 Breaking down request into {} sub-tasks:", sub_tasks.len());
println!(
"📋 Breaking down request into {} sub-tasks:",
sub_tasks.len()
);
for (i, task) in sub_tasks.iter().enumerate() {
println!(" {}. {}", i + 1, task);
}
@@ -496,13 +502,15 @@ impl Agent {
println!();
// Execute each sub-task
let result = self.execute_single_task(
sub_task,
show_prompt,
show_code,
show_timing,
cancellation_token.clone()
).await?;
let result = self
.execute_single_task(
sub_task,
show_prompt,
show_code,
show_timing,
cancellation_token.clone(),
)
.await?;
all_responses.push(result);
@@ -522,8 +530,9 @@ impl Agent {
show_prompt,
show_code,
show_timing,
cancellation_token
).await
cancellation_token,
)
.await
}
}
@@ -542,7 +551,27 @@ impl Agent {
// Only add system message if this is the first interaction (empty conversation history)
if self.context_window.conversation_history.is_empty() {
let system_prompt = "You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems by writing code.
let provider = self.providers.get(None)?;
let system_prompt = if provider.has_native_tool_calling() {
// For native tool calling providers, use a more explicit system prompt
"You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems by writing code.
You have access to tools. When you need to accomplish a task, you MUST use the appropriate tool immediately. Do not just describe what you would do - actually use the tools.
IMPORTANT: You must call tools to complete tasks. When you receive a request:
1. Identify what needs to be done
2. Immediately call the appropriate tool with the required parameters
3. Wait for the tool result
4. Continue or complete the task based on the result
For shell commands: Use the shell tool with the exact command needed. Avoid commands that produce a large amount of output, and consider piping those outputs to files. Example: If asked to list files, immediately call the shell tool with command parameter \"ls\".
For task completion: Use the final_output tool with a summary.
Do not explain what you're going to do - just do it by calling the tools.".to_string()
} else {
// For non-native providers (embedded models), use JSON format instructions
"You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems by writing code.
# Tool Call Format
@@ -573,7 +602,8 @@ The tool will execute immediately and you'll receive the result (success or erro
- Use Markdown formatting for all responses except tool calls.
- Whenever taking actions, use the pronoun 'I'
".to_string();
".to_string()
};
if show_prompt {
println!("🔍 System Prompt:");
@@ -601,11 +631,20 @@ The tool will execute immediately and you'll receive the result (success or erro
// Use the complete conversation history for the request
let messages = self.context_window.conversation_history.clone();
// Check if provider supports native tool calling and add tools if so
let provider = self.providers.get(None)?;
let tools = if provider.has_native_tool_calling() {
Some(Self::create_tool_definitions())
} else {
None
};
let request = CompletionRequest {
messages,
max_tokens: Some(2048),
temperature: Some(0.1),
stream: true, // Enable streaming
tools,
};
// Time the LLM call with cancellation support and streaming
@@ -738,6 +777,40 @@ The tool will execute immediately and you'll receive the result (success or erro
self.stream_completion_with_tools(request).await
}
/// Create tool definitions for native tool calling providers
fn create_tool_definitions() -> Vec<Tool> {
vec![
Tool {
name: "shell".to_string(),
description: "Execute shell commands".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The shell command to execute"
}
},
"required": ["command"]
}),
},
Tool {
name: "final_output".to_string(),
description: "Signal task completion with a detailed summary".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"summary": {
"type": "string",
"description": "A detailed summary of what was accomplished"
}
},
"required": ["summary"]
}),
},
]
}
async fn stream_completion_with_tools(
&mut self,
mut request: CompletionRequest,
@@ -801,8 +874,7 @@ The tool will execute immediately and you'll receive the result (success or erro
first_token_time = Some(stream_start.elapsed());
}
// Check for tool calls - either from JSON parsing (embedded models)
// or from native tool calls (Anthropic, OpenAI, etc.)
// Check for tool calls - prioritize native tool calls over JSON parsing
let mut detected_tool_call = None;
// First check for native tool calls in the chunk
@@ -823,9 +895,9 @@ The tool will execute immediately and you'll receive the result (success or erro
debug!("No native tool calls in chunk, chunk.tool_calls is None");
}
// If no native tool calls, check for JSON tool calls in text (embedded models)
// IMPORTANT: Pass raw content to parser BEFORE cleaning stop sequences
if detected_tool_call.is_none() {
// Only fall back to JSON parsing if no native tool calls and provider doesn't support native calling
if detected_tool_call.is_none() && !provider.has_native_tool_calling() {
// For embedded models and other non-native providers, parse JSON from text
detected_tool_call = parser.add_chunk(&chunk.content);
}
@@ -958,6 +1030,11 @@ The tool will execute immediately and you'll receive the result (success or erro
// Update the request with the new context for next iteration
request.messages = self.context_window.conversation_history.clone();
// Ensure tools are included for native providers in subsequent iterations
if provider.has_native_tool_calling() {
request.tools = Some(Self::create_tool_definitions());
}
full_response.push_str(final_display_content);
full_response.push_str(&format!(
"\n\nTool executed: {} -> {}\n\n",
@@ -1023,10 +1100,21 @@ The tool will execute immediately and you'll receive the result (success or erro
}
async fn execute_tool(&self, tool_call: &ToolCall) -> Result<String> {
debug!("Executing tool: {}", tool_call.tool);
debug!("Tool call args: {:?}", tool_call.args);
debug!(
"Tool call args JSON: {}",
serde_json::to_string(&tool_call.args)
.unwrap_or_else(|_| "failed to serialize".to_string())
);
match tool_call.tool.as_str() {
"shell" => {
debug!("Processing shell tool call");
if let Some(command) = tool_call.args.get("command") {
debug!("Found command parameter: {:?}", command);
if let Some(command_str) = command.as_str() {
debug!("Command string: {}", command_str);
// Use shell escaping to handle filenames with spaces and special characters
let escaped_command = shell_escape_command(command_str);
@@ -1046,9 +1134,18 @@ The tool will execute immediately and you'll receive the result (success or erro
Err(e) => Ok(format!("❌ Execution error: {}", e)),
}
} else {
debug!("Command parameter is not a string: {:?}", command);
Ok("❌ Invalid command argument".to_string())
}
} else {
debug!("No command parameter found in args: {:?}", tool_call.args);
debug!(
"Available keys: {:?}",
tool_call
.args
.as_object()
.map(|obj| obj.keys().collect::<Vec<_>>())
);
Ok("❌ Missing command argument".to_string())
}
}

View File

@@ -108,7 +108,7 @@ use tracing::{debug, error, info, warn};
use crate::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
MessageRole, Usage,
MessageRole, Tool, ToolCall, Usage,
};
const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
@@ -163,6 +163,51 @@ impl AnthropicProvider {
builder
}
fn convert_tools(&self, tools: &[Tool]) -> Vec<AnthropicTool> {
tools
.iter()
.map(|tool| {
let mut schema = AnthropicToolInputSchema {
schema_type: "object".to_string(),
properties: serde_json::Value::Object(serde_json::Map::new()),
required: None,
};
// Extract properties and required fields from the input schema
if let Ok(schema_obj) = serde_json::from_value::<serde_json::Map<String, serde_json::Value>>(tool.input_schema.clone()) {
if let Some(properties) = schema_obj.get("properties") {
schema.properties = properties.clone();
}
if let Some(required) = schema_obj.get("required") {
if let Ok(required_vec) = serde_json::from_value::<Vec<String>>(required.clone()) {
schema.required = Some(required_vec);
}
}
}
AnthropicTool {
name: tool.name.clone(),
description: tool.description.clone(),
input_schema: schema,
}
})
.collect()
}
fn convert_anthropic_tool_calls(&self, content: &[AnthropicContent]) -> Vec<ToolCall> {
content
.iter()
.filter_map(|c| match c {
AnthropicContent::ToolUse { id, name, input } => Some(ToolCall {
id: id.clone(),
tool: name.clone(),
args: input.clone(),
}),
_ => None,
})
.collect()
}
fn convert_messages(&self, messages: &[Message]) -> Result<(Option<String>, Vec<AnthropicMessage>)> {
let mut system_message = None;
let mut anthropic_messages = Vec::new();
@@ -200,6 +245,7 @@ impl AnthropicProvider {
fn create_request_body(
&self,
messages: &[Message],
tools: Option<&[Tool]>,
streaming: bool,
max_tokens: u32,
temperature: f32,
@@ -210,12 +256,16 @@ impl AnthropicProvider {
return Err(anyhow!("At least one user or assistant message is required"));
}
// Convert tools if provided
let anthropic_tools = tools.map(|t| self.convert_tools(t));
let request = AnthropicRequest {
model: self.model.clone(),
max_tokens,
temperature,
messages: anthropic_messages,
system,
tools: anthropic_tools,
stream: streaming,
};
@@ -233,6 +283,8 @@ impl AnthropicProvider {
tx: mpsc::Sender<Result<CompletionChunk>>,
) {
let mut buffer = String::new();
let mut current_tool_calls: Vec<ToolCall> = Vec::new();
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
while let Some(chunk_result) = stream.next().await {
match chunk_result {
@@ -266,7 +318,7 @@ impl AnthropicProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
tool_calls: None,
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
};
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
@@ -274,9 +326,38 @@ impl AnthropicProvider {
return;
}
debug!("Raw Claude API JSON: {}", data);
match serde_json::from_str::<AnthropicStreamEvent>(data) {
Ok(event) => {
debug!("Parsed event: {:?}", event);
match event.event_type.as_str() {
"content_block_start" => {
debug!("Received content_block_start event: {:?}", event);
if let Some(content_block) = event.content_block {
match content_block {
AnthropicContent::ToolUse { id, name, input } => {
debug!("Found tool use in content_block_start: id={}, name={}, input={:?}", id, name, input);
debug!("Input JSON string: {}", serde_json::to_string(&input).unwrap_or_else(|_| "failed to serialize".to_string()));
// Create initial tool call - we'll update the args later from streaming JSON
let tool_call = ToolCall {
id: id.clone(),
tool: name.clone(),
args: input, // This might be empty initially
};
debug!("Created initial tool call: {:?}", tool_call);
current_tool_calls.push(tool_call);
// Reset partial JSON accumulator for this tool
partial_tool_json.clear();
}
_ => {
debug!("Non-tool content block: {:?}", content_block);
}
}
}
}
"content_block_delta" => {
if let Some(delta) = event.delta {
if let Some(text) = delta.text {
@@ -290,6 +371,44 @@ impl AnthropicProvider {
return;
}
}
// Handle partial JSON for tool calls
if let Some(partial_json) = delta.partial_json {
debug!("Received partial JSON: {}", partial_json);
partial_tool_json.push_str(&partial_json);
debug!("Accumulated tool JSON: {}", partial_tool_json);
}
}
}
"content_block_stop" => {
// Tool call block is complete - now parse the accumulated JSON
if !current_tool_calls.is_empty() && !partial_tool_json.is_empty() {
debug!("Parsing complete tool JSON: {}", partial_tool_json);
// Parse the accumulated JSON and update the last tool call
if let Ok(parsed_args) = serde_json::from_str::<serde_json::Value>(&partial_tool_json) {
if let Some(last_tool) = current_tool_calls.last_mut() {
last_tool.args = parsed_args;
debug!("Updated tool call with complete args: {:?}", last_tool);
}
} else {
debug!("Failed to parse accumulated JSON: {}", partial_tool_json);
}
// Clear the accumulator
partial_tool_json.clear();
}
// Send the complete tool call
if !current_tool_calls.is_empty() {
let chunk = CompletionChunk {
content: String::new(),
finished: false,
tool_calls: Some(current_tool_calls.clone()),
};
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return;
}
}
}
"message_stop" => {
@@ -297,7 +416,7 @@ impl AnthropicProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
tool_calls: None,
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
};
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
@@ -338,7 +457,7 @@ impl AnthropicProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
tool_calls: None,
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) },
};
let _ = tx.send(Ok(final_chunk)).await;
}
@@ -355,7 +474,13 @@ impl LLMProvider for AnthropicProvider {
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
let request_body = self.create_request_body(&request.messages, false, max_tokens, temperature)?;
let request_body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
false,
max_tokens,
temperature
)?;
debug!("Sending request to Anthropic API: model={}, max_tokens={}, temperature={}",
request_body.model, request_body.max_tokens, request_body.temperature);
@@ -387,6 +512,7 @@ impl LLMProvider for AnthropicProvider {
.iter()
.filter_map(|c| match c {
AnthropicContent::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("");
@@ -418,11 +544,20 @@ impl LLMProvider for AnthropicProvider {
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
let request_body = self.create_request_body(&request.messages, true, max_tokens, temperature)?;
let request_body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
true,
max_tokens,
temperature
)?;
debug!("Sending streaming request to Anthropic API: model={}, max_tokens={}, temperature={}",
request_body.model, request_body.max_tokens, request_body.temperature);
// Debug: Log the full request body
debug!("Full request body: {}", serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "Failed to serialize".to_string()));
let response = self
.create_request_builder(true)
.json(&request_body)
@@ -475,9 +610,27 @@ struct AnthropicRequest {
messages: Vec<AnthropicMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<AnthropicTool>>,
stream: bool,
}
#[derive(Debug, Serialize)]
struct AnthropicTool {
name: String,
description: String,
input_schema: AnthropicToolInputSchema,
}
#[derive(Debug, Serialize)]
struct AnthropicToolInputSchema {
#[serde(rename = "type")]
schema_type: String,
properties: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
required: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
struct AnthropicMessage {
role: String,
@@ -489,6 +642,12 @@ struct AnthropicMessage {
enum AnthropicContent {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
}
#[derive(Debug, Deserialize)]
@@ -520,6 +679,8 @@ struct AnthropicStreamEvent {
delta: Option<AnthropicDelta>,
#[serde(default)]
error: Option<AnthropicError>,
#[serde(default)]
content_block: Option<AnthropicContent>,
}
#[derive(Debug, Deserialize)]
@@ -527,6 +688,7 @@ struct AnthropicDelta {
#[serde(rename = "type")]
delta_type: Option<String>,
text: Option<String>,
partial_json: Option<String>,
}
#[derive(Debug, Deserialize)]
@@ -589,7 +751,7 @@ mod tests {
];
let request_body = provider
.create_request_body(&messages, false, 1000, 0.5)
.create_request_body(&messages, None, false, 1000, 0.5)
.unwrap();
assert_eq!(request_body.model, "claude-3-haiku-20240307");
@@ -597,5 +759,70 @@ mod tests {
assert_eq!(request_body.temperature, 0.5);
assert!(!request_body.stream);
assert_eq!(request_body.messages.len(), 1);
assert!(request_body.tools.is_none());
}
#[test]
fn test_tool_conversion() {
let provider = AnthropicProvider::new(
"test-key".to_string(),
None,
None,
None,
).unwrap();
let tools = vec![
Tool {
name: "get_weather".to_string(),
description: "Get the current weather".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}),
},
];
let anthropic_tools = provider.convert_tools(&tools);
assert_eq!(anthropic_tools.len(), 1);
assert_eq!(anthropic_tools[0].name, "get_weather");
assert_eq!(anthropic_tools[0].description, "Get the current weather");
assert_eq!(anthropic_tools[0].input_schema.schema_type, "object");
assert!(anthropic_tools[0].input_schema.required.is_some());
assert_eq!(anthropic_tools[0].input_schema.required.as_ref().unwrap()[0], "location");
}
#[test]
fn test_tool_call_conversion() {
let provider = AnthropicProvider::new(
"test-key".to_string(),
None,
None,
None,
).unwrap();
let content = vec![
AnthropicContent::Text {
text: "I'll help you get the weather.".to_string(),
},
AnthropicContent::ToolUse {
id: "toolu_123".to_string(),
name: "get_weather".to_string(),
input: serde_json::json!({"location": "San Francisco, CA"}),
},
];
let tool_calls = provider.convert_anthropic_tool_calls(&content);
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id, "toolu_123");
assert_eq!(tool_calls[0].tool, "get_weather");
assert_eq!(tool_calls[0].args["location"], "San Francisco, CA");
}
}

View File

@@ -29,6 +29,7 @@ pub struct CompletionRequest {
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub stream: bool,
pub tools: Option<Vec<Tool>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -75,6 +76,13 @@ pub struct ToolCall {
pub args: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
}
pub mod anthropic;
pub use anthropic::AnthropicProvider;