tool calling support for anthropic

This commit is contained in:
Dhanji Prasanna
2025-09-09 14:25:39 +10:00
parent 02d95e01a0
commit fa34755851
9 changed files with 705 additions and 121 deletions

View File

@@ -8,12 +8,12 @@ use llama_cpp::{
LlamaModel, LlamaParams, LlamaSession, SessionParams,
};
use std::path::Path;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error, info, warn};
use tracing::{debug, error, info};
pub struct EmbeddedProvider {
model: Arc<LlamaModel>,
@@ -129,6 +129,9 @@ impl EmbeddedProvider {
debug!("Context calculation: prompt_tokens={}, context_length={}, available_tokens={}, dynamic_max_tokens={}",
prompt_tokens, self.context_length, available_tokens, dynamic_max_tokens);
// Get stop sequences before entering the closure
let stop_sequences = self.get_stop_sequences();
// Add timeout to the entire operation
let timeout_duration = std::time::Duration::from_secs(30); // Increased timeout for larger contexts
@@ -202,8 +205,16 @@ impl EmbeddedProvider {
}
// Stop on completion markers
if generated_text.contains("</s>") || generated_text.contains("[/INST]") {
debug!("Hit CodeLlama stop sequence at {} tokens", token_count);
let mut hit_stop = false;
for stop_seq in &stop_sequences {
if generated_text.contains(stop_seq) {
debug!("Hit stop sequence '{}' at {} tokens", stop_seq, token_count);
hit_stop = true;
break;
}
}
if hit_stop {
break;
}
}
@@ -213,7 +224,8 @@ impl EmbeddedProvider {
token_count,
start_time.elapsed()
);
Ok((generated_text.trim().to_string(), token_count))
Ok((generated_text, token_count))
}),
)
.await;
@@ -226,7 +238,8 @@ impl EmbeddedProvider {
"Completed generation: {} tokens (dynamic limit was {})",
token_count, dynamic_max_tokens
);
Ok(text)
// Clean stop sequences from the generated text after the closure
Ok(self.clean_stop_sequences(&text))
}
Err(e) => Err(e),
},
@@ -245,6 +258,78 @@ impl EmbeddedProvider {
// This is conservative - actual tokenization might be different
(text.len() as f32 / 4.0).ceil() as u32
}
// Helper function to get stop sequences based on model type
fn get_stop_sequences(&self) -> Vec<&'static str> {
// Determine model type from model_name
let model_name_lower = self.model_name.to_lowercase();
if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama") {
vec![
"</s>", // End of sequence
"[/INST]", // End of instruction
"<</SYS>>", // End of system message
"[INST]", // Start of new instruction (shouldn't appear in response)
"<<SYS>>", // Start of system (shouldn't appear in response)
]
} else if model_name_lower.contains("llama") {
vec![
"</s>", // End of sequence
"[/INST]", // End of instruction
"<</SYS>>", // End of system message
"### Human:", // Conversation format
"### Assistant:", // Conversation format
"[INST]", // Start of new instruction
]
} else if model_name_lower.contains("mistral") {
vec![
"</s>", // End of sequence
"[/INST]", // End of instruction
"<|im_end|>", // ChatML format
]
} else if model_name_lower.contains("vicuna") || model_name_lower.contains("wizard") {
vec![
"### Human:", // Conversation format
"### Assistant:", // Conversation format
"USER:", // Alternative format
"ASSISTANT:", // Alternative format
"</s>", // End of sequence
]
} else if model_name_lower.contains("alpaca") {
vec![
"### Instruction:", // Alpaca format
"### Response:", // Alpaca format
"### Input:", // Alpaca format
"</s>", // End of sequence
]
} else {
// Generic/unknown model - use common stop sequences
vec![
"</s>", // Most common end sequence
"<|endoftext|>", // GPT-style
"<|im_end|>", // ChatML
"### Human:", // Common conversation format
"### Assistant:", // Common conversation format
"[/INST]", // Instruction format
"<</SYS>>", // System format
]
}
}
// Helper function to clean up stop sequences from generated text
fn clean_stop_sequences(&self, text: &str) -> String {
let mut cleaned = text.to_string();
let stop_sequences = self.get_stop_sequences();
for stop_seq in &stop_sequences {
if let Some(pos) = cleaned.find(stop_seq) {
cleaned.truncate(pos);
break; // Only remove the first occurrence to avoid over-truncation
}
}
cleaned.trim().to_string()
}
}
#[async_trait::async_trait]
@@ -333,6 +418,17 @@ impl LLMProvider for EmbeddedProvider {
let mut accumulated_text = String::new();
let mut token_count = 0;
// Get stop sequences dynamically based on model type
// We need to create a temporary EmbeddedProvider instance to access the method
// Since we can't access self in the spawned task, we'll use a static approach
let stop_sequences = if prompt.contains("[INST]") || prompt.contains("<<SYS>>") {
// Llama/CodeLlama format detected
vec!["</s>", "[/INST]", "<</SYS>>", "[INST]", "<<SYS>>", "### Human:", "### Assistant:"]
} else {
// Generic format
vec!["</s>", "<|endoftext|>", "<|im_end|>", "### Human:", "### Assistant:", "[/INST]", "<</SYS>>"]
};
// Stream tokens with proper limits
while let Some(token) = completion_handle.next_token() {
@@ -341,13 +437,52 @@ impl LLMProvider for EmbeddedProvider {
accumulated_text.push_str(&token_string);
token_count += 1;
let chunk = CompletionChunk {
content: token_string.clone(),
finished: false,
};
// Check if we've hit a stop sequence
let mut hit_stop = false;
for stop_seq in &stop_sequences {
if accumulated_text.contains(stop_seq) {
debug!("Hit stop sequence in streaming: {}", stop_seq);
hit_stop = true;
break;
}
}
if tx.blocking_send(Ok(chunk)).is_err() {
break; // Receiver dropped
if hit_stop {
// Don't send the token that contains the stop sequence
// Instead, send only the part before the stop sequence
let mut clean_accumulated = accumulated_text.clone();
for stop_seq in &stop_sequences {
if let Some(pos) = clean_accumulated.find(stop_seq) {
clean_accumulated.truncate(pos);
break;
}
}
// Calculate what part we haven't sent yet
let already_sent_len = accumulated_text.len() - token_string.len();
if clean_accumulated.len() > already_sent_len {
let remaining_to_send = &clean_accumulated[already_sent_len..];
if !remaining_to_send.is_empty() {
let chunk = CompletionChunk {
content: remaining_to_send.to_string(),
finished: false,
tool_calls: None,
};
let _ = tx.blocking_send(Ok(chunk));
}
}
break;
} else {
// Normal token, send it
let chunk = CompletionChunk {
content: token_string.clone(),
finished: false,
tool_calls: None,
};
if tx.blocking_send(Ok(chunk)).is_err() {
break; // Receiver dropped
}
}
// Enforce token limit
@@ -355,22 +490,13 @@ impl LLMProvider for EmbeddedProvider {
debug!("Reached max token limit in streaming: {}", max_tokens);
break;
}
// Stop if we hit common stop sequences
if accumulated_text.contains("### Human")
|| accumulated_text.contains("### System")
|| accumulated_text.contains("<|end|>")
|| accumulated_text.contains("</s>")
{
debug!("Hit stop sequence in streaming, stopping generation");
break;
}
}
// Send final chunk
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
tool_calls: None,
};
let _ = tx.blocking_send(Ok(final_chunk));
});