Added --provider and --model flags

This commit is contained in:
Dhanji Prasanna
2025-10-12 17:05:46 +11:00
parent 037bff7021
commit 318355e864
3 changed files with 94 additions and 4 deletions

View File

@@ -62,6 +62,14 @@ pub struct Cli {
/// Color theme for retro mode (default, dracula, or path to theme file) /// Color theme for retro mode (default, dracula, or path to theme file)
#[arg(long, value_name = "THEME")] #[arg(long, value_name = "THEME")]
pub theme: Option<String>, pub theme: Option<String>,
/// Override the configured provider (anthropic, databricks, embedded, openai)
#[arg(long, value_name = "PROVIDER")]
pub provider: Option<String>,
/// Override the model for the selected provider
#[arg(long, value_name = "MODEL")]
pub model: Option<String>,
} }
pub async fn run() -> Result<()> { pub async fn run() -> Result<()> {
@@ -140,8 +148,23 @@ pub async fn run() -> Result<()> {
info!("Using workspace: {}", project.workspace().display()); info!("Using workspace: {}", project.workspace().display());
} }
// Load configuration // Load configuration with CLI overrides
let config = Config::load(cli.config.as_deref())?; let config = Config::load_with_overrides(
cli.config.as_deref(),
cli.provider.clone(),
cli.model.clone(),
)?;
// Validate provider if specified
if let Some(ref provider) = cli.provider {
let valid_providers = ["anthropic", "databricks", "embedded", "openai"];
if !valid_providers.contains(&provider.as_str()) {
return Err(anyhow::anyhow!(
"Invalid provider '{}'. Valid options: {:?}",
provider, valid_providers
));
}
}
// Initialize agent // Initialize agent
let ui_writer = ConsoleUiWriter::new(); let ui_writer = ConsoleUiWriter::new();
@@ -184,7 +207,7 @@ pub async fn run() -> Result<()> {
if cli.retro { if cli.retro {
// Use retro terminal UI // Use retro terminal UI
run_interactive_retro( run_interactive_retro(
config, config, // Already has overrides applied
cli.show_prompt, cli.show_prompt,
cli.show_code, cli.show_code,
cli.theme, cli.theme,
@@ -1100,7 +1123,8 @@ async fn run_autonomous(
} }
// Create a new agent instance for coach mode to ensure fresh context // Create a new agent instance for coach mode to ensure fresh context
let config = g3_config::Config::load(None)?; // Use the same config with overrides that was passed to the player agent
let config = agent.get_config().clone();
let ui_writer = ConsoleUiWriter::new(); let ui_writer = ConsoleUiWriter::new();
let mut coach_agent = Agent::new_autonomous(config, ui_writer).await?; let mut coach_agent = Agent::new_autonomous(config, ui_writer).await?;

View File

@@ -202,4 +202,64 @@ impl Config {
std::fs::write(path, toml_string)?; std::fs::write(path, toml_string)?;
Ok(()) Ok(())
} }
pub fn load_with_overrides(
config_path: Option<&str>,
provider_override: Option<String>,
model_override: Option<String>,
) -> Result<Self> {
// Load the base configuration
let mut config = Self::load(config_path)?;
// Apply provider override
if let Some(provider) = provider_override {
config.providers.default_provider = provider;
}
// Apply model override to the active provider
if let Some(model) = model_override {
match config.providers.default_provider.as_str() {
"anthropic" => {
if let Some(ref mut anthropic) = config.providers.anthropic {
anthropic.model = model;
} else {
return Err(anyhow::anyhow!(
"Provider 'anthropic' is not configured. Please add anthropic configuration to your config file."
));
}
}
"databricks" => {
if let Some(ref mut databricks) = config.providers.databricks {
databricks.model = model;
} else {
return Err(anyhow::anyhow!(
"Provider 'databricks' is not configured. Please add databricks configuration to your config file."
));
}
}
"embedded" => {
if let Some(ref mut embedded) = config.providers.embedded {
embedded.model_path = model;
} else {
return Err(anyhow::anyhow!(
"Provider 'embedded' is not configured. Please add embedded configuration to your config file."
));
}
}
"openai" => {
if let Some(ref mut openai) = config.providers.openai {
openai.model = model;
} else {
return Err(anyhow::anyhow!(
"Provider 'openai' is not configured. Please add openai configuration to your config file."
));
}
}
_ => return Err(anyhow::anyhow!("Unknown provider: {}",
config.providers.default_provider)),
}
}
Ok(config)
}
} }

View File

@@ -412,6 +412,7 @@ Format this as a detailed but concise summary that can be used to resume the con
pub struct Agent<W: UiWriter> { pub struct Agent<W: UiWriter> {
providers: ProviderRegistry, providers: ProviderRegistry,
context_window: ContextWindow, context_window: ContextWindow,
config: Config,
session_id: Option<String>, session_id: Option<String>,
tool_call_metrics: Vec<(String, Duration, bool)>, // (tool_name, duration, success) tool_call_metrics: Vec<(String, Duration, bool)>, // (tool_name, duration, success)
ui_writer: W, ui_writer: W,
@@ -549,6 +550,7 @@ impl<W: UiWriter> Agent<W> {
Ok(Self { Ok(Self {
providers, providers,
context_window, context_window,
config,
session_id: None, session_id: None,
tool_call_metrics: Vec::new(), tool_call_metrics: Vec::new(),
ui_writer, ui_writer,
@@ -959,6 +961,10 @@ The tool will execute immediately and you'll receive the result (success or erro
&self.tool_call_metrics &self.tool_call_metrics
} }
pub fn get_config(&self) -> &Config {
&self.config
}
async fn stream_completion( async fn stream_completion(
&mut self, &mut self,
request: CompletionRequest, request: CompletionRequest,