adds cache_control
This commit is contained in:
@@ -11,12 +11,23 @@ model = "databricks-claude-sonnet-4"
|
|||||||
max_tokens = 4096
|
max_tokens = 4096
|
||||||
temperature = 0.1
|
temperature = 0.1
|
||||||
use_oauth = true
|
use_oauth = true
|
||||||
|
# cache_config = "ephemeral" # Optional: Enable prompt caching for Claude models
|
||||||
|
# Options: "ephemeral", "5minute", "1hour"
|
||||||
|
# Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs.
|
||||||
|
# The cache control will be automatically applied to:
|
||||||
|
# - The system prompt at the start of each session
|
||||||
|
# - Assistant responses after every 10 tool calls
|
||||||
|
# - 5minute costs $3/mtok, more details below
|
||||||
|
# https://docs.claude.com/en/docs/build-with-claude/prompt-caching#pricing
|
||||||
|
|
||||||
[providers.anthropic]
|
[providers.anthropic]
|
||||||
api_key = "your-anthropic-api-key"
|
api_key = "your-anthropic-api-key"
|
||||||
model = "claude-3-haiku-20240307" # Using a faster model for player
|
model = "claude-3-haiku-20240307" # Using a faster model for player
|
||||||
max_tokens = 4096
|
max_tokens = 4096
|
||||||
temperature = 0.3 # Slightly higher temperature for more creative implementations
|
temperature = 0.3 # Slightly higher temperature for more creative implementations
|
||||||
|
# cache_config = "ephemeral" # Optional: Enable prompt caching
|
||||||
|
# Options: "ephemeral", "5minute", "1hour"
|
||||||
|
# Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs.
|
||||||
|
|
||||||
[agent]
|
[agent]
|
||||||
fallback_default_max_tokens = 8192
|
fallback_default_max_tokens = 8192
|
||||||
|
|||||||
@@ -14,6 +14,15 @@ max_tokens = 4096 # Per-request output limit (how many tokens the model can gen
|
|||||||
# Note: This is different from max_context_length (total conversation history size)
|
# Note: This is different from max_context_length (total conversation history size)
|
||||||
temperature = 0.1
|
temperature = 0.1
|
||||||
use_oauth = true
|
use_oauth = true
|
||||||
|
# cache_config = "ephemeral" # Optional: Enable prompt caching for Claude models on Databricks
|
||||||
|
# Options: "ephemeral", "5minute", "1hour"
|
||||||
|
# Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs.
|
||||||
|
# The cache control will be automatically applied to:
|
||||||
|
# - The system prompt at the start of each session
|
||||||
|
# - Assistant responses after every 10 tool calls
|
||||||
|
# - 5minute costs $3/mtok, more details below
|
||||||
|
# https://docs.claude.com/en/docs/build-with-claude/prompt-caching#pricing
|
||||||
|
|
||||||
|
|
||||||
# Multiple OpenAI-compatible providers can be configured with custom names
|
# Multiple OpenAI-compatible providers can be configured with custom names
|
||||||
# Each provider gets its own section under [providers.openai_compatible.<name>]
|
# Each provider gets its own section under [providers.openai_compatible.<name>]
|
||||||
|
|||||||
@@ -40,6 +40,8 @@ pub struct AnthropicConfig {
|
|||||||
pub model: String,
|
pub model: String,
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
|
pub cache_config: Option<String>, // "ephemeral", "5minute", "1hour", or None to disable
|
||||||
|
pub enable_1m_context: Option<bool>, // Enable 1m context window (costs extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -49,6 +51,7 @@ pub struct DatabricksConfig {
|
|||||||
pub model: String,
|
pub model: String,
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
|
pub cache_config: Option<String>, // "ephemeral", "5minute", "1hour", or None to disable
|
||||||
pub use_oauth: Option<bool>, // Default to true if token not provided
|
pub use_oauth: Option<bool>, // Default to true if token not provided
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -132,6 +135,7 @@ impl Default for Config {
|
|||||||
model: "databricks-claude-sonnet-4".to_string(),
|
model: "databricks-claude-sonnet-4".to_string(),
|
||||||
max_tokens: Some(4096),
|
max_tokens: Some(4096),
|
||||||
temperature: Some(0.1),
|
temperature: Some(0.1),
|
||||||
|
cache_config: None,
|
||||||
use_oauth: Some(true),
|
use_oauth: Some(true),
|
||||||
}),
|
}),
|
||||||
embedded: None,
|
embedded: None,
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ use anyhow::Result;
|
|||||||
use g3_computer_control::WebDriverController;
|
use g3_computer_control::WebDriverController;
|
||||||
use g3_config::Config;
|
use g3_config::Config;
|
||||||
use g3_execution::CodeExecutor;
|
use g3_execution::CodeExecutor;
|
||||||
use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry, Tool};
|
use g3_providers::{CacheControl, CompletionRequest, Message, MessageRole, ProviderRegistry, Tool};
|
||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@@ -423,18 +423,12 @@ Format this as a detailed but concise summary that can be used to resume the con
|
|||||||
self.used_tokens = 0;
|
self.used_tokens = 0;
|
||||||
|
|
||||||
// Add the summary as a system message
|
// Add the summary as a system message
|
||||||
let summary_message = Message {
|
let summary_message = Message::new(MessageRole::System, format!("Previous conversation summary:\n\n{}", summary));
|
||||||
role: MessageRole::System,
|
|
||||||
content: format!("Previous conversation summary:\n\n{}", summary),
|
|
||||||
};
|
|
||||||
self.add_message(summary_message);
|
self.add_message(summary_message);
|
||||||
|
|
||||||
// Add the latest user message if provided
|
// Add the latest user message if provided
|
||||||
if let Some(user_msg) = latest_user_message {
|
if let Some(user_msg) = latest_user_message {
|
||||||
self.add_message(Message {
|
self.add_message(Message::new(MessageRole::User, user_msg));
|
||||||
role: MessageRole::User,
|
|
||||||
content: user_msg,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let new_chars: usize = self
|
let new_chars: usize = self
|
||||||
@@ -756,6 +750,7 @@ pub struct Agent<W: UiWriter> {
|
|||||||
safaridriver_process: std::sync::Arc<tokio::sync::RwLock<Option<tokio::process::Child>>>,
|
safaridriver_process: std::sync::Arc<tokio::sync::RwLock<Option<tokio::process::Child>>>,
|
||||||
macax_controller:
|
macax_controller:
|
||||||
std::sync::Arc<tokio::sync::RwLock<Option<g3_computer_control::MacAxController>>>,
|
std::sync::Arc<tokio::sync::RwLock<Option<g3_computer_control::MacAxController>>>,
|
||||||
|
tool_call_count: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<W: UiWriter> Agent<W> {
|
impl<W: UiWriter> Agent<W> {
|
||||||
@@ -898,6 +893,8 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
Some(anthropic_config.model.clone()),
|
Some(anthropic_config.model.clone()),
|
||||||
anthropic_config.max_tokens,
|
anthropic_config.max_tokens,
|
||||||
anthropic_config.temperature,
|
anthropic_config.temperature,
|
||||||
|
anthropic_config.cache_config.clone(),
|
||||||
|
anthropic_config.enable_1m_context,
|
||||||
)?;
|
)?;
|
||||||
providers.register(anthropic_provider);
|
providers.register(anthropic_provider);
|
||||||
}
|
}
|
||||||
@@ -944,10 +941,7 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
|
|
||||||
// If README content is provided, add it as the first system message
|
// If README content is provided, add it as the first system message
|
||||||
if let Some(readme) = readme_content {
|
if let Some(readme) = readme_content {
|
||||||
let readme_message = Message {
|
let readme_message = Message::new(MessageRole::System, readme);
|
||||||
role: MessageRole::System,
|
|
||||||
content: readme,
|
|
||||||
};
|
|
||||||
context_window.add_message(readme_message);
|
context_window.add_message(readme_message);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1003,9 +997,23 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
None
|
None
|
||||||
}))
|
}))
|
||||||
},
|
},
|
||||||
|
tool_call_count: 0,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Convert cache config string to CacheControl enum
|
||||||
|
fn parse_cache_control(cache_config: &str) -> Option<CacheControl> {
|
||||||
|
match cache_config {
|
||||||
|
"ephemeral" => Some(CacheControl::Ephemeral),
|
||||||
|
"5minute" => Some(CacheControl::FiveMinute),
|
||||||
|
"1hour" => Some(CacheControl::OneHour),
|
||||||
|
_ => {
|
||||||
|
warn!("Invalid cache_config value: '{}'. Valid values are: ephemeral, 5minute, 1hour", cache_config);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn get_configured_context_length(config: &Config, providers: &ProviderRegistry) -> Result<u32> {
|
fn get_configured_context_length(config: &Config, providers: &ProviderRegistry) -> Result<u32> {
|
||||||
// First, check if there's a global max_context_length override in agent config
|
// First, check if there's a global max_context_length override in agent config
|
||||||
if let Some(max_context_length) = config.agent.max_context_length {
|
if let Some(max_context_length) = config.agent.max_context_length {
|
||||||
@@ -1185,7 +1193,11 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
// Only add system message if this is the first interaction (empty conversation history)
|
// Only add system message if this is the first interaction (empty conversation history)
|
||||||
if self.context_window.conversation_history.is_empty() {
|
if self.context_window.conversation_history.is_empty() {
|
||||||
let provider = self.providers.get(None)?;
|
let provider = self.providers.get(None)?;
|
||||||
let system_prompt = if provider.has_native_tool_calling() {
|
let provider_has_native_tool_calling = provider.has_native_tool_calling();
|
||||||
|
let provider_name_for_system = provider.name().to_string();
|
||||||
|
drop(provider); // Drop provider reference to avoid borrowing issues
|
||||||
|
|
||||||
|
let system_prompt = if provider_has_native_tool_calling {
|
||||||
// For native tool calling providers, use a more explicit system prompt
|
// For native tool calling providers, use a more explicit system prompt
|
||||||
"You are G3, an AI programming agent of the same skill level as a seasoned engineer at a major technology company. You analyze given tasks and write code to achieve goals.
|
"You are G3, an AI programming agent of the same skill level as a seasoned engineer at a major technology company. You analyze given tasks and write code to achieve goals.
|
||||||
|
|
||||||
@@ -1493,18 +1505,34 @@ If you can complete it with 1-2 tool calls, skip TODO.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add system message to context window
|
// Add system message to context window
|
||||||
let system_message = Message {
|
let system_message = {
|
||||||
role: MessageRole::System,
|
// Check if we should use cache control for system message
|
||||||
content: system_prompt,
|
if let Some(cache_config) = match provider_name_for_system.as_str() {
|
||||||
|
"anthropic" => self.config.providers.anthropic.as_ref()
|
||||||
|
.and_then(|c| c.cache_config.as_ref())
|
||||||
|
.and_then(|config| Self::parse_cache_control(config)),
|
||||||
|
"databricks" => self.config.providers.databricks.as_ref()
|
||||||
|
.and_then(|c| {
|
||||||
|
if c.model.contains("claude") {
|
||||||
|
c.cache_config.as_ref()
|
||||||
|
.and_then(|config| Self::parse_cache_control(config))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
_ => None,
|
||||||
|
} {
|
||||||
|
let provider = self.providers.get(None)?;
|
||||||
|
Message::with_cache_control_validated(MessageRole::System, system_prompt, cache_config, provider)
|
||||||
|
} else {
|
||||||
|
Message::new(MessageRole::System, system_prompt)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
self.context_window.add_message(system_message);
|
self.context_window.add_message(system_message);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add user message to context window
|
// Add user message to context window
|
||||||
let user_message = Message {
|
let user_message = Message::new(MessageRole::User, format!("Task: {}", description));
|
||||||
role: MessageRole::User,
|
|
||||||
content: format!("Task: {}", description),
|
|
||||||
};
|
|
||||||
self.context_window.add_message(user_message);
|
self.context_window.add_message(user_message);
|
||||||
|
|
||||||
// Use the complete conversation history for the request
|
// Use the complete conversation history for the request
|
||||||
@@ -1512,6 +1540,9 @@ If you can complete it with 1-2 tool calls, skip TODO.
|
|||||||
|
|
||||||
// Check if provider supports native tool calling and add tools if so
|
// Check if provider supports native tool calling and add tools if so
|
||||||
let provider = self.providers.get(None)?;
|
let provider = self.providers.get(None)?;
|
||||||
|
let provider_name = provider.name().to_string();
|
||||||
|
let has_native_tool_calling = provider.has_native_tool_calling();
|
||||||
|
let supports_cache_control = provider.supports_cache_control();
|
||||||
let tools = if provider.has_native_tool_calling() {
|
let tools = if provider.has_native_tool_calling() {
|
||||||
Some(Self::create_tool_definitions(
|
Some(Self::create_tool_definitions(
|
||||||
self.config.webdriver.enabled,
|
self.config.webdriver.enabled,
|
||||||
@@ -1521,9 +1552,10 @@ If you can complete it with 1-2 tool calls, skip TODO.
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
drop(provider); // Drop the provider reference to avoid borrowing issues
|
||||||
|
|
||||||
// Get max_tokens from provider configuration
|
// Get max_tokens from provider configuration
|
||||||
let max_tokens = match provider.name() {
|
let max_tokens = match provider_name.as_str() {
|
||||||
"databricks" => {
|
"databricks" => {
|
||||||
// Use the model's maximum limit for Databricks to allow large file generation
|
// Use the model's maximum limit for Databricks to allow large file generation
|
||||||
Some(32000)
|
Some(32000)
|
||||||
@@ -1578,9 +1610,32 @@ If you can complete it with 1-2 tool calls, skip TODO.
|
|||||||
// Add assistant response to context window only if not empty
|
// Add assistant response to context window only if not empty
|
||||||
// This prevents the "Skipping empty message" warning when only tools were executed
|
// This prevents the "Skipping empty message" warning when only tools were executed
|
||||||
if !response_content.trim().is_empty() {
|
if !response_content.trim().is_empty() {
|
||||||
let assistant_message = Message {
|
let assistant_message = {
|
||||||
role: MessageRole::Assistant,
|
// Check if we should use cache control (every 10 tool calls)
|
||||||
content: response_content.clone(),
|
if self.tool_call_count > 0 && self.tool_call_count % 10 == 0 {
|
||||||
|
let provider = self.providers.get(None)?;
|
||||||
|
if let Some(cache_config) = match provider.name() {
|
||||||
|
"anthropic" => self.config.providers.anthropic.as_ref()
|
||||||
|
.and_then(|c| c.cache_config.as_ref())
|
||||||
|
.and_then(|config| Self::parse_cache_control(config)),
|
||||||
|
"databricks" => self.config.providers.databricks.as_ref()
|
||||||
|
.and_then(|c| {
|
||||||
|
if c.model.contains("claude") {
|
||||||
|
c.cache_config.as_ref()
|
||||||
|
.and_then(|config| Self::parse_cache_control(config))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
_ => None,
|
||||||
|
} {
|
||||||
|
Message::with_cache_control_validated(MessageRole::Assistant, response_content.clone(), cache_config, provider)
|
||||||
|
} else {
|
||||||
|
Message::new(MessageRole::Assistant, response_content.clone())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Message::new(MessageRole::Assistant, response_content.clone())
|
||||||
|
}
|
||||||
};
|
};
|
||||||
self.context_window.add_message(assistant_message);
|
self.context_window.add_message(assistant_message);
|
||||||
} else {
|
} else {
|
||||||
@@ -1783,17 +1838,11 @@ If you can complete it with 1-2 tool calls, skip TODO.
|
|||||||
.join("\n\n");
|
.join("\n\n");
|
||||||
|
|
||||||
let summary_messages = vec![
|
let summary_messages = vec![
|
||||||
Message {
|
Message::new(MessageRole::System, "You are a helpful assistant that creates concise summaries.".to_string()),
|
||||||
role: MessageRole::System,
|
Message::new(MessageRole::User, format!(
|
||||||
content: "You are a helpful assistant that creates concise summaries.".to_string(),
|
|
||||||
},
|
|
||||||
Message {
|
|
||||||
role: MessageRole::User,
|
|
||||||
content: format!(
|
|
||||||
"Based on this conversation history, {}\n\nConversation:\n{}",
|
"Based on this conversation history, {}\n\nConversation:\n{}",
|
||||||
summary_prompt, conversation_text
|
summary_prompt, conversation_text
|
||||||
),
|
)),
|
||||||
},
|
|
||||||
];
|
];
|
||||||
|
|
||||||
let provider = self.providers.get(None)?;
|
let provider = self.providers.get(None)?;
|
||||||
@@ -2776,18 +2825,11 @@ If you can complete it with 1-2 tool calls, skip TODO.
|
|||||||
.join("\n\n");
|
.join("\n\n");
|
||||||
|
|
||||||
let summary_messages = vec![
|
let summary_messages = vec![
|
||||||
Message {
|
Message::new(MessageRole::System, "You are a helpful assistant that creates concise summaries.".to_string()),
|
||||||
role: MessageRole::System,
|
Message::new(MessageRole::User, format!(
|
||||||
content: "You are a helpful assistant that creates concise summaries."
|
|
||||||
.to_string(),
|
|
||||||
},
|
|
||||||
Message {
|
|
||||||
role: MessageRole::User,
|
|
||||||
content: format!(
|
|
||||||
"Based on this conversation history, {}\n\nConversation:\n{}",
|
"Based on this conversation history, {}\n\nConversation:\n{}",
|
||||||
summary_prompt, conversation_text
|
summary_prompt, conversation_text
|
||||||
),
|
)),
|
||||||
},
|
|
||||||
];
|
];
|
||||||
|
|
||||||
let provider = self.providers.get(None)?;
|
let provider = self.providers.get(None)?;
|
||||||
@@ -3273,29 +3315,20 @@ If you can complete it with 1-2 tool calls, skip TODO.
|
|||||||
// Add the tool call and result to the context window using RAW unfiltered content
|
// Add the tool call and result to the context window using RAW unfiltered content
|
||||||
// This ensures the log file contains the true raw content including JSON tool calls
|
// This ensures the log file contains the true raw content including JSON tool calls
|
||||||
let tool_message = if !raw_content_for_log.trim().is_empty() {
|
let tool_message = if !raw_content_for_log.trim().is_empty() {
|
||||||
Message {
|
Message::new(MessageRole::Assistant, format!(
|
||||||
role: MessageRole::Assistant,
|
|
||||||
content: format!(
|
|
||||||
"{}\n\n{{\"tool\": \"{}\", \"args\": {}}}",
|
"{}\n\n{{\"tool\": \"{}\", \"args\": {}}}",
|
||||||
raw_content_for_log.trim(),
|
raw_content_for_log.trim(),
|
||||||
tool_call.tool,
|
tool_call.tool,
|
||||||
tool_call.args
|
tool_call.args
|
||||||
),
|
))
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// No text content before tool call, just include the tool call
|
// No text content before tool call, just include the tool call
|
||||||
Message {
|
Message::new(MessageRole::Assistant, format!(
|
||||||
role: MessageRole::Assistant,
|
|
||||||
content: format!(
|
|
||||||
"{{\"tool\": \"{}\", \"args\": {}}}",
|
"{{\"tool\": \"{}\", \"args\": {}}}",
|
||||||
tool_call.tool, tool_call.args
|
tool_call.tool, tool_call.args
|
||||||
),
|
))
|
||||||
}
|
|
||||||
};
|
|
||||||
let result_message = Message {
|
|
||||||
role: MessageRole::User,
|
|
||||||
content: format!("Tool result: {}", tool_result),
|
|
||||||
};
|
};
|
||||||
|
let result_message = Message::new(MessageRole::User, format!("Tool result: {}", tool_result));
|
||||||
|
|
||||||
self.context_window.add_message(tool_message);
|
self.context_window.add_message(tool_message);
|
||||||
self.context_window.add_message(result_message);
|
self.context_window.add_message(result_message);
|
||||||
@@ -3304,7 +3337,8 @@ If you can complete it with 1-2 tool calls, skip TODO.
|
|||||||
request.messages = self.context_window.conversation_history.clone();
|
request.messages = self.context_window.conversation_history.clone();
|
||||||
|
|
||||||
// Ensure tools are included for native providers in subsequent iterations
|
// Ensure tools are included for native providers in subsequent iterations
|
||||||
if provider.has_native_tool_calling() {
|
let provider_for_tools = self.providers.get(None)?;
|
||||||
|
if provider_for_tools.has_native_tool_calling() {
|
||||||
request.tools = Some(Self::create_tool_definitions(
|
request.tools = Some(Self::create_tool_definitions(
|
||||||
self.config.webdriver.enabled,
|
self.config.webdriver.enabled,
|
||||||
self.config.macax.enabled,
|
self.config.macax.enabled,
|
||||||
@@ -3635,9 +3669,32 @@ If you can complete it with 1-2 tool calls, skip TODO.
|
|||||||
.replace("<</SYS>>", "");
|
.replace("<</SYS>>", "");
|
||||||
|
|
||||||
if !raw_clean.trim().is_empty() {
|
if !raw_clean.trim().is_empty() {
|
||||||
let assistant_message = Message {
|
let assistant_message = {
|
||||||
role: MessageRole::Assistant,
|
// Check if we should use cache control (every 10 tool calls)
|
||||||
content: raw_clean,
|
if self.tool_call_count > 0 && self.tool_call_count % 10 == 0 {
|
||||||
|
let provider = self.providers.get(None)?;
|
||||||
|
if let Some(cache_config) = match provider.name() {
|
||||||
|
"anthropic" => self.config.providers.anthropic.as_ref()
|
||||||
|
.and_then(|c| c.cache_config.as_ref())
|
||||||
|
.and_then(|config| Self::parse_cache_control(config)),
|
||||||
|
"databricks" => self.config.providers.databricks.as_ref()
|
||||||
|
.and_then(|c| {
|
||||||
|
if c.model.contains("claude") {
|
||||||
|
c.cache_config.as_ref()
|
||||||
|
.and_then(|config| Self::parse_cache_control(config))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
_ => None,
|
||||||
|
} {
|
||||||
|
Message::with_cache_control_validated(MessageRole::Assistant, raw_clean, cache_config, provider)
|
||||||
|
} else {
|
||||||
|
Message::new(MessageRole::Assistant, raw_clean)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Message::new(MessageRole::Assistant, raw_clean)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
self.context_window.add_message(assistant_message);
|
self.context_window.add_message(assistant_message);
|
||||||
}
|
}
|
||||||
@@ -3679,7 +3736,10 @@ If you can complete it with 1-2 tool calls, skip TODO.
|
|||||||
Ok(TaskResult::new(final_response, self.context_window.clone()))
|
Ok(TaskResult::new(final_response, self.context_window.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn execute_tool(&self, tool_call: &ToolCall) -> Result<String> {
|
pub async fn execute_tool(&mut self, tool_call: &ToolCall) -> Result<String> {
|
||||||
|
// Increment tool call count
|
||||||
|
self.tool_call_count += 1;
|
||||||
|
|
||||||
debug!("=== EXECUTING TOOL ===");
|
debug!("=== EXECUTING TOOL ===");
|
||||||
debug!("Tool name: {}", tool_call.tool);
|
debug!("Tool name: {}", tool_call.tool);
|
||||||
debug!("Tool args (raw): {:?}", tool_call.args);
|
debug!("Tool args (raw): {:?}", tool_call.args);
|
||||||
|
|||||||
@@ -21,22 +21,18 @@
|
|||||||
//! // Create the provider with your API key
|
//! // Create the provider with your API key
|
||||||
//! let provider = AnthropicProvider::new(
|
//! let provider = AnthropicProvider::new(
|
||||||
//! "your-api-key".to_string(),
|
//! "your-api-key".to_string(),
|
||||||
//! Some("claude-3-5-sonnet-20241022".to_string()), // Optional: defaults to claude-3-5-sonnet-20241022
|
//! Some("claude-3-5-sonnet-20241022".to_string()),
|
||||||
//! Some(4096), // Optional: max tokens
|
//! Some(4096),
|
||||||
//! Some(0.1), // Optional: temperature
|
//! Some(0.1),
|
||||||
|
//! None, // cache_config
|
||||||
|
//! None, // enable_1m_context
|
||||||
//! )?;
|
//! )?;
|
||||||
//!
|
//!
|
||||||
//! // Create a completion request
|
//! // Create a completion request
|
||||||
//! let request = CompletionRequest {
|
//! let request = CompletionRequest {
|
||||||
//! messages: vec![
|
//! messages: vec![
|
||||||
//! Message {
|
//! Message::new(MessageRole::System, "You are a helpful assistant.".to_string()),
|
||||||
//! role: MessageRole::System,
|
//! Message::new(MessageRole::User, "Hello! How are you?".to_string()),
|
||||||
//! content: "You are a helpful assistant.".to_string(),
|
|
||||||
//! },
|
|
||||||
//! Message {
|
|
||||||
//! role: MessageRole::User,
|
|
||||||
//! content: "Hello! How are you?".to_string(),
|
|
||||||
//! },
|
|
||||||
//! ],
|
//! ],
|
||||||
//! max_tokens: Some(1000),
|
//! max_tokens: Some(1000),
|
||||||
//! temperature: Some(0.7),
|
//! temperature: Some(0.7),
|
||||||
@@ -62,15 +58,16 @@
|
|||||||
//! async fn main() -> anyhow::Result<()> {
|
//! async fn main() -> anyhow::Result<()> {
|
||||||
//! let provider = AnthropicProvider::new(
|
//! let provider = AnthropicProvider::new(
|
||||||
//! "your-api-key".to_string(),
|
//! "your-api-key".to_string(),
|
||||||
//! None, None, None,
|
//! None,
|
||||||
|
//! None,
|
||||||
|
//! None,
|
||||||
|
//! None, // cache_config
|
||||||
|
//! None, // enable_1m_context
|
||||||
//! )?;
|
//! )?;
|
||||||
//!
|
//!
|
||||||
//! let request = CompletionRequest {
|
//! let request = CompletionRequest {
|
||||||
//! messages: vec![
|
//! messages: vec![
|
||||||
//! Message {
|
//! Message::new(MessageRole::User, "Write a short story about a robot.".to_string()),
|
||||||
//! role: MessageRole::User,
|
|
||||||
//! content: "Write a short story about a robot.".to_string(),
|
|
||||||
//! },
|
|
||||||
//! ],
|
//! ],
|
||||||
//! max_tokens: Some(1000),
|
//! max_tokens: Some(1000),
|
||||||
//! temperature: Some(0.7),
|
//! temperature: Some(0.7),
|
||||||
@@ -123,6 +120,8 @@ pub struct AnthropicProvider {
|
|||||||
model: String,
|
model: String,
|
||||||
max_tokens: u32,
|
max_tokens: u32,
|
||||||
temperature: f32,
|
temperature: f32,
|
||||||
|
cache_config: Option<String>,
|
||||||
|
enable_1m_context: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AnthropicProvider {
|
impl AnthropicProvider {
|
||||||
@@ -131,6 +130,8 @@ impl AnthropicProvider {
|
|||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
max_tokens: Option<u32>,
|
max_tokens: Option<u32>,
|
||||||
temperature: Option<f32>,
|
temperature: Option<f32>,
|
||||||
|
cache_config: Option<String>,
|
||||||
|
enable_1m_context: Option<bool>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let client = Client::builder()
|
let client = Client::builder()
|
||||||
.timeout(Duration::from_secs(300))
|
.timeout(Duration::from_secs(300))
|
||||||
@@ -147,6 +148,8 @@ impl AnthropicProvider {
|
|||||||
model,
|
model,
|
||||||
max_tokens: max_tokens.unwrap_or(4096),
|
max_tokens: max_tokens.unwrap_or(4096),
|
||||||
temperature: temperature.unwrap_or(0.1),
|
temperature: temperature.unwrap_or(0.1),
|
||||||
|
cache_config,
|
||||||
|
enable_1m_context: enable_1m_context.unwrap_or(false),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,9 +159,12 @@ impl AnthropicProvider {
|
|||||||
.post(ANTHROPIC_API_URL)
|
.post(ANTHROPIC_API_URL)
|
||||||
.header("x-api-key", &self.api_key)
|
.header("x-api-key", &self.api_key)
|
||||||
.header("anthropic-version", ANTHROPIC_VERSION)
|
.header("anthropic-version", ANTHROPIC_VERSION)
|
||||||
// Anthropic beta 1m context window. Enable if needed. It costs extra, so check first.
|
|
||||||
// .header("anthropic-beta", "context-1m-2025-08-07")
|
|
||||||
.header("content-type", "application/json");
|
.header("content-type", "application/json");
|
||||||
|
|
||||||
|
if self.enable_1m_context {
|
||||||
|
builder = builder.header("anthropic-beta", "context-1m-2025-08-07");
|
||||||
|
}
|
||||||
|
|
||||||
if streaming {
|
if streaming {
|
||||||
builder = builder.header("accept", "text/event-stream");
|
builder = builder.header("accept", "text/event-stream");
|
||||||
}
|
}
|
||||||
@@ -166,6 +172,17 @@ impl AnthropicProvider {
|
|||||||
builder
|
builder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn convert_cache_control(cache_control: &crate::CacheControl) -> AnthropicCacheControl {
|
||||||
|
let cache_type = match cache_control {
|
||||||
|
crate::CacheControl::Ephemeral => "ephemeral",
|
||||||
|
crate::CacheControl::FiveMinute => "5minute",
|
||||||
|
crate::CacheControl::OneHour => "1hour",
|
||||||
|
};
|
||||||
|
AnthropicCacheControl {
|
||||||
|
cache_type: cache_type.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn convert_tools(&self, tools: &[Tool]) -> Vec<AnthropicTool> {
|
fn convert_tools(&self, tools: &[Tool]) -> Vec<AnthropicTool> {
|
||||||
tools
|
tools
|
||||||
.iter()
|
.iter()
|
||||||
@@ -214,6 +231,8 @@ impl AnthropicProvider {
|
|||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: vec![AnthropicContent::Text {
|
content: vec![AnthropicContent::Text {
|
||||||
text: message.content.clone(),
|
text: message.content.clone(),
|
||||||
|
cache_control: message.cache_control.as_ref()
|
||||||
|
.map(Self::convert_cache_control),
|
||||||
}],
|
}],
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -222,6 +241,8 @@ impl AnthropicProvider {
|
|||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: vec![AnthropicContent::Text {
|
content: vec![AnthropicContent::Text {
|
||||||
text: message.content.clone(),
|
text: message.content.clone(),
|
||||||
|
cache_control: message.cache_control.as_ref()
|
||||||
|
.map(Self::convert_cache_control),
|
||||||
}],
|
}],
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -564,7 +585,7 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
.content
|
.content
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|c| match c {
|
.filter_map(|c| match c {
|
||||||
AnthropicContent::Text { text } => Some(text.as_str()),
|
AnthropicContent::Text { text, .. } => Some(text.as_str()),
|
||||||
_ => None,
|
_ => None,
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
@@ -658,6 +679,11 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
// Claude models support native tool calling
|
// Claude models support native tool calling
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_cache_control(&self) -> bool {
|
||||||
|
// Anthropic supports cache control
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Anthropic API request/response structures
|
// Anthropic API request/response structures
|
||||||
@@ -697,11 +723,21 @@ struct AnthropicMessage {
|
|||||||
content: Vec<AnthropicContent>,
|
content: Vec<AnthropicContent>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct AnthropicCacheControl {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
cache_type: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
enum AnthropicContent {
|
enum AnthropicContent {
|
||||||
#[serde(rename = "text")]
|
#[serde(rename = "text")]
|
||||||
Text { text: String },
|
Text {
|
||||||
|
text: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
cache_control: Option<AnthropicCacheControl>,
|
||||||
|
},
|
||||||
#[serde(rename = "tool_use")]
|
#[serde(rename = "tool_use")]
|
||||||
ToolUse {
|
ToolUse {
|
||||||
id: String,
|
id: String,
|
||||||
@@ -771,21 +807,14 @@ mod tests {
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
).unwrap();
|
).unwrap();
|
||||||
|
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
Message {
|
Message::new(MessageRole::System, "You are a helpful assistant.".to_string()),
|
||||||
role: MessageRole::System,
|
Message::new(MessageRole::User, "Hello!".to_string()),
|
||||||
content: "You are a helpful assistant.".to_string(),
|
Message::new(MessageRole::Assistant, "Hi there!".to_string()),
|
||||||
},
|
|
||||||
Message {
|
|
||||||
role: MessageRole::User,
|
|
||||||
content: "Hello!".to_string(),
|
|
||||||
},
|
|
||||||
Message {
|
|
||||||
role: MessageRole::Assistant,
|
|
||||||
content: "Hi there!".to_string(),
|
|
||||||
},
|
|
||||||
];
|
];
|
||||||
|
|
||||||
let (system, anthropic_messages) = provider.convert_messages(&messages).unwrap();
|
let (system, anthropic_messages) = provider.convert_messages(&messages).unwrap();
|
||||||
@@ -803,14 +832,11 @@ mod tests {
|
|||||||
Some("claude-3-haiku-20240307".to_string()),
|
Some("claude-3-haiku-20240307".to_string()),
|
||||||
Some(1000),
|
Some(1000),
|
||||||
Some(0.5),
|
Some(0.5),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
).unwrap();
|
).unwrap();
|
||||||
|
|
||||||
let messages = vec![
|
let messages = vec![Message::new(MessageRole::User, "Test message".to_string())];
|
||||||
Message {
|
|
||||||
role: MessageRole::User,
|
|
||||||
content: "Test message".to_string(),
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
let request_body = provider
|
let request_body = provider
|
||||||
.create_request_body(&messages, None, false, 1000, 0.5)
|
.create_request_body(&messages, None, false, 1000, 0.5)
|
||||||
@@ -831,6 +857,8 @@ mod tests {
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
).unwrap();
|
).unwrap();
|
||||||
|
|
||||||
let tools = vec![
|
let tools = vec![
|
||||||
@@ -859,4 +887,48 @@ mod tests {
|
|||||||
assert!(anthropic_tools[0].input_schema.required.is_some());
|
assert!(anthropic_tools[0].input_schema.required.is_some());
|
||||||
assert_eq!(anthropic_tools[0].input_schema.required.as_ref().unwrap()[0], "location");
|
assert_eq!(anthropic_tools[0].input_schema.required.as_ref().unwrap()[0], "location");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cache_control_serialization() {
|
||||||
|
let provider = AnthropicProvider::new(
|
||||||
|
"test-key".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
).unwrap();
|
||||||
|
|
||||||
|
// Test message WITHOUT cache_control
|
||||||
|
let messages_without = vec![Message::new(MessageRole::User, "Hello".to_string())];
|
||||||
|
let (_, anthropic_messages_without) = provider.convert_messages(&messages_without).unwrap();
|
||||||
|
let json_without = serde_json::to_string(&anthropic_messages_without).unwrap();
|
||||||
|
|
||||||
|
println!("Anthropic JSON without cache_control: {}", json_without);
|
||||||
|
// Check if cache_control appears in the JSON
|
||||||
|
if json_without.contains("cache_control") {
|
||||||
|
println!("WARNING: JSON contains 'cache_control' field when not configured!");
|
||||||
|
assert!(!json_without.contains("\"cache_control\":null"),
|
||||||
|
"JSON should not contain 'cache_control: null'");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test message WITH cache_control
|
||||||
|
let messages_with = vec![Message::with_cache_control(
|
||||||
|
MessageRole::User,
|
||||||
|
"Hello".to_string(),
|
||||||
|
crate::CacheControl::Ephemeral,
|
||||||
|
)];
|
||||||
|
let (_, anthropic_messages_with) = provider.convert_messages(&messages_with).unwrap();
|
||||||
|
let json_with = serde_json::to_string(&anthropic_messages_with).unwrap();
|
||||||
|
|
||||||
|
println!("Anthropic JSON with cache_control: {}", json_with);
|
||||||
|
assert!(json_with.contains("cache_control"),
|
||||||
|
"JSON should contain 'cache_control' field when configured");
|
||||||
|
assert!(json_with.contains("ephemeral"),
|
||||||
|
"JSON should contain 'ephemeral' type");
|
||||||
|
|
||||||
|
// The key assertion: when cache_control is None, it should not appear in JSON
|
||||||
|
assert!(!json_without.contains("cache_control") || !json_without.contains("null"),
|
||||||
|
"JSON should not contain 'cache_control' field or null values when not configured");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,10 +39,7 @@
|
|||||||
//! // Create a completion request
|
//! // Create a completion request
|
||||||
//! let request = CompletionRequest {
|
//! let request = CompletionRequest {
|
||||||
//! messages: vec![
|
//! messages: vec![
|
||||||
//! Message {
|
//! Message::new(MessageRole::User, "Hello! How are you?".to_string()),
|
||||||
//! role: MessageRole::User,
|
|
||||||
//! content: "Hello! How are you?".to_string(),
|
|
||||||
//! },
|
|
||||||
//! ],
|
//! ],
|
||||||
//! max_tokens: Some(1000),
|
//! max_tokens: Some(1000),
|
||||||
//! temperature: Some(0.7),
|
//! temperature: Some(0.7),
|
||||||
@@ -241,6 +238,15 @@ impl DatabricksProvider {
|
|||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn convert_cache_control(cache_control: &crate::CacheControl) -> DatabricksCacheControl {
|
||||||
|
let cache_type = match cache_control {
|
||||||
|
crate::CacheControl::Ephemeral => "ephemeral",
|
||||||
|
crate::CacheControl::FiveMinute => "5minute",
|
||||||
|
crate::CacheControl::OneHour => "1hour",
|
||||||
|
};
|
||||||
|
DatabricksCacheControl { cache_type: cache_type.to_string() }
|
||||||
|
}
|
||||||
|
|
||||||
fn convert_messages(&self, messages: &[Message]) -> Result<Vec<DatabricksMessage>> {
|
fn convert_messages(&self, messages: &[Message]) -> Result<Vec<DatabricksMessage>> {
|
||||||
let mut databricks_messages = Vec::new();
|
let mut databricks_messages = Vec::new();
|
||||||
|
|
||||||
@@ -251,9 +257,24 @@ impl DatabricksProvider {
|
|||||||
MessageRole::Assistant => "assistant",
|
MessageRole::Assistant => "assistant",
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// If message has cache_control, use content array format
|
||||||
|
let content = if message.cache_control.is_some() {
|
||||||
|
// Use array format with cache_control
|
||||||
|
let content_block = DatabricksContent::Text {
|
||||||
|
content_type: "text".to_string(),
|
||||||
|
text: message.content.clone(),
|
||||||
|
cache_control: message.cache_control.as_ref()
|
||||||
|
.map(Self::convert_cache_control),
|
||||||
|
};
|
||||||
|
serde_json::to_value(vec![content_block])?
|
||||||
|
} else {
|
||||||
|
// Use simple string format
|
||||||
|
serde_json::Value::String(message.content.clone())
|
||||||
|
};
|
||||||
|
|
||||||
databricks_messages.push(DatabricksMessage {
|
databricks_messages.push(DatabricksMessage {
|
||||||
role: role.to_string(),
|
role: role.to_string(),
|
||||||
content: Some(message.content.clone()),
|
content: Some(content),
|
||||||
tool_calls: None, // Only used in responses, not requests
|
tool_calls: None, // Only used in responses, not requests
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -864,8 +885,22 @@ impl LLMProvider for DatabricksProvider {
|
|||||||
let content = databricks_response
|
let content = databricks_response
|
||||||
.choices
|
.choices
|
||||||
.first()
|
.first()
|
||||||
.and_then(|choice| choice.message.content.as_ref())
|
.and_then(|choice| {
|
||||||
.cloned()
|
choice.message.content.as_ref().map(|c| {
|
||||||
|
// Handle both string and array formats
|
||||||
|
if let Some(s) = c.as_str() {
|
||||||
|
s.to_string()
|
||||||
|
} else if let Some(arr) = c.as_array() {
|
||||||
|
// Extract text from content blocks
|
||||||
|
arr.iter()
|
||||||
|
.filter_map(|block| block.get("text").and_then(|t| t.as_str()))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("")
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
// Check if there are tool calls in the response
|
// Check if there are tool calls in the response
|
||||||
@@ -1037,6 +1072,11 @@ impl LLMProvider for DatabricksProvider {
|
|||||||
// This includes Claude, Llama, DBRX, and most other models on the platform
|
// This includes Claude, Llama, DBRX, and most other models on the platform
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_cache_control(&self) -> bool {
|
||||||
|
// Databricks supports cache control when using Anthropic models
|
||||||
|
self.model.contains("claude")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Databricks API request/response structures
|
// Databricks API request/response structures
|
||||||
@@ -1064,10 +1104,29 @@ struct DatabricksFunction {
|
|||||||
parameters: serde_json::Value,
|
parameters: serde_json::Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct DatabricksCacheControl {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
cache_type: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum DatabricksContent {
|
||||||
|
Text {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
content_type: String,
|
||||||
|
text: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
cache_control: Option<DatabricksCacheControl>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct DatabricksMessage {
|
struct DatabricksMessage {
|
||||||
role: String,
|
role: String,
|
||||||
content: Option<String>, // Make content optional since tool calls might not have content
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
content: Option<serde_json::Value>, // Can be string or array of content blocks
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
tool_calls: Option<Vec<DatabricksToolCall>>, // Add tool_calls field for responses
|
tool_calls: Option<Vec<DatabricksToolCall>>, // Add tool_calls field for responses
|
||||||
}
|
}
|
||||||
@@ -1154,18 +1213,9 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
Message {
|
Message::new(MessageRole::System, "You are a helpful assistant.".to_string()),
|
||||||
role: MessageRole::System,
|
Message::new(MessageRole::User, "Hello!".to_string()),
|
||||||
content: "You are a helpful assistant.".to_string(),
|
Message::new(MessageRole::Assistant, "Hi there!".to_string()),
|
||||||
},
|
|
||||||
Message {
|
|
||||||
role: MessageRole::User,
|
|
||||||
content: "Hello!".to_string(),
|
|
||||||
},
|
|
||||||
Message {
|
|
||||||
role: MessageRole::Assistant,
|
|
||||||
content: "Hi there!".to_string(),
|
|
||||||
},
|
|
||||||
];
|
];
|
||||||
|
|
||||||
let databricks_messages = provider.convert_messages(&messages).unwrap();
|
let databricks_messages = provider.convert_messages(&messages).unwrap();
|
||||||
@@ -1187,10 +1237,7 @@ mod tests {
|
|||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let messages = vec![Message {
|
let messages = vec![Message::new(MessageRole::User, "Test message".to_string())];
|
||||||
role: MessageRole::User,
|
|
||||||
content: "Test message".to_string(),
|
|
||||||
}];
|
|
||||||
|
|
||||||
let request_body = provider
|
let request_body = provider
|
||||||
.create_request_body(&messages, None, false, 1000, 0.5)
|
.create_request_body(&messages, None, false, 1000, 0.5)
|
||||||
@@ -1273,4 +1320,53 @@ mod tests {
|
|||||||
assert!(llama_provider.has_native_tool_calling());
|
assert!(llama_provider.has_native_tool_calling());
|
||||||
assert!(dbrx_provider.has_native_tool_calling());
|
assert!(dbrx_provider.has_native_tool_calling());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cache_control_serialization() {
|
||||||
|
let provider = DatabricksProvider::from_token(
|
||||||
|
"https://test.databricks.com".to_string(),
|
||||||
|
"test-token".to_string(),
|
||||||
|
"databricks-claude-sonnet-4".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Test message WITHOUT cache_control - should use string format
|
||||||
|
let messages_without = vec![Message::new(MessageRole::User, "Hello".to_string())];
|
||||||
|
let databricks_messages_without = provider.convert_messages(&messages_without).unwrap();
|
||||||
|
let json_without = serde_json::to_string(&databricks_messages_without).unwrap();
|
||||||
|
|
||||||
|
println!("JSON without cache_control: {}", json_without);
|
||||||
|
assert!(!json_without.contains("cache_control"),
|
||||||
|
"JSON should not contain 'cache_control' field when not configured");
|
||||||
|
|
||||||
|
// Test message WITH cache_control - should use array format
|
||||||
|
let messages_with = vec![Message::with_cache_control(
|
||||||
|
MessageRole::User,
|
||||||
|
"Hello".to_string(),
|
||||||
|
crate::CacheControl::Ephemeral,
|
||||||
|
)];
|
||||||
|
let databricks_messages_with = provider.convert_messages(&messages_with).unwrap();
|
||||||
|
let json_with = serde_json::to_string(&databricks_messages_with).unwrap();
|
||||||
|
|
||||||
|
println!("JSON with cache_control: {}", json_with);
|
||||||
|
assert!(json_with.contains("cache_control"),
|
||||||
|
"JSON should contain 'cache_control' field when configured");
|
||||||
|
assert!(json_with.contains("ephemeral"),
|
||||||
|
"JSON should contain 'ephemeral' type");
|
||||||
|
assert!(!json_with.contains("null"),
|
||||||
|
"JSON should not contain null values");
|
||||||
|
|
||||||
|
// Verify the structure is correct
|
||||||
|
let msg_with = &databricks_messages_with[0];
|
||||||
|
if let Some(content) = &msg_with.content {
|
||||||
|
if let Some(arr) = content.as_array() {
|
||||||
|
assert_eq!(arr.len(), 1, "Content array should have one element");
|
||||||
|
assert!(arr[0].get("cache_control").is_some(), "Content should have cache_control");
|
||||||
|
} else {
|
||||||
|
panic!("Content should be an array when cache_control is present");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,6 +21,11 @@ pub trait LLMProvider: Send + Sync {
|
|||||||
fn has_native_tool_calling(&self) -> bool {
|
fn has_native_tool_calling(&self) -> bool {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check if the provider supports cache control
|
||||||
|
fn supports_cache_control(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -32,10 +37,22 @@ pub struct CompletionRequest {
|
|||||||
pub tools: Option<Vec<Tool>>,
|
pub tools: Option<Vec<Tool>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum CacheControl {
|
||||||
|
Ephemeral,
|
||||||
|
#[serde(rename = "5minute")]
|
||||||
|
FiveMinute,
|
||||||
|
#[serde(rename = "1hour")]
|
||||||
|
OneHour,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct Message {
|
pub struct Message {
|
||||||
pub role: MessageRole,
|
pub role: MessageRole,
|
||||||
pub content: String,
|
pub content: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub cache_control: Option<CacheControl>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -95,6 +112,45 @@ pub use databricks::DatabricksProvider;
|
|||||||
pub use embedded::EmbeddedProvider;
|
pub use embedded::EmbeddedProvider;
|
||||||
pub use openai::OpenAIProvider;
|
pub use openai::OpenAIProvider;
|
||||||
|
|
||||||
|
impl Message {
|
||||||
|
/// Create a new message with optional cache control
|
||||||
|
pub fn new(role: MessageRole, content: String) -> Self {
|
||||||
|
Self {
|
||||||
|
role,
|
||||||
|
content,
|
||||||
|
cache_control: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new message with cache control
|
||||||
|
pub fn with_cache_control(role: MessageRole, content: String, cache_control: CacheControl) -> Self {
|
||||||
|
Self {
|
||||||
|
role,
|
||||||
|
content,
|
||||||
|
cache_control: Some(cache_control),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a message with cache control, with provider validation
|
||||||
|
pub fn with_cache_control_validated(
|
||||||
|
role: MessageRole,
|
||||||
|
content: String,
|
||||||
|
cache_control: CacheControl,
|
||||||
|
provider: &dyn LLMProvider
|
||||||
|
) -> Self {
|
||||||
|
if !provider.supports_cache_control() {
|
||||||
|
tracing::warn!(
|
||||||
|
"Cache control requested for provider '{}' which does not support it. \
|
||||||
|
Cache control is only supported by Anthropic and Anthropic via Databricks.",
|
||||||
|
provider.name()
|
||||||
|
);
|
||||||
|
return Self::new(role, content);
|
||||||
|
}
|
||||||
|
|
||||||
|
Self::with_cache_control(role, content, cache_control)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Provider registry for managing multiple LLM providers
|
/// Provider registry for managing multiple LLM providers
|
||||||
pub struct ProviderRegistry {
|
pub struct ProviderRegistry {
|
||||||
providers: HashMap<String, Box<dyn LLMProvider>>,
|
providers: HashMap<String, Box<dyn LLMProvider>>,
|
||||||
@@ -144,3 +200,36 @@ impl Default for ProviderRegistry {
|
|||||||
Self::new()
|
Self::new()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_message_serialization_without_cache_control() {
|
||||||
|
let msg = Message::new(MessageRole::User, "Hello".to_string());
|
||||||
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
|
|
||||||
|
println!("Message JSON without cache_control: {}", json);
|
||||||
|
assert!(!json.contains("cache_control"),
|
||||||
|
"JSON should not contain 'cache_control' field when not configured");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_message_serialization_with_cache_control() {
|
||||||
|
let msg = Message::with_cache_control(
|
||||||
|
MessageRole::User,
|
||||||
|
"Hello".to_string(),
|
||||||
|
CacheControl::Ephemeral,
|
||||||
|
);
|
||||||
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
|
|
||||||
|
println!("Message JSON with cache_control: {}", json);
|
||||||
|
assert!(json.contains("cache_control"),
|
||||||
|
"JSON should contain 'cache_control' field when configured");
|
||||||
|
assert!(json.contains("ephemeral"),
|
||||||
|
"JSON should contain 'ephemeral' value");
|
||||||
|
assert!(!json.contains("null"),
|
||||||
|
"JSON should not contain null values");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user