Files
g3/crates/g3-core/src/lib.rs
Dhanji Prasanna 358faa15e0 End tool call
2025-09-15 09:07:12 +10:00

986 lines
39 KiB
Rust

use anyhow::Result;
use g3_config::Config;
use g3_execution::CodeExecutor;
use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::Path;
use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn, debug};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub tool: String,
pub args: serde_json::Value,
}
#[derive(Debug, Clone)]
pub enum StreamState {
Generating,
ToolDetected(ToolCall),
Executing,
Resuming,
}
#[derive(Debug)]
pub struct StreamingToolParser {
buffer: String,
brace_count: i32,
in_tool_call: bool,
tool_start_pos: Option<usize>,
}
impl StreamingToolParser {
pub fn new() -> Self {
Self {
buffer: String::new(),
brace_count: 0,
in_tool_call: false,
tool_start_pos: None,
}
}
pub fn add_chunk(&mut self, chunk: &str) -> Option<(ToolCall, usize)> {
self.buffer.push_str(chunk);
//info!("Parser buffer now: {:?}", self.buffer);
self.detect_tool_call()
}
fn detect_tool_call(&mut self) -> Option<(ToolCall, usize)> {
//info!("Detecting tool call in buffer: {:?}", self.buffer);
// Look for the start of a tool call pattern: {"tool":
if !self.in_tool_call {
// Look for JSON tool call pattern - check both raw JSON and inside code blocks
// Also handle malformed patterns like {"{""tool"":
let patterns = [
r#"{"tool":"#, // Normal pattern
r#"{"{""tool"":"#, // Malformed pattern with extra brace and doubled quotes
r#"{{""tool"":"#, // Alternative malformed pattern
];
for pattern in &patterns {
if let Some(pos) = self.buffer.rfind(pattern) {
info!("Found tool call pattern '{}' at position: {}", pattern, pos);
// Check if this is inside a code block
let before_pos = &self.buffer[..pos];
let code_block_count = before_pos.matches("```").count();
// Accept tool calls both inside and outside code blocks
// The LLM might use either format despite our instructions
//info!("Starting tool call parsing (code block status: {})", code_block_count % 2 == 1);
self.in_tool_call = true;
self.tool_start_pos = Some(pos);
self.brace_count = 0; // Start counting from 0, we'll count the opening brace in parsing
// Continue parsing from after the opening brace
return self.parse_from_start_pos(pos);
}
}
} else {
//info!("Already in tool call, continuing parsing");
// We're already in a tool call, continue parsing
let start_pos = self.tool_start_pos.unwrap();
return self.parse_from_start_pos(start_pos);
}
None
}
fn parse_from_start_pos(&mut self, start_pos: usize) -> Option<(ToolCall, usize)> {
let remaining = self.buffer[start_pos..].to_string();
self.parse_from_position(&remaining, start_pos)
}
fn parse_from_position(&mut self, text: &str, start_pos: usize) -> Option<(ToolCall, usize)> {
let mut current_brace_count = 0; // Always start fresh for each parsing attempt
//info!("Parsing from position {} with text: {:?}", start_pos, text);
//info!("Starting brace count: {}", current_brace_count);
for (i, ch) in text.char_indices() {
match ch {
'{' => current_brace_count += 1,
'}' => {
current_brace_count -= 1;
//info!("Found '}}' at position {}, brace count now: {}", i, current_brace_count);
if current_brace_count == 0 {
// Found complete JSON object
let end_pos = start_pos + i + 1;
let mut json_str = self.buffer[start_pos..end_pos].to_string();
// Clean up malformed JSON patterns
json_str = json_str
.replace(r#"{"{""#, r#"{"#) // Fix {"{" -> {"
.replace(r#"""}"#, r#""}"#) // Fix ""} -> "}
.replace(r#"{{""#, r#"{"#) // Fix {{" -> {"
.replace(r#"""}"#, r#""}"#); // Fix ""} -> "}
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(&json_str) {
info!("Successfully parsed tool call: {:?}", tool_call);
// Reset parser state
self.in_tool_call = false;
self.tool_start_pos = None;
self.brace_count = 0;
return Some((tool_call, end_pos));
} else {
info!("Failed to parse JSON after cleanup: {}", json_str);
// Invalid JSON, reset and continue looking
self.in_tool_call = false;
self.tool_start_pos = None;
self.brace_count = 0;
}
}
}
_ => {}
}
}
// Update brace count for next iteration
self.brace_count = current_brace_count;
//info!("End of parsing, final brace count: {}", current_brace_count);
None
}
pub fn get_content_before_tool(&self, tool_end_pos: usize) -> String {
if tool_end_pos <= self.buffer.len() {
self.buffer[..tool_end_pos].to_string()
} else {
self.buffer.clone()
}
}
pub fn get_remaining_content(&self, from_pos: usize) -> String {
if from_pos < self.buffer.len() {
self.buffer[from_pos..].to_string()
} else {
String::new()
}
}
}
#[derive(Debug, Clone)]
pub struct ContextWindow {
pub used_tokens: u32,
pub total_tokens: u32,
pub conversation_history: Vec<Message>,
}
impl ContextWindow {
pub fn new(total_tokens: u32) -> Self {
Self {
used_tokens: 0,
total_tokens,
conversation_history: Vec::new(),
}
}
pub fn add_message(&mut self, message: Message) {
// Simple token estimation: ~4 characters per token
let estimated_tokens = (message.content.len() as f32 / 4.0).ceil() as u32;
self.used_tokens += estimated_tokens;
self.conversation_history.push(message);
}
pub fn update_usage(&mut self, usage: &g3_providers::Usage) {
// Update with actual token usage from the provider
self.used_tokens = usage.total_tokens;
}
pub fn percentage_used(&self) -> f32 {
if self.total_tokens == 0 {
0.0
} else {
(self.used_tokens as f32 / self.total_tokens as f32) * 100.0
}
}
pub fn remaining_tokens(&self) -> u32 {
self.total_tokens.saturating_sub(self.used_tokens)
}
}
pub struct Agent {
providers: ProviderRegistry,
config: Config,
context_window: ContextWindow,
}
impl Agent {
pub async fn new(config: Config) -> Result<Self> {
let mut providers = ProviderRegistry::new();
// Register providers based on configuration
if let Some(openai_config) = &config.providers.openai {
let openai_provider = crate::providers::openai::OpenAIProvider::new(
openai_config.api_key.clone(),
openai_config.model.clone(),
openai_config.base_url.clone(),
)?;
providers.register(openai_provider);
}
if let Some(anthropic_config) = &config.providers.anthropic {
let anthropic_provider = crate::providers::anthropic::AnthropicProvider::new(
anthropic_config.api_key.clone(),
anthropic_config.model.clone(),
)?;
providers.register(anthropic_provider);
}
if let Some(embedded_config) = &config.providers.embedded {
let embedded_provider = crate::providers::embedded::EmbeddedProvider::new(
embedded_config.model_path.clone(),
embedded_config.model_type.clone(),
embedded_config.context_length,
embedded_config.max_tokens,
embedded_config.temperature,
embedded_config.gpu_layers,
embedded_config.threads,
)?;
providers.register(embedded_provider);
}
// Set default provider
debug!("Setting default provider to: {}", config.providers.default_provider);
providers.set_default(&config.providers.default_provider)?;
debug!("Default provider set successfully");
// Determine context window size based on active provider
let context_length = Self::determine_context_length(&config, &providers)?;
let context_window = ContextWindow::new(context_length);
Ok(Self {
providers,
config,
context_window,
})
}
fn determine_context_length(config: &Config, providers: &ProviderRegistry) -> Result<u32> {
// Get the active provider to determine context length
let provider = providers.get(None)?;
let provider_name = provider.name();
let model_name = provider.model();
// Use provider-specific context length if available, otherwise fall back to agent config
let context_length = match provider_name {
"embedded" => {
// For embedded models, use the configured context_length or model-specific defaults
if let Some(embedded_config) = &config.providers.embedded {
embedded_config.context_length.unwrap_or_else(|| {
// Model-specific defaults for embedded models
match embedded_config.model_type.to_lowercase().as_str() {
"codellama" => 16384, // CodeLlama supports 16k context
"llama" => 4096, // Base Llama models
"mistral" => 8192, // Mistral models
"qwen" => 32768, // Qwen2.5 supports 32k context
_ => 4096, // Conservative default
}
})
} else {
config.agent.max_context_length as u32
}
}
"openai" => {
// OpenAI model-specific context lengths
match model_name {
m if m.contains("gpt-4") => 128000, // GPT-4 models have 128k context
m if m.contains("gpt-3.5") => 16384, // GPT-3.5-turbo has 16k context
_ => 4096, // Conservative default
}
}
"anthropic" => {
// Anthropic model-specific context lengths
match model_name {
m if m.contains("claude-3") => 200000, // Claude-3 has 200k context
m if m.contains("claude-2") => 100000, // Claude-2 has 100k context
_ => 100000, // Conservative default for Claude
}
}
_ => config.agent.max_context_length as u32,
};
info!(
"Using context length: {} tokens for provider: {} (model: {})",
context_length, provider_name, model_name
);
Ok(context_length)
}
pub fn get_provider_info(&self) -> Result<(String, String)> {
let provider = self.providers.get(None)?;
Ok((provider.name().to_string(), provider.model().to_string()))
}
pub async fn execute_task(
&mut self,
description: &str,
language: Option<&str>,
_auto_execute: bool,
) -> Result<String> {
self.execute_task_with_options(description, language, false, false, false)
.await
}
pub async fn execute_task_with_options(
&mut self,
description: &str,
language: Option<&str>,
_auto_execute: bool,
show_prompt: bool,
show_code: bool,
) -> Result<String> {
self.execute_task_with_timing(
description,
language,
_auto_execute,
show_prompt,
show_code,
false,
)
.await
}
pub async fn execute_task_with_timing(
&mut self,
description: &str,
language: Option<&str>,
_auto_execute: bool,
show_prompt: bool,
show_code: bool,
show_timing: bool,
) -> Result<String> {
// Create a cancellation token that never cancels for backward compatibility
let cancellation_token = CancellationToken::new();
self.execute_task_with_timing_cancellable(
description,
language,
_auto_execute,
show_prompt,
show_code,
show_timing,
cancellation_token,
)
.await
}
pub async fn execute_task_with_timing_cancellable(
&mut self,
description: &str,
_language: Option<&str>,
_auto_execute: bool,
show_prompt: bool,
show_code: bool,
show_timing: bool,
cancellation_token: CancellationToken,
) -> Result<String> {
info!("Executing task: {}", description);
let _provider = self.providers.get(None)?;
// Only add system message if this is the first interaction (empty conversation history)
if self.context_window.conversation_history.is_empty() {
let system_prompt = format!(
"You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems by writing code.
# Tool Call Format
When you need to execute a tool, write ONLY the JSON tool call on a new line:
{{\"tool\": \"tool_name\", \"args\": {{\"param\": \"value\"}}}}
The tool will execute immediately and you'll receive the result to continue with.
# Available Tools
- **shell**: Execute shell commands
- Format: {{\"tool\": \"shell\", \"args\": {{\"command\": \"your_command_here\"}}}}
- Example: {{\"tool\": \"shell\", \"args\": {{\"command\": \"ls ~/Downloads\"}}}}
- **final_output**: Signal task completion with a summary of work done in markdown format
- Format: {{\"tool\": \"final_output\", \"args\": {{\"summary\": \"what_was_accomplished\"}}}}
# Instructions
1. Analyze the request and break down into smaller tasks if appropriate
2. Execute ONE tool at a time
3. STOP when the original request was satisfied
4. End with final_output when done
# Response Guidelines
- Use Markdown formatting for all responses except tool calls.
- Whenever calling tools, use the pronoun 'I'
");
if show_prompt {
println!("🔍 System Prompt:");
println!("================");
println!("{}", system_prompt);
println!("================");
println!();
}
// Add system message to context window
let system_message = Message {
role: MessageRole::System,
content: system_prompt,
};
self.context_window.add_message(system_message);
}
// Add user message to context window
let user_message = Message {
role: MessageRole::User,
content: format!("Task: {}", description),
};
self.context_window.add_message(user_message);
// Use the complete conversation history for the request
let messages = self.context_window.conversation_history.clone();
let request = CompletionRequest {
messages,
max_tokens: Some(2048),
temperature: Some(0.1),
stream: true, // Enable streaming
};
// Time the LLM call with cancellation support and streaming
let llm_start = Instant::now();
let result = tokio::select! {
result = self.stream_completion(request) => result,
_ = cancellation_token.cancelled() => {
// Save context window on cancellation
self.save_context_window("cancelled");
Err(anyhow::anyhow!("Operation cancelled by user"))
}
};
let (response_content, think_time) = match result {
Ok(content) => content,
Err(e) => {
// Save context window on error
self.save_context_window("error");
return Err(e);
}
};
let llm_duration = llm_start.elapsed();
// Create a mock usage for now (we'll need to track this during streaming)
let mock_usage = g3_providers::Usage {
prompt_tokens: 100, // Estimate
completion_tokens: response_content.len() as u32 / 4, // Rough estimate
total_tokens: 100 + (response_content.len() as u32 / 4),
};
// Update context window with estimated token usage
self.context_window.update_usage(&mock_usage);
// Add assistant response to context window
let assistant_message = Message {
role: MessageRole::Assistant,
content: response_content.clone(),
};
self.context_window.add_message(assistant_message);
// Save context window at the end of successful interaction
self.save_context_window("completed");
// With streaming tool execution, we don't need separate code execution
// The tools are already executed during streaming
if show_timing {
let timing_summary = format!(
"\n⏱️ {} | 💭 {}",
Self::format_duration(llm_duration),
Self::format_duration(think_time)
);
Ok(format!("{}\n{}", response_content, timing_summary))
} else {
Ok(response_content)
}
}
/// Save the entire context window to a file for debugging purposes
fn save_context_window(&self, status: &str) {
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let filename = format!("g3_context_{}.json", timestamp);
let context_data = serde_json::json!({
"timestamp": timestamp,
"status": status,
"context_window": {
"used_tokens": self.context_window.used_tokens,
"total_tokens": self.context_window.total_tokens,
"percentage_used": self.context_window.percentage_used(),
"conversation_history": self.context_window.conversation_history
}
});
match serde_json::to_string_pretty(&context_data) {
Ok(json_content) => {
if let Err(e) = fs::write(&filename, json_content) {
error!("Failed to save context window to {}: {}", filename, e);
} else {
info!("Context window saved to {}", filename);
}
}
Err(e) => {
error!("Failed to serialize context window: {}", e);
}
}
}
pub fn get_context_window(&self) -> &ContextWindow {
&self.context_window
}
async fn stream_completion(
&mut self,
request: CompletionRequest,
) -> Result<(String, Duration)> {
self.stream_completion_with_tools(request).await
}
async fn stream_completion_with_tools(
&mut self,
mut request: CompletionRequest,
) -> Result<(String, Duration)> {
use std::io::{self, Write};
use tokio_stream::StreamExt;
let mut full_response = String::new();
let mut first_token_time: Option<Duration> = None;
let stream_start = Instant::now();
let mut total_execution_time = Duration::new(0, 0);
let mut iteration_count = 0;
const MAX_ITERATIONS: usize = 10; // Prevent infinite loops
print!("🤖 "); // Start the response indicator
io::stdout().flush()?;
loop {
iteration_count += 1;
if iteration_count > MAX_ITERATIONS {
warn!("Maximum iterations reached, stopping stream");
break;
}
// Add a small delay between iterations to prevent "model busy" errors
if iteration_count > 1 {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
let provider = self.providers.get(None)?;
let mut stream = match provider.stream(request.clone()).await {
Ok(s) => s,
Err(e) => {
if iteration_count > 1 && e.to_string().contains("busy") {
warn!(
"Model busy on iteration {}, retrying in 500ms",
iteration_count
);
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
match provider.stream(request.clone()).await {
Ok(s) => s,
Err(e2) => {
error!("Failed to start stream after retry: {}", e2);
return Err(e2);
}
}
} else {
return Err(e);
}
}
};
let mut parser = StreamingToolParser::new();
let mut current_response = String::new();
let mut tool_executed = false;
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
// Record time to first token
if first_token_time.is_none() && !chunk.content.is_empty() {
first_token_time = Some(stream_start.elapsed());
}
// Check for tool calls - either from JSON parsing (embedded models)
// or from native tool calls (Anthropic, OpenAI, etc.)
let mut detected_tool_call = None;
// First check for native tool calls in the chunk
if let Some(ref tool_calls) = chunk.tool_calls {
debug!("Found native tool calls in chunk: {:?}", tool_calls);
if let Some(first_tool) = tool_calls.first() {
// Convert native tool call to our internal format
detected_tool_call = Some((
crate::ToolCall {
tool: first_tool.tool.clone(),
args: first_tool.args.clone(),
},
current_response.len(), // Position doesn't matter for native calls
));
debug!("Converted native tool call: {:?}", detected_tool_call);
}
} else {
debug!("No native tool calls in chunk, chunk.tool_calls is None");
}
// If no native tool calls, check for JSON tool calls in text (embedded models)
if detected_tool_call.is_none() {
detected_tool_call = parser.add_chunk(&chunk.content);
}
if let Some((tool_call, tool_end_pos)) = detected_tool_call {
// Found a complete tool call! Stop streaming and execute it
let content_before_tool = parser.get_content_before_tool(tool_end_pos);
// Display content up to the tool call (excluding the JSON and any stop tokens)
let display_content = if let Some(json_start) =
content_before_tool.rfind(r#"{"tool":"#)
{
// Only show content before the JSON tool call
content_before_tool[..json_start].trim()
} else {
// Fallback: clean any stop tokens from the content
content_before_tool.trim()
};
// Clean stop tokens from display content
let clean_display_content = display_content
.replace("<|im_end|>", "")
.replace("</s>", "")
.replace("[/INST]", "")
.replace("<</SYS>>", "");
let final_display_content = clean_display_content.trim();
// Safely get the new content to display
let new_content = if current_response.len() <= final_display_content.len() {
// Use char indices to avoid UTF-8 boundary issues
let chars_already_shown = current_response.chars().count();
final_display_content
.chars()
.skip(chars_already_shown)
.collect::<String>()
} else {
String::new()
};
// Only print if there's actually new content to show
if !new_content.trim().is_empty() {
print!("{}", new_content);
io::stdout().flush()?;
}
// Execute the tool with formatted output
println!(); // New line before tool execution
// Tool call header
println!("┌─ {}", tool_call.tool);
if let Some(args_obj) = tool_call.args.as_object() {
for (_key, value) in args_obj {
let value_str = match value {
serde_json::Value::String(s) => s.clone(),
_ => value.to_string(),
};
println!("{}", value_str);
}
}
println!("├─ output:");
let exec_start = Instant::now();
let tool_result = self.execute_tool(&tool_call).await?;
let exec_duration = exec_start.elapsed();
total_execution_time += exec_duration;
// Display tool execution result with proper indentation
let output_lines: Vec<&str> = tool_result.lines().collect();
const MAX_LINES: usize = 5;
if output_lines.len() <= MAX_LINES {
// Show all lines if within limit
for line in output_lines {
println!("{}", line);
}
} else {
// Show first MAX_LINES and add truncation note
for line in output_lines.iter().take(MAX_LINES) {
println!("{}", line);
}
let hidden_count = output_lines.len() - MAX_LINES;
println!(
"│ ... ({} more line{} hidden)",
hidden_count,
if hidden_count == 1 { "" } else { "s" }
);
}
// Closure marker with timing
println!("└─ ⚡️ {}", Self::format_duration(exec_duration));
println!();
print!("🤖 "); // Continue response indicator
io::stdout().flush()?;
// Add the tool call and result to the context window immediately
let tool_message = Message {
role: MessageRole::Assistant,
content: format!(
"{}\n\n{{\"tool\": \"{}\", \"args\": {}}}",
display_content.trim(),
tool_call.tool,
tool_call.args
),
};
let result_message = Message {
role: MessageRole::User, // Tool results come back as user messages
content: format!("Tool result: {}", tool_result),
};
// Add to context window for persistence
self.context_window.add_message(tool_message);
self.context_window.add_message(result_message);
// Update the request with the new context for next iteration
request.messages = self.context_window.conversation_history.clone();
full_response.push_str(final_display_content);
full_response.push_str(&format!(
"\n\nTool executed: {} -> {}\n\n",
tool_call.tool, tool_result
));
// Check if this was a final_output tool call - if so, stop the conversation
if tool_call.tool == "final_output" {
println!(); // New line after final output
let ttft = first_token_time.unwrap_or_else(|| stream_start.elapsed());
return Ok((full_response, ttft));
}
tool_executed = true;
// Break out of current stream to start a new one with updated context
break;
} else {
// No tool call detected, continue streaming normally
// Filter out stop tokens from the streaming output
let clean_content = chunk.content
.replace("<|im_end|>", "")
.replace("</s>", "")
.replace("[/INST]", "")
.replace("<</SYS>>", "");
if !clean_content.is_empty() {
print!("{}", clean_content);
io::stdout().flush()?;
current_response.push_str(&clean_content);
}
}
if chunk.finished {
// Stream finished naturally without tool calls
full_response.push_str(&current_response);
println!(); // New line after streaming completes
let ttft = first_token_time.unwrap_or_else(|| stream_start.elapsed());
return Ok((full_response, ttft));
}
}
Err(e) => {
error!("Streaming error: {}", e);
// If we executed a tool, try to continue with a new stream
if tool_executed {
warn!("Stream error after tool execution, attempting to continue");
break; // Break to outer loop to start new stream
} else {
return Err(e);
}
}
}
}
// If we get here and no tool was executed, we're done
if !tool_executed {
full_response.push_str(&current_response);
println!(); // New line after streaming completes
let ttft = first_token_time.unwrap_or_else(|| stream_start.elapsed());
return Ok((full_response, ttft));
}
// Continue the loop to start a new stream with updated context
info!(
"Starting new stream iteration {} with {} messages",
iteration_count,
request.messages.len()
);
}
// If we exit the loop due to max iterations
let ttft = first_token_time.unwrap_or_else(|| stream_start.elapsed());
Ok((full_response, ttft))
}
async fn execute_tool(&self, tool_call: &ToolCall) -> Result<String> {
match tool_call.tool.as_str() {
"shell" => {
if let Some(command) = tool_call.args.get("command") {
if let Some(command_str) = command.as_str() {
// Use shell escaping to handle filenames with spaces and special characters
let escaped_command = shell_escape_command(command_str);
let executor = CodeExecutor::new();
match executor.execute_code("bash", &escaped_command).await {
Ok(result) => {
if result.success {
Ok(if result.stdout.is_empty() {
"✅ Command executed successfully".to_string()
} else {
result.stdout.trim().to_string()
})
} else {
Ok(format!("❌ Command failed: {}", result.stderr.trim()))
}
}
Err(e) => Ok(format!("❌ Execution error: {}", e)),
}
} else {
Ok("❌ Invalid command argument".to_string())
}
} else {
Ok("❌ Missing command argument".to_string())
}
}
"final_output" => {
if let Some(summary) = tool_call.args.get("summary") {
if let Some(summary_str) = summary.as_str() {
Ok(format!("📋 Final Output: {}", summary_str))
} else {
Ok("📋 Task completed".to_string())
}
} else {
Ok("📋 Task completed".to_string())
}
}
_ => {
warn!("Unknown tool: {}", tool_call.tool);
Ok(format!("❓ Unknown tool: {}", tool_call.tool))
}
}
}
fn format_duration(duration: Duration) -> String {
let total_ms = duration.as_millis();
if total_ms < 1000 {
format!("{}ms", total_ms)
} else if total_ms < 60_000 {
let seconds = duration.as_secs_f64();
format!("{:.1}s", seconds)
} else {
let minutes = total_ms / 60_000;
let remaining_seconds = (total_ms % 60_000) as f64 / 1000.0;
format!("{}m {:.1}s", minutes, remaining_seconds)
}
}
}
// Helper function to properly escape shell commands
fn shell_escape_command(command: &str) -> String {
// Simple approach: if the command contains file paths with spaces,
// we need to be more intelligent about escaping
// For now, let's use a basic approach that handles common cases
// This is a simplified version - a full implementation would use proper shell parsing
let parts: Vec<&str> = command.split_whitespace().collect();
if parts.is_empty() {
return command.to_string();
}
let cmd = parts[0];
let args = &parts[1..];
// Commands that typically take file paths as arguments
let file_commands = [
"cat", "ls", "cp", "mv", "rm", "chmod", "chown", "file", "head", "tail", "wc", "grep",
];
if file_commands.contains(&cmd) {
// For file commands, we need to be smarter about escaping
// Let's use a different approach: use the original command but wrap it in quotes if needed
// Check if the command already has proper quoting
if command.contains('"') || command.contains('\'') {
// Already has some quoting, use as-is
return command.to_string();
}
// Look for file paths that need escaping (contain spaces but aren't quoted)
let mut escaped_command = String::new();
let mut in_quotes = false;
let mut current_word = String::new();
let mut words = Vec::new();
for ch in command.chars() {
match ch {
' ' if !in_quotes => {
if !current_word.is_empty() {
words.push(current_word.clone());
current_word.clear();
}
}
'"' => {
in_quotes = !in_quotes;
current_word.push(ch);
}
_ => {
current_word.push(ch);
}
}
}
if !current_word.is_empty() {
words.push(current_word);
}
// Reconstruct the command with proper escaping
for (i, word) in words.iter().enumerate() {
if i > 0 {
escaped_command.push(' ');
}
// If this word looks like a file path (contains / or ~) and has spaces, quote it
if word.contains('/') || word.starts_with('~') {
if word.contains(' ') && !word.starts_with('"') && !word.starts_with('\'') {
escaped_command.push_str(&format!("\"{}\"", word));
} else {
escaped_command.push_str(word);
}
} else {
escaped_command.push_str(word);
}
}
escaped_command
} else {
// For non-file commands, use the original command
command.to_string()
}
}
pub mod providers {
pub mod anthropic;
pub mod embedded;
pub mod openai;
}