use serde::{Deserialize, Serialize}; use anyhow::Result; use std::collections::HashMap; /// Trait for LLM providers #[async_trait::async_trait] pub trait LLMProvider: Send + Sync { /// Generate a completion for the given messages async fn complete(&self, request: CompletionRequest) -> Result; /// Stream a completion for the given messages async fn stream(&self, request: CompletionRequest) -> Result; /// Get the provider name fn name(&self) -> &str; /// Get the model name fn model(&self) -> &str; /// Check if the provider supports native tool calling fn has_native_tool_calling(&self) -> bool { false } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CompletionRequest { pub messages: Vec, pub max_tokens: Option, pub temperature: Option, pub stream: bool, pub tools: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Message { pub role: MessageRole, pub content: String, } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum MessageRole { System, User, Assistant, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CompletionResponse { pub content: String, pub usage: Usage, pub model: String, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Usage { pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, } pub type CompletionStream = tokio_stream::wrappers::ReceiverStream>; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CompletionChunk { pub content: String, pub finished: bool, pub tool_calls: Option>, pub usage: Option, // Add usage tracking for streaming } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolCall { pub id: String, pub tool: String, pub args: serde_json::Value, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Tool { pub name: String, pub description: String, pub input_schema: serde_json::Value, } pub mod anthropic; pub mod databricks; pub mod embedded; pub mod oauth; pub mod ollama; pub mod openai; pub use anthropic::AnthropicProvider; pub use databricks::DatabricksProvider; pub use embedded::EmbeddedProvider; pub use ollama::OllamaProvider; pub use openai::OpenAIProvider; /// Provider registry for managing multiple LLM providers pub struct ProviderRegistry { providers: HashMap>, default_provider: String, } impl ProviderRegistry { pub fn new() -> Self { Self { providers: HashMap::new(), default_provider: String::new(), } } pub fn register(&mut self, provider: P) { let name = provider.name().to_string(); self.providers.insert(name.clone(), Box::new(provider)); if self.default_provider.is_empty() { self.default_provider = name; } } pub fn set_default(&mut self, provider_name: &str) -> Result<()> { if !self.providers.contains_key(provider_name) { anyhow::bail!("Provider '{}' not found", provider_name); } self.default_provider = provider_name.to_string(); Ok(()) } pub fn get(&self, provider_name: Option<&str>) -> Result<&dyn LLMProvider> { let name = provider_name.unwrap_or(&self.default_provider); self.providers .get(name) .map(|p| p.as_ref()) .ok_or_else(|| anyhow::anyhow!("Provider '{}' not found", name)) } pub fn list_providers(&self) -> Vec<&str> { self.providers.keys().map(|s| s.as_str()).collect() } } impl Default for ProviderRegistry { fn default() -> Self { Self::new() } }