coach/player provider split + add OpenAI
Allows coach and player LLM providers to be separately specified. Also adds OpenAI provider
This commit is contained in:
@@ -12,3 +12,6 @@ thiserror = { workspace = true }
|
||||
toml = "0.8"
|
||||
shellexpand = "3.0"
|
||||
dirs = "5.0"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.8"
|
||||
|
||||
@@ -17,6 +17,8 @@ pub struct ProvidersConfig {
|
||||
pub databricks: Option<DatabricksConfig>,
|
||||
pub embedded: Option<EmbeddedConfig>,
|
||||
pub default_provider: String,
|
||||
pub coach: Option<String>, // Provider to use for coach in autonomous mode
|
||||
pub player: Option<String>, // Provider to use for player in autonomous mode
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -112,6 +114,8 @@ impl Default for Config {
|
||||
}),
|
||||
embedded: None,
|
||||
default_provider: "databricks".to_string(),
|
||||
coach: None, // Will use default_provider if not specified
|
||||
player: None, // Will use default_provider if not specified
|
||||
},
|
||||
agent: AgentConfig {
|
||||
max_context_length: 8192,
|
||||
@@ -224,6 +228,8 @@ impl Config {
|
||||
threads: Some(8),
|
||||
}),
|
||||
default_provider: "embedded".to_string(),
|
||||
coach: None, // Will use default_provider if not specified
|
||||
player: None, // Will use default_provider if not specified
|
||||
},
|
||||
agent: AgentConfig {
|
||||
max_context_length: 8192,
|
||||
@@ -300,4 +306,67 @@ impl Config {
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Get the provider to use for coach mode in autonomous execution
|
||||
pub fn get_coach_provider(&self) -> &str {
|
||||
self.providers.coach
|
||||
.as_deref()
|
||||
.unwrap_or(&self.providers.default_provider)
|
||||
}
|
||||
|
||||
/// Get the provider to use for player mode in autonomous execution
|
||||
pub fn get_player_provider(&self) -> &str {
|
||||
self.providers.player
|
||||
.as_deref()
|
||||
.unwrap_or(&self.providers.default_provider)
|
||||
}
|
||||
|
||||
/// Create a copy of the config with a different default provider
|
||||
pub fn with_provider_override(&self, provider: &str) -> Result<Self> {
|
||||
// Validate that the provider is configured
|
||||
match provider {
|
||||
"anthropic" if self.providers.anthropic.is_none() => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
|
||||
provider, provider
|
||||
));
|
||||
}
|
||||
"databricks" if self.providers.databricks.is_none() => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
|
||||
provider, provider
|
||||
));
|
||||
}
|
||||
"embedded" if self.providers.embedded.is_none() => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
|
||||
provider, provider
|
||||
));
|
||||
}
|
||||
"openai" if self.providers.openai.is_none() => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
|
||||
provider, provider
|
||||
));
|
||||
}
|
||||
_ => {} // Provider is configured or unknown (will be caught later)
|
||||
}
|
||||
|
||||
let mut config = self.clone();
|
||||
config.providers.default_provider = provider.to_string();
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Create a copy of the config for coach mode in autonomous execution
|
||||
pub fn for_coach(&self) -> Result<Self> {
|
||||
self.with_provider_override(self.get_coach_provider())
|
||||
}
|
||||
|
||||
/// Create a copy of the config for player mode in autonomous execution
|
||||
pub fn for_player(&self) -> Result<Self> {
|
||||
self.with_provider_override(self.get_player_provider())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
131
crates/g3-config/src/tests.rs
Normal file
131
crates/g3-config/src/tests.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::Config;
|
||||
use std::fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn test_coach_player_providers() {
|
||||
// Create a temporary directory for the test config
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let config_path = temp_dir.path().join("test_config.toml");
|
||||
|
||||
// Write a test configuration with coach and player providers
|
||||
let config_content = r#"
|
||||
[providers]
|
||||
default_provider = "databricks"
|
||||
coach = "anthropic"
|
||||
player = "embedded"
|
||||
|
||||
[providers.databricks]
|
||||
host = "https://test.databricks.com"
|
||||
token = "test-token"
|
||||
model = "test-model"
|
||||
|
||||
[providers.anthropic]
|
||||
api_key = "test-key"
|
||||
model = "claude-3"
|
||||
|
||||
[providers.embedded]
|
||||
model_path = "test.gguf"
|
||||
model_type = "llama"
|
||||
|
||||
[agent]
|
||||
max_context_length = 8192
|
||||
enable_streaming = true
|
||||
timeout_seconds = 60
|
||||
"#;
|
||||
|
||||
fs::write(&config_path, config_content).unwrap();
|
||||
|
||||
// Load the configuration
|
||||
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
|
||||
|
||||
// Test that the providers are correctly identified
|
||||
assert_eq!(config.providers.default_provider, "databricks");
|
||||
assert_eq!(config.get_coach_provider(), "anthropic");
|
||||
assert_eq!(config.get_player_provider(), "embedded");
|
||||
|
||||
// Test creating coach config
|
||||
let coach_config = config.for_coach().unwrap();
|
||||
assert_eq!(coach_config.providers.default_provider, "anthropic");
|
||||
|
||||
// Test creating player config
|
||||
let player_config = config.for_player().unwrap();
|
||||
assert_eq!(player_config.providers.default_provider, "embedded");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coach_player_fallback_to_default() {
|
||||
// Create a temporary directory for the test config
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let config_path = temp_dir.path().join("test_config.toml");
|
||||
|
||||
// Write a test configuration WITHOUT coach and player providers
|
||||
let config_content = r#"
|
||||
[providers]
|
||||
default_provider = "databricks"
|
||||
|
||||
[providers.databricks]
|
||||
host = "https://test.databricks.com"
|
||||
token = "test-token"
|
||||
model = "test-model"
|
||||
|
||||
[agent]
|
||||
max_context_length = 8192
|
||||
enable_streaming = true
|
||||
timeout_seconds = 60
|
||||
"#;
|
||||
|
||||
fs::write(&config_path, config_content).unwrap();
|
||||
|
||||
// Load the configuration
|
||||
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
|
||||
|
||||
// Test that coach and player fall back to default provider
|
||||
assert_eq!(config.get_coach_provider(), "databricks");
|
||||
assert_eq!(config.get_player_provider(), "databricks");
|
||||
|
||||
// Test creating coach config (should use default)
|
||||
let coach_config = config.for_coach().unwrap();
|
||||
assert_eq!(coach_config.providers.default_provider, "databricks");
|
||||
|
||||
// Test creating player config (should use default)
|
||||
let player_config = config.for_player().unwrap();
|
||||
assert_eq!(player_config.providers.default_provider, "databricks");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_provider_error() {
|
||||
// Create a temporary directory for the test config
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let config_path = temp_dir.path().join("test_config.toml");
|
||||
|
||||
// Write a test configuration with an unconfigured provider
|
||||
let config_content = r#"
|
||||
[providers]
|
||||
default_provider = "databricks"
|
||||
coach = "openai" # OpenAI is not configured
|
||||
|
||||
[providers.databricks]
|
||||
host = "https://test.databricks.com"
|
||||
token = "test-token"
|
||||
model = "test-model"
|
||||
|
||||
[agent]
|
||||
max_context_length = 8192
|
||||
enable_streaming = true
|
||||
timeout_seconds = 60
|
||||
"#;
|
||||
|
||||
fs::write(&config_path, config_content).unwrap();
|
||||
|
||||
// Load the configuration
|
||||
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
|
||||
|
||||
// Test that trying to create a coach config with unconfigured provider fails
|
||||
let result = config.for_coach();
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("not configured"));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user