use serde::{Deserialize, Serialize}; use anyhow::Result; use std::collections::HashMap; /// Trait for LLM providers #[async_trait::async_trait] pub trait LLMProvider: Send + Sync { /// Generate a completion for the given messages async fn complete(&self, request: CompletionRequest) -> Result; /// Stream a completion for the given messages async fn stream(&self, request: CompletionRequest) -> Result; /// Get the provider name fn name(&self) -> &str; /// Get the model name fn model(&self) -> &str; /// Check if the provider supports native tool calling fn has_native_tool_calling(&self) -> bool { false } /// Check if the provider supports cache control fn supports_cache_control(&self) -> bool { false } /// Get the configured max_tokens for this provider fn max_tokens(&self) -> u32; /// Get the configured temperature for this provider fn temperature(&self) -> f32; } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CompletionRequest { pub messages: Vec, pub max_tokens: Option, pub temperature: Option, pub stream: bool, pub tools: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CacheControl { #[serde(rename = "type")] pub cache_type: CacheType, #[serde(skip_serializing_if = "Option::is_none")] pub ttl: Option, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "lowercase")] pub enum CacheType { Ephemeral, } impl CacheControl { pub fn ephemeral() -> Self { Self { cache_type: CacheType::Ephemeral, ttl: None } } pub fn five_minute() -> Self { Self { cache_type: CacheType::Ephemeral, ttl: Some("5m".to_string()) } } pub fn one_hour() -> Self { Self { cache_type: CacheType::Ephemeral, ttl: Some("1h".to_string()) } } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Message { pub role: MessageRole, pub content: String, #[serde(skip_serializing_if = "Option::is_none")] pub cache_control: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum MessageRole { System, User, Assistant, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CompletionResponse { pub content: String, pub usage: Usage, pub model: String, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Usage { pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, } pub type CompletionStream = tokio_stream::wrappers::ReceiverStream>; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CompletionChunk { pub content: String, pub finished: bool, pub tool_calls: Option>, pub usage: Option, // Add usage tracking for streaming } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolCall { pub id: String, pub tool: String, pub args: serde_json::Value, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Tool { pub name: String, pub description: String, pub input_schema: serde_json::Value, } pub mod anthropic; pub mod databricks; pub mod embedded; pub mod oauth; pub mod openai; pub use anthropic::AnthropicProvider; pub use databricks::DatabricksProvider; pub use embedded::EmbeddedProvider; pub use openai::OpenAIProvider; impl Message { /// Create a new message with optional cache control pub fn new(role: MessageRole, content: String) -> Self { Self { role, content, cache_control: None, } } /// Create a new message with cache control pub fn with_cache_control(role: MessageRole, content: String, cache_control: CacheControl) -> Self { Self { role, content, cache_control: Some(cache_control), } } /// Create a message with cache control, with provider validation pub fn with_cache_control_validated( role: MessageRole, content: String, cache_control: CacheControl, provider: &dyn LLMProvider ) -> Self { if !provider.supports_cache_control() { tracing::warn!( "Cache control requested for provider '{}' which does not support it. \ Cache control is only supported by Anthropic and Anthropic via Databricks.", provider.name() ); return Self::new(role, content); } Self::with_cache_control(role, content, cache_control) } } /// Provider registry for managing multiple LLM providers pub struct ProviderRegistry { providers: HashMap>, default_provider: String, } impl ProviderRegistry { pub fn new() -> Self { Self { providers: HashMap::new(), default_provider: String::new(), } } pub fn register(&mut self, provider: P) { let name = provider.name().to_string(); self.providers.insert(name.clone(), Box::new(provider)); if self.default_provider.is_empty() { self.default_provider = name; } } pub fn set_default(&mut self, provider_name: &str) -> Result<()> { if !self.providers.contains_key(provider_name) { anyhow::bail!("Provider '{}' not found", provider_name); } self.default_provider = provider_name.to_string(); Ok(()) } pub fn get(&self, provider_name: Option<&str>) -> Result<&dyn LLMProvider> { let name = provider_name.unwrap_or(&self.default_provider); self.providers .get(name) .map(|p| p.as_ref()) .ok_or_else(|| anyhow::anyhow!("Provider '{}' not found", name)) } pub fn list_providers(&self) -> Vec<&str> { self.providers.keys().map(|s| s.as_str()).collect() } } impl Default for ProviderRegistry { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_message_serialization_without_cache_control() { let msg = Message::new(MessageRole::User, "Hello".to_string()); let json = serde_json::to_string(&msg).unwrap(); println!("Message JSON without cache_control: {}", json); assert!(!json.contains("cache_control"), "JSON should not contain 'cache_control' field when not configured"); } #[test] fn test_message_serialization_with_cache_control() { let msg = Message::with_cache_control( MessageRole::User, "Hello".to_string(), CacheControl::ephemeral(), ); let json = serde_json::to_string(&msg).unwrap(); println!("Message JSON with cache_control: {}", json); assert!(json.contains("cache_control"), "JSON should contain 'cache_control' field when configured"); assert!(json.contains("ephemeral"), "JSON should contain 'ephemeral' value"); assert!(json.contains("\"type\":"), "JSON should contain 'type' field in cache_control"); assert!(!json.contains("null"), "JSON should not contain null values"); } #[test] fn test_cache_control_five_minute_serialization() { let msg = Message::with_cache_control( MessageRole::User, "Hello".to_string(), CacheControl::five_minute(), ); let json = serde_json::to_string(&msg).unwrap(); println!("Message JSON with 5-minute cache_control: {}", json); assert!(json.contains("cache_control"), "JSON should contain 'cache_control' field"); assert!(json.contains("ephemeral"), "JSON should contain 'ephemeral' type"); assert!(json.contains("\"ttl\":\"5m\""), "JSON should contain ttl field with 5m value"); } #[test] fn test_cache_control_one_hour_serialization() { let msg = Message::with_cache_control( MessageRole::User, "Hello".to_string(), CacheControl::one_hour(), ); let json = serde_json::to_string(&msg).unwrap(); println!("Message JSON with 1-hour cache_control: {}", json); assert!(json.contains("cache_control"), "JSON should contain 'cache_control' field"); assert!(json.contains("ephemeral"), "JSON should contain 'ephemeral' type"); assert!(json.contains("\"ttl\":\"1h\""), "JSON should contain ttl field with 1h value"); } }