token counting fixes

This commit is contained in:
Dhanji Prasanna
2025-10-09 12:11:14 +11:00
parent 9d1eef82b9
commit 260c949576
7 changed files with 283 additions and 22 deletions

View File

@@ -269,11 +269,12 @@ impl AnthropicProvider {
&self,
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
tx: mpsc::Sender<Result<CompletionChunk>>,
) {
) -> Option<Usage> {
let mut buffer = String::new();
let mut current_tool_calls: Vec<ToolCall> = Vec::new();
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
let mut accumulated_usage: Option<Usage> = None;
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
@@ -284,7 +285,7 @@ impl AnthropicProvider {
let _ = tx
.send(Err(anyhow!("Invalid UTF-8 in stream chunk: {}", e)))
.await;
return;
return accumulated_usage;
}
};
@@ -306,12 +307,13 @@ impl AnthropicProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
};
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
}
return;
return accumulated_usage;
}
debug!("Raw Claude API JSON: {}", data);
@@ -320,6 +322,19 @@ impl AnthropicProvider {
Ok(event) => {
debug!("Parsed event type: {}, event: {:?}", event.event_type, event);
match event.event_type.as_str() {
"message_start" => {
// Extract usage data from message_start event
if let Some(message) = event.message {
if let Some(usage) = message.usage {
accumulated_usage = Some(Usage {
prompt_tokens: usage.input_tokens,
completion_tokens: usage.output_tokens,
total_tokens: usage.input_tokens + usage.output_tokens,
});
debug!("Captured usage from message_start: {:?}", accumulated_usage);
}
}
}
"content_block_start" => {
debug!("Received content_block_start event: {:?}", event);
if let Some(content_block) = event.content_block {
@@ -342,11 +357,12 @@ impl AnthropicProvider {
let chunk = CompletionChunk {
content: String::new(),
finished: false,
usage: None,
tool_calls: Some(vec![tool_call]),
};
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return;
return accumulated_usage;
}
} else {
// Arguments are empty, we'll accumulate them from partial_json
@@ -368,11 +384,12 @@ impl AnthropicProvider {
let chunk = CompletionChunk {
content: text,
finished: false,
usage: None,
tool_calls: None,
};
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return;
return accumulated_usage;
}
}
// Handle partial JSON for tool calls
@@ -407,11 +424,12 @@ impl AnthropicProvider {
let chunk = CompletionChunk {
content: String::new(),
finished: false,
usage: None,
tool_calls: Some(current_tool_calls.clone()),
};
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return;
return accumulated_usage;
}
}
}
@@ -420,12 +438,13 @@ impl AnthropicProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
};
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
}
return;
return accumulated_usage;
}
"error" => {
if let Some(error) = event.error {
@@ -433,7 +452,7 @@ impl AnthropicProvider {
let _ = tx
.send(Err(anyhow!("Anthropic API error: {:?}", error)))
.await;
return;
return accumulated_usage;
}
}
_ => {
@@ -452,7 +471,7 @@ impl AnthropicProvider {
Err(e) => {
error!("Stream error: {}", e);
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
return;
return accumulated_usage;
}
}
}
@@ -461,9 +480,11 @@ impl AnthropicProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) },
};
let _ = tx.send(Ok(final_chunk)).await;
accumulated_usage
}
}
@@ -584,7 +605,14 @@ impl LLMProvider for AnthropicProvider {
// Spawn task to process the stream
let provider = self.clone();
tokio::spawn(async move {
provider.parse_streaming_response(stream, tx).await;
let usage = provider.parse_streaming_response(stream, tx).await;
// Log the final usage if available
if let Some(usage) = usage {
debug!(
"Stream completed with usage - prompt: {}, completion: {}, total: {}",
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
);
}
});
Ok(ReceiverStream::new(rx))
@@ -679,6 +707,14 @@ struct AnthropicStreamEvent {
error: Option<AnthropicError>,
#[serde(default)]
content_block: Option<AnthropicContent>,
#[serde(default)]
message: Option<AnthropicStreamMessage>,
}
#[derive(Debug, Deserialize)]
struct AnthropicStreamMessage {
#[serde(default)]
usage: Option<AnthropicUsage>,
}
#[derive(Debug, Deserialize)]