Merge pull request #25 from dhanji/fix_max_tokens

fix bad max_tokens and context_window logic
This commit is contained in:
Dhanji R. Prasanna
2025-11-19 15:55:34 +11:00
committed by GitHub
4 changed files with 119 additions and 44 deletions

View File

@@ -22,12 +22,13 @@ use_oauth = true
[providers.anthropic] [providers.anthropic]
api_key = "your-anthropic-api-key" api_key = "your-anthropic-api-key"
model = "claude-3-haiku-20240307" # Using a faster model for player model = "claude-sonnet-4-5"
max_tokens = 4096 max_tokens = 4096
temperature = 0.3 # Slightly higher temperature for more creative implementations temperature = 0.3 # Slightly higher temperature for more creative implementations
# cache_config = "ephemeral" # Optional: Enable prompt caching # cache_config = "ephemeral" # Optional: Enable prompt caching
# Options: "ephemeral", "5minute", "1hour" # Options: "ephemeral", "5minute", "1hour"
# Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs. # Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs.
# enable_1m_context = true # optional, more expensive
[agent] [agent]
fallback_default_max_tokens = 8192 fallback_default_max_tokens = 8192

View File

@@ -14,14 +14,16 @@ max_tokens = 4096 # Per-request output limit (how many tokens the model can gen
# Note: This is different from max_context_length (total conversation history size) # Note: This is different from max_context_length (total conversation history size)
temperature = 0.1 temperature = 0.1
use_oauth = true use_oauth = true
# cache_config = "ephemeral" # Optional: Enable prompt caching for Claude models on Databricks
# Options: "ephemeral", "5minute", "1hour" [providers.anthropic]
# Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs. api_key = "your-anthropic-api-key"
# The cache control will be automatically applied to: model = "claude-sonnet-4-5"
# - The system prompt at the start of each session max_tokens = 4096
# - Assistant responses after every 10 tool calls temperature = 0.3 # Slightly higher temperature for more creative implementations
# - 5minute costs $3/mtok, more details below # cache_config = "ephemeral" # Optional: Enable prompt caching
# https://docs.claude.com/en/docs/build-with-claude/prompt-caching#pricing # Options: "ephemeral", "5minute", "1hour"
# Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs.
# enable_1m_context = true # optional, more expensive
# Multiple OpenAI-compatible providers can be configured with custom names # Multiple OpenAI-compatible providers can be configured with custom names

View File

@@ -1686,6 +1686,9 @@ async fn run_autonomous(
turn, max_turns turn, max_turns
)); ));
// Surface provider info for player agent
agent.print_provider_banner("Player");
// Player mode: implement requirements (with coach feedback if available) // Player mode: implement requirements (with coach feedback if available)
let player_prompt = if coach_feedback.is_empty() { let player_prompt = if coach_feedback.is_empty() {
format!( format!(
@@ -1879,6 +1882,9 @@ async fn run_autonomous(
let mut coach_agent = let mut coach_agent =
Agent::new_autonomous_with_readme_and_quiet(coach_config, ui_writer, None, quiet).await?; Agent::new_autonomous_with_readme_and_quiet(coach_config, ui_writer, None, quiet).await?;
// Surface provider info for coach agent
coach_agent.print_provider_banner("Coach");
// Ensure coach agent is also in the workspace directory // Ensure coach agent is also in the workspace directory
project.enter_workspace()?; project.enter_workspace()?;

View File

@@ -938,9 +938,16 @@ impl<W: UiWriter> Agent<W> {
debug!("Default provider set successfully"); debug!("Default provider set successfully");
// Determine context window size based on active provider // Determine context window size based on active provider
let context_length = Self::get_configured_context_length(&config, &providers)?; let mut context_warnings = Vec::new();
let context_length =
Self::get_configured_context_length(&config, &providers, &mut context_warnings)?;
let mut context_window = ContextWindow::new(context_length); let mut context_window = ContextWindow::new(context_length);
// Surface any context warnings to the user via UI
for warning in context_warnings {
ui_writer.print_context_status(&format!("⚠️ {}", warning));
}
// If README content is provided, add it as the first system message // If README content is provided, add it as the first system message
if let Some(readme) = readme_content { if let Some(readme) = readme_content {
let readme_message = Message::new(MessageRole::System, readme); let readme_message = Message::new(MessageRole::System, readme);
@@ -1016,24 +1023,72 @@ impl<W: UiWriter> Agent<W> {
} }
} }
fn get_configured_context_length(config: &Config, providers: &ProviderRegistry) -> Result<u32> { /// Get the configured max_tokens for a provider from top-level config
fn provider_max_tokens(config: &Config, provider_name: &str) -> Option<u32> {
match provider_name {
"anthropic" => config.providers.anthropic.as_ref()?.max_tokens,
"openai" => config.providers.openai.as_ref()?.max_tokens,
"databricks" => config.providers.databricks.as_ref()?.max_tokens,
"embedded" => config.providers.embedded.as_ref()?.max_tokens,
_ => None,
}
}
/// Resolve the max_tokens to use for a given provider, applying fallbacks
fn resolve_max_tokens(&self, provider_name: &str) -> u32 {
match provider_name {
"databricks" => Self::provider_max_tokens(&self.config, "databricks")
.or(Some(self.config.agent.fallback_default_max_tokens as u32))
.unwrap_or(32000),
other => Self::provider_max_tokens(&self.config, other)
.or(Some(self.config.agent.fallback_default_max_tokens as u32))
.unwrap_or(16000),
}
}
/// Print provider diagnostics through the UiWriter for visibility
pub fn print_provider_banner(&self, role_label: &str) {
if let Ok((provider_name, model)) = self.get_provider_info() {
let max_tokens = self.resolve_max_tokens(&provider_name);
let context_len = self.context_window.total_tokens;
let mut details = vec![
format!("provider={}", provider_name),
format!("model={}", model),
format!("max_tokens={}", max_tokens),
format!("context_window_length={}", context_len),
];
if let Ok(provider) = self.providers.get(None) {
details.push(format!(
"native_tools={}",
if provider.has_native_tool_calling() {
"yes"
} else {
"no"
}
));
if provider.supports_cache_control() {
details.push("cache_control=yes".to_string());
}
}
self.ui_writer
.print_context_status(&format!("{}: {}", role_label, details.join(", ")));
}
}
fn get_configured_context_length(
config: &Config,
providers: &ProviderRegistry,
warnings: &mut Vec<String>,
) -> Result<u32> {
// First, check if there's a global max_context_length override in agent config // First, check if there's a global max_context_length override in agent config
if let Some(max_context_length) = config.agent.max_context_length { if let Some(max_context_length) = config.agent.max_context_length {
debug!("Using configured agent.max_context_length: {}", max_context_length); debug!("Using configured agent.max_context_length: {}", max_context_length);
return Ok(max_context_length); return Ok(max_context_length);
} }
// Get the configured max_tokens for the current provider
fn get_provider_max_tokens(config: &Config, provider_name: &str) -> Option<u32> {
match provider_name {
"anthropic" => config.providers.anthropic.as_ref()?.max_tokens,
"openai" => config.providers.openai.as_ref()?.max_tokens,
"databricks" => config.providers.databricks.as_ref()?.max_tokens,
"embedded" => config.providers.embedded.as_ref()?.max_tokens,
_ => None,
}
}
// Get the active provider to determine context length // Get the active provider to determine context length
let provider = providers.get(None)?; let provider = providers.get(None)?;
let provider_name = provider.name(); let provider_name = provider.name();
@@ -1060,25 +1115,45 @@ impl<W: UiWriter> Agent<W> {
} }
"openai" => { "openai" => {
// gpt-5 has 400k window // gpt-5 has 400k window
get_provider_max_tokens(config, "openai").unwrap_or(400000) if let Some(max_tokens) = Self::provider_max_tokens(config, "openai") {
warnings.push(format!(
"Context length falling back to max_tokens ({}) for provider=openai",
max_tokens
));
max_tokens
} else {
400000
}
} }
"anthropic" => { "anthropic" => {
// Claude models have large context windows // Claude models have large context windows
// Use configured max_tokens or fall back to default // Use configured max_tokens or fall back to default
get_provider_max_tokens(config, "anthropic").unwrap_or(200000) if let Some(max_tokens) = Self::provider_max_tokens(config, "anthropic") {
warnings.push(format!(
"Context length falling back to max_tokens ({}) for provider=anthropic",
max_tokens
));
max_tokens
} else {
200000
}
} }
"databricks" => { "databricks" => {
// Databricks models have varying context windows depending on the model // Databricks models have varying context windows depending on the model
// Use configured max_tokens or fall back to model-specific defaults // Use configured max_tokens or fall back to model-specific defaults
get_provider_max_tokens(config, "databricks").unwrap_or_else(|| { if let Some(max_tokens) = Self::provider_max_tokens(config, "databricks") {
if model_name.contains("claude") { warnings.push(format!(
200000 // Claude models on Databricks have large context windows "Context length falling back to max_tokens ({}) for provider=databricks",
} else if model_name.contains("llama") || model_name.contains("dbrx") { max_tokens
32768 // DBRX supports 32k context ));
} else { max_tokens
16384 // Conservative default for other Databricks models } else if model_name.contains("claude") {
} 200000 // Claude models on Databricks have large context windows
}) } else if model_name.contains("llama") || model_name.contains("dbrx") {
32768 // DBRX supports 32k context
} else {
16384 // Conservative default for other Databricks models
}
} }
_ => config.agent.fallback_default_max_tokens as u32, _ => config.agent.fallback_default_max_tokens as u32,
}; };
@@ -1548,17 +1623,8 @@ If you can complete it with 1-2 tool calls, skip TODO.
}; };
drop(provider); // Drop the provider reference to avoid borrowing issues drop(provider); // Drop the provider reference to avoid borrowing issues
// Get max_tokens from provider configuration // Get max_tokens from provider configuration, falling back to sensible defaults
let max_tokens = match provider_name.as_str() { let max_tokens = Some(self.resolve_max_tokens(&provider_name));
"databricks" => {
// Use the model's maximum limit for Databricks to allow large file generation
Some(32000)
}
_ => {
// Default for other providers
Some(16000)
}
};
let request = CompletionRequest { let request = CompletionRequest {
messages, messages,