token counting fixes
This commit is contained in:
@@ -376,6 +376,7 @@ macro_rules! error_context {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use anyhow::anyhow;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_error_classification() {
|
fn test_error_classification() {
|
||||||
|
|||||||
@@ -226,6 +226,7 @@ impl StreamingToolParser {
|
|||||||
pub struct ContextWindow {
|
pub struct ContextWindow {
|
||||||
pub used_tokens: u32,
|
pub used_tokens: u32,
|
||||||
pub total_tokens: u32,
|
pub total_tokens: u32,
|
||||||
|
pub cumulative_tokens: u32, // Track cumulative tokens across all interactions
|
||||||
pub conversation_history: Vec<Message>,
|
pub conversation_history: Vec<Message>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -234,21 +235,47 @@ impl ContextWindow {
|
|||||||
Self {
|
Self {
|
||||||
used_tokens: 0,
|
used_tokens: 0,
|
||||||
total_tokens,
|
total_tokens,
|
||||||
|
cumulative_tokens: 0,
|
||||||
conversation_history: Vec::new(),
|
conversation_history: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_message(&mut self, message: Message) {
|
pub fn add_message(&mut self, message: Message) {
|
||||||
|
self.add_message_with_tokens(message, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a message with optional token count from the provider
|
||||||
|
pub fn add_message_with_tokens(&mut self, message: Message, tokens: Option<u32>) {
|
||||||
// Skip messages with empty content to avoid API errors
|
// Skip messages with empty content to avoid API errors
|
||||||
if message.content.trim().is_empty() {
|
if message.content.trim().is_empty() {
|
||||||
warn!("Skipping empty message to avoid API error");
|
warn!("Skipping empty message to avoid API error");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Better token estimation based on content type
|
// Use provided token count if available, otherwise estimate
|
||||||
let estimated_tokens = Self::estimate_tokens(&message.content);
|
let token_count = tokens.unwrap_or_else(|| Self::estimate_tokens(&message.content));
|
||||||
self.used_tokens += estimated_tokens;
|
self.used_tokens += token_count;
|
||||||
|
self.cumulative_tokens += token_count;
|
||||||
self.conversation_history.push(message);
|
self.conversation_history.push(message);
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Added message with {} tokens (used: {}/{}, cumulative: {})",
|
||||||
|
token_count, self.used_tokens, self.total_tokens, self.cumulative_tokens
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update token usage from provider response
|
||||||
|
pub fn update_usage_from_response(&mut self, usage: &g3_providers::Usage) {
|
||||||
|
// Update with actual token usage from the provider
|
||||||
|
// This replaces our estimate with the actual count
|
||||||
|
let old_used = self.used_tokens;
|
||||||
|
self.used_tokens = usage.total_tokens;
|
||||||
|
self.cumulative_tokens = self.cumulative_tokens - old_used + usage.total_tokens;
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Updated token usage from provider: {} -> {} (cumulative: {})",
|
||||||
|
old_used, self.used_tokens, self.cumulative_tokens
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// More accurate token estimation
|
/// More accurate token estimation
|
||||||
@@ -266,8 +293,18 @@ impl ContextWindow {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn update_usage(&mut self, usage: &g3_providers::Usage) {
|
pub fn update_usage(&mut self, usage: &g3_providers::Usage) {
|
||||||
// Update with actual token usage from the provider
|
// Deprecated: Use update_usage_from_response instead
|
||||||
self.used_tokens = usage.total_tokens;
|
self.update_usage_from_response(usage);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update cumulative token usage (for streaming)
|
||||||
|
pub fn add_streaming_tokens(&mut self, new_tokens: u32) {
|
||||||
|
self.used_tokens += new_tokens;
|
||||||
|
self.cumulative_tokens += new_tokens;
|
||||||
|
debug!(
|
||||||
|
"Added {} streaming tokens (used: {}/{}, cumulative: {})",
|
||||||
|
new_tokens, self.used_tokens, self.total_tokens, self.cumulative_tokens
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn percentage_used(&self) -> f32 {
|
pub fn percentage_used(&self) -> f32 {
|
||||||
@@ -1237,6 +1274,7 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
let mut chunks_received = 0;
|
let mut chunks_received = 0;
|
||||||
let mut raw_chunks: Vec<String> = Vec::new(); // Store raw chunks for debugging
|
let mut raw_chunks: Vec<String> = Vec::new(); // Store raw chunks for debugging
|
||||||
let mut _last_error: Option<String> = None;
|
let mut _last_error: Option<String> = None;
|
||||||
|
let mut accumulated_usage: Option<g3_providers::Usage> = None;
|
||||||
|
|
||||||
while let Some(chunk_result) = stream.next().await {
|
while let Some(chunk_result) = stream.next().await {
|
||||||
match chunk_result {
|
match chunk_result {
|
||||||
@@ -1244,6 +1282,15 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
// Notify UI about SSE received (including pings)
|
// Notify UI about SSE received (including pings)
|
||||||
self.ui_writer.notify_sse_received();
|
self.ui_writer.notify_sse_received();
|
||||||
|
|
||||||
|
// Capture usage data if available
|
||||||
|
if let Some(ref usage) = chunk.usage {
|
||||||
|
accumulated_usage = Some(usage.clone());
|
||||||
|
debug!(
|
||||||
|
"Received usage data - prompt: {}, completion: {}, total: {}",
|
||||||
|
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// Store raw chunk for debugging (limit to first 20 and last 5)
|
// Store raw chunk for debugging (limit to first 20 and last 5)
|
||||||
if chunks_received < 20 || chunk.finished {
|
if chunks_received < 20 || chunk.finished {
|
||||||
raw_chunks.push(format!(
|
raw_chunks.push(format!(
|
||||||
@@ -1645,6 +1692,17 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update context window with actual usage if available
|
||||||
|
if let Some(usage) = accumulated_usage {
|
||||||
|
debug!("Updating context window with actual usage from stream");
|
||||||
|
self.context_window.update_usage_from_response(&usage);
|
||||||
|
} else {
|
||||||
|
// Fall back to estimation if no usage data was provided
|
||||||
|
debug!("No usage data from stream, using estimation");
|
||||||
|
let estimated_tokens = ContextWindow::estimate_tokens(¤t_response);
|
||||||
|
self.context_window.add_streaming_tokens(estimated_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
// If we get here and no tool was executed, we're done
|
// If we get here and no tool was executed, we're done
|
||||||
if !tool_executed {
|
if !tool_executed {
|
||||||
// Check if we have any text in the parser that wasn't added to current_response
|
// Check if we have any text in the parser that wasn't added to current_response
|
||||||
|
|||||||
154
crates/g3-core/tests/test_token_counting.rs
Normal file
154
crates/g3-core/tests/test_token_counting.rs
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
use g3_core::ContextWindow;
|
||||||
|
use g3_providers::{Message, MessageRole, Usage};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_context_window_with_actual_tokens() {
|
||||||
|
let mut context = ContextWindow::new(10000);
|
||||||
|
|
||||||
|
// Add a message with known token count
|
||||||
|
let message = Message {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: "Hello, how are you today?".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add with actual token count (let's say this is 7 tokens)
|
||||||
|
context.add_message_with_tokens(message.clone(), Some(7));
|
||||||
|
|
||||||
|
assert_eq!(context.used_tokens, 7);
|
||||||
|
assert_eq!(context.cumulative_tokens, 7);
|
||||||
|
|
||||||
|
// Add another message with estimation (no token count provided)
|
||||||
|
let message2 = Message {
|
||||||
|
role: MessageRole::Assistant,
|
||||||
|
content: "I'm doing well, thank you for asking!".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
context.add_message_with_tokens(message2, None);
|
||||||
|
|
||||||
|
// Should have added estimated tokens (roughly 10-11 tokens for this text)
|
||||||
|
assert!(context.used_tokens > 7);
|
||||||
|
assert_eq!(context.cumulative_tokens, context.used_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_context_window_update_from_response() {
|
||||||
|
let mut context = ContextWindow::new(10000);
|
||||||
|
|
||||||
|
// Add initial messages with estimation
|
||||||
|
let message1 = Message {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: "What is the capital of France?".to_string(),
|
||||||
|
};
|
||||||
|
context.add_message(message1);
|
||||||
|
|
||||||
|
let initial_estimate = context.used_tokens;
|
||||||
|
let initial_cumulative = context.cumulative_tokens;
|
||||||
|
|
||||||
|
// Now update with actual usage from provider
|
||||||
|
let usage = Usage {
|
||||||
|
prompt_tokens: 8,
|
||||||
|
completion_tokens: 15,
|
||||||
|
total_tokens: 23,
|
||||||
|
};
|
||||||
|
|
||||||
|
context.update_usage_from_response(&usage);
|
||||||
|
|
||||||
|
// Should have replaced estimate with actual
|
||||||
|
assert_eq!(context.used_tokens, 23);
|
||||||
|
// Cumulative should be adjusted
|
||||||
|
assert_eq!(context.cumulative_tokens, context.cumulative_tokens);
|
||||||
|
assert!(context.cumulative_tokens >= 23);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_streaming_token_accumulation() {
|
||||||
|
let mut context = ContextWindow::new(10000);
|
||||||
|
|
||||||
|
// Simulate streaming tokens being added
|
||||||
|
context.add_streaming_tokens(5);
|
||||||
|
assert_eq!(context.used_tokens, 5);
|
||||||
|
assert_eq!(context.cumulative_tokens, 5);
|
||||||
|
|
||||||
|
context.add_streaming_tokens(3);
|
||||||
|
assert_eq!(context.used_tokens, 8);
|
||||||
|
assert_eq!(context.cumulative_tokens, 8);
|
||||||
|
|
||||||
|
context.add_streaming_tokens(7);
|
||||||
|
assert_eq!(context.used_tokens, 15);
|
||||||
|
assert_eq!(context.cumulative_tokens, 15);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_context_window_percentage_with_actual_tokens() {
|
||||||
|
let mut context = ContextWindow::new(1000);
|
||||||
|
|
||||||
|
// Add messages with known token counts
|
||||||
|
let message1 = Message {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: "First message".to_string(),
|
||||||
|
};
|
||||||
|
context.add_message_with_tokens(message1, Some(100));
|
||||||
|
|
||||||
|
assert_eq!(context.percentage_used(), 10.0);
|
||||||
|
|
||||||
|
let message2 = Message {
|
||||||
|
role: MessageRole::Assistant,
|
||||||
|
content: "Second message".to_string(),
|
||||||
|
};
|
||||||
|
context.add_message_with_tokens(message2, Some(400));
|
||||||
|
|
||||||
|
assert_eq!(context.percentage_used(), 50.0);
|
||||||
|
|
||||||
|
// Test should_summarize threshold (80%)
|
||||||
|
let message3 = Message {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: "Third message".to_string(),
|
||||||
|
};
|
||||||
|
context.add_message_with_tokens(message3, Some(300));
|
||||||
|
|
||||||
|
assert_eq!(context.percentage_used(), 80.0);
|
||||||
|
assert!(context.should_summarize());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_fallback_to_estimation() {
|
||||||
|
let mut context = ContextWindow::new(10000);
|
||||||
|
|
||||||
|
// Add message without token count (should use estimation)
|
||||||
|
let message = Message {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: "This is a test message without token count".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
context.add_message_with_tokens(message.clone(), None);
|
||||||
|
|
||||||
|
// Should have estimated tokens (roughly 11-12 tokens for this text)
|
||||||
|
assert!(context.used_tokens > 0);
|
||||||
|
assert!(context.used_tokens < 20); // Reasonable upper bound
|
||||||
|
|
||||||
|
// Verify estimation is reasonable
|
||||||
|
let text_len = message.content.len();
|
||||||
|
let estimated = context.used_tokens;
|
||||||
|
let ratio = text_len as f32 / estimated as f32;
|
||||||
|
|
||||||
|
// Should be roughly 3-4 characters per token
|
||||||
|
assert!(ratio > 2.0 && ratio < 6.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_empty_message_handling() {
|
||||||
|
let mut context = ContextWindow::new(10000);
|
||||||
|
|
||||||
|
// Empty messages should be skipped
|
||||||
|
let empty_message = Message {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: " ".to_string(), // Only whitespace
|
||||||
|
};
|
||||||
|
|
||||||
|
context.add_message_with_tokens(empty_message, Some(10));
|
||||||
|
|
||||||
|
// Should not have added anything
|
||||||
|
assert_eq!(context.used_tokens, 0);
|
||||||
|
assert_eq!(context.cumulative_tokens, 0);
|
||||||
|
assert_eq!(context.conversation_history.len(), 0);
|
||||||
|
}
|
||||||
@@ -269,10 +269,11 @@ impl AnthropicProvider {
|
|||||||
&self,
|
&self,
|
||||||
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
|
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
|
||||||
tx: mpsc::Sender<Result<CompletionChunk>>,
|
tx: mpsc::Sender<Result<CompletionChunk>>,
|
||||||
) {
|
) -> Option<Usage> {
|
||||||
let mut buffer = String::new();
|
let mut buffer = String::new();
|
||||||
let mut current_tool_calls: Vec<ToolCall> = Vec::new();
|
let mut current_tool_calls: Vec<ToolCall> = Vec::new();
|
||||||
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
|
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
|
||||||
|
let mut accumulated_usage: Option<Usage> = None;
|
||||||
|
|
||||||
while let Some(chunk_result) = stream.next().await {
|
while let Some(chunk_result) = stream.next().await {
|
||||||
match chunk_result {
|
match chunk_result {
|
||||||
@@ -284,7 +285,7 @@ impl AnthropicProvider {
|
|||||||
let _ = tx
|
let _ = tx
|
||||||
.send(Err(anyhow!("Invalid UTF-8 in stream chunk: {}", e)))
|
.send(Err(anyhow!("Invalid UTF-8 in stream chunk: {}", e)))
|
||||||
.await;
|
.await;
|
||||||
return;
|
return accumulated_usage;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -306,12 +307,13 @@ impl AnthropicProvider {
|
|||||||
let final_chunk = CompletionChunk {
|
let final_chunk = CompletionChunk {
|
||||||
content: String::new(),
|
content: String::new(),
|
||||||
finished: true,
|
finished: true,
|
||||||
|
usage: accumulated_usage.clone(),
|
||||||
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
|
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
|
||||||
};
|
};
|
||||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||||
debug!("Receiver dropped, stopping stream");
|
debug!("Receiver dropped, stopping stream");
|
||||||
}
|
}
|
||||||
return;
|
return accumulated_usage;
|
||||||
}
|
}
|
||||||
|
|
||||||
debug!("Raw Claude API JSON: {}", data);
|
debug!("Raw Claude API JSON: {}", data);
|
||||||
@@ -320,6 +322,19 @@ impl AnthropicProvider {
|
|||||||
Ok(event) => {
|
Ok(event) => {
|
||||||
debug!("Parsed event type: {}, event: {:?}", event.event_type, event);
|
debug!("Parsed event type: {}, event: {:?}", event.event_type, event);
|
||||||
match event.event_type.as_str() {
|
match event.event_type.as_str() {
|
||||||
|
"message_start" => {
|
||||||
|
// Extract usage data from message_start event
|
||||||
|
if let Some(message) = event.message {
|
||||||
|
if let Some(usage) = message.usage {
|
||||||
|
accumulated_usage = Some(Usage {
|
||||||
|
prompt_tokens: usage.input_tokens,
|
||||||
|
completion_tokens: usage.output_tokens,
|
||||||
|
total_tokens: usage.input_tokens + usage.output_tokens,
|
||||||
|
});
|
||||||
|
debug!("Captured usage from message_start: {:?}", accumulated_usage);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
"content_block_start" => {
|
"content_block_start" => {
|
||||||
debug!("Received content_block_start event: {:?}", event);
|
debug!("Received content_block_start event: {:?}", event);
|
||||||
if let Some(content_block) = event.content_block {
|
if let Some(content_block) = event.content_block {
|
||||||
@@ -342,11 +357,12 @@ impl AnthropicProvider {
|
|||||||
let chunk = CompletionChunk {
|
let chunk = CompletionChunk {
|
||||||
content: String::new(),
|
content: String::new(),
|
||||||
finished: false,
|
finished: false,
|
||||||
|
usage: None,
|
||||||
tool_calls: Some(vec![tool_call]),
|
tool_calls: Some(vec![tool_call]),
|
||||||
};
|
};
|
||||||
if tx.send(Ok(chunk)).await.is_err() {
|
if tx.send(Ok(chunk)).await.is_err() {
|
||||||
debug!("Receiver dropped, stopping stream");
|
debug!("Receiver dropped, stopping stream");
|
||||||
return;
|
return accumulated_usage;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Arguments are empty, we'll accumulate them from partial_json
|
// Arguments are empty, we'll accumulate them from partial_json
|
||||||
@@ -368,11 +384,12 @@ impl AnthropicProvider {
|
|||||||
let chunk = CompletionChunk {
|
let chunk = CompletionChunk {
|
||||||
content: text,
|
content: text,
|
||||||
finished: false,
|
finished: false,
|
||||||
|
usage: None,
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
};
|
};
|
||||||
if tx.send(Ok(chunk)).await.is_err() {
|
if tx.send(Ok(chunk)).await.is_err() {
|
||||||
debug!("Receiver dropped, stopping stream");
|
debug!("Receiver dropped, stopping stream");
|
||||||
return;
|
return accumulated_usage;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Handle partial JSON for tool calls
|
// Handle partial JSON for tool calls
|
||||||
@@ -407,11 +424,12 @@ impl AnthropicProvider {
|
|||||||
let chunk = CompletionChunk {
|
let chunk = CompletionChunk {
|
||||||
content: String::new(),
|
content: String::new(),
|
||||||
finished: false,
|
finished: false,
|
||||||
|
usage: None,
|
||||||
tool_calls: Some(current_tool_calls.clone()),
|
tool_calls: Some(current_tool_calls.clone()),
|
||||||
};
|
};
|
||||||
if tx.send(Ok(chunk)).await.is_err() {
|
if tx.send(Ok(chunk)).await.is_err() {
|
||||||
debug!("Receiver dropped, stopping stream");
|
debug!("Receiver dropped, stopping stream");
|
||||||
return;
|
return accumulated_usage;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -420,12 +438,13 @@ impl AnthropicProvider {
|
|||||||
let final_chunk = CompletionChunk {
|
let final_chunk = CompletionChunk {
|
||||||
content: String::new(),
|
content: String::new(),
|
||||||
finished: true,
|
finished: true,
|
||||||
|
usage: accumulated_usage.clone(),
|
||||||
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
|
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
|
||||||
};
|
};
|
||||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||||
debug!("Receiver dropped, stopping stream");
|
debug!("Receiver dropped, stopping stream");
|
||||||
}
|
}
|
||||||
return;
|
return accumulated_usage;
|
||||||
}
|
}
|
||||||
"error" => {
|
"error" => {
|
||||||
if let Some(error) = event.error {
|
if let Some(error) = event.error {
|
||||||
@@ -433,7 +452,7 @@ impl AnthropicProvider {
|
|||||||
let _ = tx
|
let _ = tx
|
||||||
.send(Err(anyhow!("Anthropic API error: {:?}", error)))
|
.send(Err(anyhow!("Anthropic API error: {:?}", error)))
|
||||||
.await;
|
.await;
|
||||||
return;
|
return accumulated_usage;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
@@ -452,7 +471,7 @@ impl AnthropicProvider {
|
|||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Stream error: {}", e);
|
error!("Stream error: {}", e);
|
||||||
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
||||||
return;
|
return accumulated_usage;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -461,9 +480,11 @@ impl AnthropicProvider {
|
|||||||
let final_chunk = CompletionChunk {
|
let final_chunk = CompletionChunk {
|
||||||
content: String::new(),
|
content: String::new(),
|
||||||
finished: true,
|
finished: true,
|
||||||
|
usage: accumulated_usage.clone(),
|
||||||
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) },
|
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) },
|
||||||
};
|
};
|
||||||
let _ = tx.send(Ok(final_chunk)).await;
|
let _ = tx.send(Ok(final_chunk)).await;
|
||||||
|
accumulated_usage
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -584,7 +605,14 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
// Spawn task to process the stream
|
// Spawn task to process the stream
|
||||||
let provider = self.clone();
|
let provider = self.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
provider.parse_streaming_response(stream, tx).await;
|
let usage = provider.parse_streaming_response(stream, tx).await;
|
||||||
|
// Log the final usage if available
|
||||||
|
if let Some(usage) = usage {
|
||||||
|
debug!(
|
||||||
|
"Stream completed with usage - prompt: {}, completion: {}, total: {}",
|
||||||
|
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
|
||||||
|
);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
Ok(ReceiverStream::new(rx))
|
Ok(ReceiverStream::new(rx))
|
||||||
@@ -679,6 +707,14 @@ struct AnthropicStreamEvent {
|
|||||||
error: Option<AnthropicError>,
|
error: Option<AnthropicError>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
content_block: Option<AnthropicContent>,
|
content_block: Option<AnthropicContent>,
|
||||||
|
#[serde(default)]
|
||||||
|
message: Option<AnthropicStreamMessage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct AnthropicStreamMessage {
|
||||||
|
#[serde(default)]
|
||||||
|
usage: Option<AnthropicUsage>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
|||||||
@@ -291,11 +291,12 @@ impl DatabricksProvider {
|
|||||||
&self,
|
&self,
|
||||||
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
|
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
|
||||||
tx: mpsc::Sender<Result<CompletionChunk>>,
|
tx: mpsc::Sender<Result<CompletionChunk>>,
|
||||||
) {
|
) -> Option<Usage> {
|
||||||
let mut buffer = String::new();
|
let mut buffer = String::new();
|
||||||
let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> =
|
let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> =
|
||||||
std::collections::HashMap::new(); // index -> (id, name, args)
|
std::collections::HashMap::new(); // index -> (id, name, args)
|
||||||
let mut incomplete_data_line = String::new(); // Buffer for incomplete data: lines
|
let mut incomplete_data_line = String::new(); // Buffer for incomplete data: lines
|
||||||
|
let accumulated_usage: Option<Usage> = None;
|
||||||
|
|
||||||
while let Some(chunk_result) = stream.next().await {
|
while let Some(chunk_result) = stream.next().await {
|
||||||
match chunk_result {
|
match chunk_result {
|
||||||
@@ -318,7 +319,7 @@ impl DatabricksProvider {
|
|||||||
let _ = tx
|
let _ = tx
|
||||||
.send(Err(anyhow!("Invalid UTF-8 in stream chunk: {}", e)))
|
.send(Err(anyhow!("Invalid UTF-8 in stream chunk: {}", e)))
|
||||||
.await;
|
.await;
|
||||||
return;
|
return accumulated_usage;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -377,6 +378,7 @@ impl DatabricksProvider {
|
|||||||
let final_chunk = CompletionChunk {
|
let final_chunk = CompletionChunk {
|
||||||
content: String::new(),
|
content: String::new(),
|
||||||
finished: true,
|
finished: true,
|
||||||
|
usage: accumulated_usage.clone(),
|
||||||
tool_calls: if final_tool_calls.is_empty() {
|
tool_calls: if final_tool_calls.is_empty() {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
@@ -386,7 +388,7 @@ impl DatabricksProvider {
|
|||||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||||
debug!("Receiver dropped, stopping stream");
|
debug!("Receiver dropped, stopping stream");
|
||||||
}
|
}
|
||||||
return;
|
return accumulated_usage;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Debug: Log every raw JSON payload from Databricks API (truncated for large payloads)
|
// Debug: Log every raw JSON payload from Databricks API (truncated for large payloads)
|
||||||
@@ -410,11 +412,12 @@ impl DatabricksProvider {
|
|||||||
let chunk = CompletionChunk {
|
let chunk = CompletionChunk {
|
||||||
content,
|
content,
|
||||||
finished: false,
|
finished: false,
|
||||||
|
usage: None,
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
};
|
};
|
||||||
if tx.send(Ok(chunk)).await.is_err() {
|
if tx.send(Ok(chunk)).await.is_err() {
|
||||||
debug!("Receiver dropped, stopping stream");
|
debug!("Receiver dropped, stopping stream");
|
||||||
return;
|
return accumulated_usage;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -502,6 +505,7 @@ impl DatabricksProvider {
|
|||||||
let final_chunk = CompletionChunk {
|
let final_chunk = CompletionChunk {
|
||||||
content: String::new(),
|
content: String::new(),
|
||||||
finished: true,
|
finished: true,
|
||||||
|
usage: accumulated_usage.clone(),
|
||||||
tool_calls: if final_tool_calls.is_empty() {
|
tool_calls: if final_tool_calls.is_empty() {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
@@ -511,7 +515,7 @@ impl DatabricksProvider {
|
|||||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||||
debug!("Receiver dropped, stopping stream");
|
debug!("Receiver dropped, stopping stream");
|
||||||
}
|
}
|
||||||
return;
|
return accumulated_usage;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -553,7 +557,7 @@ impl DatabricksProvider {
|
|||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Stream error: {}", e);
|
error!("Stream error: {}", e);
|
||||||
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
||||||
return;
|
return accumulated_usage;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -592,6 +596,7 @@ impl DatabricksProvider {
|
|||||||
let final_chunk = CompletionChunk {
|
let final_chunk = CompletionChunk {
|
||||||
content: String::new(),
|
content: String::new(),
|
||||||
finished: true,
|
finished: true,
|
||||||
|
usage: accumulated_usage.clone(),
|
||||||
tool_calls: if final_tool_calls.is_empty() {
|
tool_calls: if final_tool_calls.is_empty() {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
@@ -599,6 +604,7 @@ impl DatabricksProvider {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
let _ = tx.send(Ok(final_chunk)).await;
|
let _ = tx.send(Ok(final_chunk)).await;
|
||||||
|
accumulated_usage
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn fetch_supported_models(&mut self) -> Result<Option<Vec<String>>> {
|
pub async fn fetch_supported_models(&mut self) -> Result<Option<Vec<String>>> {
|
||||||
|
|||||||
@@ -655,6 +655,7 @@ impl LLMProvider for EmbeddedProvider {
|
|||||||
let chunk = CompletionChunk {
|
let chunk = CompletionChunk {
|
||||||
content: remaining_to_send.to_string(),
|
content: remaining_to_send.to_string(),
|
||||||
finished: false,
|
finished: false,
|
||||||
|
usage: None,
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
};
|
};
|
||||||
let _ = tx.blocking_send(Ok(chunk));
|
let _ = tx.blocking_send(Ok(chunk));
|
||||||
@@ -681,6 +682,7 @@ impl LLMProvider for EmbeddedProvider {
|
|||||||
let chunk = CompletionChunk {
|
let chunk = CompletionChunk {
|
||||||
content: remaining_to_send.to_string(),
|
content: remaining_to_send.to_string(),
|
||||||
finished: false,
|
finished: false,
|
||||||
|
usage: None,
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
};
|
};
|
||||||
let _ = tx.blocking_send(Ok(chunk));
|
let _ = tx.blocking_send(Ok(chunk));
|
||||||
@@ -714,6 +716,7 @@ impl LLMProvider for EmbeddedProvider {
|
|||||||
let chunk = CompletionChunk {
|
let chunk = CompletionChunk {
|
||||||
content: to_send.to_string(),
|
content: to_send.to_string(),
|
||||||
finished: false,
|
finished: false,
|
||||||
|
usage: None,
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
};
|
};
|
||||||
if tx.blocking_send(Ok(chunk)).is_err() {
|
if tx.blocking_send(Ok(chunk)).is_err() {
|
||||||
@@ -729,6 +732,7 @@ impl LLMProvider for EmbeddedProvider {
|
|||||||
let chunk = CompletionChunk {
|
let chunk = CompletionChunk {
|
||||||
content: unsent_tokens.clone(),
|
content: unsent_tokens.clone(),
|
||||||
finished: false,
|
finished: false,
|
||||||
|
usage: None,
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
};
|
};
|
||||||
if tx.blocking_send(Ok(chunk)).is_err() {
|
if tx.blocking_send(Ok(chunk)).is_err() {
|
||||||
@@ -749,6 +753,7 @@ impl LLMProvider for EmbeddedProvider {
|
|||||||
let final_chunk = CompletionChunk {
|
let final_chunk = CompletionChunk {
|
||||||
content: String::new(),
|
content: String::new(),
|
||||||
finished: true,
|
finished: true,
|
||||||
|
usage: None, // Embedded models calculate usage differently
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
};
|
};
|
||||||
let _ = tx.blocking_send(Ok(final_chunk));
|
let _ = tx.blocking_send(Ok(final_chunk));
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ pub struct CompletionChunk {
|
|||||||
pub content: String,
|
pub content: String,
|
||||||
pub finished: bool,
|
pub finished: bool,
|
||||||
pub tool_calls: Option<Vec<ToolCall>>,
|
pub tool_calls: Option<Vec<ToolCall>>,
|
||||||
|
pub usage: Option<Usage>, // Add usage tracking for streaming
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|||||||
Reference in New Issue
Block a user