context saving
This commit is contained in:
@@ -3,6 +3,8 @@ use g3_config::Config;
|
||||
use g3_execution::CodeExecutor;
|
||||
use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, warn};
|
||||
@@ -58,8 +60,6 @@ impl StreamingToolParser {
|
||||
let before_pos = &self.buffer[..pos];
|
||||
let code_block_count = before_pos.matches("```").count();
|
||||
|
||||
//info!("Code block count before position {}: {}", pos, code_block_count);
|
||||
|
||||
// Accept tool calls both inside and outside code blocks
|
||||
// The LLM might use either format despite our instructions
|
||||
//info!("Starting tool call parsing (code block status: {})", code_block_count % 2 == 1);
|
||||
@@ -102,10 +102,8 @@ impl StreamingToolParser {
|
||||
let end_pos = start_pos + i + 1;
|
||||
let json_str = &self.buffer[start_pos..end_pos];
|
||||
|
||||
//info!("Complete JSON found: {}", json_str);
|
||||
|
||||
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(json_str) {
|
||||
info!("Successfully parsed tool call: {:?}", tool_call);
|
||||
//info!("Successfully parsed tool call: {:?}", tool_call);
|
||||
// Reset parser state
|
||||
self.in_tool_call = false;
|
||||
self.tool_start_pos = None;
|
||||
@@ -391,7 +389,7 @@ The tool will execute immediately and you'll receive the result to continue with
|
||||
1. Break down tasks into small steps
|
||||
2. Execute ONE tool at a time
|
||||
3. Wait for the result before proceeding
|
||||
4. Use the actual file paths on the system (like ~/Downloads for Downloads folder)
|
||||
4. Use the actual file paths on the system
|
||||
5. End with final_output when done
|
||||
|
||||
Let's start with the first step of your task.
|
||||
@@ -430,12 +428,24 @@ Let's start with the first step of your task.
|
||||
|
||||
// Time the LLM call with cancellation support and streaming
|
||||
let llm_start = Instant::now();
|
||||
let (response_content, think_time) = tokio::select! {
|
||||
result = self.stream_completion(request) => result?,
|
||||
let result = tokio::select! {
|
||||
result = self.stream_completion(request) => result,
|
||||
_ = cancellation_token.cancelled() => {
|
||||
return Err(anyhow::anyhow!("Operation cancelled by user"));
|
||||
// Save context window on cancellation
|
||||
self.save_context_window("cancelled");
|
||||
Err(anyhow::anyhow!("Operation cancelled by user"))
|
||||
}
|
||||
};
|
||||
|
||||
let (response_content, think_time) = match result {
|
||||
Ok(content) => content,
|
||||
Err(e) => {
|
||||
// Save context window on error
|
||||
self.save_context_window("error");
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
let llm_duration = llm_start.elapsed();
|
||||
|
||||
// Create a mock usage for now (we'll need to track this during streaming)
|
||||
@@ -455,6 +465,9 @@ Let's start with the first step of your task.
|
||||
};
|
||||
self.context_window.add_message(assistant_message);
|
||||
|
||||
// Save context window at the end of successful interaction
|
||||
self.save_context_window("completed");
|
||||
|
||||
// With streaming tool execution, we don't need separate code execution
|
||||
// The tools are already executed during streaming
|
||||
if show_timing {
|
||||
@@ -469,6 +482,40 @@ Let's start with the first step of your task.
|
||||
}
|
||||
}
|
||||
|
||||
/// Save the entire context window to a file for debugging purposes
|
||||
fn save_context_window(&self, status: &str) {
|
||||
let timestamp = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let filename = format!("g3_context_{}.json", timestamp);
|
||||
|
||||
let context_data = serde_json::json!({
|
||||
"timestamp": timestamp,
|
||||
"status": status,
|
||||
"context_window": {
|
||||
"used_tokens": self.context_window.used_tokens,
|
||||
"total_tokens": self.context_window.total_tokens,
|
||||
"percentage_used": self.context_window.percentage_used(),
|
||||
"conversation_history": self.context_window.conversation_history
|
||||
}
|
||||
});
|
||||
|
||||
match serde_json::to_string_pretty(&context_data) {
|
||||
Ok(json_content) => {
|
||||
if let Err(e) = fs::write(&filename, json_content) {
|
||||
error!("Failed to save context window to {}: {}", filename, e);
|
||||
} else {
|
||||
info!("Context window saved to {}", filename);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to serialize context window: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_context_window(&self) -> &ContextWindow {
|
||||
&self.context_window
|
||||
}
|
||||
@@ -511,7 +558,10 @@ Let's start with the first step of your task.
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
if iteration_count > 1 && e.to_string().contains("busy") {
|
||||
warn!("Model busy on iteration {}, retrying in 500ms", iteration_count);
|
||||
warn!(
|
||||
"Model busy on iteration {}, retrying in 500ms",
|
||||
iteration_count
|
||||
);
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
||||
match provider.stream(request.clone()).await {
|
||||
Ok(s) => s,
|
||||
@@ -539,10 +589,6 @@ Let's start with the first step of your task.
|
||||
|
||||
// Check for tool calls in the streaming content
|
||||
if let Some((tool_call, tool_end_pos)) = parser.add_chunk(&chunk.content) {
|
||||
info!(
|
||||
"🔧 Detected tool call: {:?} at position {}",
|
||||
tool_call, tool_end_pos
|
||||
);
|
||||
// Found a complete tool call! Stop streaming and execute it
|
||||
let content_before_tool = parser.get_content_before_tool(tool_end_pos);
|
||||
|
||||
@@ -559,22 +605,62 @@ Let's start with the first step of your task.
|
||||
let new_content = if current_response.len() <= display_content.len() {
|
||||
// Use char indices to avoid UTF-8 boundary issues
|
||||
let chars_already_shown = current_response.chars().count();
|
||||
display_content.chars().skip(chars_already_shown).collect::<String>()
|
||||
display_content
|
||||
.chars()
|
||||
.skip(chars_already_shown)
|
||||
.collect::<String>()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
print!("{}", new_content);
|
||||
io::stdout().flush()?;
|
||||
|
||||
// Execute the tool
|
||||
// Execute the tool with formatted output
|
||||
println!(); // New line before tool execution
|
||||
|
||||
// Tool call header
|
||||
println!("┌─ {}", tool_call.tool);
|
||||
if let Some(args_obj) = tool_call.args.as_object() {
|
||||
for (key, value) in args_obj {
|
||||
let value_str = match value {
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
_ => value.to_string(),
|
||||
};
|
||||
println!("│ {}", value_str);
|
||||
}
|
||||
}
|
||||
println!("├─ output:");
|
||||
|
||||
let exec_start = Instant::now();
|
||||
let tool_result = self.execute_tool(&tool_call).await?;
|
||||
let exec_duration = exec_start.elapsed();
|
||||
total_execution_time += exec_duration;
|
||||
|
||||
// Display tool execution result
|
||||
println!("🔧 {}: {}", tool_call.tool, tool_result);
|
||||
// Display tool execution result with proper indentation
|
||||
let output_lines: Vec<&str> = tool_result.lines().collect();
|
||||
const MAX_LINES: usize = 5;
|
||||
|
||||
if output_lines.len() <= MAX_LINES {
|
||||
// Show all lines if within limit
|
||||
for line in output_lines {
|
||||
println!("│ {}", line);
|
||||
}
|
||||
} else {
|
||||
// Show first MAX_LINES and add truncation note
|
||||
for line in output_lines.iter().take(MAX_LINES) {
|
||||
println!("│ {}", line);
|
||||
}
|
||||
let hidden_count = output_lines.len() - MAX_LINES;
|
||||
println!(
|
||||
"│ ... ({} more line{} hidden)",
|
||||
hidden_count,
|
||||
if hidden_count == 1 { "" } else { "s" }
|
||||
);
|
||||
}
|
||||
|
||||
// Closure marker with timing
|
||||
println!("└─ ⚡️ {}", Self::format_duration(exec_duration));
|
||||
println!();
|
||||
print!("🤖 "); // Continue response indicator
|
||||
io::stdout().flush()?;
|
||||
|
||||
@@ -593,7 +679,7 @@ Let's start with the first step of your task.
|
||||
content: format!("Tool result: {}", tool_result),
|
||||
};
|
||||
|
||||
request.messages.push(tool_message);
|
||||
//request.messages.push(tool_message);
|
||||
request.messages.push(result_message);
|
||||
|
||||
full_response.push_str(display_content);
|
||||
@@ -660,8 +746,11 @@ Let's start with the first step of your task.
|
||||
"shell" => {
|
||||
if let Some(command) = tool_call.args.get("command") {
|
||||
if let Some(command_str) = command.as_str() {
|
||||
// Use shell escaping to handle filenames with spaces and special characters
|
||||
let escaped_command = shell_escape_command(command_str);
|
||||
|
||||
let executor = CodeExecutor::new();
|
||||
match executor.execute_code("bash", command_str).await {
|
||||
match executor.execute_code("bash", &escaped_command).await {
|
||||
Ok(result) => {
|
||||
if result.success {
|
||||
Ok(if result.stdout.is_empty() {
|
||||
@@ -716,6 +805,90 @@ Let's start with the first step of your task.
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to properly escape shell commands
|
||||
fn shell_escape_command(command: &str) -> String {
|
||||
// Simple approach: if the command contains file paths with spaces,
|
||||
// we need to be more intelligent about escaping
|
||||
|
||||
// For now, let's use a basic approach that handles common cases
|
||||
// This is a simplified version - a full implementation would use proper shell parsing
|
||||
|
||||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||||
if parts.is_empty() {
|
||||
return command.to_string();
|
||||
}
|
||||
|
||||
let cmd = parts[0];
|
||||
let args = &parts[1..];
|
||||
|
||||
// Commands that typically take file paths as arguments
|
||||
let file_commands = [
|
||||
"cat", "ls", "cp", "mv", "rm", "chmod", "chown", "file", "head", "tail", "wc", "grep",
|
||||
];
|
||||
|
||||
if file_commands.contains(&cmd) {
|
||||
// For file commands, we need to be smarter about escaping
|
||||
// Let's use a different approach: use the original command but wrap it in quotes if needed
|
||||
|
||||
// Check if the command already has proper quoting
|
||||
if command.contains('"') || command.contains('\'') {
|
||||
// Already has some quoting, use as-is
|
||||
return command.to_string();
|
||||
}
|
||||
|
||||
// Look for file paths that need escaping (contain spaces but aren't quoted)
|
||||
let mut escaped_command = String::new();
|
||||
let mut in_quotes = false;
|
||||
let mut current_word = String::new();
|
||||
let mut words = Vec::new();
|
||||
|
||||
for ch in command.chars() {
|
||||
match ch {
|
||||
' ' if !in_quotes => {
|
||||
if !current_word.is_empty() {
|
||||
words.push(current_word.clone());
|
||||
current_word.clear();
|
||||
}
|
||||
}
|
||||
'"' => {
|
||||
in_quotes = !in_quotes;
|
||||
current_word.push(ch);
|
||||
}
|
||||
_ => {
|
||||
current_word.push(ch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !current_word.is_empty() {
|
||||
words.push(current_word);
|
||||
}
|
||||
|
||||
// Reconstruct the command with proper escaping
|
||||
for (i, word) in words.iter().enumerate() {
|
||||
if i > 0 {
|
||||
escaped_command.push(' ');
|
||||
}
|
||||
|
||||
// If this word looks like a file path (contains / or ~) and has spaces, quote it
|
||||
if word.contains('/') || word.starts_with('~') {
|
||||
if word.contains(' ') && !word.starts_with('"') && !word.starts_with('\'') {
|
||||
escaped_command.push_str(&format!("\"{}\"", word));
|
||||
} else {
|
||||
escaped_command.push_str(word);
|
||||
}
|
||||
} else {
|
||||
escaped_command.push_str(word);
|
||||
}
|
||||
}
|
||||
|
||||
escaped_command
|
||||
} else {
|
||||
// For non-file commands, use the original command
|
||||
command.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub mod providers {
|
||||
pub mod anthropic;
|
||||
pub mod embedded;
|
||||
|
||||
Reference in New Issue
Block a user