diff --git a/crates/g3-core/src/lib.rs b/crates/g3-core/src/lib.rs index 7839d63..0ee85b2 100644 --- a/crates/g3-core/src/lib.rs +++ b/crates/g3-core/src/lib.rs @@ -1004,9 +1004,9 @@ impl Agent { /// Convert cache config string to CacheControl enum fn parse_cache_control(cache_config: &str) -> Option { match cache_config { - "ephemeral" => Some(CacheControl::Ephemeral), - "5minute" => Some(CacheControl::FiveMinute), - "1hour" => Some(CacheControl::OneHour), + "ephemeral" => Some(CacheControl::ephemeral()), + "5minute" => Some(CacheControl::five_minute()), + "1hour" => Some(CacheControl::one_hour()), _ => { warn!("Invalid cache_config value: '{}'. Valid values are: ephemeral, 5minute, 1hour", cache_config); None diff --git a/crates/g3-providers/src/anthropic.rs b/crates/g3-providers/src/anthropic.rs index de90e46..69ac66f 100644 --- a/crates/g3-providers/src/anthropic.rs +++ b/crates/g3-providers/src/anthropic.rs @@ -172,15 +172,9 @@ impl AnthropicProvider { builder } - fn convert_cache_control(cache_control: &crate::CacheControl) -> AnthropicCacheControl { - let cache_type = match cache_control { - crate::CacheControl::Ephemeral => "ephemeral", - crate::CacheControl::FiveMinute => "5minute", - crate::CacheControl::OneHour => "1hour", - }; - AnthropicCacheControl { - cache_type: cache_type.to_string(), - } + fn convert_cache_control(cache_control: &crate::CacheControl) -> crate::CacheControl { + // Anthropic uses the same format, so just clone it + cache_control.clone() } fn convert_tools(&self, tools: &[Tool]) -> Vec { @@ -723,12 +717,6 @@ struct AnthropicMessage { content: Vec, } -#[derive(Debug, Serialize, Deserialize)] -struct AnthropicCacheControl { - #[serde(rename = "type")] - cache_type: String, -} - #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type")] enum AnthropicContent { @@ -736,7 +724,7 @@ enum AnthropicContent { Text { text: String, #[serde(skip_serializing_if = "Option::is_none")] - cache_control: Option, + cache_control: Option, }, #[serde(rename = "tool_use")] ToolUse { @@ -916,7 +904,7 @@ mod tests { let messages_with = vec![Message::with_cache_control( MessageRole::User, "Hello".to_string(), - crate::CacheControl::Ephemeral, + crate::CacheControl::ephemeral(), )]; let (_, anthropic_messages_with) = provider.convert_messages(&messages_with).unwrap(); let json_with = serde_json::to_string(&anthropic_messages_with).unwrap(); diff --git a/crates/g3-providers/src/databricks.rs b/crates/g3-providers/src/databricks.rs index 08cddae..635653c 100644 --- a/crates/g3-providers/src/databricks.rs +++ b/crates/g3-providers/src/databricks.rs @@ -238,13 +238,9 @@ impl DatabricksProvider { .collect() } - fn convert_cache_control(cache_control: &crate::CacheControl) -> DatabricksCacheControl { - let cache_type = match cache_control { - crate::CacheControl::Ephemeral => "ephemeral", - crate::CacheControl::FiveMinute => "5minute", - crate::CacheControl::OneHour => "1hour", - }; - DatabricksCacheControl { cache_type: cache_type.to_string() } + fn convert_cache_control(cache_control: &crate::CacheControl) -> crate::CacheControl { + // Databricks uses the same format, so just clone it + cache_control.clone() } fn convert_messages(&self, messages: &[Message]) -> Result> { @@ -1104,12 +1100,6 @@ struct DatabricksFunction { parameters: serde_json::Value, } -#[derive(Debug, Serialize, Deserialize)] -struct DatabricksCacheControl { - #[serde(rename = "type")] - cache_type: String, -} - #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] enum DatabricksContent { @@ -1118,7 +1108,7 @@ enum DatabricksContent { content_type: String, text: String, #[serde(skip_serializing_if = "Option::is_none")] - cache_control: Option, + cache_control: Option, }, } @@ -1345,7 +1335,7 @@ mod tests { let messages_with = vec![Message::with_cache_control( MessageRole::User, "Hello".to_string(), - crate::CacheControl::Ephemeral, + crate::CacheControl::ephemeral(), )]; let databricks_messages_with = provider.convert_messages(&messages_with).unwrap(); let json_with = serde_json::to_string(&databricks_messages_with).unwrap(); diff --git a/crates/g3-providers/src/lib.rs b/crates/g3-providers/src/lib.rs index 8c80b64..f725c2f 100644 --- a/crates/g3-providers/src/lib.rs +++ b/crates/g3-providers/src/lib.rs @@ -38,13 +38,31 @@ pub struct CompletionRequest { } #[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CacheControl { + #[serde(rename = "type")] + pub cache_type: CacheType, + #[serde(skip_serializing_if = "Option::is_none")] + pub ttl: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "lowercase")] -pub enum CacheControl { +pub enum CacheType { Ephemeral, - #[serde(rename = "5minute")] - FiveMinute, - #[serde(rename = "1hour")] - OneHour, +} + +impl CacheControl { + pub fn ephemeral() -> Self { + Self { cache_type: CacheType::Ephemeral, ttl: None } + } + + pub fn five_minute() -> Self { + Self { cache_type: CacheType::Ephemeral, ttl: Some("5m".to_string()) } + } + + pub fn one_hour() -> Self { + Self { cache_type: CacheType::Ephemeral, ttl: Some("1h".to_string()) } + } } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -220,7 +238,7 @@ mod tests { let msg = Message::with_cache_control( MessageRole::User, "Hello".to_string(), - CacheControl::Ephemeral, + CacheControl::ephemeral(), ); let json = serde_json::to_string(&msg).unwrap(); @@ -229,7 +247,39 @@ mod tests { "JSON should contain 'cache_control' field when configured"); assert!(json.contains("ephemeral"), "JSON should contain 'ephemeral' value"); + assert!(json.contains("\"type\":"), + "JSON should contain 'type' field in cache_control"); assert!(!json.contains("null"), "JSON should not contain null values"); } + + #[test] + fn test_cache_control_five_minute_serialization() { + let msg = Message::with_cache_control( + MessageRole::User, + "Hello".to_string(), + CacheControl::five_minute(), + ); + let json = serde_json::to_string(&msg).unwrap(); + + println!("Message JSON with 5-minute cache_control: {}", json); + assert!(json.contains("cache_control"), "JSON should contain 'cache_control' field"); + assert!(json.contains("ephemeral"), "JSON should contain 'ephemeral' type"); + assert!(json.contains("\"ttl\":\"5m\""), "JSON should contain ttl field with 5m value"); + } + + #[test] + fn test_cache_control_one_hour_serialization() { + let msg = Message::with_cache_control( + MessageRole::User, + "Hello".to_string(), + CacheControl::one_hour(), + ); + let json = serde_json::to_string(&msg).unwrap(); + + println!("Message JSON with 1-hour cache_control: {}", json); + assert!(json.contains("cache_control"), "JSON should contain 'cache_control' field"); + assert!(json.contains("ephemeral"), "JSON should contain 'ephemeral' type"); + assert!(json.contains("\"ttl\":\"1h\""), "JSON should contain ttl field with 1h value"); + } }