context saving

This commit is contained in:
Dhanji Prasanna
2025-09-09 11:39:43 +10:00
parent 0c92a7c6b4
commit 02d95e01a0

View File

@@ -3,6 +3,8 @@ use g3_config::Config;
use g3_execution::CodeExecutor; use g3_execution::CodeExecutor;
use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry}; use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fs;
use std::path::Path;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn}; use tracing::{error, info, warn};
@@ -58,8 +60,6 @@ impl StreamingToolParser {
let before_pos = &self.buffer[..pos]; let before_pos = &self.buffer[..pos];
let code_block_count = before_pos.matches("```").count(); let code_block_count = before_pos.matches("```").count();
//info!("Code block count before position {}: {}", pos, code_block_count);
// Accept tool calls both inside and outside code blocks // Accept tool calls both inside and outside code blocks
// The LLM might use either format despite our instructions // The LLM might use either format despite our instructions
//info!("Starting tool call parsing (code block status: {})", code_block_count % 2 == 1); //info!("Starting tool call parsing (code block status: {})", code_block_count % 2 == 1);
@@ -102,10 +102,8 @@ impl StreamingToolParser {
let end_pos = start_pos + i + 1; let end_pos = start_pos + i + 1;
let json_str = &self.buffer[start_pos..end_pos]; let json_str = &self.buffer[start_pos..end_pos];
//info!("Complete JSON found: {}", json_str);
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(json_str) { if let Ok(tool_call) = serde_json::from_str::<ToolCall>(json_str) {
info!("Successfully parsed tool call: {:?}", tool_call); //info!("Successfully parsed tool call: {:?}", tool_call);
// Reset parser state // Reset parser state
self.in_tool_call = false; self.in_tool_call = false;
self.tool_start_pos = None; self.tool_start_pos = None;
@@ -391,7 +389,7 @@ The tool will execute immediately and you'll receive the result to continue with
1. Break down tasks into small steps 1. Break down tasks into small steps
2. Execute ONE tool at a time 2. Execute ONE tool at a time
3. Wait for the result before proceeding 3. Wait for the result before proceeding
4. Use the actual file paths on the system (like ~/Downloads for Downloads folder) 4. Use the actual file paths on the system
5. End with final_output when done 5. End with final_output when done
Let's start with the first step of your task. Let's start with the first step of your task.
@@ -430,12 +428,24 @@ Let's start with the first step of your task.
// Time the LLM call with cancellation support and streaming // Time the LLM call with cancellation support and streaming
let llm_start = Instant::now(); let llm_start = Instant::now();
let (response_content, think_time) = tokio::select! { let result = tokio::select! {
result = self.stream_completion(request) => result?, result = self.stream_completion(request) => result,
_ = cancellation_token.cancelled() => { _ = cancellation_token.cancelled() => {
return Err(anyhow::anyhow!("Operation cancelled by user")); // Save context window on cancellation
self.save_context_window("cancelled");
Err(anyhow::anyhow!("Operation cancelled by user"))
} }
}; };
let (response_content, think_time) = match result {
Ok(content) => content,
Err(e) => {
// Save context window on error
self.save_context_window("error");
return Err(e);
}
};
let llm_duration = llm_start.elapsed(); let llm_duration = llm_start.elapsed();
// Create a mock usage for now (we'll need to track this during streaming) // Create a mock usage for now (we'll need to track this during streaming)
@@ -455,6 +465,9 @@ Let's start with the first step of your task.
}; };
self.context_window.add_message(assistant_message); self.context_window.add_message(assistant_message);
// Save context window at the end of successful interaction
self.save_context_window("completed");
// With streaming tool execution, we don't need separate code execution // With streaming tool execution, we don't need separate code execution
// The tools are already executed during streaming // The tools are already executed during streaming
if show_timing { if show_timing {
@@ -469,6 +482,40 @@ Let's start with the first step of your task.
} }
} }
/// Save the entire context window to a file for debugging purposes
fn save_context_window(&self, status: &str) {
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let filename = format!("g3_context_{}.json", timestamp);
let context_data = serde_json::json!({
"timestamp": timestamp,
"status": status,
"context_window": {
"used_tokens": self.context_window.used_tokens,
"total_tokens": self.context_window.total_tokens,
"percentage_used": self.context_window.percentage_used(),
"conversation_history": self.context_window.conversation_history
}
});
match serde_json::to_string_pretty(&context_data) {
Ok(json_content) => {
if let Err(e) = fs::write(&filename, json_content) {
error!("Failed to save context window to {}: {}", filename, e);
} else {
info!("Context window saved to {}", filename);
}
}
Err(e) => {
error!("Failed to serialize context window: {}", e);
}
}
}
pub fn get_context_window(&self) -> &ContextWindow { pub fn get_context_window(&self) -> &ContextWindow {
&self.context_window &self.context_window
} }
@@ -511,7 +558,10 @@ Let's start with the first step of your task.
Ok(s) => s, Ok(s) => s,
Err(e) => { Err(e) => {
if iteration_count > 1 && e.to_string().contains("busy") { if iteration_count > 1 && e.to_string().contains("busy") {
warn!("Model busy on iteration {}, retrying in 500ms", iteration_count); warn!(
"Model busy on iteration {}, retrying in 500ms",
iteration_count
);
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
match provider.stream(request.clone()).await { match provider.stream(request.clone()).await {
Ok(s) => s, Ok(s) => s,
@@ -539,10 +589,6 @@ Let's start with the first step of your task.
// Check for tool calls in the streaming content // Check for tool calls in the streaming content
if let Some((tool_call, tool_end_pos)) = parser.add_chunk(&chunk.content) { if let Some((tool_call, tool_end_pos)) = parser.add_chunk(&chunk.content) {
info!(
"🔧 Detected tool call: {:?} at position {}",
tool_call, tool_end_pos
);
// Found a complete tool call! Stop streaming and execute it // Found a complete tool call! Stop streaming and execute it
let content_before_tool = parser.get_content_before_tool(tool_end_pos); let content_before_tool = parser.get_content_before_tool(tool_end_pos);
@@ -559,22 +605,62 @@ Let's start with the first step of your task.
let new_content = if current_response.len() <= display_content.len() { let new_content = if current_response.len() <= display_content.len() {
// Use char indices to avoid UTF-8 boundary issues // Use char indices to avoid UTF-8 boundary issues
let chars_already_shown = current_response.chars().count(); let chars_already_shown = current_response.chars().count();
display_content.chars().skip(chars_already_shown).collect::<String>() display_content
.chars()
.skip(chars_already_shown)
.collect::<String>()
} else { } else {
String::new() String::new()
}; };
print!("{}", new_content); print!("{}", new_content);
io::stdout().flush()?; io::stdout().flush()?;
// Execute the tool // Execute the tool with formatted output
println!(); // New line before tool execution println!(); // New line before tool execution
// Tool call header
println!("┌─ {}", tool_call.tool);
if let Some(args_obj) = tool_call.args.as_object() {
for (key, value) in args_obj {
let value_str = match value {
serde_json::Value::String(s) => s.clone(),
_ => value.to_string(),
};
println!("{}", value_str);
}
}
println!("├─ output:");
let exec_start = Instant::now(); let exec_start = Instant::now();
let tool_result = self.execute_tool(&tool_call).await?; let tool_result = self.execute_tool(&tool_call).await?;
let exec_duration = exec_start.elapsed(); let exec_duration = exec_start.elapsed();
total_execution_time += exec_duration; total_execution_time += exec_duration;
// Display tool execution result // Display tool execution result with proper indentation
println!("🔧 {}: {}", tool_call.tool, tool_result); let output_lines: Vec<&str> = tool_result.lines().collect();
const MAX_LINES: usize = 5;
if output_lines.len() <= MAX_LINES {
// Show all lines if within limit
for line in output_lines {
println!("{}", line);
}
} else {
// Show first MAX_LINES and add truncation note
for line in output_lines.iter().take(MAX_LINES) {
println!("{}", line);
}
let hidden_count = output_lines.len() - MAX_LINES;
println!(
"│ ... ({} more line{} hidden)",
hidden_count,
if hidden_count == 1 { "" } else { "s" }
);
}
// Closure marker with timing
println!("└─ ⚡️ {}", Self::format_duration(exec_duration));
println!();
print!("🤖 "); // Continue response indicator print!("🤖 "); // Continue response indicator
io::stdout().flush()?; io::stdout().flush()?;
@@ -593,7 +679,7 @@ Let's start with the first step of your task.
content: format!("Tool result: {}", tool_result), content: format!("Tool result: {}", tool_result),
}; };
request.messages.push(tool_message); //request.messages.push(tool_message);
request.messages.push(result_message); request.messages.push(result_message);
full_response.push_str(display_content); full_response.push_str(display_content);
@@ -660,8 +746,11 @@ Let's start with the first step of your task.
"shell" => { "shell" => {
if let Some(command) = tool_call.args.get("command") { if let Some(command) = tool_call.args.get("command") {
if let Some(command_str) = command.as_str() { if let Some(command_str) = command.as_str() {
// Use shell escaping to handle filenames with spaces and special characters
let escaped_command = shell_escape_command(command_str);
let executor = CodeExecutor::new(); let executor = CodeExecutor::new();
match executor.execute_code("bash", command_str).await { match executor.execute_code("bash", &escaped_command).await {
Ok(result) => { Ok(result) => {
if result.success { if result.success {
Ok(if result.stdout.is_empty() { Ok(if result.stdout.is_empty() {
@@ -716,6 +805,90 @@ Let's start with the first step of your task.
} }
} }
// Helper function to properly escape shell commands
fn shell_escape_command(command: &str) -> String {
// Simple approach: if the command contains file paths with spaces,
// we need to be more intelligent about escaping
// For now, let's use a basic approach that handles common cases
// This is a simplified version - a full implementation would use proper shell parsing
let parts: Vec<&str> = command.split_whitespace().collect();
if parts.is_empty() {
return command.to_string();
}
let cmd = parts[0];
let args = &parts[1..];
// Commands that typically take file paths as arguments
let file_commands = [
"cat", "ls", "cp", "mv", "rm", "chmod", "chown", "file", "head", "tail", "wc", "grep",
];
if file_commands.contains(&cmd) {
// For file commands, we need to be smarter about escaping
// Let's use a different approach: use the original command but wrap it in quotes if needed
// Check if the command already has proper quoting
if command.contains('"') || command.contains('\'') {
// Already has some quoting, use as-is
return command.to_string();
}
// Look for file paths that need escaping (contain spaces but aren't quoted)
let mut escaped_command = String::new();
let mut in_quotes = false;
let mut current_word = String::new();
let mut words = Vec::new();
for ch in command.chars() {
match ch {
' ' if !in_quotes => {
if !current_word.is_empty() {
words.push(current_word.clone());
current_word.clear();
}
}
'"' => {
in_quotes = !in_quotes;
current_word.push(ch);
}
_ => {
current_word.push(ch);
}
}
}
if !current_word.is_empty() {
words.push(current_word);
}
// Reconstruct the command with proper escaping
for (i, word) in words.iter().enumerate() {
if i > 0 {
escaped_command.push(' ');
}
// If this word looks like a file path (contains / or ~) and has spaces, quote it
if word.contains('/') || word.starts_with('~') {
if word.contains(' ') && !word.starts_with('"') && !word.starts_with('\'') {
escaped_command.push_str(&format!("\"{}\"", word));
} else {
escaped_command.push_str(word);
}
} else {
escaped_command.push_str(word);
}
}
escaped_command
} else {
// For non-file commands, use the original command
command.to_string()
}
}
pub mod providers { pub mod providers {
pub mod anthropic; pub mod anthropic;
pub mod embedded; pub mod embedded;