to get anthropic provider more reliable with tokens

This commit is contained in:
Michael Neale
2025-10-22 09:47:24 +11:00
parent 758e255af8
commit 738c3ac53e
4 changed files with 323 additions and 9 deletions

View File

@@ -319,10 +319,26 @@ impl ContextWindow {
/// Update token usage from provider response
pub fn update_usage_from_response(&mut self, usage: &g3_providers::Usage) {
// Add the tokens from this response to our running total
// The usage.total_tokens represents tokens used in this single API call
self.used_tokens += usage.total_tokens;
self.cumulative_tokens += usage.total_tokens;
// The provider's usage represents the tokens used in the last API call
// We need to be smarter about how we update our running total
let old_used = self.used_tokens;
// If the provider's total is greater than our current count, use it as the authoritative value
// This handles cases where our estimation was off
if usage.total_tokens > self.used_tokens {
self.used_tokens = usage.total_tokens;
self.cumulative_tokens += usage.total_tokens - old_used;
} else {
// Otherwise, add the tokens from this response
self.used_tokens += usage.completion_tokens; // Add only the new completion tokens
self.cumulative_tokens += usage.completion_tokens;
}
info!(
"Updated token usage - was: {}, now: {} (provider reported: prompt={}, completion={}, total={})",
old_used, self.used_tokens, usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
);
debug!(
"Added {} tokens from provider response (used: {}/{}, cumulative: {})",
@@ -429,8 +445,18 @@ Format this as a detailed but concise summary that can be used to resume the con
if current_percentage >= 50 {
let current_threshold = (current_percentage / 10) * 10; // Round down to nearest 10%
if current_threshold > self.last_thinning_percentage && current_threshold <= 80 {
info!(
"Context thinning triggered - usage: {}% ({}/{} tokens), threshold: {}%, last thinned at: {}%",
current_percentage,
self.used_tokens,
self.total_tokens,
current_threshold,
self.last_thinning_percentage
);
return true;
}
} else {
debug!("Context usage at {}% ({}/{} tokens) - no thinning needed", current_percentage, self.used_tokens, self.total_tokens);
}
false

View File

@@ -275,6 +275,7 @@ impl AnthropicProvider {
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
let mut accumulated_usage: Option<Usage> = None;
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
let mut actual_completion_tokens: u32 = 0; // Track actual completion tokens
while let Some(chunk_result) = stream.next().await {
match chunk_result {
@@ -322,7 +323,12 @@ impl AnthropicProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
usage: accumulated_usage.as_ref().map(|u| Usage {
prompt_tokens: u.prompt_tokens,
// Use actual completion tokens if we tracked them, otherwise use the estimate
completion_tokens: if actual_completion_tokens > 0 { actual_completion_tokens } else { u.completion_tokens },
total_tokens: u.prompt_tokens + if actual_completion_tokens > 0 { actual_completion_tokens } else { u.completion_tokens },
}),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
};
if tx.send(Ok(final_chunk)).await.is_err() {
@@ -336,6 +342,7 @@ impl AnthropicProvider {
match serde_json::from_str::<AnthropicStreamEvent>(data) {
Ok(event) => {
debug!("Parsed event type: {}, event: {:?}", event.event_type, event);
match event.event_type.as_str() {
"message_start" => {
// Extract usage data from message_start event
@@ -346,7 +353,10 @@ impl AnthropicProvider {
completion_tokens: usage.output_tokens,
total_tokens: usage.input_tokens + usage.output_tokens,
});
debug!("Captured usage from message_start: {:?}", accumulated_usage);
debug!("Captured initial usage from message_start - prompt: {}, completion: {} (estimated), total: {}",
usage.input_tokens,
usage.output_tokens,
usage.input_tokens + usage.output_tokens);
}
}
}
@@ -395,6 +405,9 @@ impl AnthropicProvider {
"content_block_delta" => {
if let Some(delta) = event.delta {
if let Some(text) = delta.text {
// Track actual completion tokens (rough estimate: 4 chars per token)
actual_completion_tokens += (text.len() as f32 / 4.0).ceil() as u32;
debug!("Sending text chunk of length {}: '{}'", text.len(), text);
let chunk = CompletionChunk {
content: text,
@@ -415,6 +428,19 @@ impl AnthropicProvider {
}
}
}
"message_delta" => {
// Check if message_delta contains updated usage data
if let Some(delta) = event.delta {
if let Some(usage) = delta.usage {
accumulated_usage = Some(Usage {
prompt_tokens: usage.input_tokens,
completion_tokens: usage.output_tokens,
total_tokens: usage.input_tokens + usage.output_tokens,
});
debug!("Updated usage from message_delta - prompt: {}, completion: {}, total: {}", usage.input_tokens, usage.output_tokens, usage.input_tokens + usage.output_tokens);
}
}
}
"content_block_stop" => {
// Tool call block is complete - now parse the accumulated JSON
if !current_tool_calls.is_empty() && !partial_tool_json.is_empty() {
@@ -449,11 +475,44 @@ impl AnthropicProvider {
}
}
"message_stop" => {
debug!("Received message stop event");
debug!("Received message_stop event: {:?}", event);
// Check if message_stop contains final usage data
if let Some(message) = event.message {
if let Some(usage) = message.usage {
// Update with final accurate usage data from message_stop
// This should have the actual completion token count
accumulated_usage = Some(Usage {
prompt_tokens: usage.input_tokens,
// Prefer the actual output_tokens from message_stop if available
// Otherwise use our tracked count, and as last resort the initial estimate
completion_tokens: if usage.output_tokens > 0 {
usage.output_tokens
} else if actual_completion_tokens > 0 {
actual_completion_tokens
} else { usage.output_tokens },
total_tokens: usage.input_tokens + usage.output_tokens,
});
debug!("Updated with final usage from message_stop - prompt: {}, completion: {}, total: {}",
usage.input_tokens,
usage.output_tokens,
usage.input_tokens + usage.output_tokens);
}
}
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
usage: accumulated_usage.as_ref().map(|u| Usage {
prompt_tokens: u.prompt_tokens,
// Use actual completion tokens if we tracked them and they're higher
completion_tokens: if actual_completion_tokens > u.completion_tokens {
actual_completion_tokens
} else {
u.completion_tokens
},
total_tokens: u.prompt_tokens + u32::max(actual_completion_tokens, u.completion_tokens),
}),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
};
if tx.send(Ok(final_chunk)).await.is_err() {
@@ -495,10 +554,27 @@ impl AnthropicProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
usage: accumulated_usage.as_ref().map(|u| Usage {
prompt_tokens: u.prompt_tokens,
completion_tokens: if actual_completion_tokens > u.completion_tokens {
actual_completion_tokens
} else {
u.completion_tokens
},
total_tokens: u.prompt_tokens + u32::max(actual_completion_tokens, u.completion_tokens),
}),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) },
};
let _ = tx.send(Ok(final_chunk)).await;
// Log final usage for debugging
if let Some(ref usage) = accumulated_usage {
info!("Anthropic stream completed with final usage - prompt: {}, completion: {}, total: {}",
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens);
} else {
warn!("Anthropic stream completed without usage data - token accounting will fall back to estimation");
}
accumulated_usage
}
}
@@ -736,6 +812,8 @@ struct AnthropicStreamMessage {
struct AnthropicDelta {
text: Option<String>,
partial_json: Option<String>,
#[serde(default)]
usage: Option<AnthropicUsage>,
}
#[derive(Debug, Deserialize)]