Fix for tool use
This commit is contained in:
18
Cargo.lock
generated
18
Cargo.lock
generated
@@ -1884,18 +1884,28 @@ checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde"
|
name = "serde"
|
||||||
version = "1.0.219"
|
version = "1.0.225"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6"
|
checksum = "fd6c24dee235d0da097043389623fb913daddf92c76e9f5a1db88607a0bcbd1d"
|
||||||
|
dependencies = [
|
||||||
|
"serde_core",
|
||||||
|
"serde_derive",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_core"
|
||||||
|
version = "1.0.225"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "659356f9a0cb1e529b24c01e43ad2bdf520ec4ceaf83047b83ddcc2251f96383"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde_derive",
|
"serde_derive",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_derive"
|
name = "serde_derive"
|
||||||
version = "1.0.219"
|
version = "1.0.225"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
|
checksum = "0ea936adf78b1f766949a4977b91d2f5595825bd6ec079aa9543ad2685fc4516"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use g3_config::Config;
|
use g3_config::Config;
|
||||||
use g3_execution::CodeExecutor;
|
use g3_execution::CodeExecutor;
|
||||||
use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry};
|
use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry, Tool};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::json;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
@@ -440,13 +441,15 @@ impl Agent {
|
|||||||
max_tokens: Some(512),
|
max_tokens: Some(512),
|
||||||
temperature: Some(0.1),
|
temperature: Some(0.1),
|
||||||
stream: false,
|
stream: false,
|
||||||
|
tools: None, // No tools needed for task splitting
|
||||||
};
|
};
|
||||||
|
|
||||||
// Use the non-streaming complete method
|
// Use the non-streaming complete method
|
||||||
let response = provider.complete(request).await?;
|
let response = provider.complete(request).await?;
|
||||||
|
|
||||||
// Split the response by newlines and filter out empty lines
|
// Split the response by newlines and filter out empty lines
|
||||||
let tasks: Vec<String> = response.content
|
let tasks: Vec<String> = response
|
||||||
|
.content
|
||||||
.lines()
|
.lines()
|
||||||
.filter(|line| !line.trim().is_empty())
|
.filter(|line| !line.trim().is_empty())
|
||||||
.map(|line| line.trim().to_string())
|
.map(|line| line.trim().to_string())
|
||||||
@@ -482,7 +485,10 @@ impl Agent {
|
|||||||
|
|
||||||
// If we have multiple sub-tasks, execute them sequentially
|
// If we have multiple sub-tasks, execute them sequentially
|
||||||
if sub_tasks.len() > 1 {
|
if sub_tasks.len() > 1 {
|
||||||
println!("📋 Breaking down request into {} sub-tasks:", sub_tasks.len());
|
println!(
|
||||||
|
"📋 Breaking down request into {} sub-tasks:",
|
||||||
|
sub_tasks.len()
|
||||||
|
);
|
||||||
for (i, task) in sub_tasks.iter().enumerate() {
|
for (i, task) in sub_tasks.iter().enumerate() {
|
||||||
println!(" {}. {}", i + 1, task);
|
println!(" {}. {}", i + 1, task);
|
||||||
}
|
}
|
||||||
@@ -496,13 +502,15 @@ impl Agent {
|
|||||||
println!();
|
println!();
|
||||||
|
|
||||||
// Execute each sub-task
|
// Execute each sub-task
|
||||||
let result = self.execute_single_task(
|
let result = self
|
||||||
|
.execute_single_task(
|
||||||
sub_task,
|
sub_task,
|
||||||
show_prompt,
|
show_prompt,
|
||||||
show_code,
|
show_code,
|
||||||
show_timing,
|
show_timing,
|
||||||
cancellation_token.clone()
|
cancellation_token.clone(),
|
||||||
).await?;
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
all_responses.push(result);
|
all_responses.push(result);
|
||||||
|
|
||||||
@@ -522,8 +530,9 @@ impl Agent {
|
|||||||
show_prompt,
|
show_prompt,
|
||||||
show_code,
|
show_code,
|
||||||
show_timing,
|
show_timing,
|
||||||
cancellation_token
|
cancellation_token,
|
||||||
).await
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -542,7 +551,27 @@ impl Agent {
|
|||||||
|
|
||||||
// Only add system message if this is the first interaction (empty conversation history)
|
// Only add system message if this is the first interaction (empty conversation history)
|
||||||
if self.context_window.conversation_history.is_empty() {
|
if self.context_window.conversation_history.is_empty() {
|
||||||
let system_prompt = "You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems by writing code.
|
let provider = self.providers.get(None)?;
|
||||||
|
let system_prompt = if provider.has_native_tool_calling() {
|
||||||
|
// For native tool calling providers, use a more explicit system prompt
|
||||||
|
"You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems by writing code.
|
||||||
|
|
||||||
|
You have access to tools. When you need to accomplish a task, you MUST use the appropriate tool immediately. Do not just describe what you would do - actually use the tools.
|
||||||
|
|
||||||
|
IMPORTANT: You must call tools to complete tasks. When you receive a request:
|
||||||
|
1. Identify what needs to be done
|
||||||
|
2. Immediately call the appropriate tool with the required parameters
|
||||||
|
3. Wait for the tool result
|
||||||
|
4. Continue or complete the task based on the result
|
||||||
|
|
||||||
|
For shell commands: Use the shell tool with the exact command needed. Avoid commands that produce a large amount of output, and consider piping those outputs to files. Example: If asked to list files, immediately call the shell tool with command parameter \"ls\".
|
||||||
|
For task completion: Use the final_output tool with a summary.
|
||||||
|
|
||||||
|
|
||||||
|
Do not explain what you're going to do - just do it by calling the tools.".to_string()
|
||||||
|
} else {
|
||||||
|
// For non-native providers (embedded models), use JSON format instructions
|
||||||
|
"You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems by writing code.
|
||||||
|
|
||||||
# Tool Call Format
|
# Tool Call Format
|
||||||
|
|
||||||
@@ -573,7 +602,8 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
- Use Markdown formatting for all responses except tool calls.
|
- Use Markdown formatting for all responses except tool calls.
|
||||||
- Whenever taking actions, use the pronoun 'I'
|
- Whenever taking actions, use the pronoun 'I'
|
||||||
|
|
||||||
".to_string();
|
".to_string()
|
||||||
|
};
|
||||||
|
|
||||||
if show_prompt {
|
if show_prompt {
|
||||||
println!("🔍 System Prompt:");
|
println!("🔍 System Prompt:");
|
||||||
@@ -601,11 +631,20 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
// Use the complete conversation history for the request
|
// Use the complete conversation history for the request
|
||||||
let messages = self.context_window.conversation_history.clone();
|
let messages = self.context_window.conversation_history.clone();
|
||||||
|
|
||||||
|
// Check if provider supports native tool calling and add tools if so
|
||||||
|
let provider = self.providers.get(None)?;
|
||||||
|
let tools = if provider.has_native_tool_calling() {
|
||||||
|
Some(Self::create_tool_definitions())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
let request = CompletionRequest {
|
let request = CompletionRequest {
|
||||||
messages,
|
messages,
|
||||||
max_tokens: Some(2048),
|
max_tokens: Some(2048),
|
||||||
temperature: Some(0.1),
|
temperature: Some(0.1),
|
||||||
stream: true, // Enable streaming
|
stream: true, // Enable streaming
|
||||||
|
tools,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Time the LLM call with cancellation support and streaming
|
// Time the LLM call with cancellation support and streaming
|
||||||
@@ -738,6 +777,40 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
self.stream_completion_with_tools(request).await
|
self.stream_completion_with_tools(request).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create tool definitions for native tool calling providers
|
||||||
|
fn create_tool_definitions() -> Vec<Tool> {
|
||||||
|
vec![
|
||||||
|
Tool {
|
||||||
|
name: "shell".to_string(),
|
||||||
|
description: "Execute shell commands".to_string(),
|
||||||
|
input_schema: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The shell command to execute"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["command"]
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
Tool {
|
||||||
|
name: "final_output".to_string(),
|
||||||
|
description: "Signal task completion with a detailed summary".to_string(),
|
||||||
|
input_schema: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"summary": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "A detailed summary of what was accomplished"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["summary"]
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
async fn stream_completion_with_tools(
|
async fn stream_completion_with_tools(
|
||||||
&mut self,
|
&mut self,
|
||||||
mut request: CompletionRequest,
|
mut request: CompletionRequest,
|
||||||
@@ -801,8 +874,7 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
first_token_time = Some(stream_start.elapsed());
|
first_token_time = Some(stream_start.elapsed());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for tool calls - either from JSON parsing (embedded models)
|
// Check for tool calls - prioritize native tool calls over JSON parsing
|
||||||
// or from native tool calls (Anthropic, OpenAI, etc.)
|
|
||||||
let mut detected_tool_call = None;
|
let mut detected_tool_call = None;
|
||||||
|
|
||||||
// First check for native tool calls in the chunk
|
// First check for native tool calls in the chunk
|
||||||
@@ -823,9 +895,9 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
debug!("No native tool calls in chunk, chunk.tool_calls is None");
|
debug!("No native tool calls in chunk, chunk.tool_calls is None");
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no native tool calls, check for JSON tool calls in text (embedded models)
|
// Only fall back to JSON parsing if no native tool calls and provider doesn't support native calling
|
||||||
// IMPORTANT: Pass raw content to parser BEFORE cleaning stop sequences
|
if detected_tool_call.is_none() && !provider.has_native_tool_calling() {
|
||||||
if detected_tool_call.is_none() {
|
// For embedded models and other non-native providers, parse JSON from text
|
||||||
detected_tool_call = parser.add_chunk(&chunk.content);
|
detected_tool_call = parser.add_chunk(&chunk.content);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -958,6 +1030,11 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
// Update the request with the new context for next iteration
|
// Update the request with the new context for next iteration
|
||||||
request.messages = self.context_window.conversation_history.clone();
|
request.messages = self.context_window.conversation_history.clone();
|
||||||
|
|
||||||
|
// Ensure tools are included for native providers in subsequent iterations
|
||||||
|
if provider.has_native_tool_calling() {
|
||||||
|
request.tools = Some(Self::create_tool_definitions());
|
||||||
|
}
|
||||||
|
|
||||||
full_response.push_str(final_display_content);
|
full_response.push_str(final_display_content);
|
||||||
full_response.push_str(&format!(
|
full_response.push_str(&format!(
|
||||||
"\n\nTool executed: {} -> {}\n\n",
|
"\n\nTool executed: {} -> {}\n\n",
|
||||||
@@ -1023,10 +1100,21 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn execute_tool(&self, tool_call: &ToolCall) -> Result<String> {
|
async fn execute_tool(&self, tool_call: &ToolCall) -> Result<String> {
|
||||||
|
debug!("Executing tool: {}", tool_call.tool);
|
||||||
|
debug!("Tool call args: {:?}", tool_call.args);
|
||||||
|
debug!(
|
||||||
|
"Tool call args JSON: {}",
|
||||||
|
serde_json::to_string(&tool_call.args)
|
||||||
|
.unwrap_or_else(|_| "failed to serialize".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
match tool_call.tool.as_str() {
|
match tool_call.tool.as_str() {
|
||||||
"shell" => {
|
"shell" => {
|
||||||
|
debug!("Processing shell tool call");
|
||||||
if let Some(command) = tool_call.args.get("command") {
|
if let Some(command) = tool_call.args.get("command") {
|
||||||
|
debug!("Found command parameter: {:?}", command);
|
||||||
if let Some(command_str) = command.as_str() {
|
if let Some(command_str) = command.as_str() {
|
||||||
|
debug!("Command string: {}", command_str);
|
||||||
// Use shell escaping to handle filenames with spaces and special characters
|
// Use shell escaping to handle filenames with spaces and special characters
|
||||||
let escaped_command = shell_escape_command(command_str);
|
let escaped_command = shell_escape_command(command_str);
|
||||||
|
|
||||||
@@ -1046,9 +1134,18 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
Err(e) => Ok(format!("❌ Execution error: {}", e)),
|
Err(e) => Ok(format!("❌ Execution error: {}", e)),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
debug!("Command parameter is not a string: {:?}", command);
|
||||||
Ok("❌ Invalid command argument".to_string())
|
Ok("❌ Invalid command argument".to_string())
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
debug!("No command parameter found in args: {:?}", tool_call.args);
|
||||||
|
debug!(
|
||||||
|
"Available keys: {:?}",
|
||||||
|
tool_call
|
||||||
|
.args
|
||||||
|
.as_object()
|
||||||
|
.map(|obj| obj.keys().collect::<Vec<_>>())
|
||||||
|
);
|
||||||
Ok("❌ Missing command argument".to_string())
|
Ok("❌ Missing command argument".to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ use tracing::{debug, error, info, warn};
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
|
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
|
||||||
MessageRole, Usage,
|
MessageRole, Tool, ToolCall, Usage,
|
||||||
};
|
};
|
||||||
|
|
||||||
const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
|
const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
|
||||||
@@ -163,6 +163,51 @@ impl AnthropicProvider {
|
|||||||
builder
|
builder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn convert_tools(&self, tools: &[Tool]) -> Vec<AnthropicTool> {
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.map(|tool| {
|
||||||
|
let mut schema = AnthropicToolInputSchema {
|
||||||
|
schema_type: "object".to_string(),
|
||||||
|
properties: serde_json::Value::Object(serde_json::Map::new()),
|
||||||
|
required: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Extract properties and required fields from the input schema
|
||||||
|
if let Ok(schema_obj) = serde_json::from_value::<serde_json::Map<String, serde_json::Value>>(tool.input_schema.clone()) {
|
||||||
|
if let Some(properties) = schema_obj.get("properties") {
|
||||||
|
schema.properties = properties.clone();
|
||||||
|
}
|
||||||
|
if let Some(required) = schema_obj.get("required") {
|
||||||
|
if let Ok(required_vec) = serde_json::from_value::<Vec<String>>(required.clone()) {
|
||||||
|
schema.required = Some(required_vec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AnthropicTool {
|
||||||
|
name: tool.name.clone(),
|
||||||
|
description: tool.description.clone(),
|
||||||
|
input_schema: schema,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_anthropic_tool_calls(&self, content: &[AnthropicContent]) -> Vec<ToolCall> {
|
||||||
|
content
|
||||||
|
.iter()
|
||||||
|
.filter_map(|c| match c {
|
||||||
|
AnthropicContent::ToolUse { id, name, input } => Some(ToolCall {
|
||||||
|
id: id.clone(),
|
||||||
|
tool: name.clone(),
|
||||||
|
args: input.clone(),
|
||||||
|
}),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
fn convert_messages(&self, messages: &[Message]) -> Result<(Option<String>, Vec<AnthropicMessage>)> {
|
fn convert_messages(&self, messages: &[Message]) -> Result<(Option<String>, Vec<AnthropicMessage>)> {
|
||||||
let mut system_message = None;
|
let mut system_message = None;
|
||||||
let mut anthropic_messages = Vec::new();
|
let mut anthropic_messages = Vec::new();
|
||||||
@@ -200,6 +245,7 @@ impl AnthropicProvider {
|
|||||||
fn create_request_body(
|
fn create_request_body(
|
||||||
&self,
|
&self,
|
||||||
messages: &[Message],
|
messages: &[Message],
|
||||||
|
tools: Option<&[Tool]>,
|
||||||
streaming: bool,
|
streaming: bool,
|
||||||
max_tokens: u32,
|
max_tokens: u32,
|
||||||
temperature: f32,
|
temperature: f32,
|
||||||
@@ -210,12 +256,16 @@ impl AnthropicProvider {
|
|||||||
return Err(anyhow!("At least one user or assistant message is required"));
|
return Err(anyhow!("At least one user or assistant message is required"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert tools if provided
|
||||||
|
let anthropic_tools = tools.map(|t| self.convert_tools(t));
|
||||||
|
|
||||||
let request = AnthropicRequest {
|
let request = AnthropicRequest {
|
||||||
model: self.model.clone(),
|
model: self.model.clone(),
|
||||||
max_tokens,
|
max_tokens,
|
||||||
temperature,
|
temperature,
|
||||||
messages: anthropic_messages,
|
messages: anthropic_messages,
|
||||||
system,
|
system,
|
||||||
|
tools: anthropic_tools,
|
||||||
stream: streaming,
|
stream: streaming,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -233,6 +283,8 @@ impl AnthropicProvider {
|
|||||||
tx: mpsc::Sender<Result<CompletionChunk>>,
|
tx: mpsc::Sender<Result<CompletionChunk>>,
|
||||||
) {
|
) {
|
||||||
let mut buffer = String::new();
|
let mut buffer = String::new();
|
||||||
|
let mut current_tool_calls: Vec<ToolCall> = Vec::new();
|
||||||
|
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
|
||||||
|
|
||||||
while let Some(chunk_result) = stream.next().await {
|
while let Some(chunk_result) = stream.next().await {
|
||||||
match chunk_result {
|
match chunk_result {
|
||||||
@@ -266,7 +318,7 @@ impl AnthropicProvider {
|
|||||||
let final_chunk = CompletionChunk {
|
let final_chunk = CompletionChunk {
|
||||||
content: String::new(),
|
content: String::new(),
|
||||||
finished: true,
|
finished: true,
|
||||||
tool_calls: None,
|
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
|
||||||
};
|
};
|
||||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||||
debug!("Receiver dropped, stopping stream");
|
debug!("Receiver dropped, stopping stream");
|
||||||
@@ -274,9 +326,38 @@ impl AnthropicProvider {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
debug!("Raw Claude API JSON: {}", data);
|
||||||
|
|
||||||
match serde_json::from_str::<AnthropicStreamEvent>(data) {
|
match serde_json::from_str::<AnthropicStreamEvent>(data) {
|
||||||
Ok(event) => {
|
Ok(event) => {
|
||||||
|
debug!("Parsed event: {:?}", event);
|
||||||
match event.event_type.as_str() {
|
match event.event_type.as_str() {
|
||||||
|
"content_block_start" => {
|
||||||
|
debug!("Received content_block_start event: {:?}", event);
|
||||||
|
if let Some(content_block) = event.content_block {
|
||||||
|
match content_block {
|
||||||
|
AnthropicContent::ToolUse { id, name, input } => {
|
||||||
|
debug!("Found tool use in content_block_start: id={}, name={}, input={:?}", id, name, input);
|
||||||
|
debug!("Input JSON string: {}", serde_json::to_string(&input).unwrap_or_else(|_| "failed to serialize".to_string()));
|
||||||
|
|
||||||
|
// Create initial tool call - we'll update the args later from streaming JSON
|
||||||
|
let tool_call = ToolCall {
|
||||||
|
id: id.clone(),
|
||||||
|
tool: name.clone(),
|
||||||
|
args: input, // This might be empty initially
|
||||||
|
};
|
||||||
|
debug!("Created initial tool call: {:?}", tool_call);
|
||||||
|
current_tool_calls.push(tool_call);
|
||||||
|
|
||||||
|
// Reset partial JSON accumulator for this tool
|
||||||
|
partial_tool_json.clear();
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
debug!("Non-tool content block: {:?}", content_block);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
"content_block_delta" => {
|
"content_block_delta" => {
|
||||||
if let Some(delta) = event.delta {
|
if let Some(delta) = event.delta {
|
||||||
if let Some(text) = delta.text {
|
if let Some(text) = delta.text {
|
||||||
@@ -290,6 +371,44 @@ impl AnthropicProvider {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Handle partial JSON for tool calls
|
||||||
|
if let Some(partial_json) = delta.partial_json {
|
||||||
|
debug!("Received partial JSON: {}", partial_json);
|
||||||
|
partial_tool_json.push_str(&partial_json);
|
||||||
|
debug!("Accumulated tool JSON: {}", partial_tool_json);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"content_block_stop" => {
|
||||||
|
// Tool call block is complete - now parse the accumulated JSON
|
||||||
|
if !current_tool_calls.is_empty() && !partial_tool_json.is_empty() {
|
||||||
|
debug!("Parsing complete tool JSON: {}", partial_tool_json);
|
||||||
|
|
||||||
|
// Parse the accumulated JSON and update the last tool call
|
||||||
|
if let Ok(parsed_args) = serde_json::from_str::<serde_json::Value>(&partial_tool_json) {
|
||||||
|
if let Some(last_tool) = current_tool_calls.last_mut() {
|
||||||
|
last_tool.args = parsed_args;
|
||||||
|
debug!("Updated tool call with complete args: {:?}", last_tool);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
debug!("Failed to parse accumulated JSON: {}", partial_tool_json);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear the accumulator
|
||||||
|
partial_tool_json.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the complete tool call
|
||||||
|
if !current_tool_calls.is_empty() {
|
||||||
|
let chunk = CompletionChunk {
|
||||||
|
content: String::new(),
|
||||||
|
finished: false,
|
||||||
|
tool_calls: Some(current_tool_calls.clone()),
|
||||||
|
};
|
||||||
|
if tx.send(Ok(chunk)).await.is_err() {
|
||||||
|
debug!("Receiver dropped, stopping stream");
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"message_stop" => {
|
"message_stop" => {
|
||||||
@@ -297,7 +416,7 @@ impl AnthropicProvider {
|
|||||||
let final_chunk = CompletionChunk {
|
let final_chunk = CompletionChunk {
|
||||||
content: String::new(),
|
content: String::new(),
|
||||||
finished: true,
|
finished: true,
|
||||||
tool_calls: None,
|
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
|
||||||
};
|
};
|
||||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||||
debug!("Receiver dropped, stopping stream");
|
debug!("Receiver dropped, stopping stream");
|
||||||
@@ -338,7 +457,7 @@ impl AnthropicProvider {
|
|||||||
let final_chunk = CompletionChunk {
|
let final_chunk = CompletionChunk {
|
||||||
content: String::new(),
|
content: String::new(),
|
||||||
finished: true,
|
finished: true,
|
||||||
tool_calls: None,
|
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) },
|
||||||
};
|
};
|
||||||
let _ = tx.send(Ok(final_chunk)).await;
|
let _ = tx.send(Ok(final_chunk)).await;
|
||||||
}
|
}
|
||||||
@@ -355,7 +474,13 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
|
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
|
||||||
let temperature = request.temperature.unwrap_or(self.temperature);
|
let temperature = request.temperature.unwrap_or(self.temperature);
|
||||||
|
|
||||||
let request_body = self.create_request_body(&request.messages, false, max_tokens, temperature)?;
|
let request_body = self.create_request_body(
|
||||||
|
&request.messages,
|
||||||
|
request.tools.as_deref(),
|
||||||
|
false,
|
||||||
|
max_tokens,
|
||||||
|
temperature
|
||||||
|
)?;
|
||||||
|
|
||||||
debug!("Sending request to Anthropic API: model={}, max_tokens={}, temperature={}",
|
debug!("Sending request to Anthropic API: model={}, max_tokens={}, temperature={}",
|
||||||
request_body.model, request_body.max_tokens, request_body.temperature);
|
request_body.model, request_body.max_tokens, request_body.temperature);
|
||||||
@@ -387,6 +512,7 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
.iter()
|
.iter()
|
||||||
.filter_map(|c| match c {
|
.filter_map(|c| match c {
|
||||||
AnthropicContent::Text { text } => Some(text.as_str()),
|
AnthropicContent::Text { text } => Some(text.as_str()),
|
||||||
|
_ => None,
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join("");
|
.join("");
|
||||||
@@ -418,11 +544,20 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
|
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
|
||||||
let temperature = request.temperature.unwrap_or(self.temperature);
|
let temperature = request.temperature.unwrap_or(self.temperature);
|
||||||
|
|
||||||
let request_body = self.create_request_body(&request.messages, true, max_tokens, temperature)?;
|
let request_body = self.create_request_body(
|
||||||
|
&request.messages,
|
||||||
|
request.tools.as_deref(),
|
||||||
|
true,
|
||||||
|
max_tokens,
|
||||||
|
temperature
|
||||||
|
)?;
|
||||||
|
|
||||||
debug!("Sending streaming request to Anthropic API: model={}, max_tokens={}, temperature={}",
|
debug!("Sending streaming request to Anthropic API: model={}, max_tokens={}, temperature={}",
|
||||||
request_body.model, request_body.max_tokens, request_body.temperature);
|
request_body.model, request_body.max_tokens, request_body.temperature);
|
||||||
|
|
||||||
|
// Debug: Log the full request body
|
||||||
|
debug!("Full request body: {}", serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "Failed to serialize".to_string()));
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.create_request_builder(true)
|
.create_request_builder(true)
|
||||||
.json(&request_body)
|
.json(&request_body)
|
||||||
@@ -475,9 +610,27 @@ struct AnthropicRequest {
|
|||||||
messages: Vec<AnthropicMessage>,
|
messages: Vec<AnthropicMessage>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
system: Option<String>,
|
system: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tools: Option<Vec<AnthropicTool>>,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct AnthropicTool {
|
||||||
|
name: String,
|
||||||
|
description: String,
|
||||||
|
input_schema: AnthropicToolInputSchema,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct AnthropicToolInputSchema {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
schema_type: String,
|
||||||
|
properties: serde_json::Value,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
required: Option<Vec<String>>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct AnthropicMessage {
|
struct AnthropicMessage {
|
||||||
role: String,
|
role: String,
|
||||||
@@ -489,6 +642,12 @@ struct AnthropicMessage {
|
|||||||
enum AnthropicContent {
|
enum AnthropicContent {
|
||||||
#[serde(rename = "text")]
|
#[serde(rename = "text")]
|
||||||
Text { text: String },
|
Text { text: String },
|
||||||
|
#[serde(rename = "tool_use")]
|
||||||
|
ToolUse {
|
||||||
|
id: String,
|
||||||
|
name: String,
|
||||||
|
input: serde_json::Value,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
@@ -520,6 +679,8 @@ struct AnthropicStreamEvent {
|
|||||||
delta: Option<AnthropicDelta>,
|
delta: Option<AnthropicDelta>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
error: Option<AnthropicError>,
|
error: Option<AnthropicError>,
|
||||||
|
#[serde(default)]
|
||||||
|
content_block: Option<AnthropicContent>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
@@ -527,6 +688,7 @@ struct AnthropicDelta {
|
|||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
delta_type: Option<String>,
|
delta_type: Option<String>,
|
||||||
text: Option<String>,
|
text: Option<String>,
|
||||||
|
partial_json: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
@@ -589,7 +751,7 @@ mod tests {
|
|||||||
];
|
];
|
||||||
|
|
||||||
let request_body = provider
|
let request_body = provider
|
||||||
.create_request_body(&messages, false, 1000, 0.5)
|
.create_request_body(&messages, None, false, 1000, 0.5)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(request_body.model, "claude-3-haiku-20240307");
|
assert_eq!(request_body.model, "claude-3-haiku-20240307");
|
||||||
@@ -597,5 +759,70 @@ mod tests {
|
|||||||
assert_eq!(request_body.temperature, 0.5);
|
assert_eq!(request_body.temperature, 0.5);
|
||||||
assert!(!request_body.stream);
|
assert!(!request_body.stream);
|
||||||
assert_eq!(request_body.messages.len(), 1);
|
assert_eq!(request_body.messages.len(), 1);
|
||||||
|
assert!(request_body.tools.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tool_conversion() {
|
||||||
|
let provider = AnthropicProvider::new(
|
||||||
|
"test-key".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
).unwrap();
|
||||||
|
|
||||||
|
let tools = vec![
|
||||||
|
Tool {
|
||||||
|
name: "get_weather".to_string(),
|
||||||
|
description: "Get the current weather".to_string(),
|
||||||
|
input_schema: serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let anthropic_tools = provider.convert_tools(&tools);
|
||||||
|
|
||||||
|
assert_eq!(anthropic_tools.len(), 1);
|
||||||
|
assert_eq!(anthropic_tools[0].name, "get_weather");
|
||||||
|
assert_eq!(anthropic_tools[0].description, "Get the current weather");
|
||||||
|
assert_eq!(anthropic_tools[0].input_schema.schema_type, "object");
|
||||||
|
assert!(anthropic_tools[0].input_schema.required.is_some());
|
||||||
|
assert_eq!(anthropic_tools[0].input_schema.required.as_ref().unwrap()[0], "location");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tool_call_conversion() {
|
||||||
|
let provider = AnthropicProvider::new(
|
||||||
|
"test-key".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
).unwrap();
|
||||||
|
|
||||||
|
let content = vec![
|
||||||
|
AnthropicContent::Text {
|
||||||
|
text: "I'll help you get the weather.".to_string(),
|
||||||
|
},
|
||||||
|
AnthropicContent::ToolUse {
|
||||||
|
id: "toolu_123".to_string(),
|
||||||
|
name: "get_weather".to_string(),
|
||||||
|
input: serde_json::json!({"location": "San Francisco, CA"}),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let tool_calls = provider.convert_anthropic_tool_calls(&content);
|
||||||
|
|
||||||
|
assert_eq!(tool_calls.len(), 1);
|
||||||
|
assert_eq!(tool_calls[0].id, "toolu_123");
|
||||||
|
assert_eq!(tool_calls[0].tool, "get_weather");
|
||||||
|
assert_eq!(tool_calls[0].args["location"], "San Francisco, CA");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ pub struct CompletionRequest {
|
|||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
pub stream: bool,
|
pub stream: bool,
|
||||||
|
pub tools: Option<Vec<Tool>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -75,6 +76,13 @@ pub struct ToolCall {
|
|||||||
pub args: serde_json::Value,
|
pub args: serde_json::Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Tool {
|
||||||
|
pub name: String,
|
||||||
|
pub description: String,
|
||||||
|
pub input_schema: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
pub mod anthropic;
|
pub mod anthropic;
|
||||||
|
|
||||||
pub use anthropic::AnthropicProvider;
|
pub use anthropic::AnthropicProvider;
|
||||||
|
|||||||
Reference in New Issue
Block a user