Fix for tool use
This commit is contained in:
@@ -108,7 +108,7 @@ use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::{
|
||||
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
|
||||
MessageRole, Usage,
|
||||
MessageRole, Tool, ToolCall, Usage,
|
||||
};
|
||||
|
||||
const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
|
||||
@@ -163,6 +163,51 @@ impl AnthropicProvider {
|
||||
builder
|
||||
}
|
||||
|
||||
fn convert_tools(&self, tools: &[Tool]) -> Vec<AnthropicTool> {
|
||||
tools
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
let mut schema = AnthropicToolInputSchema {
|
||||
schema_type: "object".to_string(),
|
||||
properties: serde_json::Value::Object(serde_json::Map::new()),
|
||||
required: None,
|
||||
};
|
||||
|
||||
// Extract properties and required fields from the input schema
|
||||
if let Ok(schema_obj) = serde_json::from_value::<serde_json::Map<String, serde_json::Value>>(tool.input_schema.clone()) {
|
||||
if let Some(properties) = schema_obj.get("properties") {
|
||||
schema.properties = properties.clone();
|
||||
}
|
||||
if let Some(required) = schema_obj.get("required") {
|
||||
if let Ok(required_vec) = serde_json::from_value::<Vec<String>>(required.clone()) {
|
||||
schema.required = Some(required_vec);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AnthropicTool {
|
||||
name: tool.name.clone(),
|
||||
description: tool.description.clone(),
|
||||
input_schema: schema,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn convert_anthropic_tool_calls(&self, content: &[AnthropicContent]) -> Vec<ToolCall> {
|
||||
content
|
||||
.iter()
|
||||
.filter_map(|c| match c {
|
||||
AnthropicContent::ToolUse { id, name, input } => Some(ToolCall {
|
||||
id: id.clone(),
|
||||
tool: name.clone(),
|
||||
args: input.clone(),
|
||||
}),
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn convert_messages(&self, messages: &[Message]) -> Result<(Option<String>, Vec<AnthropicMessage>)> {
|
||||
let mut system_message = None;
|
||||
let mut anthropic_messages = Vec::new();
|
||||
@@ -200,6 +245,7 @@ impl AnthropicProvider {
|
||||
fn create_request_body(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
tools: Option<&[Tool]>,
|
||||
streaming: bool,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
@@ -210,12 +256,16 @@ impl AnthropicProvider {
|
||||
return Err(anyhow!("At least one user or assistant message is required"));
|
||||
}
|
||||
|
||||
// Convert tools if provided
|
||||
let anthropic_tools = tools.map(|t| self.convert_tools(t));
|
||||
|
||||
let request = AnthropicRequest {
|
||||
model: self.model.clone(),
|
||||
max_tokens,
|
||||
temperature,
|
||||
messages: anthropic_messages,
|
||||
system,
|
||||
tools: anthropic_tools,
|
||||
stream: streaming,
|
||||
};
|
||||
|
||||
@@ -233,6 +283,8 @@ impl AnthropicProvider {
|
||||
tx: mpsc::Sender<Result<CompletionChunk>>,
|
||||
) {
|
||||
let mut buffer = String::new();
|
||||
let mut current_tool_calls: Vec<ToolCall> = Vec::new();
|
||||
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
@@ -266,7 +318,7 @@ impl AnthropicProvider {
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
tool_calls: None,
|
||||
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
|
||||
};
|
||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
@@ -274,9 +326,38 @@ impl AnthropicProvider {
|
||||
return;
|
||||
}
|
||||
|
||||
debug!("Raw Claude API JSON: {}", data);
|
||||
|
||||
match serde_json::from_str::<AnthropicStreamEvent>(data) {
|
||||
Ok(event) => {
|
||||
debug!("Parsed event: {:?}", event);
|
||||
match event.event_type.as_str() {
|
||||
"content_block_start" => {
|
||||
debug!("Received content_block_start event: {:?}", event);
|
||||
if let Some(content_block) = event.content_block {
|
||||
match content_block {
|
||||
AnthropicContent::ToolUse { id, name, input } => {
|
||||
debug!("Found tool use in content_block_start: id={}, name={}, input={:?}", id, name, input);
|
||||
debug!("Input JSON string: {}", serde_json::to_string(&input).unwrap_or_else(|_| "failed to serialize".to_string()));
|
||||
|
||||
// Create initial tool call - we'll update the args later from streaming JSON
|
||||
let tool_call = ToolCall {
|
||||
id: id.clone(),
|
||||
tool: name.clone(),
|
||||
args: input, // This might be empty initially
|
||||
};
|
||||
debug!("Created initial tool call: {:?}", tool_call);
|
||||
current_tool_calls.push(tool_call);
|
||||
|
||||
// Reset partial JSON accumulator for this tool
|
||||
partial_tool_json.clear();
|
||||
}
|
||||
_ => {
|
||||
debug!("Non-tool content block: {:?}", content_block);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"content_block_delta" => {
|
||||
if let Some(delta) = event.delta {
|
||||
if let Some(text) = delta.text {
|
||||
@@ -290,6 +371,44 @@ impl AnthropicProvider {
|
||||
return;
|
||||
}
|
||||
}
|
||||
// Handle partial JSON for tool calls
|
||||
if let Some(partial_json) = delta.partial_json {
|
||||
debug!("Received partial JSON: {}", partial_json);
|
||||
partial_tool_json.push_str(&partial_json);
|
||||
debug!("Accumulated tool JSON: {}", partial_tool_json);
|
||||
}
|
||||
}
|
||||
}
|
||||
"content_block_stop" => {
|
||||
// Tool call block is complete - now parse the accumulated JSON
|
||||
if !current_tool_calls.is_empty() && !partial_tool_json.is_empty() {
|
||||
debug!("Parsing complete tool JSON: {}", partial_tool_json);
|
||||
|
||||
// Parse the accumulated JSON and update the last tool call
|
||||
if let Ok(parsed_args) = serde_json::from_str::<serde_json::Value>(&partial_tool_json) {
|
||||
if let Some(last_tool) = current_tool_calls.last_mut() {
|
||||
last_tool.args = parsed_args;
|
||||
debug!("Updated tool call with complete args: {:?}", last_tool);
|
||||
}
|
||||
} else {
|
||||
debug!("Failed to parse accumulated JSON: {}", partial_tool_json);
|
||||
}
|
||||
|
||||
// Clear the accumulator
|
||||
partial_tool_json.clear();
|
||||
}
|
||||
|
||||
// Send the complete tool call
|
||||
if !current_tool_calls.is_empty() {
|
||||
let chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: false,
|
||||
tool_calls: Some(current_tool_calls.clone()),
|
||||
};
|
||||
if tx.send(Ok(chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
"message_stop" => {
|
||||
@@ -297,7 +416,7 @@ impl AnthropicProvider {
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
tool_calls: None,
|
||||
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
|
||||
};
|
||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
@@ -338,7 +457,7 @@ impl AnthropicProvider {
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
tool_calls: None,
|
||||
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) },
|
||||
};
|
||||
let _ = tx.send(Ok(final_chunk)).await;
|
||||
}
|
||||
@@ -355,7 +474,13 @@ impl LLMProvider for AnthropicProvider {
|
||||
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
|
||||
let temperature = request.temperature.unwrap_or(self.temperature);
|
||||
|
||||
let request_body = self.create_request_body(&request.messages, false, max_tokens, temperature)?;
|
||||
let request_body = self.create_request_body(
|
||||
&request.messages,
|
||||
request.tools.as_deref(),
|
||||
false,
|
||||
max_tokens,
|
||||
temperature
|
||||
)?;
|
||||
|
||||
debug!("Sending request to Anthropic API: model={}, max_tokens={}, temperature={}",
|
||||
request_body.model, request_body.max_tokens, request_body.temperature);
|
||||
@@ -387,6 +512,7 @@ impl LLMProvider for AnthropicProvider {
|
||||
.iter()
|
||||
.filter_map(|c| match c {
|
||||
AnthropicContent::Text { text } => Some(text.as_str()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("");
|
||||
@@ -418,10 +544,19 @@ impl LLMProvider for AnthropicProvider {
|
||||
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
|
||||
let temperature = request.temperature.unwrap_or(self.temperature);
|
||||
|
||||
let request_body = self.create_request_body(&request.messages, true, max_tokens, temperature)?;
|
||||
let request_body = self.create_request_body(
|
||||
&request.messages,
|
||||
request.tools.as_deref(),
|
||||
true,
|
||||
max_tokens,
|
||||
temperature
|
||||
)?;
|
||||
|
||||
debug!("Sending streaming request to Anthropic API: model={}, max_tokens={}, temperature={}",
|
||||
request_body.model, request_body.max_tokens, request_body.temperature);
|
||||
|
||||
// Debug: Log the full request body
|
||||
debug!("Full request body: {}", serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "Failed to serialize".to_string()));
|
||||
|
||||
let response = self
|
||||
.create_request_builder(true)
|
||||
@@ -475,9 +610,27 @@ struct AnthropicRequest {
|
||||
messages: Vec<AnthropicMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
system: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<AnthropicTool>>,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct AnthropicTool {
|
||||
name: String,
|
||||
description: String,
|
||||
input_schema: AnthropicToolInputSchema,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct AnthropicToolInputSchema {
|
||||
#[serde(rename = "type")]
|
||||
schema_type: String,
|
||||
properties: serde_json::Value,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
required: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct AnthropicMessage {
|
||||
role: String,
|
||||
@@ -489,6 +642,12 @@ struct AnthropicMessage {
|
||||
enum AnthropicContent {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: serde_json::Value,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -520,6 +679,8 @@ struct AnthropicStreamEvent {
|
||||
delta: Option<AnthropicDelta>,
|
||||
#[serde(default)]
|
||||
error: Option<AnthropicError>,
|
||||
#[serde(default)]
|
||||
content_block: Option<AnthropicContent>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -527,6 +688,7 @@ struct AnthropicDelta {
|
||||
#[serde(rename = "type")]
|
||||
delta_type: Option<String>,
|
||||
text: Option<String>,
|
||||
partial_json: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -589,7 +751,7 @@ mod tests {
|
||||
];
|
||||
|
||||
let request_body = provider
|
||||
.create_request_body(&messages, false, 1000, 0.5)
|
||||
.create_request_body(&messages, None, false, 1000, 0.5)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(request_body.model, "claude-3-haiku-20240307");
|
||||
@@ -597,5 +759,70 @@ mod tests {
|
||||
assert_eq!(request_body.temperature, 0.5);
|
||||
assert!(!request_body.stream);
|
||||
assert_eq!(request_body.messages.len(), 1);
|
||||
assert!(request_body.tools.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_conversion() {
|
||||
let provider = AnthropicProvider::new(
|
||||
"test-key".to_string(),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
).unwrap();
|
||||
|
||||
let tools = vec![
|
||||
Tool {
|
||||
name: "get_weather".to_string(),
|
||||
description: "Get the current weather".to_string(),
|
||||
input_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}),
|
||||
},
|
||||
];
|
||||
|
||||
let anthropic_tools = provider.convert_tools(&tools);
|
||||
|
||||
assert_eq!(anthropic_tools.len(), 1);
|
||||
assert_eq!(anthropic_tools[0].name, "get_weather");
|
||||
assert_eq!(anthropic_tools[0].description, "Get the current weather");
|
||||
assert_eq!(anthropic_tools[0].input_schema.schema_type, "object");
|
||||
assert!(anthropic_tools[0].input_schema.required.is_some());
|
||||
assert_eq!(anthropic_tools[0].input_schema.required.as_ref().unwrap()[0], "location");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_conversion() {
|
||||
let provider = AnthropicProvider::new(
|
||||
"test-key".to_string(),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
).unwrap();
|
||||
|
||||
let content = vec![
|
||||
AnthropicContent::Text {
|
||||
text: "I'll help you get the weather.".to_string(),
|
||||
},
|
||||
AnthropicContent::ToolUse {
|
||||
id: "toolu_123".to_string(),
|
||||
name: "get_weather".to_string(),
|
||||
input: serde_json::json!({"location": "San Francisco, CA"}),
|
||||
},
|
||||
];
|
||||
|
||||
let tool_calls = provider.convert_anthropic_tool_calls(&content);
|
||||
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].id, "toolu_123");
|
||||
assert_eq!(tool_calls[0].tool, "get_weather");
|
||||
assert_eq!(tool_calls[0].args["location"], "San Francisco, CA");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ pub struct CompletionRequest {
|
||||
pub max_tokens: Option<u32>,
|
||||
pub temperature: Option<f32>,
|
||||
pub stream: bool,
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -75,6 +76,13 @@ pub struct ToolCall {
|
||||
pub args: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Tool {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub input_schema: serde_json::Value,
|
||||
}
|
||||
|
||||
pub mod anthropic;
|
||||
|
||||
pub use anthropic::AnthropicProvider;
|
||||
|
||||
Reference in New Issue
Block a user