From 260c9495763e9ba2c13f8fb416b2690619a496e8 Mon Sep 17 00:00:00 2001 From: Dhanji Prasanna Date: Thu, 9 Oct 2025 12:11:14 +1100 Subject: [PATCH] token counting fixes --- crates/g3-core/src/error_handling.rs | 1 + crates/g3-core/src/lib.rs | 68 ++++++++- crates/g3-core/tests/test_token_counting.rs | 154 ++++++++++++++++++++ crates/g3-providers/src/anthropic.rs | 58 ++++++-- crates/g3-providers/src/databricks.rs | 18 ++- crates/g3-providers/src/embedded.rs | 5 + crates/g3-providers/src/lib.rs | 1 + 7 files changed, 283 insertions(+), 22 deletions(-) create mode 100644 crates/g3-core/tests/test_token_counting.rs diff --git a/crates/g3-core/src/error_handling.rs b/crates/g3-core/src/error_handling.rs index 3194dfe..9911f42 100644 --- a/crates/g3-core/src/error_handling.rs +++ b/crates/g3-core/src/error_handling.rs @@ -376,6 +376,7 @@ macro_rules! error_context { #[cfg(test)] mod tests { use super::*; + use anyhow::anyhow; #[test] fn test_error_classification() { diff --git a/crates/g3-core/src/lib.rs b/crates/g3-core/src/lib.rs index d4adc3f..04abb9f 100644 --- a/crates/g3-core/src/lib.rs +++ b/crates/g3-core/src/lib.rs @@ -226,6 +226,7 @@ impl StreamingToolParser { pub struct ContextWindow { pub used_tokens: u32, pub total_tokens: u32, + pub cumulative_tokens: u32, // Track cumulative tokens across all interactions pub conversation_history: Vec, } @@ -234,23 +235,49 @@ impl ContextWindow { Self { used_tokens: 0, total_tokens, + cumulative_tokens: 0, conversation_history: Vec::new(), } } pub fn add_message(&mut self, message: Message) { + self.add_message_with_tokens(message, None); + } + + /// Add a message with optional token count from the provider + pub fn add_message_with_tokens(&mut self, message: Message, tokens: Option) { // Skip messages with empty content to avoid API errors if message.content.trim().is_empty() { warn!("Skipping empty message to avoid API error"); return; } - // Better token estimation based on content type - let estimated_tokens = Self::estimate_tokens(&message.content); - self.used_tokens += estimated_tokens; + // Use provided token count if available, otherwise estimate + let token_count = tokens.unwrap_or_else(|| Self::estimate_tokens(&message.content)); + self.used_tokens += token_count; + self.cumulative_tokens += token_count; self.conversation_history.push(message); + + debug!( + "Added message with {} tokens (used: {}/{}, cumulative: {})", + token_count, self.used_tokens, self.total_tokens, self.cumulative_tokens + ); } + /// Update token usage from provider response + pub fn update_usage_from_response(&mut self, usage: &g3_providers::Usage) { + // Update with actual token usage from the provider + // This replaces our estimate with the actual count + let old_used = self.used_tokens; + self.used_tokens = usage.total_tokens; + self.cumulative_tokens = self.cumulative_tokens - old_used + usage.total_tokens; + + debug!( + "Updated token usage from provider: {} -> {} (cumulative: {})", + old_used, self.used_tokens, self.cumulative_tokens + ); + } + /// More accurate token estimation fn estimate_tokens(text: &str) -> u32 { // Better heuristic: @@ -266,8 +293,18 @@ impl ContextWindow { } pub fn update_usage(&mut self, usage: &g3_providers::Usage) { - // Update with actual token usage from the provider - self.used_tokens = usage.total_tokens; + // Deprecated: Use update_usage_from_response instead + self.update_usage_from_response(usage); + } + + /// Update cumulative token usage (for streaming) + pub fn add_streaming_tokens(&mut self, new_tokens: u32) { + self.used_tokens += new_tokens; + self.cumulative_tokens += new_tokens; + debug!( + "Added {} streaming tokens (used: {}/{}, cumulative: {})", + new_tokens, self.used_tokens, self.total_tokens, self.cumulative_tokens + ); } pub fn percentage_used(&self) -> f32 { @@ -1237,6 +1274,7 @@ The tool will execute immediately and you'll receive the result (success or erro let mut chunks_received = 0; let mut raw_chunks: Vec = Vec::new(); // Store raw chunks for debugging let mut _last_error: Option = None; + let mut accumulated_usage: Option = None; while let Some(chunk_result) = stream.next().await { match chunk_result { @@ -1244,6 +1282,15 @@ The tool will execute immediately and you'll receive the result (success or erro // Notify UI about SSE received (including pings) self.ui_writer.notify_sse_received(); + // Capture usage data if available + if let Some(ref usage) = chunk.usage { + accumulated_usage = Some(usage.clone()); + debug!( + "Received usage data - prompt: {}, completion: {}, total: {}", + usage.prompt_tokens, usage.completion_tokens, usage.total_tokens + ); + } + // Store raw chunk for debugging (limit to first 20 and last 5) if chunks_received < 20 || chunk.finished { raw_chunks.push(format!( @@ -1644,6 +1691,17 @@ The tool will execute immediately and you'll receive the result (success or erro } } } + + // Update context window with actual usage if available + if let Some(usage) = accumulated_usage { + debug!("Updating context window with actual usage from stream"); + self.context_window.update_usage_from_response(&usage); + } else { + // Fall back to estimation if no usage data was provided + debug!("No usage data from stream, using estimation"); + let estimated_tokens = ContextWindow::estimate_tokens(¤t_response); + self.context_window.add_streaming_tokens(estimated_tokens); + } // If we get here and no tool was executed, we're done if !tool_executed { diff --git a/crates/g3-core/tests/test_token_counting.rs b/crates/g3-core/tests/test_token_counting.rs new file mode 100644 index 0000000..7dbd269 --- /dev/null +++ b/crates/g3-core/tests/test_token_counting.rs @@ -0,0 +1,154 @@ +use g3_core::ContextWindow; +use g3_providers::{Message, MessageRole, Usage}; + +#[test] +fn test_context_window_with_actual_tokens() { + let mut context = ContextWindow::new(10000); + + // Add a message with known token count + let message = Message { + role: MessageRole::User, + content: "Hello, how are you today?".to_string(), + }; + + // Add with actual token count (let's say this is 7 tokens) + context.add_message_with_tokens(message.clone(), Some(7)); + + assert_eq!(context.used_tokens, 7); + assert_eq!(context.cumulative_tokens, 7); + + // Add another message with estimation (no token count provided) + let message2 = Message { + role: MessageRole::Assistant, + content: "I'm doing well, thank you for asking!".to_string(), + }; + + context.add_message_with_tokens(message2, None); + + // Should have added estimated tokens (roughly 10-11 tokens for this text) + assert!(context.used_tokens > 7); + assert_eq!(context.cumulative_tokens, context.used_tokens); +} + +#[test] +fn test_context_window_update_from_response() { + let mut context = ContextWindow::new(10000); + + // Add initial messages with estimation + let message1 = Message { + role: MessageRole::User, + content: "What is the capital of France?".to_string(), + }; + context.add_message(message1); + + let initial_estimate = context.used_tokens; + let initial_cumulative = context.cumulative_tokens; + + // Now update with actual usage from provider + let usage = Usage { + prompt_tokens: 8, + completion_tokens: 15, + total_tokens: 23, + }; + + context.update_usage_from_response(&usage); + + // Should have replaced estimate with actual + assert_eq!(context.used_tokens, 23); + // Cumulative should be adjusted + assert_eq!(context.cumulative_tokens, context.cumulative_tokens); + assert!(context.cumulative_tokens >= 23); +} + +#[test] +fn test_streaming_token_accumulation() { + let mut context = ContextWindow::new(10000); + + // Simulate streaming tokens being added + context.add_streaming_tokens(5); + assert_eq!(context.used_tokens, 5); + assert_eq!(context.cumulative_tokens, 5); + + context.add_streaming_tokens(3); + assert_eq!(context.used_tokens, 8); + assert_eq!(context.cumulative_tokens, 8); + + context.add_streaming_tokens(7); + assert_eq!(context.used_tokens, 15); + assert_eq!(context.cumulative_tokens, 15); +} + +#[test] +fn test_context_window_percentage_with_actual_tokens() { + let mut context = ContextWindow::new(1000); + + // Add messages with known token counts + let message1 = Message { + role: MessageRole::User, + content: "First message".to_string(), + }; + context.add_message_with_tokens(message1, Some(100)); + + assert_eq!(context.percentage_used(), 10.0); + + let message2 = Message { + role: MessageRole::Assistant, + content: "Second message".to_string(), + }; + context.add_message_with_tokens(message2, Some(400)); + + assert_eq!(context.percentage_used(), 50.0); + + // Test should_summarize threshold (80%) + let message3 = Message { + role: MessageRole::User, + content: "Third message".to_string(), + }; + context.add_message_with_tokens(message3, Some(300)); + + assert_eq!(context.percentage_used(), 80.0); + assert!(context.should_summarize()); +} + +#[test] +fn test_fallback_to_estimation() { + let mut context = ContextWindow::new(10000); + + // Add message without token count (should use estimation) + let message = Message { + role: MessageRole::User, + content: "This is a test message without token count".to_string(), + }; + + context.add_message_with_tokens(message.clone(), None); + + // Should have estimated tokens (roughly 11-12 tokens for this text) + assert!(context.used_tokens > 0); + assert!(context.used_tokens < 20); // Reasonable upper bound + + // Verify estimation is reasonable + let text_len = message.content.len(); + let estimated = context.used_tokens; + let ratio = text_len as f32 / estimated as f32; + + // Should be roughly 3-4 characters per token + assert!(ratio > 2.0 && ratio < 6.0); +} + +#[test] +fn test_empty_message_handling() { + let mut context = ContextWindow::new(10000); + + // Empty messages should be skipped + let empty_message = Message { + role: MessageRole::User, + content: " ".to_string(), // Only whitespace + }; + + context.add_message_with_tokens(empty_message, Some(10)); + + // Should not have added anything + assert_eq!(context.used_tokens, 0); + assert_eq!(context.cumulative_tokens, 0); + assert_eq!(context.conversation_history.len(), 0); +} \ No newline at end of file diff --git a/crates/g3-providers/src/anthropic.rs b/crates/g3-providers/src/anthropic.rs index ef691cd..b84e92e 100644 --- a/crates/g3-providers/src/anthropic.rs +++ b/crates/g3-providers/src/anthropic.rs @@ -269,11 +269,12 @@ impl AnthropicProvider { &self, mut stream: impl futures_util::Stream> + Unpin, tx: mpsc::Sender>, - ) { + ) -> Option { let mut buffer = String::new(); let mut current_tool_calls: Vec = Vec::new(); let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls - + let mut accumulated_usage: Option = None; + while let Some(chunk_result) = stream.next().await { match chunk_result { Ok(chunk) => { @@ -284,7 +285,7 @@ impl AnthropicProvider { let _ = tx .send(Err(anyhow!("Invalid UTF-8 in stream chunk: {}", e))) .await; - return; + return accumulated_usage; } }; @@ -306,12 +307,13 @@ impl AnthropicProvider { let final_chunk = CompletionChunk { content: String::new(), finished: true, + usage: accumulated_usage.clone(), tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) }, }; if tx.send(Ok(final_chunk)).await.is_err() { debug!("Receiver dropped, stopping stream"); } - return; + return accumulated_usage; } debug!("Raw Claude API JSON: {}", data); @@ -320,6 +322,19 @@ impl AnthropicProvider { Ok(event) => { debug!("Parsed event type: {}, event: {:?}", event.event_type, event); match event.event_type.as_str() { + "message_start" => { + // Extract usage data from message_start event + if let Some(message) = event.message { + if let Some(usage) = message.usage { + accumulated_usage = Some(Usage { + prompt_tokens: usage.input_tokens, + completion_tokens: usage.output_tokens, + total_tokens: usage.input_tokens + usage.output_tokens, + }); + debug!("Captured usage from message_start: {:?}", accumulated_usage); + } + } + } "content_block_start" => { debug!("Received content_block_start event: {:?}", event); if let Some(content_block) = event.content_block { @@ -342,11 +357,12 @@ impl AnthropicProvider { let chunk = CompletionChunk { content: String::new(), finished: false, + usage: None, tool_calls: Some(vec![tool_call]), }; if tx.send(Ok(chunk)).await.is_err() { debug!("Receiver dropped, stopping stream"); - return; + return accumulated_usage; } } else { // Arguments are empty, we'll accumulate them from partial_json @@ -368,11 +384,12 @@ impl AnthropicProvider { let chunk = CompletionChunk { content: text, finished: false, + usage: None, tool_calls: None, }; if tx.send(Ok(chunk)).await.is_err() { debug!("Receiver dropped, stopping stream"); - return; + return accumulated_usage; } } // Handle partial JSON for tool calls @@ -407,11 +424,12 @@ impl AnthropicProvider { let chunk = CompletionChunk { content: String::new(), finished: false, + usage: None, tool_calls: Some(current_tool_calls.clone()), }; if tx.send(Ok(chunk)).await.is_err() { debug!("Receiver dropped, stopping stream"); - return; + return accumulated_usage; } } } @@ -420,12 +438,13 @@ impl AnthropicProvider { let final_chunk = CompletionChunk { content: String::new(), finished: true, + usage: accumulated_usage.clone(), tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) }, }; if tx.send(Ok(final_chunk)).await.is_err() { debug!("Receiver dropped, stopping stream"); } - return; + return accumulated_usage; } "error" => { if let Some(error) = event.error { @@ -433,7 +452,7 @@ impl AnthropicProvider { let _ = tx .send(Err(anyhow!("Anthropic API error: {:?}", error))) .await; - return; + return accumulated_usage; } } _ => { @@ -452,7 +471,7 @@ impl AnthropicProvider { Err(e) => { error!("Stream error: {}", e); let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await; - return; + return accumulated_usage; } } } @@ -461,9 +480,11 @@ impl AnthropicProvider { let final_chunk = CompletionChunk { content: String::new(), finished: true, + usage: accumulated_usage.clone(), tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) }, }; let _ = tx.send(Ok(final_chunk)).await; + accumulated_usage } } @@ -584,7 +605,14 @@ impl LLMProvider for AnthropicProvider { // Spawn task to process the stream let provider = self.clone(); tokio::spawn(async move { - provider.parse_streaming_response(stream, tx).await; + let usage = provider.parse_streaming_response(stream, tx).await; + // Log the final usage if available + if let Some(usage) = usage { + debug!( + "Stream completed with usage - prompt: {}, completion: {}, total: {}", + usage.prompt_tokens, usage.completion_tokens, usage.total_tokens + ); + } }); Ok(ReceiverStream::new(rx)) @@ -679,6 +707,14 @@ struct AnthropicStreamEvent { error: Option, #[serde(default)] content_block: Option, + #[serde(default)] + message: Option, +} + +#[derive(Debug, Deserialize)] +struct AnthropicStreamMessage { + #[serde(default)] + usage: Option, } #[derive(Debug, Deserialize)] diff --git a/crates/g3-providers/src/databricks.rs b/crates/g3-providers/src/databricks.rs index 6bae162..c8bad2d 100644 --- a/crates/g3-providers/src/databricks.rs +++ b/crates/g3-providers/src/databricks.rs @@ -291,11 +291,12 @@ impl DatabricksProvider { &self, mut stream: impl futures_util::Stream> + Unpin, tx: mpsc::Sender>, - ) { + ) -> Option { let mut buffer = String::new(); let mut current_tool_calls: std::collections::HashMap = std::collections::HashMap::new(); // index -> (id, name, args) let mut incomplete_data_line = String::new(); // Buffer for incomplete data: lines + let accumulated_usage: Option = None; while let Some(chunk_result) = stream.next().await { match chunk_result { @@ -318,7 +319,7 @@ impl DatabricksProvider { let _ = tx .send(Err(anyhow!("Invalid UTF-8 in stream chunk: {}", e))) .await; - return; + return accumulated_usage; } }; @@ -377,6 +378,7 @@ impl DatabricksProvider { let final_chunk = CompletionChunk { content: String::new(), finished: true, + usage: accumulated_usage.clone(), tool_calls: if final_tool_calls.is_empty() { None } else { @@ -386,7 +388,7 @@ impl DatabricksProvider { if tx.send(Ok(final_chunk)).await.is_err() { debug!("Receiver dropped, stopping stream"); } - return; + return accumulated_usage; } // Debug: Log every raw JSON payload from Databricks API (truncated for large payloads) @@ -410,11 +412,12 @@ impl DatabricksProvider { let chunk = CompletionChunk { content, finished: false, + usage: None, tool_calls: None, }; if tx.send(Ok(chunk)).await.is_err() { debug!("Receiver dropped, stopping stream"); - return; + return accumulated_usage; } } @@ -502,6 +505,7 @@ impl DatabricksProvider { let final_chunk = CompletionChunk { content: String::new(), finished: true, + usage: accumulated_usage.clone(), tool_calls: if final_tool_calls.is_empty() { None } else { @@ -511,7 +515,7 @@ impl DatabricksProvider { if tx.send(Ok(final_chunk)).await.is_err() { debug!("Receiver dropped, stopping stream"); } - return; + return accumulated_usage; } } } @@ -553,7 +557,7 @@ impl DatabricksProvider { Err(e) => { error!("Stream error: {}", e); let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await; - return; + return accumulated_usage; } } } @@ -592,6 +596,7 @@ impl DatabricksProvider { let final_chunk = CompletionChunk { content: String::new(), finished: true, + usage: accumulated_usage.clone(), tool_calls: if final_tool_calls.is_empty() { None } else { @@ -599,6 +604,7 @@ impl DatabricksProvider { }, }; let _ = tx.send(Ok(final_chunk)).await; + accumulated_usage } pub async fn fetch_supported_models(&mut self) -> Result>> { diff --git a/crates/g3-providers/src/embedded.rs b/crates/g3-providers/src/embedded.rs index 7f8b0b8..361f155 100644 --- a/crates/g3-providers/src/embedded.rs +++ b/crates/g3-providers/src/embedded.rs @@ -655,6 +655,7 @@ impl LLMProvider for EmbeddedProvider { let chunk = CompletionChunk { content: remaining_to_send.to_string(), finished: false, + usage: None, tool_calls: None, }; let _ = tx.blocking_send(Ok(chunk)); @@ -681,6 +682,7 @@ impl LLMProvider for EmbeddedProvider { let chunk = CompletionChunk { content: remaining_to_send.to_string(), finished: false, + usage: None, tool_calls: None, }; let _ = tx.blocking_send(Ok(chunk)); @@ -714,6 +716,7 @@ impl LLMProvider for EmbeddedProvider { let chunk = CompletionChunk { content: to_send.to_string(), finished: false, + usage: None, tool_calls: None, }; if tx.blocking_send(Ok(chunk)).is_err() { @@ -729,6 +732,7 @@ impl LLMProvider for EmbeddedProvider { let chunk = CompletionChunk { content: unsent_tokens.clone(), finished: false, + usage: None, tool_calls: None, }; if tx.blocking_send(Ok(chunk)).is_err() { @@ -749,6 +753,7 @@ impl LLMProvider for EmbeddedProvider { let final_chunk = CompletionChunk { content: String::new(), finished: true, + usage: None, // Embedded models calculate usage differently tool_calls: None, }; let _ = tx.blocking_send(Ok(final_chunk)); diff --git a/crates/g3-providers/src/lib.rs b/crates/g3-providers/src/lib.rs index 6dc645f..df3cd6e 100644 --- a/crates/g3-providers/src/lib.rs +++ b/crates/g3-providers/src/lib.rs @@ -67,6 +67,7 @@ pub struct CompletionChunk { pub content: String, pub finished: bool, pub tool_calls: Option>, + pub usage: Option, // Add usage tracking for streaming } #[derive(Debug, Clone, Serialize, Deserialize)]