diff --git a/crates/g3-core/tests/mock_provider_integration_test.rs b/crates/g3-core/tests/mock_provider_integration_test.rs new file mode 100644 index 0000000..a00174b --- /dev/null +++ b/crates/g3-core/tests/mock_provider_integration_test.rs @@ -0,0 +1,228 @@ +//! Integration tests using MockProvider +//! +//! These tests use the mock provider to exercise real code paths in +//! stream_completion_with_tools without needing a real LLM. + +use g3_core::ui_writer::NullUiWriter; +use g3_core::Agent; +use g3_providers::mock::{MockProvider, MockResponse}; +use g3_providers::{Message, MessageRole, ProviderRegistry}; +use tempfile::TempDir; + +/// Helper to create an agent with a mock provider +async fn create_agent_with_mock(provider: MockProvider) -> (Agent, TempDir) { + let temp_dir = TempDir::new().unwrap(); + + // Create a provider registry with the mock provider + let mut registry = ProviderRegistry::new(); + registry.register(provider); + + // Create a minimal config + let config = g3_config::Config::default(); + + let agent = Agent::new_for_test( + config, + NullUiWriter, + registry, + ).await.expect("Failed to create agent"); + + (agent, temp_dir) +} + +/// Helper to count messages by role +fn count_by_role(history: &[Message], role: MessageRole) -> usize { + history.iter().filter(|m| std::mem::discriminant(&m.role) == std::mem::discriminant(&role)).count() +} + +/// Helper to check for consecutive user messages +fn has_consecutive_user_messages(history: &[Message]) -> Option<(usize, usize)> { + for i in 0..history.len().saturating_sub(1) { + if matches!(history[i].role, MessageRole::User) + && matches!(history[i + 1].role, MessageRole::User) + { + return Some((i, i + 1)); + } + } + None +} + +/// Test: Text-only response saves assistant message to context +/// +/// This is the exact bug scenario from the butler session: +/// - User sends a message +/// - LLM responds with text only (no tool calls) +/// - Assistant message should be saved to context window +#[tokio::test] +async fn test_text_only_response_saves_to_context() { + let provider = MockProvider::new() + .with_response(MockResponse::text("Hello! I'm here to help.")); + + let (mut agent, _temp_dir) = create_agent_with_mock(provider).await; + + // Get initial message count + let initial_count = agent.get_context_window().conversation_history.len(); + + // Execute a task (this adds user message and gets response) + let result = agent.execute_task("Hello", None, false).await; + assert!(result.is_ok(), "Task should succeed: {:?}", result.err()); + + // Check that messages were added + let final_count = agent.get_context_window().conversation_history.len(); + assert!( + final_count > initial_count, + "Should have more messages after task, got {} -> {}", + initial_count, + final_count + ); + + // Verify the last message is from assistant + let history = &agent.get_context_window().conversation_history; + let last_msg = history.last().unwrap(); + assert!( + matches!(last_msg.role, MessageRole::Assistant), + "Last message should be assistant, got {:?}", + last_msg.role + ); +} + +/// Test: Multiple text-only responses maintain proper alternation +#[tokio::test] +async fn test_multi_turn_text_only_maintains_alternation() { + let provider = MockProvider::new().with_responses(vec![ + MockResponse::text("First response"), + MockResponse::text("Second response"), + MockResponse::text("Third response"), + ]); + + let (mut agent, _temp_dir) = create_agent_with_mock(provider).await; + + // Execute three tasks + agent.execute_task("First question", None, false).await.unwrap(); + agent.execute_task("Second question", None, false).await.unwrap(); + agent.execute_task("Third question", None, false).await.unwrap(); + + // Verify no consecutive user messages + let history = &agent.get_context_window().conversation_history; + + if let Some((i, j)) = has_consecutive_user_messages(history) { + // Print debug info + eprintln!("\n=== BUG: Consecutive user messages ==="); + for (idx, msg) in history.iter().enumerate() { + let marker = if idx == i || idx == j { ">>>" } else { " " }; + eprintln!("{} {}: {:?} - {}...", + marker, idx, msg.role, + msg.content.chars().take(50).collect::() + ); + } + panic!("Found consecutive user messages at positions {} and {}", i, j); + } +} + +/// Test: Streaming response with multiple chunks saves correctly +#[tokio::test] +async fn test_streaming_chunks_save_complete_response() { + let provider = MockProvider::new() + .with_response(MockResponse::streaming(vec!["Hello ", "world ", "from ", "streaming!"])); + + let (mut agent, _temp_dir) = create_agent_with_mock(provider).await; + + agent.execute_task("Test streaming", None, false).await.unwrap(); + + // Find the assistant message + let history = &agent.get_context_window().conversation_history; + let assistant_msg = history + .iter() + .rev() + .find(|m| matches!(m.role, MessageRole::Assistant)) + .expect("Should have an assistant message"); + + // The complete streamed content should be saved + assert!( + assistant_msg.content.contains("Hello") + && assistant_msg.content.contains("streaming"), + "Should contain full streamed content: {}", + assistant_msg.content + ); +} + +/// Test: Truncated response (max_tokens) still saves +#[tokio::test] +async fn test_truncated_response_saves() { + let provider = MockProvider::new() + .with_response(MockResponse::truncated("This response was cut off mid-sent")); + + let (mut agent, _temp_dir) = create_agent_with_mock(provider).await; + + agent.execute_task("Generate a long response", None, false).await.unwrap(); + + // Find the assistant message + let history = &agent.get_context_window().conversation_history; + let assistant_msg = history + .iter() + .rev() + .find(|m| matches!(m.role, MessageRole::Assistant)) + .expect("Should have an assistant message"); + + assert!( + assistant_msg.content.contains("cut off"), + "Should save truncated content: {}", + assistant_msg.content + ); +} + +/// Test: The exact butler bug scenario +/// +/// Scenario: +/// 1. User sends message +/// 2. LLM responds with text (no tools) - this was NOT being saved +/// 3. User sends another message +/// 4. Result: consecutive user messages in context (BUG) +#[tokio::test] +async fn test_butler_bug_scenario() { + let provider = MockProvider::new().with_responses(vec![ + MockResponse::text("Phew! 😅 Glad it's back. Sorry about that - direct SQLite manipulation was too risky."), + MockResponse::text("Yes, tasks with subtasks is a much safer approach!"), + ]); + + let (mut agent, _temp_dir) = create_agent_with_mock(provider).await; + + // Simulate the butler session: + agent.execute_task( + "Ok it's back. I have a different solution, instead of headings, what about tasks with inner subtasks?", + None, + false + ).await.unwrap(); + + agent.execute_task( + "yep that's good enough for now", + None, + false + ).await.unwrap(); + + // Verify: no consecutive user messages + let history = &agent.get_context_window().conversation_history; + + if let Some((i, j)) = has_consecutive_user_messages(history) { + // Print debug info + eprintln!("\n=== BUG DETECTED: Consecutive user messages ==="); + for (idx, msg) in history.iter().enumerate() { + let marker = if idx == i || idx == j { ">>>" } else { " " }; + eprintln!("{} {}: {:?} - {}...", + marker, idx, msg.role, + msg.content.chars().take(50).collect::() + ); + } + panic!( + "Found consecutive user messages at positions {} and {}", + i, j + ); + } + + // Also verify we have the expected assistant responses + let assistant_count = count_by_role(history, MessageRole::Assistant); + assert!( + assistant_count >= 2, + "Should have at least 2 assistant messages, got {}", + assistant_count + ); +} diff --git a/crates/g3-providers/src/lib.rs b/crates/g3-providers/src/lib.rs index ccfc6f5..f34860b 100644 --- a/crates/g3-providers/src/lib.rs +++ b/crates/g3-providers/src/lib.rs @@ -1,4 +1,7 @@ mod streaming; +pub mod mock; +pub use mock::{MockProvider, MockResponse, MockChunk}; + pub use streaming::{decode_utf8_streaming, is_incomplete_json_error, make_final_chunk, make_text_chunk, make_tool_chunk}; use anyhow::Result; diff --git a/crates/g3-providers/src/mock.rs b/crates/g3-providers/src/mock.rs new file mode 100644 index 0000000..40571bf --- /dev/null +++ b/crates/g3-providers/src/mock.rs @@ -0,0 +1,596 @@ +#![allow(dead_code)] +//! Mock LLM Provider for Testing +//! +//! This module provides a configurable mock provider that can simulate +//! various LLM behaviors for integration testing. It allows precise control +//! over streaming chunks, tool calls, and response patterns. +//! +//! # Example +//! +//! ```rust,ignore +//! use g3_providers::mock::{MockProvider, MockResponse}; +//! +//! // Simple text-only response +//! let provider = MockProvider::new() +//! .with_response(MockResponse::text("Hello, world!")); +//! +//! // Response with tool call +//! let provider = MockProvider::new() +//! .with_response(MockResponse::tool_call("shell", json!({"command": "ls"}))); +//! +//! // Multi-chunk streaming response +//! let provider = MockProvider::new() +//! .with_response(MockResponse::streaming(vec![ +//! "Hello, ", +//! "world!", +//! ])); +//! ``` + +use crate::{ + CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, + ToolCall, Usage, +}; +use anyhow::Result; +use std::sync::{Arc, Mutex}; +use std::sync::atomic::{AtomicU64, Ordering}; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; + +/// Global counter for generating unique tool call IDs +static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(1); + +/// A mock response that can be configured for testing +#[derive(Debug, Clone)] +pub struct MockResponse { + /// Chunks to stream (content, finished, tool_calls, stop_reason) + pub chunks: Vec, + /// Usage stats to report + pub usage: Usage, +} + +/// A single chunk in a mock streaming response +#[derive(Debug, Clone)] +pub struct MockChunk { + pub content: String, + pub finished: bool, + pub tool_calls: Option>, + pub stop_reason: Option, + pub tool_call_streaming: Option, +} + +impl MockChunk { + /// Create a content chunk (not finished) + pub fn content(text: &str) -> Self { + Self { + content: text.to_string(), + finished: false, + tool_calls: None, + stop_reason: None, + tool_call_streaming: None, + } + } + + /// Create a final chunk with stop reason + pub fn finished(stop_reason: &str) -> Self { + Self { + content: String::new(), + finished: true, + tool_calls: None, + stop_reason: Some(stop_reason.to_string()), + tool_call_streaming: None, + } + } + + /// Create a chunk with a tool call + pub fn tool_call(tool: &str, args: serde_json::Value) -> Self { + Self { + content: String::new(), + finished: false, + tool_calls: Some(vec![ToolCall { + id: format!("tool_{}", TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst)), + tool: tool.to_string(), + args, + }]), + stop_reason: None, + tool_call_streaming: None, + } + } + + /// Create a chunk indicating tool call is streaming (for UI hint) + pub fn tool_streaming(tool_name: &str) -> Self { + Self { + content: String::new(), + finished: false, + tool_calls: None, + stop_reason: None, + tool_call_streaming: Some(tool_name.to_string()), + } + } +} + +impl MockResponse { + /// Create a simple text-only response (single chunk + finish) + pub fn text(content: &str) -> Self { + Self { + chunks: vec![ + MockChunk::content(content), + MockChunk::finished("end_turn"), + ], + usage: Usage { + prompt_tokens: 100, + completion_tokens: content.len() as u32 / 4, + total_tokens: 100 + content.len() as u32 / 4, + }, + } + } + + /// Create a streaming text response with multiple chunks + pub fn streaming(chunks: Vec<&str>) -> Self { + let total_content: String = chunks.iter().copied().collect(); + let mut mock_chunks: Vec = chunks + .into_iter() + .map(MockChunk::content) + .collect(); + mock_chunks.push(MockChunk::finished("end_turn")); + + Self { + chunks: mock_chunks, + usage: Usage { + prompt_tokens: 100, + completion_tokens: total_content.len() as u32 / 4, + total_tokens: 100 + total_content.len() as u32 / 4, + }, + } + } + + /// Create a response with a native tool call + pub fn native_tool_call(tool: &str, args: serde_json::Value) -> Self { + Self { + chunks: vec![ + MockChunk::tool_streaming(tool), + MockChunk::tool_call(tool, args), + MockChunk::finished("tool_use"), + ], + usage: Usage { + prompt_tokens: 100, + completion_tokens: 50, + total_tokens: 150, + }, + } + } + + /// Create a response with text followed by a JSON tool call (non-native) + pub fn text_with_json_tool(text: &str, tool: &str, args: serde_json::Value) -> Self { + let tool_json = serde_json::json!({ + "tool": tool, + "args": args + }); + let tool_str = serde_json::to_string(&tool_json).unwrap(); + let full_content = format!("{}\n\n{}", text, tool_str); + + Self { + chunks: vec![ + MockChunk::content(text), + MockChunk::content("\n\n"), + MockChunk::content(&tool_str), + MockChunk::finished("end_turn"), + ], + usage: Usage { + prompt_tokens: 100, + completion_tokens: full_content.len() as u32 / 4, + total_tokens: 100 + full_content.len() as u32 / 4, + }, + } + } + + /// Create a response that gets cut off by max_tokens + pub fn truncated(content: &str) -> Self { + Self { + chunks: vec![ + MockChunk::content(content), + MockChunk::finished("max_tokens"), + ], + usage: Usage { + prompt_tokens: 100, + completion_tokens: content.len() as u32 / 4, + total_tokens: 100 + content.len() as u32 / 4, + }, + } + } + + /// Create a custom response with explicit chunks + pub fn custom(chunks: Vec, usage: Usage) -> Self { + Self { chunks, usage } + } + + /// Builder: set custom usage + pub fn with_usage(mut self, usage: Usage) -> Self { + self.usage = usage; + self + } +} + +/// A mock LLM provider for testing +/// +/// The provider maintains a queue of responses that are returned in order. +/// It also tracks all requests made for verification in tests. +pub struct MockProvider { + name: String, + model: String, + max_tokens: u32, + temperature: f32, + native_tool_calling: bool, + /// Queue of responses to return (FIFO) + responses: Arc>>, + /// All requests received (for verification) + requests: Arc>>, + /// Default response when queue is empty + default_response: Option, +} + +impl MockProvider { + /// Create a new mock provider with default settings + pub fn new() -> Self { + Self { + name: "mock".to_string(), + model: "mock-model".to_string(), + max_tokens: 4096, + temperature: 0.7, + native_tool_calling: false, + responses: Arc::new(Mutex::new(Vec::new())), + requests: Arc::new(Mutex::new(Vec::new())), + default_response: None, + } + } + + /// Set the provider name + pub fn with_name(mut self, name: &str) -> Self { + self.name = name.to_string(); + self + } + + /// Set the model name + pub fn with_model(mut self, model: &str) -> Self { + self.model = model.to_string(); + self + } + + /// Set max tokens + pub fn with_max_tokens(mut self, max_tokens: u32) -> Self { + self.max_tokens = max_tokens; + self + } + + /// Set temperature + pub fn with_temperature(mut self, temperature: f32) -> Self { + self.temperature = temperature; + self + } + + /// Enable native tool calling + pub fn with_native_tool_calling(mut self, enabled: bool) -> Self { + self.native_tool_calling = enabled; + self + } + + /// Add a response to the queue + pub fn with_response(self, response: MockResponse) -> Self { + self.responses.lock().unwrap().push(response); + self + } + + /// Add multiple responses to the queue + pub fn with_responses(self, responses: Vec) -> Self { + self.responses.lock().unwrap().extend(responses); + self + } + + /// Set a default response when queue is empty + pub fn with_default_response(mut self, response: MockResponse) -> Self { + self.default_response = Some(response); + self + } + + /// Get all requests that were made to this provider + pub fn get_requests(&self) -> Vec { + self.requests.lock().unwrap().clone() + } + + /// Get the number of requests made + pub fn request_count(&self) -> usize { + self.requests.lock().unwrap().len() + } + + /// Clear recorded requests + pub fn clear_requests(&self) { + self.requests.lock().unwrap().clear(); + } + + /// Get the next response from the queue (or default) + fn next_response(&self) -> MockResponse { + let mut responses = self.responses.lock().unwrap(); + if responses.is_empty() { + self.default_response + .clone() + .unwrap_or_else(|| MockResponse::text("Mock response (no responses configured)")) + } else { + responses.remove(0) + } + } +} + +impl Default for MockProvider { + fn default() -> Self { + Self::new() + } +} + +#[async_trait::async_trait] +impl LLMProvider for MockProvider { + async fn complete(&self, request: CompletionRequest) -> Result { + // Record the request + self.requests.lock().unwrap().push(request); + + let response = self.next_response(); + + // Combine all chunk content for non-streaming response + let content: String = response + .chunks + .iter() + .map(|c| c.content.as_str()) + .collect(); + + Ok(CompletionResponse { + content, + usage: response.usage, + model: self.model.clone(), + }) + } + + async fn stream(&self, request: CompletionRequest) -> Result { + // Record the request + self.requests.lock().unwrap().push(request); + + let response = self.next_response(); + let usage = response.usage.clone(); + + // Create a channel for streaming + let (tx, rx) = mpsc::channel(32); + let num_chunks = response.chunks.len(); + + // Spawn a task to send chunks + tokio::spawn(async move { + for (i, chunk) in response.chunks.into_iter().enumerate() { + let is_last = chunk.finished; + let completion_chunk = CompletionChunk { + content: chunk.content, + finished: chunk.finished, + tool_calls: chunk.tool_calls, + usage: if is_last { Some(usage.clone()) } else { None }, + stop_reason: chunk.stop_reason, + tool_call_streaming: chunk.tool_call_streaming, + }; + + if tx.send(Ok(completion_chunk)).await.is_err() { + // Receiver dropped, stop sending + break; + } + + // Small delay between chunks to simulate streaming + if i < num_chunks - 1 { + tokio::time::sleep(tokio::time::Duration::from_micros(100)).await; + } + } + }); + + Ok(ReceiverStream::new(rx)) + } + + fn name(&self) -> &str { + &self.name + } + + fn model(&self) -> &str { + &self.model + } + + fn has_native_tool_calling(&self) -> bool { + self.native_tool_calling + } + + fn max_tokens(&self) -> u32 { + self.max_tokens + } + + fn temperature(&self) -> f32 { + self.temperature + } +} + +// ============================================================================ +// Preset Scenarios for Common Test Cases +// ============================================================================ + +/// Preset scenarios for common testing patterns +pub mod scenarios { + use super::*; + + /// Create a provider that returns a simple text response + /// This simulates the bug scenario where text-only responses weren't saved + pub fn text_only_response(text: &str) -> MockProvider { + MockProvider::new().with_response(MockResponse::text(text)) + } + + /// Create a provider that returns text followed by a tool call + pub fn text_then_tool(text: &str, tool: &str, args: serde_json::Value) -> MockProvider { + MockProvider::new().with_response(MockResponse::text_with_json_tool(text, tool, args)) + } + + /// Create a provider for multi-turn conversation + /// Each call returns the next response in sequence + pub fn multi_turn(responses: Vec<&str>) -> MockProvider { + let mock_responses: Vec = + responses.into_iter().map(MockResponse::text).collect(); + MockProvider::new().with_responses(mock_responses) + } + + /// Create a provider that simulates tool execution flow: + /// 1. First call: returns tool call + /// 2. Second call: returns text response after tool result + pub fn tool_then_response( + tool: &str, + args: serde_json::Value, + final_response: &str, + ) -> MockProvider { + MockProvider::new() + .with_native_tool_calling(true) + .with_responses(vec![ + MockResponse::native_tool_call(tool, args), + MockResponse::text(final_response), + ]) + } + + /// Create a provider that returns a truncated response (max_tokens hit) + pub fn truncated_response(partial_content: &str) -> MockProvider { + MockProvider::new().with_response(MockResponse::truncated(partial_content)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio_stream::StreamExt; + + #[tokio::test] + async fn test_mock_provider_text_response() { + let provider = MockProvider::new().with_response(MockResponse::text("Hello, world!")); + + let request = CompletionRequest { + messages: vec![], + max_tokens: None, + temperature: None, + stream: false, + tools: None, + disable_thinking: false, + }; + + let response = provider.complete(request).await.unwrap(); + assert_eq!(response.content, "Hello, world!"); + assert_eq!(provider.request_count(), 1); + } + + #[tokio::test] + async fn test_mock_provider_streaming() { + let provider = + MockProvider::new().with_response(MockResponse::streaming(vec!["Hello, ", "world!"])); + + let request = CompletionRequest { + messages: vec![], + max_tokens: None, + temperature: None, + stream: true, + tools: None, + disable_thinking: false, + }; + + let mut stream = provider.stream(request).await.unwrap(); + + let mut content = String::new(); + let mut chunk_count = 0; + while let Some(chunk) = stream.next().await { + let chunk = chunk.unwrap(); + content.push_str(&chunk.content); + chunk_count += 1; + } + + assert_eq!(content, "Hello, world!"); + assert_eq!(chunk_count, 3); // 2 content chunks + 1 finish chunk + } + + #[tokio::test] + async fn test_mock_provider_multi_turn() { + let provider = scenarios::multi_turn(vec!["First response", "Second response"]); + + let request = CompletionRequest { + messages: vec![], + max_tokens: None, + temperature: None, + stream: false, + tools: None, + disable_thinking: false, + }; + + let response1 = provider.complete(request.clone()).await.unwrap(); + assert_eq!(response1.content, "First response"); + + let response2 = provider.complete(request).await.unwrap(); + assert_eq!(response2.content, "Second response"); + } + + #[tokio::test] + async fn test_mock_provider_tool_call() { + let provider = MockProvider::new() + .with_native_tool_calling(true) + .with_response(MockResponse::native_tool_call( + "shell", + serde_json::json!({"command": "ls"}), + )); + + let request = CompletionRequest { + messages: vec![], + max_tokens: None, + temperature: None, + stream: true, + tools: None, + disable_thinking: false, + }; + + let mut stream = provider.stream(request).await.unwrap(); + + let mut found_tool_call = false; + while let Some(chunk) = stream.next().await { + let chunk = chunk.unwrap(); + if let Some(tool_calls) = chunk.tool_calls { + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].tool, "shell"); + found_tool_call = true; + } + } + + assert!(found_tool_call, "Should have received a tool call"); + } + + #[tokio::test] + async fn test_mock_provider_request_tracking() { + let provider = MockProvider::new().with_default_response(MockResponse::text("OK")); + + let request1 = CompletionRequest { + messages: vec![crate::Message::new(crate::MessageRole::User, "Hello".to_string())], + max_tokens: Some(100), + temperature: None, + stream: false, + tools: None, + disable_thinking: false, + }; + + let request2 = CompletionRequest { + messages: vec![crate::Message::new( + crate::MessageRole::User, + "World".to_string(), + )], + max_tokens: Some(200), + temperature: None, + stream: false, + tools: None, + disable_thinking: false, + }; + + provider.complete(request1).await.unwrap(); + provider.complete(request2).await.unwrap(); + + let requests = provider.get_requests(); + assert_eq!(requests.len(), 2); + assert_eq!(requests[0].max_tokens, Some(100)); + assert_eq!(requests[1].max_tokens, Some(200)); + } +}