to get anthropic provider more reliable with tokens

This commit is contained in:
Michael Neale
2025-10-22 09:47:24 +11:00
parent 758e255af8
commit 738c3ac53e
4 changed files with 323 additions and 9 deletions

View File

@@ -319,10 +319,26 @@ impl ContextWindow {
/// Update token usage from provider response
pub fn update_usage_from_response(&mut self, usage: &g3_providers::Usage) {
// Add the tokens from this response to our running total
// The usage.total_tokens represents tokens used in this single API call
self.used_tokens += usage.total_tokens;
self.cumulative_tokens += usage.total_tokens;
// The provider's usage represents the tokens used in the last API call
// We need to be smarter about how we update our running total
let old_used = self.used_tokens;
// If the provider's total is greater than our current count, use it as the authoritative value
// This handles cases where our estimation was off
if usage.total_tokens > self.used_tokens {
self.used_tokens = usage.total_tokens;
self.cumulative_tokens += usage.total_tokens - old_used;
} else {
// Otherwise, add the tokens from this response
self.used_tokens += usage.completion_tokens; // Add only the new completion tokens
self.cumulative_tokens += usage.completion_tokens;
}
info!(
"Updated token usage - was: {}, now: {} (provider reported: prompt={}, completion={}, total={})",
old_used, self.used_tokens, usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
);
debug!(
"Added {} tokens from provider response (used: {}/{}, cumulative: {})",
@@ -429,8 +445,18 @@ Format this as a detailed but concise summary that can be used to resume the con
if current_percentage >= 50 {
let current_threshold = (current_percentage / 10) * 10; // Round down to nearest 10%
if current_threshold > self.last_thinning_percentage && current_threshold <= 80 {
info!(
"Context thinning triggered - usage: {}% ({}/{} tokens), threshold: {}%, last thinned at: {}%",
current_percentage,
self.used_tokens,
self.total_tokens,
current_threshold,
self.last_thinning_percentage
);
return true;
}
} else {
debug!("Context usage at {}% ({}/{} tokens) - no thinning needed", current_percentage, self.used_tokens, self.total_tokens);
}
false

View File

@@ -275,6 +275,7 @@ impl AnthropicProvider {
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
let mut accumulated_usage: Option<Usage> = None;
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
let mut actual_completion_tokens: u32 = 0; // Track actual completion tokens
while let Some(chunk_result) = stream.next().await {
match chunk_result {
@@ -322,7 +323,12 @@ impl AnthropicProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
usage: accumulated_usage.as_ref().map(|u| Usage {
prompt_tokens: u.prompt_tokens,
// Use actual completion tokens if we tracked them, otherwise use the estimate
completion_tokens: if actual_completion_tokens > 0 { actual_completion_tokens } else { u.completion_tokens },
total_tokens: u.prompt_tokens + if actual_completion_tokens > 0 { actual_completion_tokens } else { u.completion_tokens },
}),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
};
if tx.send(Ok(final_chunk)).await.is_err() {
@@ -336,6 +342,7 @@ impl AnthropicProvider {
match serde_json::from_str::<AnthropicStreamEvent>(data) {
Ok(event) => {
debug!("Parsed event type: {}, event: {:?}", event.event_type, event);
match event.event_type.as_str() {
"message_start" => {
// Extract usage data from message_start event
@@ -346,7 +353,10 @@ impl AnthropicProvider {
completion_tokens: usage.output_tokens,
total_tokens: usage.input_tokens + usage.output_tokens,
});
debug!("Captured usage from message_start: {:?}", accumulated_usage);
debug!("Captured initial usage from message_start - prompt: {}, completion: {} (estimated), total: {}",
usage.input_tokens,
usage.output_tokens,
usage.input_tokens + usage.output_tokens);
}
}
}
@@ -395,6 +405,9 @@ impl AnthropicProvider {
"content_block_delta" => {
if let Some(delta) = event.delta {
if let Some(text) = delta.text {
// Track actual completion tokens (rough estimate: 4 chars per token)
actual_completion_tokens += (text.len() as f32 / 4.0).ceil() as u32;
debug!("Sending text chunk of length {}: '{}'", text.len(), text);
let chunk = CompletionChunk {
content: text,
@@ -415,6 +428,19 @@ impl AnthropicProvider {
}
}
}
"message_delta" => {
// Check if message_delta contains updated usage data
if let Some(delta) = event.delta {
if let Some(usage) = delta.usage {
accumulated_usage = Some(Usage {
prompt_tokens: usage.input_tokens,
completion_tokens: usage.output_tokens,
total_tokens: usage.input_tokens + usage.output_tokens,
});
debug!("Updated usage from message_delta - prompt: {}, completion: {}, total: {}", usage.input_tokens, usage.output_tokens, usage.input_tokens + usage.output_tokens);
}
}
}
"content_block_stop" => {
// Tool call block is complete - now parse the accumulated JSON
if !current_tool_calls.is_empty() && !partial_tool_json.is_empty() {
@@ -449,11 +475,44 @@ impl AnthropicProvider {
}
}
"message_stop" => {
debug!("Received message stop event");
debug!("Received message_stop event: {:?}", event);
// Check if message_stop contains final usage data
if let Some(message) = event.message {
if let Some(usage) = message.usage {
// Update with final accurate usage data from message_stop
// This should have the actual completion token count
accumulated_usage = Some(Usage {
prompt_tokens: usage.input_tokens,
// Prefer the actual output_tokens from message_stop if available
// Otherwise use our tracked count, and as last resort the initial estimate
completion_tokens: if usage.output_tokens > 0 {
usage.output_tokens
} else if actual_completion_tokens > 0 {
actual_completion_tokens
} else { usage.output_tokens },
total_tokens: usage.input_tokens + usage.output_tokens,
});
debug!("Updated with final usage from message_stop - prompt: {}, completion: {}, total: {}",
usage.input_tokens,
usage.output_tokens,
usage.input_tokens + usage.output_tokens);
}
}
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
usage: accumulated_usage.as_ref().map(|u| Usage {
prompt_tokens: u.prompt_tokens,
// Use actual completion tokens if we tracked them and they're higher
completion_tokens: if actual_completion_tokens > u.completion_tokens {
actual_completion_tokens
} else {
u.completion_tokens
},
total_tokens: u.prompt_tokens + u32::max(actual_completion_tokens, u.completion_tokens),
}),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
};
if tx.send(Ok(final_chunk)).await.is_err() {
@@ -495,10 +554,27 @@ impl AnthropicProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
usage: accumulated_usage.as_ref().map(|u| Usage {
prompt_tokens: u.prompt_tokens,
completion_tokens: if actual_completion_tokens > u.completion_tokens {
actual_completion_tokens
} else {
u.completion_tokens
},
total_tokens: u.prompt_tokens + u32::max(actual_completion_tokens, u.completion_tokens),
}),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) },
};
let _ = tx.send(Ok(final_chunk)).await;
// Log final usage for debugging
if let Some(ref usage) = accumulated_usage {
info!("Anthropic stream completed with final usage - prompt: {}, completion: {}, total: {}",
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens);
} else {
warn!("Anthropic stream completed without usage data - token accounting will fall back to estimation");
}
accumulated_usage
}
}
@@ -736,6 +812,8 @@ struct AnthropicStreamMessage {
struct AnthropicDelta {
text: Option<String>,
partial_json: Option<String>,
#[serde(default)]
usage: Option<AnthropicUsage>,
}
#[derive(Debug, Deserialize)]

164
test_token_accounting.py Normal file
View File

@@ -0,0 +1,164 @@
#!/usr/bin/env python3
"""
Test script to verify token accounting is working correctly with the Anthropic provider.
This script will send multiple messages and verify that token counts accumulate properly.
"""
import subprocess
import json
import re
import sys
import time
def run_g3_command(prompt, provider="anthropic"):
"""Run a g3 command and capture the output."""
cmd = [
"cargo", "run", "--release", "--",
"--provider", provider,
prompt
]
env = {
"RUST_LOG": "g3_providers=debug,g3_core=info",
"RUST_BACKTRACE": "1"
}
result = subprocess.run(
cmd,
capture_output=True,
text=True,
env={**subprocess.os.environ, **env}
)
return result.stdout + result.stderr
def extract_token_info(output):
"""Extract token usage information from the output."""
token_info = {}
# Look for token usage updates
usage_pattern = r"Updated token usage.*was: (\d+), now: (\d+).*prompt=(\d+), completion=(\d+), total=(\d+)"
matches = re.findall(usage_pattern, output)
if matches:
last_match = matches[-1]
token_info['was'] = int(last_match[0])
token_info['now'] = int(last_match[1])
token_info['prompt'] = int(last_match[2])
token_info['completion'] = int(last_match[3])
token_info['total'] = int(last_match[4])
# Look for context percentage
context_pattern = r"Context usage at (\d+)%.*\((\d+)/(\d+) tokens\)"
matches = re.findall(context_pattern, output)
if matches:
last_match = matches[-1]
token_info['percentage'] = int(last_match[0])
token_info['used'] = int(last_match[1])
token_info['total_context'] = int(last_match[2])
# Look for thinning triggers
thinning_pattern = r"Context thinning triggered.*usage: (\d+)%.*\((\d+)/(\d+) tokens\)"
matches = re.findall(thinning_pattern, output)
if matches:
token_info['thinning_triggered'] = True
token_info['thinning_percentage'] = int(matches[-1][0])
# Look for final usage from Anthropic
final_usage_pattern = r"Anthropic stream completed with final usage.*prompt: (\d+), completion: (\d+), total: (\d+)"
matches = re.findall(final_usage_pattern, output)
if matches:
last_match = matches[-1]
token_info['final_prompt'] = int(last_match[0])
token_info['final_completion'] = int(last_match[1])
token_info['final_total'] = int(last_match[2])
return token_info
def main():
print("Testing Anthropic Provider Token Accounting")
print("="*50)
# Build the project first
print("Building project...")
subprocess.run(["cargo", "build", "--release"], capture_output=True)
# Test 1: Simple prompt
print("\nTest 1: Simple prompt")
print("-"*30)
output = run_g3_command("Say 'Hello, World!' and nothing else.")
tokens = extract_token_info(output)
if tokens:
print(f"Token usage: {tokens.get('now', 'N/A')} tokens")
print(f" Prompt tokens: {tokens.get('prompt', 'N/A')}")
print(f" Completion tokens: {tokens.get('completion', 'N/A')}")
print(f" Total from provider: {tokens.get('total', 'N/A')}")
if 'final_total' in tokens:
print(f" Final total from stream: {tokens['final_total']}")
if tokens.get('now') != tokens['final_total']:
print(f" ⚠️ WARNING: Mismatch between tracked ({tokens.get('now')}) and final ({tokens['final_total']})")
# Check if the completion tokens are reasonable (should be small for "Hello, World!")
if tokens.get('completion', 0) > 50:
print(f" ⚠️ WARNING: Completion tokens seem high for a simple response: {tokens.get('completion')}")
else:
print(" ❌ No token information found in output")
# Test 2: Longer response
print("\nTest 2: Longer response")
print("-"*30)
output = run_g3_command("Write a 3-paragraph essay about the importance of accurate token counting in LLM applications.")
tokens = extract_token_info(output)
if tokens:
print(f"Token usage: {tokens.get('now', 'N/A')} tokens")
print(f" Prompt tokens: {tokens.get('prompt', 'N/A')}")
print(f" Completion tokens: {tokens.get('completion', 'N/A')}")
print(f" Total from provider: {tokens.get('total', 'N/A')}")
if 'final_total' in tokens:
print(f" Final total from stream: {tokens['final_total']}")
if tokens.get('now') != tokens['final_total']:
print(f" ⚠️ WARNING: Mismatch between tracked ({tokens.get('now')}) and final ({tokens['final_total']})")
# Check if completion tokens are reasonable for a longer response
if tokens.get('completion', 0) < 100:
print(f" ⚠️ WARNING: Completion tokens seem low for a 3-paragraph essay: {tokens.get('completion')}")
else:
print(" ❌ No token information found in output")
# Test 3: Check for proper accumulation
print("\nTest 3: Token accumulation (multiple messages)")
print("-"*30)
# First message
output1 = run_g3_command("Count from 1 to 5.")
tokens1 = extract_token_info(output1)
# Second message (this would need to be in a conversation, but for now we test separately)
output2 = run_g3_command("Now count from 6 to 10.")
tokens2 = extract_token_info(output2)
if tokens1 and tokens2:
print(f"First message: {tokens1.get('now', 'N/A')} tokens")
print(f"Second message: {tokens2.get('now', 'N/A')} tokens")
# In a real conversation, tokens2['now'] should be greater than tokens1['now']
# But since these are separate invocations, we just check they're both reasonable
if tokens1.get('now', 0) > 0 and tokens2.get('now', 0) > 0:
print(" ✅ Both messages have token counts")
else:
print(" ❌ Missing token counts")
print("\n" + "="*50)
print("Test Summary:")
print("Check the output above for any warnings or errors.")
print("Key things to verify:")
print(" 1. Token counts are being captured from the provider")
print(" 2. Completion tokens are reasonable for the response length")
print(" 3. No mismatch between tracked and final token counts")
print(" 4. Context thinning triggers at appropriate thresholds")
if __name__ == "__main__":
main()

46
test_token_accounting.sh Executable file
View File

@@ -0,0 +1,46 @@
#!/bin/bash
# Test script to verify token accounting with Anthropic provider
echo "Testing token accounting with Anthropic provider..."
echo "This test will send a few messages and check if token counts are properly tracked."
echo ""
# Set up environment for testing
export RUST_LOG=g3_providers=debug,g3_core=info
export RUST_BACKTRACE=1
# Build the project first
echo "Building project..."
cargo build --release 2>&1 | grep -E "(Compiling|Finished)" || true
echo ""
echo "Running test with Anthropic provider..."
echo "Watch for these log messages:"
echo " - 'Captured initial usage from message_start'"
echo " - 'Updated usage from message_delta' (if available)"
echo " - 'Updated with final usage from message_stop' (if available)"
echo " - 'Anthropic stream completed with final usage'"
echo " - 'Updated token usage from provider'"
echo " - 'Context thinning triggered' (when reaching thresholds)"
echo ""
# Create a simple test that will generate some tokens
cat << 'EOF' > /tmp/test_prompt.txt
Please write a short paragraph about the importance of accurate token counting in LLM applications. Then list 3 reasons why token accounting might fail.
EOF
# Run the test
echo "Sending test prompt..."
cargo run --release -- --provider anthropic "$(cat /tmp/test_prompt.txt)" 2>&1 | tee /tmp/token_test.log
echo ""
echo "Analyzing results..."
echo ""
# Check for token accounting messages
echo "Token accounting messages found:"
grep -E "(usage from|token usage|Context thinning|Context usage)" /tmp/token_test.log | head -20
echo ""
echo "Test complete. Check /tmp/token_test.log for full output."