fix bad max_tokens and context_window logic
for non-databricks code
This commit is contained in:
@@ -22,12 +22,13 @@ use_oauth = true
|
|||||||
|
|
||||||
[providers.anthropic]
|
[providers.anthropic]
|
||||||
api_key = "your-anthropic-api-key"
|
api_key = "your-anthropic-api-key"
|
||||||
model = "claude-3-haiku-20240307" # Using a faster model for player
|
model = "claude-sonnet-4-5"
|
||||||
max_tokens = 4096
|
max_tokens = 4096
|
||||||
temperature = 0.3 # Slightly higher temperature for more creative implementations
|
temperature = 0.3 # Slightly higher temperature for more creative implementations
|
||||||
# cache_config = "ephemeral" # Optional: Enable prompt caching
|
# cache_config = "ephemeral" # Optional: Enable prompt caching
|
||||||
# Options: "ephemeral", "5minute", "1hour"
|
# Options: "ephemeral", "5minute", "1hour"
|
||||||
# Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs.
|
# Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs.
|
||||||
|
# enable_1m_context = true # optional, more expensive
|
||||||
|
|
||||||
[agent]
|
[agent]
|
||||||
fallback_default_max_tokens = 8192
|
fallback_default_max_tokens = 8192
|
||||||
|
|||||||
@@ -14,14 +14,16 @@ max_tokens = 4096 # Per-request output limit (how many tokens the model can gen
|
|||||||
# Note: This is different from max_context_length (total conversation history size)
|
# Note: This is different from max_context_length (total conversation history size)
|
||||||
temperature = 0.1
|
temperature = 0.1
|
||||||
use_oauth = true
|
use_oauth = true
|
||||||
# cache_config = "ephemeral" # Optional: Enable prompt caching for Claude models on Databricks
|
|
||||||
# Options: "ephemeral", "5minute", "1hour"
|
[providers.anthropic]
|
||||||
# Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs.
|
api_key = "your-anthropic-api-key"
|
||||||
# The cache control will be automatically applied to:
|
model = "claude-sonnet-4-5"
|
||||||
# - The system prompt at the start of each session
|
max_tokens = 4096
|
||||||
# - Assistant responses after every 10 tool calls
|
temperature = 0.3 # Slightly higher temperature for more creative implementations
|
||||||
# - 5minute costs $3/mtok, more details below
|
# cache_config = "ephemeral" # Optional: Enable prompt caching
|
||||||
# https://docs.claude.com/en/docs/build-with-claude/prompt-caching#pricing
|
# Options: "ephemeral", "5minute", "1hour"
|
||||||
|
# Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs.
|
||||||
|
# enable_1m_context = true # optional, more expensive
|
||||||
|
|
||||||
|
|
||||||
# Multiple OpenAI-compatible providers can be configured with custom names
|
# Multiple OpenAI-compatible providers can be configured with custom names
|
||||||
|
|||||||
@@ -1686,6 +1686,9 @@ async fn run_autonomous(
|
|||||||
turn, max_turns
|
turn, max_turns
|
||||||
));
|
));
|
||||||
|
|
||||||
|
// Surface provider info for player agent
|
||||||
|
agent.print_provider_banner("Player");
|
||||||
|
|
||||||
// Player mode: implement requirements (with coach feedback if available)
|
// Player mode: implement requirements (with coach feedback if available)
|
||||||
let player_prompt = if coach_feedback.is_empty() {
|
let player_prompt = if coach_feedback.is_empty() {
|
||||||
format!(
|
format!(
|
||||||
@@ -1879,6 +1882,9 @@ async fn run_autonomous(
|
|||||||
let mut coach_agent =
|
let mut coach_agent =
|
||||||
Agent::new_autonomous_with_readme_and_quiet(coach_config, ui_writer, None, quiet).await?;
|
Agent::new_autonomous_with_readme_and_quiet(coach_config, ui_writer, None, quiet).await?;
|
||||||
|
|
||||||
|
// Surface provider info for coach agent
|
||||||
|
coach_agent.print_provider_banner("Coach");
|
||||||
|
|
||||||
// Ensure coach agent is also in the workspace directory
|
// Ensure coach agent is also in the workspace directory
|
||||||
project.enter_workspace()?;
|
project.enter_workspace()?;
|
||||||
|
|
||||||
|
|||||||
@@ -938,9 +938,16 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
debug!("Default provider set successfully");
|
debug!("Default provider set successfully");
|
||||||
|
|
||||||
// Determine context window size based on active provider
|
// Determine context window size based on active provider
|
||||||
let context_length = Self::get_configured_context_length(&config, &providers)?;
|
let mut context_warnings = Vec::new();
|
||||||
|
let context_length =
|
||||||
|
Self::get_configured_context_length(&config, &providers, &mut context_warnings)?;
|
||||||
let mut context_window = ContextWindow::new(context_length);
|
let mut context_window = ContextWindow::new(context_length);
|
||||||
|
|
||||||
|
// Surface any context warnings to the user via UI
|
||||||
|
for warning in context_warnings {
|
||||||
|
ui_writer.print_context_status(&format!("⚠️ {}", warning));
|
||||||
|
}
|
||||||
|
|
||||||
// If README content is provided, add it as the first system message
|
// If README content is provided, add it as the first system message
|
||||||
if let Some(readme) = readme_content {
|
if let Some(readme) = readme_content {
|
||||||
let readme_message = Message::new(MessageRole::System, readme);
|
let readme_message = Message::new(MessageRole::System, readme);
|
||||||
@@ -1016,15 +1023,8 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_configured_context_length(config: &Config, providers: &ProviderRegistry) -> Result<u32> {
|
/// Get the configured max_tokens for a provider from top-level config
|
||||||
// First, check if there's a global max_context_length override in agent config
|
fn provider_max_tokens(config: &Config, provider_name: &str) -> Option<u32> {
|
||||||
if let Some(max_context_length) = config.agent.max_context_length {
|
|
||||||
debug!("Using configured agent.max_context_length: {}", max_context_length);
|
|
||||||
return Ok(max_context_length);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the configured max_tokens for the current provider
|
|
||||||
fn get_provider_max_tokens(config: &Config, provider_name: &str) -> Option<u32> {
|
|
||||||
match provider_name {
|
match provider_name {
|
||||||
"anthropic" => config.providers.anthropic.as_ref()?.max_tokens,
|
"anthropic" => config.providers.anthropic.as_ref()?.max_tokens,
|
||||||
"openai" => config.providers.openai.as_ref()?.max_tokens,
|
"openai" => config.providers.openai.as_ref()?.max_tokens,
|
||||||
@@ -1034,6 +1034,61 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Resolve the max_tokens to use for a given provider, applying fallbacks
|
||||||
|
fn resolve_max_tokens(&self, provider_name: &str) -> u32 {
|
||||||
|
match provider_name {
|
||||||
|
"databricks" => Self::provider_max_tokens(&self.config, "databricks")
|
||||||
|
.or(Some(self.config.agent.fallback_default_max_tokens as u32))
|
||||||
|
.unwrap_or(32000),
|
||||||
|
other => Self::provider_max_tokens(&self.config, other)
|
||||||
|
.or(Some(self.config.agent.fallback_default_max_tokens as u32))
|
||||||
|
.unwrap_or(16000),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Print provider diagnostics through the UiWriter for visibility
|
||||||
|
pub fn print_provider_banner(&self, role_label: &str) {
|
||||||
|
if let Ok((provider_name, model)) = self.get_provider_info() {
|
||||||
|
let max_tokens = self.resolve_max_tokens(&provider_name);
|
||||||
|
let context_len = self.context_window.total_tokens;
|
||||||
|
|
||||||
|
let mut details = vec![
|
||||||
|
format!("provider={}", provider_name),
|
||||||
|
format!("model={}", model),
|
||||||
|
format!("max_tokens={}", max_tokens),
|
||||||
|
format!("context_window_length={}", context_len),
|
||||||
|
];
|
||||||
|
|
||||||
|
if let Ok(provider) = self.providers.get(None) {
|
||||||
|
details.push(format!(
|
||||||
|
"native_tools={}",
|
||||||
|
if provider.has_native_tool_calling() {
|
||||||
|
"yes"
|
||||||
|
} else {
|
||||||
|
"no"
|
||||||
|
}
|
||||||
|
));
|
||||||
|
if provider.supports_cache_control() {
|
||||||
|
details.push("cache_control=yes".to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.ui_writer
|
||||||
|
.print_context_status(&format!("{}: {}", role_label, details.join(", ")));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_configured_context_length(
|
||||||
|
config: &Config,
|
||||||
|
providers: &ProviderRegistry,
|
||||||
|
warnings: &mut Vec<String>,
|
||||||
|
) -> Result<u32> {
|
||||||
|
// First, check if there's a global max_context_length override in agent config
|
||||||
|
if let Some(max_context_length) = config.agent.max_context_length {
|
||||||
|
debug!("Using configured agent.max_context_length: {}", max_context_length);
|
||||||
|
return Ok(max_context_length);
|
||||||
|
}
|
||||||
|
|
||||||
// Get the active provider to determine context length
|
// Get the active provider to determine context length
|
||||||
let provider = providers.get(None)?;
|
let provider = providers.get(None)?;
|
||||||
let provider_name = provider.name();
|
let provider_name = provider.name();
|
||||||
@@ -1060,25 +1115,45 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
}
|
}
|
||||||
"openai" => {
|
"openai" => {
|
||||||
// gpt-5 has 400k window
|
// gpt-5 has 400k window
|
||||||
get_provider_max_tokens(config, "openai").unwrap_or(400000)
|
if let Some(max_tokens) = Self::provider_max_tokens(config, "openai") {
|
||||||
|
warnings.push(format!(
|
||||||
|
"Context length falling back to max_tokens ({}) for provider=openai",
|
||||||
|
max_tokens
|
||||||
|
));
|
||||||
|
max_tokens
|
||||||
|
} else {
|
||||||
|
400000
|
||||||
|
}
|
||||||
}
|
}
|
||||||
"anthropic" => {
|
"anthropic" => {
|
||||||
// Claude models have large context windows
|
// Claude models have large context windows
|
||||||
// Use configured max_tokens or fall back to default
|
// Use configured max_tokens or fall back to default
|
||||||
get_provider_max_tokens(config, "anthropic").unwrap_or(200000)
|
if let Some(max_tokens) = Self::provider_max_tokens(config, "anthropic") {
|
||||||
|
warnings.push(format!(
|
||||||
|
"Context length falling back to max_tokens ({}) for provider=anthropic",
|
||||||
|
max_tokens
|
||||||
|
));
|
||||||
|
max_tokens
|
||||||
|
} else {
|
||||||
|
200000
|
||||||
|
}
|
||||||
}
|
}
|
||||||
"databricks" => {
|
"databricks" => {
|
||||||
// Databricks models have varying context windows depending on the model
|
// Databricks models have varying context windows depending on the model
|
||||||
// Use configured max_tokens or fall back to model-specific defaults
|
// Use configured max_tokens or fall back to model-specific defaults
|
||||||
get_provider_max_tokens(config, "databricks").unwrap_or_else(|| {
|
if let Some(max_tokens) = Self::provider_max_tokens(config, "databricks") {
|
||||||
if model_name.contains("claude") {
|
warnings.push(format!(
|
||||||
|
"Context length falling back to max_tokens ({}) for provider=databricks",
|
||||||
|
max_tokens
|
||||||
|
));
|
||||||
|
max_tokens
|
||||||
|
} else if model_name.contains("claude") {
|
||||||
200000 // Claude models on Databricks have large context windows
|
200000 // Claude models on Databricks have large context windows
|
||||||
} else if model_name.contains("llama") || model_name.contains("dbrx") {
|
} else if model_name.contains("llama") || model_name.contains("dbrx") {
|
||||||
32768 // DBRX supports 32k context
|
32768 // DBRX supports 32k context
|
||||||
} else {
|
} else {
|
||||||
16384 // Conservative default for other Databricks models
|
16384 // Conservative default for other Databricks models
|
||||||
}
|
}
|
||||||
})
|
|
||||||
}
|
}
|
||||||
_ => config.agent.fallback_default_max_tokens as u32,
|
_ => config.agent.fallback_default_max_tokens as u32,
|
||||||
};
|
};
|
||||||
@@ -1548,17 +1623,8 @@ If you can complete it with 1-2 tool calls, skip TODO.
|
|||||||
};
|
};
|
||||||
drop(provider); // Drop the provider reference to avoid borrowing issues
|
drop(provider); // Drop the provider reference to avoid borrowing issues
|
||||||
|
|
||||||
// Get max_tokens from provider configuration
|
// Get max_tokens from provider configuration, falling back to sensible defaults
|
||||||
let max_tokens = match provider_name.as_str() {
|
let max_tokens = Some(self.resolve_max_tokens(&provider_name));
|
||||||
"databricks" => {
|
|
||||||
// Use the model's maximum limit for Databricks to allow large file generation
|
|
||||||
Some(32000)
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
// Default for other providers
|
|
||||||
Some(16000)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let request = CompletionRequest {
|
let request = CompletionRequest {
|
||||||
messages,
|
messages,
|
||||||
|
|||||||
Reference in New Issue
Block a user