Implement planning mode

This commit is contained in:
Jochen
2025-12-09 09:59:28 +11:00
parent 4aa84e2144
commit ff8b3e7c7b
24 changed files with 3817 additions and 346 deletions

View File

@@ -40,6 +40,18 @@ use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
/// Get the path to the todo.g3.md file.
///
/// Checks for G3_TODO_PATH environment variable first (used by planning mode),
/// then falls back to todo.g3.md in the current directory.
fn get_todo_path() -> std::path::PathBuf {
if let Ok(custom_path) = std::env::var("G3_TODO_PATH") {
std::path::PathBuf::from(custom_path)
} else {
std::env::current_dir().unwrap_or_default().join("todo.g3.md")
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub tool: String,
@@ -1119,12 +1131,18 @@ impl<W: UiWriter> Agent<W> {
vec![config.providers.default_provider.clone()]
};
// Only register providers that are configured AND selected as the default provider
// Only register providers that are configured AND selected
// This prevents unnecessary initialization of heavy providers like embedded models
// Register embedded provider if configured AND it's the default provider
if let Some(embedded_config) = &config.providers.embedded {
if providers_to_register.contains(&"embedded".to_string()) {
// Helper to check if a provider ref should be registered
let should_register = |provider_type: &str, config_name: &str| -> bool {
let full_ref = format!("{}.{}", provider_type, config_name);
providers_to_register.iter().any(|p| p == &full_ref || p.starts_with(&format!("{}.", provider_type)))
};
// Register embedded providers from HashMap
for (name, embedded_config) in &config.providers.embedded {
if should_register("embedded", name) {
let embedded_provider = g3_providers::EmbeddedProvider::new(
embedded_config.model_path.clone(),
embedded_config.model_type.clone(),
@@ -1138,10 +1156,11 @@ impl<W: UiWriter> Agent<W> {
}
}
// Register OpenAI provider if configured AND it's the default provider
if let Some(openai_config) = &config.providers.openai {
if providers_to_register.contains(&"openai".to_string()) {
let openai_provider = g3_providers::OpenAIProvider::new(
// Register OpenAI providers from HashMap
for (name, openai_config) in &config.providers.openai {
if should_register("openai", name) {
let openai_provider = g3_providers::OpenAIProvider::new_with_name(
format!("openai.{}", name),
openai_config.api_key.clone(),
Some(openai_config.model.clone()),
openai_config.base_url.clone(),
@@ -1154,7 +1173,7 @@ impl<W: UiWriter> Agent<W> {
// Register OpenAI-compatible providers (e.g., OpenRouter, Groq, etc.)
for (name, openai_config) in &config.providers.openai_compatible {
if providers_to_register.contains(name) {
if should_register(name, "default") {
let openai_provider = g3_providers::OpenAIProvider::new_with_name(
name.clone(),
openai_config.api_key.clone(),
@@ -1167,10 +1186,11 @@ impl<W: UiWriter> Agent<W> {
}
}
// Register Anthropic provider if configured AND it's the default provider
if let Some(anthropic_config) = &config.providers.anthropic {
if providers_to_register.contains(&"anthropic".to_string()) {
let anthropic_provider = g3_providers::AnthropicProvider::new(
// Register Anthropic providers from HashMap
for (name, anthropic_config) in &config.providers.anthropic {
if should_register("anthropic", name) {
let anthropic_provider = g3_providers::AnthropicProvider::new_with_name(
format!("anthropic.{}", name),
anthropic_config.api_key.clone(),
Some(anthropic_config.model.clone()),
anthropic_config.max_tokens,
@@ -1183,12 +1203,13 @@ impl<W: UiWriter> Agent<W> {
}
}
// Register Databricks provider if configured AND it's the default provider
if let Some(databricks_config) = &config.providers.databricks {
if providers_to_register.contains(&"databricks".to_string()) {
// Register Databricks providers from HashMap
for (name, databricks_config) in &config.providers.databricks {
if should_register("databricks", name) {
let databricks_provider = if let Some(token) = &databricks_config.token {
// Use token-based authentication
g3_providers::DatabricksProvider::from_token(
g3_providers::DatabricksProvider::from_token_with_name(
format!("databricks.{}", name),
databricks_config.host.clone(),
token.clone(),
databricks_config.model.clone(),
@@ -1197,7 +1218,8 @@ impl<W: UiWriter> Agent<W> {
)?
} else {
// Use OAuth authentication
g3_providers::DatabricksProvider::from_oauth(
g3_providers::DatabricksProvider::from_oauth_with_name(
format!("databricks.{}", name),
databricks_config.host.clone(),
databricks_config.model.clone(),
databricks_config.max_tokens,
@@ -1253,13 +1275,9 @@ impl<W: UiWriter> Agent<W> {
}
// Load existing TODO list if present (after system prompt and README)
let todo_path = std::env::current_dir().ok().map(|p| p.join("todo.g3.md"));
let initial_todo_content = if let Some(ref path) = todo_path {
if path.exists() {
std::fs::read_to_string(path).ok()
} else {
None
}
let todo_path = get_todo_path();
let initial_todo_content = if todo_path.exists() {
std::fs::read_to_string(&todo_path).ok()
} else {
None
};
@@ -1304,13 +1322,8 @@ impl<W: UiWriter> Agent<W> {
ui_writer,
todo_content: std::sync::Arc::new(tokio::sync::RwLock::new({
// Initialize from TODO.md file if it exists
let todo_path = std::env::current_dir().ok().map(|p| p.join("todo.g3.md"));
if let Some(path) = todo_path {
std::fs::read_to_string(&path).unwrap_or_default()
} else {
String::new()
}
let todo_path = get_todo_path();
std::fs::read_to_string(&todo_path).unwrap_or_default()
})),
is_autonomous,
quiet,
@@ -1386,22 +1399,40 @@ impl<W: UiWriter> Agent<W> {
/// Get the configured max_tokens for a provider from top-level config
fn provider_max_tokens(config: &Config, provider_name: &str) -> Option<u32> {
match provider_name {
"anthropic" => config.providers.anthropic.as_ref()?.max_tokens,
"openai" => config.providers.openai.as_ref()?.max_tokens,
"databricks" => config.providers.databricks.as_ref()?.max_tokens,
"embedded" => config.providers.embedded.as_ref()?.max_tokens,
// Parse provider reference (format: "provider_type.config_name")
let parts: Vec<&str> = provider_name.split('.').collect();
let (provider_type, config_name) = if parts.len() == 2 {
(parts[0], parts[1])
} else {
// Fallback for simple provider names - assume "default" config
(provider_name, "default")
};
match provider_type {
"anthropic" => config.providers.anthropic.get(config_name)?.max_tokens,
"openai" => config.providers.openai.get(config_name)?.max_tokens,
"databricks" => config.providers.databricks.get(config_name)?.max_tokens,
"embedded" => config.providers.embedded.get(config_name)?.max_tokens,
_ => None,
}
}
/// Get the configured temperature for a provider from top-level config
fn provider_temperature(config: &Config, provider_name: &str) -> Option<f32> {
match provider_name {
"anthropic" => config.providers.anthropic.as_ref()?.temperature,
"openai" => config.providers.openai.as_ref()?.temperature,
"databricks" => config.providers.databricks.as_ref()?.temperature,
"embedded" => config.providers.embedded.as_ref()?.temperature,
// Parse provider reference (format: "provider_type.config_name")
let parts: Vec<&str> = provider_name.split('.').collect();
let (provider_type, config_name) = if parts.len() == 2 {
(parts[0], parts[1])
} else {
// Fallback for simple provider names - assume "default" config
(provider_name, "default")
};
match provider_type {
"anthropic" => config.providers.anthropic.get(config_name)?.temperature,
"openai" => config.providers.openai.get(config_name)?.temperature,
"databricks" => config.providers.databricks.get(config_name)?.temperature,
"embedded" => config.providers.embedded.get(config_name)?.temperature,
_ => None,
}
}
@@ -1430,11 +1461,23 @@ impl<W: UiWriter> Agent<W> {
}
/// Get the thinking budget tokens for Anthropic provider, if configured
fn get_thinking_budget_tokens(&self) -> Option<u32> {
self.config
.providers
.anthropic
.as_ref()
fn get_thinking_budget_tokens(&self, provider_name: &str) -> Option<u32> {
// Parse provider reference (format: "provider_type.config_name")
let parts: Vec<&str> = provider_name.split('.').collect();
let (provider_type, config_name) = if parts.len() == 2 {
(parts[0], parts[1])
} else {
// Fallback for simple provider names - assume "default" config
(provider_name, "default")
};
// Only Anthropic has thinking_budget_tokens
if provider_type != "anthropic" {
return None;
}
self.config.providers.anthropic
.get(config_name)
.and_then(|c| c.thinking_budget_tokens)
}
@@ -1448,12 +1491,15 @@ impl<W: UiWriter> Agent<W> {
provider_name: &str,
proposed_max_tokens: u32,
) -> (u32, bool) {
// Only applies to Anthropic provider with thinking enabled
if provider_name != "anthropic" {
// Parse provider type from provider_name (format: "provider_type.config_name")
let provider_type = provider_name.split('.').next().unwrap_or(provider_name);
// Only applies to Anthropic provider
if provider_type != "anthropic" {
return (proposed_max_tokens, false);
}
let budget_tokens = match self.get_thinking_budget_tokens() {
let budget_tokens = match self.get_thinking_budget_tokens(provider_name) {
Some(budget) => budget,
None => return (proposed_max_tokens, false), // No thinking enabled
};
@@ -1702,14 +1748,23 @@ impl<W: UiWriter> Agent<W> {
let provider_name = provider.name();
let model_name = provider.model();
// Use provider-specific context length if available, otherwise fall back to agent config
let context_length = match provider_name {
"embedded" => {
// Parse provider name to get type and config name
let parts: Vec<&str> = provider_name.split('.').collect();
let (provider_type, config_name) = if parts.len() == 2 {
(parts[0], parts[1])
} else {
// Fallback for simple provider names
(provider_name, "default")
};
// Use provider-specific context length if available
let context_length = match provider_type {
"embedded" | "embedded." => {
// For embedded models, use the configured context_length or model-specific defaults
if let Some(embedded_config) = &config.providers.embedded {
if let Some(embedded_config) = config.providers.embedded.get(config_name) {
embedded_config.context_length.unwrap_or_else(|| {
// Model-specific defaults for embedded models
match embedded_config.model_type.to_lowercase().as_str() {
match &embedded_config.model_type.to_lowercase()[..] {
"codellama" => 16384, // CodeLlama supports 16k context
"llama" => 4096, // Base Llama models
"mistral" => 8192, // Mistral models
@@ -1722,11 +1777,11 @@ impl<W: UiWriter> Agent<W> {
}
}
"openai" => {
// gpt-5 has 400k window
if let Some(max_tokens) = Self::provider_max_tokens(config, "openai") {
// OpenAI models have varying context windows
if let Some(max_tokens) = Self::provider_max_tokens(config, provider_name) {
warnings.push(format!(
"Context length falling back to max_tokens ({}) for provider=openai",
max_tokens
"Context length falling back to max_tokens ({}) for provider={}",
max_tokens, provider_name
));
max_tokens
} else {
@@ -1735,11 +1790,10 @@ impl<W: UiWriter> Agent<W> {
}
"anthropic" => {
// Claude models have large context windows
// Use configured max_tokens or fall back to default
if let Some(max_tokens) = Self::provider_max_tokens(config, "anthropic") {
if let Some(max_tokens) = Self::provider_max_tokens(config, provider_name) {
warnings.push(format!(
"Context length falling back to max_tokens ({}) for provider=anthropic",
max_tokens
"Context length falling back to max_tokens ({}) for provider={}",
max_tokens, provider_name
));
max_tokens
} else {
@@ -1748,11 +1802,10 @@ impl<W: UiWriter> Agent<W> {
}
"databricks" => {
// Databricks models have varying context windows depending on the model
// Use configured max_tokens or fall back to model-specific defaults
if let Some(max_tokens) = Self::provider_max_tokens(config, "databricks") {
if let Some(max_tokens) = Self::provider_max_tokens(config, provider_name) {
warnings.push(format!(
"Context length falling back to max_tokens ({}) for provider=databricks",
max_tokens
"Context length falling back to max_tokens ({}) for provider={}",
max_tokens, provider_name
));
max_tokens
} else if model_name.contains("claude") {
@@ -1948,14 +2001,18 @@ impl<W: UiWriter> Agent<W> {
// Check if we should use cache control (every 10 tool calls)
// But only if we haven't already added 4 cache_control annotations
let provider = self.providers.get(None)?;
if let Some(cache_config) = match provider.name() {
"anthropic" => self
.config
.providers
.anthropic
.as_ref()
.and_then(|c| c.cache_config.as_ref())
.and_then(|config| Self::parse_cache_control(config)),
let provider_name = provider.name();
let provider_type = provider_name.split('.').next().unwrap_or("");
let config_name = provider_name.split('.').nth(1).unwrap_or("default");
if let Some(cache_config) = match provider_type {
"anthropic" => {
self.config
.providers
.anthropic
.get(config_name)
.and_then(|c| c.cache_config.as_ref())
.and_then(|config| Self::parse_cache_control(config))
}
_ => None,
} {
Message::with_cache_control_validated(
@@ -2451,7 +2508,7 @@ impl<W: UiWriter> Agent<W> {
// Apply provider-specific caps
// For Anthropic with thinking enabled, we need max_tokens > thinking.budget_tokens
// So we set a higher cap when thinking is configured
let anthropic_cap = match self.get_thinking_budget_tokens() {
let anthropic_cap = match self.get_thinking_budget_tokens(&provider_name) {
Some(budget) => (budget + 2000).max(10_000), // At least budget + 2000 for response
None => 10_000,
};
@@ -3485,7 +3542,7 @@ impl<W: UiWriter> Agent<W> {
// Apply provider-specific caps
// For Anthropic with thinking enabled, we need max_tokens > thinking.budget_tokens
// So we set a higher cap when thinking is configured
let anthropic_cap = match self.get_thinking_budget_tokens() {
let anthropic_cap = match self.get_thinking_budget_tokens(&provider_name) {
Some(budget) => (budget + 2000).max(10_000), // At least budget + 2000 for response
None => 10_000,
};
@@ -4078,14 +4135,18 @@ impl<W: UiWriter> Agent<W> {
&& self.count_cache_controls_in_history() < 4
{
let provider = self.providers.get(None)?;
if let Some(cache_config) = match provider.name() {
"anthropic" => self
.config
.providers
.anthropic
.as_ref()
.and_then(|c| c.cache_config.as_ref())
.and_then(|config| Self::parse_cache_control(config)),
let provider_name = provider.name();
let provider_type = provider_name.split('.').next().unwrap_or("");
let config_name = provider_name.split('.').nth(1).unwrap_or("default");
if let Some(cache_config) = match provider_type {
"anthropic" => {
self.config
.providers
.anthropic
.get(config_name)
.and_then(|c| c.cache_config.as_ref())
.and_then(|config| Self::parse_cache_control(config))
}
_ => None,
} {
Message::with_cache_control_validated(
@@ -5118,8 +5179,8 @@ impl<W: UiWriter> Agent<W> {
}
"todo_read" => {
debug!("Processing todo_read tool call");
// Read from todo.g3.md file in current workspace directory
let todo_path = std::env::current_dir()?.join("todo.g3.md");
// Read from todo.g3.md file (uses G3_TODO_PATH env var if set, else current dir)
let todo_path = get_todo_path();
if !todo_path.exists() {
// Also update in-memory content to stay in sync
@@ -5233,7 +5294,7 @@ impl<W: UiWriter> Agent<W> {
// If all todos are complete, delete the file instead of writing
if !has_incomplete && (content_str.contains("- [x]") || content_str.contains("- [X]")) {
let todo_path = std::env::current_dir()?.join("todo.g3.md");
let todo_path = get_todo_path();
if todo_path.exists() {
match std::fs::remove_file(&todo_path) {
Ok(_) => {
@@ -5253,8 +5314,8 @@ impl<W: UiWriter> Agent<W> {
}
}
// Write to todo.g3.md file in current workspace directory
let todo_path = std::env::current_dir()?.join("todo.g3.md");
// Write to todo.g3.md file (uses G3_TODO_PATH env var if set, else current dir)
let todo_path = get_todo_path();
match std::fs::write(&todo_path, content_str) {
Ok(_) => {