suppress json tool calls in raw text

This commit is contained in:
Dhanji Prasanna
2025-10-01 13:20:13 +10:00
parent 3349a33106
commit a843ecc9d0
2 changed files with 227 additions and 73 deletions

View File

@@ -1,5 +1,5 @@
pub mod project;
pub mod error_handling;
pub mod project;
#[cfg(test)]
mod error_handling_test;
@@ -581,7 +581,6 @@ impl Agent {
"You are G3, an AI programming agent. Your goal is to analyze, write and modify code to achieve given goals.
You have access to tools. When you need to accomplish a task, you MUST use the appropriate tool. Do not just describe what you would do - actually use the tools.
Always start by reading the project's README. Create one if this is a new project or making major changes.
IMPORTANT: You must call tools to achieve goals. When you receive a request:
1. Analyze and identify what needs to be done
@@ -975,15 +974,15 @@ The tool will execute immediately and you'll receive the result (success or erro
request: &CompletionRequest,
error_context: &error_handling::ErrorContext,
) -> Result<g3_providers::CompletionStream> {
use crate::error_handling::{classify_error, calculate_retry_delay, ErrorType};
use crate::error_handling::{calculate_retry_delay, classify_error, ErrorType};
let mut attempt = 0;
const MAX_ATTEMPTS: u32 = 3;
loop {
attempt += 1;
let provider = self.providers.get(None)?;
match provider.stream(request.clone()).await {
Ok(stream) => {
if attempt > 1 {
@@ -994,7 +993,10 @@ The tool will execute immediately and you'll receive the result (success or erro
Err(e) if attempt < MAX_ATTEMPTS => {
if matches!(classify_error(&e), ErrorType::Recoverable(_)) {
let delay = calculate_retry_delay(attempt);
warn!("Recoverable error on attempt {}/{}: {}. Retrying in {:?}...", attempt, MAX_ATTEMPTS, e, delay);
warn!(
"Recoverable error on attempt {}/{}: {}. Retrying in {:?}...",
attempt, MAX_ATTEMPTS, e, delay
);
tokio::time::sleep(delay).await;
} else {
error_context.clone().log_error(&e);
@@ -1164,15 +1166,16 @@ The tool will execute immediately and you'll receive the result (success or erro
let provider = self.providers.get(None)?;
debug!("Got provider: {}", provider.name());
// Create error context for detailed logging
let last_prompt = request.messages
let last_prompt = request
.messages
.iter()
.rev()
.find(|m| matches!(m.role, MessageRole::User))
.map(|m| m.content.clone())
.unwrap_or_else(|| "No user message found".to_string());
let error_context = ErrorContext::new(
"stream_completion".to_string(),
provider.name().to_string(),
@@ -1180,10 +1183,12 @@ The tool will execute immediately and you'll receive the result (success or erro
last_prompt,
self.session_id.clone(),
self.context_window.used_tokens,
).with_request(
serde_json::to_string(&request).unwrap_or_else(|_| "Failed to serialize request".to_string())
)
.with_request(
serde_json::to_string(&request)
.unwrap_or_else(|_| "Failed to serialize request".to_string()),
);
// Try to get stream with retry logic
let mut stream = match self.stream_with_retry(&request, &error_context).await {
Ok(s) => s,
@@ -1195,7 +1200,7 @@ The tool will execute immediately and you'll receive the result (success or erro
iteration_count
);
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
match self.stream_with_retry(&request, &error_context).await {
Ok(s) => s,
Err(e2) => {
@@ -1231,14 +1236,17 @@ The tool will execute immediately and you'll receive the result (success or erro
// Get the text content accumulated so far
let text_content = parser.get_text_content();
// Clean and prepare display content
let clean_display_content = text_content
// Clean the content
let clean_content = text_content
.replace("<|im_end|>", "")
.replace("</s>", "")
.replace("[/INST]", "")
.replace("<</SYS>>", "");
let final_display_content = clean_display_content.trim();
// Filter out JSON tool calls from the display
let filtered_content = filter_json_tool_calls(&clean_content);
let final_display_content = filtered_content.trim();
// Display any new content before tool execution
let new_content =
if current_response.len() <= final_display_content.len() {
@@ -1861,7 +1869,10 @@ The tool will execute immediately and you'll receive the result (success or erro
.and_then(|v| v.as_u64())
.map(|n| n as usize);
debug!("str_replace: path={}, start={:?}, end={:?}", file_path, start_char, end_char);
debug!(
"str_replace: path={}, start={:?}, end={:?}",
file_path, start_char, end_char
);
// Read the existing file
let file_content = match std::fs::read_to_string(file_path) {
@@ -1870,16 +1881,18 @@ The tool will execute immediately and you'll receive the result (success or erro
};
// Apply unified diff to content
let result = match apply_unified_diff_to_string(&file_content, diff, start_char, end_char) {
Ok(r) => r,
Err(e) => return Ok(format!("{}", e)),
};
let result =
match apply_unified_diff_to_string(&file_content, diff, start_char, end_char) {
Ok(r) => r,
Err(e) => return Ok(format!("{}", e)),
};
// Write the result back to the file
match std::fs::write(file_path, &result) {
Ok(()) => {
Ok(format!("✅ Successfully applied unified diff to '{}'", file_path))
}
Ok(()) => Ok(format!(
"✅ Successfully applied unified diff to '{}'",
file_path
)),
Err(e) => Ok(format!("❌ Failed to write to file '{}': {}", file_path, e)),
}
}
@@ -1917,56 +1930,166 @@ The tool will execute immediately and you'll receive the result (success or erro
}
}
use std::cell::RefCell;
// Thread-local state for tracking JSON tool call suppression
thread_local! {
static JSON_TOOL_STATE: RefCell<JsonToolState> = RefCell::new(JsonToolState::new());
}
#[derive(Debug, Clone)]
struct JsonToolState {
suppression_mode: bool,
brace_depth: i32,
buffer: String,
}
impl JsonToolState {
fn new() -> Self {
Self {
suppression_mode: false,
brace_depth: 0,
buffer: String::new(),
}
}
fn reset(&mut self) {
self.suppression_mode = false;
self.brace_depth = 0;
self.buffer.clear();
}
}
// Helper function to filter JSON tool calls from display content
fn filter_json_tool_calls(content: &str) -> String {
// Check if content contains any JSON tool call patterns
let patterns = [
r#"{"tool":"#,
r#"{ "tool":"#,
r#"{"tool" :"#,
r#"{ "tool" :"#,
r#"{"tool": "#, // Added pattern with space after colon
r#"{ "tool": "#, // Added pattern with spaces
];
JSON_TOOL_STATE.with(|state| {
let mut state = state.borrow_mut();
// If we're already in suppression mode, continue tracking
if state.suppression_mode {
// Add content to buffer for tracking
state.buffer.push_str(content);
// Count braces to track JSON nesting depth
for ch in content.chars() {
match ch {
'{' => state.brace_depth += 1,
'}' => {
state.brace_depth -= 1;
// Exit suppression mode when we've closed all braces
if state.brace_depth <= 0 {
debug!("Exiting JSON tool suppression mode - completed JSON object");
state.reset();
// Check if there's any content after the JSON
if let Some(close_pos) = content.rfind('}') {
if close_pos + 1 < content.len() {
// Return any content after the JSON
return content[close_pos + 1..].to_string();
}
}
}
}
_ => {}
}
}
// While in suppression mode, return empty string
return String::new();
}
// Check if any pattern is found in the content
let has_tool_call_pattern = patterns.iter().any(|pattern| content.contains(pattern));
// Check if content contains any JSON tool call patterns
let patterns = [
r#"{"tool":"#,
r#"{"tool"#, // Partial pattern
r#"{"too"#, // Even more partial
r#"{"to"#, // Very partial
r#"{"t"#, // Extremely partial
r#"{ "tool":"#,
r#"{"tool" :"#,
r#"{ "tool" :"#,
r#"{"tool": "#, // Pattern with space after colon
r#"{ "tool": "#, // Pattern with spaces
];
// Check if any pattern is found in the content
for pattern in &patterns {
if let Some(pos) = content.find(pattern) {
debug!("Detected JSON tool call pattern '{}' at position {} - entering suppression mode", pattern, pos);
// Found a tool call pattern - enter suppression mode
state.suppression_mode = true;
state.brace_depth = 0;
state.buffer.clear();
state.buffer.push_str(&content[pos..]);
// Count braces in the remaining content after the pattern
for ch in content[pos..].chars() {
match ch {
'{' => state.brace_depth += 1,
'}' => {
state.brace_depth -= 1;
if state.brace_depth <= 0 {
debug!("JSON tool call completed in same chunk - exiting suppression mode");
state.reset();
break;
}
}
_ => {}
}
}
// Return any content before the JSON tool call
if pos > 0 {
return content[..pos].to_string();
} else {
return String::new();
}
}
}
if has_tool_call_pattern {
// If we detect a JSON tool call pattern anywhere in the content,
// suppress it completely
"".to_string()
} else {
// Check for partial JSON patterns that might be split across chunks
let trimmed = content.trim();
// More comprehensive pattern matching for partial tool calls
if trimmed.starts_with(r#"{"tool"#)
|| trimmed.starts_with(r#"{ "tool"#)
|| trimmed.starts_with(r#"{"#) && (trimmed.contains("tool") || trimmed.contains("args"))
|| trimmed.contains(r#""tool":"#)
|| trimmed.contains(r#""tool": "#)
|| trimmed.contains(r#""args":"#)
|| trimmed.contains(r#""args": "#)
|| trimmed.contains(r#"file_path"#)
|| trimmed.contains(r#"command"#)
|| trimmed.contains(r#"content"#) && trimmed.contains(r#"""#) // Likely JSON string
|| trimmed.contains(r#"summary"#) && trimmed.contains(r#"""#) // Likely JSON string
|| (trimmed.starts_with('{')
&& trimmed.len() < 100 // Increased threshold
&& (trimmed.contains("tool") || trimmed.contains("args") || trimmed.contains(r#"""#)))
// Catch malformed tool calls like: {"tool": "write_file", "path
|| (trimmed.contains(r#""tool":"#) || trimmed.contains(r#""tool": "#))
|| (trimmed.starts_with(r#"{"#) && trimmed.contains(r#"", ""#))
// JSON with quoted comma pattern
{
// This looks like part of a JSON tool call, suppress it
"".to_string()
} else {
// Regular content, return as-is
content.to_string()
// Special case: single character chunks that might be part of a JSON tool call
if content.len() <= 3 && state.buffer.len() < 20 {
// Accumulate small chunks to check for patterns
state.buffer.push_str(content);
if state.buffer.contains(r#"{"tool"#) || state.buffer.contains(r#"{ "tool"#) {
state.suppression_mode = true;
state.brace_depth = state.buffer.chars().filter(|&c| c == '{').count() as i32;
return String::new();
}
}
}
// Check if this looks like the start of a JSON tool call (larger chunks)
if trimmed.starts_with('{') && (trimmed.contains("tool") || trimmed.contains('"')) {
// This might be the start of a JSON tool call
// Enter suppression mode preemptively
debug!("Detected potential JSON tool call start - entering suppression mode");
state.suppression_mode = true;
state.brace_depth = 0;
state.buffer.clear();
state.buffer.push_str(content);
// Count braces
for ch in content.chars() {
match ch {
'{' => state.brace_depth += 1,
'}' => {
state.brace_depth -= 1;
if state.brace_depth <= 0 {
state.reset();
break;
}
}
_ => {}
}
}
return String::new();
}
// No JSON tool call detected, return content as-is
content.to_string()
})
}
// Apply unified diff to an input string with optional [start, end) bounds
@@ -1985,9 +2108,7 @@ pub fn apply_unified_diff_to_string(
}
// Normalize line endings to avoid CRLF/CR mismatches
let content_norm = file_content
.replace("\r\n", "\n")
.replace('\r', "\n");
let content_norm = file_content.replace("\r\n", "\n").replace('\r', "\n");
// Determine and validate the search range
let search_start = start_char.unwrap_or(0);
@@ -2010,7 +2131,8 @@ pub fn apply_unified_diff_to_string(
if search_start > search_end {
anyhow::bail!(
"start position {} is greater than end position {}",
search_start, search_end
search_start,
search_end
);
}
@@ -2362,7 +2484,8 @@ mod integration_tests {
#[test]
fn apply_multi_hunk_unified_diff_to_string() {
let original = "line 1\nkeep\nold A\nkeep 2\nold B\nkeep 3\n";
let diff = "@@ -1,6 +1,6 @@\n line 1\n keep\n-old A\n+new A\n keep 2\n-old B\n+new B\n keep 3\n";
let diff =
"@@ -1,6 +1,6 @@\n line 1\n keep\n-old A\n+new A\n keep 2\n-old B\n+new B\n keep 3\n";
let result = apply_unified_diff_to_string(original, diff, None, None).unwrap();
let expected = "line 1\nkeep\nnew A\nkeep 2\nnew B\nkeep 3\n";
assert_eq!(result, expected);