config: default agent settings and provider override

This commit is contained in:
Dhanji R. Prasanna
2026-01-14 20:14:33 +05:30
parent 38828c7757
commit f4562cd4c9
5 changed files with 595 additions and 396 deletions

View File

@@ -1,24 +1,30 @@
# g3 Configuration Example
#
# This file demonstrates the new provider configuration format.
# Provider references use the format: "<provider_type>.<config_name>"
# Most settings have sensible defaults. A minimal config only needs:
#
# [providers]
# default_provider = "anthropic.default"
#
# [providers.anthropic.default]
# api_key = "your-api-key"
# model = "claude-sonnet-4-5"
#
# Everything else below is optional.
[providers]
# Default provider used when no specific provider is specified
default_provider = "anthropic.default"
# Optional: Specify different providers for each mode
# If not specified, these fall back to default_provider
# planner = "anthropic.planner" # Provider for planning mode
# coach = "anthropic.default" # Provider for coach (code reviewer) in autonomous mode
# player = "anthropic.default" # Provider for player (code implementer) in autonomous mode
# coach = "anthropic.default" # Provider for coach in autonomous mode
# player = "anthropic.default" # Provider for player in autonomous mode
# Named Anthropic configurations
[providers.anthropic.default]
api_key = "your-anthropic-api-key"
model = "claude-sonnet-4-5"
max_tokens = 64000
temperature = 0.3
# max_tokens = 64000 # Optional (default: provider's max)
# temperature = 0.3 # Optional
# cache_config = "ephemeral" # Optional: Enable prompt caching
# enable_1m_context = true # Optional: Enable 1M context (costs extra)
# thinking_budget_tokens = 10000 # Optional: Enable extended thinking mode
@@ -27,99 +33,50 @@ temperature = 0.3
# [providers.anthropic.planner]
# api_key = "your-anthropic-api-key"
# model = "claude-opus-4-5"
# max_tokens = 64000
# thinking_budget_tokens = 16000
# Named Databricks configurations
[providers.databricks.default]
host = "https://your-workspace.cloud.databricks.com"
# token = "your-databricks-token" # Optional - will use OAuth if not provided
model = "databricks-claude-sonnet-4"
max_tokens = 4096
temperature = 0.1
use_oauth = true
# Databricks provider example
# [providers.databricks.default]
# host = "https://your-workspace.cloud.databricks.com"
# model = "databricks-claude-sonnet-4"
# use_oauth = true
# Named OpenAI configurations
# OpenAI provider example
# [providers.openai.default]
# api_key = "your-openai-api-key"
# model = "gpt-4-turbo"
# max_tokens = 4096
# temperature = 0.1
# Multiple OpenAI-compatible providers can be configured
# OpenAI-compatible providers (OpenRouter, Groq, etc.)
# [providers.openai_compatible.openrouter]
# api_key = "your-openrouter-api-key"
# model = "anthropic/claude-3.5-sonnet"
# base_url = "https://openrouter.ai/api/v1"
# max_tokens = 4096
# temperature = 0.1
# [providers.openai_compatible.groq]
# api_key = "your-groq-api-key"
# model = "llama-3.3-70b-versatile"
# base_url = "https://api.groq.com/openai/v1"
# max_tokens = 4096
# temperature = 0.1
# =============================================================================
# Agent settings (all optional - these are the defaults)
# =============================================================================
# [agent]
# fallback_default_max_tokens = 8192
# enable_streaming = true
# timeout_seconds = 120
# auto_compact = true
# max_retry_attempts = 3
# autonomous_max_retry_attempts = 6
# max_context_length = 200000 # Override context window size
[agent]
fallback_default_max_tokens = 8192
# max_context_length: Override the context window size for all providers
# This is the total size of conversation history, not per-request output limit
# max_context_length = 200000
enable_streaming = true
timeout_seconds = 60
max_retry_attempts = 3
autonomous_max_retry_attempts = 6
allow_multiple_tool_calls = true
# =============================================================================
# Computer control (all optional - enabled by default)
# =============================================================================
# [computer_control]
# enabled = true # Requires OS accessibility permissions
# require_confirmation = true
# max_actions_per_second = 5
# Retry Configuration for Planning/Autonomous Mode
#
# The retry infrastructure handles transient errors during LLM API calls:
# - Rate limits (HTTP 429)
# - Network errors (connection failures)
# - Server errors (HTTP 5xx)
# - Request timeouts
# - Model capacity issues (model busy)
#
# Default retry behavior:
# - max_retry_attempts: Used by default interactive mode (3 retries)
# - autonomous_max_retry_attempts: Used by planning/autonomous mode (6 retries)
#
# Note: The retry logic uses exponential backoff with longer delays in
# autonomous mode to handle rate limits gracefully.
#
# Example player retry config (in code):
# RetryConfig::planning("player") # Creates: max_retries=3, is_autonomous=true
# RetryConfig::planning("player").with_max_retries(6) # Override max retries
#
# Example coach retry config (in code):
# RetryConfig::planning("coach") # Creates: max_retries=3, is_autonomous=true
# RetryConfig::planning("coach").with_max_retries(6) # Override max retries
#
[computer_control]
enabled = false # Set to true to enable computer control (requires OS permissions)
require_confirmation = true
max_actions_per_second = 5
[webdriver]
enabled = false
safari_port = 4444
chrome_port = 9515
# Browser to use: "safari" (default) or "chrome-headless"
# Safari opens a visible browser window
# Chrome headless runs in the background without a visible window
browser = "safari"
# Optional: Path to Chrome binary (e.g., Chrome for Testing)
# If not set, ChromeDriver will use the default Chrome installation
# Use this to avoid version mismatch issues between Chrome and ChromeDriver
# Run: ./scripts/setup-chrome-for-testing.sh to install matching versions
# chrome_binary = "/Users/yourname/.chrome-for-testing/chrome-mac-arm64/Google Chrome for Testing.app/Contents/MacOS/Google Chrome for Testing"
# chrome_binary = "/Users/yourname/.chrome-for-testing/chrome-mac-x64/Google Chrome for Testing.app/Contents/MacOS/Google Chrome for Testing"
# Optional: Path to ChromeDriver binary
# If not set, looks for 'chromedriver' in PATH
# The setup script creates a symlink at ~/.local/bin/chromedriver
# chromedriver_binary = "/Users/yourname/.local/bin/chromedriver"
[macax]
enabled = false
# =============================================================================
# WebDriver browser automation (all optional)
# =============================================================================
# [webdriver]
# enabled = true
# browser = "chrome-headless" # Default. Alternative: "safari"
# chrome_binary = "/path/to/chrome" # Optional: custom Chrome path
# chromedriver_binary = "/path/to/driver" # Optional: custom ChromeDriver path

View File

@@ -55,7 +55,7 @@ pub struct Cli {
#[arg(long)]
pub chat: bool,
/// Override the configured provider (anthropic, databricks, embedded, openai)
/// Override the configured provider (e.g., 'openai' or 'openai.default')
#[arg(long, value_name = "PROVIDER")]
pub provider: Option<String>,

View File

@@ -7,8 +7,11 @@ use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub providers: ProvidersConfig,
#[serde(default)]
pub agent: AgentConfig,
#[serde(default)]
pub computer_control: ComputerControlConfig,
#[serde(default)]
pub webdriver: WebDriverConfig,
}
@@ -92,24 +95,59 @@ pub struct EmbeddedConfig {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
pub max_context_length: Option<u32>,
#[serde(default = "default_fallback_max_tokens")]
pub fallback_default_max_tokens: usize,
#[serde(default = "default_true")]
pub enable_streaming: bool,
#[serde(default = "default_timeout_seconds")]
pub timeout_seconds: u64,
#[serde(default = "default_true")]
pub auto_compact: bool,
#[serde(default = "default_max_retry_attempts")]
pub max_retry_attempts: u32,
#[serde(default = "default_autonomous_max_retry_attempts")]
pub autonomous_max_retry_attempts: u32,
#[serde(default = "default_check_todo_staleness")]
pub check_todo_staleness: bool,
}
fn default_fallback_max_tokens() -> usize {
8192
}
fn default_true() -> bool {
true
}
fn default_false() -> bool {
false
}
fn default_timeout_seconds() -> u64 {
120
}
fn default_max_retry_attempts() -> u32 {
3
}
fn default_autonomous_max_retry_attempts() -> u32 {
6
}
fn default_max_actions_per_second() -> u32 {
5
}
fn default_check_todo_staleness() -> bool {
true
}
fn default_safari_port() -> u16 {
4444
}
fn default_chrome_port() -> u16 {
9515
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComputerControlConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_false")]
pub require_confirmation: bool,
#[serde(default = "default_max_actions_per_second")]
pub max_actions_per_second: u32,
}
@@ -117,17 +155,19 @@ pub struct ComputerControlConfig {
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "lowercase")]
pub enum WebDriverBrowser {
#[default]
Safari,
#[default]
#[serde(rename = "chrome-headless")]
ChromeHeadless,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct WebDriverConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_safari_port")]
pub safari_port: u16,
#[serde(default)]
#[serde(default = "default_chrome_port")]
pub chrome_port: u16,
#[serde(default)]
/// Optional path to Chrome binary (e.g., Chrome for Testing)
@@ -141,24 +181,25 @@ pub struct WebDriverConfig {
pub browser: WebDriverBrowser,
}
impl Default for WebDriverConfig {
impl Default for AgentConfig {
fn default() -> Self {
Self {
enabled: true,
safari_port: 4444,
chrome_port: 9515,
chrome_binary: None,
chromedriver_binary: None,
browser: WebDriverBrowser::Safari,
max_context_length: None,
fallback_default_max_tokens: 8192,
enable_streaming: true,
timeout_seconds: 120,
auto_compact: true,
max_retry_attempts: 3,
autonomous_max_retry_attempts: 6,
check_todo_staleness: true,
}
}
}
impl Default for ComputerControlConfig {
fn default() -> Self {
Self {
enabled: false,
require_confirmation: true,
enabled: true,
require_confirmation: false,
max_actions_per_second: 5,
}
}
@@ -445,20 +486,26 @@ impl Config {
// Apply provider override
if let Some(provider) = provider_override {
// Validate the override
// If provider doesn't contain '.', assume '.default'
let provider = if provider.contains('.') {
provider
} else {
format!("{}.default", provider)
};
config.validate_provider_reference(&provider)?;
config.providers.default_provider = provider;
}
// Apply model override to the active provider
if let Some(model) = model_override {
let (provider_type, config_name) = Self::parse_provider_reference(
&config.providers.default_provider
)?;
let (provider_type, config_name) =
Self::parse_provider_reference(&config.providers.default_provider)?;
match provider_type.as_str() {
"anthropic" => {
if let Some(ref mut anthropic_config) = config.providers.anthropic.get_mut(&config_name) {
if let Some(ref mut anthropic_config) =
config.providers.anthropic.get_mut(&config_name)
{
anthropic_config.model = model;
} else {
return Err(anyhow::anyhow!(
@@ -468,7 +515,9 @@ impl Config {
}
}
"databricks" => {
if let Some(ref mut databricks_config) = config.providers.databricks.get_mut(&config_name) {
if let Some(ref mut databricks_config) =
config.providers.databricks.get_mut(&config_name)
{
databricks_config.model = model;
} else {
return Err(anyhow::anyhow!(
@@ -478,7 +527,9 @@ impl Config {
}
}
"embedded" => {
if let Some(ref mut embedded_config) = config.providers.embedded.get_mut(&config_name) {
if let Some(ref mut embedded_config) =
config.providers.embedded.get_mut(&config_name)
{
embedded_config.model_path = model;
} else {
return Err(anyhow::anyhow!(
@@ -488,7 +539,9 @@ impl Config {
}
}
"openai" => {
if let Some(ref mut openai_config) = config.providers.openai.get_mut(&config_name) {
if let Some(ref mut openai_config) =
config.providers.openai.get_mut(&config_name)
{
openai_config.model = model;
} else {
return Err(anyhow::anyhow!(
@@ -499,13 +552,12 @@ impl Config {
}
_ => {
// Check openai_compatible
if let Some(ref mut compat_config) = config.providers.openai_compatible.get_mut(&provider_type) {
if let Some(ref mut compat_config) =
config.providers.openai_compatible.get_mut(&provider_type)
{
compat_config.model = model;
} else {
return Err(anyhow::anyhow!(
"Unknown provider type: {}",
provider_type
));
return Err(anyhow::anyhow!("Unknown provider type: {}", provider_type));
}
}
}
@@ -585,36 +637,42 @@ impl Config {
/// Get the current default provider's config
pub fn get_default_provider_config(&self) -> Result<ProviderConfigRef<'_>> {
let (provider_type, config_name) = Self::parse_provider_reference(
&self.providers.default_provider
)?;
let (provider_type, config_name) =
Self::parse_provider_reference(&self.providers.default_provider)?;
match provider_type.as_str() {
"anthropic" => {
self.providers.anthropic.get(&config_name)
"anthropic" => self
.providers
.anthropic
.get(&config_name)
.map(ProviderConfigRef::Anthropic)
.ok_or_else(|| anyhow::anyhow!("Anthropic config '{}' not found", config_name))
}
"openai" => {
self.providers.openai.get(&config_name)
.ok_or_else(|| anyhow::anyhow!("Anthropic config '{}' not found", config_name)),
"openai" => self
.providers
.openai
.get(&config_name)
.map(ProviderConfigRef::OpenAI)
.ok_or_else(|| anyhow::anyhow!("OpenAI config '{}' not found", config_name))
}
"databricks" => {
self.providers.databricks.get(&config_name)
.ok_or_else(|| anyhow::anyhow!("OpenAI config '{}' not found", config_name)),
"databricks" => self
.providers
.databricks
.get(&config_name)
.map(ProviderConfigRef::Databricks)
.ok_or_else(|| anyhow::anyhow!("Databricks config '{}' not found", config_name))
}
"embedded" => {
self.providers.embedded.get(&config_name)
.ok_or_else(|| anyhow::anyhow!("Databricks config '{}' not found", config_name)),
"embedded" => self
.providers
.embedded
.get(&config_name)
.map(ProviderConfigRef::Embedded)
.ok_or_else(|| anyhow::anyhow!("Embedded config '{}' not found", config_name))
}
_ => {
self.providers.openai_compatible.get(&provider_type)
.ok_or_else(|| anyhow::anyhow!("Embedded config '{}' not found", config_name)),
_ => self
.providers
.openai_compatible
.get(&provider_type)
.map(ProviderConfigRef::OpenAICompatible)
.ok_or_else(|| anyhow::anyhow!("OpenAI compatible config '{}' not found", provider_type))
}
.ok_or_else(|| {
anyhow::anyhow!("OpenAI compatible config '{}' not found", provider_type)
}),
}
}
}

View File

@@ -1,32 +1,38 @@
pub mod acd;
pub mod context_window;
pub mod background_process;
pub mod compaction;
pub mod code_search;
pub mod compaction;
pub mod context_window;
pub mod error_handling;
pub mod feedback_extraction;
pub mod paths;
pub mod project;
pub mod provider_registration;
pub mod provider_config;
pub mod provider_registration;
pub mod retry;
pub mod session;
pub mod session_continuation;
pub mod stats;
pub mod streaming;
pub mod streaming_parser;
pub mod task_result;
pub mod tool_dispatch;
pub mod tool_definitions;
pub mod tool_dispatch;
pub mod tools;
pub mod ui_writer;
pub mod streaming;
pub mod utils;
pub mod webdriver_session;
pub mod stats;
pub use feedback_extraction::{
extract_coach_feedback, ExtractedFeedback, FeedbackExtractionConfig, FeedbackSource,
};
pub use retry::{execute_with_retry, retry_operation, RetryConfig, RetryResult};
pub use session_continuation::{
clear_continuation, find_incomplete_agent_session, format_session_time, get_session_dir,
has_valid_continuation, list_sessions_for_directory, load_context_from_session_log,
load_continuation, save_continuation, SessionContinuation,
};
pub use task_result::TaskResult;
pub use retry::{RetryConfig, RetryResult, execute_with_retry, retry_operation};
pub use feedback_extraction::{ExtractedFeedback, FeedbackSource, FeedbackExtractionConfig, extract_coach_feedback};
pub use session_continuation::{SessionContinuation, load_continuation, save_continuation, clear_continuation, has_valid_continuation, get_session_dir, load_context_from_session_log, find_incomplete_agent_session, list_sessions_for_directory, format_session_time};
// Re-export context window types
pub use context_window::{ContextWindow, ThinScope};
@@ -55,12 +61,12 @@ use tokio_util::sync::CancellationToken;
use tracing::{debug, error, warn};
// Re-export path utilities
pub use paths::{
G3_WORKSPACE_PATH_ENV, ensure_session_dir, get_context_summary_file, get_g3_dir,
get_session_file, get_session_logs_dir, get_session_todo_path, get_thinned_dir,
get_errors_dir, get_background_processes_dir, get_discovery_dir,
};
use paths::get_todo_path;
pub use paths::{
ensure_session_dir, get_background_processes_dir, get_context_summary_file, get_discovery_dir,
get_errors_dir, get_g3_dir, get_session_file, get_session_logs_dir, get_session_todo_path,
get_thinned_dir, G3_WORKSPACE_PATH_ENV,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
@@ -68,7 +74,6 @@ pub struct ToolCall {
pub args: serde_json::Value, // Should be a JSON object with tool-specific arguments
}
// Re-export WebDriverSession from its own module
pub use webdriver_session::WebDriverSession;
@@ -87,9 +92,8 @@ pub enum StreamState {
Resuming,
}
// Re-export StreamingToolParser from its own module
pub use streaming_parser::{StreamingToolParser, sanitize_inline_tool_patterns, LBRACE_HOMOGLYPH};
pub use streaming_parser::{sanitize_inline_tool_patterns, StreamingToolParser, LBRACE_HOMOGLYPH};
pub struct Agent<W: UiWriter> {
providers: ProviderRegistry,
@@ -108,9 +112,7 @@ pub struct Agent<W: UiWriter> {
computer_controller: Option<Box<dyn g3_computer_control::ComputerController>>,
todo_content: std::sync::Arc<tokio::sync::RwLock<String>>,
webdriver_session: std::sync::Arc<
tokio::sync::RwLock<
Option<std::sync::Arc<tokio::sync::Mutex<WebDriverSession>>>,
>,
tokio::sync::RwLock<Option<std::sync::Arc<tokio::sync::Mutex<WebDriverSession>>>>,
>,
webdriver_process: std::sync::Arc<tokio::sync::RwLock<Option<tokio::process::Child>>>,
tool_call_count: usize,
@@ -167,7 +169,15 @@ impl<W: UiWriter> Agent<W> {
custom_system_prompt: String,
readme_content: Option<String>,
) -> Result<Self> {
Self::new_with_mode_and_readme(config, ui_writer, false, readme_content, false, Some(custom_system_prompt)).await
Self::new_with_mode_and_readme(
config,
ui_writer,
false,
readme_content,
false,
Some(custom_system_prompt),
)
.await
}
async fn new_with_mode(
@@ -188,8 +198,10 @@ impl<W: UiWriter> Agent<W> {
custom_system_prompt: Option<String>,
) -> Result<Self> {
// Register providers using the extracted module
let providers_to_register = provider_registration::determine_providers_to_register(&config, is_autonomous);
let providers = provider_registration::register_providers(&config, &providers_to_register).await?;
let providers_to_register =
provider_registration::determine_providers_to_register(&config, is_autonomous);
let providers =
provider_registration::register_providers(&config, &providers_to_register).await?;
// Determine context window size based on active provider
let mut context_warnings = Vec::new();
@@ -273,8 +285,9 @@ impl<W: UiWriter> Agent<W> {
working_dir: None,
background_process_manager: std::sync::Arc::new(
background_process::BackgroundProcessManager::new(
paths::get_background_processes_dir()
)),
paths::get_background_processes_dir(),
),
),
pending_images: Vec::new(),
is_agent_mode: false,
agent_name: None,
@@ -309,7 +322,9 @@ impl<W: UiWriter> Agent<W> {
// Check for system prompt markers that are present in both standard and agent mode
// Agent mode replaces the identity line but keeps all other instructions
let has_tool_instructions = first_message.content.contains("IMPORTANT: You must call tools to achieve goals");
let has_tool_instructions = first_message
.content
.contains("IMPORTANT: You must call tools to achieve goals");
if !has_tool_instructions {
panic!("FATAL: First system message does not contain the system prompt. This likely means the README was added before the system prompt.");
}
@@ -347,7 +362,10 @@ impl<W: UiWriter> Agent<W> {
let (provider_type, config_name) = provider_config::parse_provider_ref(provider_name);
match provider_type {
"anthropic" => self.config.providers.anthropic
"anthropic" => self
.config
.providers
.anthropic
.get(config_name)
.and_then(|c| c.cache_config.as_ref())
.and_then(|config| Self::parse_cache_control(config)),
@@ -362,8 +380,16 @@ impl<W: UiWriter> Agent<W> {
/// Get the thinking budget tokens for Anthropic provider, if configured.
/// Pre-flight check to validate max_tokens for thinking.budget_tokens constraint.
fn preflight_validate_max_tokens(&self, provider_name: &str, proposed_max_tokens: u32) -> (u32, bool) {
provider_config::preflight_validate_max_tokens(&self.config, provider_name, proposed_max_tokens)
fn preflight_validate_max_tokens(
&self,
provider_name: &str,
proposed_max_tokens: u32,
) -> (u32, bool) {
provider_config::preflight_validate_max_tokens(
&self.config,
provider_name,
proposed_max_tokens,
)
}
/// Calculate max_tokens for a summary request.
@@ -377,8 +403,17 @@ impl<W: UiWriter> Agent<W> {
}
/// Apply the fallback sequence to free up context space for thinking budget.
fn apply_max_tokens_fallback_sequence(&mut self, provider_name: &str, initial_max_tokens: u32, hard_coded_minimum: u32) -> u32 {
self.apply_fallback_sequence_impl(provider_name, Some(initial_max_tokens), hard_coded_minimum)
fn apply_max_tokens_fallback_sequence(
&mut self,
provider_name: &str,
initial_max_tokens: u32,
hard_coded_minimum: u32,
) -> u32 {
self.apply_fallback_sequence_impl(
provider_name,
Some(initial_max_tokens),
hard_coded_minimum,
)
}
/// Unified implementation of the fallback sequence for freeing context space.
@@ -405,27 +440,33 @@ impl<W: UiWriter> Agent<W> {
);
// Step 1: Try thinnify (first third of context)
self.ui_writer.print_context_status("🥒 Step 1: Trying thinnify...\n");
self.ui_writer
.print_context_status("🥒 Step 1: Trying thinnify...\n");
let thin_msg = self.do_thin_context();
self.ui_writer.print_context_thinning(&thin_msg);
// Recalculate after thinnify
let (new_max, still_needs_reduction) = self.recalculate_max_tokens(provider_name, initial_max_tokens.is_some());
let (new_max, still_needs_reduction) =
self.recalculate_max_tokens(provider_name, initial_max_tokens.is_some());
max_tokens = new_max;
if !still_needs_reduction {
self.ui_writer.print_context_status("✅ Thinnify resolved capacity issue. Continuing...\n");
self.ui_writer
.print_context_status("✅ Thinnify resolved capacity issue. Continuing...\n");
return max_tokens;
}
// Step 2: Try skinnify (entire context)
self.ui_writer.print_context_status("🦴 Step 2: Trying skinnify...\n");
self.ui_writer
.print_context_status("🦴 Step 2: Trying skinnify...\n");
let skinny_msg = self.do_thin_context_all();
self.ui_writer.print_context_thinning(&skinny_msg);
// Recalculate after skinnify
let (final_max, final_needs_reduction) = self.recalculate_max_tokens(provider_name, initial_max_tokens.is_some());
let (final_max, final_needs_reduction) =
self.recalculate_max_tokens(provider_name, initial_max_tokens.is_some());
if !final_needs_reduction {
self.ui_writer.print_context_status("✅ Skinnify resolved capacity issue. Continuing...\n");
self.ui_writer
.print_context_status("✅ Skinnify resolved capacity issue. Continuing...\n");
return final_max;
}
@@ -937,7 +978,7 @@ impl<W: UiWriter> Agent<W> {
/// Manually trigger context compaction regardless of context window size
/// Returns Ok(true) if compaction was successful, Ok(false) if it failed
pub async fn force_compact(&mut self) -> Result<bool> {
use crate::compaction::{CompactionConfig, perform_compaction};
use crate::compaction::{perform_compaction, CompactionConfig};
debug!("Manual compaction triggered");
@@ -971,10 +1012,12 @@ impl<W: UiWriter> Agent<W> {
compaction_config,
&self.ui_writer,
&mut self.thinning_events,
).await?;
)
.await?;
if result.success {
self.ui_writer.print_context_status("✅ Context compacted successfully.\n");
self.ui_writer
.print_context_status("✅ Context compacted successfully.\n");
self.compaction_events.push(result.chars_saved);
Ok(true)
} else {
@@ -984,7 +1027,7 @@ impl<W: UiWriter> Agent<W> {
Ok(false)
}
}
/// Manually trigger context thinning regardless of thresholds
/// Manually trigger context thinning regardless of thresholds
pub fn force_thin(&mut self) -> String {
debug!("Manual context thinning triggered");
self.do_thin_context()
@@ -1006,7 +1049,9 @@ impl<W: UiWriter> Agent<W> {
/// Internal helper: thin all context and track the event
fn do_thin_context_all(&mut self) -> String {
let (message, chars_saved) = self.context_window.thin_context_all(self.session_id.as_deref());
let (message, chars_saved) = self
.context_window
.thin_context_all(self.session_id.as_deref());
self.thinning_events.push(chars_saved);
message
}
@@ -1030,10 +1075,12 @@ impl<W: UiWriter> Agent<W> {
self.ui_writer.print_context_thinning(&thin_summary);
if !self.context_window.should_compact() {
self.ui_writer.print_context_status("✅ Thinning resolved capacity issue. Continuing...\n");
self.ui_writer
.print_context_status("✅ Thinning resolved capacity issue. Continuing...\n");
return Ok(false);
}
self.ui_writer.print_context_status("⚠️ Thinning insufficient. Proceeding with compaction...\n");
self.ui_writer
.print_context_status("⚠️ Thinning insufficient. Proceeding with compaction...\n");
}
// Compaction still needed
@@ -1041,7 +1088,7 @@ impl<W: UiWriter> Agent<W> {
return Ok(false);
}
use crate::compaction::{CompactionConfig, perform_compaction};
use crate::compaction::{perform_compaction, CompactionConfig};
self.ui_writer.print_context_status(&format!(
"\n🗜️ Context window reaching capacity ({}%). Compacting...",
@@ -1049,7 +1096,10 @@ impl<W: UiWriter> Agent<W> {
));
let provider_name = self.providers.get(None)?.name().to_string();
let latest_user_msg = request.messages.iter().rev()
let latest_user_msg = request
.messages
.iter()
.rev()
.find(|m| matches!(m.role, MessageRole::User))
.map(|m| m.content.clone());
@@ -1065,17 +1115,21 @@ impl<W: UiWriter> Agent<W> {
compaction_config,
&self.ui_writer,
&mut self.thinning_events,
).await?;
)
.await?;
if result.success {
self.ui_writer.print_context_status("✅ Context compacted successfully. Continuing...\n");
self.ui_writer
.print_context_status("✅ Context compacted successfully. Continuing...\n");
self.compaction_events.push(result.chars_saved);
request.messages = self.context_window.conversation_history.clone();
return Ok(true);
}
self.ui_writer.print_context_status("⚠️ Unable to compact context. Consider starting a new session if you continue to see errors.\n");
Err(anyhow::anyhow!("Context window at capacity and compaction failed. Please start a new session."))
Err(anyhow::anyhow!(
"Context window at capacity and compaction failed. Please start a new session."
))
}
/// Check if a tool call is a duplicate of the last tool call in the previous assistant message.
@@ -1090,11 +1144,13 @@ impl<W: UiWriter> Agent<W> {
let content = &msg.content;
// Look for the last occurrence of a tool call pattern
let last_tool_start = content.rfind(r#"{"tool""#)
let last_tool_start = content
.rfind(r#"{"tool""#)
.or_else(|| content.rfind(r#"{ "tool""#))?;
// Find the end of this JSON object
let end_offset = StreamingToolParser::find_complete_json_object_end(&content[last_tool_start..])?;
let end_offset =
StreamingToolParser::find_complete_json_object_end(&content[last_tool_start..])?;
let end_idx = last_tool_start + end_offset + 1;
let tool_json = &content[last_tool_start..end_idx];
@@ -1234,7 +1290,10 @@ impl<W: UiWriter> Agent<W> {
.unwrap_or_else(|_| ".".to_string());
// Get description from first user message (strip "Task: " prefix if present)
let description = self.context_window.conversation_history.iter()
let description = self
.context_window
.conversation_history
.iter()
.find(|m| matches!(m.role, g3_providers::MessageRole::User))
.map(|m| {
let content = m.content.strip_prefix("Task: ").unwrap_or(&m.content);
@@ -1272,13 +1331,19 @@ impl<W: UiWriter> Agent<W> {
/// Enable auto-memory reminders after turns with tool calls
pub fn set_auto_memory(&mut self, enabled: bool) {
self.auto_memory = enabled;
debug!("Auto-memory reminders: {}", if enabled { "enabled" } else { "disabled" });
debug!(
"Auto-memory reminders: {}",
if enabled { "enabled" } else { "disabled" }
);
}
/// Enable or disable aggressive context dehydration (ACD)
pub fn set_acd_enabled(&mut self, enabled: bool) {
self.acd_enabled = enabled;
debug!("ACD (aggressive context dehydration): {}", if enabled { "enabled" } else { "disabled" });
debug!(
"ACD (aggressive context dehydration): {}",
if enabled { "enabled" } else { "disabled" }
);
}
/// Build the final response and prepare for return.
@@ -1342,7 +1407,8 @@ impl<W: UiWriter> Agent<W> {
// Find the index of the last dehydration stub (marks the end of previously dehydrated content)
// We only want to dehydrate messages AFTER the last stub+summary pair
let last_stub_index = self.context_window
let last_stub_index = self
.context_window
.conversation_history
.iter()
.rposition(|m| m.is_dehydrated_stub());
@@ -1356,14 +1422,19 @@ impl<W: UiWriter> Agent<W> {
};
// Get the preceding fragment ID (if any)
let preceding_id = crate::acd::get_latest_fragment_id(&session_id).ok().flatten();
let preceding_id = crate::acd::get_latest_fragment_id(&session_id)
.ok()
.flatten();
// Extract only NEW non-system messages to dehydrate (after the last stub+summary)
let messages_to_dehydrate: Vec<_> = self.context_window
let messages_to_dehydrate: Vec<_> = self
.context_window
.conversation_history
.iter()
.enumerate()
.filter(|(idx, m)| *idx >= dehydrate_start && !matches!(m.role, g3_providers::MessageRole::System))
.filter(|(idx, m)| {
*idx >= dehydrate_start && !matches!(m.role, g3_providers::MessageRole::System)
})
.map(|(_, m)| m.clone())
.collect();
@@ -1393,7 +1464,8 @@ impl<W: UiWriter> Agent<W> {
// Now replace the context: keep system messages + previous stubs/summaries, add new stub, add new summary
// Extract messages to keep: system messages + everything up to (but not including) dehydrate_start
let messages_to_keep: Vec<_> = self.context_window
let messages_to_keep: Vec<_> = self
.context_window
.conversation_history
.iter()
.enumerate()
@@ -1445,14 +1517,18 @@ impl<W: UiWriter> Agent<W> {
// Check if any tools were called this turn
if self.tool_calls_this_turn.is_empty() {
debug!("Auto-memory: No tools called, skipping reminder");
self.ui_writer.print_context_status("📝 Auto-memory: No tools called this turn, skipping reminder.\n");
self.ui_writer.print_context_status(
"📝 Auto-memory: No tools called this turn, skipping reminder.\n",
);
return Ok(false);
}
// Check if remember was already called this turn - no need to remind
if self.tool_calls_this_turn.iter().any(|t| t == "remember") {
debug!("Auto-memory: 'remember' was already called this turn, skipping reminder");
self.ui_writer.print_context_status("\n📝 Auto-memory: 'remember' already called, skipping reminder.\n");
self.ui_writer.print_context_status(
"\n📝 Auto-memory: 'remember' already called, skipping reminder.\n",
);
self.tool_calls_this_turn.clear();
return Ok(false);
}
@@ -1460,7 +1536,11 @@ impl<W: UiWriter> Agent<W> {
// Take the tools list and reset for next turn
let tools_called = std::mem::take(&mut self.tool_calls_this_turn);
debug!("Auto-memory: Sending reminder to LLM ({} tools called this turn: {:?})", tools_called.len(), tools_called);
debug!(
"Auto-memory: Sending reminder to LLM ({} tools called this turn: {:?})",
tools_called.len(),
tools_called
);
self.ui_writer.print_context_status("\nMemory checkpoint: ");
let reminder = r#"MEMORY CHECKPOINT: If you discovered code locations worth remembering, call `remember` now.
@@ -1496,10 +1576,8 @@ Save/restore session state across g3 invocations using symlink-based approach.
Skip if nothing new. Be brief."#;
// Add the reminder as a user message and get a response
self.context_window.add_message(Message::new(
MessageRole::User,
reminder.to_string(),
));
self.context_window
.add_message(Message::new(MessageRole::User, reminder.to_string()));
// Build the completion request
let messages = self.context_window.conversation_history.clone();
@@ -1584,7 +1662,8 @@ Skip if nothing new. Be brief."#;
// Restore messages from session log (skip system messages as they're preserved)
for msg in messages {
let role_str = msg.get("role").and_then(|r| r.as_str()).unwrap_or("user");
let role_str =
msg.get("role").and_then(|r| r.as_str()).unwrap_or("user");
let content = msg.get("content").and_then(|c| c.as_str()).unwrap_or("");
let role = match role_str {
@@ -1918,10 +1997,10 @@ Skip if nothing new. Be brief."#;
// Always process all tool calls - they will be executed after stream ends
// De-duplicate tool calls (sequential duplicates in chunk + duplicates from previous message)
let deduplicated_tools = streaming::deduplicate_tool_calls(
completed_tools,
|tc| self.check_duplicate_in_previous_message(tc),
);
let deduplicated_tools =
streaming::deduplicate_tool_calls(completed_tools, |tc| {
self.check_duplicate_in_previous_message(tc)
});
// Process each tool call
for (tool_call, duplicate_type) in deduplicated_tools {
@@ -1933,7 +2012,8 @@ Skip if nothing new. Be brief."#;
"Skipping duplicate tool call ({}): {} with args {}",
dup_type,
tool_call.tool,
serde_json::to_string(&tool_call.args).unwrap_or_else(|_| "<unserializable>".to_string())
serde_json::to_string(&tool_call.args)
.unwrap_or_else(|_| "<unserializable>".to_string())
);
continue;
}
@@ -1954,11 +2034,13 @@ Skip if nothing new. Be brief."#;
let text_content = parser.get_text_content();
let clean_content = streaming::clean_llm_tokens(&text_content);
let raw_content_for_log = clean_content.clone();
let filtered_content = self.ui_writer.filter_json_tool_calls(&clean_content);
let filtered_content =
self.ui_writer.filter_json_tool_calls(&clean_content);
let final_display_content = filtered_content.trim();
// Extract only the new (undisplayed) portion
let new_content = if current_response.len() <= final_display_content.len() {
let new_content =
if current_response.len() <= final_display_content.len() {
final_display_content
.chars()
.skip(already_displayed_chars)
@@ -1981,11 +2063,13 @@ Skip if nothing new. Be brief."#;
self.ui_writer.finish_streaming_markdown();
let is_todo_tool = tool_call.tool == "todo_read" || tool_call.tool == "todo_write";
let is_todo_tool =
tool_call.tool == "todo_read" || tool_call.tool == "todo_write";
// Tool call header (TODO tools print their own)
if !is_todo_tool {
self.ui_writer.print_tool_header(&tool_call.tool, Some(&tool_call.args));
self.ui_writer
.print_tool_header(&tool_call.tool, Some(&tool_call.args));
if let Some(args_obj) = tool_call.args.as_object() {
for (key, value) in args_obj {
let value_str = streaming::format_tool_arg_value(
@@ -1999,7 +2083,17 @@ Skip if nothing new. Be brief."#;
}
// Check if this is a compact tool (file operations)
let is_compact_tool = matches!(tool_call.tool.as_str(), "read_file" | "write_file" | "str_replace" | "remember" | "screenshot" | "coverage" | "rehydrate" | "code_search");
let is_compact_tool = matches!(
tool_call.tool.as_str(),
"read_file"
| "write_file"
| "str_replace"
| "remember"
| "screenshot"
| "coverage"
| "rehydrate"
| "code_search"
);
// Only print output header for non-compact tools
if !is_compact_tool && !is_todo_tool {
@@ -2051,29 +2145,53 @@ Skip if nothing new. Be brief."#;
Some(streaming::truncate_for_display(&tool_result, 60))
} else {
match tool_call.tool.as_str() {
"read_file" => Some(streaming::format_read_file_summary(output_len, tool_result.len())),
"write_file" => Some(streaming::format_write_file_result(&tool_result)),
"read_file" => {
Some(streaming::format_read_file_summary(
output_len,
tool_result.len(),
))
}
"write_file" => Some(
streaming::format_write_file_result(&tool_result),
),
"str_replace" => {
let (ins, del) = parse_diff_stats(&tool_result);
Some(streaming::format_str_replace_summary(ins, del))
Some(streaming::format_str_replace_summary(
ins, del,
))
}
"remember" => Some(streaming::format_remember_summary(&tool_result)),
"screenshot" => Some(streaming::format_screenshot_summary(&tool_result)),
"coverage" => Some(streaming::format_coverage_summary(&tool_result)),
"rehydrate" => Some(streaming::format_rehydrate_summary(&tool_result)),
"code_search" => Some(streaming::format_code_search_summary(&tool_result)),
"remember" => Some(streaming::format_remember_summary(
&tool_result,
)),
"screenshot" => Some(
streaming::format_screenshot_summary(&tool_result),
),
"coverage" => Some(streaming::format_coverage_summary(
&tool_result,
)),
"rehydrate" => Some(
streaming::format_rehydrate_summary(&tool_result),
),
"code_search" => Some(
streaming::format_code_search_summary(&tool_result),
),
_ => Some("✅ completed".to_string()),
}
}
} else {
// Regular tools: show truncated output lines
let max_lines_to_show = if wants_full { output_len } else { MAX_LINES };
let max_lines_to_show =
if wants_full { output_len } else { MAX_LINES };
for (idx, line) in output_lines.iter().enumerate() {
if !wants_full && idx >= max_lines_to_show {
break;
}
self.ui_writer.update_tool_output_line(
&streaming::truncate_line(line, MAX_LINE_WIDTH, !wants_full)
&streaming::truncate_line(
line,
MAX_LINE_WIDTH,
!wants_full,
),
);
}
if !wants_full && output_len > MAX_LINES {
@@ -2143,7 +2261,10 @@ Skip if nothing new. Be brief."#;
self.context_window.add_message(result_message);
// Closure marker with timing
let tokens_delta = self.context_window.used_tokens.saturating_sub(tokens_before);
let tokens_delta = self
.context_window
.used_tokens
.saturating_sub(tokens_before);
// TODO tools handle their own output via print_todo_compact, skip timing
if !is_todo_tool {
@@ -2157,10 +2278,11 @@ Skip if nothing new. Be brief."#;
self.context_window.percentage_used(),
);
} else {
self.ui_writer
.print_tool_timing(&streaming::format_duration(exec_duration),
self.ui_writer.print_tool_timing(
&streaming::format_duration(exec_duration),
tokens_delta,
self.context_window.percentage_used());
self.context_window.percentage_used(),
);
}
}
self.ui_writer.print_agent_prompt();
@@ -2179,7 +2301,8 @@ Skip if nothing new. Be brief."#;
if self.agent_name.as_deref() == Some("scout") {
tool_config = tool_config.with_research_excluded();
}
request.tools = Some(tool_definitions::create_tool_definitions(tool_config));
request.tools =
Some(tool_definitions::create_tool_definitions(tool_config));
}
// DO NOT add final_display_content to full_response here!
@@ -2196,7 +2319,6 @@ Skip if nothing new. Be brief."#;
// This gives the LLM fresh attempts since it's making progress
auto_summary_attempts = 0;
// Reset the JSON tool call filter state after each tool execution
// This ensures the filter doesn't stay in suppression mode for subsequent streaming content
self.ui_writer.reset_json_filter();
@@ -2204,7 +2326,9 @@ Skip if nothing new. Be brief."#;
// Only reset parser if there are no more unexecuted tool calls in the buffer
// This handles the case where the LLM emits multiple tool calls in one response
if parser.has_unexecuted_tool_call() {
debug!("Parser still has unexecuted tool calls, not resetting buffer");
debug!(
"Parser still has unexecuted tool calls, not resetting buffer"
);
// Mark current tool as consumed so we don't re-detect it
parser.mark_tool_calls_consumed();
} else {
@@ -2366,7 +2490,6 @@ Skip if nothing new. Be brief."#;
error!("Parser state at error: text_buffer_len={}, has_incomplete={}, message_stopped={}",
parser.text_buffer_len(), parser.has_incomplete_tool_call(), parser.is_message_stopped());
// Check if this is a recoverable connection error
let is_connection_error = streaming::is_connection_error(&error_msg);
@@ -2452,7 +2575,8 @@ Skip if nothing new. Be brief."#;
}
// Check if the response was truncated due to max_tokens
let was_truncated_by_max_tokens = stream_stop_reason.as_deref() == Some("max_tokens");
let was_truncated_by_max_tokens =
stream_stop_reason.as_deref() == Some("max_tokens");
if was_truncated_by_max_tokens {
debug!("Response was truncated due to max_tokens limit");
warn!("LLM response was cut off due to max_tokens limit - will auto-continue");
@@ -2491,8 +2615,13 @@ Skip if nothing new. Be brief."#;
"\n🔄 Model stopped without providing summary. Auto-continuing...\n",
),
};
warn!("{} ({} iterations, auto-continue attempt {}/{})",
log_msg, iteration_count, auto_summary_attempts, MAX_AUTO_SUMMARY_ATTEMPTS);
warn!(
"{} ({} iterations, auto-continue attempt {}/{})",
log_msg,
iteration_count,
auto_summary_attempts,
MAX_AUTO_SUMMARY_ATTEMPTS
);
self.ui_writer.print_context_status(ui_msg);
// Add any text response to context before prompting for continuation
@@ -2613,11 +2742,14 @@ Skip if nothing new. Be brief."#;
Ok(s) => s.clone(),
Err(e) => format!("ERROR: {}", e),
};
debug!("Tool {} completed: {}", tool_call.tool, &log_str.chars().take(100).collect::<String>());
debug!(
"Tool {} completed: {}",
tool_call.tool,
&log_str.chars().take(100).collect::<String>()
);
result
}
async fn execute_tool_inner_in_dir(
&mut self,
tool_call: &ToolCall,
@@ -2660,11 +2792,8 @@ Skip if nothing new. Be brief."#;
Ok(result)
}
}
// Re-export utility functions
pub use utils::apply_unified_diff_to_string;
use utils::truncate_to_word_boundary;
@@ -2678,12 +2807,20 @@ fn parse_diff_stats(result: &str) -> (i32, i32) {
// Look for "+N insertions" pattern
if let Some(pos) = result.find("+") {
let after_plus = &result[pos + 1..];
insertions = after_plus.split_whitespace().next().and_then(|s| s.parse().ok()).unwrap_or(0);
insertions = after_plus
.split_whitespace()
.next()
.and_then(|s| s.parse().ok())
.unwrap_or(0);
}
// Look for "-M deletions" pattern
if let Some(pos) = result.find("-") {
let after_minus = &result[pos + 1..];
deletions = after_minus.split_whitespace().next().and_then(|s| s.parse().ok()).unwrap_or(0);
deletions = after_minus
.split_whitespace()
.next()
.and_then(|s| s.parse().ok())
.unwrap_or(0);
}
(insertions, deletions)
}

View File

@@ -110,9 +110,12 @@ use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error};
use crate::{
streaming::{
decode_utf8_streaming, make_final_chunk, make_final_chunk_with_reason, make_text_chunk,
make_tool_chunk,
},
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
MessageRole, Tool, ToolCall, Usage,
streaming::{decode_utf8_streaming, make_final_chunk, make_final_chunk_with_reason, make_text_chunk, make_tool_chunk},
};
const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
@@ -156,7 +159,7 @@ impl AnthropicProvider {
name: "anthropic".to_string(),
api_key,
model,
max_tokens: max_tokens.unwrap_or(4096),
max_tokens: max_tokens.unwrap_or(32768),
temperature: temperature.unwrap_or(0.1),
cache_config,
enable_1m_context: enable_1m_context.unwrap_or(false),
@@ -182,14 +185,17 @@ impl AnthropicProvider {
let model = model.unwrap_or_else(|| "claude-3-5-sonnet-20241022".to_string());
debug!("Initialized Anthropic provider '{}' with model: {}", name, model);
debug!(
"Initialized Anthropic provider '{}' with model: {}",
name, model
);
Ok(Self {
client,
name,
api_key,
model,
max_tokens: max_tokens.unwrap_or(4096),
max_tokens: max_tokens.unwrap_or(32768),
temperature: temperature.unwrap_or(0.1),
cache_config,
enable_1m_context: enable_1m_context.unwrap_or(false),
@@ -292,7 +298,10 @@ impl AnthropicProvider {
// Add text content
content_blocks.push(AnthropicContent::Text {
text: message.content.clone(),
cache_control: message.cache_control.as_ref().map(Self::convert_cache_control),
cache_control: message
.cache_control
.as_ref()
.map(Self::convert_cache_control),
});
anthropic_messages.push(AnthropicMessage {
@@ -427,7 +436,10 @@ impl AnthropicProvider {
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
debug!("Received stream completion marker");
let final_chunk = make_final_chunk(current_tool_calls.clone(), accumulated_usage.clone());
let final_chunk = make_final_chunk(
current_tool_calls.clone(),
accumulated_usage.clone(),
);
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
}
@@ -491,7 +503,8 @@ impl AnthropicProvider {
{
// We have complete arguments, send the tool call immediately
debug!("Tool call has complete args, sending immediately: {:?}", tool_call);
let chunk = make_tool_chunk(vec![tool_call]);
let chunk =
make_tool_chunk(vec![tool_call]);
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return accumulated_usage;
@@ -575,7 +588,8 @@ impl AnthropicProvider {
// Send the complete tool call
if !current_tool_calls.is_empty() {
let chunk = make_tool_chunk(current_tool_calls.clone());
let chunk =
make_tool_chunk(current_tool_calls.clone());
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return accumulated_usage;
@@ -597,7 +611,11 @@ impl AnthropicProvider {
"message_stop" => {
debug!("Received message stop event");
message_stopped = true;
let final_chunk = make_final_chunk_with_reason(current_tool_calls.clone(), accumulated_usage.clone(), stop_reason.clone());
let final_chunk = make_final_chunk_with_reason(
current_tool_calls.clone(),
accumulated_usage.clone(),
stop_reason.clone(),
);
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
}
@@ -826,7 +844,10 @@ struct ThinkingConfig {
impl ThinkingConfig {
fn enabled(budget_tokens: u32) -> Self {
Self { thinking_type: "enabled".to_string(), budget_tokens }
Self {
thinking_type: "enabled".to_string(),
budget_tokens,
}
}
}
@@ -889,9 +910,7 @@ enum AnthropicContent {
input: serde_json::Value,
},
#[serde(rename = "image")]
Image {
source: AnthropicImageSource,
},
Image { source: AnthropicImageSource },
}
/// Image source for Anthropic API
@@ -962,7 +981,8 @@ mod tests {
#[test]
fn test_message_conversion() {
let provider =
AnthropicProvider::new("test-key".to_string(), None, None, None, None, None, None).unwrap();
AnthropicProvider::new("test-key".to_string(), None, None, None, None, None, None)
.unwrap();
let messages = vec![
Message::new(
@@ -1011,7 +1031,8 @@ mod tests {
#[test]
fn test_tool_conversion() {
let provider =
AnthropicProvider::new("test-key".to_string(), None, None, None, None, None, None).unwrap();
AnthropicProvider::new("test-key".to_string(), None, None, None, None, None, None)
.unwrap();
let tools = vec![Tool {
name: "get_weather".to_string(),
@@ -1044,7 +1065,8 @@ mod tests {
#[test]
fn test_cache_control_serialization() {
let provider =
AnthropicProvider::new("test-key".to_string(), None, None, None, None, None, None).unwrap();
AnthropicProvider::new("test-key".to_string(), None, None, None, None, None, None)
.unwrap();
// Test message WITHOUT cache_control
let messages_without = vec![Message::new(MessageRole::User, "Hello".to_string())];
@@ -1106,7 +1128,10 @@ mod tests {
.create_request_body(&messages, None, false, 1000, 0.5, false)
.unwrap();
let json_without = serde_json::to_string(&request_without).unwrap();
assert!(!json_without.contains("thinking"), "JSON should not contain 'thinking' field when not configured");
assert!(
!json_without.contains("thinking"),
"JSON should not contain 'thinking' field when not configured"
);
// Test WITH thinking parameter - max_tokens must be > budget_tokens + 1024
// Using budget=10000 requires max_tokens > 11024
@@ -1125,16 +1150,28 @@ mod tests {
.create_request_body(&messages, None, false, 20000, 0.5, false)
.unwrap();
let json_with = serde_json::to_string(&request_with).unwrap();
assert!(json_with.contains("thinking"), "JSON should contain 'thinking' field when configured");
assert!(json_with.contains("\"type\":\"enabled\""), "JSON should contain type: enabled");
assert!(json_with.contains("\"budget_tokens\":10000"), "JSON should contain budget_tokens: 10000");
assert!(
json_with.contains("thinking"),
"JSON should contain 'thinking' field when configured"
);
assert!(
json_with.contains("\"type\":\"enabled\""),
"JSON should contain type: enabled"
);
assert!(
json_with.contains("\"budget_tokens\":10000"),
"JSON should contain budget_tokens: 10000"
);
// Test WITH thinking parameter but INSUFFICIENT max_tokens - thinking should be disabled
let request_insufficient = provider_with
.create_request_body(&messages, None, false, 5000, 0.5, false) // Less than budget + 1024
.unwrap();
let json_insufficient = serde_json::to_string(&request_insufficient).unwrap();
assert!(!json_insufficient.contains("thinking"), "JSON should NOT contain 'thinking' field when max_tokens is insufficient");
assert!(
!json_insufficient.contains("thinking"),
"JSON should NOT contain 'thinking' field when max_tokens is insufficient"
);
}
#[test]
@@ -1158,14 +1195,20 @@ mod tests {
.create_request_body(&messages, None, false, 20000, 0.5, false)
.unwrap();
let json_with = serde_json::to_string(&request_with_thinking).unwrap();
assert!(json_with.contains("thinking"), "JSON should contain 'thinking' field when not disabled");
assert!(
json_with.contains("thinking"),
"JSON should contain 'thinking' field when not disabled"
);
// With disable_thinking=true, thinking should be disabled even with sufficient max_tokens
let request_without_thinking = provider
.create_request_body(&messages, None, false, 20000, 0.5, true)
.unwrap();
let json_without = serde_json::to_string(&request_without_thinking).unwrap();
assert!(!json_without.contains("thinking"), "JSON should NOT contain 'thinking' field when explicitly disabled");
assert!(
!json_without.contains("thinking"),
"JSON should NOT contain 'thinking' field when explicitly disabled"
);
}
#[test]
@@ -1188,10 +1231,14 @@ mod tests {
assert_eq!(response.model, "claude-sonnet-4-5");
// Extract only text content (thinking should be filtered out)
let text_content: Vec<_> = response.content.iter().filter_map(|c| match c {
let text_content: Vec<_> = response
.content
.iter()
.filter_map(|c| match c {
AnthropicContent::Text { text, .. } => Some(text.as_str()),
_ => None,
}).collect();
})
.collect();
assert_eq!(text_content.len(), 1);
assert_eq!(text_content[0], "Here is my response.");