add context window monitor

Writes the current context window to logs/current_context_window (uses a symlink to a session ID).

This PR was unfortunately generated by a different LLM and did a ton of superficial reformating, it's actually a fairly small and benign change, but I don't want to roll back everything. Hope that's ok.
This commit is contained in:
Jochen
2025-11-27 21:00:02 +11:00
parent 93dc4acf86
commit 52f78653b4
89 changed files with 4040 additions and 2576 deletions

View File

@@ -1,8 +1,8 @@
use anyhow::Result;
use crate::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
MessageRole, Usage,
};
use anyhow::Result;
use llama_cpp::{
standard_sampler::{SamplerStage, StandardSampler},
LlamaModel, LlamaParams, LlamaSession, SessionParams,
@@ -37,7 +37,7 @@ impl EmbeddedProvider {
// Expand tilde in path
let expanded_path = shellexpand::tilde(&model_path);
let model_path_buf = PathBuf::from(expanded_path.as_ref());
// If model doesn't exist and it's the default Qwen model, offer to download it
if !model_path_buf.exists() {
if model_path.contains("qwen2.5-7b-instruct-q3_k_m.gguf") {
@@ -47,7 +47,7 @@ impl EmbeddedProvider {
anyhow::bail!("Model file not found: {}", model_path_buf.display());
}
}
let model_path = model_path_buf.as_path();
// Set up model parameters
@@ -93,24 +93,24 @@ impl EmbeddedProvider {
fn format_messages(&self, messages: &[Message]) -> String {
// Determine the appropriate format based on model type
let model_name_lower = self.model_name.to_lowercase();
if model_name_lower.contains("qwen") {
// Qwen format: <|im_start|>role\ncontent<|im_end|>
let mut formatted = String::new();
for message in messages {
let role = match message.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
};
formatted.push_str(&format!(
"<|im_start|>{}\n{}<|im_end|>\n",
role, message.content
));
}
// Add the start of assistant response
formatted.push_str("<|im_start|>assistant\n");
formatted
@@ -118,7 +118,7 @@ impl EmbeddedProvider {
// Mistral Instruct format: <s>[INST] ... [/INST] assistant_response</s>
let mut formatted = String::new();
let mut in_conversation = false;
for (i, message) in messages.iter().enumerate() {
match message.role {
MessageRole::System => {
@@ -146,12 +146,15 @@ impl EmbeddedProvider {
}
}
}
// If the last message was from user, add a space for the assistant's response
if messages.last().is_some_and(|m| matches!(m.role, MessageRole::User)) {
if messages
.last()
.is_some_and(|m| matches!(m.role, MessageRole::User))
{
formatted.push(' ');
}
formatted
} else {
// Use Llama/CodeLlama format for other models
@@ -216,16 +219,25 @@ impl EmbeddedProvider {
}
Err(_) => {
if attempt < 4 {
debug!("Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1);
std::thread::sleep(std::time::Duration::from_millis(100 * (attempt + 1) as u64));
debug!(
"Session busy, retrying in {}ms (attempt {}/5)",
100 * (attempt + 1),
attempt + 1
);
std::thread::sleep(std::time::Duration::from_millis(
100 * (attempt + 1) as u64,
));
} else {
return Err(anyhow::anyhow!("Model is busy after 5 attempts, please try again"));
return Err(anyhow::anyhow!(
"Model is busy after 5 attempts, please try again"
));
}
}
}
}
let mut session = session_guard.ok_or_else(|| anyhow::anyhow!("Failed to acquire session lock"))?;
let mut session = session_guard
.ok_or_else(|| anyhow::anyhow!("Failed to acquire session lock"))?;
debug!(
"Starting inference with prompt length: {} chars, estimated {} tokens",
@@ -297,7 +309,7 @@ impl EmbeddedProvider {
break;
}
}
if hit_stop {
break;
}
@@ -308,7 +320,7 @@ impl EmbeddedProvider {
token_count,
start_time.elapsed()
);
Ok((generated_text, token_count))
}),
)
@@ -347,21 +359,22 @@ impl EmbeddedProvider {
fn get_stop_sequences(&self) -> Vec<&'static str> {
// Determine model type from model_name
let model_name_lower = self.model_name.to_lowercase();
if model_name_lower.contains("qwen") {
vec![
"<|im_end|>", // Qwen ChatML format end token
"<|endoftext|>", // Alternative end token
"</s>", // Generic end of sequence
"<|im_start|>", // Start of new message (shouldn't appear in response)
"<|im_end|>", // Qwen ChatML format end token
"<|endoftext|>", // Alternative end token
"</s>", // Generic end of sequence
"<|im_start|>", // Start of new message (shouldn't appear in response)
]
} else if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama") {
} else if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama")
{
vec![
"</s>", // End of sequence
"[/INST]", // End of instruction
"<</SYS>>", // End of system message
"[INST]", // Start of new instruction (shouldn't appear in response)
"<<SYS>>", // Start of system (shouldn't appear in response)
"</s>", // End of sequence
"[/INST]", // End of instruction
"<</SYS>>", // End of system message
"[INST]", // Start of new instruction (shouldn't appear in response)
"<<SYS>>", // Start of system (shouldn't appear in response)
]
} else if model_name_lower.contains("llama") {
vec![
@@ -374,9 +387,9 @@ impl EmbeddedProvider {
]
} else if model_name_lower.contains("mistral") {
vec![
"</s>", // End of sequence
"[/INST]", // End of instruction
"<|im_end|>", // ChatML format
"</s>", // End of sequence
"[/INST]", // End of instruction
"<|im_end|>", // ChatML format
]
} else if model_name_lower.contains("vicuna") || model_name_lower.contains("wizard") {
vec![
@@ -391,7 +404,7 @@ impl EmbeddedProvider {
"### Instruction:", // Alpaca format
"### Response:", // Alpaca format
"### Input:", // Alpaca format
"</s>", // End of sequence
"</s>", // End of sequence
]
} else {
// Generic/unknown model - use common stop sequences
@@ -411,14 +424,14 @@ impl EmbeddedProvider {
fn clean_stop_sequences(&self, text: &str) -> String {
let mut cleaned = text.to_string();
let stop_sequences = self.get_stop_sequences();
for stop_seq in &stop_sequences {
if let Some(pos) = cleaned.find(stop_seq) {
cleaned.truncate(pos);
break; // Only remove the first occurrence to avoid over-truncation
}
}
cleaned.trim().to_string()
}
@@ -426,57 +439,64 @@ impl EmbeddedProvider {
fn download_qwen_model(model_path: &Path) -> Result<()> {
use std::fs;
use std::process::Command;
const MODEL_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-7B-Instruct-GGUF/resolve/main/qwen2.5-7b-instruct-q3_k_m.gguf";
const MODEL_SIZE_MB: u64 = 3631; // Approximate size in MB
// Create the parent directory if it doesn't exist
if let Some(parent) = model_path.parent() {
fs::create_dir_all(parent)?;
}
info!("Downloading Qwen 2.5 7B model (Q3_K_M quantization, ~3.5GB)...");
info!("This is a one-time download that may take several minutes depending on your connection.");
info!("Downloading to: {}", model_path.display());
// Use curl with progress bar for download
let output = Command::new("curl")
.args([
"-L", // Follow redirects
"-#", // Show progress bar
"-f", // Fail on HTTP errors
"-o", model_path.to_str().unwrap(),
"-L", // Follow redirects
"-#", // Show progress bar
"-f", // Fail on HTTP errors
"-o",
model_path.to_str().unwrap(),
MODEL_URL,
])
.output()?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
// If curl is not available, provide alternative instructions
if stderr.contains("command not found") || stderr.contains("not found") {
error!("curl is not installed. Please install curl or manually download the model.");
error!(
"curl is not installed. Please install curl or manually download the model."
);
error!("Manual download instructions:");
error!("1. Download from: {}", MODEL_URL);
error!("2. Save to: {}", model_path.display());
anyhow::bail!("curl not found - please install curl or download the model manually");
anyhow::bail!(
"curl not found - please install curl or download the model manually"
);
}
anyhow::bail!("Failed to download model: {}", stderr);
}
// Verify the file was created and has reasonable size
let metadata = fs::metadata(model_path)?;
let size_mb = metadata.len() / (1024 * 1024);
if size_mb < MODEL_SIZE_MB - 100 { // Allow some variance
fs::remove_file(model_path).ok(); // Clean up partial download
if size_mb < MODEL_SIZE_MB - 100 {
// Allow some variance
fs::remove_file(model_path).ok(); // Clean up partial download
anyhow::bail!(
"Downloaded file appears incomplete ({}MB vs expected ~{}MB). Please try again.",
size_mb, MODEL_SIZE_MB
size_mb,
MODEL_SIZE_MB
);
}
info!("Successfully downloaded Qwen 2.5 7B model ({}MB)", size_mb);
Ok(())
}
@@ -541,20 +561,29 @@ impl LLMProvider for EmbeddedProvider {
}
Err(_) => {
if attempt < 4 {
debug!("Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1);
std::thread::sleep(std::time::Duration::from_millis(100 * (attempt + 1) as u64));
debug!(
"Session busy, retrying in {}ms (attempt {}/5)",
100 * (attempt + 1),
attempt + 1
);
std::thread::sleep(std::time::Duration::from_millis(
100 * (attempt + 1) as u64,
));
} else {
let _ = tx.blocking_send(Err(anyhow::anyhow!("Model is busy after 5 attempts, please try again")));
let _ = tx.blocking_send(Err(anyhow::anyhow!(
"Model is busy after 5 attempts, please try again"
)));
return;
}
}
}
}
let mut session = match session_guard {
Some(ctx) => ctx,
None => {
let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to acquire session lock")));
let _ =
tx.blocking_send(Err(anyhow::anyhow!("Failed to acquire session lock")));
return;
}
};
@@ -588,17 +617,33 @@ impl LLMProvider for EmbeddedProvider {
let mut accumulated_text = String::new();
let mut token_count = 0;
let mut unsent_tokens = String::new(); // Buffer for tokens we're holding back
// Get stop sequences dynamically based on model type
let stop_sequences = if prompt.contains("<|im_start|>") {
// Qwen ChatML format detected
vec!["<|im_end|>", "<|endoftext|>", "</s>", "<|im_start|>"]
} else if prompt.contains("[INST]") || prompt.contains("<<SYS>>") {
// Llama/CodeLlama format detected
vec!["</s>", "[/INST]", "<</SYS>>", "[INST]", "<<SYS>>", "### Human:", "### Assistant:"]
vec![
"</s>",
"[/INST]",
"<</SYS>>",
"[INST]",
"<<SYS>>",
"### Human:",
"### Assistant:",
]
} else {
// Generic format
vec!["</s>", "<|endoftext|>", "<|im_end|>", "### Human:", "### Assistant:", "[/INST]", "<</SYS>>"]
vec![
"</s>",
"<|endoftext|>",
"<|im_end|>",
"### Human:",
"### Assistant:",
"[/INST]",
"<</SYS>>",
]
};
// Stream tokens with proper limits
@@ -622,10 +667,10 @@ impl LLMProvider for EmbeddedProvider {
if hit_stop {
// Before stopping, check if there might be an incomplete tool call
// Look for JSON tool call patterns that might be cut off by the stop sequence
let has_potential_tool_call = accumulated_text.contains(r#"{"tool":"#) ||
accumulated_text.contains(r#"{"{""tool"":"#) ||
accumulated_text.contains(r#"{{""tool"":"#);
let has_potential_tool_call = accumulated_text.contains(r#"{"tool":"#)
|| accumulated_text.contains(r#"{"{""tool"":"#)
|| accumulated_text.contains(r#"{{""tool"":"#);
if has_potential_tool_call {
// Check if the tool call appears to be complete (has closing brace after the stop sequence)
let mut complete_tool_call = false;
@@ -645,7 +690,7 @@ impl LLMProvider for EmbeddedProvider {
}
}
}
// If tool call is incomplete, send the raw content including stop sequences
// so the main parser can handle it properly
if !complete_tool_call {
@@ -666,7 +711,7 @@ impl LLMProvider for EmbeddedProvider {
break;
}
}
// Send any remaining clean content before stopping (original behavior)
let mut clean_accumulated = accumulated_text.clone();
for stop_seq in &stop_sequences {
@@ -675,7 +720,7 @@ impl LLMProvider for EmbeddedProvider {
break;
}
}
// Calculate what part we haven't sent yet
let already_sent_len = accumulated_text.len() - unsent_tokens.len();
if clean_accumulated.len() > already_sent_len {
@@ -711,7 +756,8 @@ impl LLMProvider for EmbeddedProvider {
if might_be_stop {
// Hold back tokens, but only for a limited buffer size
if unsent_tokens.len() > 20 { // Don't hold back more than 20 characters
if unsent_tokens.len() > 20 {
// Don't hold back more than 20 characters
// Send the oldest part and keep only the recent part that might be a stop sequence
let to_send = &unsent_tokens[..unsent_tokens.len() - 10];
if !to_send.is_empty() {
@@ -755,7 +801,7 @@ impl LLMProvider for EmbeddedProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: None, // Embedded models calculate usage differently
usage: None, // Embedded models calculate usage differently
tool_calls: None,
};
let _ = tx.blocking_send(Ok(final_chunk));
@@ -771,11 +817,11 @@ impl LLMProvider for EmbeddedProvider {
fn model(&self) -> &str {
&self.model_name
}
fn max_tokens(&self) -> u32 {
self.max_tokens
}
fn temperature(&self) -> f32 {
self.temperature
}