Implement planning mode
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
/// Main configuration structure
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Config {
|
||||
pub providers: ProvidersConfig,
|
||||
@@ -11,18 +13,40 @@ pub struct Config {
|
||||
pub macax: MacAxConfig,
|
||||
}
|
||||
|
||||
/// Provider configuration with named configs per provider type
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProvidersConfig {
|
||||
pub openai: Option<OpenAIConfig>,
|
||||
/// Default provider in format "<provider_type>.<config_name>"
|
||||
pub default_provider: String,
|
||||
|
||||
/// Provider for planner mode (optional, falls back to default_provider)
|
||||
pub planner: Option<String>,
|
||||
|
||||
/// Provider for coach in autonomous mode (optional, falls back to default_provider)
|
||||
pub coach: Option<String>,
|
||||
|
||||
/// Provider for player in autonomous mode (optional, falls back to default_provider)
|
||||
pub player: Option<String>,
|
||||
|
||||
/// Named Anthropic provider configs
|
||||
#[serde(default)]
|
||||
pub anthropic: HashMap<String, AnthropicConfig>,
|
||||
|
||||
/// Named OpenAI provider configs
|
||||
#[serde(default)]
|
||||
pub openai: HashMap<String, OpenAIConfig>,
|
||||
|
||||
/// Named Databricks provider configs
|
||||
#[serde(default)]
|
||||
pub databricks: HashMap<String, DatabricksConfig>,
|
||||
|
||||
/// Named embedded provider configs
|
||||
#[serde(default)]
|
||||
pub embedded: HashMap<String, EmbeddedConfig>,
|
||||
|
||||
/// Multiple named OpenAI-compatible providers (e.g., openrouter, groq, etc.)
|
||||
#[serde(default)]
|
||||
pub openai_compatible: std::collections::HashMap<String, OpenAIConfig>,
|
||||
pub anthropic: Option<AnthropicConfig>,
|
||||
pub databricks: Option<DatabricksConfig>,
|
||||
pub embedded: Option<EmbeddedConfig>,
|
||||
pub default_provider: String,
|
||||
pub coach: Option<String>, // Provider to use for coach in autonomous mode
|
||||
pub player: Option<String>, // Provider to use for player in autonomous mode
|
||||
pub openai_compatible: HashMap<String, OpenAIConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -40,30 +64,30 @@ pub struct AnthropicConfig {
|
||||
pub model: String,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub temperature: Option<f32>,
|
||||
pub cache_config: Option<String>, // "ephemeral", "5minute", "1hour", or None to disable
|
||||
pub enable_1m_context: Option<bool>, // Enable 1m context window (costs extra)
|
||||
pub thinking_budget_tokens: Option<u32>, // Budget tokens for extended thinking
|
||||
pub cache_config: Option<String>,
|
||||
pub enable_1m_context: Option<bool>,
|
||||
pub thinking_budget_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DatabricksConfig {
|
||||
pub host: String,
|
||||
pub token: Option<String>, // Optional - will use OAuth if not provided
|
||||
pub token: Option<String>,
|
||||
pub model: String,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub temperature: Option<f32>,
|
||||
pub use_oauth: Option<bool>, // Default to true if token not provided
|
||||
pub use_oauth: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddedConfig {
|
||||
pub model_path: String,
|
||||
pub model_type: String, // e.g., "llama", "mistral", "codellama"
|
||||
pub model_type: String,
|
||||
pub context_length: Option<u32>,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub temperature: Option<f32>,
|
||||
pub gpu_layers: Option<u32>, // Number of layers to offload to GPU
|
||||
pub threads: Option<u32>, // Number of CPU threads to use
|
||||
pub gpu_layers: Option<u32>,
|
||||
pub threads: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -120,7 +144,7 @@ impl Default for WebDriverConfig {
|
||||
impl Default for ComputerControlConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false, // Disabled by default for safety
|
||||
enabled: false,
|
||||
require_confirmation: true,
|
||||
max_actions_per_second: 5,
|
||||
}
|
||||
@@ -129,23 +153,30 @@ impl Default for ComputerControlConfig {
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
let mut databricks_configs = HashMap::new();
|
||||
databricks_configs.insert(
|
||||
"default".to_string(),
|
||||
DatabricksConfig {
|
||||
host: "https://your-workspace.cloud.databricks.com".to_string(),
|
||||
token: None,
|
||||
model: "databricks-claude-sonnet-4".to_string(),
|
||||
max_tokens: Some(4096),
|
||||
temperature: Some(0.1),
|
||||
use_oauth: Some(true),
|
||||
},
|
||||
);
|
||||
|
||||
Self {
|
||||
providers: ProvidersConfig {
|
||||
openai: None,
|
||||
openai_compatible: std::collections::HashMap::new(),
|
||||
anthropic: None,
|
||||
databricks: Some(DatabricksConfig {
|
||||
host: "https://your-workspace.cloud.databricks.com".to_string(),
|
||||
token: None, // Will use OAuth by default
|
||||
model: "databricks-claude-sonnet-4".to_string(),
|
||||
max_tokens: Some(4096),
|
||||
temperature: Some(0.1),
|
||||
use_oauth: Some(true),
|
||||
}),
|
||||
embedded: None,
|
||||
default_provider: "databricks".to_string(),
|
||||
coach: None, // Will use default_provider if not specified
|
||||
player: None, // Will use default_provider if not specified
|
||||
default_provider: "databricks.default".to_string(),
|
||||
planner: None,
|
||||
coach: None,
|
||||
player: None,
|
||||
anthropic: HashMap::new(),
|
||||
openai: HashMap::new(),
|
||||
databricks: databricks_configs,
|
||||
embedded: HashMap::new(),
|
||||
openai_compatible: HashMap::new(),
|
||||
},
|
||||
agent: AgentConfig {
|
||||
max_context_length: None,
|
||||
@@ -165,26 +196,54 @@ impl Default for Config {
|
||||
}
|
||||
}
|
||||
|
||||
/// Error message for old config format
|
||||
const OLD_CONFIG_FORMAT_ERROR: &str = r#"Your configuration file uses an old format that is no longer supported.
|
||||
|
||||
Please update your configuration to use the new provider format:
|
||||
|
||||
```toml
|
||||
[providers]
|
||||
default_provider = "anthropic.default" # Format: "<provider_type>.<config_name>"
|
||||
planner = "anthropic.planner" # Optional: specific provider for planner
|
||||
coach = "anthropic.default" # Optional: specific provider for coach
|
||||
player = "openai.player" # Optional: specific provider for player
|
||||
|
||||
# Named configs per provider type
|
||||
[providers.anthropic.default]
|
||||
api_key = "your-api-key"
|
||||
model = "claude-sonnet-4-5"
|
||||
max_tokens = 64000
|
||||
|
||||
[providers.anthropic.planner]
|
||||
api_key = "your-api-key"
|
||||
model = "claude-opus-4-5"
|
||||
thinking_budget_tokens = 16000
|
||||
|
||||
[providers.openai.player]
|
||||
api_key = "your-api-key"
|
||||
model = "gpt-5"
|
||||
```
|
||||
|
||||
Each mode (planner, coach, player) can specify a full path like "<provider_type>.<config_name>".
|
||||
If not specified, they fall back to `default_provider`."#;
|
||||
|
||||
impl Config {
|
||||
pub fn load(config_path: Option<&str>) -> Result<Self> {
|
||||
// Check if any config file exists
|
||||
let config_exists = if let Some(path) = config_path {
|
||||
Path::new(path).exists()
|
||||
} else {
|
||||
// Check default locations
|
||||
let default_paths = ["./g3.toml", "~/.config/g3/config.toml", "~/.g3.toml"];
|
||||
|
||||
default_paths.iter().any(|path| {
|
||||
let expanded_path = shellexpand::tilde(path);
|
||||
Path::new(expanded_path.as_ref()).exists()
|
||||
})
|
||||
};
|
||||
|
||||
// If no config exists, create and save a default Databricks config
|
||||
// If no config exists, create and save a default config
|
||||
if !config_exists {
|
||||
let databricks_config = Self::default();
|
||||
let default_config = Self::default();
|
||||
|
||||
// Save to default location
|
||||
let config_dir = dirs::home_dir()
|
||||
.map(|mut path| {
|
||||
path.push(".config");
|
||||
@@ -193,89 +252,171 @@ impl Config {
|
||||
})
|
||||
.unwrap_or_else(|| std::path::PathBuf::from("."));
|
||||
|
||||
// Create directory if it doesn't exist
|
||||
std::fs::create_dir_all(&config_dir).ok();
|
||||
|
||||
let config_file = config_dir.join("config.toml");
|
||||
if let Err(e) = databricks_config.save(config_file.to_str().unwrap()) {
|
||||
if let Err(e) = default_config.save(config_file.to_str().unwrap()) {
|
||||
eprintln!("Warning: Could not save default config: {}", e);
|
||||
} else {
|
||||
println!(
|
||||
"Created default Databricks configuration at: {}",
|
||||
"Created default configuration at: {}",
|
||||
config_file.display()
|
||||
);
|
||||
}
|
||||
|
||||
return Ok(databricks_config);
|
||||
return Ok(default_config);
|
||||
}
|
||||
|
||||
// Existing config loading logic
|
||||
let mut settings = config::Config::builder();
|
||||
|
||||
// Load default configuration
|
||||
settings = settings.add_source(config::Config::try_from(&Config::default())?);
|
||||
|
||||
// Load from config file if provided
|
||||
if let Some(path) = config_path {
|
||||
if Path::new(path).exists() {
|
||||
settings = settings.add_source(config::File::with_name(path));
|
||||
}
|
||||
// Load config from file
|
||||
let config_path_to_load = if let Some(path) = config_path {
|
||||
Some(path.to_string())
|
||||
} else {
|
||||
// Try to load from default locations
|
||||
let default_paths = ["./g3.toml", "~/.config/g3/config.toml", "~/.g3.toml"];
|
||||
|
||||
for path in &default_paths {
|
||||
default_paths.iter().find_map(|path| {
|
||||
let expanded_path = shellexpand::tilde(path);
|
||||
if Path::new(expanded_path.as_ref()).exists() {
|
||||
settings = settings.add_source(config::File::with_name(expanded_path.as_ref()));
|
||||
break;
|
||||
Some(expanded_path.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
if let Some(path) = config_path_to_load {
|
||||
// Read and parse the config file
|
||||
let config_content = std::fs::read_to_string(&path)?;
|
||||
|
||||
// Check for old format (direct provider config without named configs)
|
||||
if Self::is_old_format(&config_content) {
|
||||
anyhow::bail!("{}", OLD_CONFIG_FORMAT_ERROR);
|
||||
}
|
||||
|
||||
let config: Config = toml::from_str(&config_content)?;
|
||||
|
||||
// Validate the default_provider format
|
||||
config.validate_provider_reference(&config.providers.default_provider)?;
|
||||
|
||||
return Ok(config);
|
||||
}
|
||||
|
||||
Ok(Self::default())
|
||||
}
|
||||
|
||||
/// Check if the config content uses the old format
|
||||
fn is_old_format(content: &str) -> bool {
|
||||
// Old format has [providers.anthropic] with api_key directly
|
||||
// New format has [providers.anthropic.<name>] with api_key
|
||||
|
||||
// Parse as TOML value to inspect structure
|
||||
if let Ok(value) = content.parse::<toml::Value>() {
|
||||
if let Some(providers) = value.get("providers") {
|
||||
if let Some(providers_table) = providers.as_table() {
|
||||
// Check anthropic section
|
||||
if let Some(anthropic) = providers_table.get("anthropic") {
|
||||
if let Some(anthropic_table) = anthropic.as_table() {
|
||||
// If anthropic has api_key directly, it's old format
|
||||
if anthropic_table.contains_key("api_key") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Check databricks section
|
||||
if let Some(databricks) = providers_table.get("databricks") {
|
||||
if let Some(databricks_table) = databricks.as_table() {
|
||||
// If databricks has host directly, it's old format
|
||||
if databricks_table.contains_key("host") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Check openai section
|
||||
if let Some(openai) = providers_table.get("openai") {
|
||||
if let Some(openai_table) = openai.as_table() {
|
||||
// If openai has api_key directly, it's old format
|
||||
if openai_table.contains_key("api_key") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Validate a provider reference (format: "<provider_type>.<config_name>")
|
||||
fn validate_provider_reference(&self, reference: &str) -> Result<()> {
|
||||
let parts: Vec<&str> = reference.split('.').collect();
|
||||
if parts.len() != 2 {
|
||||
anyhow::bail!(
|
||||
"Invalid provider reference '{}'. Expected format: '<provider_type>.<config_name>'",
|
||||
reference
|
||||
);
|
||||
}
|
||||
|
||||
let (provider_type, config_name) = (parts[0], parts[1]);
|
||||
|
||||
match provider_type {
|
||||
"anthropic" => {
|
||||
if !self.providers.anthropic.contains_key(config_name) {
|
||||
anyhow::bail!(
|
||||
"Provider config 'anthropic.{}' not found. Available: {:?}",
|
||||
config_name,
|
||||
self.providers.anthropic.keys().collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
}
|
||||
"openai" => {
|
||||
if !self.providers.openai.contains_key(config_name) {
|
||||
anyhow::bail!(
|
||||
"Provider config 'openai.{}' not found. Available: {:?}",
|
||||
config_name,
|
||||
self.providers.openai.keys().collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
}
|
||||
"databricks" => {
|
||||
if !self.providers.databricks.contains_key(config_name) {
|
||||
anyhow::bail!(
|
||||
"Provider config 'databricks.{}' not found. Available: {:?}",
|
||||
config_name,
|
||||
self.providers.databricks.keys().collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
}
|
||||
"embedded" => {
|
||||
if !self.providers.embedded.contains_key(config_name) {
|
||||
anyhow::bail!(
|
||||
"Provider config 'embedded.{}' not found. Available: {:?}",
|
||||
config_name,
|
||||
self.providers.embedded.keys().collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Check openai_compatible providers
|
||||
if !self.providers.openai_compatible.contains_key(provider_type) {
|
||||
anyhow::bail!(
|
||||
"Unknown provider type '{}'. Valid types: anthropic, openai, databricks, embedded, or openai_compatible names",
|
||||
provider_type
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Override with environment variables
|
||||
settings = settings.add_source(config::Environment::with_prefix("G3").separator("_"));
|
||||
|
||||
let config = settings.build()?.try_deserialize()?;
|
||||
Ok(config)
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn default_qwen_config() -> Self {
|
||||
Self {
|
||||
providers: ProvidersConfig {
|
||||
openai: None,
|
||||
openai_compatible: std::collections::HashMap::new(),
|
||||
anthropic: None,
|
||||
databricks: None,
|
||||
embedded: Some(EmbeddedConfig {
|
||||
model_path: "~/.cache/g3/models/qwen2.5-7b-instruct-q3_k_m.gguf".to_string(),
|
||||
model_type: "qwen".to_string(),
|
||||
context_length: Some(32768), // Qwen2.5 supports 32k context
|
||||
max_tokens: Some(2048),
|
||||
temperature: Some(0.1),
|
||||
gpu_layers: Some(32),
|
||||
threads: Some(8),
|
||||
}),
|
||||
default_provider: "embedded".to_string(),
|
||||
coach: None, // Will use default_provider if not specified
|
||||
player: None, // Will use default_provider if not specified
|
||||
},
|
||||
agent: AgentConfig {
|
||||
max_context_length: None,
|
||||
fallback_default_max_tokens: 8192,
|
||||
enable_streaming: true,
|
||||
allow_multiple_tool_calls: false,
|
||||
timeout_seconds: 60,
|
||||
auto_compact: true,
|
||||
max_retry_attempts: 3,
|
||||
autonomous_max_retry_attempts: 6,
|
||||
check_todo_staleness: true,
|
||||
},
|
||||
computer_control: ComputerControlConfig::default(),
|
||||
webdriver: WebDriverConfig::default(),
|
||||
macax: MacAxConfig::default(),
|
||||
/// Parse a provider reference into (provider_type, config_name)
|
||||
pub fn parse_provider_reference(reference: &str) -> Result<(String, String)> {
|
||||
let parts: Vec<&str> = reference.split('.').collect();
|
||||
if parts.len() != 2 {
|
||||
anyhow::bail!(
|
||||
"Invalid provider reference '{}'. Expected format: '<provider_type>.<config_name>'",
|
||||
reference
|
||||
);
|
||||
}
|
||||
Ok((parts[0].to_string(), parts[1].to_string()))
|
||||
}
|
||||
|
||||
pub fn save(&self, path: &str) -> Result<()> {
|
||||
@@ -289,58 +430,72 @@ impl Config {
|
||||
provider_override: Option<String>,
|
||||
model_override: Option<String>,
|
||||
) -> Result<Self> {
|
||||
// Load the base configuration
|
||||
let mut config = Self::load(config_path)?;
|
||||
|
||||
// Apply provider override
|
||||
if let Some(provider) = provider_override {
|
||||
// Validate the override
|
||||
config.validate_provider_reference(&provider)?;
|
||||
config.providers.default_provider = provider;
|
||||
}
|
||||
|
||||
// Apply model override to the active provider
|
||||
if let Some(model) = model_override {
|
||||
match config.providers.default_provider.as_str() {
|
||||
let (provider_type, config_name) = Self::parse_provider_reference(
|
||||
&config.providers.default_provider
|
||||
)?;
|
||||
|
||||
match provider_type.as_str() {
|
||||
"anthropic" => {
|
||||
if let Some(ref mut anthropic) = config.providers.anthropic {
|
||||
anthropic.model = model;
|
||||
if let Some(ref mut anthropic_config) = config.providers.anthropic.get_mut(&config_name) {
|
||||
anthropic_config.model = model;
|
||||
} else {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Provider 'anthropic' is not configured. Please add anthropic configuration to your config file."
|
||||
"Provider config 'anthropic.{}' not found.",
|
||||
config_name
|
||||
));
|
||||
}
|
||||
}
|
||||
"databricks" => {
|
||||
if let Some(ref mut databricks) = config.providers.databricks {
|
||||
databricks.model = model;
|
||||
if let Some(ref mut databricks_config) = config.providers.databricks.get_mut(&config_name) {
|
||||
databricks_config.model = model;
|
||||
} else {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Provider 'databricks' is not configured. Please add databricks configuration to your config file."
|
||||
"Provider config 'databricks.{}' not found.",
|
||||
config_name
|
||||
));
|
||||
}
|
||||
}
|
||||
"embedded" => {
|
||||
if let Some(ref mut embedded) = config.providers.embedded {
|
||||
embedded.model_path = model;
|
||||
if let Some(ref mut embedded_config) = config.providers.embedded.get_mut(&config_name) {
|
||||
embedded_config.model_path = model;
|
||||
} else {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Provider 'embedded' is not configured. Please add embedded configuration to your config file."
|
||||
"Provider config 'embedded.{}' not found.",
|
||||
config_name
|
||||
));
|
||||
}
|
||||
}
|
||||
"openai" => {
|
||||
if let Some(ref mut openai) = config.providers.openai {
|
||||
openai.model = model;
|
||||
if let Some(ref mut openai_config) = config.providers.openai.get_mut(&config_name) {
|
||||
openai_config.model = model;
|
||||
} else {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Provider 'openai' is not configured. Please add openai configuration to your config file."
|
||||
"Provider config 'openai.{}' not found.",
|
||||
config_name
|
||||
));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Unknown provider: {}",
|
||||
config.providers.default_provider
|
||||
))
|
||||
// Check openai_compatible
|
||||
if let Some(ref mut compat_config) = config.providers.openai_compatible.get_mut(&provider_type) {
|
||||
compat_config.model = model;
|
||||
} else {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Unknown provider type: {}",
|
||||
provider_type
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -348,7 +503,15 @@ impl Config {
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Get the provider to use for coach mode in autonomous execution
|
||||
/// Get the provider reference for planner mode
|
||||
pub fn get_planner_provider(&self) -> &str {
|
||||
self.providers
|
||||
.planner
|
||||
.as_deref()
|
||||
.unwrap_or(&self.providers.default_provider)
|
||||
}
|
||||
|
||||
/// Get the provider reference for coach mode in autonomous execution
|
||||
pub fn get_coach_provider(&self) -> &str {
|
||||
self.providers
|
||||
.coach
|
||||
@@ -356,7 +519,7 @@ impl Config {
|
||||
.unwrap_or(&self.providers.default_provider)
|
||||
}
|
||||
|
||||
/// Get the provider to use for player mode in autonomous execution
|
||||
/// Get the provider reference for player mode in autonomous execution
|
||||
pub fn get_player_provider(&self) -> &str {
|
||||
self.providers
|
||||
.player
|
||||
@@ -365,41 +528,20 @@ impl Config {
|
||||
}
|
||||
|
||||
/// Create a copy of the config with a different default provider
|
||||
pub fn with_provider_override(&self, provider: &str) -> Result<Self> {
|
||||
pub fn with_provider_override(&self, provider_ref: &str) -> Result<Self> {
|
||||
// Validate that the provider is configured
|
||||
match provider {
|
||||
"anthropic" if self.providers.anthropic.is_none() => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
|
||||
provider, provider
|
||||
));
|
||||
}
|
||||
"databricks" if self.providers.databricks.is_none() => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
|
||||
provider, provider
|
||||
));
|
||||
}
|
||||
"embedded" if self.providers.embedded.is_none() => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
|
||||
provider, provider
|
||||
));
|
||||
}
|
||||
"openai" if self.providers.openai.is_none() => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
|
||||
provider, provider
|
||||
));
|
||||
}
|
||||
_ => {} // Provider is configured or unknown (will be caught later)
|
||||
}
|
||||
self.validate_provider_reference(provider_ref)?;
|
||||
|
||||
let mut config = self.clone();
|
||||
config.providers.default_provider = provider.to_string();
|
||||
config.providers.default_provider = provider_ref.to_string();
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Create a copy of the config for planner mode
|
||||
pub fn for_planner(&self) -> Result<Self> {
|
||||
self.with_provider_override(self.get_planner_provider())
|
||||
}
|
||||
|
||||
/// Create a copy of the config for coach mode in autonomous execution
|
||||
pub fn for_coach(&self) -> Result<Self> {
|
||||
self.with_provider_override(self.get_coach_provider())
|
||||
@@ -409,6 +551,71 @@ impl Config {
|
||||
pub fn for_player(&self) -> Result<Self> {
|
||||
self.with_provider_override(self.get_player_provider())
|
||||
}
|
||||
|
||||
/// Get Anthropic config by name
|
||||
pub fn get_anthropic_config(&self, name: &str) -> Option<&AnthropicConfig> {
|
||||
self.providers.anthropic.get(name)
|
||||
}
|
||||
|
||||
/// Get OpenAI config by name
|
||||
pub fn get_openai_config(&self, name: &str) -> Option<&OpenAIConfig> {
|
||||
self.providers.openai.get(name)
|
||||
}
|
||||
|
||||
/// Get Databricks config by name
|
||||
pub fn get_databricks_config(&self, name: &str) -> Option<&DatabricksConfig> {
|
||||
self.providers.databricks.get(name)
|
||||
}
|
||||
|
||||
/// Get Embedded config by name
|
||||
pub fn get_embedded_config(&self, name: &str) -> Option<&EmbeddedConfig> {
|
||||
self.providers.embedded.get(name)
|
||||
}
|
||||
|
||||
/// Get the current default provider's config
|
||||
pub fn get_default_provider_config(&self) -> Result<ProviderConfigRef<'_>> {
|
||||
let (provider_type, config_name) = Self::parse_provider_reference(
|
||||
&self.providers.default_provider
|
||||
)?;
|
||||
|
||||
match provider_type.as_str() {
|
||||
"anthropic" => {
|
||||
self.providers.anthropic.get(&config_name)
|
||||
.map(ProviderConfigRef::Anthropic)
|
||||
.ok_or_else(|| anyhow::anyhow!("Anthropic config '{}' not found", config_name))
|
||||
}
|
||||
"openai" => {
|
||||
self.providers.openai.get(&config_name)
|
||||
.map(ProviderConfigRef::OpenAI)
|
||||
.ok_or_else(|| anyhow::anyhow!("OpenAI config '{}' not found", config_name))
|
||||
}
|
||||
"databricks" => {
|
||||
self.providers.databricks.get(&config_name)
|
||||
.map(ProviderConfigRef::Databricks)
|
||||
.ok_or_else(|| anyhow::anyhow!("Databricks config '{}' not found", config_name))
|
||||
}
|
||||
"embedded" => {
|
||||
self.providers.embedded.get(&config_name)
|
||||
.map(ProviderConfigRef::Embedded)
|
||||
.ok_or_else(|| anyhow::anyhow!("Embedded config '{}' not found", config_name))
|
||||
}
|
||||
_ => {
|
||||
self.providers.openai_compatible.get(&provider_type)
|
||||
.map(ProviderConfigRef::OpenAICompatible)
|
||||
.ok_or_else(|| anyhow::anyhow!("OpenAI compatible config '{}' not found", provider_type))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reference to a provider configuration
|
||||
#[derive(Debug)]
|
||||
pub enum ProviderConfigRef<'a> {
|
||||
Anthropic(&'a AnthropicConfig),
|
||||
OpenAI(&'a OpenAIConfig),
|
||||
Databricks(&'a DatabricksConfig),
|
||||
Embedded(&'a EmbeddedConfig),
|
||||
OpenAICompatible(&'a OpenAIConfig),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -4,29 +4,45 @@ mod tests {
|
||||
use std::fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn test_config_footer() -> &'static str {
|
||||
r#"
|
||||
[computer_control]
|
||||
enabled = false
|
||||
require_confirmation = true
|
||||
max_actions_per_second = 10
|
||||
|
||||
[webdriver]
|
||||
enabled = false
|
||||
safari_port = 4444
|
||||
|
||||
[macax]
|
||||
enabled = false
|
||||
"#
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coach_player_providers() {
|
||||
// Create a temporary directory for the test config
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let config_path = temp_dir.path().join("test_config.toml");
|
||||
|
||||
// Write a test configuration with coach and player providers
|
||||
let config_content = r#"
|
||||
// Write a test configuration with coach and player providers (new format)
|
||||
let config_content = format!(r#"
|
||||
[providers]
|
||||
default_provider = "databricks"
|
||||
coach = "anthropic"
|
||||
player = "embedded"
|
||||
default_provider = "databricks.default"
|
||||
coach = "anthropic.default"
|
||||
player = "embedded.local"
|
||||
|
||||
[providers.databricks]
|
||||
[providers.databricks.default]
|
||||
host = "https://test.databricks.com"
|
||||
token = "test-token"
|
||||
model = "test-model"
|
||||
|
||||
[providers.anthropic]
|
||||
[providers.anthropic.default]
|
||||
api_key = "test-key"
|
||||
model = "claude-3"
|
||||
|
||||
[providers.embedded]
|
||||
[providers.embedded.local]
|
||||
model_path = "test.gguf"
|
||||
model_type = "llama"
|
||||
|
||||
@@ -34,7 +50,11 @@ model_type = "llama"
|
||||
fallback_default_max_tokens = 8192
|
||||
enable_streaming = true
|
||||
timeout_seconds = 60
|
||||
"#;
|
||||
auto_compact = true
|
||||
allow_multiple_tool_calls = false
|
||||
max_retry_attempts = 3
|
||||
autonomous_max_retry_attempts = 6
|
||||
{}"#, test_config_footer());
|
||||
|
||||
fs::write(&config_path, config_content).unwrap();
|
||||
|
||||
@@ -42,17 +62,17 @@ timeout_seconds = 60
|
||||
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
|
||||
|
||||
// Test that the providers are correctly identified
|
||||
assert_eq!(config.providers.default_provider, "databricks");
|
||||
assert_eq!(config.get_coach_provider(), "anthropic");
|
||||
assert_eq!(config.get_player_provider(), "embedded");
|
||||
assert_eq!(config.providers.default_provider, "databricks.default");
|
||||
assert_eq!(config.get_coach_provider(), "anthropic.default");
|
||||
assert_eq!(config.get_player_provider(), "embedded.local");
|
||||
|
||||
// Test creating coach config
|
||||
let coach_config = config.for_coach().unwrap();
|
||||
assert_eq!(coach_config.providers.default_provider, "anthropic");
|
||||
assert_eq!(coach_config.providers.default_provider, "anthropic.default");
|
||||
|
||||
// Test creating player config
|
||||
let player_config = config.for_player().unwrap();
|
||||
assert_eq!(player_config.providers.default_provider, "embedded");
|
||||
assert_eq!(player_config.providers.default_provider, "embedded.local");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -61,12 +81,12 @@ timeout_seconds = 60
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let config_path = temp_dir.path().join("test_config.toml");
|
||||
|
||||
// Write a test configuration WITHOUT coach and player providers
|
||||
let config_content = r#"
|
||||
// Write a test configuration WITHOUT coach and player providers (new format)
|
||||
let config_content = format!(r#"
|
||||
[providers]
|
||||
default_provider = "databricks"
|
||||
default_provider = "databricks.default"
|
||||
|
||||
[providers.databricks]
|
||||
[providers.databricks.default]
|
||||
host = "https://test.databricks.com"
|
||||
token = "test-token"
|
||||
model = "test-model"
|
||||
@@ -75,7 +95,11 @@ model = "test-model"
|
||||
fallback_default_max_tokens = 8192
|
||||
enable_streaming = true
|
||||
timeout_seconds = 60
|
||||
"#;
|
||||
auto_compact = true
|
||||
allow_multiple_tool_calls = false
|
||||
max_retry_attempts = 3
|
||||
autonomous_max_retry_attempts = 6
|
||||
{}"#, test_config_footer());
|
||||
|
||||
fs::write(&config_path, config_content).unwrap();
|
||||
|
||||
@@ -83,16 +107,16 @@ timeout_seconds = 60
|
||||
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
|
||||
|
||||
// Test that coach and player fall back to default provider
|
||||
assert_eq!(config.get_coach_provider(), "databricks");
|
||||
assert_eq!(config.get_player_provider(), "databricks");
|
||||
assert_eq!(config.get_coach_provider(), "databricks.default");
|
||||
assert_eq!(config.get_player_provider(), "databricks.default");
|
||||
|
||||
// Test creating coach config (should use default)
|
||||
let coach_config = config.for_coach().unwrap();
|
||||
assert_eq!(coach_config.providers.default_provider, "databricks");
|
||||
assert_eq!(coach_config.providers.default_provider, "databricks.default");
|
||||
|
||||
// Test creating player config (should use default)
|
||||
let player_config = config.for_player().unwrap();
|
||||
assert_eq!(player_config.providers.default_provider, "databricks");
|
||||
assert_eq!(player_config.providers.default_provider, "databricks.default");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -101,13 +125,13 @@ timeout_seconds = 60
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let config_path = temp_dir.path().join("test_config.toml");
|
||||
|
||||
// Write a test configuration with an unconfigured provider
|
||||
let config_content = r#"
|
||||
// Write a test configuration with an unconfigured provider (new format)
|
||||
let config_content = format!(r#"
|
||||
[providers]
|
||||
default_provider = "databricks"
|
||||
coach = "openai" # OpenAI is not configured
|
||||
default_provider = "databricks.default"
|
||||
coach = "openai.default" # OpenAI default is not configured
|
||||
|
||||
[providers.databricks]
|
||||
[providers.databricks.default]
|
||||
host = "https://test.databricks.com"
|
||||
token = "test-token"
|
||||
model = "test-model"
|
||||
@@ -116,7 +140,11 @@ model = "test-model"
|
||||
fallback_default_max_tokens = 8192
|
||||
enable_streaming = true
|
||||
timeout_seconds = 60
|
||||
"#;
|
||||
auto_compact = true
|
||||
allow_multiple_tool_calls = false
|
||||
max_retry_attempts = 3
|
||||
autonomous_max_retry_attempts = 6
|
||||
{}"#, test_config_footer());
|
||||
|
||||
fs::write(&config_path, config_content).unwrap();
|
||||
|
||||
@@ -126,6 +154,123 @@ timeout_seconds = 60
|
||||
// Test that trying to create a coach config with unconfigured provider fails
|
||||
let result = config.for_coach();
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("not configured"));
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(err_msg.contains("not found") || err_msg.contains("not configured"),
|
||||
"Expected error message to contain 'not found' or 'not configured', got: {}", err_msg);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_old_format_detection() {
|
||||
// Create a temporary directory for the test config
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let config_path = temp_dir.path().join("test_config.toml");
|
||||
|
||||
// Write a test configuration with OLD format (api_key directly under [providers.anthropic])
|
||||
let config_content = format!(r#"
|
||||
[providers]
|
||||
default_provider = "anthropic"
|
||||
|
||||
[providers.anthropic]
|
||||
api_key = "test-key"
|
||||
model = "claude-3"
|
||||
|
||||
[agent]
|
||||
fallback_default_max_tokens = 8192
|
||||
enable_streaming = true
|
||||
timeout_seconds = 60
|
||||
auto_compact = true
|
||||
allow_multiple_tool_calls = false
|
||||
max_retry_attempts = 3
|
||||
autonomous_max_retry_attempts = 6
|
||||
{}"#, test_config_footer());
|
||||
|
||||
fs::write(&config_path, config_content).unwrap();
|
||||
|
||||
// Loading should fail with old format error
|
||||
let result = Config::load(Some(config_path.to_str().unwrap()));
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(err_msg.contains("old format") || err_msg.contains("no longer supported"),
|
||||
"Expected error about old format, got: {}", err_msg);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_planner_provider() {
|
||||
// Create a temporary directory for the test config
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let config_path = temp_dir.path().join("test_config.toml");
|
||||
|
||||
// Write a test configuration with planner provider (new format)
|
||||
let config_content = format!(r#"
|
||||
[providers]
|
||||
default_provider = "databricks.default"
|
||||
planner = "anthropic.planner"
|
||||
|
||||
[providers.databricks.default]
|
||||
host = "https://test.databricks.com"
|
||||
token = "test-token"
|
||||
model = "test-model"
|
||||
|
||||
[providers.anthropic.planner]
|
||||
api_key = "test-key"
|
||||
model = "claude-opus"
|
||||
thinking_budget_tokens = 16000
|
||||
|
||||
[agent]
|
||||
fallback_default_max_tokens = 8192
|
||||
enable_streaming = true
|
||||
timeout_seconds = 60
|
||||
auto_compact = true
|
||||
allow_multiple_tool_calls = false
|
||||
max_retry_attempts = 3
|
||||
autonomous_max_retry_attempts = 6
|
||||
{}"#, test_config_footer());
|
||||
|
||||
fs::write(&config_path, config_content).unwrap();
|
||||
|
||||
// Load the configuration
|
||||
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
|
||||
|
||||
// Test that the planner provider is correctly identified
|
||||
assert_eq!(config.get_planner_provider(), "anthropic.planner");
|
||||
|
||||
// Test creating planner config
|
||||
let planner_config = config.for_planner().unwrap();
|
||||
assert_eq!(planner_config.providers.default_provider, "anthropic.planner");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_planner_fallback_to_default() {
|
||||
// Create a temporary directory for the test config
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let config_path = temp_dir.path().join("test_config.toml");
|
||||
|
||||
// Write a test configuration WITHOUT planner provider
|
||||
let config_content = format!(r#"
|
||||
[providers]
|
||||
default_provider = "databricks.default"
|
||||
|
||||
[providers.databricks.default]
|
||||
host = "https://test.databricks.com"
|
||||
token = "test-token"
|
||||
model = "test-model"
|
||||
|
||||
[agent]
|
||||
fallback_default_max_tokens = 8192
|
||||
enable_streaming = true
|
||||
timeout_seconds = 60
|
||||
auto_compact = true
|
||||
allow_multiple_tool_calls = false
|
||||
max_retry_attempts = 3
|
||||
autonomous_max_retry_attempts = 6
|
||||
{}"#, test_config_footer());
|
||||
|
||||
fs::write(&config_path, config_content).unwrap();
|
||||
|
||||
// Load the configuration
|
||||
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
|
||||
|
||||
// Test that planner falls back to default provider
|
||||
assert_eq!(config.get_planner_provider(), "databricks.default");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user