token counting fixes
This commit is contained in:
@@ -269,11 +269,12 @@ impl AnthropicProvider {
|
||||
&self,
|
||||
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
|
||||
tx: mpsc::Sender<Result<CompletionChunk>>,
|
||||
) {
|
||||
) -> Option<Usage> {
|
||||
let mut buffer = String::new();
|
||||
let mut current_tool_calls: Vec<ToolCall> = Vec::new();
|
||||
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
|
||||
|
||||
let mut accumulated_usage: Option<Usage> = None;
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
@@ -284,7 +285,7 @@ impl AnthropicProvider {
|
||||
let _ = tx
|
||||
.send(Err(anyhow!("Invalid UTF-8 in stream chunk: {}", e)))
|
||||
.await;
|
||||
return;
|
||||
return accumulated_usage;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -306,12 +307,13 @@ impl AnthropicProvider {
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
usage: accumulated_usage.clone(),
|
||||
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
|
||||
};
|
||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
}
|
||||
return;
|
||||
return accumulated_usage;
|
||||
}
|
||||
|
||||
debug!("Raw Claude API JSON: {}", data);
|
||||
@@ -320,6 +322,19 @@ impl AnthropicProvider {
|
||||
Ok(event) => {
|
||||
debug!("Parsed event type: {}, event: {:?}", event.event_type, event);
|
||||
match event.event_type.as_str() {
|
||||
"message_start" => {
|
||||
// Extract usage data from message_start event
|
||||
if let Some(message) = event.message {
|
||||
if let Some(usage) = message.usage {
|
||||
accumulated_usage = Some(Usage {
|
||||
prompt_tokens: usage.input_tokens,
|
||||
completion_tokens: usage.output_tokens,
|
||||
total_tokens: usage.input_tokens + usage.output_tokens,
|
||||
});
|
||||
debug!("Captured usage from message_start: {:?}", accumulated_usage);
|
||||
}
|
||||
}
|
||||
}
|
||||
"content_block_start" => {
|
||||
debug!("Received content_block_start event: {:?}", event);
|
||||
if let Some(content_block) = event.content_block {
|
||||
@@ -342,11 +357,12 @@ impl AnthropicProvider {
|
||||
let chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: false,
|
||||
usage: None,
|
||||
tool_calls: Some(vec![tool_call]),
|
||||
};
|
||||
if tx.send(Ok(chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
return;
|
||||
return accumulated_usage;
|
||||
}
|
||||
} else {
|
||||
// Arguments are empty, we'll accumulate them from partial_json
|
||||
@@ -368,11 +384,12 @@ impl AnthropicProvider {
|
||||
let chunk = CompletionChunk {
|
||||
content: text,
|
||||
finished: false,
|
||||
usage: None,
|
||||
tool_calls: None,
|
||||
};
|
||||
if tx.send(Ok(chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
return;
|
||||
return accumulated_usage;
|
||||
}
|
||||
}
|
||||
// Handle partial JSON for tool calls
|
||||
@@ -407,11 +424,12 @@ impl AnthropicProvider {
|
||||
let chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: false,
|
||||
usage: None,
|
||||
tool_calls: Some(current_tool_calls.clone()),
|
||||
};
|
||||
if tx.send(Ok(chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
return;
|
||||
return accumulated_usage;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -420,12 +438,13 @@ impl AnthropicProvider {
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
usage: accumulated_usage.clone(),
|
||||
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
|
||||
};
|
||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
}
|
||||
return;
|
||||
return accumulated_usage;
|
||||
}
|
||||
"error" => {
|
||||
if let Some(error) = event.error {
|
||||
@@ -433,7 +452,7 @@ impl AnthropicProvider {
|
||||
let _ = tx
|
||||
.send(Err(anyhow!("Anthropic API error: {:?}", error)))
|
||||
.await;
|
||||
return;
|
||||
return accumulated_usage;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
@@ -452,7 +471,7 @@ impl AnthropicProvider {
|
||||
Err(e) => {
|
||||
error!("Stream error: {}", e);
|
||||
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
||||
return;
|
||||
return accumulated_usage;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -461,9 +480,11 @@ impl AnthropicProvider {
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
usage: accumulated_usage.clone(),
|
||||
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) },
|
||||
};
|
||||
let _ = tx.send(Ok(final_chunk)).await;
|
||||
accumulated_usage
|
||||
}
|
||||
}
|
||||
|
||||
@@ -584,7 +605,14 @@ impl LLMProvider for AnthropicProvider {
|
||||
// Spawn task to process the stream
|
||||
let provider = self.clone();
|
||||
tokio::spawn(async move {
|
||||
provider.parse_streaming_response(stream, tx).await;
|
||||
let usage = provider.parse_streaming_response(stream, tx).await;
|
||||
// Log the final usage if available
|
||||
if let Some(usage) = usage {
|
||||
debug!(
|
||||
"Stream completed with usage - prompt: {}, completion: {}, total: {}",
|
||||
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(ReceiverStream::new(rx))
|
||||
@@ -679,6 +707,14 @@ struct AnthropicStreamEvent {
|
||||
error: Option<AnthropicError>,
|
||||
#[serde(default)]
|
||||
content_block: Option<AnthropicContent>,
|
||||
#[serde(default)]
|
||||
message: Option<AnthropicStreamMessage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AnthropicStreamMessage {
|
||||
#[serde(default)]
|
||||
usage: Option<AnthropicUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
||||
@@ -291,11 +291,12 @@ impl DatabricksProvider {
|
||||
&self,
|
||||
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
|
||||
tx: mpsc::Sender<Result<CompletionChunk>>,
|
||||
) {
|
||||
) -> Option<Usage> {
|
||||
let mut buffer = String::new();
|
||||
let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> =
|
||||
std::collections::HashMap::new(); // index -> (id, name, args)
|
||||
let mut incomplete_data_line = String::new(); // Buffer for incomplete data: lines
|
||||
let accumulated_usage: Option<Usage> = None;
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
@@ -318,7 +319,7 @@ impl DatabricksProvider {
|
||||
let _ = tx
|
||||
.send(Err(anyhow!("Invalid UTF-8 in stream chunk: {}", e)))
|
||||
.await;
|
||||
return;
|
||||
return accumulated_usage;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -377,6 +378,7 @@ impl DatabricksProvider {
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
usage: accumulated_usage.clone(),
|
||||
tool_calls: if final_tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
@@ -386,7 +388,7 @@ impl DatabricksProvider {
|
||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
}
|
||||
return;
|
||||
return accumulated_usage;
|
||||
}
|
||||
|
||||
// Debug: Log every raw JSON payload from Databricks API (truncated for large payloads)
|
||||
@@ -410,11 +412,12 @@ impl DatabricksProvider {
|
||||
let chunk = CompletionChunk {
|
||||
content,
|
||||
finished: false,
|
||||
usage: None,
|
||||
tool_calls: None,
|
||||
};
|
||||
if tx.send(Ok(chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
return;
|
||||
return accumulated_usage;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -502,6 +505,7 @@ impl DatabricksProvider {
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
usage: accumulated_usage.clone(),
|
||||
tool_calls: if final_tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
@@ -511,7 +515,7 @@ impl DatabricksProvider {
|
||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
}
|
||||
return;
|
||||
return accumulated_usage;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -553,7 +557,7 @@ impl DatabricksProvider {
|
||||
Err(e) => {
|
||||
error!("Stream error: {}", e);
|
||||
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
||||
return;
|
||||
return accumulated_usage;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -592,6 +596,7 @@ impl DatabricksProvider {
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
usage: accumulated_usage.clone(),
|
||||
tool_calls: if final_tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
@@ -599,6 +604,7 @@ impl DatabricksProvider {
|
||||
},
|
||||
};
|
||||
let _ = tx.send(Ok(final_chunk)).await;
|
||||
accumulated_usage
|
||||
}
|
||||
|
||||
pub async fn fetch_supported_models(&mut self) -> Result<Option<Vec<String>>> {
|
||||
|
||||
@@ -655,6 +655,7 @@ impl LLMProvider for EmbeddedProvider {
|
||||
let chunk = CompletionChunk {
|
||||
content: remaining_to_send.to_string(),
|
||||
finished: false,
|
||||
usage: None,
|
||||
tool_calls: None,
|
||||
};
|
||||
let _ = tx.blocking_send(Ok(chunk));
|
||||
@@ -681,6 +682,7 @@ impl LLMProvider for EmbeddedProvider {
|
||||
let chunk = CompletionChunk {
|
||||
content: remaining_to_send.to_string(),
|
||||
finished: false,
|
||||
usage: None,
|
||||
tool_calls: None,
|
||||
};
|
||||
let _ = tx.blocking_send(Ok(chunk));
|
||||
@@ -714,6 +716,7 @@ impl LLMProvider for EmbeddedProvider {
|
||||
let chunk = CompletionChunk {
|
||||
content: to_send.to_string(),
|
||||
finished: false,
|
||||
usage: None,
|
||||
tool_calls: None,
|
||||
};
|
||||
if tx.blocking_send(Ok(chunk)).is_err() {
|
||||
@@ -729,6 +732,7 @@ impl LLMProvider for EmbeddedProvider {
|
||||
let chunk = CompletionChunk {
|
||||
content: unsent_tokens.clone(),
|
||||
finished: false,
|
||||
usage: None,
|
||||
tool_calls: None,
|
||||
};
|
||||
if tx.blocking_send(Ok(chunk)).is_err() {
|
||||
@@ -749,6 +753,7 @@ impl LLMProvider for EmbeddedProvider {
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
usage: None, // Embedded models calculate usage differently
|
||||
tool_calls: None,
|
||||
};
|
||||
let _ = tx.blocking_send(Ok(final_chunk));
|
||||
|
||||
@@ -67,6 +67,7 @@ pub struct CompletionChunk {
|
||||
pub content: String,
|
||||
pub finished: bool,
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
pub usage: Option<Usage>, // Add usage tracking for streaming
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
||||
Reference in New Issue
Block a user