Fix multiple tool call handling and improve auto-continue logic

- Add last_consumed_position tracking to StreamingToolParser to prevent
  re-detecting already-executed tool calls
- Add mark_tool_calls_consumed() method to mark tool calls as processed
- Add find_first_tool_call_start() for forward scanning of tool patterns
- Replace try_parse_json_tool_call_from_buffer() with
  try_parse_all_json_tool_calls_from_buffer() to find ALL tool calls
- Update has_incomplete_tool_call() and has_unexecuted_tool_call() to
  only check unconsumed portion of buffer
- Fix tool execution loop to not reset parser when unexecuted tools remain
- Simplify should_auto_continue logic (remove redundant condition)
- Add comprehensive tests for auto-continue condition logic
This commit is contained in:
Dhanji R. Prasanna
2025-12-22 16:08:57 +11:00
parent a755301cf9
commit 8070147a0c
2 changed files with 235 additions and 41 deletions

View File

@@ -260,8 +260,9 @@ const TOOL_CALL_PATTERNS: [&str; 4] = [
pub struct StreamingToolParser { pub struct StreamingToolParser {
/// Buffer for accumulating text content /// Buffer for accumulating text content
text_buffer: String, text_buffer: String,
/// Buffer for accumulating native tool calls /// Position in text_buffer up to which tool calls have been consumed/executed
native_tool_calls: Vec<g3_providers::ToolCall>, /// This prevents has_unexecuted_tool_call() from returning true for already-executed tools
last_consumed_position: usize,
/// Whether we've received a message_stop event /// Whether we've received a message_stop event
message_stopped: bool, message_stopped: bool,
/// Whether we're currently in a JSON tool call (for fallback parsing) /// Whether we're currently in a JSON tool call (for fallback parsing)
@@ -280,7 +281,7 @@ impl StreamingToolParser {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
text_buffer: String::new(), text_buffer: String::new(),
native_tool_calls: Vec::new(), last_consumed_position: 0,
message_stopped: false, message_stopped: false,
in_json_tool_call: false, in_json_tool_call: false,
json_tool_start: None, json_tool_start: None,
@@ -301,6 +302,20 @@ impl StreamingToolParser {
best_start best_start
} }
/// Find the starting position of the FIRST tool call pattern in the given text
/// Returns None if no tool call pattern is found
fn find_first_tool_call_start(text: &str) -> Option<usize> {
let mut best_start: Option<usize> = None;
for pattern in &TOOL_CALL_PATTERNS {
if let Some(pos) = text.find(pattern) {
if best_start.map_or(true, |best| pos < best) {
best_start = Some(pos);
}
}
}
best_start
}
/// Validate that tool call args don't contain message-like content /// Validate that tool call args don't contain message-like content
/// This detects malformed tool calls where agent messages got mixed into args /// This detects malformed tool calls where agent messages got mixed into args
fn has_message_like_keys(args: &serde_json::Map<String, serde_json::Value>) -> bool { fn has_message_like_keys(args: &serde_json::Map<String, serde_json::Value>) -> bool {
@@ -348,10 +363,12 @@ impl StreamingToolParser {
self.message_stopped = true; self.message_stopped = true;
debug!("Message finished, processing accumulated tool calls"); debug!("Message finished, processing accumulated tool calls");
// When stream finishes, do a final check for JSON tool calls in the accumulated buffer // When stream finishes, find ALL JSON tool calls in the accumulated buffer
if completed_tools.is_empty() && !self.text_buffer.is_empty() { if completed_tools.is_empty() && !self.text_buffer.is_empty() {
if let Some(json_tool) = self.try_parse_json_tool_call_from_buffer() { let all_tools = self.try_parse_all_json_tool_calls_from_buffer();
completed_tools.push(json_tool); if !all_tools.is_empty() {
debug!("Found {} JSON tool calls in buffer at stream end", all_tools.len());
completed_tools.extend(all_tools);
} }
} }
} }
@@ -417,32 +434,45 @@ impl StreamingToolParser {
None None
} }
/// Parse JSON tool call from the accumulated text buffer (called when stream finishes) /// Parse ALL JSON tool calls from the accumulated text buffer
/// This is similar to try_parse_json_tool_call but operates on the full buffer /// This finds all complete tool calls, not just the last one
fn try_parse_json_tool_call_from_buffer(&mut self) -> Option<ToolCall> { fn try_parse_all_json_tool_calls_from_buffer(&self) -> Vec<ToolCall> {
if let Some(start_pos) = Self::find_last_tool_call_start(&self.text_buffer) { let mut tool_calls = Vec::new();
let json_text = &self.text_buffer[start_pos..]; let mut search_start = 0;
debug!("Found potential JSON tool call at position {}: {:?}", start_pos,
if json_text.len() > 200 { &json_text[..200] } else { json_text }); while search_start < self.text_buffer.len() {
let search_text = &self.text_buffer[search_start..];
// Try to find a complete JSON object using the shared helper
if let Some(end_pos) = Self::find_complete_json_object_end(json_text) { // Find the next tool call pattern
let json_str = &json_text[..=end_pos]; if let Some(relative_pos) = Self::find_first_tool_call_start(search_text) {
debug!("Attempting to parse JSON tool call from buffer: {}", json_str); let abs_start = search_start + relative_pos;
let json_text = &self.text_buffer[abs_start..];
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(json_str) {
if let Some(args_obj) = tool_call.args.as_object() { // Try to find a complete JSON object
// Use the same validation as try_parse_json_tool_call if let Some(end_pos) = Self::find_complete_json_object_end(json_text) {
if !Self::has_message_like_keys(args_obj) { let json_str = &json_text[..=end_pos];
debug!("Successfully parsed JSON tool call from buffer: {:?}", tool_call);
return Some(tool_call); if let Ok(tool_call) = serde_json::from_str::<ToolCall>(json_str) {
if let Some(args_obj) = tool_call.args.as_object() {
if !Self::has_message_like_keys(args_obj) {
debug!("Found tool call at position {}: {:?}", abs_start, tool_call.tool);
tool_calls.push(tool_call);
}
} }
} }
// Move past this tool call
search_start = abs_start + end_pos + 1;
} else {
// Incomplete JSON, stop searching
break;
} }
} else {
// No more tool call patterns found
break;
} }
} }
None tool_calls
} }
/// Get the accumulated text content (excluding tool calls) /// Get the accumulated text content (excluding tool calls)
@@ -468,8 +498,10 @@ impl StreamingToolParser {
/// This detects cases where the LLM started emitting a tool call but the stream ended /// This detects cases where the LLM started emitting a tool call but the stream ended
/// before the JSON was complete (truncated output) /// before the JSON was complete (truncated output)
pub fn has_incomplete_tool_call(&self) -> bool { pub fn has_incomplete_tool_call(&self) -> bool {
if let Some(start_pos) = Self::find_last_tool_call_start(&self.text_buffer) { // Only check the unconsumed portion of the buffer
let json_text = &self.text_buffer[start_pos..]; let unchecked_buffer = &self.text_buffer[self.last_consumed_position..];
if let Some(start_pos) = Self::find_last_tool_call_start(unchecked_buffer) {
let json_text = &unchecked_buffer[start_pos..];
// If NOT complete, it's an incomplete tool call // If NOT complete, it's an incomplete tool call
Self::find_complete_json_object_end(json_text).is_none() Self::find_complete_json_object_end(json_text).is_none()
} else { } else {
@@ -481,8 +513,10 @@ impl StreamingToolParser {
/// This detects cases where the LLM emitted a complete tool call JSON /// This detects cases where the LLM emitted a complete tool call JSON
/// but it wasn't parsed/executed (e.g., due to parsing issues) /// but it wasn't parsed/executed (e.g., due to parsing issues)
pub fn has_unexecuted_tool_call(&self) -> bool { pub fn has_unexecuted_tool_call(&self) -> bool {
if let Some(start_pos) = Self::find_last_tool_call_start(&self.text_buffer) { // Only check the unconsumed portion of the buffer
let json_text = &self.text_buffer[start_pos..]; let unchecked_buffer = &self.text_buffer[self.last_consumed_position..];
if let Some(start_pos) = Self::find_last_tool_call_start(unchecked_buffer) {
let json_text = &unchecked_buffer[start_pos..];
// If the JSON IS complete, it means there's an unexecuted tool call // If the JSON IS complete, it means there's an unexecuted tool call
if let Some(json_end) = Self::find_complete_json_object_end(json_text) { if let Some(json_end) = Self::find_complete_json_object_end(json_text) {
let json_only = &json_text[..=json_end]; let json_only = &json_text[..=json_end];
@@ -492,6 +526,12 @@ impl StreamingToolParser {
false false
} }
/// Mark all tool calls up to the current buffer position as consumed/executed
/// This prevents has_unexecuted_tool_call() from returning true for already-executed tools
pub fn mark_tool_calls_consumed(&mut self) {
self.last_consumed_position = self.text_buffer.len();
}
/// Find the end position (byte index) of a complete JSON object in the text /// Find the end position (byte index) of a complete JSON object in the text
/// Returns None if no complete JSON object is found /// Returns None if no complete JSON object is found
fn find_complete_json_object_end(text: &str) -> Option<usize> { fn find_complete_json_object_end(text: &str) -> Option<usize> {
@@ -529,7 +569,7 @@ impl StreamingToolParser {
/// Reset the parser state for a new message /// Reset the parser state for a new message
pub fn reset(&mut self) { pub fn reset(&mut self) {
self.text_buffer.clear(); self.text_buffer.clear();
self.native_tool_calls.clear(); self.last_consumed_position = 0;
self.message_stopped = false; self.message_stopped = false;
self.in_json_tool_call = false; self.in_json_tool_call = false;
self.json_tool_start = None; self.json_tool_start = None;
@@ -4178,6 +4218,12 @@ impl<W: UiWriter> Agent<W> {
completed_tools.into_iter().take(1).collect() completed_tools.into_iter().take(1).collect()
}; };
// Mark tool calls as consumed so has_unexecuted_tool_call() won't
// return true for tools we're about to execute
if !tools_to_process.is_empty() {
parser.mark_tool_calls_consumed();
}
// Helper function to check if two tool calls are duplicates // Helper function to check if two tool calls are duplicates
let are_duplicates = |tc1: &ToolCall, tc2: &ToolCall| -> bool { let are_duplicates = |tc1: &ToolCall, tc2: &ToolCall| -> bool {
tc1.tool == tc2.tool && tc1.args == tc2.args tc1.tool == tc2.tool && tc1.args == tc2.args
@@ -4645,8 +4691,16 @@ impl<W: UiWriter> Agent<W> {
// This ensures the filter doesn't stay in suppression mode for subsequent streaming content // This ensures the filter doesn't stay in suppression mode for subsequent streaming content
self.ui_writer.reset_json_filter(); self.ui_writer.reset_json_filter();
// Reset parser for next iteration - this clears the text buffer // Only reset parser if there are no more unexecuted tool calls in the buffer
parser.reset(); // This handles the case where the LLM emits multiple tool calls in one response
if parser.has_unexecuted_tool_call() {
debug!("Parser still has unexecuted tool calls, not resetting buffer");
// Mark current tool as consumed so we don't re-detect it
parser.mark_tool_calls_consumed();
} else {
// Reset parser for next iteration - this clears the text buffer
parser.reset();
}
// Clear current_response for next iteration to prevent buffered text // Clear current_response for next iteration to prevent buffered text
// from being incorrectly displayed after tool execution // from being incorrectly displayed after tool execution
@@ -4661,8 +4715,14 @@ impl<W: UiWriter> Agent<W> {
} // End of for loop processing each tool call } // End of for loop processing each tool call
// If we processed any tools in multiple mode, break out to start new stream // If we processed any tools in multiple mode, break out to start new stream
// BUT only if there are no more unexecuted tool calls in the buffer
if tool_executed && self.config.agent.allow_multiple_tool_calls { if tool_executed && self.config.agent.allow_multiple_tool_calls {
break; if parser.has_unexecuted_tool_call() {
debug!("Tool executed but parser still has unexecuted tool calls, continuing to process");
// Don't break - continue processing to pick up remaining tool calls
} else {
break;
}
} }
// If no tool calls were completed, continue streaming normally // If no tool calls were completed, continue streaming normally
@@ -4752,7 +4812,7 @@ impl<W: UiWriter> Agent<W> {
" - Text buffer content: {:?}", " - Text buffer content: {:?}",
parser.get_text_content() parser.get_text_content()
); );
error!(" - Native tool calls: {:?}", parser.native_tool_calls); error!(" - Has incomplete tool call: {}", parser.has_incomplete_tool_call());
error!(" - Message stopped: {}", parser.is_message_stopped()); error!(" - Message stopped: {}", parser.is_message_stopped());
error!(" - In JSON tool call: {}", parser.in_json_tool_call); error!(" - In JSON tool call: {}", parser.in_json_tool_call);
error!(" - JSON tool start: {:?}", parser.json_tool_start); error!(" - JSON tool start: {:?}", parser.json_tool_start);
@@ -4872,8 +4932,8 @@ impl<W: UiWriter> Agent<W> {
); );
error!("Error type: {}", std::any::type_name_of_val(&e)); error!("Error type: {}", std::any::type_name_of_val(&e));
error!("Parser state at error: text_buffer_len={}, native_tool_calls={}, message_stopped={}", error!("Parser state at error: text_buffer_len={}, has_incomplete={}, message_stopped={}",
parser.text_buffer_len(), parser.native_tool_calls.len(), parser.is_message_stopped()); parser.text_buffer_len(), parser.has_incomplete_tool_call(), parser.is_message_stopped());
// Store the error for potential logging later // Store the error for potential logging later
_last_error = Some(error_details.clone()); _last_error = Some(error_details.clone());
@@ -4892,7 +4952,7 @@ impl<W: UiWriter> Agent<W> {
// If we have any content or tool calls, treat this as a graceful end // If we have any content or tool calls, treat this as a graceful end
if chunks_received > 0 if chunks_received > 0
&& (!parser.get_text_content().is_empty() && (!parser.get_text_content().is_empty()
|| parser.native_tool_calls.len() > 0) || parser.has_unexecuted_tool_call())
{ {
warn!("Stream terminated unexpectedly but we have content, continuing"); warn!("Stream terminated unexpectedly but we have content, continuing");
break; // Break to process what we have break; // Break to process what we have
@@ -4956,12 +5016,25 @@ impl<W: UiWriter> Agent<W> {
// Check if there's a complete but unexecuted tool call in the buffer // Check if there's a complete but unexecuted tool call in the buffer
let has_unexecuted_tool_call = parser.has_unexecuted_tool_call(); let has_unexecuted_tool_call = parser.has_unexecuted_tool_call();
// Log when we detect unexecuted or incomplete tool calls for debugging
if has_incomplete_tool_call {
debug!("Detected incomplete tool call in buffer (buffer_len={}, consumed_up_to={})",
parser.text_buffer_len(), parser.text_buffer_len());
}
if has_unexecuted_tool_call {
debug!("Detected unexecuted tool call in buffer - this may indicate a parsing issue");
warn!("Unexecuted tool call detected in buffer after stream ended");
}
// Auto-continue if tools were executed but final_output was never called // Auto-continue if tools were executed but final_output was never called
// OR if the LLM emitted an incomplete tool call (truncated JSON) // OR if the LLM emitted an incomplete tool call (truncated JSON)
// OR if the LLM emitted a complete tool call that wasn't executed // OR if the LLM emitted a complete tool call that wasn't executed
// OR if the LLM emitted an empty/trivial response (just timing lines)
// This ensures we don't return control when the LLM clearly intended to call a tool // This ensures we don't return control when the LLM clearly intended to call a tool
let should_auto_continue = (any_tool_executed && !final_output_called) || has_incomplete_tool_call || has_unexecuted_tool_call || (any_tool_executed && is_empty_response); // Note: We removed the redundant condition (any_tool_executed && is_empty_response)
// because it's already covered by (any_tool_executed && !final_output_called)
let should_auto_continue = (any_tool_executed && !final_output_called)
|| has_incomplete_tool_call
|| has_unexecuted_tool_call;
if should_auto_continue { if should_auto_continue {
if auto_summary_attempts < MAX_AUTO_SUMMARY_ATTEMPTS { if auto_summary_attempts < MAX_AUTO_SUMMARY_ATTEMPTS {
auto_summary_attempts += 1; auto_summary_attempts += 1;

View File

@@ -111,3 +111,124 @@ fn test_max_auto_summary_attempts_is_reasonable() {
assert!(CURRENT_VALUE <= EXPECTED_MAX_ATTEMPTS, assert!(CURRENT_VALUE <= EXPECTED_MAX_ATTEMPTS,
"MAX_AUTO_SUMMARY_ATTEMPTS should not exceed {} to avoid infinite loops", EXPECTED_MAX_ATTEMPTS); "MAX_AUTO_SUMMARY_ATTEMPTS should not exceed {} to avoid infinite loops", EXPECTED_MAX_ATTEMPTS);
} }
// =============================================================================
// Test: Auto-continue condition logic
// =============================================================================
/// Simulates the should_auto_continue logic from lib.rs
fn should_auto_continue(
any_tool_executed: bool,
final_output_called: bool,
has_incomplete_tool_call: bool,
has_unexecuted_tool_call: bool,
is_empty_response: bool,
) -> bool {
(any_tool_executed && !final_output_called)
|| has_incomplete_tool_call
|| has_unexecuted_tool_call
|| (any_tool_executed && is_empty_response)
}
#[test]
fn test_auto_continue_after_tool_no_final_output() {
// Tool executed but no final_output - should continue
assert!(should_auto_continue(
true, // any_tool_executed
false, // final_output_called
false, // has_incomplete_tool_call
false, // has_unexecuted_tool_call
false, // is_empty_response
));
}
#[test]
fn test_auto_continue_with_final_output() {
// Tool executed AND final_output called - should NOT continue
assert!(!should_auto_continue(
true, // any_tool_executed
true, // final_output_called
false, // has_incomplete_tool_call
false, // has_unexecuted_tool_call
false, // is_empty_response
));
}
#[test]
fn test_auto_continue_incomplete_tool_call() {
// Incomplete tool call - should continue regardless of other flags
assert!(should_auto_continue(
false, // any_tool_executed
false, // final_output_called
true, // has_incomplete_tool_call
false, // has_unexecuted_tool_call
false, // is_empty_response
));
}
#[test]
fn test_auto_continue_unexecuted_tool_call() {
// Unexecuted tool call - should continue
assert!(should_auto_continue(
false, // any_tool_executed
false, // final_output_called
false, // has_incomplete_tool_call
true, // has_unexecuted_tool_call
false, // is_empty_response
));
}
#[test]
fn test_auto_continue_empty_response_after_tool() {
// Empty response after tool execution - should continue
assert!(should_auto_continue(
true, // any_tool_executed
false, // final_output_called
false, // has_incomplete_tool_call
false, // has_unexecuted_tool_call
true, // is_empty_response
));
}
#[test]
fn test_auto_continue_empty_response_no_tool() {
// Empty response but no tool executed - should NOT continue
// (This is a normal case where LLM just didn't respond)
assert!(!should_auto_continue(
false, // any_tool_executed
false, // final_output_called
false, // has_incomplete_tool_call
false, // has_unexecuted_tool_call
true, // is_empty_response
));
}
#[test]
fn test_auto_continue_no_conditions_met() {
// No tools, no incomplete calls, substantive response - should NOT continue
assert!(!should_auto_continue(
false, // any_tool_executed
false, // final_output_called
false, // has_incomplete_tool_call
false, // has_unexecuted_tool_call
false, // is_empty_response
));
}
// =============================================================================
// Test: Redundant condition detection
// =============================================================================
#[test]
fn test_redundant_empty_response_condition() {
// This test documents that (any_tool_executed && is_empty_response) is redundant
// when (any_tool_executed && !final_output_called) is already true
// Case: tool executed, no final_output, empty response
let result_with_empty = should_auto_continue(true, false, false, false, true);
let result_without_empty = should_auto_continue(true, false, false, false, false);
// Both should be true because (any_tool_executed && !final_output_called) is true
assert_eq!(result_with_empty, result_without_empty,
"The is_empty_response condition is redundant when any_tool_executed && !final_output_called");
}