Fix for tool use

This commit is contained in:
Dhanji Prasanna
2025-09-20 20:17:50 +10:00
parent 444245d7dd
commit 9a5486f2a8
4 changed files with 385 additions and 43 deletions

View File

@@ -108,7 +108,7 @@ use tracing::{debug, error, info, warn};
use crate::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
MessageRole, Usage,
MessageRole, Tool, ToolCall, Usage,
};
const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
@@ -163,6 +163,51 @@ impl AnthropicProvider {
builder
}
fn convert_tools(&self, tools: &[Tool]) -> Vec<AnthropicTool> {
tools
.iter()
.map(|tool| {
let mut schema = AnthropicToolInputSchema {
schema_type: "object".to_string(),
properties: serde_json::Value::Object(serde_json::Map::new()),
required: None,
};
// Extract properties and required fields from the input schema
if let Ok(schema_obj) = serde_json::from_value::<serde_json::Map<String, serde_json::Value>>(tool.input_schema.clone()) {
if let Some(properties) = schema_obj.get("properties") {
schema.properties = properties.clone();
}
if let Some(required) = schema_obj.get("required") {
if let Ok(required_vec) = serde_json::from_value::<Vec<String>>(required.clone()) {
schema.required = Some(required_vec);
}
}
}
AnthropicTool {
name: tool.name.clone(),
description: tool.description.clone(),
input_schema: schema,
}
})
.collect()
}
fn convert_anthropic_tool_calls(&self, content: &[AnthropicContent]) -> Vec<ToolCall> {
content
.iter()
.filter_map(|c| match c {
AnthropicContent::ToolUse { id, name, input } => Some(ToolCall {
id: id.clone(),
tool: name.clone(),
args: input.clone(),
}),
_ => None,
})
.collect()
}
fn convert_messages(&self, messages: &[Message]) -> Result<(Option<String>, Vec<AnthropicMessage>)> {
let mut system_message = None;
let mut anthropic_messages = Vec::new();
@@ -200,6 +245,7 @@ impl AnthropicProvider {
fn create_request_body(
&self,
messages: &[Message],
tools: Option<&[Tool]>,
streaming: bool,
max_tokens: u32,
temperature: f32,
@@ -210,12 +256,16 @@ impl AnthropicProvider {
return Err(anyhow!("At least one user or assistant message is required"));
}
// Convert tools if provided
let anthropic_tools = tools.map(|t| self.convert_tools(t));
let request = AnthropicRequest {
model: self.model.clone(),
max_tokens,
temperature,
messages: anthropic_messages,
system,
tools: anthropic_tools,
stream: streaming,
};
@@ -233,6 +283,8 @@ impl AnthropicProvider {
tx: mpsc::Sender<Result<CompletionChunk>>,
) {
let mut buffer = String::new();
let mut current_tool_calls: Vec<ToolCall> = Vec::new();
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
while let Some(chunk_result) = stream.next().await {
match chunk_result {
@@ -266,7 +318,7 @@ impl AnthropicProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
tool_calls: None,
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
};
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
@@ -274,9 +326,38 @@ impl AnthropicProvider {
return;
}
debug!("Raw Claude API JSON: {}", data);
match serde_json::from_str::<AnthropicStreamEvent>(data) {
Ok(event) => {
debug!("Parsed event: {:?}", event);
match event.event_type.as_str() {
"content_block_start" => {
debug!("Received content_block_start event: {:?}", event);
if let Some(content_block) = event.content_block {
match content_block {
AnthropicContent::ToolUse { id, name, input } => {
debug!("Found tool use in content_block_start: id={}, name={}, input={:?}", id, name, input);
debug!("Input JSON string: {}", serde_json::to_string(&input).unwrap_or_else(|_| "failed to serialize".to_string()));
// Create initial tool call - we'll update the args later from streaming JSON
let tool_call = ToolCall {
id: id.clone(),
tool: name.clone(),
args: input, // This might be empty initially
};
debug!("Created initial tool call: {:?}", tool_call);
current_tool_calls.push(tool_call);
// Reset partial JSON accumulator for this tool
partial_tool_json.clear();
}
_ => {
debug!("Non-tool content block: {:?}", content_block);
}
}
}
}
"content_block_delta" => {
if let Some(delta) = event.delta {
if let Some(text) = delta.text {
@@ -290,6 +371,44 @@ impl AnthropicProvider {
return;
}
}
// Handle partial JSON for tool calls
if let Some(partial_json) = delta.partial_json {
debug!("Received partial JSON: {}", partial_json);
partial_tool_json.push_str(&partial_json);
debug!("Accumulated tool JSON: {}", partial_tool_json);
}
}
}
"content_block_stop" => {
// Tool call block is complete - now parse the accumulated JSON
if !current_tool_calls.is_empty() && !partial_tool_json.is_empty() {
debug!("Parsing complete tool JSON: {}", partial_tool_json);
// Parse the accumulated JSON and update the last tool call
if let Ok(parsed_args) = serde_json::from_str::<serde_json::Value>(&partial_tool_json) {
if let Some(last_tool) = current_tool_calls.last_mut() {
last_tool.args = parsed_args;
debug!("Updated tool call with complete args: {:?}", last_tool);
}
} else {
debug!("Failed to parse accumulated JSON: {}", partial_tool_json);
}
// Clear the accumulator
partial_tool_json.clear();
}
// Send the complete tool call
if !current_tool_calls.is_empty() {
let chunk = CompletionChunk {
content: String::new(),
finished: false,
tool_calls: Some(current_tool_calls.clone()),
};
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return;
}
}
}
"message_stop" => {
@@ -297,7 +416,7 @@ impl AnthropicProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
tool_calls: None,
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
};
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
@@ -338,7 +457,7 @@ impl AnthropicProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
tool_calls: None,
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) },
};
let _ = tx.send(Ok(final_chunk)).await;
}
@@ -355,7 +474,13 @@ impl LLMProvider for AnthropicProvider {
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
let request_body = self.create_request_body(&request.messages, false, max_tokens, temperature)?;
let request_body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
false,
max_tokens,
temperature
)?;
debug!("Sending request to Anthropic API: model={}, max_tokens={}, temperature={}",
request_body.model, request_body.max_tokens, request_body.temperature);
@@ -387,6 +512,7 @@ impl LLMProvider for AnthropicProvider {
.iter()
.filter_map(|c| match c {
AnthropicContent::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("");
@@ -418,10 +544,19 @@ impl LLMProvider for AnthropicProvider {
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
let request_body = self.create_request_body(&request.messages, true, max_tokens, temperature)?;
let request_body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
true,
max_tokens,
temperature
)?;
debug!("Sending streaming request to Anthropic API: model={}, max_tokens={}, temperature={}",
request_body.model, request_body.max_tokens, request_body.temperature);
// Debug: Log the full request body
debug!("Full request body: {}", serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "Failed to serialize".to_string()));
let response = self
.create_request_builder(true)
@@ -475,9 +610,27 @@ struct AnthropicRequest {
messages: Vec<AnthropicMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<AnthropicTool>>,
stream: bool,
}
#[derive(Debug, Serialize)]
struct AnthropicTool {
name: String,
description: String,
input_schema: AnthropicToolInputSchema,
}
#[derive(Debug, Serialize)]
struct AnthropicToolInputSchema {
#[serde(rename = "type")]
schema_type: String,
properties: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
required: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
struct AnthropicMessage {
role: String,
@@ -489,6 +642,12 @@ struct AnthropicMessage {
enum AnthropicContent {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
}
#[derive(Debug, Deserialize)]
@@ -520,6 +679,8 @@ struct AnthropicStreamEvent {
delta: Option<AnthropicDelta>,
#[serde(default)]
error: Option<AnthropicError>,
#[serde(default)]
content_block: Option<AnthropicContent>,
}
#[derive(Debug, Deserialize)]
@@ -527,6 +688,7 @@ struct AnthropicDelta {
#[serde(rename = "type")]
delta_type: Option<String>,
text: Option<String>,
partial_json: Option<String>,
}
#[derive(Debug, Deserialize)]
@@ -589,7 +751,7 @@ mod tests {
];
let request_body = provider
.create_request_body(&messages, false, 1000, 0.5)
.create_request_body(&messages, None, false, 1000, 0.5)
.unwrap();
assert_eq!(request_body.model, "claude-3-haiku-20240307");
@@ -597,5 +759,70 @@ mod tests {
assert_eq!(request_body.temperature, 0.5);
assert!(!request_body.stream);
assert_eq!(request_body.messages.len(), 1);
assert!(request_body.tools.is_none());
}
#[test]
fn test_tool_conversion() {
let provider = AnthropicProvider::new(
"test-key".to_string(),
None,
None,
None,
).unwrap();
let tools = vec![
Tool {
name: "get_weather".to_string(),
description: "Get the current weather".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}),
},
];
let anthropic_tools = provider.convert_tools(&tools);
assert_eq!(anthropic_tools.len(), 1);
assert_eq!(anthropic_tools[0].name, "get_weather");
assert_eq!(anthropic_tools[0].description, "Get the current weather");
assert_eq!(anthropic_tools[0].input_schema.schema_type, "object");
assert!(anthropic_tools[0].input_schema.required.is_some());
assert_eq!(anthropic_tools[0].input_schema.required.as_ref().unwrap()[0], "location");
}
#[test]
fn test_tool_call_conversion() {
let provider = AnthropicProvider::new(
"test-key".to_string(),
None,
None,
None,
).unwrap();
let content = vec![
AnthropicContent::Text {
text: "I'll help you get the weather.".to_string(),
},
AnthropicContent::ToolUse {
id: "toolu_123".to_string(),
name: "get_weather".to_string(),
input: serde_json::json!({"location": "San Francisco, CA"}),
},
];
let tool_calls = provider.convert_anthropic_tool_calls(&content);
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id, "toolu_123");
assert_eq!(tool_calls[0].tool, "get_weather");
assert_eq!(tool_calls[0].args["location"], "San Francisco, CA");
}
}

View File

@@ -29,6 +29,7 @@ pub struct CompletionRequest {
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub stream: bool,
pub tools: Option<Vec<Tool>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -75,6 +76,13 @@ pub struct ToolCall {
pub args: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
}
pub mod anthropic;
pub use anthropic::AnthropicProvider;