Implement planning mode
This commit is contained in:
@@ -40,6 +40,18 @@ use std::time::{Duration, Instant};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
/// Get the path to the todo.g3.md file.
|
||||
///
|
||||
/// Checks for G3_TODO_PATH environment variable first (used by planning mode),
|
||||
/// then falls back to todo.g3.md in the current directory.
|
||||
fn get_todo_path() -> std::path::PathBuf {
|
||||
if let Ok(custom_path) = std::env::var("G3_TODO_PATH") {
|
||||
std::path::PathBuf::from(custom_path)
|
||||
} else {
|
||||
std::env::current_dir().unwrap_or_default().join("todo.g3.md")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
pub tool: String,
|
||||
@@ -1119,12 +1131,18 @@ impl<W: UiWriter> Agent<W> {
|
||||
vec![config.providers.default_provider.clone()]
|
||||
};
|
||||
|
||||
// Only register providers that are configured AND selected as the default provider
|
||||
// Only register providers that are configured AND selected
|
||||
// This prevents unnecessary initialization of heavy providers like embedded models
|
||||
|
||||
// Register embedded provider if configured AND it's the default provider
|
||||
if let Some(embedded_config) = &config.providers.embedded {
|
||||
if providers_to_register.contains(&"embedded".to_string()) {
|
||||
// Helper to check if a provider ref should be registered
|
||||
let should_register = |provider_type: &str, config_name: &str| -> bool {
|
||||
let full_ref = format!("{}.{}", provider_type, config_name);
|
||||
providers_to_register.iter().any(|p| p == &full_ref || p.starts_with(&format!("{}.", provider_type)))
|
||||
};
|
||||
|
||||
// Register embedded providers from HashMap
|
||||
for (name, embedded_config) in &config.providers.embedded {
|
||||
if should_register("embedded", name) {
|
||||
let embedded_provider = g3_providers::EmbeddedProvider::new(
|
||||
embedded_config.model_path.clone(),
|
||||
embedded_config.model_type.clone(),
|
||||
@@ -1138,10 +1156,11 @@ impl<W: UiWriter> Agent<W> {
|
||||
}
|
||||
}
|
||||
|
||||
// Register OpenAI provider if configured AND it's the default provider
|
||||
if let Some(openai_config) = &config.providers.openai {
|
||||
if providers_to_register.contains(&"openai".to_string()) {
|
||||
let openai_provider = g3_providers::OpenAIProvider::new(
|
||||
// Register OpenAI providers from HashMap
|
||||
for (name, openai_config) in &config.providers.openai {
|
||||
if should_register("openai", name) {
|
||||
let openai_provider = g3_providers::OpenAIProvider::new_with_name(
|
||||
format!("openai.{}", name),
|
||||
openai_config.api_key.clone(),
|
||||
Some(openai_config.model.clone()),
|
||||
openai_config.base_url.clone(),
|
||||
@@ -1154,7 +1173,7 @@ impl<W: UiWriter> Agent<W> {
|
||||
|
||||
// Register OpenAI-compatible providers (e.g., OpenRouter, Groq, etc.)
|
||||
for (name, openai_config) in &config.providers.openai_compatible {
|
||||
if providers_to_register.contains(name) {
|
||||
if should_register(name, "default") {
|
||||
let openai_provider = g3_providers::OpenAIProvider::new_with_name(
|
||||
name.clone(),
|
||||
openai_config.api_key.clone(),
|
||||
@@ -1167,10 +1186,11 @@ impl<W: UiWriter> Agent<W> {
|
||||
}
|
||||
}
|
||||
|
||||
// Register Anthropic provider if configured AND it's the default provider
|
||||
if let Some(anthropic_config) = &config.providers.anthropic {
|
||||
if providers_to_register.contains(&"anthropic".to_string()) {
|
||||
let anthropic_provider = g3_providers::AnthropicProvider::new(
|
||||
// Register Anthropic providers from HashMap
|
||||
for (name, anthropic_config) in &config.providers.anthropic {
|
||||
if should_register("anthropic", name) {
|
||||
let anthropic_provider = g3_providers::AnthropicProvider::new_with_name(
|
||||
format!("anthropic.{}", name),
|
||||
anthropic_config.api_key.clone(),
|
||||
Some(anthropic_config.model.clone()),
|
||||
anthropic_config.max_tokens,
|
||||
@@ -1183,12 +1203,13 @@ impl<W: UiWriter> Agent<W> {
|
||||
}
|
||||
}
|
||||
|
||||
// Register Databricks provider if configured AND it's the default provider
|
||||
if let Some(databricks_config) = &config.providers.databricks {
|
||||
if providers_to_register.contains(&"databricks".to_string()) {
|
||||
// Register Databricks providers from HashMap
|
||||
for (name, databricks_config) in &config.providers.databricks {
|
||||
if should_register("databricks", name) {
|
||||
let databricks_provider = if let Some(token) = &databricks_config.token {
|
||||
// Use token-based authentication
|
||||
g3_providers::DatabricksProvider::from_token(
|
||||
g3_providers::DatabricksProvider::from_token_with_name(
|
||||
format!("databricks.{}", name),
|
||||
databricks_config.host.clone(),
|
||||
token.clone(),
|
||||
databricks_config.model.clone(),
|
||||
@@ -1197,7 +1218,8 @@ impl<W: UiWriter> Agent<W> {
|
||||
)?
|
||||
} else {
|
||||
// Use OAuth authentication
|
||||
g3_providers::DatabricksProvider::from_oauth(
|
||||
g3_providers::DatabricksProvider::from_oauth_with_name(
|
||||
format!("databricks.{}", name),
|
||||
databricks_config.host.clone(),
|
||||
databricks_config.model.clone(),
|
||||
databricks_config.max_tokens,
|
||||
@@ -1253,13 +1275,9 @@ impl<W: UiWriter> Agent<W> {
|
||||
}
|
||||
|
||||
// Load existing TODO list if present (after system prompt and README)
|
||||
let todo_path = std::env::current_dir().ok().map(|p| p.join("todo.g3.md"));
|
||||
let initial_todo_content = if let Some(ref path) = todo_path {
|
||||
if path.exists() {
|
||||
std::fs::read_to_string(path).ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
let todo_path = get_todo_path();
|
||||
let initial_todo_content = if todo_path.exists() {
|
||||
std::fs::read_to_string(&todo_path).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -1304,13 +1322,8 @@ impl<W: UiWriter> Agent<W> {
|
||||
ui_writer,
|
||||
todo_content: std::sync::Arc::new(tokio::sync::RwLock::new({
|
||||
// Initialize from TODO.md file if it exists
|
||||
let todo_path = std::env::current_dir().ok().map(|p| p.join("todo.g3.md"));
|
||||
|
||||
if let Some(path) = todo_path {
|
||||
std::fs::read_to_string(&path).unwrap_or_default()
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
let todo_path = get_todo_path();
|
||||
std::fs::read_to_string(&todo_path).unwrap_or_default()
|
||||
})),
|
||||
is_autonomous,
|
||||
quiet,
|
||||
@@ -1386,22 +1399,40 @@ impl<W: UiWriter> Agent<W> {
|
||||
|
||||
/// Get the configured max_tokens for a provider from top-level config
|
||||
fn provider_max_tokens(config: &Config, provider_name: &str) -> Option<u32> {
|
||||
match provider_name {
|
||||
"anthropic" => config.providers.anthropic.as_ref()?.max_tokens,
|
||||
"openai" => config.providers.openai.as_ref()?.max_tokens,
|
||||
"databricks" => config.providers.databricks.as_ref()?.max_tokens,
|
||||
"embedded" => config.providers.embedded.as_ref()?.max_tokens,
|
||||
// Parse provider reference (format: "provider_type.config_name")
|
||||
let parts: Vec<&str> = provider_name.split('.').collect();
|
||||
let (provider_type, config_name) = if parts.len() == 2 {
|
||||
(parts[0], parts[1])
|
||||
} else {
|
||||
// Fallback for simple provider names - assume "default" config
|
||||
(provider_name, "default")
|
||||
};
|
||||
|
||||
match provider_type {
|
||||
"anthropic" => config.providers.anthropic.get(config_name)?.max_tokens,
|
||||
"openai" => config.providers.openai.get(config_name)?.max_tokens,
|
||||
"databricks" => config.providers.databricks.get(config_name)?.max_tokens,
|
||||
"embedded" => config.providers.embedded.get(config_name)?.max_tokens,
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the configured temperature for a provider from top-level config
|
||||
fn provider_temperature(config: &Config, provider_name: &str) -> Option<f32> {
|
||||
match provider_name {
|
||||
"anthropic" => config.providers.anthropic.as_ref()?.temperature,
|
||||
"openai" => config.providers.openai.as_ref()?.temperature,
|
||||
"databricks" => config.providers.databricks.as_ref()?.temperature,
|
||||
"embedded" => config.providers.embedded.as_ref()?.temperature,
|
||||
// Parse provider reference (format: "provider_type.config_name")
|
||||
let parts: Vec<&str> = provider_name.split('.').collect();
|
||||
let (provider_type, config_name) = if parts.len() == 2 {
|
||||
(parts[0], parts[1])
|
||||
} else {
|
||||
// Fallback for simple provider names - assume "default" config
|
||||
(provider_name, "default")
|
||||
};
|
||||
|
||||
match provider_type {
|
||||
"anthropic" => config.providers.anthropic.get(config_name)?.temperature,
|
||||
"openai" => config.providers.openai.get(config_name)?.temperature,
|
||||
"databricks" => config.providers.databricks.get(config_name)?.temperature,
|
||||
"embedded" => config.providers.embedded.get(config_name)?.temperature,
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@@ -1430,11 +1461,23 @@ impl<W: UiWriter> Agent<W> {
|
||||
}
|
||||
|
||||
/// Get the thinking budget tokens for Anthropic provider, if configured
|
||||
fn get_thinking_budget_tokens(&self) -> Option<u32> {
|
||||
self.config
|
||||
.providers
|
||||
.anthropic
|
||||
.as_ref()
|
||||
fn get_thinking_budget_tokens(&self, provider_name: &str) -> Option<u32> {
|
||||
// Parse provider reference (format: "provider_type.config_name")
|
||||
let parts: Vec<&str> = provider_name.split('.').collect();
|
||||
let (provider_type, config_name) = if parts.len() == 2 {
|
||||
(parts[0], parts[1])
|
||||
} else {
|
||||
// Fallback for simple provider names - assume "default" config
|
||||
(provider_name, "default")
|
||||
};
|
||||
|
||||
// Only Anthropic has thinking_budget_tokens
|
||||
if provider_type != "anthropic" {
|
||||
return None;
|
||||
}
|
||||
|
||||
self.config.providers.anthropic
|
||||
.get(config_name)
|
||||
.and_then(|c| c.thinking_budget_tokens)
|
||||
}
|
||||
|
||||
@@ -1448,12 +1491,15 @@ impl<W: UiWriter> Agent<W> {
|
||||
provider_name: &str,
|
||||
proposed_max_tokens: u32,
|
||||
) -> (u32, bool) {
|
||||
// Only applies to Anthropic provider with thinking enabled
|
||||
if provider_name != "anthropic" {
|
||||
// Parse provider type from provider_name (format: "provider_type.config_name")
|
||||
let provider_type = provider_name.split('.').next().unwrap_or(provider_name);
|
||||
|
||||
// Only applies to Anthropic provider
|
||||
if provider_type != "anthropic" {
|
||||
return (proposed_max_tokens, false);
|
||||
}
|
||||
|
||||
let budget_tokens = match self.get_thinking_budget_tokens() {
|
||||
let budget_tokens = match self.get_thinking_budget_tokens(provider_name) {
|
||||
Some(budget) => budget,
|
||||
None => return (proposed_max_tokens, false), // No thinking enabled
|
||||
};
|
||||
@@ -1702,14 +1748,23 @@ impl<W: UiWriter> Agent<W> {
|
||||
let provider_name = provider.name();
|
||||
let model_name = provider.model();
|
||||
|
||||
// Use provider-specific context length if available, otherwise fall back to agent config
|
||||
let context_length = match provider_name {
|
||||
"embedded" => {
|
||||
// Parse provider name to get type and config name
|
||||
let parts: Vec<&str> = provider_name.split('.').collect();
|
||||
let (provider_type, config_name) = if parts.len() == 2 {
|
||||
(parts[0], parts[1])
|
||||
} else {
|
||||
// Fallback for simple provider names
|
||||
(provider_name, "default")
|
||||
};
|
||||
|
||||
// Use provider-specific context length if available
|
||||
let context_length = match provider_type {
|
||||
"embedded" | "embedded." => {
|
||||
// For embedded models, use the configured context_length or model-specific defaults
|
||||
if let Some(embedded_config) = &config.providers.embedded {
|
||||
if let Some(embedded_config) = config.providers.embedded.get(config_name) {
|
||||
embedded_config.context_length.unwrap_or_else(|| {
|
||||
// Model-specific defaults for embedded models
|
||||
match embedded_config.model_type.to_lowercase().as_str() {
|
||||
match &embedded_config.model_type.to_lowercase()[..] {
|
||||
"codellama" => 16384, // CodeLlama supports 16k context
|
||||
"llama" => 4096, // Base Llama models
|
||||
"mistral" => 8192, // Mistral models
|
||||
@@ -1722,11 +1777,11 @@ impl<W: UiWriter> Agent<W> {
|
||||
}
|
||||
}
|
||||
"openai" => {
|
||||
// gpt-5 has 400k window
|
||||
if let Some(max_tokens) = Self::provider_max_tokens(config, "openai") {
|
||||
// OpenAI models have varying context windows
|
||||
if let Some(max_tokens) = Self::provider_max_tokens(config, provider_name) {
|
||||
warnings.push(format!(
|
||||
"Context length falling back to max_tokens ({}) for provider=openai",
|
||||
max_tokens
|
||||
"Context length falling back to max_tokens ({}) for provider={}",
|
||||
max_tokens, provider_name
|
||||
));
|
||||
max_tokens
|
||||
} else {
|
||||
@@ -1735,11 +1790,10 @@ impl<W: UiWriter> Agent<W> {
|
||||
}
|
||||
"anthropic" => {
|
||||
// Claude models have large context windows
|
||||
// Use configured max_tokens or fall back to default
|
||||
if let Some(max_tokens) = Self::provider_max_tokens(config, "anthropic") {
|
||||
if let Some(max_tokens) = Self::provider_max_tokens(config, provider_name) {
|
||||
warnings.push(format!(
|
||||
"Context length falling back to max_tokens ({}) for provider=anthropic",
|
||||
max_tokens
|
||||
"Context length falling back to max_tokens ({}) for provider={}",
|
||||
max_tokens, provider_name
|
||||
));
|
||||
max_tokens
|
||||
} else {
|
||||
@@ -1748,11 +1802,10 @@ impl<W: UiWriter> Agent<W> {
|
||||
}
|
||||
"databricks" => {
|
||||
// Databricks models have varying context windows depending on the model
|
||||
// Use configured max_tokens or fall back to model-specific defaults
|
||||
if let Some(max_tokens) = Self::provider_max_tokens(config, "databricks") {
|
||||
if let Some(max_tokens) = Self::provider_max_tokens(config, provider_name) {
|
||||
warnings.push(format!(
|
||||
"Context length falling back to max_tokens ({}) for provider=databricks",
|
||||
max_tokens
|
||||
"Context length falling back to max_tokens ({}) for provider={}",
|
||||
max_tokens, provider_name
|
||||
));
|
||||
max_tokens
|
||||
} else if model_name.contains("claude") {
|
||||
@@ -1948,14 +2001,18 @@ impl<W: UiWriter> Agent<W> {
|
||||
// Check if we should use cache control (every 10 tool calls)
|
||||
// But only if we haven't already added 4 cache_control annotations
|
||||
let provider = self.providers.get(None)?;
|
||||
if let Some(cache_config) = match provider.name() {
|
||||
"anthropic" => self
|
||||
.config
|
||||
.providers
|
||||
.anthropic
|
||||
.as_ref()
|
||||
.and_then(|c| c.cache_config.as_ref())
|
||||
.and_then(|config| Self::parse_cache_control(config)),
|
||||
let provider_name = provider.name();
|
||||
let provider_type = provider_name.split('.').next().unwrap_or("");
|
||||
let config_name = provider_name.split('.').nth(1).unwrap_or("default");
|
||||
if let Some(cache_config) = match provider_type {
|
||||
"anthropic" => {
|
||||
self.config
|
||||
.providers
|
||||
.anthropic
|
||||
.get(config_name)
|
||||
.and_then(|c| c.cache_config.as_ref())
|
||||
.and_then(|config| Self::parse_cache_control(config))
|
||||
}
|
||||
_ => None,
|
||||
} {
|
||||
Message::with_cache_control_validated(
|
||||
@@ -2451,7 +2508,7 @@ impl<W: UiWriter> Agent<W> {
|
||||
// Apply provider-specific caps
|
||||
// For Anthropic with thinking enabled, we need max_tokens > thinking.budget_tokens
|
||||
// So we set a higher cap when thinking is configured
|
||||
let anthropic_cap = match self.get_thinking_budget_tokens() {
|
||||
let anthropic_cap = match self.get_thinking_budget_tokens(&provider_name) {
|
||||
Some(budget) => (budget + 2000).max(10_000), // At least budget + 2000 for response
|
||||
None => 10_000,
|
||||
};
|
||||
@@ -3485,7 +3542,7 @@ impl<W: UiWriter> Agent<W> {
|
||||
// Apply provider-specific caps
|
||||
// For Anthropic with thinking enabled, we need max_tokens > thinking.budget_tokens
|
||||
// So we set a higher cap when thinking is configured
|
||||
let anthropic_cap = match self.get_thinking_budget_tokens() {
|
||||
let anthropic_cap = match self.get_thinking_budget_tokens(&provider_name) {
|
||||
Some(budget) => (budget + 2000).max(10_000), // At least budget + 2000 for response
|
||||
None => 10_000,
|
||||
};
|
||||
@@ -4078,14 +4135,18 @@ impl<W: UiWriter> Agent<W> {
|
||||
&& self.count_cache_controls_in_history() < 4
|
||||
{
|
||||
let provider = self.providers.get(None)?;
|
||||
if let Some(cache_config) = match provider.name() {
|
||||
"anthropic" => self
|
||||
.config
|
||||
.providers
|
||||
.anthropic
|
||||
.as_ref()
|
||||
.and_then(|c| c.cache_config.as_ref())
|
||||
.and_then(|config| Self::parse_cache_control(config)),
|
||||
let provider_name = provider.name();
|
||||
let provider_type = provider_name.split('.').next().unwrap_or("");
|
||||
let config_name = provider_name.split('.').nth(1).unwrap_or("default");
|
||||
if let Some(cache_config) = match provider_type {
|
||||
"anthropic" => {
|
||||
self.config
|
||||
.providers
|
||||
.anthropic
|
||||
.get(config_name)
|
||||
.and_then(|c| c.cache_config.as_ref())
|
||||
.and_then(|config| Self::parse_cache_control(config))
|
||||
}
|
||||
_ => None,
|
||||
} {
|
||||
Message::with_cache_control_validated(
|
||||
@@ -5118,8 +5179,8 @@ impl<W: UiWriter> Agent<W> {
|
||||
}
|
||||
"todo_read" => {
|
||||
debug!("Processing todo_read tool call");
|
||||
// Read from todo.g3.md file in current workspace directory
|
||||
let todo_path = std::env::current_dir()?.join("todo.g3.md");
|
||||
// Read from todo.g3.md file (uses G3_TODO_PATH env var if set, else current dir)
|
||||
let todo_path = get_todo_path();
|
||||
|
||||
if !todo_path.exists() {
|
||||
// Also update in-memory content to stay in sync
|
||||
@@ -5233,7 +5294,7 @@ impl<W: UiWriter> Agent<W> {
|
||||
|
||||
// If all todos are complete, delete the file instead of writing
|
||||
if !has_incomplete && (content_str.contains("- [x]") || content_str.contains("- [X]")) {
|
||||
let todo_path = std::env::current_dir()?.join("todo.g3.md");
|
||||
let todo_path = get_todo_path();
|
||||
if todo_path.exists() {
|
||||
match std::fs::remove_file(&todo_path) {
|
||||
Ok(_) => {
|
||||
@@ -5253,8 +5314,8 @@ impl<W: UiWriter> Agent<W> {
|
||||
}
|
||||
}
|
||||
|
||||
// Write to todo.g3.md file in current workspace directory
|
||||
let todo_path = std::env::current_dir()?.join("todo.g3.md");
|
||||
// Write to todo.g3.md file (uses G3_TODO_PATH env var if set, else current dir)
|
||||
let todo_path = get_todo_path();
|
||||
|
||||
match std::fs::write(&todo_path, content_str) {
|
||||
Ok(_) => {
|
||||
|
||||
Reference in New Issue
Block a user