Files
g3/crates/g3-core/tests/streaming_completion_test.rs
Dhanji R. Prasanna 5b4079e861 Add prompt cache statistics tracking to /stats command
- Extend Usage struct with cache_creation_tokens and cache_read_tokens fields
- Parse Anthropic cache_creation_input_tokens and cache_read_input_tokens
- Parse OpenAI prompt_tokens_details.cached_tokens for automatic prefix caching
- Add CacheStats struct to Agent for cumulative tracking across API calls
- Add "Prompt Cache Statistics" section to /stats output showing:
  - API call count and cache hit count
  - Hit rate percentage
  - Total input tokens and cache read/creation tokens
  - Cache efficiency (% of input served from cache)
- Update all provider implementations and test files
2026-01-27 11:32:45 +11:00

539 lines
18 KiB
Rust

//! Streaming Completion Integration Test
//!
//! This test verifies that the Agent correctly processes a streaming response
//! containing multiple message types (TEXT, TOOL_CALL, TOOL_CALL, TEXT, TEXT)
//! and that control is not returned to the caller until all messages have been
//! processed and the stream signals completion.
//!
//! This protects against regressions where control might be returned mid-stream
//! after a single tool call, leaving subsequent messages unprocessed.
use anyhow::Result;
use async_trait::async_trait;
use g3_core::ui_writer::UiWriter;
use g3_core::Agent;
use g3_providers::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider,
ProviderRegistry, ToolCall, Usage,
};
use serial_test::serial;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use tempfile::TempDir;
use tokio::sync::mpsc;
// =============================================================================
// Mock Provider
// =============================================================================
/// A mock LLM provider that streams a predefined sequence of chunks.
/// On the FIRST call to stream(), it sends: TEXT -> TOOL_CALL -> TOOL_CALL -> TEXT -> TEXT -> FINISHED
/// On subsequent calls, it just sends a simple text response and finishes.
struct MockStreamingProvider {
/// Counter to track how many times stream() has been called
stream_call_count: Arc<AtomicUsize>,
/// Flag set when the first stream has sent all 6 chunks including the finished signal
first_stream_all_chunks_sent: Arc<AtomicBool>,
}
impl MockStreamingProvider {
fn new() -> Self {
Self {
stream_call_count: Arc::new(AtomicUsize::new(0)),
first_stream_all_chunks_sent: Arc::new(AtomicBool::new(false)),
}
}
#[allow(dead_code)]
fn first_stream_completed(&self) -> bool {
self.first_stream_all_chunks_sent.load(Ordering::SeqCst)
}
#[allow(dead_code)]
fn stream_call_count(&self) -> usize {
self.stream_call_count.load(Ordering::SeqCst)
}
}
fn default_usage() -> Usage {
Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
cache_creation_tokens: 0,
cache_read_tokens: 0,
}
}
#[async_trait]
impl LLMProvider for MockStreamingProvider {
async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse> {
Ok(CompletionResponse {
content: String::new(),
usage: default_usage(),
model: "mock".to_string(),
})
}
async fn stream(&self, _request: CompletionRequest) -> Result<CompletionStream> {
let call_num = self.stream_call_count.fetch_add(1, Ordering::SeqCst);
let (tx, rx) = mpsc::channel(32);
let first_stream_completed = self.first_stream_all_chunks_sent.clone();
if call_num == 0 {
// First call: send the full sequence with tool calls
tokio::spawn(async move {
// Chunk 1: Initial text
let _ = tx
.send(Ok(CompletionChunk {
content: "I'll help you with that task. Let me ".to_string(),
finished: false,
tool_calls: None,
usage: None,
stop_reason: None,
tool_call_streaming: None,
}))
.await;
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
// Chunk 2: First tool call
let _ = tx
.send(Ok(CompletionChunk {
content: String::new(),
finished: false,
tool_calls: Some(vec![ToolCall {
id: "call_1".to_string(),
tool: "shell".to_string(),
args: serde_json::json!({"command": "echo 'first tool call'"}),
}]),
usage: None,
stop_reason: None,
tool_call_streaming: None,
}))
.await;
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
// Chunk 3: Second tool call
let _ = tx
.send(Ok(CompletionChunk {
content: String::new(),
finished: false,
tool_calls: Some(vec![ToolCall {
id: "call_2".to_string(),
tool: "shell".to_string(),
args: serde_json::json!({"command": "echo 'second tool call'"}),
}]),
usage: None,
stop_reason: None,
tool_call_streaming: None,
}))
.await;
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
// Chunk 4: More text
let _ = tx
.send(Ok(CompletionChunk {
content: "Both commands executed. ".to_string(),
finished: false,
tool_calls: None,
usage: None,
stop_reason: None,
tool_call_streaming: None,
}))
.await;
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
// Chunk 5: Final text
let _ = tx
.send(Ok(CompletionChunk {
content: "Done!".to_string(),
finished: false,
tool_calls: None,
usage: None,
stop_reason: None,
tool_call_streaming: None,
}))
.await;
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
// Chunk 6: Finished signal
let _ = tx
.send(Ok(CompletionChunk {
content: String::new(),
finished: true,
tool_calls: None,
usage: Some(Usage {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
cache_creation_tokens: 0,
cache_read_tokens: 0,
}),
stop_reason: Some("end_turn".to_string()),
tool_call_streaming: None,
}))
.await;
// Mark that we sent all chunks
first_stream_completed.store(true, Ordering::SeqCst);
});
} else {
// Subsequent calls: just send a simple completion
tokio::spawn(async move {
let _ = tx
.send(Ok(CompletionChunk {
content: "Task complete.".to_string(),
finished: false,
tool_calls: None,
usage: None,
stop_reason: None,
tool_call_streaming: None,
}))
.await;
let _ = tx
.send(Ok(CompletionChunk {
content: String::new(),
finished: true,
tool_calls: None,
usage: Some(Usage {
prompt_tokens: 50,
completion_tokens: 10,
total_tokens: 60,
cache_creation_tokens: 0,
cache_read_tokens: 0,
}),
stop_reason: Some("end_turn".to_string()),
tool_call_streaming: None,
}))
.await;
});
}
Ok(tokio_stream::wrappers::ReceiverStream::new(rx))
}
fn name(&self) -> &str {
"mock"
}
fn model(&self) -> &str {
"mock-streaming-model"
}
fn has_native_tool_calling(&self) -> bool {
true
}
fn supports_cache_control(&self) -> bool {
false
}
fn max_tokens(&self) -> u32 {
4096
}
fn temperature(&self) -> f32 {
0.0
}
}
// =============================================================================
// Test UI Writer that tracks events
// =============================================================================
/// A UI writer that tracks tool call events for verification
#[derive(Clone)]
struct TrackingUiWriter {
tool_calls_seen: Arc<AtomicUsize>,
responses: Arc<std::sync::Mutex<Vec<String>>>,
}
impl TrackingUiWriter {
fn new(tool_calls_seen: Arc<AtomicUsize>) -> Self {
Self {
tool_calls_seen,
responses: Arc::new(std::sync::Mutex::new(Vec::new())),
}
}
fn tool_call_count(&self) -> usize {
self.tool_calls_seen.load(Ordering::SeqCst)
}
#[allow(dead_code)]
fn responses(&self) -> Vec<String> {
self.responses.lock().unwrap().clone()
}
}
impl UiWriter for TrackingUiWriter {
fn print(&self, _message: &str) {}
fn println(&self, _message: &str) {}
fn print_inline(&self, _message: &str) {}
fn print_system_prompt(&self, _prompt: &str) {}
fn print_context_status(&self, _message: &str) {}
fn print_g3_progress(&self, _message: &str) {}
fn print_g3_status(&self, _message: &str, _status: &str) {}
fn print_thin_result(&self, _result: &g3_core::ThinResult) {}
fn print_tool_header(&self, _tool_name: &str, _tool_args: Option<&serde_json::Value>) {
// Count each tool call
self.tool_calls_seen.fetch_add(1, Ordering::SeqCst);
}
fn print_tool_arg(&self, _key: &str, _value: &str) {}
fn print_tool_output_header(&self) {}
fn update_tool_output_line(&self, _line: &str) {}
fn print_tool_output_line(&self, _line: &str) {}
fn print_tool_output_summary(&self, _total_lines: usize) {}
fn print_tool_timing(&self, _duration: &str, _tokens: u32, _context_pct: f32) {}
fn print_tool_compact(
&self,
_tool_name: &str,
_summary: &str,
_duration: &str,
_tokens: u32,
_context_pct: f32,
) -> bool {
false
}
fn print_todo_compact(&self, _content: Option<&str>, _is_write: bool) -> bool {
false
}
fn print_tool_streaming_hint(&self, _tool_name: &str) {}
fn print_tool_streaming_active(&self) {}
fn print_agent_prompt(&self) {}
fn print_agent_response(&self, response: &str) {
self.responses.lock().unwrap().push(response.to_string());
}
fn flush(&self) {}
fn finish_streaming_markdown(&self) {}
fn reset_json_filter(&self) {}
fn filter_json_tool_calls(&self, content: &str) -> String {
content.to_string()
}
fn wants_full_output(&self) -> bool {
false
}
fn notify_sse_received(&self) {}
fn prompt_user_yes_no(&self, _message: &str) -> bool {
false
}
fn prompt_user_choice(&self, _message: &str, _options: &[&str]) -> usize {
0
}
}
// =============================================================================
// Integration Tests
// =============================================================================
/// Test that all streaming chunks are processed before control returns.
/// This simulates the interactive mode flow where a user sends a message
/// and the agent processes the full response including multiple tool calls.
///
/// The key assertion is that BOTH tool calls from the first stream are
/// processed - if control returned after the first tool call, we'd only see 1.
#[tokio::test]
#[serial]
async fn test_streaming_processes_all_chunks_before_returning() {
let temp_dir = TempDir::new().unwrap();
std::env::set_current_dir(temp_dir.path()).unwrap();
// Create mock provider
let mock_provider = MockStreamingProvider::new();
let first_stream_completed = mock_provider.first_stream_all_chunks_sent.clone();
// Create provider registry with mock
let mut registry = ProviderRegistry::new();
registry.register(mock_provider);
// Create tracking UI writer
let tool_calls_seen = Arc::new(AtomicUsize::new(0));
let ui_writer = TrackingUiWriter::new(tool_calls_seen.clone());
// Create agent with mock provider
let config = g3_config::Config::default();
let mut agent = Agent::new_for_test(config, ui_writer.clone(), registry)
.await
.unwrap();
// Execute a task - this should process ALL chunks before returning
let result = agent.execute_task("test task", None, false).await;
// The task may complete or error (due to auto-continue logic), but that's ok
let _ = result;
// CRITICAL ASSERTION 1: The first stream must have sent all its chunks
assert!(
first_stream_completed.load(Ordering::SeqCst),
"First stream did not complete sending all chunks - control may have returned early"
);
// CRITICAL ASSERTION 2: Both tool calls from the first stream must have been processed
let tool_count = ui_writer.tool_call_count();
assert!(
tool_count >= 2,
"Expected at least 2 tool calls to be processed, but only {} were seen. \
This indicates control was returned after the first tool call.",
tool_count
);
}
/// Test that the finished signal (chunk.finished = true) properly terminates
/// the stream processing loop.
#[tokio::test]
#[serial]
async fn test_finished_signal_terminates_stream() {
let temp_dir = TempDir::new().unwrap();
std::env::set_current_dir(temp_dir.path()).unwrap();
// Create a simpler mock that just sends text and finishes
struct SimpleFinishProvider {
post_finish_chunk_processed: Arc<AtomicBool>,
}
#[async_trait]
impl LLMProvider for SimpleFinishProvider {
async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse> {
Ok(CompletionResponse {
content: String::new(),
usage: Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
cache_creation_tokens: 0,
cache_read_tokens: 0,
},
model: "simple".to_string(),
})
}
async fn stream(&self, _request: CompletionRequest) -> Result<CompletionStream> {
let (tx, rx) = mpsc::channel(32);
let post_finish_flag = self.post_finish_chunk_processed.clone();
tokio::spawn(async move {
// Send some text
let _ = tx
.send(Ok(CompletionChunk {
content: "Hello, this is a test response.".to_string(),
finished: false,
tool_calls: None,
usage: None,
stop_reason: None,
tool_call_streaming: None,
}))
.await;
// Send finished signal
let _ = tx
.send(Ok(CompletionChunk {
content: String::new(),
finished: true,
tool_calls: None,
usage: Some(Usage {
prompt_tokens: 10,
completion_tokens: 10,
total_tokens: 20,
cache_creation_tokens: 0,
cache_read_tokens: 0,
}),
stop_reason: Some("end_turn".to_string()),
tool_call_streaming: None,
}))
.await;
// Wait a bit then send another chunk (should not be processed)
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
// If this send succeeds and the receiver is still listening,
// it means the stream wasn't properly terminated
if tx
.send(Ok(CompletionChunk {
content: "THIS_SHOULD_NOT_APPEAR".to_string(),
finished: false,
tool_calls: None,
usage: None,
stop_reason: None,
tool_call_streaming: None,
}))
.await
.is_ok()
{
// Channel still open - but this doesn't mean it was processed
// The flag is set if the content appears in responses
}
post_finish_flag.store(true, Ordering::SeqCst);
});
Ok(tokio_stream::wrappers::ReceiverStream::new(rx))
}
fn name(&self) -> &str {
"simple"
}
fn model(&self) -> &str {
"simple-model"
}
fn has_native_tool_calling(&self) -> bool {
false
}
fn supports_cache_control(&self) -> bool {
false
}
fn max_tokens(&self) -> u32 {
4096
}
fn temperature(&self) -> f32 {
0.0
}
}
let post_finish_flag = Arc::new(AtomicBool::new(false));
let provider = SimpleFinishProvider {
post_finish_chunk_processed: post_finish_flag.clone(),
};
let mut registry = ProviderRegistry::new();
registry.register(provider);
let tool_calls_seen = Arc::new(AtomicUsize::new(0));
let ui_writer = TrackingUiWriter::new(tool_calls_seen);
let config = g3_config::Config::default();
let mut agent = Agent::new_for_test(config, ui_writer.clone(), registry)
.await
.unwrap();
let result = agent.execute_task("test", None, false).await;
assert!(result.is_ok(), "Task should complete successfully");
// Verify the post-finish content was NOT processed
let responses = ui_writer.responses();
let all_responses = responses.join("");
assert!(
!all_responses.contains("THIS_SHOULD_NOT_APPEAR"),
"Content after finished signal should not be processed. Got: {}",
all_responses
);
assert!(
all_responses.contains("Hello"),
"Content before finished signal should be processed. Got: {}",
all_responses
);
}