End tool call
This commit is contained in:
@@ -86,27 +86,51 @@ impl EmbeddedProvider {
|
||||
}
|
||||
|
||||
fn format_messages(&self, messages: &[Message]) -> String {
|
||||
// Use proper prompt format for CodeLlama
|
||||
let mut formatted = String::new();
|
||||
// Determine the appropriate format based on model type
|
||||
let model_name_lower = self.model_name.to_lowercase();
|
||||
|
||||
if model_name_lower.contains("qwen") {
|
||||
// Qwen format: <|im_start|>role\ncontent<|im_end|>
|
||||
let mut formatted = String::new();
|
||||
|
||||
for message in messages {
|
||||
let role = match message.role {
|
||||
MessageRole::System => "system",
|
||||
MessageRole::User => "user",
|
||||
MessageRole::Assistant => "assistant",
|
||||
};
|
||||
|
||||
formatted.push_str(&format!(
|
||||
"<|im_start|>{}\n{}<|im_end|>\n",
|
||||
role, message.content
|
||||
));
|
||||
}
|
||||
|
||||
// Add the start of assistant response
|
||||
formatted.push_str("<|im_start|>assistant\n");
|
||||
formatted
|
||||
} else {
|
||||
// Use Llama/CodeLlama format for other models
|
||||
let mut formatted = String::new();
|
||||
|
||||
for message in messages {
|
||||
match message.role {
|
||||
MessageRole::System => {
|
||||
formatted.push_str(&format!(
|
||||
"[INST] <<SYS>>\n{}\n<</SYS>>\n\n",
|
||||
message.content
|
||||
));
|
||||
}
|
||||
MessageRole::User => {
|
||||
formatted.push_str(&format!("{} [/INST] ", message.content));
|
||||
}
|
||||
MessageRole::Assistant => {
|
||||
formatted.push_str(&format!("{} </s><s>[INST] ", message.content));
|
||||
for message in messages {
|
||||
match message.role {
|
||||
MessageRole::System => {
|
||||
formatted.push_str(&format!(
|
||||
"[INST] <<SYS>>\n{}\n<</SYS>>\n\n",
|
||||
message.content
|
||||
));
|
||||
}
|
||||
MessageRole::User => {
|
||||
formatted.push_str(&format!("{} [/INST] ", message.content));
|
||||
}
|
||||
MessageRole::Assistant => {
|
||||
formatted.push_str(&format!("{} </s><s>[INST] ", message.content));
|
||||
}
|
||||
}
|
||||
}
|
||||
formatted
|
||||
}
|
||||
|
||||
formatted
|
||||
}
|
||||
|
||||
async fn generate_completion(
|
||||
@@ -138,10 +162,26 @@ impl EmbeddedProvider {
|
||||
let result = tokio::time::timeout(
|
||||
timeout_duration,
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut session = match session.try_lock() {
|
||||
Ok(ctx) => ctx,
|
||||
Err(_) => return Err(anyhow::anyhow!("Model is busy, please try again")),
|
||||
};
|
||||
// Retry logic for acquiring the session lock
|
||||
let mut session_guard = None;
|
||||
for attempt in 0..5 {
|
||||
match session.try_lock() {
|
||||
Ok(ctx) => {
|
||||
session_guard = Some(ctx);
|
||||
break;
|
||||
}
|
||||
Err(_) => {
|
||||
if attempt < 4 {
|
||||
debug!("Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1);
|
||||
std::thread::sleep(std::time::Duration::from_millis(100 * (attempt + 1) as u64));
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("Model is busy after 5 attempts, please try again"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut session = session_guard.ok_or_else(|| anyhow::anyhow!("Failed to acquire session lock"))?;
|
||||
|
||||
debug!(
|
||||
"Starting inference with prompt length: {} chars, estimated {} tokens",
|
||||
@@ -264,7 +304,14 @@ impl EmbeddedProvider {
|
||||
// Determine model type from model_name
|
||||
let model_name_lower = self.model_name.to_lowercase();
|
||||
|
||||
if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama") {
|
||||
if model_name_lower.contains("qwen") {
|
||||
vec![
|
||||
"<|im_end|>", // Qwen ChatML format end token
|
||||
"<|endoftext|>", // Alternative end token
|
||||
"</s>", // Generic end of sequence
|
||||
"<|im_start|>", // Start of new message (shouldn't appear in response)
|
||||
]
|
||||
} else if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama") {
|
||||
vec![
|
||||
"</s>", // End of sequence
|
||||
"[/INST]", // End of instruction
|
||||
@@ -381,11 +428,30 @@ impl LLMProvider for EmbeddedProvider {
|
||||
|
||||
// Spawn streaming task
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut session = match session.try_lock() {
|
||||
Ok(ctx) => ctx,
|
||||
Err(_) => {
|
||||
let _ =
|
||||
tx.blocking_send(Err(anyhow::anyhow!("Model is busy, please try again")));
|
||||
// Retry logic for acquiring the session lock
|
||||
let mut session_guard = None;
|
||||
for attempt in 0..5 {
|
||||
match session.try_lock() {
|
||||
Ok(ctx) => {
|
||||
session_guard = Some(ctx);
|
||||
break;
|
||||
}
|
||||
Err(_) => {
|
||||
if attempt < 4 {
|
||||
debug!("Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1);
|
||||
std::thread::sleep(std::time::Duration::from_millis(100 * (attempt + 1) as u64));
|
||||
} else {
|
||||
let _ = tx.blocking_send(Err(anyhow::anyhow!("Model is busy after 5 attempts, please try again")));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut session = match session_guard {
|
||||
Some(ctx) => ctx,
|
||||
None => {
|
||||
let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to acquire session lock")));
|
||||
return;
|
||||
}
|
||||
};
|
||||
@@ -418,11 +484,13 @@ impl LLMProvider for EmbeddedProvider {
|
||||
|
||||
let mut accumulated_text = String::new();
|
||||
let mut token_count = 0;
|
||||
let mut unsent_tokens = String::new(); // Buffer for tokens we're holding back
|
||||
|
||||
// Get stop sequences dynamically based on model type
|
||||
// We need to create a temporary EmbeddedProvider instance to access the method
|
||||
// Since we can't access self in the spawned task, we'll use a static approach
|
||||
let stop_sequences = if prompt.contains("[INST]") || prompt.contains("<<SYS>>") {
|
||||
let stop_sequences = if prompt.contains("<|im_start|>") {
|
||||
// Qwen ChatML format detected
|
||||
vec!["<|im_end|>", "<|endoftext|>", "</s>", "<|im_start|>"]
|
||||
} else if prompt.contains("[INST]") || prompt.contains("<<SYS>>") {
|
||||
// Llama/CodeLlama format detected
|
||||
vec!["</s>", "[/INST]", "<</SYS>>", "[INST]", "<<SYS>>", "### Human:", "### Assistant:"]
|
||||
} else {
|
||||
@@ -435,21 +503,21 @@ impl LLMProvider for EmbeddedProvider {
|
||||
let token_string = session.model().token_to_piece(token);
|
||||
|
||||
accumulated_text.push_str(&token_string);
|
||||
unsent_tokens.push_str(&token_string);
|
||||
token_count += 1;
|
||||
|
||||
// Check if we've hit a stop sequence
|
||||
// Check if we've hit a complete stop sequence
|
||||
let mut hit_stop = false;
|
||||
for stop_seq in &stop_sequences {
|
||||
if accumulated_text.contains(stop_seq) {
|
||||
debug!("Hit stop sequence in streaming: {}", stop_seq);
|
||||
debug!("Hit complete stop sequence in streaming: {}", stop_seq);
|
||||
hit_stop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if hit_stop {
|
||||
// Don't send the token that contains the stop sequence
|
||||
// Instead, send only the part before the stop sequence
|
||||
// Send any remaining clean content before stopping
|
||||
let mut clean_accumulated = accumulated_text.clone();
|
||||
for stop_seq in &stop_sequences {
|
||||
if let Some(pos) = clean_accumulated.find(stop_seq) {
|
||||
@@ -459,7 +527,7 @@ impl LLMProvider for EmbeddedProvider {
|
||||
}
|
||||
|
||||
// Calculate what part we haven't sent yet
|
||||
let already_sent_len = accumulated_text.len() - token_string.len();
|
||||
let already_sent_len = accumulated_text.len() - unsent_tokens.len();
|
||||
if clean_accumulated.len() > already_sent_len {
|
||||
let remaining_to_send = &clean_accumulated[already_sent_len..];
|
||||
if !remaining_to_send.is_empty() {
|
||||
@@ -472,16 +540,54 @@ impl LLMProvider for EmbeddedProvider {
|
||||
}
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
// Normal token, send it
|
||||
let chunk = CompletionChunk {
|
||||
content: token_string.clone(),
|
||||
finished: false,
|
||||
tool_calls: None,
|
||||
};
|
||||
}
|
||||
|
||||
if tx.blocking_send(Ok(chunk)).is_err() {
|
||||
break; // Receiver dropped
|
||||
// Check if we're building towards a stop sequence
|
||||
let mut might_be_stop = false;
|
||||
for stop_seq in &stop_sequences {
|
||||
for i in 1..stop_seq.len() {
|
||||
let partial = &stop_seq[..i];
|
||||
if accumulated_text.ends_with(partial) {
|
||||
debug!("Detected potential partial stop sequence: '{}'", partial);
|
||||
might_be_stop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if might_be_stop {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if might_be_stop {
|
||||
// Hold back tokens, but only for a limited buffer size
|
||||
if unsent_tokens.len() > 20 { // Don't hold back more than 20 characters
|
||||
// Send the oldest part and keep only the recent part that might be a stop sequence
|
||||
let to_send = &unsent_tokens[..unsent_tokens.len() - 10];
|
||||
if !to_send.is_empty() {
|
||||
let chunk = CompletionChunk {
|
||||
content: to_send.to_string(),
|
||||
finished: false,
|
||||
tool_calls: None,
|
||||
};
|
||||
if tx.blocking_send(Ok(chunk)).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
unsent_tokens = unsent_tokens[unsent_tokens.len() - 10..].to_string();
|
||||
}
|
||||
// Continue to next token without sending
|
||||
} else {
|
||||
// No potential stop sequence, send all unsent tokens
|
||||
if !unsent_tokens.is_empty() {
|
||||
let chunk = CompletionChunk {
|
||||
content: unsent_tokens.clone(),
|
||||
finished: false,
|
||||
tool_calls: None,
|
||||
};
|
||||
if tx.blocking_send(Ok(chunk)).is_err() {
|
||||
break;
|
||||
}
|
||||
unsent_tokens.clear();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user