Auto-detect context window size from GGUF for embedded providers
- Add context_window_size() method to LLMProvider trait - Implement for EmbeddedProvider to return the auto-detected context length - Update Agent to query provider directly instead of using hardcoded defaults - Removes need for model-specific context length mappings
This commit is contained in:
@@ -651,24 +651,19 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
let model_name = provider.model();
|
let model_name = provider.model();
|
||||||
|
|
||||||
// Parse provider name to get type and config name
|
// Parse provider name to get type and config name
|
||||||
let (provider_type, config_name) = provider_config::parse_provider_ref(provider_name);
|
let (provider_type, _config_name) = provider_config::parse_provider_ref(provider_name);
|
||||||
|
|
||||||
// Use provider-specific context length if available
|
// Use provider-specific context length if available
|
||||||
let context_length = match provider_type {
|
let context_length = match provider_type {
|
||||||
"embedded" | "embedded." => {
|
"embedded" | "embedded." => {
|
||||||
// For embedded models, use the configured context_length or model-specific defaults
|
// For embedded models, query the provider directly for its context window
|
||||||
if let Some(embedded_config) = config.providers.embedded.get(config_name) {
|
// The provider auto-detects this from the GGUF file
|
||||||
embedded_config.context_length.unwrap_or_else(|| {
|
if let Some(ctx_size) = provider.context_window_size() {
|
||||||
// Model-specific defaults for embedded models
|
debug!(
|
||||||
match &embedded_config.model_type.to_lowercase()[..] {
|
"Using context window size {} from embedded provider",
|
||||||
"codellama" => 16384, // CodeLlama supports 16k context
|
ctx_size
|
||||||
"llama" => 4096, // Base Llama models
|
);
|
||||||
"glm4" => 32768, // GLM-4 supports 32k context
|
ctx_size
|
||||||
"mistral" => 8192, // Mistral models
|
|
||||||
"qwen" => 32768, // Qwen2.5 supports 32k context
|
|
||||||
_ => 4096, // Conservative default
|
|
||||||
}
|
|
||||||
})
|
|
||||||
} else {
|
} else {
|
||||||
config.agent.fallback_default_max_tokens as u32
|
config.agent.fallback_default_max_tokens as u32
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -700,6 +700,10 @@ impl LLMProvider for EmbeddedProvider {
|
|||||||
fn temperature(&self) -> f32 {
|
fn temperature(&self) -> f32 {
|
||||||
self.temperature
|
self.temperature
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn context_window_size(&self) -> Option<u32> {
|
||||||
|
Some(self.context_length)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -39,6 +39,12 @@ pub trait LLMProvider: Send + Sync {
|
|||||||
|
|
||||||
/// Get the configured temperature for this provider
|
/// Get the configured temperature for this provider
|
||||||
fn temperature(&self) -> f32;
|
fn temperature(&self) -> f32;
|
||||||
|
|
||||||
|
/// Get the context window size for this provider
|
||||||
|
/// Returns None if the provider doesn't have a fixed context window
|
||||||
|
fn context_window_size(&self) -> Option<u32> {
|
||||||
|
None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|||||||
Reference in New Issue
Block a user