Add MockProvider for integration testing

Adds a configurable mock LLM provider that can simulate various behaviors:
- Text-only responses (single or multi-chunk streaming)
- Native tool calls
- JSON tool calls in text
- Truncated responses (max_tokens)
- Multi-turn conversations

Features:
- Builder pattern for easy test setup
- Request tracking for verification
- Preset scenarios for common patterns
- Full LLMProvider trait implementation

Also adds integration tests that use MockProvider to test the
stream_completion_with_tools code path, including:
- test_butler_bug_scenario: reproduces the exact bug where text-only
  responses were not saved to context, causing consecutive user messages

This enables testing complex streaming behaviors without real API calls.
This commit is contained in:
Dhanji R. Prasanna
2026-01-19 13:59:31 +05:30
parent 349230d0b7
commit 292a3aa48d
3 changed files with 827 additions and 0 deletions

View File

@@ -0,0 +1,228 @@
//! Integration tests using MockProvider
//!
//! These tests use the mock provider to exercise real code paths in
//! stream_completion_with_tools without needing a real LLM.
use g3_core::ui_writer::NullUiWriter;
use g3_core::Agent;
use g3_providers::mock::{MockProvider, MockResponse};
use g3_providers::{Message, MessageRole, ProviderRegistry};
use tempfile::TempDir;
/// Helper to create an agent with a mock provider
async fn create_agent_with_mock(provider: MockProvider) -> (Agent<NullUiWriter>, TempDir) {
let temp_dir = TempDir::new().unwrap();
// Create a provider registry with the mock provider
let mut registry = ProviderRegistry::new();
registry.register(provider);
// Create a minimal config
let config = g3_config::Config::default();
let agent = Agent::new_for_test(
config,
NullUiWriter,
registry,
).await.expect("Failed to create agent");
(agent, temp_dir)
}
/// Helper to count messages by role
fn count_by_role(history: &[Message], role: MessageRole) -> usize {
history.iter().filter(|m| std::mem::discriminant(&m.role) == std::mem::discriminant(&role)).count()
}
/// Helper to check for consecutive user messages
fn has_consecutive_user_messages(history: &[Message]) -> Option<(usize, usize)> {
for i in 0..history.len().saturating_sub(1) {
if matches!(history[i].role, MessageRole::User)
&& matches!(history[i + 1].role, MessageRole::User)
{
return Some((i, i + 1));
}
}
None
}
/// Test: Text-only response saves assistant message to context
///
/// This is the exact bug scenario from the butler session:
/// - User sends a message
/// - LLM responds with text only (no tool calls)
/// - Assistant message should be saved to context window
#[tokio::test]
async fn test_text_only_response_saves_to_context() {
let provider = MockProvider::new()
.with_response(MockResponse::text("Hello! I'm here to help."));
let (mut agent, _temp_dir) = create_agent_with_mock(provider).await;
// Get initial message count
let initial_count = agent.get_context_window().conversation_history.len();
// Execute a task (this adds user message and gets response)
let result = agent.execute_task("Hello", None, false).await;
assert!(result.is_ok(), "Task should succeed: {:?}", result.err());
// Check that messages were added
let final_count = agent.get_context_window().conversation_history.len();
assert!(
final_count > initial_count,
"Should have more messages after task, got {} -> {}",
initial_count,
final_count
);
// Verify the last message is from assistant
let history = &agent.get_context_window().conversation_history;
let last_msg = history.last().unwrap();
assert!(
matches!(last_msg.role, MessageRole::Assistant),
"Last message should be assistant, got {:?}",
last_msg.role
);
}
/// Test: Multiple text-only responses maintain proper alternation
#[tokio::test]
async fn test_multi_turn_text_only_maintains_alternation() {
let provider = MockProvider::new().with_responses(vec![
MockResponse::text("First response"),
MockResponse::text("Second response"),
MockResponse::text("Third response"),
]);
let (mut agent, _temp_dir) = create_agent_with_mock(provider).await;
// Execute three tasks
agent.execute_task("First question", None, false).await.unwrap();
agent.execute_task("Second question", None, false).await.unwrap();
agent.execute_task("Third question", None, false).await.unwrap();
// Verify no consecutive user messages
let history = &agent.get_context_window().conversation_history;
if let Some((i, j)) = has_consecutive_user_messages(history) {
// Print debug info
eprintln!("\n=== BUG: Consecutive user messages ===");
for (idx, msg) in history.iter().enumerate() {
let marker = if idx == i || idx == j { ">>>" } else { " " };
eprintln!("{} {}: {:?} - {}...",
marker, idx, msg.role,
msg.content.chars().take(50).collect::<String>()
);
}
panic!("Found consecutive user messages at positions {} and {}", i, j);
}
}
/// Test: Streaming response with multiple chunks saves correctly
#[tokio::test]
async fn test_streaming_chunks_save_complete_response() {
let provider = MockProvider::new()
.with_response(MockResponse::streaming(vec!["Hello ", "world ", "from ", "streaming!"]));
let (mut agent, _temp_dir) = create_agent_with_mock(provider).await;
agent.execute_task("Test streaming", None, false).await.unwrap();
// Find the assistant message
let history = &agent.get_context_window().conversation_history;
let assistant_msg = history
.iter()
.rev()
.find(|m| matches!(m.role, MessageRole::Assistant))
.expect("Should have an assistant message");
// The complete streamed content should be saved
assert!(
assistant_msg.content.contains("Hello")
&& assistant_msg.content.contains("streaming"),
"Should contain full streamed content: {}",
assistant_msg.content
);
}
/// Test: Truncated response (max_tokens) still saves
#[tokio::test]
async fn test_truncated_response_saves() {
let provider = MockProvider::new()
.with_response(MockResponse::truncated("This response was cut off mid-sent"));
let (mut agent, _temp_dir) = create_agent_with_mock(provider).await;
agent.execute_task("Generate a long response", None, false).await.unwrap();
// Find the assistant message
let history = &agent.get_context_window().conversation_history;
let assistant_msg = history
.iter()
.rev()
.find(|m| matches!(m.role, MessageRole::Assistant))
.expect("Should have an assistant message");
assert!(
assistant_msg.content.contains("cut off"),
"Should save truncated content: {}",
assistant_msg.content
);
}
/// Test: The exact butler bug scenario
///
/// Scenario:
/// 1. User sends message
/// 2. LLM responds with text (no tools) - this was NOT being saved
/// 3. User sends another message
/// 4. Result: consecutive user messages in context (BUG)
#[tokio::test]
async fn test_butler_bug_scenario() {
let provider = MockProvider::new().with_responses(vec![
MockResponse::text("Phew! 😅 Glad it's back. Sorry about that - direct SQLite manipulation was too risky."),
MockResponse::text("Yes, tasks with subtasks is a much safer approach!"),
]);
let (mut agent, _temp_dir) = create_agent_with_mock(provider).await;
// Simulate the butler session:
agent.execute_task(
"Ok it's back. I have a different solution, instead of headings, what about tasks with inner subtasks?",
None,
false
).await.unwrap();
agent.execute_task(
"yep that's good enough for now",
None,
false
).await.unwrap();
// Verify: no consecutive user messages
let history = &agent.get_context_window().conversation_history;
if let Some((i, j)) = has_consecutive_user_messages(history) {
// Print debug info
eprintln!("\n=== BUG DETECTED: Consecutive user messages ===");
for (idx, msg) in history.iter().enumerate() {
let marker = if idx == i || idx == j { ">>>" } else { " " };
eprintln!("{} {}: {:?} - {}...",
marker, idx, msg.role,
msg.content.chars().take(50).collect::<String>()
);
}
panic!(
"Found consecutive user messages at positions {} and {}",
i, j
);
}
// Also verify we have the expected assistant responses
let assistant_count = count_by_role(history, MessageRole::Assistant);
assert!(
assistant_count >= 2,
"Should have at least 2 assistant messages, got {}",
assistant_count
);
}

View File

@@ -1,4 +1,7 @@
mod streaming;
pub mod mock;
pub use mock::{MockProvider, MockResponse, MockChunk};
pub use streaming::{decode_utf8_streaming, is_incomplete_json_error, make_final_chunk, make_text_chunk, make_tool_chunk};
use anyhow::Result;

View File

@@ -0,0 +1,596 @@
#![allow(dead_code)]
//! Mock LLM Provider for Testing
//!
//! This module provides a configurable mock provider that can simulate
//! various LLM behaviors for integration testing. It allows precise control
//! over streaming chunks, tool calls, and response patterns.
//!
//! # Example
//!
//! ```rust,ignore
//! use g3_providers::mock::{MockProvider, MockResponse};
//!
//! // Simple text-only response
//! let provider = MockProvider::new()
//! .with_response(MockResponse::text("Hello, world!"));
//!
//! // Response with tool call
//! let provider = MockProvider::new()
//! .with_response(MockResponse::tool_call("shell", json!({"command": "ls"})));
//!
//! // Multi-chunk streaming response
//! let provider = MockProvider::new()
//! .with_response(MockResponse::streaming(vec![
//! "Hello, ",
//! "world!",
//! ]));
//! ```
use crate::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider,
ToolCall, Usage,
};
use anyhow::Result;
use std::sync::{Arc, Mutex};
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
/// Global counter for generating unique tool call IDs
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(1);
/// A mock response that can be configured for testing
#[derive(Debug, Clone)]
pub struct MockResponse {
/// Chunks to stream (content, finished, tool_calls, stop_reason)
pub chunks: Vec<MockChunk>,
/// Usage stats to report
pub usage: Usage,
}
/// A single chunk in a mock streaming response
#[derive(Debug, Clone)]
pub struct MockChunk {
pub content: String,
pub finished: bool,
pub tool_calls: Option<Vec<ToolCall>>,
pub stop_reason: Option<String>,
pub tool_call_streaming: Option<String>,
}
impl MockChunk {
/// Create a content chunk (not finished)
pub fn content(text: &str) -> Self {
Self {
content: text.to_string(),
finished: false,
tool_calls: None,
stop_reason: None,
tool_call_streaming: None,
}
}
/// Create a final chunk with stop reason
pub fn finished(stop_reason: &str) -> Self {
Self {
content: String::new(),
finished: true,
tool_calls: None,
stop_reason: Some(stop_reason.to_string()),
tool_call_streaming: None,
}
}
/// Create a chunk with a tool call
pub fn tool_call(tool: &str, args: serde_json::Value) -> Self {
Self {
content: String::new(),
finished: false,
tool_calls: Some(vec![ToolCall {
id: format!("tool_{}", TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst)),
tool: tool.to_string(),
args,
}]),
stop_reason: None,
tool_call_streaming: None,
}
}
/// Create a chunk indicating tool call is streaming (for UI hint)
pub fn tool_streaming(tool_name: &str) -> Self {
Self {
content: String::new(),
finished: false,
tool_calls: None,
stop_reason: None,
tool_call_streaming: Some(tool_name.to_string()),
}
}
}
impl MockResponse {
/// Create a simple text-only response (single chunk + finish)
pub fn text(content: &str) -> Self {
Self {
chunks: vec![
MockChunk::content(content),
MockChunk::finished("end_turn"),
],
usage: Usage {
prompt_tokens: 100,
completion_tokens: content.len() as u32 / 4,
total_tokens: 100 + content.len() as u32 / 4,
},
}
}
/// Create a streaming text response with multiple chunks
pub fn streaming(chunks: Vec<&str>) -> Self {
let total_content: String = chunks.iter().copied().collect();
let mut mock_chunks: Vec<MockChunk> = chunks
.into_iter()
.map(MockChunk::content)
.collect();
mock_chunks.push(MockChunk::finished("end_turn"));
Self {
chunks: mock_chunks,
usage: Usage {
prompt_tokens: 100,
completion_tokens: total_content.len() as u32 / 4,
total_tokens: 100 + total_content.len() as u32 / 4,
},
}
}
/// Create a response with a native tool call
pub fn native_tool_call(tool: &str, args: serde_json::Value) -> Self {
Self {
chunks: vec![
MockChunk::tool_streaming(tool),
MockChunk::tool_call(tool, args),
MockChunk::finished("tool_use"),
],
usage: Usage {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
},
}
}
/// Create a response with text followed by a JSON tool call (non-native)
pub fn text_with_json_tool(text: &str, tool: &str, args: serde_json::Value) -> Self {
let tool_json = serde_json::json!({
"tool": tool,
"args": args
});
let tool_str = serde_json::to_string(&tool_json).unwrap();
let full_content = format!("{}\n\n{}", text, tool_str);
Self {
chunks: vec![
MockChunk::content(text),
MockChunk::content("\n\n"),
MockChunk::content(&tool_str),
MockChunk::finished("end_turn"),
],
usage: Usage {
prompt_tokens: 100,
completion_tokens: full_content.len() as u32 / 4,
total_tokens: 100 + full_content.len() as u32 / 4,
},
}
}
/// Create a response that gets cut off by max_tokens
pub fn truncated(content: &str) -> Self {
Self {
chunks: vec![
MockChunk::content(content),
MockChunk::finished("max_tokens"),
],
usage: Usage {
prompt_tokens: 100,
completion_tokens: content.len() as u32 / 4,
total_tokens: 100 + content.len() as u32 / 4,
},
}
}
/// Create a custom response with explicit chunks
pub fn custom(chunks: Vec<MockChunk>, usage: Usage) -> Self {
Self { chunks, usage }
}
/// Builder: set custom usage
pub fn with_usage(mut self, usage: Usage) -> Self {
self.usage = usage;
self
}
}
/// A mock LLM provider for testing
///
/// The provider maintains a queue of responses that are returned in order.
/// It also tracks all requests made for verification in tests.
pub struct MockProvider {
name: String,
model: String,
max_tokens: u32,
temperature: f32,
native_tool_calling: bool,
/// Queue of responses to return (FIFO)
responses: Arc<Mutex<Vec<MockResponse>>>,
/// All requests received (for verification)
requests: Arc<Mutex<Vec<CompletionRequest>>>,
/// Default response when queue is empty
default_response: Option<MockResponse>,
}
impl MockProvider {
/// Create a new mock provider with default settings
pub fn new() -> Self {
Self {
name: "mock".to_string(),
model: "mock-model".to_string(),
max_tokens: 4096,
temperature: 0.7,
native_tool_calling: false,
responses: Arc::new(Mutex::new(Vec::new())),
requests: Arc::new(Mutex::new(Vec::new())),
default_response: None,
}
}
/// Set the provider name
pub fn with_name(mut self, name: &str) -> Self {
self.name = name.to_string();
self
}
/// Set the model name
pub fn with_model(mut self, model: &str) -> Self {
self.model = model.to_string();
self
}
/// Set max tokens
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
/// Set temperature
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
/// Enable native tool calling
pub fn with_native_tool_calling(mut self, enabled: bool) -> Self {
self.native_tool_calling = enabled;
self
}
/// Add a response to the queue
pub fn with_response(self, response: MockResponse) -> Self {
self.responses.lock().unwrap().push(response);
self
}
/// Add multiple responses to the queue
pub fn with_responses(self, responses: Vec<MockResponse>) -> Self {
self.responses.lock().unwrap().extend(responses);
self
}
/// Set a default response when queue is empty
pub fn with_default_response(mut self, response: MockResponse) -> Self {
self.default_response = Some(response);
self
}
/// Get all requests that were made to this provider
pub fn get_requests(&self) -> Vec<CompletionRequest> {
self.requests.lock().unwrap().clone()
}
/// Get the number of requests made
pub fn request_count(&self) -> usize {
self.requests.lock().unwrap().len()
}
/// Clear recorded requests
pub fn clear_requests(&self) {
self.requests.lock().unwrap().clear();
}
/// Get the next response from the queue (or default)
fn next_response(&self) -> MockResponse {
let mut responses = self.responses.lock().unwrap();
if responses.is_empty() {
self.default_response
.clone()
.unwrap_or_else(|| MockResponse::text("Mock response (no responses configured)"))
} else {
responses.remove(0)
}
}
}
impl Default for MockProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl LLMProvider for MockProvider {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
// Record the request
self.requests.lock().unwrap().push(request);
let response = self.next_response();
// Combine all chunk content for non-streaming response
let content: String = response
.chunks
.iter()
.map(|c| c.content.as_str())
.collect();
Ok(CompletionResponse {
content,
usage: response.usage,
model: self.model.clone(),
})
}
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
// Record the request
self.requests.lock().unwrap().push(request);
let response = self.next_response();
let usage = response.usage.clone();
// Create a channel for streaming
let (tx, rx) = mpsc::channel(32);
let num_chunks = response.chunks.len();
// Spawn a task to send chunks
tokio::spawn(async move {
for (i, chunk) in response.chunks.into_iter().enumerate() {
let is_last = chunk.finished;
let completion_chunk = CompletionChunk {
content: chunk.content,
finished: chunk.finished,
tool_calls: chunk.tool_calls,
usage: if is_last { Some(usage.clone()) } else { None },
stop_reason: chunk.stop_reason,
tool_call_streaming: chunk.tool_call_streaming,
};
if tx.send(Ok(completion_chunk)).await.is_err() {
// Receiver dropped, stop sending
break;
}
// Small delay between chunks to simulate streaming
if i < num_chunks - 1 {
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
}
}
});
Ok(ReceiverStream::new(rx))
}
fn name(&self) -> &str {
&self.name
}
fn model(&self) -> &str {
&self.model
}
fn has_native_tool_calling(&self) -> bool {
self.native_tool_calling
}
fn max_tokens(&self) -> u32 {
self.max_tokens
}
fn temperature(&self) -> f32 {
self.temperature
}
}
// ============================================================================
// Preset Scenarios for Common Test Cases
// ============================================================================
/// Preset scenarios for common testing patterns
pub mod scenarios {
use super::*;
/// Create a provider that returns a simple text response
/// This simulates the bug scenario where text-only responses weren't saved
pub fn text_only_response(text: &str) -> MockProvider {
MockProvider::new().with_response(MockResponse::text(text))
}
/// Create a provider that returns text followed by a tool call
pub fn text_then_tool(text: &str, tool: &str, args: serde_json::Value) -> MockProvider {
MockProvider::new().with_response(MockResponse::text_with_json_tool(text, tool, args))
}
/// Create a provider for multi-turn conversation
/// Each call returns the next response in sequence
pub fn multi_turn(responses: Vec<&str>) -> MockProvider {
let mock_responses: Vec<MockResponse> =
responses.into_iter().map(MockResponse::text).collect();
MockProvider::new().with_responses(mock_responses)
}
/// Create a provider that simulates tool execution flow:
/// 1. First call: returns tool call
/// 2. Second call: returns text response after tool result
pub fn tool_then_response(
tool: &str,
args: serde_json::Value,
final_response: &str,
) -> MockProvider {
MockProvider::new()
.with_native_tool_calling(true)
.with_responses(vec![
MockResponse::native_tool_call(tool, args),
MockResponse::text(final_response),
])
}
/// Create a provider that returns a truncated response (max_tokens hit)
pub fn truncated_response(partial_content: &str) -> MockProvider {
MockProvider::new().with_response(MockResponse::truncated(partial_content))
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio_stream::StreamExt;
#[tokio::test]
async fn test_mock_provider_text_response() {
let provider = MockProvider::new().with_response(MockResponse::text("Hello, world!"));
let request = CompletionRequest {
messages: vec![],
max_tokens: None,
temperature: None,
stream: false,
tools: None,
disable_thinking: false,
};
let response = provider.complete(request).await.unwrap();
assert_eq!(response.content, "Hello, world!");
assert_eq!(provider.request_count(), 1);
}
#[tokio::test]
async fn test_mock_provider_streaming() {
let provider =
MockProvider::new().with_response(MockResponse::streaming(vec!["Hello, ", "world!"]));
let request = CompletionRequest {
messages: vec![],
max_tokens: None,
temperature: None,
stream: true,
tools: None,
disable_thinking: false,
};
let mut stream = provider.stream(request).await.unwrap();
let mut content = String::new();
let mut chunk_count = 0;
while let Some(chunk) = stream.next().await {
let chunk = chunk.unwrap();
content.push_str(&chunk.content);
chunk_count += 1;
}
assert_eq!(content, "Hello, world!");
assert_eq!(chunk_count, 3); // 2 content chunks + 1 finish chunk
}
#[tokio::test]
async fn test_mock_provider_multi_turn() {
let provider = scenarios::multi_turn(vec!["First response", "Second response"]);
let request = CompletionRequest {
messages: vec![],
max_tokens: None,
temperature: None,
stream: false,
tools: None,
disable_thinking: false,
};
let response1 = provider.complete(request.clone()).await.unwrap();
assert_eq!(response1.content, "First response");
let response2 = provider.complete(request).await.unwrap();
assert_eq!(response2.content, "Second response");
}
#[tokio::test]
async fn test_mock_provider_tool_call() {
let provider = MockProvider::new()
.with_native_tool_calling(true)
.with_response(MockResponse::native_tool_call(
"shell",
serde_json::json!({"command": "ls"}),
));
let request = CompletionRequest {
messages: vec![],
max_tokens: None,
temperature: None,
stream: true,
tools: None,
disable_thinking: false,
};
let mut stream = provider.stream(request).await.unwrap();
let mut found_tool_call = false;
while let Some(chunk) = stream.next().await {
let chunk = chunk.unwrap();
if let Some(tool_calls) = chunk.tool_calls {
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].tool, "shell");
found_tool_call = true;
}
}
assert!(found_tool_call, "Should have received a tool call");
}
#[tokio::test]
async fn test_mock_provider_request_tracking() {
let provider = MockProvider::new().with_default_response(MockResponse::text("OK"));
let request1 = CompletionRequest {
messages: vec![crate::Message::new(crate::MessageRole::User, "Hello".to_string())],
max_tokens: Some(100),
temperature: None,
stream: false,
tools: None,
disable_thinking: false,
};
let request2 = CompletionRequest {
messages: vec![crate::Message::new(
crate::MessageRole::User,
"World".to_string(),
)],
max_tokens: Some(200),
temperature: None,
stream: false,
tools: None,
disable_thinking: false,
};
provider.complete(request1).await.unwrap();
provider.complete(request2).await.unwrap();
let requests = provider.get_requests();
assert_eq!(requests.len(), 2);
assert_eq!(requests[0].max_tokens, Some(100));
assert_eq!(requests[1].max_tokens, Some(200));
}
}