token counting fixes

This commit is contained in:
Dhanji Prasanna
2025-10-09 12:11:14 +11:00
parent 9d1eef82b9
commit 260c949576
7 changed files with 283 additions and 22 deletions

View File

@@ -376,6 +376,7 @@ macro_rules! error_context {
#[cfg(test)]
mod tests {
use super::*;
use anyhow::anyhow;
#[test]
fn test_error_classification() {

View File

@@ -226,6 +226,7 @@ impl StreamingToolParser {
pub struct ContextWindow {
pub used_tokens: u32,
pub total_tokens: u32,
pub cumulative_tokens: u32, // Track cumulative tokens across all interactions
pub conversation_history: Vec<Message>,
}
@@ -234,23 +235,49 @@ impl ContextWindow {
Self {
used_tokens: 0,
total_tokens,
cumulative_tokens: 0,
conversation_history: Vec::new(),
}
}
pub fn add_message(&mut self, message: Message) {
self.add_message_with_tokens(message, None);
}
/// Add a message with optional token count from the provider
pub fn add_message_with_tokens(&mut self, message: Message, tokens: Option<u32>) {
// Skip messages with empty content to avoid API errors
if message.content.trim().is_empty() {
warn!("Skipping empty message to avoid API error");
return;
}
// Better token estimation based on content type
let estimated_tokens = Self::estimate_tokens(&message.content);
self.used_tokens += estimated_tokens;
// Use provided token count if available, otherwise estimate
let token_count = tokens.unwrap_or_else(|| Self::estimate_tokens(&message.content));
self.used_tokens += token_count;
self.cumulative_tokens += token_count;
self.conversation_history.push(message);
debug!(
"Added message with {} tokens (used: {}/{}, cumulative: {})",
token_count, self.used_tokens, self.total_tokens, self.cumulative_tokens
);
}
/// Update token usage from provider response
pub fn update_usage_from_response(&mut self, usage: &g3_providers::Usage) {
// Update with actual token usage from the provider
// This replaces our estimate with the actual count
let old_used = self.used_tokens;
self.used_tokens = usage.total_tokens;
self.cumulative_tokens = self.cumulative_tokens - old_used + usage.total_tokens;
debug!(
"Updated token usage from provider: {} -> {} (cumulative: {})",
old_used, self.used_tokens, self.cumulative_tokens
);
}
/// More accurate token estimation
fn estimate_tokens(text: &str) -> u32 {
// Better heuristic:
@@ -266,8 +293,18 @@ impl ContextWindow {
}
pub fn update_usage(&mut self, usage: &g3_providers::Usage) {
// Update with actual token usage from the provider
self.used_tokens = usage.total_tokens;
// Deprecated: Use update_usage_from_response instead
self.update_usage_from_response(usage);
}
/// Update cumulative token usage (for streaming)
pub fn add_streaming_tokens(&mut self, new_tokens: u32) {
self.used_tokens += new_tokens;
self.cumulative_tokens += new_tokens;
debug!(
"Added {} streaming tokens (used: {}/{}, cumulative: {})",
new_tokens, self.used_tokens, self.total_tokens, self.cumulative_tokens
);
}
pub fn percentage_used(&self) -> f32 {
@@ -1237,6 +1274,7 @@ The tool will execute immediately and you'll receive the result (success or erro
let mut chunks_received = 0;
let mut raw_chunks: Vec<String> = Vec::new(); // Store raw chunks for debugging
let mut _last_error: Option<String> = None;
let mut accumulated_usage: Option<g3_providers::Usage> = None;
while let Some(chunk_result) = stream.next().await {
match chunk_result {
@@ -1244,6 +1282,15 @@ The tool will execute immediately and you'll receive the result (success or erro
// Notify UI about SSE received (including pings)
self.ui_writer.notify_sse_received();
// Capture usage data if available
if let Some(ref usage) = chunk.usage {
accumulated_usage = Some(usage.clone());
debug!(
"Received usage data - prompt: {}, completion: {}, total: {}",
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
);
}
// Store raw chunk for debugging (limit to first 20 and last 5)
if chunks_received < 20 || chunk.finished {
raw_chunks.push(format!(
@@ -1644,6 +1691,17 @@ The tool will execute immediately and you'll receive the result (success or erro
}
}
}
// Update context window with actual usage if available
if let Some(usage) = accumulated_usage {
debug!("Updating context window with actual usage from stream");
self.context_window.update_usage_from_response(&usage);
} else {
// Fall back to estimation if no usage data was provided
debug!("No usage data from stream, using estimation");
let estimated_tokens = ContextWindow::estimate_tokens(&current_response);
self.context_window.add_streaming_tokens(estimated_tokens);
}
// If we get here and no tool was executed, we're done
if !tool_executed {

View File

@@ -0,0 +1,154 @@
use g3_core::ContextWindow;
use g3_providers::{Message, MessageRole, Usage};
#[test]
fn test_context_window_with_actual_tokens() {
let mut context = ContextWindow::new(10000);
// Add a message with known token count
let message = Message {
role: MessageRole::User,
content: "Hello, how are you today?".to_string(),
};
// Add with actual token count (let's say this is 7 tokens)
context.add_message_with_tokens(message.clone(), Some(7));
assert_eq!(context.used_tokens, 7);
assert_eq!(context.cumulative_tokens, 7);
// Add another message with estimation (no token count provided)
let message2 = Message {
role: MessageRole::Assistant,
content: "I'm doing well, thank you for asking!".to_string(),
};
context.add_message_with_tokens(message2, None);
// Should have added estimated tokens (roughly 10-11 tokens for this text)
assert!(context.used_tokens > 7);
assert_eq!(context.cumulative_tokens, context.used_tokens);
}
#[test]
fn test_context_window_update_from_response() {
let mut context = ContextWindow::new(10000);
// Add initial messages with estimation
let message1 = Message {
role: MessageRole::User,
content: "What is the capital of France?".to_string(),
};
context.add_message(message1);
let initial_estimate = context.used_tokens;
let initial_cumulative = context.cumulative_tokens;
// Now update with actual usage from provider
let usage = Usage {
prompt_tokens: 8,
completion_tokens: 15,
total_tokens: 23,
};
context.update_usage_from_response(&usage);
// Should have replaced estimate with actual
assert_eq!(context.used_tokens, 23);
// Cumulative should be adjusted
assert_eq!(context.cumulative_tokens, context.cumulative_tokens);
assert!(context.cumulative_tokens >= 23);
}
#[test]
fn test_streaming_token_accumulation() {
let mut context = ContextWindow::new(10000);
// Simulate streaming tokens being added
context.add_streaming_tokens(5);
assert_eq!(context.used_tokens, 5);
assert_eq!(context.cumulative_tokens, 5);
context.add_streaming_tokens(3);
assert_eq!(context.used_tokens, 8);
assert_eq!(context.cumulative_tokens, 8);
context.add_streaming_tokens(7);
assert_eq!(context.used_tokens, 15);
assert_eq!(context.cumulative_tokens, 15);
}
#[test]
fn test_context_window_percentage_with_actual_tokens() {
let mut context = ContextWindow::new(1000);
// Add messages with known token counts
let message1 = Message {
role: MessageRole::User,
content: "First message".to_string(),
};
context.add_message_with_tokens(message1, Some(100));
assert_eq!(context.percentage_used(), 10.0);
let message2 = Message {
role: MessageRole::Assistant,
content: "Second message".to_string(),
};
context.add_message_with_tokens(message2, Some(400));
assert_eq!(context.percentage_used(), 50.0);
// Test should_summarize threshold (80%)
let message3 = Message {
role: MessageRole::User,
content: "Third message".to_string(),
};
context.add_message_with_tokens(message3, Some(300));
assert_eq!(context.percentage_used(), 80.0);
assert!(context.should_summarize());
}
#[test]
fn test_fallback_to_estimation() {
let mut context = ContextWindow::new(10000);
// Add message without token count (should use estimation)
let message = Message {
role: MessageRole::User,
content: "This is a test message without token count".to_string(),
};
context.add_message_with_tokens(message.clone(), None);
// Should have estimated tokens (roughly 11-12 tokens for this text)
assert!(context.used_tokens > 0);
assert!(context.used_tokens < 20); // Reasonable upper bound
// Verify estimation is reasonable
let text_len = message.content.len();
let estimated = context.used_tokens;
let ratio = text_len as f32 / estimated as f32;
// Should be roughly 3-4 characters per token
assert!(ratio > 2.0 && ratio < 6.0);
}
#[test]
fn test_empty_message_handling() {
let mut context = ContextWindow::new(10000);
// Empty messages should be skipped
let empty_message = Message {
role: MessageRole::User,
content: " ".to_string(), // Only whitespace
};
context.add_message_with_tokens(empty_message, Some(10));
// Should not have added anything
assert_eq!(context.used_tokens, 0);
assert_eq!(context.cumulative_tokens, 0);
assert_eq!(context.conversation_history.len(), 0);
}