config: default agent settings and provider override

This commit is contained in:
Dhanji R. Prasanna
2026-01-14 20:14:33 +05:30
parent 38828c7757
commit f4562cd4c9
5 changed files with 595 additions and 396 deletions

View File

@@ -7,8 +7,11 @@ use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub providers: ProvidersConfig,
#[serde(default)]
pub agent: AgentConfig,
#[serde(default)]
pub computer_control: ComputerControlConfig,
#[serde(default)]
pub webdriver: WebDriverConfig,
}
@@ -17,32 +20,32 @@ pub struct Config {
pub struct ProvidersConfig {
/// Default provider in format "<provider_type>.<config_name>"
pub default_provider: String,
/// Provider for planner mode (optional, falls back to default_provider)
pub planner: Option<String>,
/// Provider for coach in autonomous mode (optional, falls back to default_provider)
pub coach: Option<String>,
/// Provider for player in autonomous mode (optional, falls back to default_provider)
pub player: Option<String>,
/// Named Anthropic provider configs
#[serde(default)]
pub anthropic: HashMap<String, AnthropicConfig>,
/// Named OpenAI provider configs
#[serde(default)]
pub openai: HashMap<String, OpenAIConfig>,
/// Named Databricks provider configs
#[serde(default)]
pub databricks: HashMap<String, DatabricksConfig>,
/// Named embedded provider configs
#[serde(default)]
pub embedded: HashMap<String, EmbeddedConfig>,
/// Multiple named OpenAI-compatible providers (e.g., openrouter, groq, etc.)
#[serde(default)]
pub openai_compatible: HashMap<String, OpenAIConfig>,
@@ -92,24 +95,59 @@ pub struct EmbeddedConfig {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
pub max_context_length: Option<u32>,
#[serde(default = "default_fallback_max_tokens")]
pub fallback_default_max_tokens: usize,
#[serde(default = "default_true")]
pub enable_streaming: bool,
#[serde(default = "default_timeout_seconds")]
pub timeout_seconds: u64,
#[serde(default = "default_true")]
pub auto_compact: bool,
#[serde(default = "default_max_retry_attempts")]
pub max_retry_attempts: u32,
#[serde(default = "default_autonomous_max_retry_attempts")]
pub autonomous_max_retry_attempts: u32,
#[serde(default = "default_check_todo_staleness")]
pub check_todo_staleness: bool,
}
fn default_fallback_max_tokens() -> usize {
8192
}
fn default_true() -> bool {
true
}
fn default_false() -> bool {
false
}
fn default_timeout_seconds() -> u64 {
120
}
fn default_max_retry_attempts() -> u32 {
3
}
fn default_autonomous_max_retry_attempts() -> u32 {
6
}
fn default_max_actions_per_second() -> u32 {
5
}
fn default_check_todo_staleness() -> bool {
true
}
fn default_safari_port() -> u16 {
4444
}
fn default_chrome_port() -> u16 {
9515
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComputerControlConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_false")]
pub require_confirmation: bool,
#[serde(default = "default_max_actions_per_second")]
pub max_actions_per_second: u32,
}
@@ -117,17 +155,19 @@ pub struct ComputerControlConfig {
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "lowercase")]
pub enum WebDriverBrowser {
#[default]
Safari,
#[default]
#[serde(rename = "chrome-headless")]
ChromeHeadless,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct WebDriverConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_safari_port")]
pub safari_port: u16,
#[serde(default)]
#[serde(default = "default_chrome_port")]
pub chrome_port: u16,
#[serde(default)]
/// Optional path to Chrome binary (e.g., Chrome for Testing)
@@ -141,24 +181,25 @@ pub struct WebDriverConfig {
pub browser: WebDriverBrowser,
}
impl Default for WebDriverConfig {
impl Default for AgentConfig {
fn default() -> Self {
Self {
enabled: true,
safari_port: 4444,
chrome_port: 9515,
chrome_binary: None,
chromedriver_binary: None,
browser: WebDriverBrowser::Safari,
max_context_length: None,
fallback_default_max_tokens: 8192,
enable_streaming: true,
timeout_seconds: 120,
auto_compact: true,
max_retry_attempts: 3,
autonomous_max_retry_attempts: 6,
check_todo_staleness: true,
}
}
}
impl Default for ComputerControlConfig {
fn default() -> Self {
Self {
enabled: false,
require_confirmation: true,
enabled: true,
require_confirmation: false,
max_actions_per_second: 5,
}
}
@@ -296,17 +337,17 @@ impl Config {
if let Some(path) = config_path_to_load {
// Read and parse the config file
let config_content = std::fs::read_to_string(&path)?;
// Check for old format (direct provider config without named configs)
if Self::is_old_format(&config_content) {
anyhow::bail!("{}", OLD_CONFIG_FORMAT_ERROR);
}
let config: Config = toml::from_str(&config_content)?;
// Validate the default_provider format
config.validate_provider_reference(&config.providers.default_provider)?;
return Ok(config);
}
@@ -317,7 +358,7 @@ impl Config {
fn is_old_format(content: &str) -> bool {
// Old format has [providers.anthropic] with api_key directly
// New format has [providers.anthropic.<name>] with api_key
// Parse as TOML value to inspect structure
if let Ok(value) = content.parse::<toml::Value>() {
if let Some(providers) = value.get("providers") {
@@ -445,20 +486,26 @@ impl Config {
// Apply provider override
if let Some(provider) = provider_override {
// Validate the override
// If provider doesn't contain '.', assume '.default'
let provider = if provider.contains('.') {
provider
} else {
format!("{}.default", provider)
};
config.validate_provider_reference(&provider)?;
config.providers.default_provider = provider;
}
// Apply model override to the active provider
if let Some(model) = model_override {
let (provider_type, config_name) = Self::parse_provider_reference(
&config.providers.default_provider
)?;
let (provider_type, config_name) =
Self::parse_provider_reference(&config.providers.default_provider)?;
match provider_type.as_str() {
"anthropic" => {
if let Some(ref mut anthropic_config) = config.providers.anthropic.get_mut(&config_name) {
if let Some(ref mut anthropic_config) =
config.providers.anthropic.get_mut(&config_name)
{
anthropic_config.model = model;
} else {
return Err(anyhow::anyhow!(
@@ -468,7 +515,9 @@ impl Config {
}
}
"databricks" => {
if let Some(ref mut databricks_config) = config.providers.databricks.get_mut(&config_name) {
if let Some(ref mut databricks_config) =
config.providers.databricks.get_mut(&config_name)
{
databricks_config.model = model;
} else {
return Err(anyhow::anyhow!(
@@ -478,7 +527,9 @@ impl Config {
}
}
"embedded" => {
if let Some(ref mut embedded_config) = config.providers.embedded.get_mut(&config_name) {
if let Some(ref mut embedded_config) =
config.providers.embedded.get_mut(&config_name)
{
embedded_config.model_path = model;
} else {
return Err(anyhow::anyhow!(
@@ -488,7 +539,9 @@ impl Config {
}
}
"openai" => {
if let Some(ref mut openai_config) = config.providers.openai.get_mut(&config_name) {
if let Some(ref mut openai_config) =
config.providers.openai.get_mut(&config_name)
{
openai_config.model = model;
} else {
return Err(anyhow::anyhow!(
@@ -499,13 +552,12 @@ impl Config {
}
_ => {
// Check openai_compatible
if let Some(ref mut compat_config) = config.providers.openai_compatible.get_mut(&provider_type) {
if let Some(ref mut compat_config) =
config.providers.openai_compatible.get_mut(&provider_type)
{
compat_config.model = model;
} else {
return Err(anyhow::anyhow!(
"Unknown provider type: {}",
provider_type
));
return Err(anyhow::anyhow!("Unknown provider type: {}", provider_type));
}
}
}
@@ -585,36 +637,42 @@ impl Config {
/// Get the current default provider's config
pub fn get_default_provider_config(&self) -> Result<ProviderConfigRef<'_>> {
let (provider_type, config_name) = Self::parse_provider_reference(
&self.providers.default_provider
)?;
let (provider_type, config_name) =
Self::parse_provider_reference(&self.providers.default_provider)?;
match provider_type.as_str() {
"anthropic" => {
self.providers.anthropic.get(&config_name)
.map(ProviderConfigRef::Anthropic)
.ok_or_else(|| anyhow::anyhow!("Anthropic config '{}' not found", config_name))
}
"openai" => {
self.providers.openai.get(&config_name)
.map(ProviderConfigRef::OpenAI)
.ok_or_else(|| anyhow::anyhow!("OpenAI config '{}' not found", config_name))
}
"databricks" => {
self.providers.databricks.get(&config_name)
.map(ProviderConfigRef::Databricks)
.ok_or_else(|| anyhow::anyhow!("Databricks config '{}' not found", config_name))
}
"embedded" => {
self.providers.embedded.get(&config_name)
.map(ProviderConfigRef::Embedded)
.ok_or_else(|| anyhow::anyhow!("Embedded config '{}' not found", config_name))
}
_ => {
self.providers.openai_compatible.get(&provider_type)
.map(ProviderConfigRef::OpenAICompatible)
.ok_or_else(|| anyhow::anyhow!("OpenAI compatible config '{}' not found", provider_type))
}
"anthropic" => self
.providers
.anthropic
.get(&config_name)
.map(ProviderConfigRef::Anthropic)
.ok_or_else(|| anyhow::anyhow!("Anthropic config '{}' not found", config_name)),
"openai" => self
.providers
.openai
.get(&config_name)
.map(ProviderConfigRef::OpenAI)
.ok_or_else(|| anyhow::anyhow!("OpenAI config '{}' not found", config_name)),
"databricks" => self
.providers
.databricks
.get(&config_name)
.map(ProviderConfigRef::Databricks)
.ok_or_else(|| anyhow::anyhow!("Databricks config '{}' not found", config_name)),
"embedded" => self
.providers
.embedded
.get(&config_name)
.map(ProviderConfigRef::Embedded)
.ok_or_else(|| anyhow::anyhow!("Embedded config '{}' not found", config_name)),
_ => self
.providers
.openai_compatible
.get(&provider_type)
.map(ProviderConfigRef::OpenAICompatible)
.ok_or_else(|| {
anyhow::anyhow!("OpenAI compatible config '{}' not found", provider_type)
}),
}
}
}