From fc702168ab5b1347e3a9ec92545082e13f6e6c4d Mon Sep 17 00:00:00 2001 From: "Dhanji R. Prasanna" Date: Fri, 16 Jan 2026 20:52:32 +0530 Subject: [PATCH] Add streaming completion integration test with mock LLM provider Adds tests to verify that: - All streaming chunks are processed before control returns to caller - Both tool calls in a multi-tool-call stream are executed - The finished signal properly terminates stream processing Also adds Agent::new_for_test() to allow injecting mock providers. --- crates/g3-core/Cargo.toml | 3 + crates/g3-core/src/lib.rs | 58 ++ .../tests/streaming_completion_test.rs | 528 ++++++++++++++++++ 3 files changed, 589 insertions(+) create mode 100644 crates/g3-core/tests/streaming_completion_test.rs diff --git a/crates/g3-core/Cargo.toml b/crates/g3-core/Cargo.toml index 63f7272..2ddf354 100644 --- a/crates/g3-core/Cargo.toml +++ b/crates/g3-core/Cargo.toml @@ -4,6 +4,9 @@ version = "0.1.0" edition = "2021" description = "Core engine for G3 AI coding agent" +[features] +test-support = [] + [dependencies] g3-providers = { path = "../g3-providers" } g3-config = { path = "../g3-config" } diff --git a/crates/g3-core/src/lib.rs b/crates/g3-core/src/lib.rs index 56e64ae..688f9d9 100644 --- a/crates/g3-core/src/lib.rs +++ b/crates/g3-core/src/lib.rs @@ -180,6 +180,64 @@ impl Agent { .await } + /// Create a new agent with a custom provider registry (for testing). + /// This allows tests to inject mock providers without needing real API credentials. + /// + /// **Note**: This method is intended for testing only. Do not use in production code. + #[doc(hidden)] + pub async fn new_for_test( + config: Config, + ui_writer: W, + providers: ProviderRegistry, + ) -> Result { + use crate::context_window::ContextWindow; + use crate::prompts::get_system_prompt_for_native; + use g3_providers::{Message, MessageRole}; + + // Use a reasonable default context length for tests + let context_length = config.agent.max_context_length.unwrap_or(200_000); + let mut context_window = ContextWindow::new(context_length); + + // Add system prompt + let system_prompt = get_system_prompt_for_native(); + let system_message = Message::new(MessageRole::System, system_prompt); + context_window.add_message(system_message); + + Ok(Self { + providers, + context_window, + auto_compact: false, + pending_90_compaction: false, + thinning_events: Vec::new(), + compaction_events: Vec::new(), + first_token_times: Vec::new(), + config, + session_id: None, + tool_call_metrics: Vec::new(), + ui_writer, + todo_content: std::sync::Arc::new(tokio::sync::RwLock::new(String::new())), + is_autonomous: false, + quiet: true, + computer_controller: None, + webdriver_session: std::sync::Arc::new(tokio::sync::RwLock::new(None)), + webdriver_process: std::sync::Arc::new(tokio::sync::RwLock::new(None)), + tool_call_count: 0, + tool_calls_this_turn: Vec::new(), + requirements_sha: None, + working_dir: None, + background_process_manager: std::sync::Arc::new( + background_process::BackgroundProcessManager::new( + paths::get_background_processes_dir(), + ), + ), + pending_images: Vec::new(), + is_agent_mode: false, + agent_name: None, + auto_memory: false, + acd_enabled: false, + }) + } + async fn new_with_mode( config: Config, ui_writer: W, diff --git a/crates/g3-core/tests/streaming_completion_test.rs b/crates/g3-core/tests/streaming_completion_test.rs new file mode 100644 index 0000000..353ff11 --- /dev/null +++ b/crates/g3-core/tests/streaming_completion_test.rs @@ -0,0 +1,528 @@ +//! Streaming Completion Integration Test +//! +//! This test verifies that the Agent correctly processes a streaming response +//! containing multiple message types (TEXT, TOOL_CALL, TOOL_CALL, TEXT, TEXT) +//! and that control is not returned to the caller until all messages have been +//! processed and the stream signals completion. +//! +//! This protects against regressions where control might be returned mid-stream +//! after a single tool call, leaving subsequent messages unprocessed. + +use anyhow::Result; +use async_trait::async_trait; +use g3_core::ui_writer::UiWriter; +use g3_core::Agent; +use g3_providers::{ + CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, + ProviderRegistry, ToolCall, Usage, +}; +use serial_test::serial; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; +use tempfile::TempDir; +use tokio::sync::mpsc; + +// ============================================================================= +// Mock Provider +// ============================================================================= + +/// A mock LLM provider that streams a predefined sequence of chunks. +/// On the FIRST call to stream(), it sends: TEXT -> TOOL_CALL -> TOOL_CALL -> TEXT -> TEXT -> FINISHED +/// On subsequent calls, it just sends a simple text response and finishes. +struct MockStreamingProvider { + /// Counter to track how many times stream() has been called + stream_call_count: Arc, + /// Flag set when the first stream has sent all 6 chunks including the finished signal + first_stream_all_chunks_sent: Arc, +} + +impl MockStreamingProvider { + fn new() -> Self { + Self { + stream_call_count: Arc::new(AtomicUsize::new(0)), + first_stream_all_chunks_sent: Arc::new(AtomicBool::new(false)), + } + } + + #[allow(dead_code)] + fn first_stream_completed(&self) -> bool { + self.first_stream_all_chunks_sent.load(Ordering::SeqCst) + } + + #[allow(dead_code)] + fn stream_call_count(&self) -> usize { + self.stream_call_count.load(Ordering::SeqCst) + } +} + +fn default_usage() -> Usage { + Usage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + } +} + +#[async_trait] +impl LLMProvider for MockStreamingProvider { + async fn complete(&self, _request: CompletionRequest) -> Result { + Ok(CompletionResponse { + content: String::new(), + usage: default_usage(), + model: "mock".to_string(), + }) + } + + async fn stream(&self, _request: CompletionRequest) -> Result { + let call_num = self.stream_call_count.fetch_add(1, Ordering::SeqCst); + let (tx, rx) = mpsc::channel(32); + let first_stream_completed = self.first_stream_all_chunks_sent.clone(); + + if call_num == 0 { + // First call: send the full sequence with tool calls + tokio::spawn(async move { + // Chunk 1: Initial text + let _ = tx + .send(Ok(CompletionChunk { + content: "I'll help you with that task. Let me ".to_string(), + finished: false, + tool_calls: None, + usage: None, + stop_reason: None, + tool_call_streaming: None, + })) + .await; + + tokio::time::sleep(tokio::time::Duration::from_millis(5)).await; + + // Chunk 2: First tool call + let _ = tx + .send(Ok(CompletionChunk { + content: String::new(), + finished: false, + tool_calls: Some(vec![ToolCall { + id: "call_1".to_string(), + tool: "shell".to_string(), + args: serde_json::json!({"command": "echo 'first tool call'"}), + }]), + usage: None, + stop_reason: None, + tool_call_streaming: None, + })) + .await; + + tokio::time::sleep(tokio::time::Duration::from_millis(5)).await; + + // Chunk 3: Second tool call + let _ = tx + .send(Ok(CompletionChunk { + content: String::new(), + finished: false, + tool_calls: Some(vec![ToolCall { + id: "call_2".to_string(), + tool: "shell".to_string(), + args: serde_json::json!({"command": "echo 'second tool call'"}), + }]), + usage: None, + stop_reason: None, + tool_call_streaming: None, + })) + .await; + + tokio::time::sleep(tokio::time::Duration::from_millis(5)).await; + + // Chunk 4: More text + let _ = tx + .send(Ok(CompletionChunk { + content: "Both commands executed. ".to_string(), + finished: false, + tool_calls: None, + usage: None, + stop_reason: None, + tool_call_streaming: None, + })) + .await; + + tokio::time::sleep(tokio::time::Duration::from_millis(5)).await; + + // Chunk 5: Final text + let _ = tx + .send(Ok(CompletionChunk { + content: "Done!".to_string(), + finished: false, + tool_calls: None, + usage: None, + stop_reason: None, + tool_call_streaming: None, + })) + .await; + + tokio::time::sleep(tokio::time::Duration::from_millis(5)).await; + + // Chunk 6: Finished signal + let _ = tx + .send(Ok(CompletionChunk { + content: String::new(), + finished: true, + tool_calls: None, + usage: Some(Usage { + prompt_tokens: 100, + completion_tokens: 50, + total_tokens: 150, + }), + stop_reason: Some("end_turn".to_string()), + tool_call_streaming: None, + })) + .await; + + // Mark that we sent all chunks + first_stream_completed.store(true, Ordering::SeqCst); + }); + } else { + // Subsequent calls: just send a simple completion + tokio::spawn(async move { + let _ = tx + .send(Ok(CompletionChunk { + content: "Task complete.".to_string(), + finished: false, + tool_calls: None, + usage: None, + stop_reason: None, + tool_call_streaming: None, + })) + .await; + + let _ = tx + .send(Ok(CompletionChunk { + content: String::new(), + finished: true, + tool_calls: None, + usage: Some(Usage { + prompt_tokens: 50, + completion_tokens: 10, + total_tokens: 60, + }), + stop_reason: Some("end_turn".to_string()), + tool_call_streaming: None, + })) + .await; + }); + } + + Ok(tokio_stream::wrappers::ReceiverStream::new(rx)) + } + + fn name(&self) -> &str { + "mock" + } + + fn model(&self) -> &str { + "mock-streaming-model" + } + + fn has_native_tool_calling(&self) -> bool { + true + } + + fn supports_cache_control(&self) -> bool { + false + } + + fn max_tokens(&self) -> u32 { + 4096 + } + + fn temperature(&self) -> f32 { + 0.0 + } +} + +// ============================================================================= +// Test UI Writer that tracks events +// ============================================================================= + +/// A UI writer that tracks tool call events for verification +#[derive(Clone)] +struct TrackingUiWriter { + tool_calls_seen: Arc, + responses: Arc>>, +} + +impl TrackingUiWriter { + fn new(tool_calls_seen: Arc) -> Self { + Self { + tool_calls_seen, + responses: Arc::new(std::sync::Mutex::new(Vec::new())), + } + } + + fn tool_call_count(&self) -> usize { + self.tool_calls_seen.load(Ordering::SeqCst) + } + + #[allow(dead_code)] + fn responses(&self) -> Vec { + self.responses.lock().unwrap().clone() + } +} + +impl UiWriter for TrackingUiWriter { + fn print(&self, _message: &str) {} + fn println(&self, _message: &str) {} + fn print_inline(&self, _message: &str) {} + fn print_system_prompt(&self, _prompt: &str) {} + fn print_context_status(&self, _message: &str) {} + fn print_g3_progress(&self, _message: &str) {} + fn print_g3_status(&self, _message: &str, _status: &str) {} + fn print_context_thinning(&self, _message: &str) {} + + fn print_tool_header(&self, _tool_name: &str, _tool_args: Option<&serde_json::Value>) { + // Count each tool call + self.tool_calls_seen.fetch_add(1, Ordering::SeqCst); + } + + fn print_tool_arg(&self, _key: &str, _value: &str) {} + fn print_tool_output_header(&self) {} + fn update_tool_output_line(&self, _line: &str) {} + fn print_tool_output_line(&self, _line: &str) {} + fn print_tool_output_summary(&self, _total_lines: usize) {} + fn print_tool_timing(&self, _duration: &str, _tokens: u32, _context_pct: f32) {} + fn print_tool_compact( + &self, + _tool_name: &str, + _summary: &str, + _duration: &str, + _tokens: u32, + _context_pct: f32, + ) -> bool { + false + } + fn print_todo_compact(&self, _content: Option<&str>, _is_write: bool) -> bool { + false + } + fn print_tool_streaming_hint(&self, _tool_name: &str) {} + fn print_tool_streaming_active(&self) {} + + fn print_agent_prompt(&self) {} + + fn print_agent_response(&self, response: &str) { + self.responses.lock().unwrap().push(response.to_string()); + } + + fn flush(&self) {} + fn finish_streaming_markdown(&self) {} + fn reset_json_filter(&self) {} + fn filter_json_tool_calls(&self, content: &str) -> String { + content.to_string() + } + fn wants_full_output(&self) -> bool { + false + } + fn notify_sse_received(&self) {} + + fn prompt_user_yes_no(&self, _message: &str) -> bool { + false + } + + fn prompt_user_choice(&self, _message: &str, _options: &[&str]) -> usize { + 0 + } +} + +// ============================================================================= +// Integration Tests +// ============================================================================= + +/// Test that all streaming chunks are processed before control returns. +/// This simulates the interactive mode flow where a user sends a message +/// and the agent processes the full response including multiple tool calls. +/// +/// The key assertion is that BOTH tool calls from the first stream are +/// processed - if control returned after the first tool call, we'd only see 1. +#[tokio::test] +#[serial] +async fn test_streaming_processes_all_chunks_before_returning() { + let temp_dir = TempDir::new().unwrap(); + std::env::set_current_dir(temp_dir.path()).unwrap(); + + // Create mock provider + let mock_provider = MockStreamingProvider::new(); + let first_stream_completed = mock_provider.first_stream_all_chunks_sent.clone(); + + // Create provider registry with mock + let mut registry = ProviderRegistry::new(); + registry.register(mock_provider); + + // Create tracking UI writer + let tool_calls_seen = Arc::new(AtomicUsize::new(0)); + let ui_writer = TrackingUiWriter::new(tool_calls_seen.clone()); + + // Create agent with mock provider + let config = g3_config::Config::default(); + let mut agent = Agent::new_for_test(config, ui_writer.clone(), registry) + .await + .unwrap(); + + // Execute a task - this should process ALL chunks before returning + let result = agent.execute_task("test task", None, false).await; + + // The task may complete or error (due to auto-continue logic), but that's ok + let _ = result; + + // CRITICAL ASSERTION 1: The first stream must have sent all its chunks + assert!( + first_stream_completed.load(Ordering::SeqCst), + "First stream did not complete sending all chunks - control may have returned early" + ); + + // CRITICAL ASSERTION 2: Both tool calls from the first stream must have been processed + let tool_count = ui_writer.tool_call_count(); + assert!( + tool_count >= 2, + "Expected at least 2 tool calls to be processed, but only {} were seen. \ + This indicates control was returned after the first tool call.", + tool_count + ); +} + +/// Test that the finished signal (chunk.finished = true) properly terminates +/// the stream processing loop. +#[tokio::test] +#[serial] +async fn test_finished_signal_terminates_stream() { + let temp_dir = TempDir::new().unwrap(); + std::env::set_current_dir(temp_dir.path()).unwrap(); + + // Create a simpler mock that just sends text and finishes + struct SimpleFinishProvider { + post_finish_chunk_processed: Arc, + } + + #[async_trait] + impl LLMProvider for SimpleFinishProvider { + async fn complete(&self, _request: CompletionRequest) -> Result { + Ok(CompletionResponse { + content: String::new(), + usage: Usage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }, + model: "simple".to_string(), + }) + } + + async fn stream(&self, _request: CompletionRequest) -> Result { + let (tx, rx) = mpsc::channel(32); + let post_finish_flag = self.post_finish_chunk_processed.clone(); + + tokio::spawn(async move { + // Send some text + let _ = tx + .send(Ok(CompletionChunk { + content: "Hello, this is a test response.".to_string(), + finished: false, + tool_calls: None, + usage: None, + stop_reason: None, + tool_call_streaming: None, + })) + .await; + + // Send finished signal + let _ = tx + .send(Ok(CompletionChunk { + content: String::new(), + finished: true, + tool_calls: None, + usage: Some(Usage { + prompt_tokens: 10, + completion_tokens: 10, + total_tokens: 20, + }), + stop_reason: Some("end_turn".to_string()), + tool_call_streaming: None, + })) + .await; + + // Wait a bit then send another chunk (should not be processed) + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + + // If this send succeeds and the receiver is still listening, + // it means the stream wasn't properly terminated + if tx + .send(Ok(CompletionChunk { + content: "THIS_SHOULD_NOT_APPEAR".to_string(), + finished: false, + tool_calls: None, + usage: None, + stop_reason: None, + tool_call_streaming: None, + })) + .await + .is_ok() + { + // Channel still open - but this doesn't mean it was processed + // The flag is set if the content appears in responses + } + + post_finish_flag.store(true, Ordering::SeqCst); + }); + + Ok(tokio_stream::wrappers::ReceiverStream::new(rx)) + } + + fn name(&self) -> &str { + "simple" + } + fn model(&self) -> &str { + "simple-model" + } + fn has_native_tool_calling(&self) -> bool { + false + } + fn supports_cache_control(&self) -> bool { + false + } + fn max_tokens(&self) -> u32 { + 4096 + } + fn temperature(&self) -> f32 { + 0.0 + } + } + + let post_finish_flag = Arc::new(AtomicBool::new(false)); + let provider = SimpleFinishProvider { + post_finish_chunk_processed: post_finish_flag.clone(), + }; + + let mut registry = ProviderRegistry::new(); + registry.register(provider); + + let tool_calls_seen = Arc::new(AtomicUsize::new(0)); + let ui_writer = TrackingUiWriter::new(tool_calls_seen); + let config = g3_config::Config::default(); + let mut agent = Agent::new_for_test(config, ui_writer.clone(), registry) + .await + .unwrap(); + + let result = agent.execute_task("test", None, false).await; + + assert!(result.is_ok(), "Task should complete successfully"); + + // Verify the post-finish content was NOT processed + let responses = ui_writer.responses(); + let all_responses = responses.join(""); + + assert!( + !all_responses.contains("THIS_SHOULD_NOT_APPEAR"), + "Content after finished signal should not be processed. Got: {}", + all_responses + ); + assert!( + all_responses.contains("Hello"), + "Content before finished signal should be processed. Got: {}", + all_responses + ); +}