Adds a configurable mock LLM provider that can simulate various behaviors: - Text-only responses (single or multi-chunk streaming) - Native tool calls - JSON tool calls in text - Truncated responses (max_tokens) - Multi-turn conversations Features: - Builder pattern for easy test setup - Request tracking for verification - Preset scenarios for common patterns - Full LLMProvider trait implementation Also adds integration tests that use MockProvider to test the stream_completion_with_tools code path, including: - test_butler_bug_scenario: reproduces the exact bug where text-only responses were not saved to context, causing consecutive user messages This enables testing complex streaming behaviors without real API calls.
537 lines
15 KiB
Rust
537 lines
15 KiB
Rust
mod streaming;
|
|
pub mod mock;
|
|
pub use mock::{MockProvider, MockResponse, MockChunk};
|
|
|
|
pub use streaming::{decode_utf8_streaming, is_incomplete_json_error, make_final_chunk, make_text_chunk, make_tool_chunk};
|
|
|
|
use anyhow::Result;
|
|
use rand::Rng;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
|
|
/// Trait for LLM providers
|
|
#[async_trait::async_trait]
|
|
pub trait LLMProvider: Send + Sync {
|
|
/// Generate a completion for the given messages
|
|
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
|
|
|
|
/// Stream a completion for the given messages
|
|
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream>;
|
|
|
|
/// Get the provider name
|
|
fn name(&self) -> &str;
|
|
|
|
/// Get the model name
|
|
fn model(&self) -> &str;
|
|
|
|
/// Check if the provider supports native tool calling
|
|
fn has_native_tool_calling(&self) -> bool {
|
|
false
|
|
}
|
|
|
|
/// Check if the provider supports cache control
|
|
fn supports_cache_control(&self) -> bool {
|
|
false
|
|
}
|
|
|
|
/// Get the configured max_tokens for this provider
|
|
fn max_tokens(&self) -> u32;
|
|
|
|
/// Get the configured temperature for this provider
|
|
fn temperature(&self) -> f32;
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct CompletionRequest {
|
|
pub messages: Vec<Message>,
|
|
pub max_tokens: Option<u32>,
|
|
pub temperature: Option<f32>,
|
|
pub stream: bool,
|
|
pub tools: Option<Vec<Tool>>,
|
|
/// Force disable thinking mode for this request (used when max_tokens is too low)
|
|
pub disable_thinking: bool,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct CacheControl {
|
|
#[serde(rename = "type")]
|
|
pub cache_type: CacheType,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub ttl: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
#[serde(rename_all = "lowercase")]
|
|
pub enum CacheType {
|
|
Ephemeral,
|
|
}
|
|
|
|
impl CacheControl {
|
|
pub fn ephemeral() -> Self {
|
|
Self {
|
|
cache_type: CacheType::Ephemeral,
|
|
ttl: None,
|
|
}
|
|
}
|
|
|
|
pub fn five_minute() -> Self {
|
|
Self {
|
|
cache_type: CacheType::Ephemeral,
|
|
ttl: Some("5m".to_string()),
|
|
}
|
|
}
|
|
|
|
pub fn one_hour() -> Self {
|
|
Self {
|
|
cache_type: CacheType::Ephemeral,
|
|
ttl: Some("1h".to_string()),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct Message {
|
|
pub role: MessageRole,
|
|
pub content: String,
|
|
#[serde(skip)]
|
|
pub images: Vec<ImageContent>,
|
|
#[serde(skip)]
|
|
pub id: String,
|
|
#[serde(skip)]
|
|
pub kind: MessageKind,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub cache_control: Option<CacheControl>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(rename_all = "lowercase")]
|
|
pub enum MessageRole {
|
|
System,
|
|
User,
|
|
Assistant,
|
|
}
|
|
|
|
/// Special message kinds for context management (ACD)
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
|
|
pub enum MessageKind {
|
|
/// Regular conversation message
|
|
#[default]
|
|
Regular,
|
|
/// Dehydrated context stub (contains fragment reference)
|
|
DehydratedStub,
|
|
/// Summary of dehydrated context (the response that followed dehydration)
|
|
Summary,
|
|
/// Rehydrated content (restored from a fragment)
|
|
Rehydrated,
|
|
}
|
|
|
|
/// Image content for multimodal messages
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ImageContent {
|
|
/// Media type (e.g., "image/png", "image/jpeg", "image/gif", "image/webp")
|
|
pub media_type: String,
|
|
/// Base64-encoded image data
|
|
pub data: String,
|
|
}
|
|
|
|
impl ImageContent {
|
|
pub fn new(media_type: &str, data: String) -> Self {
|
|
Self {
|
|
media_type: media_type.to_string(),
|
|
data,
|
|
}
|
|
}
|
|
|
|
/// Detect media type from file extension
|
|
pub fn media_type_from_extension(ext: &str) -> Option<&'static str> {
|
|
match ext.to_lowercase().as_str() {
|
|
"png" => Some("image/png"),
|
|
"jpg" | "jpeg" => Some("image/jpeg"),
|
|
"gif" => Some("image/gif"),
|
|
"webp" => Some("image/webp"),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
/// Detect media type from image data magic bytes (file signature)
|
|
/// This is more reliable than file extension as it checks actual content
|
|
pub fn media_type_from_bytes(bytes: &[u8]) -> Option<&'static str> {
|
|
if bytes.len() < 12 {
|
|
return None;
|
|
}
|
|
|
|
// PNG: 89 50 4E 47 0D 0A 1A 0A
|
|
if bytes.starts_with(&[0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]) {
|
|
return Some("image/png");
|
|
}
|
|
|
|
// JPEG: FF D8 FF
|
|
if bytes.starts_with(&[0xFF, 0xD8, 0xFF]) {
|
|
return Some("image/jpeg");
|
|
}
|
|
|
|
// GIF: 47 49 46 38 (GIF8)
|
|
if bytes.starts_with(&[0x47, 0x49, 0x46, 0x38]) {
|
|
return Some("image/gif");
|
|
}
|
|
|
|
// WebP: 52 49 46 46 ... 57 45 42 50 (RIFF....WEBP)
|
|
if bytes.starts_with(&[0x52, 0x49, 0x46, 0x46]) && bytes.len() >= 12 && &bytes[8..12] == b"WEBP" {
|
|
return Some("image/webp");
|
|
}
|
|
|
|
None
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct CompletionResponse {
|
|
pub content: String,
|
|
pub usage: Usage,
|
|
pub model: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct Usage {
|
|
pub prompt_tokens: u32,
|
|
pub completion_tokens: u32,
|
|
pub total_tokens: u32,
|
|
}
|
|
|
|
pub type CompletionStream = tokio_stream::wrappers::ReceiverStream<Result<CompletionChunk>>;
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct CompletionChunk {
|
|
pub content: String,
|
|
pub finished: bool,
|
|
pub tool_calls: Option<Vec<ToolCall>>,
|
|
pub usage: Option<Usage>, // Add usage tracking for streaming
|
|
/// Stop reason from the API (e.g., "end_turn", "max_tokens", "stop_sequence")
|
|
pub stop_reason: Option<String>,
|
|
/// Tool call currently being streamed (name only, for UI hint)
|
|
pub tool_call_streaming: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ToolCall {
|
|
pub id: String,
|
|
pub tool: String,
|
|
pub args: serde_json::Value,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct Tool {
|
|
pub name: String,
|
|
pub description: String,
|
|
pub input_schema: serde_json::Value,
|
|
}
|
|
|
|
pub mod anthropic;
|
|
pub mod databricks;
|
|
pub mod embedded;
|
|
pub mod oauth;
|
|
pub mod openai;
|
|
|
|
pub use anthropic::AnthropicProvider;
|
|
pub use databricks::DatabricksProvider;
|
|
pub use embedded::EmbeddedProvider;
|
|
pub use openai::OpenAIProvider;
|
|
|
|
impl Message {
|
|
/// Generate a unique message ID in format HHMMSS-XXX
|
|
/// where XXX are 3 random alphanumeric characters (upper and lowercase)
|
|
fn generate_id() -> String {
|
|
let now = chrono::Local::now();
|
|
let timestamp = now.format("%H%M%S").to_string();
|
|
|
|
let mut rng = rand::thread_rng();
|
|
let random_chars: String = (0..3)
|
|
.map(|_| {
|
|
let chars = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
|
|
let idx = rng.gen_range(0..chars.len());
|
|
chars[idx] as char
|
|
})
|
|
.collect();
|
|
|
|
format!("{}-{}", timestamp, random_chars)
|
|
}
|
|
|
|
/// Create a new message with optional cache control
|
|
pub fn new(role: MessageRole, content: String) -> Self {
|
|
Self {
|
|
role,
|
|
content,
|
|
images: Vec::new(),
|
|
id: Self::generate_id(),
|
|
kind: MessageKind::Regular,
|
|
cache_control: None,
|
|
}
|
|
}
|
|
|
|
/// Create a new message with cache control
|
|
pub fn with_cache_control(
|
|
role: MessageRole,
|
|
content: String,
|
|
cache_control: CacheControl,
|
|
) -> Self {
|
|
Self {
|
|
role,
|
|
content,
|
|
images: Vec::new(),
|
|
id: Self::generate_id(),
|
|
kind: MessageKind::Regular,
|
|
cache_control: Some(cache_control),
|
|
}
|
|
}
|
|
|
|
/// Create a new message with a specific kind (for ACD)
|
|
pub fn with_kind(role: MessageRole, content: String, kind: MessageKind) -> Self {
|
|
Self {
|
|
role,
|
|
content,
|
|
images: Vec::new(),
|
|
id: Self::generate_id(),
|
|
kind,
|
|
cache_control: None,
|
|
}
|
|
}
|
|
|
|
/// Check if this message is a dehydrated stub
|
|
pub fn is_dehydrated_stub(&self) -> bool {
|
|
self.kind == MessageKind::DehydratedStub
|
|
}
|
|
|
|
/// Check if this message is a summary
|
|
pub fn is_summary(&self) -> bool {
|
|
self.kind == MessageKind::Summary
|
|
}
|
|
|
|
/// Create a message with cache control, with provider validation
|
|
pub fn with_cache_control_validated(
|
|
role: MessageRole,
|
|
content: String,
|
|
cache_control: CacheControl,
|
|
provider: &dyn LLMProvider,
|
|
) -> Self {
|
|
if !provider.supports_cache_control() {
|
|
tracing::warn!(
|
|
"Cache control requested for provider '{}' which does not support it. \
|
|
Cache control is only supported by Anthropic and Anthropic via Databricks.",
|
|
provider.name()
|
|
);
|
|
return Self::new(role, content);
|
|
}
|
|
|
|
Self::with_cache_control(role, content, cache_control)
|
|
}
|
|
}
|
|
|
|
/// Provider registry for managing multiple LLM providers
|
|
pub struct ProviderRegistry {
|
|
providers: HashMap<String, Box<dyn LLMProvider>>,
|
|
default_provider: String,
|
|
}
|
|
|
|
impl ProviderRegistry {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
providers: HashMap::new(),
|
|
default_provider: String::new(),
|
|
}
|
|
}
|
|
|
|
pub fn register<P: LLMProvider + 'static>(&mut self, provider: P) {
|
|
let name = provider.name().to_string();
|
|
self.providers.insert(name.clone(), Box::new(provider));
|
|
|
|
if self.default_provider.is_empty() {
|
|
self.default_provider = name;
|
|
}
|
|
}
|
|
|
|
pub fn set_default(&mut self, provider_name: &str) -> Result<()> {
|
|
if !self.providers.contains_key(provider_name) {
|
|
anyhow::bail!("Provider '{}' not found", provider_name);
|
|
}
|
|
self.default_provider = provider_name.to_string();
|
|
Ok(())
|
|
}
|
|
|
|
pub fn get(&self, provider_name: Option<&str>) -> Result<&dyn LLMProvider> {
|
|
let name = provider_name.unwrap_or(&self.default_provider);
|
|
self.providers
|
|
.get(name)
|
|
.map(|p| p.as_ref())
|
|
.ok_or_else(|| anyhow::anyhow!("Provider '{}' not found", name))
|
|
}
|
|
|
|
pub fn list_providers(&self) -> Vec<&str> {
|
|
self.providers.keys().map(|s| s.as_str()).collect()
|
|
}
|
|
}
|
|
|
|
impl Default for ProviderRegistry {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_message_serialization_without_cache_control() {
|
|
let msg = Message::new(MessageRole::User, "Hello".to_string());
|
|
let json = serde_json::to_string(&msg).unwrap();
|
|
|
|
println!("Message JSON without cache_control: {}", json);
|
|
assert!(
|
|
!json.contains("cache_control"),
|
|
"JSON should not contain 'cache_control' field when not configured"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_message_serialization_with_cache_control() {
|
|
let msg = Message::with_cache_control(
|
|
MessageRole::User,
|
|
"Hello".to_string(),
|
|
CacheControl::ephemeral(),
|
|
);
|
|
let json = serde_json::to_string(&msg).unwrap();
|
|
|
|
println!("Message JSON with cache_control: {}", json);
|
|
assert!(
|
|
json.contains("cache_control"),
|
|
"JSON should contain 'cache_control' field when configured"
|
|
);
|
|
assert!(
|
|
json.contains("ephemeral"),
|
|
"JSON should contain 'ephemeral' value"
|
|
);
|
|
assert!(
|
|
json.contains("\"type\":"),
|
|
"JSON should contain 'type' field in cache_control"
|
|
);
|
|
assert!(
|
|
!json.contains("null"),
|
|
"JSON should not contain null values"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_cache_control_five_minute_serialization() {
|
|
let msg = Message::with_cache_control(
|
|
MessageRole::User,
|
|
"Hello".to_string(),
|
|
CacheControl::five_minute(),
|
|
);
|
|
let json = serde_json::to_string(&msg).unwrap();
|
|
|
|
println!("Message JSON with 5-minute cache_control: {}", json);
|
|
assert!(
|
|
json.contains("cache_control"),
|
|
"JSON should contain 'cache_control' field"
|
|
);
|
|
assert!(
|
|
json.contains("ephemeral"),
|
|
"JSON should contain 'ephemeral' type"
|
|
);
|
|
assert!(
|
|
json.contains("\"ttl\":\"5m\""),
|
|
"JSON should contain ttl field with 5m value"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_cache_control_one_hour_serialization() {
|
|
let msg = Message::with_cache_control(
|
|
MessageRole::User,
|
|
"Hello".to_string(),
|
|
CacheControl::one_hour(),
|
|
);
|
|
let json = serde_json::to_string(&msg).unwrap();
|
|
|
|
println!("Message JSON with 1-hour cache_control: {}", json);
|
|
assert!(
|
|
json.contains("cache_control"),
|
|
"JSON should contain 'cache_control' field"
|
|
);
|
|
assert!(
|
|
json.contains("ephemeral"),
|
|
"JSON should contain 'ephemeral' type"
|
|
);
|
|
assert!(
|
|
json.contains("\"ttl\":\"1h\""),
|
|
"JSON should contain ttl field with 1h value"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_message_id_generation() {
|
|
let msg = Message::new(MessageRole::User, "Hello".to_string());
|
|
|
|
// Check that id is not empty
|
|
assert!(!msg.id.is_empty(), "Message ID should not be empty");
|
|
|
|
// Check format: HHMMSS-XXX
|
|
let parts: Vec<&str> = msg.id.split('-').collect();
|
|
assert_eq!(parts.len(), 2, "Message ID should have format HHMMSS-XXX");
|
|
|
|
// Check timestamp part is 6 digits
|
|
assert_eq!(parts[0].len(), 6, "Timestamp should be 6 digits (HHMMSS)");
|
|
assert!(
|
|
parts[0].chars().all(|c| c.is_ascii_digit()),
|
|
"Timestamp should be all digits"
|
|
);
|
|
|
|
// Check random part is 3 alpha characters
|
|
assert_eq!(parts[1].len(), 3, "Random part should be 3 characters");
|
|
assert!(
|
|
parts[1].chars().all(|c| c.is_ascii_alphabetic()),
|
|
"Random part should be all alphabetic characters"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_message_id_uniqueness() {
|
|
let msg1 = Message::new(MessageRole::User, "Hello".to_string());
|
|
let msg2 = Message::new(MessageRole::User, "Hello".to_string());
|
|
|
|
// IDs should be different (due to random component)
|
|
// Note: There's a tiny chance they could be the same, but very unlikely
|
|
println!("msg1.id: {}, msg2.id: {}", msg1.id, msg2.id);
|
|
}
|
|
|
|
#[test]
|
|
fn test_message_id_not_serialized() {
|
|
let msg = Message::new(MessageRole::User, "Hello".to_string());
|
|
let json = serde_json::to_string(&msg).unwrap();
|
|
|
|
println!("Message JSON: {}", json);
|
|
assert!(
|
|
!json.contains("\"id\""),
|
|
"JSON should not contain 'id' field"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_message_with_cache_control_has_id() {
|
|
let msg = Message::with_cache_control(
|
|
MessageRole::User,
|
|
"Hello".to_string(),
|
|
CacheControl::ephemeral(),
|
|
);
|
|
|
|
assert!(
|
|
!msg.id.is_empty(),
|
|
"Message with cache control should have an ID"
|
|
);
|
|
assert!(
|
|
msg.id.contains('-'),
|
|
"Message ID should contain hyphen separator"
|
|
);
|
|
}
|
|
}
|