Compare commits

...

41 Commits

Author SHA1 Message Date
Michael Neale
a457d46446 Merge branch 'main' into micn/fix-anthropic-1p
* main:
  control commands for machine mode
  Fix duplicate dump at end
  minor
  --machine mode flag for verbose CLI output
  fixed x,y detection in vision click
  screenshotting bug fix
  test
  Native api for screen capture
  replace tesseract with apple vision
  more macax tooling
  coach rigor +++
  thinning message highlighted
  warnings fix
  macax tools
  control commands
  Add --interactive-requirements flag for AI-enhanced requirements mode
2025-10-28 13:55:01 +11:00
Dhanji Prasanna
7c2c433746 control commands for machine mode 2025-10-28 12:35:58 +11:00
Dhanji Prasanna
98f4220544 Fix duplicate dump at end 2025-10-27 13:48:46 +11:00
Dhanji Prasanna
a4476a555c minor 2025-10-27 13:32:14 +11:00
Dhanji Prasanna
5e08d6bbba --machine mode flag for verbose CLI output 2025-10-27 10:37:05 +11:00
Dhanji Prasanna
c3f3f79dc5 fixed x,y detection in vision click 2025-10-25 16:51:27 +11:00
Dhanji Prasanna
834153ea69 screenshotting bug fix 2025-10-24 20:40:43 +11:00
Dhanji Prasanna
65f25f840e test 2025-10-24 16:11:24 +11:00
Dhanji Prasanna
a8af5d7cc1 Native api for screen capture 2025-10-24 16:11:12 +11:00
Dhanji Prasanna
61d748034d replace tesseract with apple vision 2025-10-24 15:35:47 +11:00
Dhanji Prasanna
d0ac222e2e more macax tooling 2025-10-24 10:45:24 +11:00
Dhanji Prasanna
e1e732150a coach rigor +++ 2025-10-24 10:15:42 +11:00
Dhanji Prasanna
0be4829ca9 thinning message highlighted 2025-10-23 13:16:13 +11:00
Dhanji Prasanna
efd4eca755 warnings fix 2025-10-23 07:17:55 +11:00
Dhanji Prasanna
3ec65e38ee macax tools 2025-10-23 06:53:42 +11:00
Dhanji Prasanna
c5d6fbef08 control commands 2025-10-22 22:14:12 +11:00
Dhanji R. Prasanna
f93844d378 Merge pull request #10 from dhanji/micn/interactive-requirements
Add --interactive-requirements flag for AI-enhanced requirements mode
2025-10-22 15:37:16 +11:00
Michael Neale
b3d18d02ea prefer provider count 2025-10-22 15:09:47 +11:00
Michael Neale
442ca76cd6 Merge branch 'main' into micn/fix-anthropic-1p
* main:
  fix panic in CLI parser
  coach/player provider split + add OpenAI
2025-10-22 15:01:18 +11:00
Michael Neale
af6d37a8e2 Add --interactive-requirements flag for AI-enhanced requirements mode
- Adds new --interactive-requirements CLI flag for autonomous mode
- Prompts user for brief requirements input
- Uses AI to enhance and structure requirements into proper markdown
- Shows enhanced requirements and allows user to approve/edit/cancel
- Saves to requirements.md and proceeds with autonomous mode if approved
- Includes test script for manual verification
2025-10-22 14:58:35 +11:00
Dhanji R. Prasanna
c1c6680e03 Merge pull request #7 from jochenx/jochen-add-openai-and-multi-providers
coach/player provider split + add OpenAI
2025-10-22 13:46:16 +11:00
Jochen
f2d8e744bb fix panic in CLI parser 2025-10-22 13:20:45 +11:00
Michael Neale
738c3ac53e to get anthropic provider more reliable with tokens 2025-10-22 09:47:24 +11:00
Jochen
010a43d203 coach/player provider split + add OpenAI
Allows coach and player LLM providers to be separately specified.
Also adds OpenAI provider
2025-10-21 16:59:13 +11:00
Dhanji Prasanna
758e255af8 dont run safaridriver --enable each time 2025-10-21 16:00:58 +11:00
Dhanji Prasanna
393826ae02 webdriver tools 2025-10-21 14:34:41 +11:00
Dhanji Prasanna
3afad3d61f progressive context thinning 2025-10-20 15:29:44 +11:00
Dhanji Prasanna
2488cc54d5 docs: update README and DESIGN to reflect current project state
- Add g3-computer-control crate to architecture documentation
- Document all 13 tools including computer control and TODO management
- Add context thinning feature documentation (50-80% thresholds)
- Update tool ecosystem section with complete tool list
- Remove broken link to non-existent COMPUTER_CONTROL.md
- Update workspace count from 5 to 6 crates
- Add platform-specific implementation details for computer control
- Document OCR support via Tesseract
- Clarify setup instructions for computer control features
2025-10-20 15:03:22 +11:00
Dhanji Prasanna
2ad0c9a3fd todo list formatting 2025-10-20 14:27:53 +11:00
Dhanji Prasanna
2008a81193 fix to pass feedback to player (broken by todo system) 2025-10-20 14:12:08 +11:00
Dhanji Prasanna
776f5034b8 TODO tools 2025-10-20 10:50:53 +11:00
Dhanji Prasanna
92bece957b colorizing tool calls 2025-10-18 16:09:30 +11:00
Dhanji Prasanna
767299ff4e minor 2025-10-18 16:03:58 +11:00
Dhanji Prasanna
9d35449be8 ~ expansion for read_file and str_replace 2025-10-18 16:01:15 +11:00
Dhanji Prasanna
da652bf287 computer control tools 2025-10-18 14:16:50 +11:00
Dhanji Prasanna
a566171203 small turn completing bug 2025-10-18 13:25:23 +11:00
Dhanji Prasanna
347c9e1e00 colorize timing based on duration 2025-10-17 13:54:21 +11:00
Dhanji Prasanna
aa7eda0331 fix wall clock timing 2025-10-17 10:36:21 +11:00
Dhanji Prasanna
e42c76f3b9 Tune coach pickiness down 2025-10-17 10:28:08 +11:00
Dhanji Prasanna
dd211fab1c panic fix 2025-10-17 09:50:01 +11:00
Dhanji R. Prasanna
bcece38473 Merge pull request #5 from dhanji/micn/agent-tweaks
load AGENTS.md if there
2025-10-16 15:06:14 +11:00
64 changed files with 8936 additions and 1043 deletions

1
.gitignore vendored
View File

@@ -2,6 +2,7 @@
# will have compiled files and executables
debug
target
.build
# These are backup files generated by rustfmt
**/*.rs.bk

1154
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,8 @@ members = [
"crates/g3-core",
"crates/g3-providers",
"crates/g3-config",
"crates/g3-execution"
"crates/g3-execution",
"crates/g3-computer-control"
]
resolver = "2"

View File

@@ -29,7 +29,8 @@ g3/
│ ├── g3-core/ # Core agent engine, tools, and streaming logic
│ ├── g3-providers/ # LLM provider abstractions and implementations
│ ├── g3-config/ # Configuration management
── g3-execution/ # Code execution engine
── g3-execution/ # Code execution engine
│ └── g3-computer-control/ # Computer control and automation
├── logs/ # Session logs (auto-created)
├── README.md # Project documentation
└── DESIGN.md # This design document
@@ -48,6 +49,7 @@ g3/
│ • Retro TUI │ │ • Tool system │ │ • Embedded │
│ • Autonomous │ │ • Streaming │ │ (llama.cpp) │
│ mode │ │ • Task exec │ │ • OAuth flow │
│ │ │ • TODO mgmt │ │ │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │ │
└───────────────────────┼───────────────────────┘
@@ -59,7 +61,18 @@ g3/
│ • Shell cmds │ │ • Env overrides │
│ • Streaming │ │ • Provider │
│ • Error hdlg │ │ settings │
└─────────────────┘ └─────────────────┘
└─────────────────┘ │ • Computer │
│ │ control cfg │
│ └─────────────────┘
│ │
┌─────────────────┐ │
│ g3-computer- │◄────────────┘
│ control │
│ • Mouse/kbd │
│ • Screenshots │
│ • OCR/Tesseract │
│ • Windows/UI │
└─────────────────┘
```
## Core Components
@@ -79,6 +92,7 @@ g3/
- **Streaming Parser**: Real-time parsing of LLM responses with tool call detection and execution
- **Session Management**: Automatic session logging with detailed conversation history and token usage
- **Error Recovery**: Sophisticated error classification and retry logic for recoverable errors
- **TODO Management**: In-memory TODO list with read/write tools for task tracking
**Available Tools:**
- `shell`: Execute shell commands with streaming output
@@ -86,7 +100,15 @@ g3/
- `write_file`: Create or overwrite files with content
- `str_replace`: Apply unified diffs to files with precise editing
- `final_output`: Signal task completion with detailed summaries
- **Project Management**: Workspace handling, requirements.md processing for autonomous mode
- `todo_read`: Read the entire TODO list content
- `todo_write`: Write or overwrite the entire TODO list
- `mouse_click`: Click the mouse at specific coordinates
- `type_text`: Type text at the current cursor position
- `find_element`: Find UI elements by text, role, or attributes
- `take_screenshot`: Capture screenshots of screen, region, or window
- `extract_text`: Extract text from images or screen regions using OCR
- `find_text_on_screen`: Find text visually on screen and return coordinates
- `list_windows`: List all open windows with IDs and titles
### 2. g3-providers: LLM Provider Abstraction
@@ -172,6 +194,26 @@ g3/
- **Validation**: Configuration validation with helpful error messages
- **Flexible Paths**: Support for shell expansion (`~`, environment variables)
### 6. g3-computer-control: Computer Control & Automation
**Primary Responsibilities:**
- Cross-platform computer control and automation
- Mouse and keyboard input simulation
- Window management and screenshot capture
- OCR text extraction from images and screen regions
**Platform Support:**
- **macOS**: Core Graphics, Cocoa, screencapture integration
- **Linux**: X11/Xtest for input, X11 for window management
- **Windows**: Win32 APIs for input and window control
**Key Features:**
- **OCR Integration**: Tesseract-based text extraction from images
- **Window Management**: List, identify, and capture specific application windows
- **UI Automation**: Find elements, simulate clicks, type text
- **Screenshot Capture**: Full screen, regions, or specific windows
- **Accessibility**: Requires OS-level permissions for automation
## Advanced Features
### Context Window Management
@@ -180,6 +222,7 @@ G3 implements sophisticated context window management:
- **Automatic Monitoring**: Tracks token usage with percentage-based thresholds
- **Smart Summarization**: Auto-triggers at 80% capacity to prevent context overflow
- **Context Thinning**: Progressive thinning at 50%, 60%, 70%, 80% thresholds - replaces large tool results with file references
- **Conversation Preservation**: Maintains conversation continuity through intelligent summaries
- **Provider-Specific Limits**: Adapts to different model context windows (4k to 200k+ tokens)
- **Cumulative Tracking**: Monitors total token usage across entire sessions
@@ -354,20 +397,23 @@ This design document reflects the current state of G3 as a mature, production-re
### Fully Implemented
-**Core Agent Engine**: Complete with streaming, tool execution, and context management
-**Provider System**: Anthropic, Databricks, and Embedded providers with OAuth support
-**Tool System**: All 5 core tools (shell, read_file, write_file, str_replace, final_output)
-**Tool System**: 13 tools including file ops, shell, TODO management, and computer control
-**CLI Interface**: Interactive mode, single-shot mode, retro TUI
-**Autonomous Mode**: Coach-player feedback loop with requirements.md processing
-**Configuration**: TOML-based config with environment overrides
-**Error Handling**: Comprehensive retry logic and error classification
-**Session Logging**: Automatic session tracking and JSON logs
-**Context Management**: Auto-summarization at 80% capacity
-**Context Management**: Context thinning (50-80%) and auto-summarization at 80% capacity
-**Computer Control**: Cross-platform automation with OCR support
-**TODO Management**: In-memory TODO list with read/write tools
### Architecture Highlights
- **Workspace**: 5 crates with clear separation of concerns
- **Workspace**: 6 crates with clear separation of concerns
- **Dependencies**: Modern Rust ecosystem (Tokio, Clap, Serde, etc.)
- **Streaming**: Real-time response processing with tool call detection
- **Cross-Platform**: Works on macOS, Linux, and Windows
- **GPU Support**: Metal acceleration for local models on macOS
- **GPU Support**: Metal acceleration for local models on macOS, CUDA on Linux
- **OCR Support**: Tesseract integration for text extraction from images
### Key Files
- `src/main.rs`: main entry point delegating to g3-cli
@@ -376,3 +422,5 @@ This design document reflects the current state of G3 as a mature, production-re
- `crates/g3-providers/src/lib.rs`: provider trait and registry
- `crates/g3-config/src/lib.rs`: configuration management
- `crates/g3-execution/src/lib.rs`: code execution engine
- `crates/g3-computer-control/src/lib.rs`: computer control and automation
- `crates/g3-computer-control/src/platform/`: platform-specific implementations

View File

@@ -11,8 +11,8 @@ G3 follows a modular architecture organized as a Rust workspace with multiple cr
#### **g3-core**
The heart of the agent system, containing:
- **Agent Engine**: Main orchestration logic for handling conversations, tool execution, and task management
- **Context Window Management**: Intelligent tracking of token usage with auto-summarization capabilities when approaching context limits (~80% capacity)
- **Tool System**: Built-in tools for file operations (read, write, edit), shell command execution, and structured output generation
- **Context Window Management**: Intelligent tracking of token usage with context thinning (50-80%) and auto-summarization at 80% capacity
- **Tool System**: Built-in tools for file operations, shell commands, computer control, TODO management, and structured output
- **Streaming Response Parser**: Real-time parsing of LLM responses with tool call detection and execution
- **Task Execution**: Support for single and iterative task execution with automatic retry logic
@@ -40,6 +40,13 @@ Task execution framework:
- Error handling and retry mechanisms
- Progress tracking and reporting
#### **g3-computer-control**
Computer control capabilities:
- Mouse and keyboard automation
- UI element inspection and interaction
- Screenshot capture and window management
- OCR text extraction via Tesseract
#### **g3-cli**
Command-line interface:
- Interactive terminal interface
@@ -61,13 +68,32 @@ G3 includes robust error handling with automatic retry logic:
### Intelligent Context Management
- Automatic context window monitoring with percentage-based tracking
- Smart auto-summarization when approaching token limits
- **Context thinning** at 50%, 60%, 70%, 80% thresholds - automatically replaces large tool results with file references
- Conversation history preservation through summaries
- Dynamic token allocation for different providers
- Dynamic token allocation for different providers (4k to 200k+ tokens)
### Interactive Control Commands
G3's interactive CLI includes control commands for manual context management:
- **`/compact`**: Manually trigger summarization to compact conversation history
- **`/thinnify`**: Manually trigger context thinning to replace large tool results with file references
- **`/readme`**: Reload README.md and AGENTS.md from disk without restarting
- **`/stats`**: Show detailed context and performance statistics
- **`/help`**: Display all available control commands
These commands give you fine-grained control over context management, allowing you to proactively optimize token usage and refresh project documentation. See [Control Commands Documentation](docs/CONTROL_COMMANDS.md) for detailed usage.
### Tool Ecosystem
- **File Operations**: Read, write, and edit files with line-range precision
- **Shell Integration**: Execute system commands with output capture
- **Code Generation**: Structured code generation with syntax awareness
- **TODO Management**: Read and write TODO lists with markdown checkbox format
- **Computer Control** (Experimental): Automate desktop applications
- Mouse and keyboard control
- macOS Accessibility API for native app automation (via `--macax` flag)
- UI element inspection
- Screenshot capture and window management
- OCR text extraction from images and screen regions
- Window listing and identification
- **Final Output**: Formatted result presentation
### Provider Flexibility
@@ -102,6 +128,7 @@ G3 is designed for:
- API integration and testing
- Documentation generation
- Complex multi-step workflows
- Desktop application automation and testing
## Getting Started
@@ -116,6 +143,54 @@ cargo run
g3 "implement a function to calculate fibonacci numbers"
```
## WebDriver Browser Automation
G3 includes WebDriver support for browser automation tasks using Safari.
**One-Time Setup** (macOS only):
Safari Remote Automation must be enabled before using WebDriver tools. Run this once:
```bash
# Option 1: Use the provided script
./scripts/enable-safari-automation.sh
# Option 2: Enable manually
safaridriver --enable # Requires password
# Option 3: Enable via Safari UI
# Safari → Preferences → Advanced → Show Develop menu
# Then: Develop → Allow Remote Automation
```
**For detailed setup instructions and troubleshooting**, see [WebDriver Setup Guide](docs/webdriver-setup.md).
**Usage**: Run G3 with the `--webdriver` flag to enable browser automation tools.
## macOS Accessibility API Tools
G3 includes support for controlling macOS applications via the Accessibility API, allowing you to automate native macOS apps.
**Available Tools**: `macax_list_apps`, `macax_get_frontmost_app`, `macax_activate_app`, `macax_get_ui_tree`, `macax_find_elements`, `macax_click`, `macax_set_value`, `macax_get_value`, `macax_press_key`
**Setup**: Enable with the `--macax` flag or in config with `macax.enabled = true`. Grant accessibility permissions:
- **macOS**: System Preferences → Security & Privacy → Privacy → Accessibility → Add your terminal app
**For detailed documentation**, see [macOS Accessibility Tools Guide](docs/macax-tools.md).
**Note**: This is particularly useful for testing and automating apps you're building with G3, as you can add accessibility identifiers to your UI elements.
## Computer Control (Experimental)
G3 can interact with your computer's GUI for automation tasks:
**Available Tools**: `mouse_click`, `type_text`, `find_element`, `take_screenshot`, `extract_text`, `find_text_on_screen`, `list_windows`
**Setup**: Enable in config with `computer_control.enabled = true` and grant OS accessibility permissions:
- **macOS**: System Preferences → Security & Privacy → Accessibility
- **Linux**: Ensure X11 or Wayland access
- **Windows**: Run as administrator (first time only)
## Session Logs
G3 automatically saves session logs for each interaction in the `logs/` directory. These logs contain:

View File

@@ -0,0 +1,24 @@
[providers]
default_provider = "databricks"
# Specify different providers for coach and player in autonomous mode
coach = "databricks" # Provider for coach (code reviewer) - can be more powerful/expensive
player = "anthropic" # Provider for player (code implementer) - can be faster/cheaper
[providers.databricks]
host = "https://your-workspace.cloud.databricks.com"
# token = "your-databricks-token" # Optional - will use OAuth if not provided
model = "databricks-claude-sonnet-4"
max_tokens = 4096
temperature = 0.1
use_oauth = true
[providers.anthropic]
api_key = "your-anthropic-api-key"
model = "claude-3-haiku-20240307" # Using a faster model for player
max_tokens = 4096
temperature = 0.3 # Slightly higher temperature for more creative implementations
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60

View File

@@ -1,5 +1,10 @@
[providers]
default_provider = "databricks"
# Optional: Specify different providers for coach and player in autonomous mode
# If not specified, will use default_provider for both
# coach = "databricks" # Provider for coach (code reviewer)
# player = "anthropic" # Provider for player (code implementer)
# Note: Make sure the specified providers are configured below
[providers.databricks]
host = "https://your-workspace.cloud.databricks.com"
@@ -13,3 +18,8 @@ use_oauth = true
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
[computer_control]
enabled = false # Set to true to enable computer control (requires OS permissions)
require_confirmation = true
max_actions_per_second = 5

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,94 @@
use g3_core::ui_writer::UiWriter;
use std::io::{self, Write};
/// Machine-mode implementation of UiWriter that prints plain, unformatted output
/// This is designed for programmatic consumption and outputs everything verbatim
pub struct MachineUiWriter;
impl MachineUiWriter {
pub fn new() -> Self {
Self
}
}
impl UiWriter for MachineUiWriter {
fn print(&self, message: &str) {
print!("{}", message);
}
fn println(&self, message: &str) {
println!("{}", message);
}
fn print_inline(&self, message: &str) {
print!("{}", message);
let _ = io::stdout().flush();
}
fn print_system_prompt(&self, prompt: &str) {
println!("SYSTEM_PROMPT:");
println!("{}", prompt);
println!("END_SYSTEM_PROMPT");
println!();
}
fn print_context_status(&self, message: &str) {
println!("CONTEXT_STATUS: {}", message);
}
fn print_context_thinning(&self, message: &str) {
println!("CONTEXT_THINNING: {}", message);
}
fn print_tool_header(&self, tool_name: &str) {
println!("TOOL_CALL: {}", tool_name);
}
fn print_tool_arg(&self, key: &str, value: &str) {
println!("TOOL_ARG: {} = {}", key, value);
}
fn print_tool_output_header(&self) {
println!("TOOL_OUTPUT:");
}
fn update_tool_output_line(&self, line: &str) {
println!("{}", line);
}
fn print_tool_output_line(&self, line: &str) {
println!("{}", line);
}
fn print_tool_output_summary(&self, count: usize) {
println!("TOOL_OUTPUT_LINES: {}", count);
}
fn print_tool_timing(&self, duration_str: &str) {
println!("TOOL_DURATION: {}", duration_str);
println!("END_TOOL_OUTPUT");
println!();
}
fn print_agent_prompt(&self) {
println!("AGENT_RESPONSE:");
let _ = io::stdout().flush();
}
fn print_agent_response(&self, content: &str) {
print!("{}", content);
let _ = io::stdout().flush();
}
fn notify_sse_received(&self) {
// No-op for machine mode
}
fn flush(&self) {
let _ = io::stdout().flush();
}
fn wants_full_output(&self) -> bool {
true // Machine mode wants complete, untruncated output
}
}

View File

@@ -267,23 +267,23 @@ impl TerminalState {
let mut current_text = String::new();
// Check for headers first
if line.starts_with("### ") {
if let Some(stripped) = line.strip_prefix("### ") {
return Line::from(Span::styled(
format!(" {}", &line[4..]),
format!(" {}", stripped),
Style::default()
.fg(self.theme.terminal_cyan.to_color())
.add_modifier(Modifier::BOLD | Modifier::UNDERLINED),
));
} else if line.starts_with("## ") {
} else if let Some(stripped) = line.strip_prefix("## ") {
return Line::from(Span::styled(
format!(" {}", &line[3..]),
format!(" {}", stripped),
Style::default()
.fg(self.theme.terminal_amber.to_color())
.add_modifier(Modifier::BOLD),
));
} else if line.starts_with("# ") {
} else if let Some(stripped) = line.strip_prefix("# ") {
return Line::from(Span::styled(
format!(" {}", &line[2..]),
format!(" {}", stripped),
Style::default()
.fg(self.theme.terminal_green.to_color())
.add_modifier(Modifier::BOLD),
@@ -343,7 +343,7 @@ impl TerminalState {
}
// Find closing *
let mut italic_text = String::new();
while let Some(ch) = chars.next() {
for ch in chars.by_ref() {
if ch == '*' {
break;
}
@@ -367,7 +367,7 @@ impl TerminalState {
}
// Find closing `
let mut code_text = String::new();
while let Some(ch) = chars.next() {
for ch in chars.by_ref() {
if ch == '`' {
break;
}
@@ -612,12 +612,10 @@ impl RetroTui {
}
// Update status blink only if status is "PROCESSING"
if state.status_line == "PROCESSING" {
if state.last_status_blink.elapsed() > Duration::from_millis(500) {
if state.status_line == "PROCESSING" && state.last_status_blink.elapsed() > Duration::from_millis(500) {
state.status_blink = !state.status_blink;
state.last_status_blink = Instant::now();
}
}
// Update activity area animation
let animation_speed = 0.15; // Adjust for faster/slower animation
@@ -771,12 +769,7 @@ impl RetroTui {
let total_cursor_pos = cursor_position;
// Determine the window into the buffer we should show
let window_start = if total_cursor_pos > available_width - 1 {
// Cursor is beyond the visible area, scroll the view
total_cursor_pos - (available_width - 1)
} else {
0
};
let window_start = total_cursor_pos.saturating_sub(available_width - 1);
// Get the visible portion of the buffer
let visible_buffer: String = input_buffer
@@ -1013,9 +1006,9 @@ impl RetroTui {
let fade_color = |color: Color| -> Color {
match color {
Color::Rgb(r, g, b) => {
let faded_r = ((r as f32 * opacity) as u8).max(0);
let faded_g = ((g as f32 * opacity) as u8).max(0);
let faded_b = ((b as f32 * opacity) as u8).max(0);
let faded_r = (r as f32 * opacity) as u8;
let faded_g = (g as f32 * opacity) as u8;
let faded_b = (b as f32 * opacity) as u8;
Color::Rgb(faded_r, faded_g, faded_b)
}
_ => color,
@@ -1098,9 +1091,9 @@ impl RetroTui {
let fade_color = |color: Color| -> Color {
match color {
Color::Rgb(r, g, b) => {
let faded_r = ((r as f32 * opacity) as u8).max(0);
let faded_g = ((g as f32 * opacity) as u8).max(0);
let faded_b = ((b as f32 * opacity) as u8).max(0);
let faded_r = (r as f32 * opacity) as u8;
let faded_g = (g as f32 * opacity) as u8;
let faded_b = (b as f32 * opacity) as u8;
Color::Rgb(faded_r, faded_g, faded_b)
}
_ => color,
@@ -1176,7 +1169,7 @@ impl RetroTui {
}
// Wave characters for smooth animation
let wave_chars = vec!['▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'];
let wave_chars = ['▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'];
// Build the wave line
let mut wave_line = String::new();
@@ -1190,7 +1183,7 @@ impl RetroTui {
let idx = wave_data.len().saturating_sub(display_width) + i;
if idx < wave_data.len() {
let value = wave_data[idx].min(1.0).max(0.0);
let value = wave_data[idx].clamp(0.0, 1.0);
let char_idx = ((value * 7.0) as usize).min(7);
wave_line.push(wave_chars[char_idx]);
} else {
@@ -1206,8 +1199,6 @@ impl RetroTui {
f.render_widget(wave_paragraph, area);
}
/// Draw the status bar
/// Draw the status bar
fn draw_status_bar(
f: &mut Frame,

View File

@@ -0,0 +1,32 @@
/// Simple output helper for printing messages
pub struct SimpleOutput {
machine_mode: bool,
}
impl SimpleOutput {
pub fn new() -> Self {
SimpleOutput { machine_mode: false }
}
pub fn new_with_mode(machine_mode: bool) -> Self {
SimpleOutput { machine_mode }
}
pub fn print(&self, message: &str) {
if !self.machine_mode {
println!("{}", message);
}
}
pub fn print_smart(&self, message: &str) {
if !self.machine_mode {
println!("{}", message);
}
}
}
impl Default for SimpleOutput {
fn default() -> Self {
Self::new()
}
}

View File

@@ -1,5 +1,6 @@
use crossterm::style::Color;
use crossterm::style::{SetForegroundColor, ResetColor};
use std::io::{self, Write};
use termimad::MadSkin;
/// Simple output handler with markdown support
@@ -40,7 +41,7 @@ impl SimpleOutput {
trimmed.starts_with("* ") ||
trimmed.starts_with("+ ") ||
(trimmed.len() > 2 &&
trimmed.chars().next().map_or(false, |c| c.is_ascii_digit()) &&
trimmed.chars().next().is_some_and(|c| c.is_ascii_digit()) &&
trimmed.chars().nth(1) == Some('.') &&
trimmed.chars().nth(2) == Some(' ')) ||
(trimmed.contains('[') && trimmed.contains("]("))
@@ -93,6 +94,37 @@ impl SimpleOutput {
print!("{}", ResetColor);
println!(" {:.1}% | {}/{} tokens", percentage, used, total);
}
pub fn print_context_thinning(&self, message: &str) {
// Animated highlight for context thinning
// Use bright cyan/green with a quick flash animation
// Flash animation: print with bright background, then normal
let frames = vec![
"\x1b[1;97;46m", // Frame 1: Bold white on cyan background
"\x1b[1;97;42m", // Frame 2: Bold white on green background
"\x1b[1;96;40m", // Frame 3: Bold cyan on black background
];
println!();
// Quick flash animation
for frame in &frames {
print!("\r{}{}\x1b[0m", frame, message);
let _ = io::stdout().flush();
std::thread::sleep(std::time::Duration::from_millis(80));
}
// Final display with bright cyan and sparkle emojis
print!("\r\x1b[1;96m✨ {}\x1b[0m", message);
println!();
// Add a subtle "success" indicator line
println!("\x1b[2;36m └─ Context optimized successfully\x1b[0m");
println!();
let _ = io::stdout().flush();
}
}
#[cfg(test)]

View File

@@ -1,8 +1,6 @@
use crate::retro_tui::RetroTui;
use g3_core::ui_writer::UiWriter;
use std::io::{self, Write};
use std::sync::Mutex;
use std::time::Instant;
/// Console implementation of UiWriter that prints to stdout
pub struct ConsoleUiWriter {
@@ -10,6 +8,7 @@ pub struct ConsoleUiWriter {
current_tool_args: Mutex<Vec<(String, String)>>,
current_output_line: Mutex<Option<String>>,
output_line_printed: Mutex<bool>,
in_todo_tool: Mutex<bool>,
}
impl ConsoleUiWriter {
@@ -19,6 +18,60 @@ impl ConsoleUiWriter {
current_tool_args: Mutex::new(Vec::new()),
current_output_line: Mutex::new(None),
output_line_printed: Mutex::new(false),
in_todo_tool: Mutex::new(false),
}
}
fn print_todo_line(&self, line: &str) {
// Transform and print todo list lines elegantly
let trimmed = line.trim();
// Skip the "📝 TODO list:" prefix line
if trimmed.starts_with("📝 TODO list:") || trimmed == "📝 TODO list is empty" {
return;
}
// Handle empty lines
if trimmed.is_empty() {
println!();
return;
}
// Detect indentation level
let indent_count = line.chars().take_while(|c| c.is_whitespace()).count();
let indent = " ".repeat(indent_count / 2); // Convert spaces to visual indent
// Format based on line type
if trimmed.starts_with("- [ ]") {
// Incomplete task
let task = trimmed.strip_prefix("- [ ]").unwrap_or(trimmed).trim();
println!("{}{}", indent, task);
} else if trimmed.starts_with("- [x]") || trimmed.starts_with("- [X]") {
// Completed task
let task = trimmed.strip_prefix("- [x]")
.or_else(|| trimmed.strip_prefix("- [X]"))
.unwrap_or(trimmed)
.trim();
println!("{}\x1b[2m☑ {}\x1b[0m", indent, task);
} else if trimmed.starts_with("- ") {
// Regular bullet point
let item = trimmed.strip_prefix("- ").unwrap_or(trimmed).trim();
println!("{}{}", indent, item);
} else if trimmed.starts_with("# ") {
// Heading
let heading = trimmed.strip_prefix("# ").unwrap_or(trimmed).trim();
println!("\n\x1b[1m{}\x1b[0m", heading);
} else if trimmed.starts_with("## ") {
// Subheading
let subheading = trimmed.strip_prefix("## ").unwrap_or(trimmed).trim();
println!("\n\x1b[1m{}\x1b[0m", subheading);
} else if trimmed.starts_with("**") && trimmed.ends_with("**") {
// Bold text (section marker)
let text = trimmed.trim_start_matches("**").trim_end_matches("**");
println!("{}\x1b[1m{}\x1b[0m", indent, text);
} else {
// Regular text or note
println!("{}{}", indent, trimmed);
}
}
}
@@ -49,10 +102,49 @@ impl UiWriter for ConsoleUiWriter {
println!("{}", message);
}
fn print_context_thinning(&self, message: &str) {
// Animated highlight for context thinning
// Use bright cyan/green with a quick flash animation
// Flash animation: print with bright background, then normal
let frames = vec![
"\x1b[1;97;46m", // Frame 1: Bold white on cyan background
"\x1b[1;97;42m", // Frame 2: Bold white on green background
"\x1b[1;96;40m", // Frame 3: Bold cyan on black background
];
println!();
// Quick flash animation
for frame in &frames {
print!("\r{}{}\x1b[0m", frame, message);
let _ = io::stdout().flush();
std::thread::sleep(std::time::Duration::from_millis(80));
}
// Final display with bright cyan and sparkle emojis
print!("\r\x1b[1;96m✨ {}\x1b[0m", message);
println!();
// Add a subtle "success" indicator line
println!("\x1b[2;36m └─ Context optimized successfully\x1b[0m");
println!();
let _ = io::stdout().flush();
}
fn print_tool_header(&self, tool_name: &str) {
// Store the tool name and clear args for collection
*self.current_tool_name.lock().unwrap() = Some(tool_name.to_string());
self.current_tool_args.lock().unwrap().clear();
// Check if this is a todo tool call
let is_todo = tool_name == "todo_read" || tool_name == "todo_write";
*self.in_todo_tool.lock().unwrap() = is_todo;
// For todo tools, we'll skip the normal header and print a custom one later
if is_todo {
}
}
fn print_tool_arg(&self, key: &str, value: &str) {
@@ -75,6 +167,12 @@ impl UiWriter for ConsoleUiWriter {
}
fn print_tool_output_header(&self) {
// Skip normal header for todo tools
if *self.in_todo_tool.lock().unwrap() {
println!(); // Just add a newline
return;
}
println!();
// Now print the tool header with the most important arg in bold green
if let Some(tool_name) = self.current_tool_name.lock().unwrap().as_ref() {
@@ -93,7 +191,12 @@ impl UiWriter for ConsoleUiWriter {
// Truncate long values for display
let display_value = if first_line.len() > 80 {
format!("{}...", &first_line[..77])
// Use char_indices to safely truncate at character boundary
let truncate_at = first_line.char_indices()
.nth(77)
.map(|(i, _)| i)
.unwrap_or(first_line.len());
format!("{}...", &first_line[..truncate_at])
} else {
first_line.to_string()
};
@@ -115,8 +218,8 @@ impl UiWriter for ConsoleUiWriter {
String::new()
};
// Print with bold green formatting using ANSI escape codes
println!("┌─\x1b[1;32m {} | {}{}\x1b[0m", tool_name, display_value, header_suffix);
// Print with bold green tool name, purple (non-bold) for pipe and args
println!("┌─\x1b[1;32m {}\x1b[0m\x1b[35m | {}{}\x1b[0m", tool_name, display_value, header_suffix);
} else {
// Print with bold green formatting using ANSI escape codes
println!("┌─\x1b[1;32m {}\x1b[0m", tool_name);
@@ -144,10 +247,21 @@ impl UiWriter for ConsoleUiWriter {
}
fn print_tool_output_line(&self, line: &str) {
// Special handling for todo tools
if *self.in_todo_tool.lock().unwrap() {
self.print_todo_line(line);
return;
}
println!("\x1b[2m{}\x1b[0m", line);
}
fn print_tool_output_summary(&self, count: usize) {
// Skip for todo tools
if *self.in_todo_tool.lock().unwrap() {
return;
}
println!(
"\x1b[2m({} line{})\x1b[0m",
count,
@@ -156,7 +270,55 @@ impl UiWriter for ConsoleUiWriter {
}
fn print_tool_timing(&self, duration_str: &str) {
println!("└─ ⚡️ {}", duration_str);
// For todo tools, just print a simple completion message
if *self.in_todo_tool.lock().unwrap() {
println!();
*self.in_todo_tool.lock().unwrap() = false;
return;
}
// Parse the duration string to determine color
// Format is like "1.5s", "500ms", "2m 30.0s"
let color_code = if duration_str.ends_with("ms") {
// Milliseconds - use default color (< 1s)
""
} else if duration_str.contains('m') {
// Contains minutes
// Extract minutes value
if let Some(m_pos) = duration_str.find('m') {
if let Ok(minutes) = duration_str[..m_pos].trim().parse::<u32>() {
if minutes >= 5 {
"\x1b[31m" // Red for >= 5 minutes
} else {
"\x1b[38;5;208m" // Orange for >= 1 minute but < 5 minutes
}
} else {
"" // Default color if parsing fails
}
} else {
"" // Default color if 'm' not found (shouldn't happen)
}
} else if duration_str.ends_with('s') {
// Seconds only
if let Some(s_value) = duration_str.strip_suffix('s') {
if let Ok(seconds) = s_value.trim().parse::<f64>() {
if seconds >= 1.0 {
"\x1b[33m" // Yellow for >= 1 second
} else {
"" // Default color for < 1 second
}
} else {
"" // Default color if parsing fails
}
} else {
"" // Default color
}
} else {
// Milliseconds or other format - use default color
""
};
println!("└─ ⚡️ {}{}\x1b[0m", color_code, duration_str);
println!();
// Clear the stored tool info
*self.current_tool_name.lock().unwrap() = None;
@@ -183,223 +345,3 @@ impl UiWriter for ConsoleUiWriter {
}
}
/// RetroTui implementation of UiWriter that sends output to the TUI
pub struct RetroTuiWriter {
tui: RetroTui,
current_tool_name: Mutex<Option<String>>,
current_tool_output: Mutex<Vec<String>>,
current_tool_start: Mutex<Option<Instant>>,
current_tool_caption: Mutex<String>,
}
impl RetroTuiWriter {
pub fn new(tui: RetroTui) -> Self {
Self {
tui,
current_tool_name: Mutex::new(None),
current_tool_output: Mutex::new(Vec::new()),
current_tool_start: Mutex::new(None),
current_tool_caption: Mutex::new(String::new()),
}
}
}
impl UiWriter for RetroTuiWriter {
fn print(&self, message: &str) {
self.tui.output(message);
}
fn println(&self, message: &str) {
self.tui.output(message);
}
fn print_inline(&self, message: &str) {
// For inline printing, we'll just append to the output
self.tui.output(message);
}
fn print_system_prompt(&self, prompt: &str) {
self.tui.output("🔍 System Prompt:");
self.tui.output("================");
for line in prompt.lines() {
self.tui.output(line);
}
self.tui.output("================");
self.tui.output("");
}
fn print_context_status(&self, message: &str) {
self.tui.output(message);
}
fn print_tool_header(&self, tool_name: &str) {
// Start collecting tool output
*self.current_tool_start.lock().unwrap() = Some(Instant::now());
*self.current_tool_name.lock().unwrap() = Some(tool_name.to_string());
self.current_tool_output.lock().unwrap().clear();
self.current_tool_output
.lock()
.unwrap()
.push(format!("Tool: {}", tool_name));
// Initialize caption
*self.current_tool_caption.lock().unwrap() = String::new();
}
fn print_tool_arg(&self, key: &str, value: &str) {
// Filter out any keys that look like they might be agent message content
// (e.g., keys that are suspiciously long or contain message-like content)
let is_valid_arg_key = key.len() < 50
&& !key.contains('\n')
&& !key.contains("I'll")
&& !key.contains("Let me")
&& !key.contains("Here's")
&& !key.contains("I can");
if is_valid_arg_key {
self.current_tool_output
.lock()
.unwrap()
.push(format!("{}: {}", key, value));
}
// Build caption from first argument (usually the most important one)
let mut caption = self.current_tool_caption.lock().unwrap();
if caption.is_empty() && (key == "file_path" || key == "command" || key == "path") {
// Truncate long values for the caption
let truncated = if value.len() > 50 {
format!("{}...", &value[..47])
} else {
value.to_string()
};
// Add range information for read_file tool calls
let tool_name = self.current_tool_name.lock().unwrap();
let range_suffix = if tool_name.as_ref().map_or(false, |name| name == "read_file") {
// We need to check if start/end args will be provided - for now just check if this is a partial read
// This is a simplified approach since we're building the caption incrementally
String::new() // We'll handle this in print_tool_output_header instead
} else {
String::new()
};
*caption = format!("{}{}", truncated, range_suffix);
}
}
fn print_tool_output_header(&self) {
// This is called right before tool execution starts
// Send the initial tool header to the TUI now
if let Some(tool_name) = self.current_tool_name.lock().unwrap().as_ref() {
let mut caption = self.current_tool_caption.lock().unwrap().clone();
// Add range information for read_file tool calls
if tool_name == "read_file" {
// Check the tool output for start/end parameters
let output = self.current_tool_output.lock().unwrap();
let has_start = output.iter().any(|line| line.starts_with("start:"));
let has_end = output.iter().any(|line| line.starts_with("end:"));
if has_start || has_end {
let start_val = output.iter().find(|line| line.starts_with("start:")).map(|line| line.split(':').nth(1).unwrap_or("0").trim()).unwrap_or("0");
let end_val = output.iter().find(|line| line.starts_with("end:")).map(|line| line.split(':').nth(1).unwrap_or("end").trim()).unwrap_or("end");
caption = format!("{} [{}..{}]", caption, start_val, end_val);
}
}
// Send the tool output with initial header
self.tui.tool_output(tool_name, &caption, "");
}
self.current_tool_output.lock().unwrap().push(String::new());
self.current_tool_output
.lock()
.unwrap()
.push("Output:".to_string());
}
fn update_tool_output_line(&self, line: &str) {
// For retro mode, we'll just add to the output buffer
self.current_tool_output
.lock()
.unwrap()
.push(line.to_string());
}
fn print_tool_output_line(&self, line: &str) {
self.current_tool_output
.lock()
.unwrap()
.push(line.to_string());
}
fn print_tool_output_summary(&self, hidden_count: usize) {
self.current_tool_output.lock().unwrap().push(format!(
"... ({} more line{})",
hidden_count,
if hidden_count == 1 { "" } else { "s" }
));
}
fn print_tool_timing(&self, duration_str: &str) {
self.current_tool_output
.lock()
.unwrap()
.push(format!("⚡️ {}", duration_str));
// Calculate the actual duration
let duration_ms = if let Some(start) = *self.current_tool_start.lock().unwrap() {
start.elapsed().as_millis()
} else {
0
};
// Get the tool name and caption
if let Some(tool_name) = self.current_tool_name.lock().unwrap().as_ref() {
let content = self.current_tool_output.lock().unwrap().join("\n");
let caption = self.current_tool_caption.lock().unwrap().clone();
let caption = if caption.is_empty() {
"Completed".to_string()
} else {
caption
};
// Update the tool detail panel with the complete output without adding a new header
// This keeps the original header in place to be updated by tool_complete
self.tui.update_tool_detail(tool_name, &content);
// Determine success based on whether there's an error in the output
// This is a simple heuristic - you might want to make this more sophisticated
let success = !content.contains("error")
&& !content.contains("Error")
&& !content.contains("ERROR");
// Send the completion status to update the header
self.tui
.tool_complete(tool_name, success, duration_ms, &caption);
}
// Clear the buffers
*self.current_tool_name.lock().unwrap() = None;
self.current_tool_output.lock().unwrap().clear();
*self.current_tool_start.lock().unwrap() = None;
*self.current_tool_caption.lock().unwrap() = String::new();
}
fn print_agent_prompt(&self) {
self.tui.output("\n💬 ");
}
fn print_agent_response(&self, content: &str) {
self.tui.output(content);
}
fn notify_sse_received(&self) {
// Notify the TUI that an SSE was received
self.tui.sse_received();
}
fn flush(&self) {
// No-op for TUI since it handles its own rendering
}
}

View File

@@ -0,0 +1,47 @@
[package]
name = "g3-computer-control"
version = "0.1.0"
edition = "2021"
[build-dependencies]
# Only needed for building Swift bridge on macOS
[dependencies]
# Workspace dependencies
tokio = { workspace = true }
anyhow = { workspace = true }
thiserror = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tracing = { workspace = true }
uuid = { workspace = true }
shellexpand = "3.1"
# Async trait support
async-trait = "0.1"
# WebDriver support
fantoccini = "0.21"
# macOS dependencies
[target.'cfg(target_os = "macos")'.dependencies]
core-graphics = "0.23"
core-foundation = "0.10"
cocoa = "0.25"
objc = "0.2"
accessibility = "0.2"
image = "0.24"
# Linux dependencies
[target.'cfg(target_os = "linux")'.dependencies]
x11 = { version = "2.21", features = ["xlib", "xtest"] }
image = "0.24"
# Windows dependencies
[target.'cfg(target_os = "windows")'.dependencies]
windows = { version = "0.52", features = [
"Win32_Foundation",
"Win32_UI_WindowsAndMessaging",
"Win32_UI_Input_KeyboardAndMouse",
"Win32_Graphics_Gdi",
] }

View File

@@ -0,0 +1,63 @@
use std::env;
use std::path::PathBuf;
use std::process::Command;
fn main() {
// Only build Vision bridge on macOS
if env::var("CARGO_CFG_TARGET_OS").unwrap() != "macos" {
return;
}
println!("cargo:rerun-if-changed=vision-bridge/Sources/VisionBridge/VisionOCR.swift");
println!("cargo:rerun-if-changed=vision-bridge/Sources/VisionBridge/VisionBridge.h");
println!("cargo:rerun-if-changed=vision-bridge/Package.swift");
let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap());
let vision_bridge_dir = manifest_dir.join("vision-bridge");
// Build Swift package
println!("cargo:warning=Building VisionBridge Swift package...");
let build_status = Command::new("swift")
.args(&["build", "-c", "release"])
.current_dir(&vision_bridge_dir)
.status()
.expect("Failed to build Swift package");
if !build_status.success() {
panic!("Swift build failed");
}
// Find the built library
let lib_path = vision_bridge_dir
.join(".build/release")
.canonicalize()
.expect("Failed to find .build/release directory");
// Copy the dylib to the output directory so it can be found at runtime
let target_dir = manifest_dir.parent().unwrap().parent().unwrap().join("target");
let profile = env::var("PROFILE").unwrap_or_else(|_| "debug".to_string());
let output_dir = target_dir.join(&profile);
let dylib_src = lib_path.join("libVisionBridge.dylib");
let dylib_dst = output_dir.join("libVisionBridge.dylib");
std::fs::copy(&dylib_src, &dylib_dst)
.expect(&format!("Failed to copy dylib from {} to {}", dylib_src.display(), dylib_dst.display()));
println!("cargo:warning=Copied libVisionBridge.dylib to {}", dylib_dst.display());
// Add rpath so the dylib can be found at runtime
println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path");
println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path");
println!("cargo:rustc-link-search=native={}", lib_path.display());
println!("cargo:rustc-link-lib=dylib=VisionBridge");
// Link required frameworks
println!("cargo:rustc-link-lib=framework=Vision");
println!("cargo:rustc-link-lib=framework=AppKit");
println!("cargo:rustc-link-lib=framework=Foundation");
println!("cargo:rustc-link-lib=framework=CoreGraphics");
println!("cargo:rustc-link-lib=framework=CoreImage");
println!("cargo:warning=VisionBridge built successfully at {}", lib_path.display());
}

View File

@@ -0,0 +1,46 @@
use core_graphics::display::CGDisplay;
fn main() {
let display = CGDisplay::main();
let image = display.image().expect("Failed to capture screen");
println!("CGImage properties:");
println!(" Width: {}", image.width());
println!(" Height: {}", image.height());
println!(" Bits per component: {}", image.bits_per_component());
println!(" Bits per pixel: {}", image.bits_per_pixel());
println!(" Bytes per row: {}", image.bytes_per_row());
let data = image.data();
let expected_size = image.width() * image.height() * 4;
println!(" Data length: {}", data.len());
println!(" Expected (w*h*4): {}", expected_size);
// Check if there's padding in rows
let bytes_per_row = image.bytes_per_row();
let width = image.width();
let expected_bytes_per_row = width * 4;
println!("\nRow alignment:");
println!(" Actual bytes per row: {}", bytes_per_row);
println!(" Expected (width * 4): {}", expected_bytes_per_row);
println!(" Padding per row: {}", bytes_per_row - expected_bytes_per_row);
// Sample some pixels from different locations
println!("\nFirst 3 pixels (raw bytes):");
for i in 0..3 {
let offset = i * 4;
println!(" Pixel {}: [{:3}, {:3}, {:3}, {:3}]",
i, data[offset], data[offset+1], data[offset+2], data[offset+3]);
}
// Check a pixel from the middle
let mid_row = image.height() / 2;
let mid_col = image.width() / 2;
let mid_offset = (mid_row * bytes_per_row + mid_col * 4) as usize;
println!("\nMiddle pixel (row {}, col {}):", mid_row, mid_col);
println!(" Offset: {}", mid_offset);
if mid_offset + 3 < data.len() as usize {
println!(" Bytes: [{:3}, {:3}, {:3}, {:3}]",
data[mid_offset], data[mid_offset+1], data[mid_offset+2], data[mid_offset+3]);
}
}

View File

@@ -0,0 +1,56 @@
use core_graphics::window::{kCGWindowListOptionOnScreenOnly, kCGNullWindowID, CGWindowListCopyWindowInfo};
use core_foundation::dictionary::CFDictionary;
use core_foundation::string::CFString;
use core_foundation::base::{TCFType, ToVoid};
fn main() {
println!("Listing all on-screen windows...");
println!("{:<10} {:<25} {}", "Window ID", "Owner", "Title");
println!("{}", "-".repeat(80));
unsafe {
let window_list = CGWindowListCopyWindowInfo(
kCGWindowListOptionOnScreenOnly,
kCGNullWindowID
);
let count = core_foundation::array::CFArray::<CFDictionary>::wrap_under_create_rule(window_list).len();
let array = core_foundation::array::CFArray::<CFDictionary>::wrap_under_create_rule(window_list);
for i in 0..count {
let dict = array.get(i).unwrap();
// Get window ID
let window_id_key = CFString::from_static_string("kCGWindowNumber");
let window_id: i64 = if let Some(value) = dict.find(window_id_key.to_void()) {
let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _);
num.to_i64().unwrap_or(0)
} else {
0
};
// Get owner name
let owner_key = CFString::from_static_string("kCGWindowOwnerName");
let owner: String = if let Some(value) = dict.find(owner_key.to_void()) {
let s: CFString = TCFType::wrap_under_get_rule(*value as *const _);
s.to_string()
} else {
"Unknown".to_string()
};
// Get window name/title
let name_key = CFString::from_static_string("kCGWindowName");
let title: String = if let Some(value) = dict.find(name_key.to_void()) {
let s: CFString = TCFType::wrap_under_get_rule(*value as *const _);
s.to_string()
} else {
"".to_string()
};
// Show all windows
if !owner.is_empty() {
println!("{:<10} {:<25} {}", window_id, owner, title);
}
}
}
}

View File

@@ -0,0 +1,74 @@
//! Example demonstrating macOS Accessibility API tools
//!
//! This example shows how to use the macax tools to control macOS applications.
//!
//! Run with: cargo run --example macax_demo
use anyhow::Result;
use g3_computer_control::MacAxController;
#[tokio::main]
async fn main() -> Result<()> {
println!("🍎 macOS Accessibility API Demo\n");
println!("This demo shows how to control macOS applications using the Accessibility API.\n");
// Create controller
let controller = MacAxController::new()?;
println!("✅ MacAxController initialized\n");
// List running applications
println!("📱 Listing running applications:");
match controller.list_applications() {
Ok(apps) => {
for app in apps.iter().take(10) {
println!(" - {}", app.name);
}
if apps.len() > 10 {
println!(" ... and {} more", apps.len() - 10);
}
}
Err(e) => println!(" ❌ Error: {}", e),
}
println!();
// Get frontmost app
println!("🎯 Getting frontmost application:");
match controller.get_frontmost_app() {
Ok(app) => println!(" Current: {}", app.name),
Err(e) => println!(" ❌ Error: {}", e),
}
println!();
// Example: Activate Finder and get its UI tree
println!("📂 Activating Finder and inspecting UI:");
match controller.activate_app("Finder") {
Ok(_) => {
println!(" ✅ Finder activated");
// Wait a moment for activation
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
// Get UI tree
match controller.get_ui_tree("Finder", 2) {
Ok(tree) => {
println!("\n UI Tree:");
for line in tree.lines().take(10) {
println!(" {}", line);
}
}
Err(e) => println!(" ❌ Error getting UI tree: {}", e),
}
}
Err(e) => println!(" ❌ Error: {}", e),
}
println!();
println!("✨ Demo complete!\n");
println!("💡 Tips:");
println!(" - Use --macax flag with g3 to enable these tools");
println!(" - Grant accessibility permissions in System Preferences");
println!(" - Add accessibility identifiers to your apps for easier automation");
println!(" - See docs/macax-tools.md for full documentation\n");
Ok(())
}

View File

@@ -0,0 +1,64 @@
use g3_computer_control::SafariDriver;
use g3_computer_control::webdriver::WebDriverController;
use anyhow::Result;
#[tokio::main]
async fn main() -> Result<()> {
println!("Safari WebDriver Demo");
println!("=====================\n");
println!("Make sure to:");
println!("1. Enable 'Allow Remote Automation' in Safari's Develop menu");
println!("2. Run: /usr/bin/safaridriver --enable");
println!("3. Start safaridriver in another terminal: safaridriver --port 4444\n");
println!("Connecting to SafariDriver...");
let mut driver = SafariDriver::new().await?;
println!("✅ Connected!\n");
// Navigate to a website
println!("Navigating to example.com...");
driver.navigate("https://example.com").await?;
println!("✅ Navigated\n");
// Get page title
let title = driver.title().await?;
println!("Page title: {}\n", title);
// Get current URL
let url = driver.current_url().await?;
println!("Current URL: {}\n", url);
// Find an element
println!("Finding h1 element...");
let h1 = driver.find_element("h1").await?;
let h1_text = h1.text().await?;
println!("H1 text: {}\n", h1_text);
// Find all paragraphs
println!("Finding all paragraphs...");
let paragraphs = driver.find_elements("p").await?;
println!("Found {} paragraphs\n", paragraphs.len());
// Get page source
println!("Getting page source...");
let source = driver.page_source().await?;
println!("Page source length: {} bytes\n", source.len());
// Execute JavaScript
println!("Executing JavaScript...");
let result = driver.execute_script("return document.title", vec![]).await?;
println!("JS result: {:?}\n", result);
// Take a screenshot
println!("Taking screenshot...");
driver.screenshot("/tmp/safari_demo.png").await?;
println!("✅ Screenshot saved to /tmp/safari_demo.png\n");
// Close the browser
println!("Closing browser...");
driver.quit().await?;
println!("✅ Done!");
Ok(())
}

View File

@@ -0,0 +1,21 @@
use g3_computer_control::create_controller;
#[tokio::main]
async fn main() {
println!("Testing screenshot with permission prompt...");
let controller = create_controller().expect("Failed to create controller");
match controller.take_screenshot("/tmp/test_with_prompt.png", None, None).await {
Ok(_) => {
println!("\n✅ Screenshot saved to /tmp/test_with_prompt.png");
println!("Opening screenshot...");
let _ = std::process::Command::new("open")
.arg("/tmp/test_with_prompt.png")
.spawn();
}
Err(e) => {
println!("❌ Screenshot failed: {}", e);
}
}
}

View File

@@ -0,0 +1,39 @@
use std::process::Command;
fn main() {
let path = "/tmp/rust_screencapture_test.png";
println!("Testing screencapture command from Rust...");
let mut cmd = Command::new("screencapture");
cmd.arg("-x"); // No sound
cmd.arg(path);
println!("Command: {:?}", cmd);
match cmd.output() {
Ok(output) => {
println!("Exit status: {}", output.status);
println!("Stdout: {}", String::from_utf8_lossy(&output.stdout));
println!("Stderr: {}", String::from_utf8_lossy(&output.stderr));
if output.status.success() {
println!("\n✅ Screenshot saved to: {}", path);
// Check file exists and size
if let Ok(metadata) = std::fs::metadata(path) {
println!("File size: {} bytes ({:.1} MB)", metadata.len(), metadata.len() as f64 / 1_000_000.0);
}
// Open it
let _ = Command::new("open").arg(path).spawn();
println!("\nOpened screenshot - please verify it looks correct!");
} else {
println!("\n❌ Screenshot failed!");
}
}
Err(e) => {
println!("❌ Failed to execute screencapture: {}", e);
}
}
}

View File

@@ -0,0 +1,68 @@
use core_graphics::display::CGDisplay;
use image::{ImageBuffer, RgbaImage};
fn main() {
let display = CGDisplay::main();
let image = display.image().expect("Failed to capture screen");
let width = image.width() as u32;
let height = image.height() as u32;
let bytes_per_row = image.bytes_per_row() as usize;
let data = image.data();
println!("Testing screenshot fix...");
println!("Image: {}x{}, bytes_per_row: {}", width, height, bytes_per_row);
println!("Expected bytes per row: {}", width * 4);
println!("Padding per row: {} bytes", bytes_per_row - (width as usize * 4));
// OLD METHOD (broken) - treating data as continuous
println!("\n=== OLD METHOD (BROKEN) ===");
let mut old_rgba = Vec::with_capacity(data.len() as usize);
for chunk in data.chunks_exact(4) {
old_rgba.push(chunk[2]); // R
old_rgba.push(chunk[1]); // G
old_rgba.push(chunk[0]); // B
old_rgba.push(chunk[3]); // A
}
println!("Converted {} pixels", old_rgba.len() / 4);
println!("Expected {} pixels", width * height);
// NEW METHOD (fixed) - handling row padding
println!("\n=== NEW METHOD (FIXED) ===");
let mut new_rgba = Vec::with_capacity((width * height * 4) as usize);
for row in 0..height as usize {
let row_start = row * bytes_per_row;
let row_end = row_start + (width as usize * 4);
for chunk in data[row_start..row_end].chunks_exact(4) {
new_rgba.push(chunk[2]); // R
new_rgba.push(chunk[1]); // G
new_rgba.push(chunk[0]); // B
new_rgba.push(chunk[3]); // A
}
}
println!("Converted {} pixels", new_rgba.len() / 4);
println!("Expected {} pixels", width * height);
// Save a small crop from both methods
let crop_size = 200;
// Old method crop
let old_crop: Vec<u8> = old_rgba.iter().take((crop_size * crop_size * 4) as usize).copied().collect();
if let Some(old_img) = ImageBuffer::from_raw(crop_size, crop_size, old_crop) {
let old_img: RgbaImage = old_img;
old_img.save("/tmp/screenshot_old_method.png").unwrap();
println!("\nSaved OLD method crop to: /tmp/screenshot_old_method.png");
}
// New method crop
let new_crop: Vec<u8> = new_rgba.iter().take((crop_size * crop_size * 4) as usize).copied().collect();
if let Some(new_img) = ImageBuffer::from_raw(crop_size, crop_size, new_crop) {
let new_img: RgbaImage = new_img;
new_img.save("/tmp/screenshot_new_method.png").unwrap();
println!("Saved NEW method crop to: /tmp/screenshot_new_method.png");
}
println!("\nOpen both images to compare:");
println!(" open /tmp/screenshot_old_method.png /tmp/screenshot_new_method.png");
}

View File

@@ -0,0 +1,48 @@
//! Test the new type_text functionality
use anyhow::Result;
use g3_computer_control::MacAxController;
#[tokio::main]
async fn main() -> Result<()> {
println!("🧪 Testing macax type_text functionality\n");
let controller = MacAxController::new()?;
println!("✅ Controller initialized\n");
// Test 1: Type simple text
println!("Test 1: Typing simple text into TextEdit");
println!(" Please open TextEdit and create a new document...");
std::thread::sleep(std::time::Duration::from_secs(3));
match controller.type_text("TextEdit", "Hello, World!") {
Ok(_) => println!(" ✅ Successfully typed simple text\n"),
Err(e) => println!(" ❌ Failed: {}\n", e),
}
std::thread::sleep(std::time::Duration::from_secs(1));
// Test 2: Type unicode and emojis
println!("Test 2: Typing unicode and emojis");
match controller.type_text("TextEdit", "\n🌟 Unicode test: café, naïve, 日本語 🎉") {
Ok(_) => println!(" ✅ Successfully typed unicode text\n"),
Err(e) => println!(" ❌ Failed: {}\n", e),
}
std::thread::sleep(std::time::Duration::from_secs(1));
// Test 3: Type special characters
println!("Test 3: Typing special characters");
match controller.type_text("TextEdit", "\nSpecial: @#$%^&*()_+-=[]{}|;':,.<>?/") {
Ok(_) => println!(" ✅ Successfully typed special characters\n"),
Err(e) => println!(" ❌ Failed: {}\n", e),
}
println!("\n✨ Tests complete!");
println!("\n💡 Now try with Things3:");
println!(" 1. Open Things3");
println!(" 2. Press Cmd+N to create a new task");
println!(" 3. Run: g3 --macax 'type \"🌟 My awesome task\" into Things'");
Ok(())
}

View File

@@ -0,0 +1,85 @@
use g3_computer_control::ocr::{OCREngine, DefaultOCR};
use anyhow::Result;
#[tokio::main]
async fn main() -> Result<()> {
println!("🧪 Testing Apple Vision OCR");
println!("===========================\n");
// Initialize OCR engine
println!("📦 Initializing OCR engine...");
let ocr = DefaultOCR::new()?;
println!("✅ OCR engine: {}\n", ocr.name());
// Check if test image exists
let test_image = "/tmp/safari_test.png";
if !std::path::Path::new(test_image).exists() {
println!("⚠️ Test image not found: {}", test_image);
println!(" Creating a screenshot...");
let status = std::process::Command::new("screencapture")
.arg("-x")
.arg("-R")
.arg("0,0,1200,800")
.arg(test_image)
.status()?;
if !status.success() {
anyhow::bail!("Failed to create screenshot");
}
println!("✅ Screenshot created\n");
}
// Run OCR
println!("🔍 Running Apple Vision OCR on {}...", test_image);
let start = std::time::Instant::now();
let locations = ocr.extract_text_with_locations(test_image).await?;
let duration = start.elapsed();
println!("✅ OCR completed in {:.3}s\n", duration.as_secs_f64());
// Display results
println!("📊 Results:");
println!(" Found {} text elements\n", locations.len());
if locations.is_empty() {
println!("⚠️ No text found in image");
} else {
println!(" Top 20 results:");
println!(" {:<4} {:<40} {:<15} {:<12} {:<8}", "#", "Text", "Position", "Size", "Conf");
println!(" {}", "-".repeat(85));
for (i, loc) in locations.iter().take(20).enumerate() {
let text = if loc.text.len() > 37 {
format!("{}...", &loc.text[..37])
} else {
loc.text.clone()
};
println!(" {:<4} {:<40} ({:>4},{:>4}) {:>4}x{:<4} {:.2}",
i + 1,
text,
loc.x,
loc.y,
loc.width,
loc.height,
loc.confidence
);
}
if locations.len() > 20 {
println!("\n ... and {} more", locations.len() - 20);
}
// Performance comparison
println!("\n📈 Performance:");
println!(" OCR Speed: {:.3}s", duration.as_secs_f64());
println!(" Text elements: {}", locations.len());
println!(" Avg per element: {:.1}ms", duration.as_millis() as f64 / locations.len() as f64);
}
println!("\n✅ Test complete!");
Ok(())
}

View File

@@ -0,0 +1,45 @@
use g3_computer_control::create_controller;
#[tokio::main]
async fn main() {
println!("Testing window-specific screenshot capture...");
let controller = create_controller().expect("Failed to create controller");
// Test 1: Capture iTerm2 window
println!("\n1. Capturing iTerm2 window...");
match controller.take_screenshot("/tmp/iterm_window.png", None, Some("iTerm2")).await {
Ok(_) => {
println!(" ✅ iTerm2 window captured to /tmp/iterm_window.png");
let _ = std::process::Command::new("open").arg("/tmp/iterm_window.png").spawn();
}
Err(e) => println!(" ❌ Failed: {}", e),
}
// Wait a moment for the image to open
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
// Test 2: Full screen capture for comparison
println!("\n2. Capturing full screen for comparison...");
match controller.take_screenshot("/tmp/fullscreen.png", None, None).await {
Ok(_) => {
println!(" ✅ Full screen captured to /tmp/fullscreen.png");
let _ = std::process::Command::new("open").arg("/tmp/fullscreen.png").spawn();
}
Err(e) => println!(" ❌ Failed: {}", e),
}
println!("\n=== Comparison ===");
println!("iTerm window: /tmp/iterm_window.png (should show ONLY iTerm window)");
println!("Full screen: /tmp/fullscreen.png (should show entire desktop)");
// Show file sizes
if let Ok(meta1) = std::fs::metadata("/tmp/iterm_window.png") {
if let Ok(meta2) = std::fs::metadata("/tmp/fullscreen.png") {
println!("\nFile sizes:");
println!(" iTerm window: {:.1} MB", meta1.len() as f64 / 1_000_000.0);
println!(" Full screen: {:.1} MB", meta2.len() as f64 / 1_000_000.0);
println!("\nWindow capture should be smaller than full screen.");
}
}
}

View File

@@ -0,0 +1,49 @@
// Suppress warnings from objc crate macros
#![allow(unexpected_cfgs)]
pub mod types;
pub mod platform;
pub mod ocr;
pub mod webdriver;
pub mod macax;
// Re-export webdriver types for convenience
pub use webdriver::{WebDriverController, WebElement, safari::SafariDriver};
// Re-export macax types for convenience
pub use macax::{MacAxController, AXElement, AXApplication};
use anyhow::Result;
use async_trait::async_trait;
use types::*;
#[async_trait]
pub trait ComputerController: Send + Sync {
// Screen capture
async fn take_screenshot(&self, path: &str, region: Option<Rect>, window_id: Option<&str>) -> Result<()>;
// OCR operations
async fn extract_text_from_screen(&self, region: Rect, window_id: &str) -> Result<String>;
async fn extract_text_from_image(&self, path: &str) -> Result<String>;
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>>;
async fn find_text_in_app(&self, app_name: &str, search_text: &str) -> Result<Option<TextLocation>>;
// Mouse operations
fn move_mouse(&self, x: i32, y: i32) -> Result<()>;
fn click_at(&self, x: i32, y: i32, app_name: Option<&str>) -> Result<()>;
}
// Platform-specific constructor
pub fn create_controller() -> Result<Box<dyn ComputerController>> {
#[cfg(target_os = "macos")]
return Ok(Box::new(platform::macos::MacOSController::new()?));
#[cfg(target_os = "linux")]
return Ok(Box::new(platform::linux::LinuxController::new()?));
#[cfg(target_os = "windows")]
return Ok(Box::new(platform::windows::WindowsController::new()?));
#[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))]
anyhow::bail!("Unsupported platform")
}

View File

@@ -0,0 +1,822 @@
use super::{AXApplication, AXElement};
use anyhow::{Context, Result};
use std::collections::HashMap;
#[cfg(target_os = "macos")]
use accessibility::{AXUIElement, AXUIElementAttributes, ElementFinder, TreeVisitor, TreeWalker, TreeWalkerFlow};
#[cfg(target_os = "macos")]
use core_foundation::base::TCFType;
#[cfg(target_os = "macos")]
use core_foundation::string::CFString;
/// macOS Accessibility API controller using native APIs
pub struct MacAxController {
// Cache for application elements
app_cache: std::sync::Mutex<HashMap<String, AXUIElement>>,
}
impl MacAxController {
pub fn new() -> Result<Self> {
#[cfg(target_os = "macos")]
{
// Check if we have accessibility permissions by trying to get system-wide element
let _system = AXUIElement::system_wide();
Ok(Self {
app_cache: std::sync::Mutex::new(HashMap::new()),
})
}
#[cfg(not(target_os = "macos"))]
{
anyhow::bail!("macOS Accessibility API is only available on macOS")
}
}
/// List all running applications
#[cfg(target_os = "macos")]
pub fn list_applications(&self) -> Result<Vec<AXApplication>> {
let apps = Self::get_running_applications()?;
Ok(apps)
}
#[cfg(not(target_os = "macos"))]
pub fn list_applications(&self) -> Result<Vec<AXApplication>> {
anyhow::bail!("Not supported on this platform")
}
#[cfg(target_os = "macos")]
fn get_running_applications() -> Result<Vec<AXApplication>> {
use cocoa::appkit::NSApplicationActivationPolicy;
use cocoa::base::{id, nil};
use objc::{class, msg_send, sel, sel_impl};
unsafe {
let workspace: id = msg_send![class!(NSWorkspace), sharedWorkspace];
let running_apps: id = msg_send![workspace, runningApplications];
let count: usize = msg_send![running_apps, count];
let mut apps = Vec::new();
for i in 0..count {
let app: id = msg_send![running_apps, objectAtIndex: i];
// Get app name
let localized_name: id = msg_send![app, localizedName];
if localized_name == nil {
continue;
}
let name_ptr: *const i8 = msg_send![localized_name, UTF8String];
let name = if !name_ptr.is_null() {
std::ffi::CStr::from_ptr(name_ptr)
.to_string_lossy()
.to_string()
} else {
continue;
};
// Get bundle ID
let bundle_id_obj: id = msg_send![app, bundleIdentifier];
let bundle_id = if bundle_id_obj != nil {
let bundle_id_ptr: *const i8 = msg_send![bundle_id_obj, UTF8String];
if !bundle_id_ptr.is_null() {
Some(
std::ffi::CStr::from_ptr(bundle_id_ptr)
.to_string_lossy()
.to_string(),
)
} else {
None
}
} else {
None
};
// Get PID
let pid: i32 = msg_send![app, processIdentifier];
// Skip background-only apps
let activation_policy: i64 = msg_send![app, activationPolicy];
if activation_policy == NSApplicationActivationPolicy::NSApplicationActivationPolicyRegular as i64 {
apps.push(AXApplication {
name,
bundle_id,
pid,
});
}
}
Ok(apps)
}
}
/// Get the frontmost (active) application
#[cfg(target_os = "macos")]
pub fn get_frontmost_app(&self) -> Result<AXApplication> {
use cocoa::base::{id, nil};
use objc::{class, msg_send, sel, sel_impl};
unsafe {
let workspace: id = msg_send![class!(NSWorkspace), sharedWorkspace];
let frontmost_app: id = msg_send![workspace, frontmostApplication];
if frontmost_app == nil {
anyhow::bail!("No frontmost application");
}
// Get app name
let localized_name: id = msg_send![frontmost_app, localizedName];
let name_ptr: *const i8 = msg_send![localized_name, UTF8String];
let name = std::ffi::CStr::from_ptr(name_ptr)
.to_string_lossy()
.to_string();
// Get bundle ID
let bundle_id_obj: id = msg_send![frontmost_app, bundleIdentifier];
let bundle_id = if bundle_id_obj != nil {
let bundle_id_ptr: *const i8 = msg_send![bundle_id_obj, UTF8String];
if !bundle_id_ptr.is_null() {
Some(
std::ffi::CStr::from_ptr(bundle_id_ptr)
.to_string_lossy()
.to_string(),
)
} else {
None
}
} else {
None
};
// Get PID
let pid: i32 = msg_send![frontmost_app, processIdentifier];
Ok(AXApplication {
name,
bundle_id,
pid,
})
}
}
#[cfg(not(target_os = "macos"))]
pub fn get_frontmost_app(&self) -> Result<AXApplication> {
anyhow::bail!("Not supported on this platform")
}
/// Get AXUIElement for an application by name or PID
#[cfg(target_os = "macos")]
fn get_app_element(&self, app_name: &str) -> Result<AXUIElement> {
// Check cache first
{
let cache = self.app_cache.lock().unwrap();
if let Some(element) = cache.get(app_name) {
return Ok(element.clone());
}
}
// Find the app by name
let apps = Self::get_running_applications()?;
let app = apps
.iter()
.find(|a| a.name == app_name)
.ok_or_else(|| anyhow::anyhow!("Application '{}' not found", app_name))?;
// Create AXUIElement for the app
let element = AXUIElement::application(app.pid);
// Cache it
{
let mut cache = self.app_cache.lock().unwrap();
cache.insert(app_name.to_string(), element.clone());
}
Ok(element)
}
/// Activate (bring to front) an application
#[cfg(target_os = "macos")]
pub fn activate_app(&self, app_name: &str) -> Result<()> {
use cocoa::base::id;
use objc::{class, msg_send, sel, sel_impl};
// Find the app
let apps = Self::get_running_applications()?;
let app = apps
.iter()
.find(|a| a.name == app_name)
.ok_or_else(|| anyhow::anyhow!("Application '{}' not found", app_name))?;
unsafe {
let workspace: id = msg_send![class!(NSWorkspace), sharedWorkspace];
let running_apps: id = msg_send![workspace, runningApplications];
let count: usize = msg_send![running_apps, count];
for i in 0..count {
let running_app: id = msg_send![running_apps, objectAtIndex: i];
let pid: i32 = msg_send![running_app, processIdentifier];
if pid == app.pid {
let _: bool = msg_send![running_app, activateWithOptions: 0];
return Ok(());
}
}
}
anyhow::bail!("Failed to activate application")
}
#[cfg(not(target_os = "macos"))]
pub fn activate_app(&self, _app_name: &str) -> Result<()> {
anyhow::bail!("Not supported on this platform")
}
/// Get the UI hierarchy of an application
#[cfg(target_os = "macos")]
pub fn get_ui_tree(&self, app_name: &str, max_depth: usize) -> Result<String> {
let app_element = self.get_app_element(app_name)?;
let mut output = format!("Application: {}\n", app_name);
Self::build_ui_tree(&app_element, &mut output, 0, max_depth)?;
Ok(output)
}
#[cfg(not(target_os = "macos"))]
pub fn get_ui_tree(&self, _app_name: &str, _max_depth: usize) -> Result<String> {
anyhow::bail!("Not supported on this platform")
}
#[cfg(target_os = "macos")]
fn build_ui_tree(
element: &AXUIElement,
output: &mut String,
depth: usize,
max_depth: usize,
) -> Result<()> {
if depth >= max_depth {
return Ok(());
}
let indent = " ".repeat(depth);
// Get role
let role = element.role().ok().map(|s| s.to_string())
.unwrap_or_else(|| "Unknown".to_string());
// Get title
let title = element.title().ok()
.map(|s| s.to_string());
// Get identifier
let identifier = element.identifier().ok()
.map(|s| s.to_string());
// Format output
output.push_str(&format!("{}Role: {}", indent, role));
if let Some(t) = title {
output.push_str(&format!(", Title: {}", t));
}
if let Some(id) = identifier {
output.push_str(&format!(", ID: {}", id));
}
output.push('\n');
// Get children
if let Ok(children) = element.children() {
for i in 0..children.len() {
if let Some(child) = children.get(i) {
let _ = Self::build_ui_tree(&child, output, depth + 1, max_depth);
}
}
}
Ok(())
}
/// Find UI elements in an application
#[cfg(target_os = "macos")]
pub fn find_elements(
&self,
app_name: &str,
role: Option<&str>,
title: Option<&str>,
identifier: Option<&str>,
) -> Result<Vec<AXElement>> {
let app_element = self.get_app_element(app_name)?;
let mut found_elements = Vec::new();
let visitor = ElementCollector {
role_filter: role.map(|s| s.to_string()),
title_filter: title.map(|s| s.to_string()),
identifier_filter: identifier.map(|s| s.to_string()),
results: std::cell::RefCell::new(&mut found_elements),
depth: std::cell::Cell::new(0),
};
let walker = TreeWalker::new();
walker.walk(&app_element, &visitor);
Ok(found_elements)
}
#[cfg(not(target_os = "macos"))]
pub fn find_elements(
&self,
_app_name: &str,
_role: Option<&str>,
_title: Option<&str>,
_identifier: Option<&str>,
) -> Result<Vec<AXElement>> {
anyhow::bail!("Not supported on this platform")
}
/// Find a single element (helper for click, set_value, etc.)
#[cfg(target_os = "macos")]
fn find_element(
&self,
app_name: &str,
role: &str,
title: Option<&str>,
identifier: Option<&str>,
) -> Result<AXUIElement> {
let app_element = self.get_app_element(app_name)?;
let role_str = role.to_string();
let title_str = title.map(|s| s.to_string());
let identifier_str = identifier.map(|s| s.to_string());
let finder = ElementFinder::new(
&app_element,
move |element| {
// Check role
let elem_role = element.role()
.ok()
.map(|s| s.to_string());
if let Some(r) = elem_role {
if !r.contains(&role_str) {
return false;
}
} else {
return false;
}
// Check title if specified
if let Some(ref title_filter) = title_str {
let elem_title = element.title()
.ok()
.map(|s| s.to_string());
if let Some(t) = elem_title {
if !t.contains(title_filter) {
return false;
}
} else {
return false;
}
}
// Check identifier if specified
if let Some(ref id_filter) = identifier_str {
let elem_id = element.identifier()
.ok()
.map(|s| s.to_string());
if let Some(id) = elem_id {
if !id.contains(id_filter) {
return false;
}
} else {
return false;
}
}
true
},
Some(std::time::Duration::from_secs(2)),
);
finder.find().context("Element not found")
}
/// Click on a UI element
#[cfg(target_os = "macos")]
pub fn click_element(
&self,
app_name: &str,
role: &str,
title: Option<&str>,
identifier: Option<&str>,
) -> Result<()> {
let element = self.find_element(app_name, role, title, identifier)?;
// Perform the press action
let action_name = CFString::new("AXPress");
element
.perform_action(&action_name)
.map_err(|e| anyhow::anyhow!("Failed to perform press action: {:?}", e))?;
Ok(())
}
#[cfg(not(target_os = "macos"))]
pub fn click_element(
&self,
_app_name: &str,
_role: &str,
_title: Option<&str>,
_identifier: Option<&str>,
) -> Result<()> {
anyhow::bail!("Not supported on this platform")
}
/// Set the value of a UI element
#[cfg(target_os = "macos")]
pub fn set_value(
&self,
app_name: &str,
role: &str,
value: &str,
title: Option<&str>,
identifier: Option<&str>,
) -> Result<()> {
let element = self.find_element(app_name, role, title, identifier)?;
// Set the value - convert CFString to CFType
let cf_value = CFString::new(value);
element.set_value(cf_value.as_CFType())
.map_err(|e| anyhow::anyhow!("Failed to set value: {:?}", e))?;
Ok(())
}
#[cfg(not(target_os = "macos"))]
pub fn set_value(
&self,
_app_name: &str,
_role: &str,
_value: &str,
_title: Option<&str>,
_identifier: Option<&str>,
) -> Result<()> {
anyhow::bail!("Not supported on this platform")
}
/// Get the value of a UI element
#[cfg(target_os = "macos")]
pub fn get_value(
&self,
app_name: &str,
role: &str,
title: Option<&str>,
identifier: Option<&str>,
) -> Result<String> {
let element = self.find_element(app_name, role, title, identifier)?;
// Get the value
let value_type = element.value()
.map_err(|e| anyhow::anyhow!("Failed to get value: {:?}", e))?;
// Try to downcast to CFString
if let Some(cf_string) = value_type.downcast::<CFString>() {
Ok(cf_string.to_string())
} else {
// For non-string values, try to get a description
Ok(format!("<non-string value>"))
}
}
#[cfg(not(target_os = "macos"))]
pub fn get_value(
&self,
_app_name: &str,
_role: &str,
_title: Option<&str>,
_identifier: Option<&str>,
) -> Result<String> {
anyhow::bail!("Not supported on this platform")
}
/// Type text into the currently focused element (uses system text input)
#[cfg(target_os = "macos")]
pub fn type_text(&self, app_name: &str, text: &str) -> Result<()> {
use cocoa::base::{id, nil};
use cocoa::foundation::NSString;
use objc::{class, msg_send, sel, sel_impl};
// First, make sure the app is active
self.activate_app(app_name)?;
// Wait for app to fully activate
std::thread::sleep(std::time::Duration::from_millis(500));
// Send a Tab key to try to focus on a text field
// This helps ensure something is focused before we paste
let _ = self.press_key(app_name, "tab", vec![]);
std::thread::sleep(std::time::Duration::from_millis(800));
// Save old clipboard, set new content, paste, then restore
let old_content: id;
unsafe {
// Get the general pasteboard
let pasteboard: id = msg_send![class!(NSPasteboard), generalPasteboard];
// Save current clipboard content
let ns_string_type = NSString::alloc(nil).init_str("public.utf8-plain-text");
old_content = msg_send![pasteboard, stringForType: ns_string_type];
// Clear and set new content
let _: () = msg_send![pasteboard, clearContents];
let ns_string = NSString::alloc(nil).init_str(text);
let ns_type = NSString::alloc(nil).init_str("public.utf8-plain-text");
let _: bool = msg_send![pasteboard, setString:ns_string forType:ns_type];
}
// Wait a moment for clipboard to update
std::thread::sleep(std::time::Duration::from_millis(200));
// Paste using Cmd+V (outside unsafe block)
self.press_key(app_name, "v", vec!["command"])?;
// Wait for paste to complete
std::thread::sleep(std::time::Duration::from_millis(300));
// Restore old clipboard content if it existed
unsafe {
if old_content != nil {
let pasteboard: id = msg_send![class!(NSPasteboard), generalPasteboard];
let _: () = msg_send![pasteboard, clearContents];
let ns_type = NSString::alloc(nil).init_str("public.utf8-plain-text");
let _: bool = msg_send![pasteboard, setString:old_content forType:ns_type];
}
}
Ok(())
}
#[cfg(not(target_os = "macos"))]
pub fn type_text(&self, _app_name: &str, _text: &str) -> Result<()> {
anyhow::bail!("Not supported on this platform")
}
/// Focus on a text field or text area element
#[cfg(target_os = "macos")]
pub fn focus_element(
&self,
app_name: &str,
role: &str,
title: Option<&str>,
identifier: Option<&str>,
) -> Result<()> {
let element = self.find_element(app_name, role, title, identifier)?;
// Set focused attribute to true
use core_foundation::boolean::CFBoolean;
let cf_true = CFBoolean::true_value();
element.set_attribute(&accessibility::AXAttribute::focused(), cf_true)
.map_err(|e| anyhow::anyhow!("Failed to focus element: {:?}", e))?;
Ok(())
}
/// Press a keyboard shortcut
#[cfg(target_os = "macos")]
pub fn press_key(
&self,
app_name: &str,
key: &str,
modifiers: Vec<&str>,
) -> Result<()> {
use core_graphics::event::{
CGEvent, CGEventFlags, CGEventTapLocation,
};
use core_graphics::event_source::{CGEventSource, CGEventSourceStateID};
// First, make sure the app is active
self.activate_app(app_name)?;
// Wait a bit for activation
std::thread::sleep(std::time::Duration::from_millis(100));
// Map key string to key code
let key_code = Self::key_to_keycode(key)
.ok_or_else(|| anyhow::anyhow!("Unknown key: {}", key))?;
// Map modifiers to flags
let mut flags = CGEventFlags::CGEventFlagNull;
for modifier in modifiers {
match modifier.to_lowercase().as_str() {
"command" | "cmd" => flags |= CGEventFlags::CGEventFlagCommand,
"option" | "alt" => flags |= CGEventFlags::CGEventFlagAlternate,
"control" | "ctrl" => flags |= CGEventFlags::CGEventFlagControl,
"shift" => flags |= CGEventFlags::CGEventFlagShift,
_ => {}
}
}
// Create event source
let source = CGEventSource::new(CGEventSourceStateID::HIDSystemState)
.ok().context("Failed to create event source")?;
// Create key down event
let key_down = CGEvent::new_keyboard_event(source.clone(), key_code, true)
.ok().context("Failed to create key down event")?;
key_down.set_flags(flags);
// Create key up event
let key_up = CGEvent::new_keyboard_event(source, key_code, false)
.ok().context("Failed to create key up event")?;
key_up.set_flags(flags);
// Post events
key_down.post(CGEventTapLocation::HID);
std::thread::sleep(std::time::Duration::from_millis(50));
key_up.post(CGEventTapLocation::HID);
Ok(())
}
#[cfg(not(target_os = "macos"))]
pub fn press_key(
&self,
_app_name: &str,
_key: &str,
_modifiers: Vec<&str>,
) -> Result<()> {
anyhow::bail!("Not supported on this platform")
}
#[cfg(target_os = "macos")]
fn key_to_keycode(key: &str) -> Option<u16> {
// Map common keys to keycodes
// See: https://eastmanreference.com/complete-list-of-applescript-key-codes
match key.to_lowercase().as_str() {
"a" => Some(0x00),
"s" => Some(0x01),
"d" => Some(0x02),
"f" => Some(0x03),
"h" => Some(0x04),
"g" => Some(0x05),
"z" => Some(0x06),
"x" => Some(0x07),
"c" => Some(0x08),
"v" => Some(0x09),
"b" => Some(0x0B),
"q" => Some(0x0C),
"w" => Some(0x0D),
"e" => Some(0x0E),
"r" => Some(0x0F),
"y" => Some(0x10),
"t" => Some(0x11),
"1" => Some(0x12),
"2" => Some(0x13),
"3" => Some(0x14),
"4" => Some(0x15),
"6" => Some(0x16),
"5" => Some(0x17),
"=" => Some(0x18),
"9" => Some(0x19),
"7" => Some(0x1A),
"-" => Some(0x1B),
"8" => Some(0x1C),
"0" => Some(0x1D),
"]" => Some(0x1E),
"o" => Some(0x1F),
"u" => Some(0x20),
"[" => Some(0x21),
"i" => Some(0x22),
"p" => Some(0x23),
"return" | "enter" => Some(0x24),
"l" => Some(0x25),
"j" => Some(0x26),
"'" => Some(0x27),
"k" => Some(0x28),
";" => Some(0x29),
"\\" => Some(0x2A),
"," => Some(0x2B),
"/" => Some(0x2C),
"n" => Some(0x2D),
"m" => Some(0x2E),
"." => Some(0x2F),
"tab" => Some(0x30),
"space" => Some(0x31),
"`" => Some(0x32),
"delete" | "backspace" => Some(0x33),
"escape" | "esc" => Some(0x35),
"f1" => Some(0x7A),
"f2" => Some(0x78),
"f3" => Some(0x63),
"f4" => Some(0x76),
"f5" => Some(0x60),
"f6" => Some(0x61),
"f7" => Some(0x62),
"f8" => Some(0x64),
"f9" => Some(0x65),
"f10" => Some(0x6D),
"f11" => Some(0x67),
"f12" => Some(0x6F),
"left" => Some(0x7B),
"right" => Some(0x7C),
"down" => Some(0x7D),
"up" => Some(0x7E),
_ => None,
}
}
}
#[cfg(target_os = "macos")]
struct ElementCollector<'a> {
role_filter: Option<String>,
title_filter: Option<String>,
identifier_filter: Option<String>,
results: std::cell::RefCell<&'a mut Vec<AXElement>>,
depth: std::cell::Cell<usize>,
}
#[cfg(target_os = "macos")]
impl<'a> TreeVisitor for ElementCollector<'a> {
fn enter_element(&self, element: &AXUIElement) -> TreeWalkerFlow {
self.depth.set(self.depth.get() + 1);
if self.depth.get() > 20 {
return TreeWalkerFlow::SkipSubtree;
}
// Get element properties
let role = element.role()
.ok()
.map(|s| s.to_string())
.unwrap_or_else(|| "Unknown".to_string());
let title = element.title()
.ok()
.map(|s| s.to_string());
let identifier = element.identifier()
.ok()
.map(|s| s.to_string());
// Check if this element matches the filters
let role_matches = self.role_filter.as_ref().map_or(true, |r| role.contains(r));
let title_matches = self.title_filter.as_ref().map_or(true, |t| {
title.as_ref().map_or(false, |title_str| title_str.contains(t))
});
let identifier_matches = self.identifier_filter.as_ref().map_or(true, |id| {
identifier.as_ref().map_or(false, |id_str| id_str.contains(id))
});
if role_matches && title_matches && identifier_matches {
// Get additional properties
let value = element.value()
.ok()
.and_then(|v| {
v.downcast::<CFString>().map(|s| s.to_string())
});
let label = element.description()
.ok()
.map(|s| s.to_string());
let enabled = element.enabled()
.ok()
.map(|b| b.into())
.unwrap_or(false);
let focused = element.focused()
.ok()
.map(|b| b.into())
.unwrap_or(false);
// Count children
let children_count = element.children()
.ok()
.map(|arr| arr.len() as usize)
.unwrap_or(0);
self.results.borrow_mut().push(AXElement {
role,
title,
value,
label,
identifier,
enabled,
focused,
position: None,
size: None,
children_count,
});
}
TreeWalkerFlow::Continue
}
fn exit_element(&self, _element: &AXUIElement) {
self.depth.set(self.depth.get() - 1);
}
}

View File

@@ -0,0 +1,65 @@
pub mod controller;
pub use controller::MacAxController;
use serde::{Deserialize, Serialize};
#[cfg(test)]
mod tests;
/// Represents an accessibility element in the UI hierarchy
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AXElement {
pub role: String,
pub title: Option<String>,
pub value: Option<String>,
pub label: Option<String>,
pub identifier: Option<String>,
pub enabled: bool,
pub focused: bool,
pub position: Option<(f64, f64)>,
pub size: Option<(f64, f64)>,
pub children_count: usize,
}
/// Represents a macOS application
#[derive(Debug, Clone)]
pub struct AXApplication {
pub name: String,
pub bundle_id: Option<String>,
pub pid: i32,
}
impl AXElement {
/// Convert to a human-readable string representation
pub fn to_string(&self) -> String {
let mut parts = vec![format!("Role: {}", self.role)];
if let Some(ref title) = self.title {
parts.push(format!("Title: {}", title));
}
if let Some(ref value) = self.value {
parts.push(format!("Value: {}", value));
}
if let Some(ref label) = self.label {
parts.push(format!("Label: {}", label));
}
if let Some(ref id) = self.identifier {
parts.push(format!("ID: {}", id));
}
parts.push(format!("Enabled: {}", self.enabled));
parts.push(format!("Focused: {}", self.focused));
if let Some((x, y)) = self.position {
parts.push(format!("Position: ({:.0}, {:.0})", x, y));
}
if let Some((w, h)) = self.size {
parts.push(format!("Size: ({:.0}, {:.0})", w, h));
}
parts.push(format!("Children: {}", self.children_count));
parts.join(", ")
}
}

View File

@@ -0,0 +1,37 @@
#[cfg(test)]
mod tests {
use crate::{AXElement, MacAxController};
#[test]
fn test_ax_element_to_string() {
let element = AXElement {
role: "button".to_string(),
title: Some("Click Me".to_string()),
value: None,
label: Some("Submit Button".to_string()),
identifier: Some("submitBtn".to_string()),
enabled: true,
focused: false,
position: Some((100.0, 200.0)),
size: Some((80.0, 30.0)),
children_count: 0,
};
let string_repr = element.to_string();
assert!(string_repr.contains("Role: button"));
assert!(string_repr.contains("Title: Click Me"));
assert!(string_repr.contains("Label: Submit Button"));
assert!(string_repr.contains("ID: submitBtn"));
assert!(string_repr.contains("Enabled: true"));
assert!(string_repr.contains("Position: (100, 200)"));
assert!(string_repr.contains("Size: (80, 30)"));
}
#[test]
fn test_controller_creation() {
// Just test that we can create a controller
// Actual functionality requires macOS and permissions
let result = MacAxController::new();
assert!(result.is_ok());
}
}

View File

@@ -0,0 +1,26 @@
use crate::types::TextLocation;
use anyhow::Result;
use async_trait::async_trait;
/// OCR engine trait for text recognition with bounding boxes
#[async_trait]
pub trait OCREngine: Send + Sync {
/// Extract text with locations from an image file
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>>;
/// Get the name of the OCR engine
fn name(&self) -> &str;
}
// Platform-specific modules
#[cfg(target_os = "macos")]
pub mod vision;
pub mod tesseract;
// Re-export the default OCR engine for the platform
#[cfg(target_os = "macos")]
pub use vision::AppleVisionOCR as DefaultOCR;
#[cfg(not(target_os = "macos"))]
pub use tesseract::TesseractOCR as DefaultOCR;

View File

@@ -0,0 +1,84 @@
use super::OCREngine;
use crate::types::TextLocation;
use anyhow::Result;
use async_trait::async_trait;
/// Tesseract OCR engine (fallback/cross-platform)
pub struct TesseractOCR;
impl TesseractOCR {
pub fn new() -> Result<Self> {
// Check if tesseract is available
let tesseract_check = std::process::Command::new("which")
.arg("tesseract")
.output();
if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() {
anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\
To install tesseract:\n macOS: brew install tesseract\n \
Linux: sudo apt-get install tesseract-ocr (Ubuntu/Debian)\n \
sudo yum install tesseract (RHEL/CentOS)\n \
Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki\n\n\
After installation, restart your terminal and try again.");
}
Ok(Self)
}
}
#[async_trait]
impl OCREngine for TesseractOCR {
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>> {
// Use tesseract CLI with TSV output to get bounding boxes
let output = std::process::Command::new("tesseract")
.arg(path)
.arg("stdout")
.arg("tsv")
.output()
.map_err(|e| anyhow::anyhow!("Failed to run tesseract: {}", e))?;
if !output.status.success() {
anyhow::bail!("Tesseract failed: {}", String::from_utf8_lossy(&output.stderr));
}
let tsv_text = String::from_utf8_lossy(&output.stdout);
let mut locations = Vec::new();
// Parse TSV output (skip header line)
for (i, line) in tsv_text.lines().enumerate() {
if i == 0 { continue; } // Skip header
let parts: Vec<&str> = line.split('\t').collect();
if parts.len() >= 12 {
// TSV format: level, page_num, block_num, par_num, line_num, word_num,
// left, top, width, height, conf, text
if let (Ok(x), Ok(y), Ok(w), Ok(h), Ok(conf), text) = (
parts[6].parse::<i32>(),
parts[7].parse::<i32>(),
parts[8].parse::<i32>(),
parts[9].parse::<i32>(),
parts[10].parse::<f32>(),
parts[11],
) {
let trimmed = text.trim();
if !trimmed.is_empty() && conf > 0.0 {
locations.push(TextLocation {
text: trimmed.to_string(),
x,
y,
width: w,
height: h,
confidence: conf / 100.0, // Convert from 0-100 to 0-1
});
}
}
}
}
Ok(locations)
}
fn name(&self) -> &str {
"Tesseract OCR"
}
}

View File

@@ -0,0 +1,103 @@
use super::OCREngine;
use crate::types::TextLocation;
use anyhow::{Result, Context};
use async_trait::async_trait;
use std::ffi::{CStr, CString};
use std::os::raw::{c_char, c_float, c_uint};
// FFI bindings to Swift VisionBridge
#[repr(C)]
struct VisionTextBox {
text: *const c_char,
text_len: c_uint,
x: i32,
y: i32,
width: i32,
height: i32,
confidence: c_float,
}
extern "C" {
fn vision_recognize_text(
image_path: *const c_char,
image_path_len: c_uint,
out_boxes: *mut *mut std::ffi::c_void,
out_count: *mut c_uint,
) -> bool;
fn vision_free_boxes(boxes: *mut std::ffi::c_void, count: c_uint);
}
/// Apple Vision Framework OCR engine
pub struct AppleVisionOCR;
impl AppleVisionOCR {
pub fn new() -> Result<Self> {
Ok(Self)
}
}
#[async_trait]
impl OCREngine for AppleVisionOCR {
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>> {
// Convert path to C string
let c_path = CString::new(path)
.context("Failed to convert path to C string")?;
let mut boxes_ptr: *mut std::ffi::c_void = std::ptr::null_mut();
let mut count: c_uint = 0;
// Call Swift Vision API
let success = unsafe {
vision_recognize_text(
c_path.as_ptr(),
path.len() as c_uint,
&mut boxes_ptr,
&mut count,
)
};
if !success || boxes_ptr.is_null() {
anyhow::bail!("Apple Vision OCR failed");
}
// Convert C array to Rust Vec
let mut locations = Vec::new();
unsafe {
let typed_boxes = boxes_ptr as *const VisionTextBox;
let boxes_slice = std::slice::from_raw_parts(typed_boxes, count as usize);
for box_data in boxes_slice {
// Convert C string to Rust String
let text = if !box_data.text.is_null() {
CStr::from_ptr(box_data.text)
.to_string_lossy()
.into_owned()
} else {
String::new()
};
if !text.is_empty() {
locations.push(TextLocation {
text,
x: box_data.x,
y: box_data.y,
width: box_data.width,
height: box_data.height,
confidence: box_data.confidence,
});
}
}
// Free the C array
vision_free_boxes(boxes_ptr, count);
}
Ok(locations)
}
fn name(&self) -> &str {
"Apple Vision Framework"
}
}

View File

@@ -0,0 +1,166 @@
use crate::{ComputerController, types::*};
use anyhow::Result;
use async_trait::async_trait;
use tesseract::Tesseract;
use uuid::Uuid;
pub struct LinuxController {
// Placeholder for X11 connection or other state
}
impl LinuxController {
pub fn new() -> Result<Self> {
// Initialize X11 connection
tracing::warn!("Linux computer control not fully implemented");
Ok(Self {})
}
}
#[async_trait]
impl ComputerController for LinuxController {
async fn move_mouse(&self, _x: i32, _y: i32) -> Result<()> {
anyhow::bail!("Linux implementation not yet available")
}
async fn click(&self, _button: MouseButton) -> Result<()> {
anyhow::bail!("Linux implementation not yet available")
}
async fn double_click(&self, _button: MouseButton) -> Result<()> {
anyhow::bail!("Linux implementation not yet available")
}
async fn type_text(&self, _text: &str) -> Result<()> {
anyhow::bail!("Linux implementation not yet available")
}
async fn press_key(&self, _key: &str) -> Result<()> {
anyhow::bail!("Linux implementation not yet available")
}
async fn list_windows(&self) -> Result<Vec<Window>> {
anyhow::bail!("Linux implementation not yet available")
}
async fn focus_window(&self, _window_id: &str) -> Result<()> {
anyhow::bail!("Linux implementation not yet available")
}
async fn get_window_bounds(&self, _window_id: &str) -> Result<Rect> {
anyhow::bail!("Linux implementation not yet available")
}
async fn find_element(&self, _selector: &ElementSelector) -> Result<Option<UIElement>> {
anyhow::bail!("Linux implementation not yet available")
}
async fn get_element_text(&self, _element_id: &str) -> Result<String> {
anyhow::bail!("Linux implementation not yet available")
}
async fn get_element_bounds(&self, _element_id: &str) -> Result<Rect> {
anyhow::bail!("Linux implementation not yet available")
}
async fn take_screenshot(&self, _path: &str, _region: Option<Rect>, _window_id: Option<&str>) -> Result<()> {
// Enforce that window_id must be provided
if _window_id.is_none() {
anyhow::bail!("window_id is required. You must specify which window to capture (e.g., 'Firefox', 'Terminal', 'gedit'). Use list_windows to see available windows.");
}
anyhow::bail!("Linux implementation not yet available")
}
async fn extract_text_from_screen(&self, _region: Rect, _window_id: &str) -> Result<String> {
anyhow::bail!("Linux implementation not yet available")
}
async fn extract_text_from_image(&self, _path: &str) -> Result<OCRResult> {
// Check if tesseract is available on the system
let tesseract_check = std::process::Command::new("which")
.arg("tesseract")
.output();
if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() {
anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\
To install tesseract:\n \
Ubuntu/Debian: sudo apt-get install tesseract-ocr\n \
RHEL/CentOS: sudo yum install tesseract\n \
Arch Linux: sudo pacman -S tesseract\n\n\
After installation, restart your terminal and try again.");
}
// Initialize Tesseract
let tess = Tesseract::new(None, Some("eng"))
.map_err(|e| {
anyhow::anyhow!("Failed to initialize Tesseract: {}\n\n\
This usually means:\n1. Tesseract is not properly installed\n\
2. Language data files are missing\n\nTo fix:\n \
Ubuntu/Debian: sudo apt-get install tesseract-ocr-eng\n \
RHEL/CentOS: sudo yum install tesseract-langpack-eng\n \
Arch Linux: sudo pacman -S tesseract-data-eng", e)
})?;
let text = tess.set_image(_path)
.map_err(|e| anyhow::anyhow!("Failed to load image '{}': {}", _path, e))?
.get_text()
.map_err(|e| anyhow::anyhow!("Failed to extract text from image: {}", e))?;
// Get confidence (simplified - would need more complex API calls for per-word confidence)
let confidence = 0.85; // Placeholder
Ok(OCRResult {
text,
confidence,
bounds: Rect { x: 0, y: 0, width: 0, height: 0 }, // Would need image dimensions
})
}
async fn find_text_on_screen(&self, _text: &str) -> Result<Option<Point>> {
// Check if tesseract is available on the system
let tesseract_check = std::process::Command::new("which")
.arg("tesseract")
.output();
if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() {
anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\
To install tesseract:\n \
Ubuntu/Debian: sudo apt-get install tesseract-ocr\n \
RHEL/CentOS: sudo yum install tesseract\n \
Arch Linux: sudo pacman -S tesseract\n\n\
After installation, restart your terminal and try again.");
}
// Take full screen screenshot
let temp_path = format!("/tmp/g3_ocr_search_{}.png", uuid::Uuid::new_v4());
self.take_screenshot(&temp_path, None, None).await?;
// Use Tesseract to find text with bounding boxes
let tess = Tesseract::new(None, Some("eng"))
.map_err(|e| {
anyhow::anyhow!("Failed to initialize Tesseract: {}\n\n\
This usually means:\n1. Tesseract is not properly installed\n\
2. Language data files are missing\n\nTo fix:\n \
Ubuntu/Debian: sudo apt-get install tesseract-ocr-eng\n \
RHEL/CentOS: sudo yum install tesseract-langpack-eng\n \
Arch Linux: sudo pacman -S tesseract-data-eng", e)
})?;
let full_text = tess.set_image(temp_path.as_str())
.map_err(|e| anyhow::anyhow!("Failed to load screenshot: {}", e))?
.get_text()
.map_err(|e| anyhow::anyhow!("Failed to extract text from screen: {}", e))?;
// Clean up temp file
let _ = std::fs::remove_file(&temp_path);
// Simple text search - full implementation would use get_component_images
// to get bounding boxes for each word
if full_text.contains(_text) {
tracing::warn!("Text found but precise coordinates not available in simplified implementation");
Ok(Some(Point { x: 0, y: 0 }))
} else {
Ok(None)
}
}
}

View File

@@ -0,0 +1,507 @@
use crate::{ComputerController, types::{Rect, TextLocation}};
use crate::ocr::{OCREngine, DefaultOCR};
use anyhow::{Result, Context};
use async_trait::async_trait;
use std::path::Path;
use core_graphics::window::{kCGWindowListOptionOnScreenOnly, kCGNullWindowID, CGWindowListCopyWindowInfo};
use core_foundation::dictionary::CFDictionary;
use core_foundation::string::CFString;
use core_foundation::base::{TCFType, ToVoid};
use core_foundation::array::CFArray;
pub struct MacOSController {
ocr_engine: Box<dyn OCREngine>,
#[allow(dead_code)]
ocr_name: String,
}
impl MacOSController {
pub fn new() -> Result<Self> {
let ocr = Box::new(DefaultOCR::new()?);
let ocr_name = ocr.name().to_string();
tracing::info!("Initialized macOS controller with OCR engine: {}", ocr_name);
Ok(Self { ocr_engine: ocr, ocr_name })
}
}
#[async_trait]
impl ComputerController for MacOSController {
async fn take_screenshot(&self, path: &str, region: Option<Rect>, window_id: Option<&str>) -> Result<()> {
// Enforce that window_id must be provided
if window_id.is_none() {
return Err(anyhow::anyhow!("window_id is required. You must specify which window to capture (e.g., 'Safari', 'Terminal', 'Google Chrome'). Use list_windows to see available windows."));
}
// Determine the temporary directory for screenshots
let temp_dir = std::env::var("TMPDIR")
.or_else(|_| std::env::var("HOME").map(|h| format!("{}/tmp", h)))
.unwrap_or_else(|_| "/tmp".to_string());
// Ensure temp directory exists
std::fs::create_dir_all(&temp_dir)?;
// If path is relative or doesn't specify a directory, use temp_dir
let final_path = if path.starts_with('/') {
path.to_string()
} else {
format!("{}/{}", temp_dir.trim_end_matches('/'), path)
};
let path_obj = Path::new(&final_path);
if let Some(parent) = path_obj.parent() {
std::fs::create_dir_all(parent)?;
}
let app_name = window_id.unwrap(); // Safe because we checked is_none() above
// Get the window ID for the specified application
let cg_window_id = unsafe {
let window_list = CGWindowListCopyWindowInfo(
kCGWindowListOptionOnScreenOnly,
kCGNullWindowID
);
let array = CFArray::<CFDictionary>::wrap_under_create_rule(window_list);
let count = array.len();
let mut found_window_id: Option<(u32, String)> = None; // (id, owner)
let app_name_lower = app_name.to_lowercase();
for i in 0..count {
let dict = array.get(i).unwrap();
// Get owner name
let owner_key = CFString::from_static_string("kCGWindowOwnerName");
let owner: String = if let Some(value) = dict.find(owner_key.to_void()) {
let s: CFString = TCFType::wrap_under_get_rule(*value as *const _);
s.to_string()
} else {
continue;
};
tracing::debug!("Checking window: owner='{}', looking for '{}'", owner, app_name);
let owner_lower = owner.to_lowercase();
// Normalize by removing spaces for exact matching
let app_name_normalized = app_name_lower.replace(" ", "");
let owner_normalized = owner_lower.replace(" ", "");
// ONLY accept exact matches (case-insensitive, with or without spaces)
// This prevents "Goose" from matching "GooseStudio"
let is_match = owner_lower == app_name_lower || owner_normalized == app_name_normalized;
if is_match {
// Get window ID
let window_id_key = CFString::from_static_string("kCGWindowNumber");
if let Some(value) = dict.find(window_id_key.to_void()) {
let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _);
if let Some(id) = num.to_i64() {
// Get window layer to filter out menu bar windows
let layer_key = CFString::from_static_string("kCGWindowLayer");
let layer: i32 = if let Some(value) = dict.find(layer_key.to_void()) {
let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _);
num.to_i32().unwrap_or(0)
} else {
0
};
// Get window bounds to verify it's a real window
let bounds_key = CFString::from_static_string("kCGWindowBounds");
let has_real_bounds = if let Some(value) = dict.find(bounds_key.to_void()) {
let bounds_dict: CFDictionary = TCFType::wrap_under_get_rule(*value as *const _);
let width_key = CFString::from_static_string("Width");
let height_key = CFString::from_static_string("Height");
if let (Some(w_val), Some(h_val)) = (
bounds_dict.find(width_key.to_void()),
bounds_dict.find(height_key.to_void()),
) {
let w_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*w_val as *const _);
let h_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*h_val as *const _);
let width = w_num.to_f64().unwrap_or(0.0);
let height = h_num.to_f64().unwrap_or(0.0);
// Real windows should be at least 100x100 pixels
width >= 100.0 && height >= 100.0
} else {
false
}
} else {
false
};
// Only accept windows that are:
// 1. At layer 0 (normal windows, not menu bar)
// 2. Have real bounds (width and height >= 100)
if layer == 0 && has_real_bounds {
tracing::info!("Found valid window: ID {} for app '{}' (layer={}, bounds valid)", id, owner, layer);
found_window_id = Some((id as u32, owner.clone()));
break;
} else {
tracing::debug!("Skipping window ID {} for '{}': layer={}, has_real_bounds={}", id, owner, layer, has_real_bounds);
}
}
}
}
}
found_window_id
};
let (cg_window_id, matched_owner) = cg_window_id.ok_or_else(|| {
anyhow::anyhow!("Could not find window for application '{}'. Use list_windows to see available windows.", app_name)
})?;
tracing::info!("Taking screenshot of window ID {} for app '{}'", cg_window_id, matched_owner);
// Use screencapture with the window ID for now
// TODO: Implement direct CGWindowListCreateImage approach with proper image saving
let mut cmd = std::process::Command::new("screencapture");
cmd.arg("-x"); // No sound
cmd.arg("-l");
cmd.arg(cg_window_id.to_string());
if let Some(region) = region {
cmd.arg("-R");
cmd.arg(format!("{},{},{},{}", region.x, region.y, region.width, region.height));
}
cmd.arg(&final_path);
let screenshot_result = cmd.output()?;
if !screenshot_result.status.success() {
let stderr = String::from_utf8_lossy(&screenshot_result.stderr);
return Err(anyhow::anyhow!("screencapture failed for window {}: {}", cg_window_id, stderr));
}
Ok(())
}
async fn extract_text_from_screen(&self, region: Rect, window_id: &str) -> Result<String> {
// Take screenshot of region first
let temp_path = format!("/tmp/g3_ocr_{}.png", uuid::Uuid::new_v4());
self.take_screenshot(&temp_path, Some(region), Some(window_id)).await?;
// Extract text from the screenshot
let result = self.extract_text_from_image(&temp_path).await?;
// Clean up temp file
let _ = std::fs::remove_file(&temp_path);
Ok(result)
}
async fn extract_text_from_image(&self, path: &str) -> Result<String> {
// Extract all text and concatenate
let locations = self.ocr_engine.extract_text_with_locations(path).await?;
Ok(locations.iter().map(|loc| loc.text.as_str()).collect::<Vec<_>>().join(" "))
}
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>> {
// Use the OCR engine
self.ocr_engine.extract_text_with_locations(path).await
}
async fn find_text_in_app(&self, app_name: &str, search_text: &str) -> Result<Option<TextLocation>> {
// Take screenshot of specific app window
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
let temp_path = format!("{}/tmp/g3_find_text_{}_{}.png", home, app_name, uuid::Uuid::new_v4());
self.take_screenshot(&temp_path, None, Some(app_name)).await?;
// Get screenshot dimensions before we delete it
let screenshot_dims = get_image_dimensions(&temp_path)?;
// Extract all text with locations
let locations = self.extract_text_with_locations(&temp_path).await?;
// Get window bounds to calculate coordinate transformation
let window_bounds = self.get_window_bounds(app_name)?;
// Clean up temp file
let _ = std::fs::remove_file(&temp_path);
// Find matching text (case-insensitive)
let search_lower = search_text.to_lowercase();
for location in locations {
if location.text.to_lowercase().contains(&search_lower) {
// Transform coordinates from screenshot space to screen space
let transformed = transform_screenshot_to_screen_coords(
location,
window_bounds,
screenshot_dims,
);
return Ok(Some(transformed));
}
}
Ok(None)
}
fn move_mouse(&self, x: i32, y: i32) -> Result<()> {
use core_graphics::event::{
CGEvent, CGEventTapLocation, CGEventType, CGMouseButton,
};
use core_graphics::event_source::{
CGEventSource, CGEventSourceStateID,
};
use core_graphics::geometry::CGPoint;
let source = CGEventSource::new(CGEventSourceStateID::HIDSystemState)
.ok().context("Failed to create event source")?;
let event = CGEvent::new_mouse_event(
source,
CGEventType::MouseMoved,
CGPoint::new(x as f64, y as f64),
CGMouseButton::Left,
).ok().context("Failed to create mouse event")?;
event.post(CGEventTapLocation::HID);
Ok(())
}
fn click_at(&self, x: i32, y: i32, _app_name: Option<&str>) -> Result<()> {
use core_graphics::event::{
CGEvent, CGEventTapLocation, CGEventType, CGMouseButton,
};
use core_graphics::event_source::{
CGEventSource, CGEventSourceStateID,
};
use core_graphics::geometry::CGPoint;
use core_graphics::display::CGDisplay;
// IMPORTANT: Coordinates passed here are in NSScreen/CGWindowListCopyWindowInfo space
// (Y=0 at BOTTOM, increases UPWARD)
// But CGEvent uses a different coordinate system (Y=0 at TOP, increases DOWNWARD)
// We need to convert: CGEvent.y = screenHeight - NSScreen.y
let screen_height = CGDisplay::main().pixels_high() as i32;
let cgevent_x = x;
let cgevent_y = screen_height - y;
tracing::debug!("click_at: NSScreen coords ({}, {}) -> CGEvent coords ({}, {}) [screen_height={}]",
x, y, cgevent_x, cgevent_y, screen_height);
let (global_x, global_y) = (cgevent_x, cgevent_y);
let point = CGPoint::new(global_x as f64, global_y as f64);
let source = CGEventSource::new(CGEventSourceStateID::HIDSystemState)
.ok().context("Failed to create event source")?;
// Move mouse to position first
let move_event = CGEvent::new_mouse_event(
source.clone(),
CGEventType::MouseMoved,
point,
CGMouseButton::Left,
).ok().context("Failed to create mouse move event")?;
move_event.post(CGEventTapLocation::HID);
std::thread::sleep(std::time::Duration::from_millis(100));
// Mouse down
let mouse_down = CGEvent::new_mouse_event(
source.clone(),
CGEventType::LeftMouseDown,
point,
CGMouseButton::Left,
).ok().context("Failed to create mouse down event")?;
mouse_down.post(CGEventTapLocation::HID);
std::thread::sleep(std::time::Duration::from_millis(50));
// Mouse up
let mouse_up = CGEvent::new_mouse_event(
source,
CGEventType::LeftMouseUp,
point,
CGMouseButton::Left,
).ok().context("Failed to create mouse up event")?;
mouse_up.post(CGEventTapLocation::HID);
Ok(())
}
}
impl MacOSController {
/// Get window bounds for an application (helper method)
fn get_window_bounds(&self, app_name: &str) -> Result<(i32, i32, i32, i32)> {
unsafe {
let window_list = CGWindowListCopyWindowInfo(
kCGWindowListOptionOnScreenOnly,
kCGNullWindowID
);
let array = CFArray::<CFDictionary>::wrap_under_create_rule(window_list);
let count = array.len();
let app_name_lower = app_name.to_lowercase();
for i in 0..count {
let dict = array.get(i).unwrap();
// Get owner name
let owner_key = CFString::from_static_string("kCGWindowOwnerName");
let owner: String = if let Some(value) = dict.find(owner_key.to_void()) {
let s: CFString = TCFType::wrap_under_get_rule(*value as *const _);
s.to_string()
} else {
continue;
};
let owner_lower = owner.to_lowercase();
// Normalize by removing spaces for exact matching
let app_name_normalized = app_name_lower.replace(" ", "");
let owner_normalized = owner_lower.replace(" ", "");
// ONLY accept exact matches (case-insensitive, with or without spaces)
// This prevents "Goose" from matching "GooseStudio"
let is_match = owner_lower == app_name_lower || owner_normalized == app_name_normalized;
if is_match {
// Get window layer to filter out menu bar windows
let layer_key = CFString::from_static_string("kCGWindowLayer");
let layer: i32 = if let Some(value) = dict.find(layer_key.to_void()) {
let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _);
num.to_i32().unwrap_or(0)
} else {
0
};
// Skip menu bar windows (layer >= 20)
if layer >= 20 {
tracing::debug!("Skipping window for '{}' at layer {} (menu bar)", owner, layer);
continue;
}
// Get window bounds to verify it's a real window
let bounds_key = CFString::from_static_string("kCGWindowBounds");
if let Some(value) = dict.find(bounds_key.to_void()) {
let bounds_dict: CFDictionary = TCFType::wrap_under_get_rule(*value as *const _);
let x_key = CFString::from_static_string("X");
let y_key = CFString::from_static_string("Y");
let width_key = CFString::from_static_string("Width");
let height_key = CFString::from_static_string("Height");
if let (Some(x_val), Some(y_val), Some(w_val), Some(h_val)) = (
bounds_dict.find(x_key.to_void()),
bounds_dict.find(y_key.to_void()),
bounds_dict.find(width_key.to_void()),
bounds_dict.find(height_key.to_void()),
) {
let x_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*x_val as *const _);
let y_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*y_val as *const _);
let w_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*w_val as *const _);
let h_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*h_val as *const _);
let x: i32 = x_num.to_i64().unwrap_or(0) as i32;
let y: i32 = y_num.to_i64().unwrap_or(0) as i32;
let w: i32 = w_num.to_i64().unwrap_or(0) as i32;
let h: i32 = h_num.to_i64().unwrap_or(0) as i32;
// Only accept windows with real bounds (>= 100x100 pixels)
if w >= 100 && h >= 100 {
tracing::info!("Found valid window bounds for '{}': x={}, y={}, w={}, h={} (layer={})", owner, x, y, w, h, layer);
return Ok((x, y, w, h));
} else {
tracing::debug!("Skipping window for '{}': too small ({}x{})", owner, w, h);
continue;
}
} else {
continue;
}
}
}
}
}
Err(anyhow::anyhow!("Could not find window bounds for '{}'", app_name))
}
}
/// Get image dimensions from a PNG file
fn get_image_dimensions(path: &str) -> Result<(i32, i32)> {
use std::fs::File;
use std::io::Read;
let mut file = File::open(path)?;
let mut buffer = vec![0u8; 24];
file.read_exact(&mut buffer)?;
// PNG signature check
if &buffer[0..8] != b"\x89PNG\r\n\x1a\n" {
anyhow::bail!("Not a valid PNG file");
}
// Read IHDR chunk (width and height are at bytes 16-23)
let width = u32::from_be_bytes([buffer[16], buffer[17], buffer[18], buffer[19]]) as i32;
let height = u32::from_be_bytes([buffer[20], buffer[21], buffer[22], buffer[23]]) as i32;
Ok((width, height))
}
/// Transform coordinates from screenshot space to screen space
///
/// The screenshot is taken of a window, and Vision OCR returns coordinates
/// relative to the screenshot image. We need to transform these to actual
/// screen coordinates for clicking.
///
/// On Retina displays, screenshots are taken at 2x resolution, so we need
/// to account for this scaling factor.
fn transform_screenshot_to_screen_coords(
location: TextLocation,
window_bounds: (i32, i32, i32, i32), // (x, y, width, height) in screen space
screenshot_dims: (i32, i32), // (width, height) in pixels
) -> TextLocation {
let (win_x, win_y, win_width, win_height) = window_bounds;
let (screenshot_width, screenshot_height) = screenshot_dims;
// Calculate scale factors
// On Retina displays, screenshot is typically 2x the window size
let scale_x = win_width as f64 / screenshot_width as f64;
let scale_y = win_height as f64 / screenshot_height as f64;
tracing::debug!("Transform: screenshot={}x{}, window={}x{} at ({},{}), scale=({:.2},{:.2})",
screenshot_width, screenshot_height, win_width, win_height, win_x, win_y, scale_x, scale_y);
// Transform coordinates from image space to screen space
// IMPORTANT: macOS screen coordinates have origin at BOTTOM-LEFT (Y increases upward)
// Image coordinates have origin at TOP-LEFT (Y increases downward)
// win_y is the BOTTOM of the window in screen coordinates
// So we need to: (win_y + win_height) to get window TOP, then subtract screenshot_y
let window_top_y = win_y + win_height;
tracing::debug!("[transform] Input location in image space: x={}, y={}, width={}, height={}",
location.x, location.y, location.width, location.height);
tracing::debug!("[transform] Scale factors: scale_x={:.4}, scale_y={:.4}", scale_x, scale_y);
let transformed_x = win_x + (location.x as f64 * scale_x) as i32;
let transformed_y = window_top_y - (location.y as f64 * scale_y) as i32;
let transformed_width = (location.width as f64 * scale_x) as i32;
let transformed_height = (location.height as f64 * scale_y) as i32;
tracing::debug!("[transform] Calculation details:");
tracing::debug!(" - transformed_x = {} + ({} * {:.4}) = {} + {:.2} = {}", win_x, location.x, scale_x, win_x, location.x as f64 * scale_x, transformed_x);
tracing::debug!(" - transformed_width = ({} * {:.4}) = {:.2} -> {}", location.width, scale_x, location.width as f64 * scale_x, transformed_width);
tracing::debug!(" - transformed_height = ({} * {:.4}) = {:.2} -> {}", location.height, scale_y, location.height as f64 * scale_y, transformed_height);
tracing::debug!("Transformed location: screenshot=({},{}) {}x{} -> screen=({},{}) {}x{}",
location.x, location.y, location.width, location.height,
transformed_x, transformed_y, transformed_width, transformed_height);
TextLocation {
text: location.text,
x: transformed_x,
y: transformed_y,
width: transformed_width,
height: transformed_height,
confidence: location.confidence,
}
}
#[path = "macos_window_matching_test.rs"]
#[cfg(test)]
mod tests;

View File

@@ -0,0 +1,45 @@
#[cfg(test)]
mod window_matching_tests {
/// Test that window name matching handles spaces correctly
///
/// Issue: When a user requests a screenshot of "Goose Studio" but the actual
/// application name is "GooseStudio" (no space), the fuzzy matching should
/// still find the window.
///
/// The fix normalizes both names by removing spaces before comparing.
#[test]
fn test_space_normalization() {
let test_cases = vec![
// (user_input, actual_app_name, should_match)
("Goose Studio", "GooseStudio", true),
("GooseStudio", "Goose Studio", true),
("Visual Studio Code", "VisualStudioCode", true),
("Google Chrome", "Google Chrome", true),
("Safari", "Safari", true),
("iTerm", "iTerm2", true), // fuzzy match
("Code", "Visual Studio Code", true), // fuzzy match
];
for (user_input, app_name, should_match) in test_cases {
let user_lower = user_input.to_lowercase();
let app_lower = app_name.to_lowercase();
let user_normalized = user_lower.replace(" ", "");
let app_normalized = app_lower.replace(" ", "");
let is_exact = app_lower == user_lower || app_normalized == user_normalized;
let is_fuzzy = app_lower.contains(&user_lower)
|| user_lower.contains(&app_lower)
|| app_normalized.contains(&user_normalized)
|| user_normalized.contains(&app_normalized);
let matches = is_exact || is_fuzzy;
assert_eq!(
matches, should_match,
"Expected '{}' vs '{}' to match={}, but got match={}",
user_input, app_name, should_match, matches
);
}
}
}

View File

@@ -0,0 +1,8 @@
#[cfg(target_os = "macos")]
pub mod macos;
#[cfg(target_os = "linux")]
pub mod linux;
#[cfg(target_os = "windows")]
pub mod windows;

View File

@@ -0,0 +1,167 @@
use crate::{ComputerController, types::*};
use anyhow::Result;
use async_trait::async_trait;
use tesseract::Tesseract;
use uuid::Uuid;
pub struct WindowsController {
// Placeholder for Windows-specific state
}
impl WindowsController {
pub fn new() -> Result<Self> {
tracing::warn!("Windows computer control not fully implemented");
Ok(Self {})
}
}
#[async_trait]
impl ComputerController for WindowsController {
async fn move_mouse(&self, _x: i32, _y: i32) -> Result<()> {
anyhow::bail!("Windows implementation not yet available")
}
async fn click(&self, _button: MouseButton) -> Result<()> {
anyhow::bail!("Windows implementation not yet available")
}
async fn double_click(&self, _button: MouseButton) -> Result<()> {
anyhow::bail!("Windows implementation not yet available")
}
async fn type_text(&self, _text: &str) -> Result<()> {
anyhow::bail!("Windows implementation not yet available")
}
async fn press_key(&self, _key: &str) -> Result<()> {
anyhow::bail!("Windows implementation not yet available")
}
async fn list_windows(&self) -> Result<Vec<Window>> {
anyhow::bail!("Windows implementation not yet available")
}
async fn focus_window(&self, _window_id: &str) -> Result<()> {
anyhow::bail!("Windows implementation not yet available")
}
async fn get_window_bounds(&self, _window_id: &str) -> Result<Rect> {
anyhow::bail!("Windows implementation not yet available")
}
async fn find_element(&self, _selector: &ElementSelector) -> Result<Option<UIElement>> {
anyhow::bail!("Windows implementation not yet available")
}
async fn get_element_text(&self, _element_id: &str) -> Result<String> {
anyhow::bail!("Windows implementation not yet available")
}
async fn get_element_bounds(&self, _element_id: &str) -> Result<Rect> {
anyhow::bail!("Windows implementation not yet available")
}
async fn take_screenshot(&self, _path: &str, _region: Option<Rect>, _window_id: Option<&str>) -> Result<()> {
// Enforce that window_id must be provided
if _window_id.is_none() {
anyhow::bail!("window_id is required. You must specify which window to capture (e.g., 'Chrome', 'Terminal', 'Notepad'). Use list_windows to see available windows.");
}
anyhow::bail!("Windows implementation not yet available")
}
async fn extract_text_from_screen(&self, _region: Rect, _window_id: &str) -> Result<String> {
anyhow::bail!("Windows implementation not yet available")
}
async fn extract_text_from_image(&self, _path: &str) -> Result<OCRResult> {
// Check if tesseract is available on the system
let tesseract_check = std::process::Command::new("where")
.arg("tesseract")
.output();
if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() {
anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\
To install tesseract on Windows:\n \
1. Download the installer from: https://github.com/UB-Mannheim/tesseract/wiki\n \
2. Run the installer and follow the instructions\n \
3. Add tesseract to your PATH environment variable\n \
4. Restart your terminal/command prompt\n\n\
After installation, restart your terminal and try again.");
}
// Initialize Tesseract
let tess = Tesseract::new(None, Some("eng"))
.map_err(|e| {
anyhow::anyhow!("Failed to initialize Tesseract: {}\n\n\
This usually means:\n1. Tesseract is not properly installed\n\
2. Language data files are missing\n\nTo fix:\n \
1. Reinstall tesseract from https://github.com/UB-Mannheim/tesseract/wiki\n \
2. Make sure to select 'Additional language data' during installation\n \
3. Ensure tesseract is in your PATH", e)
})?;
let text = tess.set_image(_path)
.map_err(|e| anyhow::anyhow!("Failed to load image '{}': {}", _path, e))?
.get_text()
.map_err(|e| anyhow::anyhow!("Failed to extract text from image: {}", e))?;
// Get confidence (simplified - would need more complex API calls for per-word confidence)
let confidence = 0.85; // Placeholder
Ok(OCRResult {
text,
confidence,
bounds: Rect { x: 0, y: 0, width: 0, height: 0 }, // Would need image dimensions
})
}
async fn find_text_on_screen(&self, _text: &str) -> Result<Option<Point>> {
// Check if tesseract is available on the system
let tesseract_check = std::process::Command::new("where")
.arg("tesseract")
.output();
if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() {
anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\
To install tesseract on Windows:\n \
1. Download the installer from: https://github.com/UB-Mannheim/tesseract/wiki\n \
2. Run the installer and follow the instructions\n \
3. Add tesseract to your PATH environment variable\n \
4. Restart your terminal/command prompt\n\n\
After installation, restart your terminal and try again.");
}
// Take full screen screenshot
let temp_path = format!("C:\\\\Temp\\\\g3_ocr_search_{}.png", uuid::Uuid::new_v4());
self.take_screenshot(&temp_path, None, None).await?;
// Use Tesseract to find text with bounding boxes
let tess = Tesseract::new(None, Some("eng"))
.map_err(|e| {
anyhow::anyhow!("Failed to initialize Tesseract: {}\n\n\
This usually means:\n1. Tesseract is not properly installed\n\
2. Language data files are missing\n\nTo fix:\n \
1. Reinstall tesseract from https://github.com/UB-Mannheim/tesseract/wiki\n \
2. Make sure to select 'Additional language data' during installation\n \
3. Ensure tesseract is in your PATH", e)
})?;
let full_text = tess.set_image(temp_path.as_str())
.map_err(|e| anyhow::anyhow!("Failed to load screenshot: {}", e))?
.get_text()
.map_err(|e| anyhow::anyhow!("Failed to extract text from screen: {}", e))?;
// Clean up temp file
let _ = std::fs::remove_file(&temp_path);
// Simple text search - full implementation would use get_component_images
// to get bounding boxes for each word
if full_text.contains(_text) {
tracing::warn!("Text found but precise coordinates not available in simplified implementation");
Ok(Some(Point { x: 0, y: 0 }))
} else {
Ok(None)
}
}
}

View File

@@ -0,0 +1,19 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct Rect {
pub x: i32,
pub y: i32,
pub width: i32,
pub height: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextLocation {
pub text: String,
pub x: i32,
pub y: i32,
pub width: i32,
pub height: i32,
pub confidence: f32,
}

View File

@@ -0,0 +1,111 @@
pub mod safari;
use anyhow::Result;
use async_trait::async_trait;
use serde_json::Value;
/// WebDriver controller for browser automation
#[async_trait]
pub trait WebDriverController: Send + Sync {
/// Navigate to a URL
async fn navigate(&mut self, url: &str) -> Result<()>;
/// Get the current URL
async fn current_url(&self) -> Result<String>;
/// Get the page title
async fn title(&self) -> Result<String>;
/// Find an element by CSS selector
async fn find_element(&mut self, selector: &str) -> Result<WebElement>;
/// Find multiple elements by CSS selector
async fn find_elements(&mut self, selector: &str) -> Result<Vec<WebElement>>;
/// Execute JavaScript in the browser
async fn execute_script(&mut self, script: &str, args: Vec<Value>) -> Result<Value>;
/// Get the page source (HTML)
async fn page_source(&self) -> Result<String>;
/// Take a screenshot and save to path
async fn screenshot(&mut self, path: &str) -> Result<()>;
/// Close the current window/tab
async fn close(&mut self) -> Result<()>;
/// Quit the browser session
async fn quit(self) -> Result<()>;
}
/// Represents a web element in the DOM
pub struct WebElement {
pub(crate) inner: fantoccini::elements::Element,
}
impl WebElement {
/// Click the element
pub async fn click(&mut self) -> Result<()> {
self.inner.click().await?;
Ok(())
}
/// Send keys/text to the element
pub async fn send_keys(&mut self, text: &str) -> Result<()> {
self.inner.send_keys(text).await?;
Ok(())
}
/// Clear the element's content (for input fields)
pub async fn clear(&mut self) -> Result<()> {
self.inner.clear().await?;
Ok(())
}
/// Get the element's text content
pub async fn text(&self) -> Result<String> {
Ok(self.inner.text().await?)
}
/// Get an attribute value
pub async fn attr(&self, name: &str) -> Result<Option<String>> {
Ok(self.inner.attr(name).await?)
}
/// Get a property value
pub async fn prop(&self, name: &str) -> Result<Option<String>> {
Ok(self.inner.prop(name).await?)
}
/// Get the element's HTML
pub async fn html(&self, inner: bool) -> Result<String> {
Ok(self.inner.html(inner).await?)
}
/// Check if element is displayed
pub async fn is_displayed(&self) -> Result<bool> {
Ok(self.inner.is_displayed().await?)
}
/// Check if element is enabled
pub async fn is_enabled(&self) -> Result<bool> {
Ok(self.inner.is_enabled().await?)
}
/// Check if element is selected (for checkboxes/radio buttons)
pub async fn is_selected(&self) -> Result<bool> {
Ok(self.inner.is_selected().await?)
}
/// Find a child element by CSS selector
pub async fn find_element(&mut self, selector: &str) -> Result<WebElement> {
let elem = self.inner.find(fantoccini::Locator::Css(selector)).await?;
Ok(WebElement { inner: elem })
}
/// Find multiple child elements by CSS selector
pub async fn find_elements(&mut self, selector: &str) -> Result<Vec<WebElement>> {
let elems = self.inner.find_all(fantoccini::Locator::Css(selector)).await?;
Ok(elems.into_iter().map(|inner| WebElement { inner }).collect())
}
}

View File

@@ -0,0 +1,212 @@
use super::{WebDriverController, WebElement};
use anyhow::{Context, Result};
use async_trait::async_trait;
use fantoccini::{Client, ClientBuilder};
use serde_json::Value;
use std::time::Duration;
/// SafariDriver WebDriver controller
pub struct SafariDriver {
client: Client,
}
impl SafariDriver {
/// Create a new SafariDriver instance
///
/// This will connect to SafariDriver running on the default port (4444).
/// Make sure to enable "Allow Remote Automation" in Safari's Develop menu first.
///
/// You can start SafariDriver manually with:
/// ```bash
/// /usr/bin/safaridriver --enable
/// ```
pub async fn new() -> Result<Self> {
Self::with_port(4444).await
}
/// Create a new SafariDriver instance with a custom port
pub async fn with_port(port: u16) -> Result<Self> {
let url = format!("http://localhost:{}", port);
let mut caps = serde_json::Map::new();
caps.insert("browserName".to_string(), Value::String("safari".to_string()));
let client = ClientBuilder::native()
.capabilities(caps)
.connect(&url)
.await
.context("Failed to connect to SafariDriver. Make sure SafariDriver is running and 'Allow Remote Automation' is enabled in Safari's Develop menu.")?;
Ok(Self { client })
}
/// Go back in browser history
pub async fn back(&mut self) -> Result<()> {
self.client.back().await?;
Ok(())
}
/// Go forward in browser history
pub async fn forward(&mut self) -> Result<()> {
self.client.forward().await?;
Ok(())
}
/// Refresh the current page
pub async fn refresh(&mut self) -> Result<()> {
self.client.refresh().await?;
Ok(())
}
/// Get all window handles
pub async fn window_handles(&mut self) -> Result<Vec<String>> {
let handles = self.client.windows().await?;
Ok(handles.into_iter()
.map(|h| h.into())
.collect())
}
/// Switch to a window by handle
pub async fn switch_to_window(&mut self, handle: &str) -> Result<()> {
let window_handle: fantoccini::wd::WindowHandle = handle.to_string().try_into()?;
self.client.switch_to_window(window_handle).await?;
Ok(())
}
/// Get the current window handle
pub async fn current_window_handle(&mut self) -> Result<String> {
Ok(self.client.window().await?.into())
}
/// Close the current window
pub async fn close_window(&mut self) -> Result<()> {
self.client.close_window().await?;
Ok(())
}
/// Create a new window/tab
pub async fn new_window(&mut self, is_tab: bool) -> Result<String> {
let window_type = if is_tab { "tab" } else { "window" };
let response = self.client.new_window(window_type == "tab").await?;
Ok(response.handle.into())
}
/// Get cookies
pub async fn get_cookies(&mut self) -> Result<Vec<fantoccini::cookies::Cookie<'static>>> {
Ok(self.client.get_all_cookies().await?)
}
/// Add a cookie
pub async fn add_cookie(&mut self, cookie: fantoccini::cookies::Cookie<'static>) -> Result<()> {
self.client.add_cookie(cookie).await?;
Ok(())
}
/// Delete all cookies
pub async fn delete_all_cookies(&mut self) -> Result<()> {
self.client.delete_all_cookies().await?;
Ok(())
}
/// Wait for an element to appear (with timeout)
pub async fn wait_for_element(&mut self, selector: &str, timeout: Duration) -> Result<WebElement> {
let start = std::time::Instant::now();
let poll_interval = Duration::from_millis(100);
loop {
if let Ok(elem) = self.find_element(selector).await {
return Ok(elem);
}
if start.elapsed() >= timeout {
anyhow::bail!("Timeout waiting for element: {}", selector);
}
tokio::time::sleep(poll_interval).await;
}
}
/// Wait for an element to be visible (with timeout)
pub async fn wait_for_visible(&mut self, selector: &str, timeout: Duration) -> Result<WebElement> {
let start = std::time::Instant::now();
let poll_interval = Duration::from_millis(100);
loop {
if let Ok(elem) = self.find_element(selector).await {
if elem.is_displayed().await.unwrap_or(false) {
return Ok(elem);
}
}
if start.elapsed() >= timeout {
anyhow::bail!("Timeout waiting for element to be visible: {}", selector);
}
tokio::time::sleep(poll_interval).await;
}
}
}
#[async_trait]
impl WebDriverController for SafariDriver {
async fn navigate(&mut self, url: &str) -> Result<()> {
self.client.goto(url).await?;
Ok(())
}
async fn current_url(&self) -> Result<String> {
Ok(self.client.current_url().await?.to_string())
}
async fn title(&self) -> Result<String> {
Ok(self.client.title().await?)
}
async fn find_element(&mut self, selector: &str) -> Result<WebElement> {
let elem = self.client.find(fantoccini::Locator::Css(selector)).await
.context(format!("Failed to find element with selector: {}", selector))?;
Ok(WebElement { inner: elem })
}
async fn find_elements(&mut self, selector: &str) -> Result<Vec<WebElement>> {
let elems = self.client.find_all(fantoccini::Locator::Css(selector)).await?;
Ok(elems.into_iter().map(|inner| WebElement { inner }).collect())
}
async fn execute_script(&mut self, script: &str, args: Vec<Value>) -> Result<Value> {
Ok(self.client.execute(script, args).await?)
}
async fn page_source(&self) -> Result<String> {
Ok(self.client.source().await?)
}
async fn screenshot(&mut self, path: &str) -> Result<()> {
let screenshot_data = self.client.screenshot().await?;
// Expand tilde in path
let expanded_path = shellexpand::tilde(path);
let path_str = expanded_path.as_ref();
// Create parent directories if needed
if let Some(parent) = std::path::Path::new(path_str).parent() {
std::fs::create_dir_all(parent)
.context("Failed to create parent directories for screenshot")?;
}
std::fs::write(path_str, screenshot_data)
.context("Failed to write screenshot to file")?;
Ok(())
}
async fn close(&mut self) -> Result<()> {
self.client.close_window().await?;
Ok(())
}
async fn quit(mut self) -> Result<()> {
self.client.close().await?;
Ok(())
}
}

View File

@@ -0,0 +1,17 @@
use g3_computer_control::*;
#[tokio::test]
async fn test_screenshot() {
let controller = create_controller().expect("Failed to create controller");
// Take screenshot
let path = "/tmp/test_screenshot.png";
let result = controller.take_screenshot(path, None, None).await;
assert!(result.is_ok(), "Failed to take screenshot: {:?}", result.err());
// Verify file exists
assert!(std::path::Path::new(path).exists(), "Screenshot file was not created");
// Clean up
let _ = std::fs::remove_file(path);
}

View File

@@ -0,0 +1,24 @@
// swift-tools-version:5.9
import PackageDescription
let package = Package(
name: "VisionBridge",
platforms: [
.macOS(.v11)
],
products: [
.library(
name: "VisionBridge",
type: .dynamic,
targets: ["VisionBridge"]
),
],
targets: [
.target(
name: "VisionBridge",
dependencies: [],
path: "Sources/VisionBridge",
publicHeadersPath: "."
),
]
)

View File

@@ -0,0 +1,39 @@
#ifndef VisionBridge_h
#define VisionBridge_h
#include <stdint.h>
#include <stdbool.h>
#ifdef __cplusplus
extern "C" {
#endif
// Text box structure for FFI
typedef struct {
const char* text;
uint32_t text_len;
int32_t x;
int32_t y;
int32_t width;
int32_t height;
float confidence;
} VisionTextBox;
// Recognize text in an image and return bounding boxes
// Returns true on success, false on failure
// Caller must free the returned boxes using vision_free_boxes
bool vision_recognize_text(
const char* image_path,
uint32_t image_path_len,
VisionTextBox** out_boxes,
uint32_t* out_count
);
// Free memory allocated by vision_recognize_text
void vision_free_boxes(VisionTextBox* boxes, uint32_t count);
#ifdef __cplusplus
}
#endif
#endif /* VisionBridge_h */

View File

@@ -0,0 +1,145 @@
import Foundation
import Vision
import AppKit
import CoreGraphics
// MARK: - C Bridge Functions
@_cdecl("vision_recognize_text")
public func vision_recognize_text(
_ imagePath: UnsafePointer<CChar>,
_ imagePathLen: UInt32,
_ outBoxes: UnsafeMutablePointer<UnsafeMutableRawPointer?>,
_ outCount: UnsafeMutablePointer<UInt32>
) -> Bool {
// Convert C string to Swift String
guard let pathData = Data(bytes: imagePath, count: Int(imagePathLen)).withUnsafeBytes({
String(bytes: $0, encoding: .utf8)
}) else {
return false
}
let path = pathData.trimmingCharacters(in: .whitespaces)
// Load image
guard let image = NSImage(contentsOfFile: path),
let cgImage = image.cgImage(forProposedRect: nil, context: nil, hints: nil) else {
return false
}
// Perform OCR
var textBoxes: [CTextBox] = []
let semaphore = DispatchSemaphore(value: 0)
var success = false
let request = VNRecognizeTextRequest { request, error in
defer { semaphore.signal() }
if let error = error {
print("Vision OCR error: \(error.localizedDescription)")
return
}
guard let observations = request.results as? [VNRecognizedTextObservation] else {
return
}
let imageSize = CGSize(width: cgImage.width, height: cgImage.height)
for observation in observations {
guard let candidate = observation.topCandidates(1).first else { continue }
let text = candidate.string
let boundingBox = observation.boundingBox
// Convert normalized coordinates (bottom-left origin) to pixel coordinates (top-left origin)
let x = Int32(boundingBox.origin.x * imageSize.width)
let y = Int32((1.0 - boundingBox.origin.y - boundingBox.height) * imageSize.height)
let width = Int32(boundingBox.width * imageSize.width)
let height = Int32(boundingBox.height * imageSize.height)
// Allocate C string for text
let cString = strdup(text)
textBoxes.append(CTextBox(
text: cString,
text_len: UInt32(text.utf8.count),
x: x,
y: y,
width: width,
height: height,
confidence: observation.confidence
))
}
success = true
}
// Configure request for best accuracy
request.recognitionLevel = .accurate
request.usesLanguageCorrection = true
request.recognitionLanguages = ["en-US"]
// Perform request
let handler = VNImageRequestHandler(cgImage: cgImage, options: [:])
do {
try handler.perform([request])
} catch {
print("Vision request failed: \(error.localizedDescription)")
return false
}
// Wait for completion
semaphore.wait()
if !success {
return false
}
// Allocate array for results
let boxesPtr = UnsafeMutablePointer<CTextBox>.allocate(capacity: textBoxes.count)
for (index, box) in textBoxes.enumerated() {
boxesPtr[index] = box
}
outBoxes.pointee = UnsafeMutableRawPointer(boxesPtr)
outCount.pointee = UInt32(textBoxes.count)
return true
}
@_cdecl("vision_free_boxes")
public func vision_free_boxes(
_ boxes: UnsafeMutableRawPointer,
_ count: UInt32
) {
let typedBoxes = boxes.assumingMemoryBound(to: CTextBox.self)
for i in 0..<Int(count) {
if let text = typedBoxes[i].text {
free(UnsafeMutableRawPointer(mutating: text))
}
}
typedBoxes.deallocate()
}
// MARK: - C-Compatible Structure
public struct CTextBox {
public let text: UnsafePointer<CChar>?
public let text_len: UInt32
public let x: Int32
public let y: Int32
public let width: Int32
public let height: Int32
public let confidence: Float
public init(text: UnsafePointer<CChar>?, text_len: UInt32, x: Int32, y: Int32, width: Int32, height: Int32, confidence: Float) {
self.text = text
self.text_len = text_len
self.x = x
self.y = y
self.width = width
self.height = height
self.confidence = confidence
}
}

View File

@@ -12,3 +12,6 @@ thiserror = { workspace = true }
toml = "0.8"
shellexpand = "3.0"
dirs = "5.0"
[dev-dependencies]
tempfile = "3.8"

View File

@@ -6,6 +6,9 @@ use std::path::Path;
pub struct Config {
pub providers: ProvidersConfig,
pub agent: AgentConfig,
pub computer_control: ComputerControlConfig,
pub webdriver: WebDriverConfig,
pub macax: MacAxConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -15,6 +18,8 @@ pub struct ProvidersConfig {
pub databricks: Option<DatabricksConfig>,
pub embedded: Option<EmbeddedConfig>,
pub default_provider: String,
pub coach: Option<String>, // Provider to use for coach in autonomous mode
pub player: Option<String>, // Provider to use for player in autonomous mode
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -62,6 +67,51 @@ pub struct AgentConfig {
pub timeout_seconds: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComputerControlConfig {
pub enabled: bool,
pub require_confirmation: bool,
pub max_actions_per_second: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebDriverConfig {
pub enabled: bool,
pub safari_port: u16,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MacAxConfig {
pub enabled: bool,
}
impl Default for MacAxConfig {
fn default() -> Self {
Self {
enabled: false,
}
}
}
impl Default for WebDriverConfig {
fn default() -> Self {
Self {
enabled: false,
safari_port: 4444,
}
}
}
impl Default for ComputerControlConfig {
fn default() -> Self {
Self {
enabled: false, // Disabled by default for safety
require_confirmation: true,
max_actions_per_second: 5,
}
}
}
impl Default for Config {
fn default() -> Self {
Self {
@@ -78,12 +128,17 @@ impl Default for Config {
}),
embedded: None,
default_provider: "databricks".to_string(),
coach: None, // Will use default_provider if not specified
player: None, // Will use default_provider if not specified
},
agent: AgentConfig {
max_context_length: 8192,
enable_streaming: true,
timeout_seconds: 60,
},
computer_control: ComputerControlConfig::default(),
webdriver: WebDriverConfig::default(),
macax: MacAxConfig::default(),
}
}
}
@@ -188,12 +243,17 @@ impl Config {
threads: Some(8),
}),
default_provider: "embedded".to_string(),
coach: None, // Will use default_provider if not specified
player: None, // Will use default_provider if not specified
},
agent: AgentConfig {
max_context_length: 8192,
enable_streaming: true,
timeout_seconds: 60,
},
computer_control: ComputerControlConfig::default(),
webdriver: WebDriverConfig::default(),
macax: MacAxConfig::default(),
}
}
@@ -262,4 +322,67 @@ impl Config {
Ok(config)
}
/// Get the provider to use for coach mode in autonomous execution
pub fn get_coach_provider(&self) -> &str {
self.providers.coach
.as_deref()
.unwrap_or(&self.providers.default_provider)
}
/// Get the provider to use for player mode in autonomous execution
pub fn get_player_provider(&self) -> &str {
self.providers.player
.as_deref()
.unwrap_or(&self.providers.default_provider)
}
/// Create a copy of the config with a different default provider
pub fn with_provider_override(&self, provider: &str) -> Result<Self> {
// Validate that the provider is configured
match provider {
"anthropic" if self.providers.anthropic.is_none() => {
return Err(anyhow::anyhow!(
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
provider, provider
));
}
"databricks" if self.providers.databricks.is_none() => {
return Err(anyhow::anyhow!(
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
provider, provider
));
}
"embedded" if self.providers.embedded.is_none() => {
return Err(anyhow::anyhow!(
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
provider, provider
));
}
"openai" if self.providers.openai.is_none() => {
return Err(anyhow::anyhow!(
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
provider, provider
));
}
_ => {} // Provider is configured or unknown (will be caught later)
}
let mut config = self.clone();
config.providers.default_provider = provider.to_string();
Ok(config)
}
/// Create a copy of the config for coach mode in autonomous execution
pub fn for_coach(&self) -> Result<Self> {
self.with_provider_override(self.get_coach_provider())
}
/// Create a copy of the config for player mode in autonomous execution
pub fn for_player(&self) -> Result<Self> {
self.with_provider_override(self.get_player_provider())
}
}
#[cfg(test)]
mod tests;

View File

@@ -0,0 +1,131 @@
#[cfg(test)]
mod tests {
use crate::Config;
use std::fs;
use tempfile::TempDir;
#[test]
fn test_coach_player_providers() {
// Create a temporary directory for the test config
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("test_config.toml");
// Write a test configuration with coach and player providers
let config_content = r#"
[providers]
default_provider = "databricks"
coach = "anthropic"
player = "embedded"
[providers.databricks]
host = "https://test.databricks.com"
token = "test-token"
model = "test-model"
[providers.anthropic]
api_key = "test-key"
model = "claude-3"
[providers.embedded]
model_path = "test.gguf"
model_type = "llama"
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
"#;
fs::write(&config_path, config_content).unwrap();
// Load the configuration
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
// Test that the providers are correctly identified
assert_eq!(config.providers.default_provider, "databricks");
assert_eq!(config.get_coach_provider(), "anthropic");
assert_eq!(config.get_player_provider(), "embedded");
// Test creating coach config
let coach_config = config.for_coach().unwrap();
assert_eq!(coach_config.providers.default_provider, "anthropic");
// Test creating player config
let player_config = config.for_player().unwrap();
assert_eq!(player_config.providers.default_provider, "embedded");
}
#[test]
fn test_coach_player_fallback_to_default() {
// Create a temporary directory for the test config
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("test_config.toml");
// Write a test configuration WITHOUT coach and player providers
let config_content = r#"
[providers]
default_provider = "databricks"
[providers.databricks]
host = "https://test.databricks.com"
token = "test-token"
model = "test-model"
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
"#;
fs::write(&config_path, config_content).unwrap();
// Load the configuration
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
// Test that coach and player fall back to default provider
assert_eq!(config.get_coach_provider(), "databricks");
assert_eq!(config.get_player_provider(), "databricks");
// Test creating coach config (should use default)
let coach_config = config.for_coach().unwrap();
assert_eq!(coach_config.providers.default_provider, "databricks");
// Test creating player config (should use default)
let player_config = config.for_player().unwrap();
assert_eq!(player_config.providers.default_provider, "databricks");
}
#[test]
fn test_invalid_provider_error() {
// Create a temporary directory for the test config
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("test_config.toml");
// Write a test configuration with an unconfigured provider
let config_content = r#"
[providers]
default_provider = "databricks"
coach = "openai" # OpenAI is not configured
[providers.databricks]
host = "https://test.databricks.com"
token = "test-token"
model = "test-model"
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
"#;
fs::write(&config_path, config_content).unwrap();
// Load the configuration
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
// Test that trying to create a coach config with unconfigured provider fails
let result = config.for_coach();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not configured"));
}
}

View File

@@ -8,6 +8,7 @@ description = "Core engine for G3 AI coding agent"
g3-providers = { path = "../g3-providers" }
g3-config = { path = "../g3-config" }
g3-execution = { path = "../g3-execution" }
g3-computer-control = { path = "../g3-computer-control" }
tokio = { workspace = true }
reqwest = { workspace = true }
anyhow = { workspace = true }
@@ -23,3 +24,4 @@ futures-util = "0.3"
chrono = { version = "0.4", features = ["serde"] }
rand = "0.8"
regex = "1.0"
shellexpand = "3.1"

View File

@@ -156,15 +156,15 @@ pub fn fixed_filter_json_tool_calls(content: &str) -> String {
}
// No JSON tool call detected, return only the new content we haven't returned yet
let new_content = if state.buffer.len() > state.content_returned_up_to {
if state.buffer.len() > state.content_returned_up_to {
let result = state.buffer[state.content_returned_up_to..].to_string();
state.content_returned_up_to = state.buffer.len();
result
} else {
String::new()
};
new_content
}
})
}

File diff suppressed because it is too large Load Diff

View File

@@ -104,6 +104,7 @@ impl Project {
}
/// Recursively check a directory for implementation files
#[allow(clippy::only_used_in_recursion)]
fn check_dir_for_implementation_files(&self, dir: &Path) -> bool {
// Common source file extensions
let extensions = vec![

View File

@@ -0,0 +1,37 @@
// Test to verify take_screenshot requires window_id
#[cfg(test)]
mod take_screenshot_tests {
use super::*;
use serde_json::json;
#[test]
fn test_take_screenshot_requires_window_id() {
// Create a tool call without window_id
let tool_call = ToolCall {
tool: "take_screenshot".to_string(),
args: json!({
"path": "test.png"
}),
};
// Verify that window_id is missing
assert!(tool_call.args.get("window_id").is_none());
}
#[test]
fn test_take_screenshot_with_window_id() {
// Create a tool call with window_id
let tool_call = ToolCall {
tool: "take_screenshot".to_string(),
args: json!({
"path": "test.png",
"window_id": "Safari"
}),
};
// Verify that window_id is present
assert!(tool_call.args.get("window_id").is_some());
assert_eq!(tool_call.args.get("window_id").unwrap().as_str().unwrap(), "Safari");
}
}

View File

@@ -0,0 +1,36 @@
#[cfg(test)]
mod tilde_expansion_tests {
use std::env;
#[test]
fn test_tilde_expansion() {
// Test that shellexpand works
let path_with_tilde = "~/test.txt";
let expanded = shellexpand::tilde(path_with_tilde);
// Get the actual home directory
let home = env::var("HOME").expect("HOME environment variable not set");
// Verify expansion happened
assert_eq!(expanded.as_ref(), format!("{}/test.txt", home));
assert!(!expanded.contains("~"));
}
#[test]
fn test_tilde_expansion_with_subdirs() {
let path_with_tilde = "~/Documents/test.txt";
let expanded = shellexpand::tilde(path_with_tilde);
let home = env::var("HOME").expect("HOME environment variable not set");
assert_eq!(expanded.as_ref(), format!("{}/Documents/test.txt", home));
}
#[test]
fn test_no_tilde_unchanged() {
let path_without_tilde = "/absolute/path/test.txt";
let expanded = shellexpand::tilde(path_without_tilde);
assert_eq!(expanded.as_ref(), path_without_tilde);
}
}

View File

@@ -17,6 +17,9 @@ pub trait UiWriter: Send + Sync {
/// Print a context window status message
fn print_context_status(&self, message: &str);
/// Print a context thinning success message with highlight and animation
fn print_context_thinning(&self, message: &str);
/// Print a tool execution header
fn print_tool_header(&self, tool_name: &str);
@@ -49,6 +52,10 @@ pub trait UiWriter: Send + Sync {
/// Flush any buffered output
fn flush(&self);
/// Returns true if this UI writer wants full, untruncated output
/// Default is false (truncate for human readability)
fn wants_full_output(&self) -> bool { false }
}
/// A no-op implementation for when UI output is not needed
@@ -60,6 +67,7 @@ impl UiWriter for NullUiWriter {
fn print_inline(&self, _message: &str) {}
fn print_system_prompt(&self, _prompt: &str) {}
fn print_context_status(&self, _message: &str) {}
fn print_context_thinning(&self, _message: &str) {}
fn print_tool_header(&self, _tool_name: &str) {}
fn print_tool_arg(&self, _key: &str, _value: &str) {}
fn print_tool_output_header(&self) {}
@@ -71,4 +79,5 @@ impl UiWriter for NullUiWriter {
fn print_agent_response(&self, _content: &str) {}
fn notify_sse_received(&self) {}
fn flush(&self) {}
fn wants_full_output(&self) -> bool { false }
}

View File

@@ -0,0 +1,270 @@
use g3_core::ContextWindow;
use g3_providers::{Message, MessageRole};
#[test]
fn test_thinning_thresholds() {
let mut context = ContextWindow::new(10000);
// At 0%, should not thin
assert!(!context.should_thin());
// Simulate reaching 50% usage
context.used_tokens = 5000;
assert!(context.should_thin());
// After thinning at 50%, should not thin again until next threshold
context.last_thinning_percentage = 50;
assert!(!context.should_thin());
// At 60%, should thin again
context.used_tokens = 6000;
assert!(context.should_thin());
// After thinning at 60%, should not thin
context.last_thinning_percentage = 60;
assert!(!context.should_thin());
// At 70%, should thin
context.used_tokens = 7000;
assert!(context.should_thin());
// At 80%, should thin
context.last_thinning_percentage = 70;
context.used_tokens = 8000;
assert!(context.should_thin());
// After 80%, should not thin (compaction takes over)
context.last_thinning_percentage = 80;
context.used_tokens = 8500;
assert!(!context.should_thin());
}
#[test]
fn test_thin_context_basic() {
let mut context = ContextWindow::new(10000);
// Add some messages to the first third
for i in 0..9 {
if i % 2 == 0 {
context.add_message(Message {
role: MessageRole::Assistant,
content: format!("Assistant message {}", i),
});
} else {
// Add tool results with varying sizes
let content = if i == 1 {
// Large tool result (> 1000 chars)
format!("Tool result: {}", "x".repeat(1500))
} else if i == 3 {
// Another large tool result
format!("Tool result: {}", "y".repeat(2000))
} else {
// Small tool result (< 1000 chars)
format!("Tool result: small result {}", i)
};
context.add_message(Message {
role: MessageRole::User,
content,
});
}
}
// Trigger thinning at 50%
context.used_tokens = 5000;
let (summary, _chars_saved) = context.thin_context();
println!("Thinning summary: {}", summary);
// Should have thinned at least 1 large tool result in the first third
assert!(summary.contains("1 tool result"), "Summary was: {}", summary);
assert!(summary.contains("50%"));
// Check that the large tool results were replaced
let first_third_end = context.conversation_history.len() / 3;
for i in 0..first_third_end {
if let Some(msg) = context.conversation_history.get(i) {
if matches!(msg.role, MessageRole::User) && msg.content.starts_with("Tool result:") {
if msg.content.len() > 1000 {
panic!("Found un-thinned large tool result at index {}", i);
}
}
}
}
}
#[test]
fn test_thin_write_file_tool_calls() {
let mut context = ContextWindow::new(10000);
// Add some messages including a write_file tool call with large content
context.add_message(Message {
role: MessageRole::User,
content: "Please create a large file".to_string(),
});
// Add an assistant message with a write_file tool call containing large content
let large_content = "x".repeat(1500);
let tool_call_json = format!(
r#"{{"tool": "write_file", "args": {{"file_path": "test.txt", "content": "{}"}}}}"#,
large_content
);
context.add_message(Message {
role: MessageRole::Assistant,
content: format!("I'll create that file.\n\n{}", tool_call_json),
});
context.add_message(Message {
role: MessageRole::User,
content: "Tool result: ✅ Successfully wrote 1500 lines".to_string(),
});
// Add more messages to ensure we have enough for "first third" logic
for i in 0..6 {
context.add_message(Message {
role: MessageRole::Assistant,
content: format!("Response {}", i),
});
}
// Trigger thinning at 50%
context.used_tokens = 5000;
let (summary, _chars_saved) = context.thin_context();
println!("Thinning summary: {}", summary);
// Should have thinned the write_file tool call
assert!(summary.contains("tool call") || summary.contains("chars saved"));
// Check that the large content was replaced with a file reference
let first_third_end = context.conversation_history.len() / 3;
for i in 0..first_third_end {
if let Some(msg) = context.conversation_history.get(i) {
if matches!(msg.role, MessageRole::Assistant) && msg.content.contains("write_file") {
// The content should now reference an external file
assert!(msg.content.contains("<content saved to"));
assert!(!msg.content.contains(&large_content));
}
}
}
}
#[test]
fn test_thin_str_replace_tool_calls() {
let mut context = ContextWindow::new(10000);
// Add some messages including a str_replace tool call with large diff
context.add_message(Message {
role: MessageRole::User,
content: "Please update the file".to_string(),
});
// Add an assistant message with a str_replace tool call containing large diff
let large_diff = format!("--- old\n{}\n+++ new\n{}", "-old line\n".repeat(100), "+new line\n".repeat(100));
let tool_call_json = format!(
r#"{{"tool": "str_replace", "args": {{"file_path": "test.txt", "diff": "{}"}}}}"#,
large_diff.replace('\n', "\\n")
);
context.add_message(Message {
role: MessageRole::Assistant,
content: format!("I'll update that file.\n\n{}", tool_call_json),
});
context.add_message(Message {
role: MessageRole::User,
content: "Tool result: ✅ applied unified diff".to_string(),
});
// Add more messages to ensure we have enough for "first third" logic
for i in 0..6 {
context.add_message(Message {
role: MessageRole::Assistant,
content: format!("Response {}", i),
});
}
// Trigger thinning at 50%
context.used_tokens = 5000;
let (summary, _chars_saved) = context.thin_context();
println!("Thinning summary: {}", summary);
// Should have thinned the str_replace tool call
assert!(summary.contains("tool call") || summary.contains("chars saved"));
// Check that the large diff was replaced with a file reference
let first_third_end = context.conversation_history.len() / 3;
for i in 0..first_third_end {
if let Some(msg) = context.conversation_history.get(i) {
if matches!(msg.role, MessageRole::Assistant) && msg.content.contains("str_replace") {
// The diff should now reference an external file
assert!(msg.content.contains("<diff saved to"));
// Should not contain the large diff content
assert!(!msg.content.contains("old line"));
}
}
}
}
#[test]
fn test_thin_context_no_large_results() {
let mut context = ContextWindow::new(10000);
// Add only small messages
for i in 0..9 {
context.add_message(Message {
role: MessageRole::User,
content: format!("Tool result: small {}", i),
});
}
context.used_tokens = 5000;
let (summary, _chars_saved) = context.thin_context();
// Should report no large results found
assert!(summary.contains("no large tool results or tool calls found"));
}
#[test]
fn test_thin_context_only_affects_first_third() {
let mut context = ContextWindow::new(10000);
// Add 12 messages (first third = 4 messages)
for i in 0..12 {
let content = if i % 2 == 1 {
// All odd indices are large tool results
format!("Tool result: {}", "x".repeat(1500))
} else {
format!("Assistant message {}", i)
};
let role = if i % 2 == 1 {
MessageRole::User
} else {
MessageRole::Assistant
};
context.add_message(Message { role, content });
}
context.used_tokens = 5000;
let (summary, _chars_saved) = context.thin_context();
// First third is 4 messages (indices 0-3), so only indices 1 and 3 should be thinned
// That's 2 tool results
assert!(summary.contains("2 tool results"));
// Check that messages after the first third are NOT thinned
let first_third_end = context.conversation_history.len() / 3;
for i in first_third_end..context.conversation_history.len() {
if let Some(msg) = context.conversation_history.get(i) {
if matches!(msg.role, MessageRole::User) && msg.content.starts_with("Tool result:") {
// These should still be large (not thinned)
if i % 2 == 1 {
assert!(msg.content.len() > 1000,
"Message at index {} should not have been thinned", i);
}
}
}
}
}

View File

@@ -166,6 +166,31 @@ impl CodeExecutor {
/// Execute Bash code
async fn execute_bash(&self, code: &str) -> Result<ExecutionResult> {
// Check if this is a detached/daemon command that should run independently
let is_detached = code.trim_start().starts_with("setsid ")
|| code.trim_start().starts_with("nohup ")
|| code.contains(" disown")
|| (code.contains(" &") && (code.contains("nohup") || code.contains("setsid")));
if is_detached {
// For detached commands, just spawn and return immediately
use std::process::Stdio;
Command::new("bash")
.arg("-c")
.arg(code)
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()?;
return Ok(ExecutionResult {
stdout: "✅ Command launched in background (detached process)".to_string(),
stderr: String::new(),
exit_code: 0,
success: true,
});
}
let output = Command::new("bash")
.arg("-c")
.arg(code)
@@ -221,6 +246,29 @@ impl CodeExecutor {
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command as TokioCommand;
// Check if this is a detached/daemon command that should run independently
// Look for patterns like: setsid, nohup with &, or explicit backgrounding with disown
let is_detached = code.trim_start().starts_with("setsid ")
|| code.trim_start().starts_with("nohup ")
|| code.contains(" disown")
|| (code.contains(" &") && (code.contains("nohup") || code.contains("setsid")));
if is_detached {
// For detached commands, just spawn and return immediately
TokioCommand::new("bash")
.arg("-c")
.arg(code)
.spawn()?;
// Don't wait for the process - it's meant to run independently
return Ok(ExecutionResult {
stdout: "✅ Command launched in background (detached process)".to_string(),
stderr: String::new(),
exit_code: 0,
success: true,
});
}
let mut child = TokioCommand::new("bash")
.arg("-c")
.arg(code)
@@ -259,7 +307,7 @@ impl CodeExecutor {
line = stderr_lines.next_line() => {
match line {
Ok(Some(line)) => {
receiver.on_output_line(&format!("{}", line));
receiver.on_output_line(&line.to_string());
stderr_output.push(line);
}
Ok(None) => {}, // stderr EOF, continue

View File

@@ -156,8 +156,9 @@ impl AnthropicProvider {
.post(ANTHROPIC_API_URL)
.header("x-api-key", &self.api_key)
.header("anthropic-version", ANTHROPIC_VERSION)
// Anthropic beta 1m context window. Enable if needed. It costs extra, so check first.
// .header("anthropic-beta", "context-1m-2025-08-07")
.header("content-type", "application/json");
if streaming {
builder = builder.header("accept", "text/event-stream");
}
@@ -275,6 +276,7 @@ impl AnthropicProvider {
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
let mut accumulated_usage: Option<Usage> = None;
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
let mut actual_completion_tokens: u32 = 0; // Track actual completion tokens
while let Some(chunk_result) = stream.next().await {
match chunk_result {
@@ -322,7 +324,12 @@ impl AnthropicProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
usage: accumulated_usage.as_ref().map(|u| Usage {
prompt_tokens: u.prompt_tokens,
// Use actual completion tokens if we tracked them, otherwise use the estimate
completion_tokens: if actual_completion_tokens > 0 { actual_completion_tokens } else { u.completion_tokens },
total_tokens: u.prompt_tokens + if actual_completion_tokens > 0 { actual_completion_tokens } else { u.completion_tokens },
}),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
};
if tx.send(Ok(final_chunk)).await.is_err() {
@@ -336,6 +343,7 @@ impl AnthropicProvider {
match serde_json::from_str::<AnthropicStreamEvent>(data) {
Ok(event) => {
debug!("Parsed event type: {}, event: {:?}", event.event_type, event);
match event.event_type.as_str() {
"message_start" => {
// Extract usage data from message_start event
@@ -346,7 +354,10 @@ impl AnthropicProvider {
completion_tokens: usage.output_tokens,
total_tokens: usage.input_tokens + usage.output_tokens,
});
debug!("Captured usage from message_start: {:?}", accumulated_usage);
debug!("Captured initial usage from message_start - prompt: {}, completion: {} (estimated), total: {}",
usage.input_tokens,
usage.output_tokens,
usage.input_tokens + usage.output_tokens);
}
}
}
@@ -395,6 +406,9 @@ impl AnthropicProvider {
"content_block_delta" => {
if let Some(delta) = event.delta {
if let Some(text) = delta.text {
// Track actual completion tokens (rough estimate: 4 chars per token)
actual_completion_tokens += (text.len() as f32 / 4.0).ceil() as u32;
debug!("Sending text chunk of length {}: '{}'", text.len(), text);
let chunk = CompletionChunk {
content: text,
@@ -415,6 +429,19 @@ impl AnthropicProvider {
}
}
}
"message_delta" => {
// Check if message_delta contains updated usage data
if let Some(delta) = event.delta {
if let Some(usage) = delta.usage {
accumulated_usage = Some(Usage {
prompt_tokens: usage.input_tokens,
completion_tokens: usage.output_tokens,
total_tokens: usage.input_tokens + usage.output_tokens,
});
debug!("Updated usage from message_delta - prompt: {}, completion: {}, total: {}", usage.input_tokens, usage.output_tokens, usage.input_tokens + usage.output_tokens);
}
}
}
"content_block_stop" => {
// Tool call block is complete - now parse the accumulated JSON
if !current_tool_calls.is_empty() && !partial_tool_json.is_empty() {
@@ -449,11 +476,44 @@ impl AnthropicProvider {
}
}
"message_stop" => {
debug!("Received message stop event");
debug!("Received message_stop event: {:?}", event);
// Check if message_stop contains final usage data
if let Some(message) = event.message {
if let Some(usage) = message.usage {
// Update with final accurate usage data from message_stop
// This should have the actual completion token count
accumulated_usage = Some(Usage {
prompt_tokens: usage.input_tokens,
// Prefer the actual output_tokens from message_stop if available
// Otherwise use our tracked count, and as last resort the initial estimate
completion_tokens: if usage.output_tokens > 0 {
usage.output_tokens
} else if actual_completion_tokens > 0 {
actual_completion_tokens
} else { usage.output_tokens },
total_tokens: usage.input_tokens + usage.output_tokens,
});
debug!("Updated with final usage from message_stop - prompt: {}, completion: {}, total: {}",
usage.input_tokens,
usage.output_tokens,
usage.input_tokens + usage.output_tokens);
}
}
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
usage: accumulated_usage.as_ref().map(|u| Usage {
prompt_tokens: u.prompt_tokens,
// Use actual completion tokens if we tracked them and they're higher
completion_tokens: if actual_completion_tokens > u.completion_tokens {
actual_completion_tokens
} else {
u.completion_tokens
},
total_tokens: u.prompt_tokens + u32::max(actual_completion_tokens, u.completion_tokens),
}),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
};
if tx.send(Ok(final_chunk)).await.is_err() {
@@ -495,10 +555,27 @@ impl AnthropicProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
usage: accumulated_usage.as_ref().map(|u| Usage {
prompt_tokens: u.prompt_tokens,
completion_tokens: if actual_completion_tokens > u.completion_tokens {
actual_completion_tokens
} else {
u.completion_tokens
},
total_tokens: u.prompt_tokens + u32::max(actual_completion_tokens, u.completion_tokens),
}),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) },
};
let _ = tx.send(Ok(final_chunk)).await;
// Log final usage for debugging
if let Some(ref usage) = accumulated_usage {
info!("Anthropic stream completed with final usage - prompt: {}, completion: {}, total: {}",
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens);
} else {
warn!("Anthropic stream completed without usage data - token accounting will fall back to estimation");
}
accumulated_usage
}
}
@@ -736,6 +813,8 @@ struct AnthropicStreamMessage {
struct AnthropicDelta {
text: Option<String>,
partial_json: Option<String>,
#[serde(default)]
usage: Option<AnthropicUsage>,
}
#[derive(Debug, Deserialize)]

View File

@@ -213,7 +213,7 @@ impl DatabricksProvider {
let mut builder = self
.client
.post(&format!(
.post(format!(
"{}/serving-endpoints/{}/invocations",
self.host, self.model
))
@@ -882,6 +882,14 @@ impl LLMProvider for DatabricksProvider {
request.messages.len()
);
// Debug: Log tool count
if let Some(ref tools) = request.tools {
debug!("Request has {} tools", tools.len());
for tool in tools.iter().take(5) {
debug!(" Tool: {}", tool.name);
}
}
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);

View File

@@ -88,10 +88,12 @@ pub mod anthropic;
pub mod databricks;
pub mod embedded;
pub mod oauth;
pub mod openai;
pub use anthropic::AnthropicProvider;
pub use databricks::DatabricksProvider;
pub use embedded::EmbeddedProvider;
pub use openai::OpenAIProvider;
/// Provider registry for managing multiple LLM providers
pub struct ProviderRegistry {

View File

@@ -102,7 +102,7 @@ async fn get_workspace_endpoints(host: &str) -> Result<OidcEndpoints> {
if !resp.status().is_success() {
return Err(anyhow::anyhow!(
"Failed to get OIDC configuration from {}",
oidc_url.to_string()
oidc_url
));
}

View File

@@ -0,0 +1,495 @@
use anyhow::Result;
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::stream::StreamExt;
use reqwest::Client;
use serde::Deserialize;
use serde_json::json;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error};
use crate::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider,
Message, MessageRole, Tool, ToolCall, Usage,
};
#[derive(Clone)]
pub struct OpenAIProvider {
client: Client,
api_key: String,
model: String,
base_url: String,
max_tokens: Option<u32>,
_temperature: Option<f32>,
}
impl OpenAIProvider {
pub fn new(
api_key: String,
model: Option<String>,
base_url: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
) -> Result<Self> {
Ok(Self {
client: Client::new(),
api_key,
model: model.unwrap_or_else(|| "gpt-4o".to_string()),
base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
max_tokens,
_temperature: temperature,
})
}
fn create_request_body(
&self,
messages: &[Message],
tools: Option<&[Tool]>,
stream: bool,
max_tokens: Option<u32>,
_temperature: Option<f32>,
) -> serde_json::Value {
let mut body = json!({
"model": self.model,
"messages": convert_messages(messages),
"stream": stream,
});
if let Some(max_tokens) = max_tokens.or(self.max_tokens) {
body["max_completion_tokens"] = json!(max_tokens);
}
// OpenAI calls with temp setting seem to fail, so don't send one.
// if let Some(temperature) = temperature.or(self.temperature) {
// body["temperature"] = json!(temperature);
// }
if let Some(tools) = tools {
if !tools.is_empty() {
body["tools"] = json!(convert_tools(tools));
}
}
if stream {
body["stream_options"] = json!({
"include_usage": true,
});
}
body
}
async fn parse_streaming_response(
&self,
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
tx: mpsc::Sender<Result<CompletionChunk>>,
) -> Option<Usage> {
let mut buffer = String::new();
let mut accumulated_content = String::new();
let mut accumulated_usage: Option<Usage> = None;
let mut current_tool_calls: Vec<OpenAIStreamingToolCall> = Vec::new();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
let chunk_str = match std::str::from_utf8(&chunk) {
Ok(s) => s,
Err(e) => {
error!("Failed to parse chunk as UTF-8: {}", e);
continue;
}
};
buffer.push_str(chunk_str);
// Process complete lines
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].trim().to_string();
buffer.drain(..line_end + 1);
if line.is_empty() {
continue;
}
// Parse Server-Sent Events format
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
debug!("Received stream completion marker");
// Send final chunk with accumulated content and tool calls
if !accumulated_content.is_empty() || !current_tool_calls.is_empty() {
let tool_calls = if current_tool_calls.is_empty() {
None
} else {
Some(
current_tool_calls
.iter()
.filter_map(|tc| tc.to_tool_call())
.collect(),
)
};
let final_chunk = CompletionChunk {
content: accumulated_content.clone(),
finished: true,
tool_calls,
usage: accumulated_usage.clone(),
};
let _ = tx.send(Ok(final_chunk)).await;
}
return accumulated_usage;
}
// Parse the JSON data
match serde_json::from_str::<OpenAIStreamChunk>(data) {
Ok(chunk_data) => {
// Handle content
for choice in &chunk_data.choices {
if let Some(content) = &choice.delta.content {
accumulated_content.push_str(content);
let chunk = CompletionChunk {
content: content.clone(),
finished: false,
tool_calls: None,
usage: None,
};
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return accumulated_usage;
}
}
// Handle tool calls
if let Some(delta_tool_calls) = &choice.delta.tool_calls {
for delta_tool_call in delta_tool_calls {
if let Some(index) = delta_tool_call.index {
// Ensure we have enough tool calls in our vector
while current_tool_calls.len() <= index {
current_tool_calls
.push(OpenAIStreamingToolCall::default());
}
let tool_call = &mut current_tool_calls[index];
if let Some(id) = &delta_tool_call.id {
tool_call.id = Some(id.clone());
}
if let Some(function) = &delta_tool_call.function {
if let Some(name) = &function.name {
tool_call.name = Some(name.clone());
}
if let Some(arguments) = &function.arguments {
tool_call.arguments.push_str(arguments);
}
}
}
}
}
}
// Handle usage
if let Some(usage) = chunk_data.usage {
accumulated_usage = Some(Usage {
prompt_tokens: usage.prompt_tokens,
completion_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
});
}
}
Err(e) => {
debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
}
}
}
}
}
Err(e) => {
error!("Stream error: {}", e);
let _ = tx.send(Err(anyhow::anyhow!("Stream error: {}", e))).await;
return accumulated_usage;
}
}
}
// Send final chunk if we haven't already
let tool_calls = if current_tool_calls.is_empty() {
None
} else {
Some(
current_tool_calls
.iter()
.filter_map(|tc| tc.to_tool_call())
.collect(),
)
};
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
tool_calls,
usage: accumulated_usage.clone(),
};
let _ = tx.send(Ok(final_chunk)).await;
accumulated_usage
}
}
#[async_trait]
impl LLMProvider for OpenAIProvider {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
debug!(
"Processing OpenAI completion request with {} messages",
request.messages.len()
);
let body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
false,
request.max_tokens,
request.temperature,
);
debug!("Sending request to OpenAI API: model={}", self.model);
let response = self
.client
.post(format!("{}/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text));
}
let openai_response: OpenAIResponse = response.json().await?;
let content = openai_response
.choices
.first()
.and_then(|choice| choice.message.content.clone())
.unwrap_or_default();
let usage = Usage {
prompt_tokens: openai_response.usage.prompt_tokens,
completion_tokens: openai_response.usage.completion_tokens,
total_tokens: openai_response.usage.total_tokens,
};
debug!(
"OpenAI completion successful: {} tokens generated",
usage.completion_tokens
);
Ok(CompletionResponse {
content,
usage,
model: self.model.clone(),
})
}
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
debug!(
"Processing OpenAI streaming request with {} messages",
request.messages.len()
);
let body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
true,
request.max_tokens,
request.temperature,
);
debug!("Sending streaming request to OpenAI API: model={}", self.model);
let response = self
.client
.post(format!("{}/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text));
}
let stream = response.bytes_stream();
let (tx, rx) = mpsc::channel(100);
// Spawn task to process the stream
let provider = self.clone();
tokio::spawn(async move {
let usage = provider.parse_streaming_response(stream, tx).await;
// Log the final usage if available
if let Some(usage) = usage {
debug!(
"Stream completed with usage - prompt: {}, completion: {}, total: {}",
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
);
}
});
Ok(ReceiverStream::new(rx))
}
fn name(&self) -> &str {
"openai"
}
fn model(&self) -> &str {
&self.model
}
fn has_native_tool_calling(&self) -> bool {
// OpenAI models support native tool calling
true
}
}
fn convert_messages(messages: &[Message]) -> Vec<serde_json::Value> {
messages
.iter()
.map(|msg| {
json!({
"role": match msg.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
},
"content": msg.content,
})
})
.collect()
}
fn convert_tools(tools: &[Tool]) -> Vec<serde_json::Value> {
tools
.iter()
.map(|tool| {
json!({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.input_schema,
}
})
})
.collect()
}
// OpenAI API response structures
#[derive(Debug, Deserialize)]
struct OpenAIResponse {
choices: Vec<OpenAIChoice>,
usage: OpenAIUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAIChoice {
message: OpenAIMessage,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct OpenAIMessage {
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<OpenAIToolCall>>,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct OpenAIToolCall {
id: String,
function: OpenAIFunction,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct OpenAIFunction {
name: String,
arguments: String,
}
// Streaming tool call accumulator
#[derive(Debug, Default)]
struct OpenAIStreamingToolCall {
id: Option<String>,
name: Option<String>,
arguments: String,
}
impl OpenAIStreamingToolCall {
fn to_tool_call(&self) -> Option<ToolCall> {
let id = self.id.as_ref()?;
let name = self.name.as_ref()?;
let args = serde_json::from_str(&self.arguments).unwrap_or(serde_json::Value::Null);
Some(ToolCall {
id: id.clone(),
tool: name.clone(),
args,
})
}
}
#[derive(Debug, Deserialize)]
struct OpenAIUsage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
// Streaming response structures
#[derive(Debug, Deserialize)]
struct OpenAIStreamChunk {
choices: Vec<OpenAIStreamChoice>,
usage: Option<OpenAIUsage>,
}
#[derive(Debug, Deserialize)]
struct OpenAIStreamChoice {
delta: OpenAIDelta,
}
#[derive(Debug, Deserialize)]
struct OpenAIDelta {
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<OpenAIDeltaToolCall>>,
}
#[derive(Debug, Deserialize)]
struct OpenAIDeltaToolCall {
index: Option<usize>,
id: Option<String>,
function: Option<OpenAIDeltaFunction>,
}
#[derive(Debug, Deserialize)]
struct OpenAIDeltaFunction {
name: Option<String>,
arguments: Option<String>,
}

39
test-ai-requirements.sh Executable file
View File

@@ -0,0 +1,39 @@
#!/bin/bash
# Test script for AI-enhanced interactive requirements mode
echo "Testing AI-enhanced interactive requirements mode..."
echo ""
# Create a test workspace
TEST_WORKSPACE="/tmp/g3-test-interactive-$(date +%s)"
mkdir -p "$TEST_WORKSPACE"
echo "Test workspace: $TEST_WORKSPACE"
echo ""
# Create sample brief input
BRIEF_INPUT="build a calculator cli in rust with basic operations"
echo "Brief input:"
echo "---"
echo "$BRIEF_INPUT"
echo "---"
echo ""
echo "This will:"
echo "1. Send brief input to AI"
echo "2. AI generates structured requirements.md"
echo "3. Show enhanced requirements"
echo "4. Prompt for confirmation (y/e/n)"
echo ""
echo "To test manually, run:"
echo "cargo run -- --autonomous --interactive-requirements --workspace $TEST_WORKSPACE"
echo ""
echo "Then type: $BRIEF_INPUT"
echo "Press Ctrl+D"
echo "Review the AI-generated requirements"
echo "Choose 'y' to proceed, 'e' to edit, or 'n' to cancel"
echo ""
echo "Test workspace will be at: $TEST_WORKSPACE"

164
test_token_accounting.py Normal file
View File

@@ -0,0 +1,164 @@
#!/usr/bin/env python3
"""
Test script to verify token accounting is working correctly with the Anthropic provider.
This script will send multiple messages and verify that token counts accumulate properly.
"""
import subprocess
import json
import re
import sys
import time
def run_g3_command(prompt, provider="anthropic"):
"""Run a g3 command and capture the output."""
cmd = [
"cargo", "run", "--release", "--",
"--provider", provider,
prompt
]
env = {
"RUST_LOG": "g3_providers=debug,g3_core=info",
"RUST_BACKTRACE": "1"
}
result = subprocess.run(
cmd,
capture_output=True,
text=True,
env={**subprocess.os.environ, **env}
)
return result.stdout + result.stderr
def extract_token_info(output):
"""Extract token usage information from the output."""
token_info = {}
# Look for token usage updates
usage_pattern = r"Updated token usage.*was: (\d+), now: (\d+).*prompt=(\d+), completion=(\d+), total=(\d+)"
matches = re.findall(usage_pattern, output)
if matches:
last_match = matches[-1]
token_info['was'] = int(last_match[0])
token_info['now'] = int(last_match[1])
token_info['prompt'] = int(last_match[2])
token_info['completion'] = int(last_match[3])
token_info['total'] = int(last_match[4])
# Look for context percentage
context_pattern = r"Context usage at (\d+)%.*\((\d+)/(\d+) tokens\)"
matches = re.findall(context_pattern, output)
if matches:
last_match = matches[-1]
token_info['percentage'] = int(last_match[0])
token_info['used'] = int(last_match[1])
token_info['total_context'] = int(last_match[2])
# Look for thinning triggers
thinning_pattern = r"Context thinning triggered.*usage: (\d+)%.*\((\d+)/(\d+) tokens\)"
matches = re.findall(thinning_pattern, output)
if matches:
token_info['thinning_triggered'] = True
token_info['thinning_percentage'] = int(matches[-1][0])
# Look for final usage from Anthropic
final_usage_pattern = r"Anthropic stream completed with final usage.*prompt: (\d+), completion: (\d+), total: (\d+)"
matches = re.findall(final_usage_pattern, output)
if matches:
last_match = matches[-1]
token_info['final_prompt'] = int(last_match[0])
token_info['final_completion'] = int(last_match[1])
token_info['final_total'] = int(last_match[2])
return token_info
def main():
print("Testing Anthropic Provider Token Accounting")
print("="*50)
# Build the project first
print("Building project...")
subprocess.run(["cargo", "build", "--release"], capture_output=True)
# Test 1: Simple prompt
print("\nTest 1: Simple prompt")
print("-"*30)
output = run_g3_command("Say 'Hello, World!' and nothing else.")
tokens = extract_token_info(output)
if tokens:
print(f"Token usage: {tokens.get('now', 'N/A')} tokens")
print(f" Prompt tokens: {tokens.get('prompt', 'N/A')}")
print(f" Completion tokens: {tokens.get('completion', 'N/A')}")
print(f" Total from provider: {tokens.get('total', 'N/A')}")
if 'final_total' in tokens:
print(f" Final total from stream: {tokens['final_total']}")
if tokens.get('now') != tokens['final_total']:
print(f" ⚠️ WARNING: Mismatch between tracked ({tokens.get('now')}) and final ({tokens['final_total']})")
# Check if the completion tokens are reasonable (should be small for "Hello, World!")
if tokens.get('completion', 0) > 50:
print(f" ⚠️ WARNING: Completion tokens seem high for a simple response: {tokens.get('completion')}")
else:
print(" ❌ No token information found in output")
# Test 2: Longer response
print("\nTest 2: Longer response")
print("-"*30)
output = run_g3_command("Write a 3-paragraph essay about the importance of accurate token counting in LLM applications.")
tokens = extract_token_info(output)
if tokens:
print(f"Token usage: {tokens.get('now', 'N/A')} tokens")
print(f" Prompt tokens: {tokens.get('prompt', 'N/A')}")
print(f" Completion tokens: {tokens.get('completion', 'N/A')}")
print(f" Total from provider: {tokens.get('total', 'N/A')}")
if 'final_total' in tokens:
print(f" Final total from stream: {tokens['final_total']}")
if tokens.get('now') != tokens['final_total']:
print(f" ⚠️ WARNING: Mismatch between tracked ({tokens.get('now')}) and final ({tokens['final_total']})")
# Check if completion tokens are reasonable for a longer response
if tokens.get('completion', 0) < 100:
print(f" ⚠️ WARNING: Completion tokens seem low for a 3-paragraph essay: {tokens.get('completion')}")
else:
print(" ❌ No token information found in output")
# Test 3: Check for proper accumulation
print("\nTest 3: Token accumulation (multiple messages)")
print("-"*30)
# First message
output1 = run_g3_command("Count from 1 to 5.")
tokens1 = extract_token_info(output1)
# Second message (this would need to be in a conversation, but for now we test separately)
output2 = run_g3_command("Now count from 6 to 10.")
tokens2 = extract_token_info(output2)
if tokens1 and tokens2:
print(f"First message: {tokens1.get('now', 'N/A')} tokens")
print(f"Second message: {tokens2.get('now', 'N/A')} tokens")
# In a real conversation, tokens2['now'] should be greater than tokens1['now']
# But since these are separate invocations, we just check they're both reasonable
if tokens1.get('now', 0) > 0 and tokens2.get('now', 0) > 0:
print(" ✅ Both messages have token counts")
else:
print(" ❌ Missing token counts")
print("\n" + "="*50)
print("Test Summary:")
print("Check the output above for any warnings or errors.")
print("Key things to verify:")
print(" 1. Token counts are being captured from the provider")
print(" 2. Completion tokens are reasonable for the response length")
print(" 3. No mismatch between tracked and final token counts")
print(" 4. Context thinning triggers at appropriate thresholds")
if __name__ == "__main__":
main()

46
test_token_accounting.sh Executable file
View File

@@ -0,0 +1,46 @@
#!/bin/bash
# Test script to verify token accounting with Anthropic provider
echo "Testing token accounting with Anthropic provider..."
echo "This test will send a few messages and check if token counts are properly tracked."
echo ""
# Set up environment for testing
export RUST_LOG=g3_providers=debug,g3_core=info
export RUST_BACKTRACE=1
# Build the project first
echo "Building project..."
cargo build --release 2>&1 | grep -E "(Compiling|Finished)" || true
echo ""
echo "Running test with Anthropic provider..."
echo "Watch for these log messages:"
echo " - 'Captured initial usage from message_start'"
echo " - 'Updated usage from message_delta' (if available)"
echo " - 'Updated with final usage from message_stop' (if available)"
echo " - 'Anthropic stream completed with final usage'"
echo " - 'Updated token usage from provider'"
echo " - 'Context thinning triggered' (when reaching thresholds)"
echo ""
# Create a simple test that will generate some tokens
cat << 'EOF' > /tmp/test_prompt.txt
Please write a short paragraph about the importance of accurate token counting in LLM applications. Then list 3 reasons why token accounting might fail.
EOF
# Run the test
echo "Sending test prompt..."
cargo run --release -- --provider anthropic "$(cat /tmp/test_prompt.txt)" 2>&1 | tee /tmp/token_test.log
echo ""
echo "Analyzing results..."
echo ""
# Check for token accounting messages
echo "Token accounting messages found:"
grep -E "(usage from|token usage|Context thinning|Context usage)" /tmp/token_test.log | head -20
echo ""
echo "Test complete. Check /tmp/token_test.log for full output."