embedded model support

This commit is contained in:
Dhanji Prasanna
2025-09-06 13:32:37 +10:00
parent 80e5178a1f
commit 1834b8946c
8 changed files with 793 additions and 14 deletions

View File

@@ -12,6 +12,7 @@ pub struct Config {
pub struct ProvidersConfig {
pub openai: Option<OpenAIConfig>,
pub anthropic: Option<AnthropicConfig>,
pub embedded: Option<EmbeddedConfig>,
pub default_provider: String,
}
@@ -32,6 +33,17 @@ pub struct AnthropicConfig {
pub temperature: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddedConfig {
pub model_path: String,
pub model_type: String, // e.g., "llama", "mistral", "codellama"
pub context_length: Option<u32>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub gpu_layers: Option<u32>, // Number of layers to offload to GPU
pub threads: Option<u32>, // Number of CPU threads to use
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
pub max_context_length: usize,
@@ -45,6 +57,7 @@ impl Default for Config {
providers: ProvidersConfig {
openai: None,
anthropic: None,
embedded: None,
default_provider: "openai".to_string(),
},
agent: AgentConfig {

View File

@@ -18,3 +18,5 @@ serde_json = { workspace = true }
uuid = { workspace = true }
async-trait = "0.1"
tokio-stream = "0.1"
llama_cpp = { version = "0.3.2", features = ["metal"] }
shellexpand = "3.1"

View File

@@ -72,6 +72,19 @@ impl Agent {
providers.register(anthropic_provider);
}
if let Some(embedded_config) = &config.providers.embedded {
let embedded_provider = crate::providers::embedded::EmbeddedProvider::new(
embedded_config.model_path.clone(),
embedded_config.model_type.clone(),
embedded_config.context_length,
embedded_config.max_tokens,
embedded_config.temperature,
embedded_config.gpu_layers,
embedded_config.threads,
)?;
providers.register(embedded_provider);
}
// Set default provider
providers.set_default(&config.providers.default_provider)?;
@@ -522,4 +535,5 @@ impl std::fmt::Display for AnalysisResult {
pub mod providers {
pub mod anthropic;
pub mod openai;
pub mod embedded;
}

View File

@@ -0,0 +1,362 @@
use g3_providers::{LLMProvider, CompletionRequest, CompletionResponse, CompletionStream, CompletionChunk, Usage, Message, MessageRole};
use anyhow::Result;
use llama_cpp::{LlamaModel, LlamaSession, LlamaParams, SessionParams, standard_sampler::{StandardSampler, SamplerStage}};
use std::path::Path;
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{debug, info, error, warn};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use std::sync::atomic::{AtomicBool, Ordering};
pub struct EmbeddedProvider {
model: Arc<LlamaModel>,
session: Arc<Mutex<LlamaSession>>,
model_name: String,
max_tokens: u32,
temperature: f32,
context_length: u32,
generation_active: Arc<AtomicBool>,
}
impl EmbeddedProvider {
pub fn new(
model_path: String,
model_type: String,
context_length: Option<u32>,
max_tokens: Option<u32>,
temperature: Option<f32>,
gpu_layers: Option<u32>,
threads: Option<u32>,
) -> Result<Self> {
info!("Loading embedded model from: {}", model_path);
// Expand tilde in path
let expanded_path = shellexpand::tilde(&model_path);
let model_path = Path::new(expanded_path.as_ref());
if !model_path.exists() {
anyhow::bail!("Model file not found: {}", model_path.display());
}
// Set up model parameters
let mut params = LlamaParams::default();
if let Some(gpu_layers) = gpu_layers {
params.n_gpu_layers = gpu_layers;
info!("Using {} GPU layers", gpu_layers);
}
let context_size = context_length.unwrap_or(4096);
info!("Using context length: {}", context_size);
// Load the model
info!("Loading model...");
let model = LlamaModel::load_from_file(model_path, params)
.map_err(|e| anyhow::anyhow!("Failed to load model: {}", e))?;
// Create session with parameters
let mut session_params = SessionParams::default();
session_params.n_ctx = context_size;
if let Some(threads) = threads {
session_params.n_threads = threads;
}
let session = model.create_session(session_params)
.map_err(|e| anyhow::anyhow!("Failed to create session: {}", e))?;
info!("Successfully loaded {} model", model_type);
Ok(Self {
model: Arc::new(model),
session: Arc::new(Mutex::new(session)),
model_name: format!("embedded-{}", model_type),
max_tokens: max_tokens.unwrap_or(2048),
temperature: temperature.unwrap_or(0.1),
context_length: context_size,
generation_active: Arc::new(AtomicBool::new(false)),
})
}
fn format_messages(&self, messages: &[Message]) -> String {
// Use proper prompt format for CodeLlama
let mut formatted = String::new();
for message in messages {
match message.role {
MessageRole::System => {
formatted.push_str(&format!("[INST] <<SYS>>\n{}\n<</SYS>>\n\n", message.content));
}
MessageRole::User => {
formatted.push_str(&format!("{} [/INST] ", message.content));
}
MessageRole::Assistant => {
formatted.push_str(&format!("{} </s><s>[INST] ", message.content));
}
}
}
formatted
}
async fn generate_completion(&self, prompt: &str, max_tokens: u32, temperature: f32) -> Result<String> {
let session = self.session.clone();
let prompt = prompt.to_string();
// Calculate dynamic max tokens based on available context headroom
let prompt_tokens = self.estimate_tokens(&prompt);
let available_tokens = self.context_length.saturating_sub(prompt_tokens).saturating_sub(50); // Reserve 50 tokens for safety
let dynamic_max_tokens = std::cmp::min(max_tokens as usize, available_tokens as usize);
debug!("Context calculation: prompt_tokens={}, context_length={}, available_tokens={}, dynamic_max_tokens={}",
prompt_tokens, self.context_length, available_tokens, dynamic_max_tokens);
// Add timeout to the entire operation
let timeout_duration = std::time::Duration::from_secs(30); // Increased timeout for larger contexts
let result = tokio::time::timeout(timeout_duration, tokio::task::spawn_blocking(move || {
let mut session = match session.try_lock() {
Ok(ctx) => ctx,
Err(_) => return Err(anyhow::anyhow!("Model is busy, please try again")),
};
debug!("Starting inference with prompt length: {} chars, estimated {} tokens", prompt.len(), prompt_tokens);
// Set context to the prompt
debug!("About to call set_context...");
session.set_context(&prompt)
.map_err(|e| anyhow::anyhow!("Failed to set context: {}", e))?;
debug!("set_context completed successfully");
// Create sampler with temperature
debug!("Creating sampler...");
let stages = vec![
SamplerStage::Temperature(temperature),
SamplerStage::TopK(40),
SamplerStage::TopP(0.9),
];
let sampler = StandardSampler::new_softmax(stages, 1);
debug!("Sampler created successfully");
// Start completion with dynamic max tokens
debug!("About to call start_completing_with with {} max tokens...", dynamic_max_tokens);
let mut completion_handle = session.start_completing_with(sampler, dynamic_max_tokens)
.map_err(|e| anyhow::anyhow!("Failed to start completion: {}", e))?;
debug!("start_completing_with completed successfully");
let mut generated_text = String::new();
let mut token_count = 0;
let start_time = std::time::Instant::now();
debug!("Starting token generation loop...");
// Generate tokens with dynamic limits
while let Some(token) = completion_handle.next_token() {
// Check for timeout on each token
if start_time.elapsed() > std::time::Duration::from_secs(25) {
debug!("Token generation timeout after {} tokens", token_count);
break;
}
let token_string = session.model().token_to_piece(token);
generated_text.push_str(&token_string);
token_count += 1;
if token_count <= 10 || token_count % 50 == 0 {
debug!("Generated token {}: '{}'", token_count, token_string);
}
// Use dynamic token limit
if token_count >= dynamic_max_tokens {
debug!("Reached dynamic token limit: {}", dynamic_max_tokens);
break;
}
// Stop on completion markers
if generated_text.contains("</s>") || generated_text.contains("[/INST]") {
debug!("Hit CodeLlama stop sequence at {} tokens", token_count);
break;
}
// Stop on natural completion points after reasonable generation
if token_count >= 20 && (
generated_text.trim().ends_with("```") ||
(generated_text.contains("```") && generated_text.matches("```").count() % 2 == 0) // Complete code blocks
) {
debug!("Hit code block completion at {} tokens", token_count);
break;
}
}
debug!("Token generation loop completed. Generated {} tokens in {:?}", token_count, start_time.elapsed());
Ok((generated_text.trim().to_string(), token_count))
})).await;
match result {
Ok(inner_result) => match inner_result {
Ok(task_result) => match task_result {
Ok((text, token_count)) => {
info!("Completed generation: {} tokens (dynamic limit was {})", token_count, dynamic_max_tokens);
Ok(text)
}
Err(e) => Err(e),
},
Err(e) => Err(e.into()),
},
Err(_) => {
error!("Generation timed out after 30 seconds");
Err(anyhow::anyhow!("Generation timed out"))
}
}
}
// Helper function to estimate token count from text
fn estimate_tokens(&self, text: &str) -> u32 {
// Rough estimation: average 4 characters per token
// This is conservative - actual tokenization might be different
(text.len() as f32 / 4.0).ceil() as u32
}
}
#[async_trait::async_trait]
impl LLMProvider for EmbeddedProvider {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
debug!("Processing completion request with {} messages", request.messages.len());
let prompt = self.format_messages(&request.messages);
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
debug!("Formatted prompt length: {} chars", prompt.len());
let content = self.generate_completion(&prompt, max_tokens, temperature).await?;
// Estimate token usage (rough approximation)
let prompt_tokens = (prompt.len() / 4) as u32; // Rough estimate: 4 chars per token
let completion_tokens = (content.len() / 4) as u32;
Ok(CompletionResponse {
content,
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
},
model: self.model_name.clone(),
})
}
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
debug!("Processing streaming request with {} messages", request.messages.len());
let prompt = self.format_messages(&request.messages);
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
let (tx, rx) = mpsc::channel(100);
let session = self.session.clone();
let prompt = prompt.to_string();
// Spawn streaming task
tokio::task::spawn_blocking(move || {
let mut session = match session.try_lock() {
Ok(ctx) => ctx,
Err(_) => {
let _ = tx.blocking_send(Err(anyhow::anyhow!("Model is busy, please try again")));
return;
}
};
// Set context to the prompt
if let Err(e) = session.set_context(&prompt) {
let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to set context: {}", e)));
return;
}
// Create sampler with temperature
let stages = vec![
SamplerStage::Temperature(temperature),
SamplerStage::TopK(40),
SamplerStage::TopP(0.9),
];
let sampler = StandardSampler::new_softmax(stages, 1);
// Start completion
let mut completion_handle = match session.start_completing_with(sampler, max_tokens as usize) {
Ok(handle) => handle,
Err(e) => {
let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to start completion: {}", e)));
return;
}
};
let mut accumulated_text = String::new();
let mut token_count = 0;
// Stream tokens with proper limits
while let Some(token) = completion_handle.next_token() {
let token_string = session.model().token_to_piece(token);
accumulated_text.push_str(&token_string);
token_count += 1;
let chunk = CompletionChunk {
content: token_string.clone(),
finished: false,
};
if tx.blocking_send(Ok(chunk)).is_err() {
break; // Receiver dropped
}
// Enforce token limit
if token_count >= max_tokens as usize {
debug!("Reached max token limit in streaming: {}", max_tokens);
break;
}
// Stop if we hit common stop sequences
if accumulated_text.contains("### Human") ||
accumulated_text.contains("### System") ||
accumulated_text.contains("<|end|>") ||
accumulated_text.contains("</s>") ||
accumulated_text.trim().ends_with("```") {
debug!("Hit stop sequence in streaming, stopping generation");
break;
}
// Emergency brake for streaming too
if token_count > 0 && token_count % 100 == 0 {
debug!("Streaming: Generated {} tokens so far", token_count);
if accumulated_text.trim().len() > 50 &&
(accumulated_text.contains('\n') || accumulated_text.len() > 200) {
if accumulated_text.trim().ends_with('.') ||
accumulated_text.trim().ends_with('!') ||
accumulated_text.trim().ends_with('?') ||
accumulated_text.trim().ends_with('\n') {
debug!("Found natural stopping point in streaming at {} tokens", token_count);
break;
}
}
}
}
// Send final chunk
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
};
let _ = tx.blocking_send(Ok(final_chunk));
});
Ok(ReceiverStream::new(rx))
}
fn name(&self) -> &str {
"embedded"
}
fn model(&self) -> &str {
&self.model_name
}
}