to get anthropic provider more reliable with tokens
This commit is contained in:
@@ -319,10 +319,26 @@ impl ContextWindow {
|
||||
|
||||
/// Update token usage from provider response
|
||||
pub fn update_usage_from_response(&mut self, usage: &g3_providers::Usage) {
|
||||
// Add the tokens from this response to our running total
|
||||
// The usage.total_tokens represents tokens used in this single API call
|
||||
self.used_tokens += usage.total_tokens;
|
||||
self.cumulative_tokens += usage.total_tokens;
|
||||
// The provider's usage represents the tokens used in the last API call
|
||||
// We need to be smarter about how we update our running total
|
||||
|
||||
let old_used = self.used_tokens;
|
||||
|
||||
// If the provider's total is greater than our current count, use it as the authoritative value
|
||||
// This handles cases where our estimation was off
|
||||
if usage.total_tokens > self.used_tokens {
|
||||
self.used_tokens = usage.total_tokens;
|
||||
self.cumulative_tokens += usage.total_tokens - old_used;
|
||||
} else {
|
||||
// Otherwise, add the tokens from this response
|
||||
self.used_tokens += usage.completion_tokens; // Add only the new completion tokens
|
||||
self.cumulative_tokens += usage.completion_tokens;
|
||||
}
|
||||
|
||||
info!(
|
||||
"Updated token usage - was: {}, now: {} (provider reported: prompt={}, completion={}, total={})",
|
||||
old_used, self.used_tokens, usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
|
||||
);
|
||||
|
||||
debug!(
|
||||
"Added {} tokens from provider response (used: {}/{}, cumulative: {})",
|
||||
@@ -429,8 +445,18 @@ Format this as a detailed but concise summary that can be used to resume the con
|
||||
if current_percentage >= 50 {
|
||||
let current_threshold = (current_percentage / 10) * 10; // Round down to nearest 10%
|
||||
if current_threshold > self.last_thinning_percentage && current_threshold <= 80 {
|
||||
info!(
|
||||
"Context thinning triggered - usage: {}% ({}/{} tokens), threshold: {}%, last thinned at: {}%",
|
||||
current_percentage,
|
||||
self.used_tokens,
|
||||
self.total_tokens,
|
||||
current_threshold,
|
||||
self.last_thinning_percentage
|
||||
);
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
debug!("Context usage at {}% ({}/{} tokens) - no thinning needed", current_percentage, self.used_tokens, self.total_tokens);
|
||||
}
|
||||
|
||||
false
|
||||
|
||||
@@ -275,6 +275,7 @@ impl AnthropicProvider {
|
||||
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
|
||||
let mut accumulated_usage: Option<Usage> = None;
|
||||
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
|
||||
let mut actual_completion_tokens: u32 = 0; // Track actual completion tokens
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
@@ -322,7 +323,12 @@ impl AnthropicProvider {
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
usage: accumulated_usage.clone(),
|
||||
usage: accumulated_usage.as_ref().map(|u| Usage {
|
||||
prompt_tokens: u.prompt_tokens,
|
||||
// Use actual completion tokens if we tracked them, otherwise use the estimate
|
||||
completion_tokens: if actual_completion_tokens > 0 { actual_completion_tokens } else { u.completion_tokens },
|
||||
total_tokens: u.prompt_tokens + if actual_completion_tokens > 0 { actual_completion_tokens } else { u.completion_tokens },
|
||||
}),
|
||||
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
|
||||
};
|
||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||
@@ -336,6 +342,7 @@ impl AnthropicProvider {
|
||||
match serde_json::from_str::<AnthropicStreamEvent>(data) {
|
||||
Ok(event) => {
|
||||
debug!("Parsed event type: {}, event: {:?}", event.event_type, event);
|
||||
|
||||
match event.event_type.as_str() {
|
||||
"message_start" => {
|
||||
// Extract usage data from message_start event
|
||||
@@ -346,7 +353,10 @@ impl AnthropicProvider {
|
||||
completion_tokens: usage.output_tokens,
|
||||
total_tokens: usage.input_tokens + usage.output_tokens,
|
||||
});
|
||||
debug!("Captured usage from message_start: {:?}", accumulated_usage);
|
||||
debug!("Captured initial usage from message_start - prompt: {}, completion: {} (estimated), total: {}",
|
||||
usage.input_tokens,
|
||||
usage.output_tokens,
|
||||
usage.input_tokens + usage.output_tokens);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -395,6 +405,9 @@ impl AnthropicProvider {
|
||||
"content_block_delta" => {
|
||||
if let Some(delta) = event.delta {
|
||||
if let Some(text) = delta.text {
|
||||
// Track actual completion tokens (rough estimate: 4 chars per token)
|
||||
actual_completion_tokens += (text.len() as f32 / 4.0).ceil() as u32;
|
||||
|
||||
debug!("Sending text chunk of length {}: '{}'", text.len(), text);
|
||||
let chunk = CompletionChunk {
|
||||
content: text,
|
||||
@@ -415,6 +428,19 @@ impl AnthropicProvider {
|
||||
}
|
||||
}
|
||||
}
|
||||
"message_delta" => {
|
||||
// Check if message_delta contains updated usage data
|
||||
if let Some(delta) = event.delta {
|
||||
if let Some(usage) = delta.usage {
|
||||
accumulated_usage = Some(Usage {
|
||||
prompt_tokens: usage.input_tokens,
|
||||
completion_tokens: usage.output_tokens,
|
||||
total_tokens: usage.input_tokens + usage.output_tokens,
|
||||
});
|
||||
debug!("Updated usage from message_delta - prompt: {}, completion: {}, total: {}", usage.input_tokens, usage.output_tokens, usage.input_tokens + usage.output_tokens);
|
||||
}
|
||||
}
|
||||
}
|
||||
"content_block_stop" => {
|
||||
// Tool call block is complete - now parse the accumulated JSON
|
||||
if !current_tool_calls.is_empty() && !partial_tool_json.is_empty() {
|
||||
@@ -449,11 +475,44 @@ impl AnthropicProvider {
|
||||
}
|
||||
}
|
||||
"message_stop" => {
|
||||
debug!("Received message stop event");
|
||||
debug!("Received message_stop event: {:?}", event);
|
||||
|
||||
// Check if message_stop contains final usage data
|
||||
if let Some(message) = event.message {
|
||||
if let Some(usage) = message.usage {
|
||||
// Update with final accurate usage data from message_stop
|
||||
// This should have the actual completion token count
|
||||
accumulated_usage = Some(Usage {
|
||||
prompt_tokens: usage.input_tokens,
|
||||
// Prefer the actual output_tokens from message_stop if available
|
||||
// Otherwise use our tracked count, and as last resort the initial estimate
|
||||
completion_tokens: if usage.output_tokens > 0 {
|
||||
usage.output_tokens
|
||||
} else if actual_completion_tokens > 0 {
|
||||
actual_completion_tokens
|
||||
} else { usage.output_tokens },
|
||||
total_tokens: usage.input_tokens + usage.output_tokens,
|
||||
});
|
||||
debug!("Updated with final usage from message_stop - prompt: {}, completion: {}, total: {}",
|
||||
usage.input_tokens,
|
||||
usage.output_tokens,
|
||||
usage.input_tokens + usage.output_tokens);
|
||||
}
|
||||
}
|
||||
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
usage: accumulated_usage.clone(),
|
||||
usage: accumulated_usage.as_ref().map(|u| Usage {
|
||||
prompt_tokens: u.prompt_tokens,
|
||||
// Use actual completion tokens if we tracked them and they're higher
|
||||
completion_tokens: if actual_completion_tokens > u.completion_tokens {
|
||||
actual_completion_tokens
|
||||
} else {
|
||||
u.completion_tokens
|
||||
},
|
||||
total_tokens: u.prompt_tokens + u32::max(actual_completion_tokens, u.completion_tokens),
|
||||
}),
|
||||
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
|
||||
};
|
||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||
@@ -495,10 +554,27 @@ impl AnthropicProvider {
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
usage: accumulated_usage.clone(),
|
||||
usage: accumulated_usage.as_ref().map(|u| Usage {
|
||||
prompt_tokens: u.prompt_tokens,
|
||||
completion_tokens: if actual_completion_tokens > u.completion_tokens {
|
||||
actual_completion_tokens
|
||||
} else {
|
||||
u.completion_tokens
|
||||
},
|
||||
total_tokens: u.prompt_tokens + u32::max(actual_completion_tokens, u.completion_tokens),
|
||||
}),
|
||||
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) },
|
||||
};
|
||||
let _ = tx.send(Ok(final_chunk)).await;
|
||||
|
||||
// Log final usage for debugging
|
||||
if let Some(ref usage) = accumulated_usage {
|
||||
info!("Anthropic stream completed with final usage - prompt: {}, completion: {}, total: {}",
|
||||
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens);
|
||||
} else {
|
||||
warn!("Anthropic stream completed without usage data - token accounting will fall back to estimation");
|
||||
}
|
||||
|
||||
accumulated_usage
|
||||
}
|
||||
}
|
||||
@@ -736,6 +812,8 @@ struct AnthropicStreamMessage {
|
||||
struct AnthropicDelta {
|
||||
text: Option<String>,
|
||||
partial_json: Option<String>,
|
||||
#[serde(default)]
|
||||
usage: Option<AnthropicUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
||||
164
test_token_accounting.py
Normal file
164
test_token_accounting.py
Normal file
@@ -0,0 +1,164 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify token accounting is working correctly with the Anthropic provider.
|
||||
This script will send multiple messages and verify that token counts accumulate properly.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
|
||||
def run_g3_command(prompt, provider="anthropic"):
|
||||
"""Run a g3 command and capture the output."""
|
||||
cmd = [
|
||||
"cargo", "run", "--release", "--",
|
||||
"--provider", provider,
|
||||
prompt
|
||||
]
|
||||
|
||||
env = {
|
||||
"RUST_LOG": "g3_providers=debug,g3_core=info",
|
||||
"RUST_BACKTRACE": "1"
|
||||
}
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env={**subprocess.os.environ, **env}
|
||||
)
|
||||
|
||||
return result.stdout + result.stderr
|
||||
|
||||
def extract_token_info(output):
|
||||
"""Extract token usage information from the output."""
|
||||
token_info = {}
|
||||
|
||||
# Look for token usage updates
|
||||
usage_pattern = r"Updated token usage.*was: (\d+), now: (\d+).*prompt=(\d+), completion=(\d+), total=(\d+)"
|
||||
matches = re.findall(usage_pattern, output)
|
||||
if matches:
|
||||
last_match = matches[-1]
|
||||
token_info['was'] = int(last_match[0])
|
||||
token_info['now'] = int(last_match[1])
|
||||
token_info['prompt'] = int(last_match[2])
|
||||
token_info['completion'] = int(last_match[3])
|
||||
token_info['total'] = int(last_match[4])
|
||||
|
||||
# Look for context percentage
|
||||
context_pattern = r"Context usage at (\d+)%.*\((\d+)/(\d+) tokens\)"
|
||||
matches = re.findall(context_pattern, output)
|
||||
if matches:
|
||||
last_match = matches[-1]
|
||||
token_info['percentage'] = int(last_match[0])
|
||||
token_info['used'] = int(last_match[1])
|
||||
token_info['total_context'] = int(last_match[2])
|
||||
|
||||
# Look for thinning triggers
|
||||
thinning_pattern = r"Context thinning triggered.*usage: (\d+)%.*\((\d+)/(\d+) tokens\)"
|
||||
matches = re.findall(thinning_pattern, output)
|
||||
if matches:
|
||||
token_info['thinning_triggered'] = True
|
||||
token_info['thinning_percentage'] = int(matches[-1][0])
|
||||
|
||||
# Look for final usage from Anthropic
|
||||
final_usage_pattern = r"Anthropic stream completed with final usage.*prompt: (\d+), completion: (\d+), total: (\d+)"
|
||||
matches = re.findall(final_usage_pattern, output)
|
||||
if matches:
|
||||
last_match = matches[-1]
|
||||
token_info['final_prompt'] = int(last_match[0])
|
||||
token_info['final_completion'] = int(last_match[1])
|
||||
token_info['final_total'] = int(last_match[2])
|
||||
|
||||
return token_info
|
||||
|
||||
def main():
|
||||
print("Testing Anthropic Provider Token Accounting")
|
||||
print("="*50)
|
||||
|
||||
# Build the project first
|
||||
print("Building project...")
|
||||
subprocess.run(["cargo", "build", "--release"], capture_output=True)
|
||||
|
||||
# Test 1: Simple prompt
|
||||
print("\nTest 1: Simple prompt")
|
||||
print("-"*30)
|
||||
output = run_g3_command("Say 'Hello, World!' and nothing else.")
|
||||
tokens = extract_token_info(output)
|
||||
|
||||
if tokens:
|
||||
print(f"Token usage: {tokens.get('now', 'N/A')} tokens")
|
||||
print(f" Prompt tokens: {tokens.get('prompt', 'N/A')}")
|
||||
print(f" Completion tokens: {tokens.get('completion', 'N/A')}")
|
||||
print(f" Total from provider: {tokens.get('total', 'N/A')}")
|
||||
|
||||
if 'final_total' in tokens:
|
||||
print(f" Final total from stream: {tokens['final_total']}")
|
||||
if tokens.get('now') != tokens['final_total']:
|
||||
print(f" ⚠️ WARNING: Mismatch between tracked ({tokens.get('now')}) and final ({tokens['final_total']})")
|
||||
|
||||
# Check if the completion tokens are reasonable (should be small for "Hello, World!")
|
||||
if tokens.get('completion', 0) > 50:
|
||||
print(f" ⚠️ WARNING: Completion tokens seem high for a simple response: {tokens.get('completion')}")
|
||||
else:
|
||||
print(" ❌ No token information found in output")
|
||||
|
||||
# Test 2: Longer response
|
||||
print("\nTest 2: Longer response")
|
||||
print("-"*30)
|
||||
output = run_g3_command("Write a 3-paragraph essay about the importance of accurate token counting in LLM applications.")
|
||||
tokens = extract_token_info(output)
|
||||
|
||||
if tokens:
|
||||
print(f"Token usage: {tokens.get('now', 'N/A')} tokens")
|
||||
print(f" Prompt tokens: {tokens.get('prompt', 'N/A')}")
|
||||
print(f" Completion tokens: {tokens.get('completion', 'N/A')}")
|
||||
print(f" Total from provider: {tokens.get('total', 'N/A')}")
|
||||
|
||||
if 'final_total' in tokens:
|
||||
print(f" Final total from stream: {tokens['final_total']}")
|
||||
if tokens.get('now') != tokens['final_total']:
|
||||
print(f" ⚠️ WARNING: Mismatch between tracked ({tokens.get('now')}) and final ({tokens['final_total']})")
|
||||
|
||||
# Check if completion tokens are reasonable for a longer response
|
||||
if tokens.get('completion', 0) < 100:
|
||||
print(f" ⚠️ WARNING: Completion tokens seem low for a 3-paragraph essay: {tokens.get('completion')}")
|
||||
else:
|
||||
print(" ❌ No token information found in output")
|
||||
|
||||
# Test 3: Check for proper accumulation
|
||||
print("\nTest 3: Token accumulation (multiple messages)")
|
||||
print("-"*30)
|
||||
|
||||
# First message
|
||||
output1 = run_g3_command("Count from 1 to 5.")
|
||||
tokens1 = extract_token_info(output1)
|
||||
|
||||
# Second message (this would need to be in a conversation, but for now we test separately)
|
||||
output2 = run_g3_command("Now count from 6 to 10.")
|
||||
tokens2 = extract_token_info(output2)
|
||||
|
||||
if tokens1 and tokens2:
|
||||
print(f"First message: {tokens1.get('now', 'N/A')} tokens")
|
||||
print(f"Second message: {tokens2.get('now', 'N/A')} tokens")
|
||||
|
||||
# In a real conversation, tokens2['now'] should be greater than tokens1['now']
|
||||
# But since these are separate invocations, we just check they're both reasonable
|
||||
if tokens1.get('now', 0) > 0 and tokens2.get('now', 0) > 0:
|
||||
print(" ✅ Both messages have token counts")
|
||||
else:
|
||||
print(" ❌ Missing token counts")
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("Test Summary:")
|
||||
print("Check the output above for any warnings or errors.")
|
||||
print("Key things to verify:")
|
||||
print(" 1. Token counts are being captured from the provider")
|
||||
print(" 2. Completion tokens are reasonable for the response length")
|
||||
print(" 3. No mismatch between tracked and final token counts")
|
||||
print(" 4. Context thinning triggers at appropriate thresholds")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
46
test_token_accounting.sh
Executable file
46
test_token_accounting.sh
Executable file
@@ -0,0 +1,46 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Test script to verify token accounting with Anthropic provider
|
||||
|
||||
echo "Testing token accounting with Anthropic provider..."
|
||||
echo "This test will send a few messages and check if token counts are properly tracked."
|
||||
echo ""
|
||||
|
||||
# Set up environment for testing
|
||||
export RUST_LOG=g3_providers=debug,g3_core=info
|
||||
export RUST_BACKTRACE=1
|
||||
|
||||
# Build the project first
|
||||
echo "Building project..."
|
||||
cargo build --release 2>&1 | grep -E "(Compiling|Finished)" || true
|
||||
|
||||
echo ""
|
||||
echo "Running test with Anthropic provider..."
|
||||
echo "Watch for these log messages:"
|
||||
echo " - 'Captured initial usage from message_start'"
|
||||
echo " - 'Updated usage from message_delta' (if available)"
|
||||
echo " - 'Updated with final usage from message_stop' (if available)"
|
||||
echo " - 'Anthropic stream completed with final usage'"
|
||||
echo " - 'Updated token usage from provider'"
|
||||
echo " - 'Context thinning triggered' (when reaching thresholds)"
|
||||
echo ""
|
||||
|
||||
# Create a simple test that will generate some tokens
|
||||
cat << 'EOF' > /tmp/test_prompt.txt
|
||||
Please write a short paragraph about the importance of accurate token counting in LLM applications. Then list 3 reasons why token accounting might fail.
|
||||
EOF
|
||||
|
||||
# Run the test
|
||||
echo "Sending test prompt..."
|
||||
cargo run --release -- --provider anthropic "$(cat /tmp/test_prompt.txt)" 2>&1 | tee /tmp/token_test.log
|
||||
|
||||
echo ""
|
||||
echo "Analyzing results..."
|
||||
echo ""
|
||||
|
||||
# Check for token accounting messages
|
||||
echo "Token accounting messages found:"
|
||||
grep -E "(usage from|token usage|Context thinning|Context usage)" /tmp/token_test.log | head -20
|
||||
|
||||
echo ""
|
||||
echo "Test complete. Check /tmp/token_test.log for full output."
|
||||
Reference in New Issue
Block a user