diff --git a/crates/g3-core/src/lib.rs b/crates/g3-core/src/lib.rs index 93098bf..225eb55 100644 --- a/crates/g3-core/src/lib.rs +++ b/crates/g3-core/src/lib.rs @@ -319,10 +319,26 @@ impl ContextWindow { /// Update token usage from provider response pub fn update_usage_from_response(&mut self, usage: &g3_providers::Usage) { - // Add the tokens from this response to our running total - // The usage.total_tokens represents tokens used in this single API call - self.used_tokens += usage.total_tokens; - self.cumulative_tokens += usage.total_tokens; + // The provider's usage represents the tokens used in the last API call + // We need to be smarter about how we update our running total + + let old_used = self.used_tokens; + + // If the provider's total is greater than our current count, use it as the authoritative value + // This handles cases where our estimation was off + if usage.total_tokens > self.used_tokens { + self.used_tokens = usage.total_tokens; + self.cumulative_tokens += usage.total_tokens - old_used; + } else { + // Otherwise, add the tokens from this response + self.used_tokens += usage.completion_tokens; // Add only the new completion tokens + self.cumulative_tokens += usage.completion_tokens; + } + + info!( + "Updated token usage - was: {}, now: {} (provider reported: prompt={}, completion={}, total={})", + old_used, self.used_tokens, usage.prompt_tokens, usage.completion_tokens, usage.total_tokens + ); debug!( "Added {} tokens from provider response (used: {}/{}, cumulative: {})", @@ -429,8 +445,18 @@ Format this as a detailed but concise summary that can be used to resume the con if current_percentage >= 50 { let current_threshold = (current_percentage / 10) * 10; // Round down to nearest 10% if current_threshold > self.last_thinning_percentage && current_threshold <= 80 { + info!( + "Context thinning triggered - usage: {}% ({}/{} tokens), threshold: {}%, last thinned at: {}%", + current_percentage, + self.used_tokens, + self.total_tokens, + current_threshold, + self.last_thinning_percentage + ); return true; } + } else { + debug!("Context usage at {}% ({}/{} tokens) - no thinning needed", current_percentage, self.used_tokens, self.total_tokens); } false diff --git a/crates/g3-providers/src/anthropic.rs b/crates/g3-providers/src/anthropic.rs index ae140f4..00021f9 100644 --- a/crates/g3-providers/src/anthropic.rs +++ b/crates/g3-providers/src/anthropic.rs @@ -275,6 +275,7 @@ impl AnthropicProvider { let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls let mut accumulated_usage: Option = None; let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences + let mut actual_completion_tokens: u32 = 0; // Track actual completion tokens while let Some(chunk_result) = stream.next().await { match chunk_result { @@ -322,7 +323,12 @@ impl AnthropicProvider { let final_chunk = CompletionChunk { content: String::new(), finished: true, - usage: accumulated_usage.clone(), + usage: accumulated_usage.as_ref().map(|u| Usage { + prompt_tokens: u.prompt_tokens, + // Use actual completion tokens if we tracked them, otherwise use the estimate + completion_tokens: if actual_completion_tokens > 0 { actual_completion_tokens } else { u.completion_tokens }, + total_tokens: u.prompt_tokens + if actual_completion_tokens > 0 { actual_completion_tokens } else { u.completion_tokens }, + }), tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) }, }; if tx.send(Ok(final_chunk)).await.is_err() { @@ -336,6 +342,7 @@ impl AnthropicProvider { match serde_json::from_str::(data) { Ok(event) => { debug!("Parsed event type: {}, event: {:?}", event.event_type, event); + match event.event_type.as_str() { "message_start" => { // Extract usage data from message_start event @@ -346,7 +353,10 @@ impl AnthropicProvider { completion_tokens: usage.output_tokens, total_tokens: usage.input_tokens + usage.output_tokens, }); - debug!("Captured usage from message_start: {:?}", accumulated_usage); + debug!("Captured initial usage from message_start - prompt: {}, completion: {} (estimated), total: {}", + usage.input_tokens, + usage.output_tokens, + usage.input_tokens + usage.output_tokens); } } } @@ -395,6 +405,9 @@ impl AnthropicProvider { "content_block_delta" => { if let Some(delta) = event.delta { if let Some(text) = delta.text { + // Track actual completion tokens (rough estimate: 4 chars per token) + actual_completion_tokens += (text.len() as f32 / 4.0).ceil() as u32; + debug!("Sending text chunk of length {}: '{}'", text.len(), text); let chunk = CompletionChunk { content: text, @@ -415,6 +428,19 @@ impl AnthropicProvider { } } } + "message_delta" => { + // Check if message_delta contains updated usage data + if let Some(delta) = event.delta { + if let Some(usage) = delta.usage { + accumulated_usage = Some(Usage { + prompt_tokens: usage.input_tokens, + completion_tokens: usage.output_tokens, + total_tokens: usage.input_tokens + usage.output_tokens, + }); + debug!("Updated usage from message_delta - prompt: {}, completion: {}, total: {}", usage.input_tokens, usage.output_tokens, usage.input_tokens + usage.output_tokens); + } + } + } "content_block_stop" => { // Tool call block is complete - now parse the accumulated JSON if !current_tool_calls.is_empty() && !partial_tool_json.is_empty() { @@ -449,11 +475,44 @@ impl AnthropicProvider { } } "message_stop" => { - debug!("Received message stop event"); + debug!("Received message_stop event: {:?}", event); + + // Check if message_stop contains final usage data + if let Some(message) = event.message { + if let Some(usage) = message.usage { + // Update with final accurate usage data from message_stop + // This should have the actual completion token count + accumulated_usage = Some(Usage { + prompt_tokens: usage.input_tokens, + // Prefer the actual output_tokens from message_stop if available + // Otherwise use our tracked count, and as last resort the initial estimate + completion_tokens: if usage.output_tokens > 0 { + usage.output_tokens + } else if actual_completion_tokens > 0 { + actual_completion_tokens + } else { usage.output_tokens }, + total_tokens: usage.input_tokens + usage.output_tokens, + }); + debug!("Updated with final usage from message_stop - prompt: {}, completion: {}, total: {}", + usage.input_tokens, + usage.output_tokens, + usage.input_tokens + usage.output_tokens); + } + } + let final_chunk = CompletionChunk { content: String::new(), finished: true, - usage: accumulated_usage.clone(), + usage: accumulated_usage.as_ref().map(|u| Usage { + prompt_tokens: u.prompt_tokens, + // Use actual completion tokens if we tracked them and they're higher + completion_tokens: if actual_completion_tokens > u.completion_tokens { + actual_completion_tokens + } else { + u.completion_tokens + }, + total_tokens: u.prompt_tokens + u32::max(actual_completion_tokens, u.completion_tokens), + }), tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) }, }; if tx.send(Ok(final_chunk)).await.is_err() { @@ -495,10 +554,27 @@ impl AnthropicProvider { let final_chunk = CompletionChunk { content: String::new(), finished: true, - usage: accumulated_usage.clone(), + usage: accumulated_usage.as_ref().map(|u| Usage { + prompt_tokens: u.prompt_tokens, + completion_tokens: if actual_completion_tokens > u.completion_tokens { + actual_completion_tokens + } else { + u.completion_tokens + }, + total_tokens: u.prompt_tokens + u32::max(actual_completion_tokens, u.completion_tokens), + }), tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) }, }; let _ = tx.send(Ok(final_chunk)).await; + + // Log final usage for debugging + if let Some(ref usage) = accumulated_usage { + info!("Anthropic stream completed with final usage - prompt: {}, completion: {}, total: {}", + usage.prompt_tokens, usage.completion_tokens, usage.total_tokens); + } else { + warn!("Anthropic stream completed without usage data - token accounting will fall back to estimation"); + } + accumulated_usage } } @@ -736,6 +812,8 @@ struct AnthropicStreamMessage { struct AnthropicDelta { text: Option, partial_json: Option, + #[serde(default)] + usage: Option, } #[derive(Debug, Deserialize)] diff --git a/test_token_accounting.py b/test_token_accounting.py new file mode 100644 index 0000000..46a444e --- /dev/null +++ b/test_token_accounting.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +""" +Test script to verify token accounting is working correctly with the Anthropic provider. +This script will send multiple messages and verify that token counts accumulate properly. +""" + +import subprocess +import json +import re +import sys +import time + +def run_g3_command(prompt, provider="anthropic"): + """Run a g3 command and capture the output.""" + cmd = [ + "cargo", "run", "--release", "--", + "--provider", provider, + prompt + ] + + env = { + "RUST_LOG": "g3_providers=debug,g3_core=info", + "RUST_BACKTRACE": "1" + } + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + env={**subprocess.os.environ, **env} + ) + + return result.stdout + result.stderr + +def extract_token_info(output): + """Extract token usage information from the output.""" + token_info = {} + + # Look for token usage updates + usage_pattern = r"Updated token usage.*was: (\d+), now: (\d+).*prompt=(\d+), completion=(\d+), total=(\d+)" + matches = re.findall(usage_pattern, output) + if matches: + last_match = matches[-1] + token_info['was'] = int(last_match[0]) + token_info['now'] = int(last_match[1]) + token_info['prompt'] = int(last_match[2]) + token_info['completion'] = int(last_match[3]) + token_info['total'] = int(last_match[4]) + + # Look for context percentage + context_pattern = r"Context usage at (\d+)%.*\((\d+)/(\d+) tokens\)" + matches = re.findall(context_pattern, output) + if matches: + last_match = matches[-1] + token_info['percentage'] = int(last_match[0]) + token_info['used'] = int(last_match[1]) + token_info['total_context'] = int(last_match[2]) + + # Look for thinning triggers + thinning_pattern = r"Context thinning triggered.*usage: (\d+)%.*\((\d+)/(\d+) tokens\)" + matches = re.findall(thinning_pattern, output) + if matches: + token_info['thinning_triggered'] = True + token_info['thinning_percentage'] = int(matches[-1][0]) + + # Look for final usage from Anthropic + final_usage_pattern = r"Anthropic stream completed with final usage.*prompt: (\d+), completion: (\d+), total: (\d+)" + matches = re.findall(final_usage_pattern, output) + if matches: + last_match = matches[-1] + token_info['final_prompt'] = int(last_match[0]) + token_info['final_completion'] = int(last_match[1]) + token_info['final_total'] = int(last_match[2]) + + return token_info + +def main(): + print("Testing Anthropic Provider Token Accounting") + print("="*50) + + # Build the project first + print("Building project...") + subprocess.run(["cargo", "build", "--release"], capture_output=True) + + # Test 1: Simple prompt + print("\nTest 1: Simple prompt") + print("-"*30) + output = run_g3_command("Say 'Hello, World!' and nothing else.") + tokens = extract_token_info(output) + + if tokens: + print(f"Token usage: {tokens.get('now', 'N/A')} tokens") + print(f" Prompt tokens: {tokens.get('prompt', 'N/A')}") + print(f" Completion tokens: {tokens.get('completion', 'N/A')}") + print(f" Total from provider: {tokens.get('total', 'N/A')}") + + if 'final_total' in tokens: + print(f" Final total from stream: {tokens['final_total']}") + if tokens.get('now') != tokens['final_total']: + print(f" ⚠️ WARNING: Mismatch between tracked ({tokens.get('now')}) and final ({tokens['final_total']})") + + # Check if the completion tokens are reasonable (should be small for "Hello, World!") + if tokens.get('completion', 0) > 50: + print(f" ⚠️ WARNING: Completion tokens seem high for a simple response: {tokens.get('completion')}") + else: + print(" ❌ No token information found in output") + + # Test 2: Longer response + print("\nTest 2: Longer response") + print("-"*30) + output = run_g3_command("Write a 3-paragraph essay about the importance of accurate token counting in LLM applications.") + tokens = extract_token_info(output) + + if tokens: + print(f"Token usage: {tokens.get('now', 'N/A')} tokens") + print(f" Prompt tokens: {tokens.get('prompt', 'N/A')}") + print(f" Completion tokens: {tokens.get('completion', 'N/A')}") + print(f" Total from provider: {tokens.get('total', 'N/A')}") + + if 'final_total' in tokens: + print(f" Final total from stream: {tokens['final_total']}") + if tokens.get('now') != tokens['final_total']: + print(f" ⚠️ WARNING: Mismatch between tracked ({tokens.get('now')}) and final ({tokens['final_total']})") + + # Check if completion tokens are reasonable for a longer response + if tokens.get('completion', 0) < 100: + print(f" ⚠️ WARNING: Completion tokens seem low for a 3-paragraph essay: {tokens.get('completion')}") + else: + print(" ❌ No token information found in output") + + # Test 3: Check for proper accumulation + print("\nTest 3: Token accumulation (multiple messages)") + print("-"*30) + + # First message + output1 = run_g3_command("Count from 1 to 5.") + tokens1 = extract_token_info(output1) + + # Second message (this would need to be in a conversation, but for now we test separately) + output2 = run_g3_command("Now count from 6 to 10.") + tokens2 = extract_token_info(output2) + + if tokens1 and tokens2: + print(f"First message: {tokens1.get('now', 'N/A')} tokens") + print(f"Second message: {tokens2.get('now', 'N/A')} tokens") + + # In a real conversation, tokens2['now'] should be greater than tokens1['now'] + # But since these are separate invocations, we just check they're both reasonable + if tokens1.get('now', 0) > 0 and tokens2.get('now', 0) > 0: + print(" ✅ Both messages have token counts") + else: + print(" ❌ Missing token counts") + + print("\n" + "="*50) + print("Test Summary:") + print("Check the output above for any warnings or errors.") + print("Key things to verify:") + print(" 1. Token counts are being captured from the provider") + print(" 2. Completion tokens are reasonable for the response length") + print(" 3. No mismatch between tracked and final token counts") + print(" 4. Context thinning triggers at appropriate thresholds") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test_token_accounting.sh b/test_token_accounting.sh new file mode 100755 index 0000000..23b07bd --- /dev/null +++ b/test_token_accounting.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# Test script to verify token accounting with Anthropic provider + +echo "Testing token accounting with Anthropic provider..." +echo "This test will send a few messages and check if token counts are properly tracked." +echo "" + +# Set up environment for testing +export RUST_LOG=g3_providers=debug,g3_core=info +export RUST_BACKTRACE=1 + +# Build the project first +echo "Building project..." +cargo build --release 2>&1 | grep -E "(Compiling|Finished)" || true + +echo "" +echo "Running test with Anthropic provider..." +echo "Watch for these log messages:" +echo " - 'Captured initial usage from message_start'" +echo " - 'Updated usage from message_delta' (if available)" +echo " - 'Updated with final usage from message_stop' (if available)" +echo " - 'Anthropic stream completed with final usage'" +echo " - 'Updated token usage from provider'" +echo " - 'Context thinning triggered' (when reaching thresholds)" +echo "" + +# Create a simple test that will generate some tokens +cat << 'EOF' > /tmp/test_prompt.txt +Please write a short paragraph about the importance of accurate token counting in LLM applications. Then list 3 reasons why token accounting might fail. +EOF + +# Run the test +echo "Sending test prompt..." +cargo run --release -- --provider anthropic "$(cat /tmp/test_prompt.txt)" 2>&1 | tee /tmp/token_test.log + +echo "" +echo "Analyzing results..." +echo "" + +# Check for token accounting messages +echo "Token accounting messages found:" +grep -E "(usage from|token usage|Context thinning|Context usage)" /tmp/token_test.log | head -20 + +echo "" +echo "Test complete. Check /tmp/token_test.log for full output." \ No newline at end of file