context saving
This commit is contained in:
@@ -3,6 +3,8 @@ use g3_config::Config;
|
|||||||
use g3_execution::CodeExecutor;
|
use g3_execution::CodeExecutor;
|
||||||
use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry};
|
use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::fs;
|
||||||
|
use std::path::Path;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
use tracing::{error, info, warn};
|
use tracing::{error, info, warn};
|
||||||
@@ -58,8 +60,6 @@ impl StreamingToolParser {
|
|||||||
let before_pos = &self.buffer[..pos];
|
let before_pos = &self.buffer[..pos];
|
||||||
let code_block_count = before_pos.matches("```").count();
|
let code_block_count = before_pos.matches("```").count();
|
||||||
|
|
||||||
//info!("Code block count before position {}: {}", pos, code_block_count);
|
|
||||||
|
|
||||||
// Accept tool calls both inside and outside code blocks
|
// Accept tool calls both inside and outside code blocks
|
||||||
// The LLM might use either format despite our instructions
|
// The LLM might use either format despite our instructions
|
||||||
//info!("Starting tool call parsing (code block status: {})", code_block_count % 2 == 1);
|
//info!("Starting tool call parsing (code block status: {})", code_block_count % 2 == 1);
|
||||||
@@ -102,10 +102,8 @@ impl StreamingToolParser {
|
|||||||
let end_pos = start_pos + i + 1;
|
let end_pos = start_pos + i + 1;
|
||||||
let json_str = &self.buffer[start_pos..end_pos];
|
let json_str = &self.buffer[start_pos..end_pos];
|
||||||
|
|
||||||
//info!("Complete JSON found: {}", json_str);
|
|
||||||
|
|
||||||
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(json_str) {
|
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(json_str) {
|
||||||
info!("Successfully parsed tool call: {:?}", tool_call);
|
//info!("Successfully parsed tool call: {:?}", tool_call);
|
||||||
// Reset parser state
|
// Reset parser state
|
||||||
self.in_tool_call = false;
|
self.in_tool_call = false;
|
||||||
self.tool_start_pos = None;
|
self.tool_start_pos = None;
|
||||||
@@ -391,7 +389,7 @@ The tool will execute immediately and you'll receive the result to continue with
|
|||||||
1. Break down tasks into small steps
|
1. Break down tasks into small steps
|
||||||
2. Execute ONE tool at a time
|
2. Execute ONE tool at a time
|
||||||
3. Wait for the result before proceeding
|
3. Wait for the result before proceeding
|
||||||
4. Use the actual file paths on the system (like ~/Downloads for Downloads folder)
|
4. Use the actual file paths on the system
|
||||||
5. End with final_output when done
|
5. End with final_output when done
|
||||||
|
|
||||||
Let's start with the first step of your task.
|
Let's start with the first step of your task.
|
||||||
@@ -430,12 +428,24 @@ Let's start with the first step of your task.
|
|||||||
|
|
||||||
// Time the LLM call with cancellation support and streaming
|
// Time the LLM call with cancellation support and streaming
|
||||||
let llm_start = Instant::now();
|
let llm_start = Instant::now();
|
||||||
let (response_content, think_time) = tokio::select! {
|
let result = tokio::select! {
|
||||||
result = self.stream_completion(request) => result?,
|
result = self.stream_completion(request) => result,
|
||||||
_ = cancellation_token.cancelled() => {
|
_ = cancellation_token.cancelled() => {
|
||||||
return Err(anyhow::anyhow!("Operation cancelled by user"));
|
// Save context window on cancellation
|
||||||
|
self.save_context_window("cancelled");
|
||||||
|
Err(anyhow::anyhow!("Operation cancelled by user"))
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let (response_content, think_time) = match result {
|
||||||
|
Ok(content) => content,
|
||||||
|
Err(e) => {
|
||||||
|
// Save context window on error
|
||||||
|
self.save_context_window("error");
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let llm_duration = llm_start.elapsed();
|
let llm_duration = llm_start.elapsed();
|
||||||
|
|
||||||
// Create a mock usage for now (we'll need to track this during streaming)
|
// Create a mock usage for now (we'll need to track this during streaming)
|
||||||
@@ -455,6 +465,9 @@ Let's start with the first step of your task.
|
|||||||
};
|
};
|
||||||
self.context_window.add_message(assistant_message);
|
self.context_window.add_message(assistant_message);
|
||||||
|
|
||||||
|
// Save context window at the end of successful interaction
|
||||||
|
self.save_context_window("completed");
|
||||||
|
|
||||||
// With streaming tool execution, we don't need separate code execution
|
// With streaming tool execution, we don't need separate code execution
|
||||||
// The tools are already executed during streaming
|
// The tools are already executed during streaming
|
||||||
if show_timing {
|
if show_timing {
|
||||||
@@ -469,6 +482,40 @@ Let's start with the first step of your task.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Save the entire context window to a file for debugging purposes
|
||||||
|
fn save_context_window(&self, status: &str) {
|
||||||
|
let timestamp = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs();
|
||||||
|
|
||||||
|
let filename = format!("g3_context_{}.json", timestamp);
|
||||||
|
|
||||||
|
let context_data = serde_json::json!({
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"status": status,
|
||||||
|
"context_window": {
|
||||||
|
"used_tokens": self.context_window.used_tokens,
|
||||||
|
"total_tokens": self.context_window.total_tokens,
|
||||||
|
"percentage_used": self.context_window.percentage_used(),
|
||||||
|
"conversation_history": self.context_window.conversation_history
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
match serde_json::to_string_pretty(&context_data) {
|
||||||
|
Ok(json_content) => {
|
||||||
|
if let Err(e) = fs::write(&filename, json_content) {
|
||||||
|
error!("Failed to save context window to {}: {}", filename, e);
|
||||||
|
} else {
|
||||||
|
info!("Context window saved to {}", filename);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to serialize context window: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_context_window(&self) -> &ContextWindow {
|
pub fn get_context_window(&self) -> &ContextWindow {
|
||||||
&self.context_window
|
&self.context_window
|
||||||
}
|
}
|
||||||
@@ -511,7 +558,10 @@ Let's start with the first step of your task.
|
|||||||
Ok(s) => s,
|
Ok(s) => s,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
if iteration_count > 1 && e.to_string().contains("busy") {
|
if iteration_count > 1 && e.to_string().contains("busy") {
|
||||||
warn!("Model busy on iteration {}, retrying in 500ms", iteration_count);
|
warn!(
|
||||||
|
"Model busy on iteration {}, retrying in 500ms",
|
||||||
|
iteration_count
|
||||||
|
);
|
||||||
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
||||||
match provider.stream(request.clone()).await {
|
match provider.stream(request.clone()).await {
|
||||||
Ok(s) => s,
|
Ok(s) => s,
|
||||||
@@ -539,10 +589,6 @@ Let's start with the first step of your task.
|
|||||||
|
|
||||||
// Check for tool calls in the streaming content
|
// Check for tool calls in the streaming content
|
||||||
if let Some((tool_call, tool_end_pos)) = parser.add_chunk(&chunk.content) {
|
if let Some((tool_call, tool_end_pos)) = parser.add_chunk(&chunk.content) {
|
||||||
info!(
|
|
||||||
"🔧 Detected tool call: {:?} at position {}",
|
|
||||||
tool_call, tool_end_pos
|
|
||||||
);
|
|
||||||
// Found a complete tool call! Stop streaming and execute it
|
// Found a complete tool call! Stop streaming and execute it
|
||||||
let content_before_tool = parser.get_content_before_tool(tool_end_pos);
|
let content_before_tool = parser.get_content_before_tool(tool_end_pos);
|
||||||
|
|
||||||
@@ -559,22 +605,62 @@ Let's start with the first step of your task.
|
|||||||
let new_content = if current_response.len() <= display_content.len() {
|
let new_content = if current_response.len() <= display_content.len() {
|
||||||
// Use char indices to avoid UTF-8 boundary issues
|
// Use char indices to avoid UTF-8 boundary issues
|
||||||
let chars_already_shown = current_response.chars().count();
|
let chars_already_shown = current_response.chars().count();
|
||||||
display_content.chars().skip(chars_already_shown).collect::<String>()
|
display_content
|
||||||
|
.chars()
|
||||||
|
.skip(chars_already_shown)
|
||||||
|
.collect::<String>()
|
||||||
} else {
|
} else {
|
||||||
String::new()
|
String::new()
|
||||||
};
|
};
|
||||||
print!("{}", new_content);
|
print!("{}", new_content);
|
||||||
io::stdout().flush()?;
|
io::stdout().flush()?;
|
||||||
|
|
||||||
// Execute the tool
|
// Execute the tool with formatted output
|
||||||
println!(); // New line before tool execution
|
println!(); // New line before tool execution
|
||||||
|
|
||||||
|
// Tool call header
|
||||||
|
println!("┌─ {}", tool_call.tool);
|
||||||
|
if let Some(args_obj) = tool_call.args.as_object() {
|
||||||
|
for (key, value) in args_obj {
|
||||||
|
let value_str = match value {
|
||||||
|
serde_json::Value::String(s) => s.clone(),
|
||||||
|
_ => value.to_string(),
|
||||||
|
};
|
||||||
|
println!("│ {}", value_str);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println!("├─ output:");
|
||||||
|
|
||||||
let exec_start = Instant::now();
|
let exec_start = Instant::now();
|
||||||
let tool_result = self.execute_tool(&tool_call).await?;
|
let tool_result = self.execute_tool(&tool_call).await?;
|
||||||
let exec_duration = exec_start.elapsed();
|
let exec_duration = exec_start.elapsed();
|
||||||
total_execution_time += exec_duration;
|
total_execution_time += exec_duration;
|
||||||
|
|
||||||
// Display tool execution result
|
// Display tool execution result with proper indentation
|
||||||
println!("🔧 {}: {}", tool_call.tool, tool_result);
|
let output_lines: Vec<&str> = tool_result.lines().collect();
|
||||||
|
const MAX_LINES: usize = 5;
|
||||||
|
|
||||||
|
if output_lines.len() <= MAX_LINES {
|
||||||
|
// Show all lines if within limit
|
||||||
|
for line in output_lines {
|
||||||
|
println!("│ {}", line);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Show first MAX_LINES and add truncation note
|
||||||
|
for line in output_lines.iter().take(MAX_LINES) {
|
||||||
|
println!("│ {}", line);
|
||||||
|
}
|
||||||
|
let hidden_count = output_lines.len() - MAX_LINES;
|
||||||
|
println!(
|
||||||
|
"│ ... ({} more line{} hidden)",
|
||||||
|
hidden_count,
|
||||||
|
if hidden_count == 1 { "" } else { "s" }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Closure marker with timing
|
||||||
|
println!("└─ ⚡️ {}", Self::format_duration(exec_duration));
|
||||||
|
println!();
|
||||||
print!("🤖 "); // Continue response indicator
|
print!("🤖 "); // Continue response indicator
|
||||||
io::stdout().flush()?;
|
io::stdout().flush()?;
|
||||||
|
|
||||||
@@ -593,7 +679,7 @@ Let's start with the first step of your task.
|
|||||||
content: format!("Tool result: {}", tool_result),
|
content: format!("Tool result: {}", tool_result),
|
||||||
};
|
};
|
||||||
|
|
||||||
request.messages.push(tool_message);
|
//request.messages.push(tool_message);
|
||||||
request.messages.push(result_message);
|
request.messages.push(result_message);
|
||||||
|
|
||||||
full_response.push_str(display_content);
|
full_response.push_str(display_content);
|
||||||
@@ -660,8 +746,11 @@ Let's start with the first step of your task.
|
|||||||
"shell" => {
|
"shell" => {
|
||||||
if let Some(command) = tool_call.args.get("command") {
|
if let Some(command) = tool_call.args.get("command") {
|
||||||
if let Some(command_str) = command.as_str() {
|
if let Some(command_str) = command.as_str() {
|
||||||
|
// Use shell escaping to handle filenames with spaces and special characters
|
||||||
|
let escaped_command = shell_escape_command(command_str);
|
||||||
|
|
||||||
let executor = CodeExecutor::new();
|
let executor = CodeExecutor::new();
|
||||||
match executor.execute_code("bash", command_str).await {
|
match executor.execute_code("bash", &escaped_command).await {
|
||||||
Ok(result) => {
|
Ok(result) => {
|
||||||
if result.success {
|
if result.success {
|
||||||
Ok(if result.stdout.is_empty() {
|
Ok(if result.stdout.is_empty() {
|
||||||
@@ -716,6 +805,90 @@ Let's start with the first step of your task.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper function to properly escape shell commands
|
||||||
|
fn shell_escape_command(command: &str) -> String {
|
||||||
|
// Simple approach: if the command contains file paths with spaces,
|
||||||
|
// we need to be more intelligent about escaping
|
||||||
|
|
||||||
|
// For now, let's use a basic approach that handles common cases
|
||||||
|
// This is a simplified version - a full implementation would use proper shell parsing
|
||||||
|
|
||||||
|
let parts: Vec<&str> = command.split_whitespace().collect();
|
||||||
|
if parts.is_empty() {
|
||||||
|
return command.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let cmd = parts[0];
|
||||||
|
let args = &parts[1..];
|
||||||
|
|
||||||
|
// Commands that typically take file paths as arguments
|
||||||
|
let file_commands = [
|
||||||
|
"cat", "ls", "cp", "mv", "rm", "chmod", "chown", "file", "head", "tail", "wc", "grep",
|
||||||
|
];
|
||||||
|
|
||||||
|
if file_commands.contains(&cmd) {
|
||||||
|
// For file commands, we need to be smarter about escaping
|
||||||
|
// Let's use a different approach: use the original command but wrap it in quotes if needed
|
||||||
|
|
||||||
|
// Check if the command already has proper quoting
|
||||||
|
if command.contains('"') || command.contains('\'') {
|
||||||
|
// Already has some quoting, use as-is
|
||||||
|
return command.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look for file paths that need escaping (contain spaces but aren't quoted)
|
||||||
|
let mut escaped_command = String::new();
|
||||||
|
let mut in_quotes = false;
|
||||||
|
let mut current_word = String::new();
|
||||||
|
let mut words = Vec::new();
|
||||||
|
|
||||||
|
for ch in command.chars() {
|
||||||
|
match ch {
|
||||||
|
' ' if !in_quotes => {
|
||||||
|
if !current_word.is_empty() {
|
||||||
|
words.push(current_word.clone());
|
||||||
|
current_word.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'"' => {
|
||||||
|
in_quotes = !in_quotes;
|
||||||
|
current_word.push(ch);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
current_word.push(ch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !current_word.is_empty() {
|
||||||
|
words.push(current_word);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reconstruct the command with proper escaping
|
||||||
|
for (i, word) in words.iter().enumerate() {
|
||||||
|
if i > 0 {
|
||||||
|
escaped_command.push(' ');
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this word looks like a file path (contains / or ~) and has spaces, quote it
|
||||||
|
if word.contains('/') || word.starts_with('~') {
|
||||||
|
if word.contains(' ') && !word.starts_with('"') && !word.starts_with('\'') {
|
||||||
|
escaped_command.push_str(&format!("\"{}\"", word));
|
||||||
|
} else {
|
||||||
|
escaped_command.push_str(word);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
escaped_command.push_str(word);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
escaped_command
|
||||||
|
} else {
|
||||||
|
// For non-file commands, use the original command
|
||||||
|
command.to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub mod providers {
|
pub mod providers {
|
||||||
pub mod anthropic;
|
pub mod anthropic;
|
||||||
pub mod embedded;
|
pub mod embedded;
|
||||||
|
|||||||
Reference in New Issue
Block a user