fix bad max_tokens and context_window logic

for non-databricks code
This commit is contained in:
Jochen
2025-11-19 13:51:16 +11:00
parent 3f21bdc7b2
commit 1069664e16
4 changed files with 119 additions and 44 deletions

View File

@@ -938,9 +938,16 @@ impl<W: UiWriter> Agent<W> {
debug!("Default provider set successfully");
// Determine context window size based on active provider
let context_length = Self::get_configured_context_length(&config, &providers)?;
let mut context_warnings = Vec::new();
let context_length =
Self::get_configured_context_length(&config, &providers, &mut context_warnings)?;
let mut context_window = ContextWindow::new(context_length);
// Surface any context warnings to the user via UI
for warning in context_warnings {
ui_writer.print_context_status(&format!("⚠️ {}", warning));
}
// If README content is provided, add it as the first system message
if let Some(readme) = readme_content {
let readme_message = Message::new(MessageRole::System, readme);
@@ -1016,24 +1023,72 @@ impl<W: UiWriter> Agent<W> {
}
}
fn get_configured_context_length(config: &Config, providers: &ProviderRegistry) -> Result<u32> {
/// Get the configured max_tokens for a provider from top-level config
fn provider_max_tokens(config: &Config, provider_name: &str) -> Option<u32> {
match provider_name {
"anthropic" => config.providers.anthropic.as_ref()?.max_tokens,
"openai" => config.providers.openai.as_ref()?.max_tokens,
"databricks" => config.providers.databricks.as_ref()?.max_tokens,
"embedded" => config.providers.embedded.as_ref()?.max_tokens,
_ => None,
}
}
/// Resolve the max_tokens to use for a given provider, applying fallbacks
fn resolve_max_tokens(&self, provider_name: &str) -> u32 {
match provider_name {
"databricks" => Self::provider_max_tokens(&self.config, "databricks")
.or(Some(self.config.agent.fallback_default_max_tokens as u32))
.unwrap_or(32000),
other => Self::provider_max_tokens(&self.config, other)
.or(Some(self.config.agent.fallback_default_max_tokens as u32))
.unwrap_or(16000),
}
}
/// Print provider diagnostics through the UiWriter for visibility
pub fn print_provider_banner(&self, role_label: &str) {
if let Ok((provider_name, model)) = self.get_provider_info() {
let max_tokens = self.resolve_max_tokens(&provider_name);
let context_len = self.context_window.total_tokens;
let mut details = vec![
format!("provider={}", provider_name),
format!("model={}", model),
format!("max_tokens={}", max_tokens),
format!("context_window_length={}", context_len),
];
if let Ok(provider) = self.providers.get(None) {
details.push(format!(
"native_tools={}",
if provider.has_native_tool_calling() {
"yes"
} else {
"no"
}
));
if provider.supports_cache_control() {
details.push("cache_control=yes".to_string());
}
}
self.ui_writer
.print_context_status(&format!("{}: {}", role_label, details.join(", ")));
}
}
fn get_configured_context_length(
config: &Config,
providers: &ProviderRegistry,
warnings: &mut Vec<String>,
) -> Result<u32> {
// First, check if there's a global max_context_length override in agent config
if let Some(max_context_length) = config.agent.max_context_length {
debug!("Using configured agent.max_context_length: {}", max_context_length);
return Ok(max_context_length);
}
// Get the configured max_tokens for the current provider
fn get_provider_max_tokens(config: &Config, provider_name: &str) -> Option<u32> {
match provider_name {
"anthropic" => config.providers.anthropic.as_ref()?.max_tokens,
"openai" => config.providers.openai.as_ref()?.max_tokens,
"databricks" => config.providers.databricks.as_ref()?.max_tokens,
"embedded" => config.providers.embedded.as_ref()?.max_tokens,
_ => None,
}
}
// Get the active provider to determine context length
let provider = providers.get(None)?;
let provider_name = provider.name();
@@ -1060,25 +1115,45 @@ impl<W: UiWriter> Agent<W> {
}
"openai" => {
// gpt-5 has 400k window
get_provider_max_tokens(config, "openai").unwrap_or(400000)
if let Some(max_tokens) = Self::provider_max_tokens(config, "openai") {
warnings.push(format!(
"Context length falling back to max_tokens ({}) for provider=openai",
max_tokens
));
max_tokens
} else {
400000
}
}
"anthropic" => {
// Claude models have large context windows
// Use configured max_tokens or fall back to default
get_provider_max_tokens(config, "anthropic").unwrap_or(200000)
if let Some(max_tokens) = Self::provider_max_tokens(config, "anthropic") {
warnings.push(format!(
"Context length falling back to max_tokens ({}) for provider=anthropic",
max_tokens
));
max_tokens
} else {
200000
}
}
"databricks" => {
// Databricks models have varying context windows depending on the model
// Use configured max_tokens or fall back to model-specific defaults
get_provider_max_tokens(config, "databricks").unwrap_or_else(|| {
if model_name.contains("claude") {
200000 // Claude models on Databricks have large context windows
} else if model_name.contains("llama") || model_name.contains("dbrx") {
32768 // DBRX supports 32k context
} else {
16384 // Conservative default for other Databricks models
}
})
if let Some(max_tokens) = Self::provider_max_tokens(config, "databricks") {
warnings.push(format!(
"Context length falling back to max_tokens ({}) for provider=databricks",
max_tokens
));
max_tokens
} else if model_name.contains("claude") {
200000 // Claude models on Databricks have large context windows
} else if model_name.contains("llama") || model_name.contains("dbrx") {
32768 // DBRX supports 32k context
} else {
16384 // Conservative default for other Databricks models
}
}
_ => config.agent.fallback_default_max_tokens as u32,
};
@@ -1548,17 +1623,8 @@ If you can complete it with 1-2 tool calls, skip TODO.
};
drop(provider); // Drop the provider reference to avoid borrowing issues
// Get max_tokens from provider configuration
let max_tokens = match provider_name.as_str() {
"databricks" => {
// Use the model's maximum limit for Databricks to allow large file generation
Some(32000)
}
_ => {
// Default for other providers
Some(16000)
}
};
// Get max_tokens from provider configuration, falling back to sensible defaults
let max_tokens = Some(self.resolve_max_tokens(&provider_name));
let request = CompletionRequest {
messages,