Fix temperature param + add thinking for anthropic
The temperature param was not passed to the llm. Now support anthropic models in 'thinking' mode.
This commit is contained in:
@@ -24,6 +24,8 @@ temperature = 0.3 # Slightly higher temperature for more creative implementatio
|
|||||||
# Options: "ephemeral", "5minute", "1hour"
|
# Options: "ephemeral", "5minute", "1hour"
|
||||||
# Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs.
|
# Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs.
|
||||||
# enable_1m_context = true # optional, more expensive
|
# enable_1m_context = true # optional, more expensive
|
||||||
|
# thinking_budget_tokens = 10000 # Optional: Enable extended thinking mode with token budget
|
||||||
|
# Allows the model to "think" before responding. Useful for complex reasoning tasks.
|
||||||
|
|
||||||
|
|
||||||
# Multiple OpenAI-compatible providers can be configured with custom names
|
# Multiple OpenAI-compatible providers can be configured with custom names
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ pub struct AnthropicConfig {
|
|||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
pub cache_config: Option<String>, // "ephemeral", "5minute", "1hour", or None to disable
|
pub cache_config: Option<String>, // "ephemeral", "5minute", "1hour", or None to disable
|
||||||
pub enable_1m_context: Option<bool>, // Enable 1m context window (costs extra)
|
pub enable_1m_context: Option<bool>, // Enable 1m context window (costs extra)
|
||||||
|
pub thinking_budget_tokens: Option<u32>, // Budget tokens for extended thinking
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|||||||
@@ -950,6 +950,7 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
anthropic_config.temperature,
|
anthropic_config.temperature,
|
||||||
anthropic_config.cache_config.clone(),
|
anthropic_config.cache_config.clone(),
|
||||||
anthropic_config.enable_1m_context,
|
anthropic_config.enable_1m_context,
|
||||||
|
anthropic_config.thinking_budget_tokens,
|
||||||
)?;
|
)?;
|
||||||
providers.register(anthropic_provider);
|
providers.register(anthropic_provider);
|
||||||
}
|
}
|
||||||
@@ -1167,6 +1168,17 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the configured temperature for a provider from top-level config
|
||||||
|
fn provider_temperature(config: &Config, provider_name: &str) -> Option<f32> {
|
||||||
|
match provider_name {
|
||||||
|
"anthropic" => config.providers.anthropic.as_ref()?.temperature,
|
||||||
|
"openai" => config.providers.openai.as_ref()?.temperature,
|
||||||
|
"databricks" => config.providers.databricks.as_ref()?.temperature,
|
||||||
|
"embedded" => config.providers.embedded.as_ref()?.temperature,
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Resolve the max_tokens to use for a given provider, applying fallbacks
|
/// Resolve the max_tokens to use for a given provider, applying fallbacks
|
||||||
fn resolve_max_tokens(&self, provider_name: &str) -> u32 {
|
fn resolve_max_tokens(&self, provider_name: &str) -> u32 {
|
||||||
match provider_name {
|
match provider_name {
|
||||||
@@ -1179,6 +1191,16 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Resolve the temperature to use for a given provider, applying fallbacks
|
||||||
|
fn resolve_temperature(&self, provider_name: &str) -> f32 {
|
||||||
|
match provider_name {
|
||||||
|
"databricks" => Self::provider_temperature(&self.config, "databricks")
|
||||||
|
.unwrap_or(0.1),
|
||||||
|
other => Self::provider_temperature(&self.config, other)
|
||||||
|
.unwrap_or(0.1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Print provider diagnostics through the UiWriter for visibility
|
/// Print provider diagnostics through the UiWriter for visibility
|
||||||
pub fn print_provider_banner(&self, role_label: &str) {
|
pub fn print_provider_banner(&self, role_label: &str) {
|
||||||
if let Ok((provider_name, model)) = self.get_provider_info() {
|
if let Ok((provider_name, model)) = self.get_provider_info() {
|
||||||
@@ -1562,7 +1584,7 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
let request = CompletionRequest {
|
let request = CompletionRequest {
|
||||||
messages,
|
messages,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
temperature: Some(0.1),
|
temperature: Some(self.resolve_temperature(&provider_name)),
|
||||||
stream: true, // Enable streaming
|
stream: true, // Enable streaming
|
||||||
tools,
|
tools,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -26,6 +26,7 @@
|
|||||||
//! Some(0.1),
|
//! Some(0.1),
|
||||||
//! None, // cache_config
|
//! None, // cache_config
|
||||||
//! None, // enable_1m_context
|
//! None, // enable_1m_context
|
||||||
|
//! None, // thinking_budget_tokens
|
||||||
//! )?;
|
//! )?;
|
||||||
//!
|
//!
|
||||||
//! // Create a completion request
|
//! // Create a completion request
|
||||||
@@ -63,6 +64,7 @@
|
|||||||
//! None,
|
//! None,
|
||||||
//! None, // cache_config
|
//! None, // cache_config
|
||||||
//! None, // enable_1m_context
|
//! None, // enable_1m_context
|
||||||
|
//! None, // thinking_budget_tokens
|
||||||
//! )?;
|
//! )?;
|
||||||
//!
|
//!
|
||||||
//! let request = CompletionRequest {
|
//! let request = CompletionRequest {
|
||||||
@@ -122,6 +124,7 @@ pub struct AnthropicProvider {
|
|||||||
temperature: f32,
|
temperature: f32,
|
||||||
cache_config: Option<String>,
|
cache_config: Option<String>,
|
||||||
enable_1m_context: bool,
|
enable_1m_context: bool,
|
||||||
|
thinking_budget_tokens: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AnthropicProvider {
|
impl AnthropicProvider {
|
||||||
@@ -132,6 +135,7 @@ impl AnthropicProvider {
|
|||||||
temperature: Option<f32>,
|
temperature: Option<f32>,
|
||||||
cache_config: Option<String>,
|
cache_config: Option<String>,
|
||||||
enable_1m_context: Option<bool>,
|
enable_1m_context: Option<bool>,
|
||||||
|
thinking_budget_tokens: Option<u32>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let client = Client::builder()
|
let client = Client::builder()
|
||||||
.timeout(Duration::from_secs(300))
|
.timeout(Duration::from_secs(300))
|
||||||
@@ -150,6 +154,7 @@ impl AnthropicProvider {
|
|||||||
temperature: temperature.unwrap_or(0.1),
|
temperature: temperature.unwrap_or(0.1),
|
||||||
cache_config,
|
cache_config,
|
||||||
enable_1m_context: enable_1m_context.unwrap_or(false),
|
enable_1m_context: enable_1m_context.unwrap_or(false),
|
||||||
|
thinking_budget_tokens,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -279,6 +284,11 @@ impl AnthropicProvider {
|
|||||||
// Convert tools if provided
|
// Convert tools if provided
|
||||||
let anthropic_tools = tools.map(|t| self.convert_tools(t));
|
let anthropic_tools = tools.map(|t| self.convert_tools(t));
|
||||||
|
|
||||||
|
// Add thinking configuration if budget_tokens is set
|
||||||
|
let thinking = self.thinking_budget_tokens.map(|budget| {
|
||||||
|
ThinkingConfig::enabled(budget)
|
||||||
|
});
|
||||||
|
|
||||||
let request = AnthropicRequest {
|
let request = AnthropicRequest {
|
||||||
model: self.model.clone(),
|
model: self.model.clone(),
|
||||||
max_tokens,
|
max_tokens,
|
||||||
@@ -287,6 +297,7 @@ impl AnthropicProvider {
|
|||||||
system,
|
system,
|
||||||
tools: anthropic_tools,
|
tools: anthropic_tools,
|
||||||
stream: streaming,
|
stream: streaming,
|
||||||
|
thinking,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Ensure the conversation starts with a user message
|
// Ensure the conversation starts with a user message
|
||||||
@@ -777,6 +788,19 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
|
|
||||||
// Anthropic API request/response structures
|
// Anthropic API request/response structures
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct ThinkingConfig {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
thinking_type: String,
|
||||||
|
budget_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ThinkingConfig {
|
||||||
|
fn enabled(budget_tokens: u32) -> Self {
|
||||||
|
Self { thinking_type: "enabled".to_string(), budget_tokens }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
struct AnthropicRequest {
|
struct AnthropicRequest {
|
||||||
model: String,
|
model: String,
|
||||||
@@ -788,6 +812,8 @@ struct AnthropicRequest {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
tools: Option<Vec<AnthropicTool>>,
|
tools: Option<Vec<AnthropicTool>>,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
thinking: Option<ThinkingConfig>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
@@ -886,7 +912,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_message_conversion() {
|
fn test_message_conversion() {
|
||||||
let provider =
|
let provider =
|
||||||
AnthropicProvider::new("test-key".to_string(), None, None, None, None, None).unwrap();
|
AnthropicProvider::new("test-key".to_string(), None, None, None, None, None, None).unwrap();
|
||||||
|
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
Message::new(
|
Message::new(
|
||||||
@@ -914,6 +940,7 @@ mod tests {
|
|||||||
Some(0.5),
|
Some(0.5),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -934,7 +961,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_tool_conversion() {
|
fn test_tool_conversion() {
|
||||||
let provider =
|
let provider =
|
||||||
AnthropicProvider::new("test-key".to_string(), None, None, None, None, None).unwrap();
|
AnthropicProvider::new("test-key".to_string(), None, None, None, None, None, None).unwrap();
|
||||||
|
|
||||||
let tools = vec![Tool {
|
let tools = vec![Tool {
|
||||||
name: "get_weather".to_string(),
|
name: "get_weather".to_string(),
|
||||||
@@ -967,7 +994,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_cache_control_serialization() {
|
fn test_cache_control_serialization() {
|
||||||
let provider =
|
let provider =
|
||||||
AnthropicProvider::new("test-key".to_string(), None, None, None, None, None).unwrap();
|
AnthropicProvider::new("test-key".to_string(), None, None, None, None, None, None).unwrap();
|
||||||
|
|
||||||
// Test message WITHOUT cache_control
|
// Test message WITHOUT cache_control
|
||||||
let messages_without = vec![Message::new(MessageRole::User, "Hello".to_string())];
|
let messages_without = vec![Message::new(MessageRole::User, "Hello".to_string())];
|
||||||
@@ -1009,4 +1036,46 @@ mod tests {
|
|||||||
"JSON should not contain 'cache_control' field or null values when not configured"
|
"JSON should not contain 'cache_control' field or null values when not configured"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_thinking_parameter_serialization() {
|
||||||
|
// Test WITHOUT thinking parameter
|
||||||
|
let provider_without = AnthropicProvider::new(
|
||||||
|
"test-key".to_string(),
|
||||||
|
Some("claude-sonnet-4-5".to_string()),
|
||||||
|
Some(1000),
|
||||||
|
Some(0.5),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None, // No thinking budget
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let messages = vec![Message::new(MessageRole::User, "Test message".to_string())];
|
||||||
|
let request_without = provider_without
|
||||||
|
.create_request_body(&messages, None, false, 1000, 0.5)
|
||||||
|
.unwrap();
|
||||||
|
let json_without = serde_json::to_string(&request_without).unwrap();
|
||||||
|
assert!(!json_without.contains("thinking"), "JSON should not contain 'thinking' field when not configured");
|
||||||
|
|
||||||
|
// Test WITH thinking parameter
|
||||||
|
let provider_with = AnthropicProvider::new(
|
||||||
|
"test-key".to_string(),
|
||||||
|
Some("claude-sonnet-4-5".to_string()),
|
||||||
|
Some(1000),
|
||||||
|
Some(0.5),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
Some(10000), // With thinking budget
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let request_with = provider_with
|
||||||
|
.create_request_body(&messages, None, false, 1000, 0.5)
|
||||||
|
.unwrap();
|
||||||
|
let json_with = serde_json::to_string(&request_with).unwrap();
|
||||||
|
assert!(json_with.contains("thinking"), "JSON should contain 'thinking' field when configured");
|
||||||
|
assert!(json_with.contains("\"type\":\"enabled\""), "JSON should contain type: enabled");
|
||||||
|
assert!(json_with.contains("\"budget_tokens\":10000"), "JSON should contain budget_tokens: 10000");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user