End tool call
This commit is contained in:
@@ -53,8 +53,16 @@ impl StreamingToolParser {
|
|||||||
// Look for the start of a tool call pattern: {"tool":
|
// Look for the start of a tool call pattern: {"tool":
|
||||||
if !self.in_tool_call {
|
if !self.in_tool_call {
|
||||||
// Look for JSON tool call pattern - check both raw JSON and inside code blocks
|
// Look for JSON tool call pattern - check both raw JSON and inside code blocks
|
||||||
if let Some(pos) = self.buffer.rfind(r#"{"tool":"#) {
|
// Also handle malformed patterns like {"{""tool"":
|
||||||
//info!("Found tool call pattern at position: {}", pos);
|
let patterns = [
|
||||||
|
r#"{"tool":"#, // Normal pattern
|
||||||
|
r#"{"{""tool"":"#, // Malformed pattern with extra brace and doubled quotes
|
||||||
|
r#"{{""tool"":"#, // Alternative malformed pattern
|
||||||
|
];
|
||||||
|
|
||||||
|
for pattern in &patterns {
|
||||||
|
if let Some(pos) = self.buffer.rfind(pattern) {
|
||||||
|
info!("Found tool call pattern '{}' at position: {}", pattern, pos);
|
||||||
|
|
||||||
// Check if this is inside a code block
|
// Check if this is inside a code block
|
||||||
let before_pos = &self.buffer[..pos];
|
let before_pos = &self.buffer[..pos];
|
||||||
@@ -70,6 +78,7 @@ impl StreamingToolParser {
|
|||||||
// Continue parsing from after the opening brace
|
// Continue parsing from after the opening brace
|
||||||
return self.parse_from_start_pos(pos);
|
return self.parse_from_start_pos(pos);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
//info!("Already in tool call, continuing parsing");
|
//info!("Already in tool call, continuing parsing");
|
||||||
// We're already in a tool call, continue parsing
|
// We're already in a tool call, continue parsing
|
||||||
@@ -100,10 +109,17 @@ impl StreamingToolParser {
|
|||||||
if current_brace_count == 0 {
|
if current_brace_count == 0 {
|
||||||
// Found complete JSON object
|
// Found complete JSON object
|
||||||
let end_pos = start_pos + i + 1;
|
let end_pos = start_pos + i + 1;
|
||||||
let json_str = &self.buffer[start_pos..end_pos];
|
let mut json_str = self.buffer[start_pos..end_pos].to_string();
|
||||||
|
|
||||||
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(json_str) {
|
// Clean up malformed JSON patterns
|
||||||
//info!("Successfully parsed tool call: {:?}", tool_call);
|
json_str = json_str
|
||||||
|
.replace(r#"{"{""#, r#"{"#) // Fix {"{" -> {"
|
||||||
|
.replace(r#"""}"#, r#""}"#) // Fix ""} -> "}
|
||||||
|
.replace(r#"{{""#, r#"{"#) // Fix {{" -> {"
|
||||||
|
.replace(r#"""}"#, r#""}"#); // Fix ""} -> "}
|
||||||
|
|
||||||
|
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(&json_str) {
|
||||||
|
info!("Successfully parsed tool call: {:?}", tool_call);
|
||||||
// Reset parser state
|
// Reset parser state
|
||||||
self.in_tool_call = false;
|
self.in_tool_call = false;
|
||||||
self.tool_start_pos = None;
|
self.tool_start_pos = None;
|
||||||
@@ -111,7 +127,7 @@ impl StreamingToolParser {
|
|||||||
|
|
||||||
return Some((tool_call, end_pos));
|
return Some((tool_call, end_pos));
|
||||||
} else {
|
} else {
|
||||||
info!("Failed to parse JSON: {}", json_str);
|
info!("Failed to parse JSON after cleanup: {}", json_str);
|
||||||
// Invalid JSON, reset and continue looking
|
// Invalid JSON, reset and continue looking
|
||||||
self.in_tool_call = false;
|
self.in_tool_call = false;
|
||||||
self.tool_start_pos = None;
|
self.tool_start_pos = None;
|
||||||
@@ -261,6 +277,7 @@ impl Agent {
|
|||||||
"codellama" => 16384, // CodeLlama supports 16k context
|
"codellama" => 16384, // CodeLlama supports 16k context
|
||||||
"llama" => 4096, // Base Llama models
|
"llama" => 4096, // Base Llama models
|
||||||
"mistral" => 8192, // Mistral models
|
"mistral" => 8192, // Mistral models
|
||||||
|
"qwen" => 32768, // Qwen2.5 supports 32k context
|
||||||
_ => 4096, // Conservative default
|
_ => 4096, // Conservative default
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -630,28 +647,42 @@ The tool will execute immediately and you'll receive the result to continue with
|
|||||||
// Found a complete tool call! Stop streaming and execute it
|
// Found a complete tool call! Stop streaming and execute it
|
||||||
let content_before_tool = parser.get_content_before_tool(tool_end_pos);
|
let content_before_tool = parser.get_content_before_tool(tool_end_pos);
|
||||||
|
|
||||||
// Display content up to the tool call (excluding the JSON)
|
// Display content up to the tool call (excluding the JSON and any stop tokens)
|
||||||
let display_content = if let Some(json_start) =
|
let display_content = if let Some(json_start) =
|
||||||
content_before_tool.rfind(r#"{"tool":"#)
|
content_before_tool.rfind(r#"{"tool":"#)
|
||||||
{
|
{
|
||||||
&content_before_tool[..json_start]
|
// Only show content before the JSON tool call
|
||||||
|
content_before_tool[..json_start].trim()
|
||||||
} else {
|
} else {
|
||||||
&content_before_tool
|
// Fallback: clean any stop tokens from the content
|
||||||
|
content_before_tool.trim()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Clean stop tokens from display content
|
||||||
|
let clean_display_content = display_content
|
||||||
|
.replace("<|im_end|>", "")
|
||||||
|
.replace("</s>", "")
|
||||||
|
.replace("[/INST]", "")
|
||||||
|
.replace("<</SYS>>", "");
|
||||||
|
let final_display_content = clean_display_content.trim();
|
||||||
|
|
||||||
// Safely get the new content to display
|
// Safely get the new content to display
|
||||||
let new_content = if current_response.len() <= display_content.len() {
|
let new_content = if current_response.len() <= final_display_content.len() {
|
||||||
// Use char indices to avoid UTF-8 boundary issues
|
// Use char indices to avoid UTF-8 boundary issues
|
||||||
let chars_already_shown = current_response.chars().count();
|
let chars_already_shown = current_response.chars().count();
|
||||||
display_content
|
final_display_content
|
||||||
.chars()
|
.chars()
|
||||||
.skip(chars_already_shown)
|
.skip(chars_already_shown)
|
||||||
.collect::<String>()
|
.collect::<String>()
|
||||||
} else {
|
} else {
|
||||||
String::new()
|
String::new()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Only print if there's actually new content to show
|
||||||
|
if !new_content.trim().is_empty() {
|
||||||
print!("{}", new_content);
|
print!("{}", new_content);
|
||||||
io::stdout().flush()?;
|
io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
|
||||||
// Execute the tool with formatted output
|
// Execute the tool with formatted output
|
||||||
println!(); // New line before tool execution
|
println!(); // New line before tool execution
|
||||||
@@ -724,20 +755,36 @@ The tool will execute immediately and you'll receive the result to continue with
|
|||||||
// Update the request with the new context for next iteration
|
// Update the request with the new context for next iteration
|
||||||
request.messages = self.context_window.conversation_history.clone();
|
request.messages = self.context_window.conversation_history.clone();
|
||||||
|
|
||||||
full_response.push_str(display_content);
|
full_response.push_str(final_display_content);
|
||||||
full_response.push_str(&format!(
|
full_response.push_str(&format!(
|
||||||
"\n\nTool executed: {} -> {}\n\n",
|
"\n\nTool executed: {} -> {}\n\n",
|
||||||
tool_call.tool, tool_result
|
tool_call.tool, tool_result
|
||||||
));
|
));
|
||||||
|
|
||||||
|
// Check if this was a final_output tool call - if so, stop the conversation
|
||||||
|
if tool_call.tool == "final_output" {
|
||||||
|
println!(); // New line after final output
|
||||||
|
let ttft = first_token_time.unwrap_or_else(|| stream_start.elapsed());
|
||||||
|
return Ok((full_response, ttft));
|
||||||
|
}
|
||||||
|
|
||||||
tool_executed = true;
|
tool_executed = true;
|
||||||
// Break out of current stream to start a new one with updated context
|
// Break out of current stream to start a new one with updated context
|
||||||
break;
|
break;
|
||||||
} else {
|
} else {
|
||||||
// No tool call detected, continue streaming normally
|
// No tool call detected, continue streaming normally
|
||||||
print!("{}", chunk.content);
|
// Filter out stop tokens from the streaming output
|
||||||
|
let clean_content = chunk.content
|
||||||
|
.replace("<|im_end|>", "")
|
||||||
|
.replace("</s>", "")
|
||||||
|
.replace("[/INST]", "")
|
||||||
|
.replace("<</SYS>>", "");
|
||||||
|
|
||||||
|
if !clean_content.is_empty() {
|
||||||
|
print!("{}", clean_content);
|
||||||
io::stdout().flush()?;
|
io::stdout().flush()?;
|
||||||
current_response.push_str(&chunk.content);
|
current_response.push_str(&clean_content);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if chunk.finished {
|
if chunk.finished {
|
||||||
|
|||||||
@@ -86,7 +86,31 @@ impl EmbeddedProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn format_messages(&self, messages: &[Message]) -> String {
|
fn format_messages(&self, messages: &[Message]) -> String {
|
||||||
// Use proper prompt format for CodeLlama
|
// Determine the appropriate format based on model type
|
||||||
|
let model_name_lower = self.model_name.to_lowercase();
|
||||||
|
|
||||||
|
if model_name_lower.contains("qwen") {
|
||||||
|
// Qwen format: <|im_start|>role\ncontent<|im_end|>
|
||||||
|
let mut formatted = String::new();
|
||||||
|
|
||||||
|
for message in messages {
|
||||||
|
let role = match message.role {
|
||||||
|
MessageRole::System => "system",
|
||||||
|
MessageRole::User => "user",
|
||||||
|
MessageRole::Assistant => "assistant",
|
||||||
|
};
|
||||||
|
|
||||||
|
formatted.push_str(&format!(
|
||||||
|
"<|im_start|>{}\n{}<|im_end|>\n",
|
||||||
|
role, message.content
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the start of assistant response
|
||||||
|
formatted.push_str("<|im_start|>assistant\n");
|
||||||
|
formatted
|
||||||
|
} else {
|
||||||
|
// Use Llama/CodeLlama format for other models
|
||||||
let mut formatted = String::new();
|
let mut formatted = String::new();
|
||||||
|
|
||||||
for message in messages {
|
for message in messages {
|
||||||
@@ -105,9 +129,9 @@ impl EmbeddedProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
formatted
|
formatted
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn generate_completion(
|
async fn generate_completion(
|
||||||
&self,
|
&self,
|
||||||
@@ -138,10 +162,26 @@ impl EmbeddedProvider {
|
|||||||
let result = tokio::time::timeout(
|
let result = tokio::time::timeout(
|
||||||
timeout_duration,
|
timeout_duration,
|
||||||
tokio::task::spawn_blocking(move || {
|
tokio::task::spawn_blocking(move || {
|
||||||
let mut session = match session.try_lock() {
|
// Retry logic for acquiring the session lock
|
||||||
Ok(ctx) => ctx,
|
let mut session_guard = None;
|
||||||
Err(_) => return Err(anyhow::anyhow!("Model is busy, please try again")),
|
for attempt in 0..5 {
|
||||||
};
|
match session.try_lock() {
|
||||||
|
Ok(ctx) => {
|
||||||
|
session_guard = Some(ctx);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
if attempt < 4 {
|
||||||
|
debug!("Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1);
|
||||||
|
std::thread::sleep(std::time::Duration::from_millis(100 * (attempt + 1) as u64));
|
||||||
|
} else {
|
||||||
|
return Err(anyhow::anyhow!("Model is busy after 5 attempts, please try again"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut session = session_guard.ok_or_else(|| anyhow::anyhow!("Failed to acquire session lock"))?;
|
||||||
|
|
||||||
debug!(
|
debug!(
|
||||||
"Starting inference with prompt length: {} chars, estimated {} tokens",
|
"Starting inference with prompt length: {} chars, estimated {} tokens",
|
||||||
@@ -264,7 +304,14 @@ impl EmbeddedProvider {
|
|||||||
// Determine model type from model_name
|
// Determine model type from model_name
|
||||||
let model_name_lower = self.model_name.to_lowercase();
|
let model_name_lower = self.model_name.to_lowercase();
|
||||||
|
|
||||||
if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama") {
|
if model_name_lower.contains("qwen") {
|
||||||
|
vec![
|
||||||
|
"<|im_end|>", // Qwen ChatML format end token
|
||||||
|
"<|endoftext|>", // Alternative end token
|
||||||
|
"</s>", // Generic end of sequence
|
||||||
|
"<|im_start|>", // Start of new message (shouldn't appear in response)
|
||||||
|
]
|
||||||
|
} else if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama") {
|
||||||
vec![
|
vec![
|
||||||
"</s>", // End of sequence
|
"</s>", // End of sequence
|
||||||
"[/INST]", // End of instruction
|
"[/INST]", // End of instruction
|
||||||
@@ -381,11 +428,30 @@ impl LLMProvider for EmbeddedProvider {
|
|||||||
|
|
||||||
// Spawn streaming task
|
// Spawn streaming task
|
||||||
tokio::task::spawn_blocking(move || {
|
tokio::task::spawn_blocking(move || {
|
||||||
let mut session = match session.try_lock() {
|
// Retry logic for acquiring the session lock
|
||||||
Ok(ctx) => ctx,
|
let mut session_guard = None;
|
||||||
|
for attempt in 0..5 {
|
||||||
|
match session.try_lock() {
|
||||||
|
Ok(ctx) => {
|
||||||
|
session_guard = Some(ctx);
|
||||||
|
break;
|
||||||
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
let _ =
|
if attempt < 4 {
|
||||||
tx.blocking_send(Err(anyhow::anyhow!("Model is busy, please try again")));
|
debug!("Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1);
|
||||||
|
std::thread::sleep(std::time::Duration::from_millis(100 * (attempt + 1) as u64));
|
||||||
|
} else {
|
||||||
|
let _ = tx.blocking_send(Err(anyhow::anyhow!("Model is busy after 5 attempts, please try again")));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut session = match session_guard {
|
||||||
|
Some(ctx) => ctx,
|
||||||
|
None => {
|
||||||
|
let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to acquire session lock")));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -418,11 +484,13 @@ impl LLMProvider for EmbeddedProvider {
|
|||||||
|
|
||||||
let mut accumulated_text = String::new();
|
let mut accumulated_text = String::new();
|
||||||
let mut token_count = 0;
|
let mut token_count = 0;
|
||||||
|
let mut unsent_tokens = String::new(); // Buffer for tokens we're holding back
|
||||||
|
|
||||||
// Get stop sequences dynamically based on model type
|
// Get stop sequences dynamically based on model type
|
||||||
// We need to create a temporary EmbeddedProvider instance to access the method
|
let stop_sequences = if prompt.contains("<|im_start|>") {
|
||||||
// Since we can't access self in the spawned task, we'll use a static approach
|
// Qwen ChatML format detected
|
||||||
let stop_sequences = if prompt.contains("[INST]") || prompt.contains("<<SYS>>") {
|
vec!["<|im_end|>", "<|endoftext|>", "</s>", "<|im_start|>"]
|
||||||
|
} else if prompt.contains("[INST]") || prompt.contains("<<SYS>>") {
|
||||||
// Llama/CodeLlama format detected
|
// Llama/CodeLlama format detected
|
||||||
vec!["</s>", "[/INST]", "<</SYS>>", "[INST]", "<<SYS>>", "### Human:", "### Assistant:"]
|
vec!["</s>", "[/INST]", "<</SYS>>", "[INST]", "<<SYS>>", "### Human:", "### Assistant:"]
|
||||||
} else {
|
} else {
|
||||||
@@ -435,21 +503,21 @@ impl LLMProvider for EmbeddedProvider {
|
|||||||
let token_string = session.model().token_to_piece(token);
|
let token_string = session.model().token_to_piece(token);
|
||||||
|
|
||||||
accumulated_text.push_str(&token_string);
|
accumulated_text.push_str(&token_string);
|
||||||
|
unsent_tokens.push_str(&token_string);
|
||||||
token_count += 1;
|
token_count += 1;
|
||||||
|
|
||||||
// Check if we've hit a stop sequence
|
// Check if we've hit a complete stop sequence
|
||||||
let mut hit_stop = false;
|
let mut hit_stop = false;
|
||||||
for stop_seq in &stop_sequences {
|
for stop_seq in &stop_sequences {
|
||||||
if accumulated_text.contains(stop_seq) {
|
if accumulated_text.contains(stop_seq) {
|
||||||
debug!("Hit stop sequence in streaming: {}", stop_seq);
|
debug!("Hit complete stop sequence in streaming: {}", stop_seq);
|
||||||
hit_stop = true;
|
hit_stop = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if hit_stop {
|
if hit_stop {
|
||||||
// Don't send the token that contains the stop sequence
|
// Send any remaining clean content before stopping
|
||||||
// Instead, send only the part before the stop sequence
|
|
||||||
let mut clean_accumulated = accumulated_text.clone();
|
let mut clean_accumulated = accumulated_text.clone();
|
||||||
for stop_seq in &stop_sequences {
|
for stop_seq in &stop_sequences {
|
||||||
if let Some(pos) = clean_accumulated.find(stop_seq) {
|
if let Some(pos) = clean_accumulated.find(stop_seq) {
|
||||||
@@ -459,7 +527,7 @@ impl LLMProvider for EmbeddedProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Calculate what part we haven't sent yet
|
// Calculate what part we haven't sent yet
|
||||||
let already_sent_len = accumulated_text.len() - token_string.len();
|
let already_sent_len = accumulated_text.len() - unsent_tokens.len();
|
||||||
if clean_accumulated.len() > already_sent_len {
|
if clean_accumulated.len() > already_sent_len {
|
||||||
let remaining_to_send = &clean_accumulated[already_sent_len..];
|
let remaining_to_send = &clean_accumulated[already_sent_len..];
|
||||||
if !remaining_to_send.is_empty() {
|
if !remaining_to_send.is_empty() {
|
||||||
@@ -472,16 +540,54 @@ impl LLMProvider for EmbeddedProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
} else {
|
}
|
||||||
// Normal token, send it
|
|
||||||
|
// Check if we're building towards a stop sequence
|
||||||
|
let mut might_be_stop = false;
|
||||||
|
for stop_seq in &stop_sequences {
|
||||||
|
for i in 1..stop_seq.len() {
|
||||||
|
let partial = &stop_seq[..i];
|
||||||
|
if accumulated_text.ends_with(partial) {
|
||||||
|
debug!("Detected potential partial stop sequence: '{}'", partial);
|
||||||
|
might_be_stop = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if might_be_stop {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if might_be_stop {
|
||||||
|
// Hold back tokens, but only for a limited buffer size
|
||||||
|
if unsent_tokens.len() > 20 { // Don't hold back more than 20 characters
|
||||||
|
// Send the oldest part and keep only the recent part that might be a stop sequence
|
||||||
|
let to_send = &unsent_tokens[..unsent_tokens.len() - 10];
|
||||||
|
if !to_send.is_empty() {
|
||||||
let chunk = CompletionChunk {
|
let chunk = CompletionChunk {
|
||||||
content: token_string.clone(),
|
content: to_send.to_string(),
|
||||||
finished: false,
|
finished: false,
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
if tx.blocking_send(Ok(chunk)).is_err() {
|
if tx.blocking_send(Ok(chunk)).is_err() {
|
||||||
break; // Receiver dropped
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
unsent_tokens = unsent_tokens[unsent_tokens.len() - 10..].to_string();
|
||||||
|
}
|
||||||
|
// Continue to next token without sending
|
||||||
|
} else {
|
||||||
|
// No potential stop sequence, send all unsent tokens
|
||||||
|
if !unsent_tokens.is_empty() {
|
||||||
|
let chunk = CompletionChunk {
|
||||||
|
content: unsent_tokens.clone(),
|
||||||
|
finished: false,
|
||||||
|
tool_calls: None,
|
||||||
|
};
|
||||||
|
if tx.blocking_send(Ok(chunk)).is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
unsent_tokens.clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user