Fix embedded provider initialization and logging

- Use global OnceLock for llama.cpp backend to prevent BackendAlreadyInitialized error
- Suppress verbose llama.cpp stderr logging during model loading
- Fix provider validation to accept "embedded.name" format (extract type before dot)
This commit is contained in:
Dhanji R. Prasanna
2026-01-28 10:33:10 +11:00
parent ba6e1f9896
commit e32c302023
2 changed files with 42 additions and 6 deletions

View File

@@ -139,9 +139,10 @@ pub fn load_config_with_cli_overrides(cli: &Cli) -> Result<Config> {
// Validate provider if specified // Validate provider if specified
if let Some(ref provider) = cli.provider { if let Some(ref provider) = cli.provider {
let valid_providers = ["anthropic", "databricks", "embedded", "openai"]; let valid_providers = ["anthropic", "databricks", "embedded", "openai"];
if !valid_providers.contains(&provider.as_str()) { let provider_type = provider.split('.').next().unwrap_or(provider);
if !valid_providers.contains(&provider_type) {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
"Invalid provider '{}'. Valid options: {:?}", "Invalid provider '{}'. Provider type must be one of: {:?}",
provider, provider,
valid_providers valid_providers
)); ));

View File

@@ -14,10 +14,46 @@ use llama_cpp_2::{
use std::num::NonZeroU32; use std::num::NonZeroU32;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use std::sync::OnceLock;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream; use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error}; use tracing::{debug, error};
/// Global llama.cpp backend - can only be initialized once per process
static LLAMA_BACKEND: OnceLock<Arc<LlamaBackend>> = OnceLock::new();
/// Get or initialize the global llama.cpp backend
fn get_or_init_backend() -> Result<Arc<LlamaBackend>> {
// Check if already initialized
if let Some(backend) = LLAMA_BACKEND.get() {
return Ok(Arc::clone(backend));
}
// Suppress llama.cpp's verbose logging to stderr before initialization
unsafe {
unsafe extern "C" fn void_log(
_level: std::ffi::c_int,
_text: *const std::os::raw::c_char,
_user_data: *mut std::os::raw::c_void,
) {
// Intentionally empty - suppress all llama.cpp logging
}
// Call the underlying C function directly
extern "C" { fn llama_log_set(log_callback: Option<unsafe extern "C" fn(std::ffi::c_int, *const std::os::raw::c_char, *mut std::os::raw::c_void)>, user_data: *mut std::os::raw::c_void); }
llama_log_set(Some(void_log), std::ptr::null_mut());
}
// Try to initialize
debug!("Initializing llama.cpp backend...");
let backend = LlamaBackend::init()
.map_err(|e| anyhow::anyhow!("Failed to initialize llama.cpp backend: {:?}", e))?;
// Store it (ignore if another thread beat us to it)
let _ = LLAMA_BACKEND.set(Arc::new(backend));
let backend = LLAMA_BACKEND.get().expect("backend was just set");
Ok(Arc::clone(backend))
}
/// Embedded LLM provider using llama.cpp with Metal acceleration on macOS. /// Embedded LLM provider using llama.cpp with Metal acceleration on macOS.
/// ///
/// Supports multiple model families with their native chat templates: /// Supports multiple model families with their native chat templates:
@@ -103,9 +139,8 @@ impl EmbeddedProvider {
anyhow::bail!("Model file not found: {}", model_path_buf.display()); anyhow::bail!("Model file not found: {}", model_path_buf.display());
} }
// Initialize the llama.cpp backend // Get or initialize the global llama.cpp backend
let backend = LlamaBackend::init() let backend = get_or_init_backend()?;
.map_err(|e| anyhow::anyhow!("Failed to initialize llama.cpp backend: {:?}", e))?;
// Set up model parameters // Set up model parameters
let n_gpu_layers = gpu_layers.unwrap_or(99); let n_gpu_layers = gpu_layers.unwrap_or(99);
@@ -130,7 +165,7 @@ impl EmbeddedProvider {
Ok(Self { Ok(Self {
name, name,
model: Arc::new(model), model: Arc::new(model),
backend: Arc::new(backend), backend,
model_type: model_type.to_lowercase(), model_type: model_type.to_lowercase(),
model_name: format!("embedded-{}", model_type), model_name: format!("embedded-{}", model_type),
max_tokens, max_tokens,