Add streaming completion integration test with mock LLM provider

Adds tests to verify that:
- All streaming chunks are processed before control returns to caller
- Both tool calls in a multi-tool-call stream are executed
- The finished signal properly terminates stream processing

Also adds Agent::new_for_test() to allow injecting mock providers.
This commit is contained in:
Dhanji R. Prasanna
2026-01-16 20:52:32 +05:30
parent 0e33465342
commit fc702168ab
3 changed files with 589 additions and 0 deletions

View File

@@ -4,6 +4,9 @@ version = "0.1.0"
edition = "2021" edition = "2021"
description = "Core engine for G3 AI coding agent" description = "Core engine for G3 AI coding agent"
[features]
test-support = []
[dependencies] [dependencies]
g3-providers = { path = "../g3-providers" } g3-providers = { path = "../g3-providers" }
g3-config = { path = "../g3-config" } g3-config = { path = "../g3-config" }

View File

@@ -180,6 +180,64 @@ impl<W: UiWriter> Agent<W> {
.await .await
} }
/// Create a new agent with a custom provider registry (for testing).
/// This allows tests to inject mock providers without needing real API credentials.
///
/// **Note**: This method is intended for testing only. Do not use in production code.
#[doc(hidden)]
pub async fn new_for_test(
config: Config,
ui_writer: W,
providers: ProviderRegistry,
) -> Result<Self> {
use crate::context_window::ContextWindow;
use crate::prompts::get_system_prompt_for_native;
use g3_providers::{Message, MessageRole};
// Use a reasonable default context length for tests
let context_length = config.agent.max_context_length.unwrap_or(200_000);
let mut context_window = ContextWindow::new(context_length);
// Add system prompt
let system_prompt = get_system_prompt_for_native();
let system_message = Message::new(MessageRole::System, system_prompt);
context_window.add_message(system_message);
Ok(Self {
providers,
context_window,
auto_compact: false,
pending_90_compaction: false,
thinning_events: Vec::new(),
compaction_events: Vec::new(),
first_token_times: Vec::new(),
config,
session_id: None,
tool_call_metrics: Vec::new(),
ui_writer,
todo_content: std::sync::Arc::new(tokio::sync::RwLock::new(String::new())),
is_autonomous: false,
quiet: true,
computer_controller: None,
webdriver_session: std::sync::Arc::new(tokio::sync::RwLock::new(None)),
webdriver_process: std::sync::Arc::new(tokio::sync::RwLock::new(None)),
tool_call_count: 0,
tool_calls_this_turn: Vec::new(),
requirements_sha: None,
working_dir: None,
background_process_manager: std::sync::Arc::new(
background_process::BackgroundProcessManager::new(
paths::get_background_processes_dir(),
),
),
pending_images: Vec::new(),
is_agent_mode: false,
agent_name: None,
auto_memory: false,
acd_enabled: false,
})
}
async fn new_with_mode( async fn new_with_mode(
config: Config, config: Config,
ui_writer: W, ui_writer: W,

View File

@@ -0,0 +1,528 @@
//! Streaming Completion Integration Test
//!
//! This test verifies that the Agent correctly processes a streaming response
//! containing multiple message types (TEXT, TOOL_CALL, TOOL_CALL, TEXT, TEXT)
//! and that control is not returned to the caller until all messages have been
//! processed and the stream signals completion.
//!
//! This protects against regressions where control might be returned mid-stream
//! after a single tool call, leaving subsequent messages unprocessed.
use anyhow::Result;
use async_trait::async_trait;
use g3_core::ui_writer::UiWriter;
use g3_core::Agent;
use g3_providers::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider,
ProviderRegistry, ToolCall, Usage,
};
use serial_test::serial;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use tempfile::TempDir;
use tokio::sync::mpsc;
// =============================================================================
// Mock Provider
// =============================================================================
/// A mock LLM provider that streams a predefined sequence of chunks.
/// On the FIRST call to stream(), it sends: TEXT -> TOOL_CALL -> TOOL_CALL -> TEXT -> TEXT -> FINISHED
/// On subsequent calls, it just sends a simple text response and finishes.
struct MockStreamingProvider {
/// Counter to track how many times stream() has been called
stream_call_count: Arc<AtomicUsize>,
/// Flag set when the first stream has sent all 6 chunks including the finished signal
first_stream_all_chunks_sent: Arc<AtomicBool>,
}
impl MockStreamingProvider {
fn new() -> Self {
Self {
stream_call_count: Arc::new(AtomicUsize::new(0)),
first_stream_all_chunks_sent: Arc::new(AtomicBool::new(false)),
}
}
#[allow(dead_code)]
fn first_stream_completed(&self) -> bool {
self.first_stream_all_chunks_sent.load(Ordering::SeqCst)
}
#[allow(dead_code)]
fn stream_call_count(&self) -> usize {
self.stream_call_count.load(Ordering::SeqCst)
}
}
fn default_usage() -> Usage {
Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
}
}
#[async_trait]
impl LLMProvider for MockStreamingProvider {
async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse> {
Ok(CompletionResponse {
content: String::new(),
usage: default_usage(),
model: "mock".to_string(),
})
}
async fn stream(&self, _request: CompletionRequest) -> Result<CompletionStream> {
let call_num = self.stream_call_count.fetch_add(1, Ordering::SeqCst);
let (tx, rx) = mpsc::channel(32);
let first_stream_completed = self.first_stream_all_chunks_sent.clone();
if call_num == 0 {
// First call: send the full sequence with tool calls
tokio::spawn(async move {
// Chunk 1: Initial text
let _ = tx
.send(Ok(CompletionChunk {
content: "I'll help you with that task. Let me ".to_string(),
finished: false,
tool_calls: None,
usage: None,
stop_reason: None,
tool_call_streaming: None,
}))
.await;
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
// Chunk 2: First tool call
let _ = tx
.send(Ok(CompletionChunk {
content: String::new(),
finished: false,
tool_calls: Some(vec![ToolCall {
id: "call_1".to_string(),
tool: "shell".to_string(),
args: serde_json::json!({"command": "echo 'first tool call'"}),
}]),
usage: None,
stop_reason: None,
tool_call_streaming: None,
}))
.await;
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
// Chunk 3: Second tool call
let _ = tx
.send(Ok(CompletionChunk {
content: String::new(),
finished: false,
tool_calls: Some(vec![ToolCall {
id: "call_2".to_string(),
tool: "shell".to_string(),
args: serde_json::json!({"command": "echo 'second tool call'"}),
}]),
usage: None,
stop_reason: None,
tool_call_streaming: None,
}))
.await;
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
// Chunk 4: More text
let _ = tx
.send(Ok(CompletionChunk {
content: "Both commands executed. ".to_string(),
finished: false,
tool_calls: None,
usage: None,
stop_reason: None,
tool_call_streaming: None,
}))
.await;
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
// Chunk 5: Final text
let _ = tx
.send(Ok(CompletionChunk {
content: "Done!".to_string(),
finished: false,
tool_calls: None,
usage: None,
stop_reason: None,
tool_call_streaming: None,
}))
.await;
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
// Chunk 6: Finished signal
let _ = tx
.send(Ok(CompletionChunk {
content: String::new(),
finished: true,
tool_calls: None,
usage: Some(Usage {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
}),
stop_reason: Some("end_turn".to_string()),
tool_call_streaming: None,
}))
.await;
// Mark that we sent all chunks
first_stream_completed.store(true, Ordering::SeqCst);
});
} else {
// Subsequent calls: just send a simple completion
tokio::spawn(async move {
let _ = tx
.send(Ok(CompletionChunk {
content: "Task complete.".to_string(),
finished: false,
tool_calls: None,
usage: None,
stop_reason: None,
tool_call_streaming: None,
}))
.await;
let _ = tx
.send(Ok(CompletionChunk {
content: String::new(),
finished: true,
tool_calls: None,
usage: Some(Usage {
prompt_tokens: 50,
completion_tokens: 10,
total_tokens: 60,
}),
stop_reason: Some("end_turn".to_string()),
tool_call_streaming: None,
}))
.await;
});
}
Ok(tokio_stream::wrappers::ReceiverStream::new(rx))
}
fn name(&self) -> &str {
"mock"
}
fn model(&self) -> &str {
"mock-streaming-model"
}
fn has_native_tool_calling(&self) -> bool {
true
}
fn supports_cache_control(&self) -> bool {
false
}
fn max_tokens(&self) -> u32 {
4096
}
fn temperature(&self) -> f32 {
0.0
}
}
// =============================================================================
// Test UI Writer that tracks events
// =============================================================================
/// A UI writer that tracks tool call events for verification
#[derive(Clone)]
struct TrackingUiWriter {
tool_calls_seen: Arc<AtomicUsize>,
responses: Arc<std::sync::Mutex<Vec<String>>>,
}
impl TrackingUiWriter {
fn new(tool_calls_seen: Arc<AtomicUsize>) -> Self {
Self {
tool_calls_seen,
responses: Arc::new(std::sync::Mutex::new(Vec::new())),
}
}
fn tool_call_count(&self) -> usize {
self.tool_calls_seen.load(Ordering::SeqCst)
}
#[allow(dead_code)]
fn responses(&self) -> Vec<String> {
self.responses.lock().unwrap().clone()
}
}
impl UiWriter for TrackingUiWriter {
fn print(&self, _message: &str) {}
fn println(&self, _message: &str) {}
fn print_inline(&self, _message: &str) {}
fn print_system_prompt(&self, _prompt: &str) {}
fn print_context_status(&self, _message: &str) {}
fn print_g3_progress(&self, _message: &str) {}
fn print_g3_status(&self, _message: &str, _status: &str) {}
fn print_context_thinning(&self, _message: &str) {}
fn print_tool_header(&self, _tool_name: &str, _tool_args: Option<&serde_json::Value>) {
// Count each tool call
self.tool_calls_seen.fetch_add(1, Ordering::SeqCst);
}
fn print_tool_arg(&self, _key: &str, _value: &str) {}
fn print_tool_output_header(&self) {}
fn update_tool_output_line(&self, _line: &str) {}
fn print_tool_output_line(&self, _line: &str) {}
fn print_tool_output_summary(&self, _total_lines: usize) {}
fn print_tool_timing(&self, _duration: &str, _tokens: u32, _context_pct: f32) {}
fn print_tool_compact(
&self,
_tool_name: &str,
_summary: &str,
_duration: &str,
_tokens: u32,
_context_pct: f32,
) -> bool {
false
}
fn print_todo_compact(&self, _content: Option<&str>, _is_write: bool) -> bool {
false
}
fn print_tool_streaming_hint(&self, _tool_name: &str) {}
fn print_tool_streaming_active(&self) {}
fn print_agent_prompt(&self) {}
fn print_agent_response(&self, response: &str) {
self.responses.lock().unwrap().push(response.to_string());
}
fn flush(&self) {}
fn finish_streaming_markdown(&self) {}
fn reset_json_filter(&self) {}
fn filter_json_tool_calls(&self, content: &str) -> String {
content.to_string()
}
fn wants_full_output(&self) -> bool {
false
}
fn notify_sse_received(&self) {}
fn prompt_user_yes_no(&self, _message: &str) -> bool {
false
}
fn prompt_user_choice(&self, _message: &str, _options: &[&str]) -> usize {
0
}
}
// =============================================================================
// Integration Tests
// =============================================================================
/// Test that all streaming chunks are processed before control returns.
/// This simulates the interactive mode flow where a user sends a message
/// and the agent processes the full response including multiple tool calls.
///
/// The key assertion is that BOTH tool calls from the first stream are
/// processed - if control returned after the first tool call, we'd only see 1.
#[tokio::test]
#[serial]
async fn test_streaming_processes_all_chunks_before_returning() {
let temp_dir = TempDir::new().unwrap();
std::env::set_current_dir(temp_dir.path()).unwrap();
// Create mock provider
let mock_provider = MockStreamingProvider::new();
let first_stream_completed = mock_provider.first_stream_all_chunks_sent.clone();
// Create provider registry with mock
let mut registry = ProviderRegistry::new();
registry.register(mock_provider);
// Create tracking UI writer
let tool_calls_seen = Arc::new(AtomicUsize::new(0));
let ui_writer = TrackingUiWriter::new(tool_calls_seen.clone());
// Create agent with mock provider
let config = g3_config::Config::default();
let mut agent = Agent::new_for_test(config, ui_writer.clone(), registry)
.await
.unwrap();
// Execute a task - this should process ALL chunks before returning
let result = agent.execute_task("test task", None, false).await;
// The task may complete or error (due to auto-continue logic), but that's ok
let _ = result;
// CRITICAL ASSERTION 1: The first stream must have sent all its chunks
assert!(
first_stream_completed.load(Ordering::SeqCst),
"First stream did not complete sending all chunks - control may have returned early"
);
// CRITICAL ASSERTION 2: Both tool calls from the first stream must have been processed
let tool_count = ui_writer.tool_call_count();
assert!(
tool_count >= 2,
"Expected at least 2 tool calls to be processed, but only {} were seen. \
This indicates control was returned after the first tool call.",
tool_count
);
}
/// Test that the finished signal (chunk.finished = true) properly terminates
/// the stream processing loop.
#[tokio::test]
#[serial]
async fn test_finished_signal_terminates_stream() {
let temp_dir = TempDir::new().unwrap();
std::env::set_current_dir(temp_dir.path()).unwrap();
// Create a simpler mock that just sends text and finishes
struct SimpleFinishProvider {
post_finish_chunk_processed: Arc<AtomicBool>,
}
#[async_trait]
impl LLMProvider for SimpleFinishProvider {
async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse> {
Ok(CompletionResponse {
content: String::new(),
usage: Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
model: "simple".to_string(),
})
}
async fn stream(&self, _request: CompletionRequest) -> Result<CompletionStream> {
let (tx, rx) = mpsc::channel(32);
let post_finish_flag = self.post_finish_chunk_processed.clone();
tokio::spawn(async move {
// Send some text
let _ = tx
.send(Ok(CompletionChunk {
content: "Hello, this is a test response.".to_string(),
finished: false,
tool_calls: None,
usage: None,
stop_reason: None,
tool_call_streaming: None,
}))
.await;
// Send finished signal
let _ = tx
.send(Ok(CompletionChunk {
content: String::new(),
finished: true,
tool_calls: None,
usage: Some(Usage {
prompt_tokens: 10,
completion_tokens: 10,
total_tokens: 20,
}),
stop_reason: Some("end_turn".to_string()),
tool_call_streaming: None,
}))
.await;
// Wait a bit then send another chunk (should not be processed)
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
// If this send succeeds and the receiver is still listening,
// it means the stream wasn't properly terminated
if tx
.send(Ok(CompletionChunk {
content: "THIS_SHOULD_NOT_APPEAR".to_string(),
finished: false,
tool_calls: None,
usage: None,
stop_reason: None,
tool_call_streaming: None,
}))
.await
.is_ok()
{
// Channel still open - but this doesn't mean it was processed
// The flag is set if the content appears in responses
}
post_finish_flag.store(true, Ordering::SeqCst);
});
Ok(tokio_stream::wrappers::ReceiverStream::new(rx))
}
fn name(&self) -> &str {
"simple"
}
fn model(&self) -> &str {
"simple-model"
}
fn has_native_tool_calling(&self) -> bool {
false
}
fn supports_cache_control(&self) -> bool {
false
}
fn max_tokens(&self) -> u32 {
4096
}
fn temperature(&self) -> f32 {
0.0
}
}
let post_finish_flag = Arc::new(AtomicBool::new(false));
let provider = SimpleFinishProvider {
post_finish_chunk_processed: post_finish_flag.clone(),
};
let mut registry = ProviderRegistry::new();
registry.register(provider);
let tool_calls_seen = Arc::new(AtomicUsize::new(0));
let ui_writer = TrackingUiWriter::new(tool_calls_seen);
let config = g3_config::Config::default();
let mut agent = Agent::new_for_test(config, ui_writer.clone(), registry)
.await
.unwrap();
let result = agent.execute_task("test", None, false).await;
assert!(result.is_ok(), "Task should complete successfully");
// Verify the post-finish content was NOT processed
let responses = ui_writer.responses();
let all_responses = responses.join("");
assert!(
!all_responses.contains("THIS_SHOULD_NOT_APPEAR"),
"Content after finished signal should not be processed. Got: {}",
all_responses
);
assert!(
all_responses.contains("Hello"),
"Content before finished signal should be processed. Got: {}",
all_responses
);
}