Add prompt cache statistics tracking to /stats command
- Extend Usage struct with cache_creation_tokens and cache_read_tokens fields - Parse Anthropic cache_creation_input_tokens and cache_read_input_tokens - Parse OpenAI prompt_tokens_details.cached_tokens for automatic prefix caching - Add CacheStats struct to Agent for cumulative tracking across API calls - Add "Prompt Cache Statistics" section to /stats output showing: - API call count and cache hit count - Hit rate percentage - Total input tokens and cache read/creation tokens - Cache efficiency (% of input served from cache) - Update all provider implementations and test files
This commit is contained in:
@@ -74,6 +74,22 @@ pub struct ToolCall {
|
||||
pub args: serde_json::Value, // Should be a JSON object with tool-specific arguments
|
||||
}
|
||||
|
||||
/// Cumulative cache statistics for prompt caching efficacy tracking.
|
||||
/// Tracks both Anthropic-style (cache_creation + cache_read) and OpenAI-style (cached_tokens) caching.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct CacheStats {
|
||||
/// Total tokens written to cache across all API calls
|
||||
pub total_cache_creation_tokens: u64,
|
||||
/// Total tokens read from cache across all API calls
|
||||
pub total_cache_read_tokens: u64,
|
||||
/// Total input tokens (for calculating cache hit rate)
|
||||
pub total_input_tokens: u64,
|
||||
/// Number of API calls that had cache hits
|
||||
pub cache_hit_calls: u32,
|
||||
/// Total number of API calls
|
||||
pub total_calls: u32,
|
||||
}
|
||||
|
||||
// Re-export WebDriverSession from its own module
|
||||
pub use webdriver_session::WebDriverSession;
|
||||
|
||||
@@ -103,6 +119,8 @@ pub struct Agent<W: UiWriter> {
|
||||
auto_compact: bool, // whether to auto-compact at 90% before tool calls
|
||||
compaction_events: Vec<usize>, // chars saved per compaction event
|
||||
first_token_times: Vec<Duration>, // time to first token for each completion
|
||||
/// Cumulative cache statistics across all API calls
|
||||
cache_stats: CacheStats,
|
||||
config: Config,
|
||||
session_id: Option<String>,
|
||||
tool_call_metrics: Vec<(String, Duration, bool)>, // (tool_name, duration, success)
|
||||
@@ -211,6 +229,7 @@ impl<W: UiWriter> Agent<W> {
|
||||
thinning_events: Vec::new(),
|
||||
compaction_events: Vec::new(),
|
||||
first_token_times: Vec::new(),
|
||||
cache_stats: CacheStats::default(),
|
||||
config,
|
||||
session_id: None,
|
||||
tool_call_metrics: Vec::new(),
|
||||
@@ -272,6 +291,7 @@ impl<W: UiWriter> Agent<W> {
|
||||
thinning_events: Vec::new(),
|
||||
compaction_events: Vec::new(),
|
||||
first_token_times: Vec::new(),
|
||||
cache_stats: CacheStats::default(),
|
||||
config,
|
||||
session_id: None,
|
||||
tool_call_metrics: Vec::new(),
|
||||
@@ -387,6 +407,7 @@ impl<W: UiWriter> Agent<W> {
|
||||
thinning_events: Vec::new(),
|
||||
compaction_events: Vec::new(),
|
||||
first_token_times: Vec::new(),
|
||||
cache_stats: CacheStats::default(),
|
||||
config,
|
||||
session_id: None,
|
||||
tool_call_metrics: Vec::new(),
|
||||
@@ -986,6 +1007,8 @@ impl<W: UiWriter> Agent<W> {
|
||||
prompt_tokens: 100, // Estimate
|
||||
completion_tokens: response_content.len() as u32 / 4, // Rough estimate
|
||||
total_tokens: 100 + (response_content.len() as u32 / 4),
|
||||
cache_creation_tokens: 0,
|
||||
cache_read_tokens: 0,
|
||||
};
|
||||
|
||||
// Update context window with estimated token usage
|
||||
@@ -1408,6 +1431,7 @@ impl<W: UiWriter> Agent<W> {
|
||||
first_token_times: &self.first_token_times,
|
||||
tool_call_metrics: &self.tool_call_metrics,
|
||||
provider_info: self.get_provider_info().ok(),
|
||||
cache_stats: &self.cache_stats,
|
||||
};
|
||||
|
||||
snapshot.format()
|
||||
@@ -2111,6 +2135,17 @@ Skip if nothing new. Be brief."#;
|
||||
if let Some(ref usage) = chunk.usage {
|
||||
iter.accumulated_usage = Some(usage.clone());
|
||||
state.turn_accumulated_usage = Some(usage.clone());
|
||||
|
||||
// Update cumulative cache statistics
|
||||
self.cache_stats.total_calls += 1;
|
||||
self.cache_stats.total_input_tokens += usage.prompt_tokens as u64;
|
||||
self.cache_stats.total_cache_creation_tokens +=
|
||||
usage.cache_creation_tokens as u64;
|
||||
self.cache_stats.total_cache_read_tokens +=
|
||||
usage.cache_read_tokens as u64;
|
||||
if usage.cache_read_tokens > 0 {
|
||||
self.cache_stats.cache_hit_calls += 1;
|
||||
}
|
||||
debug!(
|
||||
"Received usage data - prompt: {}, completion: {}, total: {}",
|
||||
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
|
||||
|
||||
@@ -7,6 +7,7 @@ use g3_providers::MessageRole;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::context_window::ContextWindow;
|
||||
use crate::CacheStats;
|
||||
|
||||
/// Data required to format agent statistics.
|
||||
/// This struct captures a snapshot of agent state for formatting.
|
||||
@@ -17,6 +18,7 @@ pub struct AgentStatsSnapshot<'a> {
|
||||
pub first_token_times: &'a [Duration],
|
||||
pub tool_call_metrics: &'a [(String, Duration, bool)],
|
||||
pub provider_info: Option<(String, String)>,
|
||||
pub cache_stats: &'a CacheStats,
|
||||
}
|
||||
|
||||
impl<'a> AgentStatsSnapshot<'a> {
|
||||
@@ -33,6 +35,7 @@ impl<'a> AgentStatsSnapshot<'a> {
|
||||
self.format_performance_metrics(&mut stats);
|
||||
self.format_conversation_history(&mut stats);
|
||||
self.format_tool_call_metrics(&mut stats);
|
||||
self.format_cache_stats(&mut stats);
|
||||
self.format_provider_info(&mut stats);
|
||||
|
||||
stats.push_str(&"=".repeat(60));
|
||||
@@ -184,6 +187,53 @@ impl<'a> AgentStatsSnapshot<'a> {
|
||||
stats.push('\n');
|
||||
}
|
||||
|
||||
fn format_cache_stats(&self, stats: &mut String) {
|
||||
stats.push_str("💾 Prompt Cache Statistics:\n");
|
||||
stats.push_str(&format!(
|
||||
" • API Calls: {:>10}\n",
|
||||
self.cache_stats.total_calls
|
||||
));
|
||||
stats.push_str(&format!(
|
||||
" • Cache Hits: {:>10}\n",
|
||||
self.cache_stats.cache_hit_calls
|
||||
));
|
||||
|
||||
// Calculate hit rate
|
||||
let hit_rate = if self.cache_stats.total_calls > 0 {
|
||||
(self.cache_stats.cache_hit_calls as f64 / self.cache_stats.total_calls as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
stats.push_str(&format!(" • Hit Rate: {:>9.1}%\n", hit_rate));
|
||||
|
||||
stats.push_str(&format!(
|
||||
" • Total Input Tokens:{:>10}\n",
|
||||
self.cache_stats.total_input_tokens
|
||||
));
|
||||
stats.push_str(&format!(
|
||||
" • Cache Created: {:>10} tokens\n",
|
||||
self.cache_stats.total_cache_creation_tokens
|
||||
));
|
||||
stats.push_str(&format!(
|
||||
" • Cache Read: {:>10} tokens\n",
|
||||
self.cache_stats.total_cache_read_tokens
|
||||
));
|
||||
|
||||
// Calculate cache read percentage of total input
|
||||
let cache_read_pct = if self.cache_stats.total_input_tokens > 0 {
|
||||
(self.cache_stats.total_cache_read_tokens as f64
|
||||
/ self.cache_stats.total_input_tokens as f64)
|
||||
* 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
stats.push_str(&format!(
|
||||
" • Cache Efficiency: {:>9.1}% of input from cache\n",
|
||||
cache_read_pct
|
||||
));
|
||||
stats.push('\n');
|
||||
}
|
||||
|
||||
fn format_provider_info(&self, stats: &mut String) {
|
||||
stats.push_str("🔌 Provider:\n");
|
||||
if let Some((provider, model)) = &self.provider_info {
|
||||
@@ -201,6 +251,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_format_stats_empty() {
|
||||
let context_window = ContextWindow::new(100000);
|
||||
let cache_stats = CacheStats::default();
|
||||
let snapshot = AgentStatsSnapshot {
|
||||
context_window: &context_window,
|
||||
thinning_events: &[],
|
||||
@@ -208,6 +259,7 @@ mod tests {
|
||||
first_token_times: &[],
|
||||
tool_call_metrics: &[],
|
||||
provider_info: None,
|
||||
cache_stats: &cache_stats,
|
||||
};
|
||||
|
||||
let stats = snapshot.format();
|
||||
@@ -215,6 +267,7 @@ mod tests {
|
||||
assert!(stats.contains("Used Tokens"));
|
||||
assert!(stats.contains("Thinning Events"));
|
||||
assert!(stats.contains("Tool Call Metrics"));
|
||||
assert!(stats.contains("Prompt Cache Statistics"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -222,6 +275,13 @@ mod tests {
|
||||
let context_window = ContextWindow::new(100000);
|
||||
let thinning_events = vec![1000, 2000, 1500];
|
||||
let compaction_events = vec![5000];
|
||||
let cache_stats = CacheStats {
|
||||
total_calls: 5,
|
||||
cache_hit_calls: 3,
|
||||
total_input_tokens: 10000,
|
||||
total_cache_creation_tokens: 2000,
|
||||
total_cache_read_tokens: 6000,
|
||||
};
|
||||
let first_token_times = vec![
|
||||
Duration::from_millis(100),
|
||||
Duration::from_millis(150),
|
||||
@@ -240,6 +300,7 @@ mod tests {
|
||||
first_token_times: &first_token_times,
|
||||
tool_call_metrics: &tool_call_metrics,
|
||||
provider_info: Some(("anthropic".to_string(), "claude-3".to_string())),
|
||||
cache_stats: &cache_stats,
|
||||
};
|
||||
|
||||
let stats = snapshot.format();
|
||||
@@ -259,5 +320,12 @@ mod tests {
|
||||
// Check provider info
|
||||
assert!(stats.contains("Provider: anthropic"));
|
||||
assert!(stats.contains("Model: claude-3"));
|
||||
|
||||
// Check cache stats
|
||||
assert!(stats.contains("Prompt Cache Statistics"));
|
||||
assert!(stats.contains("API Calls: 5"));
|
||||
assert!(stats.contains("Cache Hits: 3"));
|
||||
assert!(stats.contains("Hit Rate:") && stats.contains("60.0%"));
|
||||
assert!(stats.contains("Cache Efficiency:"));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user