tool calling support for anthropic
This commit is contained in:
@@ -8,12 +8,12 @@ use llama_cpp::{
|
||||
LlamaModel, LlamaParams, LlamaSession, SessionParams,
|
||||
};
|
||||
use std::path::Path;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tracing::{debug, error, info, warn};
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
pub struct EmbeddedProvider {
|
||||
model: Arc<LlamaModel>,
|
||||
@@ -129,6 +129,9 @@ impl EmbeddedProvider {
|
||||
debug!("Context calculation: prompt_tokens={}, context_length={}, available_tokens={}, dynamic_max_tokens={}",
|
||||
prompt_tokens, self.context_length, available_tokens, dynamic_max_tokens);
|
||||
|
||||
// Get stop sequences before entering the closure
|
||||
let stop_sequences = self.get_stop_sequences();
|
||||
|
||||
// Add timeout to the entire operation
|
||||
let timeout_duration = std::time::Duration::from_secs(30); // Increased timeout for larger contexts
|
||||
|
||||
@@ -202,8 +205,16 @@ impl EmbeddedProvider {
|
||||
}
|
||||
|
||||
// Stop on completion markers
|
||||
if generated_text.contains("</s>") || generated_text.contains("[/INST]") {
|
||||
debug!("Hit CodeLlama stop sequence at {} tokens", token_count);
|
||||
let mut hit_stop = false;
|
||||
for stop_seq in &stop_sequences {
|
||||
if generated_text.contains(stop_seq) {
|
||||
debug!("Hit stop sequence '{}' at {} tokens", stop_seq, token_count);
|
||||
hit_stop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if hit_stop {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -213,7 +224,8 @@ impl EmbeddedProvider {
|
||||
token_count,
|
||||
start_time.elapsed()
|
||||
);
|
||||
Ok((generated_text.trim().to_string(), token_count))
|
||||
|
||||
Ok((generated_text, token_count))
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
@@ -226,7 +238,8 @@ impl EmbeddedProvider {
|
||||
"Completed generation: {} tokens (dynamic limit was {})",
|
||||
token_count, dynamic_max_tokens
|
||||
);
|
||||
Ok(text)
|
||||
// Clean stop sequences from the generated text after the closure
|
||||
Ok(self.clean_stop_sequences(&text))
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
},
|
||||
@@ -245,6 +258,78 @@ impl EmbeddedProvider {
|
||||
// This is conservative - actual tokenization might be different
|
||||
(text.len() as f32 / 4.0).ceil() as u32
|
||||
}
|
||||
|
||||
// Helper function to get stop sequences based on model type
|
||||
fn get_stop_sequences(&self) -> Vec<&'static str> {
|
||||
// Determine model type from model_name
|
||||
let model_name_lower = self.model_name.to_lowercase();
|
||||
|
||||
if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama") {
|
||||
vec![
|
||||
"</s>", // End of sequence
|
||||
"[/INST]", // End of instruction
|
||||
"<</SYS>>", // End of system message
|
||||
"[INST]", // Start of new instruction (shouldn't appear in response)
|
||||
"<<SYS>>", // Start of system (shouldn't appear in response)
|
||||
]
|
||||
} else if model_name_lower.contains("llama") {
|
||||
vec![
|
||||
"</s>", // End of sequence
|
||||
"[/INST]", // End of instruction
|
||||
"<</SYS>>", // End of system message
|
||||
"### Human:", // Conversation format
|
||||
"### Assistant:", // Conversation format
|
||||
"[INST]", // Start of new instruction
|
||||
]
|
||||
} else if model_name_lower.contains("mistral") {
|
||||
vec![
|
||||
"</s>", // End of sequence
|
||||
"[/INST]", // End of instruction
|
||||
"<|im_end|>", // ChatML format
|
||||
]
|
||||
} else if model_name_lower.contains("vicuna") || model_name_lower.contains("wizard") {
|
||||
vec![
|
||||
"### Human:", // Conversation format
|
||||
"### Assistant:", // Conversation format
|
||||
"USER:", // Alternative format
|
||||
"ASSISTANT:", // Alternative format
|
||||
"</s>", // End of sequence
|
||||
]
|
||||
} else if model_name_lower.contains("alpaca") {
|
||||
vec![
|
||||
"### Instruction:", // Alpaca format
|
||||
"### Response:", // Alpaca format
|
||||
"### Input:", // Alpaca format
|
||||
"</s>", // End of sequence
|
||||
]
|
||||
} else {
|
||||
// Generic/unknown model - use common stop sequences
|
||||
vec![
|
||||
"</s>", // Most common end sequence
|
||||
"<|endoftext|>", // GPT-style
|
||||
"<|im_end|>", // ChatML
|
||||
"### Human:", // Common conversation format
|
||||
"### Assistant:", // Common conversation format
|
||||
"[/INST]", // Instruction format
|
||||
"<</SYS>>", // System format
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to clean up stop sequences from generated text
|
||||
fn clean_stop_sequences(&self, text: &str) -> String {
|
||||
let mut cleaned = text.to_string();
|
||||
let stop_sequences = self.get_stop_sequences();
|
||||
|
||||
for stop_seq in &stop_sequences {
|
||||
if let Some(pos) = cleaned.find(stop_seq) {
|
||||
cleaned.truncate(pos);
|
||||
break; // Only remove the first occurrence to avoid over-truncation
|
||||
}
|
||||
}
|
||||
|
||||
cleaned.trim().to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -333,6 +418,17 @@ impl LLMProvider for EmbeddedProvider {
|
||||
|
||||
let mut accumulated_text = String::new();
|
||||
let mut token_count = 0;
|
||||
|
||||
// Get stop sequences dynamically based on model type
|
||||
// We need to create a temporary EmbeddedProvider instance to access the method
|
||||
// Since we can't access self in the spawned task, we'll use a static approach
|
||||
let stop_sequences = if prompt.contains("[INST]") || prompt.contains("<<SYS>>") {
|
||||
// Llama/CodeLlama format detected
|
||||
vec!["</s>", "[/INST]", "<</SYS>>", "[INST]", "<<SYS>>", "### Human:", "### Assistant:"]
|
||||
} else {
|
||||
// Generic format
|
||||
vec!["</s>", "<|endoftext|>", "<|im_end|>", "### Human:", "### Assistant:", "[/INST]", "<</SYS>>"]
|
||||
};
|
||||
|
||||
// Stream tokens with proper limits
|
||||
while let Some(token) = completion_handle.next_token() {
|
||||
@@ -341,13 +437,52 @@ impl LLMProvider for EmbeddedProvider {
|
||||
accumulated_text.push_str(&token_string);
|
||||
token_count += 1;
|
||||
|
||||
let chunk = CompletionChunk {
|
||||
content: token_string.clone(),
|
||||
finished: false,
|
||||
};
|
||||
// Check if we've hit a stop sequence
|
||||
let mut hit_stop = false;
|
||||
for stop_seq in &stop_sequences {
|
||||
if accumulated_text.contains(stop_seq) {
|
||||
debug!("Hit stop sequence in streaming: {}", stop_seq);
|
||||
hit_stop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if tx.blocking_send(Ok(chunk)).is_err() {
|
||||
break; // Receiver dropped
|
||||
if hit_stop {
|
||||
// Don't send the token that contains the stop sequence
|
||||
// Instead, send only the part before the stop sequence
|
||||
let mut clean_accumulated = accumulated_text.clone();
|
||||
for stop_seq in &stop_sequences {
|
||||
if let Some(pos) = clean_accumulated.find(stop_seq) {
|
||||
clean_accumulated.truncate(pos);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate what part we haven't sent yet
|
||||
let already_sent_len = accumulated_text.len() - token_string.len();
|
||||
if clean_accumulated.len() > already_sent_len {
|
||||
let remaining_to_send = &clean_accumulated[already_sent_len..];
|
||||
if !remaining_to_send.is_empty() {
|
||||
let chunk = CompletionChunk {
|
||||
content: remaining_to_send.to_string(),
|
||||
finished: false,
|
||||
tool_calls: None,
|
||||
};
|
||||
let _ = tx.blocking_send(Ok(chunk));
|
||||
}
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
// Normal token, send it
|
||||
let chunk = CompletionChunk {
|
||||
content: token_string.clone(),
|
||||
finished: false,
|
||||
tool_calls: None,
|
||||
};
|
||||
|
||||
if tx.blocking_send(Ok(chunk)).is_err() {
|
||||
break; // Receiver dropped
|
||||
}
|
||||
}
|
||||
|
||||
// Enforce token limit
|
||||
@@ -355,22 +490,13 @@ impl LLMProvider for EmbeddedProvider {
|
||||
debug!("Reached max token limit in streaming: {}", max_tokens);
|
||||
break;
|
||||
}
|
||||
|
||||
// Stop if we hit common stop sequences
|
||||
if accumulated_text.contains("### Human")
|
||||
|| accumulated_text.contains("### System")
|
||||
|| accumulated_text.contains("<|end|>")
|
||||
|| accumulated_text.contains("</s>")
|
||||
{
|
||||
debug!("Hit stop sequence in streaming, stopping generation");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Send final chunk
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
tool_calls: None,
|
||||
};
|
||||
let _ = tx.blocking_send(Ok(final_chunk));
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user