Add streaming completion integration test with mock LLM provider
Adds tests to verify that: - All streaming chunks are processed before control returns to caller - Both tool calls in a multi-tool-call stream are executed - The finished signal properly terminates stream processing Also adds Agent::new_for_test() to allow injecting mock providers.
This commit is contained in:
@@ -4,6 +4,9 @@ version = "0.1.0"
|
||||
edition = "2021"
|
||||
description = "Core engine for G3 AI coding agent"
|
||||
|
||||
[features]
|
||||
test-support = []
|
||||
|
||||
[dependencies]
|
||||
g3-providers = { path = "../g3-providers" }
|
||||
g3-config = { path = "../g3-config" }
|
||||
|
||||
@@ -180,6 +180,64 @@ impl<W: UiWriter> Agent<W> {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Create a new agent with a custom provider registry (for testing).
|
||||
/// This allows tests to inject mock providers without needing real API credentials.
|
||||
///
|
||||
/// **Note**: This method is intended for testing only. Do not use in production code.
|
||||
#[doc(hidden)]
|
||||
pub async fn new_for_test(
|
||||
config: Config,
|
||||
ui_writer: W,
|
||||
providers: ProviderRegistry,
|
||||
) -> Result<Self> {
|
||||
use crate::context_window::ContextWindow;
|
||||
use crate::prompts::get_system_prompt_for_native;
|
||||
use g3_providers::{Message, MessageRole};
|
||||
|
||||
// Use a reasonable default context length for tests
|
||||
let context_length = config.agent.max_context_length.unwrap_or(200_000);
|
||||
let mut context_window = ContextWindow::new(context_length);
|
||||
|
||||
// Add system prompt
|
||||
let system_prompt = get_system_prompt_for_native();
|
||||
let system_message = Message::new(MessageRole::System, system_prompt);
|
||||
context_window.add_message(system_message);
|
||||
|
||||
Ok(Self {
|
||||
providers,
|
||||
context_window,
|
||||
auto_compact: false,
|
||||
pending_90_compaction: false,
|
||||
thinning_events: Vec::new(),
|
||||
compaction_events: Vec::new(),
|
||||
first_token_times: Vec::new(),
|
||||
config,
|
||||
session_id: None,
|
||||
tool_call_metrics: Vec::new(),
|
||||
ui_writer,
|
||||
todo_content: std::sync::Arc::new(tokio::sync::RwLock::new(String::new())),
|
||||
is_autonomous: false,
|
||||
quiet: true,
|
||||
computer_controller: None,
|
||||
webdriver_session: std::sync::Arc::new(tokio::sync::RwLock::new(None)),
|
||||
webdriver_process: std::sync::Arc::new(tokio::sync::RwLock::new(None)),
|
||||
tool_call_count: 0,
|
||||
tool_calls_this_turn: Vec::new(),
|
||||
requirements_sha: None,
|
||||
working_dir: None,
|
||||
background_process_manager: std::sync::Arc::new(
|
||||
background_process::BackgroundProcessManager::new(
|
||||
paths::get_background_processes_dir(),
|
||||
),
|
||||
),
|
||||
pending_images: Vec::new(),
|
||||
is_agent_mode: false,
|
||||
agent_name: None,
|
||||
auto_memory: false,
|
||||
acd_enabled: false,
|
||||
})
|
||||
}
|
||||
|
||||
async fn new_with_mode(
|
||||
config: Config,
|
||||
ui_writer: W,
|
||||
|
||||
528
crates/g3-core/tests/streaming_completion_test.rs
Normal file
528
crates/g3-core/tests/streaming_completion_test.rs
Normal file
@@ -0,0 +1,528 @@
|
||||
//! Streaming Completion Integration Test
|
||||
//!
|
||||
//! This test verifies that the Agent correctly processes a streaming response
|
||||
//! containing multiple message types (TEXT, TOOL_CALL, TOOL_CALL, TEXT, TEXT)
|
||||
//! and that control is not returned to the caller until all messages have been
|
||||
//! processed and the stream signals completion.
|
||||
//!
|
||||
//! This protects against regressions where control might be returned mid-stream
|
||||
//! after a single tool call, leaving subsequent messages unprocessed.
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use g3_core::ui_writer::UiWriter;
|
||||
use g3_core::Agent;
|
||||
use g3_providers::{
|
||||
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider,
|
||||
ProviderRegistry, ToolCall, Usage,
|
||||
};
|
||||
use serial_test::serial;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tempfile::TempDir;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
// =============================================================================
|
||||
// Mock Provider
|
||||
// =============================================================================
|
||||
|
||||
/// A mock LLM provider that streams a predefined sequence of chunks.
|
||||
/// On the FIRST call to stream(), it sends: TEXT -> TOOL_CALL -> TOOL_CALL -> TEXT -> TEXT -> FINISHED
|
||||
/// On subsequent calls, it just sends a simple text response and finishes.
|
||||
struct MockStreamingProvider {
|
||||
/// Counter to track how many times stream() has been called
|
||||
stream_call_count: Arc<AtomicUsize>,
|
||||
/// Flag set when the first stream has sent all 6 chunks including the finished signal
|
||||
first_stream_all_chunks_sent: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl MockStreamingProvider {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
stream_call_count: Arc::new(AtomicUsize::new(0)),
|
||||
first_stream_all_chunks_sent: Arc::new(AtomicBool::new(false)),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn first_stream_completed(&self) -> bool {
|
||||
self.first_stream_all_chunks_sent.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn stream_call_count(&self) -> usize {
|
||||
self.stream_call_count.load(Ordering::SeqCst)
|
||||
}
|
||||
}
|
||||
|
||||
fn default_usage() -> Usage {
|
||||
Usage {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for MockStreamingProvider {
|
||||
async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse> {
|
||||
Ok(CompletionResponse {
|
||||
content: String::new(),
|
||||
usage: default_usage(),
|
||||
model: "mock".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn stream(&self, _request: CompletionRequest) -> Result<CompletionStream> {
|
||||
let call_num = self.stream_call_count.fetch_add(1, Ordering::SeqCst);
|
||||
let (tx, rx) = mpsc::channel(32);
|
||||
let first_stream_completed = self.first_stream_all_chunks_sent.clone();
|
||||
|
||||
if call_num == 0 {
|
||||
// First call: send the full sequence with tool calls
|
||||
tokio::spawn(async move {
|
||||
// Chunk 1: Initial text
|
||||
let _ = tx
|
||||
.send(Ok(CompletionChunk {
|
||||
content: "I'll help you with that task. Let me ".to_string(),
|
||||
finished: false,
|
||||
tool_calls: None,
|
||||
usage: None,
|
||||
stop_reason: None,
|
||||
tool_call_streaming: None,
|
||||
}))
|
||||
.await;
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
|
||||
|
||||
// Chunk 2: First tool call
|
||||
let _ = tx
|
||||
.send(Ok(CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: false,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: "call_1".to_string(),
|
||||
tool: "shell".to_string(),
|
||||
args: serde_json::json!({"command": "echo 'first tool call'"}),
|
||||
}]),
|
||||
usage: None,
|
||||
stop_reason: None,
|
||||
tool_call_streaming: None,
|
||||
}))
|
||||
.await;
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
|
||||
|
||||
// Chunk 3: Second tool call
|
||||
let _ = tx
|
||||
.send(Ok(CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: false,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: "call_2".to_string(),
|
||||
tool: "shell".to_string(),
|
||||
args: serde_json::json!({"command": "echo 'second tool call'"}),
|
||||
}]),
|
||||
usage: None,
|
||||
stop_reason: None,
|
||||
tool_call_streaming: None,
|
||||
}))
|
||||
.await;
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
|
||||
|
||||
// Chunk 4: More text
|
||||
let _ = tx
|
||||
.send(Ok(CompletionChunk {
|
||||
content: "Both commands executed. ".to_string(),
|
||||
finished: false,
|
||||
tool_calls: None,
|
||||
usage: None,
|
||||
stop_reason: None,
|
||||
tool_call_streaming: None,
|
||||
}))
|
||||
.await;
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
|
||||
|
||||
// Chunk 5: Final text
|
||||
let _ = tx
|
||||
.send(Ok(CompletionChunk {
|
||||
content: "Done!".to_string(),
|
||||
finished: false,
|
||||
tool_calls: None,
|
||||
usage: None,
|
||||
stop_reason: None,
|
||||
tool_call_streaming: None,
|
||||
}))
|
||||
.await;
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
|
||||
|
||||
// Chunk 6: Finished signal
|
||||
let _ = tx
|
||||
.send(Ok(CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
tool_calls: None,
|
||||
usage: Some(Usage {
|
||||
prompt_tokens: 100,
|
||||
completion_tokens: 50,
|
||||
total_tokens: 150,
|
||||
}),
|
||||
stop_reason: Some("end_turn".to_string()),
|
||||
tool_call_streaming: None,
|
||||
}))
|
||||
.await;
|
||||
|
||||
// Mark that we sent all chunks
|
||||
first_stream_completed.store(true, Ordering::SeqCst);
|
||||
});
|
||||
} else {
|
||||
// Subsequent calls: just send a simple completion
|
||||
tokio::spawn(async move {
|
||||
let _ = tx
|
||||
.send(Ok(CompletionChunk {
|
||||
content: "Task complete.".to_string(),
|
||||
finished: false,
|
||||
tool_calls: None,
|
||||
usage: None,
|
||||
stop_reason: None,
|
||||
tool_call_streaming: None,
|
||||
}))
|
||||
.await;
|
||||
|
||||
let _ = tx
|
||||
.send(Ok(CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
tool_calls: None,
|
||||
usage: Some(Usage {
|
||||
prompt_tokens: 50,
|
||||
completion_tokens: 10,
|
||||
total_tokens: 60,
|
||||
}),
|
||||
stop_reason: Some("end_turn".to_string()),
|
||||
tool_call_streaming: None,
|
||||
}))
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
Ok(tokio_stream::wrappers::ReceiverStream::new(rx))
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
"mock-streaming-model"
|
||||
}
|
||||
|
||||
fn has_native_tool_calling(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn supports_cache_control(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn max_tokens(&self) -> u32 {
|
||||
4096
|
||||
}
|
||||
|
||||
fn temperature(&self) -> f32 {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Test UI Writer that tracks events
|
||||
// =============================================================================
|
||||
|
||||
/// A UI writer that tracks tool call events for verification
|
||||
#[derive(Clone)]
|
||||
struct TrackingUiWriter {
|
||||
tool_calls_seen: Arc<AtomicUsize>,
|
||||
responses: Arc<std::sync::Mutex<Vec<String>>>,
|
||||
}
|
||||
|
||||
impl TrackingUiWriter {
|
||||
fn new(tool_calls_seen: Arc<AtomicUsize>) -> Self {
|
||||
Self {
|
||||
tool_calls_seen,
|
||||
responses: Arc::new(std::sync::Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
fn tool_call_count(&self) -> usize {
|
||||
self.tool_calls_seen.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn responses(&self) -> Vec<String> {
|
||||
self.responses.lock().unwrap().clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl UiWriter for TrackingUiWriter {
|
||||
fn print(&self, _message: &str) {}
|
||||
fn println(&self, _message: &str) {}
|
||||
fn print_inline(&self, _message: &str) {}
|
||||
fn print_system_prompt(&self, _prompt: &str) {}
|
||||
fn print_context_status(&self, _message: &str) {}
|
||||
fn print_g3_progress(&self, _message: &str) {}
|
||||
fn print_g3_status(&self, _message: &str, _status: &str) {}
|
||||
fn print_context_thinning(&self, _message: &str) {}
|
||||
|
||||
fn print_tool_header(&self, _tool_name: &str, _tool_args: Option<&serde_json::Value>) {
|
||||
// Count each tool call
|
||||
self.tool_calls_seen.fetch_add(1, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
fn print_tool_arg(&self, _key: &str, _value: &str) {}
|
||||
fn print_tool_output_header(&self) {}
|
||||
fn update_tool_output_line(&self, _line: &str) {}
|
||||
fn print_tool_output_line(&self, _line: &str) {}
|
||||
fn print_tool_output_summary(&self, _total_lines: usize) {}
|
||||
fn print_tool_timing(&self, _duration: &str, _tokens: u32, _context_pct: f32) {}
|
||||
fn print_tool_compact(
|
||||
&self,
|
||||
_tool_name: &str,
|
||||
_summary: &str,
|
||||
_duration: &str,
|
||||
_tokens: u32,
|
||||
_context_pct: f32,
|
||||
) -> bool {
|
||||
false
|
||||
}
|
||||
fn print_todo_compact(&self, _content: Option<&str>, _is_write: bool) -> bool {
|
||||
false
|
||||
}
|
||||
fn print_tool_streaming_hint(&self, _tool_name: &str) {}
|
||||
fn print_tool_streaming_active(&self) {}
|
||||
|
||||
fn print_agent_prompt(&self) {}
|
||||
|
||||
fn print_agent_response(&self, response: &str) {
|
||||
self.responses.lock().unwrap().push(response.to_string());
|
||||
}
|
||||
|
||||
fn flush(&self) {}
|
||||
fn finish_streaming_markdown(&self) {}
|
||||
fn reset_json_filter(&self) {}
|
||||
fn filter_json_tool_calls(&self, content: &str) -> String {
|
||||
content.to_string()
|
||||
}
|
||||
fn wants_full_output(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn notify_sse_received(&self) {}
|
||||
|
||||
fn prompt_user_yes_no(&self, _message: &str) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn prompt_user_choice(&self, _message: &str, _options: &[&str]) -> usize {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Integration Tests
|
||||
// =============================================================================
|
||||
|
||||
/// Test that all streaming chunks are processed before control returns.
|
||||
/// This simulates the interactive mode flow where a user sends a message
|
||||
/// and the agent processes the full response including multiple tool calls.
|
||||
///
|
||||
/// The key assertion is that BOTH tool calls from the first stream are
|
||||
/// processed - if control returned after the first tool call, we'd only see 1.
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_streaming_processes_all_chunks_before_returning() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
std::env::set_current_dir(temp_dir.path()).unwrap();
|
||||
|
||||
// Create mock provider
|
||||
let mock_provider = MockStreamingProvider::new();
|
||||
let first_stream_completed = mock_provider.first_stream_all_chunks_sent.clone();
|
||||
|
||||
// Create provider registry with mock
|
||||
let mut registry = ProviderRegistry::new();
|
||||
registry.register(mock_provider);
|
||||
|
||||
// Create tracking UI writer
|
||||
let tool_calls_seen = Arc::new(AtomicUsize::new(0));
|
||||
let ui_writer = TrackingUiWriter::new(tool_calls_seen.clone());
|
||||
|
||||
// Create agent with mock provider
|
||||
let config = g3_config::Config::default();
|
||||
let mut agent = Agent::new_for_test(config, ui_writer.clone(), registry)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Execute a task - this should process ALL chunks before returning
|
||||
let result = agent.execute_task("test task", None, false).await;
|
||||
|
||||
// The task may complete or error (due to auto-continue logic), but that's ok
|
||||
let _ = result;
|
||||
|
||||
// CRITICAL ASSERTION 1: The first stream must have sent all its chunks
|
||||
assert!(
|
||||
first_stream_completed.load(Ordering::SeqCst),
|
||||
"First stream did not complete sending all chunks - control may have returned early"
|
||||
);
|
||||
|
||||
// CRITICAL ASSERTION 2: Both tool calls from the first stream must have been processed
|
||||
let tool_count = ui_writer.tool_call_count();
|
||||
assert!(
|
||||
tool_count >= 2,
|
||||
"Expected at least 2 tool calls to be processed, but only {} were seen. \
|
||||
This indicates control was returned after the first tool call.",
|
||||
tool_count
|
||||
);
|
||||
}
|
||||
|
||||
/// Test that the finished signal (chunk.finished = true) properly terminates
|
||||
/// the stream processing loop.
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_finished_signal_terminates_stream() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
std::env::set_current_dir(temp_dir.path()).unwrap();
|
||||
|
||||
// Create a simpler mock that just sends text and finishes
|
||||
struct SimpleFinishProvider {
|
||||
post_finish_chunk_processed: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for SimpleFinishProvider {
|
||||
async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse> {
|
||||
Ok(CompletionResponse {
|
||||
content: String::new(),
|
||||
usage: Usage {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
},
|
||||
model: "simple".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn stream(&self, _request: CompletionRequest) -> Result<CompletionStream> {
|
||||
let (tx, rx) = mpsc::channel(32);
|
||||
let post_finish_flag = self.post_finish_chunk_processed.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Send some text
|
||||
let _ = tx
|
||||
.send(Ok(CompletionChunk {
|
||||
content: "Hello, this is a test response.".to_string(),
|
||||
finished: false,
|
||||
tool_calls: None,
|
||||
usage: None,
|
||||
stop_reason: None,
|
||||
tool_call_streaming: None,
|
||||
}))
|
||||
.await;
|
||||
|
||||
// Send finished signal
|
||||
let _ = tx
|
||||
.send(Ok(CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
tool_calls: None,
|
||||
usage: Some(Usage {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 10,
|
||||
total_tokens: 20,
|
||||
}),
|
||||
stop_reason: Some("end_turn".to_string()),
|
||||
tool_call_streaming: None,
|
||||
}))
|
||||
.await;
|
||||
|
||||
// Wait a bit then send another chunk (should not be processed)
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||
|
||||
// If this send succeeds and the receiver is still listening,
|
||||
// it means the stream wasn't properly terminated
|
||||
if tx
|
||||
.send(Ok(CompletionChunk {
|
||||
content: "THIS_SHOULD_NOT_APPEAR".to_string(),
|
||||
finished: false,
|
||||
tool_calls: None,
|
||||
usage: None,
|
||||
stop_reason: None,
|
||||
tool_call_streaming: None,
|
||||
}))
|
||||
.await
|
||||
.is_ok()
|
||||
{
|
||||
// Channel still open - but this doesn't mean it was processed
|
||||
// The flag is set if the content appears in responses
|
||||
}
|
||||
|
||||
post_finish_flag.store(true, Ordering::SeqCst);
|
||||
});
|
||||
|
||||
Ok(tokio_stream::wrappers::ReceiverStream::new(rx))
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"simple"
|
||||
}
|
||||
fn model(&self) -> &str {
|
||||
"simple-model"
|
||||
}
|
||||
fn has_native_tool_calling(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn supports_cache_control(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn max_tokens(&self) -> u32 {
|
||||
4096
|
||||
}
|
||||
fn temperature(&self) -> f32 {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
let post_finish_flag = Arc::new(AtomicBool::new(false));
|
||||
let provider = SimpleFinishProvider {
|
||||
post_finish_chunk_processed: post_finish_flag.clone(),
|
||||
};
|
||||
|
||||
let mut registry = ProviderRegistry::new();
|
||||
registry.register(provider);
|
||||
|
||||
let tool_calls_seen = Arc::new(AtomicUsize::new(0));
|
||||
let ui_writer = TrackingUiWriter::new(tool_calls_seen);
|
||||
let config = g3_config::Config::default();
|
||||
let mut agent = Agent::new_for_test(config, ui_writer.clone(), registry)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let result = agent.execute_task("test", None, false).await;
|
||||
|
||||
assert!(result.is_ok(), "Task should complete successfully");
|
||||
|
||||
// Verify the post-finish content was NOT processed
|
||||
let responses = ui_writer.responses();
|
||||
let all_responses = responses.join("");
|
||||
|
||||
assert!(
|
||||
!all_responses.contains("THIS_SHOULD_NOT_APPEAR"),
|
||||
"Content after finished signal should not be processed. Got: {}",
|
||||
all_responses
|
||||
);
|
||||
assert!(
|
||||
all_responses.contains("Hello"),
|
||||
"Content before finished signal should be processed. Got: {}",
|
||||
all_responses
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user