fix bad max_tokens and context_window logic
for non-databricks code
This commit is contained in:
@@ -938,9 +938,16 @@ impl<W: UiWriter> Agent<W> {
|
||||
debug!("Default provider set successfully");
|
||||
|
||||
// Determine context window size based on active provider
|
||||
let context_length = Self::get_configured_context_length(&config, &providers)?;
|
||||
let mut context_warnings = Vec::new();
|
||||
let context_length =
|
||||
Self::get_configured_context_length(&config, &providers, &mut context_warnings)?;
|
||||
let mut context_window = ContextWindow::new(context_length);
|
||||
|
||||
// Surface any context warnings to the user via UI
|
||||
for warning in context_warnings {
|
||||
ui_writer.print_context_status(&format!("⚠️ {}", warning));
|
||||
}
|
||||
|
||||
// If README content is provided, add it as the first system message
|
||||
if let Some(readme) = readme_content {
|
||||
let readme_message = Message::new(MessageRole::System, readme);
|
||||
@@ -1016,24 +1023,72 @@ impl<W: UiWriter> Agent<W> {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_configured_context_length(config: &Config, providers: &ProviderRegistry) -> Result<u32> {
|
||||
/// Get the configured max_tokens for a provider from top-level config
|
||||
fn provider_max_tokens(config: &Config, provider_name: &str) -> Option<u32> {
|
||||
match provider_name {
|
||||
"anthropic" => config.providers.anthropic.as_ref()?.max_tokens,
|
||||
"openai" => config.providers.openai.as_ref()?.max_tokens,
|
||||
"databricks" => config.providers.databricks.as_ref()?.max_tokens,
|
||||
"embedded" => config.providers.embedded.as_ref()?.max_tokens,
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve the max_tokens to use for a given provider, applying fallbacks
|
||||
fn resolve_max_tokens(&self, provider_name: &str) -> u32 {
|
||||
match provider_name {
|
||||
"databricks" => Self::provider_max_tokens(&self.config, "databricks")
|
||||
.or(Some(self.config.agent.fallback_default_max_tokens as u32))
|
||||
.unwrap_or(32000),
|
||||
other => Self::provider_max_tokens(&self.config, other)
|
||||
.or(Some(self.config.agent.fallback_default_max_tokens as u32))
|
||||
.unwrap_or(16000),
|
||||
}
|
||||
}
|
||||
|
||||
/// Print provider diagnostics through the UiWriter for visibility
|
||||
pub fn print_provider_banner(&self, role_label: &str) {
|
||||
if let Ok((provider_name, model)) = self.get_provider_info() {
|
||||
let max_tokens = self.resolve_max_tokens(&provider_name);
|
||||
let context_len = self.context_window.total_tokens;
|
||||
|
||||
let mut details = vec![
|
||||
format!("provider={}", provider_name),
|
||||
format!("model={}", model),
|
||||
format!("max_tokens={}", max_tokens),
|
||||
format!("context_window_length={}", context_len),
|
||||
];
|
||||
|
||||
if let Ok(provider) = self.providers.get(None) {
|
||||
details.push(format!(
|
||||
"native_tools={}",
|
||||
if provider.has_native_tool_calling() {
|
||||
"yes"
|
||||
} else {
|
||||
"no"
|
||||
}
|
||||
));
|
||||
if provider.supports_cache_control() {
|
||||
details.push("cache_control=yes".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
self.ui_writer
|
||||
.print_context_status(&format!("{}: {}", role_label, details.join(", ")));
|
||||
}
|
||||
}
|
||||
|
||||
fn get_configured_context_length(
|
||||
config: &Config,
|
||||
providers: &ProviderRegistry,
|
||||
warnings: &mut Vec<String>,
|
||||
) -> Result<u32> {
|
||||
// First, check if there's a global max_context_length override in agent config
|
||||
if let Some(max_context_length) = config.agent.max_context_length {
|
||||
debug!("Using configured agent.max_context_length: {}", max_context_length);
|
||||
return Ok(max_context_length);
|
||||
}
|
||||
|
||||
// Get the configured max_tokens for the current provider
|
||||
fn get_provider_max_tokens(config: &Config, provider_name: &str) -> Option<u32> {
|
||||
match provider_name {
|
||||
"anthropic" => config.providers.anthropic.as_ref()?.max_tokens,
|
||||
"openai" => config.providers.openai.as_ref()?.max_tokens,
|
||||
"databricks" => config.providers.databricks.as_ref()?.max_tokens,
|
||||
"embedded" => config.providers.embedded.as_ref()?.max_tokens,
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
// Get the active provider to determine context length
|
||||
let provider = providers.get(None)?;
|
||||
let provider_name = provider.name();
|
||||
@@ -1060,25 +1115,45 @@ impl<W: UiWriter> Agent<W> {
|
||||
}
|
||||
"openai" => {
|
||||
// gpt-5 has 400k window
|
||||
get_provider_max_tokens(config, "openai").unwrap_or(400000)
|
||||
if let Some(max_tokens) = Self::provider_max_tokens(config, "openai") {
|
||||
warnings.push(format!(
|
||||
"Context length falling back to max_tokens ({}) for provider=openai",
|
||||
max_tokens
|
||||
));
|
||||
max_tokens
|
||||
} else {
|
||||
400000
|
||||
}
|
||||
}
|
||||
"anthropic" => {
|
||||
// Claude models have large context windows
|
||||
// Use configured max_tokens or fall back to default
|
||||
get_provider_max_tokens(config, "anthropic").unwrap_or(200000)
|
||||
if let Some(max_tokens) = Self::provider_max_tokens(config, "anthropic") {
|
||||
warnings.push(format!(
|
||||
"Context length falling back to max_tokens ({}) for provider=anthropic",
|
||||
max_tokens
|
||||
));
|
||||
max_tokens
|
||||
} else {
|
||||
200000
|
||||
}
|
||||
}
|
||||
"databricks" => {
|
||||
// Databricks models have varying context windows depending on the model
|
||||
// Use configured max_tokens or fall back to model-specific defaults
|
||||
get_provider_max_tokens(config, "databricks").unwrap_or_else(|| {
|
||||
if model_name.contains("claude") {
|
||||
200000 // Claude models on Databricks have large context windows
|
||||
} else if model_name.contains("llama") || model_name.contains("dbrx") {
|
||||
32768 // DBRX supports 32k context
|
||||
} else {
|
||||
16384 // Conservative default for other Databricks models
|
||||
}
|
||||
})
|
||||
if let Some(max_tokens) = Self::provider_max_tokens(config, "databricks") {
|
||||
warnings.push(format!(
|
||||
"Context length falling back to max_tokens ({}) for provider=databricks",
|
||||
max_tokens
|
||||
));
|
||||
max_tokens
|
||||
} else if model_name.contains("claude") {
|
||||
200000 // Claude models on Databricks have large context windows
|
||||
} else if model_name.contains("llama") || model_name.contains("dbrx") {
|
||||
32768 // DBRX supports 32k context
|
||||
} else {
|
||||
16384 // Conservative default for other Databricks models
|
||||
}
|
||||
}
|
||||
_ => config.agent.fallback_default_max_tokens as u32,
|
||||
};
|
||||
@@ -1548,17 +1623,8 @@ If you can complete it with 1-2 tool calls, skip TODO.
|
||||
};
|
||||
drop(provider); // Drop the provider reference to avoid borrowing issues
|
||||
|
||||
// Get max_tokens from provider configuration
|
||||
let max_tokens = match provider_name.as_str() {
|
||||
"databricks" => {
|
||||
// Use the model's maximum limit for Databricks to allow large file generation
|
||||
Some(32000)
|
||||
}
|
||||
_ => {
|
||||
// Default for other providers
|
||||
Some(16000)
|
||||
}
|
||||
};
|
||||
// Get max_tokens from provider configuration, falling back to sensible defaults
|
||||
let max_tokens = Some(self.resolve_max_tokens(&provider_name));
|
||||
|
||||
let request = CompletionRequest {
|
||||
messages,
|
||||
|
||||
Reference in New Issue
Block a user