move embedded provider to a better crate
This commit is contained in:
@@ -18,8 +18,6 @@ serde_json = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
async-trait = "0.1"
|
||||
tokio-stream = "0.1"
|
||||
llama_cpp = { version = "0.3.2", features = ["metal"] }
|
||||
shellexpand = "3.1"
|
||||
tokio-util = "0.7"
|
||||
futures-util = "0.3"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
|
||||
@@ -408,7 +408,6 @@ mod tests {
|
||||
let truncated = truncate_for_logging(long_text, 20);
|
||||
assert!(truncated.starts_with("This is a very long "));
|
||||
assert!(truncated.contains("truncated"));
|
||||
assert!(truncated.contains("total chars"));
|
||||
assert!(truncated.contains("total bytes"));
|
||||
}
|
||||
|
||||
|
||||
@@ -348,7 +348,7 @@ impl Agent {
|
||||
if let Some(embedded_config) = &config.providers.embedded {
|
||||
if config.providers.default_provider == "embedded" {
|
||||
info!("Initializing embedded provider (selected as default)");
|
||||
let embedded_provider = crate::providers::embedded::EmbeddedProvider::new(
|
||||
let embedded_provider = g3_providers::EmbeddedProvider::new(
|
||||
embedded_config.model_path.clone(),
|
||||
embedded_config.model_type.clone(),
|
||||
embedded_config.context_length,
|
||||
@@ -736,12 +736,17 @@ The tool will execute immediately and you'll receive the result (success or erro
|
||||
// Update context window with estimated token usage
|
||||
self.context_window.update_usage(&mock_usage);
|
||||
|
||||
// Add assistant response to context window
|
||||
let assistant_message = Message {
|
||||
role: MessageRole::Assistant,
|
||||
content: response_content.clone(),
|
||||
};
|
||||
self.context_window.add_message(assistant_message);
|
||||
// Add assistant response to context window only if not empty
|
||||
// This prevents the "Skipping empty message" warning when only tools were executed
|
||||
if !response_content.trim().is_empty() {
|
||||
let assistant_message = Message {
|
||||
role: MessageRole::Assistant,
|
||||
content: response_content.clone(),
|
||||
};
|
||||
self.context_window.add_message(assistant_message);
|
||||
} else {
|
||||
debug!("Assistant response was empty (likely only tool execution), skipping message addition");
|
||||
}
|
||||
|
||||
// Save context window at the end of successful interaction
|
||||
self.save_context_window("completed");
|
||||
@@ -2450,10 +2455,6 @@ fn fix_mixed_quotes_in_json(json_str: &str) -> String {
|
||||
result
|
||||
}
|
||||
|
||||
pub mod providers {
|
||||
pub mod embedded;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::parse_unified_diff_hunks;
|
||||
|
||||
@@ -1,767 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use g3_providers::{
|
||||
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
|
||||
MessageRole, Usage,
|
||||
};
|
||||
use llama_cpp::{
|
||||
standard_sampler::{SamplerStage, StandardSampler},
|
||||
LlamaModel, LlamaParams, LlamaSession, SessionParams,
|
||||
};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
pub struct EmbeddedProvider {
|
||||
session: Arc<Mutex<LlamaSession>>,
|
||||
model_name: String,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
context_length: u32,
|
||||
}
|
||||
|
||||
impl EmbeddedProvider {
|
||||
pub fn new(
|
||||
model_path: String,
|
||||
model_type: String,
|
||||
context_length: Option<u32>,
|
||||
max_tokens: Option<u32>,
|
||||
temperature: Option<f32>,
|
||||
gpu_layers: Option<u32>,
|
||||
threads: Option<u32>,
|
||||
) -> Result<Self> {
|
||||
info!("Loading embedded model from: {}", model_path);
|
||||
|
||||
// Expand tilde in path
|
||||
let expanded_path = shellexpand::tilde(&model_path);
|
||||
let model_path_buf = PathBuf::from(expanded_path.as_ref());
|
||||
|
||||
// If model doesn't exist and it's the default Qwen model, offer to download it
|
||||
if !model_path_buf.exists() {
|
||||
if model_path.contains("qwen2.5-7b-instruct-q3_k_m.gguf") {
|
||||
info!("Model file not found. Attempting to download Qwen 2.5 7B model...");
|
||||
Self::download_qwen_model(&model_path_buf)?;
|
||||
} else {
|
||||
anyhow::bail!("Model file not found: {}", model_path_buf.display());
|
||||
}
|
||||
}
|
||||
|
||||
let model_path = model_path_buf.as_path();
|
||||
|
||||
// Set up model parameters
|
||||
let mut params = LlamaParams::default();
|
||||
|
||||
if let Some(gpu_layers) = gpu_layers {
|
||||
params.n_gpu_layers = gpu_layers;
|
||||
info!("Using {} GPU layers", gpu_layers);
|
||||
}
|
||||
|
||||
let context_size = context_length.unwrap_or(4096);
|
||||
info!("Using context length: {}", context_size);
|
||||
|
||||
// Load the model
|
||||
info!("Loading model...");
|
||||
let model = LlamaModel::load_from_file(model_path, params)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to load model: {}", e))?;
|
||||
|
||||
// Create session with parameters
|
||||
let mut session_params = SessionParams::default();
|
||||
session_params.n_ctx = context_size;
|
||||
if let Some(threads) = threads {
|
||||
session_params.n_threads = threads;
|
||||
}
|
||||
|
||||
let session = model
|
||||
.create_session(session_params)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to create session: {}", e))?;
|
||||
|
||||
info!("Successfully loaded {} model", model_type);
|
||||
|
||||
Ok(Self {
|
||||
session: Arc::new(Mutex::new(session)),
|
||||
model_name: format!("embedded-{}", model_type),
|
||||
max_tokens: max_tokens.unwrap_or(2048),
|
||||
temperature: temperature.unwrap_or(0.1),
|
||||
context_length: context_size,
|
||||
})
|
||||
}
|
||||
|
||||
fn format_messages(&self, messages: &[Message]) -> String {
|
||||
// Determine the appropriate format based on model type
|
||||
let model_name_lower = self.model_name.to_lowercase();
|
||||
|
||||
if model_name_lower.contains("qwen") {
|
||||
// Qwen format: <|im_start|>role\ncontent<|im_end|>
|
||||
let mut formatted = String::new();
|
||||
|
||||
for message in messages {
|
||||
let role = match message.role {
|
||||
MessageRole::System => "system",
|
||||
MessageRole::User => "user",
|
||||
MessageRole::Assistant => "assistant",
|
||||
};
|
||||
|
||||
formatted.push_str(&format!(
|
||||
"<|im_start|>{}\n{}<|im_end|>\n",
|
||||
role, message.content
|
||||
));
|
||||
}
|
||||
|
||||
// Add the start of assistant response
|
||||
formatted.push_str("<|im_start|>assistant\n");
|
||||
formatted
|
||||
} else if model_name_lower.contains("mistral") {
|
||||
// Mistral Instruct format: <s>[INST] ... [/INST] assistant_response</s>
|
||||
let mut formatted = String::new();
|
||||
let mut in_conversation = false;
|
||||
|
||||
for (i, message) in messages.iter().enumerate() {
|
||||
match message.role {
|
||||
MessageRole::System => {
|
||||
// Mistral doesn't have a special system token, include it at the start
|
||||
if i == 0 {
|
||||
formatted.push_str("<s>[INST] ");
|
||||
formatted.push_str(&message.content);
|
||||
formatted.push_str("\n\n");
|
||||
in_conversation = true;
|
||||
}
|
||||
}
|
||||
MessageRole::User => {
|
||||
if !in_conversation {
|
||||
formatted.push_str("<s>[INST] ");
|
||||
}
|
||||
formatted.push_str(&message.content);
|
||||
formatted.push_str(" [/INST]");
|
||||
in_conversation = false;
|
||||
}
|
||||
MessageRole::Assistant => {
|
||||
formatted.push_str(" ");
|
||||
formatted.push_str(&message.content);
|
||||
formatted.push_str("</s> ");
|
||||
in_conversation = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the last message was from user, add a space for the assistant's response
|
||||
if messages.last().map_or(false, |m| matches!(m.role, MessageRole::User)) {
|
||||
formatted.push_str(" ");
|
||||
}
|
||||
|
||||
formatted
|
||||
} else {
|
||||
// Use Llama/CodeLlama format for other models
|
||||
let mut formatted = String::new();
|
||||
|
||||
for message in messages {
|
||||
match message.role {
|
||||
MessageRole::System => {
|
||||
formatted.push_str(&format!(
|
||||
"[INST] <<SYS>>\n{}\n<</SYS>>\n\n",
|
||||
message.content
|
||||
));
|
||||
}
|
||||
MessageRole::User => {
|
||||
formatted.push_str(&format!("{} [/INST] ", message.content));
|
||||
}
|
||||
MessageRole::Assistant => {
|
||||
formatted.push_str(&format!("{} </s><s>[INST] ", message.content));
|
||||
}
|
||||
}
|
||||
}
|
||||
formatted
|
||||
}
|
||||
}
|
||||
|
||||
async fn generate_completion(
|
||||
&self,
|
||||
prompt: &str,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
) -> Result<String> {
|
||||
let session = self.session.clone();
|
||||
let prompt = prompt.to_string();
|
||||
|
||||
// Calculate dynamic max tokens based on available context headroom
|
||||
let prompt_tokens = self.estimate_tokens(&prompt);
|
||||
let available_tokens = self
|
||||
.context_length
|
||||
.saturating_sub(prompt_tokens)
|
||||
.saturating_sub(50); // Reserve 50 tokens for safety
|
||||
let dynamic_max_tokens = std::cmp::min(max_tokens as usize, available_tokens as usize);
|
||||
|
||||
debug!("Context calculation: prompt_tokens={}, context_length={}, available_tokens={}, dynamic_max_tokens={}",
|
||||
prompt_tokens, self.context_length, available_tokens, dynamic_max_tokens);
|
||||
|
||||
// Get stop sequences before entering the closure
|
||||
let stop_sequences = self.get_stop_sequences();
|
||||
|
||||
// Add timeout to the entire operation
|
||||
let timeout_duration = std::time::Duration::from_secs(30); // Increased timeout for larger contexts
|
||||
|
||||
let result = tokio::time::timeout(
|
||||
timeout_duration,
|
||||
tokio::task::spawn_blocking(move || {
|
||||
// Retry logic for acquiring the session lock
|
||||
let mut session_guard = None;
|
||||
for attempt in 0..5 {
|
||||
match session.try_lock() {
|
||||
Ok(ctx) => {
|
||||
session_guard = Some(ctx);
|
||||
break;
|
||||
}
|
||||
Err(_) => {
|
||||
if attempt < 4 {
|
||||
debug!("Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1);
|
||||
std::thread::sleep(std::time::Duration::from_millis(100 * (attempt + 1) as u64));
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("Model is busy after 5 attempts, please try again"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut session = session_guard.ok_or_else(|| anyhow::anyhow!("Failed to acquire session lock"))?;
|
||||
|
||||
debug!(
|
||||
"Starting inference with prompt length: {} chars, estimated {} tokens",
|
||||
prompt.len(),
|
||||
prompt_tokens
|
||||
);
|
||||
|
||||
// Set context to the prompt
|
||||
debug!("About to call set_context...");
|
||||
session
|
||||
.set_context(&prompt)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to set context: {}", e))?;
|
||||
debug!("set_context completed successfully");
|
||||
|
||||
// Create sampler with temperature
|
||||
debug!("Creating sampler...");
|
||||
let stages = vec![
|
||||
SamplerStage::Temperature(temperature),
|
||||
SamplerStage::TopK(40),
|
||||
SamplerStage::TopP(0.9),
|
||||
];
|
||||
let sampler = StandardSampler::new_softmax(stages, 1);
|
||||
debug!("Sampler created successfully");
|
||||
|
||||
// Start completion with dynamic max tokens
|
||||
debug!(
|
||||
"About to call start_completing_with with {} max tokens...",
|
||||
dynamic_max_tokens
|
||||
);
|
||||
let mut completion_handle = session
|
||||
.start_completing_with(sampler, dynamic_max_tokens)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to start completion: {}", e))?;
|
||||
debug!("start_completing_with completed successfully");
|
||||
|
||||
let mut generated_text = String::new();
|
||||
let mut token_count = 0;
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
debug!("Starting token generation loop...");
|
||||
|
||||
// Generate tokens with dynamic limits
|
||||
while let Some(token) = completion_handle.next_token() {
|
||||
// Check for timeout on each token
|
||||
if start_time.elapsed() > std::time::Duration::from_secs(25) {
|
||||
debug!("Token generation timeout after {} tokens", token_count);
|
||||
break;
|
||||
}
|
||||
|
||||
let token_string = session.model().token_to_piece(token);
|
||||
generated_text.push_str(&token_string);
|
||||
token_count += 1;
|
||||
|
||||
if token_count <= 10 || token_count % 50 == 0 {
|
||||
debug!("Generated token {}: '{}'", token_count, token_string);
|
||||
}
|
||||
|
||||
// Use dynamic token limit
|
||||
if token_count >= dynamic_max_tokens {
|
||||
debug!("Reached dynamic token limit: {}", dynamic_max_tokens);
|
||||
break;
|
||||
}
|
||||
|
||||
// Stop on completion markers
|
||||
let mut hit_stop = false;
|
||||
for stop_seq in &stop_sequences {
|
||||
if generated_text.contains(stop_seq) {
|
||||
debug!("Hit stop sequence '{}' at {} tokens", stop_seq, token_count);
|
||||
hit_stop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if hit_stop {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
debug!(
|
||||
"Token generation loop completed. Generated {} tokens in {:?}",
|
||||
token_count,
|
||||
start_time.elapsed()
|
||||
);
|
||||
|
||||
Ok((generated_text, token_count))
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(inner_result) => match inner_result {
|
||||
Ok(task_result) => match task_result {
|
||||
Ok((text, token_count)) => {
|
||||
info!(
|
||||
"Completed generation: {} tokens (dynamic limit was {})",
|
||||
token_count, dynamic_max_tokens
|
||||
);
|
||||
// Clean stop sequences from the generated text after the closure
|
||||
Ok(self.clean_stop_sequences(&text))
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
},
|
||||
Err(e) => Err(e.into()),
|
||||
},
|
||||
Err(_) => {
|
||||
error!("Generation timed out after 30 seconds");
|
||||
Err(anyhow::anyhow!("Generation timed out"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to estimate token count from text
|
||||
fn estimate_tokens(&self, text: &str) -> u32 {
|
||||
// Rough estimation: average 4 characters per token
|
||||
// This is conservative - actual tokenization might be different
|
||||
(text.len() as f32 / 4.0).ceil() as u32
|
||||
}
|
||||
|
||||
// Helper function to get stop sequences based on model type
|
||||
fn get_stop_sequences(&self) -> Vec<&'static str> {
|
||||
// Determine model type from model_name
|
||||
let model_name_lower = self.model_name.to_lowercase();
|
||||
|
||||
if model_name_lower.contains("qwen") {
|
||||
vec![
|
||||
"<|im_end|>", // Qwen ChatML format end token
|
||||
"<|endoftext|>", // Alternative end token
|
||||
"</s>", // Generic end of sequence
|
||||
"<|im_start|>", // Start of new message (shouldn't appear in response)
|
||||
]
|
||||
} else if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama") {
|
||||
vec![
|
||||
"</s>", // End of sequence
|
||||
"[/INST]", // End of instruction
|
||||
"<</SYS>>", // End of system message
|
||||
"[INST]", // Start of new instruction (shouldn't appear in response)
|
||||
"<<SYS>>", // Start of system (shouldn't appear in response)
|
||||
]
|
||||
} else if model_name_lower.contains("llama") {
|
||||
vec![
|
||||
"</s>", // End of sequence
|
||||
"[/INST]", // End of instruction
|
||||
"<</SYS>>", // End of system message
|
||||
"### Human:", // Conversation format
|
||||
"### Assistant:", // Conversation format
|
||||
"[INST]", // Start of new instruction
|
||||
]
|
||||
} else if model_name_lower.contains("mistral") {
|
||||
vec![
|
||||
"</s>", // End of sequence
|
||||
"[/INST]", // End of instruction
|
||||
"<|im_end|>", // ChatML format
|
||||
]
|
||||
} else if model_name_lower.contains("vicuna") || model_name_lower.contains("wizard") {
|
||||
vec![
|
||||
"### Human:", // Conversation format
|
||||
"### Assistant:", // Conversation format
|
||||
"USER:", // Alternative format
|
||||
"ASSISTANT:", // Alternative format
|
||||
"</s>", // End of sequence
|
||||
]
|
||||
} else if model_name_lower.contains("alpaca") {
|
||||
vec![
|
||||
"### Instruction:", // Alpaca format
|
||||
"### Response:", // Alpaca format
|
||||
"### Input:", // Alpaca format
|
||||
"</s>", // End of sequence
|
||||
]
|
||||
} else {
|
||||
// Generic/unknown model - use common stop sequences
|
||||
vec![
|
||||
"</s>", // Most common end sequence
|
||||
"<|endoftext|>", // GPT-style
|
||||
"<|im_end|>", // ChatML
|
||||
"### Human:", // Common conversation format
|
||||
"### Assistant:", // Common conversation format
|
||||
"[/INST]", // Instruction format
|
||||
"<</SYS>>", // System format
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to clean up stop sequences from generated text
|
||||
fn clean_stop_sequences(&self, text: &str) -> String {
|
||||
let mut cleaned = text.to_string();
|
||||
let stop_sequences = self.get_stop_sequences();
|
||||
|
||||
for stop_seq in &stop_sequences {
|
||||
if let Some(pos) = cleaned.find(stop_seq) {
|
||||
cleaned.truncate(pos);
|
||||
break; // Only remove the first occurrence to avoid over-truncation
|
||||
}
|
||||
}
|
||||
|
||||
cleaned.trim().to_string()
|
||||
}
|
||||
|
||||
// Download the Qwen 2.5 7B model if it doesn't exist
|
||||
fn download_qwen_model(model_path: &Path) -> Result<()> {
|
||||
use std::fs;
|
||||
use std::process::Command;
|
||||
|
||||
const MODEL_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-7B-Instruct-GGUF/resolve/main/qwen2.5-7b-instruct-q3_k_m.gguf";
|
||||
const MODEL_SIZE_MB: u64 = 3631; // Approximate size in MB
|
||||
|
||||
// Create the parent directory if it doesn't exist
|
||||
if let Some(parent) = model_path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
info!("Downloading Qwen 2.5 7B model (Q3_K_M quantization, ~3.5GB)...");
|
||||
info!("This is a one-time download that may take several minutes depending on your connection.");
|
||||
info!("Downloading to: {}", model_path.display());
|
||||
|
||||
// Use curl with progress bar for download
|
||||
let output = Command::new("curl")
|
||||
.args(&[
|
||||
"-L", // Follow redirects
|
||||
"-#", // Show progress bar
|
||||
"-f", // Fail on HTTP errors
|
||||
"-o", model_path.to_str().unwrap(),
|
||||
MODEL_URL,
|
||||
])
|
||||
.output()?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
|
||||
// If curl is not available, provide alternative instructions
|
||||
if stderr.contains("command not found") || stderr.contains("not found") {
|
||||
error!("curl is not installed. Please install curl or manually download the model.");
|
||||
error!("Manual download instructions:");
|
||||
error!("1. Download from: {}", MODEL_URL);
|
||||
error!("2. Save to: {}", model_path.display());
|
||||
anyhow::bail!("curl not found - please install curl or download the model manually");
|
||||
}
|
||||
|
||||
anyhow::bail!("Failed to download model: {}", stderr);
|
||||
}
|
||||
|
||||
// Verify the file was created and has reasonable size
|
||||
let metadata = fs::metadata(model_path)?;
|
||||
let size_mb = metadata.len() / (1024 * 1024);
|
||||
|
||||
if size_mb < MODEL_SIZE_MB - 100 { // Allow some variance
|
||||
fs::remove_file(model_path).ok(); // Clean up partial download
|
||||
anyhow::bail!(
|
||||
"Downloaded file appears incomplete ({}MB vs expected ~{}MB). Please try again.",
|
||||
size_mb, MODEL_SIZE_MB
|
||||
);
|
||||
}
|
||||
|
||||
info!("Successfully downloaded Qwen 2.5 7B model ({}MB)", size_mb);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl LLMProvider for EmbeddedProvider {
|
||||
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
|
||||
debug!(
|
||||
"Processing completion request with {} messages",
|
||||
request.messages.len()
|
||||
);
|
||||
|
||||
let prompt = self.format_messages(&request.messages);
|
||||
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
|
||||
let temperature = request.temperature.unwrap_or(self.temperature);
|
||||
|
||||
debug!("Formatted prompt length: {} chars", prompt.len());
|
||||
|
||||
let content = self
|
||||
.generate_completion(&prompt, max_tokens, temperature)
|
||||
.await?;
|
||||
|
||||
// Estimate token usage (rough approximation)
|
||||
let prompt_tokens = (prompt.len() / 4) as u32; // Rough estimate: 4 chars per token
|
||||
let completion_tokens = (content.len() / 4) as u32;
|
||||
|
||||
Ok(CompletionResponse {
|
||||
content,
|
||||
usage: Usage {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens: prompt_tokens + completion_tokens,
|
||||
},
|
||||
model: self.model_name.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
|
||||
debug!(
|
||||
"Processing streaming request with {} messages",
|
||||
request.messages.len()
|
||||
);
|
||||
|
||||
let prompt = self.format_messages(&request.messages);
|
||||
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
|
||||
let temperature = request.temperature.unwrap_or(self.temperature);
|
||||
|
||||
let (tx, rx) = mpsc::channel(100);
|
||||
let session = self.session.clone();
|
||||
let prompt = prompt.to_string();
|
||||
|
||||
// Spawn streaming task
|
||||
tokio::task::spawn_blocking(move || {
|
||||
// Retry logic for acquiring the session lock
|
||||
let mut session_guard = None;
|
||||
for attempt in 0..5 {
|
||||
match session.try_lock() {
|
||||
Ok(ctx) => {
|
||||
session_guard = Some(ctx);
|
||||
break;
|
||||
}
|
||||
Err(_) => {
|
||||
if attempt < 4 {
|
||||
debug!("Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1);
|
||||
std::thread::sleep(std::time::Duration::from_millis(100 * (attempt + 1) as u64));
|
||||
} else {
|
||||
let _ = tx.blocking_send(Err(anyhow::anyhow!("Model is busy after 5 attempts, please try again")));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut session = match session_guard {
|
||||
Some(ctx) => ctx,
|
||||
None => {
|
||||
let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to acquire session lock")));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Set context to the prompt
|
||||
if let Err(e) = session.set_context(&prompt) {
|
||||
let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to set context: {}", e)));
|
||||
return;
|
||||
}
|
||||
|
||||
// Create sampler with temperature
|
||||
let stages = vec![
|
||||
SamplerStage::Temperature(temperature),
|
||||
SamplerStage::TopK(40),
|
||||
SamplerStage::TopP(0.9),
|
||||
];
|
||||
let sampler = StandardSampler::new_softmax(stages, 1);
|
||||
|
||||
// Start completion
|
||||
let mut completion_handle = match session
|
||||
.start_completing_with(sampler, max_tokens as usize)
|
||||
{
|
||||
Ok(handle) => handle,
|
||||
Err(e) => {
|
||||
let _ =
|
||||
tx.blocking_send(Err(anyhow::anyhow!("Failed to start completion: {}", e)));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut accumulated_text = String::new();
|
||||
let mut token_count = 0;
|
||||
let mut unsent_tokens = String::new(); // Buffer for tokens we're holding back
|
||||
|
||||
// Get stop sequences dynamically based on model type
|
||||
let stop_sequences = if prompt.contains("<|im_start|>") {
|
||||
// Qwen ChatML format detected
|
||||
vec!["<|im_end|>", "<|endoftext|>", "</s>", "<|im_start|>"]
|
||||
} else if prompt.contains("[INST]") || prompt.contains("<<SYS>>") {
|
||||
// Llama/CodeLlama format detected
|
||||
vec!["</s>", "[/INST]", "<</SYS>>", "[INST]", "<<SYS>>", "### Human:", "### Assistant:"]
|
||||
} else {
|
||||
// Generic format
|
||||
vec!["</s>", "<|endoftext|>", "<|im_end|>", "### Human:", "### Assistant:", "[/INST]", "<</SYS>>"]
|
||||
};
|
||||
|
||||
// Stream tokens with proper limits
|
||||
while let Some(token) = completion_handle.next_token() {
|
||||
let token_string = session.model().token_to_piece(token);
|
||||
|
||||
accumulated_text.push_str(&token_string);
|
||||
unsent_tokens.push_str(&token_string);
|
||||
token_count += 1;
|
||||
|
||||
// Check if we've hit a complete stop sequence
|
||||
let mut hit_stop = false;
|
||||
for stop_seq in &stop_sequences {
|
||||
if accumulated_text.contains(stop_seq) {
|
||||
debug!("Hit complete stop sequence in streaming: {}", stop_seq);
|
||||
hit_stop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if hit_stop {
|
||||
// Before stopping, check if there might be an incomplete tool call
|
||||
// Look for JSON tool call patterns that might be cut off by the stop sequence
|
||||
let has_potential_tool_call = accumulated_text.contains(r#"{"tool":"#) ||
|
||||
accumulated_text.contains(r#"{"{""tool"":"#) ||
|
||||
accumulated_text.contains(r#"{{""tool"":"#);
|
||||
|
||||
if has_potential_tool_call {
|
||||
// Check if the tool call appears to be complete (has closing brace after the stop sequence)
|
||||
let mut complete_tool_call = false;
|
||||
for stop_seq in &stop_sequences {
|
||||
if let Some(stop_pos) = accumulated_text.find(stop_seq) {
|
||||
// Look for tool call pattern before the stop sequence
|
||||
let before_stop = &accumulated_text[..stop_pos];
|
||||
if let Some(tool_start) = before_stop.rfind(r#"{"tool":"#) {
|
||||
let tool_part = &before_stop[tool_start..];
|
||||
// Count braces to see if JSON is complete
|
||||
let open_braces = tool_part.matches('{').count();
|
||||
let close_braces = tool_part.matches('}').count();
|
||||
if open_braces > 0 && open_braces == close_braces {
|
||||
complete_tool_call = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If tool call is incomplete, send the raw content including stop sequences
|
||||
// so the main parser can handle it properly
|
||||
if !complete_tool_call {
|
||||
debug!("Found incomplete tool call, sending raw content with stop sequences");
|
||||
let already_sent_len = accumulated_text.len() - unsent_tokens.len();
|
||||
if accumulated_text.len() > already_sent_len {
|
||||
let remaining_to_send = &accumulated_text[already_sent_len..];
|
||||
if !remaining_to_send.is_empty() {
|
||||
let chunk = CompletionChunk {
|
||||
content: remaining_to_send.to_string(),
|
||||
finished: false,
|
||||
tool_calls: None,
|
||||
};
|
||||
let _ = tx.blocking_send(Ok(chunk));
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Send any remaining clean content before stopping (original behavior)
|
||||
let mut clean_accumulated = accumulated_text.clone();
|
||||
for stop_seq in &stop_sequences {
|
||||
if let Some(pos) = clean_accumulated.find(stop_seq) {
|
||||
clean_accumulated.truncate(pos);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate what part we haven't sent yet
|
||||
let already_sent_len = accumulated_text.len() - unsent_tokens.len();
|
||||
if clean_accumulated.len() > already_sent_len {
|
||||
let remaining_to_send = &clean_accumulated[already_sent_len..];
|
||||
if !remaining_to_send.is_empty() {
|
||||
let chunk = CompletionChunk {
|
||||
content: remaining_to_send.to_string(),
|
||||
finished: false,
|
||||
tool_calls: None,
|
||||
};
|
||||
let _ = tx.blocking_send(Ok(chunk));
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// Check if we're building towards a stop sequence
|
||||
let mut might_be_stop = false;
|
||||
for stop_seq in &stop_sequences {
|
||||
for i in 1..stop_seq.len() {
|
||||
let partial = &stop_seq[..i];
|
||||
if accumulated_text.ends_with(partial) {
|
||||
debug!("Detected potential partial stop sequence: '{}'", partial);
|
||||
might_be_stop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if might_be_stop {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if might_be_stop {
|
||||
// Hold back tokens, but only for a limited buffer size
|
||||
if unsent_tokens.len() > 20 { // Don't hold back more than 20 characters
|
||||
// Send the oldest part and keep only the recent part that might be a stop sequence
|
||||
let to_send = &unsent_tokens[..unsent_tokens.len() - 10];
|
||||
if !to_send.is_empty() {
|
||||
let chunk = CompletionChunk {
|
||||
content: to_send.to_string(),
|
||||
finished: false,
|
||||
tool_calls: None,
|
||||
};
|
||||
if tx.blocking_send(Ok(chunk)).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
unsent_tokens = unsent_tokens[unsent_tokens.len() - 10..].to_string();
|
||||
}
|
||||
// Continue to next token without sending
|
||||
} else {
|
||||
// No potential stop sequence, send all unsent tokens
|
||||
if !unsent_tokens.is_empty() {
|
||||
let chunk = CompletionChunk {
|
||||
content: unsent_tokens.clone(),
|
||||
finished: false,
|
||||
tool_calls: None,
|
||||
};
|
||||
if tx.blocking_send(Ok(chunk)).is_err() {
|
||||
break;
|
||||
}
|
||||
unsent_tokens.clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Enforce token limit
|
||||
if token_count >= max_tokens as usize {
|
||||
debug!("Reached max token limit in streaming: {}", max_tokens);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Send final chunk
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
tool_calls: None,
|
||||
};
|
||||
let _ = tx.blocking_send(Ok(final_chunk));
|
||||
});
|
||||
|
||||
Ok(ReceiverStream::new(rx))
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"embedded"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.model_name
|
||||
}
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
Reference in New Issue
Block a user