suppress json tool calls in raw text

This commit is contained in:
Dhanji Prasanna
2025-10-01 13:20:13 +10:00
parent 3349a33106
commit a843ecc9d0
2 changed files with 227 additions and 73 deletions

View File

@@ -303,7 +303,21 @@ fn truncate_for_logging(s: &str, max_len: usize) -> String {
if s.len() <= max_len { if s.len() <= max_len {
s.to_string() s.to_string()
} else { } else {
format!("{}... (truncated, {} total chars)", &s[..max_len], s.len()) // Find a safe UTF-8 boundary to truncate at
// We need to ensure we don't cut in the middle of a multi-byte character
let mut truncate_at = max_len;
// Walk backwards from max_len to find a character boundary
while truncate_at > 0 && !s.is_char_boundary(truncate_at) {
truncate_at -= 1;
}
// If we couldn't find a boundary (shouldn't happen), use a safe default
if truncate_at == 0 {
truncate_at = max_len.min(s.len());
}
format!("{}... (truncated, {} total bytes)", &s[..truncate_at], s.len())
} }
} }
@@ -395,5 +409,22 @@ mod tests {
assert!(truncated.starts_with("This is a very long ")); assert!(truncated.starts_with("This is a very long "));
assert!(truncated.contains("truncated")); assert!(truncated.contains("truncated"));
assert!(truncated.contains("total chars")); assert!(truncated.contains("total chars"));
assert!(truncated.contains("total bytes"));
}
#[test]
fn test_truncate_with_multibyte_chars() {
// Test with multi-byte UTF-8 characters
let text_with_emoji = "Hello 👋 World 🌍 Test ✨ More text here";
let truncated = truncate_for_logging(text_with_emoji, 10);
// Should truncate at a valid UTF-8 boundary
assert!(truncated.starts_with("Hello "));
// Test with box-drawing characters like the one causing the panic
let text_with_box = "Some text ┌─────┐ more text";
let truncated = truncate_for_logging(text_with_box, 12);
// Should not panic and should truncate at a valid boundary
assert!(truncated.contains("Some text"));
assert!(truncated.contains("truncated"));
} }
} }

View File

@@ -1,5 +1,5 @@
pub mod project;
pub mod error_handling; pub mod error_handling;
pub mod project;
#[cfg(test)] #[cfg(test)]
mod error_handling_test; mod error_handling_test;
@@ -581,7 +581,6 @@ impl Agent {
"You are G3, an AI programming agent. Your goal is to analyze, write and modify code to achieve given goals. "You are G3, an AI programming agent. Your goal is to analyze, write and modify code to achieve given goals.
You have access to tools. When you need to accomplish a task, you MUST use the appropriate tool. Do not just describe what you would do - actually use the tools. You have access to tools. When you need to accomplish a task, you MUST use the appropriate tool. Do not just describe what you would do - actually use the tools.
Always start by reading the project's README. Create one if this is a new project or making major changes.
IMPORTANT: You must call tools to achieve goals. When you receive a request: IMPORTANT: You must call tools to achieve goals. When you receive a request:
1. Analyze and identify what needs to be done 1. Analyze and identify what needs to be done
@@ -975,7 +974,7 @@ The tool will execute immediately and you'll receive the result (success or erro
request: &CompletionRequest, request: &CompletionRequest,
error_context: &error_handling::ErrorContext, error_context: &error_handling::ErrorContext,
) -> Result<g3_providers::CompletionStream> { ) -> Result<g3_providers::CompletionStream> {
use crate::error_handling::{classify_error, calculate_retry_delay, ErrorType}; use crate::error_handling::{calculate_retry_delay, classify_error, ErrorType};
let mut attempt = 0; let mut attempt = 0;
const MAX_ATTEMPTS: u32 = 3; const MAX_ATTEMPTS: u32 = 3;
@@ -994,7 +993,10 @@ The tool will execute immediately and you'll receive the result (success or erro
Err(e) if attempt < MAX_ATTEMPTS => { Err(e) if attempt < MAX_ATTEMPTS => {
if matches!(classify_error(&e), ErrorType::Recoverable(_)) { if matches!(classify_error(&e), ErrorType::Recoverable(_)) {
let delay = calculate_retry_delay(attempt); let delay = calculate_retry_delay(attempt);
warn!("Recoverable error on attempt {}/{}: {}. Retrying in {:?}...", attempt, MAX_ATTEMPTS, e, delay); warn!(
"Recoverable error on attempt {}/{}: {}. Retrying in {:?}...",
attempt, MAX_ATTEMPTS, e, delay
);
tokio::time::sleep(delay).await; tokio::time::sleep(delay).await;
} else { } else {
error_context.clone().log_error(&e); error_context.clone().log_error(&e);
@@ -1166,7 +1168,8 @@ The tool will execute immediately and you'll receive the result (success or erro
debug!("Got provider: {}", provider.name()); debug!("Got provider: {}", provider.name());
// Create error context for detailed logging // Create error context for detailed logging
let last_prompt = request.messages let last_prompt = request
.messages
.iter() .iter()
.rev() .rev()
.find(|m| matches!(m.role, MessageRole::User)) .find(|m| matches!(m.role, MessageRole::User))
@@ -1180,8 +1183,10 @@ The tool will execute immediately and you'll receive the result (success or erro
last_prompt, last_prompt,
self.session_id.clone(), self.session_id.clone(),
self.context_window.used_tokens, self.context_window.used_tokens,
).with_request( )
serde_json::to_string(&request).unwrap_or_else(|_| "Failed to serialize request".to_string()) .with_request(
serde_json::to_string(&request)
.unwrap_or_else(|_| "Failed to serialize request".to_string()),
); );
// Try to get stream with retry logic // Try to get stream with retry logic
@@ -1231,13 +1236,16 @@ The tool will execute immediately and you'll receive the result (success or erro
// Get the text content accumulated so far // Get the text content accumulated so far
let text_content = parser.get_text_content(); let text_content = parser.get_text_content();
// Clean and prepare display content // Clean the content
let clean_display_content = text_content let clean_content = text_content
.replace("<|im_end|>", "") .replace("<|im_end|>", "")
.replace("</s>", "") .replace("</s>", "")
.replace("[/INST]", "") .replace("[/INST]", "")
.replace("<</SYS>>", ""); .replace("<</SYS>>", "");
let final_display_content = clean_display_content.trim();
// Filter out JSON tool calls from the display
let filtered_content = filter_json_tool_calls(&clean_content);
let final_display_content = filtered_content.trim();
// Display any new content before tool execution // Display any new content before tool execution
let new_content = let new_content =
@@ -1861,7 +1869,10 @@ The tool will execute immediately and you'll receive the result (success or erro
.and_then(|v| v.as_u64()) .and_then(|v| v.as_u64())
.map(|n| n as usize); .map(|n| n as usize);
debug!("str_replace: path={}, start={:?}, end={:?}", file_path, start_char, end_char); debug!(
"str_replace: path={}, start={:?}, end={:?}",
file_path, start_char, end_char
);
// Read the existing file // Read the existing file
let file_content = match std::fs::read_to_string(file_path) { let file_content = match std::fs::read_to_string(file_path) {
@@ -1870,16 +1881,18 @@ The tool will execute immediately and you'll receive the result (success or erro
}; };
// Apply unified diff to content // Apply unified diff to content
let result = match apply_unified_diff_to_string(&file_content, diff, start_char, end_char) { let result =
match apply_unified_diff_to_string(&file_content, diff, start_char, end_char) {
Ok(r) => r, Ok(r) => r,
Err(e) => return Ok(format!("{}", e)), Err(e) => return Ok(format!("{}", e)),
}; };
// Write the result back to the file // Write the result back to the file
match std::fs::write(file_path, &result) { match std::fs::write(file_path, &result) {
Ok(()) => { Ok(()) => Ok(format!(
Ok(format!("✅ Successfully applied unified diff to '{}'", file_path)) "✅ Successfully applied unified diff to '{}'",
} file_path
)),
Err(e) => Ok(format!("❌ Failed to write to file '{}': {}", file_path, e)), Err(e) => Ok(format!("❌ Failed to write to file '{}': {}", file_path, e)),
} }
} }
@@ -1917,56 +1930,166 @@ The tool will execute immediately and you'll receive the result (success or erro
} }
} }
use std::cell::RefCell;
// Thread-local state for tracking JSON tool call suppression
thread_local! {
static JSON_TOOL_STATE: RefCell<JsonToolState> = RefCell::new(JsonToolState::new());
}
#[derive(Debug, Clone)]
struct JsonToolState {
suppression_mode: bool,
brace_depth: i32,
buffer: String,
}
impl JsonToolState {
fn new() -> Self {
Self {
suppression_mode: false,
brace_depth: 0,
buffer: String::new(),
}
}
fn reset(&mut self) {
self.suppression_mode = false;
self.brace_depth = 0;
self.buffer.clear();
}
}
// Helper function to filter JSON tool calls from display content // Helper function to filter JSON tool calls from display content
fn filter_json_tool_calls(content: &str) -> String { fn filter_json_tool_calls(content: &str) -> String {
JSON_TOOL_STATE.with(|state| {
let mut state = state.borrow_mut();
// If we're already in suppression mode, continue tracking
if state.suppression_mode {
// Add content to buffer for tracking
state.buffer.push_str(content);
// Count braces to track JSON nesting depth
for ch in content.chars() {
match ch {
'{' => state.brace_depth += 1,
'}' => {
state.brace_depth -= 1;
// Exit suppression mode when we've closed all braces
if state.brace_depth <= 0 {
debug!("Exiting JSON tool suppression mode - completed JSON object");
state.reset();
// Check if there's any content after the JSON
if let Some(close_pos) = content.rfind('}') {
if close_pos + 1 < content.len() {
// Return any content after the JSON
return content[close_pos + 1..].to_string();
}
}
}
}
_ => {}
}
}
// While in suppression mode, return empty string
return String::new();
}
// Check if content contains any JSON tool call patterns // Check if content contains any JSON tool call patterns
let patterns = [ let patterns = [
r#"{"tool":"#, r#"{"tool":"#,
r#"{"tool"#, // Partial pattern
r#"{"too"#, // Even more partial
r#"{"to"#, // Very partial
r#"{"t"#, // Extremely partial
r#"{ "tool":"#, r#"{ "tool":"#,
r#"{"tool" :"#, r#"{"tool" :"#,
r#"{ "tool" :"#, r#"{ "tool" :"#,
r#"{"tool": "#, // Added pattern with space after colon r#"{"tool": "#, // Pattern with space after colon
r#"{ "tool": "#, // Added pattern with spaces r#"{ "tool": "#, // Pattern with spaces
]; ];
// Check if any pattern is found in the content // Check if any pattern is found in the content
let has_tool_call_pattern = patterns.iter().any(|pattern| content.contains(pattern)); for pattern in &patterns {
if let Some(pos) = content.find(pattern) {
debug!("Detected JSON tool call pattern '{}' at position {} - entering suppression mode", pattern, pos);
// Found a tool call pattern - enter suppression mode
state.suppression_mode = true;
state.brace_depth = 0;
state.buffer.clear();
state.buffer.push_str(&content[pos..]);
if has_tool_call_pattern { // Count braces in the remaining content after the pattern
// If we detect a JSON tool call pattern anywhere in the content, for ch in content[pos..].chars() {
// suppress it completely match ch {
"".to_string() '{' => state.brace_depth += 1,
'}' => {
state.brace_depth -= 1;
if state.brace_depth <= 0 {
debug!("JSON tool call completed in same chunk - exiting suppression mode");
state.reset();
break;
}
}
_ => {}
}
}
// Return any content before the JSON tool call
if pos > 0 {
return content[..pos].to_string();
} else { } else {
return String::new();
}
}
}
// Check for partial JSON patterns that might be split across chunks // Check for partial JSON patterns that might be split across chunks
let trimmed = content.trim(); let trimmed = content.trim();
// More comprehensive pattern matching for partial tool calls // Special case: single character chunks that might be part of a JSON tool call
if trimmed.starts_with(r#"{"tool"#) if content.len() <= 3 && state.buffer.len() < 20 {
|| trimmed.starts_with(r#"{ "tool"#) // Accumulate small chunks to check for patterns
|| trimmed.starts_with(r#"{"#) && (trimmed.contains("tool") || trimmed.contains("args")) state.buffer.push_str(content);
|| trimmed.contains(r#""tool":"#) if state.buffer.contains(r#"{"tool"#) || state.buffer.contains(r#"{ "tool"#) {
|| trimmed.contains(r#""tool": "#) state.suppression_mode = true;
|| trimmed.contains(r#""args":"#) state.brace_depth = state.buffer.chars().filter(|&c| c == '{').count() as i32;
|| trimmed.contains(r#""args": "#) return String::new();
|| trimmed.contains(r#"file_path"#) }
|| trimmed.contains(r#"command"#) }
|| trimmed.contains(r#"content"#) && trimmed.contains(r#"""#) // Likely JSON string
|| trimmed.contains(r#"summary"#) && trimmed.contains(r#"""#) // Likely JSON string // Check if this looks like the start of a JSON tool call (larger chunks)
|| (trimmed.starts_with('{') if trimmed.starts_with('{') && (trimmed.contains("tool") || trimmed.contains('"')) {
&& trimmed.len() < 100 // Increased threshold // This might be the start of a JSON tool call
&& (trimmed.contains("tool") || trimmed.contains("args") || trimmed.contains(r#"""#))) // Enter suppression mode preemptively
// Catch malformed tool calls like: {"tool": "write_file", "path debug!("Detected potential JSON tool call start - entering suppression mode");
|| (trimmed.contains(r#""tool":"#) || trimmed.contains(r#""tool": "#)) state.suppression_mode = true;
|| (trimmed.starts_with(r#"{"#) && trimmed.contains(r#"", ""#)) state.brace_depth = 0;
// JSON with quoted comma pattern state.buffer.clear();
{ state.buffer.push_str(content);
// This looks like part of a JSON tool call, suppress it
"".to_string() // Count braces
} else { for ch in content.chars() {
// Regular content, return as-is match ch {
'{' => state.brace_depth += 1,
'}' => {
state.brace_depth -= 1;
if state.brace_depth <= 0 {
state.reset();
break;
}
}
_ => {}
}
}
return String::new();
}
// No JSON tool call detected, return content as-is
content.to_string() content.to_string()
} })
}
} }
// Apply unified diff to an input string with optional [start, end) bounds // Apply unified diff to an input string with optional [start, end) bounds
@@ -1985,9 +2108,7 @@ pub fn apply_unified_diff_to_string(
} }
// Normalize line endings to avoid CRLF/CR mismatches // Normalize line endings to avoid CRLF/CR mismatches
let content_norm = file_content let content_norm = file_content.replace("\r\n", "\n").replace('\r', "\n");
.replace("\r\n", "\n")
.replace('\r', "\n");
// Determine and validate the search range // Determine and validate the search range
let search_start = start_char.unwrap_or(0); let search_start = start_char.unwrap_or(0);
@@ -2010,7 +2131,8 @@ pub fn apply_unified_diff_to_string(
if search_start > search_end { if search_start > search_end {
anyhow::bail!( anyhow::bail!(
"start position {} is greater than end position {}", "start position {} is greater than end position {}",
search_start, search_end search_start,
search_end
); );
} }
@@ -2362,7 +2484,8 @@ mod integration_tests {
#[test] #[test]
fn apply_multi_hunk_unified_diff_to_string() { fn apply_multi_hunk_unified_diff_to_string() {
let original = "line 1\nkeep\nold A\nkeep 2\nold B\nkeep 3\n"; let original = "line 1\nkeep\nold A\nkeep 2\nold B\nkeep 3\n";
let diff = "@@ -1,6 +1,6 @@\n line 1\n keep\n-old A\n+new A\n keep 2\n-old B\n+new B\n keep 3\n"; let diff =
"@@ -1,6 +1,6 @@\n line 1\n keep\n-old A\n+new A\n keep 2\n-old B\n+new B\n keep 3\n";
let result = apply_unified_diff_to_string(original, diff, None, None).unwrap(); let result = apply_unified_diff_to_string(original, diff, None, None).unwrap();
let expected = "line 1\nkeep\nnew A\nkeep 2\nnew B\nkeep 3\n"; let expected = "line 1\nkeep\nnew A\nkeep 2\nnew B\nkeep 3\n";
assert_eq!(result, expected); assert_eq!(result, expected);