From 52f78653b418564029e55bcd8d54383ab0fbf80e Mon Sep 17 00:00:00 2001 From: Jochen Date: Thu, 27 Nov 2025 21:00:02 +1100 Subject: [PATCH] add context window monitor Writes the current context window to logs/current_context_window (uses a symlink to a session ID). This PR was unfortunately generated by a different LLM and did a ton of superficial reformating, it's actually a fairly small and benign change, but I don't want to roll back everything. Hope that's ok. --- Cargo.lock | 2 + Cargo.toml | 6 + crates/g3-cli/src/lib.rs | 548 ++++++++----- crates/g3-cli/src/machine_ui_writer.rs | 4 +- crates/g3-cli/src/simple_output.rs | 4 +- crates/g3-cli/src/ui_writer_impl.rs | 67 +- crates/g3-computer-control/build.rs | 46 +- .../examples/debug_screenshot.rs | 34 +- .../examples/list_windows.rs | 36 +- .../examples/macax_demo.rs | 16 +- .../examples/safari_demo.rs | 32 +- .../examples/test_permission_prompt.rs | 9 +- .../examples/test_screencapture_direct.rs | 20 +- .../examples/test_screenshot_fix.rs | 40 +- .../examples/test_type_text.rs | 18 +- .../examples/test_vision.rs | 45 +- .../examples/test_window_capture.rs | 30 +- crates/g3-computer-control/src/lib.rs | 35 +- .../src/macax/controller.rs | 346 ++++---- crates/g3-computer-control/src/macax/mod.rs | 10 +- crates/g3-computer-control/src/ocr/mod.rs | 2 +- .../g3-computer-control/src/ocr/tesseract.rs | 31 +- crates/g3-computer-control/src/ocr/vision.rs | 31 +- .../g3-computer-control/src/platform/linux.rs | 112 +-- .../g3-computer-control/src/platform/macos.rs | 490 ++++++----- .../platform/macos_window_matching_test.rs | 16 +- .../src/platform/windows.rs | 112 +-- .../g3-computer-control/src/webdriver/mod.rs | 50 +- .../src/webdriver/safari.rs | 120 +-- .../tests/integration_test.rs | 22 +- crates/g3-config/src/lib.rs | 92 +-- crates/g3-config/src/tests.rs | 38 +- .../tests/test_multiple_tool_calls.rs | 14 +- crates/g3-console/examples/debug_detector.rs | 16 +- crates/g3-console/examples/test_api.rs | 8 +- crates/g3-console/examples/test_detector.rs | 6 +- crates/g3-console/src/api/control.rs | 97 ++- crates/g3-console/src/api/instances.rs | 91 ++- crates/g3-console/src/api/logs.rs | 4 +- crates/g3-console/src/api/mod.rs | 2 +- crates/g3-console/src/api/state.rs | 32 +- crates/g3-console/src/launch.rs | 18 +- crates/g3-console/src/lib.rs | 2 +- crates/g3-console/src/logs.rs | 48 +- crates/g3-console/src/main.rs | 10 +- crates/g3-console/src/models/instance.rs | 2 +- crates/g3-console/src/models/message.rs | 2 +- crates/g3-console/src/process/controller.rs | 93 ++- crates/g3-console/src/process/detector.rs | 52 +- crates/g3-console/src/process/mod.rs | 4 +- crates/g3-core/examples/inspect_ast.rs | 6 +- crates/g3-core/examples/inspect_python_ast.rs | 6 +- crates/g3-core/examples/test_python_query.rs | 2 +- crates/g3-core/src/code_search/searcher.rs | 25 +- crates/g3-core/src/error_handling.rs | 176 ++-- crates/g3-core/src/error_handling_test.rs | 8 +- crates/g3-core/src/fixed_filter_json.rs | 4 +- crates/g3-core/src/fixed_filter_tests.rs | 4 +- crates/g3-core/src/lib.rs | 764 ++++++++++++------ crates/g3-core/src/project.rs | 47 +- crates/g3-core/src/prompts.rs | 2 +- crates/g3-core/src/task_result.rs | 49 +- .../src/task_result_comprehensive_tests.rs | 116 ++- crates/g3-core/src/tilde_expansion_tests.rs | 10 +- crates/g3-core/src/ui_writer.rs | 52 +- crates/g3-core/tests/code_search_test.rs | 101 ++- crates/g3-core/tests/test_context_thinning.rs | 102 +-- .../tests/test_todo_context_thinning.rs | 96 ++- crates/g3-core/tests/test_todo_persistence.rs | 189 +++-- crates/g3-core/tests/test_token_counting.rs | 35 +- crates/g3-core/tests/todo_staleness_test.rs | 45 +- crates/g3-ensembles/src/flock.rs | 315 +++++--- crates/g3-ensembles/src/status.rs | 136 ++-- crates/g3-ensembles/src/tests.rs | 3 +- .../g3-ensembles/tests/integration_tests.rs | 12 +- .../examples/setup_coverage_tools.rs | 2 +- crates/g3-execution/src/lib.rs | 226 +++--- crates/g3-planner/src/code_explore.rs | 5 +- crates/g3-planner/src/lib.rs | 9 +- crates/g3-planner/tests/logging_test.rs | 30 +- crates/g3-providers/src/anthropic.rs | 302 ++++--- crates/g3-providers/src/databricks.rs | 56 +- crates/g3-providers/src/embedded.rs | 194 +++-- crates/g3-providers/src/lib.rs | 180 +++-- crates/g3-providers/src/openai.rs | 48 +- .../cache_control_error_regression_test.rs | 137 ++-- .../tests/cache_control_integration_test.rs | 126 +-- examples/verify_message_id.rs | 8 +- monitor_context_window.sh | 23 + 89 files changed, 4040 insertions(+), 2576 deletions(-) create mode 100755 monitor_context_window.sh diff --git a/Cargo.lock b/Cargo.lock index 161dba1..383c54a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1351,6 +1351,8 @@ version = "0.1.0" dependencies = [ "anyhow", "g3-cli", + "g3-providers", + "serde_json", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 9842a21..6036f89 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,3 +45,9 @@ license = "MIT" g3-cli = { path = "crates/g3-cli" } tokio = { workspace = true } anyhow = { workspace = true } +g3-providers = { path = "crates/g3-providers" } +serde_json = { workspace = true } + +[[example]] +name = "verify_message_id" +path = "examples/verify_message_id.rs" diff --git a/crates/g3-cli/src/lib.rs b/crates/g3-cli/src/lib.rs index 6dfe2ee..0ea24ca 100644 --- a/crates/g3-cli/src/lib.rs +++ b/crates/g3-cli/src/lib.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use crossterm::style::{Color, SetForegroundColor, ResetColor}; +use crossterm::style::{Color, ResetColor, SetForegroundColor}; use std::time::{Duration, Instant}; #[derive(Debug, Clone)] @@ -16,39 +16,51 @@ fn generate_turn_histogram(turn_metrics: &[TurnMetrics]) -> String { } let mut histogram = String::new(); - + // Find max values for scaling - let max_tokens = turn_metrics.iter().map(|t| t.tokens_used).max().unwrap_or(1); - let max_time_ms = turn_metrics.iter() + let max_tokens = turn_metrics + .iter() + .map(|t| t.tokens_used) + .max() + .unwrap_or(1); + let max_time_ms = turn_metrics + .iter() .map(|t| t.wall_clock_time.as_millis().min(u32::MAX as u128) as u32) .max() .unwrap_or(1); - + // Constants for histogram display const MAX_BAR_WIDTH: usize = 40; const TOKEN_CHAR: char = '█'; const TIME_CHAR: char = '▓'; - + histogram.push_str("\n📊 Per-Turn Performance Histogram:\n"); - histogram.push_str(&format!(" {} = Tokens Used (max: {})\n", TOKEN_CHAR, max_tokens)); - histogram.push_str(&format!(" {} = Wall Clock Time (max: {:.1}s)\n\n", TIME_CHAR, max_time_ms as f64 / 1000.0)); - + histogram.push_str(&format!( + " {} = Tokens Used (max: {})\n", + TOKEN_CHAR, max_tokens + )); + histogram.push_str(&format!( + " {} = Wall Clock Time (max: {:.1}s)\n\n", + TIME_CHAR, + max_time_ms as f64 / 1000.0 + )); + for metrics in turn_metrics { let turn_time_ms = metrics.wall_clock_time.as_millis().min(u32::MAX as u128) as u32; - + // Calculate bar lengths (proportional to max values) let token_bar_len = if max_tokens > 0 { ((metrics.tokens_used as f64 / max_tokens as f64) * MAX_BAR_WIDTH as f64) as usize } else { 0 }; - + let time_bar_len = if max_time_ms > 0 { ((turn_time_ms as f64 / max_time_ms as f64) * MAX_BAR_WIDTH as f64) as usize } else { 0 }; - + // Format time duration let time_str = if turn_time_ms < 1000 { format!("{}ms", turn_time_ms) @@ -59,42 +71,50 @@ fn generate_turn_histogram(turn_metrics: &[TurnMetrics]) -> String { let seconds = (turn_time_ms % 60_000) as f64 / 1000.0; format!("{}m{:.1}s", minutes, seconds) }; - + // Create the bars let token_bar = TOKEN_CHAR.to_string().repeat(token_bar_len); let time_bar = TIME_CHAR.to_string().repeat(time_bar_len); - + // Add turn information histogram.push_str(&format!( " Turn {:2}: {:>6} tokens │{:<40}│\n", - metrics.turn_number, - metrics.tokens_used, - token_bar + metrics.turn_number, metrics.tokens_used, token_bar )); histogram.push_str(&format!( " {:>6} │{:<40}│\n", - time_str, - time_bar + time_str, time_bar )); - + // Add separator line between turns (except for last turn) if metrics.turn_number != turn_metrics.last().unwrap().turn_number { - histogram.push_str(" ────────────┼────────────────────────────────────────┤\n"); + histogram + .push_str(" ────────────┼────────────────────────────────────────┤\n"); } } - + // Add summary statistics let total_tokens: u32 = turn_metrics.iter().map(|t| t.tokens_used).sum(); let total_time: Duration = turn_metrics.iter().map(|t| t.wall_clock_time).sum(); let avg_tokens = total_tokens as f64 / turn_metrics.len() as f64; let avg_time_ms = total_time.as_millis() as f64 / turn_metrics.len() as f64; - + histogram.push_str("\n📈 Summary Statistics:\n"); - histogram.push_str(&format!(" • Total Tokens: {} across {} turns\n", total_tokens, turn_metrics.len())); + histogram.push_str(&format!( + " • Total Tokens: {} across {} turns\n", + total_tokens, + turn_metrics.len() + )); histogram.push_str(&format!(" • Average Tokens/Turn: {:.1}\n", avg_tokens)); - histogram.push_str(&format!(" • Total Time: {:.1}s\n", total_time.as_secs_f64())); - histogram.push_str(&format!(" • Average Time/Turn: {:.1}s\n", avg_time_ms / 1000.0)); - + histogram.push_str(&format!( + " • Total Time: {:.1}s\n", + total_time.as_secs_f64() + )); + histogram.push_str(&format!( + " • Average Time/Turn: {:.1}s\n", + avg_time_ms / 1000.0 + )); + histogram } @@ -181,15 +201,15 @@ use g3_config::Config; use g3_core::{project::Project, ui_writer::UiWriter, Agent, DiscoveryOptions}; use rustyline::error::ReadlineError; use rustyline::DefaultEditor; +use sha2::{Digest, Sha256}; use std::path::Path; use std::path::PathBuf; -use sha2::{Digest, Sha256}; use tokio_util::sync::CancellationToken; use tracing::{error, info}; use g3_core::error_handling::{classify_error, ErrorType, RecoverableError}; -mod ui_writer_impl; mod simple_output; +mod ui_writer_impl; use simple_output::SimpleOutput; mod machine_ui_writer; use machine_ui_writer::MachineUiWriter; @@ -285,7 +305,7 @@ pub struct Cli { /// Enable fast codebase discovery before first LLM turn #[arg(long, value_name = "PATH")] - pub codebase_fast_start: Option + pub codebase_fast_start: Option, } pub async fn run() -> Result<()> { @@ -293,9 +313,16 @@ pub async fn run() -> Result<()> { // Check if flock mode is enabled if let (Some(project_dir), Some(flock_workspace), Some(num_segments)) = - (&cli.project, &cli.flock_workspace, cli.segments) { + (&cli.project, &cli.flock_workspace, cli.segments) + { // Run flock mode - return run_flock_mode(project_dir.clone(), flock_workspace.clone(), num_segments, cli.flock_max_turns).await; + return run_flock_mode( + project_dir.clone(), + flock_workspace.clone(), + num_segments, + cli.flock_max_turns, + ) + .await; } // Otherwise, continue with normal mode @@ -353,7 +380,7 @@ pub async fn run() -> Result<()> { // Check if we're in a project directory and read README and AGENTS.md if available // Load AGENTS.md first (if present) to provide agent-specific instructions let agents_content = read_agents_config(&workspace_dir); - + // Then load README for project context let readme_content = read_project_readme(&workspace_dir); @@ -361,7 +388,10 @@ pub async fn run() -> Result<()> { let project = if cli.autonomous { if let Some(requirements_text) = &cli.requirements { // Use requirements text override - Project::new_autonomous_with_requirements(workspace_dir.clone(), requirements_text.clone())? + Project::new_autonomous_with_requirements( + workspace_dir.clone(), + requirements_text.clone(), + )? } else { // Use traditional requirements.md file Project::new_autonomous(workspace_dir.clone())? @@ -410,23 +440,21 @@ pub async fn run() -> Result<()> { // Initialize agent // ui_writer will be created conditionally based on machine mode - + // Combine AGENTS.md and README content if both exist let combined_content = match (agents_content.clone(), readme_content.clone()) { - (Some(agents), Some(readme)) => { - Some(format!("{}\n\n{}", agents, readme)) - } + (Some(agents), Some(readme)) => Some(format!("{}\n\n{}", agents, readme)), (Some(agents), None) => Some(agents), (None, Some(readme)) => Some(readme), (None, None) => None, }; - + // Execute task, autonomous mode, or start interactive mode based on machine mode if cli.machine { // Machine mode - use MachineUiWriter - + let ui_writer = MachineUiWriter::new(); - + let agent = if cli.autonomous { Agent::new_autonomous_with_readme_and_quiet( config.clone(), @@ -444,26 +472,27 @@ pub async fn run() -> Result<()> { ) .await? }; - + run_with_machine_mode(agent, cli, project).await?; } else { // Normal mode - use ConsoleUiWriter - + // DEFAULT: Chat mode for interactive sessions // It runs when: // 1. No task is provided (not single-shot) // 2. Not in autonomous mode // 3. Not explicitly enabled with --auto flag let use_accumulative = cli.task.is_none() && !cli.autonomous && cli.auto; - + if use_accumulative { // Run accumulative mode and return early - run_accumulative_mode(workspace_dir.clone(), cli.clone(), combined_content.clone()).await?; + run_accumulative_mode(workspace_dir.clone(), cli.clone(), combined_content.clone()) + .await?; return Ok(()); } - + let ui_writer = ConsoleUiWriter::new(); - + let agent = if cli.autonomous { Agent::new_autonomous_with_readme_and_quiet( config.clone(), @@ -481,10 +510,10 @@ pub async fn run() -> Result<()> { ) .await? }; - + run_with_console_mode(agent, cli, project, combined_content).await?; } - + Ok(()) } @@ -529,14 +558,17 @@ async fn run_accumulative_mode( combined_content: Option, ) -> Result<()> { let output = SimpleOutput::new(); - + output.print(""); output.print("g3 programming agent - autonomous mode"); output.print(" >> describe what you want, I'll build it iteratively"); output.print(""); - print!("{}workspace: {}{}\n", + print!( + "{}workspace: {}{}\n", SetForegroundColor(Color::DarkGrey), - workspace_dir.display(), ResetColor); + workspace_dir.display(), + ResetColor + ); output.print(""); output.print("💡 Each input you provide will be added to requirements"); output.print(" and I'll automatically work on implementing them. You can"); @@ -544,45 +576,48 @@ async fn run_accumulative_mode( output.print(""); output.print(" Type '/help' for commands, 'exit' or 'quit' to stop, Ctrl+D to finish"); output.print(""); - + // Initialize rustyline editor with history let mut rl = DefaultEditor::new()?; let history_file = dirs::home_dir().map(|mut path| { path.push(".g3_accumulative_history"); path }); - + if let Some(ref history_path) = history_file { let _ = rl.load_history(history_path); } - + // Accumulated requirements stored in memory let mut accumulated_requirements = Vec::new(); let mut turn_number = 0; - + loop { output.print(&format!("\n{}", "=".repeat(60))); if accumulated_requirements.is_empty() { output.print("📝 What would you like me to build? (describe your requirements)"); } else { - output.print(&format!("📝 Turn {} - What's next? (add more requirements or refinements)", turn_number + 1)); + output.print(&format!( + "📝 Turn {} - What's next? (add more requirements or refinements)", + turn_number + 1 + )); } output.print(&format!("{}", "=".repeat(60))); - + let readline = rl.readline("requirement> "); match readline { Ok(line) => { let input = line.trim().to_string(); - + if input.is_empty() { continue; } - + if input == "exit" || input == "quit" { output.print("\n👋 Goodbye!"); break; } - + // Check for slash commands if input.starts_with('/') { match input.as_str() { @@ -614,7 +649,7 @@ async fn run_accumulative_mode( output.print(""); output.print("🔄 Switching to interactive chat mode..."); output.print(""); - + // Build context message with accumulated requirements let requirements_context = if accumulated_requirements.is_empty() { None @@ -626,36 +661,39 @@ async fn run_accumulative_mode( accumulated_requirements.join("\n") )) }; - + // Combine with existing content (README/AGENTS.md) - let chat_combined_content = match (requirements_context, combined_content.clone()) { - (Some(req_ctx), Some(existing)) => Some(format!("{}\n\n{}", req_ctx, existing)), - (Some(req_ctx), None) => Some(req_ctx), - (None, existing) => existing, - }; - + let chat_combined_content = + match (requirements_context, combined_content.clone()) { + (Some(req_ctx), Some(existing)) => { + Some(format!("{}\n\n{}", req_ctx, existing)) + } + (Some(req_ctx), None) => Some(req_ctx), + (None, existing) => existing, + }; + // Load configuration let mut config = Config::load_with_overrides( cli.config.as_deref(), cli.provider.clone(), cli.model.clone(), )?; - + // Apply macax flag override if cli.macax { config.macax.enabled = true; } - + // Apply webdriver flag override if cli.webdriver { config.webdriver.enabled = true; } - + // Apply no-auto-compact flag override if cli.manual_compact { config.agent.auto_compact = false; } - + // Create agent for interactive mode with requirements context let ui_writer = ConsoleUiWriter::new(); let agent = Agent::new_with_readme_and_quiet( @@ -665,28 +703,38 @@ async fn run_accumulative_mode( cli.quiet, ) .await?; - + // Run interactive mode - run_interactive(agent, cli.show_prompt, cli.show_code, chat_combined_content, &workspace_dir).await?; - + run_interactive( + agent, + cli.show_prompt, + cli.show_code, + chat_combined_content, + &workspace_dir, + ) + .await?; + // After returning from interactive mode, exit output.print("\n👋 Goodbye!"); break; } _ => { - output.print(&format!("❌ Unknown command: {}. Type /help for available commands.", input)); + output.print(&format!( + "❌ Unknown command: {}. Type /help for available commands.", + input + )); continue; } } } - + // Add to history rl.add_history_entry(&input)?; - + // Add this requirement to accumulated list turn_number += 1; accumulated_requirements.push(format!("{}. {}", turn_number, input)); - + // Build the complete requirements document let requirements_doc = format!( "# Project Requirements\n\n\ @@ -698,46 +746,49 @@ async fn run_accumulative_mode( turn_number, input ); - + output.print(""); - output.print(&format!("📋 Current instructions and requirements (Turn {}):", turn_number)); + output.print(&format!( + "📋 Current instructions and requirements (Turn {}):", + turn_number + )); output.print(&format!(" {}", input)); output.print(""); output.print("🚀 Starting autonomous implementation..."); output.print(""); - + // Create a project with the accumulated requirements let project = Project::new_autonomous_with_requirements( workspace_dir.clone(), - requirements_doc.clone() + requirements_doc.clone(), )?; - + // Ensure workspace exists and enter it project.ensure_workspace_exists()?; project.enter_workspace()?; - + // Load configuration with CLI overrides let mut config = Config::load_with_overrides( cli.config.as_deref(), cli.provider.clone(), cli.model.clone(), )?; - + // Apply macax flag override if cli.macax { config.macax.enabled = true; } - + // Apply webdriver flag override if cli.webdriver { config.webdriver.enabled = true; } - + // Apply no-auto-compact flag override if cli.manual_compact { config.agent.auto_compact = false; } - + // Create agent for this autonomous run let ui_writer = ConsoleUiWriter::new(); let agent = Agent::new_autonomous_with_readme_and_quiet( @@ -747,7 +798,7 @@ async fn run_accumulative_mode( cli.quiet, ) .await?; - + // Run autonomous mode with the accumulated requirements let autonomous_result = tokio::select! { result = run_autonomous( @@ -764,9 +815,8 @@ async fn run_accumulative_mode( Ok(()) } }; - - match autonomous_result - { + + match autonomous_result { Ok(_) => { output.print(""); output.print("✅ Autonomous run completed"); @@ -792,12 +842,12 @@ async fn run_accumulative_mode( } } } - + // Save history before exiting if let Some(ref history_path) = history_file { let _ = rl.save_history(history_path); } - + Ok(()) } @@ -840,7 +890,9 @@ async fn run_autonomous_machine( ); println!("TASK_START"); - let result = agent.execute_task_with_timing(&task, None, false, show_prompt, show_code, true, None).await?; + let result = agent + .execute_task_with_timing(&task, None, false, show_prompt, show_code, true, None) + .await?; println!("AGENT_RESPONSE:"); println!("{}", result.response); println!("END_AGENT_RESPONSE"); @@ -856,7 +908,6 @@ async fn run_with_console_mode( project: Project, combined_content: Option, ) -> Result<()> { - // Execute task, autonomous mode, or start interactive mode if cli.autonomous { // Autonomous mode with coach-player feedback loop @@ -874,12 +925,27 @@ async fn run_with_console_mode( // Single-shot mode let output = SimpleOutput::new(); let result = agent - .execute_task_with_timing(&task, None, false, cli.show_prompt, cli.show_code, true, None) + .execute_task_with_timing( + &task, + None, + false, + cli.show_prompt, + cli.show_code, + true, + None, + ) .await?; output.print_smart(&result.response); } else { // Interactive mode (default) - run_interactive(agent, cli.show_prompt, cli.show_code, combined_content, project.workspace()).await?; + run_interactive( + agent, + cli.show_prompt, + cli.show_code, + combined_content, + project.workspace(), + ) + .await?; } Ok(()) @@ -905,7 +971,15 @@ async fn run_with_machine_mode( } else if let Some(task) = cli.task { // Single-shot mode let result = agent - .execute_task_with_timing(&task, None, false, cli.show_prompt, cli.show_code, true, None) + .execute_task_with_timing( + &task, + None, + false, + cli.show_prompt, + cli.show_code, + true, + None, + ) .await?; println!("AGENT_RESPONSE:"); println!("{}", result.response); @@ -922,7 +996,7 @@ async fn run_with_machine_mode( fn read_agents_config(workspace_dir: &Path) -> Option { // Look for AGENTS.md in the current directory let agents_path = workspace_dir.join("AGENTS.md"); - + if agents_path.exists() { match std::fs::read_to_string(&agents_path) { Ok(content) => { @@ -943,9 +1017,10 @@ fn read_agents_config(workspace_dir: &Path) -> Option { let alt_path = workspace_dir.join("agents.md"); if alt_path.exists() { match std::fs::read_to_string(&alt_path) { - Ok(content) => { - Some(format!("🤖 Agent Configuration (from agents.md):\n\n{}", content)) - } + Ok(content) => Some(format!( + "🤖 Agent Configuration (from agents.md):\n\n{}", + content + )), Err(e) => { error!("Failed to read agents.md: {}", e); None @@ -1069,9 +1144,14 @@ async fn run_interactive( // Display provider and model information match agent.get_provider_info() { Ok((provider, model)) => { - print!("🔧 {}{}{} | {}{}{}\n", - SetForegroundColor(Color::Cyan), provider, ResetColor, - SetForegroundColor(Color::Yellow), model, ResetColor + print!( + "🔧 {}{}{} | {}{}{}\n", + SetForegroundColor(Color::Cyan), + provider, + ResetColor, + SetForegroundColor(Color::Yellow), + model, + ResetColor ); } Err(e) => { @@ -1084,28 +1164,36 @@ async fn run_interactive( // Check what was loaded let has_agents = content.contains("Agent Configuration"); let has_readme = content.contains("Project README"); - + if has_agents { - print!("{}🤖 AGENTS.md configuration loaded{}\n", - SetForegroundColor(Color::DarkGrey), ResetColor); + print!( + "{}🤖 AGENTS.md configuration loaded{}\n", + SetForegroundColor(Color::DarkGrey), + ResetColor + ); } - + if has_readme { // Extract the first heading or title from the README let readme_snippet = extract_readme_heading(content) .unwrap_or_else(|| "Project documentation loaded".to_string()); - print!("{}📚 detected: {}{}\n", + print!( + "{}📚 detected: {}{}\n", SetForegroundColor(Color::DarkGrey), readme_snippet, - ResetColor); + ResetColor + ); } } // Display workspace path - print!("{}workspace: {}{}\n", + print!( + "{}workspace: {}{}\n", SetForegroundColor(Color::DarkGrey), - workspace_path.display(), ResetColor); + workspace_path.display(), + ResetColor + ); output.print(""); // Initialize rustyline editor with history @@ -1190,7 +1278,9 @@ async fn run_interactive( output.print("📖 Control Commands:"); output.print(" /compact - Trigger auto-summarization (compacts conversation history)"); output.print(" /thinnify - Trigger context thinning (replaces large tool results with file references)"); - output.print(" /readme - Reload README.md and AGENTS.md from disk"); + output.print( + " /readme - Reload README.md and AGENTS.md from disk", + ); output.print(" /stats - Show detailed context and performance statistics"); output.print(" /help - Show this help message"); output.print(" exit/quit - Exit the interactive session"); @@ -1207,7 +1297,10 @@ async fn run_interactive( output.print("⚠️ Summarization failed"); } Err(e) => { - output.print(&format!("❌ Error during summarization: {}", e)); + output.print(&format!( + "❌ Error during summarization: {}", + e + )); } } continue; @@ -1220,9 +1313,14 @@ async fn run_interactive( "/readme" => { output.print("📚 Reloading README.md and AGENTS.md..."); match agent.reload_readme() { - Ok(true) => output.print("✅ README content reloaded successfully"), - Ok(false) => output.print("⚠️ No README was loaded at startup, cannot reload"), - Err(e) => output.print(&format!("❌ Error reloading README: {}", e)), + Ok(true) => { + output.print("✅ README content reloaded successfully") + } + Ok(false) => output + .print("⚠️ No README was loaded at startup, cannot reload"), + Err(e) => { + output.print(&format!("❌ Error reloading README: {}", e)) + } } continue; } @@ -1232,7 +1330,10 @@ async fn run_interactive( continue; } _ => { - output.print(&format!("❌ Unknown command: {}. Type /help for available commands.", input)); + output.print(&format!( + "❌ Unknown command: {}. Type /help for available commands.", + input + )); continue; } } @@ -1421,8 +1522,12 @@ async fn run_interactive_machine( "/readme" => { println!("COMMAND: readme"); match agent.reload_readme() { - Ok(true) => println!("RESULT: README content reloaded successfully"), - Ok(false) => println!("RESULT: No README was loaded at startup, cannot reload"), + Ok(true) => { + println!("RESULT: README content reloaded successfully") + } + Ok(false) => println!( + "RESULT: No README was loaded at startup, cannot reload" + ), Err(e) => println!("ERROR: {}", e), } continue; @@ -1527,7 +1632,10 @@ async fn execute_task_machine( let delay_ms = 1000 * (2_u64.pow(attempt - 1)); let delay = std::time::Duration::from_millis(delay_ms); - println!("TIMEOUT: attempt {} of {}, retrying in {:?}", attempt, MAX_TIMEOUT_RETRIES, delay); + println!( + "TIMEOUT: attempt {} of {}, retrying in {:?}", + attempt, MAX_TIMEOUT_RETRIES, delay + ); // Wait before retrying tokio::time::sleep(delay).await; @@ -1579,29 +1687,41 @@ fn handle_execution_error(e: &anyhow::Error, input: &str, output: &SimpleOutput, fn display_context_progress(agent: &Agent, _output: &SimpleOutput) { let context = agent.get_context_window(); let percentage = context.percentage_used(); - + // Create 10 dots representing context fullness let total_dots: usize = 10; let filled_dots = ((percentage / 100.0) * total_dots as f32).round() as usize; let empty_dots = total_dots.saturating_sub(filled_dots); - + let filled_str = "●".repeat(filled_dots); let empty_str = "○".repeat(empty_dots); - + // Determine color based on percentage let color = if percentage < 40.0 { Color::Green } else if percentage < 60.0 { Color::Yellow } else if percentage < 80.0 { - Color::Rgb { r: 255, g: 165, b: 0 } // Orange + Color::Rgb { + r: 255, + g: 165, + b: 0, + } // Orange } else { Color::Red }; - + // Print with colored dots (using print! directly to handle color codes) - print!("Context: {}{}{}{} {:.0}% ({}/{} tokens)\n", - SetForegroundColor(color), filled_str, empty_str, ResetColor, percentage, context.used_tokens, context.total_tokens); + print!( + "Context: {}{}{}{} {:.0}% ({}/{} tokens)\n", + SetForegroundColor(color), + filled_str, + empty_str, + ResetColor, + percentage, + context.used_tokens, + context.total_tokens + ); } /// Set up the workspace directory for autonomous mode @@ -1752,9 +1872,9 @@ async fn run_autonomous( let mut hasher = Sha256::new(); hasher.update(requirements.as_bytes()); let requirements_sha = hex::encode(hasher.finalize()); - + output.print(&format!("🔒 Requirements SHA256: {}", requirements_sha)); - + // Pass SHA to agent for staleness checking agent.set_requirements_sha(requirements_sha.clone()); @@ -1763,35 +1883,53 @@ async fn run_autonomous( // Load fast-discovery messages before the loop starts (if enabled) let (discovery_messages, discovery_working_dir): (Vec, Option) = - if let Some(ref codebase_path) = codebase_fast_start { - // Canonicalize the path to ensure it's absolute - let canonical_path = codebase_path.canonicalize().unwrap_or_else(|_| codebase_path.clone()); - let path_str = canonical_path.to_string_lossy(); - output.print(&format!("🔍 Fast-discovery mode: will explore codebase at {}", path_str)); - // Get the provider from the agent and use async LLM-based discovery - match agent.get_provider() { - Ok(provider) => { - // Create a status callback that prints to output - let output_clone = output.clone(); - let status_callback: g3_planner::StatusCallback = Box::new(move |msg: &str| { - output_clone.print(msg); - }); - match g3_planner::get_initial_discovery_messages(&path_str, Some(&requirements), provider, Some(&status_callback)).await { - Ok(messages) => (messages, Some(path_str.to_string())), - Err(e) => { - output.print(&format!("⚠️ LLM discovery failed: {}, skipping fast-start", e)); - (Vec::new(), None) + if let Some(ref codebase_path) = codebase_fast_start { + // Canonicalize the path to ensure it's absolute + let canonical_path = codebase_path + .canonicalize() + .unwrap_or_else(|_| codebase_path.clone()); + let path_str = canonical_path.to_string_lossy(); + output.print(&format!( + "🔍 Fast-discovery mode: will explore codebase at {}", + path_str + )); + // Get the provider from the agent and use async LLM-based discovery + match agent.get_provider() { + Ok(provider) => { + // Create a status callback that prints to output + let output_clone = output.clone(); + let status_callback: g3_planner::StatusCallback = Box::new(move |msg: &str| { + output_clone.print(msg); + }); + match g3_planner::get_initial_discovery_messages( + &path_str, + Some(&requirements), + provider, + Some(&status_callback), + ) + .await + { + Ok(messages) => (messages, Some(path_str.to_string())), + Err(e) => { + output.print(&format!( + "⚠️ LLM discovery failed: {}, skipping fast-start", + e + )); + (Vec::new(), None) + } } } + Err(e) => { + output.print(&format!( + "⚠️ Could not get provider: {}, skipping fast-start", + e + )); + (Vec::new(), None) + } } - Err(e) => { - output.print(&format!("⚠️ Could not get provider: {}, skipping fast-start", e)); - (Vec::new(), None) - } - } - } else { - (Vec::new(), None) - }; + } else { + (Vec::new(), None) + }; let has_discovery = !discovery_messages.is_empty(); let mut turn = 1; @@ -1823,7 +1961,10 @@ async fn run_autonomous( ) }; - output.print(&format!("🎯 Starting player implementation... (elapsed: {})", format_elapsed_time(loop_start.elapsed()))); + output.print(&format!( + "🎯 Starting player implementation... (elapsed: {})", + format_elapsed_time(loop_start.elapsed()) + )); // Display what feedback the player is receiving // If there's no coach feedback on subsequent turns, this is an error @@ -1863,7 +2004,9 @@ async fn run_autonomous( messages: &discovery_messages, fast_start_path: discovery_working_dir.as_deref(), }) - } else { None }, + } else { + None + }, ) .await { @@ -1878,7 +2021,10 @@ async fn run_autonomous( use g3_core::error_handling::{classify_error, ErrorType, RecoverableError}; let error_type = classify_error(&e); - if matches!(error_type, ErrorType::Recoverable(RecoverableError::ContextLengthExceeded)) { + if matches!( + error_type, + ErrorType::Recoverable(RecoverableError::ContextLengthExceeded) + ) { output.print(&format!("⚠️ Context length exceeded in player turn: {}", e)); output.print("📝 Logging error to session and ending current turn..."); @@ -1924,10 +2070,7 @@ async fn run_autonomous( output.print("📝 Final Status: 💥 PLAYER PANIC"); output.print("\n📈 Token Usage Statistics:"); - output.print(&format!( - " • Used Tokens: {}", - context_window.used_tokens - )); + output.print(&format!(" • Used Tokens: {}", context_window.used_tokens)); output.print(&format!( " • Total Available: {}", context_window.total_tokens @@ -1954,9 +2097,8 @@ async fn run_autonomous( )); if _player_retry_count >= MAX_PLAYER_RETRIES { - output.print( - "🔄 Max retries reached for player, marking turn as failed...", - ); + output + .print("🔄 Max retries reached for player, marking turn as failed..."); player_failed = true; break; // Exit retry loop } @@ -1973,7 +2115,10 @@ async fn run_autonomous( )); // Record turn metrics before incrementing let turn_duration = turn_start_time.elapsed(); - let turn_tokens = agent.get_context_window().used_tokens.saturating_sub(turn_start_tokens); + let turn_tokens = agent + .get_context_window() + .used_tokens + .saturating_sub(turn_start_tokens); turn_metrics.push(TurnMetrics { turn_number: turn, tokens_used: turn_tokens, @@ -2006,7 +2151,8 @@ async fn run_autonomous( let ui_writer = ConsoleUiWriter::new(); let mut coach_agent = - Agent::new_autonomous_with_readme_and_quiet(coach_config, ui_writer, None, quiet).await?; + Agent::new_autonomous_with_readme_and_quiet(coach_config, ui_writer, None, quiet) + .await?; // Surface provider info for coach agent coach_agent.print_provider_banner("Coach"); @@ -2050,7 +2196,10 @@ Remember: Be clear in your review and concise in your feedback. APPROVE iff the requirements ); - output.print(&format!("🎓 Starting coach review... (elapsed: {})", format_elapsed_time(loop_start.elapsed()))); + output.print(&format!( + "🎓 Starting coach review... (elapsed: {})", + format_elapsed_time(loop_start.elapsed()) + )); // Execute coach task with retry on error let mut coach_retry_count = 0; @@ -2060,13 +2209,22 @@ Remember: Be clear in your review and concise in your feedback. APPROVE iff the loop { match coach_agent - .execute_task_with_timing(&coach_prompt, None, false, show_prompt, show_code, true, + .execute_task_with_timing( + &coach_prompt, + None, + false, + show_prompt, + show_code, + true, if has_discovery { Some(DiscoveryOptions { messages: &discovery_messages, fast_start_path: discovery_working_dir.as_deref(), }) - } else { None }) + } else { + None + }, + ) .await { Ok(result) => { @@ -2077,11 +2235,14 @@ Remember: Be clear in your review and concise in your feedback. APPROVE iff the // Check if this is a context length exceeded error use g3_core::error_handling::{classify_error, ErrorType, RecoverableError}; let error_type = classify_error(&e); - - if matches!(error_type, ErrorType::Recoverable(RecoverableError::ContextLengthExceeded)) { + + if matches!( + error_type, + ErrorType::Recoverable(RecoverableError::ContextLengthExceeded) + ) { output.print(&format!("⚠️ Context length exceeded in coach turn: {}", e)); output.print("📝 Logging error to session and ending current turn..."); - + // Build forensic context let forensic_context = format!( "Turn: {}\n\ @@ -2098,10 +2259,10 @@ Remember: Be clear in your review and concise in your feedback. APPROVE iff the coach_prompt.len(), chrono::Utc::now().to_rfc3339() ); - + // Log to coach's session JSON coach_agent.log_error_to_session(&e, "assistant", Some(forensic_context)); - + // Mark turn as failed and continue to next turn coach_result_opt = None; coach_failed = true; @@ -2174,7 +2335,10 @@ Remember: Be clear in your review and concise in your feedback. APPROVE iff the coach_feedback = "The implementation needs review. Please ensure all requirements are met and the code compiles without errors.".to_string(); // Record turn metrics before incrementing let turn_duration = turn_start_time.elapsed(); - let turn_tokens = agent.get_context_window().used_tokens.saturating_sub(turn_start_tokens); + let turn_tokens = agent + .get_context_window() + .used_tokens + .saturating_sub(turn_start_tokens); turn_metrics.push(TurnMetrics { turn_number: turn, tokens_used: turn_tokens, @@ -2210,7 +2374,10 @@ Remember: Be clear in your review and concise in your feedback. APPROVE iff the coach_feedback = "The implementation needs review. Please ensure all requirements are met and the code compiles without errors.".to_string(); // Record turn metrics before incrementing let turn_duration = turn_start_time.elapsed(); - let turn_tokens = agent.get_context_window().used_tokens.saturating_sub(turn_start_tokens); + let turn_tokens = agent + .get_context_window() + .used_tokens + .saturating_sub(turn_start_tokens); turn_metrics.push(TurnMetrics { turn_number: turn, tokens_used: turn_tokens, @@ -2241,7 +2408,10 @@ Remember: Be clear in your review and concise in your feedback. APPROVE iff the coach_feedback = coach_feedback_text; // Record turn metrics before incrementing let turn_duration = turn_start_time.elapsed(); - let turn_tokens = agent.get_context_window().used_tokens.saturating_sub(turn_start_tokens); + let turn_tokens = agent + .get_context_window() + .used_tokens + .saturating_sub(turn_start_tokens); turn_metrics.push(TurnMetrics { turn_number: turn, tokens_used: turn_tokens, @@ -2290,15 +2460,21 @@ Remember: Be clear in your review and concise in your feedback. APPROVE iff the " • Usage Percentage: {:.1}%", context_window.percentage_used() )); - + // Add per-turn histogram output.print(&generate_turn_histogram(&turn_metrics)); output.print(&"=".repeat(60)); if implementation_approved { - output.print(&format!("\n🎉 Autonomous mode completed successfully (total loop time: {})", format_elapsed_time(loop_start.elapsed()))); + output.print(&format!( + "\n🎉 Autonomous mode completed successfully (total loop time: {})", + format_elapsed_time(loop_start.elapsed()) + )); } else { - output.print(&format!("\n🔄 Autonomous mode terminated (max iterations) (total loop time: {})", format_elapsed_time(loop_start.elapsed()))); + output.print(&format!( + "\n🔄 Autonomous mode terminated (max iterations) (total loop time: {})", + format_elapsed_time(loop_start.elapsed()) + )); } Ok(()) diff --git a/crates/g3-cli/src/machine_ui_writer.rs b/crates/g3-cli/src/machine_ui_writer.rs index 6b70837..f3cc288 100644 --- a/crates/g3-cli/src/machine_ui_writer.rs +++ b/crates/g3-cli/src/machine_ui_writer.rs @@ -87,9 +87,9 @@ impl UiWriter for MachineUiWriter { fn flush(&self) { let _ = io::stdout().flush(); } - + fn wants_full_output(&self) -> bool { - true // Machine mode wants complete, untruncated output + true // Machine mode wants complete, untruncated output } fn prompt_user_yes_no(&self, message: &str) -> bool { diff --git a/crates/g3-cli/src/simple_output.rs b/crates/g3-cli/src/simple_output.rs index a917097..3838f29 100644 --- a/crates/g3-cli/src/simple_output.rs +++ b/crates/g3-cli/src/simple_output.rs @@ -6,7 +6,9 @@ pub struct SimpleOutput { impl SimpleOutput { pub fn new() -> Self { - SimpleOutput { machine_mode: false } + SimpleOutput { + machine_mode: false, + } } pub fn new_with_mode(machine_mode: bool) -> Self { diff --git a/crates/g3-cli/src/ui_writer_impl.rs b/crates/g3-cli/src/ui_writer_impl.rs index f9c844f..898ed5e 100644 --- a/crates/g3-cli/src/ui_writer_impl.rs +++ b/crates/g3-cli/src/ui_writer_impl.rs @@ -25,22 +25,22 @@ impl ConsoleUiWriter { fn print_todo_line(&self, line: &str) { // Transform and print todo list lines elegantly let trimmed = line.trim(); - + // Skip the "📝 TODO list:" prefix line if trimmed.starts_with("📝 TODO list:") || trimmed == "📝 TODO list is empty" { return; } - + // Handle empty lines if trimmed.is_empty() { println!(); return; } - + // Detect indentation level let indent_count = line.chars().take_while(|c| c.is_whitespace()).count(); let indent = " ".repeat(indent_count / 2); // Convert spaces to visual indent - + // Format based on line type if trimmed.starts_with("- [ ]") { // Incomplete task @@ -48,7 +48,8 @@ impl ConsoleUiWriter { println!("{}☐ {}", indent, task); } else if trimmed.starts_with("- [x]") || trimmed.starts_with("- [X]") { // Completed task - let task = trimmed.strip_prefix("- [x]") + let task = trimmed + .strip_prefix("- [x]") .or_else(|| trimmed.strip_prefix("- [X]")) .unwrap_or(trimmed) .trim(); @@ -105,31 +106,31 @@ impl UiWriter for ConsoleUiWriter { fn print_context_thinning(&self, message: &str) { // Animated highlight for context thinning // Use bright cyan/green with a quick flash animation - + // Flash animation: print with bright background, then normal let frames = vec![ - "\x1b[1;97;46m", // Frame 1: Bold white on cyan background - "\x1b[1;97;42m", // Frame 2: Bold white on green background - "\x1b[1;96;40m", // Frame 3: Bold cyan on black background + "\x1b[1;97;46m", // Frame 1: Bold white on cyan background + "\x1b[1;97;42m", // Frame 2: Bold white on green background + "\x1b[1;96;40m", // Frame 3: Bold cyan on black background ]; - + println!(); - + // Quick flash animation for frame in &frames { print!("\r{} ✨ {} ✨\x1b[0m", frame, message); let _ = io::stdout().flush(); std::thread::sleep(std::time::Duration::from_millis(80)); } - + // Final display with bright cyan and sparkle emojis print!("\r\x1b[1;96m✨ {} ✨\x1b[0m", message); println!(); - + // Add a subtle "success" indicator line println!("\x1b[2;36m └─ Context optimized successfully\x1b[0m"); println!(); - + let _ = io::stdout().flush(); } @@ -137,14 +138,13 @@ impl UiWriter for ConsoleUiWriter { // Store the tool name and clear args for collection *self.current_tool_name.lock().unwrap() = Some(tool_name.to_string()); self.current_tool_args.lock().unwrap().clear(); - + // Check if this is a todo tool call let is_todo = tool_name == "todo_read" || tool_name == "todo_write"; *self.in_todo_tool.lock().unwrap() = is_todo; - + // For todo tools, we'll skip the normal header and print a custom one later - if is_todo { - } + if is_todo {} } fn print_tool_arg(&self, key: &str, value: &str) { @@ -172,7 +172,7 @@ impl UiWriter for ConsoleUiWriter { println!(); // Just add a newline return; } - + println!(); // Now print the tool header with the most important arg in bold green if let Some(tool_name) = self.current_tool_name.lock().unwrap().as_ref() { @@ -192,7 +192,8 @@ impl UiWriter for ConsoleUiWriter { // Truncate long values for display let display_value = if first_line.len() > 80 { // Use char_indices to safely truncate at character boundary - let truncate_at = first_line.char_indices() + let truncate_at = first_line + .char_indices() .nth(77) .map(|(i, _)| i) .unwrap_or(first_line.len()); @@ -206,10 +207,18 @@ impl UiWriter for ConsoleUiWriter { // Check if start or end parameters are present let has_start = args.iter().any(|(k, _)| k == "start"); let has_end = args.iter().any(|(k, _)| k == "end"); - + if has_start || has_end { - let start_val = args.iter().find(|(k, _)| k == "start").map(|(_, v)| v.as_str()).unwrap_or("0"); - let end_val = args.iter().find(|(k, _)| k == "end").map(|(_, v)| v.as_str()).unwrap_or("end"); + let start_val = args + .iter() + .find(|(k, _)| k == "start") + .map(|(_, v)| v.as_str()) + .unwrap_or("0"); + let end_val = args + .iter() + .find(|(k, _)| k == "end") + .map(|(_, v)| v.as_str()) + .unwrap_or("end"); format!(" [{}..{}]", start_val, end_val) } else { String::new() @@ -219,7 +228,10 @@ impl UiWriter for ConsoleUiWriter { }; // Print with bold green tool name, purple (non-bold) for pipe and args - println!("┌─\x1b[1;32m {}\x1b[0m\x1b[35m | {}{}\x1b[0m", tool_name, display_value, header_suffix); + println!( + "┌─\x1b[1;32m {}\x1b[0m\x1b[35m | {}{}\x1b[0m", + tool_name, display_value, header_suffix + ); } else { // Print with bold green formatting using ANSI escape codes println!("┌─\x1b[1;32m {}\x1b[0m", tool_name); @@ -252,7 +264,7 @@ impl UiWriter for ConsoleUiWriter { self.print_todo_line(line); return; } - + println!("│ \x1b[2m{}\x1b[0m", line); } @@ -261,7 +273,7 @@ impl UiWriter for ConsoleUiWriter { if *self.in_todo_tool.lock().unwrap() { return; } - + println!( "│ \x1b[2m({} line{})\x1b[0m", count, @@ -276,7 +288,7 @@ impl UiWriter for ConsoleUiWriter { *self.in_todo_tool.lock().unwrap() = false; return; } - + // Parse the duration string to determine color // Format is like "1.5s", "500ms", "2m 30.0s" let color_code = if duration_str.ends_with("ms") { @@ -379,4 +391,3 @@ impl UiWriter for ConsoleUiWriter { } } } - diff --git a/crates/g3-computer-control/build.rs b/crates/g3-computer-control/build.rs index d7d6f63..b7760a2 100644 --- a/crates/g3-computer-control/build.rs +++ b/crates/g3-computer-control/build.rs @@ -34,27 +34,40 @@ fn main() { .expect("Failed to find .build/release directory"); // Copy the dylib to the output directory so it can be found at runtime - let target_dir = manifest_dir.parent().unwrap().parent().unwrap().join("target"); + let target_dir = manifest_dir + .parent() + .unwrap() + .parent() + .unwrap() + .join("target"); let profile = env::var("PROFILE").unwrap_or_else(|_| "debug".to_string()); - + // Determine the actual target directory (could be llvm-cov-target or regular target) - let target_dir_name = env::var("CARGO_TARGET_DIR") - .unwrap_or_else(|_| target_dir.to_string_lossy().to_string()); + let target_dir_name = + env::var("CARGO_TARGET_DIR").unwrap_or_else(|_| target_dir.to_string_lossy().to_string()); let actual_target_dir = PathBuf::from(&target_dir_name); let output_dir = actual_target_dir.join(&profile); - + let dylib_src = lib_path.join("libVisionBridge.dylib"); let dylib_dst = output_dir.join("libVisionBridge.dylib"); - + // Create output directory if it doesn't exist - std::fs::create_dir_all(&output_dir) - .expect(&format!("Failed to create output directory {}", output_dir.display())); - - std::fs::copy(&dylib_src, &dylib_dst) - .expect(&format!("Failed to copy dylib from {} to {}", dylib_src.display(), dylib_dst.display())); - - println!("cargo:warning=Copied libVisionBridge.dylib to {}", dylib_dst.display()); - + std::fs::create_dir_all(&output_dir).expect(&format!( + "Failed to create output directory {}", + output_dir.display() + )); + + std::fs::copy(&dylib_src, &dylib_dst).expect(&format!( + "Failed to copy dylib from {} to {}", + dylib_src.display(), + dylib_dst.display() + )); + + println!( + "cargo:warning=Copied libVisionBridge.dylib to {}", + dylib_dst.display() + ); + // Add rpath so the dylib can be found at runtime println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path"); println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path"); @@ -68,5 +81,8 @@ fn main() { println!("cargo:rustc-link-lib=framework=CoreGraphics"); println!("cargo:rustc-link-lib=framework=CoreImage"); - println!("cargo:warning=VisionBridge built successfully at {}", lib_path.display()); + println!( + "cargo:warning=VisionBridge built successfully at {}", + lib_path.display() + ); } diff --git a/crates/g3-computer-control/examples/debug_screenshot.rs b/crates/g3-computer-control/examples/debug_screenshot.rs index 5e0930f..7ff7db7 100644 --- a/crates/g3-computer-control/examples/debug_screenshot.rs +++ b/crates/g3-computer-control/examples/debug_screenshot.rs @@ -3,19 +3,19 @@ use core_graphics::display::CGDisplay; fn main() { let display = CGDisplay::main(); let image = display.image().expect("Failed to capture screen"); - + println!("CGImage properties:"); println!(" Width: {}", image.width()); println!(" Height: {}", image.height()); println!(" Bits per component: {}", image.bits_per_component()); println!(" Bits per pixel: {}", image.bits_per_pixel()); println!(" Bytes per row: {}", image.bytes_per_row()); - + let data = image.data(); let expected_size = image.width() * image.height() * 4; println!(" Data length: {}", data.len()); println!(" Expected (w*h*4): {}", expected_size); - + // Check if there's padding in rows let bytes_per_row = image.bytes_per_row(); let width = image.width(); @@ -23,16 +23,25 @@ fn main() { println!("\nRow alignment:"); println!(" Actual bytes per row: {}", bytes_per_row); println!(" Expected (width * 4): {}", expected_bytes_per_row); - println!(" Padding per row: {}", bytes_per_row - expected_bytes_per_row); - + println!( + " Padding per row: {}", + bytes_per_row - expected_bytes_per_row + ); + // Sample some pixels from different locations println!("\nFirst 3 pixels (raw bytes):"); for i in 0..3 { let offset = i * 4; - println!(" Pixel {}: [{:3}, {:3}, {:3}, {:3}]", - i, data[offset], data[offset+1], data[offset+2], data[offset+3]); + println!( + " Pixel {}: [{:3}, {:3}, {:3}, {:3}]", + i, + data[offset], + data[offset + 1], + data[offset + 2], + data[offset + 3] + ); } - + // Check a pixel from the middle let mid_row = image.height() / 2; let mid_col = image.width() / 2; @@ -40,7 +49,12 @@ fn main() { println!("\nMiddle pixel (row {}, col {}):", mid_row, mid_col); println!(" Offset: {}", mid_offset); if mid_offset + 3 < data.len() as usize { - println!(" Bytes: [{:3}, {:3}, {:3}, {:3}]", - data[mid_offset], data[mid_offset+1], data[mid_offset+2], data[mid_offset+3]); + println!( + " Bytes: [{:3}, {:3}, {:3}, {:3}]", + data[mid_offset], + data[mid_offset + 1], + data[mid_offset + 2], + data[mid_offset + 3] + ); } } diff --git a/crates/g3-computer-control/examples/list_windows.rs b/crates/g3-computer-control/examples/list_windows.rs index f1681ff..983ab92 100644 --- a/crates/g3-computer-control/examples/list_windows.rs +++ b/crates/g3-computer-control/examples/list_windows.rs @@ -1,34 +1,38 @@ -use core_graphics::window::{kCGWindowListOptionOnScreenOnly, kCGNullWindowID, CGWindowListCopyWindowInfo}; +use core_foundation::base::{TCFType, ToVoid}; use core_foundation::dictionary::CFDictionary; use core_foundation::string::CFString; -use core_foundation::base::{TCFType, ToVoid}; +use core_graphics::window::{ + kCGNullWindowID, kCGWindowListOptionOnScreenOnly, CGWindowListCopyWindowInfo, +}; fn main() { println!("Listing all on-screen windows..."); println!("{:<10} {:<25} {}", "Window ID", "Owner", "Title"); println!("{}", "-".repeat(80)); - + unsafe { - let window_list = CGWindowListCopyWindowInfo( - kCGWindowListOptionOnScreenOnly, - kCGNullWindowID - ); - - let count = core_foundation::array::CFArray::::wrap_under_create_rule(window_list).len(); - let array = core_foundation::array::CFArray::::wrap_under_create_rule(window_list); - + let window_list = + CGWindowListCopyWindowInfo(kCGWindowListOptionOnScreenOnly, kCGNullWindowID); + + let count = + core_foundation::array::CFArray::::wrap_under_create_rule(window_list) + .len(); + let array = + core_foundation::array::CFArray::::wrap_under_create_rule(window_list); + for i in 0..count { let dict = array.get(i).unwrap(); - + // Get window ID let window_id_key = CFString::from_static_string("kCGWindowNumber"); let window_id: i64 = if let Some(value) = dict.find(window_id_key.to_void()) { - let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _); + let num: core_foundation::number::CFNumber = + TCFType::wrap_under_get_rule(*value as *const _); num.to_i64().unwrap_or(0) } else { 0 }; - + // Get owner name let owner_key = CFString::from_static_string("kCGWindowOwnerName"); let owner: String = if let Some(value) = dict.find(owner_key.to_void()) { @@ -37,7 +41,7 @@ fn main() { } else { "Unknown".to_string() }; - + // Get window name/title let name_key = CFString::from_static_string("kCGWindowName"); let title: String = if let Some(value) = dict.find(name_key.to_void()) { @@ -46,7 +50,7 @@ fn main() { } else { "".to_string() }; - + // Show all windows if !owner.is_empty() { println!("{:<10} {:<25} {}", window_id, owner, title); diff --git a/crates/g3-computer-control/examples/macax_demo.rs b/crates/g3-computer-control/examples/macax_demo.rs index ff1398d..1eefe34 100644 --- a/crates/g3-computer-control/examples/macax_demo.rs +++ b/crates/g3-computer-control/examples/macax_demo.rs @@ -11,11 +11,11 @@ use g3_computer_control::MacAxController; async fn main() -> Result<()> { println!("🍎 macOS Accessibility API Demo\n"); println!("This demo shows how to control macOS applications using the Accessibility API.\n"); - + // Create controller let controller = MacAxController::new()?; println!("✅ MacAxController initialized\n"); - + // List running applications println!("📱 Listing running applications:"); match controller.list_applications() { @@ -30,7 +30,7 @@ async fn main() -> Result<()> { Err(e) => println!(" ❌ Error: {}", e), } println!(); - + // Get frontmost app println!("🎯 Getting frontmost application:"); match controller.get_frontmost_app() { @@ -38,16 +38,16 @@ async fn main() -> Result<()> { Err(e) => println!(" ❌ Error: {}", e), } println!(); - + // Example: Activate Finder and get its UI tree println!("📂 Activating Finder and inspecting UI:"); match controller.activate_app("Finder") { Ok(_) => { println!(" ✅ Finder activated"); - + // Wait a moment for activation tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; - + // Get UI tree match controller.get_ui_tree("Finder", 2) { Ok(tree) => { @@ -62,13 +62,13 @@ async fn main() -> Result<()> { Err(e) => println!(" ❌ Error: {}", e), } println!(); - + println!("✨ Demo complete!\n"); println!("💡 Tips:"); println!(" - Use --macax flag with g3 to enable these tools"); println!(" - Grant accessibility permissions in System Preferences"); println!(" - Add accessibility identifiers to your apps for easier automation"); println!(" - See docs/macax-tools.md for full documentation\n"); - + Ok(()) } diff --git a/crates/g3-computer-control/examples/safari_demo.rs b/crates/g3-computer-control/examples/safari_demo.rs index b28ebd6..73535dc 100644 --- a/crates/g3-computer-control/examples/safari_demo.rs +++ b/crates/g3-computer-control/examples/safari_demo.rs @@ -1,64 +1,66 @@ -use g3_computer_control::SafariDriver; -use g3_computer_control::webdriver::WebDriverController; use anyhow::Result; +use g3_computer_control::webdriver::WebDriverController; +use g3_computer_control::SafariDriver; #[tokio::main] async fn main() -> Result<()> { println!("Safari WebDriver Demo"); println!("=====================\n"); - + println!("Make sure to:"); println!("1. Enable 'Allow Remote Automation' in Safari's Develop menu"); println!("2. Run: /usr/bin/safaridriver --enable"); println!("3. Start safaridriver in another terminal: safaridriver --port 4444\n"); - + println!("Connecting to SafariDriver..."); let mut driver = SafariDriver::new().await?; println!("✅ Connected!\n"); - + // Navigate to a website println!("Navigating to example.com..."); driver.navigate("https://example.com").await?; println!("✅ Navigated\n"); - + // Get page title let title = driver.title().await?; println!("Page title: {}\n", title); - + // Get current URL let url = driver.current_url().await?; println!("Current URL: {}\n", url); - + // Find an element println!("Finding h1 element..."); let h1 = driver.find_element("h1").await?; let h1_text = h1.text().await?; println!("H1 text: {}\n", h1_text); - + // Find all paragraphs println!("Finding all paragraphs..."); let paragraphs = driver.find_elements("p").await?; println!("Found {} paragraphs\n", paragraphs.len()); - + // Get page source println!("Getting page source..."); let source = driver.page_source().await?; println!("Page source length: {} bytes\n", source.len()); - + // Execute JavaScript println!("Executing JavaScript..."); - let result = driver.execute_script("return document.title", vec![]).await?; + let result = driver + .execute_script("return document.title", vec![]) + .await?; println!("JS result: {:?}\n", result); - + // Take a screenshot println!("Taking screenshot..."); driver.screenshot("/tmp/safari_demo.png").await?; println!("✅ Screenshot saved to /tmp/safari_demo.png\n"); - + // Close the browser println!("Closing browser..."); driver.quit().await?; println!("✅ Done!"); - + Ok(()) } diff --git a/crates/g3-computer-control/examples/test_permission_prompt.rs b/crates/g3-computer-control/examples/test_permission_prompt.rs index fdd5a4b..6d0f78d 100644 --- a/crates/g3-computer-control/examples/test_permission_prompt.rs +++ b/crates/g3-computer-control/examples/test_permission_prompt.rs @@ -3,10 +3,13 @@ use g3_computer_control::create_controller; #[tokio::main] async fn main() { println!("Testing screenshot with permission prompt..."); - + let controller = create_controller().expect("Failed to create controller"); - - match controller.take_screenshot("/tmp/test_with_prompt.png", None, None).await { + + match controller + .take_screenshot("/tmp/test_with_prompt.png", None, None) + .await + { Ok(_) => { println!("\n✅ Screenshot saved to /tmp/test_with_prompt.png"); println!("Opening screenshot..."); diff --git a/crates/g3-computer-control/examples/test_screencapture_direct.rs b/crates/g3-computer-control/examples/test_screencapture_direct.rs index 7b72125..6c1ce25 100644 --- a/crates/g3-computer-control/examples/test_screencapture_direct.rs +++ b/crates/g3-computer-control/examples/test_screencapture_direct.rs @@ -2,29 +2,33 @@ use std::process::Command; fn main() { let path = "/tmp/rust_screencapture_test.png"; - + println!("Testing screencapture command from Rust..."); - + let mut cmd = Command::new("screencapture"); cmd.arg("-x"); // No sound cmd.arg(path); - + println!("Command: {:?}", cmd); - + match cmd.output() { Ok(output) => { println!("Exit status: {}", output.status); println!("Stdout: {}", String::from_utf8_lossy(&output.stdout)); println!("Stderr: {}", String::from_utf8_lossy(&output.stderr)); - + if output.status.success() { println!("\n✅ Screenshot saved to: {}", path); - + // Check file exists and size if let Ok(metadata) = std::fs::metadata(path) { - println!("File size: {} bytes ({:.1} MB)", metadata.len(), metadata.len() as f64 / 1_000_000.0); + println!( + "File size: {} bytes ({:.1} MB)", + metadata.len(), + metadata.len() as f64 / 1_000_000.0 + ); } - + // Open it let _ = Command::new("open").arg(path).spawn(); println!("\nOpened screenshot - please verify it looks correct!"); diff --git a/crates/g3-computer-control/examples/test_screenshot_fix.rs b/crates/g3-computer-control/examples/test_screenshot_fix.rs index 467da49..beaffe3 100644 --- a/crates/g3-computer-control/examples/test_screenshot_fix.rs +++ b/crates/g3-computer-control/examples/test_screenshot_fix.rs @@ -4,17 +4,23 @@ use image::{ImageBuffer, RgbaImage}; fn main() { let display = CGDisplay::main(); let image = display.image().expect("Failed to capture screen"); - + let width = image.width() as u32; let height = image.height() as u32; let bytes_per_row = image.bytes_per_row() as usize; let data = image.data(); - + println!("Testing screenshot fix..."); - println!("Image: {}x{}, bytes_per_row: {}", width, height, bytes_per_row); + println!( + "Image: {}x{}, bytes_per_row: {}", + width, height, bytes_per_row + ); println!("Expected bytes per row: {}", width * 4); - println!("Padding per row: {} bytes", bytes_per_row - (width as usize * 4)); - + println!( + "Padding per row: {} bytes", + bytes_per_row - (width as usize * 4) + ); + // OLD METHOD (broken) - treating data as continuous println!("\n=== OLD METHOD (BROKEN) ==="); let mut old_rgba = Vec::with_capacity(data.len() as usize); @@ -26,14 +32,14 @@ fn main() { } println!("Converted {} pixels", old_rgba.len() / 4); println!("Expected {} pixels", width * height); - + // NEW METHOD (fixed) - handling row padding println!("\n=== NEW METHOD (FIXED) ==="); let mut new_rgba = Vec::with_capacity((width * height * 4) as usize); for row in 0..height as usize { let row_start = row * bytes_per_row; let row_end = row_start + (width as usize * 4); - + for chunk in data[row_start..row_end].chunks_exact(4) { new_rgba.push(chunk[2]); // R new_rgba.push(chunk[1]); // G @@ -43,26 +49,34 @@ fn main() { } println!("Converted {} pixels", new_rgba.len() / 4); println!("Expected {} pixels", width * height); - + // Save a small crop from both methods let crop_size = 200; - + // Old method crop - let old_crop: Vec = old_rgba.iter().take((crop_size * crop_size * 4) as usize).copied().collect(); + let old_crop: Vec = old_rgba + .iter() + .take((crop_size * crop_size * 4) as usize) + .copied() + .collect(); if let Some(old_img) = ImageBuffer::from_raw(crop_size, crop_size, old_crop) { let old_img: RgbaImage = old_img; old_img.save("/tmp/screenshot_old_method.png").unwrap(); println!("\nSaved OLD method crop to: /tmp/screenshot_old_method.png"); } - + // New method crop - let new_crop: Vec = new_rgba.iter().take((crop_size * crop_size * 4) as usize).copied().collect(); + let new_crop: Vec = new_rgba + .iter() + .take((crop_size * crop_size * 4) as usize) + .copied() + .collect(); if let Some(new_img) = ImageBuffer::from_raw(crop_size, crop_size, new_crop) { let new_img: RgbaImage = new_img; new_img.save("/tmp/screenshot_new_method.png").unwrap(); println!("Saved NEW method crop to: /tmp/screenshot_new_method.png"); } - + println!("\nOpen both images to compare:"); println!(" open /tmp/screenshot_old_method.png /tmp/screenshot_new_method.png"); } diff --git a/crates/g3-computer-control/examples/test_type_text.rs b/crates/g3-computer-control/examples/test_type_text.rs index 2d1aea0..eb5d416 100644 --- a/crates/g3-computer-control/examples/test_type_text.rs +++ b/crates/g3-computer-control/examples/test_type_text.rs @@ -6,43 +6,43 @@ use g3_computer_control::MacAxController; #[tokio::main] async fn main() -> Result<()> { println!("🧪 Testing macax type_text functionality\n"); - + let controller = MacAxController::new()?; println!("✅ Controller initialized\n"); - + // Test 1: Type simple text println!("Test 1: Typing simple text into TextEdit"); println!(" Please open TextEdit and create a new document..."); std::thread::sleep(std::time::Duration::from_secs(3)); - + match controller.type_text("TextEdit", "Hello, World!") { Ok(_) => println!(" ✅ Successfully typed simple text\n"), Err(e) => println!(" ❌ Failed: {}\n", e), } - + std::thread::sleep(std::time::Duration::from_secs(1)); - + // Test 2: Type unicode and emojis println!("Test 2: Typing unicode and emojis"); match controller.type_text("TextEdit", "\n🌟 Unicode test: café, naïve, 日本語 🎉") { Ok(_) => println!(" ✅ Successfully typed unicode text\n"), Err(e) => println!(" ❌ Failed: {}\n", e), } - + std::thread::sleep(std::time::Duration::from_secs(1)); - + // Test 3: Type special characters println!("Test 3: Typing special characters"); match controller.type_text("TextEdit", "\nSpecial: @#$%^&*()_+-=[]{}|;':,.<>?/") { Ok(_) => println!(" ✅ Successfully typed special characters\n"), Err(e) => println!(" ❌ Failed: {}\n", e), } - + println!("\n✨ Tests complete!"); println!("\n💡 Now try with Things3:"); println!(" 1. Open Things3"); println!(" 2. Press Cmd+N to create a new task"); println!(" 3. Run: g3 --macax 'type \"🌟 My awesome task\" into Things'"); - + Ok(()) } diff --git a/crates/g3-computer-control/examples/test_vision.rs b/crates/g3-computer-control/examples/test_vision.rs index 5ff09a5..3b65f93 100644 --- a/crates/g3-computer-control/examples/test_vision.rs +++ b/crates/g3-computer-control/examples/test_vision.rs @@ -1,63 +1,67 @@ -use g3_computer_control::ocr::{OCREngine, DefaultOCR}; use anyhow::Result; +use g3_computer_control::ocr::{DefaultOCR, OCREngine}; #[tokio::main] async fn main() -> Result<()> { println!("🧪 Testing Apple Vision OCR"); println!("===========================\n"); - + // Initialize OCR engine println!("📦 Initializing OCR engine..."); let ocr = DefaultOCR::new()?; println!("✅ OCR engine: {}\n", ocr.name()); - + // Check if test image exists let test_image = "/tmp/safari_test.png"; if !std::path::Path::new(test_image).exists() { println!("⚠️ Test image not found: {}", test_image); println!(" Creating a screenshot..."); - + let status = std::process::Command::new("screencapture") .arg("-x") .arg("-R") .arg("0,0,1200,800") .arg(test_image) .status()?; - + if !status.success() { anyhow::bail!("Failed to create screenshot"); } - + println!("✅ Screenshot created\n"); } - + // Run OCR println!("🔍 Running Apple Vision OCR on {}...", test_image); let start = std::time::Instant::now(); let locations = ocr.extract_text_with_locations(test_image).await?; let duration = start.elapsed(); - + println!("✅ OCR completed in {:.3}s\n", duration.as_secs_f64()); - + // Display results println!("📊 Results:"); println!(" Found {} text elements\n", locations.len()); - + if locations.is_empty() { println!("⚠️ No text found in image"); } else { println!(" Top 20 results:"); - println!(" {:<4} {:<40} {:<15} {:<12} {:<8}", "#", "Text", "Position", "Size", "Conf"); + println!( + " {:<4} {:<40} {:<15} {:<12} {:<8}", + "#", "Text", "Position", "Size", "Conf" + ); println!(" {}", "-".repeat(85)); - + for (i, loc) in locations.iter().take(20).enumerate() { let text = if loc.text.len() > 37 { format!("{}...", &loc.text[..37]) } else { loc.text.clone() }; - - println!(" {:<4} {:<40} ({:>4},{:>4}) {:>4}x{:<4} {:.2}", + + println!( + " {:<4} {:<40} ({:>4},{:>4}) {:>4}x{:<4} {:.2}", i + 1, text, loc.x, @@ -67,19 +71,22 @@ async fn main() -> Result<()> { loc.confidence ); } - + if locations.len() > 20 { println!("\n ... and {} more", locations.len() - 20); } - + // Performance comparison println!("\n📈 Performance:"); println!(" OCR Speed: {:.3}s", duration.as_secs_f64()); println!(" Text elements: {}", locations.len()); - println!(" Avg per element: {:.1}ms", duration.as_millis() as f64 / locations.len() as f64); + println!( + " Avg per element: {:.1}ms", + duration.as_millis() as f64 / locations.len() as f64 + ); } - + println!("\n✅ Test complete!"); - + Ok(()) } diff --git a/crates/g3-computer-control/examples/test_window_capture.rs b/crates/g3-computer-control/examples/test_window_capture.rs index 51ce5bc..dc075ca 100644 --- a/crates/g3-computer-control/examples/test_window_capture.rs +++ b/crates/g3-computer-control/examples/test_window_capture.rs @@ -3,36 +3,46 @@ use g3_computer_control::create_controller; #[tokio::main] async fn main() { println!("Testing window-specific screenshot capture..."); - + let controller = create_controller().expect("Failed to create controller"); - + // Test 1: Capture iTerm2 window println!("\n1. Capturing iTerm2 window..."); - match controller.take_screenshot("/tmp/iterm_window.png", None, Some("iTerm2")).await { + match controller + .take_screenshot("/tmp/iterm_window.png", None, Some("iTerm2")) + .await + { Ok(_) => { println!(" ✅ iTerm2 window captured to /tmp/iterm_window.png"); - let _ = std::process::Command::new("open").arg("/tmp/iterm_window.png").spawn(); + let _ = std::process::Command::new("open") + .arg("/tmp/iterm_window.png") + .spawn(); } Err(e) => println!(" ❌ Failed: {}", e), } - + // Wait a moment for the image to open tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; - + // Test 2: Full screen capture for comparison println!("\n2. Capturing full screen for comparison..."); - match controller.take_screenshot("/tmp/fullscreen.png", None, None).await { + match controller + .take_screenshot("/tmp/fullscreen.png", None, None) + .await + { Ok(_) => { println!(" ✅ Full screen captured to /tmp/fullscreen.png"); - let _ = std::process::Command::new("open").arg("/tmp/fullscreen.png").spawn(); + let _ = std::process::Command::new("open") + .arg("/tmp/fullscreen.png") + .spawn(); } Err(e) => println!(" ❌ Failed: {}", e), } - + println!("\n=== Comparison ==="); println!("iTerm window: /tmp/iterm_window.png (should show ONLY iTerm window)"); println!("Full screen: /tmp/fullscreen.png (should show entire desktop)"); - + // Show file sizes if let Ok(meta1) = std::fs::metadata("/tmp/iterm_window.png") { if let Ok(meta2) = std::fs::metadata("/tmp/fullscreen.png") { diff --git a/crates/g3-computer-control/src/lib.rs b/crates/g3-computer-control/src/lib.rs index b1cbc36..8133373 100644 --- a/crates/g3-computer-control/src/lib.rs +++ b/crates/g3-computer-control/src/lib.rs @@ -1,17 +1,17 @@ // Suppress warnings from objc crate macros #![allow(unexpected_cfgs)] -pub mod types; -pub mod platform; -pub mod ocr; -pub mod webdriver; pub mod macax; +pub mod ocr; +pub mod platform; +pub mod types; +pub mod webdriver; // Re-export webdriver types for convenience -pub use webdriver::{WebDriverController, WebElement, safari::SafariDriver}; +pub use webdriver::{safari::SafariDriver, WebDriverController, WebElement}; // Re-export macax types for convenience -pub use macax::{MacAxController, AXElement, AXApplication}; +pub use macax::{AXApplication, AXElement, MacAxController}; use anyhow::Result; use async_trait::async_trait; @@ -20,14 +20,23 @@ use types::*; #[async_trait] pub trait ComputerController: Send + Sync { // Screen capture - async fn take_screenshot(&self, path: &str, region: Option, window_id: Option<&str>) -> Result<()>; - + async fn take_screenshot( + &self, + path: &str, + region: Option, + window_id: Option<&str>, + ) -> Result<()>; + // OCR operations async fn extract_text_from_screen(&self, region: Rect, window_id: &str) -> Result; async fn extract_text_from_image(&self, path: &str) -> Result; async fn extract_text_with_locations(&self, path: &str) -> Result>; - async fn find_text_in_app(&self, app_name: &str, search_text: &str) -> Result>; - + async fn find_text_in_app( + &self, + app_name: &str, + search_text: &str, + ) -> Result>; + // Mouse operations fn move_mouse(&self, x: i32, y: i32) -> Result<()>; fn click_at(&self, x: i32, y: i32, app_name: Option<&str>) -> Result<()>; @@ -37,13 +46,13 @@ pub trait ComputerController: Send + Sync { pub fn create_controller() -> Result> { #[cfg(target_os = "macos")] return Ok(Box::new(platform::macos::MacOSController::new()?)); - + #[cfg(target_os = "linux")] return Ok(Box::new(platform::linux::LinuxController::new()?)); - + #[cfg(target_os = "windows")] return Ok(Box::new(platform::windows::WindowsController::new()?)); - + #[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))] anyhow::bail!("Unsupported platform") } diff --git a/crates/g3-computer-control/src/macax/controller.rs b/crates/g3-computer-control/src/macax/controller.rs index ac91ac1..aa3d529 100644 --- a/crates/g3-computer-control/src/macax/controller.rs +++ b/crates/g3-computer-control/src/macax/controller.rs @@ -3,7 +3,9 @@ use anyhow::{Context, Result}; use std::collections::HashMap; #[cfg(target_os = "macos")] -use accessibility::{AXUIElement, AXUIElementAttributes, ElementFinder, TreeVisitor, TreeWalker, TreeWalkerFlow}; +use accessibility::{ + AXUIElement, AXUIElementAttributes, ElementFinder, TreeVisitor, TreeWalker, TreeWalkerFlow, +}; #[cfg(target_os = "macos")] use core_foundation::base::TCFType; @@ -23,46 +25,46 @@ impl MacAxController { { // Check if we have accessibility permissions by trying to get system-wide element let _system = AXUIElement::system_wide(); - + Ok(Self { app_cache: std::sync::Mutex::new(HashMap::new()), }) } - + #[cfg(not(target_os = "macos"))] { anyhow::bail!("macOS Accessibility API is only available on macOS") } } - + /// List all running applications #[cfg(target_os = "macos")] pub fn list_applications(&self) -> Result> { let apps = Self::get_running_applications()?; Ok(apps) } - + #[cfg(not(target_os = "macos"))] pub fn list_applications(&self) -> Result> { anyhow::bail!("Not supported on this platform") } - + #[cfg(target_os = "macos")] fn get_running_applications() -> Result> { use cocoa::appkit::NSApplicationActivationPolicy; use cocoa::base::{id, nil}; use objc::{class, msg_send, sel, sel_impl}; - + unsafe { let workspace: id = msg_send![class!(NSWorkspace), sharedWorkspace]; let running_apps: id = msg_send![workspace, runningApplications]; let count: usize = msg_send![running_apps, count]; - + let mut apps = Vec::new(); - + for i in 0..count { let app: id = msg_send![running_apps, objectAtIndex: i]; - + // Get app name let localized_name: id = msg_send![app, localizedName]; if localized_name == nil { @@ -76,7 +78,7 @@ impl MacAxController { } else { continue; }; - + // Get bundle ID let bundle_id_obj: id = msg_send![app, bundleIdentifier]; let bundle_id = if bundle_id_obj != nil { @@ -93,13 +95,15 @@ impl MacAxController { } else { None }; - + // Get PID let pid: i32 = msg_send![app, processIdentifier]; - + // Skip background-only apps let activation_policy: i64 = msg_send![app, activationPolicy]; - if activation_policy == NSApplicationActivationPolicy::NSApplicationActivationPolicyRegular as i64 { + if activation_policy + == NSApplicationActivationPolicy::NSApplicationActivationPolicyRegular as i64 + { apps.push(AXApplication { name, bundle_id, @@ -107,32 +111,32 @@ impl MacAxController { }); } } - + Ok(apps) } } - + /// Get the frontmost (active) application #[cfg(target_os = "macos")] pub fn get_frontmost_app(&self) -> Result { use cocoa::base::{id, nil}; use objc::{class, msg_send, sel, sel_impl}; - + unsafe { let workspace: id = msg_send![class!(NSWorkspace), sharedWorkspace]; let frontmost_app: id = msg_send![workspace, frontmostApplication]; - + if frontmost_app == nil { anyhow::bail!("No frontmost application"); } - + // Get app name let localized_name: id = msg_send![frontmost_app, localizedName]; let name_ptr: *const i8 = msg_send![localized_name, UTF8String]; let name = std::ffi::CStr::from_ptr(name_ptr) .to_string_lossy() .to_string(); - + // Get bundle ID let bundle_id_obj: id = msg_send![frontmost_app, bundleIdentifier]; let bundle_id = if bundle_id_obj != nil { @@ -149,10 +153,10 @@ impl MacAxController { } else { None }; - + // Get PID let pid: i32 = msg_send![frontmost_app, processIdentifier]; - + Ok(AXApplication { name, bundle_id, @@ -160,12 +164,12 @@ impl MacAxController { }) } } - + #[cfg(not(target_os = "macos"))] pub fn get_frontmost_app(&self) -> Result { anyhow::bail!("Not supported on this platform") } - + /// Get AXUIElement for an application by name or PID #[cfg(target_os = "macos")] fn get_app_element(&self, app_name: &str) -> Result { @@ -176,79 +180,79 @@ impl MacAxController { return Ok(element.clone()); } } - + // Find the app by name let apps = Self::get_running_applications()?; let app = apps .iter() .find(|a| a.name == app_name) .ok_or_else(|| anyhow::anyhow!("Application '{}' not found", app_name))?; - + // Create AXUIElement for the app let element = AXUIElement::application(app.pid); - + // Cache it { let mut cache = self.app_cache.lock().unwrap(); cache.insert(app_name.to_string(), element.clone()); } - + Ok(element) } - + /// Activate (bring to front) an application #[cfg(target_os = "macos")] pub fn activate_app(&self, app_name: &str) -> Result<()> { use cocoa::base::id; use objc::{class, msg_send, sel, sel_impl}; - + // Find the app let apps = Self::get_running_applications()?; let app = apps .iter() .find(|a| a.name == app_name) .ok_or_else(|| anyhow::anyhow!("Application '{}' not found", app_name))?; - + unsafe { let workspace: id = msg_send![class!(NSWorkspace), sharedWorkspace]; let running_apps: id = msg_send![workspace, runningApplications]; let count: usize = msg_send![running_apps, count]; - + for i in 0..count { let running_app: id = msg_send![running_apps, objectAtIndex: i]; let pid: i32 = msg_send![running_app, processIdentifier]; - + if pid == app.pid { let _: bool = msg_send![running_app, activateWithOptions: 0]; return Ok(()); } } } - + anyhow::bail!("Failed to activate application") } - + #[cfg(not(target_os = "macos"))] pub fn activate_app(&self, _app_name: &str) -> Result<()> { anyhow::bail!("Not supported on this platform") } - + /// Get the UI hierarchy of an application #[cfg(target_os = "macos")] pub fn get_ui_tree(&self, app_name: &str, max_depth: usize) -> Result { let app_element = self.get_app_element(app_name)?; let mut output = format!("Application: {}\n", app_name); - + Self::build_ui_tree(&app_element, &mut output, 0, max_depth)?; - + Ok(output) } - + #[cfg(not(target_os = "macos"))] pub fn get_ui_tree(&self, _app_name: &str, _max_depth: usize) -> Result { anyhow::bail!("Not supported on this platform") } - + #[cfg(target_os = "macos")] fn build_ui_tree( element: &AXUIElement, @@ -259,21 +263,22 @@ impl MacAxController { if depth >= max_depth { return Ok(()); } - + let indent = " ".repeat(depth); - + // Get role - let role = element.role().ok().map(|s| s.to_string()) + let role = element + .role() + .ok() + .map(|s| s.to_string()) .unwrap_or_else(|| "Unknown".to_string()); - + // Get title - let title = element.title().ok() - .map(|s| s.to_string()); - + let title = element.title().ok().map(|s| s.to_string()); + // Get identifier - let identifier = element.identifier().ok() - .map(|s| s.to_string()); - + let identifier = element.identifier().ok().map(|s| s.to_string()); + // Format output output.push_str(&format!("{}Role: {}", indent, role)); if let Some(t) = title { @@ -283,7 +288,7 @@ impl MacAxController { output.push_str(&format!(", ID: {}", id)); } output.push('\n'); - + // Get children if let Ok(children) = element.children() { for i in 0..children.len() { @@ -292,10 +297,10 @@ impl MacAxController { } } } - + Ok(()) } - + /// Find UI elements in an application #[cfg(target_os = "macos")] pub fn find_elements( @@ -307,7 +312,7 @@ impl MacAxController { ) -> Result> { let app_element = self.get_app_element(app_name)?; let mut found_elements = Vec::new(); - + let visitor = ElementCollector { role_filter: role.map(|s| s.to_string()), title_filter: title.map(|s| s.to_string()), @@ -315,13 +320,13 @@ impl MacAxController { results: std::cell::RefCell::new(&mut found_elements), depth: std::cell::Cell::new(0), }; - + let walker = TreeWalker::new(); walker.walk(&app_element, &visitor); - + Ok(found_elements) } - + #[cfg(not(target_os = "macos"))] pub fn find_elements( &self, @@ -332,7 +337,7 @@ impl MacAxController { ) -> Result> { anyhow::bail!("Not supported on this platform") } - + /// Find a single element (helper for click, set_value, etc.) #[cfg(target_os = "macos")] fn find_element( @@ -343,19 +348,17 @@ impl MacAxController { identifier: Option<&str>, ) -> Result { let app_element = self.get_app_element(app_name)?; - + let role_str = role.to_string(); let title_str = title.map(|s| s.to_string()); let identifier_str = identifier.map(|s| s.to_string()); - + let finder = ElementFinder::new( &app_element, move |element| { // Check role - let elem_role = element.role() - .ok() - .map(|s| s.to_string()); - + let elem_role = element.role().ok().map(|s| s.to_string()); + if let Some(r) = elem_role { if !r.contains(&role_str) { return false; @@ -363,13 +366,11 @@ impl MacAxController { } else { return false; } - + // Check title if specified if let Some(ref title_filter) = title_str { - let elem_title = element.title() - .ok() - .map(|s| s.to_string()); - + let elem_title = element.title().ok().map(|s| s.to_string()); + if let Some(t) = elem_title { if !t.contains(title_filter) { return false; @@ -378,13 +379,11 @@ impl MacAxController { return false; } } - + // Check identifier if specified if let Some(ref id_filter) = identifier_str { - let elem_id = element.identifier() - .ok() - .map(|s| s.to_string()); - + let elem_id = element.identifier().ok().map(|s| s.to_string()); + if let Some(id) = elem_id { if !id.contains(id_filter) { return false; @@ -393,15 +392,15 @@ impl MacAxController { return false; } } - + true }, Some(std::time::Duration::from_secs(2)), ); - + finder.find().context("Element not found") } - + /// Click on a UI element #[cfg(target_os = "macos")] pub fn click_element( @@ -412,16 +411,16 @@ impl MacAxController { identifier: Option<&str>, ) -> Result<()> { let element = self.find_element(app_name, role, title, identifier)?; - + // Perform the press action let action_name = CFString::new("AXPress"); element .perform_action(&action_name) .map_err(|e| anyhow::anyhow!("Failed to perform press action: {:?}", e))?; - + Ok(()) } - + #[cfg(not(target_os = "macos"))] pub fn click_element( &self, @@ -432,7 +431,7 @@ impl MacAxController { ) -> Result<()> { anyhow::bail!("Not supported on this platform") } - + /// Set the value of a UI element #[cfg(target_os = "macos")] pub fn set_value( @@ -444,16 +443,17 @@ impl MacAxController { identifier: Option<&str>, ) -> Result<()> { let element = self.find_element(app_name, role, title, identifier)?; - + // Set the value - convert CFString to CFType let cf_value = CFString::new(value); - - element.set_value(cf_value.as_CFType()) + + element + .set_value(cf_value.as_CFType()) .map_err(|e| anyhow::anyhow!("Failed to set value: {:?}", e))?; - + Ok(()) } - + #[cfg(not(target_os = "macos"))] pub fn set_value( &self, @@ -465,7 +465,7 @@ impl MacAxController { ) -> Result<()> { anyhow::bail!("Not supported on this platform") } - + /// Get the value of a UI element #[cfg(target_os = "macos")] pub fn get_value( @@ -476,11 +476,12 @@ impl MacAxController { identifier: Option<&str>, ) -> Result { let element = self.find_element(app_name, role, title, identifier)?; - + // Get the value - let value_type = element.value() + let value_type = element + .value() .map_err(|e| anyhow::anyhow!("Failed to get value: {:?}", e))?; - + // Try to downcast to CFString if let Some(cf_string) = value_type.downcast::() { Ok(cf_string.to_string()) @@ -489,7 +490,7 @@ impl MacAxController { Ok(format!("")) } } - + #[cfg(not(target_os = "macos"))] pub fn get_value( &self, @@ -500,52 +501,52 @@ impl MacAxController { ) -> Result { anyhow::bail!("Not supported on this platform") } - + /// Type text into the currently focused element (uses system text input) #[cfg(target_os = "macos")] pub fn type_text(&self, app_name: &str, text: &str) -> Result<()> { use cocoa::base::{id, nil}; use cocoa::foundation::NSString; use objc::{class, msg_send, sel, sel_impl}; - + // First, make sure the app is active self.activate_app(app_name)?; - + // Wait for app to fully activate std::thread::sleep(std::time::Duration::from_millis(500)); - + // Send a Tab key to try to focus on a text field // This helps ensure something is focused before we paste let _ = self.press_key(app_name, "tab", vec![]); std::thread::sleep(std::time::Duration::from_millis(800)); - + // Save old clipboard, set new content, paste, then restore let old_content: id; unsafe { // Get the general pasteboard let pasteboard: id = msg_send![class!(NSPasteboard), generalPasteboard]; - + // Save current clipboard content let ns_string_type = NSString::alloc(nil).init_str("public.utf8-plain-text"); old_content = msg_send![pasteboard, stringForType: ns_string_type]; - + // Clear and set new content let _: () = msg_send![pasteboard, clearContents]; - + let ns_string = NSString::alloc(nil).init_str(text); let ns_type = NSString::alloc(nil).init_str("public.utf8-plain-text"); let _: bool = msg_send![pasteboard, setString:ns_string forType:ns_type]; } - + // Wait a moment for clipboard to update std::thread::sleep(std::time::Duration::from_millis(200)); - + // Paste using Cmd+V (outside unsafe block) self.press_key(app_name, "v", vec!["command"])?; - + // Wait for paste to complete std::thread::sleep(std::time::Duration::from_millis(300)); - + // Restore old clipboard content if it existed unsafe { if old_content != nil { @@ -555,15 +556,15 @@ impl MacAxController { let _: bool = msg_send![pasteboard, setString:old_content forType:ns_type]; } } - + Ok(()) } - + #[cfg(not(target_os = "macos"))] pub fn type_text(&self, _app_name: &str, _text: &str) -> Result<()> { anyhow::bail!("Not supported on this platform") } - + /// Focus on a text field or text area element #[cfg(target_os = "macos")] pub fn focus_element( @@ -574,40 +575,34 @@ impl MacAxController { identifier: Option<&str>, ) -> Result<()> { let element = self.find_element(app_name, role, title, identifier)?; - + // Set focused attribute to true use core_foundation::boolean::CFBoolean; let cf_true = CFBoolean::true_value(); - - element.set_attribute(&accessibility::AXAttribute::focused(), cf_true) + + element + .set_attribute(&accessibility::AXAttribute::focused(), cf_true) .map_err(|e| anyhow::anyhow!("Failed to focus element: {:?}", e))?; - + Ok(()) } - + /// Press a keyboard shortcut #[cfg(target_os = "macos")] - pub fn press_key( - &self, - app_name: &str, - key: &str, - modifiers: Vec<&str>, - ) -> Result<()> { - use core_graphics::event::{ - CGEvent, CGEventFlags, CGEventTapLocation, - }; + pub fn press_key(&self, app_name: &str, key: &str, modifiers: Vec<&str>) -> Result<()> { + use core_graphics::event::{CGEvent, CGEventFlags, CGEventTapLocation}; use core_graphics::event_source::{CGEventSource, CGEventSourceStateID}; - + // First, make sure the app is active self.activate_app(app_name)?; - + // Wait a bit for activation std::thread::sleep(std::time::Duration::from_millis(100)); - + // Map key string to key code - let key_code = Self::key_to_keycode(key) - .ok_or_else(|| anyhow::anyhow!("Unknown key: {}", key))?; - + let key_code = + Self::key_to_keycode(key).ok_or_else(|| anyhow::anyhow!("Unknown key: {}", key))?; + // Map modifiers to flags let mut flags = CGEventFlags::CGEventFlagNull; for modifier in modifiers { @@ -619,39 +614,37 @@ impl MacAxController { _ => {} } } - + // Create event source let source = CGEventSource::new(CGEventSourceStateID::HIDSystemState) - .ok().context("Failed to create event source")?; - + .ok() + .context("Failed to create event source")?; + // Create key down event let key_down = CGEvent::new_keyboard_event(source.clone(), key_code, true) - .ok().context("Failed to create key down event")?; + .ok() + .context("Failed to create key down event")?; key_down.set_flags(flags); - + // Create key up event let key_up = CGEvent::new_keyboard_event(source, key_code, false) - .ok().context("Failed to create key up event")?; + .ok() + .context("Failed to create key up event")?; key_up.set_flags(flags); - + // Post events key_down.post(CGEventTapLocation::HID); std::thread::sleep(std::time::Duration::from_millis(50)); key_up.post(CGEventTapLocation::HID); - + Ok(()) } - + #[cfg(not(target_os = "macos"))] - pub fn press_key( - &self, - _app_name: &str, - _key: &str, - _modifiers: Vec<&str>, - ) -> Result<()> { + pub fn press_key(&self, _app_name: &str, _key: &str, _modifiers: Vec<&str>) -> Result<()> { anyhow::bail!("Not supported on this platform") } - + #[cfg(target_os = "macos")] fn key_to_keycode(key: &str) -> Option { // Map common keys to keycodes @@ -743,62 +736,55 @@ struct ElementCollector<'a> { impl<'a> TreeVisitor for ElementCollector<'a> { fn enter_element(&self, element: &AXUIElement) -> TreeWalkerFlow { self.depth.set(self.depth.get() + 1); - + if self.depth.get() > 20 { return TreeWalkerFlow::SkipSubtree; } - + // Get element properties - let role = element.role() + let role = element + .role() .ok() .map(|s| s.to_string()) .unwrap_or_else(|| "Unknown".to_string()); - - let title = element.title() - .ok() - .map(|s| s.to_string()); - - let identifier = element.identifier() - .ok() - .map(|s| s.to_string()); - + + let title = element.title().ok().map(|s| s.to_string()); + + let identifier = element.identifier().ok().map(|s| s.to_string()); + // Check if this element matches the filters let role_matches = self.role_filter.as_ref().map_or(true, |r| role.contains(r)); let title_matches = self.title_filter.as_ref().map_or(true, |t| { - title.as_ref().map_or(false, |title_str| title_str.contains(t)) + title + .as_ref() + .map_or(false, |title_str| title_str.contains(t)) }); let identifier_matches = self.identifier_filter.as_ref().map_or(true, |id| { - identifier.as_ref().map_or(false, |id_str| id_str.contains(id)) + identifier + .as_ref() + .map_or(false, |id_str| id_str.contains(id)) }); - + if role_matches && title_matches && identifier_matches { // Get additional properties - let value = element.value() + let value = element + .value() .ok() - .and_then(|v| { - v.downcast::().map(|s| s.to_string()) - }); - - let label = element.description() - .ok() - .map(|s| s.to_string()); - - let enabled = element.enabled() - .ok() - .map(|b| b.into()) - .unwrap_or(false); - - let focused = element.focused() - .ok() - .map(|b| b.into()) - .unwrap_or(false); - + .and_then(|v| v.downcast::().map(|s| s.to_string())); + + let label = element.description().ok().map(|s| s.to_string()); + + let enabled = element.enabled().ok().map(|b| b.into()).unwrap_or(false); + + let focused = element.focused().ok().map(|b| b.into()).unwrap_or(false); + // Count children - let children_count = element.children() + let children_count = element + .children() .ok() .map(|arr| arr.len() as usize) .unwrap_or(0); - + self.results.borrow_mut().push(AXElement { role, title, @@ -812,10 +798,10 @@ impl<'a> TreeVisitor for ElementCollector<'a> { children_count, }); } - + TreeWalkerFlow::Continue } - + fn exit_element(&self, _element: &AXUIElement) { self.depth.set(self.depth.get() - 1); } diff --git a/crates/g3-computer-control/src/macax/mod.rs b/crates/g3-computer-control/src/macax/mod.rs index b62e87d..afe29a2 100644 --- a/crates/g3-computer-control/src/macax/mod.rs +++ b/crates/g3-computer-control/src/macax/mod.rs @@ -34,7 +34,7 @@ impl AXElement { /// Convert to a human-readable string representation pub fn to_string(&self) -> String { let mut parts = vec![format!("Role: {}", self.role)]; - + if let Some(ref title) = self.title { parts.push(format!("Title: {}", title)); } @@ -47,19 +47,19 @@ impl AXElement { if let Some(ref id) = self.identifier { parts.push(format!("ID: {}", id)); } - + parts.push(format!("Enabled: {}", self.enabled)); parts.push(format!("Focused: {}", self.focused)); - + if let Some((x, y)) = self.position { parts.push(format!("Position: ({:.0}, {:.0})", x, y)); } if let Some((w, h)) = self.size { parts.push(format!("Size: ({:.0}, {:.0})", w, h)); } - + parts.push(format!("Children: {}", self.children_count)); - + parts.join(", ") } } diff --git a/crates/g3-computer-control/src/ocr/mod.rs b/crates/g3-computer-control/src/ocr/mod.rs index b651da3..a5c59d4 100644 --- a/crates/g3-computer-control/src/ocr/mod.rs +++ b/crates/g3-computer-control/src/ocr/mod.rs @@ -7,7 +7,7 @@ use async_trait::async_trait; pub trait OCREngine: Send + Sync { /// Extract text with locations from an image file async fn extract_text_with_locations(&self, path: &str) -> Result>; - + /// Get the name of the OCR engine fn name(&self) -> &str; } diff --git a/crates/g3-computer-control/src/ocr/tesseract.rs b/crates/g3-computer-control/src/ocr/tesseract.rs index d55fc3f..7c11129 100644 --- a/crates/g3-computer-control/src/ocr/tesseract.rs +++ b/crates/g3-computer-control/src/ocr/tesseract.rs @@ -12,16 +12,18 @@ impl TesseractOCR { let tesseract_check = std::process::Command::new("which") .arg("tesseract") .output(); - + if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() { - anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\ + anyhow::bail!( + "Tesseract OCR is not installed on your system.\n\n\ To install tesseract:\n macOS: brew install tesseract\n \ Linux: sudo apt-get install tesseract-ocr (Ubuntu/Debian)\n \ sudo yum install tesseract (RHEL/CentOS)\n \ Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki\n\n\ - After installation, restart your terminal and try again."); + After installation, restart your terminal and try again." + ); } - + Ok(Self) } } @@ -36,18 +38,23 @@ impl OCREngine for TesseractOCR { .arg("tsv") .output() .map_err(|e| anyhow::anyhow!("Failed to run tesseract: {}", e))?; - + if !output.status.success() { - anyhow::bail!("Tesseract failed: {}", String::from_utf8_lossy(&output.stderr)); + anyhow::bail!( + "Tesseract failed: {}", + String::from_utf8_lossy(&output.stderr) + ); } - + let tsv_text = String::from_utf8_lossy(&output.stdout); let mut locations = Vec::new(); - + // Parse TSV output (skip header line) for (i, line) in tsv_text.lines().enumerate() { - if i == 0 { continue; } // Skip header - + if i == 0 { + continue; + } // Skip header + let parts: Vec<&str> = line.split('\t').collect(); if parts.len() >= 12 { // TSV format: level, page_num, block_num, par_num, line_num, word_num, @@ -74,10 +81,10 @@ impl OCREngine for TesseractOCR { } } } - + Ok(locations) } - + fn name(&self) -> &str { "Tesseract OCR" } diff --git a/crates/g3-computer-control/src/ocr/vision.rs b/crates/g3-computer-control/src/ocr/vision.rs index d35491d..acc93e9 100644 --- a/crates/g3-computer-control/src/ocr/vision.rs +++ b/crates/g3-computer-control/src/ocr/vision.rs @@ -1,6 +1,6 @@ use super::OCREngine; use crate::types::TextLocation; -use anyhow::{Result, Context}; +use anyhow::{Context, Result}; use async_trait::async_trait; use std::ffi::{CStr, CString}; use std::os::raw::{c_char, c_float, c_uint}; @@ -24,7 +24,7 @@ extern "C" { out_boxes: *mut *mut std::ffi::c_void, out_count: *mut c_uint, ) -> bool; - + fn vision_free_boxes(boxes: *mut std::ffi::c_void, count: c_uint); } @@ -41,12 +41,11 @@ impl AppleVisionOCR { impl OCREngine for AppleVisionOCR { async fn extract_text_with_locations(&self, path: &str) -> Result> { // Convert path to C string - let c_path = CString::new(path) - .context("Failed to convert path to C string")?; - + let c_path = CString::new(path).context("Failed to convert path to C string")?; + let mut boxes_ptr: *mut std::ffi::c_void = std::ptr::null_mut(); let mut count: c_uint = 0; - + // Call Swift Vision API let success = unsafe { vision_recognize_text( @@ -56,28 +55,26 @@ impl OCREngine for AppleVisionOCR { &mut count, ) }; - + if !success || boxes_ptr.is_null() { anyhow::bail!("Apple Vision OCR failed"); } - + // Convert C array to Rust Vec let mut locations = Vec::new(); - + unsafe { let typed_boxes = boxes_ptr as *const VisionTextBox; let boxes_slice = std::slice::from_raw_parts(typed_boxes, count as usize); - + for box_data in boxes_slice { // Convert C string to Rust String let text = if !box_data.text.is_null() { - CStr::from_ptr(box_data.text) - .to_string_lossy() - .into_owned() + CStr::from_ptr(box_data.text).to_string_lossy().into_owned() } else { String::new() }; - + if !text.is_empty() { locations.push(TextLocation { text, @@ -89,14 +86,14 @@ impl OCREngine for AppleVisionOCR { }); } } - + // Free the C array vision_free_boxes(boxes_ptr, count); } - + Ok(locations) } - + fn name(&self) -> &str { "Apple Vision Framework" } diff --git a/crates/g3-computer-control/src/platform/linux.rs b/crates/g3-computer-control/src/platform/linux.rs index cf485ed..cdaf64e 100644 --- a/crates/g3-computer-control/src/platform/linux.rs +++ b/crates/g3-computer-control/src/platform/linux.rs @@ -1,4 +1,4 @@ -use crate::{ComputerController, types::*}; +use crate::{types::*, ComputerController}; use anyhow::Result; use async_trait::async_trait; use tesseract::Tesseract; @@ -21,48 +21,53 @@ impl ComputerController for LinuxController { async fn move_mouse(&self, _x: i32, _y: i32) -> Result<()> { anyhow::bail!("Linux implementation not yet available") } - + async fn click(&self, _button: MouseButton) -> Result<()> { anyhow::bail!("Linux implementation not yet available") } - + async fn double_click(&self, _button: MouseButton) -> Result<()> { anyhow::bail!("Linux implementation not yet available") } - + async fn type_text(&self, _text: &str) -> Result<()> { anyhow::bail!("Linux implementation not yet available") } - + async fn press_key(&self, _key: &str) -> Result<()> { anyhow::bail!("Linux implementation not yet available") } - + async fn list_windows(&self) -> Result> { anyhow::bail!("Linux implementation not yet available") } - + async fn focus_window(&self, _window_id: &str) -> Result<()> { anyhow::bail!("Linux implementation not yet available") } - + async fn get_window_bounds(&self, _window_id: &str) -> Result { anyhow::bail!("Linux implementation not yet available") } - + async fn find_element(&self, _selector: &ElementSelector) -> Result> { anyhow::bail!("Linux implementation not yet available") } - + async fn get_element_text(&self, _element_id: &str) -> Result { anyhow::bail!("Linux implementation not yet available") } - + async fn get_element_bounds(&self, _element_id: &str) -> Result { anyhow::bail!("Linux implementation not yet available") } - - async fn take_screenshot(&self, _path: &str, _region: Option, _window_id: Option<&str>) -> Result<()> { + + async fn take_screenshot( + &self, + _path: &str, + _region: Option, + _window_id: Option<&str>, + ) -> Result<()> { // Enforce that window_id must be provided if _window_id.is_none() { anyhow::bail!("window_id is required. You must specify which window to capture (e.g., 'Firefox', 'Terminal', 'gedit'). Use list_windows to see available windows."); @@ -70,94 +75,111 @@ impl ComputerController for LinuxController { anyhow::bail!("Linux implementation not yet available") } - + async fn extract_text_from_screen(&self, _region: Rect, _window_id: &str) -> Result { anyhow::bail!("Linux implementation not yet available") } - + async fn extract_text_from_image(&self, _path: &str) -> Result { // Check if tesseract is available on the system let tesseract_check = std::process::Command::new("which") .arg("tesseract") .output(); - + if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() { - anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\ + anyhow::bail!( + "Tesseract OCR is not installed on your system.\n\n\ To install tesseract:\n \ Ubuntu/Debian: sudo apt-get install tesseract-ocr\n \ RHEL/CentOS: sudo yum install tesseract\n \ Arch Linux: sudo pacman -S tesseract\n\n\ - After installation, restart your terminal and try again."); + After installation, restart your terminal and try again." + ); } - + // Initialize Tesseract - let tess = Tesseract::new(None, Some("eng")) - .map_err(|e| { - anyhow::anyhow!("Failed to initialize Tesseract: {}\n\n\ + let tess = Tesseract::new(None, Some("eng")).map_err(|e| { + anyhow::anyhow!( + "Failed to initialize Tesseract: {}\n\n\ This usually means:\n1. Tesseract is not properly installed\n\ 2. Language data files are missing\n\nTo fix:\n \ Ubuntu/Debian: sudo apt-get install tesseract-ocr-eng\n \ RHEL/CentOS: sudo yum install tesseract-langpack-eng\n \ - Arch Linux: sudo pacman -S tesseract-data-eng", e) - })?; - - let text = tess.set_image(_path) + Arch Linux: sudo pacman -S tesseract-data-eng", + e + ) + })?; + + let text = tess + .set_image(_path) .map_err(|e| anyhow::anyhow!("Failed to load image '{}': {}", _path, e))? .get_text() .map_err(|e| anyhow::anyhow!("Failed to extract text from image: {}", e))?; - + // Get confidence (simplified - would need more complex API calls for per-word confidence) let confidence = 0.85; // Placeholder - + Ok(OCRResult { text, confidence, - bounds: Rect { x: 0, y: 0, width: 0, height: 0 }, // Would need image dimensions + bounds: Rect { + x: 0, + y: 0, + width: 0, + height: 0, + }, // Would need image dimensions }) } - + async fn find_text_on_screen(&self, _text: &str) -> Result> { // Check if tesseract is available on the system let tesseract_check = std::process::Command::new("which") .arg("tesseract") .output(); - + if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() { - anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\ + anyhow::bail!( + "Tesseract OCR is not installed on your system.\n\n\ To install tesseract:\n \ Ubuntu/Debian: sudo apt-get install tesseract-ocr\n \ RHEL/CentOS: sudo yum install tesseract\n \ Arch Linux: sudo pacman -S tesseract\n\n\ - After installation, restart your terminal and try again."); + After installation, restart your terminal and try again." + ); } - + // Take full screen screenshot let temp_path = format!("/tmp/g3_ocr_search_{}.png", uuid::Uuid::new_v4()); self.take_screenshot(&temp_path, None, None).await?; - + // Use Tesseract to find text with bounding boxes - let tess = Tesseract::new(None, Some("eng")) - .map_err(|e| { - anyhow::anyhow!("Failed to initialize Tesseract: {}\n\n\ + let tess = Tesseract::new(None, Some("eng")).map_err(|e| { + anyhow::anyhow!( + "Failed to initialize Tesseract: {}\n\n\ This usually means:\n1. Tesseract is not properly installed\n\ 2. Language data files are missing\n\nTo fix:\n \ Ubuntu/Debian: sudo apt-get install tesseract-ocr-eng\n \ RHEL/CentOS: sudo yum install tesseract-langpack-eng\n \ - Arch Linux: sudo pacman -S tesseract-data-eng", e) - })?; - - let full_text = tess.set_image(temp_path.as_str()) + Arch Linux: sudo pacman -S tesseract-data-eng", + e + ) + })?; + + let full_text = tess + .set_image(temp_path.as_str()) .map_err(|e| anyhow::anyhow!("Failed to load screenshot: {}", e))? .get_text() .map_err(|e| anyhow::anyhow!("Failed to extract text from screen: {}", e))?; - + // Clean up temp file let _ = std::fs::remove_file(&temp_path); - + // Simple text search - full implementation would use get_component_images // to get bounding boxes for each word if full_text.contains(_text) { - tracing::warn!("Text found but precise coordinates not available in simplified implementation"); + tracing::warn!( + "Text found but precise coordinates not available in simplified implementation" + ); Ok(Some(Point { x: 0, y: 0 })) } else { Ok(None) diff --git a/crates/g3-computer-control/src/platform/macos.rs b/crates/g3-computer-control/src/platform/macos.rs index da9c81b..dda2c0f 100644 --- a/crates/g3-computer-control/src/platform/macos.rs +++ b/crates/g3-computer-control/src/platform/macos.rs @@ -1,13 +1,18 @@ -use crate::{ComputerController, types::{Rect, TextLocation}}; -use crate::ocr::{OCREngine, DefaultOCR}; -use anyhow::{Result, Context}; +use crate::ocr::{DefaultOCR, OCREngine}; +use crate::{ + types::{Rect, TextLocation}, + ComputerController, +}; +use anyhow::{Context, Result}; use async_trait::async_trait; -use std::path::Path; -use core_graphics::window::{kCGWindowListOptionOnScreenOnly, kCGNullWindowID, CGWindowListCopyWindowInfo}; +use core_foundation::array::CFArray; +use core_foundation::base::{TCFType, ToVoid}; use core_foundation::dictionary::CFDictionary; use core_foundation::string::CFString; -use core_foundation::base::{TCFType, ToVoid}; -use core_foundation::array::CFArray; +use core_graphics::window::{ + kCGNullWindowID, kCGWindowListOptionOnScreenOnly, CGWindowListCopyWindowInfo, +}; +use std::path::Path; pub struct MacOSController { ocr_engine: Box, @@ -20,13 +25,21 @@ impl MacOSController { let ocr = Box::new(DefaultOCR::new()?); let ocr_name = ocr.name().to_string(); tracing::info!("Initialized macOS controller with OCR engine: {}", ocr_name); - Ok(Self { ocr_engine: ocr, ocr_name }) + Ok(Self { + ocr_engine: ocr, + ocr_name, + }) } } #[async_trait] impl ComputerController for MacOSController { - async fn take_screenshot(&self, path: &str, region: Option, window_id: Option<&str>) -> Result<()> { + async fn take_screenshot( + &self, + path: &str, + region: Option, + window_id: Option<&str>, + ) -> Result<()> { // Enforce that window_id must be provided if window_id.is_none() { return Err(anyhow::anyhow!("window_id is required. You must specify which window to capture (e.g., 'Safari', 'Terminal', 'Google Chrome'). Use list_windows to see available windows.")); @@ -36,40 +49,38 @@ impl ComputerController for MacOSController { let temp_dir = std::env::var("TMPDIR") .or_else(|_| std::env::var("HOME").map(|h| format!("{}/tmp", h))) .unwrap_or_else(|_| "/tmp".to_string()); - + // Ensure temp directory exists std::fs::create_dir_all(&temp_dir)?; - + // If path is relative or doesn't specify a directory, use temp_dir let final_path = if path.starts_with('/') { path.to_string() } else { format!("{}/{}", temp_dir.trim_end_matches('/'), path) }; - + let path_obj = Path::new(&final_path); if let Some(parent) = path_obj.parent() { std::fs::create_dir_all(parent)?; } - + let app_name = window_id.unwrap(); // Safe because we checked is_none() above - + // Get the window ID for the specified application let cg_window_id = unsafe { - let window_list = CGWindowListCopyWindowInfo( - kCGWindowListOptionOnScreenOnly, - kCGNullWindowID - ); - + let window_list = + CGWindowListCopyWindowInfo(kCGWindowListOptionOnScreenOnly, kCGNullWindowID); + let array = CFArray::::wrap_under_create_rule(window_list); let count = array.len(); - + let mut found_window_id: Option<(u32, String)> = None; // (id, owner) let app_name_lower = app_name.to_lowercase(); - + for i in 0..count { let dict = array.get(i).unwrap(); - + // Get owner name let owner_key = CFString::from_static_string("kCGWindowOwnerName"); let owner: String = if let Some(value) = dict.find(owner_key.to_void()) { @@ -78,57 +89,68 @@ impl ComputerController for MacOSController { } else { continue; }; - - tracing::debug!("Checking window: owner='{}', looking for '{}'", owner, app_name); + + tracing::debug!( + "Checking window: owner='{}', looking for '{}'", + owner, + app_name + ); let owner_lower = owner.to_lowercase(); - + // Normalize by removing spaces for exact matching let app_name_normalized = app_name_lower.replace(" ", ""); let owner_normalized = owner_lower.replace(" ", ""); - + // ONLY accept exact matches (case-insensitive, with or without spaces) // This prevents "Goose" from matching "GooseStudio" - let is_match = owner_lower == app_name_lower || owner_normalized == app_name_normalized; - + let is_match = + owner_lower == app_name_lower || owner_normalized == app_name_normalized; + if is_match { // Get window ID let window_id_key = CFString::from_static_string("kCGWindowNumber"); if let Some(value) = dict.find(window_id_key.to_void()) { - let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _); + let num: core_foundation::number::CFNumber = + TCFType::wrap_under_get_rule(*value as *const _); if let Some(id) = num.to_i64() { // Get window layer to filter out menu bar windows let layer_key = CFString::from_static_string("kCGWindowLayer"); let layer: i32 = if let Some(value) = dict.find(layer_key.to_void()) { - let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _); + let num: core_foundation::number::CFNumber = + TCFType::wrap_under_get_rule(*value as *const _); num.to_i32().unwrap_or(0) } else { 0 }; - + // Get window bounds to verify it's a real window let bounds_key = CFString::from_static_string("kCGWindowBounds"); - let has_real_bounds = if let Some(value) = dict.find(bounds_key.to_void()) { - let bounds_dict: CFDictionary = TCFType::wrap_under_get_rule(*value as *const _); - let width_key = CFString::from_static_string("Width"); - let height_key = CFString::from_static_string("Height"); - - if let (Some(w_val), Some(h_val)) = ( - bounds_dict.find(width_key.to_void()), - bounds_dict.find(height_key.to_void()), - ) { - let w_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*w_val as *const _); - let h_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*h_val as *const _); - let width = w_num.to_f64().unwrap_or(0.0); - let height = h_num.to_f64().unwrap_or(0.0); - // Real windows should be at least 100x100 pixels - width >= 100.0 && height >= 100.0 + let has_real_bounds = + if let Some(value) = dict.find(bounds_key.to_void()) { + let bounds_dict: CFDictionary = + TCFType::wrap_under_get_rule(*value as *const _); + let width_key = CFString::from_static_string("Width"); + let height_key = CFString::from_static_string("Height"); + + if let (Some(w_val), Some(h_val)) = ( + bounds_dict.find(width_key.to_void()), + bounds_dict.find(height_key.to_void()), + ) { + let w_num: core_foundation::number::CFNumber = + TCFType::wrap_under_get_rule(*w_val as *const _); + let h_num: core_foundation::number::CFNumber = + TCFType::wrap_under_get_rule(*h_val as *const _); + let width = w_num.to_f64().unwrap_or(0.0); + let height = h_num.to_f64().unwrap_or(0.0); + // Real windows should be at least 100x100 pixels + width >= 100.0 && height >= 100.0 + } else { + false + } } else { false - } - } else { - false - }; - + }; + // Only accept windows that are: // 1. At layer 0 (normal windows, not menu bar) // 2. Have real bounds (width and height >= 100) @@ -137,189 +159,222 @@ impl ComputerController for MacOSController { found_window_id = Some((id as u32, owner.clone())); break; } else { - tracing::debug!("Skipping window ID {} for '{}': layer={}, has_real_bounds={}", id, owner, layer, has_real_bounds); + tracing::debug!( + "Skipping window ID {} for '{}': layer={}, has_real_bounds={}", + id, + owner, + layer, + has_real_bounds + ); } } } } } - + found_window_id }; - + let (cg_window_id, matched_owner) = cg_window_id.ok_or_else(|| { anyhow::anyhow!("Could not find window for application '{}'. Use list_windows to see available windows.", app_name) })?; - tracing::info!("Taking screenshot of window ID {} for app '{}'", cg_window_id, matched_owner); - + tracing::info!( + "Taking screenshot of window ID {} for app '{}'", + cg_window_id, + matched_owner + ); + // Use screencapture with the window ID for now // TODO: Implement direct CGWindowListCreateImage approach with proper image saving let mut cmd = std::process::Command::new("screencapture"); cmd.arg("-x"); // No sound cmd.arg("-l"); cmd.arg(cg_window_id.to_string()); - + if let Some(region) = region { cmd.arg("-R"); - cmd.arg(format!("{},{},{},{}", region.x, region.y, region.width, region.height)); + cmd.arg(format!( + "{},{},{},{}", + region.x, region.y, region.width, region.height + )); } - + cmd.arg(&final_path); - + let screenshot_result = cmd.output()?; - + if !screenshot_result.status.success() { let stderr = String::from_utf8_lossy(&screenshot_result.stderr); - return Err(anyhow::anyhow!("screencapture failed for window {}: {}", cg_window_id, stderr)); + return Err(anyhow::anyhow!( + "screencapture failed for window {}: {}", + cg_window_id, + stderr + )); } - + Ok(()) } - + async fn extract_text_from_screen(&self, region: Rect, window_id: &str) -> Result { // Take screenshot of region first let temp_path = format!("/tmp/g3_ocr_{}.png", uuid::Uuid::new_v4()); - self.take_screenshot(&temp_path, Some(region), Some(window_id)).await?; - + self.take_screenshot(&temp_path, Some(region), Some(window_id)) + .await?; + // Extract text from the screenshot let result = self.extract_text_from_image(&temp_path).await?; - + // Clean up temp file let _ = std::fs::remove_file(&temp_path); - + Ok(result) } - + async fn extract_text_from_image(&self, path: &str) -> Result { // Extract all text and concatenate let locations = self.ocr_engine.extract_text_with_locations(path).await?; - Ok(locations.iter().map(|loc| loc.text.as_str()).collect::>().join(" ")) + Ok(locations + .iter() + .map(|loc| loc.text.as_str()) + .collect::>() + .join(" ")) } - + async fn extract_text_with_locations(&self, path: &str) -> Result> { // Use the OCR engine self.ocr_engine.extract_text_with_locations(path).await } - - async fn find_text_in_app(&self, app_name: &str, search_text: &str) -> Result> { + + async fn find_text_in_app( + &self, + app_name: &str, + search_text: &str, + ) -> Result> { // Take screenshot of specific app window let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string()); - let temp_path = format!("{}/tmp/g3_find_text_{}_{}.png", home, app_name, uuid::Uuid::new_v4()); - self.take_screenshot(&temp_path, None, Some(app_name)).await?; - + let temp_path = format!( + "{}/tmp/g3_find_text_{}_{}.png", + home, + app_name, + uuid::Uuid::new_v4() + ); + self.take_screenshot(&temp_path, None, Some(app_name)) + .await?; + // Get screenshot dimensions before we delete it let screenshot_dims = get_image_dimensions(&temp_path)?; - + // Extract all text with locations let locations = self.extract_text_with_locations(&temp_path).await?; - + // Get window bounds to calculate coordinate transformation let window_bounds = self.get_window_bounds(app_name)?; - + // Clean up temp file let _ = std::fs::remove_file(&temp_path); - + // Find matching text (case-insensitive) let search_lower = search_text.to_lowercase(); for location in locations { if location.text.to_lowercase().contains(&search_lower) { // Transform coordinates from screenshot space to screen space - let transformed = transform_screenshot_to_screen_coords( - location, - window_bounds, - screenshot_dims, - ); + let transformed = + transform_screenshot_to_screen_coords(location, window_bounds, screenshot_dims); return Ok(Some(transformed)); } } - + Ok(None) } - + fn move_mouse(&self, x: i32, y: i32) -> Result<()> { - use core_graphics::event::{ - CGEvent, CGEventTapLocation, CGEventType, CGMouseButton, - }; - use core_graphics::event_source::{ - CGEventSource, CGEventSourceStateID, - }; + use core_graphics::event::{CGEvent, CGEventTapLocation, CGEventType, CGMouseButton}; + use core_graphics::event_source::{CGEventSource, CGEventSourceStateID}; use core_graphics::geometry::CGPoint; - + let source = CGEventSource::new(CGEventSourceStateID::HIDSystemState) - .ok().context("Failed to create event source")?; - + .ok() + .context("Failed to create event source")?; + let event = CGEvent::new_mouse_event( source, CGEventType::MouseMoved, CGPoint::new(x as f64, y as f64), CGMouseButton::Left, - ).ok().context("Failed to create mouse event")?; - + ) + .ok() + .context("Failed to create mouse event")?; + event.post(CGEventTapLocation::HID); - + Ok(()) } - + fn click_at(&self, x: i32, y: i32, _app_name: Option<&str>) -> Result<()> { - use core_graphics::event::{ - CGEvent, CGEventTapLocation, CGEventType, CGMouseButton, - }; - use core_graphics::event_source::{ - CGEventSource, CGEventSourceStateID, - }; - use core_graphics::geometry::CGPoint; use core_graphics::display::CGDisplay; - + use core_graphics::event::{CGEvent, CGEventTapLocation, CGEventType, CGMouseButton}; + use core_graphics::event_source::{CGEventSource, CGEventSourceStateID}; + use core_graphics::geometry::CGPoint; + // IMPORTANT: Coordinates passed here are in NSScreen/CGWindowListCopyWindowInfo space // (Y=0 at BOTTOM, increases UPWARD) // But CGEvent uses a different coordinate system (Y=0 at TOP, increases DOWNWARD) // We need to convert: CGEvent.y = screenHeight - NSScreen.y - + let screen_height = CGDisplay::main().pixels_high() as i32; let cgevent_x = x; let cgevent_y = screen_height - y; - - tracing::debug!("click_at: NSScreen coords ({}, {}) -> CGEvent coords ({}, {}) [screen_height={}]", - x, y, cgevent_x, cgevent_y, screen_height); - + + tracing::debug!( + "click_at: NSScreen coords ({}, {}) -> CGEvent coords ({}, {}) [screen_height={}]", + x, + y, + cgevent_x, + cgevent_y, + screen_height + ); + let (global_x, global_y) = (cgevent_x, cgevent_y); - + let point = CGPoint::new(global_x as f64, global_y as f64); - + let source = CGEventSource::new(CGEventSourceStateID::HIDSystemState) - .ok().context("Failed to create event source")?; - + .ok() + .context("Failed to create event source")?; + // Move mouse to position first let move_event = CGEvent::new_mouse_event( source.clone(), CGEventType::MouseMoved, point, CGMouseButton::Left, - ).ok().context("Failed to create mouse move event")?; + ) + .ok() + .context("Failed to create mouse move event")?; move_event.post(CGEventTapLocation::HID); - + std::thread::sleep(std::time::Duration::from_millis(100)); - + // Mouse down let mouse_down = CGEvent::new_mouse_event( source.clone(), CGEventType::LeftMouseDown, point, CGMouseButton::Left, - ).ok().context("Failed to create mouse down event")?; + ) + .ok() + .context("Failed to create mouse down event")?; mouse_down.post(CGEventTapLocation::HID); - + std::thread::sleep(std::time::Duration::from_millis(50)); - + // Mouse up - let mouse_up = CGEvent::new_mouse_event( - source, - CGEventType::LeftMouseUp, - point, - CGMouseButton::Left, - ).ok().context("Failed to create mouse up event")?; + let mouse_up = + CGEvent::new_mouse_event(source, CGEventType::LeftMouseUp, point, CGMouseButton::Left) + .ok() + .context("Failed to create mouse up event")?; mouse_up.post(CGEventTapLocation::HID); - + Ok(()) } } @@ -328,19 +383,17 @@ impl MacOSController { /// Get window bounds for an application (helper method) fn get_window_bounds(&self, app_name: &str) -> Result<(i32, i32, i32, i32)> { unsafe { - let window_list = CGWindowListCopyWindowInfo( - kCGWindowListOptionOnScreenOnly, - kCGNullWindowID - ); - + let window_list = + CGWindowListCopyWindowInfo(kCGWindowListOptionOnScreenOnly, kCGNullWindowID); + let array = CFArray::::wrap_under_create_rule(window_list); let count = array.len(); - + let app_name_lower = app_name.to_lowercase(); - + for i in 0..count { let dict = array.get(i).unwrap(); - + // Get owner name let owner_key = CFString::from_static_string("kCGWindowOwnerName"); let owner: String = if let Some(value) = dict.find(owner_key.to_void()) { @@ -349,65 +402,81 @@ impl MacOSController { } else { continue; }; - + let owner_lower = owner.to_lowercase(); - + // Normalize by removing spaces for exact matching let app_name_normalized = app_name_lower.replace(" ", ""); let owner_normalized = owner_lower.replace(" ", ""); - + // ONLY accept exact matches (case-insensitive, with or without spaces) // This prevents "Goose" from matching "GooseStudio" - let is_match = owner_lower == app_name_lower || owner_normalized == app_name_normalized; - + let is_match = + owner_lower == app_name_lower || owner_normalized == app_name_normalized; + if is_match { // Get window layer to filter out menu bar windows let layer_key = CFString::from_static_string("kCGWindowLayer"); let layer: i32 = if let Some(value) = dict.find(layer_key.to_void()) { - let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _); + let num: core_foundation::number::CFNumber = + TCFType::wrap_under_get_rule(*value as *const _); num.to_i32().unwrap_or(0) } else { 0 }; - + // Skip menu bar windows (layer >= 20) if layer >= 20 { - tracing::debug!("Skipping window for '{}' at layer {} (menu bar)", owner, layer); + tracing::debug!( + "Skipping window for '{}' at layer {} (menu bar)", + owner, + layer + ); continue; } - + // Get window bounds to verify it's a real window let bounds_key = CFString::from_static_string("kCGWindowBounds"); if let Some(value) = dict.find(bounds_key.to_void()) { - let bounds_dict: CFDictionary = TCFType::wrap_under_get_rule(*value as *const _); - + let bounds_dict: CFDictionary = + TCFType::wrap_under_get_rule(*value as *const _); + let x_key = CFString::from_static_string("X"); let y_key = CFString::from_static_string("Y"); let width_key = CFString::from_static_string("Width"); let height_key = CFString::from_static_string("Height"); - + if let (Some(x_val), Some(y_val), Some(w_val), Some(h_val)) = ( bounds_dict.find(x_key.to_void()), bounds_dict.find(y_key.to_void()), bounds_dict.find(width_key.to_void()), bounds_dict.find(height_key.to_void()), ) { - let x_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*x_val as *const _); - let y_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*y_val as *const _); - let w_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*w_val as *const _); - let h_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*h_val as *const _); - + let x_num: core_foundation::number::CFNumber = + TCFType::wrap_under_get_rule(*x_val as *const _); + let y_num: core_foundation::number::CFNumber = + TCFType::wrap_under_get_rule(*y_val as *const _); + let w_num: core_foundation::number::CFNumber = + TCFType::wrap_under_get_rule(*w_val as *const _); + let h_num: core_foundation::number::CFNumber = + TCFType::wrap_under_get_rule(*h_val as *const _); + let x: i32 = x_num.to_i64().unwrap_or(0) as i32; let y: i32 = y_num.to_i64().unwrap_or(0) as i32; let w: i32 = w_num.to_i64().unwrap_or(0) as i32; let h: i32 = h_num.to_i64().unwrap_or(0) as i32; - + // Only accept windows with real bounds (>= 100x100 pixels) if w >= 100 && h >= 100 { tracing::info!("Found valid window bounds for '{}': x={}, y={}, w={}, h={} (layer={})", owner, x, y, w, h, layer); return Ok((x, y, w, h)); } else { - tracing::debug!("Skipping window for '{}': too small ({}x{})", owner, w, h); + tracing::debug!( + "Skipping window for '{}': too small ({}x{})", + owner, + w, + h + ); continue; } } else { @@ -417,8 +486,11 @@ impl MacOSController { } } } - - Err(anyhow::anyhow!("Could not find window bounds for '{}'", app_name)) + + Err(anyhow::anyhow!( + "Could not find window bounds for '{}'", + app_name + )) } } @@ -426,72 +498,118 @@ impl MacOSController { fn get_image_dimensions(path: &str) -> Result<(i32, i32)> { use std::fs::File; use std::io::Read; - + let mut file = File::open(path)?; let mut buffer = vec![0u8; 24]; file.read_exact(&mut buffer)?; - + // PNG signature check if &buffer[0..8] != b"\x89PNG\r\n\x1a\n" { anyhow::bail!("Not a valid PNG file"); } - + // Read IHDR chunk (width and height are at bytes 16-23) let width = u32::from_be_bytes([buffer[16], buffer[17], buffer[18], buffer[19]]) as i32; let height = u32::from_be_bytes([buffer[20], buffer[21], buffer[22], buffer[23]]) as i32; - + Ok((width, height)) } /// Transform coordinates from screenshot space to screen space -/// +/// /// The screenshot is taken of a window, and Vision OCR returns coordinates /// relative to the screenshot image. We need to transform these to actual /// screen coordinates for clicking. -/// +/// /// On Retina displays, screenshots are taken at 2x resolution, so we need /// to account for this scaling factor. fn transform_screenshot_to_screen_coords( location: TextLocation, window_bounds: (i32, i32, i32, i32), // (x, y, width, height) in screen space - screenshot_dims: (i32, i32), // (width, height) in pixels + screenshot_dims: (i32, i32), // (width, height) in pixels ) -> TextLocation { let (win_x, win_y, win_width, win_height) = window_bounds; let (screenshot_width, screenshot_height) = screenshot_dims; - + // Calculate scale factors // On Retina displays, screenshot is typically 2x the window size let scale_x = win_width as f64 / screenshot_width as f64; let scale_y = win_height as f64 / screenshot_height as f64; - - tracing::debug!("Transform: screenshot={}x{}, window={}x{} at ({},{}), scale=({:.2},{:.2})", - screenshot_width, screenshot_height, win_width, win_height, win_x, win_y, scale_x, scale_y); - + + tracing::debug!( + "Transform: screenshot={}x{}, window={}x{} at ({},{}), scale=({:.2},{:.2})", + screenshot_width, + screenshot_height, + win_width, + win_height, + win_x, + win_y, + scale_x, + scale_y + ); + // Transform coordinates from image space to screen space // IMPORTANT: macOS screen coordinates have origin at BOTTOM-LEFT (Y increases upward) // Image coordinates have origin at TOP-LEFT (Y increases downward) // win_y is the BOTTOM of the window in screen coordinates // So we need to: (win_y + win_height) to get window TOP, then subtract screenshot_y let window_top_y = win_y + win_height; - - tracing::debug!("[transform] Input location in image space: x={}, y={}, width={}, height={}", - location.x, location.y, location.width, location.height); - tracing::debug!("[transform] Scale factors: scale_x={:.4}, scale_y={:.4}", scale_x, scale_y); - + + tracing::debug!( + "[transform] Input location in image space: x={}, y={}, width={}, height={}", + location.x, + location.y, + location.width, + location.height + ); + tracing::debug!( + "[transform] Scale factors: scale_x={:.4}, scale_y={:.4}", + scale_x, + scale_y + ); + let transformed_x = win_x + (location.x as f64 * scale_x) as i32; let transformed_y = window_top_y - (location.y as f64 * scale_y) as i32; let transformed_width = (location.width as f64 * scale_x) as i32; let transformed_height = (location.height as f64 * scale_y) as i32; - + tracing::debug!("[transform] Calculation details:"); - tracing::debug!(" - transformed_x = {} + ({} * {:.4}) = {} + {:.2} = {}", win_x, location.x, scale_x, win_x, location.x as f64 * scale_x, transformed_x); - tracing::debug!(" - transformed_width = ({} * {:.4}) = {:.2} -> {}", location.width, scale_x, location.width as f64 * scale_x, transformed_width); - tracing::debug!(" - transformed_height = ({} * {:.4}) = {:.2} -> {}", location.height, scale_y, location.height as f64 * scale_y, transformed_height); - - tracing::debug!("Transformed location: screenshot=({},{}) {}x{} -> screen=({},{}) {}x{}", - location.x, location.y, location.width, location.height, - transformed_x, transformed_y, transformed_width, transformed_height); - + tracing::debug!( + " - transformed_x = {} + ({} * {:.4}) = {} + {:.2} = {}", + win_x, + location.x, + scale_x, + win_x, + location.x as f64 * scale_x, + transformed_x + ); + tracing::debug!( + " - transformed_width = ({} * {:.4}) = {:.2} -> {}", + location.width, + scale_x, + location.width as f64 * scale_x, + transformed_width + ); + tracing::debug!( + " - transformed_height = ({} * {:.4}) = {:.2} -> {}", + location.height, + scale_y, + location.height as f64 * scale_y, + transformed_height + ); + + tracing::debug!( + "Transformed location: screenshot=({},{}) {}x{} -> screen=({},{}) {}x{}", + location.x, + location.y, + location.width, + location.height, + transformed_x, + transformed_y, + transformed_width, + transformed_height + ); + TextLocation { text: location.text, x: transformed_x, @@ -504,4 +622,4 @@ fn transform_screenshot_to_screen_coords( #[path = "macos_window_matching_test.rs"] #[cfg(test)] -mod tests; \ No newline at end of file +mod tests; diff --git a/crates/g3-computer-control/src/platform/macos_window_matching_test.rs b/crates/g3-computer-control/src/platform/macos_window_matching_test.rs index 387988f..16bbbcf 100644 --- a/crates/g3-computer-control/src/platform/macos_window_matching_test.rs +++ b/crates/g3-computer-control/src/platform/macos_window_matching_test.rs @@ -1,11 +1,11 @@ #[cfg(test)] mod window_matching_tests { /// Test that window name matching handles spaces correctly - /// + /// /// Issue: When a user requests a screenshot of "Goose Studio" but the actual /// application name is "GooseStudio" (no space), the fuzzy matching should /// still find the window. - /// + /// /// The fix normalizes both names by removing spaces before comparing. #[test] fn test_space_normalization() { @@ -16,25 +16,25 @@ mod window_matching_tests { ("Visual Studio Code", "VisualStudioCode", true), ("Google Chrome", "Google Chrome", true), ("Safari", "Safari", true), - ("iTerm", "iTerm2", true), // fuzzy match + ("iTerm", "iTerm2", true), // fuzzy match ("Code", "Visual Studio Code", true), // fuzzy match ]; for (user_input, app_name, should_match) in test_cases { let user_lower = user_input.to_lowercase(); let app_lower = app_name.to_lowercase(); - + let user_normalized = user_lower.replace(" ", ""); let app_normalized = app_lower.replace(" ", ""); - + let is_exact = app_lower == user_lower || app_normalized == user_normalized; - let is_fuzzy = app_lower.contains(&user_lower) + let is_fuzzy = app_lower.contains(&user_lower) || user_lower.contains(&app_lower) || app_normalized.contains(&user_normalized) || user_normalized.contains(&app_normalized); - + let matches = is_exact || is_fuzzy; - + assert_eq!( matches, should_match, "Expected '{}' vs '{}' to match={}, but got match={}", diff --git a/crates/g3-computer-control/src/platform/windows.rs b/crates/g3-computer-control/src/platform/windows.rs index f3250f7..1209084 100644 --- a/crates/g3-computer-control/src/platform/windows.rs +++ b/crates/g3-computer-control/src/platform/windows.rs @@ -1,4 +1,4 @@ -use crate::{ComputerController, types::*}; +use crate::{types::*, ComputerController}; use anyhow::Result; use async_trait::async_trait; use tesseract::Tesseract; @@ -20,48 +20,53 @@ impl ComputerController for WindowsController { async fn move_mouse(&self, _x: i32, _y: i32) -> Result<()> { anyhow::bail!("Windows implementation not yet available") } - + async fn click(&self, _button: MouseButton) -> Result<()> { anyhow::bail!("Windows implementation not yet available") } - + async fn double_click(&self, _button: MouseButton) -> Result<()> { anyhow::bail!("Windows implementation not yet available") } - + async fn type_text(&self, _text: &str) -> Result<()> { anyhow::bail!("Windows implementation not yet available") } - + async fn press_key(&self, _key: &str) -> Result<()> { anyhow::bail!("Windows implementation not yet available") } - + async fn list_windows(&self) -> Result> { anyhow::bail!("Windows implementation not yet available") } - + async fn focus_window(&self, _window_id: &str) -> Result<()> { anyhow::bail!("Windows implementation not yet available") } - + async fn get_window_bounds(&self, _window_id: &str) -> Result { anyhow::bail!("Windows implementation not yet available") } - + async fn find_element(&self, _selector: &ElementSelector) -> Result> { anyhow::bail!("Windows implementation not yet available") } - + async fn get_element_text(&self, _element_id: &str) -> Result { anyhow::bail!("Windows implementation not yet available") } - + async fn get_element_bounds(&self, _element_id: &str) -> Result { anyhow::bail!("Windows implementation not yet available") } - - async fn take_screenshot(&self, _path: &str, _region: Option, _window_id: Option<&str>) -> Result<()> { + + async fn take_screenshot( + &self, + _path: &str, + _region: Option, + _window_id: Option<&str>, + ) -> Result<()> { // Enforce that window_id must be provided if _window_id.is_none() { anyhow::bail!("window_id is required. You must specify which window to capture (e.g., 'Chrome', 'Terminal', 'Notepad'). Use list_windows to see available windows."); @@ -69,96 +74,113 @@ impl ComputerController for WindowsController { anyhow::bail!("Windows implementation not yet available") } - + async fn extract_text_from_screen(&self, _region: Rect, _window_id: &str) -> Result { anyhow::bail!("Windows implementation not yet available") } - + async fn extract_text_from_image(&self, _path: &str) -> Result { // Check if tesseract is available on the system let tesseract_check = std::process::Command::new("where") .arg("tesseract") .output(); - + if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() { - anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\ + anyhow::bail!( + "Tesseract OCR is not installed on your system.\n\n\ To install tesseract on Windows:\n \ 1. Download the installer from: https://github.com/UB-Mannheim/tesseract/wiki\n \ 2. Run the installer and follow the instructions\n \ 3. Add tesseract to your PATH environment variable\n \ 4. Restart your terminal/command prompt\n\n\ - After installation, restart your terminal and try again."); + After installation, restart your terminal and try again." + ); } - + // Initialize Tesseract - let tess = Tesseract::new(None, Some("eng")) - .map_err(|e| { - anyhow::anyhow!("Failed to initialize Tesseract: {}\n\n\ + let tess = Tesseract::new(None, Some("eng")).map_err(|e| { + anyhow::anyhow!( + "Failed to initialize Tesseract: {}\n\n\ This usually means:\n1. Tesseract is not properly installed\n\ 2. Language data files are missing\n\nTo fix:\n \ 1. Reinstall tesseract from https://github.com/UB-Mannheim/tesseract/wiki\n \ 2. Make sure to select 'Additional language data' during installation\n \ - 3. Ensure tesseract is in your PATH", e) - })?; - - let text = tess.set_image(_path) + 3. Ensure tesseract is in your PATH", + e + ) + })?; + + let text = tess + .set_image(_path) .map_err(|e| anyhow::anyhow!("Failed to load image '{}': {}", _path, e))? .get_text() .map_err(|e| anyhow::anyhow!("Failed to extract text from image: {}", e))?; - + // Get confidence (simplified - would need more complex API calls for per-word confidence) let confidence = 0.85; // Placeholder - + Ok(OCRResult { text, confidence, - bounds: Rect { x: 0, y: 0, width: 0, height: 0 }, // Would need image dimensions + bounds: Rect { + x: 0, + y: 0, + width: 0, + height: 0, + }, // Would need image dimensions }) } - + async fn find_text_on_screen(&self, _text: &str) -> Result> { // Check if tesseract is available on the system let tesseract_check = std::process::Command::new("where") .arg("tesseract") .output(); - + if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() { - anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\ + anyhow::bail!( + "Tesseract OCR is not installed on your system.\n\n\ To install tesseract on Windows:\n \ 1. Download the installer from: https://github.com/UB-Mannheim/tesseract/wiki\n \ 2. Run the installer and follow the instructions\n \ 3. Add tesseract to your PATH environment variable\n \ 4. Restart your terminal/command prompt\n\n\ - After installation, restart your terminal and try again."); + After installation, restart your terminal and try again." + ); } - + // Take full screen screenshot let temp_path = format!("C:\\\\Temp\\\\g3_ocr_search_{}.png", uuid::Uuid::new_v4()); self.take_screenshot(&temp_path, None, None).await?; - + // Use Tesseract to find text with bounding boxes - let tess = Tesseract::new(None, Some("eng")) - .map_err(|e| { - anyhow::anyhow!("Failed to initialize Tesseract: {}\n\n\ + let tess = Tesseract::new(None, Some("eng")).map_err(|e| { + anyhow::anyhow!( + "Failed to initialize Tesseract: {}\n\n\ This usually means:\n1. Tesseract is not properly installed\n\ 2. Language data files are missing\n\nTo fix:\n \ 1. Reinstall tesseract from https://github.com/UB-Mannheim/tesseract/wiki\n \ 2. Make sure to select 'Additional language data' during installation\n \ - 3. Ensure tesseract is in your PATH", e) - })?; - - let full_text = tess.set_image(temp_path.as_str()) + 3. Ensure tesseract is in your PATH", + e + ) + })?; + + let full_text = tess + .set_image(temp_path.as_str()) .map_err(|e| anyhow::anyhow!("Failed to load screenshot: {}", e))? .get_text() .map_err(|e| anyhow::anyhow!("Failed to extract text from screen: {}", e))?; - + // Clean up temp file let _ = std::fs::remove_file(&temp_path); - + // Simple text search - full implementation would use get_component_images // to get bounding boxes for each word if full_text.contains(_text) { - tracing::warn!("Text found but precise coordinates not available in simplified implementation"); + tracing::warn!( + "Text found but precise coordinates not available in simplified implementation" + ); Ok(Some(Point { x: 0, y: 0 })) } else { Ok(None) diff --git a/crates/g3-computer-control/src/webdriver/mod.rs b/crates/g3-computer-control/src/webdriver/mod.rs index 1951bef..ac25f00 100644 --- a/crates/g3-computer-control/src/webdriver/mod.rs +++ b/crates/g3-computer-control/src/webdriver/mod.rs @@ -9,31 +9,31 @@ use serde_json::Value; pub trait WebDriverController: Send + Sync { /// Navigate to a URL async fn navigate(&mut self, url: &str) -> Result<()>; - + /// Get the current URL async fn current_url(&self) -> Result; - + /// Get the page title async fn title(&self) -> Result; - + /// Find an element by CSS selector async fn find_element(&mut self, selector: &str) -> Result; - + /// Find multiple elements by CSS selector async fn find_elements(&mut self, selector: &str) -> Result>; - + /// Execute JavaScript in the browser async fn execute_script(&mut self, script: &str, args: Vec) -> Result; - + /// Get the page source (HTML) async fn page_source(&self) -> Result; - + /// Take a screenshot and save to path async fn screenshot(&mut self, path: &str) -> Result<()>; - + /// Close the current window/tab async fn close(&mut self) -> Result<()>; - + /// Quit the browser session async fn quit(self) -> Result<()>; } @@ -49,63 +49,69 @@ impl WebElement { self.inner.click().await?; Ok(()) } - + /// Send keys/text to the element pub async fn send_keys(&mut self, text: &str) -> Result<()> { self.inner.send_keys(text).await?; Ok(()) } - + /// Clear the element's content (for input fields) pub async fn clear(&mut self) -> Result<()> { self.inner.clear().await?; Ok(()) } - + /// Get the element's text content pub async fn text(&self) -> Result { Ok(self.inner.text().await?) } - + /// Get an attribute value pub async fn attr(&self, name: &str) -> Result> { Ok(self.inner.attr(name).await?) } - + /// Get a property value pub async fn prop(&self, name: &str) -> Result> { Ok(self.inner.prop(name).await?) } - + /// Get the element's HTML pub async fn html(&self, inner: bool) -> Result { Ok(self.inner.html(inner).await?) } - + /// Check if element is displayed pub async fn is_displayed(&self) -> Result { Ok(self.inner.is_displayed().await?) } - + /// Check if element is enabled pub async fn is_enabled(&self) -> Result { Ok(self.inner.is_enabled().await?) } - + /// Check if element is selected (for checkboxes/radio buttons) pub async fn is_selected(&self) -> Result { Ok(self.inner.is_selected().await?) } - + /// Find a child element by CSS selector pub async fn find_element(&mut self, selector: &str) -> Result { let elem = self.inner.find(fantoccini::Locator::Css(selector)).await?; Ok(WebElement { inner: elem }) } - + /// Find multiple child elements by CSS selector pub async fn find_elements(&mut self, selector: &str) -> Result> { - let elems = self.inner.find_all(fantoccini::Locator::Css(selector)).await?; - Ok(elems.into_iter().map(|inner| WebElement { inner }).collect()) + let elems = self + .inner + .find_all(fantoccini::Locator::Css(selector)) + .await?; + Ok(elems + .into_iter() + .map(|inner| WebElement { inner }) + .collect()) } } diff --git a/crates/g3-computer-control/src/webdriver/safari.rs b/crates/g3-computer-control/src/webdriver/safari.rs index 762bd8d..6df8290 100644 --- a/crates/g3-computer-control/src/webdriver/safari.rs +++ b/crates/g3-computer-control/src/webdriver/safari.rs @@ -12,10 +12,10 @@ pub struct SafariDriver { impl SafariDriver { /// Create a new SafariDriver instance - /// + /// /// This will connect to SafariDriver running on the default port (4444). /// Make sure to enable "Allow Remote Automation" in Safari's Develop menu first. - /// + /// /// You can start SafariDriver manually with: /// ```bash /// /usr/bin/safaridriver --enable @@ -23,125 +23,134 @@ impl SafariDriver { pub async fn new() -> Result { Self::with_port(4444).await } - + /// Create a new SafariDriver instance with a custom port pub async fn with_port(port: u16) -> Result { let url = format!("http://localhost:{}", port); - + let mut caps = serde_json::Map::new(); - caps.insert("browserName".to_string(), Value::String("safari".to_string())); - + caps.insert( + "browserName".to_string(), + Value::String("safari".to_string()), + ); + let client = ClientBuilder::native() .capabilities(caps) .connect(&url) .await .context("Failed to connect to SafariDriver. Make sure SafariDriver is running and 'Allow Remote Automation' is enabled in Safari's Develop menu.")?; - + Ok(Self { client }) } - + /// Go back in browser history pub async fn back(&mut self) -> Result<()> { self.client.back().await?; Ok(()) } - + /// Go forward in browser history pub async fn forward(&mut self) -> Result<()> { self.client.forward().await?; Ok(()) } - + /// Refresh the current page pub async fn refresh(&mut self) -> Result<()> { self.client.refresh().await?; Ok(()) } - + /// Get all window handles pub async fn window_handles(&mut self) -> Result> { let handles = self.client.windows().await?; - Ok(handles.into_iter() - .map(|h| h.into()) - .collect()) + Ok(handles.into_iter().map(|h| h.into()).collect()) } - + /// Switch to a window by handle pub async fn switch_to_window(&mut self, handle: &str) -> Result<()> { let window_handle: fantoccini::wd::WindowHandle = handle.to_string().try_into()?; self.client.switch_to_window(window_handle).await?; Ok(()) } - + /// Get the current window handle pub async fn current_window_handle(&mut self) -> Result { Ok(self.client.window().await?.into()) } - + /// Close the current window pub async fn close_window(&mut self) -> Result<()> { self.client.close_window().await?; Ok(()) } - + /// Create a new window/tab pub async fn new_window(&mut self, is_tab: bool) -> Result { let window_type = if is_tab { "tab" } else { "window" }; let response = self.client.new_window(window_type == "tab").await?; Ok(response.handle.into()) } - + /// Get cookies pub async fn get_cookies(&mut self) -> Result>> { Ok(self.client.get_all_cookies().await?) } - + /// Add a cookie pub async fn add_cookie(&mut self, cookie: fantoccini::cookies::Cookie<'static>) -> Result<()> { self.client.add_cookie(cookie).await?; Ok(()) } - + /// Delete all cookies pub async fn delete_all_cookies(&mut self) -> Result<()> { self.client.delete_all_cookies().await?; Ok(()) } - + /// Wait for an element to appear (with timeout) - pub async fn wait_for_element(&mut self, selector: &str, timeout: Duration) -> Result { + pub async fn wait_for_element( + &mut self, + selector: &str, + timeout: Duration, + ) -> Result { let start = std::time::Instant::now(); let poll_interval = Duration::from_millis(100); - + loop { if let Ok(elem) = self.find_element(selector).await { return Ok(elem); } - + if start.elapsed() >= timeout { anyhow::bail!("Timeout waiting for element: {}", selector); } - + tokio::time::sleep(poll_interval).await; } } - + /// Wait for an element to be visible (with timeout) - pub async fn wait_for_visible(&mut self, selector: &str, timeout: Duration) -> Result { + pub async fn wait_for_visible( + &mut self, + selector: &str, + timeout: Duration, + ) -> Result { let start = std::time::Instant::now(); let poll_interval = Duration::from_millis(100); - + loop { if let Ok(elem) = self.find_element(selector).await { if elem.is_displayed().await.unwrap_or(false) { return Ok(elem); } } - + if start.elapsed() >= timeout { anyhow::bail!("Timeout waiting for element to be visible: {}", selector); } - + tokio::time::sleep(poll_interval).await; } } @@ -153,58 +162,69 @@ impl WebDriverController for SafariDriver { self.client.goto(url).await?; Ok(()) } - + async fn current_url(&self) -> Result { Ok(self.client.current_url().await?.to_string()) } - + async fn title(&self) -> Result { Ok(self.client.title().await?) } - + async fn find_element(&mut self, selector: &str) -> Result { - let elem = self.client.find(fantoccini::Locator::Css(selector)).await - .context(format!("Failed to find element with selector: {}", selector))?; + let elem = self + .client + .find(fantoccini::Locator::Css(selector)) + .await + .context(format!( + "Failed to find element with selector: {}", + selector + ))?; Ok(WebElement { inner: elem }) } - + async fn find_elements(&mut self, selector: &str) -> Result> { - let elems = self.client.find_all(fantoccini::Locator::Css(selector)).await?; - Ok(elems.into_iter().map(|inner| WebElement { inner }).collect()) + let elems = self + .client + .find_all(fantoccini::Locator::Css(selector)) + .await?; + Ok(elems + .into_iter() + .map(|inner| WebElement { inner }) + .collect()) } - + async fn execute_script(&mut self, script: &str, args: Vec) -> Result { Ok(self.client.execute(script, args).await?) } - + async fn page_source(&self) -> Result { Ok(self.client.source().await?) } - + async fn screenshot(&mut self, path: &str) -> Result<()> { let screenshot_data = self.client.screenshot().await?; - + // Expand tilde in path let expanded_path = shellexpand::tilde(path); let path_str = expanded_path.as_ref(); - + // Create parent directories if needed if let Some(parent) = std::path::Path::new(path_str).parent() { std::fs::create_dir_all(parent) .context("Failed to create parent directories for screenshot")?; } - - std::fs::write(path_str, screenshot_data) - .context("Failed to write screenshot to file")?; - + + std::fs::write(path_str, screenshot_data).context("Failed to write screenshot to file")?; + Ok(()) } - + async fn close(&mut self) -> Result<()> { self.client.close_window().await?; Ok(()) } - + async fn quit(mut self) -> Result<()> { self.client.close().await?; Ok(()) diff --git a/crates/g3-computer-control/tests/integration_test.rs b/crates/g3-computer-control/tests/integration_test.rs index 17aaf1f..ebee24a 100644 --- a/crates/g3-computer-control/tests/integration_test.rs +++ b/crates/g3-computer-control/tests/integration_test.rs @@ -3,29 +3,35 @@ use g3_computer_control::*; #[tokio::test] async fn test_screenshot() { let controller = create_controller().expect("Failed to create controller"); - + // Test that screenshot without window_id fails with appropriate error let path = "/tmp/test_screenshot.png"; let result = controller.take_screenshot(path, None, None).await; - assert!(result.is_err(), "Expected error when window_id is not provided"); - + assert!( + result.is_err(), + "Expected error when window_id is not provided" + ); + let error_msg = result.unwrap_err().to_string(); - assert!(error_msg.contains("window_id is required"), - "Expected error message about window_id being required, got: {}", error_msg); + assert!( + error_msg.contains("window_id is required"), + "Expected error message about window_id being required, got: {}", + error_msg + ); } #[tokio::test] async fn test_screenshot_with_window() { let controller = create_controller().expect("Failed to create controller"); - + // Take screenshot of Finder (should always be available on macOS) let path = "/tmp/test_screenshot_finder.png"; let result = controller.take_screenshot(path, None, Some("Finder")).await; - + // This test may fail if Finder is not running, so we just check it doesn't panic // and returns a proper Result let _ = result; // Don't assert success since Finder might not be visible - + // Clean up let _ = std::fs::remove_file(path); } diff --git a/crates/g3-config/src/lib.rs b/crates/g3-config/src/lib.rs index 0c034de..f556ac7 100644 --- a/crates/g3-config/src/lib.rs +++ b/crates/g3-config/src/lib.rs @@ -1,5 +1,5 @@ -use serde::{Deserialize, Serialize}; use anyhow::Result; +use serde::{Deserialize, Serialize}; use std::path::Path; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -21,7 +21,7 @@ pub struct ProvidersConfig { pub databricks: Option, pub embedded: Option, pub default_provider: String, - pub coach: Option, // Provider to use for coach in autonomous mode + pub coach: Option, // Provider to use for coach in autonomous mode pub player: Option, // Provider to use for player in autonomous mode } @@ -103,9 +103,7 @@ pub struct MacAxConfig { impl Default for MacAxConfig { fn default() -> Self { - Self { - enabled: false, - } + Self { enabled: false } } } @@ -173,22 +171,18 @@ impl Config { Path::new(path).exists() } else { // Check default locations - let default_paths = [ - "./g3.toml", - "~/.config/g3/config.toml", - "~/.g3.toml", - ]; - + let default_paths = ["./g3.toml", "~/.config/g3/config.toml", "~/.g3.toml"]; + default_paths.iter().any(|path| { let expanded_path = shellexpand::tilde(path); Path::new(expanded_path.as_ref()).exists() }) }; - + // If no config exists, create and save a default Databricks config if !config_exists { let databricks_config = Self::default(); - + // Save to default location let config_dir = dirs::home_dir() .map(|mut path| { @@ -197,26 +191,29 @@ impl Config { path }) .unwrap_or_else(|| std::path::PathBuf::from(".")); - + // Create directory if it doesn't exist std::fs::create_dir_all(&config_dir).ok(); - + let config_file = config_dir.join("config.toml"); if let Err(e) = databricks_config.save(config_file.to_str().unwrap()) { eprintln!("Warning: Could not save default config: {}", e); } else { - println!("Created default Databricks configuration at: {}", config_file.display()); + println!( + "Created default Databricks configuration at: {}", + config_file.display() + ); } - + return Ok(databricks_config); } - + // Existing config loading logic let mut settings = config::Config::builder(); - + // Load default configuration settings = settings.add_source(config::Config::try_from(&Config::default())?); - + // Load from config file if provided if let Some(path) = config_path { if Path::new(path).exists() { @@ -224,12 +221,8 @@ impl Config { } } else { // Try to load from default locations - let default_paths = [ - "./g3.toml", - "~/.config/g3/config.toml", - "~/.g3.toml", - ]; - + let default_paths = ["./g3.toml", "~/.config/g3/config.toml", "~/.g3.toml"]; + for path in &default_paths { let expanded_path = shellexpand::tilde(path); if Path::new(expanded_path.as_ref()).exists() { @@ -238,13 +231,10 @@ impl Config { } } } - + // Override with environment variables - settings = settings.add_source( - config::Environment::with_prefix("G3") - .separator("_") - ); - + settings = settings.add_source(config::Environment::with_prefix("G3").separator("_")); + let config = settings.build()?.try_deserialize()?; Ok(config) } @@ -260,7 +250,7 @@ impl Config { embedded: Some(EmbeddedConfig { model_path: "~/.cache/g3/models/qwen2.5-7b-instruct-q3_k_m.gguf".to_string(), model_type: "qwen".to_string(), - context_length: Some(32768), // Qwen2.5 supports 32k context + context_length: Some(32768), // Qwen2.5 supports 32k context max_tokens: Some(2048), temperature: Some(0.1), gpu_layers: Some(32), @@ -286,13 +276,13 @@ impl Config { macax: MacAxConfig::default(), } } - + pub fn save(&self, path: &str) -> Result<()> { let toml_string = toml::to_string_pretty(self)?; std::fs::write(path, toml_string)?; Ok(()) } - + pub fn load_with_overrides( config_path: Option<&str>, provider_override: Option, @@ -300,12 +290,12 @@ impl Config { ) -> Result { // Load the base configuration let mut config = Self::load(config_path)?; - + // Apply provider override if let Some(provider) = provider_override { config.providers.default_provider = provider; } - + // Apply model override to the active provider if let Some(model) = model_override { match config.providers.default_provider.as_str() { @@ -345,28 +335,34 @@ impl Config { )); } } - _ => return Err(anyhow::anyhow!("Unknown provider: {}", - config.providers.default_provider)), + _ => { + return Err(anyhow::anyhow!( + "Unknown provider: {}", + config.providers.default_provider + )) + } } } - + Ok(config) } - + /// Get the provider to use for coach mode in autonomous execution pub fn get_coach_provider(&self) -> &str { - self.providers.coach + self.providers + .coach .as_deref() .unwrap_or(&self.providers.default_provider) } - + /// Get the provider to use for player mode in autonomous execution pub fn get_player_provider(&self) -> &str { - self.providers.player + self.providers + .player .as_deref() .unwrap_or(&self.providers.default_provider) } - + /// Create a copy of the config with a different default provider pub fn with_provider_override(&self, provider: &str) -> Result { // Validate that the provider is configured @@ -397,17 +393,17 @@ impl Config { } _ => {} // Provider is configured or unknown (will be caught later) } - + let mut config = self.clone(); config.providers.default_provider = provider.to_string(); Ok(config) } - + /// Create a copy of the config for coach mode in autonomous execution pub fn for_coach(&self) -> Result { self.with_provider_override(self.get_coach_provider()) } - + /// Create a copy of the config for player mode in autonomous execution pub fn for_player(&self) -> Result { self.with_provider_override(self.get_player_provider()) diff --git a/crates/g3-config/src/tests.rs b/crates/g3-config/src/tests.rs index 6899a8b..d97604b 100644 --- a/crates/g3-config/src/tests.rs +++ b/crates/g3-config/src/tests.rs @@ -9,7 +9,7 @@ mod tests { // Create a temporary directory for the test config let temp_dir = TempDir::new().unwrap(); let config_path = temp_dir.path().join("test_config.toml"); - + // Write a test configuration with coach and player providers let config_content = r#" [providers] @@ -35,32 +35,32 @@ fallback_default_max_tokens = 8192 enable_streaming = true timeout_seconds = 60 "#; - + fs::write(&config_path, config_content).unwrap(); - + // Load the configuration let config = Config::load(Some(config_path.to_str().unwrap())).unwrap(); - + // Test that the providers are correctly identified assert_eq!(config.providers.default_provider, "databricks"); assert_eq!(config.get_coach_provider(), "anthropic"); assert_eq!(config.get_player_provider(), "embedded"); - + // Test creating coach config let coach_config = config.for_coach().unwrap(); assert_eq!(coach_config.providers.default_provider, "anthropic"); - + // Test creating player config let player_config = config.for_player().unwrap(); assert_eq!(player_config.providers.default_provider, "embedded"); } - + #[test] fn test_coach_player_fallback_to_default() { // Create a temporary directory for the test config let temp_dir = TempDir::new().unwrap(); let config_path = temp_dir.path().join("test_config.toml"); - + // Write a test configuration WITHOUT coach and player providers let config_content = r#" [providers] @@ -76,31 +76,31 @@ fallback_default_max_tokens = 8192 enable_streaming = true timeout_seconds = 60 "#; - + fs::write(&config_path, config_content).unwrap(); - + // Load the configuration let config = Config::load(Some(config_path.to_str().unwrap())).unwrap(); - + // Test that coach and player fall back to default provider assert_eq!(config.get_coach_provider(), "databricks"); assert_eq!(config.get_player_provider(), "databricks"); - + // Test creating coach config (should use default) let coach_config = config.for_coach().unwrap(); assert_eq!(coach_config.providers.default_provider, "databricks"); - + // Test creating player config (should use default) let player_config = config.for_player().unwrap(); assert_eq!(player_config.providers.default_provider, "databricks"); } - + #[test] fn test_invalid_provider_error() { // Create a temporary directory for the test config let temp_dir = TempDir::new().unwrap(); let config_path = temp_dir.path().join("test_config.toml"); - + // Write a test configuration with an unconfigured provider let config_content = r#" [providers] @@ -117,15 +117,15 @@ fallback_default_max_tokens = 8192 enable_streaming = true timeout_seconds = 60 "#; - + fs::write(&config_path, config_content).unwrap(); - + // Load the configuration let config = Config::load(Some(config_path.to_str().unwrap())).unwrap(); - + // Test that trying to create a coach config with unconfigured provider fails let result = config.for_coach(); assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("not configured")); } -} \ No newline at end of file +} diff --git a/crates/g3-config/tests/test_multiple_tool_calls.rs b/crates/g3-config/tests/test_multiple_tool_calls.rs index 53ce8a3..0f774e2 100644 --- a/crates/g3-config/tests/test_multiple_tool_calls.rs +++ b/crates/g3-config/tests/test_multiple_tool_calls.rs @@ -1,20 +1,20 @@ #[cfg(test)] mod test_multiple_tool_calls { - use g3_config::{Config, AgentConfig}; - + use g3_config::{AgentConfig, Config}; + #[test] fn test_config_has_multiple_tool_calls_field() { let config = Config::default(); - + // Test that the field exists and defaults to false assert_eq!(config.agent.allow_multiple_tool_calls, false); - + // Test that we can create a config with the field set to true let mut custom_config = Config::default(); custom_config.agent.allow_multiple_tool_calls = true; assert_eq!(custom_config.agent.allow_multiple_tool_calls, true); } - + #[test] fn test_agent_config_serialization() { let agent_config = AgentConfig { @@ -28,11 +28,11 @@ mod test_multiple_tool_calls { autonomous_max_retry_attempts: 6, check_todo_staleness: true, }; - + // Test serialization let json = serde_json::to_string(&agent_config).unwrap(); assert!(json.contains("\"allow_multiple_tool_calls\":true")); - + // Test deserialization let deserialized: AgentConfig = serde_json::from_str(&json).unwrap(); assert_eq!(deserialized.allow_multiple_tool_calls, true); diff --git a/crates/g3-console/examples/debug_detector.rs b/crates/g3-console/examples/debug_detector.rs index aefa22d..ea6e1d3 100644 --- a/crates/g3-console/examples/debug_detector.rs +++ b/crates/g3-console/examples/debug_detector.rs @@ -1,19 +1,19 @@ -use sysinfo::{System, Pid}; +use sysinfo::{Pid, System}; fn main() { let mut sys = System::new_all(); sys.refresh_processes(); - + println!("Looking for g3 processes..."); - + for (pid, process) in sys.processes() { let cmd = process.cmd(); if cmd.is_empty() { continue; } - + let cmd_str = cmd.join(" "); - + // Check if this contains 'g3' if cmd_str.contains("g3") { println!("\nFound potential g3 process:"); @@ -21,15 +21,15 @@ fn main() { println!(" Name: {}", process.name()); println!(" Cmd[0]: {:?}", cmd.get(0)); println!(" Full cmd: {:?}", cmd); - + // Check detection logic let is_g3_binary = cmd.get(0).map(|s| s.ends_with("g3")).unwrap_or(false); let is_cargo_run = cmd.get(0).map(|s| s.contains("cargo")).unwrap_or(false) && cmd.iter().any(|s| s == "run" || s.contains("g3")); - + println!(" is_g3_binary: {}", is_g3_binary); println!(" is_cargo_run: {}", is_cargo_run); - + // Check workspace let has_workspace = cmd.iter().any(|s| s == "--workspace" || s == "-w"); println!(" has_workspace: {}", has_workspace); diff --git a/crates/g3-console/examples/test_api.rs b/crates/g3-console/examples/test_api.rs index 1cf1bb3..3b06936 100644 --- a/crates/g3-console/examples/test_api.rs +++ b/crates/g3-console/examples/test_api.rs @@ -3,13 +3,15 @@ use g3_console::process::ProcessDetector; fn main() { let mut detector = ProcessDetector::new(); - + match detector.detect_instances() { Ok(instances) => { println!("Found {} instances:", instances.len()); for instance in instances { - println!(" - PID: {}, Workspace: {:?}, Type: {:?}", - instance.pid, instance.workspace, instance.instance_type); + println!( + " - PID: {}, Workspace: {:?}, Type: {:?}", + instance.pid, instance.workspace, instance.instance_type + ); } } Err(e) => { diff --git a/crates/g3-console/examples/test_detector.rs b/crates/g3-console/examples/test_detector.rs index 6e4e74d..b7ac3b1 100644 --- a/crates/g3-console/examples/test_detector.rs +++ b/crates/g3-console/examples/test_detector.rs @@ -1,12 +1,12 @@ -use sysinfo::{System, Pid}; +use sysinfo::{Pid, System}; fn main() { let mut sys = System::new_all(); sys.refresh_processes(); - + // Test with known PIDs let pids = vec![68123, 72749]; - + for pid_num in pids { let pid = Pid::from_u32(pid_num); if let Some(process) = sys.process(pid) { diff --git a/crates/g3-console/src/api/control.rs b/crates/g3-console/src/api/control.rs index 384198e..f5eb489 100644 --- a/crates/g3-console/src/api/control.rs +++ b/crates/g3-console/src/api/control.rs @@ -19,7 +19,7 @@ pub async fn kill_instance( .ok_or(StatusCode::BAD_REQUEST)?; let mut controller = controller.lock().await; - + match controller.kill_process(pid) { Ok(_) => { info!("Successfully killed process {}", pid); @@ -39,35 +39,38 @@ pub async fn restart_instance( axum::extract::Path(id): axum::extract::Path, ) -> Result, StatusCode> { info!("Restarting instance: {}", id); - + // Extract PID from instance ID (format: pid_timestamp) let pid: u32 = id .split('_') .next() .and_then(|s| s.parse().ok()) .ok_or(StatusCode::BAD_REQUEST)?; - + let mut controller = controller.lock().await; - + // Get stored launch params - let params = controller.get_launch_params(pid) + let params = controller + .get_launch_params(pid) .ok_or(StatusCode::NOT_FOUND)?; - + // Launch new instance with same parameters - let new_pid = controller.launch_g3( - params.workspace.to_str().unwrap(), - ¶ms.provider, - ¶ms.model, - ¶ms.prompt, - params.autonomous, - params.g3_binary_path.as_deref(), - ).map_err(|e| { - error!("Failed to restart instance: {}", e); - StatusCode::INTERNAL_SERVER_ERROR - })?; - + let new_pid = controller + .launch_g3( + params.workspace.to_str().unwrap(), + ¶ms.provider, + ¶ms.model, + ¶ms.prompt, + params.autonomous, + params.g3_binary_path.as_deref(), + ) + .map_err(|e| { + error!("Failed to restart instance: {}", e); + StatusCode::INTERNAL_SERVER_ERROR + })?; + let new_id = format!("{}_{}", new_pid, chrono::Utc::now().timestamp()); - + Ok(Json(LaunchResponse { id: new_id, status: "starting".to_string(), @@ -79,7 +82,7 @@ pub async fn launch_instance( Json(request): Json, ) -> Result, (StatusCode, Json)> { info!("Launching new g3 instance: {:?}", request); - + // Validate binary path if provided if let Some(ref binary_path) = request.g3_binary_path { // Expand relative paths and resolve to absolute @@ -90,16 +93,19 @@ pub async fn launch_instance( } else { std::path::PathBuf::from(binary_path) }; - + // Check if file exists if !path.exists() { error!("G3 binary not found: {}", binary_path); - return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ - "error": "G3 binary not found", - "message": format!("The specified g3 binary does not exist: {}", binary_path) - })))); + return Err(( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "G3 binary not found", + "message": format!("The specified g3 binary does not exist: {}", binary_path) + })), + )); } - + // Check if file is executable (Unix only) #[cfg(unix)] { @@ -107,26 +113,32 @@ pub async fn launch_instance( if let Ok(metadata) = std::fs::metadata(path) { if metadata.permissions().mode() & 0o111 == 0 { error!("G3 binary is not executable: {}", binary_path); - return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ - "error": "G3 binary is not executable", - "message": format!("The specified g3 binary is not executable: {}", binary_path) - })))); + return Err(( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "G3 binary is not executable", + "message": format!("The specified g3 binary is not executable: {}", binary_path) + })), + )); } } } } - + let workspace = request.workspace.to_str().ok_or_else(|| { - (StatusCode::BAD_REQUEST, Json(serde_json::json!({ - "error": "Invalid workspace path", - "message": "The workspace path contains invalid characters" - }))) + ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "Invalid workspace path", + "message": "The workspace path contains invalid characters" + })), + ) })?; let autonomous = request.mode == LaunchMode::Ensemble; let g3_binary_path = request.g3_binary_path.as_deref(); - + let mut controller = controller.lock().await; - + match controller.launch_g3( workspace, &request.provider, @@ -145,10 +157,13 @@ pub async fn launch_instance( } Err(e) => { error!("Failed to launch g3 instance: {}", e); - Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ - "error": "Failed to launch instance", - "message": format!("Error: {}", e) - })))) + Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": "Failed to launch instance", + "message": format!("Error: {}", e) + })), + )) } } } diff --git a/crates/g3-console/src/api/instances.rs b/crates/g3-console/src/api/instances.rs index cd29521..a5a565c 100644 --- a/crates/g3-console/src/api/instances.rs +++ b/crates/g3-console/src/api/instances.rs @@ -1,7 +1,11 @@ use crate::logs::{LogParser, StatsAggregator}; use crate::models::*; use crate::process::ProcessDetector; -use axum::{extract::{Query, State}, http::StatusCode, Json}; +use axum::{ + extract::{Query, State}, + http::StatusCode, + Json, +}; use serde::Deserialize; use std::sync::Arc; use tokio::sync::Mutex; @@ -13,11 +17,11 @@ pub async fn list_instances( State(detector): State, ) -> Result>, StatusCode> { let mut detector = detector.lock().await; - + match detector.detect_instances() { Ok(instances) => { let mut details = Vec::new(); - + for instance in instances { match get_instance_detail(&instance) { Ok(detail) => details.push(detail), @@ -27,7 +31,7 @@ pub async fn list_instances( } } } - + Ok(Json(details)) } Err(e) => { @@ -42,7 +46,7 @@ pub async fn get_instance( axum::extract::Path(id): axum::extract::Path, ) -> Result, StatusCode> { let mut detector = detector.lock().await; - + match detector.detect_instances() { Ok(instances) => { if let Some(instance) = instances.into_iter().find(|i| i.id == id) { @@ -69,30 +73,36 @@ fn get_instance_detail(instance: &Instance) -> anyhow::Result { let log_entries = match LogParser::parse_logs(&instance.workspace) { Ok(entries) => entries, Err(e) => { - warn!("Failed to parse logs for instance {}: {}. Instance may be newly started.", instance.id, e); + warn!( + "Failed to parse logs for instance {}: {}. Instance may be newly started.", + instance.id, e + ); Vec::new() } }; - + // Aggregate stats let is_ensemble = instance.instance_type == crate::models::InstanceType::Ensemble; let stats = StatsAggregator::aggregate_stats(&log_entries, instance.start_time, is_ensemble); - + // Get latest message let latest_message = StatsAggregator::get_latest_message(&log_entries); - + // Get git status - don't fail if not a git repo let git_status = match get_git_status(&instance.workspace) { Some(status) => Some(status), None => { - debug!("No git status available for workspace: {:?}", instance.workspace); + debug!( + "No git status available for workspace: {:?}", + instance.workspace + ); None } }; - + // Get project files let project_files = get_project_files(&instance.workspace); - + Ok(InstanceDetail { instance: instance.clone(), stats, @@ -104,7 +114,7 @@ fn get_instance_detail(instance: &Instance) -> anyhow::Result { fn get_git_status(workspace: &std::path::Path) -> Option { use std::process::Command; - + // Get current branch let branch = Command::new("git") .arg("-C") @@ -115,7 +125,7 @@ fn get_git_status(workspace: &std::path::Path) -> Option { .ok() .and_then(|output| String::from_utf8(output.stdout).ok()) .map(|s| s.trim().to_string())?; - + // Get status let status_output = Command::new("git") .arg("-C") @@ -125,19 +135,19 @@ fn get_git_status(workspace: &std::path::Path) -> Option { .output() .ok() .and_then(|output| String::from_utf8(output.stdout).ok())?; - + let mut modified_files = Vec::new(); let mut added_files = Vec::new(); let mut deleted_files = Vec::new(); - + for line in status_output.lines() { if line.len() < 4 { continue; } - + let status = &line[0..2]; let file = line[3..].trim(); - + match status.trim() { "M" | "MM" => modified_files.push(file.to_string()), "A" | "AM" => added_files.push(file.to_string()), @@ -145,9 +155,9 @@ fn get_git_status(workspace: &std::path::Path) -> Option { _ => modified_files.push(file.to_string()), } } - + let uncommitted_changes = modified_files.len() + added_files.len() + deleted_files.len(); - + Some(GitStatus { branch, uncommitted_changes, @@ -161,7 +171,7 @@ fn get_project_files(workspace: &std::path::Path) -> ProjectFiles { let requirements = read_file_snippet(workspace, "requirements.md"); let readme = read_file_snippet(workspace, "README.md"); let agents = read_file_snippet(workspace, "AGENTS.md"); - + ProjectFiles { requirements, readme, @@ -171,22 +181,16 @@ fn get_project_files(workspace: &std::path::Path) -> ProjectFiles { fn read_file_snippet(workspace: &std::path::Path, filename: &str) -> Option { use std::fs; - + let path = workspace.join(filename); if !path.exists() { return None; } - - fs::read_to_string(&path) - .ok() - .map(|content| { - // Return first 10 lines - content - .lines() - .take(10) - .collect::>() - .join("\n") - }) + + fs::read_to_string(&path).ok().map(|content| { + // Return first 10 lines + content.lines().take(10).collect::>().join("\n") + }) } #[derive(Deserialize)] @@ -200,20 +204,25 @@ pub async fn get_file_content( State(detector): State, ) -> Result, StatusCode> { let mut detector = detector.lock().await; - + // Find the instance - let instances = detector.detect_instances().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - let instance = instances.iter().find(|i| i.id == id).ok_or(StatusCode::NOT_FOUND)?; - + let instances = detector + .detect_instances() + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let instance = instances + .iter() + .find(|i| i.id == id) + .ok_or(StatusCode::NOT_FOUND)?; + // Read the full file let file_path = instance.workspace.join(&query.name); if !file_path.exists() { return Err(StatusCode::NOT_FOUND); } - - let content = std::fs::read_to_string(&file_path) - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - + + let content = + std::fs::read_to_string(&file_path).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + Ok(Json(serde_json::json!({ "name": query.name, "content": content, diff --git a/crates/g3-console/src/api/logs.rs b/crates/g3-console/src/api/logs.rs index 4001bbc..088f4aa 100644 --- a/crates/g3-console/src/api/logs.rs +++ b/crates/g3-console/src/api/logs.rs @@ -12,7 +12,7 @@ pub async fn get_instance_logs( axum::extract::Path(id): axum::extract::Path, ) -> Result, StatusCode> { let mut detector = detector.lock().await; - + match detector.detect_instances() { Ok(instances) => { if let Some(instance) = instances.into_iter().find(|i| i.id == id) { @@ -20,7 +20,7 @@ pub async fn get_instance_logs( Ok(entries) => { let messages = LogParser::extract_chat_messages(&entries); let tool_calls = LogParser::extract_tool_calls(&entries); - + Ok(Json(serde_json::json!({ "messages": messages, "tool_calls": tool_calls, diff --git a/crates/g3-console/src/api/mod.rs b/crates/g3-console/src/api/mod.rs index eb5ac59..80e5599 100644 --- a/crates/g3-console/src/api/mod.rs +++ b/crates/g3-console/src/api/mod.rs @@ -1,4 +1,4 @@ -pub mod instances; pub mod control; +pub mod instances; pub mod logs; pub mod state; diff --git a/crates/g3-console/src/api/state.rs b/crates/g3-console/src/api/state.rs index 57aa627..fc31a19 100644 --- a/crates/g3-console/src/api/state.rs +++ b/crates/g3-console/src/api/state.rs @@ -1,8 +1,8 @@ use crate::launch::ConsoleState; use axum::{http::StatusCode, Json}; use serde::{Deserialize, Serialize}; -use std::path::PathBuf; use std::os::unix::fs::PermissionsExt; +use std::path::PathBuf; use tracing::{error, info}; pub async fn get_state() -> Result, StatusCode> { @@ -52,24 +52,26 @@ pub async fn browse_filesystem( Json(request): Json, ) -> Result, StatusCode> { use std::fs; - + let path = if let Some(p) = request.path { PathBuf::from(p) } else { std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")) }; - - let current_path = path.canonicalize() + + let current_path = path + .canonicalize() .map_err(|_| StatusCode::BAD_REQUEST)? .to_string_lossy() .to_string(); - - let parent_path = path.parent() + + let parent_path = path + .parent() .and_then(|p| p.to_str()) .map(|s| s.to_string()); - + let mut entries = Vec::new(); - + if let Ok(read_dir) = fs::read_dir(&path) { for entry in read_dir.flatten() { if let Ok(metadata) = entry.metadata() { @@ -82,15 +84,13 @@ pub async fn browse_filesystem( } } } - - entries.sort_by(|a, b| { - match (a.is_dir, b.is_dir) { - (true, false) => std::cmp::Ordering::Less, - (false, true) => std::cmp::Ordering::Greater, - _ => a.name.cmp(&b.name), - } + + entries.sort_by(|a, b| match (a.is_dir, b.is_dir) { + (true, false) => std::cmp::Ordering::Less, + (false, true) => std::cmp::Ordering::Greater, + _ => a.name.cmp(&b.name), }); - + Ok(Json(BrowseResponse { current_path, parent_path, diff --git a/crates/g3-console/src/launch.rs b/crates/g3-console/src/launch.rs index c241903..cd46f4f 100644 --- a/crates/g3-console/src/launch.rs +++ b/crates/g3-console/src/launch.rs @@ -27,7 +27,7 @@ impl Default for ConsoleState { impl ConsoleState { pub fn load() -> Self { let config_path = Self::config_path(); - + if config_path.exists() { if let Ok(content) = fs::read_to_string(&config_path) { return serde_json::from_str(&content).unwrap_or_else(|e| { @@ -36,31 +36,29 @@ impl ConsoleState { }); } } - + Self::default() } - + pub fn save(&self) -> anyhow::Result<()> { let config_path = Self::config_path(); info!("Saving console state to: {:?}", config_path); - + // Create parent directory if it doesn't exist if let Some(parent) = config_path.parent() { fs::create_dir_all(parent)?; } - + let content = serde_json::to_string_pretty(self)?; fs::write(&config_path, content)?; info!("Console state saved successfully to: {:?}", config_path); - + Ok(()) } - + fn config_path() -> PathBuf { // Use explicit ~/.config/g3/console.json path as per requirements let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from(".")); - home.join(".config") - .join("g3") - .join("console.json") + home.join(".config").join("g3").join("console.json") } } diff --git a/crates/g3-console/src/lib.rs b/crates/g3-console/src/lib.rs index aa6f18a..5a6ee14 100644 --- a/crates/g3-console/src/lib.rs +++ b/crates/g3-console/src/lib.rs @@ -1,5 +1,5 @@ pub mod api; +pub mod launch; pub mod logs; pub mod models; pub mod process; -pub mod launch; diff --git a/crates/g3-console/src/logs.rs b/crates/g3-console/src/logs.rs index 59207e7..1330e90 100644 --- a/crates/g3-console/src/logs.rs +++ b/crates/g3-console/src/logs.rs @@ -36,7 +36,7 @@ impl LogParser { /// Parse logs from a workspace directory pub fn parse_logs(workspace: &Path) -> Result> { let logs_dir = workspace.join("logs"); - + if !logs_dir.exists() { return Ok(Vec::new()); } @@ -47,7 +47,7 @@ impl LogParser { for entry in fs::read_dir(&logs_dir).context("Failed to read logs directory")? { let entry = entry?; let path = entry.path(); - + if path.extension().and_then(|s| s.to_str()) == Some("json") { if let Ok(content) = fs::read_to_string(&path) { if let Ok(json) = serde_json::from_str::(&content) { @@ -55,17 +55,21 @@ impl LogParser { if let Some(messages) = json.get("messages").and_then(|m| m.as_array()) { for msg in messages { entries.push(LogEntry { - timestamp: msg.get("timestamp") + timestamp: msg + .get("timestamp") .and_then(|t| t.as_str()) .and_then(|s| DateTime::parse_from_rfc3339(s).ok()) .map(|dt| dt.with_timezone(&Utc)), - role: msg.get("role") + role: msg + .get("role") .and_then(|r| r.as_str()) .map(String::from), - content: msg.get("content") + content: msg + .get("content") .and_then(|c| c.as_str()) .map(String::from), - tool_calls: msg.get("tool_calls") + tool_calls: msg + .get("tool_calls") .and_then(|tc| tc.as_array()) .map(|arr| arr.clone()), raw: msg.clone(), @@ -78,13 +82,11 @@ impl LogParser { } // Sort by timestamp - entries.sort_by(|a, b| { - match (&a.timestamp, &b.timestamp) { - (Some(t1), Some(t2)) => t1.cmp(t2), - (Some(_), None) => std::cmp::Ordering::Less, - (None, Some(_)) => std::cmp::Ordering::Greater, - (None, None) => std::cmp::Ordering::Equal, - } + entries.sort_by(|a, b| match (&a.timestamp, &b.timestamp) { + (Some(t1), Some(t2)) => t1.cmp(t2), + (Some(_), None) => std::cmp::Ordering::Less, + (None, Some(_)) => std::cmp::Ordering::Greater, + (None, None) => std::cmp::Ordering::Equal, }); Ok(entries) @@ -97,7 +99,7 @@ impl LogParser { .filter_map(|entry| { let role = entry.role.clone()?; let content = entry.content.clone()?; - + Some(ChatMessage { role, content, @@ -117,10 +119,12 @@ impl LogParser { if let Some(name) = call.get("name").and_then(|n| n.as_str()) { tool_calls.push(ToolCall { name: name.to_string(), - parameters: call.get("parameters") + parameters: call + .get("parameters") .cloned() .unwrap_or(Value::Object(serde_json::Map::new())), - result: call.get("result") + result: call + .get("result") .and_then(|r| r.as_str()) .map(String::from), timestamp: entry.timestamp, @@ -146,7 +150,7 @@ impl StatsAggregator { let total_tokens = Self::count_tokens(entries); let tool_calls = Self::count_tool_calls(entries); let errors = Self::count_errors(entries); - + let duration_secs = if let Some(last_entry) = entries.last() { if let Some(last_time) = last_entry.timestamp { (last_time - start_time).num_seconds().max(0) as u64 @@ -193,7 +197,9 @@ impl StatsAggregator { entries .iter() .filter_map(|entry| { - entry.raw.get("usage") + entry + .raw + .get("usage") .and_then(|u| u.get("total_tokens")) .and_then(|t| t.as_u64()) }) @@ -213,7 +219,11 @@ impl StatsAggregator { .iter() .filter(|entry| { entry.raw.get("error").is_some() - || entry.content.as_ref().map(|c| c.to_lowercase().contains("error")).unwrap_or(false) + || entry + .content + .as_ref() + .map(|c| c.to_lowercase().contains("error")) + .unwrap_or(false) }) .count() as u64 } diff --git a/crates/g3-console/src/main.rs b/crates/g3-console/src/main.rs index f7af8f7..182a316 100644 --- a/crates/g3-console/src/main.rs +++ b/crates/g3-console/src/main.rs @@ -1,11 +1,11 @@ use g3_console::api; -use g3_console::process; use g3_console::launch; +use g3_console::process; use api::control::{kill_instance, launch_instance, restart_instance}; -use api::instances::{get_instance, get_file_content, list_instances}; +use api::instances::{get_file_content, get_instance, list_instances}; use api::logs::get_instance_logs; -use api::state::{get_state, save_state, browse_filesystem}; +use api::state::{browse_filesystem, get_state, save_state}; use axum::{ routing::{get, post}, Router, @@ -39,9 +39,7 @@ struct Args { #[tokio::main] async fn main() -> anyhow::Result<()> { // Initialize tracing - tracing_subscriber::fmt() - .with_max_level(Level::INFO) - .init(); + tracing_subscriber::fmt().with_max_level(Level::INFO).init(); let args = Args::parse(); diff --git a/crates/g3-console/src/models/instance.rs b/crates/g3-console/src/models/instance.rs index f2dd634..c924637 100644 --- a/crates/g3-console/src/models/instance.rs +++ b/crates/g3-console/src/models/instance.rs @@ -1,6 +1,6 @@ +use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use std::path::PathBuf; -use chrono::{DateTime, Utc}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Instance { diff --git a/crates/g3-console/src/models/message.rs b/crates/g3-console/src/models/message.rs index 5a2b93a..0532086 100644 --- a/crates/g3-console/src/models/message.rs +++ b/crates/g3-console/src/models/message.rs @@ -1,5 +1,5 @@ -use serde::{Deserialize, Serialize}; use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatMessage { diff --git a/crates/g3-console/src/process/controller.rs b/crates/g3-console/src/process/controller.rs index a402370..3a1f8f9 100644 --- a/crates/g3-console/src/process/controller.rs +++ b/crates/g3-console/src/process/controller.rs @@ -1,12 +1,12 @@ -use anyhow::{anyhow, Context, Result}; -use std::process::{Command, Stdio}; -use std::os::unix::process::CommandExt; -use std::collections::HashMap; -use std::sync::Mutex; -use std::path::PathBuf; -use sysinfo::{Pid, Signal, System, Process}; -use tracing::{debug, info}; use crate::models::LaunchParams; +use anyhow::{anyhow, Context, Result}; +use std::collections::HashMap; +use std::os::unix::process::CommandExt; +use std::path::PathBuf; +use std::process::{Command, Stdio}; +use std::sync::Mutex; +use sysinfo::{Pid, Process, Signal, System}; +use tracing::{debug, info}; pub struct ProcessController { system: System, @@ -27,15 +27,15 @@ impl ProcessController { if let Some(process) = self.system.process(sysinfo_pid) { info!("Killing process {} ({})", pid, process.name()); - + // Try SIGTERM first if process.kill_with(Signal::Term).is_some() { debug!("Sent SIGTERM to process {}", pid); - + // Wait a bit and check if it's still running std::thread::sleep(std::time::Duration::from_secs(2)); self.system.refresh_processes(); - + if self.system.process(sysinfo_pid).is_some() { // Still running, send SIGKILL if let Some(proc) = self.system.process(sysinfo_pid) { @@ -43,7 +43,7 @@ impl ProcessController { debug!("Sent SIGKILL to process {}", pid); } } - + Ok(()) } else { Err(anyhow!("Failed to send signal to process {}", pid)) @@ -64,7 +64,7 @@ impl ProcessController { g3_binary_path: Option<&str>, ) -> Result { let binary = g3_binary_path.unwrap_or("g3"); - + let mut cmd = Command::new(binary); cmd.arg("--workspace") .arg(workspace) @@ -108,36 +108,41 @@ impl ProcessController { } info!("Launching g3: {:?}", cmd); - + // Spawn and wait for the intermediate process to exit let mut child = cmd.spawn().context("Failed to spawn g3 process")?; let intermediate_pid = child.id(); - + // Wait for intermediate process (it will exit immediately after forking) - child.wait().context("Failed to wait for intermediate process")?; - + child + .wait() + .context("Failed to wait for intermediate process")?; + // The actual g3 process is now running as orphan // We need to scan for it by matching workspace and recent start time - info!("Scanning for newly launched g3 process in workspace: {}", workspace); - + info!( + "Scanning for newly launched g3 process in workspace: {}", + workspace + ); + // Wait even longer for the process to fully start and appear in process list std::thread::sleep(std::time::Duration::from_millis(2500)); - + // Refresh and scan for the process self.system.refresh_processes(); let workspace_path = PathBuf::from(workspace); let mut found_pid = None; - + for (pid, process) in self.system.processes() { let cmd = process.cmd(); let cmd_str = cmd.join(" "); - + // Check if this is a g3 process let is_g3 = process.name().contains("g3") || cmd_str.contains("g3"); if !is_g3 { continue; } - + // Check if it has our workspace let has_workspace = cmd.iter().any(|arg| { if let Ok(path) = PathBuf::from(arg).canonicalize() { @@ -147,11 +152,12 @@ impl ProcessController { } false }); - + if has_workspace { // Check if it's recent (started within last 10 seconds) let now = std::time::SystemTime::now(); - let start_time = std::time::UNIX_EPOCH + std::time::Duration::from_secs(process.start_time()); + let start_time = + std::time::UNIX_EPOCH + std::time::Duration::from_secs(process.start_time()); if let Ok(duration) = now.duration_since(start_time) { if duration.as_secs() < 10 { found_pid = Some(pid.as_u32()); @@ -160,7 +166,7 @@ impl ProcessController { } } } - + let pid = if let Some(found) = found_pid { found } else { @@ -168,18 +174,18 @@ impl ProcessController { info!("Process not found on first scan, trying again..."); std::thread::sleep(std::time::Duration::from_millis(2000)); self.system.refresh_processes(); - + // Try the scan again with full logic let mut retry_found = None; for (pid, process) in self.system.processes() { let cmd = process.cmd(); let cmd_str = cmd.join(" "); - + let is_g3 = process.name().contains("g3") || cmd_str.contains("g3"); if !is_g3 { continue; } - + let has_workspace = cmd.iter().any(|arg| { if let Ok(path) = PathBuf::from(arg).canonicalize() { if let Ok(ws) = workspace_path.canonicalize() { @@ -188,18 +194,18 @@ impl ProcessController { } false }); - + if has_workspace { retry_found = Some(pid.as_u32()); break; } } - + retry_found.unwrap_or(intermediate_pid) }; info!("Launched g3 process with PID {}", pid); - + // Store launch params for restart let params = LaunchParams { workspace: workspace.into(), @@ -209,14 +215,14 @@ impl ProcessController { autonomous, g3_binary_path: g3_binary_path.map(|s| s.to_string()), }; - + if let Ok(mut map) = self.launch_params.lock() { map.insert(pid, params); } - + Ok(pid) } - + pub fn get_launch_params(&mut self, pid: u32) -> Option { // First check if we have stored params (for console-launched instances) if let Ok(map) = self.launch_params.lock() { @@ -224,19 +230,19 @@ impl ProcessController { return Some(params.clone()); } } - + // If not found, try to parse from process command line (for detected instances) self.system.refresh_processes(); let sysinfo_pid = Pid::from_u32(pid); - + if let Some(process) = self.system.process(sysinfo_pid) { let cmd = process.cmd(); return self.parse_launch_params_from_cmd(cmd); } - + None } - + fn parse_launch_params_from_cmd(&self, cmd: &[String]) -> Option { let mut workspace = None; let mut provider = None; @@ -244,7 +250,7 @@ impl ProcessController { let mut prompt = None; let mut autonomous = false; let mut g3_binary_path = None; - + let mut i = 0; while i < cmd.len() { match cmd[i].as_str() { @@ -273,7 +279,7 @@ impl ProcessController { } } } - + // Try to determine binary path from cmd[0] if !cmd.is_empty() { let first = &cmd[0]; @@ -281,9 +287,10 @@ impl ProcessController { g3_binary_path = Some(first.clone()); } } - + // Only return params if we have the minimum required fields - if let (Some(ws), Some(prov), Some(mdl), Some(prmt)) = (workspace, provider, model, prompt) { + if let (Some(ws), Some(prov), Some(mdl), Some(prmt)) = (workspace, provider, model, prompt) + { Some(LaunchParams { workspace: ws, provider: prov, diff --git a/crates/g3-console/src/process/detector.rs b/crates/g3-console/src/process/detector.rs index 9b488f7..43b7b5c 100644 --- a/crates/g3-console/src/process/detector.rs +++ b/crates/g3-console/src/process/detector.rs @@ -2,7 +2,7 @@ use crate::models::{ExecutionMethod, Instance, InstanceStatus, InstanceType}; use anyhow::Result; use chrono::{DateTime, Utc}; use std::path::PathBuf; -use sysinfo::{System, Pid, Process}; +use sysinfo::{Pid, Process, System}; use tracing::{debug, info, warn}; pub struct ProcessDetector { @@ -41,36 +41,37 @@ impl ProcessDetector { Ok(instances) } - fn parse_g3_process( - &self, - pid: Pid, - process: &Process, - cmd: &[String], - ) -> Option { + fn parse_g3_process(&self, pid: Pid, process: &Process, cmd: &[String]) -> Option { let cmd_str = cmd.join(" "); - + // Exclude g3-console itself if cmd_str.contains("g3-console") { return None; } - + // Check if this is a g3 binary (more comprehensive check) - let is_g3_binary = cmd.get(0).map(|s| { - (s.ends_with("g3") || s.ends_with("/g3") || s.contains("/target/release/g3") || s.contains("/target/debug/g3")) - && !s.contains("g3-") // Exclude other g3-* binaries - }).unwrap_or(false); - + let is_g3_binary = cmd + .get(0) + .map(|s| { + (s.ends_with("g3") + || s.ends_with("/g3") + || s.contains("/target/release/g3") + || s.contains("/target/debug/g3")) + && !s.contains("g3-") // Exclude other g3-* binaries + }) + .unwrap_or(false); + // Check if this is cargo run with g3 (not g3-console or other variants) - let is_cargo_run = cmd.get(0).map(|s| s.contains("cargo")).unwrap_or(false) + let is_cargo_run = cmd.get(0).map(|s| s.contains("cargo")).unwrap_or(false) && cmd.iter().any(|s| s == "run") && !cmd_str.contains("g3-console"); - + // Also check if command line has g3-specific flags let has_g3_flags = cmd_str.contains("--workspace") || cmd_str.contains("--autonomous"); - + // Accept if it's a g3 binary or cargo run with g3, and has typical g3 patterns let is_g3_process = is_g3_binary || (is_cargo_run && has_g3_flags); - + if !is_g3_process { return None; } @@ -97,8 +98,8 @@ impl ProcessDetector { let model = self.extract_flag_value(cmd, "--model"); // Get start time - let start_time = DateTime::from_timestamp(process.start_time() as i64, 0) - .unwrap_or_else(Utc::now); + let start_time = + DateTime::from_timestamp(process.start_time() as i64, 0).unwrap_or_else(Utc::now); // Generate instance ID from PID and start time let id = format!("{}_{}", pid, start_time.timestamp()); @@ -139,7 +140,7 @@ impl ProcessDetector { return Some(cwd); } } - + #[cfg(target_os = "macos")] { // On macOS, use lsof to get the current working directory @@ -156,9 +157,12 @@ impl ProcessDetector { } } } - + // Final fallback: use current directory of console - warn!("Could not determine workspace for PID {}, using current directory", pid); + warn!( + "Could not determine workspace for PID {}, using current directory", + pid + ); std::env::current_dir().ok() } @@ -173,7 +177,7 @@ impl ProcessDetector { pub fn get_process_status(&mut self, pid: u32) -> Option { self.system.refresh_all(); - + let sysinfo_pid = Pid::from_u32(pid); if self.system.process(sysinfo_pid).is_some() { Some(InstanceStatus::Running) diff --git a/crates/g3-console/src/process/mod.rs b/crates/g3-console/src/process/mod.rs index ccfeba6..0239b22 100644 --- a/crates/g3-console/src/process/mod.rs +++ b/crates/g3-console/src/process/mod.rs @@ -1,5 +1,5 @@ -pub mod detector; pub mod controller; +pub mod detector; -pub use detector::*; pub use controller::*; +pub use detector::*; diff --git a/crates/g3-core/examples/inspect_ast.rs b/crates/g3-core/examples/inspect_ast.rs index b4e7fff..58ba065 100644 --- a/crates/g3-core/examples/inspect_ast.rs +++ b/crates/g3-core/examples/inspect_ast.rs @@ -1,6 +1,6 @@ //! Inspect tree-sitter AST structure for Rust code -use tree_sitter::{Parser, Language}; +use tree_sitter::{Language, Parser}; fn print_tree(node: tree_sitter::Node, source: &str, indent: usize) { let indent_str = " ".repeat(indent); @@ -10,7 +10,7 @@ fn print_tree(node: tree_sitter::Node, source: &str, indent: usize) { } else { node_text.to_string() }; - + println!( "{}{} [{}:{}] '{}'", indent_str, @@ -19,7 +19,7 @@ fn print_tree(node: tree_sitter::Node, source: &str, indent: usize) { node.start_position().column + 1, preview.replace('\n', "\\n") ); - + let mut cursor = node.walk(); for child in node.children(&mut cursor) { print_tree(child, source, indent + 1); diff --git a/crates/g3-core/examples/inspect_python_ast.rs b/crates/g3-core/examples/inspect_python_ast.rs index 78675b1..3f4e8e8 100644 --- a/crates/g3-core/examples/inspect_python_ast.rs +++ b/crates/g3-core/examples/inspect_python_ast.rs @@ -1,6 +1,6 @@ //! Inspect tree-sitter AST structure for Python code -use tree_sitter::{Parser, Language}; +use tree_sitter::{Language, Parser}; fn print_tree(node: tree_sitter::Node, source: &str, indent: usize) { let indent_str = " ".repeat(indent); @@ -10,7 +10,7 @@ fn print_tree(node: tree_sitter::Node, source: &str, indent: usize) { } else { node_text.to_string() }; - + println!( "{}{} [{}:{}] '{}'", indent_str, @@ -19,7 +19,7 @@ fn print_tree(node: tree_sitter::Node, source: &str, indent: usize) { node.start_position().column + 1, preview.replace('\n', "\\n") ); - + let mut cursor = node.walk(); for child in node.children(&mut cursor) { print_tree(child, source, indent + 1); diff --git a/crates/g3-core/examples/test_python_query.rs b/crates/g3-core/examples/test_python_query.rs index 51220c3..c0eb865 100644 --- a/crates/g3-core/examples/test_python_query.rs +++ b/crates/g3-core/examples/test_python_query.rs @@ -1,7 +1,7 @@ //! Test Python async query -use tree_sitter::{Parser, Query, QueryCursor, Language}; use streaming_iterator::StreamingIterator; +use tree_sitter::{Language, Parser, Query, QueryCursor}; fn main() -> anyhow::Result<()> { let source_code = r#" diff --git a/crates/g3-core/src/code_search/searcher.rs b/crates/g3-core/src/code_search/searcher.rs index 6e2f913..b38cd72 100644 --- a/crates/g3-core/src/code_search/searcher.rs +++ b/crates/g3-core/src/code_search/searcher.rs @@ -3,8 +3,8 @@ use anyhow::{anyhow, Result}; use std::collections::HashMap; use std::fs; use std::path::Path; -use tree_sitter::{Language, Parser, Query, QueryCursor}; use streaming_iterator::StreamingIterator; +use tree_sitter::{Language, Parser, Query, QueryCursor}; use walkdir::WalkDir; pub struct TreeSitterSearcher { @@ -47,10 +47,11 @@ impl TreeSitterSearcher { .set_language(&language) .map_err(|e| anyhow!("Failed to set JavaScript language: {}", e))?; parsers.insert("javascript".to_string(), parser); - + // Create separate parser for "js" alias let mut parser_js = Parser::new(); - parser_js.set_language(&language) + parser_js + .set_language(&language) .map_err(|e| anyhow!("Failed to set JavaScript language: {}", e))?; parsers.insert("js".to_string(), parser_js); languages.insert("javascript".to_string(), language.clone()); @@ -65,10 +66,11 @@ impl TreeSitterSearcher { .set_language(&language) .map_err(|e| anyhow!("Failed to set TypeScript language: {}", e))?; parsers.insert("typescript".to_string(), parser); - + // Create separate parser for "ts" alias let mut parser_ts = Parser::new(); - parser_ts.set_language(&language) + parser_ts + .set_language(&language) .map_err(|e| anyhow!("Failed to set TypeScript language: {}", e))?; parsers.insert("ts".to_string(), parser_ts); languages.insert("typescript".to_string(), language.clone()); @@ -215,8 +217,8 @@ impl TreeSitterSearcher { .ok_or_else(|| anyhow!("Language not found: {}", spec.language))?; // Parse query - let query = Query::new(language, &spec.query) - .map_err(|e| anyhow!("Invalid query: {}", e))?; + let query = + Query::new(language, &spec.query).map_err(|e| anyhow!("Invalid query: {}", e))?; let mut matches = Vec::new(); let mut files_searched = 0; @@ -255,11 +257,8 @@ impl TreeSitterSearcher { if let Ok(source_code) = fs::read_to_string(path) { if let Some(tree) = parser.parse(&source_code, None) { let mut cursor = QueryCursor::new(); - let mut query_matches = cursor.matches( - &query, - tree.root_node(), - source_code.as_bytes(), - ); + let mut query_matches = + cursor.matches(&query, tree.root_node(), source_code.as_bytes()); query_matches.advance(); while let Some(query_match) = query_matches.get() { @@ -308,7 +307,7 @@ impl TreeSitterSearcher { captures: captures_map, context, }); - + query_matches.advance(); } } diff --git a/crates/g3-core/src/error_handling.rs b/crates/g3-core/src/error_handling.rs index 7646df6..c2aea36 100644 --- a/crates/g3-core/src/error_handling.rs +++ b/crates/g3-core/src/error_handling.rs @@ -106,15 +106,15 @@ impl ErrorContext { error!("Session ID: {:?}", self.session_id); error!("Context Tokens: {}", self.context_tokens); error!("Last Prompt: {}", self.last_prompt); - + if let Some(ref req) = self.raw_request { error!("Raw Request: {}", req); } - + if let Some(ref resp) = self.raw_response { error!("Raw Response: {}", resp); } - + error!("Stack Trace:\n{}", self.stack_trace); error!("=== END ERROR DETAILS ==="); @@ -191,23 +191,36 @@ pub fn classify_error(error: &anyhow::Error) -> ErrorType { let error_str = error.to_string().to_lowercase(); // Check for recoverable error patterns - if error_str.contains("rate limit") || error_str.contains("rate_limit") || error_str.contains("429") { + if error_str.contains("rate limit") + || error_str.contains("rate_limit") + || error_str.contains("429") + { return ErrorType::Recoverable(RecoverableError::RateLimit); } - if error_str.contains("network") || error_str.contains("connection") || - error_str.contains("dns") || error_str.contains("refused") { + if error_str.contains("network") + || error_str.contains("connection") + || error_str.contains("dns") + || error_str.contains("refused") + { return ErrorType::Recoverable(RecoverableError::NetworkError); } - if error_str.contains("500") || error_str.contains("502") || - error_str.contains("503") || error_str.contains("504") || - error_str.contains("server error") || error_str.contains("internal error") { + if error_str.contains("500") + || error_str.contains("502") + || error_str.contains("503") + || error_str.contains("504") + || error_str.contains("server error") + || error_str.contains("internal error") + { return ErrorType::Recoverable(RecoverableError::ServerError); } - if error_str.contains("busy") || error_str.contains("overloaded") || - error_str.contains("capacity") || error_str.contains("unavailable") { + if error_str.contains("busy") + || error_str.contains("overloaded") + || error_str.contains("capacity") + || error_str.contains("unavailable") + { return ErrorType::Recoverable(RecoverableError::ModelBusy); } @@ -216,18 +229,24 @@ pub fn classify_error(error: &anyhow::Error) -> ErrorType { error_str.contains("timed out") || error_str.contains("operation timed out") || error_str.contains("request or response body error") || // Common timeout pattern - error_str.contains("stream error") && error_str.contains("timed out") { + error_str.contains("stream error") && error_str.contains("timed out") + { return ErrorType::Recoverable(RecoverableError::Timeout); } // Check for context length exceeded errors (HTTP 400 with specific messages) - if (error_str.contains("400") || error_str.contains("bad request")) && - (error_str.contains("context length") || error_str.contains("prompt is too long") || - error_str.contains("maximum context length") || error_str.contains("context_length_exceeded")) { + if (error_str.contains("400") || error_str.contains("bad request")) + && (error_str.contains("context length") + || error_str.contains("prompt is too long") + || error_str.contains("maximum context length") + || error_str.contains("context_length_exceeded")) + { return ErrorType::Recoverable(RecoverableError::ContextLengthExceeded); } - if error_str.contains("token") && (error_str.contains("limit") || error_str.contains("exceeded")) { + if error_str.contains("token") + && (error_str.contains("limit") || error_str.contains("exceeded")) + { return ErrorType::Recoverable(RecoverableError::TokenLimit); } @@ -239,12 +258,14 @@ pub fn classify_error(error: &anyhow::Error) -> ErrorType { fn calculate_autonomous_retry_delay(attempt: u32) -> Duration { use rand::Rng; let mut rng = rand::thread_rng(); - + // Distribute 6 retries over 10 minutes (600 seconds) // Base delays: 10s, 30s, 60s, 120s, 180s, 200s = 600s total let base_delays_ms = [10000, 30000, 60000, 120000, 180000, 200000]; - let base_delay = base_delays_ms.get(attempt.saturating_sub(1) as usize).unwrap_or(&200000); - + let base_delay = base_delays_ms + .get(attempt.saturating_sub(1) as usize) + .unwrap_or(&200000); + // Add jitter of ±30% to prevent thundering herd let jitter = (*base_delay as f64 * 0.3 * rng.gen::()) as u64; let final_delay = if rng.gen_bool(0.5) { @@ -252,7 +273,7 @@ fn calculate_autonomous_retry_delay(attempt: u32) -> Duration { } else { base_delay.saturating_sub(jitter) }; - + Duration::from_millis(final_delay) } @@ -261,14 +282,18 @@ pub fn calculate_retry_delay(attempt: u32, is_autonomous: bool) -> Duration { if is_autonomous { return calculate_autonomous_retry_delay(attempt); } - + use rand::Rng; - let max_retry_delay_ms = if is_autonomous { AUTONOMOUS_MAX_RETRY_DELAY_MS } else { DEFAULT_MAX_RETRY_DELAY_MS }; - + let max_retry_delay_ms = if is_autonomous { + AUTONOMOUS_MAX_RETRY_DELAY_MS + } else { + DEFAULT_MAX_RETRY_DELAY_MS + }; + // Exponential backoff: delay = base * 2^attempt let base_delay = BASE_RETRY_DELAY_MS * (2_u64.pow(attempt.saturating_sub(1))); let capped_delay = base_delay.min(max_retry_delay_ms); - + // Add jitter to prevent thundering herd let mut rng = rand::thread_rng(); let jitter = (capped_delay as f64 * JITTER_FACTOR * rng.gen::()) as u64; @@ -277,7 +302,7 @@ pub fn calculate_retry_delay(attempt: u32, is_autonomous: bool) -> Duration { } else { capped_delay.saturating_sub(jitter) }; - + Duration::from_millis(final_delay) } @@ -298,7 +323,7 @@ where loop { attempt += 1; - + match operation().await { Ok(result) => { if attempt > 1 { @@ -321,19 +346,19 @@ where context.clone().log_error(&error); return Err(error); } - + let delay = calculate_retry_delay(attempt, is_autonomous); warn!( "Recoverable error ({:?}) in '{}' (attempt {}/{}). Retrying in {:?}...", recoverable_type, operation_name, attempt, max_attempts, delay ); warn!("Error details: {}", error); - + // Special handling for token limit errors if matches!(recoverable_type, RecoverableError::TokenLimit) { info!("Token limit error detected. Consider triggering summarization."); } - + tokio::time::sleep(delay).await; _last_error = Some(error); } @@ -359,18 +384,22 @@ fn truncate_for_logging(s: &str, max_len: usize) -> String { // Find a safe UTF-8 boundary to truncate at // We need to ensure we don't cut in the middle of a multi-byte character let mut truncate_at = max_len; - + // Walk backwards from max_len to find a character boundary while truncate_at > 0 && !s.is_char_boundary(truncate_at) { truncate_at -= 1; } - + // If we couldn't find a boundary (shouldn't happen), use a safe default if truncate_at == 0 { truncate_at = max_len.min(s.len()); } - - format!("{}... (truncated, {} total bytes)", &s[..truncate_at], s.len()) + + format!( + "{}... (truncated, {} total bytes)", + &s[..truncate_at], + s.len() + ) } } @@ -398,42 +427,69 @@ mod tests { fn test_error_classification() { // Rate limit errors let error = anyhow!("Rate limit exceeded"); - assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::RateLimit)); - + assert_eq!( + classify_error(&error), + ErrorType::Recoverable(RecoverableError::RateLimit) + ); + let error = anyhow!("HTTP 429 Too Many Requests"); - assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::RateLimit)); - + assert_eq!( + classify_error(&error), + ErrorType::Recoverable(RecoverableError::RateLimit) + ); + // Network errors let error = anyhow!("Network connection failed"); - assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::NetworkError)); - + assert_eq!( + classify_error(&error), + ErrorType::Recoverable(RecoverableError::NetworkError) + ); + // Server errors let error = anyhow!("HTTP 503 Service Unavailable"); - assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::ServerError)); - + assert_eq!( + classify_error(&error), + ErrorType::Recoverable(RecoverableError::ServerError) + ); + // Model busy let error = anyhow!("Model is busy, please try again"); - assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::ModelBusy)); - + assert_eq!( + classify_error(&error), + ErrorType::Recoverable(RecoverableError::ModelBusy) + ); + // Timeout let error = anyhow!("Request timed out"); - assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::Timeout)); - + assert_eq!( + classify_error(&error), + ErrorType::Recoverable(RecoverableError::Timeout) + ); + // Token limit let error = anyhow!("Token limit exceeded"); - assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::TokenLimit)); - + assert_eq!( + classify_error(&error), + ErrorType::Recoverable(RecoverableError::TokenLimit) + ); + // Context length exceeded let error = anyhow!("HTTP 400 Bad Request: context length exceeded"); - assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::ContextLengthExceeded)); - + assert_eq!( + classify_error(&error), + ErrorType::Recoverable(RecoverableError::ContextLengthExceeded) + ); + let error = anyhow!("Error 400: prompt is too long"); - assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::ContextLengthExceeded)); - + assert_eq!( + classify_error(&error), + ErrorType::Recoverable(RecoverableError::ContextLengthExceeded) + ); + // Non-recoverable let error = anyhow!("Invalid API key"); assert_eq!(classify_error(&error), ErrorType::NonRecoverable); - + let error = anyhow!("Malformed request"); assert_eq!(classify_error(&error), ErrorType::NonRecoverable); } @@ -444,17 +500,17 @@ mod tests { let delay1 = calculate_retry_delay(1, false); let delay2 = calculate_retry_delay(2, false); let delay3 = calculate_retry_delay(3, false); - + // Due to jitter, we can't test exact values, but the base should increase assert!(delay1.as_millis() >= (BASE_RETRY_DELAY_MS as f64 * 0.7) as u128); assert!(delay1.as_millis() <= (BASE_RETRY_DELAY_MS as f64 * 1.3) as u128); - + // Delay 2 should be roughly 2x delay 1 (minus jitter) assert!(delay2.as_millis() >= delay1.as_millis()); - + // Delay 3 should be roughly 2x delay 2 (minus jitter) assert!(delay3.as_millis() >= delay2.as_millis()); - + // Test max cap let delay_max = calculate_retry_delay(10, false); assert!(delay_max.as_millis() <= (DEFAULT_MAX_RETRY_DELAY_MS as f64 * 1.3) as u128); @@ -469,7 +525,7 @@ mod tests { let delay4 = calculate_retry_delay(4, true); let delay5 = calculate_retry_delay(5, true); let delay6 = calculate_retry_delay(6, true); - + // Base delays should be around: 10s, 30s, 60s, 120s, 180s, 200s // With ±30% jitter assert!(delay1.as_millis() >= 7000 && delay1.as_millis() <= 13000); @@ -484,14 +540,14 @@ mod tests { fn test_truncate_for_logging() { let short_text = "Hello, world!"; assert_eq!(truncate_for_logging(short_text, 20), "Hello, world!"); - + let long_text = "This is a very long text that should be truncated for logging purposes"; let truncated = truncate_for_logging(long_text, 20); assert!(truncated.starts_with("This is a very long ")); assert!(truncated.contains("truncated")); assert!(truncated.contains("total bytes")); } - + #[test] fn test_truncate_with_multibyte_chars() { // Test with multi-byte UTF-8 characters @@ -499,7 +555,7 @@ mod tests { let truncated = truncate_for_logging(text_with_emoji, 10); // Should truncate at a valid UTF-8 boundary assert!(truncated.starts_with("Hello ")); - + // Test with box-drawing characters like the one causing the panic let text_with_box = "Some text ┌─────┐ more text"; let truncated = truncate_for_logging(text_with_box, 12); diff --git a/crates/g3-core/src/error_handling_test.rs b/crates/g3-core/src/error_handling_test.rs index 488db08..9e26bad 100644 --- a/crates/g3-core/src/error_handling_test.rs +++ b/crates/g3-core/src/error_handling_test.rs @@ -17,7 +17,7 @@ mod tests { "test prompt".to_string(), None, 100, - false, // quiet parameter + false, // quiet parameter ); let result = retry_with_backoff( @@ -57,7 +57,7 @@ mod tests { "test prompt".to_string(), None, 100, - false, // quiet parameter + false, // quiet parameter ); let result: Result<&str, _> = retry_with_backoff( @@ -91,7 +91,7 @@ mod tests { "test prompt".to_string(), None, 100, - false, // quiet parameter + false, // quiet parameter ); let result: Result<&str, _> = retry_with_backoff( @@ -124,7 +124,7 @@ mod tests { long_prompt, None, 100, - false, // quiet parameter + false, // quiet parameter ); // The prompt should be truncated to 1000 chars diff --git a/crates/g3-core/src/fixed_filter_json.rs b/crates/g3-core/src/fixed_filter_json.rs index b86db42..dac5bfb 100644 --- a/crates/g3-core/src/fixed_filter_json.rs +++ b/crates/g3-core/src/fixed_filter_json.rs @@ -5,7 +5,7 @@ // 4. Return everything else as the final filtered string //! JSON tool call filtering for streaming LLM responses. -//! +//! //! This module filters out JSON tool calls from LLM output streams while preserving //! regular text content. It uses a state machine to handle streaming chunks. @@ -29,7 +29,7 @@ struct FixedJsonToolState { brace_depth: i32, buffer: String, json_start_in_buffer: Option, // Position where confirmed JSON tool call starts - content_returned_up_to: usize, // Track how much content we've already returned + content_returned_up_to: usize, // Track how much content we've already returned potential_json_start: Option, // Where the potential JSON started } diff --git a/crates/g3-core/src/fixed_filter_tests.rs b/crates/g3-core/src/fixed_filter_tests.rs index 39ae617..5d021fb 100644 --- a/crates/g3-core/src/fixed_filter_tests.rs +++ b/crates/g3-core/src/fixed_filter_tests.rs @@ -358,8 +358,8 @@ More text"#; // 2. Then the same complete JSON appears let chunks = vec![ "Some text\n", - r#"{"tool": "str_replace", "args": {"diff":"...","file_path":"./crates/g3-cli"#, // Truncated - r#"{"tool": "str_replace", "args": {"diff":"...","file_path":"./crates/g3-cli/src/lib.rs"}}"#, // Complete + r#"{"tool": "str_replace", "args": {"diff":"...","file_path":"./crates/g3-cli"#, // Truncated + r#"{"tool": "str_replace", "args": {"diff":"...","file_path":"./crates/g3-cli/src/lib.rs"}}"#, // Complete "\nMore text", ]; diff --git a/crates/g3-core/src/lib.rs b/crates/g3-core/src/lib.rs index a4ef797..07574e5 100644 --- a/crates/g3-core/src/lib.rs +++ b/crates/g3-core/src/lib.rs @@ -23,11 +23,12 @@ mod error_handling_test; mod prompts; use anyhow::Result; +use chrono::Local; use g3_computer_control::WebDriverController; use g3_config::Config; use g3_execution::CodeExecutor; use g3_providers::{CacheControl, CompletionRequest, Message, MessageRole, ProviderRegistry, Tool}; -use chrono::Local; +use prompts::{get_system_prompt_for_native, SYSTEM_PROMPT_FOR_NON_NATIVE_TOOL_USE}; #[allow(unused_imports)] use regex::Regex; use serde::{Deserialize, Serialize}; @@ -38,7 +39,6 @@ use std::sync::{Mutex, OnceLock}; use std::time::{Duration, Instant}; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, warn}; -use prompts::{SYSTEM_PROMPT_FOR_NON_NATIVE_TOOL_USE, get_system_prompt_for_native}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolCall { @@ -438,7 +438,10 @@ Format this as a detailed but concise summary that can be used to resume the con self.used_tokens = 0; // Add the summary as a system message - let summary_message = Message::new(MessageRole::System, format!("Previous conversation summary:\n\n{}", summary)); + let summary_message = Message::new( + MessageRole::System, + format!("Previous conversation summary:\n\n{}", summary), + ); self.add_message(summary_message); // Add the latest user message if provided @@ -503,21 +506,27 @@ Format this as a detailed but concise summary that can be used to resume the con let is_todo_result = if i > 0 { if let Some(prev_message) = self.conversation_history.get(i - 1) { if matches!(prev_message.role, MessageRole::Assistant) { - prev_message.content.contains(r#""tool":"todo_read""#) || - prev_message.content.contains(r#""tool":"todo_write""#) || - prev_message.content.contains(r#""tool": "todo_read""#) || - prev_message.content.contains(r#""tool": "todo_write""#) - } else { false } - } else { false } - } else { false }; - + prev_message.content.contains(r#""tool":"todo_read""#) + || prev_message.content.contains(r#""tool":"todo_write""#) + || prev_message.content.contains(r#""tool": "todo_read""#) + || prev_message.content.contains(r#""tool": "todo_write""#) + } else { + false + } + } else { + false + } + } else { + false + }; + if let Some(message) = self.conversation_history.get_mut(i) { // Process User messages that look like tool results if matches!(message.role, MessageRole::User) && message.content.starts_with("Tool result:") { let content_len = message.content.len(); - + // Only thin if the content is greater than 500 chars and not a TODO tool result if !is_todo_result && content_len > 500 { // Generate a unique filename based on timestamp and index @@ -744,9 +753,9 @@ Format this as a detailed but concise summary that can be used to resume the con pub struct Agent { providers: ProviderRegistry, context_window: ContextWindow, - thinning_events: Vec, // chars saved per thinning event - pending_90_summarization: bool, // flag to trigger summarization at 90% - auto_compact: bool, // whether to auto-compact at 90% before tool calls + thinning_events: Vec, // chars saved per thinning event + pending_90_summarization: bool, // flag to trigger summarization at 90% + auto_compact: bool, // whether to auto-compact at 90% before tool calls summarization_events: Vec, // chars saved per summarization event first_token_times: Vec, // time to first token for each completion config: Config, @@ -977,10 +986,10 @@ impl Agent { // For non-native providers (embedded models), use JSON format instructions SYSTEM_PROMPT_FOR_NON_NATIVE_TOOL_USE.to_string() }; - + let system_message = Message::new(MessageRole::System, system_prompt); context_window.add_message(system_message); - + // If README content is provided, add it as a second system message (after the main system prompt) if let Some(readme) = readme_content { let readme_message = Message::new(MessageRole::System, readme); @@ -1017,10 +1026,8 @@ impl Agent { ui_writer, todo_content: std::sync::Arc::new(tokio::sync::RwLock::new({ // Initialize from TODO.md file if it exists - let todo_path = std::env::current_dir() - .ok() - .map(|p| p.join("todo.g3.md")); - + let todo_path = std::env::current_dir().ok().map(|p| p.join("todo.g3.md")); + if let Some(path) = todo_path { std::fs::read_to_string(&path).unwrap_or_default() } else { @@ -1047,7 +1054,7 @@ impl Agent { /// Validate that the system prompt is the first message in the conversation history. /// This is a critical invariant that must be maintained for proper agent operation. - /// + /// /// # Panics /// Panics if: /// - The conversation history is empty @@ -1061,7 +1068,7 @@ impl Agent { } let first_message = &self.context_window.conversation_history[0]; - + if !matches!(first_message.role, MessageRole::System) { panic!( "FATAL: First message is not a System message. Found: {:?}", @@ -1081,7 +1088,10 @@ impl Agent { "5minute" => Some(CacheControl::five_minute()), "1hour" => Some(CacheControl::one_hour()), _ => { - warn!("Invalid cache_config value: '{}'. Valid values are: ephemeral, 5minute, 1hour", cache_config); + warn!( + "Invalid cache_config value: '{}'. Valid values are: ephemeral, 5minute, 1hour", + cache_config + ); None } } @@ -1089,7 +1099,8 @@ impl Agent { /// Count how many cache_control annotations exist in the conversation history fn count_cache_controls_in_history(&self) -> usize { - self.context_window.conversation_history + self.context_window + .conversation_history .iter() .filter(|msg| msg.cache_control.is_some()) .count() @@ -1157,7 +1168,10 @@ impl Agent { ) -> Result { // First, check if there's a global max_context_length override in agent config if let Some(max_context_length) = config.agent.max_context_length { - debug!("Using configured agent.max_context_length: {}", max_context_length); + debug!( + "Using configured agent.max_context_length: {}", + max_context_length + ); return Ok(max_context_length); } @@ -1251,11 +1265,7 @@ impl Agent { let ts = Local::now().format("%Y%m%d_%H%M%S").to_string(); let path = format!("logs/tool_calls_{}.log", ts); - match OpenOptions::new() - .create(true) - .append(true) - .open(&path) - { + match OpenOptions::new().create(true).append(true).open(&path) { Ok(file) => Some(Mutex::new(file)), Err(e) => { error!("Failed to open tool log file {}: {}", path, e); @@ -1402,7 +1412,7 @@ impl Agent { // Reset the JSON tool call filter state at the start of each new task // This prevents the filter from staying in suppression mode between user interactions fixed_filter_json::reset_fixed_json_tool_state(); - + // Validate that the system prompt is the first message (critical invariant) self.validate_system_prompt_is_first(); @@ -1417,12 +1427,21 @@ impl Agent { // But only if we haven't already added 4 cache_control annotations let provider = self.providers.get(None)?; if let Some(cache_config) = match provider.name() { - "anthropic" => self.config.providers.anthropic.as_ref() + "anthropic" => self + .config + .providers + .anthropic + .as_ref() .and_then(|c| c.cache_config.as_ref()) .and_then(|config| Self::parse_cache_control(config)), _ => None, } { - Message::with_cache_control_validated(MessageRole::User, format!("Task: {}", description), cache_config, provider) + Message::with_cache_control_validated( + MessageRole::User, + format!("Task: {}", description), + cache_config, + provider, + ) } else { Message::new(MessageRole::User, format!("Task: {}", description)) } @@ -1431,7 +1450,8 @@ impl Agent { // Execute fast-discovery tool calls if provided (immediately after user message) if let Some(ref options) = discovery_options { - self.ui_writer.println("▶️ Playing back discovery commands..."); + self.ui_writer + .println("▶️ Playing back discovery commands..."); // Store the working directory for subsequent tool calls in the streaming loop if let Some(path) = options.fast_start_path { self.working_dir = Some(path.to_string()); @@ -1439,16 +1459,21 @@ impl Agent { let provider = self.providers.get(None)?; let supports_cache = provider.supports_cache_control(); let message_count = options.messages.len(); - + for (idx, discovery_msg) in options.messages.iter().enumerate() { if let Ok(tool_call) = serde_json::from_str::(&discovery_msg.content) { self.add_message_to_context(discovery_msg.clone()); - let result = self.execute_tool_call_in_dir(&tool_call, options.fast_start_path).await + let result = self + .execute_tool_call_in_dir(&tool_call, options.fast_start_path) + .await .unwrap_or_else(|e| format!("Error: {}", e)); - + // Add cache_control to the last user message if provider supports it (anthropic) let is_last = idx == message_count - 1; - let result_message = if supports_cache && is_last && self.count_cache_controls_in_history() < 4 { + let result_message = if supports_cache + && is_last + && self.count_cache_controls_in_history() < 4 + { Message::with_cache_control( MessageRole::User, format!("Tool result: {}", result), @@ -1539,9 +1564,8 @@ impl Agent { // Check if we need to do 90% auto-compaction if self.pending_90_summarization { - self.ui_writer.print_context_status( - "\n⚡ Context window reached 90% - auto-compacting...\n" - ); + self.ui_writer + .print_context_status("\n⚡ Context window reached 90% - auto-compacting...\n"); if let Err(e) = self.force_summarize().await { warn!("Failed to auto-compact at 90%: {}", e); } else { @@ -1631,6 +1655,144 @@ impl Agent { } } + /// Format token count in compact form (e.g., 1K, 2M, 100b, 200K) and clamp to 4 chars right-aligned + fn format_token_count(tokens: u32) -> String { + let mut raw = if tokens >= 1_000_000_000 { + format!("{}b", tokens / 1_000_000_000) + } else if tokens >= 1_000_000 { + format!("{}M", tokens / 1_000_000) + } else if tokens >= 1_000 { + format!("{}K", tokens / 1_000) + } else { + format!("0K") + }; + + if raw.len() > 4 { + raw.truncate(4); + } + + format!("{:>4}", raw) + } + + /// Pick a single Unicode indicator for token magnitude (maps to previous color bands) + fn token_indicator(tokens: u32) -> &'static str { + if tokens <= 1_000 { + "🟢" + } else if tokens <= 5_000 { + "🟡" + } else if tokens <= 10_000 { + "🟠" + } else if tokens <= 20_000 { + "🔴" + } else { + "🟣" + } + } + + /// Write context window summary to file + /// Format: date&time, token_count, message_id, role, first_100_chars + fn write_context_window_summary(&self) { + // Skip if quiet mode is enabled + if self.quiet { + return; + } + + // Skip if no session ID + let session_id = match &self.session_id { + Some(id) => id, + None => return, + }; + + // Create logs directory if it doesn't exist + let logs_dir = std::path::Path::new("logs"); + if !logs_dir.exists() { + if let Err(e) = std::fs::create_dir_all(logs_dir) { + error!("Failed to create logs directory: {}", e); + return; + } + } + + // Generate filename using same pattern as save_context_window + let filename = format!("logs/context_window_{}.txt", session_id); + let symlink_path = "logs/current_context_window"; + + // Build the summary content + let mut summary_lines = Vec::new(); + + for message in &self.context_window.conversation_history { + let timestamp = chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string(); + + // Estimate tokens for this message + let message_tokens = ContextWindow::estimate_tokens(&message.content); + + // Format token count + let token_str = Self::format_token_count(message_tokens); + + // Get token indicator + let indicator = Self::token_indicator(message_tokens); + + // Get role as string + let role = match message.role { + MessageRole::System => "sys", + MessageRole::User => "usr", + MessageRole::Assistant => "ass", + }; + + // Get first 100 characters of content + let content_preview: String = message.content.chars().take(100).collect(); + + // Replace newlines with spaces for single-line format + let content_preview = content_preview.replace('\n', " ").replace('\r', " "); + + // Format: date&time, indicator + token_count (fixed width), message_id, role, first_100_chars + let line = format!( + "{}, {} {}, {}, {}, {}\n", + timestamp, token_str, indicator, message.id, role, content_preview + ); + + summary_lines.push(line); + } + + // Write to file + let summary_content = summary_lines.join(""); + if let Err(e) = std::fs::write(&filename, summary_content) { + error!( + "Failed to write context window summary to {}: {}", + filename, e + ); + return; + } + + // Update symlink + // Remove old symlink if it exists + let _ = std::fs::remove_file(symlink_path); + + // Create new symlink + #[cfg(unix)] + { + use std::os::unix::fs::symlink; + let target = format!("context_window_{}.txt", session_id); + if let Err(e) = symlink(&target, symlink_path) { + error!("Failed to create symlink {}: {}", symlink_path, e); + } + } + + #[cfg(windows)] + { + use std::os::windows::fs::symlink_file; + let target = format!("context_window_{}.txt", session_id); + if let Err(e) = symlink_file(&target, symlink_path) { + error!("Failed to create symlink {}: {}", symlink_path, e); + } + } + + debug!( + "Context window summary written to {} ({} messages)", + filename, + self.context_window.conversation_history.len() + ); + } + pub fn get_context_window(&self) -> &ContextWindow { &self.context_window } @@ -1649,7 +1811,11 @@ impl Agent { } /// Execute a tool call with an optional working directory (for discovery commands) - pub async fn execute_tool_call_in_dir(&mut self, tool_call: &ToolCall, working_dir: Option<&str>) -> Result { + pub async fn execute_tool_call_in_dir( + &mut self, + tool_call: &ToolCall, + working_dir: Option<&str>, + ) -> Result { self.execute_tool_in_dir(tool_call, working_dir).await } @@ -1748,11 +1914,17 @@ impl Agent { .join("\n\n"); let summary_messages = vec![ - Message::new(MessageRole::System, "You are a helpful assistant that creates concise summaries.".to_string()), - Message::new(MessageRole::User, format!( + Message::new( + MessageRole::System, + "You are a helpful assistant that creates concise summaries.".to_string(), + ), + Message::new( + MessageRole::User, + format!( "Based on this conversation history, {}\n\nConversation:\n{}", summary_prompt, conversation_text - )), + ), + ), ]; let provider = self.providers.get(None)?; @@ -1845,14 +2017,14 @@ impl Agent { let has_readme = self .context_window .conversation_history - .get(1) // Check the SECOND message (index 1) + .get(1) // Check the SECOND message (index 1) .map(|m| { matches!(m.role, MessageRole::System) && (m.content.contains("Project README") || m.content.contains("Agent Configuration")) }) .unwrap_or(false); - + // Validate that the system prompt is still first self.validate_system_prompt_is_first(); @@ -2717,158 +2889,175 @@ impl Agent { "\n🥒 Context window at {}%. Trying thinning first...", self.context_window.percentage_used() as u32 )); - + let (thin_summary, chars_saved) = self.context_window.thin_context(); self.thinning_events.push(chars_saved); self.ui_writer.print_context_thinning(&thin_summary); - + // Check if thinning was sufficient if !self.context_window.should_summarize() { - self.ui_writer.print_context_status("✅ Thinning resolved capacity issue. Continuing...\n"); + self.ui_writer.print_context_status( + "✅ Thinning resolved capacity issue. Continuing...\n", + ); // Continue with the original request without summarization } else { - self.ui_writer.print_context_status("⚠️ Thinning insufficient. Proceeding with summarization...\n"); + self.ui_writer.print_context_status( + "⚠️ Thinning insufficient. Proceeding with summarization...\n", + ); } } - + // Only proceed with summarization if still needed after thinning if self.context_window.should_summarize() { - // Notify user about summarization - self.ui_writer.print_context_status(&format!( - "\n🗜️ Context window reaching capacity ({}%). Creating summary...", - self.context_window.percentage_used() as u32 - )); + // Notify user about summarization + self.ui_writer.print_context_status(&format!( + "\n🗜️ Context window reaching capacity ({}%). Creating summary...", + self.context_window.percentage_used() as u32 + )); - // Create summary request with FULL history - let summary_prompt = self.context_window.create_summary_prompt(); + // Create summary request with FULL history + let summary_prompt = self.context_window.create_summary_prompt(); - // Get the full conversation history - let conversation_text = self - .context_window - .conversation_history - .iter() - .map(|m| format!("{:?}: {}", m.role, m.content)) - .collect::>() - .join("\n\n"); + // Get the full conversation history + let conversation_text = self + .context_window + .conversation_history + .iter() + .map(|m| format!("{:?}: {}", m.role, m.content)) + .collect::>() + .join("\n\n"); - let summary_messages = vec![ - Message::new(MessageRole::System, "You are a helpful assistant that creates concise summaries.".to_string()), - Message::new(MessageRole::User, format!( - "Based on this conversation history, {}\n\nConversation:\n{}", - summary_prompt, conversation_text - )), - ]; + let summary_messages = vec![ + Message::new( + MessageRole::System, + "You are a helpful assistant that creates concise summaries.".to_string(), + ), + Message::new( + MessageRole::User, + format!( + "Based on this conversation history, {}\n\nConversation:\n{}", + summary_prompt, conversation_text + ), + ), + ]; - let provider = self.providers.get(None)?; + let provider = self.providers.get(None)?; - // Dynamically calculate max_tokens for summary based on what's left - // We need to ensure: used_tokens + max_tokens <= total_context_limit - let summary_max_tokens = match provider.name() { - "databricks" | "anthropic" => { - // Use the actual configured context window size - let model_limit = self.context_window.total_tokens; - let current_usage = self.context_window.used_tokens; - - // Check if we have enough capacity for summarization - if current_usage >= model_limit.saturating_sub(1000) { - error!("Context window at capacity ({}%), cannot summarize. Current: {}, Limit: {}", + // Dynamically calculate max_tokens for summary based on what's left + // We need to ensure: used_tokens + max_tokens <= total_context_limit + let summary_max_tokens = match provider.name() { + "databricks" | "anthropic" => { + // Use the actual configured context window size + let model_limit = self.context_window.total_tokens; + let current_usage = self.context_window.used_tokens; + + // Check if we have enough capacity for summarization + if current_usage >= model_limit.saturating_sub(1000) { + error!("Context window at capacity ({}%), cannot summarize. Current: {}, Limit: {}", self.context_window.percentage_used(), current_usage, model_limit); - return Err(anyhow::anyhow!("Context window at capacity. Try using /thinnify or /compact commands to reduce context size, or start a new session.")); + return Err(anyhow::anyhow!("Context window at capacity. Try using /thinnify or /compact commands to reduce context size, or start a new session.")); + } + + // Leave buffer proportional to model size (min 1k, max 10k) + let buffer = (model_limit / 40).clamp(1000, 10000); // 2.5% buffer + let available = model_limit + .saturating_sub(current_usage) + .saturating_sub(buffer); + // Cap at a reasonable summary size (10k tokens max) + Some(available.min(10_000)) } - - // Leave buffer proportional to model size (min 1k, max 10k) - let buffer = (model_limit / 40).clamp(1000, 10000); // 2.5% buffer - let available = model_limit - .saturating_sub(current_usage) - .saturating_sub(buffer); - // Cap at a reasonable summary size (10k tokens max) - Some(available.min(10_000)) - } - "embedded" => { - // For smaller context models, be more conservative - let model_limit = self.context_window.total_tokens; - let current_usage = self.context_window.used_tokens; - - // Check capacity for embedded models too - if current_usage >= model_limit.saturating_sub(500) { - error!("Embedded model context window at capacity ({}%)", self.context_window.percentage_used()); - return Err(anyhow::anyhow!("Context window at capacity. Try using /thinnify command to reduce context size, or start a new session.")); + "embedded" => { + // For smaller context models, be more conservative + let model_limit = self.context_window.total_tokens; + let current_usage = self.context_window.used_tokens; + + // Check capacity for embedded models too + if current_usage >= model_limit.saturating_sub(500) { + error!( + "Embedded model context window at capacity ({}%)", + self.context_window.percentage_used() + ); + return Err(anyhow::anyhow!("Context window at capacity. Try using /thinnify command to reduce context size, or start a new session.")); + } + + // Leave 1k buffer + let available = model_limit + .saturating_sub(current_usage) + .saturating_sub(1000); + // Cap at 3k for embedded models + Some(available.min(3000)) } - - // Leave 1k buffer - let available = model_limit - .saturating_sub(current_usage) - .saturating_sub(1000); - // Cap at 3k for embedded models - Some(available.min(3000)) - } - _ => { - // Default: conservative approach - let model_limit = self.context_window.total_tokens; - let current_usage = self.context_window.used_tokens; - - if current_usage >= model_limit.saturating_sub(1000) { - error!("Context window at capacity ({}%)", self.context_window.percentage_used()); - return Err(anyhow::anyhow!("Context window at capacity. Try using /thinnify or /compact commands, or start a new session.")); + _ => { + // Default: conservative approach + let model_limit = self.context_window.total_tokens; + let current_usage = self.context_window.used_tokens; + + if current_usage >= model_limit.saturating_sub(1000) { + error!( + "Context window at capacity ({}%)", + self.context_window.percentage_used() + ); + return Err(anyhow::anyhow!("Context window at capacity. Try using /thinnify or /compact commands, or start a new session.")); + } + + let available = self.context_window.remaining_tokens().saturating_sub(2000); + Some(available.min(5000)) } - - let available = self.context_window.remaining_tokens().saturating_sub(2000); - Some(available.min(5000)) + }; + + debug!( + "Requesting summary with max_tokens: {:?} (current usage: {} tokens)", + summary_max_tokens, self.context_window.used_tokens + ); + + // Final safety check + if summary_max_tokens.unwrap_or(0) == 0 { + error!("No tokens available for summarization"); + return Err(anyhow::anyhow!("No context window capacity left for summarization. Use /thinnify to reduce context size or start a new session.")); } - }; - debug!( - "Requesting summary with max_tokens: {:?} (current usage: {} tokens)", - summary_max_tokens, self.context_window.used_tokens - ); - - // Final safety check - if summary_max_tokens.unwrap_or(0) == 0 { - error!("No tokens available for summarization"); - return Err(anyhow::anyhow!("No context window capacity left for summarization. Use /thinnify to reduce context size or start a new session.")); - } + let summary_request = CompletionRequest { + messages: summary_messages, + max_tokens: summary_max_tokens, + temperature: Some(0.3), // Lower temperature for factual summary + stream: false, + tools: None, + }; - let summary_request = CompletionRequest { - messages: summary_messages, - max_tokens: summary_max_tokens, - temperature: Some(0.3), // Lower temperature for factual summary - stream: false, - tools: None, - }; + // Get the summary + match provider.complete(summary_request).await { + Ok(summary_response) => { + self.ui_writer.print_context_status( + "✅ Context compacted successfully. Continuing...\n", + ); - // Get the summary - match provider.complete(summary_request).await { - Ok(summary_response) => { - self.ui_writer - .print_context_status("✅ Context compacted successfully. Continuing...\n"); + // Extract the latest user message from the request + let latest_user_msg = request + .messages + .iter() + .rev() + .find(|m| matches!(m.role, MessageRole::User)) + .map(|m| m.content.clone()); - // Extract the latest user message from the request - let latest_user_msg = request - .messages - .iter() - .rev() - .find(|m| matches!(m.role, MessageRole::User)) - .map(|m| m.content.clone()); + // Reset context with summary + let chars_saved = self + .context_window + .reset_with_summary(summary_response.content, latest_user_msg); + self.summarization_events.push(chars_saved); - // Reset context with summary - let chars_saved = self - .context_window - .reset_with_summary(summary_response.content, latest_user_msg); - self.summarization_events.push(chars_saved); - - // Update the request with new context - request.messages = self.context_window.conversation_history.clone(); - } - Err(e) => { - error!("Failed to create summary: {}", e); - self.ui_writer.print_context_status("⚠️ Unable to create summary. Consider starting a new session if you continue to see errors.\n"); - // Don't continue with the original request if summarization failed - // as we're likely at token limit - return Err(anyhow::anyhow!("Context window at capacity and summarization failed. Please start a new session.")); + // Update the request with new context + request.messages = self.context_window.conversation_history.clone(); + } + Err(e) => { + error!("Failed to create summary: {}", e); + self.ui_writer.print_context_status("⚠️ Unable to create summary. Consider starting a new session if you continue to see errors.\n"); + // Don't continue with the original request if summarization failed + // as we're likely at token limit + return Err(anyhow::anyhow!("Context window at capacity and summarization failed. Please start a new session.")); + } } } } - } loop { iteration_count += 1; @@ -2949,6 +3138,9 @@ impl Agent { } }; + // Write context window summary every time we send messages to LLM + self.write_context_window_summary(); + let mut parser = StreamingToolParser::new(); let mut current_response = String::new(); let mut tool_executed = false; @@ -3007,12 +3199,13 @@ impl Agent { let completed_tools = parser.process_chunk(&chunk); // Handle completed tool calls - process all if multiple calls enabled - let tools_to_process: Vec = if self.config.agent.allow_multiple_tool_calls { - completed_tools - } else { - // Original behavior - only take the first tool - completed_tools.into_iter().take(1).collect() - }; + let tools_to_process: Vec = + if self.config.agent.allow_multiple_tool_calls { + completed_tools + } else { + // Original behavior - only take the first tool + completed_tools.into_iter().take(1).collect() + }; // Helper function to check if two tool calls are duplicates let are_duplicates = |tc1: &ToolCall, tc2: &ToolCall| -> bool { @@ -3022,12 +3215,15 @@ impl Agent { // De-duplicate tool calls and track duplicates let mut seen_in_chunk: Vec = Vec::new(); let mut deduplicated_tools: Vec<(ToolCall, Option)> = Vec::new(); - + for tool_call in tools_to_process { let mut duplicate_type = None; - + // Check for duplicates in current chunk - if seen_in_chunk.iter().any(|tc| are_duplicates(tc, &tool_call)) { + if seen_in_chunk + .iter() + .any(|tc| are_duplicates(tc, &tool_call)) + { duplicate_type = Some("DUP IN CHUNK".to_string()); } else { // Check for duplicate against previous message in history @@ -3040,15 +3236,18 @@ impl Agent { // Simple JSON extraction for tool calls let content = &msg.content; let mut start_idx = 0; - while let Some(tool_start) = content[start_idx..].find(r#"{\"tool\""#) { + while let Some(tool_start) = + content[start_idx..].find(r#"{\"tool\""#) + { let tool_start = start_idx + tool_start; // Find the end of this JSON object let mut brace_count = 0; let mut in_string = false; let mut escape_next = false; let mut end_idx = tool_start; - - for (i, ch) in content[tool_start..].char_indices() { + + for (i, ch) in content[tool_start..].char_indices() + { if escape_next { escape_next = false; continue; @@ -3072,10 +3271,12 @@ impl Agent { } } } - + if end_idx > tool_start { let tool_json = &content[tool_start..end_idx]; - if let Ok(prev_tool) = serde_json::from_str::(tool_json) { + if let Ok(prev_tool) = + serde_json::from_str::(tool_json) + { if are_duplicates(&prev_tool, &tool_call) { found_in_prev = true; break; @@ -3089,42 +3290,46 @@ impl Agent { break; } } - + if found_in_prev { duplicate_type = Some("DUP IN MSG".to_string()); } } - + // Add to seen list if not a duplicate in chunk - if duplicate_type.as_ref().map_or(true, |s| s != "DUP IN CHUNK") { + if duplicate_type + .as_ref() + .map_or(true, |s| s != "DUP IN CHUNK") + { seen_in_chunk.push(tool_call.clone()); } - + deduplicated_tools.push((tool_call, duplicate_type)); } // Process each tool call for (tool_call, duplicate_type) in deduplicated_tools { debug!("Processing completed tool call: {:?}", tool_call); - + // If it's a duplicate, log it and return a warning if let Some(dup_type) = &duplicate_type { // Log the duplicate with red prefix - let prefixed_tool_name = format!("🟥 {} {}", tool_call.tool, dup_type); + let prefixed_tool_name = + format!("🟥 {} {}", tool_call.tool, dup_type); let warning_msg = format!( "⚠️ Duplicate tool call detected ({}): Skipping execution of {} with args {}", dup_type, tool_call.tool, serde_json::to_string(&tool_call.args).unwrap_or_else(|_| "".to_string()) ); - + // Log to tool log with red prefix let mut modified_tool_call = tool_call.clone(); modified_tool_call.tool = prefixed_tool_name; self.log_tool_call(&modified_tool_call, &warning_msg); continue; // Skip execution of duplicate } - + // Check if we should auto-compact at 90% BEFORE executing the tool // We need to do this before any borrows of self if self.auto_compact && self.context_window.percentage_used() >= 90.0 { @@ -3132,7 +3337,7 @@ impl Agent { // We can't do it now due to borrow checker constraints self.pending_90_summarization = true; } - + // Check if we should thin the context BEFORE executing the tool if self.context_window.should_thin() { let (thin_summary, chars_saved) = @@ -3142,7 +3347,6 @@ impl Agent { self.ui_writer.print_context_thinning(&thin_summary); } - // Track what we've already displayed before getting new text // This prevents re-displaying old content after tool execution let already_displayed_chars = current_response.chars().count(); @@ -3312,8 +3516,12 @@ impl Agent { break; } // Clip line to max width (but not for todo tools) - let clipped_line = truncate_line(line, MAX_LINE_WIDTH, !wants_full && !is_todo_tool); - + let clipped_line = truncate_line( + line, + MAX_LINE_WIDTH, + !wants_full && !is_todo_tool, + ); + // Use print_tool_output_line for todo tools to get special formatting if is_todo_tool { self.ui_writer.print_tool_output_line(&clipped_line); @@ -3364,36 +3572,60 @@ impl Agent { // Add the tool call and result to the context window using RAW unfiltered content // This ensures the log file contains the true raw content including JSON tool calls let tool_message = if !raw_content_for_log.trim().is_empty() { - Message::new(MessageRole::Assistant, format!( + Message::new( + MessageRole::Assistant, + format!( "{}\n\n{{\"tool\": \"{}\", \"args\": {}}}", raw_content_for_log.trim(), tool_call.tool, tool_call.args - )) + ), + ) } else { // No text content before tool call, just include the tool call - Message::new(MessageRole::Assistant, format!( + Message::new( + MessageRole::Assistant, + format!( "{{\"tool\": \"{}\", \"args\": {}}}", tool_call.tool, tool_call.args - )) + ), + ) }; let result_message = { // Check if we should use cache control (every 10 tool calls) // But only if we haven't already added 4 cache_control annotations - if self.tool_call_count > 0 && self.tool_call_count % 10 == 0 && self.count_cache_controls_in_history() < 4 { + if self.tool_call_count > 0 + && self.tool_call_count % 10 == 0 + && self.count_cache_controls_in_history() < 4 + { let provider = self.providers.get(None)?; if let Some(cache_config) = match provider.name() { - "anthropic" => self.config.providers.anthropic.as_ref() + "anthropic" => self + .config + .providers + .anthropic + .as_ref() .and_then(|c| c.cache_config.as_ref()) .and_then(|config| Self::parse_cache_control(config)), _ => None, } { - Message::with_cache_control_validated(MessageRole::User, format!("Tool result: {}", tool_result), cache_config, provider) + Message::with_cache_control_validated( + MessageRole::User, + format!("Tool result: {}", tool_result), + cache_config, + provider, + ) } else { - Message::new(MessageRole::User, format!("Tool result: {}", tool_result)) + Message::new( + MessageRole::User, + format!("Tool result: {}", tool_result), + ) } } else { - Message::new(MessageRole::User, format!("Tool result: {}", tool_result)) + Message::new( + MessageRole::User, + format!("Tool result: {}", tool_result), + ) } }; @@ -3434,7 +3666,7 @@ impl Agent { current_response.clear(); // Reset response_started flag for next iteration response_started = false; - + // For single tool mode, break immediately if !self.config.agent.allow_multiple_tool_calls { break; // Break out of current stream to start a new one @@ -3526,8 +3758,7 @@ impl Agent { error!("Iteration: {}/{}", iteration_count, MAX_ITERATIONS); error!( "Provider: {} (model: {})", - provider_name, - provider_model + provider_name, provider_model ); error!("Chunks received: {}", chunks_received); error!("Parser state:"); @@ -3648,25 +3879,35 @@ impl Agent { Err(e) => { // Capture detailed streaming error information let error_msg = e.to_string(); - let error_details = format!("Streaming error at chunk {}: {}", chunks_received + 1, error_msg); - + let error_details = format!( + "Streaming error at chunk {}: {}", + chunks_received + 1, + error_msg + ); + error!("Error type: {}", std::any::type_name_of_val(&e)); error!("Parser state at error: text_buffer_len={}, native_tool_calls={}, message_stopped={}", parser.text_buffer_len(), parser.native_tool_calls.len(), parser.is_message_stopped()); // Store the error for potential logging later _last_error = Some(error_details.clone()); - + // Check if this is a recoverable connection error - let is_connection_error = error_msg.contains("unexpected EOF") - || error_msg.contains("connection") + let is_connection_error = error_msg.contains("unexpected EOF") + || error_msg.contains("connection") || error_msg.contains("chunk size line") || error_msg.contains("body error"); - + if is_connection_error { - warn!("Connection error at chunk {}, treating as end of stream", chunks_received + 1); + warn!( + "Connection error at chunk {}, treating as end of stream", + chunks_received + 1 + ); // If we have any content or tool calls, treat this as a graceful end - if chunks_received > 0 && (!parser.get_text_content().is_empty() || parser.native_tool_calls.len() > 0) { + if chunks_received > 0 + && (!parser.get_text_content().is_empty() + || parser.native_tool_calls.len() > 0) + { warn!("Stream terminated unexpectedly but we have content, continuing"); break; // Break to process what we have } @@ -3793,7 +4034,11 @@ impl Agent { } /// Execute a tool with an optional working directory (for discovery commands) - pub async fn execute_tool_in_dir(&mut self, tool_call: &ToolCall, working_dir: Option<&str>) -> Result { + pub async fn execute_tool_in_dir( + &mut self, + tool_call: &ToolCall, + working_dir: Option<&str>, + ) -> Result { // Only increment tool call count if not already incremented by execute_tool if working_dir.is_some() { self.tool_call_count += 1; @@ -3808,10 +4053,17 @@ impl Agent { result } - async fn execute_tool_inner_in_dir(&mut self, tool_call: &ToolCall, working_dir: Option<&str>) -> Result { + async fn execute_tool_inner_in_dir( + &mut self, + tool_call: &ToolCall, + working_dir: Option<&str>, + ) -> Result { debug!("=== EXECUTING TOOL ==="); debug!("Tool name: {}", tool_call.tool); - debug!("Working directory passed to execute_tool_inner_in_dir: {:?}", working_dir); + debug!( + "Working directory passed to execute_tool_inner_in_dir: {:?}", + working_dir + ); debug!("Tool args (raw): {:?}", tool_call.args); debug!( "Tool args (JSON): {}", @@ -3846,7 +4098,7 @@ impl Agent { let receiver = ToolOutputReceiver { ui_writer: &self.ui_writer, }; - + debug!("ABOUT TO CALL execute_bash_streaming_in_dir: escaped_command='{}', working_dir={:?}", escaped_command, working_dir); match executor @@ -4346,7 +4598,7 @@ impl Agent { debug!("Processing todo_read tool call"); // Read from todo.g3.md file in current workspace directory let todo_path = std::env::current_dir()?.join("todo.g3.md"); - + if !todo_path.exists() { // Also update in-memory content to stay in sync let mut todo = self.todo_content.write().await; @@ -4358,34 +4610,44 @@ impl Agent { // Update in-memory content to stay in sync let mut todo = self.todo_content.write().await; *todo = content.clone(); - + // Check for staleness if enabled and we have a requirements SHA if self.config.agent.check_todo_staleness { if let Some(req_sha) = &self.requirements_sha { // Parse the first line for the SHA header if let Some(first_line) = content.lines().next() { - if first_line.starts_with("{{Based on the requirements file with SHA256:") { - let parts: Vec<&str> = first_line.split("SHA256:").collect(); + if first_line.starts_with( + "{{Based on the requirements file with SHA256:", + ) { + let parts: Vec<&str> = + first_line.split("SHA256:").collect(); if parts.len() > 1 { - let todo_sha = parts[1].trim().trim_end_matches("}}").trim(); + let todo_sha = + parts[1].trim().trim_end_matches("}}").trim(); if todo_sha != req_sha { let warning = format!( "⚠️ TODO list is stale! It was generated from a different requirements file.\nExpected SHA: {}\nFound SHA: {}", req_sha, todo_sha ); self.ui_writer.print_context_status(&warning); - + // Beep 6 times print!("\x07\x07\x07\x07\x07\x07"); let _ = std::io::stdout().flush(); - - let options = ["Ignore and Continue", "Mark as Stale", "Quit Application"]; + + let options = [ + "Ignore and Continue", + "Mark as Stale", + "Quit Application", + ]; let choice = self.ui_writer.prompt_user_choice("Requirements have changed! What would you like to do?", &options); - + match choice { 0 => { // Ignore and Continue - self.ui_writer.print_context_status("⚠️ Ignoring staleness warning."); + self.ui_writer.print_context_status( + "⚠️ Ignoring staleness warning.", + ); } 1 => { // Mark as Stale @@ -4438,13 +4700,16 @@ impl Agent { // Write to todo.g3.md file in current workspace directory let todo_path = std::env::current_dir()?.join("todo.g3.md"); - + match std::fs::write(&todo_path, content_str) { Ok(_) => { // Also update in-memory content to stay in sync let mut todo = self.todo_content.write().await; *todo = content_str.to_string(); - Ok(format!("✅ TODO list updated ({} chars) and saved to todo.g3.md", char_count)) + Ok(format!( + "✅ TODO list updated ({} chars) and saved to todo.g3.md", + char_count + )) } Err(e) => Ok(format!("❌ Failed to write todo.g3.md: {}", e)), } @@ -4457,32 +4722,35 @@ impl Agent { } "code_coverage" => { debug!("Processing code_coverage tool call"); - self.ui_writer.print_context_status("🔍 Generating code coverage report..."); - + self.ui_writer + .print_context_status("🔍 Generating code coverage report..."); + // Ensure coverage tools are installed match g3_execution::ensure_coverage_tools_installed() { Ok(already_installed) => { if !already_installed { - self.ui_writer.print_context_status("✅ Coverage tools installed successfully"); + self.ui_writer + .print_context_status("✅ Coverage tools installed successfully"); } } Err(e) => { return Ok(format!("❌ Failed to install coverage tools: {}", e)); } } - + // Run cargo llvm-cov --workspace let output = std::process::Command::new("cargo") .args(&["llvm-cov", "--workspace"]) .current_dir(std::env::current_dir()?) .output()?; - + if output.status.success() { let stdout = String::from_utf8_lossy(&output.stdout); let stderr = String::from_utf8_lossy(&output.stderr); - + // Combine output - let mut result = String::from("✅ Code coverage report generated successfully\n\n"); + let mut result = + String::from("✅ Code coverage report generated successfully\n\n"); result.push_str("## Coverage Summary\n"); result.push_str(&stdout); if !stderr.is_empty() { @@ -4492,7 +4760,10 @@ impl Agent { Ok(result) } else { let stderr = String::from_utf8_lossy(&output.stderr); - Ok(format!("❌ Failed to generate coverage report:\n{}", stderr)) + Ok(format!( + "❌ Failed to generate coverage report:\n{}", + stderr + )) } } "webdriver_start" => { @@ -5426,9 +5697,7 @@ impl Agent { Err(e) => Ok(format!("❌ Failed to serialize response: {}", e)), } } - Err(e) => { - Ok(format!("❌ Code search failed: {}", e)) - } + Err(e) => Ok(format!("❌ Code search failed: {}", e)), } } _ => { @@ -5891,10 +6160,13 @@ impl Drop for Agent { if let Err(e) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { self.validate_system_prompt_is_first(); })) { - eprintln!("\n⚠️ FATAL ERROR ON EXIT: System prompt validation failed: {:?}", e); + eprintln!( + "\n⚠️ FATAL ERROR ON EXIT: System prompt validation failed: {:?}", + e + ); } } - + // Try to kill safaridriver process if it's still running // We need to use try_lock since we can't await in Drop if let Ok(mut process_guard) = self.safaridriver_process.try_write() { diff --git a/crates/g3-core/src/project.rs b/crates/g3-core/src/project.rs index 0263b7c..17e933c 100644 --- a/crates/g3-core/src/project.rs +++ b/crates/g3-core/src/project.rs @@ -7,19 +7,19 @@ use std::path::{Path, PathBuf}; pub struct Project { /// The workspace directory for the project pub workspace_dir: PathBuf, - + /// Path to the requirements document (for autonomous mode) pub requirements_path: Option, - + /// Override requirements text (takes precedence over requirements_path) pub requirements_text: Option, - + /// Whether the project is in autonomous mode pub autonomous: bool, - + /// Project name (derived from workspace directory name) pub name: String, - + /// Session ID for tracking pub session_id: Option, } @@ -32,7 +32,7 @@ impl Project { .and_then(|n| n.to_str()) .unwrap_or("unnamed") .to_string(); - + Self { workspace_dir, requirements_path: None, @@ -42,33 +42,36 @@ impl Project { session_id: None, } } - + /// Create a project for autonomous mode pub fn new_autonomous(workspace_dir: PathBuf) -> Result { let mut project = Self::new(workspace_dir.clone()); project.autonomous = true; - + // Look for requirements.md in the workspace directory let requirements_path = workspace_dir.join("requirements.md"); if requirements_path.exists() { project.requirements_path = Some(requirements_path); } - + Ok(project) } - + /// Create a project for autonomous mode with requirements text override - pub fn new_autonomous_with_requirements(workspace_dir: PathBuf, requirements_text: String) -> Result { + pub fn new_autonomous_with_requirements( + workspace_dir: PathBuf, + requirements_text: String, + ) -> Result { let mut project = Self::new(workspace_dir.clone()); project.autonomous = true; project.requirements_text = Some(requirements_text); - + // Don't look for requirements.md file when text is provided // The text override takes precedence - + Ok(project) } - + /// Set the workspace directory and update related paths pub fn set_workspace(&mut self, workspace_dir: PathBuf) { self.workspace_dir = workspace_dir.clone(); @@ -77,7 +80,7 @@ impl Project { .and_then(|n| n.to_str()) .unwrap_or("unnamed") .to_string(); - + // Update requirements path if in autonomous mode if self.autonomous { let requirements_path = workspace_dir.join("requirements.md"); @@ -86,18 +89,18 @@ impl Project { } } } - + /// Get the workspace directory pub fn workspace(&self) -> &Path { &self.workspace_dir } - + /// Check if requirements file exists pub fn has_requirements(&self) -> bool { // Has requirements if either text override is provided or requirements file exists self.requirements_text.is_some() || self.requirements_path.is_some() } - + /// Read the requirements file content pub fn read_requirements(&self) -> Result> { // Prioritize requirements text override @@ -110,7 +113,7 @@ impl Project { Ok(None) } } - + /// Create the workspace directory if it doesn't exist pub fn ensure_workspace_exists(&self) -> Result<()> { if !self.workspace_dir.exists() { @@ -118,18 +121,18 @@ impl Project { } Ok(()) } - + /// Change to the workspace directory pub fn enter_workspace(&self) -> Result<()> { std::env::set_current_dir(&self.workspace_dir)?; Ok(()) } - + /// Get the logs directory for the project pub fn logs_dir(&self) -> PathBuf { self.workspace_dir.join("logs") } - + /// Ensure the logs directory exists pub fn ensure_logs_dir(&self) -> Result<()> { let logs_dir = self.logs_dir(); diff --git a/crates/g3-core/src/prompts.rs b/crates/g3-core/src/prompts.rs index 206f41d..79c323b 100644 --- a/crates/g3-core/src/prompts.rs +++ b/crates/g3-core/src/prompts.rs @@ -189,7 +189,7 @@ Do not explain what you're going to do - just do it by calling the tools. "; pub const SYSTEM_PROMPT_FOR_NATIVE_TOOL_USE: &'static str = -concatcp!(SYSTEM_NATIVE_TOOL_CALLS, CODING_STYLE); + concatcp!(SYSTEM_NATIVE_TOOL_CALLS, CODING_STYLE); /// Generate system prompt based on whether multiple tool calls are allowed pub fn get_system_prompt_for_native(allow_multiple: bool) -> String { diff --git a/crates/g3-core/src/task_result.rs b/crates/g3-core/src/task_result.rs index b896c9f..bf8fe48 100644 --- a/crates/g3-core/src/task_result.rs +++ b/crates/g3-core/src/task_result.rs @@ -30,7 +30,7 @@ impl TaskResult { // Look for the final_output marker pattern // The final_output content typically appears after the tool is called // and is the substantive content that follows - + // First, try to find if there's a clear final_output section // This would be the content after the last tool execution if let Some(final_output_pos) = content_without_timing.rfind("final_output") { @@ -39,7 +39,7 @@ impl TaskResult { if let Some(content_start) = content_without_timing[final_output_pos..].find('\n') { let start_pos = final_output_pos + content_start + 1; let final_content = &content_without_timing[start_pos..]; - + // Trim and return the complete content let trimmed = final_content.trim(); if !trimmed.is_empty() { @@ -47,7 +47,7 @@ impl TaskResult { } } } - + // Fallback to the original extract_last_block behavior if we can't find final_output // This maintains backward compatibility self.extract_last_block() @@ -62,12 +62,13 @@ impl TaskResult { } else { &self.response }; - + // Split by double newlines to find the last substantial block let blocks: Vec<&str> = content_without_timing.split("\n\n").collect(); - + // Find the last non-empty block that isn't just whitespace - blocks.iter() + blocks + .iter() .rev() .find(|block| !block.trim().is_empty()) .map(|block| block.trim().to_string()) @@ -79,7 +80,8 @@ impl TaskResult { /// Check if the response contains an approval (for autonomous mode) pub fn is_approved(&self) -> bool { - self.extract_final_output().contains("IMPLEMENTATION_APPROVED") + self.extract_final_output() + .contains("IMPLEMENTATION_APPROVED") } } @@ -91,20 +93,21 @@ mod tests { fn test_extract_last_block() { // Test case 1: Response with timing info let context_window = ContextWindow::new(1000); - let response_with_timing = "Some initial content\n\nFinal block content\n\n⏱️ 2.3s | 💭 1.2s".to_string(); + let response_with_timing = + "Some initial content\n\nFinal block content\n\n⏱️ 2.3s | 💭 1.2s".to_string(); let result = TaskResult::new(response_with_timing, context_window.clone()); assert_eq!(result.extract_last_block(), "Final block content"); - + // Test case 2: Response without timing let response_no_timing = "Some initial content\n\nFinal block content".to_string(); let result = TaskResult::new(response_no_timing, context_window.clone()); assert_eq!(result.extract_last_block(), "Final block content"); - + // Test case 3: Response with IMPLEMENTATION_APPROVED let response_approved = "Some content\n\nIMPLEMENTATION_APPROVED".to_string(); let result = TaskResult::new(response_approved, context_window.clone()); assert!(result.is_approved()); - + // Test case 4: Response without approval let response_not_approved = "Some content\n\nNeeds more work".to_string(); let result = TaskResult::new(response_not_approved, context_window); @@ -114,17 +117,17 @@ mod tests { #[test] fn test_extract_last_block_edge_cases() { let context_window = ContextWindow::new(1000); - + // Test empty response let empty_response = "".to_string(); let result = TaskResult::new(empty_response, context_window.clone()); assert_eq!(result.extract_last_block(), ""); - + // Test single block let single_block = "Just one block".to_string(); let result = TaskResult::new(single_block, context_window.clone()); assert_eq!(result.extract_last_block(), "Just one block"); - + // Test multiple empty blocks let multiple_empty = "\n\n\n\nSome content\n\n\n\n".to_string(); let result = TaskResult::new(multiple_empty, context_window); @@ -134,18 +137,22 @@ mod tests { #[test] fn test_extract_final_output() { let context_window = ContextWindow::new(1000); - + // Test case 1: Response with final_output tool call let response_with_final_output = "Analyzing files...\n\nCalling final_output\n\nThis is the complete feedback\nwith multiple lines\nand important details\n\n⏱️ 2.3s".to_string(); let result = TaskResult::new(response_with_final_output, context_window.clone()); - assert_eq!(result.extract_final_output(), "This is the complete feedback\nwith multiple lines\nand important details"); - + assert_eq!( + result.extract_final_output(), + "This is the complete feedback\nwith multiple lines\nand important details" + ); + // Test case 2: Response with IMPLEMENTATION_APPROVED in final_output - let response_approved = "Review complete\n\nfinal_output called\n\nIMPLEMENTATION_APPROVED".to_string(); + let response_approved = + "Review complete\n\nfinal_output called\n\nIMPLEMENTATION_APPROVED".to_string(); let result = TaskResult::new(response_approved, context_window.clone()); assert_eq!(result.extract_final_output(), "IMPLEMENTATION_APPROVED"); assert!(result.is_approved()); - + // Test case 3: Response with detailed feedback in final_output let response_feedback = "Checking implementation...\n\nfinal_output\n\nThe following issues need to be addressed:\n1. Missing error handling in main.rs\n2. Tests are not comprehensive\n3. Documentation needs improvement\n\nPlease fix these issues.".to_string(); let result = TaskResult::new(response_feedback, context_window.clone()); @@ -154,12 +161,12 @@ mod tests { assert!(extracted.contains("1. Missing error handling")); assert!(extracted.contains("Please fix these issues.")); assert!(!result.is_approved()); - + // Test case 4: Response without final_output (fallback to extract_last_block) let response_no_final_output = "Some analysis\n\nFinal thoughts here".to_string(); let result = TaskResult::new(response_no_final_output, context_window.clone()); assert_eq!(result.extract_final_output(), "Final thoughts here"); - + // Test case 5: Empty response let empty_response = "".to_string(); let result = TaskResult::new(empty_response, context_window); diff --git a/crates/g3-core/src/task_result_comprehensive_tests.rs b/crates/g3-core/src/task_result_comprehensive_tests.rs index 0f15f49..8d59459 100644 --- a/crates/g3-core/src/task_result_comprehensive_tests.rs +++ b/crates/g3-core/src/task_result_comprehensive_tests.rs @@ -6,15 +6,19 @@ use std::sync::Arc; fn test_task_result_basic_functionality() { // Create a context window with some messages let mut context = ContextWindow::new(10000); - context.add_message(Message::new(MessageRole::User, "Test message 1".to_string()) - ); - context.add_message(Message::new(MessageRole::Assistant, "Response 1".to_string()) - ); - + context.add_message(Message::new( + MessageRole::User, + "Test message 1".to_string(), + )); + context.add_message(Message::new( + MessageRole::Assistant, + "Response 1".to_string(), + )); + // Create a TaskResult let response = "This is the response\n\nFinal output block".to_string(); let result = TaskResult::new(response.clone(), context.clone()); - + // Test basic properties assert_eq!(result.response, response); assert_eq!(result.context_window.conversation_history.len(), 2); @@ -24,32 +28,32 @@ fn test_task_result_basic_functionality() { #[test] fn test_extract_last_block_various_formats() { let context = ContextWindow::new(1000); - + // Test 1: Standard format with multiple blocks let response1 = "First block\n\nSecond block\n\nThird block".to_string(); let result1 = TaskResult::new(response1, context.clone()); assert_eq!(result1.extract_last_block(), "Third block"); - + // Test 2: With timing information let response2 = "Content\n\nFinal block\n\n⏱️ 2.3s | 💭 1.2s".to_string(); let result2 = TaskResult::new(response2, context.clone()); assert_eq!(result2.extract_last_block(), "Final block"); - + // Test 3: Single line response let response3 = "Single line response".to_string(); let result3 = TaskResult::new(response3, context.clone()); assert_eq!(result3.extract_last_block(), "Single line response"); - + // Test 4: Empty response let response4 = "".to_string(); let result4 = TaskResult::new(response4, context.clone()); assert_eq!(result4.extract_last_block(), ""); - + // Test 5: Only whitespace let response5 = "\n\n\n \n\n".to_string(); let result5 = TaskResult::new(response5, context.clone()); assert_eq!(result5.extract_last_block(), ""); - + // Test 6: Multiple blocks with empty ones let response6 = "First\n\n\n\n\n\nLast block here".to_string(); let result6 = TaskResult::new(response6, context.clone()); @@ -59,7 +63,7 @@ fn test_extract_last_block_various_formats() { #[test] fn test_is_approved_detection() { let context = ContextWindow::new(1000); - + // Test approved cases let approved_responses = vec![ "Analysis complete\n\nIMPLEMENTATION_APPROVED", @@ -67,12 +71,16 @@ fn test_is_approved_detection() { "IMPLEMENTATION_APPROVED", "Review done\n\n✅ IMPLEMENTATION_APPROVED - All tests pass", ]; - + for response in approved_responses { let result = TaskResult::new(response.to_string(), context.clone()); - assert!(result.is_approved(), "Failed to detect approval in: {}", response); + assert!( + result.is_approved(), + "Failed to detect approval in: {}", + response + ); } - + // Test not approved cases let not_approved_responses = vec![ "Needs more work", @@ -81,10 +89,14 @@ fn test_is_approved_detection() { "Almost there but not APPROVED", "", ]; - + for response in not_approved_responses { let result = TaskResult::new(response.to_string(), context.clone()); - assert!(!result.is_approved(), "Incorrectly detected approval in: {}", response); + assert!( + !result.is_approved(), + "Incorrectly detected approval in: {}", + response + ); } } @@ -93,33 +105,46 @@ fn test_context_window_preservation() { // Create a context window with specific state let mut context = ContextWindow::new(5000); context.used_tokens = 1234; - + // Add some messages for i in 0..5 { - context.add_message(Message::new(if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant }, format!("Message {}", i))); + context.add_message(Message::new( + if i % 2 == 0 { + MessageRole::User + } else { + MessageRole::Assistant + }, + format!("Message {}", i), + )); } - + // Create TaskResult let result = TaskResult::new("Response".to_string(), context.clone()); - + // Verify context is preserved assert_eq!(result.context_window.total_tokens, 5000); assert!(result.context_window.used_tokens > 1234); // Should have increased assert_eq!(result.context_window.conversation_history.len(), 5); - + // Verify messages are preserved correctly for i in 0..5 { - let is_user = matches!(result.context_window.conversation_history[i].role, MessageRole::User); + let is_user = matches!( + result.context_window.conversation_history[i].role, + MessageRole::User + ); let expected_is_user = i % 2 == 0; assert_eq!(is_user, expected_is_user, "Message {} has wrong role", i); - assert_eq!(result.context_window.conversation_history[i].content, format!("Message {}", i)); + assert_eq!( + result.context_window.conversation_history[i].content, + format!("Message {}", i) + ); } } #[test] fn test_coach_feedback_extraction_scenarios() { let context = ContextWindow::new(1000); - + // Scenario 1: Coach feedback with file operations and analysis let coach_response = r#"Reading file: src/main.rs 📄 File content (23 lines): @@ -133,13 +158,13 @@ The implementation needs the following fixes: 1. Add error handling 2. Implement missing functions 3. Add tests"#; - + let result = TaskResult::new(coach_response.to_string(), context.clone()); let feedback = result.extract_last_block(); assert!(feedback.contains("Add error handling")); assert!(feedback.contains("Implement missing functions")); assert!(feedback.contains("Add tests")); - + // Scenario 2: Coach approval let approval_response = r#"Checking compilation... ✅ Build successful @@ -148,11 +173,11 @@ Running tests... ✅ All tests pass IMPLEMENTATION_APPROVED"#; - + let result = TaskResult::new(approval_response.to_string(), context.clone()); assert!(result.is_approved()); assert_eq!(result.extract_last_block(), "IMPLEMENTATION_APPROVED"); - + // Scenario 3: Complex feedback with timing let complex_response = r#"Tool execution log... @@ -163,7 +188,7 @@ The following issues were found: - Missing input validation ⏱️ 5.2s | 💭 2.1s"#; - + let result = TaskResult::new(complex_response.to_string(), context.clone()); let feedback = result.extract_last_block(); assert!(feedback.contains("Memory leak")); @@ -174,17 +199,18 @@ The following issues were found: #[test] fn test_edge_cases_and_special_characters() { let context = ContextWindow::new(1000); - + // Test with special characters and emojis let response_with_emojis = "First part 🚀\n\n✅ Final part with emojis 🎉".to_string(); let result = TaskResult::new(response_with_emojis, context.clone()); assert_eq!(result.extract_last_block(), "✅ Final part with emojis 🎉"); - + // Test with code blocks - let response_with_code = "Explanation\n\n```rust\nfn main() {}\n```\n\nFinal comment".to_string(); + let response_with_code = + "Explanation\n\n```rust\nfn main() {}\n```\n\nFinal comment".to_string(); let result = TaskResult::new(response_with_code, context.clone()); assert_eq!(result.extract_last_block(), "Final comment"); - + // Test with mixed newlines let mixed_newlines = "Part 1\r\n\r\nPart 2\n\nPart 3".to_string(); let result = TaskResult::new(mixed_newlines, context.clone()); @@ -194,30 +220,33 @@ fn test_edge_cases_and_special_characters() { #[test] fn test_large_response_handling() { let context = ContextWindow::new(100000); - + // Create a large response let mut large_response = String::new(); for i in 0..100 { large_response.push_str(&format!("Block {} with some content\n\n", i)); } large_response.push_str("This is the final block after 100 other blocks"); - + let result = TaskResult::new(large_response, context); - assert_eq!(result.extract_last_block(), "This is the final block after 100 other blocks"); + assert_eq!( + result.extract_last_block(), + "This is the final block after 100 other blocks" + ); } #[test] fn test_concurrent_access() { use std::thread; - + let context = ContextWindow::new(1000); let result = Arc::new(TaskResult::new( "Concurrent test\n\nFinal block".to_string(), context, )); - + let mut handles = vec![]; - + // Spawn multiple threads to access the TaskResult for _ in 0..10 { let result_clone = Arc::clone(&result); @@ -225,16 +254,15 @@ fn test_concurrent_access() { // Each thread extracts the last block let block = result_clone.extract_last_block(); assert_eq!(block, "Final block"); - + // Check approval status assert!(!result_clone.is_approved()); }); handles.push(handle); } - + // Wait for all threads to complete for handle in handles { handle.join().unwrap(); } } - diff --git a/crates/g3-core/src/tilde_expansion_tests.rs b/crates/g3-core/src/tilde_expansion_tests.rs index e3a0ebc..5320efe 100644 --- a/crates/g3-core/src/tilde_expansion_tests.rs +++ b/crates/g3-core/src/tilde_expansion_tests.rs @@ -7,10 +7,10 @@ mod tilde_expansion_tests { // Test that shellexpand works let path_with_tilde = "~/test.txt"; let expanded = shellexpand::tilde(path_with_tilde); - + // Get the actual home directory let home = env::var("HOME").expect("HOME environment variable not set"); - + // Verify expansion happened assert_eq!(expanded.as_ref(), format!("{}/test.txt", home)); assert!(!expanded.contains("~")); @@ -20,9 +20,9 @@ mod tilde_expansion_tests { fn test_tilde_expansion_with_subdirs() { let path_with_tilde = "~/Documents/test.txt"; let expanded = shellexpand::tilde(path_with_tilde); - + let home = env::var("HOME").expect("HOME environment variable not set"); - + assert_eq!(expanded.as_ref(), format!("{}/Documents/test.txt", home)); } @@ -30,7 +30,7 @@ mod tilde_expansion_tests { fn test_no_tilde_unchanged() { let path_without_tilde = "/absolute/path/test.txt"; let expanded = shellexpand::tilde(path_without_tilde); - + assert_eq!(expanded.as_ref(), path_without_tilde); } } diff --git a/crates/g3-core/src/ui_writer.rs b/crates/g3-core/src/ui_writer.rs index e817b49..3d66d37 100644 --- a/crates/g3-core/src/ui_writer.rs +++ b/crates/g3-core/src/ui_writer.rs @@ -4,58 +4,60 @@ pub trait UiWriter: Send + Sync { /// Print a simple message fn print(&self, message: &str); - + /// Print a message with a newline fn println(&self, message: &str); - + /// Print without newline (for progress indicators) fn print_inline(&self, message: &str); - + /// Print a system prompt section fn print_system_prompt(&self, prompt: &str); - + /// Print a context window status message fn print_context_status(&self, message: &str); - + /// Print a context thinning success message with highlight and animation fn print_context_thinning(&self, message: &str); - + /// Print a tool execution header fn print_tool_header(&self, tool_name: &str); - + /// Print a tool argument fn print_tool_arg(&self, key: &str, value: &str); - + /// Print tool output header fn print_tool_output_header(&self); - + /// Update the current tool output line (replaces previous line) fn update_tool_output_line(&self, line: &str); - + /// Print a tool output line fn print_tool_output_line(&self, line: &str); - + /// Print tool output summary (when output is truncated) fn print_tool_output_summary(&self, hidden_count: usize); - + /// Print tool execution timing fn print_tool_timing(&self, duration_str: &str); - + /// Print the agent prompt indicator fn print_agent_prompt(&self); - + /// Print agent response inline (for streaming) fn print_agent_response(&self, content: &str); - + /// Notify that an SSE event was received (including pings) fn notify_sse_received(&self); - + /// Flush any buffered output fn flush(&self); - + /// Returns true if this UI writer wants full, untruncated output /// Default is false (truncate for human readability) - fn wants_full_output(&self) -> bool { false } + fn wants_full_output(&self) -> bool { + false + } /// Prompt the user for a yes/no confirmation fn prompt_user_yes_no(&self, message: &str) -> bool; @@ -86,7 +88,13 @@ impl UiWriter for NullUiWriter { fn print_agent_response(&self, _content: &str) {} fn notify_sse_received(&self) {} fn flush(&self) {} - fn wants_full_output(&self) -> bool { false } - fn prompt_user_yes_no(&self, _message: &str) -> bool { true } - fn prompt_user_choice(&self, _message: &str, _options: &[&str]) -> usize { 0 } -} \ No newline at end of file + fn wants_full_output(&self) -> bool { + false + } + fn prompt_user_yes_no(&self, _message: &str) -> bool { + true + } + fn prompt_user_choice(&self, _message: &str, _options: &[&str]) -> usize { + 0 + } +} diff --git a/crates/g3-core/tests/code_search_test.rs b/crates/g3-core/tests/code_search_test.rs index 6685009..c42dad6 100644 --- a/crates/g3-core/tests/code_search_test.rs +++ b/crates/g3-core/tests/code_search_test.rs @@ -8,7 +8,7 @@ async fn test_find_async_functions() { // Create a temporary test file let test_dir = std::env::temp_dir().join("g3_test_code_search"); fs::create_dir_all(&test_dir).unwrap(); - + let test_file = test_dir.join("test.rs"); fs::write( &test_file, @@ -47,7 +47,10 @@ pub async fn another_async(x: i32) -> Result<(), ()> { assert_eq!(response.searches.len(), 1); let search_result = &response.searches[0]; assert_eq!(search_result.name, "find_async_functions"); - assert_eq!(search_result.match_count, 2, "Should find 2 async functions"); + assert_eq!( + search_result.match_count, 2, + "Should find 2 async functions" + ); assert!(search_result.error.is_none()); // Check that we found the right functions @@ -69,7 +72,7 @@ async fn test_find_all_functions() { // Create a temporary test file let test_dir = std::env::temp_dir().join("g3_test_code_search_2"); fs::create_dir_all(&test_dir).unwrap(); - + let test_file = test_dir.join("test.rs"); fs::write( &test_file, @@ -107,7 +110,10 @@ pub async fn another_async(x: i32) -> Result<(), ()> { assert_eq!(response.searches.len(), 1); let search_result = &response.searches[0]; assert_eq!(search_result.name, "find_all_functions"); - assert_eq!(search_result.match_count, 3, "Should find 3 functions total"); + assert_eq!( + search_result.match_count, 3, + "Should find 3 functions total" + ); assert!(search_result.error.is_none()); // Check that we found all functions @@ -130,7 +136,7 @@ async fn test_find_structs() { // Create a temporary test file let test_dir = std::env::temp_dir().join("g3_test_code_search_3"); fs::create_dir_all(&test_dir).unwrap(); - + let test_file = test_dir.join("test.rs"); fs::write( &test_file, @@ -188,7 +194,7 @@ async fn test_context_lines() { // Create a temporary test file let test_dir = std::env::temp_dir().join("g3_test_code_search_4"); fs::create_dir_all(&test_dir).unwrap(); - + let test_file = test_dir.join("test.rs"); fs::write( &test_file, @@ -223,16 +229,22 @@ pub fn target_function() { assert_eq!(response.searches.len(), 1); let search_result = &response.searches[0]; assert_eq!(search_result.match_count, 1); - + let match_result = &search_result.matches[0]; assert!(match_result.context.is_some()); - + let context = match_result.context.as_ref().unwrap(); assert!(context.contains("Line 2"), "Should include 2 lines before"); - assert!(context.contains("target_function"), "Should include the function"); + assert!( + context.contains("target_function"), + "Should include the function" + ); // Note: context_lines=2 means 2 lines before and after the match line (line 4) // So we get lines 2-6, which includes up to println but not the closing brace - assert!(context.contains("println"), "Should include 2 lines after the match"); + assert!( + context.contains("println"), + "Should include 2 lines after the match" + ); // Cleanup fs::remove_dir_all(&test_dir).ok(); @@ -243,7 +255,7 @@ async fn test_multiple_searches() { // Create a temporary test file let test_dir = std::env::temp_dir().join("g3_test_code_search_5"); fs::create_dir_all(&test_dir).unwrap(); - + let test_file = test_dir.join("test.rs"); fs::write( &test_file, @@ -301,7 +313,7 @@ async fn test_python_search() { // Create a temporary Python test file let test_dir = std::env::temp_dir().join("g3_test_code_search_python"); fs::create_dir_all(&test_dir).unwrap(); - + let test_file = test_dir.join("test.py"); fs::write( &test_file, @@ -338,14 +350,17 @@ class MyClass: assert_eq!(response.searches.len(), 1); let search_result = &response.searches[0]; - assert_eq!(search_result.match_count, 3, "Should find 3 functions in Python (2 regular + 1 async + 1 method)"); - + assert_eq!( + search_result.match_count, 3, + "Should find 3 functions in Python (2 regular + 1 async + 1 method)" + ); + let function_names: Vec = search_result .matches .iter() .filter_map(|m| m.captures.get("name").cloned()) .collect(); - + assert!(function_names.contains(&"regular_function".to_string())); assert!(function_names.contains(&"async_function".to_string())); assert!(function_names.contains(&"method".to_string())); @@ -359,7 +374,7 @@ async fn test_javascript_search() { // Create a temporary JavaScript test file let test_dir = std::env::temp_dir().join("g3_test_code_search_js"); fs::create_dir_all(&test_dir).unwrap(); - + let test_file = test_dir.join("test.js"); fs::write( &test_file, @@ -396,14 +411,17 @@ class MyClass { assert_eq!(response.searches.len(), 1); let search_result = &response.searches[0]; - assert_eq!(search_result.match_count, 2, "Should find 2 functions in JavaScript"); - + assert_eq!( + search_result.match_count, 2, + "Should find 2 functions in JavaScript" + ); + let function_names: Vec = search_result .matches .iter() .filter_map(|m| m.captures.get("name").cloned()) .collect(); - + assert!(function_names.contains(&"regularFunction".to_string())); assert!(function_names.contains(&"asyncFunction".to_string())); @@ -420,7 +438,7 @@ async fn test_go_search() { .and_then(|p| p.parent()) .unwrap(); let test_code_path = workspace_root.join("examples/test_code"); - + let request = CodeSearchRequest { searches: vec![SearchSpec { name: "go_functions".to_string(), @@ -435,14 +453,19 @@ async fn test_go_search() { let response = execute_code_search(request).await.unwrap(); assert_eq!(response.searches.len(), 1); - + eprintln!("Go search result: {:?}", response.searches[0]); eprintln!("Match count: {}", response.searches[0].matches.len()); eprintln!("Error: {:?}", response.searches[0].error); - assert!(response.searches[0].matches.len() > 0, "No matches found for Go search"); - + assert!( + response.searches[0].matches.len() > 0, + "No matches found for Go search" + ); + // Should find main and greet functions - let names: Vec<&str> = response.searches[0].matches.iter() + let names: Vec<&str> = response.searches[0] + .matches + .iter() .filter_map(|m| m.captures.get("name").map(|s| s.as_str())) .collect(); assert!(names.contains(&"main")); @@ -458,7 +481,7 @@ async fn test_java_search() { .and_then(|p| p.parent()) .unwrap(); let test_code_path = workspace_root.join("examples/test_code"); - + let request = CodeSearchRequest { searches: vec![SearchSpec { name: "java_classes".to_string(), @@ -474,9 +497,11 @@ async fn test_java_search() { let response = execute_code_search(request).await.unwrap(); assert_eq!(response.searches.len(), 1); assert!(response.searches[0].matches.len() > 0); - + // Should find Example class - let names: Vec<&str> = response.searches[0].matches.iter() + let names: Vec<&str> = response.searches[0] + .matches + .iter() .filter_map(|m| m.captures.get("name").map(|s| s.as_str())) .collect(); assert!(names.contains(&"Example")); @@ -491,7 +516,7 @@ async fn test_c_search() { .and_then(|p| p.parent()) .unwrap(); let test_code_path = workspace_root.join("examples/test_code"); - + let request = CodeSearchRequest { searches: vec![SearchSpec { name: "c_functions".to_string(), @@ -507,9 +532,11 @@ async fn test_c_search() { let response = execute_code_search(request).await.unwrap(); assert_eq!(response.searches.len(), 1); assert!(response.searches[0].matches.len() > 0); - + // Should find greet, add, and main functions - let names: Vec<&str> = response.searches[0].matches.iter() + let names: Vec<&str> = response.searches[0] + .matches + .iter() .filter_map(|m| m.captures.get("name").map(|s| s.as_str())) .collect(); assert!(names.contains(&"greet")); @@ -526,7 +553,7 @@ async fn test_cpp_search() { .and_then(|p| p.parent()) .unwrap(); let test_code_path = workspace_root.join("examples/test_code"); - + let request = CodeSearchRequest { searches: vec![SearchSpec { name: "cpp_classes".to_string(), @@ -542,9 +569,11 @@ async fn test_cpp_search() { let response = execute_code_search(request).await.unwrap(); assert_eq!(response.searches.len(), 1); assert!(response.searches[0].matches.len() > 0); - + // Should find Person class - let names: Vec<&str> = response.searches[0].matches.iter() + let names: Vec<&str> = response.searches[0] + .matches + .iter() .filter_map(|m| m.captures.get("name").map(|s| s.as_str())) .collect(); assert!(names.contains(&"Person")); @@ -568,9 +597,11 @@ async fn test_kotlin_search() { let response = execute_code_search(request).await.unwrap(); assert_eq!(response.searches.len(), 1); assert!(response.searches[0].matches.len() > 0); - + // Should find Person class - let names: Vec<&str> = response.searches[0].matches.iter() + let names: Vec<&str> = response.searches[0] + .matches + .iter() .filter_map(|m| m.captures.get("name").map(|s| s.as_str())) .collect(); assert!(names.contains(&"Person")); diff --git a/crates/g3-core/tests/test_context_thinning.rs b/crates/g3-core/tests/test_context_thinning.rs index f0ef2a1..fbdc5a0 100644 --- a/crates/g3-core/tests/test_context_thinning.rs +++ b/crates/g3-core/tests/test_context_thinning.rs @@ -4,35 +4,35 @@ use g3_providers::{Message, MessageRole}; #[test] fn test_thinning_thresholds() { let mut context = ContextWindow::new(10000); - + // At 0%, should not thin assert!(!context.should_thin()); - + // Simulate reaching 50% usage context.used_tokens = 5000; assert!(context.should_thin()); - + // After thinning at 50%, should not thin again until next threshold context.last_thinning_percentage = 50; assert!(!context.should_thin()); - + // At 60%, should thin again context.used_tokens = 6000; assert!(context.should_thin()); - + // After thinning at 60%, should not thin context.last_thinning_percentage = 60; assert!(!context.should_thin()); - + // At 70%, should thin context.used_tokens = 7000; assert!(context.should_thin()); - + // At 80%, should thin context.last_thinning_percentage = 70; context.used_tokens = 8000; assert!(context.should_thin()); - + // After 80%, should not thin (compaction takes over) context.last_thinning_percentage = 80; context.used_tokens = 8500; @@ -42,7 +42,7 @@ fn test_thinning_thresholds() { #[test] fn test_thin_context_basic() { let mut context = ContextWindow::new(10000); - + // Add some messages to the first third for i in 0..9 { if i % 2 == 0 { @@ -62,24 +62,25 @@ fn test_thin_context_basic() { // Small tool result (< 1000 chars) format!("Tool result: small result {}", i) }; - - context.add_message(Message::new( - MessageRole::User, - content, - )); + + context.add_message(Message::new(MessageRole::User, content)); } } - + // Trigger thinning at 50% context.used_tokens = 5000; let (summary, _chars_saved) = context.thin_context(); - + println!("Thinning summary: {}", summary); - + // Should have thinned at least 1 large tool result in the first third - assert!(summary.contains("1 tool result"), "Summary was: {}", summary); + assert!( + summary.contains("1 tool result"), + "Summary was: {}", + summary + ); assert!(summary.contains("50%")); - + // Check that the large tool results were replaced let first_third_end = context.conversation_history.len() / 3; for i in 0..first_third_end { @@ -96,13 +97,13 @@ fn test_thin_context_basic() { #[test] fn test_thin_write_file_tool_calls() { let mut context = ContextWindow::new(10000); - + // Add some messages including a write_file tool call with large content context.add_message(Message::new( MessageRole::User, "Please create a large file".to_string(), )); - + // Add an assistant message with a write_file tool call containing large content let large_content = "x".repeat(1500); let tool_call_json = format!( @@ -113,12 +114,12 @@ fn test_thin_write_file_tool_calls() { MessageRole::Assistant, format!("I'll create that file.\n\n{}", tool_call_json), )); - + context.add_message(Message::new( MessageRole::User, "Tool result: ✅ Successfully wrote 1500 lines".to_string(), )); - + // Add more messages to ensure we have enough for "first third" logic for i in 0..6 { context.add_message(Message::new( @@ -126,16 +127,16 @@ fn test_thin_write_file_tool_calls() { format!("Response {}", i), )); } - + // Trigger thinning at 50% context.used_tokens = 5000; let (summary, _chars_saved) = context.thin_context(); - + println!("Thinning summary: {}", summary); - + // Should have thinned the write_file tool call assert!(summary.contains("tool call") || summary.contains("chars saved")); - + // Check that the large content was replaced with a file reference let first_third_end = context.conversation_history.len() / 3; for i in 0..first_third_end { @@ -152,15 +153,19 @@ fn test_thin_write_file_tool_calls() { #[test] fn test_thin_str_replace_tool_calls() { let mut context = ContextWindow::new(10000); - + // Add some messages including a str_replace tool call with large diff context.add_message(Message::new( MessageRole::User, "Please update the file".to_string(), )); - + // Add an assistant message with a str_replace tool call containing large diff - let large_diff = format!("--- old\n{}\n+++ new\n{}", "-old line\n".repeat(100), "+new line\n".repeat(100)); + let large_diff = format!( + "--- old\n{}\n+++ new\n{}", + "-old line\n".repeat(100), + "+new line\n".repeat(100) + ); let tool_call_json = format!( r#"{{"tool": "str_replace", "args": {{"file_path": "test.txt", "diff": "{}"}}}}"#, large_diff.replace('\n', "\\n") @@ -169,12 +174,12 @@ fn test_thin_str_replace_tool_calls() { MessageRole::Assistant, format!("I'll update that file.\n\n{}", tool_call_json), )); - + context.add_message(Message::new( MessageRole::User, "Tool result: ✅ applied unified diff".to_string(), )); - + // Add more messages to ensure we have enough for "first third" logic for i in 0..6 { context.add_message(Message::new( @@ -182,16 +187,16 @@ fn test_thin_str_replace_tool_calls() { format!("Response {}", i), )); } - + // Trigger thinning at 50% context.used_tokens = 5000; let (summary, _chars_saved) = context.thin_context(); - + println!("Thinning summary: {}", summary); - + // Should have thinned the str_replace tool call assert!(summary.contains("tool call") || summary.contains("chars saved")); - + // Check that the large diff was replaced with a file reference let first_third_end = context.conversation_history.len() / 3; for i in 0..first_third_end { @@ -209,7 +214,7 @@ fn test_thin_str_replace_tool_calls() { #[test] fn test_thin_context_no_large_results() { let mut context = ContextWindow::new(10000); - + // Add only small messages for i in 0..9 { context.add_message(Message::new( @@ -217,10 +222,10 @@ fn test_thin_context_no_large_results() { format!("Tool result: small {}", i), )); } - + context.used_tokens = 5000; let (summary, _chars_saved) = context.thin_context(); - + // Should report no large results found assert!(summary.contains("no large tool results or tool calls found")); } @@ -228,7 +233,7 @@ fn test_thin_context_no_large_results() { #[test] fn test_thin_context_only_affects_first_third() { let mut context = ContextWindow::new(10000); - + // Add 12 messages (first third = 4 messages) for i in 0..12 { let content = if i % 2 == 1 { @@ -237,23 +242,23 @@ fn test_thin_context_only_affects_first_third() { } else { format!("Assistant message {}", i) }; - + let role = if i % 2 == 1 { MessageRole::User } else { MessageRole::Assistant }; - + context.add_message(Message::new(role, content)); } - + context.used_tokens = 5000; let (summary, _chars_saved) = context.thin_context(); - + // First third is 4 messages (indices 0-3), so only indices 1 and 3 should be thinned // That's 2 tool results assert!(summary.contains("2 tool results")); - + // Check that messages after the first third are NOT thinned let first_third_end = context.conversation_history.len() / 3; for i in first_third_end..context.conversation_history.len() { @@ -261,8 +266,11 @@ fn test_thin_context_only_affects_first_third() { if matches!(msg.role, MessageRole::User) && msg.content.starts_with("Tool result:") { // These should still be large (not thinned) if i % 2 == 1 { - assert!(msg.content.len() > 1000, - "Message at index {} should not have been thinned", i); + assert!( + msg.content.len() > 1000, + "Message at index {} should not have been thinned", + i + ); } } } diff --git a/crates/g3-core/tests/test_todo_context_thinning.rs b/crates/g3-core/tests/test_todo_context_thinning.rs index 27443a9..815fa6e 100644 --- a/crates/g3-core/tests/test_todo_context_thinning.rs +++ b/crates/g3-core/tests/test_todo_context_thinning.rs @@ -6,28 +6,34 @@ use serial_test::serial; #[serial] fn test_todo_read_results_not_thinned() { let mut context = ContextWindow::new(10000); - + // Add a todo_read tool call - context.add_message(Message::new(MessageRole::Assistant, r#"{"tool": "todo_read", "args": {}}"#.to_string())); - + context.add_message(Message::new( + MessageRole::Assistant, + r#"{"tool": "todo_read", "args": {}}"#.to_string(), + )); + // Add a large TODO result (> 500 chars) let large_todo_result = format!( "Tool result: 📝 TODO list:\n{}", "- [ ] Task with long description\n".repeat(50) ); context.add_message(Message::new(MessageRole::User, large_todo_result.clone())); - + // Add more messages to ensure we have enough for "first third" logic for i in 0..6 { - context.add_message(Message::new(MessageRole::Assistant, format!("Response {}", i))) + context.add_message(Message::new( + MessageRole::Assistant, + format!("Response {}", i), + )) } - + // Trigger thinning at 50% context.used_tokens = 5000; let (summary, _chars_saved) = context.thin_context(); - + println!("Thinning summary: {}", summary); - + // Check that the TODO result was NOT thinned let first_third_end = context.conversation_history.len() / 3; for i in 0..first_third_end { @@ -53,29 +59,38 @@ fn test_todo_read_results_not_thinned() { #[serial] fn test_todo_write_results_not_thinned() { let mut context = ContextWindow::new(10000); - + // Add a todo_write tool call let large_content = "- [ ] Task\n".repeat(100); - context.add_message(Message::new(MessageRole::Assistant, format!(r#"{{"tool": "todo_write", "args": {{"content": "{}"}}}}"#, large_content))); - + context.add_message(Message::new( + MessageRole::Assistant, + format!( + r#"{{"tool": "todo_write", "args": {{"content": "{}"}}}}"#, + large_content + ), + )); + // Add a large TODO write result let large_todo_result = format!( "Tool result: ✅ TODO list updated ({} chars) and saved to todo.g3.md", large_content.len() ); context.add_message(Message::new(MessageRole::User, large_todo_result.clone())); - + // Add more messages for i in 0..6 { - context.add_message(Message::new(MessageRole::Assistant, format!("Response {}", i))) + context.add_message(Message::new( + MessageRole::Assistant, + format!("Response {}", i), + )) } - + // Trigger thinning at 50% context.used_tokens = 5000; let (summary, _chars_saved) = context.thin_context(); - + println!("Thinning summary: {}", summary); - + // Check that the TODO write result was NOT thinned let first_third_end = context.conversation_history.len() / 3; for i in 0..first_third_end { @@ -99,31 +114,37 @@ fn test_todo_write_results_not_thinned() { #[serial] fn test_non_todo_results_still_thinned() { let mut context = ContextWindow::new(10000); - + // Add a non-TODO tool call (e.g., read_file) - context.add_message(Message::new(MessageRole::Assistant, r#"{"tool": "read_file", "args": {"file_path": "test.txt"}}"#.to_string())); - + context.add_message(Message::new( + MessageRole::Assistant, + r#"{"tool": "read_file", "args": {"file_path": "test.txt"}}"#.to_string(), + )); + // Add a large read_file result (> 500 chars) let large_result = format!("Tool result: {}", "x".repeat(1500)); context.add_message(Message::new(MessageRole::User, large_result)); - + // Add more messages for i in 0..6 { - context.add_message(Message::new(MessageRole::Assistant, format!("Response {}", i))) + context.add_message(Message::new( + MessageRole::Assistant, + format!("Response {}", i), + )) } - + // Trigger thinning at 50% context.used_tokens = 5000; let (summary, _chars_saved) = context.thin_context(); - + println!("Thinning summary: {}", summary); - + // Should have thinned the non-TODO result assert!( summary.contains("1 tool result") || summary.contains("chars saved"), "Non-TODO results should be thinned" ); - + // Check that the result was actually thinned let first_third_end = context.conversation_history.len() / 3; for i in 0..first_third_end { @@ -143,26 +164,29 @@ fn test_non_todo_results_still_thinned() { #[serial] fn test_todo_read_with_spaces_in_tool_name() { let mut context = ContextWindow::new(10000); - + // Add a todo_read tool call with spaces (JSON formatting variation) - context.add_message(Message::new(MessageRole::Assistant, r#"{"tool": "todo_read", "args": {}}"#.to_string())); - + context.add_message(Message::new( + MessageRole::Assistant, + r#"{"tool": "todo_read", "args": {}}"#.to_string(), + )); + // Add a large TODO result - let large_todo_result = format!( - "Tool result: 📝 TODO list:\n{}", - "- [ ] Task\n".repeat(50) - ); + let large_todo_result = format!("Tool result: 📝 TODO list:\n{}", "- [ ] Task\n".repeat(50)); context.add_message(Message::new(MessageRole::User, large_todo_result.clone())); - + // Add more messages for i in 0..6 { - context.add_message(Message::new(MessageRole::Assistant, format!("Response {}", i))) + context.add_message(Message::new( + MessageRole::Assistant, + format!("Response {}", i), + )) } - + // Trigger thinning context.used_tokens = 5000; let (_summary, _chars_saved) = context.thin_context(); - + // Verify TODO result was not thinned let first_third_end = context.conversation_history.len() / 3; for i in 0..first_third_end { diff --git a/crates/g3-core/tests/test_todo_persistence.rs b/crates/g3-core/tests/test_todo_persistence.rs index 69baabd..bca8d6b 100644 --- a/crates/g3-core/tests/test_todo_persistence.rs +++ b/crates/g3-core/tests/test_todo_persistence.rs @@ -1,20 +1,19 @@ -use g3_core::Agent; use g3_core::ui_writer::NullUiWriter; +use g3_core::Agent; use serial_test::serial; use std::fs; use std::path::PathBuf; use tempfile::TempDir; - /// Helper to create a test agent in a temporary directory async fn create_test_agent_in_dir(temp_dir: &TempDir) -> Agent { // Change to temp directory std::env::set_current_dir(temp_dir.path()).unwrap(); - + // Create a minimal config let config = g3_config::Config::default(); let ui_writer = NullUiWriter; - + Agent::new(config, ui_writer).await.unwrap() } @@ -29,10 +28,10 @@ async fn test_todo_write_creates_file() { let temp_dir = TempDir::new().unwrap(); let mut agent = create_test_agent_in_dir(&temp_dir).await; let todo_path = get_todo_path(&temp_dir); - + // Initially, todo.g3.md should not exist assert!(!todo_path.exists(), "todo.g3.md should not exist initially"); - + // Create a tool call to write TODO let tool_call = g3_core::ToolCall { tool: "todo_write".to_string(), @@ -40,17 +39,21 @@ async fn test_todo_write_creates_file() { "content": "- [ ] Task 1\n- [ ] Task 2\n- [x] Task 3" }), }; - + // Execute the tool let result = agent.execute_tool(&tool_call).await.unwrap(); - + // Should report success assert!(result.contains("✅"), "Should report success: {}", result); - assert!(result.contains("todo.g3.md"), "Should mention todo.g3.md: {}", result); - + assert!( + result.contains("todo.g3.md"), + "Should mention todo.g3.md: {}", + result + ); + // File should now exist assert!(todo_path.exists(), "todo.g3.md should exist after write"); - + // File should contain the correct content let content = fs::read_to_string(&todo_path).unwrap(); assert_eq!(content, "- [ ] Task 1\n- [ ] Task 2\n- [x] Task 3"); @@ -61,27 +64,39 @@ async fn test_todo_write_creates_file() { async fn test_todo_read_from_file() { let temp_dir = TempDir::new().unwrap(); let todo_path = get_todo_path(&temp_dir); - + // Pre-create a todo.g3.md file let test_content = "# My TODO\n\n- [ ] First task\n- [x] Completed task"; fs::write(&todo_path, test_content).unwrap(); - + // Create agent (should load from file) let mut agent = create_test_agent_in_dir(&temp_dir).await; - + // Create a tool call to read TODO let tool_call = g3_core::ToolCall { tool: "todo_read".to_string(), args: serde_json::json!({}), }; - + // Execute the tool let result = agent.execute_tool(&tool_call).await.unwrap(); - + // Should contain the TODO content - assert!(result.contains("📝 TODO list:"), "Should have TODO list header: {}", result); - assert!(result.contains("First task"), "Should contain first task: {}", result); - assert!(result.contains("Completed task"), "Should contain completed task: {}", result); + assert!( + result.contains("📝 TODO list:"), + "Should have TODO list header: {}", + result + ); + assert!( + result.contains("First task"), + "Should contain first task: {}", + result + ); + assert!( + result.contains("Completed task"), + "Should contain completed task: {}", + result + ); } #[tokio::test] @@ -89,16 +104,16 @@ async fn test_todo_read_from_file() { async fn test_todo_read_empty_file() { let temp_dir = TempDir::new().unwrap(); let mut agent = create_test_agent_in_dir(&temp_dir).await; - + // Create a tool call to read TODO (file doesn't exist) let tool_call = g3_core::ToolCall { tool: "todo_read".to_string(), args: serde_json::json!({}), }; - + // Execute the tool let result = agent.execute_tool(&tool_call).await.unwrap(); - + // Should report empty assert!(result.contains("empty"), "Should report empty: {}", result); } @@ -108,7 +123,7 @@ async fn test_todo_read_empty_file() { async fn test_todo_persistence_across_agents() { let temp_dir = TempDir::new().unwrap(); let todo_path = get_todo_path(&temp_dir); - + // Agent 1: Write TODO { let mut agent = create_test_agent_in_dir(&temp_dir).await; @@ -120,10 +135,13 @@ async fn test_todo_persistence_across_agents() { }; agent.execute_tool(&tool_call).await.unwrap(); } - + // Verify file exists - assert!(todo_path.exists(), "todo.g3.md should persist after agent drops"); - + assert!( + todo_path.exists(), + "todo.g3.md should persist after agent drops" + ); + // Agent 2: Read TODO (new agent instance) { let mut agent = create_test_agent_in_dir(&temp_dir).await; @@ -132,10 +150,18 @@ async fn test_todo_persistence_across_agents() { args: serde_json::json!({}), }; let result = agent.execute_tool(&tool_call).await.unwrap(); - + // Should read the persisted content - assert!(result.contains("Persistent task"), "Should read persisted task: {}", result); - assert!(result.contains("Done task"), "Should read done task: {}", result); + assert!( + result.contains("Persistent task"), + "Should read persisted task: {}", + result + ); + assert!( + result.contains("Done task"), + "Should read done task: {}", + result + ); } } @@ -145,7 +171,7 @@ async fn test_todo_update_preserves_file() { let temp_dir = TempDir::new().unwrap(); let mut agent = create_test_agent_in_dir(&temp_dir).await; let todo_path = get_todo_path(&temp_dir); - + // Write initial TODO let write_call = g3_core::ToolCall { tool: "todo_write".to_string(), @@ -154,7 +180,7 @@ async fn test_todo_update_preserves_file() { }), }; agent.execute_tool(&write_call).await.unwrap(); - + // Update TODO let update_call = g3_core::ToolCall { tool: "todo_write".to_string(), @@ -163,7 +189,7 @@ async fn test_todo_update_preserves_file() { }), }; agent.execute_tool(&update_call).await.unwrap(); - + // Verify file has updated content let content = fs::read_to_string(&todo_path).unwrap(); assert_eq!(content, "- [x] Task 1\n- [ ] Task 2\n- [ ] Task 3"); @@ -175,23 +201,30 @@ async fn test_todo_handles_large_content() { let temp_dir = TempDir::new().unwrap(); let mut agent = create_test_agent_in_dir(&temp_dir).await; let todo_path = get_todo_path(&temp_dir); - + // Create a large TODO (but under the 50k limit) let mut large_content = String::from("# Large TODO\n\n"); for i in 0..100 { - large_content.push_str(&format!("- [ ] Task {} with a long description that exceeds normal line lengths\n", i)); + large_content.push_str(&format!( + "- [ ] Task {} with a long description that exceeds normal line lengths\n", + i + )); } - + let tool_call = g3_core::ToolCall { tool: "todo_write".to_string(), args: serde_json::json!({ "content": large_content }), }; - + let result = agent.execute_tool(&tool_call).await.unwrap(); - assert!(result.contains("✅"), "Should handle large content: {}", result); - + assert!( + result.contains("✅"), + "Should handle large content: {}", + result + ); + // Verify file contains all content let file_content = fs::read_to_string(&todo_path).unwrap(); assert_eq!(file_content, large_content); @@ -203,22 +236,30 @@ async fn test_todo_handles_large_content() { async fn test_todo_respects_size_limit() { let temp_dir = TempDir::new().unwrap(); let mut agent = create_test_agent_in_dir(&temp_dir).await; - + // Create content that exceeds the default 50k limit let huge_content = "x".repeat(60_000); - + let tool_call = g3_core::ToolCall { tool: "todo_write".to_string(), args: serde_json::json!({ "content": huge_content }), }; - + let result = agent.execute_tool(&tool_call).await.unwrap(); - + // Should reject content that's too large - assert!(result.contains("❌"), "Should reject oversized content: {}", result); - assert!(result.contains("too large"), "Should mention size limit: {}", result); + assert!( + result.contains("❌"), + "Should reject oversized content: {}", + result + ); + assert!( + result.contains("too large"), + "Should mention size limit: {}", + result + ); } #[tokio::test] @@ -226,22 +267,26 @@ async fn test_todo_respects_size_limit() { async fn test_todo_agent_initialization_loads_file() { let temp_dir = TempDir::new().unwrap(); let todo_path = get_todo_path(&temp_dir); - + // Pre-create todo.g3.md before agent initialization let initial_content = "- [ ] Pre-existing task"; fs::write(&todo_path, initial_content).unwrap(); - + // Create agent - should load the file during initialization let mut agent = create_test_agent_in_dir(&temp_dir).await; - + // Read TODO - should return the pre-existing content let tool_call = g3_core::ToolCall { tool: "todo_read".to_string(), args: serde_json::json!({}), }; - + let result = agent.execute_tool(&tool_call).await.unwrap(); - assert!(result.contains("Pre-existing task"), "Should load file on init: {}", result); + assert!( + result.contains("Pre-existing task"), + "Should load file on init: {}", + result + ); } #[tokio::test] @@ -250,33 +295,41 @@ async fn test_todo_handles_unicode_content() { let temp_dir = TempDir::new().unwrap(); let mut agent = create_test_agent_in_dir(&temp_dir).await; let todo_path = get_todo_path(&temp_dir); - + // Create TODO with unicode characters let unicode_content = "- [ ] 日本語タスク\n- [ ] Émoji task 🚀\n- [x] Ελληνικά task"; - + let tool_call = g3_core::ToolCall { tool: "todo_write".to_string(), args: serde_json::json!({ "content": unicode_content }), }; - + agent.execute_tool(&tool_call).await.unwrap(); - + // Verify file preserves unicode let file_content = fs::read_to_string(&todo_path).unwrap(); assert_eq!(file_content, unicode_content); - + // Verify reading back works let read_call = g3_core::ToolCall { tool: "todo_read".to_string(), args: serde_json::json!({}), }; - + let result = agent.execute_tool(&read_call).await.unwrap(); - assert!(result.contains("日本語"), "Should preserve Japanese: {}", result); + assert!( + result.contains("日本語"), + "Should preserve Japanese: {}", + result + ); assert!(result.contains("🚀"), "Should preserve emoji: {}", result); - assert!(result.contains("Ελληνικά"), "Should preserve Greek: {}", result); + assert!( + result.contains("Ελληνικά"), + "Should preserve Greek: {}", + result + ); } #[tokio::test] @@ -285,7 +338,7 @@ async fn test_todo_empty_content_creates_empty_file() { let temp_dir = TempDir::new().unwrap(); let mut agent = create_test_agent_in_dir(&temp_dir).await; let todo_path = get_todo_path(&temp_dir); - + // Write empty TODO let tool_call = g3_core::ToolCall { tool: "todo_write".to_string(), @@ -293,9 +346,9 @@ async fn test_todo_empty_content_creates_empty_file() { "content": "" }), }; - + agent.execute_tool(&tool_call).await.unwrap(); - + // File should exist but be empty assert!(todo_path.exists(), "Empty todo.g3.md should create file"); let content = fs::read_to_string(&todo_path).unwrap(); @@ -307,7 +360,7 @@ async fn test_todo_empty_content_creates_empty_file() { async fn test_todo_whitespace_only_content() { let temp_dir = TempDir::new().unwrap(); let mut agent = create_test_agent_in_dir(&temp_dir).await; - + // Write whitespace-only TODO let tool_call = g3_core::ToolCall { tool: "todo_write".to_string(), @@ -315,17 +368,21 @@ async fn test_todo_whitespace_only_content() { "content": " \n\n \t \n" }), }; - + agent.execute_tool(&tool_call).await.unwrap(); - + // Read it back let read_call = g3_core::ToolCall { tool: "todo_read".to_string(), args: serde_json::json!({}), }; - + let result = agent.execute_tool(&read_call).await.unwrap(); - + // Should report as empty (whitespace is trimmed) - assert!(result.contains("empty"), "Whitespace-only should be empty: {}", result); + assert!( + result.contains("empty"), + "Whitespace-only should be empty: {}", + result + ); } diff --git a/crates/g3-core/tests/test_token_counting.rs b/crates/g3-core/tests/test_token_counting.rs index 5ba13e5..02646fa 100644 --- a/crates/g3-core/tests/test_token_counting.rs +++ b/crates/g3-core/tests/test_token_counting.rs @@ -4,7 +4,7 @@ use g3_providers::Usage; #[test] fn test_token_accumulation() { let mut window = ContextWindow::new(10000); - + // First API call: 100 prompt + 50 completion = 150 total let usage1 = Usage { prompt_tokens: 100, @@ -22,7 +22,10 @@ fn test_token_accumulation() { total_tokens: 275, }; window.update_usage_from_response(&usage2); - assert_eq!(window.used_tokens, 425, "Second call should accumulate to 425 tokens"); + assert_eq!( + window.used_tokens, 425, + "Second call should accumulate to 425 tokens" + ); assert_eq!(window.cumulative_tokens, 425, "Cumulative should be 425"); // Third API call with SMALLER token count: 50 prompt + 25 completion = 75 total @@ -32,27 +35,33 @@ fn test_token_accumulation() { total_tokens: 75, }; window.update_usage_from_response(&usage3); - assert_eq!(window.used_tokens, 500, "Third call should accumulate to 500 tokens"); + assert_eq!( + window.used_tokens, 500, + "Third call should accumulate to 500 tokens" + ); assert_eq!(window.cumulative_tokens, 500, "Cumulative should be 500"); - + // Verify tokens never decrease - assert!(window.used_tokens >= 425, "Token count should never decrease!"); + assert!( + window.used_tokens >= 425, + "Token count should never decrease!" + ); } #[test] fn test_add_streaming_tokens() { let mut window = ContextWindow::new(10000); - + // Add some streaming tokens window.add_streaming_tokens(100); assert_eq!(window.used_tokens, 100); assert_eq!(window.cumulative_tokens, 100); - + // Add more window.add_streaming_tokens(50); assert_eq!(window.used_tokens, 150); assert_eq!(window.cumulative_tokens, 150); - + // Now update from provider response let usage = Usage { prompt_tokens: 80, @@ -60,7 +69,7 @@ fn test_add_streaming_tokens() { total_tokens: 120, }; window.update_usage_from_response(&usage); - + // Should ADD to existing, not replace assert_eq!(window.used_tokens, 270, "Should add 120 to existing 150"); assert_eq!(window.cumulative_tokens, 270); @@ -69,7 +78,7 @@ fn test_add_streaming_tokens() { #[test] fn test_percentage_calculation() { let mut window = ContextWindow::new(1000); - + // Add tokens via provider response let usage = Usage { prompt_tokens: 150, @@ -77,10 +86,10 @@ fn test_percentage_calculation() { total_tokens: 250, }; window.update_usage_from_response(&usage); - + assert_eq!(window.percentage_used(), 25.0); assert_eq!(window.remaining_tokens(), 750); - + // Add more tokens let usage2 = Usage { prompt_tokens: 300, @@ -88,7 +97,7 @@ fn test_percentage_calculation() { total_tokens: 500, }; window.update_usage_from_response(&usage2); - + assert_eq!(window.percentage_used(), 75.0); assert_eq!(window.remaining_tokens(), 250); } diff --git a/crates/g3-core/tests/todo_staleness_test.rs b/crates/g3-core/tests/todo_staleness_test.rs index 6e54855..1ce9ddf 100644 --- a/crates/g3-core/tests/todo_staleness_test.rs +++ b/crates/g3-core/tests/todo_staleness_test.rs @@ -1,9 +1,9 @@ -use g3_core::{Agent, ToolCall}; -use g3_core::ui_writer::UiWriter; use g3_config::Config; +use g3_core::ui_writer::UiWriter; +use g3_core::{Agent, ToolCall}; +use serial_test::serial; use std::sync::{Arc, Mutex}; use tempfile::TempDir; -use serial_test::serial; // Mock UI Writer for testing #[derive(Clone)] @@ -47,7 +47,10 @@ impl UiWriter for MockUiWriter { } fn print_system_prompt(&self, _prompt: &str) {} fn print_context_status(&self, message: &str) { - self.output.lock().unwrap().push(format!("STATUS: {}", message)); + self.output + .lock() + .unwrap() + .push(format!("STATUS: {}", message)); } fn print_context_thinning(&self, _message: &str) {} fn print_tool_header(&self, _tool_name: &str) {} @@ -61,13 +64,21 @@ impl UiWriter for MockUiWriter { fn print_agent_response(&self, _content: &str) {} fn notify_sse_received(&self) {} fn flush(&self) {} - fn wants_full_output(&self) -> bool { false } + fn wants_full_output(&self) -> bool { + false + } fn prompt_user_yes_no(&self, message: &str) -> bool { - self.output.lock().unwrap().push(format!("PROMPT: {}", message)); + self.output + .lock() + .unwrap() + .push(format!("PROMPT: {}", message)); self.prompt_responses.lock().unwrap().pop().unwrap_or(true) } fn prompt_user_choice(&self, message: &str, options: &[&str]) -> usize { - self.output.lock().unwrap().push(format!("CHOICE: {} Options: {:?}", message, options)); + self.output + .lock() + .unwrap() + .push(format!("CHOICE: {} Options: {:?}", message, options)); self.choice_responses.lock().unwrap().pop().unwrap_or(0) } } @@ -80,7 +91,10 @@ async fn test_todo_staleness_check_matching_sha() { std::env::set_current_dir(&temp_dir).unwrap(); let sha = "abc123hash"; - let content = format!("{{{{Based on the requirements file with SHA256: {}}}}}\n- [ ] Task 1", sha); + let content = format!( + "{{{{Based on the requirements file with SHA256: {}}}}}\n- [ ] Task 1", + sha + ); std::fs::write(&todo_path, content).unwrap(); let mut config = Config::default(); @@ -109,7 +123,10 @@ async fn test_todo_staleness_check_mismatch_sha_ignore() { let sha_file = "old_sha"; let sha_req = "new_sha"; - let content = format!("{{{{Based on the requirements file with SHA256: {}}}}}\n- [ ] Task 1", sha_file); + let content = format!( + "{{{{Based on the requirements file with SHA256: {}}}}}\n- [ ] Task 1", + sha_file + ); std::fs::write(&todo_path, content).unwrap(); let mut config = Config::default(); @@ -139,7 +156,10 @@ async fn test_todo_staleness_check_mismatch_sha_mark_stale() { let sha_file = "old_sha"; let sha_req = "new_sha"; - let content = format!("{{{{Based on the requirements file with SHA256: {}}}}}\n- [ ] Task 1", sha_file); + let content = format!( + "{{{{Based on the requirements file with SHA256: {}}}}}\n- [ ] Task 1", + sha_file + ); std::fs::write(&todo_path, content).unwrap(); let mut config = Config::default(); @@ -173,7 +193,10 @@ async fn test_todo_staleness_check_disabled() { let sha_file = "old_sha"; let sha_req = "new_sha"; - let content = format!("{{{{Based on the requirements file with SHA256: {}}}}}\n- [ ] Task 1", sha_file); + let content = format!( + "{{{{Based on the requirements file with SHA256: {}}}}}\n- [ ] Task 1", + sha_file + ); std::fs::write(&todo_path, content).unwrap(); let mut config = Config::default(); diff --git a/crates/g3-ensembles/src/flock.rs b/crates/g3-ensembles/src/flock.rs index d37ac83..8b7ad0c 100644 --- a/crates/g3-ensembles/src/flock.rs +++ b/crates/g3-ensembles/src/flock.rs @@ -17,19 +17,19 @@ use crate::status::{FlockStatus, SegmentState, SegmentStatus}; pub struct FlockConfig { /// Project directory (must be a git repo with flock-requirements.md) pub project_dir: PathBuf, - + /// Flock workspace directory where segments will be created pub flock_workspace: PathBuf, - + /// Number of segments to partition work into pub num_segments: usize, - + /// Maximum turns per segment (for autonomous mode) pub max_turns: usize, - + /// G3 configuration to use for agents pub g3_config: Config, - + /// Path to g3 binary (defaults to current executable) pub g3_binary: Option, } @@ -43,14 +43,20 @@ impl FlockConfig { ) -> Result { // Validate project directory if !project_dir.exists() { - anyhow::bail!("Project directory does not exist: {}", project_dir.display()); + anyhow::bail!( + "Project directory does not exist: {}", + project_dir.display() + ); } - + // Check if it's a git repo if !project_dir.join(".git").exists() { - anyhow::bail!("Project directory must be a git repository: {}", project_dir.display()); + anyhow::bail!( + "Project directory must be a git repository: {}", + project_dir.display() + ); } - + // Check for flock-requirements.md let requirements_path = project_dir.join("flock-requirements.md"); if !requirements_path.exists() { @@ -59,10 +65,10 @@ impl FlockConfig { project_dir.display() ); } - + // Load default config let g3_config = Config::load(None)?; - + Ok(Self { project_dir, flock_workspace, @@ -72,19 +78,19 @@ impl FlockConfig { g3_binary: None, }) } - + /// Set maximum turns per segment pub fn with_max_turns(mut self, max_turns: usize) -> Self { self.max_turns = max_turns; self } - + /// Set custom g3 binary path pub fn with_g3_binary(mut self, binary: PathBuf) -> Self { self.g3_binary = Some(binary); self } - + /// Set custom g3 config pub fn with_config(mut self, config: Config) -> Self { self.g3_config = config; @@ -103,58 +109,67 @@ impl FlockMode { /// Create a new flock mode instance pub fn new(config: FlockConfig) -> Result { let session_id = Uuid::new_v4().to_string(); - + let status = FlockStatus::new( session_id.clone(), config.project_dir.clone(), config.flock_workspace.clone(), config.num_segments, ); - + Ok(Self { config, status, session_id, }) } - + /// Run flock mode pub async fn run(&mut self) -> Result<()> { - info!("Starting flock mode with {} segments", self.config.num_segments); - + info!( + "Starting flock mode with {} segments", + self.config.num_segments + ); + // Step 1: Partition requirements - println!("\n🧠 Step 1: Partitioning requirements into {} segments...", self.config.num_segments); + println!( + "\n🧠 Step 1: Partitioning requirements into {} segments...", + self.config.num_segments + ); let partitions = self.partition_requirements().await?; - + // Step 2: Create segment workspaces println!("\n📁 Step 2: Creating segment workspaces..."); self.create_segment_workspaces(&partitions).await?; - + // Step 3: Run segments in parallel - println!("\n🚀 Step 3: Running {} segments in parallel...", self.config.num_segments); + println!( + "\n🚀 Step 3: Running {} segments in parallel...", + self.config.num_segments + ); self.run_segments_parallel().await?; - + // Step 4: Generate final report println!("\n📊 Step 4: Generating final report..."); self.status.completed_at = Some(Utc::now()); self.save_status()?; - + let report = self.status.generate_report(); println!("{}", report); - + Ok(()) } - + /// Partition requirements using an AI agent async fn partition_requirements(&mut self) -> Result> { let requirements_path = self.config.project_dir.join("flock-requirements.md"); let requirements_content = std::fs::read_to_string(&requirements_path) .context("Failed to read flock-requirements.md")?; - + // Create a temporary workspace for the partitioning agent let partition_workspace = self.config.flock_workspace.join("_partition"); std::fs::create_dir_all(&partition_workspace)?; - + // Create the partitioning prompt let partition_prompt = format!( "You are a software architect tasked with partitioning project requirements into {} logical, \ @@ -198,10 +213,10 @@ impl FlockMode { requirements_content, self.config.num_segments ); - + // Get g3 binary path let g3_binary = self.get_g3_binary()?; - + // Run g3 in single-shot mode to partition requirements println!(" Analyzing requirements and creating partitions..."); let output = Command::new(&g3_binary) @@ -212,23 +227,23 @@ impl FlockMode { .output() .await .context("Failed to run g3 for partitioning")?; - + if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); anyhow::bail!("Partitioning agent failed: {}", stderr); } - + let stdout = String::from_utf8_lossy(&output.stdout); debug!("Partitioning agent output: {}", stdout); - + // Extract JSON from the output let partitions_json = Self::extract_json_from_output(&stdout) .context("Failed to extract partition JSON from agent output")?; - + // Parse the partitions - let partitions: Vec = serde_json::from_str(&partitions_json) - .context("Failed to parse partition JSON")?; - + let partitions: Vec = + serde_json::from_str(&partitions_json).context("Failed to parse partition JSON")?; + if partitions.len() != self.config.num_segments { warn!( "Expected {} partitions but got {}. Adjusting...", @@ -236,14 +251,12 @@ impl FlockMode { partitions.len() ); } - + // Extract requirements text from each partition let mut partition_texts = Vec::new(); for (i, partition) in partitions.iter().enumerate() { let default_name = format!("module-{}", i + 1); - let module_name = partition["module_name"] - .as_str() - .unwrap_or(&default_name); + let module_name = partition["module_name"].as_str().unwrap_or(&default_name); let requirements = partition["requirements"] .as_str() .context("Missing requirements field in partition")?; @@ -256,7 +269,7 @@ impl FlockMode { .join(", ") }) .unwrap_or_default(); - + let partition_text = format!( "# Module: {}\n\n## Dependencies\n{}\n\n## Requirements\n\n{}", module_name, @@ -267,69 +280,80 @@ impl FlockMode { }, requirements ); - + partition_texts.push(partition_text); println!(" ✓ Created partition {}: {}", i + 1, module_name); } - + Ok(partition_texts) } - + /// Extract JSON from agent output (looks for JSON array in output) fn extract_json_from_output(output: &str) -> Result { // Try to find all occurrences of partition markers and extract valid JSON const MARKERS: &[&str] = &["{{PARTITION JSON}}", "{PARTITION JSON}"]; - + let mut candidates = Vec::new(); - + // Find all marker occurrences for &marker in MARKERS { let mut search_start = 0; while let Some(marker_index) = output[search_start..].find(marker) { let absolute_index = search_start + marker_index; let after_marker = &output[absolute_index + marker.len()..]; - + // Try to find a code fence after this marker if let Some(fence_start) = after_marker.find("```") { let after_fence = &after_marker[fence_start + 3..]; - + // Skip optional "json" language identifier let content_start = after_fence .strip_prefix("json") .unwrap_or(after_fence) .trim_start_matches(|c: char| c.is_whitespace()); - + // Find closing fence if let Some(fence_end) = content_start.find("```") { let json_candidate = content_start[..fence_end].trim(); candidates.push(json_candidate.to_string()); } } - + // Move search position forward search_start = absolute_index + marker.len(); } } - + if candidates.is_empty() { - anyhow::bail!("Could not find any partition JSON markers with code fences in agent output"); + anyhow::bail!( + "Could not find any partition JSON markers with code fences in agent output" + ); } - + // Try to parse each candidate and return the first valid JSON let mut last_error = None; for (i, candidate) in candidates.iter().enumerate() { match serde_json::from_str::(candidate) { Ok(_) => { - debug!("Successfully parsed JSON from candidate {} of {}", i + 1, candidates.len()); + debug!( + "Successfully parsed JSON from candidate {} of {}", + i + 1, + candidates.len() + ); return Ok(candidate.clone()); } Err(e) => { - debug!("Failed to parse candidate {} of {}: {}", i + 1, candidates.len(), e); + debug!( + "Failed to parse candidate {} of {}: {}", + i + 1, + candidates.len(), + e + ); last_error = Some(e); } } } - + // If we get here, none of the candidates were valid JSON if let Some(err) = last_error { anyhow::bail!( @@ -338,37 +362,46 @@ impl FlockMode { err ); } - + anyhow::bail!("No valid JSON found in output") } - + /// Create segment workspaces by copying project directory async fn create_segment_workspaces(&mut self, partitions: &[String]) -> Result<()> { // Ensure flock workspace exists std::fs::create_dir_all(&self.config.flock_workspace)?; - + for (i, partition) in partitions.iter().enumerate() { let segment_id = i + 1; - let segment_dir = self.config.flock_workspace.join(format!("segment-{}", segment_id)); - + let segment_dir = self + .config + .flock_workspace + .join(format!("segment-{}", segment_id)); + println!(" Creating segment {} workspace...", segment_id); - + // Copy project directory to segment directory self.copy_git_repo(&self.config.project_dir, &segment_dir) .await .context(format!("Failed to copy project to segment {}", segment_id))?; - + // Write segment-requirements.md let requirements_path = segment_dir.join("segment-requirements.md"); - std::fs::write(&requirements_path, partition) - .context(format!("Failed to write requirements for segment {}", segment_id))?; - - println!(" ✓ Segment {} workspace ready at {}", segment_id, segment_dir.display()); + std::fs::write(&requirements_path, partition).context(format!( + "Failed to write requirements for segment {}", + segment_id + ))?; + + println!( + " ✓ Segment {} workspace ready at {}", + segment_id, + segment_dir.display() + ); } - + Ok(()) } - + /// Copy a git repository to a new location async fn copy_git_repo(&self, source: &Path, dest: &Path) -> Result<()> { // Use git clone for efficient copying @@ -379,26 +412,29 @@ impl FlockMode { .output() .await .context("Failed to run git clone")?; - + if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); anyhow::bail!("Git clone failed: {}", stderr); } - + Ok(()) } - + /// Run all segments in parallel async fn run_segments_parallel(&mut self) -> Result<()> { let mut handles = Vec::new(); - + for segment_id in 1..=self.config.num_segments { - let segment_dir = self.config.flock_workspace.join(format!("segment-{}", segment_id)); + let segment_dir = self + .config + .flock_workspace + .join(format!("segment-{}", segment_id)); let max_turns = self.config.max_turns; let g3_binary = self.get_g3_binary()?; let status_file = self.get_status_file_path(); let session_id = self.session_id.clone(); - + // Initialize segment status let segment_status = SegmentStatus { segment_id, @@ -414,10 +450,10 @@ impl FlockMode { last_message: Some("Starting...".to_string()), error_message: None, }; - + self.status.update_segment(segment_id, segment_status); self.save_status()?; - + // Spawn a task for this segment let handle = tokio::spawn(async move { run_segment( @@ -430,10 +466,10 @@ impl FlockMode { ) .await }); - + handles.push((segment_id, handle)); } - + // Wait for all segments to complete for (segment_id, handle) in handles { match handle.await { @@ -444,10 +480,17 @@ impl FlockMode { } Ok(Err(e)) => { error!("Segment {} failed: {}", segment_id, e); - let mut segment_status = self.status.segments.get(&segment_id).cloned() + let mut segment_status = self + .status + .segments + .get(&segment_id) + .cloned() .unwrap_or_else(|| SegmentStatus { segment_id, - workspace: self.config.flock_workspace.join(format!("segment-{}", segment_id)), + workspace: self + .config + .flock_workspace + .join(format!("segment-{}", segment_id)), state: SegmentState::Failed, started_at: Utc::now(), completed_at: Some(Utc::now()), @@ -468,10 +511,17 @@ impl FlockMode { } Err(e) => { error!("Segment {} task panicked: {}", segment_id, e); - let mut segment_status = self.status.segments.get(&segment_id).cloned() + let mut segment_status = self + .status + .segments + .get(&segment_id) + .cloned() .unwrap_or_else(|| SegmentStatus { segment_id, - workspace: self.config.flock_workspace.join(format!("segment-{}", segment_id)), + workspace: self + .config + .flock_workspace + .join(format!("segment-{}", segment_id)), state: SegmentState::Failed, started_at: Utc::now(), completed_at: Some(Utc::now()), @@ -492,10 +542,10 @@ impl FlockMode { } } } - + Ok(()) } - + /// Get the g3 binary path fn get_g3_binary(&self) -> Result { if let Some(ref binary) = self.config.g3_binary { @@ -505,12 +555,12 @@ impl FlockMode { std::env::current_exe().context("Failed to get current executable path") } } - + /// Get the status file path fn get_status_file_path(&self) -> PathBuf { self.config.flock_workspace.join("flock-status.json") } - + /// Save current status to file fn save_status(&self) -> Result<()> { let status_file = self.get_status_file_path(); @@ -527,8 +577,12 @@ async fn run_segment( status_file: PathBuf, session_id: String, ) -> Result { - info!("Starting segment {} in {}", segment_id, segment_dir.display()); - + info!( + "Starting segment {} in {}", + segment_id, + segment_dir.display() + ); + let mut segment_status = SegmentStatus { segment_id, workspace: segment_dir.clone(), @@ -543,7 +597,7 @@ async fn run_segment( last_message: Some("Starting autonomous mode...".to_string()), error_message: None, }; - + // Run g3 in autonomous mode with segment-requirements.md let mut child = Command::new(&g3_binary) .arg("--workspace") @@ -552,23 +606,25 @@ async fn run_segment( .arg("--max-turns") .arg(max_turns.to_string()) .arg("--requirements") - .arg(std::fs::read_to_string(segment_dir.join("segment-requirements.md"))?) + .arg(std::fs::read_to_string( + segment_dir.join("segment-requirements.md"), + )?) .arg("--quiet") // Disable session logging for workers .stdout(Stdio::piped()) .stderr(Stdio::piped()) .spawn() .context("Failed to spawn g3 process")?; - + // Stream output and update status let stdout = child.stdout.take().context("Failed to get stdout")?; let stderr = child.stderr.take().context("Failed to get stderr")?; - + let stdout_reader = BufReader::new(stdout); let stderr_reader = BufReader::new(stderr); - + let mut stdout_lines = stdout_reader.lines(); let mut stderr_lines = stderr_reader.lines(); - + // Read output and update status loop { tokio::select! { @@ -576,7 +632,7 @@ async fn run_segment( match line { Ok(Some(line)) => { println!("[Segment {}] {}", segment_id, line); - + // Parse output for status updates if line.contains("TURN") { // Extract turn number if possible @@ -586,7 +642,7 @@ async fn run_segment( } } } - + segment_status.last_message = Some(line); update_status_file(&status_file, &session_id, segment_status.clone())?; } @@ -613,12 +669,15 @@ async fn run_segment( } } } - + // Wait for process to complete - let status = child.wait().await.context("Failed to wait for g3 process")?; - + let status = child + .wait() + .await + .context("Failed to wait for g3 process")?; + segment_status.completed_at = Some(Utc::now()); - + if status.success() { segment_status.state = SegmentState::Completed; segment_status.last_message = Some("Completed successfully".to_string()); @@ -627,7 +686,7 @@ async fn run_segment( segment_status.error_message = Some(format!("Process exited with status: {}", status)); segment_status.errors += 1; } - + // Try to extract metrics from session log if available let log_dir = segment_dir.join("logs"); if log_dir.exists() { @@ -636,7 +695,9 @@ async fn run_segment( let path = entry.path(); if path.extension().and_then(|s| s.to_str()) == Some("json") { if let Ok(log_content) = std::fs::read_to_string(&path) { - if let Ok(log_json) = serde_json::from_str::(&log_content) { + if let Ok(log_json) = + serde_json::from_str::(&log_content) + { // Extract token usage if let Some(context) = log_json.get("context_window") { if let Some(cumulative) = context.get("cumulative_tokens") { @@ -645,7 +706,7 @@ async fn run_segment( } } } - + // Count tool calls from conversation history if let Some(context) = log_json.get("context_window") { if let Some(history) = context.get("conversation_history") { @@ -653,8 +714,7 @@ async fn run_segment( let tool_call_count = messages .iter() .filter(|msg| { - msg.get("role") - .and_then(|r| r.as_str()) + msg.get("role").and_then(|r| r.as_str()) == Some("tool") }) .count(); @@ -668,9 +728,9 @@ async fn run_segment( } } } - + update_status_file(&status_file, &session_id, segment_status.clone())?; - + Ok(segment_status) } @@ -685,24 +745,19 @@ fn update_status_file( FlockStatus::load_from_file(status_file)? } else { // This shouldn't happen, but handle it gracefully - FlockStatus::new( - session_id.to_string(), - PathBuf::new(), - PathBuf::new(), - 0, - ) + FlockStatus::new(session_id.to_string(), PathBuf::new(), PathBuf::new(), 0) }; - + flock_status.update_segment(segment_status.segment_id, segment_status); flock_status.save_to_file(status_file)?; - + Ok(()) } #[cfg(test)] mod tests { use super::FlockMode; - + #[test] fn extract_json_from_output_handles_partition_marker_and_fences() { const NOISY_PREFIX: &str = concat!( @@ -730,7 +785,7 @@ mod tests { "## Module Partitioning\n", "\n" ); - + let expected_json = r#"[ { "module_name": "message-protocol", @@ -743,18 +798,18 @@ mod tests { "dependencies": ["message-protocol"] } ]"#; - + let mut output = String::from(NOISY_PREFIX); output.push_str("{{PARTITION JSON}}\n```json\n"); output.push_str(expected_json); output.push_str("```"); - + let extracted = FlockMode::extract_json_from_output(&output) .expect("should extract JSON between markers"); - + assert_eq!(extracted, expected_json); } - + #[test] fn extract_json_from_output_handles_multiple_markers_and_invalid_json() { // This is the actual output from the LLM that was failing @@ -891,19 +946,19 @@ The requirements have been partitioned into two logical, largely non-overlapping 4. **Maintainability**: Changes to logging/monitoring don't affect core message handling 5. **Scalability**: Observability could be extracted to a separate service for distributed systems 6. **Dependency Direction**: Clean one-way dependency (observability → message-protocol) prevents circular dependencies"#; - + let extracted = FlockMode::extract_json_from_output(output) .expect("should extract valid JSON from output with multiple markers"); - + // Should be able to parse as JSON - let parsed: serde_json::Value = serde_json::from_str(&extracted) - .expect("extracted content should be valid JSON"); - + let parsed: serde_json::Value = + serde_json::from_str(&extracted).expect("extracted content should be valid JSON"); + // Verify it's an array with 2 elements assert!(parsed.is_array()); let arr = parsed.as_array().unwrap(); assert_eq!(arr.len(), 2); - + // Verify the structure assert_eq!(arr[0]["module_name"], "message-protocol"); assert_eq!(arr[1]["module_name"], "observability"); diff --git a/crates/g3-ensembles/src/status.rs b/crates/g3-ensembles/src/status.rs index 54d3529..474dcc5 100644 --- a/crates/g3-ensembles/src/status.rs +++ b/crates/g3-ensembles/src/status.rs @@ -10,37 +10,37 @@ use std::path::PathBuf; pub struct SegmentStatus { /// Segment number pub segment_id: usize, - + /// Segment workspace directory pub workspace: PathBuf, - + /// Current state of the segment pub state: SegmentState, - + /// Start time pub started_at: DateTime, - + /// Completion time (if finished) pub completed_at: Option>, - + /// Total tokens used pub tokens_used: u64, - + /// Number of tool calls made pub tool_calls: u64, - + /// Number of errors encountered pub errors: u64, - + /// Current turn number (for autonomous mode) pub current_turn: usize, - + /// Maximum turns allowed pub max_turns: usize, - + /// Last status message pub last_message: Option, - + /// Error message (if failed) pub error_message: Option, } @@ -50,16 +50,16 @@ pub struct SegmentStatus { pub enum SegmentState { /// Waiting to start Pending, - + /// Currently running Running, - + /// Completed successfully Completed, - + /// Failed with error Failed, - + /// Cancelled by user Cancelled, } @@ -81,31 +81,31 @@ impl std::fmt::Display for SegmentState { pub struct FlockStatus { /// Flock session ID pub session_id: String, - + /// Project directory pub project_dir: PathBuf, - + /// Flock workspace directory pub flock_workspace: PathBuf, - + /// Number of segments pub num_segments: usize, - + /// Start time pub started_at: DateTime, - + /// Completion time (if finished) pub completed_at: Option>, - + /// Status of each segment pub segments: HashMap, - + /// Total tokens used across all segments pub total_tokens: u64, - + /// Total tool calls across all segments pub total_tool_calls: u64, - + /// Total errors across all segments pub total_errors: u64, } @@ -131,20 +131,20 @@ impl FlockStatus { total_errors: 0, } } - + /// Update segment status pub fn update_segment(&mut self, segment_id: usize, status: SegmentStatus) { self.segments.insert(segment_id, status); self.recalculate_totals(); } - + /// Recalculate total metrics fn recalculate_totals(&mut self) { self.total_tokens = self.segments.values().map(|s| s.tokens_used).sum(); self.total_tool_calls = self.segments.values().map(|s| s.tool_calls).sum(); self.total_errors = self.segments.values().map(|s| s.errors).sum(); } - + /// Check if all segments are complete pub fn is_complete(&self) -> bool { self.segments.len() == self.num_segments @@ -155,86 +155,116 @@ impl FlockStatus { ) }) } - + /// Get count of segments by state pub fn count_by_state(&self, state: SegmentState) -> usize { self.segments.values().filter(|s| s.state == state).count() } - + /// Save status to file pub fn save_to_file(&self, path: &PathBuf) -> anyhow::Result<()> { let json = serde_json::to_string_pretty(self)?; std::fs::write(path, json)?; Ok(()) } - + /// Load status from file pub fn load_from_file(path: &PathBuf) -> anyhow::Result { let json = std::fs::read_to_string(path)?; let status = serde_json::from_str(&json)?; Ok(status) } - + /// Generate a summary report pub fn generate_report(&self) -> String { let mut report = String::new(); - + report.push_str(&format!("\n{}", "=".repeat(80))); report.push_str(&format!("\n📊 FLOCK MODE SESSION REPORT")); report.push_str(&format!("\n{}", "=".repeat(80))); - + report.push_str(&format!("\n\n🆔 Session ID: {}", self.session_id)); report.push_str(&format!("\n📁 Project: {}", self.project_dir.display())); - report.push_str(&format!("\n🗂️ Workspace: {}", self.flock_workspace.display())); + report.push_str(&format!( + "\n🗂️ Workspace: {}", + self.flock_workspace.display() + )); report.push_str(&format!("\n🔢 Segments: {}", self.num_segments)); - + let duration = if let Some(completed) = self.completed_at { completed.signed_duration_since(self.started_at) } else { Utc::now().signed_duration_since(self.started_at) }; - - report.push_str(&format!("\n⏱️ Duration: {:.2}s", duration.num_milliseconds() as f64 / 1000.0)); - + + report.push_str(&format!( + "\n⏱️ Duration: {:.2}s", + duration.num_milliseconds() as f64 / 1000.0 + )); + // Segment status summary report.push_str(&format!("\n\n📈 Segment Status:")); - report.push_str(&format!("\n • Completed: {}", self.count_by_state(SegmentState::Completed))); - report.push_str(&format!("\n • Running: {}", self.count_by_state(SegmentState::Running))); - report.push_str(&format!("\n • Failed: {}", self.count_by_state(SegmentState::Failed))); - report.push_str(&format!("\n • Pending: {}", self.count_by_state(SegmentState::Pending))); - report.push_str(&format!("\n • Cancelled: {}", self.count_by_state(SegmentState::Cancelled))); - + report.push_str(&format!( + "\n • Completed: {}", + self.count_by_state(SegmentState::Completed) + )); + report.push_str(&format!( + "\n • Running: {}", + self.count_by_state(SegmentState::Running) + )); + report.push_str(&format!( + "\n • Failed: {}", + self.count_by_state(SegmentState::Failed) + )); + report.push_str(&format!( + "\n • Pending: {}", + self.count_by_state(SegmentState::Pending) + )); + report.push_str(&format!( + "\n • Cancelled: {}", + self.count_by_state(SegmentState::Cancelled) + )); + // Metrics report.push_str(&format!("\n\n📊 Aggregate Metrics:")); report.push_str(&format!("\n • Total Tokens: {}", self.total_tokens)); - report.push_str(&format!("\n • Total Tool Calls: {}", self.total_tool_calls)); + report.push_str(&format!( + "\n • Total Tool Calls: {}", + self.total_tool_calls + )); report.push_str(&format!("\n • Total Errors: {}", self.total_errors)); - + // Per-segment details report.push_str(&format!("\n\n🔍 Segment Details:")); let mut segments: Vec<_> = self.segments.iter().collect(); segments.sort_by_key(|(id, _)| *id); - + for (id, segment) in segments { report.push_str(&format!("\n\n Segment {}:", id)); report.push_str(&format!("\n Status: {}", segment.state)); - report.push_str(&format!("\n Workspace: {}", segment.workspace.display())); + report.push_str(&format!( + "\n Workspace: {}", + segment.workspace.display() + )); report.push_str(&format!("\n Tokens: {}", segment.tokens_used)); report.push_str(&format!("\n Tool Calls: {}", segment.tool_calls)); report.push_str(&format!("\n Errors: {}", segment.errors)); - report.push_str(&format!("\n Turn: {}/{}", segment.current_turn, segment.max_turns)); - + report.push_str(&format!( + "\n Turn: {}/{}", + segment.current_turn, segment.max_turns + )); + if let Some(ref msg) = segment.last_message { report.push_str(&format!("\n Last Message: {}", msg)); } - + if let Some(ref err) = segment.error_message { report.push_str(&format!("\n Error: {}", err)); } } - + report.push_str(&format!("\n\n{}", "=".repeat(80))); - + report } } diff --git a/crates/g3-ensembles/src/tests.rs b/crates/g3-ensembles/src/tests.rs index 7907757..ec89b96 100644 --- a/crates/g3-ensembles/src/tests.rs +++ b/crates/g3-ensembles/src/tests.rs @@ -283,8 +283,7 @@ mod tests { assert!(json.contains("Completed")); // Deserialize back - let deserialized: FlockStatus = - serde_json::from_str(&json).expect("Failed to deserialize"); + let deserialized: FlockStatus = serde_json::from_str(&json).expect("Failed to deserialize"); assert_eq!(deserialized.session_id, "test-session"); assert_eq!(deserialized.segments.len(), 1); assert_eq!(deserialized.total_tokens, 1000); diff --git a/crates/g3-ensembles/tests/integration_tests.rs b/crates/g3-ensembles/tests/integration_tests.rs index 1a7a895..c47a448 100644 --- a/crates/g3-ensembles/tests/integration_tests.rs +++ b/crates/g3-ensembles/tests/integration_tests.rs @@ -71,7 +71,7 @@ fn create_test_project(name: &str) -> TempDir { } #[test] - fn test_flock_config_validation() { +fn test_flock_config_validation() { let temp_dir = TempDir::new().unwrap(); let project_path = temp_dir.path().to_path_buf(); let workspace_path = temp_dir.path().join("workspace"); @@ -213,8 +213,7 @@ fn test_multiple_segment_clones() { assert!(segment2.exists()); // Modify segment 1 - fs::write(segment1.join("test.txt"), "segment 1") - .expect("Failed to write to segment 1"); + fs::write(segment1.join("test.txt"), "segment 1").expect("Failed to write to segment 1"); // Verify segment 2 is unaffected assert!(!segment2.join("test.txt").exists()); @@ -236,8 +235,11 @@ fn test_segment_requirements_creation() { // Create segment-requirements.md (what flock mode does) let segment_requirements = "# Module A\n\nImplement module A functionality\n"; - fs::write(segment_dir.join("segment-requirements.md"), segment_requirements) - .expect("Failed to write segment requirements"); + fs::write( + segment_dir.join("segment-requirements.md"), + segment_requirements, + ) + .expect("Failed to write segment requirements"); // Verify it was created assert!(segment_dir.join("segment-requirements.md").exists()); diff --git a/crates/g3-execution/examples/setup_coverage_tools.rs b/crates/g3-execution/examples/setup_coverage_tools.rs index 9495800..2cc5b81 100644 --- a/crates/g3-execution/examples/setup_coverage_tools.rs +++ b/crates/g3-execution/examples/setup_coverage_tools.rs @@ -3,7 +3,7 @@ use g3_execution::ensure_coverage_tools_installed; fn main() -> anyhow::Result<()> { // Ensure coverage tools are installed let already_installed = ensure_coverage_tools_installed()?; - + if already_installed { println!("All coverage tools are already installed!"); } else { diff --git a/crates/g3-execution/src/lib.rs b/crates/g3-execution/src/lib.rs index 0994831..2629932 100644 --- a/crates/g3-execution/src/lib.rs +++ b/crates/g3-execution/src/lib.rs @@ -1,9 +1,9 @@ use anyhow::Result; use regex::Regex; +use std::io::Write; use std::process::Command; use tempfile::NamedTempFile; -use std::io::Write; -use tracing::{info, debug, error}; +use tracing::{debug, error, info}; /// Expand tilde (~) in a path to the user's home directory fn expand_tilde(path: &str) -> String { @@ -32,40 +32,52 @@ impl CodeExecutor { pub fn new() -> Self { Self {} } - + /// Extract code blocks from LLM response and execute them pub async fn execute_from_response(&self, response: &str) -> Result { - self.execute_from_response_with_options(response, true).await + self.execute_from_response_with_options(response, true) + .await } - + /// Extract code blocks from LLM response and execute them with UI options - pub async fn execute_from_response_with_options(&self, response: &str, show_code: bool) -> Result { - debug!("CodeExecutor received response ({} chars): {}", response.len(), response); + pub async fn execute_from_response_with_options( + &self, + response: &str, + show_code: bool, + ) -> Result { + debug!( + "CodeExecutor received response ({} chars): {}", + response.len(), + response + ); let code_blocks = self.extract_code_blocks(response)?; - + if code_blocks.is_empty() { if show_code { - return Ok(format!("⚠️ No executable code blocks found in response.\n\n{}", response)); + return Ok(format!( + "⚠️ No executable code blocks found in response.\n\n{}", + response + )); } else { return Ok("⚠️ No executable code found.".to_string()); } } - + let mut results = Vec::new(); - + // Only show the original LLM response if show_code is true if show_code { results.push(response.to_string()); results.push("\n🚀 Executing code...\n".to_string()); } - + for (language, code) in code_blocks { info!("Executing {} code", language); - + if show_code { results.push(format!("📋 Running {} code:", language)); } - + match self.execute_code(&language, &code).await { Ok(result) => { if result.success { @@ -89,8 +101,8 @@ impl CodeExecutor { } } } - - // If no results were added (e.g., successful execution with no output), + + // If no results were added (e.g., successful execution with no output), // return a simple success message when show_code is false if results.is_empty() && !show_code { Ok("✅ Done".to_string()) @@ -98,51 +110,58 @@ impl CodeExecutor { Ok(results.join("\n")) } } - + /// Extract code blocks from markdown-formatted text fn extract_code_blocks(&self, text: &str) -> Result> { let mut blocks = Vec::new(); - + debug!("Extracting code blocks from text: {}", text); - + // Pattern 1: Standard markdown format ```language\ncode``` let markdown_re = Regex::new(r"(?s)```(\w+)?\n(.*?)```")?; for cap in markdown_re.captures_iter(text) { - let language = cap.get(1) + let language = cap + .get(1) .map(|m| m.as_str().to_lowercase()) .unwrap_or_else(|| "bash".to_string()); // Default to bash let code = cap.get(2).map(|m| m.as_str()).unwrap_or("").trim(); - - debug!("Found markdown code block - language: '{}', code: '{}'", language, code); - + + debug!( + "Found markdown code block - language: '{}', code: '{}'", + language, code + ); + if !code.is_empty() { blocks.push((language, code.to_string())); } } - + // Pattern 2: Bracket format [Language]code[/Language] let bracket_re = Regex::new(r"(?s)\[(\w+)\]\s*(.*?)\s*\[/(\w+)\]")?; for cap in bracket_re.captures_iter(text) { let open_lang = cap.get(1).map(|m| m.as_str()).unwrap_or(""); let close_lang = cap.get(3).map(|m| m.as_str()).unwrap_or(""); - + // Only match if opening and closing tags are the same (case insensitive) if open_lang.to_lowercase() == close_lang.to_lowercase() { let language = open_lang.to_lowercase(); let code = cap.get(2).map(|m| m.as_str()).unwrap_or("").trim(); - - debug!("Found bracket code block - language: '{}', code: '{}'", language, code); - + + debug!( + "Found bracket code block - language: '{}', code: '{}'", + language, code + ); + if !code.is_empty() { blocks.push((language, code.to_string())); } } } - + debug!("Total code blocks found: {}", blocks.len()); Ok(blocks) } - + /// Execute code in the specified language pub async fn execute_code(&self, language: &str, code: &str) -> Result { match language.to_lowercase().as_str() { @@ -156,17 +175,15 @@ impl CodeExecutor { } } } - + /// Execute Python code async fn execute_python(&self, code: &str) -> Result { let mut temp_file = NamedTempFile::new()?; temp_file.write_all(code.as_bytes())?; let temp_path = temp_file.path(); - - let output = Command::new("python3") - .arg(temp_path) - .output()?; - + + let output = Command::new("python3").arg(temp_path).output()?; + Ok(ExecutionResult { stdout: String::from_utf8_lossy(&output.stdout).to_string(), stderr: String::from_utf8_lossy(&output.stderr).to_string(), @@ -174,15 +191,15 @@ impl CodeExecutor { success: output.status.success(), }) } - + /// Execute Bash code async fn execute_bash(&self, code: &str) -> Result { // Check if this is a detached/daemon command that should run independently - let is_detached = code.trim_start().starts_with("setsid ") + let is_detached = code.trim_start().starts_with("setsid ") || code.trim_start().starts_with("nohup ") || code.contains(" disown") || (code.contains(" &") && (code.contains("nohup") || code.contains("setsid"))); - + if is_detached { // For detached commands, just spawn and return immediately use std::process::Stdio; @@ -193,7 +210,7 @@ impl CodeExecutor { .stdout(Stdio::null()) .stderr(Stdio::null()) .spawn()?; - + return Ok(ExecutionResult { stdout: "✅ Command launched in background (detached process)".to_string(), stderr: String::new(), @@ -201,12 +218,9 @@ impl CodeExecutor { success: true, }); } - - let output = Command::new("bash") - .arg("-c") - .arg(code) - .output()?; - + + let output = Command::new("bash").arg("-c").arg(code).output()?; + Ok(ExecutionResult { stdout: String::from_utf8_lossy(&output.stdout).to_string(), stderr: String::from_utf8_lossy(&output.stderr).to_string(), @@ -214,17 +228,15 @@ impl CodeExecutor { success: output.status.success(), }) } - + /// Execute JavaScript code (requires Node.js) async fn execute_javascript(&self, code: &str) -> Result { let mut temp_file = NamedTempFile::new()?; temp_file.write_all(code.as_bytes())?; let temp_path = temp_file.path(); - - let output = Command::new("node") - .arg(temp_path) - .output()?; - + + let output = Command::new("node").arg(temp_path).output()?; + Ok(ExecutionResult { stdout: String::from_utf8_lossy(&output.stdout).to_string(), stderr: String::from_utf8_lossy(&output.stderr).to_string(), @@ -249,57 +261,69 @@ pub trait OutputReceiver: Send + Sync { impl CodeExecutor { /// Execute bash command with streaming output pub async fn execute_bash_streaming( - &self, - code: &str, - receiver: &R + &self, + code: &str, + receiver: &R, ) -> Result { - self.execute_bash_streaming_in_dir(code, receiver, None).await + self.execute_bash_streaming_in_dir(code, receiver, None) + .await } /// Execute bash command with streaming output in a specific directory pub async fn execute_bash_streaming_in_dir( - &self, - code: &str, + &self, + code: &str, receiver: &R, working_dir: Option<&str>, ) -> Result { use std::process::Stdio; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command as TokioCommand; - + // CRITICAL DEBUG: Print to stderr so it's always visible debug!("========== execute_bash_streaming_in_dir START =========="); debug!("Code to execute: {}", code); debug!("Working directory parameter: {:?}", working_dir); - debug!("FULL DIAGNOSTIC: code='{}', working_dir={:?}", code, working_dir); - + debug!( + "FULL DIAGNOSTIC: code='{}', working_dir={:?}", + code, working_dir + ); + if let Some(dir) = working_dir { - debug!("Working dir exists check: {}", std::path::Path::new(dir).exists()); - debug!("Working dir is_dir check: {}", std::path::Path::new(dir).is_dir()); + debug!( + "Working dir exists check: {}", + std::path::Path::new(dir).exists() + ); + debug!( + "Working dir is_dir check: {}", + std::path::Path::new(dir).is_dir() + ); } - debug!("Current process working directory: {:?}", std::env::current_dir()); - + debug!( + "Current process working directory: {:?}", + std::env::current_dir() + ); + // Check if this is a detached/daemon command that should run independently // Look for patterns like: setsid, nohup with &, or explicit backgrounding with disown - let is_detached = code.trim_start().starts_with("setsid ") + let is_detached = code.trim_start().starts_with("setsid ") || code.trim_start().starts_with("nohup ") || code.contains(" disown") || (code.contains(" &") && (code.contains("nohup") || code.contains("setsid"))); - + if is_detached { // For detached commands, just spawn and return immediately let mut cmd = TokioCommand::new("bash"); - cmd.arg("-c") - .arg(code); - + cmd.arg("-c").arg(code); + // Set working directory if provided if let Some(dir) = working_dir { let expanded_dir = expand_tilde(dir); cmd.current_dir(&expanded_dir); } - + cmd.spawn()?; - + // Don't wait for the process - it's meant to run independently return Ok(ExecutionResult { stdout: "✅ Command launched in background (detached process)".to_string(), @@ -308,23 +332,29 @@ impl CodeExecutor { success: true, }); } - + let mut cmd = TokioCommand::new("bash"); cmd.arg("-c") .arg(code) .stdout(Stdio::piped()) .stderr(Stdio::piped()); - + // Set working directory if provided if let Some(dir) = working_dir { debug!("Setting current_dir on command to: {}", dir); let expanded_dir = expand_tilde(dir); debug!("Expanded working dir: {}", expanded_dir); - debug!("Expanded dir exists: {}", std::path::Path::new(&expanded_dir).exists()); - debug!("Expanded dir is_dir: {}", std::path::Path::new(&expanded_dir).is_dir()); + debug!( + "Expanded dir exists: {}", + std::path::Path::new(&expanded_dir).exists() + ); + debug!( + "Expanded dir is_dir: {}", + std::path::Path::new(&expanded_dir).is_dir() + ); cmd.current_dir(&expanded_dir); } - + debug!("About to spawn command..."); let spawn_result = cmd.spawn(); debug!("Spawn result: {:?}", spawn_result.is_ok()); @@ -336,19 +366,19 @@ impl CodeExecutor { } }; debug!("Command spawned successfully"); - + let stdout = child.stdout.take().unwrap(); let stderr = child.stderr.take().unwrap(); - + let stdout_reader = BufReader::new(stdout); let stderr_reader = BufReader::new(stderr); - + let mut stdout_lines = stdout_reader.lines(); let mut stderr_lines = stderr_reader.lines(); - + let mut stdout_output = Vec::new(); let mut stderr_output = Vec::new(); - + // Read output lines as they come loop { tokio::select! { @@ -380,16 +410,16 @@ impl CodeExecutor { else => break } } - + let status = child.wait().await?; - + let result = ExecutionResult { stdout: stdout_output.join("\n"), stderr: stderr_output.join("\n"), exit_code: status.code().unwrap_or(-1), success: status.success(), }; - + debug!("========== execute_bash_streaming_in_dir END =========="); debug!("Exit code: {}", result.exit_code); debug!("Success: {}", result.success); @@ -408,24 +438,22 @@ pub fn is_llvm_tools_installed() -> Result { let output = Command::new("rustup") .args(&["component", "list", "--installed"]) .output()?; - + let installed = String::from_utf8_lossy(&output.stdout) .lines() .any(|line| line.trim() == "llvm-tools-preview" || line.starts_with("llvm-tools")); - + Ok(installed) } /// Check if cargo-llvm-cov is installed pub fn is_cargo_llvm_cov_installed() -> Result { - let output = Command::new("cargo") - .args(&["--list"]) - .output()?; - + let output = Command::new("cargo").args(&["--list"]).output()?; + let installed = String::from_utf8_lossy(&output.stdout) .lines() .any(|line| line.trim().starts_with("llvm-cov")); - + Ok(installed) } @@ -435,12 +463,12 @@ pub fn install_llvm_tools() -> Result<()> { let output = Command::new("rustup") .args(&["component", "add", "llvm-tools-preview"]) .output()?; - + if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); anyhow::bail!("Failed to install llvm-tools-preview: {}", stderr); } - + info!("✅ llvm-tools-preview installed successfully"); Ok(()) } @@ -451,12 +479,12 @@ pub fn install_cargo_llvm_cov() -> Result<()> { let output = Command::new("cargo") .args(&["install", "cargo-llvm-cov"]) .output()?; - + if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); anyhow::bail!("Failed to install cargo-llvm-cov: {}", stderr); } - + info!("✅ cargo-llvm-cov installed successfully"); Ok(()) } @@ -465,7 +493,7 @@ pub fn install_cargo_llvm_cov() -> Result<()> { /// Returns Ok(true) if tools were already installed, Ok(false) if they were installed by this function pub fn ensure_coverage_tools_installed() -> Result { let mut already_installed = true; - + // Check and install llvm-tools-preview if !is_llvm_tools_installed()? { info!("llvm-tools-preview not found, installing..."); @@ -474,7 +502,7 @@ pub fn ensure_coverage_tools_installed() -> Result { } else { info!("✅ llvm-tools-preview is already installed"); } - + // Check and install cargo-llvm-cov if !is_cargo_llvm_cov_installed()? { info!("cargo-llvm-cov not found, installing..."); @@ -483,6 +511,6 @@ pub fn ensure_coverage_tools_installed() -> Result { } else { info!("✅ cargo-llvm-cov is already installed"); } - + Ok(already_installed) } diff --git a/crates/g3-planner/src/code_explore.rs b/crates/g3-planner/src/code_explore.rs index ee8c3da..ad930da 100644 --- a/crates/g3-planner/src/code_explore.rs +++ b/crates/g3-planner/src/code_explore.rs @@ -291,7 +291,10 @@ pub fn explore_kotlin(path: &str) -> String { // Build files report.push_str("--- Build Configuration ---\n"); - let build = run_command("cat build.gradle.kts 2>/dev/null | head -50 || cat build.gradle 2>/dev/null | head -50", path); + let build = run_command( + "cat build.gradle.kts 2>/dev/null | head -50 || cat build.gradle 2>/dev/null | head -50", + path, + ); report.push_str(&build); report.push('\n'); diff --git a/crates/g3-planner/src/lib.rs b/crates/g3-planner/src/lib.rs index 3bf4ac9..8317b64 100644 --- a/crates/g3-planner/src/lib.rs +++ b/crates/g3-planner/src/lib.rs @@ -9,11 +9,11 @@ pub mod prompts; pub use code_explore::explore_codebase; use anyhow::Result; -use g3_providers::{CompletionRequest, LLMProvider, Message, MessageRole}; use chrono::Local; +use g3_providers::{CompletionRequest, LLMProvider, Message, MessageRole}; +use prompts::{DISCOVERY_REQUIREMENTS_PROMPT, DISCOVERY_SYSTEM_PROMPT}; use std::fs::{self, OpenOptions}; use std::io::Write; -use prompts::{DISCOVERY_REQUIREMENTS_PROMPT, DISCOVERY_SYSTEM_PROMPT}; /// Type alias for a status callback function pub type StatusCallback = Box; @@ -94,7 +94,10 @@ pub async fn get_initial_discovery_messages( // Step 5: Extract shell commands from the response let shell_commands = extract_shell_commands(&response.content); - status(&format!("📋 Extracted {} discovery commands", shell_commands.len())); + status(&format!( + "📋 Extracted {} discovery commands", + shell_commands.len() + )); // Write the discovery commands to logs directory write_discovery_commands(&shell_commands)?; diff --git a/crates/g3-planner/tests/logging_test.rs b/crates/g3-planner/tests/logging_test.rs index 59ff045..4d42722 100644 --- a/crates/g3-planner/tests/logging_test.rs +++ b/crates/g3-planner/tests/logging_test.rs @@ -7,38 +7,40 @@ use std::path::Path; fn test_log_files_created() { // This test verifies that the logging functions work correctly // by checking that files can be created in the logs directory - + // Clean up any existing test logs let _ = fs::remove_dir_all("logs"); - + // Create logs directory fs::create_dir_all("logs").expect("Failed to create logs directory"); - + // Verify directory exists assert!(Path::new("logs").exists()); assert!(Path::new("logs").is_dir()); - + // Test writing a code report let test_report = "Test codebase report\nLine 2\nLine 3"; let timestamp = chrono::Local::now().format("%Y%m%d_%H%M%S").to_string(); let report_filename = format!("logs/code_report_{}.log", timestamp); - + fs::write(&report_filename, test_report).expect("Failed to write code report"); assert!(Path::new(&report_filename).exists()); - + let content = fs::read_to_string(&report_filename).expect("Failed to read code report"); assert_eq!(content, test_report); - + // Test writing discovery commands let commands_filename = format!("logs/discovery_commands_{}.log", timestamp); - let test_commands = "# Discovery Commands\n# Generated by g3-planner\n\nls -la\ncat README.md\n"; - + let test_commands = + "# Discovery Commands\n# Generated by g3-planner\n\nls -la\ncat README.md\n"; + fs::write(&commands_filename, test_commands).expect("Failed to write discovery commands"); assert!(Path::new(&commands_filename).exists()); - - let content = fs::read_to_string(&commands_filename).expect("Failed to read discovery commands"); + + let content = + fs::read_to_string(&commands_filename).expect("Failed to read discovery commands"); assert_eq!(content, test_commands); - + // Clean up let _ = fs::remove_file(&report_filename); let _ = fs::remove_file(&commands_filename); @@ -48,11 +50,11 @@ fn test_log_files_created() { fn test_filename_format() { // Verify the filename format matches the tool_calls log format let timestamp = chrono::Local::now().format("%Y%m%d_%H%M%S").to_string(); - + // Check format: YYYYMMDD_HHMMSS assert_eq!(timestamp.len(), 15); // 8 digits + underscore + 6 digits assert!(timestamp.contains('_')); - + let parts: Vec<&str> = timestamp.split('_').collect(); assert_eq!(parts.len(), 2); assert_eq!(parts[0].len(), 8); // YYYYMMDD diff --git a/crates/g3-providers/src/anthropic.rs b/crates/g3-providers/src/anthropic.rs index f83b855..fc8fcbf 100644 --- a/crates/g3-providers/src/anthropic.rs +++ b/crates/g3-providers/src/anthropic.rs @@ -139,7 +139,7 @@ impl AnthropicProvider { .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?; let model = model.unwrap_or_else(|| "claude-3-5-sonnet-20241022".to_string()); - + debug!("Initialized Anthropic provider with model: {}", model); Ok(Self { @@ -160,11 +160,11 @@ impl AnthropicProvider { .header("x-api-key", &self.api_key) .header("anthropic-version", ANTHROPIC_VERSION) .header("content-type", "application/json"); - + if self.enable_1m_context { builder = builder.header("anthropic-beta", "context-1m-2025-08-07"); } - + if streaming { builder = builder.header("accept", "text/event-stream"); } @@ -188,12 +188,17 @@ impl AnthropicProvider { }; // Extract properties and required fields from the input schema - if let Ok(schema_obj) = serde_json::from_value::>(tool.input_schema.clone()) { + if let Ok(schema_obj) = serde_json::from_value::< + serde_json::Map, + >(tool.input_schema.clone()) + { if let Some(properties) = schema_obj.get("properties") { schema.properties = properties.clone(); } if let Some(required) = schema_obj.get("required") { - if let Ok(required_vec) = serde_json::from_value::>(required.clone()) { + if let Ok(required_vec) = + serde_json::from_value::>(required.clone()) + { schema.required = Some(required_vec); } } @@ -208,7 +213,10 @@ impl AnthropicProvider { .collect() } - fn convert_messages(&self, messages: &[Message]) -> Result<(Option, Vec)> { + fn convert_messages( + &self, + messages: &[Message], + ) -> Result<(Option, Vec)> { let mut system_message = None; let mut anthropic_messages = Vec::new(); @@ -225,7 +233,9 @@ impl AnthropicProvider { role: "user".to_string(), content: vec![AnthropicContent::Text { text: message.content.clone(), - cache_control: message.cache_control.as_ref() + cache_control: message + .cache_control + .as_ref() .map(Self::convert_cache_control), }], }); @@ -235,7 +245,9 @@ impl AnthropicProvider { role: "assistant".to_string(), content: vec![AnthropicContent::Text { text: message.content.clone(), - cache_control: message.cache_control.as_ref() + cache_control: message + .cache_control + .as_ref() .map(Self::convert_cache_control), }], }); @@ -257,7 +269,9 @@ impl AnthropicProvider { let (system, anthropic_messages) = self.convert_messages(messages)?; if anthropic_messages.is_empty() { - return Err(anyhow!("At least one user or assistant message is required")); + return Err(anyhow!( + "At least one user or assistant message is required" + )); } // Convert tools if provided @@ -292,13 +306,13 @@ impl AnthropicProvider { let mut accumulated_usage: Option = None; let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences let mut message_stopped = false; // Track if we've received message_stop - + while let Some(chunk_result) = stream.next().await { match chunk_result { Ok(chunk) => { // Append new bytes to our buffer byte_buffer.extend_from_slice(&chunk); - + // Try to convert the entire buffer to UTF-8 let chunk_str = match std::str::from_utf8(&byte_buffer) { Ok(s) => { @@ -312,7 +326,8 @@ impl AnthropicProvider { let valid_up_to = e.valid_up_to(); if valid_up_to > 0 { // We have some valid UTF-8, extract it and keep the rest for next iteration - let valid_bytes = byte_buffer.drain(..valid_up_to).collect::>(); + let valid_bytes = + byte_buffer.drain(..valid_up_to).collect::>(); std::str::from_utf8(&valid_bytes).unwrap().to_string() } else { // No valid UTF-8 at all, skip this chunk and continue @@ -346,7 +361,11 @@ impl AnthropicProvider { content: String::new(), finished: true, usage: accumulated_usage.clone(), - tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) }, + tool_calls: if current_tool_calls.is_empty() { + None + } else { + Some(current_tool_calls.clone()) + }, }; if tx.send(Ok(final_chunk)).await.is_err() { debug!("Receiver dropped, stopping stream"); @@ -358,7 +377,10 @@ impl AnthropicProvider { match serde_json::from_str::(data) { Ok(event) => { - debug!("Parsed event type: {}, event: {:?}", event.event_type, event); + debug!( + "Parsed event type: {}, event: {:?}", + event.event_type, event + ); match event.event_type.as_str() { "message_start" => { // Extract usage data from message_start event @@ -367,19 +389,30 @@ impl AnthropicProvider { accumulated_usage = Some(Usage { prompt_tokens: usage.input_tokens, completion_tokens: usage.output_tokens, - total_tokens: usage.input_tokens + usage.output_tokens, + total_tokens: usage.input_tokens + + usage.output_tokens, }); - debug!("Captured usage from message_start: {:?}", accumulated_usage); + debug!( + "Captured usage from message_start: {:?}", + accumulated_usage + ); } } } "content_block_start" => { - debug!("Received content_block_start event: {:?}", event); + debug!( + "Received content_block_start event: {:?}", + event + ); if let Some(content_block) = event.content_block { match content_block { - AnthropicContent::ToolUse { id, name, input } => { + AnthropicContent::ToolUse { + id, + name, + input, + } => { debug!("Found tool use in content_block_start: id={}, name={}, input={:?}", id, name, input); - + // For native tool calls, create the tool call immediately if we have complete args // If args are empty, we'll wait for partial_json to accumulate them let tool_call = ToolCall { @@ -387,9 +420,14 @@ impl AnthropicProvider { tool: name.clone(), args: input.clone(), }; - + // Check if we already have complete arguments - if !input.is_null() && input != serde_json::Value::Object(serde_json::Map::new()) { + if !input.is_null() + && input + != serde_json::Value::Object( + serde_json::Map::new(), + ) + { // We have complete arguments, send the tool call immediately debug!("Tool call has complete args, sending immediately: {:?}", tool_call); let chunk = CompletionChunk { @@ -410,7 +448,10 @@ impl AnthropicProvider { } } _ => { - debug!("Non-tool content block: {:?}", content_block); + debug!( + "Non-tool content block: {:?}", + content_block + ); } } } @@ -418,7 +459,11 @@ impl AnthropicProvider { "content_block_delta" => { if let Some(delta) = event.delta { if let Some(text) = delta.text { - debug!("Sending text chunk of length {}: '{}'", text.len(), text); + debug!( + "Sending text chunk of length {}: '{}'", + text.len(), + text + ); let chunk = CompletionChunk { content: text, finished: false, @@ -432,31 +477,51 @@ impl AnthropicProvider { } // Handle partial JSON for tool calls if let Some(partial_json) = delta.partial_json { - debug!("Received partial JSON: {}", partial_json); + debug!( + "Received partial JSON: {}", + partial_json + ); partial_tool_json.push_str(&partial_json); - debug!("Accumulated tool JSON: {}", partial_tool_json); + debug!( + "Accumulated tool JSON: {}", + partial_tool_json + ); } } } "content_block_stop" => { // Tool call block is complete - now parse the accumulated JSON - if !current_tool_calls.is_empty() && !partial_tool_json.is_empty() { - debug!("Parsing complete tool JSON: {}", partial_tool_json); - + if !current_tool_calls.is_empty() + && !partial_tool_json.is_empty() + { + debug!( + "Parsing complete tool JSON: {}", + partial_tool_json + ); + // Parse the accumulated JSON and update the last tool call - if let Ok(parsed_args) = serde_json::from_str::(&partial_tool_json) { - if let Some(last_tool) = current_tool_calls.last_mut() { + if let Ok(parsed_args) = + serde_json::from_str::( + &partial_tool_json, + ) + { + if let Some(last_tool) = + current_tool_calls.last_mut() + { last_tool.args = parsed_args; debug!("Updated tool call with complete args: {:?}", last_tool); } } else { - debug!("Failed to parse accumulated JSON: {}", partial_tool_json); + debug!( + "Failed to parse accumulated JSON: {}", + partial_tool_json + ); } - + // Clear the accumulator partial_tool_json.clear(); } - + // Send the complete tool call if !current_tool_calls.is_empty() { let chunk = CompletionChunk { @@ -478,7 +543,11 @@ impl AnthropicProvider { content: String::new(), finished: true, usage: accumulated_usage.clone(), - tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) }, + tool_calls: if current_tool_calls.is_empty() { + None + } else { + Some(current_tool_calls.clone()) + }, }; if tx.send(Ok(final_chunk)).await.is_err() { debug!("Receiver dropped, stopping stream"); @@ -490,7 +559,10 @@ impl AnthropicProvider { if let Some(error) = event.error { error!("Anthropic API error: {:?}", error); let _ = tx - .send(Err(anyhow!("Anthropic API error: {:?}", error))) + .send(Err(anyhow!( + "Anthropic API error: {:?}", + error + ))) .await; break; // Break to let stream exhaust naturally } @@ -524,7 +596,11 @@ impl AnthropicProvider { content: String::new(), finished: true, usage: accumulated_usage.clone(), - tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) }, + tool_calls: if current_tool_calls.is_empty() { + None + } else { + Some(current_tool_calls) + }, }; let _ = tx.send(Ok(final_chunk)).await; accumulated_usage @@ -543,15 +619,17 @@ impl LLMProvider for AnthropicProvider { let temperature = request.temperature.unwrap_or(self.temperature); let request_body = self.create_request_body( - &request.messages, - request.tools.as_deref(), - false, - max_tokens, - temperature + &request.messages, + request.tools.as_deref(), + false, + max_tokens, + temperature, )?; - debug!("Sending request to Anthropic API: model={}, max_tokens={}, temperature={}", - request_body.model, request_body.max_tokens, request_body.temperature); + debug!( + "Sending request to Anthropic API: model={}, max_tokens={}, temperature={}", + request_body.model, request_body.max_tokens, request_body.temperature + ); let response = self .create_request_builder(false) @@ -588,7 +666,8 @@ impl LLMProvider for AnthropicProvider { let usage = Usage { prompt_tokens: anthropic_response.usage.input_tokens, completion_tokens: anthropic_response.usage.output_tokens, - total_tokens: anthropic_response.usage.input_tokens + anthropic_response.usage.output_tokens, + total_tokens: anthropic_response.usage.input_tokens + + anthropic_response.usage.output_tokens, }; debug!( @@ -613,18 +692,24 @@ impl LLMProvider for AnthropicProvider { let temperature = request.temperature.unwrap_or(self.temperature); let request_body = self.create_request_body( - &request.messages, - request.tools.as_deref(), - true, - max_tokens, - temperature + &request.messages, + request.tools.as_deref(), + true, + max_tokens, + temperature, )?; - debug!("Sending streaming request to Anthropic API: model={}, max_tokens={}, temperature={}", - request_body.model, request_body.max_tokens, request_body.temperature); - + debug!( + "Sending streaming request to Anthropic API: model={}, max_tokens={}, temperature={}", + request_body.model, request_body.max_tokens, request_body.temperature + ); + // Debug: Log the full request body - debug!("Full request body: {}", serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "Failed to serialize".to_string())); + debug!( + "Full request body: {}", + serde_json::to_string_pretty(&request_body) + .unwrap_or_else(|_| "Failed to serialize".to_string()) + ); let response = self .create_request_builder(true) @@ -673,16 +758,16 @@ impl LLMProvider for AnthropicProvider { // Claude models support native tool calling true } - + fn supports_cache_control(&self) -> bool { // Anthropic supports cache control true } - + fn max_tokens(&self) -> u32 { self.max_tokens } - + fn temperature(&self) -> f32 { self.temperature } @@ -729,7 +814,7 @@ struct AnthropicMessage { #[serde(tag = "type")] enum AnthropicContent { #[serde(rename = "text")] - Text { + Text { text: String, #[serde(skip_serializing_if = "Option::is_none")] cache_control: Option, @@ -798,17 +883,14 @@ mod tests { #[test] fn test_message_conversion() { - let provider = AnthropicProvider::new( - "test-key".to_string(), - None, - None, - None, - None, - None, - ).unwrap(); + let provider = + AnthropicProvider::new("test-key".to_string(), None, None, None, None, None).unwrap(); let messages = vec![ - Message::new(MessageRole::System, "You are a helpful assistant.".to_string()), + Message::new( + MessageRole::System, + "You are a helpful assistant.".to_string(), + ), Message::new(MessageRole::User, "Hello!".to_string()), Message::new(MessageRole::Assistant, "Hi there!".to_string()), ]; @@ -830,7 +912,8 @@ mod tests { Some(0.5), None, None, - ).unwrap(); + ) + .unwrap(); let messages = vec![Message::new(MessageRole::User, "Test message".to_string())]; @@ -848,31 +931,23 @@ mod tests { #[test] fn test_tool_conversion() { - let provider = AnthropicProvider::new( - "test-key".to_string(), - None, - None, - None, - None, - None, - ).unwrap(); + let provider = + AnthropicProvider::new("test-key".to_string(), None, None, None, None, None).unwrap(); - let tools = vec![ - Tool { - name: "get_weather".to_string(), - description: "Get the current weather".to_string(), - input_schema: serde_json::json!({ - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state" - } - }, - "required": ["location"] - }), - }, - ]; + let tools = vec![Tool { + name: "get_weather".to_string(), + description: "Get the current weather".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state" + } + }, + "required": ["location"] + }), + }]; let anthropic_tools = provider.convert_tools(&tools); @@ -881,31 +956,30 @@ mod tests { assert_eq!(anthropic_tools[0].description, "Get the current weather"); assert_eq!(anthropic_tools[0].input_schema.schema_type, "object"); assert!(anthropic_tools[0].input_schema.required.is_some()); - assert_eq!(anthropic_tools[0].input_schema.required.as_ref().unwrap()[0], "location"); + assert_eq!( + anthropic_tools[0].input_schema.required.as_ref().unwrap()[0], + "location" + ); } #[test] fn test_cache_control_serialization() { - let provider = AnthropicProvider::new( - "test-key".to_string(), - None, - None, - None, - None, - None, - ).unwrap(); + let provider = + AnthropicProvider::new("test-key".to_string(), None, None, None, None, None).unwrap(); // Test message WITHOUT cache_control let messages_without = vec![Message::new(MessageRole::User, "Hello".to_string())]; let (_, anthropic_messages_without) = provider.convert_messages(&messages_without).unwrap(); let json_without = serde_json::to_string(&anthropic_messages_without).unwrap(); - + println!("Anthropic JSON without cache_control: {}", json_without); // Check if cache_control appears in the JSON if json_without.contains("cache_control") { println!("WARNING: JSON contains 'cache_control' field when not configured!"); - assert!(!json_without.contains("\"cache_control\":null"), - "JSON should not contain 'cache_control: null'"); + assert!( + !json_without.contains("\"cache_control\":null"), + "JSON should not contain 'cache_control: null'" + ); } // Test message WITH cache_control @@ -916,15 +990,21 @@ mod tests { )]; let (_, anthropic_messages_with) = provider.convert_messages(&messages_with).unwrap(); let json_with = serde_json::to_string(&anthropic_messages_with).unwrap(); - + println!("Anthropic JSON with cache_control: {}", json_with); - assert!(json_with.contains("cache_control"), - "JSON should contain 'cache_control' field when configured"); - assert!(json_with.contains("ephemeral"), - "JSON should contain 'ephemeral' type"); - + assert!( + json_with.contains("cache_control"), + "JSON should contain 'cache_control' field when configured" + ); + assert!( + json_with.contains("ephemeral"), + "JSON should contain 'ephemeral' type" + ); + // The key assertion: when cache_control is None, it should not appear in JSON - assert!(!json_without.contains("cache_control") || !json_without.contains("null"), - "JSON should not contain 'cache_control' field or null values when not configured"); + assert!( + !json_without.contains("cache_control") || !json_without.contains("null"), + "JSON should not contain 'cache_control' field or null values when not configured" + ); } } diff --git a/crates/g3-providers/src/databricks.rs b/crates/g3-providers/src/databricks.rs index fb826d2..c5e9dc6 100644 --- a/crates/g3-providers/src/databricks.rs +++ b/crates/g3-providers/src/databricks.rs @@ -312,7 +312,7 @@ impl DatabricksProvider { // Append new bytes to our buffer byte_buffer.extend_from_slice(&chunk); - + // Try to convert the entire buffer to UTF-8 let chunk_str = match std::str::from_utf8(&byte_buffer) { Ok(s) => { @@ -326,7 +326,8 @@ impl DatabricksProvider { let valid_up_to = e.valid_up_to(); if valid_up_to > 0 { // We have some valid UTF-8, extract it and keep the rest for next iteration - let valid_bytes = byte_buffer.drain(..valid_up_to).collect::>(); + let valid_bytes = + byte_buffer.drain(..valid_up_to).collect::>(); std::str::from_utf8(&valid_bytes).unwrap().to_string() } else { // No valid UTF-8 at all, skip this chunk and continue @@ -593,7 +594,7 @@ impl DatabricksProvider { } Err(e) => { error!("Stream error at chunk {}: {}", chunk_count, e); - + // Check if this is a connection error that might be recoverable let error_msg = e.to_string(); if error_msg.contains("unexpected EOF") || error_msg.contains("connection") { @@ -610,10 +611,14 @@ impl DatabricksProvider { // Log final state debug!("Stream ended after {} chunks", chunk_count); - debug!("Final state: buffer_len={}, incomplete_data_line_len={}, byte_buffer_len={}", - buffer.len(), incomplete_data_line.len(), byte_buffer.len()); + debug!( + "Final state: buffer_len={}, incomplete_data_line_len={}, byte_buffer_len={}", + buffer.len(), + incomplete_data_line.len(), + byte_buffer.len() + ); debug!("Accumulated tool calls: {}", current_tool_calls.len()); - + // If we have any remaining data in buffers, log it for debugging if !buffer.is_empty() { debug!("Remaining buffer content: {:?}", buffer); @@ -924,7 +929,7 @@ impl LLMProvider for DatabricksProvider { "Processing Databricks streaming request with {} messages", request.messages.len() ); - + // Debug: Log tool count if let Some(ref tools) = request.tools { debug!("Request has {} tools", tools.len()); @@ -1051,15 +1056,15 @@ impl LLMProvider for DatabricksProvider { // This includes Claude, Llama, DBRX, and most other models on the platform true } - + fn supports_cache_control(&self) -> bool { false } - + fn max_tokens(&self) -> u32 { self.max_tokens } - + fn temperature(&self) -> f32 { self.temperature } @@ -1181,7 +1186,10 @@ mod tests { .unwrap(); let messages = vec![ - Message::new(MessageRole::System, "You are a helpful assistant.".to_string()), + Message::new( + MessageRole::System, + "You are a helpful assistant.".to_string(), + ), Message::new(MessageRole::User, "Hello!".to_string()), Message::new(MessageRole::Assistant, "Hi there!".to_string()), ]; @@ -1304,10 +1312,12 @@ mod tests { let messages_without = vec![Message::new(MessageRole::User, "Hello".to_string())]; let databricks_messages_without = provider.convert_messages(&messages_without).unwrap(); let json_without = serde_json::to_string(&databricks_messages_without).unwrap(); - + println!("JSON without cache_control: {}", json_without); - assert!(!json_without.contains("cache_control"), - "JSON should not contain 'cache_control' field when not configured"); + assert!( + !json_without.contains("cache_control"), + "JSON should not contain 'cache_control' field when not configured" + ); // Test message WITH cache_control - should still NOT include it (Databricks doesn't support it) let messages_with = vec![Message::with_cache_control( @@ -1317,10 +1327,12 @@ mod tests { )]; let databricks_messages_with = provider.convert_messages(&messages_with).unwrap(); let json_with = serde_json::to_string(&databricks_messages_with).unwrap(); - + println!("JSON with cache_control: {}", json_with); - assert!(!json_with.contains("cache_control"), - "JSON should NOT contain 'cache_control' field - Databricks doesn't support it"); + assert!( + !json_with.contains("cache_control"), + "JSON should NOT contain 'cache_control' field - Databricks doesn't support it" + ); } #[test] @@ -1343,7 +1355,13 @@ mod tests { ) .unwrap(); - assert!(!claude_provider.supports_cache_control(), "Databricks should not support cache_control even for Claude models"); - assert!(!llama_provider.supports_cache_control(), "Databricks should not support cache_control for Llama models"); + assert!( + !claude_provider.supports_cache_control(), + "Databricks should not support cache_control even for Claude models" + ); + assert!( + !llama_provider.supports_cache_control(), + "Databricks should not support cache_control for Llama models" + ); } } diff --git a/crates/g3-providers/src/embedded.rs b/crates/g3-providers/src/embedded.rs index 1c29bc0..3bf8e1b 100644 --- a/crates/g3-providers/src/embedded.rs +++ b/crates/g3-providers/src/embedded.rs @@ -1,8 +1,8 @@ -use anyhow::Result; use crate::{ CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message, MessageRole, Usage, }; +use anyhow::Result; use llama_cpp::{ standard_sampler::{SamplerStage, StandardSampler}, LlamaModel, LlamaParams, LlamaSession, SessionParams, @@ -37,7 +37,7 @@ impl EmbeddedProvider { // Expand tilde in path let expanded_path = shellexpand::tilde(&model_path); let model_path_buf = PathBuf::from(expanded_path.as_ref()); - + // If model doesn't exist and it's the default Qwen model, offer to download it if !model_path_buf.exists() { if model_path.contains("qwen2.5-7b-instruct-q3_k_m.gguf") { @@ -47,7 +47,7 @@ impl EmbeddedProvider { anyhow::bail!("Model file not found: {}", model_path_buf.display()); } } - + let model_path = model_path_buf.as_path(); // Set up model parameters @@ -93,24 +93,24 @@ impl EmbeddedProvider { fn format_messages(&self, messages: &[Message]) -> String { // Determine the appropriate format based on model type let model_name_lower = self.model_name.to_lowercase(); - + if model_name_lower.contains("qwen") { // Qwen format: <|im_start|>role\ncontent<|im_end|> let mut formatted = String::new(); - + for message in messages { let role = match message.role { MessageRole::System => "system", - MessageRole::User => "user", + MessageRole::User => "user", MessageRole::Assistant => "assistant", }; - + formatted.push_str(&format!( "<|im_start|>{}\n{}<|im_end|>\n", role, message.content )); } - + // Add the start of assistant response formatted.push_str("<|im_start|>assistant\n"); formatted @@ -118,7 +118,7 @@ impl EmbeddedProvider { // Mistral Instruct format: [INST] ... [/INST] assistant_response let mut formatted = String::new(); let mut in_conversation = false; - + for (i, message) in messages.iter().enumerate() { match message.role { MessageRole::System => { @@ -146,12 +146,15 @@ impl EmbeddedProvider { } } } - + // If the last message was from user, add a space for the assistant's response - if messages.last().is_some_and(|m| matches!(m.role, MessageRole::User)) { + if messages + .last() + .is_some_and(|m| matches!(m.role, MessageRole::User)) + { formatted.push(' '); } - + formatted } else { // Use Llama/CodeLlama format for other models @@ -216,16 +219,25 @@ impl EmbeddedProvider { } Err(_) => { if attempt < 4 { - debug!("Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1); - std::thread::sleep(std::time::Duration::from_millis(100 * (attempt + 1) as u64)); + debug!( + "Session busy, retrying in {}ms (attempt {}/5)", + 100 * (attempt + 1), + attempt + 1 + ); + std::thread::sleep(std::time::Duration::from_millis( + 100 * (attempt + 1) as u64, + )); } else { - return Err(anyhow::anyhow!("Model is busy after 5 attempts, please try again")); + return Err(anyhow::anyhow!( + "Model is busy after 5 attempts, please try again" + )); } } } } - - let mut session = session_guard.ok_or_else(|| anyhow::anyhow!("Failed to acquire session lock"))?; + + let mut session = session_guard + .ok_or_else(|| anyhow::anyhow!("Failed to acquire session lock"))?; debug!( "Starting inference with prompt length: {} chars, estimated {} tokens", @@ -297,7 +309,7 @@ impl EmbeddedProvider { break; } } - + if hit_stop { break; } @@ -308,7 +320,7 @@ impl EmbeddedProvider { token_count, start_time.elapsed() ); - + Ok((generated_text, token_count)) }), ) @@ -347,21 +359,22 @@ impl EmbeddedProvider { fn get_stop_sequences(&self) -> Vec<&'static str> { // Determine model type from model_name let model_name_lower = self.model_name.to_lowercase(); - + if model_name_lower.contains("qwen") { vec![ - "<|im_end|>", // Qwen ChatML format end token - "<|endoftext|>", // Alternative end token - "", // Generic end of sequence - "<|im_start|>", // Start of new message (shouldn't appear in response) + "<|im_end|>", // Qwen ChatML format end token + "<|endoftext|>", // Alternative end token + "", // Generic end of sequence + "<|im_start|>", // Start of new message (shouldn't appear in response) ] - } else if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama") { + } else if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama") + { vec![ - "", // End of sequence - "[/INST]", // End of instruction - "<>", // End of system message - "[INST]", // Start of new instruction (shouldn't appear in response) - "<>", // Start of system (shouldn't appear in response) + "", // End of sequence + "[/INST]", // End of instruction + "<>", // End of system message + "[INST]", // Start of new instruction (shouldn't appear in response) + "<>", // Start of system (shouldn't appear in response) ] } else if model_name_lower.contains("llama") { vec![ @@ -374,9 +387,9 @@ impl EmbeddedProvider { ] } else if model_name_lower.contains("mistral") { vec![ - "", // End of sequence - "[/INST]", // End of instruction - "<|im_end|>", // ChatML format + "", // End of sequence + "[/INST]", // End of instruction + "<|im_end|>", // ChatML format ] } else if model_name_lower.contains("vicuna") || model_name_lower.contains("wizard") { vec![ @@ -391,7 +404,7 @@ impl EmbeddedProvider { "### Instruction:", // Alpaca format "### Response:", // Alpaca format "### Input:", // Alpaca format - "", // End of sequence + "", // End of sequence ] } else { // Generic/unknown model - use common stop sequences @@ -411,14 +424,14 @@ impl EmbeddedProvider { fn clean_stop_sequences(&self, text: &str) -> String { let mut cleaned = text.to_string(); let stop_sequences = self.get_stop_sequences(); - + for stop_seq in &stop_sequences { if let Some(pos) = cleaned.find(stop_seq) { cleaned.truncate(pos); break; // Only remove the first occurrence to avoid over-truncation } } - + cleaned.trim().to_string() } @@ -426,57 +439,64 @@ impl EmbeddedProvider { fn download_qwen_model(model_path: &Path) -> Result<()> { use std::fs; use std::process::Command; - + const MODEL_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-7B-Instruct-GGUF/resolve/main/qwen2.5-7b-instruct-q3_k_m.gguf"; const MODEL_SIZE_MB: u64 = 3631; // Approximate size in MB - + // Create the parent directory if it doesn't exist if let Some(parent) = model_path.parent() { fs::create_dir_all(parent)?; } - + info!("Downloading Qwen 2.5 7B model (Q3_K_M quantization, ~3.5GB)..."); info!("This is a one-time download that may take several minutes depending on your connection."); info!("Downloading to: {}", model_path.display()); - + // Use curl with progress bar for download let output = Command::new("curl") .args([ - "-L", // Follow redirects - "-#", // Show progress bar - "-f", // Fail on HTTP errors - "-o", model_path.to_str().unwrap(), + "-L", // Follow redirects + "-#", // Show progress bar + "-f", // Fail on HTTP errors + "-o", + model_path.to_str().unwrap(), MODEL_URL, ]) .output()?; - + if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); - + // If curl is not available, provide alternative instructions if stderr.contains("command not found") || stderr.contains("not found") { - error!("curl is not installed. Please install curl or manually download the model."); + error!( + "curl is not installed. Please install curl or manually download the model." + ); error!("Manual download instructions:"); error!("1. Download from: {}", MODEL_URL); error!("2. Save to: {}", model_path.display()); - anyhow::bail!("curl not found - please install curl or download the model manually"); + anyhow::bail!( + "curl not found - please install curl or download the model manually" + ); } - + anyhow::bail!("Failed to download model: {}", stderr); } - + // Verify the file was created and has reasonable size let metadata = fs::metadata(model_path)?; let size_mb = metadata.len() / (1024 * 1024); - - if size_mb < MODEL_SIZE_MB - 100 { // Allow some variance - fs::remove_file(model_path).ok(); // Clean up partial download + + if size_mb < MODEL_SIZE_MB - 100 { + // Allow some variance + fs::remove_file(model_path).ok(); // Clean up partial download anyhow::bail!( "Downloaded file appears incomplete ({}MB vs expected ~{}MB). Please try again.", - size_mb, MODEL_SIZE_MB + size_mb, + MODEL_SIZE_MB ); } - + info!("Successfully downloaded Qwen 2.5 7B model ({}MB)", size_mb); Ok(()) } @@ -541,20 +561,29 @@ impl LLMProvider for EmbeddedProvider { } Err(_) => { if attempt < 4 { - debug!("Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1); - std::thread::sleep(std::time::Duration::from_millis(100 * (attempt + 1) as u64)); + debug!( + "Session busy, retrying in {}ms (attempt {}/5)", + 100 * (attempt + 1), + attempt + 1 + ); + std::thread::sleep(std::time::Duration::from_millis( + 100 * (attempt + 1) as u64, + )); } else { - let _ = tx.blocking_send(Err(anyhow::anyhow!("Model is busy after 5 attempts, please try again"))); + let _ = tx.blocking_send(Err(anyhow::anyhow!( + "Model is busy after 5 attempts, please try again" + ))); return; } } } } - + let mut session = match session_guard { Some(ctx) => ctx, None => { - let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to acquire session lock"))); + let _ = + tx.blocking_send(Err(anyhow::anyhow!("Failed to acquire session lock"))); return; } }; @@ -588,17 +617,33 @@ impl LLMProvider for EmbeddedProvider { let mut accumulated_text = String::new(); let mut token_count = 0; let mut unsent_tokens = String::new(); // Buffer for tokens we're holding back - + // Get stop sequences dynamically based on model type let stop_sequences = if prompt.contains("<|im_start|>") { // Qwen ChatML format detected vec!["<|im_end|>", "<|endoftext|>", "", "<|im_start|>"] } else if prompt.contains("[INST]") || prompt.contains("<>") { // Llama/CodeLlama format detected - vec!["", "[/INST]", "<>", "[INST]", "<>", "### Human:", "### Assistant:"] + vec![ + "", + "[/INST]", + "<>", + "[INST]", + "<>", + "### Human:", + "### Assistant:", + ] } else { // Generic format - vec!["", "<|endoftext|>", "<|im_end|>", "### Human:", "### Assistant:", "[/INST]", "<>"] + vec![ + "", + "<|endoftext|>", + "<|im_end|>", + "### Human:", + "### Assistant:", + "[/INST]", + "<>", + ] }; // Stream tokens with proper limits @@ -622,10 +667,10 @@ impl LLMProvider for EmbeddedProvider { if hit_stop { // Before stopping, check if there might be an incomplete tool call // Look for JSON tool call patterns that might be cut off by the stop sequence - let has_potential_tool_call = accumulated_text.contains(r#"{"tool":"#) || - accumulated_text.contains(r#"{"{""tool"":"#) || - accumulated_text.contains(r#"{{""tool"":"#); - + let has_potential_tool_call = accumulated_text.contains(r#"{"tool":"#) + || accumulated_text.contains(r#"{"{""tool"":"#) + || accumulated_text.contains(r#"{{""tool"":"#); + if has_potential_tool_call { // Check if the tool call appears to be complete (has closing brace after the stop sequence) let mut complete_tool_call = false; @@ -645,7 +690,7 @@ impl LLMProvider for EmbeddedProvider { } } } - + // If tool call is incomplete, send the raw content including stop sequences // so the main parser can handle it properly if !complete_tool_call { @@ -666,7 +711,7 @@ impl LLMProvider for EmbeddedProvider { break; } } - + // Send any remaining clean content before stopping (original behavior) let mut clean_accumulated = accumulated_text.clone(); for stop_seq in &stop_sequences { @@ -675,7 +720,7 @@ impl LLMProvider for EmbeddedProvider { break; } } - + // Calculate what part we haven't sent yet let already_sent_len = accumulated_text.len() - unsent_tokens.len(); if clean_accumulated.len() > already_sent_len { @@ -711,7 +756,8 @@ impl LLMProvider for EmbeddedProvider { if might_be_stop { // Hold back tokens, but only for a limited buffer size - if unsent_tokens.len() > 20 { // Don't hold back more than 20 characters + if unsent_tokens.len() > 20 { + // Don't hold back more than 20 characters // Send the oldest part and keep only the recent part that might be a stop sequence let to_send = &unsent_tokens[..unsent_tokens.len() - 10]; if !to_send.is_empty() { @@ -755,7 +801,7 @@ impl LLMProvider for EmbeddedProvider { let final_chunk = CompletionChunk { content: String::new(), finished: true, - usage: None, // Embedded models calculate usage differently + usage: None, // Embedded models calculate usage differently tool_calls: None, }; let _ = tx.blocking_send(Ok(final_chunk)); @@ -771,11 +817,11 @@ impl LLMProvider for EmbeddedProvider { fn model(&self) -> &str { &self.model_name } - + fn max_tokens(&self) -> u32 { self.max_tokens } - + fn temperature(&self) -> f32 { self.temperature } diff --git a/crates/g3-providers/src/lib.rs b/crates/g3-providers/src/lib.rs index 6759662..180400d 100644 --- a/crates/g3-providers/src/lib.rs +++ b/crates/g3-providers/src/lib.rs @@ -1,36 +1,36 @@ -use serde::{Deserialize, Serialize}; use anyhow::Result; -use std::collections::HashMap; use rand::Rng; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; /// Trait for LLM providers #[async_trait::async_trait] pub trait LLMProvider: Send + Sync { /// Generate a completion for the given messages async fn complete(&self, request: CompletionRequest) -> Result; - + /// Stream a completion for the given messages async fn stream(&self, request: CompletionRequest) -> Result; - + /// Get the provider name fn name(&self) -> &str; - + /// Get the model name fn model(&self) -> &str; - + /// Check if the provider supports native tool calling fn has_native_tool_calling(&self) -> bool { false } - + /// Check if the provider supports cache control fn supports_cache_control(&self) -> bool { false } - + /// Get the configured max_tokens for this provider fn max_tokens(&self) -> u32; - + /// Get the configured temperature for this provider fn temperature(&self) -> f32; } @@ -60,15 +60,24 @@ pub enum CacheType { impl CacheControl { pub fn ephemeral() -> Self { - Self { cache_type: CacheType::Ephemeral, ttl: None } + Self { + cache_type: CacheType::Ephemeral, + ttl: None, + } } - + pub fn five_minute() -> Self { - Self { cache_type: CacheType::Ephemeral, ttl: Some("5m".to_string()) } + Self { + cache_type: CacheType::Ephemeral, + ttl: Some("5m".to_string()), + } } - + pub fn one_hour() -> Self { - Self { cache_type: CacheType::Ephemeral, ttl: Some("1h".to_string()) } + Self { + cache_type: CacheType::Ephemeral, + ttl: Some("1h".to_string()), + } } } @@ -76,6 +85,7 @@ impl CacheControl { pub struct Message { pub role: MessageRole, pub content: String, + #[serde(skip)] pub id: String, #[serde(skip_serializing_if = "Option::is_none")] pub cache_control: Option, @@ -110,7 +120,7 @@ pub struct CompletionChunk { pub content: String, pub finished: bool, pub tool_calls: Option>, - pub usage: Option, // Add usage tracking for streaming + pub usage: Option, // Add usage tracking for streaming } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -144,7 +154,7 @@ impl Message { fn generate_id() -> String { let now = chrono::Local::now(); let timestamp = now.format("%H%M%S").to_string(); - + let mut rng = rand::thread_rng(); let random_chars: String = (0..3) .map(|_| { @@ -153,10 +163,10 @@ impl Message { chars[idx] as char }) .collect(); - + format!("{}-{}", timestamp, random_chars) } - + /// Create a new message with optional cache control pub fn new(role: MessageRole, content: String) -> Self { Self { @@ -168,7 +178,11 @@ impl Message { } /// Create a new message with cache control - pub fn with_cache_control(role: MessageRole, content: String, cache_control: CacheControl) -> Self { + pub fn with_cache_control( + role: MessageRole, + content: String, + cache_control: CacheControl, + ) -> Self { Self { role, content, @@ -176,13 +190,13 @@ impl Message { cache_control: Some(cache_control), } } - + /// Create a message with cache control, with provider validation pub fn with_cache_control_validated( - role: MessageRole, - content: String, + role: MessageRole, + content: String, cache_control: CacheControl, - provider: &dyn LLMProvider + provider: &dyn LLMProvider, ) -> Self { if !provider.supports_cache_control() { tracing::warn!( @@ -192,7 +206,7 @@ impl Message { ); return Self::new(role, content); } - + Self::with_cache_control(role, content, cache_control) } } @@ -210,16 +224,16 @@ impl ProviderRegistry { default_provider: String::new(), } } - + pub fn register(&mut self, provider: P) { let name = provider.name().to_string(); self.providers.insert(name.clone(), Box::new(provider)); - + if self.default_provider.is_empty() { self.default_provider = name; } } - + pub fn set_default(&mut self, provider_name: &str) -> Result<()> { if !self.providers.contains_key(provider_name) { anyhow::bail!("Provider '{}' not found", provider_name); @@ -227,7 +241,7 @@ impl ProviderRegistry { self.default_provider = provider_name.to_string(); Ok(()) } - + pub fn get(&self, provider_name: Option<&str>) -> Result<&dyn LLMProvider> { let name = provider_name.unwrap_or(&self.default_provider); self.providers @@ -235,7 +249,7 @@ impl ProviderRegistry { .map(|p| p.as_ref()) .ok_or_else(|| anyhow::anyhow!("Provider '{}' not found", name)) } - + pub fn list_providers(&self) -> Vec<&str> { self.providers.keys().map(|s| s.as_str()).collect() } @@ -255,10 +269,12 @@ mod tests { fn test_message_serialization_without_cache_control() { let msg = Message::new(MessageRole::User, "Hello".to_string()); let json = serde_json::to_string(&msg).unwrap(); - + println!("Message JSON without cache_control: {}", json); - assert!(!json.contains("cache_control"), - "JSON should not contain 'cache_control' field when not configured"); + assert!( + !json.contains("cache_control"), + "JSON should not contain 'cache_control' field when not configured" + ); } #[test] @@ -269,16 +285,24 @@ mod tests { CacheControl::ephemeral(), ); let json = serde_json::to_string(&msg).unwrap(); - + println!("Message JSON with cache_control: {}", json); - assert!(json.contains("cache_control"), - "JSON should contain 'cache_control' field when configured"); - assert!(json.contains("ephemeral"), - "JSON should contain 'ephemeral' value"); - assert!(json.contains("\"type\":"), - "JSON should contain 'type' field in cache_control"); - assert!(!json.contains("null"), - "JSON should not contain null values"); + assert!( + json.contains("cache_control"), + "JSON should contain 'cache_control' field when configured" + ); + assert!( + json.contains("ephemeral"), + "JSON should contain 'ephemeral' value" + ); + assert!( + json.contains("\"type\":"), + "JSON should contain 'type' field in cache_control" + ); + assert!( + !json.contains("null"), + "JSON should not contain null values" + ); } #[test] @@ -289,11 +313,20 @@ mod tests { CacheControl::five_minute(), ); let json = serde_json::to_string(&msg).unwrap(); - + println!("Message JSON with 5-minute cache_control: {}", json); - assert!(json.contains("cache_control"), "JSON should contain 'cache_control' field"); - assert!(json.contains("ephemeral"), "JSON should contain 'ephemeral' type"); - assert!(json.contains("\"ttl\":\"5m\""), "JSON should contain ttl field with 5m value"); + assert!( + json.contains("cache_control"), + "JSON should contain 'cache_control' field" + ); + assert!( + json.contains("ephemeral"), + "JSON should contain 'ephemeral' type" + ); + assert!( + json.contains("\"ttl\":\"5m\""), + "JSON should contain ttl field with 5m value" + ); } #[test] @@ -304,39 +337,53 @@ mod tests { CacheControl::one_hour(), ); let json = serde_json::to_string(&msg).unwrap(); - + println!("Message JSON with 1-hour cache_control: {}", json); - assert!(json.contains("cache_control"), "JSON should contain 'cache_control' field"); - assert!(json.contains("ephemeral"), "JSON should contain 'ephemeral' type"); - assert!(json.contains("\"ttl\":\"1h\""), "JSON should contain ttl field with 1h value"); + assert!( + json.contains("cache_control"), + "JSON should contain 'cache_control' field" + ); + assert!( + json.contains("ephemeral"), + "JSON should contain 'ephemeral' type" + ); + assert!( + json.contains("\"ttl\":\"1h\""), + "JSON should contain ttl field with 1h value" + ); } #[test] fn test_message_id_generation() { let msg = Message::new(MessageRole::User, "Hello".to_string()); - + // Check that id is not empty assert!(!msg.id.is_empty(), "Message ID should not be empty"); - + // Check format: HHMMSS-XXX let parts: Vec<&str> = msg.id.split('-').collect(); assert_eq!(parts.len(), 2, "Message ID should have format HHMMSS-XXX"); - + // Check timestamp part is 6 digits assert_eq!(parts[0].len(), 6, "Timestamp should be 6 digits (HHMMSS)"); - assert!(parts[0].chars().all(|c| c.is_ascii_digit()), "Timestamp should be all digits"); - + assert!( + parts[0].chars().all(|c| c.is_ascii_digit()), + "Timestamp should be all digits" + ); + // Check random part is 3 alpha characters assert_eq!(parts[1].len(), 3, "Random part should be 3 characters"); - assert!(parts[1].chars().all(|c| c.is_ascii_alphabetic()), - "Random part should be all alphabetic characters"); + assert!( + parts[1].chars().all(|c| c.is_ascii_alphabetic()), + "Random part should be all alphabetic characters" + ); } #[test] fn test_message_id_uniqueness() { let msg1 = Message::new(MessageRole::User, "Hello".to_string()); let msg2 = Message::new(MessageRole::User, "Hello".to_string()); - + // IDs should be different (due to random component) // Note: There's a tiny chance they could be the same, but very unlikely println!("msg1.id: {}, msg2.id: {}", msg1.id, msg2.id); @@ -346,9 +393,12 @@ mod tests { fn test_message_id_not_serialized() { let msg = Message::new(MessageRole::User, "Hello".to_string()); let json = serde_json::to_string(&msg).unwrap(); - + println!("Message JSON: {}", json); - assert!(!json.contains("\"id\""), "JSON should not contain 'id' field"); + assert!( + !json.contains("\"id\""), + "JSON should not contain 'id' field" + ); } #[test] @@ -358,8 +408,14 @@ mod tests { "Hello".to_string(), CacheControl::ephemeral(), ); - - assert!(!msg.id.is_empty(), "Message with cache control should have an ID"); - assert!(msg.id.contains('-'), "Message ID should contain hyphen separator"); + + assert!( + !msg.id.is_empty(), + "Message with cache control should have an ID" + ); + assert!( + msg.id.contains('-'), + "Message ID should contain hyphen separator" + ); } } diff --git a/crates/g3-providers/src/openai.rs b/crates/g3-providers/src/openai.rs index 3704d62..d322663 100644 --- a/crates/g3-providers/src/openai.rs +++ b/crates/g3-providers/src/openai.rs @@ -10,8 +10,8 @@ use tokio_stream::wrappers::ReceiverStream; use tracing::{debug, error}; use crate::{ - CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, - Message, MessageRole, Tool, ToolCall, Usage, + CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message, + MessageRole, Tool, ToolCall, Usage, }; #[derive(Clone)] @@ -138,7 +138,8 @@ impl OpenAIProvider { debug!("Received stream completion marker"); // Send final chunk with accumulated content and tool calls - if !accumulated_content.is_empty() || !current_tool_calls.is_empty() { + if !accumulated_content.is_empty() || !current_tool_calls.is_empty() + { let tool_calls = if current_tool_calls.is_empty() { None } else { @@ -188,8 +189,9 @@ impl OpenAIProvider { if let Some(index) = delta_tool_call.index { // Ensure we have enough tool calls in our vector while current_tool_calls.len() <= index { - current_tool_calls - .push(OpenAIStreamingToolCall::default()); + current_tool_calls.push( + OpenAIStreamingToolCall::default(), + ); } let tool_call = &mut current_tool_calls[index]; @@ -198,11 +200,14 @@ impl OpenAIProvider { tool_call.id = Some(id.clone()); } - if let Some(function) = &delta_tool_call.function { + if let Some(function) = + &delta_tool_call.function + { if let Some(name) = &function.name { tool_call.name = Some(name.clone()); } - if let Some(arguments) = &function.arguments { + if let Some(arguments) = &function.arguments + { tool_call.arguments.push_str(arguments); } } @@ -246,7 +251,7 @@ impl OpenAIProvider { .collect(), ) }; - + let final_chunk = CompletionChunk { content: String::new(), finished: true, @@ -254,7 +259,7 @@ impl OpenAIProvider { usage: accumulated_usage.clone(), }; let _ = tx.send(Ok(final_chunk)).await; - + accumulated_usage } } @@ -291,7 +296,11 @@ impl LLMProvider for OpenAIProvider { .text() .await .unwrap_or_else(|_| "Unknown error".to_string()); - return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text)); + return Err(anyhow::anyhow!( + "OpenAI API error {}: {}", + status, + error_text + )); } let openai_response: OpenAIResponse = response.json().await?; @@ -334,7 +343,10 @@ impl LLMProvider for OpenAIProvider { request.temperature, ); - debug!("Sending streaming request to OpenAI API: model={}", self.model); + debug!( + "Sending streaming request to OpenAI API: model={}", + self.model + ); let response = self .client @@ -350,7 +362,11 @@ impl LLMProvider for OpenAIProvider { .text() .await .unwrap_or_else(|_| "Unknown error".to_string()); - return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text)); + return Err(anyhow::anyhow!( + "OpenAI API error {}: {}", + status, + error_text + )); } let stream = response.bytes_stream(); @@ -384,11 +400,11 @@ impl LLMProvider for OpenAIProvider { // OpenAI models support native tool calling true } - + fn max_tokens(&self) -> u32 { self.max_tokens.unwrap_or(16000) } - + fn temperature(&self) -> f32 { self._temperature.unwrap_or(0.1) } @@ -472,9 +488,9 @@ impl OpenAIStreamingToolCall { fn to_tool_call(&self) -> Option { let id = self.id.as_ref()?; let name = self.name.as_ref()?; - + let args = serde_json::from_str(&self.arguments).unwrap_or(serde_json::Value::Null); - + Some(ToolCall { id: id.clone(), tool: name.clone(), diff --git a/crates/g3-providers/tests/cache_control_error_regression_test.rs b/crates/g3-providers/tests/cache_control_error_regression_test.rs index 533c943..03b4dea 100644 --- a/crates/g3-providers/tests/cache_control_error_regression_test.rs +++ b/crates/g3-providers/tests/cache_control_error_regression_test.rs @@ -20,18 +20,24 @@ fn test_no_wrong_serialization_format() { CacheControl::ephemeral(), ); let json = serde_json::to_string(&msg).unwrap(); - + println!("Ephemeral message JSON: {}", json); - + // Should NOT contain the wrong format - assert!(!json.contains("system.0.cache_control"), - "JSON should not contain 'system.0.cache_control' path"); - assert!(!json.contains("cache_control.ephemeral"), - "JSON should not contain 'cache_control.ephemeral' path"); - + assert!( + !json.contains("system.0.cache_control"), + "JSON should not contain 'system.0.cache_control' path" + ); + assert!( + !json.contains("cache_control.ephemeral"), + "JSON should not contain 'cache_control.ephemeral' path" + ); + // Should contain the correct format - assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#), - "JSON should contain correct cache_control format"); + assert!( + json.contains(r#""cache_control":{"type":"ephemeral"}"#), + "JSON should contain correct cache_control format" + ); } #[test] @@ -42,20 +48,28 @@ fn test_five_minute_no_wrong_format() { CacheControl::five_minute(), ); let json = serde_json::to_string(&msg).unwrap(); - + println!("5-minute message JSON: {}", json); - + // Should NOT contain the wrong format - assert!(!json.contains("system.0.cache_control"), - "JSON should not contain 'system.0.cache_control' path"); - assert!(!json.contains("cache_control.ephemeral.ttl"), - "JSON should not contain 'cache_control.ephemeral.ttl' path"); - + assert!( + !json.contains("system.0.cache_control"), + "JSON should not contain 'system.0.cache_control' path" + ); + assert!( + !json.contains("cache_control.ephemeral.ttl"), + "JSON should not contain 'cache_control.ephemeral.ttl' path" + ); + // Should contain the correct format with ttl as a direct field - assert!(json.contains(r#""type":"ephemeral""#), - "JSON should contain type field"); - assert!(json.contains(r#""ttl":"5m""#), - "JSON should contain ttl field with value 5m"); + assert!( + json.contains(r#""type":"ephemeral""#), + "JSON should contain type field" + ); + assert!( + json.contains(r#""ttl":"5m""#), + "JSON should contain ttl field with value 5m" + ); } #[test] @@ -66,44 +80,59 @@ fn test_one_hour_no_wrong_format() { CacheControl::one_hour(), ); let json = serde_json::to_string(&msg).unwrap(); - + println!("1-hour message JSON: {}", json); - + // Should NOT contain the wrong format - assert!(!json.contains("system.0.cache_control"), - "JSON should not contain 'system.0.cache_control' path"); - assert!(!json.contains("cache_control.ephemeral.ttl"), - "JSON should not contain 'cache_control.ephemeral.ttl' path"); - + assert!( + !json.contains("system.0.cache_control"), + "JSON should not contain 'system.0.cache_control' path" + ); + assert!( + !json.contains("cache_control.ephemeral.ttl"), + "JSON should not contain 'cache_control.ephemeral.ttl' path" + ); + // Should contain the correct format with ttl as a direct field - assert!(json.contains(r#""type":"ephemeral""#), - "JSON should contain type field"); - assert!(json.contains(r#""ttl":"1h""#), - "JSON should contain ttl field with value 1h"); + assert!( + json.contains(r#""type":"ephemeral""#), + "JSON should contain type field" + ); + assert!( + json.contains(r#""ttl":"1h""#), + "JSON should contain ttl field with value 1h" + ); } #[test] fn test_cache_control_structure_is_flat() { // Verify that the cache_control object has a flat structure // with 'type' and optional 'ttl' at the same level - + let cache_control = CacheControl::five_minute(); let json_value = serde_json::to_value(&cache_control).unwrap(); - - println!("Cache control as JSON value: {}", serde_json::to_string_pretty(&json_value).unwrap()); - + + println!( + "Cache control as JSON value: {}", + serde_json::to_string_pretty(&json_value).unwrap() + ); + let obj = json_value.as_object().expect("Should be an object"); - + // Should have exactly 2 keys at the top level - assert_eq!(obj.len(), 2, "Cache control should have exactly 2 top-level fields"); - + assert_eq!( + obj.len(), + 2, + "Cache control should have exactly 2 top-level fields" + ); + // Both 'type' and 'ttl' should be at the same level assert!(obj.contains_key("type"), "Should have 'type' field"); assert!(obj.contains_key("ttl"), "Should have 'ttl' field"); - + // 'type' should be a string, not an object assert!(obj["type"].is_string(), "'type' should be a string value"); - + // 'ttl' should be a string, not nested assert!(obj["ttl"].is_string(), "'ttl' should be a string value"); } @@ -112,20 +141,30 @@ fn test_cache_control_structure_is_flat() { fn test_ephemeral_cache_control_structure() { let cache_control = CacheControl::ephemeral(); let json_value = serde_json::to_value(&cache_control).unwrap(); - - println!("Ephemeral cache control as JSON value: {}", serde_json::to_string_pretty(&json_value).unwrap()); - + + println!( + "Ephemeral cache control as JSON value: {}", + serde_json::to_string_pretty(&json_value).unwrap() + ); + let obj = json_value.as_object().expect("Should be an object"); - + // Should have exactly 1 key (only 'type', no 'ttl') - assert_eq!(obj.len(), 1, "Ephemeral cache control should have exactly 1 top-level field"); - + assert_eq!( + obj.len(), + 1, + "Ephemeral cache control should have exactly 1 top-level field" + ); + // Should have 'type' field assert!(obj.contains_key("type"), "Should have 'type' field"); - + // Should NOT have 'ttl' field - assert!(!obj.contains_key("ttl"), "Ephemeral should not have 'ttl' field"); - + assert!( + !obj.contains_key("ttl"), + "Ephemeral should not have 'ttl' field" + ); + // 'type' should be a string with value "ephemeral" assert_eq!(obj["type"].as_str().unwrap(), "ephemeral"); } diff --git a/crates/g3-providers/tests/cache_control_integration_test.rs b/crates/g3-providers/tests/cache_control_integration_test.rs index 5ec365c..71ad89b 100644 --- a/crates/g3-providers/tests/cache_control_integration_test.rs +++ b/crates/g3-providers/tests/cache_control_integration_test.rs @@ -10,13 +10,19 @@ use serde_json::json; fn test_ephemeral_cache_control_serialization() { let cache_control = CacheControl::ephemeral(); let json = serde_json::to_value(&cache_control).unwrap(); - - println!("Ephemeral cache_control JSON: {}", serde_json::to_string(&json).unwrap()); - - assert_eq!(json, json!({ - "type": "ephemeral" - })); - + + println!( + "Ephemeral cache_control JSON: {}", + serde_json::to_string(&json).unwrap() + ); + + assert_eq!( + json, + json!({ + "type": "ephemeral" + }) + ); + // Verify no ttl field is present assert!(!json.as_object().unwrap().contains_key("ttl")); } @@ -25,26 +31,38 @@ fn test_ephemeral_cache_control_serialization() { fn test_five_minute_cache_control_serialization() { let cache_control = CacheControl::five_minute(); let json = serde_json::to_value(&cache_control).unwrap(); - - println!("5-minute cache_control JSON: {}", serde_json::to_string(&json).unwrap()); - - assert_eq!(json, json!({ - "type": "ephemeral", - "ttl": "5m" - })); + + println!( + "5-minute cache_control JSON: {}", + serde_json::to_string(&json).unwrap() + ); + + assert_eq!( + json, + json!({ + "type": "ephemeral", + "ttl": "5m" + }) + ); } #[test] fn test_one_hour_cache_control_serialization() { let cache_control = CacheControl::one_hour(); let json = serde_json::to_value(&cache_control).unwrap(); - - println!("1-hour cache_control JSON: {}", serde_json::to_string(&json).unwrap()); - - assert_eq!(json, json!({ - "type": "ephemeral", - "ttl": "1h" - })); + + println!( + "1-hour cache_control JSON: {}", + serde_json::to_string(&json).unwrap() + ); + + assert_eq!( + json, + json!({ + "type": "ephemeral", + "ttl": "1h" + }) + ); } #[test] @@ -54,11 +72,16 @@ fn test_message_with_ephemeral_cache_control() { "System prompt".to_string(), CacheControl::ephemeral(), ); - + let json = serde_json::to_value(&msg).unwrap(); - println!("Message with ephemeral cache_control: {}", serde_json::to_string(&json).unwrap()); - - let cache_control = json.get("cache_control").expect("cache_control field should exist"); + println!( + "Message with ephemeral cache_control: {}", + serde_json::to_string(&json).unwrap() + ); + + let cache_control = json + .get("cache_control") + .expect("cache_control field should exist"); assert_eq!(cache_control.get("type").unwrap(), "ephemeral"); assert!(!cache_control.as_object().unwrap().contains_key("ttl")); } @@ -70,11 +93,16 @@ fn test_message_with_five_minute_cache_control() { "System prompt".to_string(), CacheControl::five_minute(), ); - + let json = serde_json::to_value(&msg).unwrap(); - println!("Message with 5-minute cache_control: {}", serde_json::to_string(&json).unwrap()); - - let cache_control = json.get("cache_control").expect("cache_control field should exist"); + println!( + "Message with 5-minute cache_control: {}", + serde_json::to_string(&json).unwrap() + ); + + let cache_control = json + .get("cache_control") + .expect("cache_control field should exist"); assert_eq!(cache_control.get("type").unwrap(), "ephemeral"); assert_eq!(cache_control.get("ttl").unwrap(), "5m"); } @@ -86,11 +114,16 @@ fn test_message_with_one_hour_cache_control() { "System prompt".to_string(), CacheControl::one_hour(), ); - + let json = serde_json::to_value(&msg).unwrap(); - println!("Message with 1-hour cache_control: {}", serde_json::to_string(&json).unwrap()); - - let cache_control = json.get("cache_control").expect("cache_control field should exist"); + println!( + "Message with 1-hour cache_control: {}", + serde_json::to_string(&json).unwrap() + ); + + let cache_control = json + .get("cache_control") + .expect("cache_control field should exist"); assert_eq!(cache_control.get("type").unwrap(), "ephemeral"); assert_eq!(cache_control.get("ttl").unwrap(), "1h"); } @@ -98,10 +131,13 @@ fn test_message_with_one_hour_cache_control() { #[test] fn test_message_without_cache_control() { let msg = Message::new(MessageRole::User, "Hello".to_string()); - + let json = serde_json::to_value(&msg).unwrap(); - println!("Message without cache_control: {}", serde_json::to_string(&json).unwrap()); - + println!( + "Message without cache_control: {}", + serde_json::to_string(&json).unwrap() + ); + // cache_control field should not be present when not set assert!(!json.as_object().unwrap().contains_key("cache_control")); } @@ -110,9 +146,9 @@ fn test_message_without_cache_control() { fn test_cache_control_json_format_ephemeral() { let cache_control = CacheControl::ephemeral(); let json_str = serde_json::to_string(&cache_control).unwrap(); - + println!("Ephemeral JSON string: {}", json_str); - + // Verify exact JSON format assert_eq!(json_str, r#"{"type":"ephemeral"}"#); } @@ -121,9 +157,9 @@ fn test_cache_control_json_format_ephemeral() { fn test_cache_control_json_format_five_minute() { let cache_control = CacheControl::five_minute(); let json_str = serde_json::to_string(&cache_control).unwrap(); - + println!("5-minute JSON string: {}", json_str); - + // Verify exact JSON format assert_eq!(json_str, r#"{"type":"ephemeral","ttl":"5m"}"#); } @@ -132,9 +168,9 @@ fn test_cache_control_json_format_five_minute() { fn test_cache_control_json_format_one_hour() { let cache_control = CacheControl::one_hour(); let json_str = serde_json::to_string(&cache_control).unwrap(); - + println!("1-hour JSON string: {}", json_str); - + // Verify exact JSON format assert_eq!(json_str, r#"{"type":"ephemeral","ttl":"1h"}"#); } @@ -143,7 +179,7 @@ fn test_cache_control_json_format_one_hour() { fn test_deserialization_ephemeral() { let json_str = r#"{"type":"ephemeral"}"#; let cache_control: CacheControl = serde_json::from_str(json_str).unwrap(); - + assert_eq!(cache_control.ttl, None); } @@ -151,7 +187,7 @@ fn test_deserialization_ephemeral() { fn test_deserialization_five_minute() { let json_str = r#"{"type":"ephemeral","ttl":"5m"}"#; let cache_control: CacheControl = serde_json::from_str(json_str).unwrap(); - + assert_eq!(cache_control.ttl, Some("5m".to_string())); } @@ -159,6 +195,6 @@ fn test_deserialization_five_minute() { fn test_deserialization_one_hour() { let json_str = r#"{"type":"ephemeral","ttl":"1h"}"#; let cache_control: CacheControl = serde_json::from_str(json_str).unwrap(); - + assert_eq!(cache_control.ttl, Some("1h".to_string())); } diff --git a/examples/verify_message_id.rs b/examples/verify_message_id.rs index 85cb3c5..6652010 100644 --- a/examples/verify_message_id.rs +++ b/examples/verify_message_id.rs @@ -6,19 +6,19 @@ use g3_providers::{Message, MessageRole}; fn main() { println!("=== Message ID Implementation Verification ==="); println!(); - + // Create several messages to show ID generation println!("Creating 5 messages to demonstrate ID generation:"); for i in 1..=5 { let msg = Message::new(MessageRole::User, format!("Test message {}", i)); println!(" Message {}: id = '{}'", i, msg.id); } - + println!(); println!("ID Format: HHMMSS-XXX"); println!(" - HHMMSS: Current time (hours, minutes, seconds)"); println!(" - XXX: 3 random alphabetic characters (a-z, A-Z)"); - + println!(); println!("Verifying ID is NOT serialized to JSON:"); let msg = Message::new(MessageRole::User, "Hello World".to_string()); @@ -26,7 +26,7 @@ fn main() { println!(" Message ID: {}", msg.id); println!(" JSON output: {}", json); println!(" Contains 'id' field: {}", json.contains("\"id\"")); - + println!(); println!("✅ Implementation complete!"); } diff --git a/monitor_context_window.sh b/monitor_context_window.sh new file mode 100755 index 0000000..1bae991 --- /dev/null +++ b/monitor_context_window.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +# Hacky script for viewing context window + +if [[ -n "$G3_WORKSPACE" ]]; then + TARGET_DIR="$G3_WORKSPACE/logs" +else + TARGET_DIR="$HOME/tmp/workspace/logs" +fi + +if [[ ! -d "$TARGET_DIR" ]]; then + echo "Error: Directory '$TARGET_DIR' does not exist." + exit 1 +fi + +cd "$TARGET_DIR" || exit 1 + +NAME="$TARGET_DIR/current_context_window" + +echo "Monitoring directory '$NAME' for current context window, (waits for first update)" + + +L=$(stat -f %m $NAME); while sleep 0.5; do N=$(stat -f %m $NAME); if [ "$N" != "$L" ]; then clear; cat $NAME; L=$N; fi; done \ No newline at end of file