Files
g3/crates/g3-providers/src/lib.rs
Jochen 52f78653b4 add context window monitor
Writes the current context window to logs/current_context_window (uses a symlink to a session ID).

This PR was unfortunately generated by a different LLM and did a ton of superficial reformating, it's actually a fairly small and benign change, but I don't want to roll back everything. Hope that's ok.
2025-11-27 21:00:02 +11:00

422 lines
12 KiB
Rust

use anyhow::Result;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Trait for LLM providers
#[async_trait::async_trait]
pub trait LLMProvider: Send + Sync {
/// Generate a completion for the given messages
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
/// Stream a completion for the given messages
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream>;
/// Get the provider name
fn name(&self) -> &str;
/// Get the model name
fn model(&self) -> &str;
/// Check if the provider supports native tool calling
fn has_native_tool_calling(&self) -> bool {
false
}
/// Check if the provider supports cache control
fn supports_cache_control(&self) -> bool {
false
}
/// Get the configured max_tokens for this provider
fn max_tokens(&self) -> u32;
/// Get the configured temperature for this provider
fn temperature(&self) -> f32;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionRequest {
pub messages: Vec<Message>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub stream: bool,
pub tools: Option<Vec<Tool>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheControl {
#[serde(rename = "type")]
pub cache_type: CacheType,
#[serde(skip_serializing_if = "Option::is_none")]
pub ttl: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum CacheType {
Ephemeral,
}
impl CacheControl {
pub fn ephemeral() -> Self {
Self {
cache_type: CacheType::Ephemeral,
ttl: None,
}
}
pub fn five_minute() -> Self {
Self {
cache_type: CacheType::Ephemeral,
ttl: Some("5m".to_string()),
}
}
pub fn one_hour() -> Self {
Self {
cache_type: CacheType::Ephemeral,
ttl: Some("1h".to_string()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: MessageRole,
pub content: String,
#[serde(skip)]
pub id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_control: Option<CacheControl>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
User,
Assistant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionResponse {
pub content: String,
pub usage: Usage,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
pub type CompletionStream = tokio_stream::wrappers::ReceiverStream<Result<CompletionChunk>>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionChunk {
pub content: String,
pub finished: bool,
pub tool_calls: Option<Vec<ToolCall>>,
pub usage: Option<Usage>, // Add usage tracking for streaming
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub tool: String,
pub args: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
}
pub mod anthropic;
pub mod databricks;
pub mod embedded;
pub mod oauth;
pub mod openai;
pub use anthropic::AnthropicProvider;
pub use databricks::DatabricksProvider;
pub use embedded::EmbeddedProvider;
pub use openai::OpenAIProvider;
impl Message {
/// Generate a unique message ID in format HHMMSS-XXX
/// where XXX are 3 random alphanumeric characters (upper and lowercase)
fn generate_id() -> String {
let now = chrono::Local::now();
let timestamp = now.format("%H%M%S").to_string();
let mut rng = rand::thread_rng();
let random_chars: String = (0..3)
.map(|_| {
let chars = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
let idx = rng.gen_range(0..chars.len());
chars[idx] as char
})
.collect();
format!("{}-{}", timestamp, random_chars)
}
/// Create a new message with optional cache control
pub fn new(role: MessageRole, content: String) -> Self {
Self {
role,
content,
id: Self::generate_id(),
cache_control: None,
}
}
/// Create a new message with cache control
pub fn with_cache_control(
role: MessageRole,
content: String,
cache_control: CacheControl,
) -> Self {
Self {
role,
content,
id: Self::generate_id(),
cache_control: Some(cache_control),
}
}
/// Create a message with cache control, with provider validation
pub fn with_cache_control_validated(
role: MessageRole,
content: String,
cache_control: CacheControl,
provider: &dyn LLMProvider,
) -> Self {
if !provider.supports_cache_control() {
tracing::warn!(
"Cache control requested for provider '{}' which does not support it. \
Cache control is only supported by Anthropic and Anthropic via Databricks.",
provider.name()
);
return Self::new(role, content);
}
Self::with_cache_control(role, content, cache_control)
}
}
/// Provider registry for managing multiple LLM providers
pub struct ProviderRegistry {
providers: HashMap<String, Box<dyn LLMProvider>>,
default_provider: String,
}
impl ProviderRegistry {
pub fn new() -> Self {
Self {
providers: HashMap::new(),
default_provider: String::new(),
}
}
pub fn register<P: LLMProvider + 'static>(&mut self, provider: P) {
let name = provider.name().to_string();
self.providers.insert(name.clone(), Box::new(provider));
if self.default_provider.is_empty() {
self.default_provider = name;
}
}
pub fn set_default(&mut self, provider_name: &str) -> Result<()> {
if !self.providers.contains_key(provider_name) {
anyhow::bail!("Provider '{}' not found", provider_name);
}
self.default_provider = provider_name.to_string();
Ok(())
}
pub fn get(&self, provider_name: Option<&str>) -> Result<&dyn LLMProvider> {
let name = provider_name.unwrap_or(&self.default_provider);
self.providers
.get(name)
.map(|p| p.as_ref())
.ok_or_else(|| anyhow::anyhow!("Provider '{}' not found", name))
}
pub fn list_providers(&self) -> Vec<&str> {
self.providers.keys().map(|s| s.as_str()).collect()
}
}
impl Default for ProviderRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_serialization_without_cache_control() {
let msg = Message::new(MessageRole::User, "Hello".to_string());
let json = serde_json::to_string(&msg).unwrap();
println!("Message JSON without cache_control: {}", json);
assert!(
!json.contains("cache_control"),
"JSON should not contain 'cache_control' field when not configured"
);
}
#[test]
fn test_message_serialization_with_cache_control() {
let msg = Message::with_cache_control(
MessageRole::User,
"Hello".to_string(),
CacheControl::ephemeral(),
);
let json = serde_json::to_string(&msg).unwrap();
println!("Message JSON with cache_control: {}", json);
assert!(
json.contains("cache_control"),
"JSON should contain 'cache_control' field when configured"
);
assert!(
json.contains("ephemeral"),
"JSON should contain 'ephemeral' value"
);
assert!(
json.contains("\"type\":"),
"JSON should contain 'type' field in cache_control"
);
assert!(
!json.contains("null"),
"JSON should not contain null values"
);
}
#[test]
fn test_cache_control_five_minute_serialization() {
let msg = Message::with_cache_control(
MessageRole::User,
"Hello".to_string(),
CacheControl::five_minute(),
);
let json = serde_json::to_string(&msg).unwrap();
println!("Message JSON with 5-minute cache_control: {}", json);
assert!(
json.contains("cache_control"),
"JSON should contain 'cache_control' field"
);
assert!(
json.contains("ephemeral"),
"JSON should contain 'ephemeral' type"
);
assert!(
json.contains("\"ttl\":\"5m\""),
"JSON should contain ttl field with 5m value"
);
}
#[test]
fn test_cache_control_one_hour_serialization() {
let msg = Message::with_cache_control(
MessageRole::User,
"Hello".to_string(),
CacheControl::one_hour(),
);
let json = serde_json::to_string(&msg).unwrap();
println!("Message JSON with 1-hour cache_control: {}", json);
assert!(
json.contains("cache_control"),
"JSON should contain 'cache_control' field"
);
assert!(
json.contains("ephemeral"),
"JSON should contain 'ephemeral' type"
);
assert!(
json.contains("\"ttl\":\"1h\""),
"JSON should contain ttl field with 1h value"
);
}
#[test]
fn test_message_id_generation() {
let msg = Message::new(MessageRole::User, "Hello".to_string());
// Check that id is not empty
assert!(!msg.id.is_empty(), "Message ID should not be empty");
// Check format: HHMMSS-XXX
let parts: Vec<&str> = msg.id.split('-').collect();
assert_eq!(parts.len(), 2, "Message ID should have format HHMMSS-XXX");
// Check timestamp part is 6 digits
assert_eq!(parts[0].len(), 6, "Timestamp should be 6 digits (HHMMSS)");
assert!(
parts[0].chars().all(|c| c.is_ascii_digit()),
"Timestamp should be all digits"
);
// Check random part is 3 alpha characters
assert_eq!(parts[1].len(), 3, "Random part should be 3 characters");
assert!(
parts[1].chars().all(|c| c.is_ascii_alphabetic()),
"Random part should be all alphabetic characters"
);
}
#[test]
fn test_message_id_uniqueness() {
let msg1 = Message::new(MessageRole::User, "Hello".to_string());
let msg2 = Message::new(MessageRole::User, "Hello".to_string());
// IDs should be different (due to random component)
// Note: There's a tiny chance they could be the same, but very unlikely
println!("msg1.id: {}, msg2.id: {}", msg1.id, msg2.id);
}
#[test]
fn test_message_id_not_serialized() {
let msg = Message::new(MessageRole::User, "Hello".to_string());
let json = serde_json::to_string(&msg).unwrap();
println!("Message JSON: {}", json);
assert!(
!json.contains("\"id\""),
"JSON should not contain 'id' field"
);
}
#[test]
fn test_message_with_cache_control_has_id() {
let msg = Message::with_cache_control(
MessageRole::User,
"Hello".to_string(),
CacheControl::ephemeral(),
);
assert!(
!msg.id.is_empty(),
"Message with cache control should have an ID"
);
assert!(
msg.id.contains('-'),
"Message ID should contain hyphen separator"
);
}
}