adds cache_control

This commit is contained in:
Jochen
2025-11-18 22:38:52 +11:00
parent 39efa24c55
commit 296bf5a449
7 changed files with 466 additions and 125 deletions

View File

@@ -21,6 +21,11 @@ pub trait LLMProvider: Send + Sync {
fn has_native_tool_calling(&self) -> bool {
false
}
/// Check if the provider supports cache control
fn supports_cache_control(&self) -> bool {
false
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -32,10 +37,22 @@ pub struct CompletionRequest {
pub tools: Option<Vec<Tool>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum CacheControl {
Ephemeral,
#[serde(rename = "5minute")]
FiveMinute,
#[serde(rename = "1hour")]
OneHour,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: MessageRole,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_control: Option<CacheControl>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -95,6 +112,45 @@ pub use databricks::DatabricksProvider;
pub use embedded::EmbeddedProvider;
pub use openai::OpenAIProvider;
impl Message {
/// Create a new message with optional cache control
pub fn new(role: MessageRole, content: String) -> Self {
Self {
role,
content,
cache_control: None,
}
}
/// Create a new message with cache control
pub fn with_cache_control(role: MessageRole, content: String, cache_control: CacheControl) -> Self {
Self {
role,
content,
cache_control: Some(cache_control),
}
}
/// Create a message with cache control, with provider validation
pub fn with_cache_control_validated(
role: MessageRole,
content: String,
cache_control: CacheControl,
provider: &dyn LLMProvider
) -> Self {
if !provider.supports_cache_control() {
tracing::warn!(
"Cache control requested for provider '{}' which does not support it. \
Cache control is only supported by Anthropic and Anthropic via Databricks.",
provider.name()
);
return Self::new(role, content);
}
Self::with_cache_control(role, content, cache_control)
}
}
/// Provider registry for managing multiple LLM providers
pub struct ProviderRegistry {
providers: HashMap<String, Box<dyn LLMProvider>>,
@@ -144,3 +200,36 @@ impl Default for ProviderRegistry {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_serialization_without_cache_control() {
let msg = Message::new(MessageRole::User, "Hello".to_string());
let json = serde_json::to_string(&msg).unwrap();
println!("Message JSON without cache_control: {}", json);
assert!(!json.contains("cache_control"),
"JSON should not contain 'cache_control' field when not configured");
}
#[test]
fn test_message_serialization_with_cache_control() {
let msg = Message::with_cache_control(
MessageRole::User,
"Hello".to_string(),
CacheControl::Ephemeral,
);
let json = serde_json::to_string(&msg).unwrap();
println!("Message JSON with cache_control: {}", json);
assert!(json.contains("cache_control"),
"JSON should contain 'cache_control' field when configured");
assert!(json.contains("ephemeral"),
"JSON should contain 'ephemeral' value");
assert!(!json.contains("null"),
"JSON should not contain null values");
}
}