Files
g3/crates/g3-core/src/provider_registration.rs
Dhanji R. Prasanna a902be1562 Refactor system prompts to eliminate duplication; upgrade embedded provider
- Refactor prompts.rs: extract shared sections (intro, TODO, workspace memory,
  web research, response guidelines) used by both native and non-native prompts
- Fix typo in native prompt: "save them.." -> "save them."
- Fix non-native prompt: add missing closing braces in JSON examples,
  add IMPORTANT steps section, align with native prompt quality
- Add 9 unit tests to verify both prompts contain required sections
- Upgrade llama-cpp-2 dependency and refactor embedded provider
- Update config.example.toml with embedded model examples
- Update workspace memory
2026-01-28 09:56:39 +11:00

235 lines
8.7 KiB
Rust

//! Provider registration logic for the Agent.
//!
//! This module handles the registration of LLM providers (Anthropic, OpenAI, Databricks, Embedded)
//! based on configuration. It consolidates the duplicated registration patterns into a single
//! cohesive module.
use anyhow::Result;
use g3_config::Config;
use g3_providers::ProviderRegistry;
use tracing::debug;
/// Determines which providers should be registered based on mode and configuration.
///
/// In autonomous mode, registers coach and player providers in addition to the default.
/// In normal mode, only registers the default provider.
pub fn determine_providers_to_register(config: &Config, is_autonomous: bool) -> Vec<String> {
if is_autonomous {
let mut providers = vec![config.providers.default_provider.clone()];
if let Some(coach) = &config.providers.coach {
if !providers.contains(coach) {
providers.push(coach.clone());
}
}
if let Some(player) = &config.providers.player {
if !providers.contains(player) {
providers.push(player.clone());
}
}
providers
} else {
vec![config.providers.default_provider.clone()]
}
}
/// Checks if a provider reference should be registered.
///
/// A provider should be registered if:
/// - Its full reference (e.g., "openai.default") is in the list, OR
/// - Any provider of that type is in the list (e.g., "openai.*")
fn should_register(providers_to_register: &[String], provider_type: &str, config_name: &str) -> bool {
let full_ref = format!("{}.{}", provider_type, config_name);
providers_to_register
.iter()
.any(|p| p == &full_ref || p.starts_with(&format!("{}.", provider_type)))
}
/// Registers all configured providers based on the providers_to_register list.
///
/// This is an async function because Databricks OAuth registration requires async.
pub async fn register_providers(
config: &Config,
providers_to_register: &[String],
) -> Result<ProviderRegistry> {
let mut registry = ProviderRegistry::new();
register_embedded_providers(config, providers_to_register, &mut registry)?;
register_openai_providers(config, providers_to_register, &mut registry)?;
register_openai_compatible_providers(config, providers_to_register, &mut registry)?;
register_anthropic_providers(config, providers_to_register, &mut registry)?;
register_databricks_providers(config, providers_to_register, &mut registry).await?;
// Set default provider
debug!(
"Setting default provider to: {}",
config.providers.default_provider
);
registry.set_default(&config.providers.default_provider)?;
debug!("Default provider set successfully");
Ok(registry)
}
/// Register embedded providers from configuration.
fn register_embedded_providers(
config: &Config,
providers_to_register: &[String],
registry: &mut ProviderRegistry,
) -> Result<()> {
for (name, embedded_config) in &config.providers.embedded {
if should_register(providers_to_register, "embedded", name) {
let embedded_provider = g3_providers::EmbeddedProvider::new_with_name(
format!("embedded.{}", name),
embedded_config.model_path.clone(),
embedded_config.model_type.clone(),
embedded_config.context_length,
embedded_config.max_tokens,
embedded_config.temperature,
embedded_config.gpu_layers,
embedded_config.threads,
)?;
registry.register(embedded_provider);
}
}
Ok(())
}
/// Register OpenAI providers from configuration.
fn register_openai_providers(
config: &Config,
providers_to_register: &[String],
registry: &mut ProviderRegistry,
) -> Result<()> {
for (name, openai_config) in &config.providers.openai {
if should_register(providers_to_register, "openai", name) {
let openai_provider = g3_providers::OpenAIProvider::new_with_name(
format!("openai.{}", name),
openai_config.api_key.clone(),
Some(openai_config.model.clone()),
openai_config.base_url.clone(),
openai_config.max_tokens,
openai_config.temperature,
)?;
registry.register(openai_provider);
}
}
Ok(())
}
/// Register OpenAI-compatible providers (e.g., OpenRouter, Groq) from configuration.
fn register_openai_compatible_providers(
config: &Config,
providers_to_register: &[String],
registry: &mut ProviderRegistry,
) -> Result<()> {
for (name, openai_config) in &config.providers.openai_compatible {
if should_register(providers_to_register, name, "default") {
let openai_provider = g3_providers::OpenAIProvider::new_with_name(
name.clone(),
openai_config.api_key.clone(),
Some(openai_config.model.clone()),
openai_config.base_url.clone(),
openai_config.max_tokens,
openai_config.temperature,
)?;
registry.register(openai_provider);
}
}
Ok(())
}
/// Register Anthropic providers from configuration.
fn register_anthropic_providers(
config: &Config,
providers_to_register: &[String],
registry: &mut ProviderRegistry,
) -> Result<()> {
for (name, anthropic_config) in &config.providers.anthropic {
if should_register(providers_to_register, "anthropic", name) {
let anthropic_provider = g3_providers::AnthropicProvider::new_with_name(
format!("anthropic.{}", name),
anthropic_config.api_key.clone(),
Some(anthropic_config.model.clone()),
anthropic_config.max_tokens,
anthropic_config.temperature,
anthropic_config.cache_config.clone(),
anthropic_config.enable_1m_context,
anthropic_config.thinking_budget_tokens,
)?;
registry.register(anthropic_provider);
}
}
Ok(())
}
/// Register Databricks providers from configuration.
///
/// This is async because OAuth authentication requires async operations.
async fn register_databricks_providers(
config: &Config,
providers_to_register: &[String],
registry: &mut ProviderRegistry,
) -> Result<()> {
for (name, databricks_config) in &config.providers.databricks {
if should_register(providers_to_register, "databricks", name) {
let databricks_provider = if let Some(token) = &databricks_config.token {
// Use token-based authentication
g3_providers::DatabricksProvider::from_token_with_name(
format!("databricks.{}", name),
databricks_config.host.clone(),
token.clone(),
databricks_config.model.clone(),
databricks_config.max_tokens,
databricks_config.temperature,
)?
} else {
// Use OAuth authentication
g3_providers::DatabricksProvider::from_oauth_with_name(
format!("databricks.{}", name),
databricks_config.host.clone(),
databricks_config.model.clone(),
databricks_config.max_tokens,
databricks_config.temperature,
)
.await?
};
registry.register(databricks_provider);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_should_register_exact_match() {
let providers = vec!["openai.default".to_string()];
assert!(should_register(&providers, "openai", "default"));
// When openai.default is in the list, ALL openai.* providers are registered
// This is intentional - the original code registered all providers of a type
assert!(should_register(&providers, "openai", "other"));
assert!(!should_register(&providers, "anthropic", "default"));
}
#[test]
fn test_should_register_type_prefix() {
let providers = vec!["openai.gpt4".to_string()];
// Any openai.* should match when we have openai.gpt4
assert!(should_register(&providers, "openai", "gpt4"));
assert!(should_register(&providers, "openai", "other")); // prefix match
assert!(!should_register(&providers, "anthropic", "default"));
}
#[test]
fn test_determine_providers_normal_mode() {
// Create a minimal config for testing
let config = Config::default();
let providers = determine_providers_to_register(&config, false);
assert_eq!(providers.len(), 1);
assert_eq!(providers[0], config.providers.default_provider);
}
}