fix bad max_tokens and context_window logic
for non-databricks code
This commit is contained in:
@@ -22,12 +22,13 @@ use_oauth = true
|
||||
|
||||
[providers.anthropic]
|
||||
api_key = "your-anthropic-api-key"
|
||||
model = "claude-3-haiku-20240307" # Using a faster model for player
|
||||
model = "claude-sonnet-4-5"
|
||||
max_tokens = 4096
|
||||
temperature = 0.3 # Slightly higher temperature for more creative implementations
|
||||
# cache_config = "ephemeral" # Optional: Enable prompt caching
|
||||
# Options: "ephemeral", "5minute", "1hour"
|
||||
# Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs.
|
||||
# enable_1m_context = true # optional, more expensive
|
||||
|
||||
[agent]
|
||||
fallback_default_max_tokens = 8192
|
||||
|
||||
@@ -14,14 +14,16 @@ max_tokens = 4096 # Per-request output limit (how many tokens the model can gen
|
||||
# Note: This is different from max_context_length (total conversation history size)
|
||||
temperature = 0.1
|
||||
use_oauth = true
|
||||
# cache_config = "ephemeral" # Optional: Enable prompt caching for Claude models on Databricks
|
||||
|
||||
[providers.anthropic]
|
||||
api_key = "your-anthropic-api-key"
|
||||
model = "claude-sonnet-4-5"
|
||||
max_tokens = 4096
|
||||
temperature = 0.3 # Slightly higher temperature for more creative implementations
|
||||
# cache_config = "ephemeral" # Optional: Enable prompt caching
|
||||
# Options: "ephemeral", "5minute", "1hour"
|
||||
# Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs.
|
||||
# The cache control will be automatically applied to:
|
||||
# - The system prompt at the start of each session
|
||||
# - Assistant responses after every 10 tool calls
|
||||
# - 5minute costs $3/mtok, more details below
|
||||
# https://docs.claude.com/en/docs/build-with-claude/prompt-caching#pricing
|
||||
# enable_1m_context = true # optional, more expensive
|
||||
|
||||
|
||||
# Multiple OpenAI-compatible providers can be configured with custom names
|
||||
|
||||
@@ -1686,6 +1686,9 @@ async fn run_autonomous(
|
||||
turn, max_turns
|
||||
));
|
||||
|
||||
// Surface provider info for player agent
|
||||
agent.print_provider_banner("Player");
|
||||
|
||||
// Player mode: implement requirements (with coach feedback if available)
|
||||
let player_prompt = if coach_feedback.is_empty() {
|
||||
format!(
|
||||
@@ -1879,6 +1882,9 @@ async fn run_autonomous(
|
||||
let mut coach_agent =
|
||||
Agent::new_autonomous_with_readme_and_quiet(coach_config, ui_writer, None, quiet).await?;
|
||||
|
||||
// Surface provider info for coach agent
|
||||
coach_agent.print_provider_banner("Coach");
|
||||
|
||||
// Ensure coach agent is also in the workspace directory
|
||||
project.enter_workspace()?;
|
||||
|
||||
|
||||
@@ -938,9 +938,16 @@ impl<W: UiWriter> Agent<W> {
|
||||
debug!("Default provider set successfully");
|
||||
|
||||
// Determine context window size based on active provider
|
||||
let context_length = Self::get_configured_context_length(&config, &providers)?;
|
||||
let mut context_warnings = Vec::new();
|
||||
let context_length =
|
||||
Self::get_configured_context_length(&config, &providers, &mut context_warnings)?;
|
||||
let mut context_window = ContextWindow::new(context_length);
|
||||
|
||||
// Surface any context warnings to the user via UI
|
||||
for warning in context_warnings {
|
||||
ui_writer.print_context_status(&format!("⚠️ {}", warning));
|
||||
}
|
||||
|
||||
// If README content is provided, add it as the first system message
|
||||
if let Some(readme) = readme_content {
|
||||
let readme_message = Message::new(MessageRole::System, readme);
|
||||
@@ -1016,15 +1023,8 @@ impl<W: UiWriter> Agent<W> {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_configured_context_length(config: &Config, providers: &ProviderRegistry) -> Result<u32> {
|
||||
// First, check if there's a global max_context_length override in agent config
|
||||
if let Some(max_context_length) = config.agent.max_context_length {
|
||||
debug!("Using configured agent.max_context_length: {}", max_context_length);
|
||||
return Ok(max_context_length);
|
||||
}
|
||||
|
||||
// Get the configured max_tokens for the current provider
|
||||
fn get_provider_max_tokens(config: &Config, provider_name: &str) -> Option<u32> {
|
||||
/// Get the configured max_tokens for a provider from top-level config
|
||||
fn provider_max_tokens(config: &Config, provider_name: &str) -> Option<u32> {
|
||||
match provider_name {
|
||||
"anthropic" => config.providers.anthropic.as_ref()?.max_tokens,
|
||||
"openai" => config.providers.openai.as_ref()?.max_tokens,
|
||||
@@ -1034,6 +1034,61 @@ impl<W: UiWriter> Agent<W> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve the max_tokens to use for a given provider, applying fallbacks
|
||||
fn resolve_max_tokens(&self, provider_name: &str) -> u32 {
|
||||
match provider_name {
|
||||
"databricks" => Self::provider_max_tokens(&self.config, "databricks")
|
||||
.or(Some(self.config.agent.fallback_default_max_tokens as u32))
|
||||
.unwrap_or(32000),
|
||||
other => Self::provider_max_tokens(&self.config, other)
|
||||
.or(Some(self.config.agent.fallback_default_max_tokens as u32))
|
||||
.unwrap_or(16000),
|
||||
}
|
||||
}
|
||||
|
||||
/// Print provider diagnostics through the UiWriter for visibility
|
||||
pub fn print_provider_banner(&self, role_label: &str) {
|
||||
if let Ok((provider_name, model)) = self.get_provider_info() {
|
||||
let max_tokens = self.resolve_max_tokens(&provider_name);
|
||||
let context_len = self.context_window.total_tokens;
|
||||
|
||||
let mut details = vec![
|
||||
format!("provider={}", provider_name),
|
||||
format!("model={}", model),
|
||||
format!("max_tokens={}", max_tokens),
|
||||
format!("context_window_length={}", context_len),
|
||||
];
|
||||
|
||||
if let Ok(provider) = self.providers.get(None) {
|
||||
details.push(format!(
|
||||
"native_tools={}",
|
||||
if provider.has_native_tool_calling() {
|
||||
"yes"
|
||||
} else {
|
||||
"no"
|
||||
}
|
||||
));
|
||||
if provider.supports_cache_control() {
|
||||
details.push("cache_control=yes".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
self.ui_writer
|
||||
.print_context_status(&format!("{}: {}", role_label, details.join(", ")));
|
||||
}
|
||||
}
|
||||
|
||||
fn get_configured_context_length(
|
||||
config: &Config,
|
||||
providers: &ProviderRegistry,
|
||||
warnings: &mut Vec<String>,
|
||||
) -> Result<u32> {
|
||||
// First, check if there's a global max_context_length override in agent config
|
||||
if let Some(max_context_length) = config.agent.max_context_length {
|
||||
debug!("Using configured agent.max_context_length: {}", max_context_length);
|
||||
return Ok(max_context_length);
|
||||
}
|
||||
|
||||
// Get the active provider to determine context length
|
||||
let provider = providers.get(None)?;
|
||||
let provider_name = provider.name();
|
||||
@@ -1060,25 +1115,45 @@ impl<W: UiWriter> Agent<W> {
|
||||
}
|
||||
"openai" => {
|
||||
// gpt-5 has 400k window
|
||||
get_provider_max_tokens(config, "openai").unwrap_or(400000)
|
||||
if let Some(max_tokens) = Self::provider_max_tokens(config, "openai") {
|
||||
warnings.push(format!(
|
||||
"Context length falling back to max_tokens ({}) for provider=openai",
|
||||
max_tokens
|
||||
));
|
||||
max_tokens
|
||||
} else {
|
||||
400000
|
||||
}
|
||||
}
|
||||
"anthropic" => {
|
||||
// Claude models have large context windows
|
||||
// Use configured max_tokens or fall back to default
|
||||
get_provider_max_tokens(config, "anthropic").unwrap_or(200000)
|
||||
if let Some(max_tokens) = Self::provider_max_tokens(config, "anthropic") {
|
||||
warnings.push(format!(
|
||||
"Context length falling back to max_tokens ({}) for provider=anthropic",
|
||||
max_tokens
|
||||
));
|
||||
max_tokens
|
||||
} else {
|
||||
200000
|
||||
}
|
||||
}
|
||||
"databricks" => {
|
||||
// Databricks models have varying context windows depending on the model
|
||||
// Use configured max_tokens or fall back to model-specific defaults
|
||||
get_provider_max_tokens(config, "databricks").unwrap_or_else(|| {
|
||||
if model_name.contains("claude") {
|
||||
if let Some(max_tokens) = Self::provider_max_tokens(config, "databricks") {
|
||||
warnings.push(format!(
|
||||
"Context length falling back to max_tokens ({}) for provider=databricks",
|
||||
max_tokens
|
||||
));
|
||||
max_tokens
|
||||
} else if model_name.contains("claude") {
|
||||
200000 // Claude models on Databricks have large context windows
|
||||
} else if model_name.contains("llama") || model_name.contains("dbrx") {
|
||||
32768 // DBRX supports 32k context
|
||||
} else {
|
||||
16384 // Conservative default for other Databricks models
|
||||
}
|
||||
})
|
||||
}
|
||||
_ => config.agent.fallback_default_max_tokens as u32,
|
||||
};
|
||||
@@ -1548,17 +1623,8 @@ If you can complete it with 1-2 tool calls, skip TODO.
|
||||
};
|
||||
drop(provider); // Drop the provider reference to avoid borrowing issues
|
||||
|
||||
// Get max_tokens from provider configuration
|
||||
let max_tokens = match provider_name.as_str() {
|
||||
"databricks" => {
|
||||
// Use the model's maximum limit for Databricks to allow large file generation
|
||||
Some(32000)
|
||||
}
|
||||
_ => {
|
||||
// Default for other providers
|
||||
Some(16000)
|
||||
}
|
||||
};
|
||||
// Get max_tokens from provider configuration, falling back to sensible defaults
|
||||
let max_tokens = Some(self.resolve_max_tokens(&provider_name));
|
||||
|
||||
let request = CompletionRequest {
|
||||
messages,
|
||||
|
||||
Reference in New Issue
Block a user