context saving

This commit is contained in:
Dhanji Prasanna
2025-09-09 11:39:43 +10:00
parent 0c92a7c6b4
commit 02d95e01a0

View File

@@ -3,6 +3,8 @@ use g3_config::Config;
use g3_execution::CodeExecutor;
use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::Path;
use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn};
@@ -58,8 +60,6 @@ impl StreamingToolParser {
let before_pos = &self.buffer[..pos];
let code_block_count = before_pos.matches("```").count();
//info!("Code block count before position {}: {}", pos, code_block_count);
// Accept tool calls both inside and outside code blocks
// The LLM might use either format despite our instructions
//info!("Starting tool call parsing (code block status: {})", code_block_count % 2 == 1);
@@ -102,10 +102,8 @@ impl StreamingToolParser {
let end_pos = start_pos + i + 1;
let json_str = &self.buffer[start_pos..end_pos];
//info!("Complete JSON found: {}", json_str);
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(json_str) {
info!("Successfully parsed tool call: {:?}", tool_call);
//info!("Successfully parsed tool call: {:?}", tool_call);
// Reset parser state
self.in_tool_call = false;
self.tool_start_pos = None;
@@ -391,7 +389,7 @@ The tool will execute immediately and you'll receive the result to continue with
1. Break down tasks into small steps
2. Execute ONE tool at a time
3. Wait for the result before proceeding
4. Use the actual file paths on the system (like ~/Downloads for Downloads folder)
4. Use the actual file paths on the system
5. End with final_output when done
Let's start with the first step of your task.
@@ -430,12 +428,24 @@ Let's start with the first step of your task.
// Time the LLM call with cancellation support and streaming
let llm_start = Instant::now();
let (response_content, think_time) = tokio::select! {
result = self.stream_completion(request) => result?,
let result = tokio::select! {
result = self.stream_completion(request) => result,
_ = cancellation_token.cancelled() => {
return Err(anyhow::anyhow!("Operation cancelled by user"));
// Save context window on cancellation
self.save_context_window("cancelled");
Err(anyhow::anyhow!("Operation cancelled by user"))
}
};
let (response_content, think_time) = match result {
Ok(content) => content,
Err(e) => {
// Save context window on error
self.save_context_window("error");
return Err(e);
}
};
let llm_duration = llm_start.elapsed();
// Create a mock usage for now (we'll need to track this during streaming)
@@ -455,6 +465,9 @@ Let's start with the first step of your task.
};
self.context_window.add_message(assistant_message);
// Save context window at the end of successful interaction
self.save_context_window("completed");
// With streaming tool execution, we don't need separate code execution
// The tools are already executed during streaming
if show_timing {
@@ -469,6 +482,40 @@ Let's start with the first step of your task.
}
}
/// Save the entire context window to a file for debugging purposes
fn save_context_window(&self, status: &str) {
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let filename = format!("g3_context_{}.json", timestamp);
let context_data = serde_json::json!({
"timestamp": timestamp,
"status": status,
"context_window": {
"used_tokens": self.context_window.used_tokens,
"total_tokens": self.context_window.total_tokens,
"percentage_used": self.context_window.percentage_used(),
"conversation_history": self.context_window.conversation_history
}
});
match serde_json::to_string_pretty(&context_data) {
Ok(json_content) => {
if let Err(e) = fs::write(&filename, json_content) {
error!("Failed to save context window to {}: {}", filename, e);
} else {
info!("Context window saved to {}", filename);
}
}
Err(e) => {
error!("Failed to serialize context window: {}", e);
}
}
}
pub fn get_context_window(&self) -> &ContextWindow {
&self.context_window
}
@@ -511,7 +558,10 @@ Let's start with the first step of your task.
Ok(s) => s,
Err(e) => {
if iteration_count > 1 && e.to_string().contains("busy") {
warn!("Model busy on iteration {}, retrying in 500ms", iteration_count);
warn!(
"Model busy on iteration {}, retrying in 500ms",
iteration_count
);
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
match provider.stream(request.clone()).await {
Ok(s) => s,
@@ -539,10 +589,6 @@ Let's start with the first step of your task.
// Check for tool calls in the streaming content
if let Some((tool_call, tool_end_pos)) = parser.add_chunk(&chunk.content) {
info!(
"🔧 Detected tool call: {:?} at position {}",
tool_call, tool_end_pos
);
// Found a complete tool call! Stop streaming and execute it
let content_before_tool = parser.get_content_before_tool(tool_end_pos);
@@ -559,22 +605,62 @@ Let's start with the first step of your task.
let new_content = if current_response.len() <= display_content.len() {
// Use char indices to avoid UTF-8 boundary issues
let chars_already_shown = current_response.chars().count();
display_content.chars().skip(chars_already_shown).collect::<String>()
display_content
.chars()
.skip(chars_already_shown)
.collect::<String>()
} else {
String::new()
};
print!("{}", new_content);
io::stdout().flush()?;
// Execute the tool
// Execute the tool with formatted output
println!(); // New line before tool execution
// Tool call header
println!("┌─ {}", tool_call.tool);
if let Some(args_obj) = tool_call.args.as_object() {
for (key, value) in args_obj {
let value_str = match value {
serde_json::Value::String(s) => s.clone(),
_ => value.to_string(),
};
println!("{}", value_str);
}
}
println!("├─ output:");
let exec_start = Instant::now();
let tool_result = self.execute_tool(&tool_call).await?;
let exec_duration = exec_start.elapsed();
total_execution_time += exec_duration;
// Display tool execution result
println!("🔧 {}: {}", tool_call.tool, tool_result);
// Display tool execution result with proper indentation
let output_lines: Vec<&str> = tool_result.lines().collect();
const MAX_LINES: usize = 5;
if output_lines.len() <= MAX_LINES {
// Show all lines if within limit
for line in output_lines {
println!("{}", line);
}
} else {
// Show first MAX_LINES and add truncation note
for line in output_lines.iter().take(MAX_LINES) {
println!("{}", line);
}
let hidden_count = output_lines.len() - MAX_LINES;
println!(
"│ ... ({} more line{} hidden)",
hidden_count,
if hidden_count == 1 { "" } else { "s" }
);
}
// Closure marker with timing
println!("└─ ⚡️ {}", Self::format_duration(exec_duration));
println!();
print!("🤖 "); // Continue response indicator
io::stdout().flush()?;
@@ -593,7 +679,7 @@ Let's start with the first step of your task.
content: format!("Tool result: {}", tool_result),
};
request.messages.push(tool_message);
//request.messages.push(tool_message);
request.messages.push(result_message);
full_response.push_str(display_content);
@@ -660,8 +746,11 @@ Let's start with the first step of your task.
"shell" => {
if let Some(command) = tool_call.args.get("command") {
if let Some(command_str) = command.as_str() {
// Use shell escaping to handle filenames with spaces and special characters
let escaped_command = shell_escape_command(command_str);
let executor = CodeExecutor::new();
match executor.execute_code("bash", command_str).await {
match executor.execute_code("bash", &escaped_command).await {
Ok(result) => {
if result.success {
Ok(if result.stdout.is_empty() {
@@ -716,6 +805,90 @@ Let's start with the first step of your task.
}
}
// Helper function to properly escape shell commands
fn shell_escape_command(command: &str) -> String {
// Simple approach: if the command contains file paths with spaces,
// we need to be more intelligent about escaping
// For now, let's use a basic approach that handles common cases
// This is a simplified version - a full implementation would use proper shell parsing
let parts: Vec<&str> = command.split_whitespace().collect();
if parts.is_empty() {
return command.to_string();
}
let cmd = parts[0];
let args = &parts[1..];
// Commands that typically take file paths as arguments
let file_commands = [
"cat", "ls", "cp", "mv", "rm", "chmod", "chown", "file", "head", "tail", "wc", "grep",
];
if file_commands.contains(&cmd) {
// For file commands, we need to be smarter about escaping
// Let's use a different approach: use the original command but wrap it in quotes if needed
// Check if the command already has proper quoting
if command.contains('"') || command.contains('\'') {
// Already has some quoting, use as-is
return command.to_string();
}
// Look for file paths that need escaping (contain spaces but aren't quoted)
let mut escaped_command = String::new();
let mut in_quotes = false;
let mut current_word = String::new();
let mut words = Vec::new();
for ch in command.chars() {
match ch {
' ' if !in_quotes => {
if !current_word.is_empty() {
words.push(current_word.clone());
current_word.clear();
}
}
'"' => {
in_quotes = !in_quotes;
current_word.push(ch);
}
_ => {
current_word.push(ch);
}
}
}
if !current_word.is_empty() {
words.push(current_word);
}
// Reconstruct the command with proper escaping
for (i, word) in words.iter().enumerate() {
if i > 0 {
escaped_command.push(' ');
}
// If this word looks like a file path (contains / or ~) and has spaces, quote it
if word.contains('/') || word.starts_with('~') {
if word.contains(' ') && !word.starts_with('"') && !word.starts_with('\'') {
escaped_command.push_str(&format!("\"{}\"", word));
} else {
escaped_command.push_str(word);
}
} else {
escaped_command.push_str(word);
}
}
escaped_command
} else {
// For non-file commands, use the original command
command.to_string()
}
}
pub mod providers {
pub mod anthropic;
pub mod embedded;