add context window monitor
Writes the current context window to logs/current_context_window (uses a symlink to a session ID). This PR was unfortunately generated by a different LLM and did a ton of superficial reformating, it's actually a fairly small and benign change, but I don't want to roll back everything. Hope that's ok.
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
use anyhow::Result;
|
||||
use crate::{
|
||||
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
|
||||
MessageRole, Usage,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use llama_cpp::{
|
||||
standard_sampler::{SamplerStage, StandardSampler},
|
||||
LlamaModel, LlamaParams, LlamaSession, SessionParams,
|
||||
@@ -37,7 +37,7 @@ impl EmbeddedProvider {
|
||||
// Expand tilde in path
|
||||
let expanded_path = shellexpand::tilde(&model_path);
|
||||
let model_path_buf = PathBuf::from(expanded_path.as_ref());
|
||||
|
||||
|
||||
// If model doesn't exist and it's the default Qwen model, offer to download it
|
||||
if !model_path_buf.exists() {
|
||||
if model_path.contains("qwen2.5-7b-instruct-q3_k_m.gguf") {
|
||||
@@ -47,7 +47,7 @@ impl EmbeddedProvider {
|
||||
anyhow::bail!("Model file not found: {}", model_path_buf.display());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let model_path = model_path_buf.as_path();
|
||||
|
||||
// Set up model parameters
|
||||
@@ -93,24 +93,24 @@ impl EmbeddedProvider {
|
||||
fn format_messages(&self, messages: &[Message]) -> String {
|
||||
// Determine the appropriate format based on model type
|
||||
let model_name_lower = self.model_name.to_lowercase();
|
||||
|
||||
|
||||
if model_name_lower.contains("qwen") {
|
||||
// Qwen format: <|im_start|>role\ncontent<|im_end|>
|
||||
let mut formatted = String::new();
|
||||
|
||||
|
||||
for message in messages {
|
||||
let role = match message.role {
|
||||
MessageRole::System => "system",
|
||||
MessageRole::User => "user",
|
||||
MessageRole::User => "user",
|
||||
MessageRole::Assistant => "assistant",
|
||||
};
|
||||
|
||||
|
||||
formatted.push_str(&format!(
|
||||
"<|im_start|>{}\n{}<|im_end|>\n",
|
||||
role, message.content
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
// Add the start of assistant response
|
||||
formatted.push_str("<|im_start|>assistant\n");
|
||||
formatted
|
||||
@@ -118,7 +118,7 @@ impl EmbeddedProvider {
|
||||
// Mistral Instruct format: <s>[INST] ... [/INST] assistant_response</s>
|
||||
let mut formatted = String::new();
|
||||
let mut in_conversation = false;
|
||||
|
||||
|
||||
for (i, message) in messages.iter().enumerate() {
|
||||
match message.role {
|
||||
MessageRole::System => {
|
||||
@@ -146,12 +146,15 @@ impl EmbeddedProvider {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// If the last message was from user, add a space for the assistant's response
|
||||
if messages.last().is_some_and(|m| matches!(m.role, MessageRole::User)) {
|
||||
if messages
|
||||
.last()
|
||||
.is_some_and(|m| matches!(m.role, MessageRole::User))
|
||||
{
|
||||
formatted.push(' ');
|
||||
}
|
||||
|
||||
|
||||
formatted
|
||||
} else {
|
||||
// Use Llama/CodeLlama format for other models
|
||||
@@ -216,16 +219,25 @@ impl EmbeddedProvider {
|
||||
}
|
||||
Err(_) => {
|
||||
if attempt < 4 {
|
||||
debug!("Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1);
|
||||
std::thread::sleep(std::time::Duration::from_millis(100 * (attempt + 1) as u64));
|
||||
debug!(
|
||||
"Session busy, retrying in {}ms (attempt {}/5)",
|
||||
100 * (attempt + 1),
|
||||
attempt + 1
|
||||
);
|
||||
std::thread::sleep(std::time::Duration::from_millis(
|
||||
100 * (attempt + 1) as u64,
|
||||
));
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("Model is busy after 5 attempts, please try again"));
|
||||
return Err(anyhow::anyhow!(
|
||||
"Model is busy after 5 attempts, please try again"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut session = session_guard.ok_or_else(|| anyhow::anyhow!("Failed to acquire session lock"))?;
|
||||
|
||||
let mut session = session_guard
|
||||
.ok_or_else(|| anyhow::anyhow!("Failed to acquire session lock"))?;
|
||||
|
||||
debug!(
|
||||
"Starting inference with prompt length: {} chars, estimated {} tokens",
|
||||
@@ -297,7 +309,7 @@ impl EmbeddedProvider {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if hit_stop {
|
||||
break;
|
||||
}
|
||||
@@ -308,7 +320,7 @@ impl EmbeddedProvider {
|
||||
token_count,
|
||||
start_time.elapsed()
|
||||
);
|
||||
|
||||
|
||||
Ok((generated_text, token_count))
|
||||
}),
|
||||
)
|
||||
@@ -347,21 +359,22 @@ impl EmbeddedProvider {
|
||||
fn get_stop_sequences(&self) -> Vec<&'static str> {
|
||||
// Determine model type from model_name
|
||||
let model_name_lower = self.model_name.to_lowercase();
|
||||
|
||||
|
||||
if model_name_lower.contains("qwen") {
|
||||
vec![
|
||||
"<|im_end|>", // Qwen ChatML format end token
|
||||
"<|endoftext|>", // Alternative end token
|
||||
"</s>", // Generic end of sequence
|
||||
"<|im_start|>", // Start of new message (shouldn't appear in response)
|
||||
"<|im_end|>", // Qwen ChatML format end token
|
||||
"<|endoftext|>", // Alternative end token
|
||||
"</s>", // Generic end of sequence
|
||||
"<|im_start|>", // Start of new message (shouldn't appear in response)
|
||||
]
|
||||
} else if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama") {
|
||||
} else if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama")
|
||||
{
|
||||
vec![
|
||||
"</s>", // End of sequence
|
||||
"[/INST]", // End of instruction
|
||||
"<</SYS>>", // End of system message
|
||||
"[INST]", // Start of new instruction (shouldn't appear in response)
|
||||
"<<SYS>>", // Start of system (shouldn't appear in response)
|
||||
"</s>", // End of sequence
|
||||
"[/INST]", // End of instruction
|
||||
"<</SYS>>", // End of system message
|
||||
"[INST]", // Start of new instruction (shouldn't appear in response)
|
||||
"<<SYS>>", // Start of system (shouldn't appear in response)
|
||||
]
|
||||
} else if model_name_lower.contains("llama") {
|
||||
vec![
|
||||
@@ -374,9 +387,9 @@ impl EmbeddedProvider {
|
||||
]
|
||||
} else if model_name_lower.contains("mistral") {
|
||||
vec![
|
||||
"</s>", // End of sequence
|
||||
"[/INST]", // End of instruction
|
||||
"<|im_end|>", // ChatML format
|
||||
"</s>", // End of sequence
|
||||
"[/INST]", // End of instruction
|
||||
"<|im_end|>", // ChatML format
|
||||
]
|
||||
} else if model_name_lower.contains("vicuna") || model_name_lower.contains("wizard") {
|
||||
vec![
|
||||
@@ -391,7 +404,7 @@ impl EmbeddedProvider {
|
||||
"### Instruction:", // Alpaca format
|
||||
"### Response:", // Alpaca format
|
||||
"### Input:", // Alpaca format
|
||||
"</s>", // End of sequence
|
||||
"</s>", // End of sequence
|
||||
]
|
||||
} else {
|
||||
// Generic/unknown model - use common stop sequences
|
||||
@@ -411,14 +424,14 @@ impl EmbeddedProvider {
|
||||
fn clean_stop_sequences(&self, text: &str) -> String {
|
||||
let mut cleaned = text.to_string();
|
||||
let stop_sequences = self.get_stop_sequences();
|
||||
|
||||
|
||||
for stop_seq in &stop_sequences {
|
||||
if let Some(pos) = cleaned.find(stop_seq) {
|
||||
cleaned.truncate(pos);
|
||||
break; // Only remove the first occurrence to avoid over-truncation
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
cleaned.trim().to_string()
|
||||
}
|
||||
|
||||
@@ -426,57 +439,64 @@ impl EmbeddedProvider {
|
||||
fn download_qwen_model(model_path: &Path) -> Result<()> {
|
||||
use std::fs;
|
||||
use std::process::Command;
|
||||
|
||||
|
||||
const MODEL_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-7B-Instruct-GGUF/resolve/main/qwen2.5-7b-instruct-q3_k_m.gguf";
|
||||
const MODEL_SIZE_MB: u64 = 3631; // Approximate size in MB
|
||||
|
||||
|
||||
// Create the parent directory if it doesn't exist
|
||||
if let Some(parent) = model_path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
|
||||
info!("Downloading Qwen 2.5 7B model (Q3_K_M quantization, ~3.5GB)...");
|
||||
info!("This is a one-time download that may take several minutes depending on your connection.");
|
||||
info!("Downloading to: {}", model_path.display());
|
||||
|
||||
|
||||
// Use curl with progress bar for download
|
||||
let output = Command::new("curl")
|
||||
.args([
|
||||
"-L", // Follow redirects
|
||||
"-#", // Show progress bar
|
||||
"-f", // Fail on HTTP errors
|
||||
"-o", model_path.to_str().unwrap(),
|
||||
"-L", // Follow redirects
|
||||
"-#", // Show progress bar
|
||||
"-f", // Fail on HTTP errors
|
||||
"-o",
|
||||
model_path.to_str().unwrap(),
|
||||
MODEL_URL,
|
||||
])
|
||||
.output()?;
|
||||
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
|
||||
|
||||
// If curl is not available, provide alternative instructions
|
||||
if stderr.contains("command not found") || stderr.contains("not found") {
|
||||
error!("curl is not installed. Please install curl or manually download the model.");
|
||||
error!(
|
||||
"curl is not installed. Please install curl or manually download the model."
|
||||
);
|
||||
error!("Manual download instructions:");
|
||||
error!("1. Download from: {}", MODEL_URL);
|
||||
error!("2. Save to: {}", model_path.display());
|
||||
anyhow::bail!("curl not found - please install curl or download the model manually");
|
||||
anyhow::bail!(
|
||||
"curl not found - please install curl or download the model manually"
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
anyhow::bail!("Failed to download model: {}", stderr);
|
||||
}
|
||||
|
||||
|
||||
// Verify the file was created and has reasonable size
|
||||
let metadata = fs::metadata(model_path)?;
|
||||
let size_mb = metadata.len() / (1024 * 1024);
|
||||
|
||||
if size_mb < MODEL_SIZE_MB - 100 { // Allow some variance
|
||||
fs::remove_file(model_path).ok(); // Clean up partial download
|
||||
|
||||
if size_mb < MODEL_SIZE_MB - 100 {
|
||||
// Allow some variance
|
||||
fs::remove_file(model_path).ok(); // Clean up partial download
|
||||
anyhow::bail!(
|
||||
"Downloaded file appears incomplete ({}MB vs expected ~{}MB). Please try again.",
|
||||
size_mb, MODEL_SIZE_MB
|
||||
size_mb,
|
||||
MODEL_SIZE_MB
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
info!("Successfully downloaded Qwen 2.5 7B model ({}MB)", size_mb);
|
||||
Ok(())
|
||||
}
|
||||
@@ -541,20 +561,29 @@ impl LLMProvider for EmbeddedProvider {
|
||||
}
|
||||
Err(_) => {
|
||||
if attempt < 4 {
|
||||
debug!("Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1);
|
||||
std::thread::sleep(std::time::Duration::from_millis(100 * (attempt + 1) as u64));
|
||||
debug!(
|
||||
"Session busy, retrying in {}ms (attempt {}/5)",
|
||||
100 * (attempt + 1),
|
||||
attempt + 1
|
||||
);
|
||||
std::thread::sleep(std::time::Duration::from_millis(
|
||||
100 * (attempt + 1) as u64,
|
||||
));
|
||||
} else {
|
||||
let _ = tx.blocking_send(Err(anyhow::anyhow!("Model is busy after 5 attempts, please try again")));
|
||||
let _ = tx.blocking_send(Err(anyhow::anyhow!(
|
||||
"Model is busy after 5 attempts, please try again"
|
||||
)));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let mut session = match session_guard {
|
||||
Some(ctx) => ctx,
|
||||
None => {
|
||||
let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to acquire session lock")));
|
||||
let _ =
|
||||
tx.blocking_send(Err(anyhow::anyhow!("Failed to acquire session lock")));
|
||||
return;
|
||||
}
|
||||
};
|
||||
@@ -588,17 +617,33 @@ impl LLMProvider for EmbeddedProvider {
|
||||
let mut accumulated_text = String::new();
|
||||
let mut token_count = 0;
|
||||
let mut unsent_tokens = String::new(); // Buffer for tokens we're holding back
|
||||
|
||||
|
||||
// Get stop sequences dynamically based on model type
|
||||
let stop_sequences = if prompt.contains("<|im_start|>") {
|
||||
// Qwen ChatML format detected
|
||||
vec!["<|im_end|>", "<|endoftext|>", "</s>", "<|im_start|>"]
|
||||
} else if prompt.contains("[INST]") || prompt.contains("<<SYS>>") {
|
||||
// Llama/CodeLlama format detected
|
||||
vec!["</s>", "[/INST]", "<</SYS>>", "[INST]", "<<SYS>>", "### Human:", "### Assistant:"]
|
||||
vec![
|
||||
"</s>",
|
||||
"[/INST]",
|
||||
"<</SYS>>",
|
||||
"[INST]",
|
||||
"<<SYS>>",
|
||||
"### Human:",
|
||||
"### Assistant:",
|
||||
]
|
||||
} else {
|
||||
// Generic format
|
||||
vec!["</s>", "<|endoftext|>", "<|im_end|>", "### Human:", "### Assistant:", "[/INST]", "<</SYS>>"]
|
||||
vec![
|
||||
"</s>",
|
||||
"<|endoftext|>",
|
||||
"<|im_end|>",
|
||||
"### Human:",
|
||||
"### Assistant:",
|
||||
"[/INST]",
|
||||
"<</SYS>>",
|
||||
]
|
||||
};
|
||||
|
||||
// Stream tokens with proper limits
|
||||
@@ -622,10 +667,10 @@ impl LLMProvider for EmbeddedProvider {
|
||||
if hit_stop {
|
||||
// Before stopping, check if there might be an incomplete tool call
|
||||
// Look for JSON tool call patterns that might be cut off by the stop sequence
|
||||
let has_potential_tool_call = accumulated_text.contains(r#"{"tool":"#) ||
|
||||
accumulated_text.contains(r#"{"{""tool"":"#) ||
|
||||
accumulated_text.contains(r#"{{""tool"":"#);
|
||||
|
||||
let has_potential_tool_call = accumulated_text.contains(r#"{"tool":"#)
|
||||
|| accumulated_text.contains(r#"{"{""tool"":"#)
|
||||
|| accumulated_text.contains(r#"{{""tool"":"#);
|
||||
|
||||
if has_potential_tool_call {
|
||||
// Check if the tool call appears to be complete (has closing brace after the stop sequence)
|
||||
let mut complete_tool_call = false;
|
||||
@@ -645,7 +690,7 @@ impl LLMProvider for EmbeddedProvider {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// If tool call is incomplete, send the raw content including stop sequences
|
||||
// so the main parser can handle it properly
|
||||
if !complete_tool_call {
|
||||
@@ -666,7 +711,7 @@ impl LLMProvider for EmbeddedProvider {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Send any remaining clean content before stopping (original behavior)
|
||||
let mut clean_accumulated = accumulated_text.clone();
|
||||
for stop_seq in &stop_sequences {
|
||||
@@ -675,7 +720,7 @@ impl LLMProvider for EmbeddedProvider {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Calculate what part we haven't sent yet
|
||||
let already_sent_len = accumulated_text.len() - unsent_tokens.len();
|
||||
if clean_accumulated.len() > already_sent_len {
|
||||
@@ -711,7 +756,8 @@ impl LLMProvider for EmbeddedProvider {
|
||||
|
||||
if might_be_stop {
|
||||
// Hold back tokens, but only for a limited buffer size
|
||||
if unsent_tokens.len() > 20 { // Don't hold back more than 20 characters
|
||||
if unsent_tokens.len() > 20 {
|
||||
// Don't hold back more than 20 characters
|
||||
// Send the oldest part and keep only the recent part that might be a stop sequence
|
||||
let to_send = &unsent_tokens[..unsent_tokens.len() - 10];
|
||||
if !to_send.is_empty() {
|
||||
@@ -755,7 +801,7 @@ impl LLMProvider for EmbeddedProvider {
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
usage: None, // Embedded models calculate usage differently
|
||||
usage: None, // Embedded models calculate usage differently
|
||||
tool_calls: None,
|
||||
};
|
||||
let _ = tx.blocking_send(Ok(final_chunk));
|
||||
@@ -771,11 +817,11 @@ impl LLMProvider for EmbeddedProvider {
|
||||
fn model(&self) -> &str {
|
||||
&self.model_name
|
||||
}
|
||||
|
||||
|
||||
fn max_tokens(&self) -> u32 {
|
||||
self.max_tokens
|
||||
}
|
||||
|
||||
|
||||
fn temperature(&self) -> f32 {
|
||||
self.temperature
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user