can choose per mode models for auto mode
This commit is contained in:
131
crates/g3-config/src/autonomous_config_tests.rs
Normal file
131
crates/g3-config/src/autonomous_config_tests.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
#[cfg(test)]
|
||||
mod autonomous_config_tests {
|
||||
use crate::{Config, AnthropicConfig, DatabricksConfig};
|
||||
|
||||
#[test]
|
||||
fn test_default_autonomous_config() {
|
||||
let config = Config::default();
|
||||
assert!(config.autonomous.coach_provider.is_none());
|
||||
assert!(config.autonomous.coach_model.is_none());
|
||||
assert!(config.autonomous.player_provider.is_none());
|
||||
assert!(config.autonomous.player_model.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_for_coach_with_overrides() {
|
||||
let mut config = Config::default();
|
||||
|
||||
// Set up base config with anthropic
|
||||
config.providers.anthropic = Some(AnthropicConfig {
|
||||
api_key: "test-key".to_string(),
|
||||
model: "claude-3-5-sonnet-20241022".to_string(),
|
||||
max_tokens: Some(4096),
|
||||
temperature: Some(0.1),
|
||||
});
|
||||
|
||||
// Set coach overrides
|
||||
config.autonomous.coach_provider = Some("anthropic".to_string());
|
||||
config.autonomous.coach_model = Some("claude-3-opus-20240229".to_string());
|
||||
|
||||
let coach_config = config.for_coach().unwrap();
|
||||
|
||||
// Verify coach uses overridden provider and model
|
||||
assert_eq!(coach_config.providers.default_provider, "anthropic");
|
||||
assert_eq!(
|
||||
coach_config.providers.anthropic.as_ref().unwrap().model,
|
||||
"claude-3-opus-20240229"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_for_player_with_overrides() {
|
||||
let mut config = Config::default();
|
||||
|
||||
// Set up base config with databricks
|
||||
config.providers.databricks = Some(DatabricksConfig {
|
||||
host: "https://test.databricks.com".to_string(),
|
||||
token: Some("test-token".to_string()),
|
||||
model: "databricks-meta-llama-3-1-70b-instruct".to_string(),
|
||||
max_tokens: Some(4096),
|
||||
temperature: Some(0.1),
|
||||
use_oauth: Some(false),
|
||||
});
|
||||
|
||||
// Set player overrides
|
||||
config.autonomous.player_provider = Some("databricks".to_string());
|
||||
config.autonomous.player_model = Some("databricks-dbrx-instruct".to_string());
|
||||
|
||||
let player_config = config.for_player().unwrap();
|
||||
|
||||
// Verify player uses overridden provider and model
|
||||
assert_eq!(player_config.providers.default_provider, "databricks");
|
||||
assert_eq!(
|
||||
player_config.providers.databricks.as_ref().unwrap().model,
|
||||
"databricks-dbrx-instruct"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_overrides_uses_defaults() {
|
||||
let mut config = Config::default();
|
||||
config.providers.default_provider = "databricks".to_string();
|
||||
|
||||
let coach_config = config.for_coach().unwrap();
|
||||
let player_config = config.for_player().unwrap();
|
||||
|
||||
// Both should use the default provider when no overrides
|
||||
assert_eq!(coach_config.providers.default_provider, "databricks");
|
||||
assert_eq!(player_config.providers.default_provider, "databricks");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_override_only() {
|
||||
let mut config = Config::default();
|
||||
|
||||
config.providers.anthropic = Some(AnthropicConfig {
|
||||
api_key: "test-key".to_string(),
|
||||
model: "claude-3-5-sonnet-20241022".to_string(),
|
||||
max_tokens: Some(4096),
|
||||
temperature: Some(0.1),
|
||||
});
|
||||
|
||||
// Only override provider, not model
|
||||
config.autonomous.coach_provider = Some("anthropic".to_string());
|
||||
|
||||
let coach_config = config.for_coach().unwrap();
|
||||
|
||||
// Should use overridden provider with its default model
|
||||
assert_eq!(coach_config.providers.default_provider, "anthropic");
|
||||
assert_eq!(
|
||||
coach_config.providers.anthropic.as_ref().unwrap().model,
|
||||
"claude-3-5-sonnet-20241022"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_override_only() {
|
||||
let mut config = Config::default();
|
||||
config.providers.default_provider = "databricks".to_string();
|
||||
|
||||
config.providers.databricks = Some(DatabricksConfig {
|
||||
host: "https://test.databricks.com".to_string(),
|
||||
token: Some("test-token".to_string()),
|
||||
model: "databricks-meta-llama-3-1-70b-instruct".to_string(),
|
||||
max_tokens: Some(4096),
|
||||
temperature: Some(0.1),
|
||||
use_oauth: Some(false),
|
||||
});
|
||||
|
||||
// Only override model, not provider
|
||||
config.autonomous.player_model = Some("databricks-dbrx-instruct".to_string());
|
||||
|
||||
let player_config = config.for_player().unwrap();
|
||||
|
||||
// Should use default provider with overridden model
|
||||
assert_eq!(player_config.providers.default_provider, "databricks");
|
||||
assert_eq!(
|
||||
player_config.providers.databricks.as_ref().unwrap().model,
|
||||
"databricks-dbrx-instruct"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -2,12 +2,16 @@ use serde::{Deserialize, Serialize};
|
||||
use anyhow::Result;
|
||||
use std::path::Path;
|
||||
|
||||
#[cfg(test)]
|
||||
mod autonomous_config_tests;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Config {
|
||||
pub providers: ProvidersConfig,
|
||||
pub agent: AgentConfig,
|
||||
pub computer_control: ComputerControlConfig,
|
||||
pub webdriver: WebDriverConfig,
|
||||
pub autonomous: AutonomousConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -86,6 +90,20 @@ impl Default for WebDriverConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AutonomousConfig {
|
||||
pub coach_provider: Option<String>,
|
||||
pub coach_model: Option<String>,
|
||||
pub player_provider: Option<String>,
|
||||
pub player_model: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for AutonomousConfig {
|
||||
fn default() -> Self {
|
||||
Self { coach_provider: None, coach_model: None, player_provider: None, player_model: None }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ComputerControlConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@@ -120,6 +138,7 @@ impl Default for Config {
|
||||
},
|
||||
computer_control: ComputerControlConfig::default(),
|
||||
webdriver: WebDriverConfig::default(),
|
||||
autonomous: AutonomousConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -232,6 +251,7 @@ impl Config {
|
||||
},
|
||||
computer_control: ComputerControlConfig::default(),
|
||||
webdriver: WebDriverConfig::default(),
|
||||
autonomous: AutonomousConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,4 +320,78 @@ impl Config {
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Create a config for the coach agent in autonomous mode
|
||||
pub fn for_coach(&self) -> Result<Self> {
|
||||
let mut config = self.clone();
|
||||
|
||||
// Apply coach-specific overrides if configured
|
||||
if let Some(ref coach_provider) = self.autonomous.coach_provider {
|
||||
config.providers.default_provider = coach_provider.clone();
|
||||
}
|
||||
|
||||
if let Some(ref coach_model) = self.autonomous.coach_model {
|
||||
// Apply model override to the coach's provider
|
||||
match config.providers.default_provider.as_str() {
|
||||
"anthropic" => {
|
||||
if let Some(ref mut anthropic) = config.providers.anthropic {
|
||||
anthropic.model = coach_model.clone();
|
||||
} else {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Coach provider 'anthropic' is not configured. Please add anthropic configuration to your config file."
|
||||
));
|
||||
}
|
||||
}
|
||||
"databricks" => {
|
||||
if let Some(ref mut databricks) = config.providers.databricks {
|
||||
databricks.model = coach_model.clone();
|
||||
} else {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Coach provider 'databricks' is not configured. Please add databricks configuration to your config file."
|
||||
));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Create a config for the player agent in autonomous mode
|
||||
pub fn for_player(&self) -> Result<Self> {
|
||||
let mut config = self.clone();
|
||||
|
||||
// Apply player-specific overrides if configured
|
||||
if let Some(ref player_provider) = self.autonomous.player_provider {
|
||||
config.providers.default_provider = player_provider.clone();
|
||||
}
|
||||
|
||||
if let Some(ref player_model) = self.autonomous.player_model {
|
||||
// Apply model override to the player's provider
|
||||
match config.providers.default_provider.as_str() {
|
||||
"anthropic" => {
|
||||
if let Some(ref mut anthropic) = config.providers.anthropic {
|
||||
anthropic.model = player_model.clone();
|
||||
} else {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Player provider 'anthropic' is not configured. Please add anthropic configuration to your config file."
|
||||
));
|
||||
}
|
||||
}
|
||||
"databricks" => {
|
||||
if let Some(ref mut databricks) = config.providers.databricks {
|
||||
databricks.model = player_model.clone();
|
||||
} else {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Player provider 'databricks' is not configured. Please add databricks configuration to your config file."
|
||||
));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user