Added --provider and --model flags
This commit is contained in:
@@ -62,6 +62,14 @@ pub struct Cli {
|
|||||||
/// Color theme for retro mode (default, dracula, or path to theme file)
|
/// Color theme for retro mode (default, dracula, or path to theme file)
|
||||||
#[arg(long, value_name = "THEME")]
|
#[arg(long, value_name = "THEME")]
|
||||||
pub theme: Option<String>,
|
pub theme: Option<String>,
|
||||||
|
|
||||||
|
/// Override the configured provider (anthropic, databricks, embedded, openai)
|
||||||
|
#[arg(long, value_name = "PROVIDER")]
|
||||||
|
pub provider: Option<String>,
|
||||||
|
|
||||||
|
/// Override the model for the selected provider
|
||||||
|
#[arg(long, value_name = "MODEL")]
|
||||||
|
pub model: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run() -> Result<()> {
|
pub async fn run() -> Result<()> {
|
||||||
@@ -140,8 +148,23 @@ pub async fn run() -> Result<()> {
|
|||||||
info!("Using workspace: {}", project.workspace().display());
|
info!("Using workspace: {}", project.workspace().display());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load configuration
|
// Load configuration with CLI overrides
|
||||||
let config = Config::load(cli.config.as_deref())?;
|
let config = Config::load_with_overrides(
|
||||||
|
cli.config.as_deref(),
|
||||||
|
cli.provider.clone(),
|
||||||
|
cli.model.clone(),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Validate provider if specified
|
||||||
|
if let Some(ref provider) = cli.provider {
|
||||||
|
let valid_providers = ["anthropic", "databricks", "embedded", "openai"];
|
||||||
|
if !valid_providers.contains(&provider.as_str()) {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"Invalid provider '{}'. Valid options: {:?}",
|
||||||
|
provider, valid_providers
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize agent
|
// Initialize agent
|
||||||
let ui_writer = ConsoleUiWriter::new();
|
let ui_writer = ConsoleUiWriter::new();
|
||||||
@@ -184,7 +207,7 @@ pub async fn run() -> Result<()> {
|
|||||||
if cli.retro {
|
if cli.retro {
|
||||||
// Use retro terminal UI
|
// Use retro terminal UI
|
||||||
run_interactive_retro(
|
run_interactive_retro(
|
||||||
config,
|
config, // Already has overrides applied
|
||||||
cli.show_prompt,
|
cli.show_prompt,
|
||||||
cli.show_code,
|
cli.show_code,
|
||||||
cli.theme,
|
cli.theme,
|
||||||
@@ -1100,7 +1123,8 @@ async fn run_autonomous(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a new agent instance for coach mode to ensure fresh context
|
// Create a new agent instance for coach mode to ensure fresh context
|
||||||
let config = g3_config::Config::load(None)?;
|
// Use the same config with overrides that was passed to the player agent
|
||||||
|
let config = agent.get_config().clone();
|
||||||
let ui_writer = ConsoleUiWriter::new();
|
let ui_writer = ConsoleUiWriter::new();
|
||||||
let mut coach_agent = Agent::new_autonomous(config, ui_writer).await?;
|
let mut coach_agent = Agent::new_autonomous(config, ui_writer).await?;
|
||||||
|
|
||||||
|
|||||||
@@ -202,4 +202,64 @@ impl Config {
|
|||||||
std::fs::write(path, toml_string)?;
|
std::fs::write(path, toml_string)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn load_with_overrides(
|
||||||
|
config_path: Option<&str>,
|
||||||
|
provider_override: Option<String>,
|
||||||
|
model_override: Option<String>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
// Load the base configuration
|
||||||
|
let mut config = Self::load(config_path)?;
|
||||||
|
|
||||||
|
// Apply provider override
|
||||||
|
if let Some(provider) = provider_override {
|
||||||
|
config.providers.default_provider = provider;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply model override to the active provider
|
||||||
|
if let Some(model) = model_override {
|
||||||
|
match config.providers.default_provider.as_str() {
|
||||||
|
"anthropic" => {
|
||||||
|
if let Some(ref mut anthropic) = config.providers.anthropic {
|
||||||
|
anthropic.model = model;
|
||||||
|
} else {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"Provider 'anthropic' is not configured. Please add anthropic configuration to your config file."
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"databricks" => {
|
||||||
|
if let Some(ref mut databricks) = config.providers.databricks {
|
||||||
|
databricks.model = model;
|
||||||
|
} else {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"Provider 'databricks' is not configured. Please add databricks configuration to your config file."
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"embedded" => {
|
||||||
|
if let Some(ref mut embedded) = config.providers.embedded {
|
||||||
|
embedded.model_path = model;
|
||||||
|
} else {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"Provider 'embedded' is not configured. Please add embedded configuration to your config file."
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"openai" => {
|
||||||
|
if let Some(ref mut openai) = config.providers.openai {
|
||||||
|
openai.model = model;
|
||||||
|
} else {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"Provider 'openai' is not configured. Please add openai configuration to your config file."
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => return Err(anyhow::anyhow!("Unknown provider: {}",
|
||||||
|
config.providers.default_provider)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(config)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -412,6 +412,7 @@ Format this as a detailed but concise summary that can be used to resume the con
|
|||||||
pub struct Agent<W: UiWriter> {
|
pub struct Agent<W: UiWriter> {
|
||||||
providers: ProviderRegistry,
|
providers: ProviderRegistry,
|
||||||
context_window: ContextWindow,
|
context_window: ContextWindow,
|
||||||
|
config: Config,
|
||||||
session_id: Option<String>,
|
session_id: Option<String>,
|
||||||
tool_call_metrics: Vec<(String, Duration, bool)>, // (tool_name, duration, success)
|
tool_call_metrics: Vec<(String, Duration, bool)>, // (tool_name, duration, success)
|
||||||
ui_writer: W,
|
ui_writer: W,
|
||||||
@@ -549,6 +550,7 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
Ok(Self {
|
Ok(Self {
|
||||||
providers,
|
providers,
|
||||||
context_window,
|
context_window,
|
||||||
|
config,
|
||||||
session_id: None,
|
session_id: None,
|
||||||
tool_call_metrics: Vec::new(),
|
tool_call_metrics: Vec::new(),
|
||||||
ui_writer,
|
ui_writer,
|
||||||
@@ -958,6 +960,10 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
pub fn get_tool_call_metrics(&self) -> &Vec<(String, Duration, bool)> {
|
pub fn get_tool_call_metrics(&self) -> &Vec<(String, Duration, bool)> {
|
||||||
&self.tool_call_metrics
|
&self.tool_call_metrics
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_config(&self) -> &Config {
|
||||||
|
&self.config
|
||||||
|
}
|
||||||
|
|
||||||
async fn stream_completion(
|
async fn stream_completion(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
|||||||
Reference in New Issue
Block a user