config: default agent settings and provider override
This commit is contained in:
@@ -7,8 +7,11 @@ use std::path::Path;
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Config {
|
||||
pub providers: ProvidersConfig,
|
||||
#[serde(default)]
|
||||
pub agent: AgentConfig,
|
||||
#[serde(default)]
|
||||
pub computer_control: ComputerControlConfig,
|
||||
#[serde(default)]
|
||||
pub webdriver: WebDriverConfig,
|
||||
}
|
||||
|
||||
@@ -17,32 +20,32 @@ pub struct Config {
|
||||
pub struct ProvidersConfig {
|
||||
/// Default provider in format "<provider_type>.<config_name>"
|
||||
pub default_provider: String,
|
||||
|
||||
|
||||
/// Provider for planner mode (optional, falls back to default_provider)
|
||||
pub planner: Option<String>,
|
||||
|
||||
|
||||
/// Provider for coach in autonomous mode (optional, falls back to default_provider)
|
||||
pub coach: Option<String>,
|
||||
|
||||
|
||||
/// Provider for player in autonomous mode (optional, falls back to default_provider)
|
||||
pub player: Option<String>,
|
||||
|
||||
|
||||
/// Named Anthropic provider configs
|
||||
#[serde(default)]
|
||||
pub anthropic: HashMap<String, AnthropicConfig>,
|
||||
|
||||
|
||||
/// Named OpenAI provider configs
|
||||
#[serde(default)]
|
||||
pub openai: HashMap<String, OpenAIConfig>,
|
||||
|
||||
|
||||
/// Named Databricks provider configs
|
||||
#[serde(default)]
|
||||
pub databricks: HashMap<String, DatabricksConfig>,
|
||||
|
||||
|
||||
/// Named embedded provider configs
|
||||
#[serde(default)]
|
||||
pub embedded: HashMap<String, EmbeddedConfig>,
|
||||
|
||||
|
||||
/// Multiple named OpenAI-compatible providers (e.g., openrouter, groq, etc.)
|
||||
#[serde(default)]
|
||||
pub openai_compatible: HashMap<String, OpenAIConfig>,
|
||||
@@ -92,24 +95,59 @@ pub struct EmbeddedConfig {
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AgentConfig {
|
||||
pub max_context_length: Option<u32>,
|
||||
#[serde(default = "default_fallback_max_tokens")]
|
||||
pub fallback_default_max_tokens: usize,
|
||||
#[serde(default = "default_true")]
|
||||
pub enable_streaming: bool,
|
||||
#[serde(default = "default_timeout_seconds")]
|
||||
pub timeout_seconds: u64,
|
||||
#[serde(default = "default_true")]
|
||||
pub auto_compact: bool,
|
||||
#[serde(default = "default_max_retry_attempts")]
|
||||
pub max_retry_attempts: u32,
|
||||
#[serde(default = "default_autonomous_max_retry_attempts")]
|
||||
pub autonomous_max_retry_attempts: u32,
|
||||
#[serde(default = "default_check_todo_staleness")]
|
||||
pub check_todo_staleness: bool,
|
||||
}
|
||||
|
||||
fn default_fallback_max_tokens() -> usize {
|
||||
8192
|
||||
}
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
fn default_false() -> bool {
|
||||
false
|
||||
}
|
||||
fn default_timeout_seconds() -> u64 {
|
||||
120
|
||||
}
|
||||
fn default_max_retry_attempts() -> u32 {
|
||||
3
|
||||
}
|
||||
fn default_autonomous_max_retry_attempts() -> u32 {
|
||||
6
|
||||
}
|
||||
fn default_max_actions_per_second() -> u32 {
|
||||
5
|
||||
}
|
||||
fn default_check_todo_staleness() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_safari_port() -> u16 {
|
||||
4444
|
||||
}
|
||||
fn default_chrome_port() -> u16 {
|
||||
9515
|
||||
}
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ComputerControlConfig {
|
||||
#[serde(default = "default_true")]
|
||||
pub enabled: bool,
|
||||
#[serde(default = "default_false")]
|
||||
pub require_confirmation: bool,
|
||||
#[serde(default = "default_max_actions_per_second")]
|
||||
pub max_actions_per_second: u32,
|
||||
}
|
||||
|
||||
@@ -117,17 +155,19 @@ pub struct ComputerControlConfig {
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum WebDriverBrowser {
|
||||
#[default]
|
||||
Safari,
|
||||
#[default]
|
||||
#[serde(rename = "chrome-headless")]
|
||||
ChromeHeadless,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct WebDriverConfig {
|
||||
#[serde(default = "default_true")]
|
||||
pub enabled: bool,
|
||||
#[serde(default = "default_safari_port")]
|
||||
pub safari_port: u16,
|
||||
#[serde(default)]
|
||||
#[serde(default = "default_chrome_port")]
|
||||
pub chrome_port: u16,
|
||||
#[serde(default)]
|
||||
/// Optional path to Chrome binary (e.g., Chrome for Testing)
|
||||
@@ -141,24 +181,25 @@ pub struct WebDriverConfig {
|
||||
pub browser: WebDriverBrowser,
|
||||
}
|
||||
|
||||
impl Default for WebDriverConfig {
|
||||
impl Default for AgentConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
safari_port: 4444,
|
||||
chrome_port: 9515,
|
||||
chrome_binary: None,
|
||||
chromedriver_binary: None,
|
||||
browser: WebDriverBrowser::Safari,
|
||||
max_context_length: None,
|
||||
fallback_default_max_tokens: 8192,
|
||||
enable_streaming: true,
|
||||
timeout_seconds: 120,
|
||||
auto_compact: true,
|
||||
max_retry_attempts: 3,
|
||||
autonomous_max_retry_attempts: 6,
|
||||
check_todo_staleness: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ComputerControlConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
require_confirmation: true,
|
||||
enabled: true,
|
||||
require_confirmation: false,
|
||||
max_actions_per_second: 5,
|
||||
}
|
||||
}
|
||||
@@ -296,17 +337,17 @@ impl Config {
|
||||
if let Some(path) = config_path_to_load {
|
||||
// Read and parse the config file
|
||||
let config_content = std::fs::read_to_string(&path)?;
|
||||
|
||||
|
||||
// Check for old format (direct provider config without named configs)
|
||||
if Self::is_old_format(&config_content) {
|
||||
anyhow::bail!("{}", OLD_CONFIG_FORMAT_ERROR);
|
||||
}
|
||||
|
||||
|
||||
let config: Config = toml::from_str(&config_content)?;
|
||||
|
||||
|
||||
// Validate the default_provider format
|
||||
config.validate_provider_reference(&config.providers.default_provider)?;
|
||||
|
||||
|
||||
return Ok(config);
|
||||
}
|
||||
|
||||
@@ -317,7 +358,7 @@ impl Config {
|
||||
fn is_old_format(content: &str) -> bool {
|
||||
// Old format has [providers.anthropic] with api_key directly
|
||||
// New format has [providers.anthropic.<name>] with api_key
|
||||
|
||||
|
||||
// Parse as TOML value to inspect structure
|
||||
if let Ok(value) = content.parse::<toml::Value>() {
|
||||
if let Some(providers) = value.get("providers") {
|
||||
@@ -445,20 +486,26 @@ impl Config {
|
||||
|
||||
// Apply provider override
|
||||
if let Some(provider) = provider_override {
|
||||
// Validate the override
|
||||
// If provider doesn't contain '.', assume '.default'
|
||||
let provider = if provider.contains('.') {
|
||||
provider
|
||||
} else {
|
||||
format!("{}.default", provider)
|
||||
};
|
||||
config.validate_provider_reference(&provider)?;
|
||||
config.providers.default_provider = provider;
|
||||
}
|
||||
|
||||
// Apply model override to the active provider
|
||||
if let Some(model) = model_override {
|
||||
let (provider_type, config_name) = Self::parse_provider_reference(
|
||||
&config.providers.default_provider
|
||||
)?;
|
||||
let (provider_type, config_name) =
|
||||
Self::parse_provider_reference(&config.providers.default_provider)?;
|
||||
|
||||
match provider_type.as_str() {
|
||||
"anthropic" => {
|
||||
if let Some(ref mut anthropic_config) = config.providers.anthropic.get_mut(&config_name) {
|
||||
if let Some(ref mut anthropic_config) =
|
||||
config.providers.anthropic.get_mut(&config_name)
|
||||
{
|
||||
anthropic_config.model = model;
|
||||
} else {
|
||||
return Err(anyhow::anyhow!(
|
||||
@@ -468,7 +515,9 @@ impl Config {
|
||||
}
|
||||
}
|
||||
"databricks" => {
|
||||
if let Some(ref mut databricks_config) = config.providers.databricks.get_mut(&config_name) {
|
||||
if let Some(ref mut databricks_config) =
|
||||
config.providers.databricks.get_mut(&config_name)
|
||||
{
|
||||
databricks_config.model = model;
|
||||
} else {
|
||||
return Err(anyhow::anyhow!(
|
||||
@@ -478,7 +527,9 @@ impl Config {
|
||||
}
|
||||
}
|
||||
"embedded" => {
|
||||
if let Some(ref mut embedded_config) = config.providers.embedded.get_mut(&config_name) {
|
||||
if let Some(ref mut embedded_config) =
|
||||
config.providers.embedded.get_mut(&config_name)
|
||||
{
|
||||
embedded_config.model_path = model;
|
||||
} else {
|
||||
return Err(anyhow::anyhow!(
|
||||
@@ -488,7 +539,9 @@ impl Config {
|
||||
}
|
||||
}
|
||||
"openai" => {
|
||||
if let Some(ref mut openai_config) = config.providers.openai.get_mut(&config_name) {
|
||||
if let Some(ref mut openai_config) =
|
||||
config.providers.openai.get_mut(&config_name)
|
||||
{
|
||||
openai_config.model = model;
|
||||
} else {
|
||||
return Err(anyhow::anyhow!(
|
||||
@@ -499,13 +552,12 @@ impl Config {
|
||||
}
|
||||
_ => {
|
||||
// Check openai_compatible
|
||||
if let Some(ref mut compat_config) = config.providers.openai_compatible.get_mut(&provider_type) {
|
||||
if let Some(ref mut compat_config) =
|
||||
config.providers.openai_compatible.get_mut(&provider_type)
|
||||
{
|
||||
compat_config.model = model;
|
||||
} else {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Unknown provider type: {}",
|
||||
provider_type
|
||||
));
|
||||
return Err(anyhow::anyhow!("Unknown provider type: {}", provider_type));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -585,36 +637,42 @@ impl Config {
|
||||
|
||||
/// Get the current default provider's config
|
||||
pub fn get_default_provider_config(&self) -> Result<ProviderConfigRef<'_>> {
|
||||
let (provider_type, config_name) = Self::parse_provider_reference(
|
||||
&self.providers.default_provider
|
||||
)?;
|
||||
let (provider_type, config_name) =
|
||||
Self::parse_provider_reference(&self.providers.default_provider)?;
|
||||
|
||||
match provider_type.as_str() {
|
||||
"anthropic" => {
|
||||
self.providers.anthropic.get(&config_name)
|
||||
.map(ProviderConfigRef::Anthropic)
|
||||
.ok_or_else(|| anyhow::anyhow!("Anthropic config '{}' not found", config_name))
|
||||
}
|
||||
"openai" => {
|
||||
self.providers.openai.get(&config_name)
|
||||
.map(ProviderConfigRef::OpenAI)
|
||||
.ok_or_else(|| anyhow::anyhow!("OpenAI config '{}' not found", config_name))
|
||||
}
|
||||
"databricks" => {
|
||||
self.providers.databricks.get(&config_name)
|
||||
.map(ProviderConfigRef::Databricks)
|
||||
.ok_or_else(|| anyhow::anyhow!("Databricks config '{}' not found", config_name))
|
||||
}
|
||||
"embedded" => {
|
||||
self.providers.embedded.get(&config_name)
|
||||
.map(ProviderConfigRef::Embedded)
|
||||
.ok_or_else(|| anyhow::anyhow!("Embedded config '{}' not found", config_name))
|
||||
}
|
||||
_ => {
|
||||
self.providers.openai_compatible.get(&provider_type)
|
||||
.map(ProviderConfigRef::OpenAICompatible)
|
||||
.ok_or_else(|| anyhow::anyhow!("OpenAI compatible config '{}' not found", provider_type))
|
||||
}
|
||||
"anthropic" => self
|
||||
.providers
|
||||
.anthropic
|
||||
.get(&config_name)
|
||||
.map(ProviderConfigRef::Anthropic)
|
||||
.ok_or_else(|| anyhow::anyhow!("Anthropic config '{}' not found", config_name)),
|
||||
"openai" => self
|
||||
.providers
|
||||
.openai
|
||||
.get(&config_name)
|
||||
.map(ProviderConfigRef::OpenAI)
|
||||
.ok_or_else(|| anyhow::anyhow!("OpenAI config '{}' not found", config_name)),
|
||||
"databricks" => self
|
||||
.providers
|
||||
.databricks
|
||||
.get(&config_name)
|
||||
.map(ProviderConfigRef::Databricks)
|
||||
.ok_or_else(|| anyhow::anyhow!("Databricks config '{}' not found", config_name)),
|
||||
"embedded" => self
|
||||
.providers
|
||||
.embedded
|
||||
.get(&config_name)
|
||||
.map(ProviderConfigRef::Embedded)
|
||||
.ok_or_else(|| anyhow::anyhow!("Embedded config '{}' not found", config_name)),
|
||||
_ => self
|
||||
.providers
|
||||
.openai_compatible
|
||||
.get(&provider_type)
|
||||
.map(ProviderConfigRef::OpenAICompatible)
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!("OpenAI compatible config '{}' not found", provider_type)
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user