diff --git a/config.coach-player.example.toml b/config.coach-player.example.toml index 999b674..50d7943 100644 --- a/config.coach-player.example.toml +++ b/config.coach-player.example.toml @@ -11,12 +11,23 @@ model = "databricks-claude-sonnet-4" max_tokens = 4096 temperature = 0.1 use_oauth = true +# cache_config = "ephemeral" # Optional: Enable prompt caching for Claude models + # Options: "ephemeral", "5minute", "1hour" + # Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs. + # The cache control will be automatically applied to: + # - The system prompt at the start of each session + # - Assistant responses after every 10 tool calls + # - 5minute costs $3/mtok, more details below + # https://docs.claude.com/en/docs/build-with-claude/prompt-caching#pricing [providers.anthropic] api_key = "your-anthropic-api-key" model = "claude-3-haiku-20240307" # Using a faster model for player max_tokens = 4096 temperature = 0.3 # Slightly higher temperature for more creative implementations +# cache_config = "ephemeral" # Optional: Enable prompt caching + # Options: "ephemeral", "5minute", "1hour" + # Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs. [agent] fallback_default_max_tokens = 8192 diff --git a/config.example.toml b/config.example.toml index 1bc0893..b0fce4e 100644 --- a/config.example.toml +++ b/config.example.toml @@ -14,6 +14,15 @@ max_tokens = 4096 # Per-request output limit (how many tokens the model can gen # Note: This is different from max_context_length (total conversation history size) temperature = 0.1 use_oauth = true +# cache_config = "ephemeral" # Optional: Enable prompt caching for Claude models on Databricks + # Options: "ephemeral", "5minute", "1hour" + # Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs. + # The cache control will be automatically applied to: + # - The system prompt at the start of each session + # - Assistant responses after every 10 tool calls + # - 5minute costs $3/mtok, more details below + # https://docs.claude.com/en/docs/build-with-claude/prompt-caching#pricing + # Multiple OpenAI-compatible providers can be configured with custom names # Each provider gets its own section under [providers.openai_compatible.] diff --git a/crates/g3-config/src/lib.rs b/crates/g3-config/src/lib.rs index c860481..827a83b 100644 --- a/crates/g3-config/src/lib.rs +++ b/crates/g3-config/src/lib.rs @@ -40,6 +40,8 @@ pub struct AnthropicConfig { pub model: String, pub max_tokens: Option, pub temperature: Option, + pub cache_config: Option, // "ephemeral", "5minute", "1hour", or None to disable + pub enable_1m_context: Option, // Enable 1m context window (costs extra) } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/g3-core/src/lib.rs b/crates/g3-core/src/lib.rs index 1ac374e..ffc13e6 100644 --- a/crates/g3-core/src/lib.rs +++ b/crates/g3-core/src/lib.rs @@ -3,6 +3,8 @@ pub mod error_handling; pub mod project; pub mod task_result; pub mod ui_writer; + +use std::process::exit; pub use task_result::TaskResult; #[cfg(test)] @@ -23,7 +25,7 @@ use anyhow::Result; use g3_computer_control::WebDriverController; use g3_config::Config; use g3_execution::CodeExecutor; -use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry, Tool}; +use g3_providers::{CacheControl, CompletionRequest, Message, MessageRole, ProviderRegistry, Tool}; #[allow(unused_imports)] use regex::Regex; use serde::{Deserialize, Serialize}; @@ -423,18 +425,12 @@ Format this as a detailed but concise summary that can be used to resume the con self.used_tokens = 0; // Add the summary as a system message - let summary_message = Message { - role: MessageRole::System, - content: format!("Previous conversation summary:\n\n{}", summary), - }; + let summary_message = Message::new(MessageRole::System, format!("Previous conversation summary:\n\n{}", summary)); self.add_message(summary_message); // Add the latest user message if provided if let Some(user_msg) = latest_user_message { - self.add_message(Message { - role: MessageRole::User, - content: user_msg, - }); + self.add_message(Message::new(MessageRole::User, user_msg)); } let new_chars: usize = self @@ -756,6 +752,7 @@ pub struct Agent { safaridriver_process: std::sync::Arc>>, macax_controller: std::sync::Arc>>, + tool_call_count: usize, } impl Agent { @@ -898,6 +895,8 @@ impl Agent { Some(anthropic_config.model.clone()), anthropic_config.max_tokens, anthropic_config.temperature, + anthropic_config.cache_config.clone(), + anthropic_config.enable_1m_context, )?; providers.register(anthropic_provider); } @@ -944,10 +943,7 @@ impl Agent { // If README content is provided, add it as the first system message if let Some(readme) = readme_content { - let readme_message = Message { - role: MessageRole::System, - content: readme, - }; + let readme_message = Message::new(MessageRole::System, readme); context_window.add_message(readme_message); } @@ -1003,9 +999,23 @@ impl Agent { None })) }, + tool_call_count: 0, }) } + /// Convert cache config string to CacheControl enum + fn parse_cache_control(cache_config: &str) -> Option { + match cache_config { + "ephemeral" => Some(CacheControl::ephemeral()), + "5minute" => Some(CacheControl::five_minute()), + "1hour" => Some(CacheControl::one_hour()), + _ => { + warn!("Invalid cache_config value: '{}'. Valid values are: ephemeral, 5minute, 1hour", cache_config); + None + } + } + } + fn get_configured_context_length(config: &Config, providers: &ProviderRegistry) -> Result { // First, check if there's a global max_context_length override in agent config if let Some(max_context_length) = config.agent.max_context_length { @@ -1185,7 +1195,11 @@ impl Agent { // Only add system message if this is the first interaction (empty conversation history) if self.context_window.conversation_history.is_empty() { let provider = self.providers.get(None)?; - let system_prompt = if provider.has_native_tool_calling() { + let provider_has_native_tool_calling = provider.has_native_tool_calling(); + let provider_name_for_system = provider.name().to_string(); + drop(provider); // Drop provider reference to avoid borrowing issues + + let system_prompt = if provider_has_native_tool_calling { // For native tool calling providers, use a more explicit system prompt "You are G3, an AI programming agent of the same skill level as a seasoned engineer at a major technology company. You analyze given tasks and write code to achieve goals. @@ -1493,18 +1507,26 @@ If you can complete it with 1-2 tool calls, skip TODO. } // Add system message to context window - let system_message = Message { - role: MessageRole::System, - content: system_prompt, + let system_message = { + // Check if we should use cache control for system message + if let Some(cache_config) = match provider_name_for_system.as_str() { + "anthropic" => self.config.providers.anthropic.as_ref() + .and_then(|c| c.cache_config.as_ref()) + .and_then(|config| Self::parse_cache_control(config)), + _ => None, + } { + let provider = self.providers.get(None)?; + Message::with_cache_control_validated(MessageRole::System, system_prompt, cache_config, provider) + } else { + Message::new(MessageRole::System, system_prompt) + } }; + self.context_window.add_message(system_message); } // Add user message to context window - let user_message = Message { - role: MessageRole::User, - content: format!("Task: {}", description), - }; + let user_message = Message::new(MessageRole::User, format!("Task: {}", description)); self.context_window.add_message(user_message); // Use the complete conversation history for the request @@ -1512,6 +1534,9 @@ If you can complete it with 1-2 tool calls, skip TODO. // Check if provider supports native tool calling and add tools if so let provider = self.providers.get(None)?; + let provider_name = provider.name().to_string(); + let has_native_tool_calling = provider.has_native_tool_calling(); + let supports_cache_control = provider.supports_cache_control(); let tools = if provider.has_native_tool_calling() { Some(Self::create_tool_definitions( self.config.webdriver.enabled, @@ -1521,9 +1546,10 @@ If you can complete it with 1-2 tool calls, skip TODO. } else { None }; + drop(provider); // Drop the provider reference to avoid borrowing issues // Get max_tokens from provider configuration - let max_tokens = match provider.name() { + let max_tokens = match provider_name.as_str() { "databricks" => { // Use the model's maximum limit for Databricks to allow large file generation Some(32000) @@ -1578,9 +1604,23 @@ If you can complete it with 1-2 tool calls, skip TODO. // Add assistant response to context window only if not empty // This prevents the "Skipping empty message" warning when only tools were executed if !response_content.trim().is_empty() { - let assistant_message = Message { - role: MessageRole::Assistant, - content: response_content.clone(), + let assistant_message = { + // Check if we should use cache control (every 10 tool calls) + if self.tool_call_count > 0 && self.tool_call_count % 10 == 0 { + let provider = self.providers.get(None)?; + if let Some(cache_config) = match provider.name() { + "anthropic" => self.config.providers.anthropic.as_ref() + .and_then(|c| c.cache_config.as_ref()) + .and_then(|config| Self::parse_cache_control(config)), + _ => None, + } { + Message::with_cache_control_validated(MessageRole::Assistant, response_content.clone(), cache_config, provider) + } else { + Message::new(MessageRole::Assistant, response_content.clone()) + } + } else { + Message::new(MessageRole::Assistant, response_content.clone()) + } }; self.context_window.add_message(assistant_message); } else { @@ -1783,17 +1823,11 @@ If you can complete it with 1-2 tool calls, skip TODO. .join("\n\n"); let summary_messages = vec![ - Message { - role: MessageRole::System, - content: "You are a helpful assistant that creates concise summaries.".to_string(), - }, - Message { - role: MessageRole::User, - content: format!( + Message::new(MessageRole::System, "You are a helpful assistant that creates concise summaries.".to_string()), + Message::new(MessageRole::User, format!( "Based on this conversation history, {}\n\nConversation:\n{}", summary_prompt, conversation_text - ), - }, + )), ]; let provider = self.providers.get(None)?; @@ -2776,18 +2810,11 @@ If you can complete it with 1-2 tool calls, skip TODO. .join("\n\n"); let summary_messages = vec![ - Message { - role: MessageRole::System, - content: "You are a helpful assistant that creates concise summaries." - .to_string(), - }, - Message { - role: MessageRole::User, - content: format!( + Message::new(MessageRole::System, "You are a helpful assistant that creates concise summaries.".to_string()), + Message::new(MessageRole::User, format!( "Based on this conversation history, {}\n\nConversation:\n{}", summary_prompt, conversation_text - ), - }, + )), ]; let provider = self.providers.get(None)?; @@ -3273,29 +3300,20 @@ If you can complete it with 1-2 tool calls, skip TODO. // Add the tool call and result to the context window using RAW unfiltered content // This ensures the log file contains the true raw content including JSON tool calls let tool_message = if !raw_content_for_log.trim().is_empty() { - Message { - role: MessageRole::Assistant, - content: format!( + Message::new(MessageRole::Assistant, format!( "{}\n\n{{\"tool\": \"{}\", \"args\": {}}}", raw_content_for_log.trim(), tool_call.tool, tool_call.args - ), - } + )) } else { // No text content before tool call, just include the tool call - Message { - role: MessageRole::Assistant, - content: format!( + Message::new(MessageRole::Assistant, format!( "{{\"tool\": \"{}\", \"args\": {}}}", tool_call.tool, tool_call.args - ), - } - }; - let result_message = Message { - role: MessageRole::User, - content: format!("Tool result: {}", tool_result), + )) }; + let result_message = Message::new(MessageRole::User, format!("Tool result: {}", tool_result)); self.context_window.add_message(tool_message); self.context_window.add_message(result_message); @@ -3304,7 +3322,8 @@ If you can complete it with 1-2 tool calls, skip TODO. request.messages = self.context_window.conversation_history.clone(); // Ensure tools are included for native providers in subsequent iterations - if provider.has_native_tool_calling() { + let provider_for_tools = self.providers.get(None)?; + if provider_for_tools.has_native_tool_calling() { request.tools = Some(Self::create_tool_definitions( self.config.webdriver.enabled, self.config.macax.enabled, @@ -3635,9 +3654,23 @@ If you can complete it with 1-2 tool calls, skip TODO. .replace("<>", ""); if !raw_clean.trim().is_empty() { - let assistant_message = Message { - role: MessageRole::Assistant, - content: raw_clean, + let assistant_message = { + // Check if we should use cache control (every 10 tool calls) + if self.tool_call_count > 0 && self.tool_call_count % 10 == 0 { + let provider = self.providers.get(None)?; + if let Some(cache_config) = match provider.name() { + "anthropic" => self.config.providers.anthropic.as_ref() + .and_then(|c| c.cache_config.as_ref()) + .and_then(|config| Self::parse_cache_control(config)), + _ => None, + } { + Message::with_cache_control_validated(MessageRole::Assistant, raw_clean, cache_config, provider) + } else { + Message::new(MessageRole::Assistant, raw_clean) + } + } else { + Message::new(MessageRole::Assistant, raw_clean) + } }; self.context_window.add_message(assistant_message); } @@ -3679,7 +3712,10 @@ If you can complete it with 1-2 tool calls, skip TODO. Ok(TaskResult::new(final_response, self.context_window.clone())) } - pub async fn execute_tool(&self, tool_call: &ToolCall) -> Result { + pub async fn execute_tool(&mut self, tool_call: &ToolCall) -> Result { + // Increment tool call count + self.tool_call_count += 1; + debug!("=== EXECUTING TOOL ==="); debug!("Tool name: {}", tool_call.tool); debug!("Tool args (raw): {:?}", tool_call.args); diff --git a/crates/g3-core/src/task_result_comprehensive_tests.rs b/crates/g3-core/src/task_result_comprehensive_tests.rs index 3164d68..0f15f49 100644 --- a/crates/g3-core/src/task_result_comprehensive_tests.rs +++ b/crates/g3-core/src/task_result_comprehensive_tests.rs @@ -6,14 +6,10 @@ use std::sync::Arc; fn test_task_result_basic_functionality() { // Create a context window with some messages let mut context = ContextWindow::new(10000); - context.add_message(Message { - role: MessageRole::User, - content: "Test message 1".to_string(), - }); - context.add_message(Message { - role: MessageRole::Assistant, - content: "Response 1".to_string(), - }); + context.add_message(Message::new(MessageRole::User, "Test message 1".to_string()) + ); + context.add_message(Message::new(MessageRole::Assistant, "Response 1".to_string()) + ); // Create a TaskResult let response = "This is the response\n\nFinal output block".to_string(); @@ -100,10 +96,7 @@ fn test_context_window_preservation() { // Add some messages for i in 0..5 { - context.add_message(Message { - role: if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant }, - content: format!("Message {}", i), - }); + context.add_message(Message::new(if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant }, format!("Message {}", i))); } // Create TaskResult diff --git a/crates/g3-core/tests/test_context_thinning.rs b/crates/g3-core/tests/test_context_thinning.rs index db6761f..f0ef2a1 100644 --- a/crates/g3-core/tests/test_context_thinning.rs +++ b/crates/g3-core/tests/test_context_thinning.rs @@ -46,10 +46,10 @@ fn test_thin_context_basic() { // Add some messages to the first third for i in 0..9 { if i % 2 == 0 { - context.add_message(Message { - role: MessageRole::Assistant, - content: format!("Assistant message {}", i), - }); + context.add_message(Message::new( + MessageRole::Assistant, + format!("Assistant message {}", i), + )); } else { // Add tool results with varying sizes let content = if i == 1 { @@ -63,10 +63,10 @@ fn test_thin_context_basic() { format!("Tool result: small result {}", i) }; - context.add_message(Message { - role: MessageRole::User, + context.add_message(Message::new( + MessageRole::User, content, - }); + )); } } @@ -98,10 +98,10 @@ fn test_thin_write_file_tool_calls() { let mut context = ContextWindow::new(10000); // Add some messages including a write_file tool call with large content - context.add_message(Message { - role: MessageRole::User, - content: "Please create a large file".to_string(), - }); + context.add_message(Message::new( + MessageRole::User, + "Please create a large file".to_string(), + )); // Add an assistant message with a write_file tool call containing large content let large_content = "x".repeat(1500); @@ -109,22 +109,22 @@ fn test_thin_write_file_tool_calls() { r#"{{"tool": "write_file", "args": {{"file_path": "test.txt", "content": "{}"}}}}"#, large_content ); - context.add_message(Message { - role: MessageRole::Assistant, - content: format!("I'll create that file.\n\n{}", tool_call_json), - }); + context.add_message(Message::new( + MessageRole::Assistant, + format!("I'll create that file.\n\n{}", tool_call_json), + )); - context.add_message(Message { - role: MessageRole::User, - content: "Tool result: ✅ Successfully wrote 1500 lines".to_string(), - }); + context.add_message(Message::new( + MessageRole::User, + "Tool result: ✅ Successfully wrote 1500 lines".to_string(), + )); // Add more messages to ensure we have enough for "first third" logic for i in 0..6 { - context.add_message(Message { - role: MessageRole::Assistant, - content: format!("Response {}", i), - }); + context.add_message(Message::new( + MessageRole::Assistant, + format!("Response {}", i), + )); } // Trigger thinning at 50% @@ -154,10 +154,10 @@ fn test_thin_str_replace_tool_calls() { let mut context = ContextWindow::new(10000); // Add some messages including a str_replace tool call with large diff - context.add_message(Message { - role: MessageRole::User, - content: "Please update the file".to_string(), - }); + context.add_message(Message::new( + MessageRole::User, + "Please update the file".to_string(), + )); // Add an assistant message with a str_replace tool call containing large diff let large_diff = format!("--- old\n{}\n+++ new\n{}", "-old line\n".repeat(100), "+new line\n".repeat(100)); @@ -165,22 +165,22 @@ fn test_thin_str_replace_tool_calls() { r#"{{"tool": "str_replace", "args": {{"file_path": "test.txt", "diff": "{}"}}}}"#, large_diff.replace('\n', "\\n") ); - context.add_message(Message { - role: MessageRole::Assistant, - content: format!("I'll update that file.\n\n{}", tool_call_json), - }); + context.add_message(Message::new( + MessageRole::Assistant, + format!("I'll update that file.\n\n{}", tool_call_json), + )); - context.add_message(Message { - role: MessageRole::User, - content: "Tool result: ✅ applied unified diff".to_string(), - }); + context.add_message(Message::new( + MessageRole::User, + "Tool result: ✅ applied unified diff".to_string(), + )); // Add more messages to ensure we have enough for "first third" logic for i in 0..6 { - context.add_message(Message { - role: MessageRole::Assistant, - content: format!("Response {}", i), - }); + context.add_message(Message::new( + MessageRole::Assistant, + format!("Response {}", i), + )); } // Trigger thinning at 50% @@ -212,10 +212,10 @@ fn test_thin_context_no_large_results() { // Add only small messages for i in 0..9 { - context.add_message(Message { - role: MessageRole::User, - content: format!("Tool result: small {}", i), - }); + context.add_message(Message::new( + MessageRole::User, + format!("Tool result: small {}", i), + )); } context.used_tokens = 5000; @@ -244,7 +244,7 @@ fn test_thin_context_only_affects_first_third() { MessageRole::Assistant }; - context.add_message(Message { role, content }); + context.add_message(Message::new(role, content)); } context.used_tokens = 5000; diff --git a/crates/g3-core/tests/test_todo_context_thinning.rs b/crates/g3-core/tests/test_todo_context_thinning.rs index 016e3e6..27443a9 100644 --- a/crates/g3-core/tests/test_todo_context_thinning.rs +++ b/crates/g3-core/tests/test_todo_context_thinning.rs @@ -8,27 +8,18 @@ fn test_todo_read_results_not_thinned() { let mut context = ContextWindow::new(10000); // Add a todo_read tool call - context.add_message(Message { - role: MessageRole::Assistant, - content: r#"{"tool": "todo_read", "args": {}}"#.to_string(), - }); + context.add_message(Message::new(MessageRole::Assistant, r#"{"tool": "todo_read", "args": {}}"#.to_string())); // Add a large TODO result (> 500 chars) let large_todo_result = format!( "Tool result: 📝 TODO list:\n{}", "- [ ] Task with long description\n".repeat(50) ); - context.add_message(Message { - role: MessageRole::User, - content: large_todo_result.clone(), - }); + context.add_message(Message::new(MessageRole::User, large_todo_result.clone())); // Add more messages to ensure we have enough for "first third" logic for i in 0..6 { - context.add_message(Message { - role: MessageRole::Assistant, - content: format!("Response {}", i), - }); + context.add_message(Message::new(MessageRole::Assistant, format!("Response {}", i))) } // Trigger thinning at 50% @@ -65,27 +56,18 @@ fn test_todo_write_results_not_thinned() { // Add a todo_write tool call let large_content = "- [ ] Task\n".repeat(100); - context.add_message(Message { - role: MessageRole::Assistant, - content: format!(r#"{{"tool": "todo_write", "args": {{"content": "{}"}}}}"#, large_content), - }); + context.add_message(Message::new(MessageRole::Assistant, format!(r#"{{"tool": "todo_write", "args": {{"content": "{}"}}}}"#, large_content))); // Add a large TODO write result let large_todo_result = format!( "Tool result: ✅ TODO list updated ({} chars) and saved to todo.g3.md", large_content.len() ); - context.add_message(Message { - role: MessageRole::User, - content: large_todo_result.clone(), - }); + context.add_message(Message::new(MessageRole::User, large_todo_result.clone())); // Add more messages for i in 0..6 { - context.add_message(Message { - role: MessageRole::Assistant, - content: format!("Response {}", i), - }); + context.add_message(Message::new(MessageRole::Assistant, format!("Response {}", i))) } // Trigger thinning at 50% @@ -119,24 +101,15 @@ fn test_non_todo_results_still_thinned() { let mut context = ContextWindow::new(10000); // Add a non-TODO tool call (e.g., read_file) - context.add_message(Message { - role: MessageRole::Assistant, - content: r#"{"tool": "read_file", "args": {"file_path": "test.txt"}}"#.to_string(), - }); + context.add_message(Message::new(MessageRole::Assistant, r#"{"tool": "read_file", "args": {"file_path": "test.txt"}}"#.to_string())); // Add a large read_file result (> 500 chars) let large_result = format!("Tool result: {}", "x".repeat(1500)); - context.add_message(Message { - role: MessageRole::User, - content: large_result, - }); + context.add_message(Message::new(MessageRole::User, large_result)); // Add more messages for i in 0..6 { - context.add_message(Message { - role: MessageRole::Assistant, - content: format!("Response {}", i), - }); + context.add_message(Message::new(MessageRole::Assistant, format!("Response {}", i))) } // Trigger thinning at 50% @@ -172,27 +145,18 @@ fn test_todo_read_with_spaces_in_tool_name() { let mut context = ContextWindow::new(10000); // Add a todo_read tool call with spaces (JSON formatting variation) - context.add_message(Message { - role: MessageRole::Assistant, - content: r#"{"tool": "todo_read", "args": {}}"#.to_string(), - }); + context.add_message(Message::new(MessageRole::Assistant, r#"{"tool": "todo_read", "args": {}}"#.to_string())); // Add a large TODO result let large_todo_result = format!( "Tool result: 📝 TODO list:\n{}", "- [ ] Task\n".repeat(50) ); - context.add_message(Message { - role: MessageRole::User, - content: large_todo_result.clone(), - }); + context.add_message(Message::new(MessageRole::User, large_todo_result.clone())); // Add more messages for i in 0..6 { - context.add_message(Message { - role: MessageRole::Assistant, - content: format!("Response {}", i), - }); + context.add_message(Message::new(MessageRole::Assistant, format!("Response {}", i))) } // Trigger thinning diff --git a/crates/g3-core/tests/test_todo_persistence.rs b/crates/g3-core/tests/test_todo_persistence.rs index f43eed3..69baabd 100644 --- a/crates/g3-core/tests/test_todo_persistence.rs +++ b/crates/g3-core/tests/test_todo_persistence.rs @@ -27,7 +27,7 @@ fn get_todo_path(temp_dir: &TempDir) -> PathBuf { #[serial] async fn test_todo_write_creates_file() { let temp_dir = TempDir::new().unwrap(); - let agent = create_test_agent_in_dir(&temp_dir).await; + let mut agent = create_test_agent_in_dir(&temp_dir).await; let todo_path = get_todo_path(&temp_dir); // Initially, todo.g3.md should not exist @@ -67,7 +67,7 @@ async fn test_todo_read_from_file() { fs::write(&todo_path, test_content).unwrap(); // Create agent (should load from file) - let agent = create_test_agent_in_dir(&temp_dir).await; + let mut agent = create_test_agent_in_dir(&temp_dir).await; // Create a tool call to read TODO let tool_call = g3_core::ToolCall { @@ -88,7 +88,7 @@ async fn test_todo_read_from_file() { #[serial] async fn test_todo_read_empty_file() { let temp_dir = TempDir::new().unwrap(); - let agent = create_test_agent_in_dir(&temp_dir).await; + let mut agent = create_test_agent_in_dir(&temp_dir).await; // Create a tool call to read TODO (file doesn't exist) let tool_call = g3_core::ToolCall { @@ -111,7 +111,7 @@ async fn test_todo_persistence_across_agents() { // Agent 1: Write TODO { - let agent = create_test_agent_in_dir(&temp_dir).await; + let mut agent = create_test_agent_in_dir(&temp_dir).await; let tool_call = g3_core::ToolCall { tool: "todo_write".to_string(), args: serde_json::json!({ @@ -126,7 +126,7 @@ async fn test_todo_persistence_across_agents() { // Agent 2: Read TODO (new agent instance) { - let agent = create_test_agent_in_dir(&temp_dir).await; + let mut agent = create_test_agent_in_dir(&temp_dir).await; let tool_call = g3_core::ToolCall { tool: "todo_read".to_string(), args: serde_json::json!({}), @@ -143,7 +143,7 @@ async fn test_todo_persistence_across_agents() { #[serial] async fn test_todo_update_preserves_file() { let temp_dir = TempDir::new().unwrap(); - let agent = create_test_agent_in_dir(&temp_dir).await; + let mut agent = create_test_agent_in_dir(&temp_dir).await; let todo_path = get_todo_path(&temp_dir); // Write initial TODO @@ -173,7 +173,7 @@ async fn test_todo_update_preserves_file() { #[serial] async fn test_todo_handles_large_content() { let temp_dir = TempDir::new().unwrap(); - let agent = create_test_agent_in_dir(&temp_dir).await; + let mut agent = create_test_agent_in_dir(&temp_dir).await; let todo_path = get_todo_path(&temp_dir); // Create a large TODO (but under the 50k limit) @@ -202,7 +202,7 @@ async fn test_todo_handles_large_content() { #[serial] async fn test_todo_respects_size_limit() { let temp_dir = TempDir::new().unwrap(); - let agent = create_test_agent_in_dir(&temp_dir).await; + let mut agent = create_test_agent_in_dir(&temp_dir).await; // Create content that exceeds the default 50k limit let huge_content = "x".repeat(60_000); @@ -232,7 +232,7 @@ async fn test_todo_agent_initialization_loads_file() { fs::write(&todo_path, initial_content).unwrap(); // Create agent - should load the file during initialization - let agent = create_test_agent_in_dir(&temp_dir).await; + let mut agent = create_test_agent_in_dir(&temp_dir).await; // Read TODO - should return the pre-existing content let tool_call = g3_core::ToolCall { @@ -248,7 +248,7 @@ async fn test_todo_agent_initialization_loads_file() { #[serial] async fn test_todo_handles_unicode_content() { let temp_dir = TempDir::new().unwrap(); - let agent = create_test_agent_in_dir(&temp_dir).await; + let mut agent = create_test_agent_in_dir(&temp_dir).await; let todo_path = get_todo_path(&temp_dir); // Create TODO with unicode characters @@ -283,7 +283,7 @@ async fn test_todo_handles_unicode_content() { #[serial] async fn test_todo_empty_content_creates_empty_file() { let temp_dir = TempDir::new().unwrap(); - let agent = create_test_agent_in_dir(&temp_dir).await; + let mut agent = create_test_agent_in_dir(&temp_dir).await; let todo_path = get_todo_path(&temp_dir); // Write empty TODO @@ -306,7 +306,7 @@ async fn test_todo_empty_content_creates_empty_file() { #[serial] async fn test_todo_whitespace_only_content() { let temp_dir = TempDir::new().unwrap(); - let agent = create_test_agent_in_dir(&temp_dir).await; + let mut agent = create_test_agent_in_dir(&temp_dir).await; // Write whitespace-only TODO let tool_call = g3_core::ToolCall { diff --git a/crates/g3-providers/src/anthropic.rs b/crates/g3-providers/src/anthropic.rs index 75a5abb..69ac66f 100644 --- a/crates/g3-providers/src/anthropic.rs +++ b/crates/g3-providers/src/anthropic.rs @@ -21,22 +21,18 @@ //! // Create the provider with your API key //! let provider = AnthropicProvider::new( //! "your-api-key".to_string(), -//! Some("claude-3-5-sonnet-20241022".to_string()), // Optional: defaults to claude-3-5-sonnet-20241022 -//! Some(4096), // Optional: max tokens -//! Some(0.1), // Optional: temperature +//! Some("claude-3-5-sonnet-20241022".to_string()), +//! Some(4096), +//! Some(0.1), +//! None, // cache_config +//! None, // enable_1m_context //! )?; //! //! // Create a completion request //! let request = CompletionRequest { //! messages: vec![ -//! Message { -//! role: MessageRole::System, -//! content: "You are a helpful assistant.".to_string(), -//! }, -//! Message { -//! role: MessageRole::User, -//! content: "Hello! How are you?".to_string(), -//! }, +//! Message::new(MessageRole::System, "You are a helpful assistant.".to_string()), +//! Message::new(MessageRole::User, "Hello! How are you?".to_string()), //! ], //! max_tokens: Some(1000), //! temperature: Some(0.7), @@ -62,15 +58,16 @@ //! async fn main() -> anyhow::Result<()> { //! let provider = AnthropicProvider::new( //! "your-api-key".to_string(), -//! None, None, None, +//! None, +//! None, +//! None, +//! None, // cache_config +//! None, // enable_1m_context //! )?; //! //! let request = CompletionRequest { //! messages: vec![ -//! Message { -//! role: MessageRole::User, -//! content: "Write a short story about a robot.".to_string(), -//! }, +//! Message::new(MessageRole::User, "Write a short story about a robot.".to_string()), //! ], //! max_tokens: Some(1000), //! temperature: Some(0.7), @@ -123,6 +120,8 @@ pub struct AnthropicProvider { model: String, max_tokens: u32, temperature: f32, + cache_config: Option, + enable_1m_context: bool, } impl AnthropicProvider { @@ -131,6 +130,8 @@ impl AnthropicProvider { model: Option, max_tokens: Option, temperature: Option, + cache_config: Option, + enable_1m_context: Option, ) -> Result { let client = Client::builder() .timeout(Duration::from_secs(300)) @@ -147,6 +148,8 @@ impl AnthropicProvider { model, max_tokens: max_tokens.unwrap_or(4096), temperature: temperature.unwrap_or(0.1), + cache_config, + enable_1m_context: enable_1m_context.unwrap_or(false), }) } @@ -156,9 +159,12 @@ impl AnthropicProvider { .post(ANTHROPIC_API_URL) .header("x-api-key", &self.api_key) .header("anthropic-version", ANTHROPIC_VERSION) - // Anthropic beta 1m context window. Enable if needed. It costs extra, so check first. - // .header("anthropic-beta", "context-1m-2025-08-07") .header("content-type", "application/json"); + + if self.enable_1m_context { + builder = builder.header("anthropic-beta", "context-1m-2025-08-07"); + } + if streaming { builder = builder.header("accept", "text/event-stream"); } @@ -166,6 +172,11 @@ impl AnthropicProvider { builder } + fn convert_cache_control(cache_control: &crate::CacheControl) -> crate::CacheControl { + // Anthropic uses the same format, so just clone it + cache_control.clone() + } + fn convert_tools(&self, tools: &[Tool]) -> Vec { tools .iter() @@ -214,6 +225,8 @@ impl AnthropicProvider { role: "user".to_string(), content: vec![AnthropicContent::Text { text: message.content.clone(), + cache_control: message.cache_control.as_ref() + .map(Self::convert_cache_control), }], }); } @@ -222,6 +235,8 @@ impl AnthropicProvider { role: "assistant".to_string(), content: vec![AnthropicContent::Text { text: message.content.clone(), + cache_control: message.cache_control.as_ref() + .map(Self::convert_cache_control), }], }); } @@ -564,7 +579,7 @@ impl LLMProvider for AnthropicProvider { .content .iter() .filter_map(|c| match c { - AnthropicContent::Text { text } => Some(text.as_str()), + AnthropicContent::Text { text, .. } => Some(text.as_str()), _ => None, }) .collect::>() @@ -658,6 +673,11 @@ impl LLMProvider for AnthropicProvider { // Claude models support native tool calling true } + + fn supports_cache_control(&self) -> bool { + // Anthropic supports cache control + true + } } // Anthropic API request/response structures @@ -701,7 +721,11 @@ struct AnthropicMessage { #[serde(tag = "type")] enum AnthropicContent { #[serde(rename = "text")] - Text { text: String }, + Text { + text: String, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, + }, #[serde(rename = "tool_use")] ToolUse { id: String, @@ -771,21 +795,14 @@ mod tests { None, None, None, + None, + None, ).unwrap(); let messages = vec![ - Message { - role: MessageRole::System, - content: "You are a helpful assistant.".to_string(), - }, - Message { - role: MessageRole::User, - content: "Hello!".to_string(), - }, - Message { - role: MessageRole::Assistant, - content: "Hi there!".to_string(), - }, + Message::new(MessageRole::System, "You are a helpful assistant.".to_string()), + Message::new(MessageRole::User, "Hello!".to_string()), + Message::new(MessageRole::Assistant, "Hi there!".to_string()), ]; let (system, anthropic_messages) = provider.convert_messages(&messages).unwrap(); @@ -803,14 +820,11 @@ mod tests { Some("claude-3-haiku-20240307".to_string()), Some(1000), Some(0.5), + None, + None, ).unwrap(); - let messages = vec![ - Message { - role: MessageRole::User, - content: "Test message".to_string(), - }, - ]; + let messages = vec![Message::new(MessageRole::User, "Test message".to_string())]; let request_body = provider .create_request_body(&messages, None, false, 1000, 0.5) @@ -831,6 +845,8 @@ mod tests { None, None, None, + None, + None, ).unwrap(); let tools = vec![ @@ -859,4 +875,48 @@ mod tests { assert!(anthropic_tools[0].input_schema.required.is_some()); assert_eq!(anthropic_tools[0].input_schema.required.as_ref().unwrap()[0], "location"); } + + #[test] + fn test_cache_control_serialization() { + let provider = AnthropicProvider::new( + "test-key".to_string(), + None, + None, + None, + None, + None, + ).unwrap(); + + // Test message WITHOUT cache_control + let messages_without = vec![Message::new(MessageRole::User, "Hello".to_string())]; + let (_, anthropic_messages_without) = provider.convert_messages(&messages_without).unwrap(); + let json_without = serde_json::to_string(&anthropic_messages_without).unwrap(); + + println!("Anthropic JSON without cache_control: {}", json_without); + // Check if cache_control appears in the JSON + if json_without.contains("cache_control") { + println!("WARNING: JSON contains 'cache_control' field when not configured!"); + assert!(!json_without.contains("\"cache_control\":null"), + "JSON should not contain 'cache_control: null'"); + } + + // Test message WITH cache_control + let messages_with = vec![Message::with_cache_control( + MessageRole::User, + "Hello".to_string(), + crate::CacheControl::ephemeral(), + )]; + let (_, anthropic_messages_with) = provider.convert_messages(&messages_with).unwrap(); + let json_with = serde_json::to_string(&anthropic_messages_with).unwrap(); + + println!("Anthropic JSON with cache_control: {}", json_with); + assert!(json_with.contains("cache_control"), + "JSON should contain 'cache_control' field when configured"); + assert!(json_with.contains("ephemeral"), + "JSON should contain 'ephemeral' type"); + + // The key assertion: when cache_control is None, it should not appear in JSON + assert!(!json_without.contains("cache_control") || !json_without.contains("null"), + "JSON should not contain 'cache_control' field or null values when not configured"); + } } diff --git a/crates/g3-providers/src/databricks.rs b/crates/g3-providers/src/databricks.rs index 962822f..d7aed9b 100644 --- a/crates/g3-providers/src/databricks.rs +++ b/crates/g3-providers/src/databricks.rs @@ -39,10 +39,7 @@ //! // Create a completion request //! let request = CompletionRequest { //! messages: vec![ -//! Message { -//! role: MessageRole::User, -//! content: "Hello! How are you?".to_string(), -//! }, +//! Message::new(MessageRole::User, "Hello! How are you?".to_string()), //! ], //! max_tokens: Some(1000), //! temperature: Some(0.7), @@ -251,9 +248,12 @@ impl DatabricksProvider { MessageRole::Assistant => "assistant", }; + // Always use simple string format (Databricks doesn't support cache_control) + let content = serde_json::Value::String(message.content.clone()); + databricks_messages.push(DatabricksMessage { role: role.to_string(), - content: Some(message.content.clone()), + content: Some(content), tool_calls: None, // Only used in responses, not requests }); } @@ -864,8 +864,22 @@ impl LLMProvider for DatabricksProvider { let content = databricks_response .choices .first() - .and_then(|choice| choice.message.content.as_ref()) - .cloned() + .and_then(|choice| { + choice.message.content.as_ref().map(|c| { + // Handle both string and array formats + if let Some(s) = c.as_str() { + s.to_string() + } else if let Some(arr) = c.as_array() { + // Extract text from content blocks + arr.iter() + .filter_map(|block| block.get("text").and_then(|t| t.as_str())) + .collect::>() + .join("") + } else { + String::new() + } + }) + }) .unwrap_or_default(); // Check if there are tool calls in the response @@ -1037,6 +1051,10 @@ impl LLMProvider for DatabricksProvider { // This includes Claude, Llama, DBRX, and most other models on the platform true } + + fn supports_cache_control(&self) -> bool { + false + } } // Databricks API request/response structures @@ -1067,7 +1085,8 @@ struct DatabricksFunction { #[derive(Debug, Serialize, Deserialize)] struct DatabricksMessage { role: String, - content: Option, // Make content optional since tool calls might not have content + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, // Can be string or array of content blocks #[serde(skip_serializing_if = "Option::is_none")] tool_calls: Option>, // Add tool_calls field for responses } @@ -1154,18 +1173,9 @@ mod tests { .unwrap(); let messages = vec![ - Message { - role: MessageRole::System, - content: "You are a helpful assistant.".to_string(), - }, - Message { - role: MessageRole::User, - content: "Hello!".to_string(), - }, - Message { - role: MessageRole::Assistant, - content: "Hi there!".to_string(), - }, + Message::new(MessageRole::System, "You are a helpful assistant.".to_string()), + Message::new(MessageRole::User, "Hello!".to_string()), + Message::new(MessageRole::Assistant, "Hi there!".to_string()), ]; let databricks_messages = provider.convert_messages(&messages).unwrap(); @@ -1187,10 +1197,7 @@ mod tests { ) .unwrap(); - let messages = vec![Message { - role: MessageRole::User, - content: "Test message".to_string(), - }]; + let messages = vec![Message::new(MessageRole::User, "Test message".to_string())]; let request_body = provider .create_request_body(&messages, None, false, 1000, 0.5) @@ -1273,4 +1280,62 @@ mod tests { assert!(llama_provider.has_native_tool_calling()); assert!(dbrx_provider.has_native_tool_calling()); } + + #[test] + fn test_cache_control_serialization() { + let provider = DatabricksProvider::from_token( + "https://test.databricks.com".to_string(), + "test-token".to_string(), + "databricks-claude-sonnet-4".to_string(), + None, + None, + ) + .unwrap(); + + // Test message WITHOUT cache_control + let messages_without = vec![Message::new(MessageRole::User, "Hello".to_string())]; + let databricks_messages_without = provider.convert_messages(&messages_without).unwrap(); + let json_without = serde_json::to_string(&databricks_messages_without).unwrap(); + + println!("JSON without cache_control: {}", json_without); + assert!(!json_without.contains("cache_control"), + "JSON should not contain 'cache_control' field when not configured"); + + // Test message WITH cache_control - should still NOT include it (Databricks doesn't support it) + let messages_with = vec![Message::with_cache_control( + MessageRole::User, + "Hello".to_string(), + crate::CacheControl::ephemeral(), + )]; + let databricks_messages_with = provider.convert_messages(&messages_with).unwrap(); + let json_with = serde_json::to_string(&databricks_messages_with).unwrap(); + + println!("JSON with cache_control: {}", json_with); + assert!(!json_with.contains("cache_control"), + "JSON should NOT contain 'cache_control' field - Databricks doesn't support it"); + } + + #[test] + fn test_databricks_does_not_support_cache_control() { + let claude_provider = DatabricksProvider::from_token( + "https://test.databricks.com".to_string(), + "test-token".to_string(), + "databricks-claude-sonnet-4".to_string(), + None, + None, + ) + .unwrap(); + + let llama_provider = DatabricksProvider::from_token( + "https://test.databricks.com".to_string(), + "test-token".to_string(), + "databricks-meta-llama-3-3-70b-instruct".to_string(), + None, + None, + ) + .unwrap(); + + assert!(!claude_provider.supports_cache_control(), "Databricks should not support cache_control even for Claude models"); + assert!(!llama_provider.supports_cache_control(), "Databricks should not support cache_control for Llama models"); + } } diff --git a/crates/g3-providers/src/lib.rs b/crates/g3-providers/src/lib.rs index 51ea55a..f725c2f 100644 --- a/crates/g3-providers/src/lib.rs +++ b/crates/g3-providers/src/lib.rs @@ -21,6 +21,11 @@ pub trait LLMProvider: Send + Sync { fn has_native_tool_calling(&self) -> bool { false } + + /// Check if the provider supports cache control + fn supports_cache_control(&self) -> bool { + false + } } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -32,10 +37,40 @@ pub struct CompletionRequest { pub tools: Option>, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CacheControl { + #[serde(rename = "type")] + pub cache_type: CacheType, + #[serde(skip_serializing_if = "Option::is_none")] + pub ttl: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum CacheType { + Ephemeral, +} + +impl CacheControl { + pub fn ephemeral() -> Self { + Self { cache_type: CacheType::Ephemeral, ttl: None } + } + + pub fn five_minute() -> Self { + Self { cache_type: CacheType::Ephemeral, ttl: Some("5m".to_string()) } + } + + pub fn one_hour() -> Self { + Self { cache_type: CacheType::Ephemeral, ttl: Some("1h".to_string()) } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Message { pub role: MessageRole, pub content: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_control: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -95,6 +130,45 @@ pub use databricks::DatabricksProvider; pub use embedded::EmbeddedProvider; pub use openai::OpenAIProvider; +impl Message { + /// Create a new message with optional cache control + pub fn new(role: MessageRole, content: String) -> Self { + Self { + role, + content, + cache_control: None, + } + } + + /// Create a new message with cache control + pub fn with_cache_control(role: MessageRole, content: String, cache_control: CacheControl) -> Self { + Self { + role, + content, + cache_control: Some(cache_control), + } + } + + /// Create a message with cache control, with provider validation + pub fn with_cache_control_validated( + role: MessageRole, + content: String, + cache_control: CacheControl, + provider: &dyn LLMProvider + ) -> Self { + if !provider.supports_cache_control() { + tracing::warn!( + "Cache control requested for provider '{}' which does not support it. \ + Cache control is only supported by Anthropic and Anthropic via Databricks.", + provider.name() + ); + return Self::new(role, content); + } + + Self::with_cache_control(role, content, cache_control) + } +} + /// Provider registry for managing multiple LLM providers pub struct ProviderRegistry { providers: HashMap>, @@ -144,3 +218,68 @@ impl Default for ProviderRegistry { Self::new() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_message_serialization_without_cache_control() { + let msg = Message::new(MessageRole::User, "Hello".to_string()); + let json = serde_json::to_string(&msg).unwrap(); + + println!("Message JSON without cache_control: {}", json); + assert!(!json.contains("cache_control"), + "JSON should not contain 'cache_control' field when not configured"); + } + + #[test] + fn test_message_serialization_with_cache_control() { + let msg = Message::with_cache_control( + MessageRole::User, + "Hello".to_string(), + CacheControl::ephemeral(), + ); + let json = serde_json::to_string(&msg).unwrap(); + + println!("Message JSON with cache_control: {}", json); + assert!(json.contains("cache_control"), + "JSON should contain 'cache_control' field when configured"); + assert!(json.contains("ephemeral"), + "JSON should contain 'ephemeral' value"); + assert!(json.contains("\"type\":"), + "JSON should contain 'type' field in cache_control"); + assert!(!json.contains("null"), + "JSON should not contain null values"); + } + + #[test] + fn test_cache_control_five_minute_serialization() { + let msg = Message::with_cache_control( + MessageRole::User, + "Hello".to_string(), + CacheControl::five_minute(), + ); + let json = serde_json::to_string(&msg).unwrap(); + + println!("Message JSON with 5-minute cache_control: {}", json); + assert!(json.contains("cache_control"), "JSON should contain 'cache_control' field"); + assert!(json.contains("ephemeral"), "JSON should contain 'ephemeral' type"); + assert!(json.contains("\"ttl\":\"5m\""), "JSON should contain ttl field with 5m value"); + } + + #[test] + fn test_cache_control_one_hour_serialization() { + let msg = Message::with_cache_control( + MessageRole::User, + "Hello".to_string(), + CacheControl::one_hour(), + ); + let json = serde_json::to_string(&msg).unwrap(); + + println!("Message JSON with 1-hour cache_control: {}", json); + assert!(json.contains("cache_control"), "JSON should contain 'cache_control' field"); + assert!(json.contains("ephemeral"), "JSON should contain 'ephemeral' type"); + assert!(json.contains("\"ttl\":\"1h\""), "JSON should contain ttl field with 1h value"); + } +} diff --git a/crates/g3-providers/tests/cache_control_error_regression_test.rs b/crates/g3-providers/tests/cache_control_error_regression_test.rs new file mode 100644 index 0000000..533c943 --- /dev/null +++ b/crates/g3-providers/tests/cache_control_error_regression_test.rs @@ -0,0 +1,131 @@ +//! Regression test for cache_control serialization bug +//! +//! This test verifies that cache_control is NOT serialized in the wrong format. +//! The bug was that it serialized as: +//! - `system.0.cache_control.ephemeral.ttl` (WRONG) +//! +//! It should serialize as: +//! - `"cache_control": {"type": "ephemeral"}` for ephemeral +//! - `"cache_control": {"type": "ephemeral", "ttl": "5m"}` for 5minute +//! - `"cache_control": {"type": "ephemeral", "ttl": "1h"}` for 1hour + +use g3_providers::{CacheControl, Message, MessageRole}; + +#[test] +fn test_no_wrong_serialization_format() { + // Test ephemeral + let msg = Message::with_cache_control( + MessageRole::System, + "Test".to_string(), + CacheControl::ephemeral(), + ); + let json = serde_json::to_string(&msg).unwrap(); + + println!("Ephemeral message JSON: {}", json); + + // Should NOT contain the wrong format + assert!(!json.contains("system.0.cache_control"), + "JSON should not contain 'system.0.cache_control' path"); + assert!(!json.contains("cache_control.ephemeral"), + "JSON should not contain 'cache_control.ephemeral' path"); + + // Should contain the correct format + assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#), + "JSON should contain correct cache_control format"); +} + +#[test] +fn test_five_minute_no_wrong_format() { + let msg = Message::with_cache_control( + MessageRole::System, + "Test".to_string(), + CacheControl::five_minute(), + ); + let json = serde_json::to_string(&msg).unwrap(); + + println!("5-minute message JSON: {}", json); + + // Should NOT contain the wrong format + assert!(!json.contains("system.0.cache_control"), + "JSON should not contain 'system.0.cache_control' path"); + assert!(!json.contains("cache_control.ephemeral.ttl"), + "JSON should not contain 'cache_control.ephemeral.ttl' path"); + + // Should contain the correct format with ttl as a direct field + assert!(json.contains(r#""type":"ephemeral""#), + "JSON should contain type field"); + assert!(json.contains(r#""ttl":"5m""#), + "JSON should contain ttl field with value 5m"); +} + +#[test] +fn test_one_hour_no_wrong_format() { + let msg = Message::with_cache_control( + MessageRole::System, + "Test".to_string(), + CacheControl::one_hour(), + ); + let json = serde_json::to_string(&msg).unwrap(); + + println!("1-hour message JSON: {}", json); + + // Should NOT contain the wrong format + assert!(!json.contains("system.0.cache_control"), + "JSON should not contain 'system.0.cache_control' path"); + assert!(!json.contains("cache_control.ephemeral.ttl"), + "JSON should not contain 'cache_control.ephemeral.ttl' path"); + + // Should contain the correct format with ttl as a direct field + assert!(json.contains(r#""type":"ephemeral""#), + "JSON should contain type field"); + assert!(json.contains(r#""ttl":"1h""#), + "JSON should contain ttl field with value 1h"); +} + +#[test] +fn test_cache_control_structure_is_flat() { + // Verify that the cache_control object has a flat structure + // with 'type' and optional 'ttl' at the same level + + let cache_control = CacheControl::five_minute(); + let json_value = serde_json::to_value(&cache_control).unwrap(); + + println!("Cache control as JSON value: {}", serde_json::to_string_pretty(&json_value).unwrap()); + + let obj = json_value.as_object().expect("Should be an object"); + + // Should have exactly 2 keys at the top level + assert_eq!(obj.len(), 2, "Cache control should have exactly 2 top-level fields"); + + // Both 'type' and 'ttl' should be at the same level + assert!(obj.contains_key("type"), "Should have 'type' field"); + assert!(obj.contains_key("ttl"), "Should have 'ttl' field"); + + // 'type' should be a string, not an object + assert!(obj["type"].is_string(), "'type' should be a string value"); + + // 'ttl' should be a string, not nested + assert!(obj["ttl"].is_string(), "'ttl' should be a string value"); +} + +#[test] +fn test_ephemeral_cache_control_structure() { + let cache_control = CacheControl::ephemeral(); + let json_value = serde_json::to_value(&cache_control).unwrap(); + + println!("Ephemeral cache control as JSON value: {}", serde_json::to_string_pretty(&json_value).unwrap()); + + let obj = json_value.as_object().expect("Should be an object"); + + // Should have exactly 1 key (only 'type', no 'ttl') + assert_eq!(obj.len(), 1, "Ephemeral cache control should have exactly 1 top-level field"); + + // Should have 'type' field + assert!(obj.contains_key("type"), "Should have 'type' field"); + + // Should NOT have 'ttl' field + assert!(!obj.contains_key("ttl"), "Ephemeral should not have 'ttl' field"); + + // 'type' should be a string with value "ephemeral" + assert_eq!(obj["type"].as_str().unwrap(), "ephemeral"); +} diff --git a/crates/g3-providers/tests/cache_control_integration_test.rs b/crates/g3-providers/tests/cache_control_integration_test.rs new file mode 100644 index 0000000..5ec365c --- /dev/null +++ b/crates/g3-providers/tests/cache_control_integration_test.rs @@ -0,0 +1,164 @@ +//! Integration tests for cache_control feature +//! +//! These tests verify that cache_control is correctly serialized in messages +//! for both Anthropic and Databricks providers. + +use g3_providers::{CacheControl, Message, MessageRole}; +use serde_json::json; + +#[test] +fn test_ephemeral_cache_control_serialization() { + let cache_control = CacheControl::ephemeral(); + let json = serde_json::to_value(&cache_control).unwrap(); + + println!("Ephemeral cache_control JSON: {}", serde_json::to_string(&json).unwrap()); + + assert_eq!(json, json!({ + "type": "ephemeral" + })); + + // Verify no ttl field is present + assert!(!json.as_object().unwrap().contains_key("ttl")); +} + +#[test] +fn test_five_minute_cache_control_serialization() { + let cache_control = CacheControl::five_minute(); + let json = serde_json::to_value(&cache_control).unwrap(); + + println!("5-minute cache_control JSON: {}", serde_json::to_string(&json).unwrap()); + + assert_eq!(json, json!({ + "type": "ephemeral", + "ttl": "5m" + })); +} + +#[test] +fn test_one_hour_cache_control_serialization() { + let cache_control = CacheControl::one_hour(); + let json = serde_json::to_value(&cache_control).unwrap(); + + println!("1-hour cache_control JSON: {}", serde_json::to_string(&json).unwrap()); + + assert_eq!(json, json!({ + "type": "ephemeral", + "ttl": "1h" + })); +} + +#[test] +fn test_message_with_ephemeral_cache_control() { + let msg = Message::with_cache_control( + MessageRole::System, + "System prompt".to_string(), + CacheControl::ephemeral(), + ); + + let json = serde_json::to_value(&msg).unwrap(); + println!("Message with ephemeral cache_control: {}", serde_json::to_string(&json).unwrap()); + + let cache_control = json.get("cache_control").expect("cache_control field should exist"); + assert_eq!(cache_control.get("type").unwrap(), "ephemeral"); + assert!(!cache_control.as_object().unwrap().contains_key("ttl")); +} + +#[test] +fn test_message_with_five_minute_cache_control() { + let msg = Message::with_cache_control( + MessageRole::System, + "System prompt".to_string(), + CacheControl::five_minute(), + ); + + let json = serde_json::to_value(&msg).unwrap(); + println!("Message with 5-minute cache_control: {}", serde_json::to_string(&json).unwrap()); + + let cache_control = json.get("cache_control").expect("cache_control field should exist"); + assert_eq!(cache_control.get("type").unwrap(), "ephemeral"); + assert_eq!(cache_control.get("ttl").unwrap(), "5m"); +} + +#[test] +fn test_message_with_one_hour_cache_control() { + let msg = Message::with_cache_control( + MessageRole::System, + "System prompt".to_string(), + CacheControl::one_hour(), + ); + + let json = serde_json::to_value(&msg).unwrap(); + println!("Message with 1-hour cache_control: {}", serde_json::to_string(&json).unwrap()); + + let cache_control = json.get("cache_control").expect("cache_control field should exist"); + assert_eq!(cache_control.get("type").unwrap(), "ephemeral"); + assert_eq!(cache_control.get("ttl").unwrap(), "1h"); +} + +#[test] +fn test_message_without_cache_control() { + let msg = Message::new(MessageRole::User, "Hello".to_string()); + + let json = serde_json::to_value(&msg).unwrap(); + println!("Message without cache_control: {}", serde_json::to_string(&json).unwrap()); + + // cache_control field should not be present when not set + assert!(!json.as_object().unwrap().contains_key("cache_control")); +} + +#[test] +fn test_cache_control_json_format_ephemeral() { + let cache_control = CacheControl::ephemeral(); + let json_str = serde_json::to_string(&cache_control).unwrap(); + + println!("Ephemeral JSON string: {}", json_str); + + // Verify exact JSON format + assert_eq!(json_str, r#"{"type":"ephemeral"}"#); +} + +#[test] +fn test_cache_control_json_format_five_minute() { + let cache_control = CacheControl::five_minute(); + let json_str = serde_json::to_string(&cache_control).unwrap(); + + println!("5-minute JSON string: {}", json_str); + + // Verify exact JSON format + assert_eq!(json_str, r#"{"type":"ephemeral","ttl":"5m"}"#); +} + +#[test] +fn test_cache_control_json_format_one_hour() { + let cache_control = CacheControl::one_hour(); + let json_str = serde_json::to_string(&cache_control).unwrap(); + + println!("1-hour JSON string: {}", json_str); + + // Verify exact JSON format + assert_eq!(json_str, r#"{"type":"ephemeral","ttl":"1h"}"#); +} + +#[test] +fn test_deserialization_ephemeral() { + let json_str = r#"{"type":"ephemeral"}"#; + let cache_control: CacheControl = serde_json::from_str(json_str).unwrap(); + + assert_eq!(cache_control.ttl, None); +} + +#[test] +fn test_deserialization_five_minute() { + let json_str = r#"{"type":"ephemeral","ttl":"5m"}"#; + let cache_control: CacheControl = serde_json::from_str(json_str).unwrap(); + + assert_eq!(cache_control.ttl, Some("5m".to_string())); +} + +#[test] +fn test_deserialization_one_hour() { + let json_str = r#"{"type":"ephemeral","ttl":"1h"}"#; + let cache_control: CacheControl = serde_json::from_str(json_str).unwrap(); + + assert_eq!(cache_control.ttl, Some("1h".to_string())); +}