diff --git a/config.coach-player.example.toml b/config.coach-player.example.toml index 999b674..50d7943 100644 --- a/config.coach-player.example.toml +++ b/config.coach-player.example.toml @@ -11,12 +11,23 @@ model = "databricks-claude-sonnet-4" max_tokens = 4096 temperature = 0.1 use_oauth = true +# cache_config = "ephemeral" # Optional: Enable prompt caching for Claude models + # Options: "ephemeral", "5minute", "1hour" + # Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs. + # The cache control will be automatically applied to: + # - The system prompt at the start of each session + # - Assistant responses after every 10 tool calls + # - 5minute costs $3/mtok, more details below + # https://docs.claude.com/en/docs/build-with-claude/prompt-caching#pricing [providers.anthropic] api_key = "your-anthropic-api-key" model = "claude-3-haiku-20240307" # Using a faster model for player max_tokens = 4096 temperature = 0.3 # Slightly higher temperature for more creative implementations +# cache_config = "ephemeral" # Optional: Enable prompt caching + # Options: "ephemeral", "5minute", "1hour" + # Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs. [agent] fallback_default_max_tokens = 8192 diff --git a/config.example.toml b/config.example.toml index 1bc0893..b0fce4e 100644 --- a/config.example.toml +++ b/config.example.toml @@ -14,6 +14,15 @@ max_tokens = 4096 # Per-request output limit (how many tokens the model can gen # Note: This is different from max_context_length (total conversation history size) temperature = 0.1 use_oauth = true +# cache_config = "ephemeral" # Optional: Enable prompt caching for Claude models on Databricks + # Options: "ephemeral", "5minute", "1hour" + # Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs. + # The cache control will be automatically applied to: + # - The system prompt at the start of each session + # - Assistant responses after every 10 tool calls + # - 5minute costs $3/mtok, more details below + # https://docs.claude.com/en/docs/build-with-claude/prompt-caching#pricing + # Multiple OpenAI-compatible providers can be configured with custom names # Each provider gets its own section under [providers.openai_compatible.] diff --git a/crates/g3-config/src/lib.rs b/crates/g3-config/src/lib.rs index c860481..ae3a242 100644 --- a/crates/g3-config/src/lib.rs +++ b/crates/g3-config/src/lib.rs @@ -40,6 +40,8 @@ pub struct AnthropicConfig { pub model: String, pub max_tokens: Option, pub temperature: Option, + pub cache_config: Option, // "ephemeral", "5minute", "1hour", or None to disable + pub enable_1m_context: Option, // Enable 1m context window (costs extra) } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -49,6 +51,7 @@ pub struct DatabricksConfig { pub model: String, pub max_tokens: Option, pub temperature: Option, + pub cache_config: Option, // "ephemeral", "5minute", "1hour", or None to disable pub use_oauth: Option, // Default to true if token not provided } @@ -132,6 +135,7 @@ impl Default for Config { model: "databricks-claude-sonnet-4".to_string(), max_tokens: Some(4096), temperature: Some(0.1), + cache_config: None, use_oauth: Some(true), }), embedded: None, diff --git a/crates/g3-core/src/lib.rs b/crates/g3-core/src/lib.rs index 1ac374e..7839d63 100644 --- a/crates/g3-core/src/lib.rs +++ b/crates/g3-core/src/lib.rs @@ -23,7 +23,7 @@ use anyhow::Result; use g3_computer_control::WebDriverController; use g3_config::Config; use g3_execution::CodeExecutor; -use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry, Tool}; +use g3_providers::{CacheControl, CompletionRequest, Message, MessageRole, ProviderRegistry, Tool}; #[allow(unused_imports)] use regex::Regex; use serde::{Deserialize, Serialize}; @@ -423,18 +423,12 @@ Format this as a detailed but concise summary that can be used to resume the con self.used_tokens = 0; // Add the summary as a system message - let summary_message = Message { - role: MessageRole::System, - content: format!("Previous conversation summary:\n\n{}", summary), - }; + let summary_message = Message::new(MessageRole::System, format!("Previous conversation summary:\n\n{}", summary)); self.add_message(summary_message); // Add the latest user message if provided if let Some(user_msg) = latest_user_message { - self.add_message(Message { - role: MessageRole::User, - content: user_msg, - }); + self.add_message(Message::new(MessageRole::User, user_msg)); } let new_chars: usize = self @@ -756,6 +750,7 @@ pub struct Agent { safaridriver_process: std::sync::Arc>>, macax_controller: std::sync::Arc>>, + tool_call_count: usize, } impl Agent { @@ -898,6 +893,8 @@ impl Agent { Some(anthropic_config.model.clone()), anthropic_config.max_tokens, anthropic_config.temperature, + anthropic_config.cache_config.clone(), + anthropic_config.enable_1m_context, )?; providers.register(anthropic_provider); } @@ -944,10 +941,7 @@ impl Agent { // If README content is provided, add it as the first system message if let Some(readme) = readme_content { - let readme_message = Message { - role: MessageRole::System, - content: readme, - }; + let readme_message = Message::new(MessageRole::System, readme); context_window.add_message(readme_message); } @@ -1003,9 +997,23 @@ impl Agent { None })) }, + tool_call_count: 0, }) } + /// Convert cache config string to CacheControl enum + fn parse_cache_control(cache_config: &str) -> Option { + match cache_config { + "ephemeral" => Some(CacheControl::Ephemeral), + "5minute" => Some(CacheControl::FiveMinute), + "1hour" => Some(CacheControl::OneHour), + _ => { + warn!("Invalid cache_config value: '{}'. Valid values are: ephemeral, 5minute, 1hour", cache_config); + None + } + } + } + fn get_configured_context_length(config: &Config, providers: &ProviderRegistry) -> Result { // First, check if there's a global max_context_length override in agent config if let Some(max_context_length) = config.agent.max_context_length { @@ -1185,7 +1193,11 @@ impl Agent { // Only add system message if this is the first interaction (empty conversation history) if self.context_window.conversation_history.is_empty() { let provider = self.providers.get(None)?; - let system_prompt = if provider.has_native_tool_calling() { + let provider_has_native_tool_calling = provider.has_native_tool_calling(); + let provider_name_for_system = provider.name().to_string(); + drop(provider); // Drop provider reference to avoid borrowing issues + + let system_prompt = if provider_has_native_tool_calling { // For native tool calling providers, use a more explicit system prompt "You are G3, an AI programming agent of the same skill level as a seasoned engineer at a major technology company. You analyze given tasks and write code to achieve goals. @@ -1493,18 +1505,34 @@ If you can complete it with 1-2 tool calls, skip TODO. } // Add system message to context window - let system_message = Message { - role: MessageRole::System, - content: system_prompt, + let system_message = { + // Check if we should use cache control for system message + if let Some(cache_config) = match provider_name_for_system.as_str() { + "anthropic" => self.config.providers.anthropic.as_ref() + .and_then(|c| c.cache_config.as_ref()) + .and_then(|config| Self::parse_cache_control(config)), + "databricks" => self.config.providers.databricks.as_ref() + .and_then(|c| { + if c.model.contains("claude") { + c.cache_config.as_ref() + .and_then(|config| Self::parse_cache_control(config)) + } else { + None + } + }), + _ => None, + } { + let provider = self.providers.get(None)?; + Message::with_cache_control_validated(MessageRole::System, system_prompt, cache_config, provider) + } else { + Message::new(MessageRole::System, system_prompt) + } }; self.context_window.add_message(system_message); } // Add user message to context window - let user_message = Message { - role: MessageRole::User, - content: format!("Task: {}", description), - }; + let user_message = Message::new(MessageRole::User, format!("Task: {}", description)); self.context_window.add_message(user_message); // Use the complete conversation history for the request @@ -1512,6 +1540,9 @@ If you can complete it with 1-2 tool calls, skip TODO. // Check if provider supports native tool calling and add tools if so let provider = self.providers.get(None)?; + let provider_name = provider.name().to_string(); + let has_native_tool_calling = provider.has_native_tool_calling(); + let supports_cache_control = provider.supports_cache_control(); let tools = if provider.has_native_tool_calling() { Some(Self::create_tool_definitions( self.config.webdriver.enabled, @@ -1521,9 +1552,10 @@ If you can complete it with 1-2 tool calls, skip TODO. } else { None }; + drop(provider); // Drop the provider reference to avoid borrowing issues // Get max_tokens from provider configuration - let max_tokens = match provider.name() { + let max_tokens = match provider_name.as_str() { "databricks" => { // Use the model's maximum limit for Databricks to allow large file generation Some(32000) @@ -1578,9 +1610,32 @@ If you can complete it with 1-2 tool calls, skip TODO. // Add assistant response to context window only if not empty // This prevents the "Skipping empty message" warning when only tools were executed if !response_content.trim().is_empty() { - let assistant_message = Message { - role: MessageRole::Assistant, - content: response_content.clone(), + let assistant_message = { + // Check if we should use cache control (every 10 tool calls) + if self.tool_call_count > 0 && self.tool_call_count % 10 == 0 { + let provider = self.providers.get(None)?; + if let Some(cache_config) = match provider.name() { + "anthropic" => self.config.providers.anthropic.as_ref() + .and_then(|c| c.cache_config.as_ref()) + .and_then(|config| Self::parse_cache_control(config)), + "databricks" => self.config.providers.databricks.as_ref() + .and_then(|c| { + if c.model.contains("claude") { + c.cache_config.as_ref() + .and_then(|config| Self::parse_cache_control(config)) + } else { + None + } + }), + _ => None, + } { + Message::with_cache_control_validated(MessageRole::Assistant, response_content.clone(), cache_config, provider) + } else { + Message::new(MessageRole::Assistant, response_content.clone()) + } + } else { + Message::new(MessageRole::Assistant, response_content.clone()) + } }; self.context_window.add_message(assistant_message); } else { @@ -1783,17 +1838,11 @@ If you can complete it with 1-2 tool calls, skip TODO. .join("\n\n"); let summary_messages = vec![ - Message { - role: MessageRole::System, - content: "You are a helpful assistant that creates concise summaries.".to_string(), - }, - Message { - role: MessageRole::User, - content: format!( + Message::new(MessageRole::System, "You are a helpful assistant that creates concise summaries.".to_string()), + Message::new(MessageRole::User, format!( "Based on this conversation history, {}\n\nConversation:\n{}", summary_prompt, conversation_text - ), - }, + )), ]; let provider = self.providers.get(None)?; @@ -2776,18 +2825,11 @@ If you can complete it with 1-2 tool calls, skip TODO. .join("\n\n"); let summary_messages = vec![ - Message { - role: MessageRole::System, - content: "You are a helpful assistant that creates concise summaries." - .to_string(), - }, - Message { - role: MessageRole::User, - content: format!( + Message::new(MessageRole::System, "You are a helpful assistant that creates concise summaries.".to_string()), + Message::new(MessageRole::User, format!( "Based on this conversation history, {}\n\nConversation:\n{}", summary_prompt, conversation_text - ), - }, + )), ]; let provider = self.providers.get(None)?; @@ -3273,29 +3315,20 @@ If you can complete it with 1-2 tool calls, skip TODO. // Add the tool call and result to the context window using RAW unfiltered content // This ensures the log file contains the true raw content including JSON tool calls let tool_message = if !raw_content_for_log.trim().is_empty() { - Message { - role: MessageRole::Assistant, - content: format!( + Message::new(MessageRole::Assistant, format!( "{}\n\n{{\"tool\": \"{}\", \"args\": {}}}", raw_content_for_log.trim(), tool_call.tool, tool_call.args - ), - } + )) } else { // No text content before tool call, just include the tool call - Message { - role: MessageRole::Assistant, - content: format!( + Message::new(MessageRole::Assistant, format!( "{{\"tool\": \"{}\", \"args\": {}}}", tool_call.tool, tool_call.args - ), - } - }; - let result_message = Message { - role: MessageRole::User, - content: format!("Tool result: {}", tool_result), + )) }; + let result_message = Message::new(MessageRole::User, format!("Tool result: {}", tool_result)); self.context_window.add_message(tool_message); self.context_window.add_message(result_message); @@ -3304,7 +3337,8 @@ If you can complete it with 1-2 tool calls, skip TODO. request.messages = self.context_window.conversation_history.clone(); // Ensure tools are included for native providers in subsequent iterations - if provider.has_native_tool_calling() { + let provider_for_tools = self.providers.get(None)?; + if provider_for_tools.has_native_tool_calling() { request.tools = Some(Self::create_tool_definitions( self.config.webdriver.enabled, self.config.macax.enabled, @@ -3635,9 +3669,32 @@ If you can complete it with 1-2 tool calls, skip TODO. .replace("<>", ""); if !raw_clean.trim().is_empty() { - let assistant_message = Message { - role: MessageRole::Assistant, - content: raw_clean, + let assistant_message = { + // Check if we should use cache control (every 10 tool calls) + if self.tool_call_count > 0 && self.tool_call_count % 10 == 0 { + let provider = self.providers.get(None)?; + if let Some(cache_config) = match provider.name() { + "anthropic" => self.config.providers.anthropic.as_ref() + .and_then(|c| c.cache_config.as_ref()) + .and_then(|config| Self::parse_cache_control(config)), + "databricks" => self.config.providers.databricks.as_ref() + .and_then(|c| { + if c.model.contains("claude") { + c.cache_config.as_ref() + .and_then(|config| Self::parse_cache_control(config)) + } else { + None + } + }), + _ => None, + } { + Message::with_cache_control_validated(MessageRole::Assistant, raw_clean, cache_config, provider) + } else { + Message::new(MessageRole::Assistant, raw_clean) + } + } else { + Message::new(MessageRole::Assistant, raw_clean) + } }; self.context_window.add_message(assistant_message); } @@ -3679,7 +3736,10 @@ If you can complete it with 1-2 tool calls, skip TODO. Ok(TaskResult::new(final_response, self.context_window.clone())) } - pub async fn execute_tool(&self, tool_call: &ToolCall) -> Result { + pub async fn execute_tool(&mut self, tool_call: &ToolCall) -> Result { + // Increment tool call count + self.tool_call_count += 1; + debug!("=== EXECUTING TOOL ==="); debug!("Tool name: {}", tool_call.tool); debug!("Tool args (raw): {:?}", tool_call.args); diff --git a/crates/g3-providers/src/anthropic.rs b/crates/g3-providers/src/anthropic.rs index 75a5abb..de90e46 100644 --- a/crates/g3-providers/src/anthropic.rs +++ b/crates/g3-providers/src/anthropic.rs @@ -21,22 +21,18 @@ //! // Create the provider with your API key //! let provider = AnthropicProvider::new( //! "your-api-key".to_string(), -//! Some("claude-3-5-sonnet-20241022".to_string()), // Optional: defaults to claude-3-5-sonnet-20241022 -//! Some(4096), // Optional: max tokens -//! Some(0.1), // Optional: temperature +//! Some("claude-3-5-sonnet-20241022".to_string()), +//! Some(4096), +//! Some(0.1), +//! None, // cache_config +//! None, // enable_1m_context //! )?; //! //! // Create a completion request //! let request = CompletionRequest { //! messages: vec![ -//! Message { -//! role: MessageRole::System, -//! content: "You are a helpful assistant.".to_string(), -//! }, -//! Message { -//! role: MessageRole::User, -//! content: "Hello! How are you?".to_string(), -//! }, +//! Message::new(MessageRole::System, "You are a helpful assistant.".to_string()), +//! Message::new(MessageRole::User, "Hello! How are you?".to_string()), //! ], //! max_tokens: Some(1000), //! temperature: Some(0.7), @@ -62,15 +58,16 @@ //! async fn main() -> anyhow::Result<()> { //! let provider = AnthropicProvider::new( //! "your-api-key".to_string(), -//! None, None, None, +//! None, +//! None, +//! None, +//! None, // cache_config +//! None, // enable_1m_context //! )?; //! //! let request = CompletionRequest { //! messages: vec![ -//! Message { -//! role: MessageRole::User, -//! content: "Write a short story about a robot.".to_string(), -//! }, +//! Message::new(MessageRole::User, "Write a short story about a robot.".to_string()), //! ], //! max_tokens: Some(1000), //! temperature: Some(0.7), @@ -123,6 +120,8 @@ pub struct AnthropicProvider { model: String, max_tokens: u32, temperature: f32, + cache_config: Option, + enable_1m_context: bool, } impl AnthropicProvider { @@ -131,6 +130,8 @@ impl AnthropicProvider { model: Option, max_tokens: Option, temperature: Option, + cache_config: Option, + enable_1m_context: Option, ) -> Result { let client = Client::builder() .timeout(Duration::from_secs(300)) @@ -147,6 +148,8 @@ impl AnthropicProvider { model, max_tokens: max_tokens.unwrap_or(4096), temperature: temperature.unwrap_or(0.1), + cache_config, + enable_1m_context: enable_1m_context.unwrap_or(false), }) } @@ -156,9 +159,12 @@ impl AnthropicProvider { .post(ANTHROPIC_API_URL) .header("x-api-key", &self.api_key) .header("anthropic-version", ANTHROPIC_VERSION) - // Anthropic beta 1m context window. Enable if needed. It costs extra, so check first. - // .header("anthropic-beta", "context-1m-2025-08-07") .header("content-type", "application/json"); + + if self.enable_1m_context { + builder = builder.header("anthropic-beta", "context-1m-2025-08-07"); + } + if streaming { builder = builder.header("accept", "text/event-stream"); } @@ -166,6 +172,17 @@ impl AnthropicProvider { builder } + fn convert_cache_control(cache_control: &crate::CacheControl) -> AnthropicCacheControl { + let cache_type = match cache_control { + crate::CacheControl::Ephemeral => "ephemeral", + crate::CacheControl::FiveMinute => "5minute", + crate::CacheControl::OneHour => "1hour", + }; + AnthropicCacheControl { + cache_type: cache_type.to_string(), + } + } + fn convert_tools(&self, tools: &[Tool]) -> Vec { tools .iter() @@ -214,6 +231,8 @@ impl AnthropicProvider { role: "user".to_string(), content: vec![AnthropicContent::Text { text: message.content.clone(), + cache_control: message.cache_control.as_ref() + .map(Self::convert_cache_control), }], }); } @@ -222,6 +241,8 @@ impl AnthropicProvider { role: "assistant".to_string(), content: vec![AnthropicContent::Text { text: message.content.clone(), + cache_control: message.cache_control.as_ref() + .map(Self::convert_cache_control), }], }); } @@ -564,7 +585,7 @@ impl LLMProvider for AnthropicProvider { .content .iter() .filter_map(|c| match c { - AnthropicContent::Text { text } => Some(text.as_str()), + AnthropicContent::Text { text, .. } => Some(text.as_str()), _ => None, }) .collect::>() @@ -658,6 +679,11 @@ impl LLMProvider for AnthropicProvider { // Claude models support native tool calling true } + + fn supports_cache_control(&self) -> bool { + // Anthropic supports cache control + true + } } // Anthropic API request/response structures @@ -697,11 +723,21 @@ struct AnthropicMessage { content: Vec, } +#[derive(Debug, Serialize, Deserialize)] +struct AnthropicCacheControl { + #[serde(rename = "type")] + cache_type: String, +} + #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type")] enum AnthropicContent { #[serde(rename = "text")] - Text { text: String }, + Text { + text: String, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, + }, #[serde(rename = "tool_use")] ToolUse { id: String, @@ -771,21 +807,14 @@ mod tests { None, None, None, + None, + None, ).unwrap(); let messages = vec![ - Message { - role: MessageRole::System, - content: "You are a helpful assistant.".to_string(), - }, - Message { - role: MessageRole::User, - content: "Hello!".to_string(), - }, - Message { - role: MessageRole::Assistant, - content: "Hi there!".to_string(), - }, + Message::new(MessageRole::System, "You are a helpful assistant.".to_string()), + Message::new(MessageRole::User, "Hello!".to_string()), + Message::new(MessageRole::Assistant, "Hi there!".to_string()), ]; let (system, anthropic_messages) = provider.convert_messages(&messages).unwrap(); @@ -803,14 +832,11 @@ mod tests { Some("claude-3-haiku-20240307".to_string()), Some(1000), Some(0.5), + None, + None, ).unwrap(); - let messages = vec![ - Message { - role: MessageRole::User, - content: "Test message".to_string(), - }, - ]; + let messages = vec![Message::new(MessageRole::User, "Test message".to_string())]; let request_body = provider .create_request_body(&messages, None, false, 1000, 0.5) @@ -831,6 +857,8 @@ mod tests { None, None, None, + None, + None, ).unwrap(); let tools = vec![ @@ -859,4 +887,48 @@ mod tests { assert!(anthropic_tools[0].input_schema.required.is_some()); assert_eq!(anthropic_tools[0].input_schema.required.as_ref().unwrap()[0], "location"); } + + #[test] + fn test_cache_control_serialization() { + let provider = AnthropicProvider::new( + "test-key".to_string(), + None, + None, + None, + None, + None, + ).unwrap(); + + // Test message WITHOUT cache_control + let messages_without = vec![Message::new(MessageRole::User, "Hello".to_string())]; + let (_, anthropic_messages_without) = provider.convert_messages(&messages_without).unwrap(); + let json_without = serde_json::to_string(&anthropic_messages_without).unwrap(); + + println!("Anthropic JSON without cache_control: {}", json_without); + // Check if cache_control appears in the JSON + if json_without.contains("cache_control") { + println!("WARNING: JSON contains 'cache_control' field when not configured!"); + assert!(!json_without.contains("\"cache_control\":null"), + "JSON should not contain 'cache_control: null'"); + } + + // Test message WITH cache_control + let messages_with = vec![Message::with_cache_control( + MessageRole::User, + "Hello".to_string(), + crate::CacheControl::Ephemeral, + )]; + let (_, anthropic_messages_with) = provider.convert_messages(&messages_with).unwrap(); + let json_with = serde_json::to_string(&anthropic_messages_with).unwrap(); + + println!("Anthropic JSON with cache_control: {}", json_with); + assert!(json_with.contains("cache_control"), + "JSON should contain 'cache_control' field when configured"); + assert!(json_with.contains("ephemeral"), + "JSON should contain 'ephemeral' type"); + + // The key assertion: when cache_control is None, it should not appear in JSON + assert!(!json_without.contains("cache_control") || !json_without.contains("null"), + "JSON should not contain 'cache_control' field or null values when not configured"); + } } diff --git a/crates/g3-providers/src/databricks.rs b/crates/g3-providers/src/databricks.rs index 962822f..08cddae 100644 --- a/crates/g3-providers/src/databricks.rs +++ b/crates/g3-providers/src/databricks.rs @@ -39,10 +39,7 @@ //! // Create a completion request //! let request = CompletionRequest { //! messages: vec![ -//! Message { -//! role: MessageRole::User, -//! content: "Hello! How are you?".to_string(), -//! }, +//! Message::new(MessageRole::User, "Hello! How are you?".to_string()), //! ], //! max_tokens: Some(1000), //! temperature: Some(0.7), @@ -241,6 +238,15 @@ impl DatabricksProvider { .collect() } + fn convert_cache_control(cache_control: &crate::CacheControl) -> DatabricksCacheControl { + let cache_type = match cache_control { + crate::CacheControl::Ephemeral => "ephemeral", + crate::CacheControl::FiveMinute => "5minute", + crate::CacheControl::OneHour => "1hour", + }; + DatabricksCacheControl { cache_type: cache_type.to_string() } + } + fn convert_messages(&self, messages: &[Message]) -> Result> { let mut databricks_messages = Vec::new(); @@ -251,9 +257,24 @@ impl DatabricksProvider { MessageRole::Assistant => "assistant", }; + // If message has cache_control, use content array format + let content = if message.cache_control.is_some() { + // Use array format with cache_control + let content_block = DatabricksContent::Text { + content_type: "text".to_string(), + text: message.content.clone(), + cache_control: message.cache_control.as_ref() + .map(Self::convert_cache_control), + }; + serde_json::to_value(vec![content_block])? + } else { + // Use simple string format + serde_json::Value::String(message.content.clone()) + }; + databricks_messages.push(DatabricksMessage { role: role.to_string(), - content: Some(message.content.clone()), + content: Some(content), tool_calls: None, // Only used in responses, not requests }); } @@ -864,8 +885,22 @@ impl LLMProvider for DatabricksProvider { let content = databricks_response .choices .first() - .and_then(|choice| choice.message.content.as_ref()) - .cloned() + .and_then(|choice| { + choice.message.content.as_ref().map(|c| { + // Handle both string and array formats + if let Some(s) = c.as_str() { + s.to_string() + } else if let Some(arr) = c.as_array() { + // Extract text from content blocks + arr.iter() + .filter_map(|block| block.get("text").and_then(|t| t.as_str())) + .collect::>() + .join("") + } else { + String::new() + } + }) + }) .unwrap_or_default(); // Check if there are tool calls in the response @@ -1037,6 +1072,11 @@ impl LLMProvider for DatabricksProvider { // This includes Claude, Llama, DBRX, and most other models on the platform true } + + fn supports_cache_control(&self) -> bool { + // Databricks supports cache control when using Anthropic models + self.model.contains("claude") + } } // Databricks API request/response structures @@ -1064,10 +1104,29 @@ struct DatabricksFunction { parameters: serde_json::Value, } +#[derive(Debug, Serialize, Deserialize)] +struct DatabricksCacheControl { + #[serde(rename = "type")] + cache_type: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +enum DatabricksContent { + Text { + #[serde(rename = "type")] + content_type: String, + text: String, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, + }, +} + #[derive(Debug, Serialize, Deserialize)] struct DatabricksMessage { role: String, - content: Option, // Make content optional since tool calls might not have content + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, // Can be string or array of content blocks #[serde(skip_serializing_if = "Option::is_none")] tool_calls: Option>, // Add tool_calls field for responses } @@ -1154,18 +1213,9 @@ mod tests { .unwrap(); let messages = vec![ - Message { - role: MessageRole::System, - content: "You are a helpful assistant.".to_string(), - }, - Message { - role: MessageRole::User, - content: "Hello!".to_string(), - }, - Message { - role: MessageRole::Assistant, - content: "Hi there!".to_string(), - }, + Message::new(MessageRole::System, "You are a helpful assistant.".to_string()), + Message::new(MessageRole::User, "Hello!".to_string()), + Message::new(MessageRole::Assistant, "Hi there!".to_string()), ]; let databricks_messages = provider.convert_messages(&messages).unwrap(); @@ -1187,10 +1237,7 @@ mod tests { ) .unwrap(); - let messages = vec![Message { - role: MessageRole::User, - content: "Test message".to_string(), - }]; + let messages = vec![Message::new(MessageRole::User, "Test message".to_string())]; let request_body = provider .create_request_body(&messages, None, false, 1000, 0.5) @@ -1273,4 +1320,53 @@ mod tests { assert!(llama_provider.has_native_tool_calling()); assert!(dbrx_provider.has_native_tool_calling()); } + + #[test] + fn test_cache_control_serialization() { + let provider = DatabricksProvider::from_token( + "https://test.databricks.com".to_string(), + "test-token".to_string(), + "databricks-claude-sonnet-4".to_string(), + None, + None, + ) + .unwrap(); + + // Test message WITHOUT cache_control - should use string format + let messages_without = vec![Message::new(MessageRole::User, "Hello".to_string())]; + let databricks_messages_without = provider.convert_messages(&messages_without).unwrap(); + let json_without = serde_json::to_string(&databricks_messages_without).unwrap(); + + println!("JSON without cache_control: {}", json_without); + assert!(!json_without.contains("cache_control"), + "JSON should not contain 'cache_control' field when not configured"); + + // Test message WITH cache_control - should use array format + let messages_with = vec![Message::with_cache_control( + MessageRole::User, + "Hello".to_string(), + crate::CacheControl::Ephemeral, + )]; + let databricks_messages_with = provider.convert_messages(&messages_with).unwrap(); + let json_with = serde_json::to_string(&databricks_messages_with).unwrap(); + + println!("JSON with cache_control: {}", json_with); + assert!(json_with.contains("cache_control"), + "JSON should contain 'cache_control' field when configured"); + assert!(json_with.contains("ephemeral"), + "JSON should contain 'ephemeral' type"); + assert!(!json_with.contains("null"), + "JSON should not contain null values"); + + // Verify the structure is correct + let msg_with = &databricks_messages_with[0]; + if let Some(content) = &msg_with.content { + if let Some(arr) = content.as_array() { + assert_eq!(arr.len(), 1, "Content array should have one element"); + assert!(arr[0].get("cache_control").is_some(), "Content should have cache_control"); + } else { + panic!("Content should be an array when cache_control is present"); + } + } + } } diff --git a/crates/g3-providers/src/lib.rs b/crates/g3-providers/src/lib.rs index 51ea55a..8c80b64 100644 --- a/crates/g3-providers/src/lib.rs +++ b/crates/g3-providers/src/lib.rs @@ -21,6 +21,11 @@ pub trait LLMProvider: Send + Sync { fn has_native_tool_calling(&self) -> bool { false } + + /// Check if the provider supports cache control + fn supports_cache_control(&self) -> bool { + false + } } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -32,10 +37,22 @@ pub struct CompletionRequest { pub tools: Option>, } +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum CacheControl { + Ephemeral, + #[serde(rename = "5minute")] + FiveMinute, + #[serde(rename = "1hour")] + OneHour, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Message { pub role: MessageRole, pub content: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_control: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -95,6 +112,45 @@ pub use databricks::DatabricksProvider; pub use embedded::EmbeddedProvider; pub use openai::OpenAIProvider; +impl Message { + /// Create a new message with optional cache control + pub fn new(role: MessageRole, content: String) -> Self { + Self { + role, + content, + cache_control: None, + } + } + + /// Create a new message with cache control + pub fn with_cache_control(role: MessageRole, content: String, cache_control: CacheControl) -> Self { + Self { + role, + content, + cache_control: Some(cache_control), + } + } + + /// Create a message with cache control, with provider validation + pub fn with_cache_control_validated( + role: MessageRole, + content: String, + cache_control: CacheControl, + provider: &dyn LLMProvider + ) -> Self { + if !provider.supports_cache_control() { + tracing::warn!( + "Cache control requested for provider '{}' which does not support it. \ + Cache control is only supported by Anthropic and Anthropic via Databricks.", + provider.name() + ); + return Self::new(role, content); + } + + Self::with_cache_control(role, content, cache_control) + } +} + /// Provider registry for managing multiple LLM providers pub struct ProviderRegistry { providers: HashMap>, @@ -144,3 +200,36 @@ impl Default for ProviderRegistry { Self::new() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_message_serialization_without_cache_control() { + let msg = Message::new(MessageRole::User, "Hello".to_string()); + let json = serde_json::to_string(&msg).unwrap(); + + println!("Message JSON without cache_control: {}", json); + assert!(!json.contains("cache_control"), + "JSON should not contain 'cache_control' field when not configured"); + } + + #[test] + fn test_message_serialization_with_cache_control() { + let msg = Message::with_cache_control( + MessageRole::User, + "Hello".to_string(), + CacheControl::Ephemeral, + ); + let json = serde_json::to_string(&msg).unwrap(); + + println!("Message JSON with cache_control: {}", json); + assert!(json.contains("cache_control"), + "JSON should contain 'cache_control' field when configured"); + assert!(json.contains("ephemeral"), + "JSON should contain 'ephemeral' value"); + assert!(!json.contains("null"), + "JSON should not contain null values"); + } +}