Compare commits

..

183 Commits

Author SHA1 Message Date
Michael Neale
79b375519b use proper context for qwen3-coder 2025-11-05 13:55:45 +11:00
Michael Neale
88c3cc23fe works better without streaming 2025-11-05 12:55:21 +11:00
Michael Neale
4622507f37 gpt context aware 2025-11-05 12:25:02 +11:00
Michael Neale
217df2f2af ollama support 2025-11-05 12:17:01 +11:00
Dhanji Prasanna
22a0090cdc fix unexpected EOF on streams 2025-11-04 16:28:41 +11:00
Dhanji Prasanna
631f3c16ca compact on tool call if > 90% 2025-11-04 14:35:11 +11:00
Dhanji Prasanna
1f9fef5f18 more json filtering 2025-11-03 11:56:16 +11:00
Dhanji Prasanna
57d473c19d mild json filtering improvement 2025-11-03 11:54:27 +11:00
Jochen
e59ce2f93f Merge pull request #16 from dhanji/jochen-ast-tool
adds ast-grep tool for faster code exploration
2025-11-02 21:04:11 +11:00
Jochen
a1ad94ed75 Added comment & example for native flow
detailed examples for using code_search tool for native tool use.
2025-11-02 21:02:43 +11:00
Jochen
982c0bbfb3 amend instructions for tool use 2025-11-01 15:52:08 +11:00
Jochen
ad9ba5e5d8 added ast-grep use
g3 tool use of ast-grep command with batching for faster code exploration.
2025-11-01 14:59:55 +11:00
Dhanji Prasanna
f89bbfc89a fix final_output bug 2025-10-31 14:48:36 +11:00
Dhanji Prasanna
11eb01e04d add back progress bar to cli 2025-10-31 14:28:35 +11:00
Dhanji Prasanna
bdaacfd051 fix for duplicate messages at end 2025-10-31 13:34:36 +11:00
Dhanji R. Prasanna
92ae776510 Merge pull request #14 from dhanji/micn/always-coach-player
always coach player
2025-10-31 12:52:17 +11:00
Michael Neale
c42e0bce54 use --auto flag 2025-10-31 11:42:42 +11:00
Michael Neale
b529d7f814 add ability to use slash commands, and also go back to chat in context of player auto mode 2025-10-29 18:09:13 +11:00
Michael Neale
9752e81489 cleanup 2025-10-29 14:53:10 +11:00
Michael Neale
63c2aff7ba clearer 2025-10-29 14:47:25 +11:00
Michael Neale
aa4a0267ea can interrupt now 2025-10-29 13:29:03 +11:00
Michael Neale
6cfa1e225c can cancell acc mode 2025-10-29 13:13:41 +11:00
Michael Neale
f53cd8e8f3 requirements always 2025-10-29 13:09:15 +11:00
Michael Neale
45bffc40da coach player always when starting 2025-10-29 13:04:16 +11:00
Dhanji R. Prasanna
4bf0f71bbd Merge pull request #12 from dhanji/libvision
will need this for it to work
2025-10-28 15:12:51 +11:00
Michael Neale
c1ce3038d8 will need this for it to work 2025-10-28 15:07:24 +11:00
Dhanji Prasanna
4b1694b308 machine mode 2025-10-28 14:51:32 +11:00
Dhanji Prasanna
5e08d6bbba --machine mode flag for verbose CLI output 2025-10-27 10:37:05 +11:00
Dhanji Prasanna
c3f3f79dc5 fixed x,y detection in vision click 2025-10-25 16:51:27 +11:00
Dhanji Prasanna
834153ea69 screenshotting bug fix 2025-10-24 20:40:43 +11:00
Dhanji Prasanna
65f25f840e test 2025-10-24 16:11:24 +11:00
Dhanji Prasanna
a8af5d7cc1 Native api for screen capture 2025-10-24 16:11:12 +11:00
Dhanji Prasanna
61d748034d replace tesseract with apple vision 2025-10-24 15:35:47 +11:00
Dhanji Prasanna
d0ac222e2e more macax tooling 2025-10-24 10:45:24 +11:00
Dhanji Prasanna
e1e732150a coach rigor +++ 2025-10-24 10:15:42 +11:00
Dhanji Prasanna
0be4829ca9 thinning message highlighted 2025-10-23 13:16:13 +11:00
Dhanji Prasanna
efd4eca755 warnings fix 2025-10-23 07:17:55 +11:00
Dhanji Prasanna
3ec65e38ee macax tools 2025-10-23 06:53:42 +11:00
Dhanji Prasanna
c5d6fbef08 control commands 2025-10-22 22:14:12 +11:00
Dhanji R. Prasanna
f93844d378 Merge pull request #10 from dhanji/micn/interactive-requirements
Add --interactive-requirements flag for AI-enhanced requirements mode
2025-10-22 15:37:16 +11:00
Michael Neale
af6d37a8e2 Add --interactive-requirements flag for AI-enhanced requirements mode
- Adds new --interactive-requirements CLI flag for autonomous mode
- Prompts user for brief requirements input
- Uses AI to enhance and structure requirements into proper markdown
- Shows enhanced requirements and allows user to approve/edit/cancel
- Saves to requirements.md and proceeds with autonomous mode if approved
- Includes test script for manual verification
2025-10-22 14:58:35 +11:00
Dhanji R. Prasanna
c1c6680e03 Merge pull request #7 from jochenx/jochen-add-openai-and-multi-providers
coach/player provider split + add OpenAI
2025-10-22 13:46:16 +11:00
Jochen
f2d8e744bb fix panic in CLI parser 2025-10-22 13:20:45 +11:00
Jochen
010a43d203 coach/player provider split + add OpenAI
Allows coach and player LLM providers to be separately specified.
Also adds OpenAI provider
2025-10-21 16:59:13 +11:00
Dhanji Prasanna
758e255af8 dont run safaridriver --enable each time 2025-10-21 16:00:58 +11:00
Dhanji Prasanna
393826ae02 webdriver tools 2025-10-21 14:34:41 +11:00
Dhanji Prasanna
3afad3d61f progressive context thinning 2025-10-20 15:29:44 +11:00
Dhanji Prasanna
2488cc54d5 docs: update README and DESIGN to reflect current project state
- Add g3-computer-control crate to architecture documentation
- Document all 13 tools including computer control and TODO management
- Add context thinning feature documentation (50-80% thresholds)
- Update tool ecosystem section with complete tool list
- Remove broken link to non-existent COMPUTER_CONTROL.md
- Update workspace count from 5 to 6 crates
- Add platform-specific implementation details for computer control
- Document OCR support via Tesseract
- Clarify setup instructions for computer control features
2025-10-20 15:03:22 +11:00
Dhanji Prasanna
2ad0c9a3fd todo list formatting 2025-10-20 14:27:53 +11:00
Dhanji Prasanna
2008a81193 fix to pass feedback to player (broken by todo system) 2025-10-20 14:12:08 +11:00
Dhanji Prasanna
776f5034b8 TODO tools 2025-10-20 10:50:53 +11:00
Dhanji Prasanna
92bece957b colorizing tool calls 2025-10-18 16:09:30 +11:00
Dhanji Prasanna
767299ff4e minor 2025-10-18 16:03:58 +11:00
Dhanji Prasanna
9d35449be8 ~ expansion for read_file and str_replace 2025-10-18 16:01:15 +11:00
Dhanji Prasanna
da652bf287 computer control tools 2025-10-18 14:16:50 +11:00
Dhanji Prasanna
a566171203 small turn completing bug 2025-10-18 13:25:23 +11:00
Dhanji Prasanna
347c9e1e00 colorize timing based on duration 2025-10-17 13:54:21 +11:00
Dhanji Prasanna
aa7eda0331 fix wall clock timing 2025-10-17 10:36:21 +11:00
Dhanji Prasanna
e42c76f3b9 Tune coach pickiness down 2025-10-17 10:28:08 +11:00
Dhanji Prasanna
dd211fab1c panic fix 2025-10-17 09:50:01 +11:00
Dhanji R. Prasanna
bcece38473 Merge pull request #5 from dhanji/micn/agent-tweaks
load AGENTS.md if there
2025-10-16 15:06:14 +11:00
Michael Neale
3ff8413538 loading agents 2025-10-16 15:03:23 +11:00
Michael Neale
de2a761dbd Merge branch 'main' into micn/agent-tweaks 2025-10-16 14:49:16 +11:00
Dhanji Prasanna
e5a6ab66d7 turn histogram from autonomous mode 2025-10-16 14:35:47 +11:00
Dhanji Prasanna
444c0bc6c6 --quiet flag suppresses logs 2025-10-16 13:08:26 +11:00
Michael Neale
758a6b18c8 load agents if there 2025-10-16 12:00:50 +11:00
Dhanji Prasanna
41c1363fb5 guard case to ensure approval terminates run 2025-10-16 11:01:46 +11:00
Dhanji Prasanna
52ada78151 requirements flag 2025-10-16 10:08:04 +11:00
Dhanji Prasanna
662748ed23 better formatting cli 2025-10-15 22:04:39 +11:00
Dhanji Prasanna
beccc8fa15 reset filter suppression state between tool calls (still broken) 2025-10-15 21:15:24 +11:00
Dhanji Prasanna
c9037ede22 fixed feedback handoff in autonomous mode 2025-10-15 14:07:25 +11:00
Dhanji Prasanna
793fc544c0 some cleanup 2025-10-15 11:12:26 +11:00
Dhanji Prasanna
fb64b7fe32 fixed filtering and tool call timeouts 2025-10-15 10:18:20 +11:00
Dhanji Prasanna
befc55152d fixed tool call cli output 2025-10-15 09:55:59 +11:00
Dhanji Prasanna
bb90cc7826 some fixes 2025-10-14 12:44:02 +11:00
Dhanji Prasanna
5110da0c61 design doc 2025-10-14 12:33:36 +11:00
Dhanji Prasanna
bfd256db3b fix tool output 2025-10-14 12:21:22 +11:00
Dhanji R. Prasanna
cef4d12d36 small cleanup to shell 2025-10-13 21:52:47 +11:00
Dhanji R. Prasanna
45eb0a4b63 small compile error 2025-10-13 21:51:44 +11:00
Dhanji R. Prasanna
a914afedd8 panic fix in tui 2025-10-13 21:49:22 +11:00
Dhanji Prasanna
627fdcd9bf streaming tool call attempt 1 2025-10-13 20:25:12 +11:00
Dhanji Prasanna
b43b693b60 small tweak to tmp prompting 2025-10-13 13:38:34 +11:00
Dhanji Prasanna
062e6de63f fix for buffered messages at end, colorized context bars 2025-10-13 13:36:37 +11:00
Dhanji Prasanna
318355e864 Added --provider and --model flags 2025-10-12 17:05:58 +11:00
Dhanji Prasanna
037bff7021 UTF-8 decoding bug 2025-10-12 14:54:28 +11:00
Dhanji Prasanna
05c21b61df coach mode feedback fix 2025-10-11 16:13:39 +11:00
Dhanji Prasanna
f42e43a0d6 auto mode report 2025-10-11 15:11:07 +11:00
Dhanji Prasanna
658a335615 prompt change 2025-10-11 15:07:47 +11:00
Dhanji Prasanna
e89e1acf41 auto mode and message fix 2025-10-11 15:06:37 +11:00
Dhanji Prasanna
7dd4fbf9b6 restart turn on error 2025-10-11 13:47:13 +11:00
Dhanji Prasanna
5fb631d5c3 cosmetic fixes to tool call headers 2025-10-11 13:32:35 +11:00
Dhanji Prasanna
13236a1be5 ui writer fixes 2025-10-10 15:39:42 +11:00
Dhanji Prasanna
1bae19abd4 Revert "fix for tool args and missing msgs"
This reverts commit 1e9ff972d9.
2025-10-10 15:39:42 +11:00
Dhanji Prasanna
d16a694862 show messages fix 2025-10-10 15:36:57 +11:00
Dhanji Prasanna
4a819e8f27 context window counting bug 2025-10-10 14:40:10 +11:00
Dhanji Prasanna
1e9ff972d9 fix for tool args and missing msgs 2025-10-10 14:28:02 +11:00
Dhanji Prasanna
57b7bcb0de cosmetic tool call stuff 2025-10-10 14:18:35 +11:00
Dhanji Prasanna
426a9b88a9 readme tweaks 2025-10-10 14:08:37 +11:00
Dhanji Prasanna
2d959b3d63 cosmetic 2025-10-10 13:52:04 +11:00
Dhanji Prasanna
16216532d0 newline 2025-10-10 13:46:08 +11:00
Dhanji Prasanna
3ef7ec0d9f colorize 2025-10-10 13:38:38 +11:00
Dhanji Prasanna
0ad52a2eb2 tighten tool output in normal cli 2025-10-10 10:03:15 +11:00
Dhanji Prasanna
1e44971cf8 error recovery and tests 2025-10-10 09:35:03 +11:00
Dhanji Prasanna
ef01226ee1 auto readme 2025-10-09 14:56:25 +11:00
Dhanji Prasanna
260c949576 token counting fixes 2025-10-09 12:11:21 +11:00
Dhanji Prasanna
9d1eef82b9 final output fix for auto mode 2025-10-09 11:16:21 +11:00
Dhanji Prasanna
cd489fb235 partial readfile support 2025-10-09 11:08:02 +11:00
Dhanji Prasanna
0973b83d3a fix build warnings 2025-10-08 14:06:25 +11:00
Dhanji Prasanna
5e6ac4e5f5 tweak to colors 2025-10-08 13:43:29 +11:00
Dhanji Prasanna
e1b1ed560a dracula theme tweaks 2025-10-08 12:39:14 +11:00
Dhanji Prasanna
8e4d0a3975 dracula theme tweak 2025-10-08 11:19:00 +11:00
Dhanji Prasanna
b369a1f5c3 fixes for coach mode 2025-10-08 11:17:24 +11:00
Dhanji Prasanna
e11a287acc color schemes 2025-10-08 11:14:56 +11:00
Dhanji Prasanna
ed769bd58a some graphing updates 2025-10-07 15:13:45 +11:00
Dhanji Prasanna
e6cec5ef0f retry on errors 2025-10-07 11:20:19 +11:00
Dhanji Prasanna
5a83e1b7e0 input box fixes 2025-10-06 14:48:27 +11:00
Dhanji Prasanna
c9487db5e7 only show tool detail when running 2025-10-06 14:33:15 +11:00
Dhanji Prasanna
340ba78eb3 remove output box border 2025-10-06 14:23:17 +11:00
Dhanji Prasanna
4a25191c77 bug fix on end of agent turn 2025-10-06 13:25:19 +11:00
Dhanji Prasanna
bcba99ec6c auto refresh token 2025-10-04 17:32:48 +10:00
Dhanji Prasanna
1a57dd3b1d tool window scrolling 2025-10-04 16:34:59 +10:00
Dhanji Prasanna
1379af7159 tool headers working 2025-10-04 16:24:33 +10:00
Dhanji Prasanna
9b7c228134 scroll hack 2025-10-04 15:05:06 +10:00
Dhanji Prasanna
f562301aa2 tweaks to newline 2025-10-04 13:30:11 +10:00
Dhanji Prasanna
cdfca615e3 tail cursor 2025-10-03 14:23:56 +10:00
Dhanji Prasanna
54e2a66b7d fixed newline character messup 2025-10-03 14:04:17 +10:00
Dhanji Prasanna
dfa54f20ec tmp file directive 2025-10-03 13:09:02 +10:00
Dhanji Prasanna
213dfd28d4 colors 2025-10-03 11:05:01 +10:00
Dhanji Prasanna
b39fd02603 scrolling fixed 2025-10-03 11:01:39 +10:00
Dhanji Prasanna
56e13ced64 tool calling boxes 2025-10-02 15:50:04 +10:00
Dhanji Prasanna
4e457960ed processing blink 2025-10-02 15:38:27 +10:00
Dhanji Prasanna
1faf16b23a tweaks 2025-10-02 15:34:37 +10:00
Dhanji Prasanna
4de994a2a7 tweak 2025-10-02 15:23:15 +10:00
Dhanji Prasanna
dd89067ac1 minor 2025-10-02 15:18:56 +10:00
Dhanji Prasanna
c065532c41 softer colors 2025-10-02 15:15:21 +10:00
Dhanji Prasanna
7ce1bfc8e2 tweaks to ui 2025-10-02 15:03:23 +10:00
Dhanji Prasanna
cd7f8d3fc7 model only 2025-10-02 14:58:03 +10:00
Dhanji Prasanna
bf5efde06e show model and provider 2025-10-02 14:53:38 +10:00
Dhanji Prasanna
57b1b51e65 retro mode ui! 2025-10-02 14:47:19 +10:00
Dhanji Prasanna
a87f81042a remove edit_file 2025-10-02 13:58:57 +10:00
Dhanji Prasanna
8c7dd146f8 UI writer abstraction instead of printlns everywhere 2025-10-02 11:06:14 +10:00
Dhanji Prasanna
e324ddd99d hopefully a bit better tool call detection 2025-10-02 10:27:58 +10:00
Dhanji Prasanna
9638f40cfb some autonomous mode fixes 2025-10-02 09:45:18 +10:00
Dhanji Prasanna
98cf72c12a suppress printout of final_output 2025-10-01 15:26:55 +10:00
Dhanji Prasanna
046b54c49b move embedded provider to a better crate 2025-10-01 15:19:37 +10:00
Dhanji Prasanna
b9679e14dc force update of head 2025-10-01 14:12:23 +10:00
Dhanji Prasanna
a843ecc9d0 suppress json tool calls in raw text 2025-10-01 13:20:13 +10:00
Dhanji Prasanna
3349a33106 dracula theme 2025-10-01 11:24:30 +10:00
Dhanji Prasanna
1621d081ec tui lib for nicer cli 2025-10-01 11:19:34 +10:00
Dhanji Prasanna
5f642061de error handling in autonomous mode 2025-10-01 11:01:23 +10:00
Dhanji Prasanna
f0ddfdc3d2 move logs into subdir 2025-09-30 22:29:49 +10:00
Dhanji Prasanna
92318ff51c str_replace fixes 2025-09-30 22:24:54 +10:00
Dhanji Prasanna
03229effba increase max iterations to 400 2025-09-30 21:35:42 +10:00
Dhanji Prasanna
f99c61331c str_replace instead of edit_file much better 2025-09-30 21:15:28 +10:00
Dhanji Prasanna
b3c2c0ad30 edit file fixes 2025-09-30 20:41:14 +10:00
Dhanji Prasanna
3c4da6f974 update readme 2025-09-30 14:00:04 +10:00
Dhanji Prasanna
270cbae1e6 edit_file 2025-09-30 13:50:02 +10:00
Dhanji Prasanna
69fc3e90dc max_tokens fix 2025-09-29 11:05:57 +10:00
Dhanji Prasanna
ce273ba3fb multiline input with \ 2025-09-29 10:23:41 +10:00
Dhanji Prasanna
c4ee4a6cde basic project model 2025-09-29 09:23:27 +10:00
Dhanji Prasanna
315596e316 only emit final response once 2025-09-29 08:22:26 +10:00
Dhanji Prasanna
39ef13e317 fix a looping error in iterations 2025-09-29 06:54:55 +10:00
Dhanji Prasanna
4e64555008 max tokens fix for databricks 2025-09-29 06:45:53 +10:00
Dhanji Prasanna
f3cf9b688e tool call cosmetic cleanup 2025-09-27 20:40:06 +10:00
Dhanji Prasanna
e2354b0679 allow more iterations per turn 2025-09-27 20:34:39 +10:00
Dhanji Prasanna
c490228824 databricks support 2025-09-27 17:28:02 +10:00
Dhanji Prasanna
258eb4fd54 minor 2025-09-27 15:49:55 +10:00
Dhanji Prasanna
091b824b1e more debug logging 2025-09-27 15:43:33 +10:00
Dhanji Prasanna
2b561516b6 cleanup 2025-09-27 15:16:42 +10:00
Dhanji Prasanna
1046b30138 print report at end 2025-09-27 15:01:59 +10:00
Dhanji Prasanna
7fbfec50d8 working much simpler 2025-09-27 14:46:53 +10:00
Dhanji Prasanna
3c74cd410e remove subtasks 2025-09-27 14:39:08 +10:00
Dhanji Prasanna
811c642b17 suppress json text tool calls a bit jankily 2025-09-27 14:34:54 +10:00
Dhanji Prasanna
016ee80554 cap total subtasks 2025-09-27 14:28:13 +10:00
Dhanji Prasanna
622de9d540 go straight to coach turn if files exist 2025-09-27 14:19:54 +10:00
Dhanji Prasanna
e82821189b write/read file support 2025-09-27 13:43:09 +10:00
Dhanji Prasanna
7595ee083e logging optimization 2025-09-27 12:18:27 +10:00
Dhanji Prasanna
fb114cfcf5 imrpv 2025-09-27 06:29:33 +10:00
Dhanji Prasanna
e97614df76 better output 2025-09-26 22:37:30 +10:00
Dhanji Prasanna
58052fd0fe autonomous mode 2025-09-26 22:34:47 +10:00
Dhanji Prasanna
6ec596ae4d minor 2025-09-26 21:55:15 +10:00
Dhanji Prasanna
5ef4a74468 minor 2025-09-26 21:38:01 +10:00
Dhanji Prasanna
dd20e0bb01 some cleanup of converstation mgmt 2025-09-22 20:38:44 +10:00
83 changed files with 22667 additions and 1157 deletions

5
.cargo/config.toml Normal file
View File

@@ -0,0 +1,5 @@
[target.aarch64-apple-darwin]
rustflags = ["-C", "link-args=-Wl,-rpath,@executable_path"]
[target.x86_64-apple-darwin]
rustflags = ["-C", "link-args=-Wl,-rpath,@executable_path"]

8
.gitignore vendored
View File

@@ -2,10 +2,14 @@
# will have compiled files and executables
debug
target
.build
appy/
# These are backup files generated by rustfmt
**/*.rs.bk
**/.DS_Store
# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
@@ -19,3 +23,7 @@ target
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Session logs directory
logs/
*.json

2133
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,8 @@ members = [
"crates/g3-core",
"crates/g3-providers",
"crates/g3-config",
"crates/g3-execution"
"crates/g3-execution",
"crates/g3-computer-control"
]
resolver = "2"

466
DESIGN.md
View File

@@ -1,157 +1,316 @@
# G3 General Purpose AI Agent - Design Document
# G3 - AI Coding Agent - Design Document
## Overview
G3 is a **code-first AI agent** that helps you complete tasks by writing and executing code or scripts. Instead of just giving advice, G3 solves problems by generating executable code in the appropriate language.
G3 is a **modular, composable AI coding agent** built in Rust that helps you complete tasks by writing and executing code. It provides a flexible architecture for interacting with various Large Language Model (LLM) providers while offering powerful code generation, file manipulation, and task automation capabilities.
The agent follows a **tool-first philosophy**: instead of just providing advice, G3 actively uses tools to read files, write code, execute commands, and complete tasks autonomously.
## Core Principles
1. **Code-First Philosophy**: Always try to solve problems with executable code
2. **Multi-Language Support**: Generate scripts in Python, Bash, JavaScript, Rust, etc.
3. **Unix Philosophy**: Small, focused tools that do one thing well
1. **Tool-First Philosophy**: Solve problems by actively using tools rather than just providing advice
2. **Modular Architecture**: Clear separation of concerns across multiple Rust crates
3. **Provider Flexibility**: Support multiple LLM providers through a unified interface
4. **Modularity**: Clear separation of concerns
5. **Composability**: Components can be combined in different ways
6. **Performance**: Blazing fast execution
6. **Performance**: Built in Rust for speed and reliability
7. **Context Intelligence**: Smart context window management with auto-summarization
8. **Error Resilience**: Robust error handling with automatic retry logic
## Architecture
## Project Structure
### High-Level Components
G3 is organized as a Rust workspace with the following crates:
```
g3/
├── src/main.rs # Main entry point (delegates to g3-cli)
├── crates/
│ ├── g3-cli/ # Command-line interface, TUI, and retro mode
│ ├── g3-core/ # Core agent engine, tools, and streaming logic
│ ├── g3-providers/ # LLM provider abstractions and implementations
│ ├── g3-config/ # Configuration management
│ ├── g3-execution/ # Code execution engine
│ └── g3-computer-control/ # Computer control and automation
├── logs/ # Session logs (auto-created)
├── README.md # Project documentation
└── DESIGN.md # This design document
```
## Architecture Overview
### High-Level Architecture
```
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
CLI Module │ │ Core Engine │ │ LLM Providers │
g3-cli │ │ g3-core │ │ g3-providers
│ │ │ │ │ │
- Task commands │◄──►│ - Task │◄──►│ - OpenAI
- Interactive │ │ interpretation│ │ - Anthropic
mode │ │ - Code │ │ - Embedded │
- Code exec │ │ generation │ │ (llama.cpp) │
approval │ │ - Script │ │ - Custom APIs
│ │ │ execution │ │ │
• CLI parsing │◄──►│ • Agent engine │◄──►│ • Anthropic
Interactive │ │ • Context mgmt │ │ • Databricks
• Retro TUI │ │ • Tool system │ │ Embedded │
• Autonomous │ │ • Streaming │ │ (llama.cpp) │
mode │ │ • Task exec │ │ • OAuth flow
│ │ │ • TODO mgmt │ │ │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │ │
└───────────────────────┼───────────────────────┘
┌─────────────────┐
Execution │
Engine
- Python
- Bash/Shell
- JavaScript
│ - Rust
│ - Sandboxing
┌─────────────────┐ ┌─────────────────┐
g3-execution │ │ g3-config
│ │
• Code exec │ • TOML config
• Shell cmds │ │ • Env overrides
• Streaming │ │ • Provider
• Error hdlg │ │ settings
└─────────────────┘ │ • Computer
│ │ control cfg
│ └─────────────────┘
│ │
┌─────────────────┐ │
│ g3-computer- │◄────────────┘
│ control │
│ • Mouse/kbd │
│ • Screenshots │
│ • OCR/Tesseract │
│ • Windows/UI │
└─────────────────┘
```
### Module Breakdown
## Core Components
#### 1. CLI Module (`g3-cli`)
- **Responsibility**: User interface and task interpretation
- **New Features**:
- Progress indicators for script execution
### 1. g3-core: Agent Engine
#### 2. Core Engine (`g3-core`)
- **Responsibility**: Task interpretation and code generation
- **New Features**:
- Task analysis and decomposition
- Language selection based on task type
- Code generation with execution context
- Script template system
- Autonomous execution of generated code
**Primary Responsibilities:**
- Main orchestration logic for handling conversations and task execution
- Context window management with intelligent token tracking
- Built-in tool system for file operations and command execution
- Streaming response parsing with real-time tool call detection
- Error handling with automatic retry logic
#### 3. LLM Providers (`g3-providers`)
- **Responsibility**: LLM communication and model abstraction
- **Supported Providers**:
- **OpenAI**: GPT-4, GPT-3.5-turbo via API
- **Anthropic**: Claude models via API
- **Embedded**: Local open-weights models via llama.cpp
- **Enhanced Prompts**:
- Code-first system prompts
- Language-specific generation instructions
**Key Features:**
- **Context Window Intelligence**: Automatic monitoring with percentage-based tracking (80% capacity triggers auto-summarization)
- **Tool System**: Built-in tools for file operations (read, write, edit), shell commands, and structured output
- **Streaming Parser**: Real-time parsing of LLM responses with tool call detection and execution
- **Session Management**: Automatic session logging with detailed conversation history and token usage
- **Error Recovery**: Sophisticated error classification and retry logic for recoverable errors
- **TODO Management**: In-memory TODO list with read/write tools for task tracking
#### 5. Embedded Provider (`g3-core/providers/embedded`) - NEW
- **Responsibility**: Local model inference using llama.cpp
- **Features**:
- GGUF model support (Llama, CodeLlama, Mistral, etc.)
- GPU acceleration via CUDA/Metal
- Configurable context length and generation parameters
- Async-compatible inference without blocking
- Thread-safe model access
- Stop sequence detection
**Available Tools:**
- `shell`: Execute shell commands with streaming output
- `read_file`: Read file contents with optional character range support
- `write_file`: Create or overwrite files with content
- `str_replace`: Apply unified diffs to files with precise editing
- `final_output`: Signal task completion with detailed summaries
- `todo_read`: Read the entire TODO list content
- `todo_write`: Write or overwrite the entire TODO list
- `mouse_click`: Click the mouse at specific coordinates
- `type_text`: Type text at the current cursor position
- `find_element`: Find UI elements by text, role, or attributes
- `take_screenshot`: Capture screenshots of screen, region, or window
- `extract_text`: Extract text from images or screen regions using OCR
- `find_text_on_screen`: Find text visually on screen and return coordinates
- `list_windows`: List all open windows with IDs and titles
#### 4. Execution Engine (`g3-execution`) - NEW
- **Responsibility**: Safe code execution
- **Features**:
- Multi-language script execution
- Sandboxing and security
- Resource limits
- Output capture and formatting
- Error handling and recovery
### 2. g3-providers: LLM Provider Abstraction
### Task Types and Language Selection
**Primary Responsibilities:**
- Unified interface for multiple LLM providers
- Provider-specific optimizations and feature support
- OAuth authentication flows
- Streaming and non-streaming completion support
| Task Type | Preferred Language | Use Cases |
|-----------|-------------------|-----------|
| Data Processing | Python | CSV/JSON analysis, data transformation |
| File Operations | Bash/Shell | File manipulation, backups, organization |
| System Admin | Bash/Shell | Process management, system monitoring |
| Text Processing | Python/Bash | Log analysis, text transformation |
| Database | Python/SQL | Data migration, queries, reporting |
| Image/Media | Python | Image processing, format conversion |
| Development | Rust | Code generation, project setup |
**Supported Providers:**
- **Anthropic**: Claude models via API with native tool calling support
- **Databricks**: Foundation Model APIs with OAuth and token-based authentication (default provider)
- **Embedded**: Local models via llama.cpp with GPU acceleration (Metal/CUDA)
- **Provider Registry**: Dynamic provider management and hot-swapping
## Implementation Plan
**Key Features:**
- **Native Tool Calling**: Full support for structured tool calls where available
- **Fallback Parsing**: JSON tool call parsing for providers without native support
- **OAuth Integration**: Built-in OAuth flow for secure provider authentication
- **Context-Aware**: Provider-specific context length and token limit handling
- **Streaming Support**: Real-time response streaming with tool call detection
### Phase 1: Core Refactoring ✅
1. ✅ Update CLI commands for task-oriented interface
2. ✅ Enhance system prompts for code-first approach
3. ✅ Add basic code execution capabilities
4. ✅ Update interactive mode messaging
### 3. g3-cli: Command-Line Interface
### Phase 2: Enhanced Provider Support ✅
1. ✅ Implement embedded model provider using llama.cpp
2. ✅ Add GGUF model support for local inference
3. ✅ Configure GPU acceleration and performance optimization
4. ✅ Add comprehensive logging and debugging support
**Primary Responsibilities:**
- Command-line argument parsing and validation
- Interactive terminal interface with history support
- Retro-style terminal UI (80s sci-fi inspired)
- Autonomous mode with coach-player feedback loops
- Session management and workspace handling
### Phase 3: Advanced Features (Future)
1. Model quantization and optimization
2. Multi-model ensemble support
3. Advanced code execution sandboxing
4. Plugin system for custom providers
5. Web interface for remote access
**Execution Modes:**
- **Single-shot**: Execute one task and exit
- **Interactive**: REPL-style conversation with the agent (default mode)
- **Autonomous**: Coach-player feedback loop for complex projects
- **Retro TUI**: Full-screen terminal interface with real-time updates
**Key Features:**
- **Multi-line Input**: Support for complex, multi-line prompts with backslash continuation
- **Context Progress**: Real-time display of token usage and context window status
- **Error Recovery**: Automatic retry logic for timeout and recoverable errors
- **History Management**: Persistent command history across sessions
- **Theme Support**: Customizable color themes for retro mode
- **Cancellation**: Ctrl+C support for graceful operation cancellation
### 4. g3-execution: Code Execution Engine
**Primary Responsibilities:**
- Safe execution of shell commands and scripts
- Streaming output capture and display
- Multi-language code execution support
- Error handling and result formatting
**Supported Execution:**
- **Bash/Shell**: Direct command execution with streaming output (primary use case)
- **Python**: Script execution via temporary files (legacy support)
- **JavaScript**: Node.js-based execution (legacy support)
**Key Features:**
- **Streaming Output**: Real-time command output display
- **Error Capture**: Comprehensive stderr and stdout handling
- **Exit Code Tracking**: Proper success/failure detection
- **Async Execution**: Non-blocking command execution
- **Output Formatting**: Clean, user-friendly result presentation
### 5. g3-config: Configuration Management
**Primary Responsibilities:**
- TOML-based configuration file management
- Environment variable overrides
- Provider-specific settings and credentials
- CLI argument integration
**Configuration Hierarchy:**
1. Default configuration (Databricks provider with OAuth)
2. Configuration files (`~/.config/g3/config.toml`, `./g3.toml`)
3. Environment variables (`G3_*`)
4. CLI arguments (highest priority)
**Key Features:**
- **Auto-generation**: Creates default configuration files if none exist
- **Provider Overrides**: Runtime provider and model selection
- **Validation**: Configuration validation with helpful error messages
- **Flexible Paths**: Support for shell expansion (`~`, environment variables)
### 6. g3-computer-control: Computer Control & Automation
**Primary Responsibilities:**
- Cross-platform computer control and automation
- Mouse and keyboard input simulation
- Window management and screenshot capture
- OCR text extraction from images and screen regions
**Platform Support:**
- **macOS**: Core Graphics, Cocoa, screencapture integration
- **Linux**: X11/Xtest for input, X11 for window management
- **Windows**: Win32 APIs for input and window control
**Key Features:**
- **OCR Integration**: Tesseract-based text extraction from images
- **Window Management**: List, identify, and capture specific application windows
- **UI Automation**: Find elements, simulate clicks, type text
- **Screenshot Capture**: Full screen, regions, or specific windows
- **Accessibility**: Requires OS-level permissions for automation
## Advanced Features
### Context Window Management
G3 implements sophisticated context window management:
- **Automatic Monitoring**: Tracks token usage with percentage-based thresholds
- **Smart Summarization**: Auto-triggers at 80% capacity to prevent context overflow
- **Context Thinning**: Progressive thinning at 50%, 60%, 70%, 80% thresholds - replaces large tool results with file references
- **Conversation Preservation**: Maintains conversation continuity through intelligent summaries
- **Provider-Specific Limits**: Adapts to different model context windows (4k to 200k+ tokens)
- **Cumulative Tracking**: Monitors total token usage across entire sessions
### Error Handling & Resilience
Comprehensive error handling system:
- **Error Classification**: Distinguishes between recoverable and non-recoverable errors
- **Automatic Retry**: Exponential backoff with jitter for rate limits, timeouts, and server errors
- **Detailed Logging**: Comprehensive error context including stack traces and session data
- **Error Persistence**: Saves detailed error logs to `logs/errors/` for analysis
- **Graceful Degradation**: Continues operation when possible, fails gracefully when not
### Session Management
Automatic session tracking and logging:
- **Session IDs**: Generated based on initial prompts for easy identification
- **Complete Logs**: Full conversation history, token usage, and timing data
- **JSON Format**: Structured logs for easy parsing and analysis
- **Automatic Cleanup**: Organized in `logs/` directory with timestamps
- **Status Tracking**: Records session completion status (completed, cancelled, error)
### Autonomous Mode
Advanced autonomous operation with coach-player feedback:
- **Requirements-Driven**: Reads `requirements.md` for project specifications
- **Dual-Agent System**: Separate player (implementation) and coach (review) agents
- **Iterative Improvement**: Multiple rounds of implementation and feedback
- **Progress Tracking**: Detailed reporting of turns, token usage, and final status
- **Workspace Management**: Automatic workspace setup and file organization
## Provider Comparison
| Feature | OpenAI | Anthropic | Embedded |
|---------|--------|-----------|----------|
| Feature | Anthropic | Databricks (Default) | Embedded |
|---------|-----------|------------|----------|
| **Cost** | Pay per token | Pay per token | Free after download |
| **Privacy** | Data sent to API | Data sent to API | Completely local |
| **Performance** | Very fast | Very fast | Depends on hardware |
| **Model Quality** | Excellent | Excellent | Good (varies by model) |
| **Offline Support** | No | No | Yes |
| **Setup Complexity** | API key only | API key only | Model download required |
| **Setup Complexity** | API key only | OAuth or token | Model download required |
| **Context Window** | 200k tokens | Varies by model | 4k-32k tokens |
| **Tool Calling** | Native support | Native support | JSON fallback |
| **Hardware Requirements** | None | None | 4-16GB RAM, optional GPU |
## Configuration Examples
### Cloud-First Setup
### Cloud-First Setup (Anthropic)
```toml
[providers]
default_provider = "openai"
default_provider = "anthropic"
[providers.openai]
api_key = "sk-..."
model = "gpt-4"
[providers.anthropic]
api_key = "sk-ant-..."
model = "claude-3-5-sonnet-20241022"
max_tokens = 8192
temperature = 0.1
```
### Privacy-First Setup
### Enterprise Setup (Databricks - Default)
```toml
[providers]
default_provider = "databricks"
[providers.databricks]
host = "https://your-workspace.cloud.databricks.com"
model = "databricks-claude-sonnet-4"
max_tokens = 32000
temperature = 0.1
use_oauth = true
```
### Privacy-First Setup (Local Models)
```toml
[providers]
default_provider = "embedded"
[providers.embedded]
model_path = "~/.cache/g3/models/codellama-7b-instruct.Q4_K_M.gguf"
model_type = "codellama"
model_path = "~/.cache/g3/models/qwen2.5-7b-instruct-q3_k_m.gguf"
model_type = "qwen"
context_length = 32768
max_tokens = 2048
temperature = 0.1
gpu_layers = 32
threads = 8
```
### Hybrid Setup
@@ -159,14 +318,109 @@ gpu_layers = 32
[providers]
default_provider = "embedded"
# Use embedded for most tasks
# Local model for most tasks
[providers.embedded]
model_path = "~/.cache/g3/models/codellama-7b-instruct.Q4_K_M.gguf"
model_type = "codellama"
context_length = 16384
gpu_layers = 32
# Fallback to cloud for complex tasks
[providers.openai]
api_key = "sk-..."
model = "gpt-4"
# Cloud fallback for complex tasks
[providers.anthropic]
api_key = "sk-ant-..."
model = "claude-3-5-sonnet-20241022"
```
## Usage Examples
### Single-Shot Mode
```bash
g3 "implement a fibonacci function in Rust"
```
### Interactive Mode
```bash
g3
g3> read the README and suggest improvements
g3> implement the suggestions you made
```
### Autonomous Mode
```bash
g3 --autonomous --max-turns 10
# Reads requirements.md and implements iteratively
```
### Retro TUI Mode
```bash
g3 --retro --theme dracula
# Full-screen terminal interface
```
## Implementation Details
### Planned Features
- **Plugin System**: Custom tool and provider plugins
- **Web Interface**: Browser-based UI for remote access
- **Model Quantization**: Optimized local model deployment
- **Multi-Model Ensemble**: Combine multiple models for better results
- **Advanced Sandboxing**: Enhanced security for code execution
- **Collaborative Mode**: Multi-user sessions and shared workspaces
### Technical Improvements
- **Performance Optimization**: Faster streaming and tool execution
- **Memory Management**: Better handling of large contexts and files
- **Caching System**: Intelligent caching of model responses and computations
- **Monitoring**: Built-in metrics and performance monitoring
- **Testing**: Comprehensive test suite and CI/CD integration
## Development Guidelines
### Code Organization
- **Modular Design**: Each crate has a single, well-defined responsibility
- **Trait-Based**: Use traits for abstraction and testability
- **Error Handling**: Comprehensive error types with context
- **Documentation**: Inline docs and examples for all public APIs
- **Testing**: Unit tests, integration tests, and property-based testing
### Performance Considerations
- **Async-First**: All I/O operations are asynchronous (Tokio runtime)
- **Streaming**: Real-time response processing where possible
- **Memory Efficiency**: Careful memory management for large contexts
- **Caching**: Strategic caching of expensive operations
- **Profiling**: Regular performance profiling and optimization
This design document reflects the current state of G3 as a mature, production-ready AI coding agent with sophisticated architecture and comprehensive feature set.
## Current Implementation Status
### Fully Implemented
-**Core Agent Engine**: Complete with streaming, tool execution, and context management
-**Provider System**: Anthropic, Databricks, and Embedded providers with OAuth support
-**Tool System**: 13 tools including file ops, shell, TODO management, and computer control
-**CLI Interface**: Interactive mode, single-shot mode, retro TUI
-**Autonomous Mode**: Coach-player feedback loop with requirements.md processing
-**Configuration**: TOML-based config with environment overrides
-**Error Handling**: Comprehensive retry logic and error classification
-**Session Logging**: Automatic session tracking and JSON logs
-**Context Management**: Context thinning (50-80%) and auto-summarization at 80% capacity
-**Computer Control**: Cross-platform automation with OCR support
-**TODO Management**: In-memory TODO list with read/write tools
### Architecture Highlights
- **Workspace**: 6 crates with clear separation of concerns
- **Dependencies**: Modern Rust ecosystem (Tokio, Clap, Serde, etc.)
- **Streaming**: Real-time response processing with tool call detection
- **Cross-Platform**: Works on macOS, Linux, and Windows
- **GPU Support**: Metal acceleration for local models on macOS, CUDA on Linux
- **OCR Support**: Tesseract integration for text extraction from images
### Key Files
- `src/main.rs`: main entry point delegating to g3-cli
- `crates/g3-core/src/lib.rs`: main agent implementation
- `crates/g3-cli/src/lib.rs`: CLI and interaction modes
- `crates/g3-providers/src/lib.rs`: provider trait and registry
- `crates/g3-config/src/lib.rs`: configuration management
- `crates/g3-execution/src/lib.rs`: code execution engine
- `crates/g3-computer-control/src/lib.rs`: computer control and automation
- `crates/g3-computer-control/src/platform/`: platform-specific implementations

456
OLLAMA_CONFIG.md Normal file
View File

@@ -0,0 +1,456 @@
# Configuring Ollama Provider in G3
This guide shows you how to configure G3 to use Ollama as your LLM provider.
## Quick Start
### 1. Install Ollama
```bash
# Visit https://ollama.ai to download and install
# Or use curl:
curl https://ollama.ai/install.sh | sh
```
### 2. Pull a Model
```bash
ollama pull llama3.2
# or any other model you prefer
```
### 3. Create Configuration File
Copy the example configuration:
```bash
cp config.ollama.example.toml ~/.config/g3/config.toml
```
Or create it manually:
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "llama3.2"
```
### 4. Run G3
```bash
g3
# G3 will now use Ollama with llama3.2!
```
## Configuration Options
### Basic Configuration
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "llama3.2"
```
This is the minimal configuration needed. It uses all defaults:
- Base URL: `http://localhost:11434`
- Temperature: `0.7`
- Max tokens: Not limited (uses model default)
### Full Configuration
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "llama3.2"
base_url = "http://localhost:11434"
max_tokens = 2048
temperature = 0.7
```
### Custom Ollama Host
If you're running Ollama on a different machine or port:
```toml
[providers.ollama]
model = "llama3.2"
base_url = "http://192.168.1.100:11434"
```
### Different Models
You can use any Ollama model:
```toml
[providers.ollama]
model = "qwen2.5:7b" # Alibaba's Qwen model
```
```toml
[providers.ollama]
model = "mistral" # Mistral AI
```
```toml
[providers.ollama]
model = "llama3.1:70b" # Larger Llama model
```
## Multiple Provider Configuration
You can configure multiple providers and switch between them:
```toml
[providers]
default_provider = "ollama" # Default for most operations
# Ollama for local, fast responses
[providers.ollama]
model = "llama3.2:3b"
temperature = 0.7
# Databricks for more complex tasks
[providers.databricks]
host = "https://your-workspace.cloud.databricks.com"
model = "databricks-claude-sonnet-4"
max_tokens = 4096
temperature = 0.1
use_oauth = true
```
Then switch providers with:
```bash
g3 --provider databricks
```
## Autonomous Mode (Coach-Player)
Use different providers for code review (coach) and implementation (player):
```toml
[providers]
default_provider = "ollama"
coach = "databricks" # Use powerful cloud model for review
player = "ollama" # Use local model for implementation
[providers.ollama]
model = "qwen2.5:14b" # Larger local model for coding
[providers.databricks]
host = "https://your-workspace.cloud.databricks.com"
model = "databricks-claude-sonnet-4"
use_oauth = true
```
This gives you the best of both worlds:
- Fast local execution for coding tasks
- Powerful cloud review for quality assurance
## Recommended Models
### For Coding Tasks
| Model | Size | Speed | Quality | Notes |
|-------|------|-------|---------|-------|
| **qwen2.5:7b** | 7B | Fast | Excellent | Best balance for coding |
| **llama3.2:3b** | 3B | Very Fast | Good | Great for quick tasks |
| **llama3.1:8b** | 8B | Medium | Very Good | Solid all-rounder |
| **mistral** | 7B | Fast | Good | Good for general use |
### For Complex Tasks
| Model | Size | Speed | Quality | Notes |
|-------|------|-------|---------|-------|
| **qwen2.5:14b** | 14B | Medium | Excellent | Best local model for coding |
| **qwen2.5:32b** | 32B | Slow | Outstanding | If you have the resources |
| **llama3.1:70b** | 70B | Very Slow | Outstanding | Requires significant RAM/GPU |
## Temperature Settings
Temperature controls randomness in responses:
- **0.1-0.3**: Deterministic, good for code generation
- **0.5-0.7**: Balanced, good for most tasks
- **0.8-1.0**: Creative, good for brainstorming
```toml
[providers.ollama]
model = "qwen2.5:7b"
temperature = 0.2 # Focused code generation
```
## Max Tokens
Control response length:
```toml
[providers.ollama]
model = "llama3.2"
max_tokens = 1024 # Shorter responses
```
```toml
[providers.ollama]
model = "qwen2.5:7b"
max_tokens = 4096 # Longer, detailed responses
```
Leave it unset for model defaults (recommended).
## Performance Tuning
### GPU Acceleration
Ollama automatically uses GPU if available. To check:
```bash
ollama ps
```
### Quantized Models
For faster responses with less RAM:
```toml
[providers.ollama]
model = "llama3.2:3b-q4_0" # 4-bit quantization
```
Quantization options:
- `q4_0`: 4-bit, fastest, lowest quality
- `q5_0`: 5-bit, balanced
- `q8_0`: 8-bit, slower, better quality
### Multiple Models
You can pull multiple models and switch easily:
```bash
ollama pull llama3.2:3b # Fast for chat
ollama pull qwen2.5:7b # Better for code
ollama pull mistral # General purpose
```
Then change your config:
```toml
[providers.ollama]
model = "qwen2.5:7b" # Just change this line
```
## Troubleshooting
### Ollama Not Running
```bash
# Check if Ollama is running
curl http://localhost:11434/api/version
# Start Ollama (macOS/Linux)
ollama serve
# Or just run a model (auto-starts)
ollama run llama3.2
```
### Model Not Found
```bash
# List available models
ollama list
# Pull the model
ollama pull llama3.2
```
### Slow Responses
1. Use a smaller model:
```toml
model = "llama3.2:1b" # Smallest, fastest
```
2. Use quantized version:
```toml
model = "llama3.2:3b-q4_0"
```
3. Reduce max_tokens:
```toml
max_tokens = 512
```
### Out of Memory
1. Switch to smaller model
2. Use quantized version
3. Close other applications
4. Check GPU memory: `ollama ps`
### Connection Refused
Check base_url is correct:
```toml
[providers.ollama]
model = "llama3.2"
base_url = "http://localhost:11434" # Default
```
For remote Ollama:
```toml
base_url = "http://your-server:11434"
```
## Complete Example Configs
### Minimal Local Setup
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "llama3.2"
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
```
### Optimized for Coding
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "qwen2.5:7b"
temperature = 0.2
max_tokens = 2048
[agent]
max_context_length = 16384
enable_streaming = true
timeout_seconds = 120
```
### Fast Responses
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "llama3.2:3b-q4_0"
temperature = 0.7
max_tokens = 1024
[agent]
max_context_length = 4096
enable_streaming = true
timeout_seconds = 30
```
### High Quality (Requires Good Hardware)
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "qwen2.5:32b"
temperature = 0.3
max_tokens = 4096
[agent]
max_context_length = 32768
enable_streaming = true
timeout_seconds = 300
```
### Hybrid (Local + Cloud)
```toml
[providers]
default_provider = "ollama"
coach = "databricks"
player = "ollama"
[providers.ollama]
model = "qwen2.5:14b"
temperature = 0.2
[providers.databricks]
host = "https://your-workspace.cloud.databricks.com"
model = "databricks-claude-sonnet-4"
use_oauth = true
[agent]
max_context_length = 16384
enable_streaming = true
timeout_seconds = 120
```
## Environment Variables
You can override config with environment variables:
```bash
# Override model
G3_PROVIDERS_OLLAMA_MODEL=qwen2.5:7b g3
# Override base URL
G3_PROVIDERS_OLLAMA_BASE_URL=http://192.168.1.100:11434 g3
# Override default provider
G3_PROVIDERS_DEFAULT_PROVIDER=ollama g3
```
## Best Practices
1. **Start Small**: Begin with llama3.2:3b, scale up if needed
2. **Use Quantization**: q4_0 or q5_0 for best speed/quality balance
3. **Match Task to Model**:
- Quick edits: 1B-3B models
- Code generation: 7B-14B models
- Complex refactoring: 14B-32B models
4. **Temperature for Code**: Use 0.1-0.3 for deterministic output
5. **Enable Streaming**: Always enable for better UX
6. **Local First**: Use Ollama by default, cloud for special cases
## Comparison with Other Providers
| Feature | Ollama | Databricks | OpenAI | Anthropic |
|---------|--------|------------|--------|-----------|
| Cost | Free | Paid | Paid | Paid |
| Privacy | Full | Medium | Low | Low |
| Speed (small models) | Fast | Fast | Medium | Medium |
| Speed (large models) | Slow | Fast | Fast | Fast |
| Setup Complexity | Low | Medium | Low | Low |
| Authentication | None | OAuth/Token | API Key | API Key |
| Offline Support | Yes | No | No | No |
| Tool Calling | Yes | Yes | Yes | Yes |
## Next Steps
1. Try different models: `ollama pull mistral`, `ollama pull qwen2.5`
2. Experiment with temperature settings
3. Set up hybrid config with cloud provider for complex tasks
4. Share your config in the community!
## Getting Help
- Ollama docs: https://ollama.ai/docs
- G3 issues: https://github.com/your-repo/issues
- Test your config: `g3 --help`

315
OLLAMA_EXAMPLE.md Normal file
View File

@@ -0,0 +1,315 @@
# Ollama Provider for g3
A simple, local LLM provider implementation for g3 that connects to Ollama.
## Features
-**Simple Setup**: No API keys or authentication required
-**Local Execution**: Runs entirely on your machine
-**Tool Calling Support**: Native tool calling for compatible models
-**Streaming**: Full streaming support with real-time responses
-**Flexible Configuration**: Custom base URL, temperature, and max tokens
-**Model Discovery**: Automatic detection of available models
## Quick Start
### Prerequisites
1. Install and start Ollama: https://ollama.ai
2. Pull a model: `ollama pull llama3.2`
### Basic Usage
```rust
use g3_providers::{OllamaProvider, LLMProvider, CompletionRequest, Message, MessageRole};
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Create provider with default settings (localhost:11434)
let provider = OllamaProvider::new(
"llama3.2".to_string(),
None, // base_url: defaults to http://localhost:11434
None, // max_tokens: optional
None, // temperature: defaults to 0.7
)?;
// Create a simple request
let request = CompletionRequest {
messages: vec![
Message {
role: MessageRole::User,
content: "What is the capital of France?".to_string(),
},
],
max_tokens: Some(1000),
temperature: Some(0.7),
stream: false,
tools: None,
};
// Get completion
let response = provider.complete(request).await?;
println!("Response: {}", response.content);
println!("Tokens: {}", response.usage.total_tokens);
Ok(())
}
```
### Streaming Example
```rust
use futures_util::StreamExt;
let request = CompletionRequest {
messages: vec![
Message {
role: MessageRole::User,
content: "Write a short poem about coding".to_string(),
},
],
max_tokens: Some(500),
temperature: Some(0.8),
stream: true,
tools: None,
};
let mut stream = provider.stream(request).await?;
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
print!("{}", chunk.content);
if chunk.finished {
println!("\n\nDone!");
if let Some(usage) = chunk.usage {
println!("Total tokens: {}", usage.total_tokens);
}
}
}
Err(e) => eprintln!("Error: {}", e),
}
}
```
### Tool Calling Example
```rust
use serde_json::json;
let tools = vec![Tool {
name: "get_weather".to_string(),
description: "Get current weather for a location".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "City name"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "Temperature unit"
}
},
"required": ["location"]
}),
}];
let request = CompletionRequest {
messages: vec![
Message {
role: MessageRole::User,
content: "What's the weather in Paris?".to_string(),
},
],
max_tokens: Some(500),
temperature: Some(0.5),
stream: false,
tools: Some(tools),
};
let response = provider.complete(request).await?;
println!("Response: {}", response.content);
```
### Custom Ollama Host
```rust
// Connect to remote Ollama instance
let provider = OllamaProvider::new(
"llama3.2".to_string(),
Some("http://192.168.1.100:11434".to_string()),
None,
None,
)?;
```
### Fetch Available Models
```rust
// Discover what models are available
let models = provider.fetch_available_models().await?;
println!("Available models:");
for model in models {
println!(" - {}", model);
}
```
## Supported Models
The provider works with any Ollama model, including:
- **llama3.2** (1B, 3B) - Meta's latest Llama models
- **llama3.1** (8B, 70B, 405B) - Previous generation
- **qwen2.5** (7B, 14B, 32B) - Alibaba's Qwen models
- **mistral** - Mistral AI models
- **mixtral** - Mixture of experts model
- **phi3** - Microsoft's Phi-3
- **gemma2** - Google's Gemma 2
## Configuration
### Constructor Parameters
```rust
OllamaProvider::new(
model: String, // Model name (e.g., "llama3.2")
base_url: Option<String>, // Ollama API URL (default: http://localhost:11434)
max_tokens: Option<u32>, // Maximum tokens to generate (optional)
temperature: Option<f32>, // Sampling temperature (default: 0.7)
)
```
### Request Options
```rust
CompletionRequest {
messages: Vec<Message>, // Conversation history
max_tokens: Option<u32>, // Override provider's max_tokens
temperature: Option<f32>, // Override provider's temperature
stream: bool, // Enable streaming responses
tools: Option<Vec<Tool>>, // Tools for function calling
}
```
## Comparison with Other Providers
| Feature | Ollama | OpenAI | Anthropic | Databricks |
|---------|--------|--------|-----------|------------|
| Local Execution | ✅ | ❌ | ❌ | ❌ |
| Authentication | None | API Key | API Key | OAuth/Token |
| Tool Calling | ✅ | ✅ | ✅ | ✅ |
| Streaming | ✅ | ✅ | ✅ | ✅ |
| Cost | Free | Paid | Paid | Paid |
| Privacy | High | Low | Low | Medium |
## Implementation Details
### API Endpoints
- **Chat Completion**: `POST /api/chat`
- **Model List**: `GET /api/tags`
### Response Format
Ollama uses a simple JSON-per-line streaming format:
```json
{"message":{"role":"assistant","content":"Hello"},"done":false}
{"message":{"role":"assistant","content":" there"},"done":false}
{"done":true,"prompt_eval_count":10,"eval_count":20}
```
### Tool Call Format
Tool calls are returned in the message structure:
```json
{
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": {"location": "Paris", "unit": "celsius"}
}
}
]
},
"done": true
}
```
## Troubleshooting
### Connection Errors
If you see connection errors, ensure Ollama is running:
```bash
# Check if Ollama is running
curl http://localhost:11434/api/version
# Start Ollama (if needed)
ollama serve
```
### Model Not Found
Pull the model first:
```bash
ollama pull llama3.2
ollama list # Check available models
```
### Performance Issues
- Use smaller models (1B, 3B) for faster responses
- Reduce `max_tokens` to limit generation length
- Enable GPU acceleration if available
- Consider quantized models (e.g., `llama3.2:3b-q4_0`)
## Testing
Run the included tests:
```bash
cargo test --package g3-providers ollama
```
All tests should pass:
```
running 4 tests
test ollama::tests::test_custom_base_url ... ok
test ollama::tests::test_message_conversion ... ok
test ollama::tests::test_provider_creation ... ok
test ollama::tests::test_tool_conversion ... ok
```
## Architecture
The provider follows the same architecture as other g3 providers:
1. **OllamaProvider**: Main struct implementing `LLMProvider` trait
2. **Request/Response Structures**: Internal types for Ollama API
3. **Streaming Parser**: Handles line-by-line JSON parsing
4. **Tool Call Handling**: Accumulates and converts tool calls
5. **Error Handling**: Robust error handling with retries
## Contributing
The provider is part of the g3-providers crate. To contribute:
1. Add features to `ollama.rs`
2. Update tests
3. Run `cargo test --package g3-providers`
4. Update this documentation
## License
Same as the g3 project.

248
README.md
View File

@@ -1,3 +1,247 @@
# G3
# G3 - AI Coding Agent
An experiment in a code-first AI agent that helps you complete tasks by writing and executing code.
G3 is a coding AI agent designed to help you complete tasks by writing code and executing commands. Built in Rust, it provides a flexible architecture for interacting with various Large Language Model (LLM) providers while offering powerful code generation and task automation capabilities.
## Architecture Overview
G3 follows a modular architecture organized as a Rust workspace with multiple crates, each responsible for specific functionality:
### Core Components
#### **g3-core**
The heart of the agent system, containing:
- **Agent Engine**: Main orchestration logic for handling conversations, tool execution, and task management
- **Context Window Management**: Intelligent tracking of token usage with context thinning (50-80%) and auto-summarization at 80% capacity
- **Tool System**: Built-in tools for file operations, shell commands, computer control, TODO management, and structured output
- **Streaming Response Parser**: Real-time parsing of LLM responses with tool call detection and execution
- **Task Execution**: Support for single and iterative task execution with automatic retry logic
#### **g3-providers**
Abstraction layer for LLM providers:
- **Provider Interface**: Common trait-based API for different LLM backends
- **Multiple Provider Support**:
- Anthropic (Claude models)
- Databricks (DBRX and other models)
- Local/embedded models via llama.cpp with Metal acceleration on macOS
- **OAuth Authentication**: Built-in OAuth flow support for secure provider authentication
- **Provider Registry**: Dynamic provider management and selection
#### **g3-config**
Configuration management system:
- Environment-based configuration
- Provider credentials and settings
- Model selection and parameters
- Runtime configuration options
#### **g3-execution**
Task execution framework:
- Task planning and decomposition
- Execution strategies (sequential, parallel)
- Error handling and retry mechanisms
- Progress tracking and reporting
#### **g3-computer-control**
Computer control capabilities:
- Mouse and keyboard automation
- UI element inspection and interaction
- Screenshot capture and window management
- OCR text extraction via Tesseract
#### **g3-cli**
Command-line interface:
- Interactive terminal interface
- Task submission and monitoring
- Configuration management commands
- Session management
### Error Handling & Resilience
G3 includes robust error handling with automatic retry logic:
- **Recoverable Error Detection**: Automatically identifies recoverable errors (rate limits, network issues, server errors, timeouts)
- **Exponential Backoff with Jitter**: Implements intelligent retry delays to avoid overwhelming services
- **Detailed Error Logging**: Captures comprehensive error context including stack traces, request/response data, and session information
- **Error Persistence**: Saves detailed error logs to `logs/errors/` for post-mortem analysis
- **Graceful Degradation**: Non-recoverable errors are logged with full context before terminating
## Key Features
### Intelligent Context Management
- Automatic context window monitoring with percentage-based tracking
- Smart auto-summarization when approaching token limits
- **Context thinning** at 50%, 60%, 70%, 80% thresholds - automatically replaces large tool results with file references
- Conversation history preservation through summaries
- Dynamic token allocation for different providers (4k to 200k+ tokens)
### Interactive Control Commands
G3's interactive CLI includes control commands for manual context management:
- **`/compact`**: Manually trigger summarization to compact conversation history
- **`/thinnify`**: Manually trigger context thinning to replace large tool results with file references
- **`/readme`**: Reload README.md and AGENTS.md from disk without restarting
- **`/stats`**: Show detailed context and performance statistics
- **`/help`**: Display all available control commands
These commands give you fine-grained control over context management, allowing you to proactively optimize token usage and refresh project documentation. See [Control Commands Documentation](docs/CONTROL_COMMANDS.md) for detailed usage.
### Tool Ecosystem
- **File Operations**: Read, write, and edit files with line-range precision
- **Shell Integration**: Execute system commands with output capture
- **Code Generation**: Structured code generation with syntax awareness
- **TODO Management**: Read and write TODO lists with markdown checkbox format
- **Computer Control** (Experimental): Automate desktop applications
- Mouse and keyboard control
- macOS Accessibility API for native app automation (via `--macax` flag)
- UI element inspection
- Screenshot capture and window management
- OCR text extraction from images and screen regions
- Window listing and identification
- **Final Output**: Formatted result presentation
### Provider Flexibility
- Support for multiple LLM providers through a unified interface
- Hot-swappable providers without code changes
- Provider-specific optimizations and feature support
- Local model support for offline operation
### Task Automation
- Single-shot task execution for quick operations
- Iterative task mode for complex, multi-step workflows
- Automatic error recovery and retry logic
- Progress tracking and intermediate result handling
## Language & Technology Stack
- **Language**: Rust (2021 edition)
- **Async Runtime**: Tokio for concurrent operations
- **HTTP Client**: Reqwest for API communications
- **Serialization**: Serde for JSON handling
- **CLI Framework**: Clap for command-line parsing
- **Logging**: Tracing for structured logging
- **Local Models**: llama.cpp with Metal acceleration support
## Use Cases
G3 is designed for:
- Automated code generation and refactoring
- File manipulation and project scaffolding
- System administration tasks
- Data processing and transformation
- API integration and testing
- Documentation generation
- Complex multi-step workflows
- Desktop application automation and testing
## Getting Started
### Default Mode: Accumulative Autonomous
The default interactive mode now uses **accumulative autonomous mode**, which combines the best of interactive and autonomous workflows:
```bash
# Simply run g3 in any directory
g3
# You'll be prompted to describe what you want to build
# Each input you provide:
# 1. Gets added to accumulated requirements
# 2. Automatically triggers autonomous mode (coach-player loop)
# 3. Implements your requirements iteratively
# Example session:
requirement> create a simple web server in Python with Flask
# ... autonomous mode runs and implements it ...
requirement> add a /health endpoint that returns JSON
# ... autonomous mode runs again with both requirements ...
```
### Other Modes
```bash
# Single-shot mode (one task, then exit)
g3 "implement a function to calculate fibonacci numbers"
# Traditional autonomous mode (reads requirements.md)
g3 --autonomous
# Traditional chat mode (simple interactive chat without autonomous runs)
g3 --chat
```
```bash
# Build the project
cargo build --release
# Run from the build directory
./target/release/g3
# Or copy both files to somewhere in your PATH (macOS only needs both files)
cp target/release/g3 ~/.local/bin/
cp target/release/libVisionBridge.dylib ~/.local/bin/ # macOS only
# Execute a task
g3 "implement a function to calculate fibonacci numbers"
```
## WebDriver Browser Automation
G3 includes WebDriver support for browser automation tasks using Safari.
**One-Time Setup** (macOS only):
Safari Remote Automation must be enabled before using WebDriver tools. Run this once:
```bash
# Option 1: Use the provided script
./scripts/enable-safari-automation.sh
# Option 2: Enable manually
safaridriver --enable # Requires password
# Option 3: Enable via Safari UI
# Safari → Preferences → Advanced → Show Develop menu
# Then: Develop → Allow Remote Automation
```
**For detailed setup instructions and troubleshooting**, see [WebDriver Setup Guide](docs/webdriver-setup.md).
**Usage**: Run G3 with the `--webdriver` flag to enable browser automation tools.
## macOS Accessibility API Tools
G3 includes support for controlling macOS applications via the Accessibility API, allowing you to automate native macOS apps.
**Available Tools**: `macax_list_apps`, `macax_get_frontmost_app`, `macax_activate_app`, `macax_get_ui_tree`, `macax_find_elements`, `macax_click`, `macax_set_value`, `macax_get_value`, `macax_press_key`
**Setup**: Enable with the `--macax` flag or in config with `macax.enabled = true`. Grant accessibility permissions:
- **macOS**: System Preferences → Security & Privacy → Privacy → Accessibility → Add your terminal app
**For detailed documentation**, see [macOS Accessibility Tools Guide](docs/macax-tools.md).
**Note**: This is particularly useful for testing and automating apps you're building with G3, as you can add accessibility identifiers to your UI elements.
## Computer Control (Experimental)
G3 can interact with your computer's GUI for automation tasks:
**Available Tools**: `mouse_click`, `type_text`, `find_element`, `take_screenshot`, `extract_text`, `find_text_on_screen`, `list_windows`
**Setup**: Enable in config with `computer_control.enabled = true` and grant OS accessibility permissions:
- **macOS**: System Preferences → Security & Privacy → Accessibility
- **Linux**: Ensure X11 or Wayland access
- **Windows**: Run as administrator (first time only)
## Session Logs
G3 automatically saves session logs for each interaction in the `logs/` directory. These logs contain:
- Complete conversation history
- Token usage statistics
- Timestamps and session status
The `logs/` directory is created automatically on first use and is excluded from version control.
## License
MIT License - see LICENSE file for details
## Contributing
G3 is an open-source project. Contributions are welcome! Please see CONTRIBUTING.md for guidelines.

19
TODO Normal file
View File

@@ -0,0 +1,19 @@
next tasks
x get something working with autonomous mode
- g3d
- bug where it prints everything in a conversation turn all over again before final_output
x ui abstraction from core
- context token counting bug
- embedded model
- prompt rewriting
- generates status messages "ruffling feathers..."
- project description?
- treesitter + friends
x error where it just gives up turn
- "project" behaviors (read readme first)
- advance project mgmt
- git for reverting
- swarm
- ui tests / computer controller

View File

@@ -0,0 +1,24 @@
[providers]
default_provider = "databricks"
# Specify different providers for coach and player in autonomous mode
coach = "databricks" # Provider for coach (code reviewer) - can be more powerful/expensive
player = "anthropic" # Provider for player (code implementer) - can be faster/cheaper
[providers.databricks]
host = "https://your-workspace.cloud.databricks.com"
# token = "your-databricks-token" # Optional - will use OAuth if not provided
model = "databricks-claude-sonnet-4"
max_tokens = 4096
temperature = 0.1
use_oauth = true
[providers.anthropic]
api_key = "your-anthropic-api-key"
model = "claude-3-haiku-20240307" # Using a faster model for player
max_tokens = 4096
temperature = 0.3 # Slightly higher temperature for more creative implementations
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60

View File

@@ -1,38 +1,25 @@
# Example configuration file for G3
# Copy to ~/.config/g3/config.toml and customize
[providers]
default_provider = "embedded"
default_provider = "databricks"
# Optional: Specify different providers for coach and player in autonomous mode
# If not specified, will use default_provider for both
# coach = "databricks" # Provider for coach (code reviewer)
# player = "anthropic" # Provider for player (code implementer)
# Note: Make sure the specified providers are configured below
[providers.openai]
# Get your API key from https://platform.openai.com/api-keys
api_key = "sk-your-openai-api-key-here"
model = "gpt-4"
# Optional: custom base URL for OpenAI-compatible APIs
# base_url = "https://api.openai.com/v1"
max_tokens = 2048
temperature = 0.1
[providers.anthropic]
# Get your API key from https://console.anthropic.com/
api_key = "your-anthropic-api-key-here"
model = "claude-3-5-sonnet-20241022"
[providers.databricks]
host = "https://your-workspace.cloud.databricks.com"
# token = "your-databricks-token" # Optional - will use OAuth if not provided
model = "databricks-claude-sonnet-4"
max_tokens = 4096
temperature = 0.1
[providers.embedded]
# Path to your GGUF model file
model_path = "~/.cache/g3/models/codellama-7b-instruct.Q4_K_M.gguf"
model_type = "codellama"
context_length = 16384 # Use CodeLlama's full context capability
max_tokens = 2048 # Default fallback, but will be calculated dynamically
temperature = 0.1
# Number of layers to offload to GPU (0 for CPU only)
gpu_layers = 32
# Number of CPU threads to use
threads = 8
use_oauth = true
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
[computer_control]
enabled = false # Set to true to enable computer control (requires OS permissions)
require_confirmation = true
max_actions_per_second = 5

View File

@@ -0,0 +1,26 @@
# Example G3 configuration using Ollama provider
# Copy this to ~/.config/g3/config.toml or ./g3.toml to use it
[providers]
default_provider = "ollama"
# Ollama configuration (local LLM)
[providers.ollama]
model = "llama3.2" # or qwen2.5, mistral, etc.
# base_url = "http://localhost:11434" # Optional, defaults to localhost
# max_tokens = 2048 # Optional
# temperature = 0.7 # Optional
# Optional: Specify different providers for coach and player in autonomous mode
# coach = "ollama" # Provider for coach (code reviewer)
# player = "ollama" # Provider for player (code implementer)
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
[computer_control]
enabled = false # Set to true to enable computer control (requires OS permissions)
require_confirmation = true
max_actions_per_second = 5

View File

@@ -12,9 +12,13 @@ tokio = { workspace = true }
anyhow = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter"] }
serde = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
rustyline = "17.0.1"
dirs = "5.0"
tokio-util = "0.7"
indicatif = "0.17"
chrono = { version = "0.4", features = ["serde"] }
crossterm = "0.29.0"
ratatui = "0.29"
termimad = "0.34.0"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,94 @@
use g3_core::ui_writer::UiWriter;
use std::io::{self, Write};
/// Machine-mode implementation of UiWriter that prints plain, unformatted output
/// This is designed for programmatic consumption and outputs everything verbatim
pub struct MachineUiWriter;
impl MachineUiWriter {
pub fn new() -> Self {
Self
}
}
impl UiWriter for MachineUiWriter {
fn print(&self, message: &str) {
print!("{}", message);
}
fn println(&self, message: &str) {
println!("{}", message);
}
fn print_inline(&self, message: &str) {
print!("{}", message);
let _ = io::stdout().flush();
}
fn print_system_prompt(&self, prompt: &str) {
println!("SYSTEM_PROMPT:");
println!("{}", prompt);
println!("END_SYSTEM_PROMPT");
println!();
}
fn print_context_status(&self, message: &str) {
println!("CONTEXT_STATUS: {}", message);
}
fn print_context_thinning(&self, message: &str) {
println!("CONTEXT_THINNING: {}", message);
}
fn print_tool_header(&self, tool_name: &str) {
println!("TOOL_CALL: {}", tool_name);
}
fn print_tool_arg(&self, key: &str, value: &str) {
println!("TOOL_ARG: {} = {}", key, value);
}
fn print_tool_output_header(&self) {
println!("TOOL_OUTPUT:");
}
fn update_tool_output_line(&self, line: &str) {
println!("{}", line);
}
fn print_tool_output_line(&self, line: &str) {
println!("{}", line);
}
fn print_tool_output_summary(&self, count: usize) {
println!("TOOL_OUTPUT_LINES: {}", count);
}
fn print_tool_timing(&self, duration_str: &str) {
println!("TOOL_DURATION: {}", duration_str);
println!("END_TOOL_OUTPUT");
println!();
}
fn print_agent_prompt(&self) {
println!("AGENT_RESPONSE:");
let _ = io::stdout().flush();
}
fn print_agent_response(&self, content: &str) {
print!("{}", content);
let _ = io::stdout().flush();
}
fn notify_sse_received(&self) {
// No-op for machine mode
}
fn flush(&self) {
let _ = io::stdout().flush();
}
fn wants_full_output(&self) -> bool {
true // Machine mode wants complete, untruncated output
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,32 @@
/// Simple output helper for printing messages
pub struct SimpleOutput {
machine_mode: bool,
}
impl SimpleOutput {
pub fn new() -> Self {
SimpleOutput { machine_mode: false }
}
pub fn new_with_mode(machine_mode: bool) -> Self {
SimpleOutput { machine_mode }
}
pub fn print(&self, message: &str) {
if !self.machine_mode {
println!("{}", message);
}
}
pub fn print_smart(&self, message: &str) {
if !self.machine_mode {
println!("{}", message);
}
}
}
impl Default for SimpleOutput {
fn default() -> Self {
Self::new()
}
}

147
crates/g3-cli/src/theme.rs Normal file
View File

@@ -0,0 +1,147 @@
use ratatui::style::Color;
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::Path;
use anyhow::Result;
/// Color theme configuration for the retro TUI
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColorTheme {
/// Name of the theme
pub name: String,
/// Main terminal text color (for general output)
pub terminal_green: ColorValue,
/// Warning/system messages color
pub terminal_amber: ColorValue,
/// Border and dim text color
pub terminal_dim_green: ColorValue,
/// Background color
pub terminal_bg: ColorValue,
/// Highlight/emphasis color
pub terminal_cyan: ColorValue,
/// Error/negative diff color
pub terminal_red: ColorValue,
/// READY status color
pub terminal_pale_blue: ColorValue,
/// PROCESSING status color
pub terminal_dark_amber: ColorValue,
/// Bright/punchy text color
pub terminal_white: ColorValue,
/// Success status color (for tool completions)
pub terminal_success: ColorValue,
}
/// Represents a color value that can be serialized/deserialized
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ColorValue {
/// RGB color with r, g, b components
Rgb { r: u8, g: u8, b: u8 },
/// Named color
Named(String),
}
impl ColorValue {
/// Convert to ratatui Color
pub fn to_color(&self) -> Color {
match self {
ColorValue::Rgb { r, g, b } => Color::Rgb(*r, *g, *b),
ColorValue::Named(name) => match name.to_lowercase().as_str() {
"black" => Color::Black,
"red" => Color::Red,
"green" => Color::Green,
"yellow" => Color::Yellow,
"blue" => Color::Blue,
"magenta" => Color::Magenta,
"cyan" => Color::Cyan,
"gray" | "grey" => Color::Gray,
"darkgray" | "darkgrey" => Color::DarkGray,
"lightred" => Color::LightRed,
"lightgreen" => Color::LightGreen,
"lightyellow" => Color::LightYellow,
"lightblue" => Color::LightBlue,
"lightmagenta" => Color::LightMagenta,
"lightcyan" => Color::LightCyan,
"white" => Color::White,
_ => Color::White, // Default fallback
},
}
}
}
impl ColorTheme {
/// Load a theme from a JSON file
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let content = fs::read_to_string(path)?;
let theme: ColorTheme = serde_json::from_str(&content)?;
Ok(theme)
}
/// Get the default retro sci-fi theme (inspired by Alien terminals)
pub fn default() -> Self {
ColorTheme {
name: "Retro Sci-Fi".to_string(),
terminal_green: ColorValue::Rgb { r: 136, g: 244, b: 152 },
terminal_amber: ColorValue::Rgb { r: 242, g: 204, b: 148 },
terminal_dim_green: ColorValue::Rgb { r: 154, g: 174, b: 135 },
terminal_bg: ColorValue::Rgb { r: 0, g: 10, b: 0 },
terminal_cyan: ColorValue::Rgb { r: 0, g: 255, b: 255 },
terminal_red: ColorValue::Rgb { r: 239, g: 119, b: 109 },
terminal_pale_blue: ColorValue::Rgb { r: 173, g: 234, b: 251 },
terminal_dark_amber: ColorValue::Rgb { r: 204, g: 119, b: 34 },
terminal_white: ColorValue::Rgb { r: 218, g: 218, b: 219 },
terminal_success: ColorValue::Rgb { r: 136, g: 244, b: 152 }, // Same as terminal_green for retro theme
}
}
/// Get the Dracula theme
pub fn dracula() -> Self {
ColorTheme {
name: "Dracula".to_string(),
terminal_green: ColorValue::Rgb { r: 248, g: 248, b: 242 }, // Use Dracula foreground (white) for main text
terminal_amber: ColorValue::Rgb { r: 255, g: 184, b: 108 }, // Dracula orange
terminal_dim_green: ColorValue::Rgb { r: 98, g: 114, b: 164 }, // Dracula comment
terminal_bg: ColorValue::Rgb { r: 40, g: 42, b: 54 }, // Dracula background
terminal_cyan: ColorValue::Rgb { r: 139, g: 233, b: 253 }, // Dracula cyan
terminal_red: ColorValue::Rgb { r: 255, g: 85, b: 85 }, // Dracula red
terminal_pale_blue: ColorValue::Rgb { r: 189, g: 147, b: 249 }, // Dracula purple
terminal_dark_amber: ColorValue::Rgb { r: 255, g: 121, b: 198 }, // Dracula pink
terminal_white: ColorValue::Rgb { r: 248, g: 248, b: 242 }, // Dracula foreground
terminal_success: ColorValue::Rgb { r: 80, g: 250, b: 123 }, // Dracula green for success
}
}
/// Get a theme by name or from file
pub fn load(theme_name: Option<&str>) -> Result<Self> {
match theme_name {
None => Ok(Self::default()),
Some("default") | Some("retro") => Ok(Self::default()),
Some("dracula") => Ok(Self::dracula()),
Some(path) => {
// Try to load from file
if Path::new(path).exists() {
Self::from_file(path)
} else {
// Try to find in standard locations
let home = dirs::home_dir().ok_or_else(|| anyhow::anyhow!("Could not find home directory"))?;
let theme_file = home.join(".config").join("g3").join("themes").join(format!("{}.json", path));
if theme_file.exists() {
Self::from_file(theme_file)
} else {
Err(anyhow::anyhow!("Theme '{}' not found", path))
}
}
}
}
}
}

160
crates/g3-cli/src/tui.rs Normal file
View File

@@ -0,0 +1,160 @@
use crossterm::style::Color;
use crossterm::style::{SetForegroundColor, ResetColor};
use std::io::{self, Write};
use termimad::MadSkin;
/// Simple output handler with markdown support
pub struct SimpleOutput {
mad_skin: MadSkin,
}
impl SimpleOutput {
pub fn new() -> Self {
let mut mad_skin = MadSkin::default();
// Dracula color scheme
// Background: #282a36, Foreground: #f8f8f2
// Colors: Cyan #8be9fd, Green #50fa7b, Orange #ffb86c, Pink #ff79c6, Purple #bd93f9, Red #ff5555, Yellow #f1fa8c
mad_skin.set_headers_fg(Color::Rgb { r: 189, g: 147, b: 249 }); // Purple for headers
mad_skin.bold.set_fg(Color::Rgb { r: 255, g: 121, b: 198 }); // Pink for bold
mad_skin.italic.set_fg(Color::Rgb { r: 139, g: 233, b: 253 }); // Cyan for italic
mad_skin.code_block.set_bg(Color::Rgb { r: 68, g: 71, b: 90 }); // Dracula background variant
mad_skin.code_block.set_fg(Color::Rgb { r: 80, g: 250, b: 123 }); // Green for code text
mad_skin.inline_code.set_bg(Color::Rgb { r: 68, g: 71, b: 90 }); // Same background for inline code
mad_skin.inline_code.set_fg(Color::Rgb { r: 241, g: 250, b: 140 }); // Yellow for inline code
mad_skin.quote_mark.set_fg(Color::Rgb { r: 98, g: 114, b: 164 }); // Comment purple for quote marks
mad_skin.strikeout.set_fg(Color::Rgb { r: 255, g: 85, b: 85 }); // Red for strikethrough
Self { mad_skin }
}
/// Detect if text contains markdown formatting
fn has_markdown(&self, text: &str) -> bool {
// Check for common markdown patterns
text.contains("**") ||
text.contains("```") ||
text.contains("`") ||
text.lines().any(|line| {
let trimmed = line.trim();
trimmed.starts_with('#') ||
trimmed.starts_with("- ") ||
trimmed.starts_with("* ") ||
trimmed.starts_with("+ ") ||
(trimmed.len() > 2 &&
trimmed.chars().next().is_some_and(|c| c.is_ascii_digit()) &&
trimmed.chars().nth(1) == Some('.') &&
trimmed.chars().nth(2) == Some(' ')) ||
(trimmed.contains('[') && trimmed.contains("]("))
}) ||
(text.matches('*').count() >= 2 && !text.contains("/*") && !text.contains("*/"))
}
pub fn print(&self, text: &str) {
println!("{}", text);
}
/// Smart print that automatically detects and renders markdown
pub fn print_smart(&self, text: &str) {
if self.has_markdown(text) {
self.print_markdown(text);
} else {
self.print(text);
}
}
pub fn print_markdown(&self, markdown: &str) {
self.mad_skin.print_text(markdown);
}
pub fn _print_status(&self, status: &str) {
println!("📊 {}", status);
}
pub fn print_context(&self, used: u32, total: u32, percentage: f32) {
let total_dots = 10;
let filled_dots = ((percentage / 100.0) * total_dots as f32) as usize;
let empty_dots = total_dots.saturating_sub(filled_dots);
let filled_str = "".repeat(filled_dots);
let empty_str = "".repeat(empty_dots);
// Determine color based on percentage
let color = if percentage < 40.0 {
crossterm::style::Color::Green
} else if percentage < 60.0 {
crossterm::style::Color::Yellow
} else if percentage < 80.0 {
crossterm::style::Color::Rgb { r: 255, g: 165, b: 0 } // Orange
} else {
crossterm::style::Color::Red
};
// Print with colored progress bar
print!("Context: ");
print!("{}", SetForegroundColor(color));
print!("{}{}", filled_str, empty_str);
print!("{}", ResetColor);
println!(" {:.0}% ({}/{} tokens)", percentage, used, total);
}
pub fn print_context_thinning(&self, message: &str) {
// Animated highlight for context thinning
// Use bright cyan/green with a quick flash animation
// Flash animation: print with bright background, then normal
let frames = vec![
"\x1b[1;97;46m", // Frame 1: Bold white on cyan background
"\x1b[1;97;42m", // Frame 2: Bold white on green background
"\x1b[1;96;40m", // Frame 3: Bold cyan on black background
];
println!();
// Quick flash animation
for frame in &frames {
print!("\r{}{}\x1b[0m", frame, message);
let _ = io::stdout().flush();
std::thread::sleep(std::time::Duration::from_millis(80));
}
// Final display with bright cyan and sparkle emojis
print!("\r\x1b[1;96m✨ {}\x1b[0m", message);
println!();
// Add a subtle "success" indicator line
println!("\x1b[2;36m └─ Context optimized successfully\x1b[0m");
println!();
let _ = io::stdout().flush();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_markdown_detection() {
let output = SimpleOutput::new();
// Should detect markdown
assert!(output.has_markdown("**bold text**"));
assert!(output.has_markdown("`code`"));
assert!(output.has_markdown("```\ncode block\n```"));
assert!(output.has_markdown("# Header"));
assert!(output.has_markdown("- list item"));
assert!(output.has_markdown("* list item"));
assert!(output.has_markdown("+ list item"));
assert!(output.has_markdown("1. numbered item"));
assert!(output.has_markdown("[link](url)"));
assert!(output.has_markdown("*italic* text"));
// Should NOT detect markdown
assert!(!output.has_markdown("plain text"));
assert!(!output.has_markdown("file.txt"));
assert!(!output.has_markdown("/* comment */"));
assert!(!output.has_markdown("just one * asterisk"));
assert!(!output.has_markdown("📁 Workspace: /path/to/dir"));
assert!(!output.has_markdown("✅ Success message"));
}
}

View File

@@ -0,0 +1,347 @@
use g3_core::ui_writer::UiWriter;
use std::io::{self, Write};
use std::sync::Mutex;
/// Console implementation of UiWriter that prints to stdout
pub struct ConsoleUiWriter {
current_tool_name: Mutex<Option<String>>,
current_tool_args: Mutex<Vec<(String, String)>>,
current_output_line: Mutex<Option<String>>,
output_line_printed: Mutex<bool>,
in_todo_tool: Mutex<bool>,
}
impl ConsoleUiWriter {
pub fn new() -> Self {
Self {
current_tool_name: Mutex::new(None),
current_tool_args: Mutex::new(Vec::new()),
current_output_line: Mutex::new(None),
output_line_printed: Mutex::new(false),
in_todo_tool: Mutex::new(false),
}
}
fn print_todo_line(&self, line: &str) {
// Transform and print todo list lines elegantly
let trimmed = line.trim();
// Skip the "📝 TODO list:" prefix line
if trimmed.starts_with("📝 TODO list:") || trimmed == "📝 TODO list is empty" {
return;
}
// Handle empty lines
if trimmed.is_empty() {
println!();
return;
}
// Detect indentation level
let indent_count = line.chars().take_while(|c| c.is_whitespace()).count();
let indent = " ".repeat(indent_count / 2); // Convert spaces to visual indent
// Format based on line type
if trimmed.starts_with("- [ ]") {
// Incomplete task
let task = trimmed.strip_prefix("- [ ]").unwrap_or(trimmed).trim();
println!("{}{}", indent, task);
} else if trimmed.starts_with("- [x]") || trimmed.starts_with("- [X]") {
// Completed task
let task = trimmed.strip_prefix("- [x]")
.or_else(|| trimmed.strip_prefix("- [X]"))
.unwrap_or(trimmed)
.trim();
println!("{}\x1b[2m☑ {}\x1b[0m", indent, task);
} else if trimmed.starts_with("- ") {
// Regular bullet point
let item = trimmed.strip_prefix("- ").unwrap_or(trimmed).trim();
println!("{}{}", indent, item);
} else if trimmed.starts_with("# ") {
// Heading
let heading = trimmed.strip_prefix("# ").unwrap_or(trimmed).trim();
println!("\n\x1b[1m{}\x1b[0m", heading);
} else if trimmed.starts_with("## ") {
// Subheading
let subheading = trimmed.strip_prefix("## ").unwrap_or(trimmed).trim();
println!("\n\x1b[1m{}\x1b[0m", subheading);
} else if trimmed.starts_with("**") && trimmed.ends_with("**") {
// Bold text (section marker)
let text = trimmed.trim_start_matches("**").trim_end_matches("**");
println!("{}\x1b[1m{}\x1b[0m", indent, text);
} else {
// Regular text or note
println!("{}{}", indent, trimmed);
}
}
}
impl UiWriter for ConsoleUiWriter {
fn print(&self, message: &str) {
print!("{}", message);
}
fn println(&self, message: &str) {
println!("{}", message);
}
fn print_inline(&self, message: &str) {
print!("{}", message);
let _ = io::stdout().flush();
}
fn print_system_prompt(&self, prompt: &str) {
println!("🔍 System Prompt:");
println!("================");
println!("{}", prompt);
println!("================");
println!();
}
fn print_context_status(&self, message: &str) {
println!("{}", message);
}
fn print_context_thinning(&self, message: &str) {
// Animated highlight for context thinning
// Use bright cyan/green with a quick flash animation
// Flash animation: print with bright background, then normal
let frames = vec![
"\x1b[1;97;46m", // Frame 1: Bold white on cyan background
"\x1b[1;97;42m", // Frame 2: Bold white on green background
"\x1b[1;96;40m", // Frame 3: Bold cyan on black background
];
println!();
// Quick flash animation
for frame in &frames {
print!("\r{}{}\x1b[0m", frame, message);
let _ = io::stdout().flush();
std::thread::sleep(std::time::Duration::from_millis(80));
}
// Final display with bright cyan and sparkle emojis
print!("\r\x1b[1;96m✨ {}\x1b[0m", message);
println!();
// Add a subtle "success" indicator line
println!("\x1b[2;36m └─ Context optimized successfully\x1b[0m");
println!();
let _ = io::stdout().flush();
}
fn print_tool_header(&self, tool_name: &str) {
// Store the tool name and clear args for collection
*self.current_tool_name.lock().unwrap() = Some(tool_name.to_string());
self.current_tool_args.lock().unwrap().clear();
// Check if this is a todo tool call
let is_todo = tool_name == "todo_read" || tool_name == "todo_write";
*self.in_todo_tool.lock().unwrap() = is_todo;
// For todo tools, we'll skip the normal header and print a custom one later
if is_todo {
}
}
fn print_tool_arg(&self, key: &str, value: &str) {
// Collect arguments instead of printing immediately
// Filter out any keys that look like they might be agent message content
// (e.g., keys that are suspiciously long or contain message-like content)
let is_valid_arg_key = key.len() < 50
&& !key.contains('\n')
&& !key.contains("I'll")
&& !key.contains("Let me")
&& !key.contains("Here's")
&& !key.contains("I can");
if is_valid_arg_key {
self.current_tool_args
.lock()
.unwrap()
.push((key.to_string(), value.to_string()));
}
}
fn print_tool_output_header(&self) {
// Skip normal header for todo tools
if *self.in_todo_tool.lock().unwrap() {
println!(); // Just add a newline
return;
}
println!();
// Now print the tool header with the most important arg in bold green
if let Some(tool_name) = self.current_tool_name.lock().unwrap().as_ref() {
let args = self.current_tool_args.lock().unwrap();
// Find the most important argument - prioritize file_path if available
let important_arg = args
.iter()
.find(|(k, _)| k == "file_path")
.or_else(|| args.iter().find(|(k, _)| k == "command" || k == "path"))
.or_else(|| args.first());
if let Some((_, value)) = important_arg {
// For multi-line values, only show the first line
let first_line = value.lines().next().unwrap_or("");
// Truncate long values for display
let display_value = if first_line.len() > 80 {
// Use char_indices to safely truncate at character boundary
let truncate_at = first_line.char_indices()
.nth(77)
.map(|(i, _)| i)
.unwrap_or(first_line.len());
format!("{}...", &first_line[..truncate_at])
} else {
first_line.to_string()
};
// Add range information for read_file tool calls
let header_suffix = if tool_name == "read_file" {
// Check if start or end parameters are present
let has_start = args.iter().any(|(k, _)| k == "start");
let has_end = args.iter().any(|(k, _)| k == "end");
if has_start || has_end {
let start_val = args.iter().find(|(k, _)| k == "start").map(|(_, v)| v.as_str()).unwrap_or("0");
let end_val = args.iter().find(|(k, _)| k == "end").map(|(_, v)| v.as_str()).unwrap_or("end");
format!(" [{}..{}]", start_val, end_val)
} else {
String::new()
}
} else {
String::new()
};
// Print with bold green tool name, purple (non-bold) for pipe and args
println!("┌─\x1b[1;32m {}\x1b[0m\x1b[35m | {}{}\x1b[0m", tool_name, display_value, header_suffix);
} else {
// Print with bold green formatting using ANSI escape codes
println!("┌─\x1b[1;32m {}\x1b[0m", tool_name);
}
}
}
fn update_tool_output_line(&self, line: &str) {
let mut current_line = self.current_output_line.lock().unwrap();
let mut line_printed = self.output_line_printed.lock().unwrap();
// If we've already printed a line, clear it first
if *line_printed {
// Move cursor up one line and clear it
print!("\x1b[1A\x1b[2K");
}
// Print the new line
println!("\x1b[2m{}\x1b[0m", line);
let _ = io::stdout().flush();
// Update state
*current_line = Some(line.to_string());
*line_printed = true;
}
fn print_tool_output_line(&self, line: &str) {
// Special handling for todo tools
if *self.in_todo_tool.lock().unwrap() {
self.print_todo_line(line);
return;
}
println!("\x1b[2m{}\x1b[0m", line);
}
fn print_tool_output_summary(&self, count: usize) {
// Skip for todo tools
if *self.in_todo_tool.lock().unwrap() {
return;
}
println!(
"\x1b[2m({} line{})\x1b[0m",
count,
if count == 1 { "" } else { "s" }
);
}
fn print_tool_timing(&self, duration_str: &str) {
// For todo tools, just print a simple completion message
if *self.in_todo_tool.lock().unwrap() {
println!();
*self.in_todo_tool.lock().unwrap() = false;
return;
}
// Parse the duration string to determine color
// Format is like "1.5s", "500ms", "2m 30.0s"
let color_code = if duration_str.ends_with("ms") {
// Milliseconds - use default color (< 1s)
""
} else if duration_str.contains('m') {
// Contains minutes
// Extract minutes value
if let Some(m_pos) = duration_str.find('m') {
if let Ok(minutes) = duration_str[..m_pos].trim().parse::<u32>() {
if minutes >= 5 {
"\x1b[31m" // Red for >= 5 minutes
} else {
"\x1b[38;5;208m" // Orange for >= 1 minute but < 5 minutes
}
} else {
"" // Default color if parsing fails
}
} else {
"" // Default color if 'm' not found (shouldn't happen)
}
} else if duration_str.ends_with('s') {
// Seconds only
if let Some(s_value) = duration_str.strip_suffix('s') {
if let Ok(seconds) = s_value.trim().parse::<f64>() {
if seconds >= 1.0 {
"\x1b[33m" // Yellow for >= 1 second
} else {
"" // Default color for < 1 second
}
} else {
"" // Default color if parsing fails
}
} else {
"" // Default color
}
} else {
// Milliseconds or other format - use default color
""
};
println!("└─ ⚡️ {}{}\x1b[0m", color_code, duration_str);
println!();
// Clear the stored tool info
*self.current_tool_name.lock().unwrap() = None;
self.current_tool_args.lock().unwrap().clear();
*self.current_output_line.lock().unwrap() = None;
*self.output_line_printed.lock().unwrap() = false;
}
fn print_agent_prompt(&self) {
let _ = io::stdout().flush();
}
fn print_agent_response(&self, content: &str) {
print!("{}", content);
let _ = io::stdout().flush();
}
fn notify_sse_received(&self) {
// No-op for console - we don't track SSEs in console mode
}
fn flush(&self) {
let _ = io::stdout().flush();
}
}

View File

@@ -0,0 +1,47 @@
[package]
name = "g3-computer-control"
version = "0.1.0"
edition = "2021"
[build-dependencies]
# Only needed for building Swift bridge on macOS
[dependencies]
# Workspace dependencies
tokio = { workspace = true }
anyhow = { workspace = true }
thiserror = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tracing = { workspace = true }
uuid = { workspace = true }
shellexpand = "3.1"
# Async trait support
async-trait = "0.1"
# WebDriver support
fantoccini = "0.21"
# macOS dependencies
[target.'cfg(target_os = "macos")'.dependencies]
core-graphics = "0.23"
core-foundation = "0.10"
cocoa = "0.25"
objc = "0.2"
accessibility = "0.2"
image = "0.24"
# Linux dependencies
[target.'cfg(target_os = "linux")'.dependencies]
x11 = { version = "2.21", features = ["xlib", "xtest"] }
image = "0.24"
# Windows dependencies
[target.'cfg(target_os = "windows")'.dependencies]
windows = { version = "0.52", features = [
"Win32_Foundation",
"Win32_UI_WindowsAndMessaging",
"Win32_UI_Input_KeyboardAndMouse",
"Win32_Graphics_Gdi",
] }

View File

@@ -0,0 +1,63 @@
use std::env;
use std::path::PathBuf;
use std::process::Command;
fn main() {
// Only build Vision bridge on macOS
if env::var("CARGO_CFG_TARGET_OS").unwrap() != "macos" {
return;
}
println!("cargo:rerun-if-changed=vision-bridge/Sources/VisionBridge/VisionOCR.swift");
println!("cargo:rerun-if-changed=vision-bridge/Sources/VisionBridge/VisionBridge.h");
println!("cargo:rerun-if-changed=vision-bridge/Package.swift");
let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap());
let vision_bridge_dir = manifest_dir.join("vision-bridge");
// Build Swift package
println!("cargo:warning=Building VisionBridge Swift package...");
let build_status = Command::new("swift")
.args(&["build", "-c", "release"])
.current_dir(&vision_bridge_dir)
.status()
.expect("Failed to build Swift package");
if !build_status.success() {
panic!("Swift build failed");
}
// Find the built library
let lib_path = vision_bridge_dir
.join(".build/release")
.canonicalize()
.expect("Failed to find .build/release directory");
// Copy the dylib to the output directory so it can be found at runtime
let target_dir = manifest_dir.parent().unwrap().parent().unwrap().join("target");
let profile = env::var("PROFILE").unwrap_or_else(|_| "debug".to_string());
let output_dir = target_dir.join(&profile);
let dylib_src = lib_path.join("libVisionBridge.dylib");
let dylib_dst = output_dir.join("libVisionBridge.dylib");
std::fs::copy(&dylib_src, &dylib_dst)
.expect(&format!("Failed to copy dylib from {} to {}", dylib_src.display(), dylib_dst.display()));
println!("cargo:warning=Copied libVisionBridge.dylib to {}", dylib_dst.display());
// Add rpath so the dylib can be found at runtime
println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path");
println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path");
println!("cargo:rustc-link-search=native={}", lib_path.display());
println!("cargo:rustc-link-lib=dylib=VisionBridge");
// Link required frameworks
println!("cargo:rustc-link-lib=framework=Vision");
println!("cargo:rustc-link-lib=framework=AppKit");
println!("cargo:rustc-link-lib=framework=Foundation");
println!("cargo:rustc-link-lib=framework=CoreGraphics");
println!("cargo:rustc-link-lib=framework=CoreImage");
println!("cargo:warning=VisionBridge built successfully at {}", lib_path.display());
}

View File

@@ -0,0 +1,46 @@
use core_graphics::display::CGDisplay;
fn main() {
let display = CGDisplay::main();
let image = display.image().expect("Failed to capture screen");
println!("CGImage properties:");
println!(" Width: {}", image.width());
println!(" Height: {}", image.height());
println!(" Bits per component: {}", image.bits_per_component());
println!(" Bits per pixel: {}", image.bits_per_pixel());
println!(" Bytes per row: {}", image.bytes_per_row());
let data = image.data();
let expected_size = image.width() * image.height() * 4;
println!(" Data length: {}", data.len());
println!(" Expected (w*h*4): {}", expected_size);
// Check if there's padding in rows
let bytes_per_row = image.bytes_per_row();
let width = image.width();
let expected_bytes_per_row = width * 4;
println!("\nRow alignment:");
println!(" Actual bytes per row: {}", bytes_per_row);
println!(" Expected (width * 4): {}", expected_bytes_per_row);
println!(" Padding per row: {}", bytes_per_row - expected_bytes_per_row);
// Sample some pixels from different locations
println!("\nFirst 3 pixels (raw bytes):");
for i in 0..3 {
let offset = i * 4;
println!(" Pixel {}: [{:3}, {:3}, {:3}, {:3}]",
i, data[offset], data[offset+1], data[offset+2], data[offset+3]);
}
// Check a pixel from the middle
let mid_row = image.height() / 2;
let mid_col = image.width() / 2;
let mid_offset = (mid_row * bytes_per_row + mid_col * 4) as usize;
println!("\nMiddle pixel (row {}, col {}):", mid_row, mid_col);
println!(" Offset: {}", mid_offset);
if mid_offset + 3 < data.len() as usize {
println!(" Bytes: [{:3}, {:3}, {:3}, {:3}]",
data[mid_offset], data[mid_offset+1], data[mid_offset+2], data[mid_offset+3]);
}
}

View File

@@ -0,0 +1,56 @@
use core_graphics::window::{kCGWindowListOptionOnScreenOnly, kCGNullWindowID, CGWindowListCopyWindowInfo};
use core_foundation::dictionary::CFDictionary;
use core_foundation::string::CFString;
use core_foundation::base::{TCFType, ToVoid};
fn main() {
println!("Listing all on-screen windows...");
println!("{:<10} {:<25} {}", "Window ID", "Owner", "Title");
println!("{}", "-".repeat(80));
unsafe {
let window_list = CGWindowListCopyWindowInfo(
kCGWindowListOptionOnScreenOnly,
kCGNullWindowID
);
let count = core_foundation::array::CFArray::<CFDictionary>::wrap_under_create_rule(window_list).len();
let array = core_foundation::array::CFArray::<CFDictionary>::wrap_under_create_rule(window_list);
for i in 0..count {
let dict = array.get(i).unwrap();
// Get window ID
let window_id_key = CFString::from_static_string("kCGWindowNumber");
let window_id: i64 = if let Some(value) = dict.find(window_id_key.to_void()) {
let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _);
num.to_i64().unwrap_or(0)
} else {
0
};
// Get owner name
let owner_key = CFString::from_static_string("kCGWindowOwnerName");
let owner: String = if let Some(value) = dict.find(owner_key.to_void()) {
let s: CFString = TCFType::wrap_under_get_rule(*value as *const _);
s.to_string()
} else {
"Unknown".to_string()
};
// Get window name/title
let name_key = CFString::from_static_string("kCGWindowName");
let title: String = if let Some(value) = dict.find(name_key.to_void()) {
let s: CFString = TCFType::wrap_under_get_rule(*value as *const _);
s.to_string()
} else {
"".to_string()
};
// Show all windows
if !owner.is_empty() {
println!("{:<10} {:<25} {}", window_id, owner, title);
}
}
}
}

View File

@@ -0,0 +1,74 @@
//! Example demonstrating macOS Accessibility API tools
//!
//! This example shows how to use the macax tools to control macOS applications.
//!
//! Run with: cargo run --example macax_demo
use anyhow::Result;
use g3_computer_control::MacAxController;
#[tokio::main]
async fn main() -> Result<()> {
println!("🍎 macOS Accessibility API Demo\n");
println!("This demo shows how to control macOS applications using the Accessibility API.\n");
// Create controller
let controller = MacAxController::new()?;
println!("✅ MacAxController initialized\n");
// List running applications
println!("📱 Listing running applications:");
match controller.list_applications() {
Ok(apps) => {
for app in apps.iter().take(10) {
println!(" - {}", app.name);
}
if apps.len() > 10 {
println!(" ... and {} more", apps.len() - 10);
}
}
Err(e) => println!(" ❌ Error: {}", e),
}
println!();
// Get frontmost app
println!("🎯 Getting frontmost application:");
match controller.get_frontmost_app() {
Ok(app) => println!(" Current: {}", app.name),
Err(e) => println!(" ❌ Error: {}", e),
}
println!();
// Example: Activate Finder and get its UI tree
println!("📂 Activating Finder and inspecting UI:");
match controller.activate_app("Finder") {
Ok(_) => {
println!(" ✅ Finder activated");
// Wait a moment for activation
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
// Get UI tree
match controller.get_ui_tree("Finder", 2) {
Ok(tree) => {
println!("\n UI Tree:");
for line in tree.lines().take(10) {
println!(" {}", line);
}
}
Err(e) => println!(" ❌ Error getting UI tree: {}", e),
}
}
Err(e) => println!(" ❌ Error: {}", e),
}
println!();
println!("✨ Demo complete!\n");
println!("💡 Tips:");
println!(" - Use --macax flag with g3 to enable these tools");
println!(" - Grant accessibility permissions in System Preferences");
println!(" - Add accessibility identifiers to your apps for easier automation");
println!(" - See docs/macax-tools.md for full documentation\n");
Ok(())
}

View File

@@ -0,0 +1,64 @@
use g3_computer_control::SafariDriver;
use g3_computer_control::webdriver::WebDriverController;
use anyhow::Result;
#[tokio::main]
async fn main() -> Result<()> {
println!("Safari WebDriver Demo");
println!("=====================\n");
println!("Make sure to:");
println!("1. Enable 'Allow Remote Automation' in Safari's Develop menu");
println!("2. Run: /usr/bin/safaridriver --enable");
println!("3. Start safaridriver in another terminal: safaridriver --port 4444\n");
println!("Connecting to SafariDriver...");
let mut driver = SafariDriver::new().await?;
println!("✅ Connected!\n");
// Navigate to a website
println!("Navigating to example.com...");
driver.navigate("https://example.com").await?;
println!("✅ Navigated\n");
// Get page title
let title = driver.title().await?;
println!("Page title: {}\n", title);
// Get current URL
let url = driver.current_url().await?;
println!("Current URL: {}\n", url);
// Find an element
println!("Finding h1 element...");
let h1 = driver.find_element("h1").await?;
let h1_text = h1.text().await?;
println!("H1 text: {}\n", h1_text);
// Find all paragraphs
println!("Finding all paragraphs...");
let paragraphs = driver.find_elements("p").await?;
println!("Found {} paragraphs\n", paragraphs.len());
// Get page source
println!("Getting page source...");
let source = driver.page_source().await?;
println!("Page source length: {} bytes\n", source.len());
// Execute JavaScript
println!("Executing JavaScript...");
let result = driver.execute_script("return document.title", vec![]).await?;
println!("JS result: {:?}\n", result);
// Take a screenshot
println!("Taking screenshot...");
driver.screenshot("/tmp/safari_demo.png").await?;
println!("✅ Screenshot saved to /tmp/safari_demo.png\n");
// Close the browser
println!("Closing browser...");
driver.quit().await?;
println!("✅ Done!");
Ok(())
}

View File

@@ -0,0 +1,21 @@
use g3_computer_control::create_controller;
#[tokio::main]
async fn main() {
println!("Testing screenshot with permission prompt...");
let controller = create_controller().expect("Failed to create controller");
match controller.take_screenshot("/tmp/test_with_prompt.png", None, None).await {
Ok(_) => {
println!("\n✅ Screenshot saved to /tmp/test_with_prompt.png");
println!("Opening screenshot...");
let _ = std::process::Command::new("open")
.arg("/tmp/test_with_prompt.png")
.spawn();
}
Err(e) => {
println!("❌ Screenshot failed: {}", e);
}
}
}

View File

@@ -0,0 +1,39 @@
use std::process::Command;
fn main() {
let path = "/tmp/rust_screencapture_test.png";
println!("Testing screencapture command from Rust...");
let mut cmd = Command::new("screencapture");
cmd.arg("-x"); // No sound
cmd.arg(path);
println!("Command: {:?}", cmd);
match cmd.output() {
Ok(output) => {
println!("Exit status: {}", output.status);
println!("Stdout: {}", String::from_utf8_lossy(&output.stdout));
println!("Stderr: {}", String::from_utf8_lossy(&output.stderr));
if output.status.success() {
println!("\n✅ Screenshot saved to: {}", path);
// Check file exists and size
if let Ok(metadata) = std::fs::metadata(path) {
println!("File size: {} bytes ({:.1} MB)", metadata.len(), metadata.len() as f64 / 1_000_000.0);
}
// Open it
let _ = Command::new("open").arg(path).spawn();
println!("\nOpened screenshot - please verify it looks correct!");
} else {
println!("\n❌ Screenshot failed!");
}
}
Err(e) => {
println!("❌ Failed to execute screencapture: {}", e);
}
}
}

View File

@@ -0,0 +1,68 @@
use core_graphics::display::CGDisplay;
use image::{ImageBuffer, RgbaImage};
fn main() {
let display = CGDisplay::main();
let image = display.image().expect("Failed to capture screen");
let width = image.width() as u32;
let height = image.height() as u32;
let bytes_per_row = image.bytes_per_row() as usize;
let data = image.data();
println!("Testing screenshot fix...");
println!("Image: {}x{}, bytes_per_row: {}", width, height, bytes_per_row);
println!("Expected bytes per row: {}", width * 4);
println!("Padding per row: {} bytes", bytes_per_row - (width as usize * 4));
// OLD METHOD (broken) - treating data as continuous
println!("\n=== OLD METHOD (BROKEN) ===");
let mut old_rgba = Vec::with_capacity(data.len() as usize);
for chunk in data.chunks_exact(4) {
old_rgba.push(chunk[2]); // R
old_rgba.push(chunk[1]); // G
old_rgba.push(chunk[0]); // B
old_rgba.push(chunk[3]); // A
}
println!("Converted {} pixels", old_rgba.len() / 4);
println!("Expected {} pixels", width * height);
// NEW METHOD (fixed) - handling row padding
println!("\n=== NEW METHOD (FIXED) ===");
let mut new_rgba = Vec::with_capacity((width * height * 4) as usize);
for row in 0..height as usize {
let row_start = row * bytes_per_row;
let row_end = row_start + (width as usize * 4);
for chunk in data[row_start..row_end].chunks_exact(4) {
new_rgba.push(chunk[2]); // R
new_rgba.push(chunk[1]); // G
new_rgba.push(chunk[0]); // B
new_rgba.push(chunk[3]); // A
}
}
println!("Converted {} pixels", new_rgba.len() / 4);
println!("Expected {} pixels", width * height);
// Save a small crop from both methods
let crop_size = 200;
// Old method crop
let old_crop: Vec<u8> = old_rgba.iter().take((crop_size * crop_size * 4) as usize).copied().collect();
if let Some(old_img) = ImageBuffer::from_raw(crop_size, crop_size, old_crop) {
let old_img: RgbaImage = old_img;
old_img.save("/tmp/screenshot_old_method.png").unwrap();
println!("\nSaved OLD method crop to: /tmp/screenshot_old_method.png");
}
// New method crop
let new_crop: Vec<u8> = new_rgba.iter().take((crop_size * crop_size * 4) as usize).copied().collect();
if let Some(new_img) = ImageBuffer::from_raw(crop_size, crop_size, new_crop) {
let new_img: RgbaImage = new_img;
new_img.save("/tmp/screenshot_new_method.png").unwrap();
println!("Saved NEW method crop to: /tmp/screenshot_new_method.png");
}
println!("\nOpen both images to compare:");
println!(" open /tmp/screenshot_old_method.png /tmp/screenshot_new_method.png");
}

View File

@@ -0,0 +1,48 @@
//! Test the new type_text functionality
use anyhow::Result;
use g3_computer_control::MacAxController;
#[tokio::main]
async fn main() -> Result<()> {
println!("🧪 Testing macax type_text functionality\n");
let controller = MacAxController::new()?;
println!("✅ Controller initialized\n");
// Test 1: Type simple text
println!("Test 1: Typing simple text into TextEdit");
println!(" Please open TextEdit and create a new document...");
std::thread::sleep(std::time::Duration::from_secs(3));
match controller.type_text("TextEdit", "Hello, World!") {
Ok(_) => println!(" ✅ Successfully typed simple text\n"),
Err(e) => println!(" ❌ Failed: {}\n", e),
}
std::thread::sleep(std::time::Duration::from_secs(1));
// Test 2: Type unicode and emojis
println!("Test 2: Typing unicode and emojis");
match controller.type_text("TextEdit", "\n🌟 Unicode test: café, naïve, 日本語 🎉") {
Ok(_) => println!(" ✅ Successfully typed unicode text\n"),
Err(e) => println!(" ❌ Failed: {}\n", e),
}
std::thread::sleep(std::time::Duration::from_secs(1));
// Test 3: Type special characters
println!("Test 3: Typing special characters");
match controller.type_text("TextEdit", "\nSpecial: @#$%^&*()_+-=[]{}|;':,.<>?/") {
Ok(_) => println!(" ✅ Successfully typed special characters\n"),
Err(e) => println!(" ❌ Failed: {}\n", e),
}
println!("\n✨ Tests complete!");
println!("\n💡 Now try with Things3:");
println!(" 1. Open Things3");
println!(" 2. Press Cmd+N to create a new task");
println!(" 3. Run: g3 --macax 'type \"🌟 My awesome task\" into Things'");
Ok(())
}

View File

@@ -0,0 +1,85 @@
use g3_computer_control::ocr::{OCREngine, DefaultOCR};
use anyhow::Result;
#[tokio::main]
async fn main() -> Result<()> {
println!("🧪 Testing Apple Vision OCR");
println!("===========================\n");
// Initialize OCR engine
println!("📦 Initializing OCR engine...");
let ocr = DefaultOCR::new()?;
println!("✅ OCR engine: {}\n", ocr.name());
// Check if test image exists
let test_image = "/tmp/safari_test.png";
if !std::path::Path::new(test_image).exists() {
println!("⚠️ Test image not found: {}", test_image);
println!(" Creating a screenshot...");
let status = std::process::Command::new("screencapture")
.arg("-x")
.arg("-R")
.arg("0,0,1200,800")
.arg(test_image)
.status()?;
if !status.success() {
anyhow::bail!("Failed to create screenshot");
}
println!("✅ Screenshot created\n");
}
// Run OCR
println!("🔍 Running Apple Vision OCR on {}...", test_image);
let start = std::time::Instant::now();
let locations = ocr.extract_text_with_locations(test_image).await?;
let duration = start.elapsed();
println!("✅ OCR completed in {:.3}s\n", duration.as_secs_f64());
// Display results
println!("📊 Results:");
println!(" Found {} text elements\n", locations.len());
if locations.is_empty() {
println!("⚠️ No text found in image");
} else {
println!(" Top 20 results:");
println!(" {:<4} {:<40} {:<15} {:<12} {:<8}", "#", "Text", "Position", "Size", "Conf");
println!(" {}", "-".repeat(85));
for (i, loc) in locations.iter().take(20).enumerate() {
let text = if loc.text.len() > 37 {
format!("{}...", &loc.text[..37])
} else {
loc.text.clone()
};
println!(" {:<4} {:<40} ({:>4},{:>4}) {:>4}x{:<4} {:.2}",
i + 1,
text,
loc.x,
loc.y,
loc.width,
loc.height,
loc.confidence
);
}
if locations.len() > 20 {
println!("\n ... and {} more", locations.len() - 20);
}
// Performance comparison
println!("\n📈 Performance:");
println!(" OCR Speed: {:.3}s", duration.as_secs_f64());
println!(" Text elements: {}", locations.len());
println!(" Avg per element: {:.1}ms", duration.as_millis() as f64 / locations.len() as f64);
}
println!("\n✅ Test complete!");
Ok(())
}

View File

@@ -0,0 +1,45 @@
use g3_computer_control::create_controller;
#[tokio::main]
async fn main() {
println!("Testing window-specific screenshot capture...");
let controller = create_controller().expect("Failed to create controller");
// Test 1: Capture iTerm2 window
println!("\n1. Capturing iTerm2 window...");
match controller.take_screenshot("/tmp/iterm_window.png", None, Some("iTerm2")).await {
Ok(_) => {
println!(" ✅ iTerm2 window captured to /tmp/iterm_window.png");
let _ = std::process::Command::new("open").arg("/tmp/iterm_window.png").spawn();
}
Err(e) => println!(" ❌ Failed: {}", e),
}
// Wait a moment for the image to open
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
// Test 2: Full screen capture for comparison
println!("\n2. Capturing full screen for comparison...");
match controller.take_screenshot("/tmp/fullscreen.png", None, None).await {
Ok(_) => {
println!(" ✅ Full screen captured to /tmp/fullscreen.png");
let _ = std::process::Command::new("open").arg("/tmp/fullscreen.png").spawn();
}
Err(e) => println!(" ❌ Failed: {}", e),
}
println!("\n=== Comparison ===");
println!("iTerm window: /tmp/iterm_window.png (should show ONLY iTerm window)");
println!("Full screen: /tmp/fullscreen.png (should show entire desktop)");
// Show file sizes
if let Ok(meta1) = std::fs::metadata("/tmp/iterm_window.png") {
if let Ok(meta2) = std::fs::metadata("/tmp/fullscreen.png") {
println!("\nFile sizes:");
println!(" iTerm window: {:.1} MB", meta1.len() as f64 / 1_000_000.0);
println!(" Full screen: {:.1} MB", meta2.len() as f64 / 1_000_000.0);
println!("\nWindow capture should be smaller than full screen.");
}
}
}

View File

@@ -0,0 +1,49 @@
// Suppress warnings from objc crate macros
#![allow(unexpected_cfgs)]
pub mod types;
pub mod platform;
pub mod ocr;
pub mod webdriver;
pub mod macax;
// Re-export webdriver types for convenience
pub use webdriver::{WebDriverController, WebElement, safari::SafariDriver};
// Re-export macax types for convenience
pub use macax::{MacAxController, AXElement, AXApplication};
use anyhow::Result;
use async_trait::async_trait;
use types::*;
#[async_trait]
pub trait ComputerController: Send + Sync {
// Screen capture
async fn take_screenshot(&self, path: &str, region: Option<Rect>, window_id: Option<&str>) -> Result<()>;
// OCR operations
async fn extract_text_from_screen(&self, region: Rect, window_id: &str) -> Result<String>;
async fn extract_text_from_image(&self, path: &str) -> Result<String>;
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>>;
async fn find_text_in_app(&self, app_name: &str, search_text: &str) -> Result<Option<TextLocation>>;
// Mouse operations
fn move_mouse(&self, x: i32, y: i32) -> Result<()>;
fn click_at(&self, x: i32, y: i32, app_name: Option<&str>) -> Result<()>;
}
// Platform-specific constructor
pub fn create_controller() -> Result<Box<dyn ComputerController>> {
#[cfg(target_os = "macos")]
return Ok(Box::new(platform::macos::MacOSController::new()?));
#[cfg(target_os = "linux")]
return Ok(Box::new(platform::linux::LinuxController::new()?));
#[cfg(target_os = "windows")]
return Ok(Box::new(platform::windows::WindowsController::new()?));
#[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))]
anyhow::bail!("Unsupported platform")
}

View File

@@ -0,0 +1,822 @@
use super::{AXApplication, AXElement};
use anyhow::{Context, Result};
use std::collections::HashMap;
#[cfg(target_os = "macos")]
use accessibility::{AXUIElement, AXUIElementAttributes, ElementFinder, TreeVisitor, TreeWalker, TreeWalkerFlow};
#[cfg(target_os = "macos")]
use core_foundation::base::TCFType;
#[cfg(target_os = "macos")]
use core_foundation::string::CFString;
/// macOS Accessibility API controller using native APIs
pub struct MacAxController {
// Cache for application elements
app_cache: std::sync::Mutex<HashMap<String, AXUIElement>>,
}
impl MacAxController {
pub fn new() -> Result<Self> {
#[cfg(target_os = "macos")]
{
// Check if we have accessibility permissions by trying to get system-wide element
let _system = AXUIElement::system_wide();
Ok(Self {
app_cache: std::sync::Mutex::new(HashMap::new()),
})
}
#[cfg(not(target_os = "macos"))]
{
anyhow::bail!("macOS Accessibility API is only available on macOS")
}
}
/// List all running applications
#[cfg(target_os = "macos")]
pub fn list_applications(&self) -> Result<Vec<AXApplication>> {
let apps = Self::get_running_applications()?;
Ok(apps)
}
#[cfg(not(target_os = "macos"))]
pub fn list_applications(&self) -> Result<Vec<AXApplication>> {
anyhow::bail!("Not supported on this platform")
}
#[cfg(target_os = "macos")]
fn get_running_applications() -> Result<Vec<AXApplication>> {
use cocoa::appkit::NSApplicationActivationPolicy;
use cocoa::base::{id, nil};
use objc::{class, msg_send, sel, sel_impl};
unsafe {
let workspace: id = msg_send![class!(NSWorkspace), sharedWorkspace];
let running_apps: id = msg_send![workspace, runningApplications];
let count: usize = msg_send![running_apps, count];
let mut apps = Vec::new();
for i in 0..count {
let app: id = msg_send![running_apps, objectAtIndex: i];
// Get app name
let localized_name: id = msg_send![app, localizedName];
if localized_name == nil {
continue;
}
let name_ptr: *const i8 = msg_send![localized_name, UTF8String];
let name = if !name_ptr.is_null() {
std::ffi::CStr::from_ptr(name_ptr)
.to_string_lossy()
.to_string()
} else {
continue;
};
// Get bundle ID
let bundle_id_obj: id = msg_send![app, bundleIdentifier];
let bundle_id = if bundle_id_obj != nil {
let bundle_id_ptr: *const i8 = msg_send![bundle_id_obj, UTF8String];
if !bundle_id_ptr.is_null() {
Some(
std::ffi::CStr::from_ptr(bundle_id_ptr)
.to_string_lossy()
.to_string(),
)
} else {
None
}
} else {
None
};
// Get PID
let pid: i32 = msg_send![app, processIdentifier];
// Skip background-only apps
let activation_policy: i64 = msg_send![app, activationPolicy];
if activation_policy == NSApplicationActivationPolicy::NSApplicationActivationPolicyRegular as i64 {
apps.push(AXApplication {
name,
bundle_id,
pid,
});
}
}
Ok(apps)
}
}
/// Get the frontmost (active) application
#[cfg(target_os = "macos")]
pub fn get_frontmost_app(&self) -> Result<AXApplication> {
use cocoa::base::{id, nil};
use objc::{class, msg_send, sel, sel_impl};
unsafe {
let workspace: id = msg_send![class!(NSWorkspace), sharedWorkspace];
let frontmost_app: id = msg_send![workspace, frontmostApplication];
if frontmost_app == nil {
anyhow::bail!("No frontmost application");
}
// Get app name
let localized_name: id = msg_send![frontmost_app, localizedName];
let name_ptr: *const i8 = msg_send![localized_name, UTF8String];
let name = std::ffi::CStr::from_ptr(name_ptr)
.to_string_lossy()
.to_string();
// Get bundle ID
let bundle_id_obj: id = msg_send![frontmost_app, bundleIdentifier];
let bundle_id = if bundle_id_obj != nil {
let bundle_id_ptr: *const i8 = msg_send![bundle_id_obj, UTF8String];
if !bundle_id_ptr.is_null() {
Some(
std::ffi::CStr::from_ptr(bundle_id_ptr)
.to_string_lossy()
.to_string(),
)
} else {
None
}
} else {
None
};
// Get PID
let pid: i32 = msg_send![frontmost_app, processIdentifier];
Ok(AXApplication {
name,
bundle_id,
pid,
})
}
}
#[cfg(not(target_os = "macos"))]
pub fn get_frontmost_app(&self) -> Result<AXApplication> {
anyhow::bail!("Not supported on this platform")
}
/// Get AXUIElement for an application by name or PID
#[cfg(target_os = "macos")]
fn get_app_element(&self, app_name: &str) -> Result<AXUIElement> {
// Check cache first
{
let cache = self.app_cache.lock().unwrap();
if let Some(element) = cache.get(app_name) {
return Ok(element.clone());
}
}
// Find the app by name
let apps = Self::get_running_applications()?;
let app = apps
.iter()
.find(|a| a.name == app_name)
.ok_or_else(|| anyhow::anyhow!("Application '{}' not found", app_name))?;
// Create AXUIElement for the app
let element = AXUIElement::application(app.pid);
// Cache it
{
let mut cache = self.app_cache.lock().unwrap();
cache.insert(app_name.to_string(), element.clone());
}
Ok(element)
}
/// Activate (bring to front) an application
#[cfg(target_os = "macos")]
pub fn activate_app(&self, app_name: &str) -> Result<()> {
use cocoa::base::id;
use objc::{class, msg_send, sel, sel_impl};
// Find the app
let apps = Self::get_running_applications()?;
let app = apps
.iter()
.find(|a| a.name == app_name)
.ok_or_else(|| anyhow::anyhow!("Application '{}' not found", app_name))?;
unsafe {
let workspace: id = msg_send![class!(NSWorkspace), sharedWorkspace];
let running_apps: id = msg_send![workspace, runningApplications];
let count: usize = msg_send![running_apps, count];
for i in 0..count {
let running_app: id = msg_send![running_apps, objectAtIndex: i];
let pid: i32 = msg_send![running_app, processIdentifier];
if pid == app.pid {
let _: bool = msg_send![running_app, activateWithOptions: 0];
return Ok(());
}
}
}
anyhow::bail!("Failed to activate application")
}
#[cfg(not(target_os = "macos"))]
pub fn activate_app(&self, _app_name: &str) -> Result<()> {
anyhow::bail!("Not supported on this platform")
}
/// Get the UI hierarchy of an application
#[cfg(target_os = "macos")]
pub fn get_ui_tree(&self, app_name: &str, max_depth: usize) -> Result<String> {
let app_element = self.get_app_element(app_name)?;
let mut output = format!("Application: {}\n", app_name);
Self::build_ui_tree(&app_element, &mut output, 0, max_depth)?;
Ok(output)
}
#[cfg(not(target_os = "macos"))]
pub fn get_ui_tree(&self, _app_name: &str, _max_depth: usize) -> Result<String> {
anyhow::bail!("Not supported on this platform")
}
#[cfg(target_os = "macos")]
fn build_ui_tree(
element: &AXUIElement,
output: &mut String,
depth: usize,
max_depth: usize,
) -> Result<()> {
if depth >= max_depth {
return Ok(());
}
let indent = " ".repeat(depth);
// Get role
let role = element.role().ok().map(|s| s.to_string())
.unwrap_or_else(|| "Unknown".to_string());
// Get title
let title = element.title().ok()
.map(|s| s.to_string());
// Get identifier
let identifier = element.identifier().ok()
.map(|s| s.to_string());
// Format output
output.push_str(&format!("{}Role: {}", indent, role));
if let Some(t) = title {
output.push_str(&format!(", Title: {}", t));
}
if let Some(id) = identifier {
output.push_str(&format!(", ID: {}", id));
}
output.push('\n');
// Get children
if let Ok(children) = element.children() {
for i in 0..children.len() {
if let Some(child) = children.get(i) {
let _ = Self::build_ui_tree(&child, output, depth + 1, max_depth);
}
}
}
Ok(())
}
/// Find UI elements in an application
#[cfg(target_os = "macos")]
pub fn find_elements(
&self,
app_name: &str,
role: Option<&str>,
title: Option<&str>,
identifier: Option<&str>,
) -> Result<Vec<AXElement>> {
let app_element = self.get_app_element(app_name)?;
let mut found_elements = Vec::new();
let visitor = ElementCollector {
role_filter: role.map(|s| s.to_string()),
title_filter: title.map(|s| s.to_string()),
identifier_filter: identifier.map(|s| s.to_string()),
results: std::cell::RefCell::new(&mut found_elements),
depth: std::cell::Cell::new(0),
};
let walker = TreeWalker::new();
walker.walk(&app_element, &visitor);
Ok(found_elements)
}
#[cfg(not(target_os = "macos"))]
pub fn find_elements(
&self,
_app_name: &str,
_role: Option<&str>,
_title: Option<&str>,
_identifier: Option<&str>,
) -> Result<Vec<AXElement>> {
anyhow::bail!("Not supported on this platform")
}
/// Find a single element (helper for click, set_value, etc.)
#[cfg(target_os = "macos")]
fn find_element(
&self,
app_name: &str,
role: &str,
title: Option<&str>,
identifier: Option<&str>,
) -> Result<AXUIElement> {
let app_element = self.get_app_element(app_name)?;
let role_str = role.to_string();
let title_str = title.map(|s| s.to_string());
let identifier_str = identifier.map(|s| s.to_string());
let finder = ElementFinder::new(
&app_element,
move |element| {
// Check role
let elem_role = element.role()
.ok()
.map(|s| s.to_string());
if let Some(r) = elem_role {
if !r.contains(&role_str) {
return false;
}
} else {
return false;
}
// Check title if specified
if let Some(ref title_filter) = title_str {
let elem_title = element.title()
.ok()
.map(|s| s.to_string());
if let Some(t) = elem_title {
if !t.contains(title_filter) {
return false;
}
} else {
return false;
}
}
// Check identifier if specified
if let Some(ref id_filter) = identifier_str {
let elem_id = element.identifier()
.ok()
.map(|s| s.to_string());
if let Some(id) = elem_id {
if !id.contains(id_filter) {
return false;
}
} else {
return false;
}
}
true
},
Some(std::time::Duration::from_secs(2)),
);
finder.find().context("Element not found")
}
/// Click on a UI element
#[cfg(target_os = "macos")]
pub fn click_element(
&self,
app_name: &str,
role: &str,
title: Option<&str>,
identifier: Option<&str>,
) -> Result<()> {
let element = self.find_element(app_name, role, title, identifier)?;
// Perform the press action
let action_name = CFString::new("AXPress");
element
.perform_action(&action_name)
.map_err(|e| anyhow::anyhow!("Failed to perform press action: {:?}", e))?;
Ok(())
}
#[cfg(not(target_os = "macos"))]
pub fn click_element(
&self,
_app_name: &str,
_role: &str,
_title: Option<&str>,
_identifier: Option<&str>,
) -> Result<()> {
anyhow::bail!("Not supported on this platform")
}
/// Set the value of a UI element
#[cfg(target_os = "macos")]
pub fn set_value(
&self,
app_name: &str,
role: &str,
value: &str,
title: Option<&str>,
identifier: Option<&str>,
) -> Result<()> {
let element = self.find_element(app_name, role, title, identifier)?;
// Set the value - convert CFString to CFType
let cf_value = CFString::new(value);
element.set_value(cf_value.as_CFType())
.map_err(|e| anyhow::anyhow!("Failed to set value: {:?}", e))?;
Ok(())
}
#[cfg(not(target_os = "macos"))]
pub fn set_value(
&self,
_app_name: &str,
_role: &str,
_value: &str,
_title: Option<&str>,
_identifier: Option<&str>,
) -> Result<()> {
anyhow::bail!("Not supported on this platform")
}
/// Get the value of a UI element
#[cfg(target_os = "macos")]
pub fn get_value(
&self,
app_name: &str,
role: &str,
title: Option<&str>,
identifier: Option<&str>,
) -> Result<String> {
let element = self.find_element(app_name, role, title, identifier)?;
// Get the value
let value_type = element.value()
.map_err(|e| anyhow::anyhow!("Failed to get value: {:?}", e))?;
// Try to downcast to CFString
if let Some(cf_string) = value_type.downcast::<CFString>() {
Ok(cf_string.to_string())
} else {
// For non-string values, try to get a description
Ok(format!("<non-string value>"))
}
}
#[cfg(not(target_os = "macos"))]
pub fn get_value(
&self,
_app_name: &str,
_role: &str,
_title: Option<&str>,
_identifier: Option<&str>,
) -> Result<String> {
anyhow::bail!("Not supported on this platform")
}
/// Type text into the currently focused element (uses system text input)
#[cfg(target_os = "macos")]
pub fn type_text(&self, app_name: &str, text: &str) -> Result<()> {
use cocoa::base::{id, nil};
use cocoa::foundation::NSString;
use objc::{class, msg_send, sel, sel_impl};
// First, make sure the app is active
self.activate_app(app_name)?;
// Wait for app to fully activate
std::thread::sleep(std::time::Duration::from_millis(500));
// Send a Tab key to try to focus on a text field
// This helps ensure something is focused before we paste
let _ = self.press_key(app_name, "tab", vec![]);
std::thread::sleep(std::time::Duration::from_millis(800));
// Save old clipboard, set new content, paste, then restore
let old_content: id;
unsafe {
// Get the general pasteboard
let pasteboard: id = msg_send![class!(NSPasteboard), generalPasteboard];
// Save current clipboard content
let ns_string_type = NSString::alloc(nil).init_str("public.utf8-plain-text");
old_content = msg_send![pasteboard, stringForType: ns_string_type];
// Clear and set new content
let _: () = msg_send![pasteboard, clearContents];
let ns_string = NSString::alloc(nil).init_str(text);
let ns_type = NSString::alloc(nil).init_str("public.utf8-plain-text");
let _: bool = msg_send![pasteboard, setString:ns_string forType:ns_type];
}
// Wait a moment for clipboard to update
std::thread::sleep(std::time::Duration::from_millis(200));
// Paste using Cmd+V (outside unsafe block)
self.press_key(app_name, "v", vec!["command"])?;
// Wait for paste to complete
std::thread::sleep(std::time::Duration::from_millis(300));
// Restore old clipboard content if it existed
unsafe {
if old_content != nil {
let pasteboard: id = msg_send![class!(NSPasteboard), generalPasteboard];
let _: () = msg_send![pasteboard, clearContents];
let ns_type = NSString::alloc(nil).init_str("public.utf8-plain-text");
let _: bool = msg_send![pasteboard, setString:old_content forType:ns_type];
}
}
Ok(())
}
#[cfg(not(target_os = "macos"))]
pub fn type_text(&self, _app_name: &str, _text: &str) -> Result<()> {
anyhow::bail!("Not supported on this platform")
}
/// Focus on a text field or text area element
#[cfg(target_os = "macos")]
pub fn focus_element(
&self,
app_name: &str,
role: &str,
title: Option<&str>,
identifier: Option<&str>,
) -> Result<()> {
let element = self.find_element(app_name, role, title, identifier)?;
// Set focused attribute to true
use core_foundation::boolean::CFBoolean;
let cf_true = CFBoolean::true_value();
element.set_attribute(&accessibility::AXAttribute::focused(), cf_true)
.map_err(|e| anyhow::anyhow!("Failed to focus element: {:?}", e))?;
Ok(())
}
/// Press a keyboard shortcut
#[cfg(target_os = "macos")]
pub fn press_key(
&self,
app_name: &str,
key: &str,
modifiers: Vec<&str>,
) -> Result<()> {
use core_graphics::event::{
CGEvent, CGEventFlags, CGEventTapLocation,
};
use core_graphics::event_source::{CGEventSource, CGEventSourceStateID};
// First, make sure the app is active
self.activate_app(app_name)?;
// Wait a bit for activation
std::thread::sleep(std::time::Duration::from_millis(100));
// Map key string to key code
let key_code = Self::key_to_keycode(key)
.ok_or_else(|| anyhow::anyhow!("Unknown key: {}", key))?;
// Map modifiers to flags
let mut flags = CGEventFlags::CGEventFlagNull;
for modifier in modifiers {
match modifier.to_lowercase().as_str() {
"command" | "cmd" => flags |= CGEventFlags::CGEventFlagCommand,
"option" | "alt" => flags |= CGEventFlags::CGEventFlagAlternate,
"control" | "ctrl" => flags |= CGEventFlags::CGEventFlagControl,
"shift" => flags |= CGEventFlags::CGEventFlagShift,
_ => {}
}
}
// Create event source
let source = CGEventSource::new(CGEventSourceStateID::HIDSystemState)
.ok().context("Failed to create event source")?;
// Create key down event
let key_down = CGEvent::new_keyboard_event(source.clone(), key_code, true)
.ok().context("Failed to create key down event")?;
key_down.set_flags(flags);
// Create key up event
let key_up = CGEvent::new_keyboard_event(source, key_code, false)
.ok().context("Failed to create key up event")?;
key_up.set_flags(flags);
// Post events
key_down.post(CGEventTapLocation::HID);
std::thread::sleep(std::time::Duration::from_millis(50));
key_up.post(CGEventTapLocation::HID);
Ok(())
}
#[cfg(not(target_os = "macos"))]
pub fn press_key(
&self,
_app_name: &str,
_key: &str,
_modifiers: Vec<&str>,
) -> Result<()> {
anyhow::bail!("Not supported on this platform")
}
#[cfg(target_os = "macos")]
fn key_to_keycode(key: &str) -> Option<u16> {
// Map common keys to keycodes
// See: https://eastmanreference.com/complete-list-of-applescript-key-codes
match key.to_lowercase().as_str() {
"a" => Some(0x00),
"s" => Some(0x01),
"d" => Some(0x02),
"f" => Some(0x03),
"h" => Some(0x04),
"g" => Some(0x05),
"z" => Some(0x06),
"x" => Some(0x07),
"c" => Some(0x08),
"v" => Some(0x09),
"b" => Some(0x0B),
"q" => Some(0x0C),
"w" => Some(0x0D),
"e" => Some(0x0E),
"r" => Some(0x0F),
"y" => Some(0x10),
"t" => Some(0x11),
"1" => Some(0x12),
"2" => Some(0x13),
"3" => Some(0x14),
"4" => Some(0x15),
"6" => Some(0x16),
"5" => Some(0x17),
"=" => Some(0x18),
"9" => Some(0x19),
"7" => Some(0x1A),
"-" => Some(0x1B),
"8" => Some(0x1C),
"0" => Some(0x1D),
"]" => Some(0x1E),
"o" => Some(0x1F),
"u" => Some(0x20),
"[" => Some(0x21),
"i" => Some(0x22),
"p" => Some(0x23),
"return" | "enter" => Some(0x24),
"l" => Some(0x25),
"j" => Some(0x26),
"'" => Some(0x27),
"k" => Some(0x28),
";" => Some(0x29),
"\\" => Some(0x2A),
"," => Some(0x2B),
"/" => Some(0x2C),
"n" => Some(0x2D),
"m" => Some(0x2E),
"." => Some(0x2F),
"tab" => Some(0x30),
"space" => Some(0x31),
"`" => Some(0x32),
"delete" | "backspace" => Some(0x33),
"escape" | "esc" => Some(0x35),
"f1" => Some(0x7A),
"f2" => Some(0x78),
"f3" => Some(0x63),
"f4" => Some(0x76),
"f5" => Some(0x60),
"f6" => Some(0x61),
"f7" => Some(0x62),
"f8" => Some(0x64),
"f9" => Some(0x65),
"f10" => Some(0x6D),
"f11" => Some(0x67),
"f12" => Some(0x6F),
"left" => Some(0x7B),
"right" => Some(0x7C),
"down" => Some(0x7D),
"up" => Some(0x7E),
_ => None,
}
}
}
#[cfg(target_os = "macos")]
struct ElementCollector<'a> {
role_filter: Option<String>,
title_filter: Option<String>,
identifier_filter: Option<String>,
results: std::cell::RefCell<&'a mut Vec<AXElement>>,
depth: std::cell::Cell<usize>,
}
#[cfg(target_os = "macos")]
impl<'a> TreeVisitor for ElementCollector<'a> {
fn enter_element(&self, element: &AXUIElement) -> TreeWalkerFlow {
self.depth.set(self.depth.get() + 1);
if self.depth.get() > 20 {
return TreeWalkerFlow::SkipSubtree;
}
// Get element properties
let role = element.role()
.ok()
.map(|s| s.to_string())
.unwrap_or_else(|| "Unknown".to_string());
let title = element.title()
.ok()
.map(|s| s.to_string());
let identifier = element.identifier()
.ok()
.map(|s| s.to_string());
// Check if this element matches the filters
let role_matches = self.role_filter.as_ref().map_or(true, |r| role.contains(r));
let title_matches = self.title_filter.as_ref().map_or(true, |t| {
title.as_ref().map_or(false, |title_str| title_str.contains(t))
});
let identifier_matches = self.identifier_filter.as_ref().map_or(true, |id| {
identifier.as_ref().map_or(false, |id_str| id_str.contains(id))
});
if role_matches && title_matches && identifier_matches {
// Get additional properties
let value = element.value()
.ok()
.and_then(|v| {
v.downcast::<CFString>().map(|s| s.to_string())
});
let label = element.description()
.ok()
.map(|s| s.to_string());
let enabled = element.enabled()
.ok()
.map(|b| b.into())
.unwrap_or(false);
let focused = element.focused()
.ok()
.map(|b| b.into())
.unwrap_or(false);
// Count children
let children_count = element.children()
.ok()
.map(|arr| arr.len() as usize)
.unwrap_or(0);
self.results.borrow_mut().push(AXElement {
role,
title,
value,
label,
identifier,
enabled,
focused,
position: None,
size: None,
children_count,
});
}
TreeWalkerFlow::Continue
}
fn exit_element(&self, _element: &AXUIElement) {
self.depth.set(self.depth.get() - 1);
}
}

View File

@@ -0,0 +1,65 @@
pub mod controller;
pub use controller::MacAxController;
use serde::{Deserialize, Serialize};
#[cfg(test)]
mod tests;
/// Represents an accessibility element in the UI hierarchy
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AXElement {
pub role: String,
pub title: Option<String>,
pub value: Option<String>,
pub label: Option<String>,
pub identifier: Option<String>,
pub enabled: bool,
pub focused: bool,
pub position: Option<(f64, f64)>,
pub size: Option<(f64, f64)>,
pub children_count: usize,
}
/// Represents a macOS application
#[derive(Debug, Clone)]
pub struct AXApplication {
pub name: String,
pub bundle_id: Option<String>,
pub pid: i32,
}
impl AXElement {
/// Convert to a human-readable string representation
pub fn to_string(&self) -> String {
let mut parts = vec![format!("Role: {}", self.role)];
if let Some(ref title) = self.title {
parts.push(format!("Title: {}", title));
}
if let Some(ref value) = self.value {
parts.push(format!("Value: {}", value));
}
if let Some(ref label) = self.label {
parts.push(format!("Label: {}", label));
}
if let Some(ref id) = self.identifier {
parts.push(format!("ID: {}", id));
}
parts.push(format!("Enabled: {}", self.enabled));
parts.push(format!("Focused: {}", self.focused));
if let Some((x, y)) = self.position {
parts.push(format!("Position: ({:.0}, {:.0})", x, y));
}
if let Some((w, h)) = self.size {
parts.push(format!("Size: ({:.0}, {:.0})", w, h));
}
parts.push(format!("Children: {}", self.children_count));
parts.join(", ")
}
}

View File

@@ -0,0 +1,37 @@
#[cfg(test)]
mod tests {
use crate::{AXElement, MacAxController};
#[test]
fn test_ax_element_to_string() {
let element = AXElement {
role: "button".to_string(),
title: Some("Click Me".to_string()),
value: None,
label: Some("Submit Button".to_string()),
identifier: Some("submitBtn".to_string()),
enabled: true,
focused: false,
position: Some((100.0, 200.0)),
size: Some((80.0, 30.0)),
children_count: 0,
};
let string_repr = element.to_string();
assert!(string_repr.contains("Role: button"));
assert!(string_repr.contains("Title: Click Me"));
assert!(string_repr.contains("Label: Submit Button"));
assert!(string_repr.contains("ID: submitBtn"));
assert!(string_repr.contains("Enabled: true"));
assert!(string_repr.contains("Position: (100, 200)"));
assert!(string_repr.contains("Size: (80, 30)"));
}
#[test]
fn test_controller_creation() {
// Just test that we can create a controller
// Actual functionality requires macOS and permissions
let result = MacAxController::new();
assert!(result.is_ok());
}
}

View File

@@ -0,0 +1,26 @@
use crate::types::TextLocation;
use anyhow::Result;
use async_trait::async_trait;
/// OCR engine trait for text recognition with bounding boxes
#[async_trait]
pub trait OCREngine: Send + Sync {
/// Extract text with locations from an image file
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>>;
/// Get the name of the OCR engine
fn name(&self) -> &str;
}
// Platform-specific modules
#[cfg(target_os = "macos")]
pub mod vision;
pub mod tesseract;
// Re-export the default OCR engine for the platform
#[cfg(target_os = "macos")]
pub use vision::AppleVisionOCR as DefaultOCR;
#[cfg(not(target_os = "macos"))]
pub use tesseract::TesseractOCR as DefaultOCR;

View File

@@ -0,0 +1,84 @@
use super::OCREngine;
use crate::types::TextLocation;
use anyhow::Result;
use async_trait::async_trait;
/// Tesseract OCR engine (fallback/cross-platform)
pub struct TesseractOCR;
impl TesseractOCR {
pub fn new() -> Result<Self> {
// Check if tesseract is available
let tesseract_check = std::process::Command::new("which")
.arg("tesseract")
.output();
if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() {
anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\
To install tesseract:\n macOS: brew install tesseract\n \
Linux: sudo apt-get install tesseract-ocr (Ubuntu/Debian)\n \
sudo yum install tesseract (RHEL/CentOS)\n \
Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki\n\n\
After installation, restart your terminal and try again.");
}
Ok(Self)
}
}
#[async_trait]
impl OCREngine for TesseractOCR {
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>> {
// Use tesseract CLI with TSV output to get bounding boxes
let output = std::process::Command::new("tesseract")
.arg(path)
.arg("stdout")
.arg("tsv")
.output()
.map_err(|e| anyhow::anyhow!("Failed to run tesseract: {}", e))?;
if !output.status.success() {
anyhow::bail!("Tesseract failed: {}", String::from_utf8_lossy(&output.stderr));
}
let tsv_text = String::from_utf8_lossy(&output.stdout);
let mut locations = Vec::new();
// Parse TSV output (skip header line)
for (i, line) in tsv_text.lines().enumerate() {
if i == 0 { continue; } // Skip header
let parts: Vec<&str> = line.split('\t').collect();
if parts.len() >= 12 {
// TSV format: level, page_num, block_num, par_num, line_num, word_num,
// left, top, width, height, conf, text
if let (Ok(x), Ok(y), Ok(w), Ok(h), Ok(conf), text) = (
parts[6].parse::<i32>(),
parts[7].parse::<i32>(),
parts[8].parse::<i32>(),
parts[9].parse::<i32>(),
parts[10].parse::<f32>(),
parts[11],
) {
let trimmed = text.trim();
if !trimmed.is_empty() && conf > 0.0 {
locations.push(TextLocation {
text: trimmed.to_string(),
x,
y,
width: w,
height: h,
confidence: conf / 100.0, // Convert from 0-100 to 0-1
});
}
}
}
}
Ok(locations)
}
fn name(&self) -> &str {
"Tesseract OCR"
}
}

View File

@@ -0,0 +1,103 @@
use super::OCREngine;
use crate::types::TextLocation;
use anyhow::{Result, Context};
use async_trait::async_trait;
use std::ffi::{CStr, CString};
use std::os::raw::{c_char, c_float, c_uint};
// FFI bindings to Swift VisionBridge
#[repr(C)]
struct VisionTextBox {
text: *const c_char,
text_len: c_uint,
x: i32,
y: i32,
width: i32,
height: i32,
confidence: c_float,
}
extern "C" {
fn vision_recognize_text(
image_path: *const c_char,
image_path_len: c_uint,
out_boxes: *mut *mut std::ffi::c_void,
out_count: *mut c_uint,
) -> bool;
fn vision_free_boxes(boxes: *mut std::ffi::c_void, count: c_uint);
}
/// Apple Vision Framework OCR engine
pub struct AppleVisionOCR;
impl AppleVisionOCR {
pub fn new() -> Result<Self> {
Ok(Self)
}
}
#[async_trait]
impl OCREngine for AppleVisionOCR {
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>> {
// Convert path to C string
let c_path = CString::new(path)
.context("Failed to convert path to C string")?;
let mut boxes_ptr: *mut std::ffi::c_void = std::ptr::null_mut();
let mut count: c_uint = 0;
// Call Swift Vision API
let success = unsafe {
vision_recognize_text(
c_path.as_ptr(),
path.len() as c_uint,
&mut boxes_ptr,
&mut count,
)
};
if !success || boxes_ptr.is_null() {
anyhow::bail!("Apple Vision OCR failed");
}
// Convert C array to Rust Vec
let mut locations = Vec::new();
unsafe {
let typed_boxes = boxes_ptr as *const VisionTextBox;
let boxes_slice = std::slice::from_raw_parts(typed_boxes, count as usize);
for box_data in boxes_slice {
// Convert C string to Rust String
let text = if !box_data.text.is_null() {
CStr::from_ptr(box_data.text)
.to_string_lossy()
.into_owned()
} else {
String::new()
};
if !text.is_empty() {
locations.push(TextLocation {
text,
x: box_data.x,
y: box_data.y,
width: box_data.width,
height: box_data.height,
confidence: box_data.confidence,
});
}
}
// Free the C array
vision_free_boxes(boxes_ptr, count);
}
Ok(locations)
}
fn name(&self) -> &str {
"Apple Vision Framework"
}
}

View File

@@ -0,0 +1,166 @@
use crate::{ComputerController, types::*};
use anyhow::Result;
use async_trait::async_trait;
use tesseract::Tesseract;
use uuid::Uuid;
pub struct LinuxController {
// Placeholder for X11 connection or other state
}
impl LinuxController {
pub fn new() -> Result<Self> {
// Initialize X11 connection
tracing::warn!("Linux computer control not fully implemented");
Ok(Self {})
}
}
#[async_trait]
impl ComputerController for LinuxController {
async fn move_mouse(&self, _x: i32, _y: i32) -> Result<()> {
anyhow::bail!("Linux implementation not yet available")
}
async fn click(&self, _button: MouseButton) -> Result<()> {
anyhow::bail!("Linux implementation not yet available")
}
async fn double_click(&self, _button: MouseButton) -> Result<()> {
anyhow::bail!("Linux implementation not yet available")
}
async fn type_text(&self, _text: &str) -> Result<()> {
anyhow::bail!("Linux implementation not yet available")
}
async fn press_key(&self, _key: &str) -> Result<()> {
anyhow::bail!("Linux implementation not yet available")
}
async fn list_windows(&self) -> Result<Vec<Window>> {
anyhow::bail!("Linux implementation not yet available")
}
async fn focus_window(&self, _window_id: &str) -> Result<()> {
anyhow::bail!("Linux implementation not yet available")
}
async fn get_window_bounds(&self, _window_id: &str) -> Result<Rect> {
anyhow::bail!("Linux implementation not yet available")
}
async fn find_element(&self, _selector: &ElementSelector) -> Result<Option<UIElement>> {
anyhow::bail!("Linux implementation not yet available")
}
async fn get_element_text(&self, _element_id: &str) -> Result<String> {
anyhow::bail!("Linux implementation not yet available")
}
async fn get_element_bounds(&self, _element_id: &str) -> Result<Rect> {
anyhow::bail!("Linux implementation not yet available")
}
async fn take_screenshot(&self, _path: &str, _region: Option<Rect>, _window_id: Option<&str>) -> Result<()> {
// Enforce that window_id must be provided
if _window_id.is_none() {
anyhow::bail!("window_id is required. You must specify which window to capture (e.g., 'Firefox', 'Terminal', 'gedit'). Use list_windows to see available windows.");
}
anyhow::bail!("Linux implementation not yet available")
}
async fn extract_text_from_screen(&self, _region: Rect, _window_id: &str) -> Result<String> {
anyhow::bail!("Linux implementation not yet available")
}
async fn extract_text_from_image(&self, _path: &str) -> Result<OCRResult> {
// Check if tesseract is available on the system
let tesseract_check = std::process::Command::new("which")
.arg("tesseract")
.output();
if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() {
anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\
To install tesseract:\n \
Ubuntu/Debian: sudo apt-get install tesseract-ocr\n \
RHEL/CentOS: sudo yum install tesseract\n \
Arch Linux: sudo pacman -S tesseract\n\n\
After installation, restart your terminal and try again.");
}
// Initialize Tesseract
let tess = Tesseract::new(None, Some("eng"))
.map_err(|e| {
anyhow::anyhow!("Failed to initialize Tesseract: {}\n\n\
This usually means:\n1. Tesseract is not properly installed\n\
2. Language data files are missing\n\nTo fix:\n \
Ubuntu/Debian: sudo apt-get install tesseract-ocr-eng\n \
RHEL/CentOS: sudo yum install tesseract-langpack-eng\n \
Arch Linux: sudo pacman -S tesseract-data-eng", e)
})?;
let text = tess.set_image(_path)
.map_err(|e| anyhow::anyhow!("Failed to load image '{}': {}", _path, e))?
.get_text()
.map_err(|e| anyhow::anyhow!("Failed to extract text from image: {}", e))?;
// Get confidence (simplified - would need more complex API calls for per-word confidence)
let confidence = 0.85; // Placeholder
Ok(OCRResult {
text,
confidence,
bounds: Rect { x: 0, y: 0, width: 0, height: 0 }, // Would need image dimensions
})
}
async fn find_text_on_screen(&self, _text: &str) -> Result<Option<Point>> {
// Check if tesseract is available on the system
let tesseract_check = std::process::Command::new("which")
.arg("tesseract")
.output();
if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() {
anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\
To install tesseract:\n \
Ubuntu/Debian: sudo apt-get install tesseract-ocr\n \
RHEL/CentOS: sudo yum install tesseract\n \
Arch Linux: sudo pacman -S tesseract\n\n\
After installation, restart your terminal and try again.");
}
// Take full screen screenshot
let temp_path = format!("/tmp/g3_ocr_search_{}.png", uuid::Uuid::new_v4());
self.take_screenshot(&temp_path, None, None).await?;
// Use Tesseract to find text with bounding boxes
let tess = Tesseract::new(None, Some("eng"))
.map_err(|e| {
anyhow::anyhow!("Failed to initialize Tesseract: {}\n\n\
This usually means:\n1. Tesseract is not properly installed\n\
2. Language data files are missing\n\nTo fix:\n \
Ubuntu/Debian: sudo apt-get install tesseract-ocr-eng\n \
RHEL/CentOS: sudo yum install tesseract-langpack-eng\n \
Arch Linux: sudo pacman -S tesseract-data-eng", e)
})?;
let full_text = tess.set_image(temp_path.as_str())
.map_err(|e| anyhow::anyhow!("Failed to load screenshot: {}", e))?
.get_text()
.map_err(|e| anyhow::anyhow!("Failed to extract text from screen: {}", e))?;
// Clean up temp file
let _ = std::fs::remove_file(&temp_path);
// Simple text search - full implementation would use get_component_images
// to get bounding boxes for each word
if full_text.contains(_text) {
tracing::warn!("Text found but precise coordinates not available in simplified implementation");
Ok(Some(Point { x: 0, y: 0 }))
} else {
Ok(None)
}
}
}

View File

@@ -0,0 +1,507 @@
use crate::{ComputerController, types::{Rect, TextLocation}};
use crate::ocr::{OCREngine, DefaultOCR};
use anyhow::{Result, Context};
use async_trait::async_trait;
use std::path::Path;
use core_graphics::window::{kCGWindowListOptionOnScreenOnly, kCGNullWindowID, CGWindowListCopyWindowInfo};
use core_foundation::dictionary::CFDictionary;
use core_foundation::string::CFString;
use core_foundation::base::{TCFType, ToVoid};
use core_foundation::array::CFArray;
pub struct MacOSController {
ocr_engine: Box<dyn OCREngine>,
#[allow(dead_code)]
ocr_name: String,
}
impl MacOSController {
pub fn new() -> Result<Self> {
let ocr = Box::new(DefaultOCR::new()?);
let ocr_name = ocr.name().to_string();
tracing::info!("Initialized macOS controller with OCR engine: {}", ocr_name);
Ok(Self { ocr_engine: ocr, ocr_name })
}
}
#[async_trait]
impl ComputerController for MacOSController {
async fn take_screenshot(&self, path: &str, region: Option<Rect>, window_id: Option<&str>) -> Result<()> {
// Enforce that window_id must be provided
if window_id.is_none() {
return Err(anyhow::anyhow!("window_id is required. You must specify which window to capture (e.g., 'Safari', 'Terminal', 'Google Chrome'). Use list_windows to see available windows."));
}
// Determine the temporary directory for screenshots
let temp_dir = std::env::var("TMPDIR")
.or_else(|_| std::env::var("HOME").map(|h| format!("{}/tmp", h)))
.unwrap_or_else(|_| "/tmp".to_string());
// Ensure temp directory exists
std::fs::create_dir_all(&temp_dir)?;
// If path is relative or doesn't specify a directory, use temp_dir
let final_path = if path.starts_with('/') {
path.to_string()
} else {
format!("{}/{}", temp_dir.trim_end_matches('/'), path)
};
let path_obj = Path::new(&final_path);
if let Some(parent) = path_obj.parent() {
std::fs::create_dir_all(parent)?;
}
let app_name = window_id.unwrap(); // Safe because we checked is_none() above
// Get the window ID for the specified application
let cg_window_id = unsafe {
let window_list = CGWindowListCopyWindowInfo(
kCGWindowListOptionOnScreenOnly,
kCGNullWindowID
);
let array = CFArray::<CFDictionary>::wrap_under_create_rule(window_list);
let count = array.len();
let mut found_window_id: Option<(u32, String)> = None; // (id, owner)
let app_name_lower = app_name.to_lowercase();
for i in 0..count {
let dict = array.get(i).unwrap();
// Get owner name
let owner_key = CFString::from_static_string("kCGWindowOwnerName");
let owner: String = if let Some(value) = dict.find(owner_key.to_void()) {
let s: CFString = TCFType::wrap_under_get_rule(*value as *const _);
s.to_string()
} else {
continue;
};
tracing::debug!("Checking window: owner='{}', looking for '{}'", owner, app_name);
let owner_lower = owner.to_lowercase();
// Normalize by removing spaces for exact matching
let app_name_normalized = app_name_lower.replace(" ", "");
let owner_normalized = owner_lower.replace(" ", "");
// ONLY accept exact matches (case-insensitive, with or without spaces)
// This prevents "Goose" from matching "GooseStudio"
let is_match = owner_lower == app_name_lower || owner_normalized == app_name_normalized;
if is_match {
// Get window ID
let window_id_key = CFString::from_static_string("kCGWindowNumber");
if let Some(value) = dict.find(window_id_key.to_void()) {
let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _);
if let Some(id) = num.to_i64() {
// Get window layer to filter out menu bar windows
let layer_key = CFString::from_static_string("kCGWindowLayer");
let layer: i32 = if let Some(value) = dict.find(layer_key.to_void()) {
let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _);
num.to_i32().unwrap_or(0)
} else {
0
};
// Get window bounds to verify it's a real window
let bounds_key = CFString::from_static_string("kCGWindowBounds");
let has_real_bounds = if let Some(value) = dict.find(bounds_key.to_void()) {
let bounds_dict: CFDictionary = TCFType::wrap_under_get_rule(*value as *const _);
let width_key = CFString::from_static_string("Width");
let height_key = CFString::from_static_string("Height");
if let (Some(w_val), Some(h_val)) = (
bounds_dict.find(width_key.to_void()),
bounds_dict.find(height_key.to_void()),
) {
let w_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*w_val as *const _);
let h_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*h_val as *const _);
let width = w_num.to_f64().unwrap_or(0.0);
let height = h_num.to_f64().unwrap_or(0.0);
// Real windows should be at least 100x100 pixels
width >= 100.0 && height >= 100.0
} else {
false
}
} else {
false
};
// Only accept windows that are:
// 1. At layer 0 (normal windows, not menu bar)
// 2. Have real bounds (width and height >= 100)
if layer == 0 && has_real_bounds {
tracing::info!("Found valid window: ID {} for app '{}' (layer={}, bounds valid)", id, owner, layer);
found_window_id = Some((id as u32, owner.clone()));
break;
} else {
tracing::debug!("Skipping window ID {} for '{}': layer={}, has_real_bounds={}", id, owner, layer, has_real_bounds);
}
}
}
}
}
found_window_id
};
let (cg_window_id, matched_owner) = cg_window_id.ok_or_else(|| {
anyhow::anyhow!("Could not find window for application '{}'. Use list_windows to see available windows.", app_name)
})?;
tracing::info!("Taking screenshot of window ID {} for app '{}'", cg_window_id, matched_owner);
// Use screencapture with the window ID for now
// TODO: Implement direct CGWindowListCreateImage approach with proper image saving
let mut cmd = std::process::Command::new("screencapture");
cmd.arg("-x"); // No sound
cmd.arg("-l");
cmd.arg(cg_window_id.to_string());
if let Some(region) = region {
cmd.arg("-R");
cmd.arg(format!("{},{},{},{}", region.x, region.y, region.width, region.height));
}
cmd.arg(&final_path);
let screenshot_result = cmd.output()?;
if !screenshot_result.status.success() {
let stderr = String::from_utf8_lossy(&screenshot_result.stderr);
return Err(anyhow::anyhow!("screencapture failed for window {}: {}", cg_window_id, stderr));
}
Ok(())
}
async fn extract_text_from_screen(&self, region: Rect, window_id: &str) -> Result<String> {
// Take screenshot of region first
let temp_path = format!("/tmp/g3_ocr_{}.png", uuid::Uuid::new_v4());
self.take_screenshot(&temp_path, Some(region), Some(window_id)).await?;
// Extract text from the screenshot
let result = self.extract_text_from_image(&temp_path).await?;
// Clean up temp file
let _ = std::fs::remove_file(&temp_path);
Ok(result)
}
async fn extract_text_from_image(&self, path: &str) -> Result<String> {
// Extract all text and concatenate
let locations = self.ocr_engine.extract_text_with_locations(path).await?;
Ok(locations.iter().map(|loc| loc.text.as_str()).collect::<Vec<_>>().join(" "))
}
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>> {
// Use the OCR engine
self.ocr_engine.extract_text_with_locations(path).await
}
async fn find_text_in_app(&self, app_name: &str, search_text: &str) -> Result<Option<TextLocation>> {
// Take screenshot of specific app window
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
let temp_path = format!("{}/tmp/g3_find_text_{}_{}.png", home, app_name, uuid::Uuid::new_v4());
self.take_screenshot(&temp_path, None, Some(app_name)).await?;
// Get screenshot dimensions before we delete it
let screenshot_dims = get_image_dimensions(&temp_path)?;
// Extract all text with locations
let locations = self.extract_text_with_locations(&temp_path).await?;
// Get window bounds to calculate coordinate transformation
let window_bounds = self.get_window_bounds(app_name)?;
// Clean up temp file
let _ = std::fs::remove_file(&temp_path);
// Find matching text (case-insensitive)
let search_lower = search_text.to_lowercase();
for location in locations {
if location.text.to_lowercase().contains(&search_lower) {
// Transform coordinates from screenshot space to screen space
let transformed = transform_screenshot_to_screen_coords(
location,
window_bounds,
screenshot_dims,
);
return Ok(Some(transformed));
}
}
Ok(None)
}
fn move_mouse(&self, x: i32, y: i32) -> Result<()> {
use core_graphics::event::{
CGEvent, CGEventTapLocation, CGEventType, CGMouseButton,
};
use core_graphics::event_source::{
CGEventSource, CGEventSourceStateID,
};
use core_graphics::geometry::CGPoint;
let source = CGEventSource::new(CGEventSourceStateID::HIDSystemState)
.ok().context("Failed to create event source")?;
let event = CGEvent::new_mouse_event(
source,
CGEventType::MouseMoved,
CGPoint::new(x as f64, y as f64),
CGMouseButton::Left,
).ok().context("Failed to create mouse event")?;
event.post(CGEventTapLocation::HID);
Ok(())
}
fn click_at(&self, x: i32, y: i32, _app_name: Option<&str>) -> Result<()> {
use core_graphics::event::{
CGEvent, CGEventTapLocation, CGEventType, CGMouseButton,
};
use core_graphics::event_source::{
CGEventSource, CGEventSourceStateID,
};
use core_graphics::geometry::CGPoint;
use core_graphics::display::CGDisplay;
// IMPORTANT: Coordinates passed here are in NSScreen/CGWindowListCopyWindowInfo space
// (Y=0 at BOTTOM, increases UPWARD)
// But CGEvent uses a different coordinate system (Y=0 at TOP, increases DOWNWARD)
// We need to convert: CGEvent.y = screenHeight - NSScreen.y
let screen_height = CGDisplay::main().pixels_high() as i32;
let cgevent_x = x;
let cgevent_y = screen_height - y;
tracing::debug!("click_at: NSScreen coords ({}, {}) -> CGEvent coords ({}, {}) [screen_height={}]",
x, y, cgevent_x, cgevent_y, screen_height);
let (global_x, global_y) = (cgevent_x, cgevent_y);
let point = CGPoint::new(global_x as f64, global_y as f64);
let source = CGEventSource::new(CGEventSourceStateID::HIDSystemState)
.ok().context("Failed to create event source")?;
// Move mouse to position first
let move_event = CGEvent::new_mouse_event(
source.clone(),
CGEventType::MouseMoved,
point,
CGMouseButton::Left,
).ok().context("Failed to create mouse move event")?;
move_event.post(CGEventTapLocation::HID);
std::thread::sleep(std::time::Duration::from_millis(100));
// Mouse down
let mouse_down = CGEvent::new_mouse_event(
source.clone(),
CGEventType::LeftMouseDown,
point,
CGMouseButton::Left,
).ok().context("Failed to create mouse down event")?;
mouse_down.post(CGEventTapLocation::HID);
std::thread::sleep(std::time::Duration::from_millis(50));
// Mouse up
let mouse_up = CGEvent::new_mouse_event(
source,
CGEventType::LeftMouseUp,
point,
CGMouseButton::Left,
).ok().context("Failed to create mouse up event")?;
mouse_up.post(CGEventTapLocation::HID);
Ok(())
}
}
impl MacOSController {
/// Get window bounds for an application (helper method)
fn get_window_bounds(&self, app_name: &str) -> Result<(i32, i32, i32, i32)> {
unsafe {
let window_list = CGWindowListCopyWindowInfo(
kCGWindowListOptionOnScreenOnly,
kCGNullWindowID
);
let array = CFArray::<CFDictionary>::wrap_under_create_rule(window_list);
let count = array.len();
let app_name_lower = app_name.to_lowercase();
for i in 0..count {
let dict = array.get(i).unwrap();
// Get owner name
let owner_key = CFString::from_static_string("kCGWindowOwnerName");
let owner: String = if let Some(value) = dict.find(owner_key.to_void()) {
let s: CFString = TCFType::wrap_under_get_rule(*value as *const _);
s.to_string()
} else {
continue;
};
let owner_lower = owner.to_lowercase();
// Normalize by removing spaces for exact matching
let app_name_normalized = app_name_lower.replace(" ", "");
let owner_normalized = owner_lower.replace(" ", "");
// ONLY accept exact matches (case-insensitive, with or without spaces)
// This prevents "Goose" from matching "GooseStudio"
let is_match = owner_lower == app_name_lower || owner_normalized == app_name_normalized;
if is_match {
// Get window layer to filter out menu bar windows
let layer_key = CFString::from_static_string("kCGWindowLayer");
let layer: i32 = if let Some(value) = dict.find(layer_key.to_void()) {
let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _);
num.to_i32().unwrap_or(0)
} else {
0
};
// Skip menu bar windows (layer >= 20)
if layer >= 20 {
tracing::debug!("Skipping window for '{}' at layer {} (menu bar)", owner, layer);
continue;
}
// Get window bounds to verify it's a real window
let bounds_key = CFString::from_static_string("kCGWindowBounds");
if let Some(value) = dict.find(bounds_key.to_void()) {
let bounds_dict: CFDictionary = TCFType::wrap_under_get_rule(*value as *const _);
let x_key = CFString::from_static_string("X");
let y_key = CFString::from_static_string("Y");
let width_key = CFString::from_static_string("Width");
let height_key = CFString::from_static_string("Height");
if let (Some(x_val), Some(y_val), Some(w_val), Some(h_val)) = (
bounds_dict.find(x_key.to_void()),
bounds_dict.find(y_key.to_void()),
bounds_dict.find(width_key.to_void()),
bounds_dict.find(height_key.to_void()),
) {
let x_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*x_val as *const _);
let y_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*y_val as *const _);
let w_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*w_val as *const _);
let h_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*h_val as *const _);
let x: i32 = x_num.to_i64().unwrap_or(0) as i32;
let y: i32 = y_num.to_i64().unwrap_or(0) as i32;
let w: i32 = w_num.to_i64().unwrap_or(0) as i32;
let h: i32 = h_num.to_i64().unwrap_or(0) as i32;
// Only accept windows with real bounds (>= 100x100 pixels)
if w >= 100 && h >= 100 {
tracing::info!("Found valid window bounds for '{}': x={}, y={}, w={}, h={} (layer={})", owner, x, y, w, h, layer);
return Ok((x, y, w, h));
} else {
tracing::debug!("Skipping window for '{}': too small ({}x{})", owner, w, h);
continue;
}
} else {
continue;
}
}
}
}
}
Err(anyhow::anyhow!("Could not find window bounds for '{}'", app_name))
}
}
/// Get image dimensions from a PNG file
fn get_image_dimensions(path: &str) -> Result<(i32, i32)> {
use std::fs::File;
use std::io::Read;
let mut file = File::open(path)?;
let mut buffer = vec![0u8; 24];
file.read_exact(&mut buffer)?;
// PNG signature check
if &buffer[0..8] != b"\x89PNG\r\n\x1a\n" {
anyhow::bail!("Not a valid PNG file");
}
// Read IHDR chunk (width and height are at bytes 16-23)
let width = u32::from_be_bytes([buffer[16], buffer[17], buffer[18], buffer[19]]) as i32;
let height = u32::from_be_bytes([buffer[20], buffer[21], buffer[22], buffer[23]]) as i32;
Ok((width, height))
}
/// Transform coordinates from screenshot space to screen space
///
/// The screenshot is taken of a window, and Vision OCR returns coordinates
/// relative to the screenshot image. We need to transform these to actual
/// screen coordinates for clicking.
///
/// On Retina displays, screenshots are taken at 2x resolution, so we need
/// to account for this scaling factor.
fn transform_screenshot_to_screen_coords(
location: TextLocation,
window_bounds: (i32, i32, i32, i32), // (x, y, width, height) in screen space
screenshot_dims: (i32, i32), // (width, height) in pixels
) -> TextLocation {
let (win_x, win_y, win_width, win_height) = window_bounds;
let (screenshot_width, screenshot_height) = screenshot_dims;
// Calculate scale factors
// On Retina displays, screenshot is typically 2x the window size
let scale_x = win_width as f64 / screenshot_width as f64;
let scale_y = win_height as f64 / screenshot_height as f64;
tracing::debug!("Transform: screenshot={}x{}, window={}x{} at ({},{}), scale=({:.2},{:.2})",
screenshot_width, screenshot_height, win_width, win_height, win_x, win_y, scale_x, scale_y);
// Transform coordinates from image space to screen space
// IMPORTANT: macOS screen coordinates have origin at BOTTOM-LEFT (Y increases upward)
// Image coordinates have origin at TOP-LEFT (Y increases downward)
// win_y is the BOTTOM of the window in screen coordinates
// So we need to: (win_y + win_height) to get window TOP, then subtract screenshot_y
let window_top_y = win_y + win_height;
tracing::debug!("[transform] Input location in image space: x={}, y={}, width={}, height={}",
location.x, location.y, location.width, location.height);
tracing::debug!("[transform] Scale factors: scale_x={:.4}, scale_y={:.4}", scale_x, scale_y);
let transformed_x = win_x + (location.x as f64 * scale_x) as i32;
let transformed_y = window_top_y - (location.y as f64 * scale_y) as i32;
let transformed_width = (location.width as f64 * scale_x) as i32;
let transformed_height = (location.height as f64 * scale_y) as i32;
tracing::debug!("[transform] Calculation details:");
tracing::debug!(" - transformed_x = {} + ({} * {:.4}) = {} + {:.2} = {}", win_x, location.x, scale_x, win_x, location.x as f64 * scale_x, transformed_x);
tracing::debug!(" - transformed_width = ({} * {:.4}) = {:.2} -> {}", location.width, scale_x, location.width as f64 * scale_x, transformed_width);
tracing::debug!(" - transformed_height = ({} * {:.4}) = {:.2} -> {}", location.height, scale_y, location.height as f64 * scale_y, transformed_height);
tracing::debug!("Transformed location: screenshot=({},{}) {}x{} -> screen=({},{}) {}x{}",
location.x, location.y, location.width, location.height,
transformed_x, transformed_y, transformed_width, transformed_height);
TextLocation {
text: location.text,
x: transformed_x,
y: transformed_y,
width: transformed_width,
height: transformed_height,
confidence: location.confidence,
}
}
#[path = "macos_window_matching_test.rs"]
#[cfg(test)]
mod tests;

View File

@@ -0,0 +1,45 @@
#[cfg(test)]
mod window_matching_tests {
/// Test that window name matching handles spaces correctly
///
/// Issue: When a user requests a screenshot of "Goose Studio" but the actual
/// application name is "GooseStudio" (no space), the fuzzy matching should
/// still find the window.
///
/// The fix normalizes both names by removing spaces before comparing.
#[test]
fn test_space_normalization() {
let test_cases = vec![
// (user_input, actual_app_name, should_match)
("Goose Studio", "GooseStudio", true),
("GooseStudio", "Goose Studio", true),
("Visual Studio Code", "VisualStudioCode", true),
("Google Chrome", "Google Chrome", true),
("Safari", "Safari", true),
("iTerm", "iTerm2", true), // fuzzy match
("Code", "Visual Studio Code", true), // fuzzy match
];
for (user_input, app_name, should_match) in test_cases {
let user_lower = user_input.to_lowercase();
let app_lower = app_name.to_lowercase();
let user_normalized = user_lower.replace(" ", "");
let app_normalized = app_lower.replace(" ", "");
let is_exact = app_lower == user_lower || app_normalized == user_normalized;
let is_fuzzy = app_lower.contains(&user_lower)
|| user_lower.contains(&app_lower)
|| app_normalized.contains(&user_normalized)
|| user_normalized.contains(&app_normalized);
let matches = is_exact || is_fuzzy;
assert_eq!(
matches, should_match,
"Expected '{}' vs '{}' to match={}, but got match={}",
user_input, app_name, should_match, matches
);
}
}
}

View File

@@ -0,0 +1,8 @@
#[cfg(target_os = "macos")]
pub mod macos;
#[cfg(target_os = "linux")]
pub mod linux;
#[cfg(target_os = "windows")]
pub mod windows;

View File

@@ -0,0 +1,167 @@
use crate::{ComputerController, types::*};
use anyhow::Result;
use async_trait::async_trait;
use tesseract::Tesseract;
use uuid::Uuid;
pub struct WindowsController {
// Placeholder for Windows-specific state
}
impl WindowsController {
pub fn new() -> Result<Self> {
tracing::warn!("Windows computer control not fully implemented");
Ok(Self {})
}
}
#[async_trait]
impl ComputerController for WindowsController {
async fn move_mouse(&self, _x: i32, _y: i32) -> Result<()> {
anyhow::bail!("Windows implementation not yet available")
}
async fn click(&self, _button: MouseButton) -> Result<()> {
anyhow::bail!("Windows implementation not yet available")
}
async fn double_click(&self, _button: MouseButton) -> Result<()> {
anyhow::bail!("Windows implementation not yet available")
}
async fn type_text(&self, _text: &str) -> Result<()> {
anyhow::bail!("Windows implementation not yet available")
}
async fn press_key(&self, _key: &str) -> Result<()> {
anyhow::bail!("Windows implementation not yet available")
}
async fn list_windows(&self) -> Result<Vec<Window>> {
anyhow::bail!("Windows implementation not yet available")
}
async fn focus_window(&self, _window_id: &str) -> Result<()> {
anyhow::bail!("Windows implementation not yet available")
}
async fn get_window_bounds(&self, _window_id: &str) -> Result<Rect> {
anyhow::bail!("Windows implementation not yet available")
}
async fn find_element(&self, _selector: &ElementSelector) -> Result<Option<UIElement>> {
anyhow::bail!("Windows implementation not yet available")
}
async fn get_element_text(&self, _element_id: &str) -> Result<String> {
anyhow::bail!("Windows implementation not yet available")
}
async fn get_element_bounds(&self, _element_id: &str) -> Result<Rect> {
anyhow::bail!("Windows implementation not yet available")
}
async fn take_screenshot(&self, _path: &str, _region: Option<Rect>, _window_id: Option<&str>) -> Result<()> {
// Enforce that window_id must be provided
if _window_id.is_none() {
anyhow::bail!("window_id is required. You must specify which window to capture (e.g., 'Chrome', 'Terminal', 'Notepad'). Use list_windows to see available windows.");
}
anyhow::bail!("Windows implementation not yet available")
}
async fn extract_text_from_screen(&self, _region: Rect, _window_id: &str) -> Result<String> {
anyhow::bail!("Windows implementation not yet available")
}
async fn extract_text_from_image(&self, _path: &str) -> Result<OCRResult> {
// Check if tesseract is available on the system
let tesseract_check = std::process::Command::new("where")
.arg("tesseract")
.output();
if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() {
anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\
To install tesseract on Windows:\n \
1. Download the installer from: https://github.com/UB-Mannheim/tesseract/wiki\n \
2. Run the installer and follow the instructions\n \
3. Add tesseract to your PATH environment variable\n \
4. Restart your terminal/command prompt\n\n\
After installation, restart your terminal and try again.");
}
// Initialize Tesseract
let tess = Tesseract::new(None, Some("eng"))
.map_err(|e| {
anyhow::anyhow!("Failed to initialize Tesseract: {}\n\n\
This usually means:\n1. Tesseract is not properly installed\n\
2. Language data files are missing\n\nTo fix:\n \
1. Reinstall tesseract from https://github.com/UB-Mannheim/tesseract/wiki\n \
2. Make sure to select 'Additional language data' during installation\n \
3. Ensure tesseract is in your PATH", e)
})?;
let text = tess.set_image(_path)
.map_err(|e| anyhow::anyhow!("Failed to load image '{}': {}", _path, e))?
.get_text()
.map_err(|e| anyhow::anyhow!("Failed to extract text from image: {}", e))?;
// Get confidence (simplified - would need more complex API calls for per-word confidence)
let confidence = 0.85; // Placeholder
Ok(OCRResult {
text,
confidence,
bounds: Rect { x: 0, y: 0, width: 0, height: 0 }, // Would need image dimensions
})
}
async fn find_text_on_screen(&self, _text: &str) -> Result<Option<Point>> {
// Check if tesseract is available on the system
let tesseract_check = std::process::Command::new("where")
.arg("tesseract")
.output();
if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() {
anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\
To install tesseract on Windows:\n \
1. Download the installer from: https://github.com/UB-Mannheim/tesseract/wiki\n \
2. Run the installer and follow the instructions\n \
3. Add tesseract to your PATH environment variable\n \
4. Restart your terminal/command prompt\n\n\
After installation, restart your terminal and try again.");
}
// Take full screen screenshot
let temp_path = format!("C:\\\\Temp\\\\g3_ocr_search_{}.png", uuid::Uuid::new_v4());
self.take_screenshot(&temp_path, None, None).await?;
// Use Tesseract to find text with bounding boxes
let tess = Tesseract::new(None, Some("eng"))
.map_err(|e| {
anyhow::anyhow!("Failed to initialize Tesseract: {}\n\n\
This usually means:\n1. Tesseract is not properly installed\n\
2. Language data files are missing\n\nTo fix:\n \
1. Reinstall tesseract from https://github.com/UB-Mannheim/tesseract/wiki\n \
2. Make sure to select 'Additional language data' during installation\n \
3. Ensure tesseract is in your PATH", e)
})?;
let full_text = tess.set_image(temp_path.as_str())
.map_err(|e| anyhow::anyhow!("Failed to load screenshot: {}", e))?
.get_text()
.map_err(|e| anyhow::anyhow!("Failed to extract text from screen: {}", e))?;
// Clean up temp file
let _ = std::fs::remove_file(&temp_path);
// Simple text search - full implementation would use get_component_images
// to get bounding boxes for each word
if full_text.contains(_text) {
tracing::warn!("Text found but precise coordinates not available in simplified implementation");
Ok(Some(Point { x: 0, y: 0 }))
} else {
Ok(None)
}
}
}

View File

@@ -0,0 +1,19 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct Rect {
pub x: i32,
pub y: i32,
pub width: i32,
pub height: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextLocation {
pub text: String,
pub x: i32,
pub y: i32,
pub width: i32,
pub height: i32,
pub confidence: f32,
}

View File

@@ -0,0 +1,111 @@
pub mod safari;
use anyhow::Result;
use async_trait::async_trait;
use serde_json::Value;
/// WebDriver controller for browser automation
#[async_trait]
pub trait WebDriverController: Send + Sync {
/// Navigate to a URL
async fn navigate(&mut self, url: &str) -> Result<()>;
/// Get the current URL
async fn current_url(&self) -> Result<String>;
/// Get the page title
async fn title(&self) -> Result<String>;
/// Find an element by CSS selector
async fn find_element(&mut self, selector: &str) -> Result<WebElement>;
/// Find multiple elements by CSS selector
async fn find_elements(&mut self, selector: &str) -> Result<Vec<WebElement>>;
/// Execute JavaScript in the browser
async fn execute_script(&mut self, script: &str, args: Vec<Value>) -> Result<Value>;
/// Get the page source (HTML)
async fn page_source(&self) -> Result<String>;
/// Take a screenshot and save to path
async fn screenshot(&mut self, path: &str) -> Result<()>;
/// Close the current window/tab
async fn close(&mut self) -> Result<()>;
/// Quit the browser session
async fn quit(self) -> Result<()>;
}
/// Represents a web element in the DOM
pub struct WebElement {
pub(crate) inner: fantoccini::elements::Element,
}
impl WebElement {
/// Click the element
pub async fn click(&mut self) -> Result<()> {
self.inner.click().await?;
Ok(())
}
/// Send keys/text to the element
pub async fn send_keys(&mut self, text: &str) -> Result<()> {
self.inner.send_keys(text).await?;
Ok(())
}
/// Clear the element's content (for input fields)
pub async fn clear(&mut self) -> Result<()> {
self.inner.clear().await?;
Ok(())
}
/// Get the element's text content
pub async fn text(&self) -> Result<String> {
Ok(self.inner.text().await?)
}
/// Get an attribute value
pub async fn attr(&self, name: &str) -> Result<Option<String>> {
Ok(self.inner.attr(name).await?)
}
/// Get a property value
pub async fn prop(&self, name: &str) -> Result<Option<String>> {
Ok(self.inner.prop(name).await?)
}
/// Get the element's HTML
pub async fn html(&self, inner: bool) -> Result<String> {
Ok(self.inner.html(inner).await?)
}
/// Check if element is displayed
pub async fn is_displayed(&self) -> Result<bool> {
Ok(self.inner.is_displayed().await?)
}
/// Check if element is enabled
pub async fn is_enabled(&self) -> Result<bool> {
Ok(self.inner.is_enabled().await?)
}
/// Check if element is selected (for checkboxes/radio buttons)
pub async fn is_selected(&self) -> Result<bool> {
Ok(self.inner.is_selected().await?)
}
/// Find a child element by CSS selector
pub async fn find_element(&mut self, selector: &str) -> Result<WebElement> {
let elem = self.inner.find(fantoccini::Locator::Css(selector)).await?;
Ok(WebElement { inner: elem })
}
/// Find multiple child elements by CSS selector
pub async fn find_elements(&mut self, selector: &str) -> Result<Vec<WebElement>> {
let elems = self.inner.find_all(fantoccini::Locator::Css(selector)).await?;
Ok(elems.into_iter().map(|inner| WebElement { inner }).collect())
}
}

View File

@@ -0,0 +1,212 @@
use super::{WebDriverController, WebElement};
use anyhow::{Context, Result};
use async_trait::async_trait;
use fantoccini::{Client, ClientBuilder};
use serde_json::Value;
use std::time::Duration;
/// SafariDriver WebDriver controller
pub struct SafariDriver {
client: Client,
}
impl SafariDriver {
/// Create a new SafariDriver instance
///
/// This will connect to SafariDriver running on the default port (4444).
/// Make sure to enable "Allow Remote Automation" in Safari's Develop menu first.
///
/// You can start SafariDriver manually with:
/// ```bash
/// /usr/bin/safaridriver --enable
/// ```
pub async fn new() -> Result<Self> {
Self::with_port(4444).await
}
/// Create a new SafariDriver instance with a custom port
pub async fn with_port(port: u16) -> Result<Self> {
let url = format!("http://localhost:{}", port);
let mut caps = serde_json::Map::new();
caps.insert("browserName".to_string(), Value::String("safari".to_string()));
let client = ClientBuilder::native()
.capabilities(caps)
.connect(&url)
.await
.context("Failed to connect to SafariDriver. Make sure SafariDriver is running and 'Allow Remote Automation' is enabled in Safari's Develop menu.")?;
Ok(Self { client })
}
/// Go back in browser history
pub async fn back(&mut self) -> Result<()> {
self.client.back().await?;
Ok(())
}
/// Go forward in browser history
pub async fn forward(&mut self) -> Result<()> {
self.client.forward().await?;
Ok(())
}
/// Refresh the current page
pub async fn refresh(&mut self) -> Result<()> {
self.client.refresh().await?;
Ok(())
}
/// Get all window handles
pub async fn window_handles(&mut self) -> Result<Vec<String>> {
let handles = self.client.windows().await?;
Ok(handles.into_iter()
.map(|h| h.into())
.collect())
}
/// Switch to a window by handle
pub async fn switch_to_window(&mut self, handle: &str) -> Result<()> {
let window_handle: fantoccini::wd::WindowHandle = handle.to_string().try_into()?;
self.client.switch_to_window(window_handle).await?;
Ok(())
}
/// Get the current window handle
pub async fn current_window_handle(&mut self) -> Result<String> {
Ok(self.client.window().await?.into())
}
/// Close the current window
pub async fn close_window(&mut self) -> Result<()> {
self.client.close_window().await?;
Ok(())
}
/// Create a new window/tab
pub async fn new_window(&mut self, is_tab: bool) -> Result<String> {
let window_type = if is_tab { "tab" } else { "window" };
let response = self.client.new_window(window_type == "tab").await?;
Ok(response.handle.into())
}
/// Get cookies
pub async fn get_cookies(&mut self) -> Result<Vec<fantoccini::cookies::Cookie<'static>>> {
Ok(self.client.get_all_cookies().await?)
}
/// Add a cookie
pub async fn add_cookie(&mut self, cookie: fantoccini::cookies::Cookie<'static>) -> Result<()> {
self.client.add_cookie(cookie).await?;
Ok(())
}
/// Delete all cookies
pub async fn delete_all_cookies(&mut self) -> Result<()> {
self.client.delete_all_cookies().await?;
Ok(())
}
/// Wait for an element to appear (with timeout)
pub async fn wait_for_element(&mut self, selector: &str, timeout: Duration) -> Result<WebElement> {
let start = std::time::Instant::now();
let poll_interval = Duration::from_millis(100);
loop {
if let Ok(elem) = self.find_element(selector).await {
return Ok(elem);
}
if start.elapsed() >= timeout {
anyhow::bail!("Timeout waiting for element: {}", selector);
}
tokio::time::sleep(poll_interval).await;
}
}
/// Wait for an element to be visible (with timeout)
pub async fn wait_for_visible(&mut self, selector: &str, timeout: Duration) -> Result<WebElement> {
let start = std::time::Instant::now();
let poll_interval = Duration::from_millis(100);
loop {
if let Ok(elem) = self.find_element(selector).await {
if elem.is_displayed().await.unwrap_or(false) {
return Ok(elem);
}
}
if start.elapsed() >= timeout {
anyhow::bail!("Timeout waiting for element to be visible: {}", selector);
}
tokio::time::sleep(poll_interval).await;
}
}
}
#[async_trait]
impl WebDriverController for SafariDriver {
async fn navigate(&mut self, url: &str) -> Result<()> {
self.client.goto(url).await?;
Ok(())
}
async fn current_url(&self) -> Result<String> {
Ok(self.client.current_url().await?.to_string())
}
async fn title(&self) -> Result<String> {
Ok(self.client.title().await?)
}
async fn find_element(&mut self, selector: &str) -> Result<WebElement> {
let elem = self.client.find(fantoccini::Locator::Css(selector)).await
.context(format!("Failed to find element with selector: {}", selector))?;
Ok(WebElement { inner: elem })
}
async fn find_elements(&mut self, selector: &str) -> Result<Vec<WebElement>> {
let elems = self.client.find_all(fantoccini::Locator::Css(selector)).await?;
Ok(elems.into_iter().map(|inner| WebElement { inner }).collect())
}
async fn execute_script(&mut self, script: &str, args: Vec<Value>) -> Result<Value> {
Ok(self.client.execute(script, args).await?)
}
async fn page_source(&self) -> Result<String> {
Ok(self.client.source().await?)
}
async fn screenshot(&mut self, path: &str) -> Result<()> {
let screenshot_data = self.client.screenshot().await?;
// Expand tilde in path
let expanded_path = shellexpand::tilde(path);
let path_str = expanded_path.as_ref();
// Create parent directories if needed
if let Some(parent) = std::path::Path::new(path_str).parent() {
std::fs::create_dir_all(parent)
.context("Failed to create parent directories for screenshot")?;
}
std::fs::write(path_str, screenshot_data)
.context("Failed to write screenshot to file")?;
Ok(())
}
async fn close(&mut self) -> Result<()> {
self.client.close_window().await?;
Ok(())
}
async fn quit(mut self) -> Result<()> {
self.client.close().await?;
Ok(())
}
}

View File

@@ -0,0 +1,17 @@
use g3_computer_control::*;
#[tokio::test]
async fn test_screenshot() {
let controller = create_controller().expect("Failed to create controller");
// Take screenshot
let path = "/tmp/test_screenshot.png";
let result = controller.take_screenshot(path, None, None).await;
assert!(result.is_ok(), "Failed to take screenshot: {:?}", result.err());
// Verify file exists
assert!(std::path::Path::new(path).exists(), "Screenshot file was not created");
// Clean up
let _ = std::fs::remove_file(path);
}

View File

@@ -0,0 +1,24 @@
// swift-tools-version:5.9
import PackageDescription
let package = Package(
name: "VisionBridge",
platforms: [
.macOS(.v11)
],
products: [
.library(
name: "VisionBridge",
type: .dynamic,
targets: ["VisionBridge"]
),
],
targets: [
.target(
name: "VisionBridge",
dependencies: [],
path: "Sources/VisionBridge",
publicHeadersPath: "."
),
]
)

View File

@@ -0,0 +1,39 @@
#ifndef VisionBridge_h
#define VisionBridge_h
#include <stdint.h>
#include <stdbool.h>
#ifdef __cplusplus
extern "C" {
#endif
// Text box structure for FFI
typedef struct {
const char* text;
uint32_t text_len;
int32_t x;
int32_t y;
int32_t width;
int32_t height;
float confidence;
} VisionTextBox;
// Recognize text in an image and return bounding boxes
// Returns true on success, false on failure
// Caller must free the returned boxes using vision_free_boxes
bool vision_recognize_text(
const char* image_path,
uint32_t image_path_len,
VisionTextBox** out_boxes,
uint32_t* out_count
);
// Free memory allocated by vision_recognize_text
void vision_free_boxes(VisionTextBox* boxes, uint32_t count);
#ifdef __cplusplus
}
#endif
#endif /* VisionBridge_h */

View File

@@ -0,0 +1,145 @@
import Foundation
import Vision
import AppKit
import CoreGraphics
// MARK: - C Bridge Functions
@_cdecl("vision_recognize_text")
public func vision_recognize_text(
_ imagePath: UnsafePointer<CChar>,
_ imagePathLen: UInt32,
_ outBoxes: UnsafeMutablePointer<UnsafeMutableRawPointer?>,
_ outCount: UnsafeMutablePointer<UInt32>
) -> Bool {
// Convert C string to Swift String
guard let pathData = Data(bytes: imagePath, count: Int(imagePathLen)).withUnsafeBytes({
String(bytes: $0, encoding: .utf8)
}) else {
return false
}
let path = pathData.trimmingCharacters(in: .whitespaces)
// Load image
guard let image = NSImage(contentsOfFile: path),
let cgImage = image.cgImage(forProposedRect: nil, context: nil, hints: nil) else {
return false
}
// Perform OCR
var textBoxes: [CTextBox] = []
let semaphore = DispatchSemaphore(value: 0)
var success = false
let request = VNRecognizeTextRequest { request, error in
defer { semaphore.signal() }
if let error = error {
print("Vision OCR error: \(error.localizedDescription)")
return
}
guard let observations = request.results as? [VNRecognizedTextObservation] else {
return
}
let imageSize = CGSize(width: cgImage.width, height: cgImage.height)
for observation in observations {
guard let candidate = observation.topCandidates(1).first else { continue }
let text = candidate.string
let boundingBox = observation.boundingBox
// Convert normalized coordinates (bottom-left origin) to pixel coordinates (top-left origin)
let x = Int32(boundingBox.origin.x * imageSize.width)
let y = Int32((1.0 - boundingBox.origin.y - boundingBox.height) * imageSize.height)
let width = Int32(boundingBox.width * imageSize.width)
let height = Int32(boundingBox.height * imageSize.height)
// Allocate C string for text
let cString = strdup(text)
textBoxes.append(CTextBox(
text: cString,
text_len: UInt32(text.utf8.count),
x: x,
y: y,
width: width,
height: height,
confidence: observation.confidence
))
}
success = true
}
// Configure request for best accuracy
request.recognitionLevel = .accurate
request.usesLanguageCorrection = true
request.recognitionLanguages = ["en-US"]
// Perform request
let handler = VNImageRequestHandler(cgImage: cgImage, options: [:])
do {
try handler.perform([request])
} catch {
print("Vision request failed: \(error.localizedDescription)")
return false
}
// Wait for completion
semaphore.wait()
if !success {
return false
}
// Allocate array for results
let boxesPtr = UnsafeMutablePointer<CTextBox>.allocate(capacity: textBoxes.count)
for (index, box) in textBoxes.enumerated() {
boxesPtr[index] = box
}
outBoxes.pointee = UnsafeMutableRawPointer(boxesPtr)
outCount.pointee = UInt32(textBoxes.count)
return true
}
@_cdecl("vision_free_boxes")
public func vision_free_boxes(
_ boxes: UnsafeMutableRawPointer,
_ count: UInt32
) {
let typedBoxes = boxes.assumingMemoryBound(to: CTextBox.self)
for i in 0..<Int(count) {
if let text = typedBoxes[i].text {
free(UnsafeMutableRawPointer(mutating: text))
}
}
typedBoxes.deallocate()
}
// MARK: - C-Compatible Structure
public struct CTextBox {
public let text: UnsafePointer<CChar>?
public let text_len: UInt32
public let x: Int32
public let y: Int32
public let width: Int32
public let height: Int32
public let confidence: Float
public init(text: UnsafePointer<CChar>?, text_len: UInt32, x: Int32, y: Int32, width: Int32, height: Int32, confidence: Float) {
self.text = text
self.text_len = text_len
self.x = x
self.y = y
self.width = width
self.height = height
self.confidence = confidence
}
}

View File

@@ -12,3 +12,6 @@ thiserror = { workspace = true }
toml = "0.8"
shellexpand = "3.0"
dirs = "5.0"
[dev-dependencies]
tempfile = "3.8"

View File

@@ -6,14 +6,21 @@ use std::path::Path;
pub struct Config {
pub providers: ProvidersConfig,
pub agent: AgentConfig,
pub computer_control: ComputerControlConfig,
pub webdriver: WebDriverConfig,
pub macax: MacAxConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProvidersConfig {
pub openai: Option<OpenAIConfig>,
pub anthropic: Option<AnthropicConfig>,
pub databricks: Option<DatabricksConfig>,
pub embedded: Option<EmbeddedConfig>,
pub ollama: Option<OllamaConfig>,
pub default_provider: String,
pub coach: Option<String>, // Provider to use for coach in autonomous mode
pub player: Option<String>, // Provider to use for player in autonomous mode
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -33,6 +40,16 @@ pub struct AnthropicConfig {
pub temperature: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabricksConfig {
pub host: String,
pub token: Option<String>, // Optional - will use OAuth if not provided
pub model: String,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub use_oauth: Option<bool>, // Default to true if token not provided
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddedConfig {
pub model_path: String,
@@ -44,11 +61,65 @@ pub struct EmbeddedConfig {
pub threads: Option<u32>, // Number of CPU threads to use
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaConfig {
pub model: String,
pub base_url: Option<String>, // Default: http://localhost:11434
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
pub max_context_length: usize,
pub enable_streaming: bool,
pub timeout_seconds: u64,
pub auto_compact: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComputerControlConfig {
pub enabled: bool,
pub require_confirmation: bool,
pub max_actions_per_second: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebDriverConfig {
pub enabled: bool,
pub safari_port: u16,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MacAxConfig {
pub enabled: bool,
}
impl Default for MacAxConfig {
fn default() -> Self {
Self {
enabled: false,
}
}
}
impl Default for WebDriverConfig {
fn default() -> Self {
Self {
enabled: false,
safari_port: 4444,
}
}
}
impl Default for ComputerControlConfig {
fn default() -> Self {
Self {
enabled: false, // Disabled by default for safety
require_confirmation: true,
max_actions_per_second: 5,
}
}
}
impl Default for Config {
@@ -57,14 +128,29 @@ impl Default for Config {
providers: ProvidersConfig {
openai: None,
anthropic: None,
databricks: Some(DatabricksConfig {
host: "https://your-workspace.cloud.databricks.com".to_string(),
token: None, // Will use OAuth by default
model: "databricks-claude-sonnet-4".to_string(),
max_tokens: Some(4096),
temperature: Some(0.1),
use_oauth: Some(true),
}),
embedded: None,
default_provider: "anthropic".to_string(),
ollama: None,
default_provider: "databricks".to_string(),
coach: None, // Will use default_provider if not specified
player: None, // Will use default_provider if not specified
},
agent: AgentConfig {
max_context_length: 8192,
enable_streaming: true,
timeout_seconds: 60,
auto_compact: true,
},
computer_control: ComputerControlConfig::default(),
webdriver: WebDriverConfig::default(),
macax: MacAxConfig::default(),
}
}
}
@@ -88,9 +174,9 @@ impl Config {
})
};
// If no config exists, create and save a default Qwen config
// If no config exists, create and save a default Databricks config
if !config_exists {
let qwen_config = Self::default_qwen_config();
let databricks_config = Self::default();
// Save to default location
let config_dir = dirs::home_dir()
@@ -105,13 +191,13 @@ impl Config {
std::fs::create_dir_all(&config_dir).ok();
let config_file = config_dir.join("config.toml");
if let Err(e) = qwen_config.save(config_file.to_str().unwrap()) {
if let Err(e) = databricks_config.save(config_file.to_str().unwrap()) {
eprintln!("Warning: Could not save default config: {}", e);
} else {
println!("Created default Qwen configuration at: {}", config_file.display());
println!("Created default Databricks configuration at: {}", config_file.display());
}
return Ok(qwen_config);
return Ok(databricks_config);
}
// Existing config loading logic
@@ -151,12 +237,14 @@ impl Config {
let config = settings.build()?.try_deserialize()?;
Ok(config)
}
#[allow(dead_code)]
fn default_qwen_config() -> Self {
Self {
providers: ProvidersConfig {
openai: None,
anthropic: None,
databricks: None,
embedded: Some(EmbeddedConfig {
model_path: "~/.cache/g3/models/qwen2.5-7b-instruct-q3_k_m.gguf".to_string(),
model_type: "qwen".to_string(),
@@ -166,13 +254,20 @@ impl Config {
gpu_layers: Some(32),
threads: Some(8),
}),
ollama: None,
default_provider: "embedded".to_string(),
coach: None, // Will use default_provider if not specified
player: None, // Will use default_provider if not specified
},
agent: AgentConfig {
max_context_length: 8192,
enable_streaming: true,
timeout_seconds: 60,
auto_compact: true,
},
computer_control: ComputerControlConfig::default(),
webdriver: WebDriverConfig::default(),
macax: MacAxConfig::default(),
}
}
@@ -181,4 +276,127 @@ impl Config {
std::fs::write(path, toml_string)?;
Ok(())
}
pub fn load_with_overrides(
config_path: Option<&str>,
provider_override: Option<String>,
model_override: Option<String>,
) -> Result<Self> {
// Load the base configuration
let mut config = Self::load(config_path)?;
// Apply provider override
if let Some(provider) = provider_override {
config.providers.default_provider = provider;
}
// Apply model override to the active provider
if let Some(model) = model_override {
match config.providers.default_provider.as_str() {
"anthropic" => {
if let Some(ref mut anthropic) = config.providers.anthropic {
anthropic.model = model;
} else {
return Err(anyhow::anyhow!(
"Provider 'anthropic' is not configured. Please add anthropic configuration to your config file."
));
}
}
"databricks" => {
if let Some(ref mut databricks) = config.providers.databricks {
databricks.model = model;
} else {
return Err(anyhow::anyhow!(
"Provider 'databricks' is not configured. Please add databricks configuration to your config file."
));
}
}
"embedded" => {
if let Some(ref mut embedded) = config.providers.embedded {
embedded.model_path = model;
} else {
return Err(anyhow::anyhow!(
"Provider 'embedded' is not configured. Please add embedded configuration to your config file."
));
}
}
"openai" => {
if let Some(ref mut openai) = config.providers.openai {
openai.model = model;
} else {
return Err(anyhow::anyhow!(
"Provider 'openai' is not configured. Please add openai configuration to your config file."
));
}
}
_ => return Err(anyhow::anyhow!("Unknown provider: {}",
config.providers.default_provider)),
}
}
Ok(config)
}
/// Get the provider to use for coach mode in autonomous execution
pub fn get_coach_provider(&self) -> &str {
self.providers.coach
.as_deref()
.unwrap_or(&self.providers.default_provider)
}
/// Get the provider to use for player mode in autonomous execution
pub fn get_player_provider(&self) -> &str {
self.providers.player
.as_deref()
.unwrap_or(&self.providers.default_provider)
}
/// Create a copy of the config with a different default provider
pub fn with_provider_override(&self, provider: &str) -> Result<Self> {
// Validate that the provider is configured
match provider {
"anthropic" if self.providers.anthropic.is_none() => {
return Err(anyhow::anyhow!(
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
provider, provider
));
}
"databricks" if self.providers.databricks.is_none() => {
return Err(anyhow::anyhow!(
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
provider, provider
));
}
"embedded" if self.providers.embedded.is_none() => {
return Err(anyhow::anyhow!(
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
provider, provider
));
}
"openai" if self.providers.openai.is_none() => {
return Err(anyhow::anyhow!(
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
provider, provider
));
}
_ => {} // Provider is configured or unknown (will be caught later)
}
let mut config = self.clone();
config.providers.default_provider = provider.to_string();
Ok(config)
}
/// Create a copy of the config for coach mode in autonomous execution
pub fn for_coach(&self) -> Result<Self> {
self.with_provider_override(self.get_coach_provider())
}
/// Create a copy of the config for player mode in autonomous execution
pub fn for_player(&self) -> Result<Self> {
self.with_provider_override(self.get_player_provider())
}
}
#[cfg(test)]
mod tests;

View File

@@ -0,0 +1,131 @@
#[cfg(test)]
mod tests {
use crate::Config;
use std::fs;
use tempfile::TempDir;
#[test]
fn test_coach_player_providers() {
// Create a temporary directory for the test config
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("test_config.toml");
// Write a test configuration with coach and player providers
let config_content = r#"
[providers]
default_provider = "databricks"
coach = "anthropic"
player = "embedded"
[providers.databricks]
host = "https://test.databricks.com"
token = "test-token"
model = "test-model"
[providers.anthropic]
api_key = "test-key"
model = "claude-3"
[providers.embedded]
model_path = "test.gguf"
model_type = "llama"
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
"#;
fs::write(&config_path, config_content).unwrap();
// Load the configuration
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
// Test that the providers are correctly identified
assert_eq!(config.providers.default_provider, "databricks");
assert_eq!(config.get_coach_provider(), "anthropic");
assert_eq!(config.get_player_provider(), "embedded");
// Test creating coach config
let coach_config = config.for_coach().unwrap();
assert_eq!(coach_config.providers.default_provider, "anthropic");
// Test creating player config
let player_config = config.for_player().unwrap();
assert_eq!(player_config.providers.default_provider, "embedded");
}
#[test]
fn test_coach_player_fallback_to_default() {
// Create a temporary directory for the test config
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("test_config.toml");
// Write a test configuration WITHOUT coach and player providers
let config_content = r#"
[providers]
default_provider = "databricks"
[providers.databricks]
host = "https://test.databricks.com"
token = "test-token"
model = "test-model"
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
"#;
fs::write(&config_path, config_content).unwrap();
// Load the configuration
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
// Test that coach and player fall back to default provider
assert_eq!(config.get_coach_provider(), "databricks");
assert_eq!(config.get_player_provider(), "databricks");
// Test creating coach config (should use default)
let coach_config = config.for_coach().unwrap();
assert_eq!(coach_config.providers.default_provider, "databricks");
// Test creating player config (should use default)
let player_config = config.for_player().unwrap();
assert_eq!(player_config.providers.default_provider, "databricks");
}
#[test]
fn test_invalid_provider_error() {
// Create a temporary directory for the test config
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("test_config.toml");
// Write a test configuration with an unconfigured provider
let config_content = r#"
[providers]
default_provider = "databricks"
coach = "openai" # OpenAI is not configured
[providers.databricks]
host = "https://test.databricks.com"
token = "test-token"
model = "test-model"
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
"#;
fs::write(&config_path, config_content).unwrap();
// Load the configuration
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
// Test that trying to create a coach config with unconfigured provider fails
let result = config.for_coach();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not configured"));
}
}

View File

@@ -8,6 +8,7 @@ description = "Core engine for G3 AI coding agent"
g3-providers = { path = "../g3-providers" }
g3-config = { path = "../g3-config" }
g3-execution = { path = "../g3-execution" }
g3-computer-control = { path = "../g3-computer-control" }
tokio = { workspace = true }
reqwest = { workspace = true }
anyhow = { workspace = true }
@@ -18,7 +19,10 @@ serde_json = { workspace = true }
uuid = { workspace = true }
async-trait = "0.1"
tokio-stream = "0.1"
llama_cpp = { version = "0.3.2", features = ["metal"] }
shellexpand = "3.1"
tokio-util = "0.7"
futures-util = "0.3"
chrono = { version = "0.4", features = ["serde"] }
rand = "0.8"
regex = "1.0"
shellexpand = "3.1"
serde_yaml = "0.9"

View File

@@ -0,0 +1,787 @@
//! Code search functionality using ast-grep for syntax-aware semantic searches
use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::process::Stdio;
use std::time::{Duration, Instant};
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command;
use tokio::sync::Semaphore;
use tracing::{debug, error, info, warn};
/// Maximum number of searches allowed per request
const MAX_SEARCHES: usize = 20;
/// Default timeout for individual searches in seconds
const DEFAULT_TIMEOUT_SECS: u64 = 60;
/// Default maximum concurrency
const DEFAULT_MAX_CONCURRENCY: usize = 4;
/// Default maximum matches per search
const DEFAULT_MAX_MATCHES: usize = 500;
/// Search specification for a single ast-grep search
#[derive(Debug, Clone, Deserialize)]
pub struct SearchSpec {
pub name: String,
pub mode: SearchMode,
// Pattern mode fields
pub pattern: Option<String>,
pub language: Option<String>,
// YAML mode fields
pub rule_yaml: Option<String>,
// Common fields
pub paths: Option<Vec<String>>,
pub globs: Option<Vec<String>>,
pub json_style: Option<JsonStyle>,
pub context: Option<u32>,
pub threads: Option<u32>,
pub include_metadata: Option<bool>,
pub no_ignore: Option<Vec<NoIgnoreType>>,
pub severity: Option<HashMap<String, SeverityLevel>>,
pub timeout_secs: Option<u64>,
}
/// Search mode: pattern or yaml
#[derive(Debug, Clone, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum SearchMode {
Pattern,
Yaml,
}
/// JSON output style
#[derive(Debug, Clone, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum JsonStyle {
Pretty,
Stream,
Compact,
}
impl Default for JsonStyle {
fn default() -> Self {
JsonStyle::Stream
}
}
/// No-ignore types
#[derive(Debug, Clone, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum NoIgnoreType {
Hidden,
Dot,
Exclude,
Global,
Parent,
Vcs,
}
/// Severity levels for YAML rules
#[derive(Debug, Clone, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum SeverityLevel {
Error,
Warning,
Info,
Hint,
Off,
}
/// Request structure for code search
#[derive(Debug, Deserialize)]
pub struct CodeSearchRequest {
pub searches: Vec<SearchSpec>,
pub max_concurrency: Option<usize>,
pub max_matches_per_search: Option<usize>,
}
/// Result of a single search
#[derive(Debug, Serialize)]
pub struct SearchResult {
pub name: String,
pub mode: String,
pub status: String,
pub cmd: Vec<String>,
pub match_count: Option<usize>,
pub truncated: Option<bool>,
pub matches: Option<Vec<Value>>,
pub stderr: Option<String>,
pub exit_code: Option<i32>,
pub duration_ms: u64,
}
/// Summary of all searches
#[derive(Debug, Serialize)]
pub struct SearchSummary {
pub completed: usize,
pub total: usize,
pub total_matches: usize,
pub duration_ms: u64,
}
/// Complete response structure
#[derive(Debug, Serialize)]
pub struct CodeSearchResponse {
pub summary: SearchSummary,
pub searches: Vec<SearchResult>,
}
/// YAML rule structure for validation
#[derive(Debug, Deserialize)]
struct YamlRule {
pub id: String,
pub language: String,
pub rule: Value,
}
/// Execute a batch of code searches using ast-grep
pub async fn execute_code_search(request: CodeSearchRequest) -> Result<CodeSearchResponse> {
let start_time = Instant::now();
// Validate request
if request.searches.is_empty() {
return Err(anyhow!("No searches specified"));
}
if request.searches.len() > MAX_SEARCHES {
return Err(anyhow!(
"Too many searches: {} (max: {})",
request.searches.len(),
MAX_SEARCHES
));
}
// Check if ast-grep is available
check_ast_grep_available().await?;
let max_concurrency = request.max_concurrency.unwrap_or(DEFAULT_MAX_CONCURRENCY);
let max_matches = request.max_matches_per_search.unwrap_or(DEFAULT_MAX_MATCHES);
// Create semaphore for concurrency control
let semaphore = std::sync::Arc::new(Semaphore::new(max_concurrency));
// Execute searches concurrently
let mut tasks = Vec::new();
for search in request.searches {
let sem = semaphore.clone();
let task = tokio::spawn(async move {
let _permit = sem.acquire().await.unwrap();
execute_single_search(search, max_matches).await
});
tasks.push(task);
}
// Wait for all searches to complete
let mut results = Vec::new();
let mut total_matches = 0;
let mut completed = 0;
for task in tasks {
match task.await {
Ok(result) => {
if result.status == "ok" {
completed += 1;
if let Some(count) = result.match_count {
total_matches += count;
}
}
results.push(result);
}
Err(e) => {
error!("Task join error: {}", e);
// Create an error result
results.push(SearchResult {
name: "unknown".to_string(),
mode: "unknown".to_string(),
status: "error".to_string(),
cmd: vec![],
match_count: None,
truncated: None,
matches: None,
stderr: Some(format!("Task execution error: {}", e)),
exit_code: None,
duration_ms: 0,
});
}
}
}
let total_duration = start_time.elapsed();
Ok(CodeSearchResponse {
summary: SearchSummary {
completed,
total: results.len(),
total_matches,
duration_ms: total_duration.as_millis() as u64,
},
searches: results,
})
}
/// Execute a single search
async fn execute_single_search(search: SearchSpec, max_matches: usize) -> SearchResult {
let start_time = Instant::now();
let timeout_secs = search.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
// Validate the search specification
if let Err(e) = validate_search_spec(&search) {
return SearchResult {
name: search.name,
mode: format!("{:?}", search.mode).to_lowercase(),
status: "error".to_string(),
cmd: vec![],
match_count: None,
truncated: None,
matches: None,
stderr: Some(format!("Validation error: {}", e)),
exit_code: None,
duration_ms: start_time.elapsed().as_millis() as u64,
};
}
// Build command
let cmd_args = match build_ast_grep_command(&search) {
Ok(args) => args,
Err(e) => {
return SearchResult {
name: search.name,
mode: format!("{:?}", search.mode).to_lowercase(),
status: "error".to_string(),
cmd: vec![],
match_count: None,
truncated: None,
matches: None,
stderr: Some(format!("Command build error: {}", e)),
exit_code: None,
duration_ms: start_time.elapsed().as_millis() as u64,
};
}
};
debug!("Executing ast-grep command: {:?}", cmd_args);
// Execute with timeout
let timeout_duration = Duration::from_secs(timeout_secs);
match tokio::time::timeout(timeout_duration, run_ast_grep_command(&cmd_args)).await {
Ok(Ok((stdout, stderr, exit_code))) => {
let duration_ms = start_time.elapsed().as_millis() as u64;
if exit_code == 0 {
// Parse JSON output
match parse_ast_grep_output(&stdout, max_matches) {
Ok((matches, truncated)) => {
SearchResult {
name: search.name,
mode: format!("{:?}", search.mode).to_lowercase(),
status: "ok".to_string(),
cmd: cmd_args,
match_count: Some(matches.len()),
truncated: Some(truncated),
matches: Some(matches),
stderr: if stderr.is_empty() { None } else { Some(stderr) },
exit_code: None,
duration_ms,
}
}
Err(e) => {
SearchResult {
name: search.name,
mode: format!("{:?}", search.mode).to_lowercase(),
status: "error".to_string(),
cmd: cmd_args,
match_count: None,
truncated: None,
matches: None,
stderr: Some(format!("JSON parse error: {}\nRaw output: {}", e, stdout)),
exit_code: Some(exit_code),
duration_ms,
}
}
}
} else {
SearchResult {
name: search.name,
mode: format!("{:?}", search.mode).to_lowercase(),
status: "error".to_string(),
cmd: cmd_args,
match_count: None,
truncated: None,
matches: None,
stderr: Some(stderr),
exit_code: Some(exit_code),
duration_ms,
}
}
}
Ok(Err(e)) => {
SearchResult {
name: search.name,
mode: format!("{:?}", search.mode).to_lowercase(),
status: "error".to_string(),
cmd: cmd_args,
match_count: None,
truncated: None,
matches: None,
stderr: Some(format!("Execution error: {}", e)),
exit_code: None,
duration_ms: start_time.elapsed().as_millis() as u64,
}
}
Err(_) => {
SearchResult {
name: search.name,
mode: format!("{:?}", search.mode).to_lowercase(),
status: "timeout".to_string(),
cmd: cmd_args,
match_count: None,
truncated: None,
matches: None,
stderr: Some(format!("Search timed out after {} seconds", timeout_secs)),
exit_code: None,
duration_ms: start_time.elapsed().as_millis() as u64,
}
}
}
}
/// Validate a search specification
fn validate_search_spec(search: &SearchSpec) -> Result<()> {
match search.mode {
SearchMode::Pattern => {
if search.pattern.is_none() || search.pattern.as_ref().unwrap().is_empty() {
return Err(anyhow!("Pattern mode requires non-empty 'pattern' field"));
}
}
SearchMode::Yaml => {
let rule_yaml = search.rule_yaml.as_ref()
.ok_or_else(|| anyhow!("YAML mode requires 'rule_yaml' field"))?;
if rule_yaml.is_empty() {
return Err(anyhow!("YAML mode requires non-empty 'rule_yaml' field"));
}
// Parse and validate YAML structure
let parsed: YamlRule = serde_yaml::from_str(rule_yaml)
.map_err(|e| anyhow!("Invalid YAML rule: {}", e))?;
if parsed.id.is_empty() {
return Err(anyhow!("YAML rule must have non-empty 'id' field"));
}
if parsed.language.is_empty() {
return Err(anyhow!("YAML rule must have non-empty 'language' field"));
}
// Validate language is supported (basic check)
validate_language(&parsed.language)?;
}
}
// Validate context range
if let Some(context) = search.context {
if context > 20 {
return Err(anyhow!("Context lines cannot exceed 20"));
}
}
Ok(())
}
/// Validate that a language is supported by ast-grep
fn validate_language(language: &str) -> Result<()> {
let supported_languages = [
"rust", "javascript", "typescript", "python", "java", "c", "cpp", "csharp",
"go", "html", "css", "json", "yaml", "xml", "bash", "kotlin", "swift",
"php", "ruby", "scala", "dart", "lua", "r", "sql", "dockerfile",
"Rust", "JavaScript", "TypeScript", "Python", "Java", "C", "Cpp", "CSharp",
"Go", "Html", "Css", "Json", "Yaml", "Xml", "Bash", "Kotlin", "Swift",
"Php", "Ruby", "Scala", "Dart", "Lua", "R", "Sql", "Dockerfile"
];
if !supported_languages.contains(&language) {
warn!("Language '{}' may not be supported by ast-grep", language);
}
Ok(())
}
/// Build ast-grep command arguments
fn build_ast_grep_command(search: &SearchSpec) -> Result<Vec<String>> {
let mut args = vec!["ast-grep".to_string()];
match search.mode {
SearchMode::Pattern => {
args.push("run".to_string());
// Add pattern
args.push("-p".to_string());
args.push(search.pattern.as_ref().unwrap().clone());
// Add language if specified
if let Some(ref lang) = search.language {
args.push("-l".to_string());
args.push(lang.clone());
}
}
SearchMode::Yaml => {
args.push("scan".to_string());
// Add inline rules
args.push("--inline-rules".to_string());
args.push(search.rule_yaml.as_ref().unwrap().clone());
// Add include-metadata if requested
if search.include_metadata.unwrap_or(false) {
args.push("--include-metadata".to_string());
}
// Add severity overrides
if let Some(ref severity_map) = search.severity {
for (rule_id, severity) in severity_map {
match severity {
SeverityLevel::Error => {
args.push("--error".to_string());
args.push(rule_id.clone());
}
SeverityLevel::Warning => {
args.push("--warning".to_string());
args.push(rule_id.clone());
}
SeverityLevel::Info => {
args.push("--info".to_string());
args.push(rule_id.clone());
}
SeverityLevel::Hint => {
args.push("--hint".to_string());
args.push(rule_id.clone());
}
SeverityLevel::Off => {
args.push("--off".to_string());
args.push(rule_id.clone());
}
}
}
}
}
}
// Add common arguments
// Add globs if specified
if let Some(ref globs) = search.globs {
if !globs.is_empty() {
args.push("--globs".to_string());
args.push(globs.join(","));
}
}
// Add context
if let Some(context) = search.context {
args.push("-C".to_string());
args.push(context.to_string());
}
// Add threads
if let Some(threads) = search.threads {
args.push("-j".to_string());
args.push(threads.to_string());
}
// Add JSON output style
let json_style = search.json_style.as_ref().unwrap_or(&JsonStyle::Stream);
let json_arg = match json_style {
JsonStyle::Pretty => "--json=pretty",
JsonStyle::Stream => "--json=stream",
JsonStyle::Compact => "--json=compact",
};
args.push(json_arg.to_string());
// Add no-ignore options
if let Some(ref no_ignore_list) = search.no_ignore {
for no_ignore_type in no_ignore_list {
let flag = match no_ignore_type {
NoIgnoreType::Hidden => "--no-ignore=hidden",
NoIgnoreType::Dot => "--no-ignore=dot",
NoIgnoreType::Exclude => "--no-ignore=exclude",
NoIgnoreType::Global => "--no-ignore=global",
NoIgnoreType::Parent => "--no-ignore=parent",
NoIgnoreType::Vcs => "--no-ignore=vcs",
};
args.push(flag.to_string());
}
}
// Add paths (default to current directory if none specified)
if let Some(ref paths) = search.paths {
if !paths.is_empty() {
args.extend(paths.clone());
} else {
args.push(".".to_string());
}
} else {
args.push(".".to_string());
}
Ok(args)
}
/// Run ast-grep command and capture output
async fn run_ast_grep_command(args: &[String]) -> Result<(String, String, i32)> {
let mut cmd = Command::new(&args[0]);
cmd.args(&args[1..]);
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
debug!("Running command: {:?}", args);
let mut child = cmd.spawn()
.map_err(|e| anyhow!("Failed to spawn ast-grep process: {}", e))?;
let stdout = child.stdout.take().unwrap();
let stderr = child.stderr.take().unwrap();
let stdout_reader = BufReader::new(stdout);
let stderr_reader = BufReader::new(stderr);
let stdout_task = tokio::spawn(async move {
let mut lines = stdout_reader.lines();
let mut output = String::new();
while let Ok(Some(line)) = lines.next_line().await {
if !output.is_empty() {
output.push('\n');
}
output.push_str(&line);
}
output
});
let stderr_task = tokio::spawn(async move {
let mut lines = stderr_reader.lines();
let mut output = String::new();
while let Ok(Some(line)) = lines.next_line().await {
if !output.is_empty() {
output.push('\n');
}
output.push_str(&line);
}
output
});
let status = child.wait().await
.map_err(|e| anyhow!("Failed to wait for ast-grep process: {}", e))?;
let stdout_output = stdout_task.await
.map_err(|e| anyhow!("Failed to read stdout: {}", e))?;
let stderr_output = stderr_task.await
.map_err(|e| anyhow!("Failed to read stderr: {}", e))?;
let exit_code = status.code().unwrap_or(-1);
Ok((stdout_output, stderr_output, exit_code))
}
/// Parse ast-grep JSON output
fn parse_ast_grep_output(output: &str, max_matches: usize) -> Result<(Vec<Value>, bool)> {
if output.trim().is_empty() {
return Ok((vec![], false));
}
let mut matches = Vec::new();
let mut truncated = false;
// Handle stream format (line-delimited JSON)
for line in output.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
match serde_json::from_str::<Value>(line) {
Ok(match_obj) => {
if matches.len() >= max_matches {
truncated = true;
break;
}
matches.push(match_obj);
}
Err(e) => {
debug!("Failed to parse JSON line '{}': {}", line, e);
// Try to parse the entire output as a single JSON array
match serde_json::from_str::<Vec<Value>>(output) {
Ok(array_matches) => {
let take_count = array_matches.len().min(max_matches);
let total_count = array_matches.len();
matches = array_matches.into_iter().take(take_count).collect();
truncated = take_count < total_count;
break;
}
Err(e2) => {
return Err(anyhow!(
"Failed to parse ast-grep output as line-delimited JSON or JSON array. Line error: {}, Array error: {}",
e, e2
));
}
}
}
}
}
Ok((matches, truncated))
}
/// Check if ast-grep is available and provide installation hints if not
async fn check_ast_grep_available() -> Result<()> {
match Command::new("ast-grep")
.arg("--version")
.output()
.await
{
Ok(output) => {
if output.status.success() {
let version = String::from_utf8_lossy(&output.stdout);
info!("Found ast-grep: {}", version.trim());
Ok(())
} else {
Err(anyhow!("ast-grep command failed: {}", String::from_utf8_lossy(&output.stderr)))
}
}
Err(_) => {
Err(anyhow!(
"ast-grep not found. Please install it using one of these methods:\n\n\
• Homebrew (macOS): brew install ast-grep\n\
• MacPorts (macOS): sudo port install ast-grep\n\
• Nix: nix-env -iA nixpkgs.ast-grep\n\
• Cargo: cargo install ast-grep\n\
• npm: npm install -g @ast-grep/cli\n\
• pip: pip install ast-grep\n\n\
For more installation options, visit: https://ast-grep.github.io/guide/quick-start.html"
))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_pattern_search() {
let search = SearchSpec {
name: "test".to_string(),
mode: SearchMode::Pattern,
pattern: Some("fn $NAME() {}".to_string()),
language: Some("rust".to_string()),
rule_yaml: None,
paths: None,
globs: None,
json_style: None,
context: None,
threads: None,
include_metadata: None,
no_ignore: None,
severity: None,
timeout_secs: None,
};
assert!(validate_search_spec(&search).is_ok());
}
#[test]
fn test_validate_yaml_search() {
let yaml_rule = r#"
id: test-rule
language: Rust
rule:
pattern: "fn $NAME() {}"
"#;
let search = SearchSpec {
name: "test".to_string(),
mode: SearchMode::Yaml,
pattern: None,
language: None,
rule_yaml: Some(yaml_rule.to_string()),
paths: None,
globs: None,
json_style: None,
context: None,
threads: None,
include_metadata: None,
no_ignore: None,
severity: None,
timeout_secs: None,
};
assert!(validate_search_spec(&search).is_ok());
}
#[test]
fn test_build_pattern_command() {
let search = SearchSpec {
name: "test".to_string(),
mode: SearchMode::Pattern,
pattern: Some("fn $NAME() {}".to_string()),
language: Some("rust".to_string()),
rule_yaml: None,
paths: Some(vec!["src/".to_string()]),
globs: None,
json_style: Some(JsonStyle::Stream),
context: Some(2),
threads: Some(4),
include_metadata: None,
no_ignore: None,
severity: None,
timeout_secs: None,
};
let cmd = build_ast_grep_command(&search).unwrap();
assert_eq!(cmd[0], "ast-grep");
assert_eq!(cmd[1], "run");
assert!(cmd.contains(&"-p".to_string()));
assert!(cmd.contains(&"fn $NAME() {}".to_string()));
assert!(cmd.contains(&"-l".to_string()));
assert!(cmd.contains(&"rust".to_string()));
assert!(cmd.contains(&"--json=stream".to_string()));
assert!(cmd.contains(&"-C".to_string()));
assert!(cmd.contains(&"2".to_string()));
assert!(cmd.contains(&"-j".to_string()));
assert!(cmd.contains(&"4".to_string()));
assert!(cmd.contains(&"src/".to_string()));
}
#[test]
fn test_parse_stream_json() {
let output = r#"{"file":"test.rs","text":"fn hello() {}"}
{"file":"test2.rs","text":"fn world() {}"}"#;
let (matches, truncated) = parse_ast_grep_output(output, 10).unwrap();
assert_eq!(matches.len(), 2);
assert!(!truncated);
assert_eq!(matches[0]["file"], "test.rs");
assert_eq!(matches[1]["file"], "test2.rs");
}
#[test]
fn test_parse_truncated_output() {
let output = r#"{"file":"test1.rs","text":"fn a() {}"}
{"file":"test2.rs","text":"fn b() {}"}
{"file":"test3.rs","text":"fn c() {}"}"#;
let (matches, truncated) = parse_ast_grep_output(output, 2).unwrap();
assert_eq!(matches.len(), 2);
assert!(truncated);
}
}

View File

@@ -0,0 +1,501 @@
//! Error handling module for G3 with retry logic and detailed logging
//!
//! This module provides:
//! - Classification of errors as recoverable or non-recoverable
//! - Retry logic with exponential backoff and jitter for recoverable errors
//! - Detailed error logging with context information
//! - Request/response capture for debugging
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tracing::{error, info, warn};
/// Maximum number of retry attempts for recoverable errors (default mode)
const DEFAULT_MAX_RETRY_ATTEMPTS: u32 = 3;
/// Maximum number of retry attempts for autonomous mode
const AUTONOMOUS_MAX_RETRY_ATTEMPTS: u32 = 6;
/// Base delay for exponential backoff (in milliseconds)
const BASE_RETRY_DELAY_MS: u64 = 1000;
/// Maximum delay between retries (in milliseconds) for default mode
const DEFAULT_MAX_RETRY_DELAY_MS: u64 = 10000;
/// Maximum delay between retries (in milliseconds) for autonomous mode
/// Spread over 10 minutes (600 seconds) with 6 retries
const AUTONOMOUS_MAX_RETRY_DELAY_MS: u64 = 120000; // 2 minutes max per retry
// Removed unused constants AUTONOMOUS_RETRY_BUDGET_MS and DEFAULT_JITTER_FACTOR
/// Jitter factor for autonomous mode (higher for better distribution)
const JITTER_FACTOR: f64 = 0.3;
/// Error context information for detailed logging
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorContext {
/// The operation that was being performed
pub operation: String,
/// The provider being used
pub provider: String,
/// The model being used
pub model: String,
/// The last prompt sent (truncated for logging)
pub last_prompt: String,
/// Raw request data (if available)
pub raw_request: Option<String>,
/// Raw response data (if available)
pub raw_response: Option<String>,
/// Stack trace
pub stack_trace: String,
/// Timestamp
pub timestamp: u64,
/// Number of tokens in context
pub context_tokens: u32,
/// Session ID if available
pub session_id: Option<String>,
/// Whether to skip file logging (quiet mode)
pub quiet: bool,
}
impl ErrorContext {
pub fn new(
operation: String,
provider: String,
model: String,
last_prompt: String,
session_id: Option<String>,
context_tokens: u32,
quiet: bool,
) -> Self {
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
// Capture stack trace
let stack_trace = std::backtrace::Backtrace::force_capture().to_string();
Self {
operation,
provider,
model,
last_prompt: truncate_for_logging(&last_prompt, 1000),
raw_request: None,
raw_response: None,
stack_trace,
timestamp,
context_tokens,
session_id,
quiet,
}
}
pub fn with_request(mut self, request: String) -> Self {
self.raw_request = Some(truncate_for_logging(&request, 5000));
self
}
pub fn with_response(mut self, response: String) -> Self {
self.raw_response = Some(truncate_for_logging(&response, 5000));
self
}
/// Log the error context with ERROR level
pub fn log_error(&self, error: &anyhow::Error) {
error!("=== G3 ERROR DETAILS ===");
error!("Operation: {}", self.operation);
error!("Provider: {} | Model: {}", self.provider, self.model);
error!("Error: {}", error);
error!("Timestamp: {}", self.timestamp);
error!("Session ID: {:?}", self.session_id);
error!("Context Tokens: {}", self.context_tokens);
error!("Last Prompt: {}", self.last_prompt);
if let Some(ref req) = self.raw_request {
error!("Raw Request: {}", req);
}
if let Some(ref resp) = self.raw_response {
error!("Raw Response: {}", resp);
}
error!("Stack Trace:\n{}", self.stack_trace);
error!("=== END ERROR DETAILS ===");
// Also save to error log file
self.save_to_file();
}
/// Save error context to a file for later analysis
fn save_to_file(&self) {
// Skip file logging if quiet mode is enabled
if self.quiet {
return;
}
let logs_dir = std::path::Path::new("logs/errors");
if !logs_dir.exists() {
if let Err(e) = std::fs::create_dir_all(logs_dir) {
error!("Failed to create error logs directory: {}", e);
return;
}
}
let filename = format!(
"logs/errors/error_{}_{}.json",
self.timestamp,
self.session_id.as_deref().unwrap_or("unknown")
);
match serde_json::to_string_pretty(self) {
Ok(json_content) => {
if let Err(e) = std::fs::write(&filename, json_content) {
error!("Failed to save error context to {}: {}", filename, e);
} else {
info!("Error details saved to: {}", filename);
}
}
Err(e) => {
error!("Failed to serialize error context: {}", e);
}
}
}
}
/// Classification of error types
#[derive(Debug, Clone, PartialEq)]
pub enum ErrorType {
/// Recoverable errors that should be retried
Recoverable(RecoverableError),
/// Non-recoverable errors that should terminate execution
NonRecoverable,
}
/// Types of recoverable errors
#[derive(Debug, Clone, PartialEq)]
pub enum RecoverableError {
/// Rate limit exceeded
RateLimit,
/// Temporary network error
NetworkError,
/// Server error (5xx)
ServerError,
/// Model is busy/overloaded
ModelBusy,
/// Timeout
Timeout,
/// Token limit exceeded (might be recoverable with summarization)
TokenLimit,
}
/// Classify an error as recoverable or non-recoverable
pub fn classify_error(error: &anyhow::Error) -> ErrorType {
let error_str = error.to_string().to_lowercase();
// Check for recoverable error patterns
if error_str.contains("rate limit") || error_str.contains("rate_limit") || error_str.contains("429") {
return ErrorType::Recoverable(RecoverableError::RateLimit);
}
if error_str.contains("network") || error_str.contains("connection") ||
error_str.contains("dns") || error_str.contains("refused") {
return ErrorType::Recoverable(RecoverableError::NetworkError);
}
if error_str.contains("500") || error_str.contains("502") ||
error_str.contains("503") || error_str.contains("504") ||
error_str.contains("server error") || error_str.contains("internal error") {
return ErrorType::Recoverable(RecoverableError::ServerError);
}
if error_str.contains("busy") || error_str.contains("overloaded") ||
error_str.contains("capacity") || error_str.contains("unavailable") {
return ErrorType::Recoverable(RecoverableError::ModelBusy);
}
// Enhanced timeout detection - check for various timeout patterns
if error_str.contains("timeout") ||
error_str.contains("timed out") ||
error_str.contains("operation timed out") ||
error_str.contains("request or response body error") || // Common timeout pattern
error_str.contains("stream error") && error_str.contains("timed out") {
return ErrorType::Recoverable(RecoverableError::Timeout);
}
if error_str.contains("token") && (error_str.contains("limit") || error_str.contains("exceeded")) {
return ErrorType::Recoverable(RecoverableError::TokenLimit);
}
// Default to non-recoverable for unknown errors
ErrorType::NonRecoverable
}
/// Calculate retry delay for autonomous mode with better distribution over 10 minutes
fn calculate_autonomous_retry_delay(attempt: u32) -> Duration {
use rand::Rng;
let mut rng = rand::thread_rng();
// Distribute 6 retries over 10 minutes (600 seconds)
// Base delays: 10s, 30s, 60s, 120s, 180s, 200s = 600s total
let base_delays_ms = [10000, 30000, 60000, 120000, 180000, 200000];
let base_delay = base_delays_ms.get(attempt.saturating_sub(1) as usize).unwrap_or(&200000);
// Add jitter of ±30% to prevent thundering herd
let jitter = (*base_delay as f64 * 0.3 * rng.gen::<f64>()) as u64;
let final_delay = if rng.gen_bool(0.5) {
base_delay + jitter
} else {
base_delay.saturating_sub(jitter)
};
Duration::from_millis(final_delay)
}
/// Calculate retry delay with exponential backoff and jitter
pub fn calculate_retry_delay(attempt: u32, is_autonomous: bool) -> Duration {
if is_autonomous {
return calculate_autonomous_retry_delay(attempt);
}
use rand::Rng;
let max_retry_delay_ms = if is_autonomous { AUTONOMOUS_MAX_RETRY_DELAY_MS } else { DEFAULT_MAX_RETRY_DELAY_MS };
// Exponential backoff: delay = base * 2^attempt
let base_delay = BASE_RETRY_DELAY_MS * (2_u64.pow(attempt.saturating_sub(1)));
let capped_delay = base_delay.min(max_retry_delay_ms);
// Add jitter to prevent thundering herd
let mut rng = rand::thread_rng();
let jitter = (capped_delay as f64 * JITTER_FACTOR * rng.gen::<f64>()) as u64;
let final_delay = if rng.gen_bool(0.5) {
capped_delay + jitter
} else {
capped_delay.saturating_sub(jitter)
};
Duration::from_millis(final_delay)
}
/// Retry logic for async operations
pub async fn retry_with_backoff<F, Fut, T>(
operation_name: &str,
mut operation: F,
context: &ErrorContext,
is_autonomous: bool,
) -> Result<T>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
let mut attempt = 0;
let mut _last_error = None;
loop {
attempt += 1;
match operation().await {
Ok(result) => {
if attempt > 1 {
info!(
"Operation '{}' succeeded after {} attempts",
operation_name, attempt
);
}
return Ok(result);
}
Err(error) => {
let error_type = classify_error(&error);
let max_attempts = if is_autonomous { AUTONOMOUS_MAX_RETRY_ATTEMPTS } else { DEFAULT_MAX_RETRY_ATTEMPTS };
match error_type {
ErrorType::Recoverable(recoverable_type) => {
if attempt >= max_attempts {
error!(
"Operation '{}' failed after {} attempts. Giving up.",
operation_name, attempt
);
context.clone().log_error(&error);
return Err(error);
}
let delay = calculate_retry_delay(attempt, is_autonomous);
warn!(
"Recoverable error ({:?}) in '{}' (attempt {}/{}). Retrying in {:?}...",
recoverable_type, operation_name, attempt, max_attempts, delay
);
warn!("Error details: {}", error);
// Special handling for token limit errors
if matches!(recoverable_type, RecoverableError::TokenLimit) {
info!("Token limit error detected. Consider triggering summarization.");
}
tokio::time::sleep(delay).await;
_last_error = Some(error);
}
ErrorType::NonRecoverable => {
error!(
"Non-recoverable error in '{}' (attempt {}). Terminating.",
operation_name, attempt
);
context.clone().log_error(&error);
return Err(error);
}
}
}
}
}
}
/// Helper function to truncate strings for logging
fn truncate_for_logging(s: &str, max_len: usize) -> String {
if s.len() <= max_len {
s.to_string()
} else {
// Find a safe UTF-8 boundary to truncate at
// We need to ensure we don't cut in the middle of a multi-byte character
let mut truncate_at = max_len;
// Walk backwards from max_len to find a character boundary
while truncate_at > 0 && !s.is_char_boundary(truncate_at) {
truncate_at -= 1;
}
// If we couldn't find a boundary (shouldn't happen), use a safe default
if truncate_at == 0 {
truncate_at = max_len.min(s.len());
}
format!("{}... (truncated, {} total bytes)", &s[..truncate_at], s.len())
}
}
/// Macro for creating error context easily
#[macro_export]
macro_rules! error_context {
($operation:expr, $provider:expr, $model:expr, $prompt:expr, $session_id:expr, $tokens:expr) => {
$crate::error_handling::ErrorContext::new(
$operation.to_string(),
$provider.to_string(),
$model.to_string(),
$prompt.to_string(),
$session_id,
$tokens,
)
};
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::anyhow;
#[test]
fn test_error_classification() {
// Rate limit errors
let error = anyhow!("Rate limit exceeded");
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::RateLimit));
let error = anyhow!("HTTP 429 Too Many Requests");
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::RateLimit));
// Network errors
let error = anyhow!("Network connection failed");
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::NetworkError));
// Server errors
let error = anyhow!("HTTP 503 Service Unavailable");
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::ServerError));
// Model busy
let error = anyhow!("Model is busy, please try again");
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::ModelBusy));
// Timeout
let error = anyhow!("Request timed out");
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::Timeout));
// Token limit
let error = anyhow!("Token limit exceeded");
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::TokenLimit));
// Non-recoverable
let error = anyhow!("Invalid API key");
assert_eq!(classify_error(&error), ErrorType::NonRecoverable);
let error = anyhow!("Malformed request");
assert_eq!(classify_error(&error), ErrorType::NonRecoverable);
}
#[test]
fn test_retry_delay_calculation() {
// Test that delays increase exponentially
let delay1 = calculate_retry_delay(1, false);
let delay2 = calculate_retry_delay(2, false);
let delay3 = calculate_retry_delay(3, false);
// Due to jitter, we can't test exact values, but the base should increase
assert!(delay1.as_millis() >= (BASE_RETRY_DELAY_MS as f64 * 0.7) as u128);
assert!(delay1.as_millis() <= (BASE_RETRY_DELAY_MS as f64 * 1.3) as u128);
// Delay 2 should be roughly 2x delay 1 (minus jitter)
assert!(delay2.as_millis() >= delay1.as_millis());
// Delay 3 should be roughly 2x delay 2 (minus jitter)
assert!(delay3.as_millis() >= delay2.as_millis());
// Test max cap
let delay_max = calculate_retry_delay(10, false);
assert!(delay_max.as_millis() <= (DEFAULT_MAX_RETRY_DELAY_MS as f64 * 1.3) as u128);
}
#[test]
fn test_autonomous_retry_delay_calculation() {
// Test autonomous mode delays are distributed over 10 minutes
let delay1 = calculate_retry_delay(1, true);
let delay2 = calculate_retry_delay(2, true);
let delay3 = calculate_retry_delay(3, true);
let delay4 = calculate_retry_delay(4, true);
let delay5 = calculate_retry_delay(5, true);
let delay6 = calculate_retry_delay(6, true);
// Base delays should be around: 10s, 30s, 60s, 120s, 180s, 200s
// With ±30% jitter
assert!(delay1.as_millis() >= 7000 && delay1.as_millis() <= 13000);
assert!(delay2.as_millis() >= 21000 && delay2.as_millis() <= 39000);
assert!(delay3.as_millis() >= 42000 && delay3.as_millis() <= 78000);
assert!(delay4.as_millis() >= 84000 && delay4.as_millis() <= 156000);
assert!(delay5.as_millis() >= 126000 && delay5.as_millis() <= 234000);
assert!(delay6.as_millis() >= 140000 && delay6.as_millis() <= 260000);
}
#[test]
fn test_truncate_for_logging() {
let short_text = "Hello, world!";
assert_eq!(truncate_for_logging(short_text, 20), "Hello, world!");
let long_text = "This is a very long text that should be truncated for logging purposes";
let truncated = truncate_for_logging(long_text, 20);
assert!(truncated.starts_with("This is a very long "));
assert!(truncated.contains("truncated"));
assert!(truncated.contains("total bytes"));
}
#[test]
fn test_truncate_with_multibyte_chars() {
// Test with multi-byte UTF-8 characters
let text_with_emoji = "Hello 👋 World 🌍 Test ✨ More text here";
let truncated = truncate_for_logging(text_with_emoji, 10);
// Should truncate at a valid UTF-8 boundary
assert!(truncated.starts_with("Hello "));
// Test with box-drawing characters like the one causing the panic
let text_with_box = "Some text ┌─────┐ more text";
let truncated = truncate_for_logging(text_with_box, 12);
// Should not panic and should truncate at a valid boundary
assert!(truncated.contains("Some text"));
assert!(truncated.contains("truncated"));
}
}

View File

@@ -0,0 +1,154 @@
//! Integration tests for error handling with retry logic
#[cfg(test)]
mod tests {
use crate::error_handling::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[tokio::test]
async fn test_retry_with_recoverable_error() {
let attempt_count = Arc::new(AtomicU32::new(0));
let context = ErrorContext::new(
"test_operation".to_string(),
"test_provider".to_string(),
"test_model".to_string(),
"test prompt".to_string(),
None,
100,
false, // quiet parameter
);
let result = retry_with_backoff(
"test_operation",
|| {
let counter = Arc::clone(&attempt_count);
async move {
let count = counter.fetch_add(1, Ordering::SeqCst);
if count < 2 {
// Fail with recoverable error on first two attempts
Err(anyhow::anyhow!("Rate limit exceeded"))
} else {
// Succeed on third attempt
Ok("Success")
}
}
},
&context,
false, // not autonomous mode
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "Success");
assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_with_non_recoverable_error() {
let attempt_count = Arc::new(AtomicU32::new(0));
let context = ErrorContext::new(
"test_operation".to_string(),
"test_provider".to_string(),
"test_model".to_string(),
"test prompt".to_string(),
None,
100,
false, // quiet parameter
);
let result: Result<&str, _> = retry_with_backoff(
"test_operation",
|| {
let counter = Arc::clone(&attempt_count);
async move {
counter.fetch_add(1, Ordering::SeqCst);
// Always fail with non-recoverable error
Err(anyhow::anyhow!("Invalid API key"))
}
},
&context,
false, // not autonomous mode
)
.await;
assert!(result.is_err());
assert_eq!(attempt_count.load(Ordering::SeqCst), 1); // Should only try once
}
#[tokio::test]
async fn test_retry_exhaustion() {
let attempt_count = Arc::new(AtomicU32::new(0));
let context = ErrorContext::new(
"test_operation".to_string(),
"test_provider".to_string(),
"test_model".to_string(),
"test prompt".to_string(),
None,
100,
false, // quiet parameter
);
let result: Result<&str, _> = retry_with_backoff(
"test_operation",
|| {
let counter = Arc::clone(&attempt_count);
async move {
counter.fetch_add(1, Ordering::SeqCst);
// Always fail with recoverable error
Err(anyhow::anyhow!("Network connection failed"))
}
},
&context,
false, // not autonomous mode
)
.await;
assert!(result.is_err());
assert_eq!(attempt_count.load(Ordering::SeqCst), 3); // Should try MAX_RETRY_ATTEMPTS times
}
#[test]
fn test_error_context_truncation() {
let long_prompt = "a".repeat(2000);
let context = ErrorContext::new(
"test_op".to_string(),
"provider".to_string(),
"model".to_string(),
long_prompt,
None,
100,
false, // quiet parameter
);
// The prompt should be truncated to 1000 chars
assert!(context.last_prompt.len() < 1100); // Some buffer for the truncation message
assert!(context.last_prompt.contains("truncated"));
}
#[test]
fn test_retry_delay_increases() {
let delay1 = calculate_retry_delay(1, false);
let delay2 = calculate_retry_delay(2, false);
let delay3 = calculate_retry_delay(3, false);
// Delays should generally increase (though jitter can affect this)
// We'll test the base delays without jitter
let base1 = 1000u64; // BASE_RETRY_DELAY_MS
let base2 = 1000u64 * 2;
let base3 = 1000u64 * 4;
// Check that delays are within expected ranges (accounting for jitter)
assert!(delay1.as_millis() >= (base1 as f64 * 0.7) as u128);
assert!(delay1.as_millis() <= (base1 as f64 * 1.3) as u128);
assert!(delay2.as_millis() >= (base2 as f64 * 0.7) as u128);
assert!(delay2.as_millis() <= (base2 as f64 * 1.3) as u128);
assert!(delay3.as_millis() >= (base3 as f64 * 0.7) as u128);
assert!(delay3.as_millis() <= (base3 as f64 * 1.3) as u128);
}
}

View File

@@ -0,0 +1,463 @@
// FINAL CORRECTED implementation of filter_json_tool_calls function according to specification
// 1. Detect tool call start with regex '\w*{\w*"tool"\w*:\w*"' on the very next newline
// 2. Enter suppression mode and use brace counting to find complete JSON
// 3. Only elide JSON content between first '{' and last '}' (inclusive)
// 4. Return everything else as the final filtered string
//! JSON tool call filtering for streaming LLM responses.
//!
//! This module filters out JSON tool calls from LLM output streams while preserving
//! regular text content. It uses a state machine to handle streaming chunks.
use regex::Regex;
use std::cell::RefCell;
use tracing::debug;
// Thread-local state for tracking JSON tool call suppression
thread_local! {
static FIXED_JSON_TOOL_STATE: RefCell<FixedJsonToolState> = RefCell::new(FixedJsonToolState::new());
}
/// Internal state for tracking JSON tool call filtering across streaming chunks.
#[derive(Debug, Clone)]
struct FixedJsonToolState {
/// True when actively suppressing a confirmed tool call
suppression_mode: bool,
/// True when buffering potential JSON (saw { but not yet confirmed as tool call)
potential_json_mode: bool,
/// Tracks nesting depth of braces within JSON
brace_depth: i32,
buffer: String,
json_start_in_buffer: Option<usize>, // Position where confirmed JSON tool call starts
content_returned_up_to: usize, // Track how much content we've already returned
potential_json_start: Option<usize>, // Where the potential JSON started
}
impl FixedJsonToolState {
fn new() -> Self {
Self {
suppression_mode: false,
potential_json_mode: false,
brace_depth: 0,
buffer: String::new(),
json_start_in_buffer: None,
content_returned_up_to: 0,
potential_json_start: None,
}
}
fn reset(&mut self) {
self.suppression_mode = false;
self.potential_json_mode = false;
self.brace_depth = 0;
self.buffer.clear();
self.json_start_in_buffer = None;
self.content_returned_up_to = 0;
self.potential_json_start = None;
}
}
// FINAL CORRECTED implementation according to specification
/// Filters JSON tool calls from streaming LLM content.
///
/// Processes content chunks and removes JSON tool calls while preserving regular text.
/// Maintains state across calls to handle tool calls spanning multiple chunks.
pub fn fixed_filter_json_tool_calls(content: &str) -> String {
if content.is_empty() {
return String::new();
}
FIXED_JSON_TOOL_STATE.with(|state| {
let mut state = state.borrow_mut();
// Add new content to buffer
state.buffer.push_str(content);
// If we're already in suppression mode, continue brace counting
if state.suppression_mode {
// Count braces in the new content only
for ch in content.chars() {
match ch {
'{' => state.brace_depth += 1,
'}' => {
state.brace_depth -= 1;
// Exit suppression mode when all braces are closed
if state.brace_depth <= 0 {
debug!("JSON tool call completed - exiting suppression mode");
// Extract the complete result with JSON filtered out
let result = extract_fixed_content(
&state.buffer,
state.json_start_in_buffer.unwrap_or(0),
);
// Return only the part we haven't returned yet
let new_content = if result.len() > state.content_returned_up_to {
result[state.content_returned_up_to..].to_string()
} else {
String::new()
};
state.reset();
return new_content;
}
}
_ => {}
}
}
// CRITICAL FIX: After counting braces, if still in suppression mode,
// check if a new tool call pattern appears. This handles truncated JSON
// followed by complete JSON.
if state.suppression_mode {
let current_json_start = state.json_start_in_buffer.unwrap();
// Don't require newline - the new JSON might be concatenated directly
let tool_call_regex = Regex::new(r#"\{\s*"tool"\s*:\s*""#).unwrap();
// Look for new tool call patterns after the current one
if let Some(captures) = tool_call_regex.find(&state.buffer[current_json_start + 1..]) {
let new_json_start = current_json_start + 1 + captures.start() + captures.as_str().find('{').unwrap();
debug!("Detected new tool call at position {} while processing incomplete one at {} - discarding old", new_json_start, current_json_start);
// The previous JSON was incomplete/malformed
// Return content before the old JSON (if any)
let content_before_old_json = if current_json_start > state.content_returned_up_to {
state.buffer[state.content_returned_up_to..current_json_start].to_string()
} else {
String::new()
};
// Update state to skip the incomplete JSON and position at the new one
// We'll process the new JSON on the next call
state.content_returned_up_to = new_json_start;
state.suppression_mode = false;
state.json_start_in_buffer = None;
state.brace_depth = 0;
return content_before_old_json;
}
}
// Still in suppression mode, return empty string (content is being accumulated)
return String::new();
}
// Check if we're in potential JSON mode (saw { but waiting to confirm it's a tool call)
if state.potential_json_mode {
// Check if the buffer contains a confirmed tool call pattern
let tool_call_regex = Regex::new(r#"(?m)^\s*\{\s*"tool"\s*:\s*""#).unwrap();
if let Some(captures) = tool_call_regex.find(&state.buffer) {
// Confirmed! This is a tool call - enter suppression mode
let match_text = captures.as_str();
if let Some(brace_offset) = match_text.find('{') {
let json_start = captures.start() + brace_offset;
debug!("Confirmed JSON tool call at position {} - entering suppression mode", json_start);
state.potential_json_mode = false;
state.suppression_mode = true;
state.brace_depth = 0;
state.json_start_in_buffer = Some(json_start);
// Count braces from json_start to see if JSON is complete
let buffer_slice = state.buffer[json_start..].to_string();
for ch in buffer_slice.chars() {
match ch {
'{' => state.brace_depth += 1,
'}' => {
state.brace_depth -= 1;
if state.brace_depth <= 0 {
debug!("JSON tool call completed immediately");
let result = extract_fixed_content(&state.buffer, json_start);
let new_content = if result.len() > state.content_returned_up_to {
result[state.content_returned_up_to..].to_string()
} else {
String::new()
};
state.reset();
return new_content;
}
}
_ => {}
}
}
// JSON incomplete, stay in suppression mode, return nothing
return String::new();
}
}
// Check if we can rule out this being a tool call
// If we have enough content after the { and it doesn't match the pattern, release it
if let Some(potential_start) = state.potential_json_start {
let content_after_brace = &state.buffer[potential_start..];
// Rule out as a tool call if:
// 1. Closing } appears before we see the full pattern
// 2. Content clearly doesn't match the tool call pattern
// 3. Newline appears after the opening brace (tool calls should be compact)
let has_closing_brace = content_after_brace.contains('}');
let has_newline = content_after_brace[1..].contains('\n'); // Skip first char which is {
let long_enough = content_after_brace.len() >= 10;
// Detect non-tool JSON patterns:
// - { followed by " and a key that doesn't start with "tool"
// - { followed by "t" but not "to"
// - { followed by "to" but not "too", etc.
let not_tool_pattern = Regex::new(r#"^\{\s*"(?:[^t]|t(?:[^o]|o(?:[^o]|o(?:[^l]|l[^"\s:]))))"#).unwrap();
let definitely_not_tool = not_tool_pattern.is_match(content_after_brace);
if has_closing_brace || has_newline || (long_enough && definitely_not_tool) {
debug!("Potential JSON ruled out - not a tool call");
state.potential_json_mode = false;
state.potential_json_start = None;
// Return the buffered content we've been holding
let new_content = if state.buffer.len() > state.content_returned_up_to {
state.buffer[state.content_returned_up_to..].to_string()
} else {
String::new()
};
state.content_returned_up_to = state.buffer.len();
return new_content;
}
}
// Still in potential mode, keep buffering
return String::new();
}
// Detect potential JSON start: { at the beginning of a line
let potential_json_regex = Regex::new(r"(?m)^\s*\{\s*").unwrap();
if let Some(captures) = potential_json_regex.find(&state.buffer[state.content_returned_up_to..]) {
let match_start = state.content_returned_up_to + captures.start();
let brace_pos = match_start + captures.as_str().find('{').unwrap();
debug!("Potential JSON detected at position {} - entering buffering mode", brace_pos);
// Fast path: check if this is already a confirmed tool call
let tool_call_regex = Regex::new(r#"(?m)^\s*\{\s*"tool"\s*:\s*""#).unwrap();
if tool_call_regex.is_match(&state.buffer[brace_pos..]) {
// This is a confirmed tool call! Process it immediately
let json_start = brace_pos;
debug!("Immediately confirmed tool call at position {}", json_start);
// Return content before JSON
let content_before = if json_start > state.content_returned_up_to {
state.buffer[state.content_returned_up_to..json_start].to_string()
} else {
String::new()
};
state.content_returned_up_to = json_start;
state.suppression_mode = true;
state.brace_depth = 0;
state.json_start_in_buffer = Some(json_start);
// Count braces to see if JSON is complete
let buffer_slice = state.buffer[json_start..].to_string();
for ch in buffer_slice.chars() {
match ch {
'{' => state.brace_depth += 1,
'}' => {
state.brace_depth -= 1;
if state.brace_depth <= 0 {
debug!("JSON tool call completed in same chunk");
let result = extract_fixed_content(&state.buffer, json_start);
let content_after = if result.len() > json_start {
&result[json_start..]
} else {
""
};
let final_result = format!("{}{}", content_before, content_after);
state.reset();
return final_result;
}
}
_ => {}
}
}
// JSON incomplete, return content before and stay in suppression mode
return content_before;
}
// Return content before the potential JSON
let content_before = if brace_pos > state.content_returned_up_to {
state.buffer[state.content_returned_up_to..brace_pos].to_string()
} else {
String::new()
};
state.content_returned_up_to = brace_pos;
state.potential_json_mode = true;
state.potential_json_start = Some(brace_pos);
// Optimization: immediately check if we can rule this out for single-chunk processing
let content_after_brace = &state.buffer[brace_pos..];
let has_closing_brace = content_after_brace.contains('}');
let has_newline = content_after_brace.len() > 1 && content_after_brace[1..].contains('\n');
let long_enough = content_after_brace.len() >= 10;
let not_tool_pattern = Regex::new(r#"^\{\s*"(?:[^t]|t(?:[^o]|o(?:[^o]|o(?:[^l]|l[^"\s:]))))"#).unwrap();
let definitely_not_tool = not_tool_pattern.is_match(content_after_brace);
if has_closing_brace || has_newline || (long_enough && definitely_not_tool) {
debug!("Immediately ruled out as not a tool call");
state.potential_json_mode = false;
state.potential_json_start = None;
// Return all the buffered content
let new_content = if state.buffer.len() > state.content_returned_up_to {
state.buffer[state.content_returned_up_to..].to_string()
} else {
String::new()
};
state.content_returned_up_to = state.buffer.len();
return format!("{}{}", content_before, new_content);
}
return content_before;
}
// Check for tool call pattern using corrected regex
let tool_call_regex = Regex::new(r#"(?m)^\s*\{\s*"tool"\s*:\s*"[^"]*""#).unwrap();
if let Some(captures) = tool_call_regex.find(&state.buffer) {
let match_text = captures.as_str();
// Find the position of the opening brace in the match
if let Some(brace_offset) = match_text.find('{') {
let json_start = captures.start() + brace_offset;
debug!(
"Detected JSON tool call at position {} - entering suppression mode",
json_start
);
// Return content before JSON that we haven't returned yet
let content_before_json = if json_start >= state.content_returned_up_to {
state.buffer[state.content_returned_up_to..json_start].to_string()
} else {
String::new()
};
state.content_returned_up_to = json_start;
// Enter suppression mode
state.suppression_mode = true;
state.brace_depth = 0;
state.json_start_in_buffer = Some(json_start);
// Count braces from the JSON start to see if it's complete
let buffer_clone = state.buffer.clone();
for ch in buffer_clone[json_start..].chars() {
match ch {
'{' => state.brace_depth += 1,
'}' => {
state.brace_depth -= 1;
if state.brace_depth <= 0 {
// JSON is complete in this chunk
debug!("JSON tool call completed in same chunk");
let result = extract_fixed_content(&buffer_clone, json_start);
// Return content before JSON plus content after JSON
let content_after_json = if result.len() > json_start {
&result[json_start..]
} else {
""
};
let final_result =
format!("{}{}", content_before_json, content_after_json);
state.reset();
return final_result;
}
}
_ => {}
}
}
// JSON is incomplete, return only the content before JSON
return content_before_json;
}
}
// No JSON tool call detected, return only the new content we haven't returned yet
if state.buffer.len() > state.content_returned_up_to {
let result = state.buffer[state.content_returned_up_to..].to_string();
state.content_returned_up_to = state.buffer.len();
result
} else {
String::new()
}
})
}
/// Extracts content from buffer, removing the JSON tool call.
///
/// Given a buffer and the start position of a JSON tool call, this function:
/// 1. Extracts all content before the JSON
/// 2. Finds the end of the JSON (matching closing brace)
/// 3. Extracts all content after the JSON
/// 4. Returns the concatenation of before + after (JSON removed)
///
/// # Arguments
/// * `full_content` - The full content buffer
/// * `json_start` - Position where the JSON tool call begins
fn extract_fixed_content(full_content: &str, json_start: usize) -> String {
// Find the end of the JSON using proper brace counting with string handling
let mut brace_depth = 0;
let mut json_end = json_start;
let mut in_string = false;
let mut escape_next = false;
for (i, ch) in full_content[json_start..].char_indices() {
if escape_next {
escape_next = false;
continue;
}
match ch {
'\\' if in_string => escape_next = true,
'"' if !escape_next => in_string = !in_string,
'{' if !in_string => {
brace_depth += 1;
}
'}' if !in_string => {
brace_depth -= 1;
if brace_depth == 0 {
json_end = json_start + i + 1; // +1 to include the closing brace
break;
}
}
_ => {}
}
}
// Return content before and after the JSON (excluding the JSON itself)
let before = &full_content[..json_start];
let after = if json_end < full_content.len() {
&full_content[json_end..]
} else {
""
};
format!("{}{}", before, after)
}
/// Resets the global JSON filtering state.
///
/// Call this between independent filtering sessions to ensure clean state.
/// This is particularly important in tests and when starting new conversations.
pub fn reset_fixed_json_tool_state() {
FIXED_JSON_TOOL_STATE.with(|state| {
let mut state = state.borrow_mut();
state.reset();
});
}

View File

@@ -0,0 +1,384 @@
//! Tests for JSON tool call filtering.
//!
//! These tests verify that the filter correctly identifies and removes JSON tool calls
//! from LLM output streams while preserving all other content.
#[cfg(test)]
mod fixed_filter_tests {
use crate::fixed_filter_json::{fixed_filter_json_tool_calls, reset_fixed_json_tool_state};
use regex::Regex;
/// Test that regular text without tool calls passes through unchanged.
#[test]
fn test_no_tool_call_passthrough() {
reset_fixed_json_tool_state();
let input = "This is regular text without any tool calls.";
let result = fixed_filter_json_tool_calls(input);
assert_eq!(result, input);
}
/// Test detection and removal of a complete tool call in a single chunk.
#[test]
fn test_simple_tool_call_detection() {
reset_fixed_json_tool_state();
let input = r#"Some text before
{"tool": "shell", "args": {"command": "ls"}}
Some text after"#;
let result = fixed_filter_json_tool_calls(input);
let expected = "Some text before\n\nSome text after";
assert_eq!(result, expected);
}
/// Test handling of tool calls that arrive across multiple streaming chunks.
#[test]
fn test_streaming_chunks() {
reset_fixed_json_tool_state();
// Simulate streaming where the tool call comes in multiple chunks
let chunks = vec![
"Some text before\n",
"{\"tool\": \"",
"shell\", \"args\": {",
"\"command\": \"ls\"",
"}}\nText after",
];
let mut results = Vec::new();
for chunk in chunks {
let result = fixed_filter_json_tool_calls(chunk);
results.push(result);
}
// The final accumulated result should have the JSON filtered out
let final_result: String = results.join("");
let expected = "Some text before\n\nText after";
assert_eq!(final_result, expected);
}
/// Test correct handling of nested braces within JSON strings.
#[test]
fn test_nested_braces_in_tool_call() {
reset_fixed_json_tool_state();
let input = r#"Text before
{"tool": "write_file", "args": {"file_path": "test.json", "content": "{\"nested\": \"value\"}"}}
Text after"#;
let result = fixed_filter_json_tool_calls(input);
let expected = "Text before\n\nText after";
assert_eq!(result, expected);
}
/// Verify the regex pattern matches the specification with flexible whitespace.
#[test]
fn test_regex_pattern_specification() {
// Test the corrected regex pattern that's more flexible with whitespace
let pattern = Regex::new(r#"(?m)^\s*\{\s*"tool"\s*:"#).unwrap();
let test_cases = vec![
(
r#"line
{"tool":"#,
true,
),
(
r#"line
{"tool" :"#,
true,
),
(
r#"line
{ "tool":"#,
true,
), // Space after { DOES match with \s*
(
r#"line
{"tool123":"#,
false,
), // "tool123" is not exactly "tool"
(
r#"line
{"tool" : "#,
true,
),
];
for (input, should_match) in test_cases {
let matches = pattern.is_match(input);
assert_eq!(
matches, should_match,
"Pattern matching failed for: {}",
input
);
}
}
/// Test that tool calls must appear at the start of a line (after newline).
#[test]
fn test_newline_requirement() {
reset_fixed_json_tool_state();
// According to spec, tool call should be detected "on the very next newline"
// Our current regex matches any line that contains the pattern, not just after newlines
let input_with_newline = "Text\n{\"tool\": \"shell\", \"args\": {\"command\": \"ls\"}}";
let input_without_newline = "Text {\"tool\": \"shell\", \"args\": {\"command\": \"ls\"}}";
let result1 = fixed_filter_json_tool_calls(input_with_newline);
reset_fixed_json_tool_state();
let result2 = fixed_filter_json_tool_calls(input_without_newline);
// With the new aggressive filtering, only the newline case should trigger suppression
// The pattern requires { to be at the start of a line (after ^)
assert_eq!(result1, "Text\n");
// Without newline before {, it should pass through unchanged
assert_eq!(result2, input_without_newline);
}
/// Test handling of escaped quotes within JSON strings.
#[test]
fn test_json_with_escaped_quotes() {
reset_fixed_json_tool_state();
let input = r#"Text
{"tool": "write_file", "args": {"content": "He said \"hello\" to me"}}
More text"#;
let result = fixed_filter_json_tool_calls(input);
let expected = "Text\n\nMore text";
assert_eq!(result, expected);
}
/// Test graceful handling of incomplete/malformed JSON.
#[test]
fn test_edge_case_malformed_json() {
reset_fixed_json_tool_state();
// Test what happens with malformed JSON that starts like a tool call
let input = r#"Text
{"tool": "shell", "args": {"command": "ls"
More text"#;
let result = fixed_filter_json_tool_calls(input);
// Should handle gracefully - since JSON is incomplete, it should return content before JSON
let expected = "Text\n";
assert_eq!(result, expected);
}
/// Test processing multiple independent tool calls sequentially.
#[test]
fn test_multiple_tool_calls_sequential() {
reset_fixed_json_tool_state();
// Test processing multiple tool calls one at a time
let input1 = r#"First text
{"tool": "shell", "args": {"command": "ls"}}
Middle text"#;
let result1 = fixed_filter_json_tool_calls(input1);
let expected1 = "First text\n\nMiddle text";
assert_eq!(result1, expected1);
// Reset and process second tool call
reset_fixed_json_tool_state();
let input2 = r#"More text
{"tool": "read_file", "args": {"file_path": "test.txt"}}
Final text"#;
let result2 = fixed_filter_json_tool_calls(input2);
let expected2 = "More text\n\nFinal text";
assert_eq!(result2, expected2);
}
/// Test tool calls with complex multi-line arguments.
#[test]
fn test_tool_call_with_complex_args() {
reset_fixed_json_tool_state();
let input = r#"Before
{"tool": "str_replace", "args": {"file_path": "test.rs", "diff": "--- old\n-old line\n+++ new\n+new line", "start": 0, "end": 100}}
After"#;
let result = fixed_filter_json_tool_calls(input);
let expected = "Before\n\nAfter";
assert_eq!(result, expected);
}
/// Test input containing only a tool call with no surrounding text.
#[test]
fn test_tool_call_only() {
reset_fixed_json_tool_state();
let input = r#"
{"tool": "final_output", "args": {"summary": "Task completed successfully"}}"#;
let result = fixed_filter_json_tool_calls(input);
let expected = "\n";
assert_eq!(result, expected);
}
/// Test accurate brace counting with deeply nested structures.
#[test]
fn test_brace_counting_accuracy() {
reset_fixed_json_tool_state();
// Test complex nested structure
let input = r#"Start
{"tool": "write_file", "args": {"content": "function() { return {a: 1, b: {c: 2}}; }", "file_path": "test.js"}}
End"#;
let result = fixed_filter_json_tool_calls(input);
let expected = "Start\n\nEnd";
assert_eq!(result, expected);
}
/// Test that braces within strings don't affect brace counting.
#[test]
fn test_string_escaping_in_json() {
reset_fixed_json_tool_state();
// Test JSON with escaped quotes and braces in strings
let input = r#"Text
{"tool": "shell", "args": {"command": "echo \"Hello {world}\" > file.txt"}}
More"#;
let result = fixed_filter_json_tool_calls(input);
let expected = "Text\n\nMore";
assert_eq!(result, expected);
}
/// Verify compliance with the exact specification requirements.
#[test]
fn test_specification_compliance() {
reset_fixed_json_tool_state();
// Test the exact specification requirements:
// 1. Detect start with regex '\w*{\w*"tool"\w*:\w*"' on newline
// 2. Enter suppression mode and use brace counting
// 3. Elide only JSON between first '{' and last '}' (inclusive)
// 4. Return everything else
let input = "Before text\nSome more text\n{\"tool\": \"test\", \"args\": {}}\nAfter text\nMore after";
let result = fixed_filter_json_tool_calls(input);
let expected = "Before text\nSome more text\n\nAfter text\nMore after";
assert_eq!(result, expected);
}
/// Test that non-tool JSON objects are not filtered.
#[test]
fn test_no_false_positives() {
reset_fixed_json_tool_state();
// Test that we don't incorrectly identify non-tool JSON as tool calls
let input = r#"Some text
{"not_tool": "value", "other": "data"}
More text"#;
let result = fixed_filter_json_tool_calls(input);
// Should pass through unchanged since it doesn't match the tool pattern
assert_eq!(result, input);
}
/// Test patterns that look similar to tool calls but aren't exact matches.
#[test]
fn test_partial_tool_patterns() {
reset_fixed_json_tool_state();
// Test patterns that look like tool calls but aren't complete
let test_cases = vec![
"Text\n{\"too\": \"value\"}", // "too" not "tool"
"Text\n{\"tools\": \"value\"}", // "tools" not "tool"
"Text\n{\"tool\": }", // Missing value after colon
];
for input in test_cases {
reset_fixed_json_tool_state();
let result = fixed_filter_json_tool_calls(input);
// These should all pass through unchanged
assert_eq!(result, input, "Input should pass through: {}", input);
}
}
/// Test streaming with very small chunks (character-by-character).
#[test]
fn test_streaming_edge_cases() {
reset_fixed_json_tool_state();
// Test streaming with very small chunks
let chunks = vec![
"Text\n", "{", "\"", "tool", "\"", ":", " ", "\"", "test", "\"", "}", "\nAfter",
];
let mut results = Vec::new();
for chunk in chunks {
let result = fixed_filter_json_tool_calls(chunk);
results.push(result);
}
let final_result: String = results.join("");
// With the new aggressive filtering, the JSON should be completely filtered out
// even when it arrives in very small chunks
let expected = "Text\n\nAfter";
assert_eq!(final_result, expected);
}
/// Debug test with detailed logging for streaming behavior.
#[test]
fn test_streaming_debug() {
reset_fixed_json_tool_state();
// Debug the exact failing case
let chunks = vec![
"Some text before\n",
"{\"tool\": \"",
"shell\", \"args\": {",
"\"command\": \"ls\"",
"}}\nText after",
];
let mut results = Vec::new();
for (i, chunk) in chunks.iter().enumerate() {
let result = fixed_filter_json_tool_calls(chunk);
println!("Chunk {}: {:?} -> {:?}", i, chunk, result);
results.push(result);
}
let final_result: String = results.join("");
println!("Final result: {:?}", final_result);
println!("Expected: {:?}", "Some text before\n\nText after");
let expected = "Some text before\n\nText after";
assert_eq!(final_result, expected);
}
/// Test handling of truncated JSON followed by complete JSON (the json_err pattern)
#[test]
fn test_truncated_then_complete_json() {
reset_fixed_json_tool_state();
// Simulate the pattern from json_err trace:
// 1. Incomplete/truncated JSON appears
// 2. Then the same complete JSON appears
let chunks = vec![
"Some text\n",
r#"{"tool": "str_replace", "args": {"diff":"...","file_path":"./crates/g3-cli"#, // Truncated
r#"{"tool": "str_replace", "args": {"diff":"...","file_path":"./crates/g3-cli/src/lib.rs"}}"#, // Complete
"\nMore text",
];
let mut results = Vec::new();
for (i, chunk) in chunks.iter().enumerate() {
let result = fixed_filter_json_tool_calls(chunk);
println!("Chunk {}: {:?} -> {:?}", i, chunk, result);
results.push(result);
}
let final_result: String = results.join("");
println!("Final result: {:?}", final_result);
// The truncated JSON should be discarded when the complete one appears
// Both JSONs should be filtered out, leaving only the text
let expected = "Some text\n\nMore text";
assert_eq!(
final_result, expected,
"Failed to handle truncated JSON followed by complete JSON"
);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,184 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
/// Represents a G3 project with workspace configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Project {
/// The workspace directory for the project
pub workspace_dir: PathBuf,
/// Path to the requirements document (for autonomous mode)
pub requirements_path: Option<PathBuf>,
/// Override requirements text (takes precedence over requirements_path)
pub requirements_text: Option<String>,
/// Whether the project is in autonomous mode
pub autonomous: bool,
/// Project name (derived from workspace directory name)
pub name: String,
/// Session ID for tracking
pub session_id: Option<String>,
}
impl Project {
/// Create a new project with the given workspace directory
pub fn new(workspace_dir: PathBuf) -> Self {
let name = workspace_dir
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("unnamed")
.to_string();
Self {
workspace_dir,
requirements_path: None,
requirements_text: None,
autonomous: false,
name,
session_id: None,
}
}
/// Create a project for autonomous mode
pub fn new_autonomous(workspace_dir: PathBuf) -> Result<Self> {
let mut project = Self::new(workspace_dir.clone());
project.autonomous = true;
// Look for requirements.md in the workspace directory
let requirements_path = workspace_dir.join("requirements.md");
if requirements_path.exists() {
project.requirements_path = Some(requirements_path);
}
Ok(project)
}
/// Create a project for autonomous mode with requirements text override
pub fn new_autonomous_with_requirements(workspace_dir: PathBuf, requirements_text: String) -> Result<Self> {
let mut project = Self::new(workspace_dir.clone());
project.autonomous = true;
project.requirements_text = Some(requirements_text);
// Don't look for requirements.md file when text is provided
// The text override takes precedence
Ok(project)
}
/// Set the workspace directory and update related paths
pub fn set_workspace(&mut self, workspace_dir: PathBuf) {
self.workspace_dir = workspace_dir.clone();
self.name = workspace_dir
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("unnamed")
.to_string();
// Update requirements path if in autonomous mode
if self.autonomous {
let requirements_path = workspace_dir.join("requirements.md");
if requirements_path.exists() {
self.requirements_path = Some(requirements_path);
}
}
}
/// Get the workspace directory
pub fn workspace(&self) -> &Path {
&self.workspace_dir
}
/// Check if requirements file exists
pub fn has_requirements(&self) -> bool {
// Has requirements if either text override is provided or requirements file exists
self.requirements_text.is_some() || self.requirements_path.is_some()
}
/// Check if implementation files exist in the workspace
pub fn has_implementation_files(&self) -> bool {
self.check_dir_for_implementation_files(&self.workspace_dir)
}
/// Recursively check a directory for implementation files
#[allow(clippy::only_used_in_recursion)]
fn check_dir_for_implementation_files(&self, dir: &Path) -> bool {
// Common source file extensions
let extensions = vec![
"swift", "rs", "py", "js", "ts", "java", "cpp", "c",
"go", "rb", "php", "cs", "kt", "scala", "m", "h"
];
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.is_file() {
// Check if it's a source file
if let Some(ext) = path.extension() {
if let Some(ext_str) = ext.to_str() {
if extensions.contains(&ext_str) {
return true;
}
}
}
} else if path.is_dir() {
// Skip hidden directories and common non-source directories
if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
if !name.starts_with('.') && name != "logs" && name != "target" && name != "node_modules" {
// Recursively check subdirectories
if self.check_dir_for_implementation_files(&path) {
return true;
}
}
}
}
}
}
false
}
/// Read the requirements file content
pub fn read_requirements(&self) -> Result<Option<String>> {
// Prioritize requirements text override
if let Some(ref text) = self.requirements_text {
Ok(Some(text.clone()))
} else if let Some(ref path) = self.requirements_path {
// Fall back to reading from file
Ok(Some(std::fs::read_to_string(path)?))
} else {
Ok(None)
}
}
/// Create the workspace directory if it doesn't exist
pub fn ensure_workspace_exists(&self) -> Result<()> {
if !self.workspace_dir.exists() {
std::fs::create_dir_all(&self.workspace_dir)?;
}
Ok(())
}
/// Change to the workspace directory
pub fn enter_workspace(&self) -> Result<()> {
std::env::set_current_dir(&self.workspace_dir)?;
Ok(())
}
/// Get the logs directory for the project
pub fn logs_dir(&self) -> PathBuf {
self.workspace_dir.join("logs")
}
/// Ensure the logs directory exists
pub fn ensure_logs_dir(&self) -> Result<()> {
let logs_dir = self.logs_dir();
if !logs_dir.exists() {
std::fs::create_dir_all(&logs_dir)?;
}
Ok(())
}
}

View File

@@ -1 +0,0 @@

View File

@@ -0,0 +1,37 @@
// Test to verify take_screenshot requires window_id
#[cfg(test)]
mod take_screenshot_tests {
use super::*;
use serde_json::json;
#[test]
fn test_take_screenshot_requires_window_id() {
// Create a tool call without window_id
let tool_call = ToolCall {
tool: "take_screenshot".to_string(),
args: json!({
"path": "test.png"
}),
};
// Verify that window_id is missing
assert!(tool_call.args.get("window_id").is_none());
}
#[test]
fn test_take_screenshot_with_window_id() {
// Create a tool call with window_id
let tool_call = ToolCall {
tool: "take_screenshot".to_string(),
args: json!({
"path": "test.png",
"window_id": "Safari"
}),
};
// Verify that window_id is present
assert!(tool_call.args.get("window_id").is_some());
assert_eq!(tool_call.args.get("window_id").unwrap().as_str().unwrap(), "Safari");
}
}

View File

@@ -0,0 +1,168 @@
use crate::ContextWindow;
/// Result of a task execution containing both the response and the context window
#[derive(Debug, Clone)]
pub struct TaskResult {
/// The actual response content from the task execution
pub response: String,
/// The complete context window at the time of completion
pub context_window: ContextWindow,
}
impl TaskResult {
pub fn new(response: String, context_window: ContextWindow) -> Self {
Self {
response,
context_window,
}
}
/// Extract the final_output content from the response (for coach feedback in autonomous mode)
/// This looks for the complete final_output content, not just the last block
pub fn extract_final_output(&self) -> String {
// Remove any timing information at the end
let content_without_timing = if let Some(timing_pos) = self.response.rfind("\n⏱️") {
&self.response[..timing_pos]
} else {
&self.response
};
// Look for the final_output marker pattern
// The final_output content typically appears after the tool is called
// and is the substantive content that follows
// First, try to find if there's a clear final_output section
// This would be the content after the last tool execution
if let Some(final_output_pos) = content_without_timing.rfind("final_output") {
// Find the content that follows the final_output call
// Skip past the tool call line and any immediate formatting
if let Some(content_start) = content_without_timing[final_output_pos..].find('\n') {
let start_pos = final_output_pos + content_start + 1;
let final_content = &content_without_timing[start_pos..];
// Trim and return the complete content
let trimmed = final_content.trim();
if !trimmed.is_empty() {
return trimmed.to_string();
}
}
}
// Fallback to the original extract_last_block behavior if we can't find final_output
// This maintains backward compatibility
self.extract_last_block()
}
/// Extract the last block from the response (for coach feedback in autonomous mode)
/// This looks for the final_output content which is the last substantial block
pub fn extract_last_block(&self) -> String {
// Remove any timing information at the end
let content_without_timing = if let Some(timing_pos) = self.response.rfind("\n⏱️") {
&self.response[..timing_pos]
} else {
&self.response
};
// Split by double newlines to find the last substantial block
let blocks: Vec<&str> = content_without_timing.split("\n\n").collect();
// Find the last non-empty block that isn't just whitespace
blocks.iter()
.rev()
.find(|block| !block.trim().is_empty())
.map(|block| block.trim().to_string())
.unwrap_or_else(|| {
// Fallback: if we can't find a clear block, take the whole thing
content_without_timing.trim().to_string()
})
}
/// Check if the response contains an approval (for autonomous mode)
pub fn is_approved(&self) -> bool {
self.extract_final_output().contains("IMPLEMENTATION_APPROVED")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_last_block() {
// Test case 1: Response with timing info
let context_window = ContextWindow::new(1000);
let response_with_timing = "Some initial content\n\nFinal block content\n\n⏱️ 2.3s | 💭 1.2s".to_string();
let result = TaskResult::new(response_with_timing, context_window.clone());
assert_eq!(result.extract_last_block(), "Final block content");
// Test case 2: Response without timing
let response_no_timing = "Some initial content\n\nFinal block content".to_string();
let result = TaskResult::new(response_no_timing, context_window.clone());
assert_eq!(result.extract_last_block(), "Final block content");
// Test case 3: Response with IMPLEMENTATION_APPROVED
let response_approved = "Some content\n\nIMPLEMENTATION_APPROVED".to_string();
let result = TaskResult::new(response_approved, context_window.clone());
assert!(result.is_approved());
// Test case 4: Response without approval
let response_not_approved = "Some content\n\nNeeds more work".to_string();
let result = TaskResult::new(response_not_approved, context_window);
assert!(!result.is_approved());
}
#[test]
fn test_extract_last_block_edge_cases() {
let context_window = ContextWindow::new(1000);
// Test empty response
let empty_response = "".to_string();
let result = TaskResult::new(empty_response, context_window.clone());
assert_eq!(result.extract_last_block(), "");
// Test single block
let single_block = "Just one block".to_string();
let result = TaskResult::new(single_block, context_window.clone());
assert_eq!(result.extract_last_block(), "Just one block");
// Test multiple empty blocks
let multiple_empty = "\n\n\n\nSome content\n\n\n\n".to_string();
let result = TaskResult::new(multiple_empty, context_window);
assert_eq!(result.extract_last_block(), "Some content");
}
#[test]
fn test_extract_final_output() {
let context_window = ContextWindow::new(1000);
// Test case 1: Response with final_output tool call
let response_with_final_output = "Analyzing files...\n\nCalling final_output\n\nThis is the complete feedback\nwith multiple lines\nand important details\n\n⏱️ 2.3s".to_string();
let result = TaskResult::new(response_with_final_output, context_window.clone());
assert_eq!(result.extract_final_output(), "This is the complete feedback\nwith multiple lines\nand important details");
// Test case 2: Response with IMPLEMENTATION_APPROVED in final_output
let response_approved = "Review complete\n\nfinal_output called\n\nIMPLEMENTATION_APPROVED".to_string();
let result = TaskResult::new(response_approved, context_window.clone());
assert_eq!(result.extract_final_output(), "IMPLEMENTATION_APPROVED");
assert!(result.is_approved());
// Test case 3: Response with detailed feedback in final_output
let response_feedback = "Checking implementation...\n\nfinal_output\n\nThe following issues need to be addressed:\n1. Missing error handling in main.rs\n2. Tests are not comprehensive\n3. Documentation needs improvement\n\nPlease fix these issues.".to_string();
let result = TaskResult::new(response_feedback, context_window.clone());
let extracted = result.extract_final_output();
assert!(extracted.contains("The following issues need to be addressed:"));
assert!(extracted.contains("1. Missing error handling"));
assert!(extracted.contains("Please fix these issues."));
assert!(!result.is_approved());
// Test case 4: Response without final_output (fallback to extract_last_block)
let response_no_final_output = "Some analysis\n\nFinal thoughts here".to_string();
let result = TaskResult::new(response_no_final_output, context_window.clone());
assert_eq!(result.extract_final_output(), "Final thoughts here");
// Test case 5: Empty response
let empty_response = "".to_string();
let result = TaskResult::new(empty_response, context_window);
assert_eq!(result.extract_final_output(), "");
}
}

View File

@@ -0,0 +1,247 @@
use crate::{ContextWindow, TaskResult};
use g3_providers::{Message, MessageRole};
use std::sync::Arc;
#[test]
fn test_task_result_basic_functionality() {
// Create a context window with some messages
let mut context = ContextWindow::new(10000);
context.add_message(Message {
role: MessageRole::User,
content: "Test message 1".to_string(),
});
context.add_message(Message {
role: MessageRole::Assistant,
content: "Response 1".to_string(),
});
// Create a TaskResult
let response = "This is the response\n\nFinal output block".to_string();
let result = TaskResult::new(response.clone(), context.clone());
// Test basic properties
assert_eq!(result.response, response);
assert_eq!(result.context_window.conversation_history.len(), 2);
assert_eq!(result.context_window.total_tokens, 10000);
}
#[test]
fn test_extract_last_block_various_formats() {
let context = ContextWindow::new(1000);
// Test 1: Standard format with multiple blocks
let response1 = "First block\n\nSecond block\n\nThird block".to_string();
let result1 = TaskResult::new(response1, context.clone());
assert_eq!(result1.extract_last_block(), "Third block");
// Test 2: With timing information
let response2 = "Content\n\nFinal block\n\n⏱️ 2.3s | 💭 1.2s".to_string();
let result2 = TaskResult::new(response2, context.clone());
assert_eq!(result2.extract_last_block(), "Final block");
// Test 3: Single line response
let response3 = "Single line response".to_string();
let result3 = TaskResult::new(response3, context.clone());
assert_eq!(result3.extract_last_block(), "Single line response");
// Test 4: Empty response
let response4 = "".to_string();
let result4 = TaskResult::new(response4, context.clone());
assert_eq!(result4.extract_last_block(), "");
// Test 5: Only whitespace
let response5 = "\n\n\n \n\n".to_string();
let result5 = TaskResult::new(response5, context.clone());
assert_eq!(result5.extract_last_block(), "");
// Test 6: Multiple blocks with empty ones
let response6 = "First\n\n\n\n\n\nLast block here".to_string();
let result6 = TaskResult::new(response6, context.clone());
assert_eq!(result6.extract_last_block(), "Last block here");
}
#[test]
fn test_is_approved_detection() {
let context = ContextWindow::new(1000);
// Test approved cases
let approved_responses = vec![
"Analysis complete\n\nIMPLEMENTATION_APPROVED",
"Some content\n\nThe implementation is good. IMPLEMENTATION_APPROVED",
"IMPLEMENTATION_APPROVED",
"Review done\n\n✅ IMPLEMENTATION_APPROVED - All tests pass",
];
for response in approved_responses {
let result = TaskResult::new(response.to_string(), context.clone());
assert!(result.is_approved(), "Failed to detect approval in: {}", response);
}
// Test not approved cases
let not_approved_responses = vec![
"Needs more work",
"Implementation needs fixes",
"IMPLEMENTATION_REJECTED",
"Almost there but not APPROVED",
"",
];
for response in not_approved_responses {
let result = TaskResult::new(response.to_string(), context.clone());
assert!(!result.is_approved(), "Incorrectly detected approval in: {}", response);
}
}
#[test]
fn test_context_window_preservation() {
// Create a context window with specific state
let mut context = ContextWindow::new(5000);
context.used_tokens = 1234;
// Add some messages
for i in 0..5 {
context.add_message(Message {
role: if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant },
content: format!("Message {}", i),
});
}
// Create TaskResult
let result = TaskResult::new("Response".to_string(), context.clone());
// Verify context is preserved
assert_eq!(result.context_window.total_tokens, 5000);
assert!(result.context_window.used_tokens > 1234); // Should have increased
assert_eq!(result.context_window.conversation_history.len(), 5);
// Verify messages are preserved correctly
for i in 0..5 {
let is_user = matches!(result.context_window.conversation_history[i].role, MessageRole::User);
let expected_is_user = i % 2 == 0;
assert_eq!(is_user, expected_is_user, "Message {} has wrong role", i);
assert_eq!(result.context_window.conversation_history[i].content, format!("Message {}", i));
}
}
#[test]
fn test_coach_feedback_extraction_scenarios() {
let context = ContextWindow::new(1000);
// Scenario 1: Coach feedback with file operations and analysis
let coach_response = r#"Reading file: src/main.rs
📄 File content (23 lines):
fn main() {
println!("Hello");
}
Analyzing implementation...
The implementation needs the following fixes:
1. Add error handling
2. Implement missing functions
3. Add tests"#;
let result = TaskResult::new(coach_response.to_string(), context.clone());
let feedback = result.extract_last_block();
assert!(feedback.contains("Add error handling"));
assert!(feedback.contains("Implement missing functions"));
assert!(feedback.contains("Add tests"));
// Scenario 2: Coach approval
let approval_response = r#"Checking compilation...
✅ Build successful
Running tests...
✅ All tests pass
IMPLEMENTATION_APPROVED"#;
let result = TaskResult::new(approval_response.to_string(), context.clone());
assert!(result.is_approved());
assert_eq!(result.extract_last_block(), "IMPLEMENTATION_APPROVED");
// Scenario 3: Complex feedback with timing
let complex_response = r#"Tool execution log...
Analysis complete.
The following issues were found:
- Memory leak in process_data()
- Missing input validation
⏱️ 5.2s | 💭 2.1s"#;
let result = TaskResult::new(complex_response.to_string(), context.clone());
let feedback = result.extract_last_block();
assert!(feedback.contains("Memory leak"));
assert!(feedback.contains("Missing input validation"));
assert!(!feedback.contains("⏱️")); // Timing should be stripped
}
#[test]
fn test_edge_cases_and_special_characters() {
let context = ContextWindow::new(1000);
// Test with special characters and emojis
let response_with_emojis = "First part 🚀\n\n✅ Final part with emojis 🎉".to_string();
let result = TaskResult::new(response_with_emojis, context.clone());
assert_eq!(result.extract_last_block(), "✅ Final part with emojis 🎉");
// Test with code blocks
let response_with_code = "Explanation\n\n```rust\nfn main() {}\n```\n\nFinal comment".to_string();
let result = TaskResult::new(response_with_code, context.clone());
assert_eq!(result.extract_last_block(), "Final comment");
// Test with mixed newlines
let mixed_newlines = "Part 1\r\n\r\nPart 2\n\nPart 3".to_string();
let result = TaskResult::new(mixed_newlines, context.clone());
assert_eq!(result.extract_last_block(), "Part 3");
}
#[test]
fn test_large_response_handling() {
let context = ContextWindow::new(100000);
// Create a large response
let mut large_response = String::new();
for i in 0..100 {
large_response.push_str(&format!("Block {} with some content\n\n", i));
}
large_response.push_str("This is the final block after 100 other blocks");
let result = TaskResult::new(large_response, context);
assert_eq!(result.extract_last_block(), "This is the final block after 100 other blocks");
}
#[test]
fn test_concurrent_access() {
use std::thread;
let context = ContextWindow::new(1000);
let result = Arc::new(TaskResult::new(
"Concurrent test\n\nFinal block".to_string(),
context,
));
let mut handles = vec![];
// Spawn multiple threads to access the TaskResult
for _ in 0..10 {
let result_clone = Arc::clone(&result);
let handle = thread::spawn(move || {
// Each thread extracts the last block
let block = result_clone.extract_last_block();
assert_eq!(block, "Final block");
// Check approval status
assert!(!result_clone.is_approved());
});
handles.push(handle);
}
// Wait for all threads to complete
for handle in handles {
handle.join().unwrap();
}
}

View File

@@ -0,0 +1,48 @@
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_last_block() {
// Test case 1: Response with timing info
let context_window = ContextWindow::new(1000);
let response_with_timing = "Some initial content\n\nFinal block content\n\n⏱️ 2.3s | 💭 1.2s".to_string();
let result = TaskResult::new(response_with_timing, context_window.clone());
assert_eq!(result.extract_last_block(), "Final block content");
// Test case 2: Response without timing
let response_no_timing = "Some initial content\n\nFinal block content".to_string();
let result = TaskResult::new(response_no_timing, context_window.clone());
assert_eq!(result.extract_last_block(), "Final block content");
// Test case 3: Response with IMPLEMENTATION_APPROVED
let response_approved = "Some content\n\nIMPLEMENTATION_APPROVED".to_string();
let result = TaskResult::new(response_approved, context_window.clone());
assert!(result.is_approved());
// Test case 4: Response without approval
let response_not_approved = "Some content\n\nNeeds more work".to_string();
let result = TaskResult::new(response_not_approved, context_window);
assert!(!result.is_approved());
}
#[test]
fn test_extract_last_block_edge_cases() {
let context_window = ContextWindow::new(1000);
// Test empty response
let empty_response = "".to_string();
let result = TaskResult::new(empty_response, context_window.clone());
assert_eq!(result.extract_last_block(), "");
// Test single block
let single_block = "Just one block".to_string();
let result = TaskResult::new(single_block, context_window.clone());
assert_eq!(result.extract_last_block(), "Just one block");
// Test multiple empty blocks
let multiple_empty = "\n\n\n\nSome content\n\n\n\n".to_string();
let result = TaskResult::new(multiple_empty, context_window);
assert_eq!(result.extract_last_block(), "Some content");
}
}

View File

@@ -0,0 +1,36 @@
#[cfg(test)]
mod tilde_expansion_tests {
use std::env;
#[test]
fn test_tilde_expansion() {
// Test that shellexpand works
let path_with_tilde = "~/test.txt";
let expanded = shellexpand::tilde(path_with_tilde);
// Get the actual home directory
let home = env::var("HOME").expect("HOME environment variable not set");
// Verify expansion happened
assert_eq!(expanded.as_ref(), format!("{}/test.txt", home));
assert!(!expanded.contains("~"));
}
#[test]
fn test_tilde_expansion_with_subdirs() {
let path_with_tilde = "~/Documents/test.txt";
let expanded = shellexpand::tilde(path_with_tilde);
let home = env::var("HOME").expect("HOME environment variable not set");
assert_eq!(expanded.as_ref(), format!("{}/Documents/test.txt", home));
}
#[test]
fn test_no_tilde_unchanged() {
let path_without_tilde = "/absolute/path/test.txt";
let expanded = shellexpand::tilde(path_without_tilde);
assert_eq!(expanded.as_ref(), path_without_tilde);
}
}

View File

@@ -0,0 +1,83 @@
/// Interface for UI output operations
/// This trait abstracts all UI operations to allow different implementations
/// (console, TUI, web, etc.) without coupling the core logic to specific output methods.
pub trait UiWriter: Send + Sync {
/// Print a simple message
fn print(&self, message: &str);
/// Print a message with a newline
fn println(&self, message: &str);
/// Print without newline (for progress indicators)
fn print_inline(&self, message: &str);
/// Print a system prompt section
fn print_system_prompt(&self, prompt: &str);
/// Print a context window status message
fn print_context_status(&self, message: &str);
/// Print a context thinning success message with highlight and animation
fn print_context_thinning(&self, message: &str);
/// Print a tool execution header
fn print_tool_header(&self, tool_name: &str);
/// Print a tool argument
fn print_tool_arg(&self, key: &str, value: &str);
/// Print tool output header
fn print_tool_output_header(&self);
/// Update the current tool output line (replaces previous line)
fn update_tool_output_line(&self, line: &str);
/// Print a tool output line
fn print_tool_output_line(&self, line: &str);
/// Print tool output summary (when output is truncated)
fn print_tool_output_summary(&self, hidden_count: usize);
/// Print tool execution timing
fn print_tool_timing(&self, duration_str: &str);
/// Print the agent prompt indicator
fn print_agent_prompt(&self);
/// Print agent response inline (for streaming)
fn print_agent_response(&self, content: &str);
/// Notify that an SSE event was received (including pings)
fn notify_sse_received(&self);
/// Flush any buffered output
fn flush(&self);
/// Returns true if this UI writer wants full, untruncated output
/// Default is false (truncate for human readability)
fn wants_full_output(&self) -> bool { false }
}
/// A no-op implementation for when UI output is not needed
pub struct NullUiWriter;
impl UiWriter for NullUiWriter {
fn print(&self, _message: &str) {}
fn println(&self, _message: &str) {}
fn print_inline(&self, _message: &str) {}
fn print_system_prompt(&self, _prompt: &str) {}
fn print_context_status(&self, _message: &str) {}
fn print_context_thinning(&self, _message: &str) {}
fn print_tool_header(&self, _tool_name: &str) {}
fn print_tool_arg(&self, _key: &str, _value: &str) {}
fn print_tool_output_header(&self) {}
fn update_tool_output_line(&self, _line: &str) {}
fn print_tool_output_line(&self, _line: &str) {}
fn print_tool_output_summary(&self, _hidden_count: usize) {}
fn print_tool_timing(&self, _duration_str: &str) {}
fn print_agent_prompt(&self) {}
fn print_agent_response(&self, _content: &str) {}
fn notify_sse_received(&self) {}
fn flush(&self) {}
fn wants_full_output(&self) -> bool { false }
}

View File

@@ -0,0 +1,270 @@
use g3_core::ContextWindow;
use g3_providers::{Message, MessageRole};
#[test]
fn test_thinning_thresholds() {
let mut context = ContextWindow::new(10000);
// At 0%, should not thin
assert!(!context.should_thin());
// Simulate reaching 50% usage
context.used_tokens = 5000;
assert!(context.should_thin());
// After thinning at 50%, should not thin again until next threshold
context.last_thinning_percentage = 50;
assert!(!context.should_thin());
// At 60%, should thin again
context.used_tokens = 6000;
assert!(context.should_thin());
// After thinning at 60%, should not thin
context.last_thinning_percentage = 60;
assert!(!context.should_thin());
// At 70%, should thin
context.used_tokens = 7000;
assert!(context.should_thin());
// At 80%, should thin
context.last_thinning_percentage = 70;
context.used_tokens = 8000;
assert!(context.should_thin());
// After 80%, should not thin (compaction takes over)
context.last_thinning_percentage = 80;
context.used_tokens = 8500;
assert!(!context.should_thin());
}
#[test]
fn test_thin_context_basic() {
let mut context = ContextWindow::new(10000);
// Add some messages to the first third
for i in 0..9 {
if i % 2 == 0 {
context.add_message(Message {
role: MessageRole::Assistant,
content: format!("Assistant message {}", i),
});
} else {
// Add tool results with varying sizes
let content = if i == 1 {
// Large tool result (> 1000 chars)
format!("Tool result: {}", "x".repeat(1500))
} else if i == 3 {
// Another large tool result
format!("Tool result: {}", "y".repeat(2000))
} else {
// Small tool result (< 1000 chars)
format!("Tool result: small result {}", i)
};
context.add_message(Message {
role: MessageRole::User,
content,
});
}
}
// Trigger thinning at 50%
context.used_tokens = 5000;
let (summary, _chars_saved) = context.thin_context();
println!("Thinning summary: {}", summary);
// Should have thinned at least 1 large tool result in the first third
assert!(summary.contains("1 tool result"), "Summary was: {}", summary);
assert!(summary.contains("50%"));
// Check that the large tool results were replaced
let first_third_end = context.conversation_history.len() / 3;
for i in 0..first_third_end {
if let Some(msg) = context.conversation_history.get(i) {
if matches!(msg.role, MessageRole::User) && msg.content.starts_with("Tool result:") {
if msg.content.len() > 1000 {
panic!("Found un-thinned large tool result at index {}", i);
}
}
}
}
}
#[test]
fn test_thin_write_file_tool_calls() {
let mut context = ContextWindow::new(10000);
// Add some messages including a write_file tool call with large content
context.add_message(Message {
role: MessageRole::User,
content: "Please create a large file".to_string(),
});
// Add an assistant message with a write_file tool call containing large content
let large_content = "x".repeat(1500);
let tool_call_json = format!(
r#"{{"tool": "write_file", "args": {{"file_path": "test.txt", "content": "{}"}}}}"#,
large_content
);
context.add_message(Message {
role: MessageRole::Assistant,
content: format!("I'll create that file.\n\n{}", tool_call_json),
});
context.add_message(Message {
role: MessageRole::User,
content: "Tool result: ✅ Successfully wrote 1500 lines".to_string(),
});
// Add more messages to ensure we have enough for "first third" logic
for i in 0..6 {
context.add_message(Message {
role: MessageRole::Assistant,
content: format!("Response {}", i),
});
}
// Trigger thinning at 50%
context.used_tokens = 5000;
let (summary, _chars_saved) = context.thin_context();
println!("Thinning summary: {}", summary);
// Should have thinned the write_file tool call
assert!(summary.contains("tool call") || summary.contains("chars saved"));
// Check that the large content was replaced with a file reference
let first_third_end = context.conversation_history.len() / 3;
for i in 0..first_third_end {
if let Some(msg) = context.conversation_history.get(i) {
if matches!(msg.role, MessageRole::Assistant) && msg.content.contains("write_file") {
// The content should now reference an external file
assert!(msg.content.contains("<content saved to"));
assert!(!msg.content.contains(&large_content));
}
}
}
}
#[test]
fn test_thin_str_replace_tool_calls() {
let mut context = ContextWindow::new(10000);
// Add some messages including a str_replace tool call with large diff
context.add_message(Message {
role: MessageRole::User,
content: "Please update the file".to_string(),
});
// Add an assistant message with a str_replace tool call containing large diff
let large_diff = format!("--- old\n{}\n+++ new\n{}", "-old line\n".repeat(100), "+new line\n".repeat(100));
let tool_call_json = format!(
r#"{{"tool": "str_replace", "args": {{"file_path": "test.txt", "diff": "{}"}}}}"#,
large_diff.replace('\n', "\\n")
);
context.add_message(Message {
role: MessageRole::Assistant,
content: format!("I'll update that file.\n\n{}", tool_call_json),
});
context.add_message(Message {
role: MessageRole::User,
content: "Tool result: ✅ applied unified diff".to_string(),
});
// Add more messages to ensure we have enough for "first third" logic
for i in 0..6 {
context.add_message(Message {
role: MessageRole::Assistant,
content: format!("Response {}", i),
});
}
// Trigger thinning at 50%
context.used_tokens = 5000;
let (summary, _chars_saved) = context.thin_context();
println!("Thinning summary: {}", summary);
// Should have thinned the str_replace tool call
assert!(summary.contains("tool call") || summary.contains("chars saved"));
// Check that the large diff was replaced with a file reference
let first_third_end = context.conversation_history.len() / 3;
for i in 0..first_third_end {
if let Some(msg) = context.conversation_history.get(i) {
if matches!(msg.role, MessageRole::Assistant) && msg.content.contains("str_replace") {
// The diff should now reference an external file
assert!(msg.content.contains("<diff saved to"));
// Should not contain the large diff content
assert!(!msg.content.contains("old line"));
}
}
}
}
#[test]
fn test_thin_context_no_large_results() {
let mut context = ContextWindow::new(10000);
// Add only small messages
for i in 0..9 {
context.add_message(Message {
role: MessageRole::User,
content: format!("Tool result: small {}", i),
});
}
context.used_tokens = 5000;
let (summary, _chars_saved) = context.thin_context();
// Should report no large results found
assert!(summary.contains("no large tool results or tool calls found"));
}
#[test]
fn test_thin_context_only_affects_first_third() {
let mut context = ContextWindow::new(10000);
// Add 12 messages (first third = 4 messages)
for i in 0..12 {
let content = if i % 2 == 1 {
// All odd indices are large tool results
format!("Tool result: {}", "x".repeat(1500))
} else {
format!("Assistant message {}", i)
};
let role = if i % 2 == 1 {
MessageRole::User
} else {
MessageRole::Assistant
};
context.add_message(Message { role, content });
}
context.used_tokens = 5000;
let (summary, _chars_saved) = context.thin_context();
// First third is 4 messages (indices 0-3), so only indices 1 and 3 should be thinned
// That's 2 tool results
assert!(summary.contains("2 tool results"));
// Check that messages after the first third are NOT thinned
let first_third_end = context.conversation_history.len() / 3;
for i in first_third_end..context.conversation_history.len() {
if let Some(msg) = context.conversation_history.get(i) {
if matches!(msg.role, MessageRole::User) && msg.content.starts_with("Tool result:") {
// These should still be large (not thinned)
if i % 2 == 1 {
assert!(msg.content.len() > 1000,
"Message at index {} should not have been thinned", i);
}
}
}
}
}

View File

@@ -0,0 +1,94 @@
use g3_core::ContextWindow;
use g3_providers::Usage;
#[test]
fn test_token_accumulation() {
let mut window = ContextWindow::new(10000);
// First API call: 100 prompt + 50 completion = 150 total
let usage1 = Usage {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
};
window.update_usage_from_response(&usage1);
assert_eq!(window.used_tokens, 150, "First call should have 150 tokens");
assert_eq!(window.cumulative_tokens, 150, "Cumulative should be 150");
// Second API call: 200 prompt + 75 completion = 275 total
let usage2 = Usage {
prompt_tokens: 200,
completion_tokens: 75,
total_tokens: 275,
};
window.update_usage_from_response(&usage2);
assert_eq!(window.used_tokens, 425, "Second call should accumulate to 425 tokens");
assert_eq!(window.cumulative_tokens, 425, "Cumulative should be 425");
// Third API call with SMALLER token count: 50 prompt + 25 completion = 75 total
let usage3 = Usage {
prompt_tokens: 50,
completion_tokens: 25,
total_tokens: 75,
};
window.update_usage_from_response(&usage3);
assert_eq!(window.used_tokens, 500, "Third call should accumulate to 500 tokens");
assert_eq!(window.cumulative_tokens, 500, "Cumulative should be 500");
// Verify tokens never decrease
assert!(window.used_tokens >= 425, "Token count should never decrease!");
}
#[test]
fn test_add_streaming_tokens() {
let mut window = ContextWindow::new(10000);
// Add some streaming tokens
window.add_streaming_tokens(100);
assert_eq!(window.used_tokens, 100);
assert_eq!(window.cumulative_tokens, 100);
// Add more
window.add_streaming_tokens(50);
assert_eq!(window.used_tokens, 150);
assert_eq!(window.cumulative_tokens, 150);
// Now update from provider response
let usage = Usage {
prompt_tokens: 80,
completion_tokens: 40,
total_tokens: 120,
};
window.update_usage_from_response(&usage);
// Should ADD to existing, not replace
assert_eq!(window.used_tokens, 270, "Should add 120 to existing 150");
assert_eq!(window.cumulative_tokens, 270);
}
#[test]
fn test_percentage_calculation() {
let mut window = ContextWindow::new(1000);
// Add tokens via provider response
let usage = Usage {
prompt_tokens: 150,
completion_tokens: 100,
total_tokens: 250,
};
window.update_usage_from_response(&usage);
assert_eq!(window.percentage_used(), 25.0);
assert_eq!(window.remaining_tokens(), 750);
// Add more tokens
let usage2 = Usage {
prompt_tokens: 300,
completion_tokens: 200,
total_tokens: 500,
};
window.update_usage_from_response(&usage2);
assert_eq!(window.percentage_used(), 75.0);
assert_eq!(window.remaining_tokens(), 250);
}

View File

@@ -7,6 +7,7 @@ description = "Code execution engine for G3 AI agent"
[dependencies]
tokio = { workspace = true }
anyhow = { workspace = true }
futures = "0.3"
thiserror = { workspace = true }
tracing = { workspace = true }
regex = "1.0"

View File

@@ -166,6 +166,31 @@ impl CodeExecutor {
/// Execute Bash code
async fn execute_bash(&self, code: &str) -> Result<ExecutionResult> {
// Check if this is a detached/daemon command that should run independently
let is_detached = code.trim_start().starts_with("setsid ")
|| code.trim_start().starts_with("nohup ")
|| code.contains(" disown")
|| (code.contains(" &") && (code.contains("nohup") || code.contains("setsid")));
if is_detached {
// For detached commands, just spawn and return immediately
use std::process::Stdio;
Command::new("bash")
.arg("-c")
.arg(code)
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()?;
return Ok(ExecutionResult {
stdout: "✅ Command launched in background (detached process)".to_string(),
stderr: String::new(),
exit_code: 0,
success: true,
});
}
let output = Command::new("bash")
.arg("-c")
.arg(code)
@@ -203,3 +228,105 @@ impl Default for CodeExecutor {
Self::new()
}
}
/// Trait for receiving streaming output from command execution
pub trait OutputReceiver: Send + Sync {
/// Called when a new line of output is available
fn on_output_line(&self, line: &str);
}
impl CodeExecutor {
/// Execute bash command with streaming output
pub async fn execute_bash_streaming<R: OutputReceiver>(
&self,
code: &str,
receiver: &R
) -> Result<ExecutionResult> {
use std::process::Stdio;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command as TokioCommand;
// Check if this is a detached/daemon command that should run independently
// Look for patterns like: setsid, nohup with &, or explicit backgrounding with disown
let is_detached = code.trim_start().starts_with("setsid ")
|| code.trim_start().starts_with("nohup ")
|| code.contains(" disown")
|| (code.contains(" &") && (code.contains("nohup") || code.contains("setsid")));
if is_detached {
// For detached commands, just spawn and return immediately
TokioCommand::new("bash")
.arg("-c")
.arg(code)
.spawn()?;
// Don't wait for the process - it's meant to run independently
return Ok(ExecutionResult {
stdout: "✅ Command launched in background (detached process)".to_string(),
stderr: String::new(),
exit_code: 0,
success: true,
});
}
let mut child = TokioCommand::new("bash")
.arg("-c")
.arg(code)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
let stdout = child.stdout.take().unwrap();
let stderr = child.stderr.take().unwrap();
let stdout_reader = BufReader::new(stdout);
let stderr_reader = BufReader::new(stderr);
let mut stdout_lines = stdout_reader.lines();
let mut stderr_lines = stderr_reader.lines();
let mut stdout_output = Vec::new();
let mut stderr_output = Vec::new();
// Read output lines as they come
loop {
tokio::select! {
line = stdout_lines.next_line() => {
match line {
Ok(Some(line)) => {
receiver.on_output_line(&line);
stdout_output.push(line);
}
Ok(None) => break, // EOF
Err(e) => {
error!("Error reading stdout: {}", e);
break;
}
}
}
line = stderr_lines.next_line() => {
match line {
Ok(Some(line)) => {
receiver.on_output_line(&line.to_string());
stderr_output.push(line);
}
Ok(None) => {}, // stderr EOF, continue
Err(e) => {
error!("Error reading stderr: {}", e);
}
}
}
else => break
}
}
let status = child.wait().await?;
Ok(ExecutionResult {
stdout: stdout_output.join("\n"),
stderr: stderr_output.join("\n"),
exit_code: status.code().unwrap_or(-1),
success: status.success(),
})
}
}

View File

@@ -16,3 +16,16 @@ async-trait = "0.1"
tokio-stream = "0.1"
futures-util = "0.3"
bytes = "1.0"
# OAuth dependencies
axum = "0.7"
base64 = "0.22"
chrono = { version = "0.4", features = ["serde"] }
sha2 = "0.10"
url = "2.5"
webbrowser = "1.0"
nanoid = "0.4"
serde_urlencoded = "0.7"
tokio-util = "0.7"
dirs = "5.0"
llama_cpp = { version = "0.3.2", features = ["metal"] }
shellexpand = "3.1"

View File

@@ -41,6 +41,7 @@
//! max_tokens: Some(1000),
//! temperature: Some(0.7),
//! stream: false,
//! tools: None,
//! };
//!
//! // Get a completion
@@ -74,6 +75,7 @@
//! max_tokens: Some(1000),
//! temperature: Some(0.7),
//! stream: true,
//! tools: None,
//! };
//!
//! let mut stream = provider.stream(request).await?;
@@ -154,8 +156,9 @@ impl AnthropicProvider {
.post(ANTHROPIC_API_URL)
.header("x-api-key", &self.api_key)
.header("anthropic-version", ANTHROPIC_VERSION)
// Anthropic beta 1m context window. Enable if needed. It costs extra, so check first.
// .header("anthropic-beta", "context-1m-2025-08-07")
.header("content-type", "application/json");
if streaming {
builder = builder.header("accept", "text/event-stream");
}
@@ -194,20 +197,6 @@ impl AnthropicProvider {
.collect()
}
fn convert_anthropic_tool_calls(&self, content: &[AnthropicContent]) -> Vec<ToolCall> {
content
.iter()
.filter_map(|c| match c {
AnthropicContent::ToolUse { id, name, input } => Some(ToolCall {
id: id.clone(),
tool: name.clone(),
args: input.clone(),
}),
_ => None,
})
.collect()
}
fn convert_messages(&self, messages: &[Message]) -> Result<(Option<String>, Vec<AnthropicMessage>)> {
let mut system_message = None;
let mut anthropic_messages = Vec::new();
@@ -281,26 +270,43 @@ impl AnthropicProvider {
&self,
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
tx: mpsc::Sender<Result<CompletionChunk>>,
) {
) -> Option<Usage> {
let mut buffer = String::new();
let mut current_tool_calls: Vec<ToolCall> = Vec::new();
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
let mut accumulated_usage: Option<Usage> = None;
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
let mut message_stopped = false; // Track if we've received message_stop
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
let chunk_str = match std::str::from_utf8(&chunk) {
Ok(s) => s,
// Append new bytes to our buffer
byte_buffer.extend_from_slice(&chunk);
// Try to convert the entire buffer to UTF-8
let chunk_str = match std::str::from_utf8(&byte_buffer) {
Ok(s) => {
// Successfully converted entire buffer, clear it and use the string
let result = s.to_string();
byte_buffer.clear();
result
}
Err(e) => {
error!("Invalid UTF-8 in stream chunk: {}", e);
let _ = tx
.send(Err(anyhow!("Invalid UTF-8 in stream chunk: {}", e)))
.await;
return;
// Check if this is an incomplete sequence at the end
let valid_up_to = e.valid_up_to();
if valid_up_to > 0 {
// We have some valid UTF-8, extract it and keep the rest for next iteration
let valid_bytes = byte_buffer.drain(..valid_up_to).collect::<Vec<_>>();
std::str::from_utf8(&valid_bytes).unwrap().to_string()
} else {
// No valid UTF-8 at all, skip this chunk and continue
continue;
}
}
};
buffer.push_str(chunk_str);
buffer.push_str(&chunk_str);
// Process complete lines
while let Some(line_end) = buffer.find('\n') {
@@ -311,6 +317,12 @@ impl AnthropicProvider {
continue;
}
// If we've already sent the final chunk, skip processing more events
if message_stopped {
debug!("Skipping event after message_stop: {}", line);
continue;
}
// Parse Server-Sent Events format
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
@@ -318,20 +330,34 @@ impl AnthropicProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
};
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
}
return;
return accumulated_usage;
}
debug!("Raw Claude API JSON: {}", data);
match serde_json::from_str::<AnthropicStreamEvent>(data) {
Ok(event) => {
debug!("Parsed event: {:?}", event);
debug!("Parsed event type: {}, event: {:?}", event.event_type, event);
match event.event_type.as_str() {
"message_start" => {
// Extract usage data from message_start event
if let Some(message) = event.message {
if let Some(usage) = message.usage {
accumulated_usage = Some(Usage {
prompt_tokens: usage.input_tokens,
completion_tokens: usage.output_tokens,
total_tokens: usage.input_tokens + usage.output_tokens,
});
debug!("Captured usage from message_start: {:?}", accumulated_usage);
}
}
}
"content_block_start" => {
debug!("Received content_block_start event: {:?}", event);
if let Some(content_block) = event.content_block {
@@ -354,11 +380,12 @@ impl AnthropicProvider {
let chunk = CompletionChunk {
content: String::new(),
finished: false,
usage: None,
tool_calls: Some(vec![tool_call]),
};
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return;
return accumulated_usage;
}
} else {
// Arguments are empty, we'll accumulate them from partial_json
@@ -380,11 +407,12 @@ impl AnthropicProvider {
let chunk = CompletionChunk {
content: text,
finished: false,
usage: None,
tool_calls: None,
};
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return;
return accumulated_usage;
}
}
// Handle partial JSON for tool calls
@@ -419,25 +447,29 @@ impl AnthropicProvider {
let chunk = CompletionChunk {
content: String::new(),
finished: false,
usage: None,
tool_calls: Some(current_tool_calls.clone()),
};
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return;
return accumulated_usage;
}
}
}
"message_stop" => {
debug!("Received message stop event");
message_stopped = true;
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
};
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
}
return;
// Don't return here - let the stream naturally exhaust
// This prevents dropping the sender prematurely
}
"error" => {
if let Some(error) = event.error {
@@ -445,7 +477,7 @@ impl AnthropicProvider {
let _ = tx
.send(Err(anyhow!("Anthropic API error: {:?}", error)))
.await;
return;
break; // Break to let stream exhaust naturally
}
}
_ => {
@@ -464,7 +496,10 @@ impl AnthropicProvider {
Err(e) => {
error!("Stream error: {}", e);
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
return;
// Don't return here either - let the stream exhaust naturally
// The error has been sent to the receiver, so it will handle it
// Breaking here ensures we clean up properly
break;
}
}
}
@@ -473,9 +508,11 @@ impl AnthropicProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) },
};
let _ = tx.send(Ok(final_chunk)).await;
accumulated_usage
}
}
@@ -596,7 +633,14 @@ impl LLMProvider for AnthropicProvider {
// Spawn task to process the stream
let provider = self.clone();
tokio::spawn(async move {
provider.parse_streaming_response(stream, tx).await;
let usage = provider.parse_streaming_response(stream, tx).await;
// Log the final usage if available
if let Some(usage) = usage {
debug!(
"Stream completed with usage - prompt: {}, completion: {}, total: {}",
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
);
}
});
Ok(ReceiverStream::new(rx))
@@ -668,14 +712,8 @@ enum AnthropicContent {
#[derive(Debug, Deserialize)]
struct AnthropicResponse {
id: String,
#[serde(rename = "type")]
response_type: String,
role: String,
content: Vec<AnthropicContent>,
model: String,
stop_reason: Option<String>,
stop_sequence: Option<String>,
usage: AnthropicUsage,
}
@@ -697,12 +735,18 @@ struct AnthropicStreamEvent {
error: Option<AnthropicError>,
#[serde(default)]
content_block: Option<AnthropicContent>,
#[serde(default)]
message: Option<AnthropicStreamMessage>,
}
#[derive(Debug, Deserialize)]
struct AnthropicStreamMessage {
#[serde(default)]
usage: Option<AnthropicUsage>,
}
#[derive(Debug, Deserialize)]
struct AnthropicDelta {
#[serde(rename = "type")]
delta_type: Option<String>,
text: Option<String>,
partial_json: Option<String>,
}
@@ -710,7 +754,9 @@ struct AnthropicDelta {
#[derive(Debug, Deserialize)]
struct AnthropicError {
#[serde(rename = "type")]
#[allow(dead_code)]
error_type: String,
#[allow(dead_code)]
message: String,
}
@@ -813,32 +859,4 @@ mod tests {
assert!(anthropic_tools[0].input_schema.required.is_some());
assert_eq!(anthropic_tools[0].input_schema.required.as_ref().unwrap()[0], "location");
}
#[test]
fn test_tool_call_conversion() {
let provider = AnthropicProvider::new(
"test-key".to_string(),
None,
None,
None,
).unwrap();
let content = vec![
AnthropicContent::Text {
text: "I'll help you get the weather.".to_string(),
},
AnthropicContent::ToolUse {
id: "toolu_123".to_string(),
name: "get_weather".to_string(),
input: serde_json::json!({"location": "San Francisco, CA"}),
},
];
let tool_calls = provider.convert_anthropic_tool_calls(&content);
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id, "toolu_123");
assert_eq!(tool_calls[0].tool, "get_weather");
assert_eq!(tool_calls[0].args["location"], "San Francisco, CA");
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
use anyhow::Result;
use g3_providers::{
use crate::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
MessageRole, Usage,
};
@@ -8,22 +8,18 @@ use llama_cpp::{
LlamaModel, LlamaParams, LlamaSession, SessionParams,
};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error, info, warn};
use tracing::{debug, error, info};
pub struct EmbeddedProvider {
model: Arc<LlamaModel>,
session: Arc<Mutex<LlamaSession>>,
model_name: String,
max_tokens: u32,
temperature: f32,
context_length: u32,
generation_active: Arc<AtomicBool>,
}
impl EmbeddedProvider {
@@ -71,8 +67,10 @@ impl EmbeddedProvider {
.map_err(|e| anyhow::anyhow!("Failed to load model: {}", e))?;
// Create session with parameters
let mut session_params = SessionParams::default();
session_params.n_ctx = context_size;
let mut session_params = SessionParams {
n_ctx: context_size,
..Default::default()
};
if let Some(threads) = threads {
session_params.n_threads = threads;
}
@@ -84,13 +82,11 @@ impl EmbeddedProvider {
info!("Successfully loaded {} model", model_type);
Ok(Self {
model: Arc::new(model),
session: Arc::new(Mutex::new(session)),
model_name: format!("embedded-{}", model_type),
max_tokens: max_tokens.unwrap_or(2048),
temperature: temperature.unwrap_or(0.1),
context_length: context_size,
generation_active: Arc::new(AtomicBool::new(false)),
})
}
@@ -143,7 +139,7 @@ impl EmbeddedProvider {
in_conversation = false;
}
MessageRole::Assistant => {
formatted.push_str(" ");
formatted.push(' ');
formatted.push_str(&message.content);
formatted.push_str("</s> ");
in_conversation = false;
@@ -152,8 +148,8 @@ impl EmbeddedProvider {
}
// If the last message was from user, add a space for the assistant's response
if messages.last().map_or(false, |m| matches!(m.role, MessageRole::User)) {
formatted.push_str(" ");
if messages.last().is_some_and(|m| matches!(m.role, MessageRole::User)) {
formatted.push(' ');
}
formatted
@@ -429,7 +425,6 @@ impl EmbeddedProvider {
// Download the Qwen 2.5 7B model if it doesn't exist
fn download_qwen_model(model_path: &Path) -> Result<()> {
use std::fs;
use std::io::Write;
use std::process::Command;
const MODEL_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-7B-Instruct-GGUF/resolve/main/qwen2.5-7b-instruct-q3_k_m.gguf";
@@ -446,7 +441,7 @@ impl EmbeddedProvider {
// Use curl with progress bar for download
let output = Command::new("curl")
.args(&[
.args([
"-L", // Follow redirects
"-#", // Show progress bar
"-f", // Fail on HTTP errors
@@ -662,6 +657,7 @@ impl LLMProvider for EmbeddedProvider {
let chunk = CompletionChunk {
content: remaining_to_send.to_string(),
finished: false,
usage: None,
tool_calls: None,
};
let _ = tx.blocking_send(Ok(chunk));
@@ -688,6 +684,7 @@ impl LLMProvider for EmbeddedProvider {
let chunk = CompletionChunk {
content: remaining_to_send.to_string(),
finished: false,
usage: None,
tool_calls: None,
};
let _ = tx.blocking_send(Ok(chunk));
@@ -721,6 +718,7 @@ impl LLMProvider for EmbeddedProvider {
let chunk = CompletionChunk {
content: to_send.to_string(),
finished: false,
usage: None,
tool_calls: None,
};
if tx.blocking_send(Ok(chunk)).is_err() {
@@ -736,6 +734,7 @@ impl LLMProvider for EmbeddedProvider {
let chunk = CompletionChunk {
content: unsent_tokens.clone(),
finished: false,
usage: None,
tool_calls: None,
};
if tx.blocking_send(Ok(chunk)).is_err() {
@@ -756,6 +755,7 @@ impl LLMProvider for EmbeddedProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: None, // Embedded models calculate usage differently
tool_calls: None,
};
let _ = tx.blocking_send(Ok(final_chunk));

View File

@@ -67,6 +67,7 @@ pub struct CompletionChunk {
pub content: String,
pub finished: bool,
pub tool_calls: Option<Vec<ToolCall>>,
pub usage: Option<Usage>, // Add usage tracking for streaming
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -84,8 +85,17 @@ pub struct Tool {
}
pub mod anthropic;
pub mod databricks;
pub mod embedded;
pub mod oauth;
pub mod ollama;
pub mod openai;
pub use anthropic::AnthropicProvider;
pub use databricks::DatabricksProvider;
pub use embedded::EmbeddedProvider;
pub use ollama::OllamaProvider;
pub use openai::OpenAIProvider;
/// Provider registry for managing multiple LLM providers
pub struct ProviderRegistry {

View File

@@ -0,0 +1,463 @@
use anyhow::Result;
use axum::{extract::Query, response::Html, routing::get, Router};
use base64::Engine;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sha2::Digest;
use std::{collections::HashMap, fs, net::SocketAddr, path::PathBuf, sync::Arc};
use tokio::sync::{oneshot, Mutex as TokioMutex};
use url::Url;
#[derive(Debug, Clone)]
struct OidcEndpoints {
authorization_endpoint: String,
token_endpoint: String,
}
#[derive(Serialize, Deserialize)]
struct TokenData {
/// The access token used to authenticate API requests
access_token: String,
/// Optional refresh token that can be used to obtain a new access token
/// when the current one expires, enabling offline access without user interaction
refresh_token: Option<String>,
/// When the access token expires (if known)
/// Used to determine when a token needs to be refreshed
expires_at: Option<DateTime<Utc>>,
}
struct TokenCache {
cache_path: PathBuf,
}
fn get_base_path() -> PathBuf {
// Use a similar pattern to Goose but for g3
// macOS/Linux: ~/.config/g3/databricks/oauth
// Windows: ~\AppData\Roaming\g3\config\databricks\oauth\
let mut path = dirs::config_dir().unwrap_or_else(|| PathBuf::from("."));
path.push("g3");
path.push("databricks");
path.push("oauth");
path
}
impl TokenCache {
fn new(host: &str, client_id: &str, scopes: &[String]) -> Self {
let mut hasher = sha2::Sha256::new();
hasher.update(host.as_bytes());
hasher.update(client_id.as_bytes());
hasher.update(scopes.join(",").as_bytes());
let hash = format!("{:x}", hasher.finalize());
fs::create_dir_all(get_base_path()).unwrap_or(());
let cache_path = get_base_path().join(format!("{}.json", hash));
Self { cache_path }
}
fn load_token(&self) -> Option<TokenData> {
if let Ok(contents) = fs::read_to_string(&self.cache_path) {
if let Ok(token_data) = serde_json::from_str::<TokenData>(&contents) {
// Only return tokens that have a refresh token
if token_data.refresh_token.is_some() {
// If token is not expired, return it for immediate use
if let Some(expires_at) = token_data.expires_at {
if expires_at > Utc::now() {
return Some(token_data);
}
// If token is expired but has refresh token, return it so we can refresh
return Some(token_data);
}
// No expiration time but has refresh token, return it
return Some(token_data);
}
// Token doesn't have a refresh token, ignore it to force a new OAuth flow
}
}
None
}
fn save_token(&self, token_data: &TokenData) -> Result<()> {
if let Some(parent) = self.cache_path.parent() {
fs::create_dir_all(parent)?;
}
let contents = serde_json::to_string(token_data)?;
fs::write(&self.cache_path, contents)?;
Ok(())
}
}
async fn get_workspace_endpoints(host: &str) -> Result<OidcEndpoints> {
let base_url = Url::parse(host).expect("Invalid host URL");
let oidc_url = base_url
.join("oidc/.well-known/oauth-authorization-server")
.expect("Invalid OIDC URL");
let client = reqwest::Client::new();
let resp = client.get(oidc_url.clone()).send().await?;
if !resp.status().is_success() {
return Err(anyhow::anyhow!(
"Failed to get OIDC configuration from {}",
oidc_url
));
}
let oidc_config: Value = resp.json().await?;
let authorization_endpoint = oidc_config
.get("authorization_endpoint")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("authorization_endpoint not found in OIDC configuration"))?
.to_string();
let token_endpoint = oidc_config
.get("token_endpoint")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("token_endpoint not found in OIDC configuration"))?
.to_string();
Ok(OidcEndpoints {
authorization_endpoint,
token_endpoint,
})
}
struct OAuthFlow {
endpoints: OidcEndpoints,
client_id: String,
redirect_url: String,
scopes: Vec<String>,
state: String,
verifier: String,
}
impl OAuthFlow {
fn new(
endpoints: OidcEndpoints,
client_id: String,
redirect_url: String,
scopes: Vec<String>,
) -> Self {
Self {
endpoints,
client_id,
redirect_url,
scopes,
state: nanoid::nanoid!(16),
verifier: nanoid::nanoid!(64),
}
}
/// Extracts token data from an OAuth 2.0 token response.
fn extract_token_data(
&self,
token_response: &Value,
old_refresh_token: Option<&str>,
) -> Result<TokenData> {
// Extract access token (required)
let access_token = token_response
.get("access_token")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("access_token not found in token response"))?
.to_string();
// Extract refresh token if available
let refresh_token = token_response
.get("refresh_token")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| old_refresh_token.map(|s| s.to_string()));
// Handle token expiration
let expires_at =
if let Some(expires_in) = token_response.get("expires_in").and_then(|v| v.as_u64()) {
// Traditional OAuth flow with expires_in seconds
Some(Utc::now() + chrono::Duration::seconds(expires_in as i64))
} else {
// If the server doesn't provide any expiration info, log it but don't set an expiration
tracing::debug!(
"No expiration information provided by server, token expiration unknown."
);
None
};
Ok(TokenData {
access_token,
refresh_token,
expires_at,
})
}
fn get_authorization_url(&self) -> String {
let challenge = {
let digest = sha2::Sha256::digest(self.verifier.as_bytes());
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest)
};
let params = [
("response_type", "code"),
("client_id", &self.client_id),
("redirect_uri", &self.redirect_url),
("scope", &self.scopes.join(" ")),
("state", &self.state),
("code_challenge", &challenge),
("code_challenge_method", "S256"),
];
format!(
"{}?{}",
self.endpoints.authorization_endpoint,
serde_urlencoded::to_string(params).unwrap()
)
}
async fn exchange_code_for_token(&self, code: &str) -> Result<TokenData> {
let params = [
("grant_type", "authorization_code"),
("code", code),
("redirect_uri", &self.redirect_url),
("code_verifier", &self.verifier),
("client_id", &self.client_id),
];
let client = reqwest::Client::new();
let resp = client
.post(&self.endpoints.token_endpoint)
.header("Content-Type", "application/x-www-form-urlencoded")
.form(&params)
.send()
.await?;
if !resp.status().is_success() {
let err_text = resp.text().await?;
return Err(anyhow::anyhow!(
"Failed to exchange code for token: {}",
err_text
));
}
let token_response: Value = resp.json().await?;
self.extract_token_data(&token_response, None)
}
async fn refresh_token(&self, refresh_token: &str) -> Result<TokenData> {
let params = [
("grant_type", "refresh_token"),
("refresh_token", refresh_token),
("client_id", &self.client_id),
];
tracing::debug!("Refreshing token using refresh_token");
let client = reqwest::Client::new();
let resp = client
.post(&self.endpoints.token_endpoint)
.header("Content-Type", "application/x-www-form-urlencoded")
.form(&params)
.send()
.await?;
if !resp.status().is_success() {
let err_text = resp.text().await?;
return Err(anyhow::anyhow!("Failed to refresh token: {}", err_text));
}
let token_response: Value = resp.json().await?;
self.extract_token_data(&token_response, Some(refresh_token))
}
async fn execute(&self) -> Result<TokenData> {
// Create a channel that will send the auth code from the app process
let (tx, rx) = oneshot::channel();
let state = self.state.clone();
let tx = Arc::new(TokioMutex::new(Some(tx)));
// Setup a server that will receive the redirect, capture the code, and display success/failure
let app = Router::new().route(
"/",
get(move |Query(params): Query<HashMap<String, String>>| {
let tx = Arc::clone(&tx);
let state = state.clone();
async move {
let code = params.get("code").cloned();
let received_state = params.get("state").cloned();
if let (Some(code), Some(received_state)) = (code, received_state) {
if received_state == state {
if let Some(sender) = tx.lock().await.take() {
if sender.send(code).is_ok() {
return Html(
"<h2>G3 Authentication Success</h2><p>You can close this window and return to your terminal.</p>",
);
}
}
Html("<h2>Error</h2><p>Authentication already completed.</p>")
} else {
Html("<h2>Error</h2><p>State mismatch.</p>")
}
} else {
Html("<h2>Error</h2><p>Authentication failed.</p>")
}
}
}),
);
// Start the server to accept the oauth code
let redirect_url = Url::parse(&self.redirect_url)?;
let port = redirect_url.port().unwrap_or(80);
let addr = SocketAddr::from(([127, 0, 0, 1], port));
let listener = tokio::net::TcpListener::bind(addr).await?;
let server_handle = tokio::spawn(async move {
let server = axum::serve(listener, app);
server.await.unwrap();
});
// Open the browser which will redirect with the code to the server
let authorization_url = self.get_authorization_url();
if std::env::var("G3_RETRO_MODE").is_err() {
println!("🔐 Opening browser for Databricks authentication...");
}
if webbrowser::open(&authorization_url).is_err() {
println!(
"Please open this URL in your browser:\n{}",
authorization_url
);
}
// Wait for the authorization code with a timeout
let code = tokio::time::timeout(
std::time::Duration::from_secs(120), // 2 minute timeout
rx,
)
.await
.map_err(|_| anyhow::anyhow!("Authentication timed out after 2 minutes"))??;
// Stop the server
server_handle.abort();
if std::env::var("G3_RETRO_MODE").is_err() {
println!("✅ Authentication successful! Exchanging code for token...");
}
// Exchange the code for a token
self.exchange_code_for_token(&code).await
}
}
pub async fn get_oauth_token_async(
host: &str,
client_id: &str,
redirect_url: &str,
scopes: &[String],
) -> Result<String> {
let token_cache = TokenCache::new(host, client_id, scopes);
// Try cache first
if let Some(token) = token_cache.load_token() {
// If token has an expiration time, check if it's expired
if let Some(expires_at) = token.expires_at {
if expires_at > Utc::now() {
tracing::debug!("Using cached token");
return Ok(token.access_token);
}
// Token is expired, will try to refresh below
tracing::debug!("Token is expired, attempting to refresh");
} else {
// No expiration time was provided by the server
tracing::debug!("Token has no expiration time, using cached token");
return Ok(token.access_token);
}
// Token is expired or has no expiration, try to refresh if we have a refresh token
if let Some(refresh_token) = token.refresh_token {
// Get endpoints for token refresh
match get_workspace_endpoints(host).await {
Ok(endpoints) => {
let flow = OAuthFlow::new(
endpoints,
client_id.to_string(),
redirect_url.to_string(),
scopes.to_vec(),
);
// Try to refresh the token
match flow.refresh_token(&refresh_token).await {
Ok(new_token) => {
if let Err(e) = token_cache.save_token(&new_token) {
tracing::warn!("Failed to save refreshed token: {}", e);
}
tracing::info!("Successfully refreshed token");
return Ok(new_token.access_token);
}
Err(e) => {
tracing::warn!(
"Failed to refresh token, will try new auth flow: {}",
e
);
// Continue to new auth flow
}
}
}
Err(e) => {
tracing::warn!("Failed to get endpoints for token refresh: {}", e);
// Continue to new auth flow
}
}
}
}
// Get endpoints and execute flow for a new token
let endpoints = get_workspace_endpoints(host).await?;
let flow = OAuthFlow::new(
endpoints,
client_id.to_string(),
redirect_url.to_string(),
scopes.to_vec(),
);
// Execute the OAuth flow and get token
let token = flow.execute().await?;
// Cache and return
token_cache.save_token(&token)?;
if std::env::var("G3_RETRO_MODE").is_err() {
println!("🎉 Databricks authentication complete!");
}
Ok(token.access_token)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_cache() -> Result<()> {
let cache = TokenCache::new(
"https://example.com",
"test-client",
&["scope1".to_string()],
);
// Test with expiration time
let token_data = TokenData {
access_token: "test-token".to_string(),
refresh_token: Some("test-refresh-token".to_string()),
expires_at: Some(Utc::now() + chrono::Duration::hours(1)),
};
cache.save_token(&token_data)?;
let loaded_token = cache.load_token().unwrap();
assert_eq!(loaded_token.access_token, token_data.access_token);
assert_eq!(loaded_token.refresh_token, token_data.refresh_token);
assert!(loaded_token.expires_at.is_some());
Ok(())
}
}

View File

@@ -0,0 +1,751 @@
//! Ollama LLM provider implementation for the g3-providers crate.
//!
//! This module provides an implementation of the `LLMProvider` trait for Ollama,
//! supporting both completion and streaming modes with native tool calling.
//!
//! # Features
//!
//! - Support for any Ollama model (llama3.2, mistral, qwen, etc.)
//! - Both completion and streaming response modes
//! - Native tool calling support for compatible models
//! - Configurable base URL (defaults to http://localhost:11434)
//! - Simple configuration with no authentication required
//!
//! # Usage
//!
//! ```rust,no_run
//! use g3_providers::{OllamaProvider, LLMProvider, CompletionRequest, Message, MessageRole};
//!
//! #[tokio::main]
//! async fn main() -> anyhow::Result<()> {
//! // Create the provider with default settings (localhost:11434)
//! let provider = OllamaProvider::new(
//! "llama3.2".to_string(),
//! None, // Optional: base_url
//! None, // Optional: max tokens
//! None, // Optional: temperature
//! )?;
//!
//! // Create a completion request
//! let request = CompletionRequest {
//! messages: vec![
//! Message {
//! role: MessageRole::User,
//! content: "Hello! How are you?".to_string(),
//! },
//! ],
//! max_tokens: Some(1000),
//! temperature: Some(0.7),
//! stream: false,
//! tools: None,
//! };
//!
//! // Get a completion
//! let response = provider.complete(request).await?;
//! println!("Response: {}", response.content);
//!
//! Ok(())
//! }
//! ```
use anyhow::{anyhow, Result};
use bytes::Bytes;
use futures_util::stream::StreamExt;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error, info, warn};
use crate::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
MessageRole, Tool, ToolCall, Usage,
};
const DEFAULT_BASE_URL: &str = "http://localhost:11434";
const DEFAULT_TIMEOUT_SECS: u64 = 600;
pub const OLLAMA_DEFAULT_MODEL: &str = "llama3.2";
pub const OLLAMA_KNOWN_MODELS: &[&str] = &[
"llama3.2",
"llama3.2:1b",
"llama3.2:3b",
"llama3.1",
"llama3.1:8b",
"llama3.1:70b",
"mistral",
"mistral-nemo",
"mixtral",
"qwen2.5",
"qwen2.5:7b",
"qwen2.5:14b",
"qwen2.5:32b",
"qwen2.5-coder",
"qwen2.5-coder:7b",
"qwen3-coder",
"phi3",
"gemma2",
];
#[derive(Debug, Clone)]
pub struct OllamaProvider {
client: Client,
base_url: String,
model: String,
max_tokens: Option<u32>,
temperature: f32,
}
impl OllamaProvider {
pub fn new(
model: String,
base_url: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
) -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.build()
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
let base_url = base_url
.unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
.trim_end_matches('/')
.to_string();
info!(
"Initialized Ollama provider with model: {} at {}",
model, base_url
);
Ok(Self {
client,
base_url,
model,
max_tokens,
temperature: temperature.unwrap_or(0.7),
})
}
fn convert_tools(&self, tools: &[Tool]) -> Vec<OllamaTool> {
tools
.iter()
.map(|tool| OllamaTool {
r#type: "function".to_string(),
function: OllamaFunction {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.input_schema.clone(),
},
})
.collect()
}
fn convert_messages(&self, messages: &[Message]) -> Result<Vec<OllamaMessage>> {
let mut ollama_messages = Vec::new();
for message in messages {
let role = match message.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
};
ollama_messages.push(OllamaMessage {
role: role.to_string(),
content: message.content.clone(),
tool_calls: None, // Only used in responses
});
}
if ollama_messages.is_empty() {
return Err(anyhow!("At least one message is required"));
}
Ok(ollama_messages)
}
fn create_request_body(
&self,
messages: &[Message],
tools: Option<&[Tool]>,
streaming: bool,
max_tokens: Option<u32>,
temperature: f32,
) -> Result<OllamaRequest> {
let ollama_messages = self.convert_messages(messages)?;
let ollama_tools = tools.map(|t| self.convert_tools(t));
let mut options = OllamaOptions {
temperature,
num_predict: max_tokens,
};
// If max_tokens is provided, use it; otherwise use the instance default
if max_tokens.is_none() {
options.num_predict = self.max_tokens;
}
let request = OllamaRequest {
model: self.model.clone(),
messages: ollama_messages,
tools: ollama_tools,
stream: streaming,
options,
};
Ok(request)
}
async fn parse_streaming_response(
&self,
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
tx: mpsc::Sender<Result<CompletionChunk>>,
) -> Option<Usage> {
let mut buffer = String::new();
let mut accumulated_usage: Option<Usage> = None;
let mut current_tool_calls: Vec<OllamaToolCall> = Vec::new();
let mut byte_buffer = Vec::new();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
// Append new bytes to our buffer
byte_buffer.extend_from_slice(&chunk);
// Try to convert the entire buffer to UTF-8
let chunk_str = match std::str::from_utf8(&byte_buffer) {
Ok(s) => {
let result = s.to_string();
byte_buffer.clear();
result
}
Err(e) => {
let valid_up_to = e.valid_up_to();
if valid_up_to > 0 {
let valid_bytes =
byte_buffer.drain(..valid_up_to).collect::<Vec<_>>();
std::str::from_utf8(&valid_bytes).unwrap().to_string()
} else {
continue;
}
}
};
buffer.push_str(&chunk_str);
// Process complete lines
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].trim().to_string();
buffer.drain(..line_end + 1);
if line.is_empty() {
continue;
}
// Ollama streaming sends JSON objects per line
match serde_json::from_str::<OllamaStreamChunk>(&line) {
Ok(chunk) => {
// Handle text content
if let Some(message) = &chunk.message {
let content = &message.content;
if !content.is_empty() {
debug!("Sending text chunk: '{}'", content);
let chunk = CompletionChunk {
content: content.clone(),
finished: false,
usage: None,
tool_calls: None,
};
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return accumulated_usage;
}
}
// Handle tool calls
if let Some(tool_calls) = &message.tool_calls {
current_tool_calls.extend(tool_calls.clone());
}
}
// Check if stream is done
if chunk.done.unwrap_or(false) {
debug!("Stream completed");
// Update usage if available
if let Some(eval_count) = chunk.eval_count {
accumulated_usage = Some(Usage {
prompt_tokens: chunk.prompt_eval_count.unwrap_or(0),
completion_tokens: eval_count,
total_tokens: chunk.prompt_eval_count.unwrap_or(0)
+ eval_count,
});
}
// Send final chunk with tool calls if any
let final_tool_calls: Vec<ToolCall> = current_tool_calls
.iter()
.map(|tc| ToolCall {
id: tc.function.name.clone(), // Ollama doesn't provide IDs
tool: tc.function.name.clone(),
args: tc.function.arguments.clone(),
})
.collect();
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if final_tool_calls.is_empty() {
None
} else {
Some(final_tool_calls)
},
};
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
}
return accumulated_usage;
}
}
Err(e) => {
debug!("Failed to parse Ollama stream chunk: {} - Line: {}", e, line);
// Don't error out, just continue
}
}
}
}
Err(e) => {
error!("Stream error: {}", e);
let error_msg = e.to_string();
if error_msg.contains("unexpected EOF") || error_msg.contains("connection") {
warn!("Connection terminated unexpectedly, treating as end of stream");
break;
} else {
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
}
return accumulated_usage;
}
}
}
// Send final chunk if we haven't already
let final_tool_calls: Vec<ToolCall> = current_tool_calls
.iter()
.map(|tc| ToolCall {
id: tc.function.name.clone(),
tool: tc.function.name.clone(),
args: tc.function.arguments.clone(),
})
.collect();
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if final_tool_calls.is_empty() {
None
} else {
Some(final_tool_calls)
},
};
let _ = tx.send(Ok(final_chunk)).await;
accumulated_usage
}
/// Fetch available models from the Ollama instance
pub async fn fetch_available_models(&self) -> Result<Vec<String>> {
let response = self
.client
.get(format!("{}/api/tags", self.base_url))
.send()
.await
.map_err(|e| anyhow!("Failed to fetch Ollama models: {}", e))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!(
"Failed to fetch Ollama models: {} - {}",
status,
error_text
));
}
let json: serde_json::Value = response.json().await?;
let models = json
.get("models")
.and_then(|v| v.as_array())
.ok_or_else(|| anyhow!("Unexpected response format: missing 'models' array"))?;
let model_names: Vec<String> = models
.iter()
.filter_map(|model| model.get("name").and_then(|n| n.as_str()).map(String::from))
.collect();
debug!("Found {} models in Ollama", model_names.len());
Ok(model_names)
}
}
#[async_trait::async_trait]
impl LLMProvider for OllamaProvider {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
debug!(
"Processing Ollama completion request with {} messages",
request.messages.len()
);
let max_tokens = request.max_tokens.or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
let request_body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
false,
max_tokens,
temperature,
)?;
debug!(
"Sending request to Ollama API: model={}, temperature={}",
self.model, request_body.options.temperature
);
let response = self
.client
.post(format!("{}/api/chat", self.base_url))
.json(&request_body)
.send()
.await
.map_err(|e| anyhow!("Failed to send request to Ollama API: {}", e))?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!("Ollama API error {}: {}", status, error_text));
}
let response_text = response.text().await?;
debug!("Raw Ollama API response: {}", response_text);
let ollama_response: OllamaResponse =
serde_json::from_str(&response_text).map_err(|e| {
anyhow!(
"Failed to parse Ollama response: {} - Response: {}",
e,
response_text
)
})?;
let content = ollama_response.message.content.clone();
let usage = Usage {
prompt_tokens: ollama_response.prompt_eval_count.unwrap_or(0),
completion_tokens: ollama_response.eval_count.unwrap_or(0),
total_tokens: ollama_response.prompt_eval_count.unwrap_or(0)
+ ollama_response.eval_count.unwrap_or(0),
};
debug!(
"Ollama completion successful: {} tokens generated",
usage.completion_tokens
);
Ok(CompletionResponse {
content,
usage,
model: self.model.clone(),
})
}
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
debug!(
"Processing Ollama request (non-streaming) with {} messages",
request.messages.len()
);
if let Some(ref tools) = request.tools {
debug!("Request has {} tools", tools.len());
for tool in tools.iter().take(5) {
debug!(" Tool: {}", tool.name);
}
}
let max_tokens = request.max_tokens.or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
let request_body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
false, // Use non-streaming mode to avoid streaming bugs
max_tokens,
temperature,
)?;
debug!(
"Sending request to Ollama API (stream=false): model={}, temperature={}",
self.model, request_body.options.temperature
);
let response = self
.client
.post(format!("{}/api/chat", self.base_url))
.json(&request_body)
.send()
.await
.map_err(|e| anyhow!("Failed to send request to Ollama API: {}", e))?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!("Ollama API error {}: {}", status, error_text));
}
// For non-streaming, parse the complete JSON response
let response_text = response.text().await?;
debug!("Raw Ollama API response: {}", response_text);
let ollama_response: OllamaResponse =
serde_json::from_str(&response_text).map_err(|e| {
anyhow!(
"Failed to parse Ollama response: {} - Response: {}",
e,
response_text
)
})?;
let (tx, rx) = mpsc::channel(100);
tokio::spawn(async move {
let content = ollama_response.message.content;
let usage = Usage {
prompt_tokens: ollama_response.prompt_eval_count.unwrap_or(0),
completion_tokens: ollama_response.eval_count.unwrap_or(0),
total_tokens: ollama_response.prompt_eval_count.unwrap_or(0)
+ ollama_response.eval_count.unwrap_or(0),
};
// Extract tool calls if present
let tool_calls: Option<Vec<ToolCall>> = ollama_response.message.tool_calls.map(|tcs| {
tcs.iter()
.map(|tc| ToolCall {
id: tc.function.name.clone(),
tool: tc.function.name.clone(),
args: tc.function.arguments.clone(),
})
.collect()
});
// Send content if any
if !content.is_empty() {
let _ = tx.send(Ok(CompletionChunk {
content,
finished: false,
usage: None,
tool_calls: None,
})).await;
}
// Send final chunk with usage and tool calls
let _ = tx.send(Ok(CompletionChunk {
content: String::new(),
finished: true,
usage: Some(usage),
tool_calls,
})).await;
});
Ok(ReceiverStream::new(rx))
}
fn name(&self) -> &str {
"ollama"
}
fn model(&self) -> &str {
&self.model
}
fn has_native_tool_calling(&self) -> bool {
// Most modern Ollama models support tool calling
// Models like llama3.2, qwen2.5, mistral, etc. have good tool support
true
}
}
// Ollama API request/response structures
#[derive(Debug, Serialize)]
struct OllamaRequest {
model: String,
messages: Vec<OllamaMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<OllamaTool>>,
stream: bool,
options: OllamaOptions,
}
#[derive(Debug, Serialize)]
struct OllamaOptions {
temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
num_predict: Option<u32>, // Ollama's equivalent of max_tokens
}
#[derive(Debug, Serialize)]
struct OllamaTool {
r#type: String,
function: OllamaFunction,
}
#[derive(Debug, Serialize)]
struct OllamaFunction {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct OllamaMessage {
role: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<OllamaToolCall>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct OllamaToolCall {
function: OllamaToolCallFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct OllamaToolCallFunction {
name: String,
arguments: serde_json::Value,
}
#[derive(Debug, Deserialize)]
struct OllamaResponse {
message: OllamaMessage,
#[allow(dead_code)]
done: bool,
#[allow(dead_code)]
total_duration: Option<u64>,
#[allow(dead_code)]
load_duration: Option<u64>,
prompt_eval_count: Option<u32>,
eval_count: Option<u32>,
}
#[derive(Debug, Deserialize)]
struct OllamaStreamChunk {
message: Option<OllamaMessage>,
done: Option<bool>,
prompt_eval_count: Option<u32>,
eval_count: Option<u32>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_creation() {
let provider = OllamaProvider::new(
"llama3.2".to_string(),
None,
Some(1000),
Some(0.7),
)
.unwrap();
assert_eq!(provider.model(), "llama3.2");
assert_eq!(provider.name(), "ollama");
assert!(provider.has_native_tool_calling());
}
#[test]
fn test_message_conversion() {
let provider = OllamaProvider::new(
"llama3.2".to_string(),
None,
None,
None,
)
.unwrap();
let messages = vec![
Message {
role: MessageRole::System,
content: "You are a helpful assistant.".to_string(),
},
Message {
role: MessageRole::User,
content: "Hello!".to_string(),
},
];
let ollama_messages = provider.convert_messages(&messages).unwrap();
assert_eq!(ollama_messages.len(), 2);
assert_eq!(ollama_messages[0].role, "system");
assert_eq!(ollama_messages[1].role, "user");
}
#[test]
fn test_tool_conversion() {
let provider = OllamaProvider::new(
"llama3.2".to_string(),
None,
None,
None,
)
.unwrap();
let tools = vec![Tool {
name: "get_weather".to_string(),
description: "Get the current weather".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}),
}];
let ollama_tools = provider.convert_tools(&tools);
assert_eq!(ollama_tools.len(), 1);
assert_eq!(ollama_tools[0].r#type, "function");
assert_eq!(ollama_tools[0].function.name, "get_weather");
}
#[test]
fn test_custom_base_url() {
let provider = OllamaProvider::new(
"llama3.2".to_string(),
Some("http://custom:11434".to_string()),
None,
None,
)
.unwrap();
assert_eq!(provider.base_url, "http://custom:11434");
}
}

View File

@@ -0,0 +1,495 @@
use anyhow::Result;
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::stream::StreamExt;
use reqwest::Client;
use serde::Deserialize;
use serde_json::json;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error};
use crate::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider,
Message, MessageRole, Tool, ToolCall, Usage,
};
#[derive(Clone)]
pub struct OpenAIProvider {
client: Client,
api_key: String,
model: String,
base_url: String,
max_tokens: Option<u32>,
_temperature: Option<f32>,
}
impl OpenAIProvider {
pub fn new(
api_key: String,
model: Option<String>,
base_url: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
) -> Result<Self> {
Ok(Self {
client: Client::new(),
api_key,
model: model.unwrap_or_else(|| "gpt-4o".to_string()),
base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
max_tokens,
_temperature: temperature,
})
}
fn create_request_body(
&self,
messages: &[Message],
tools: Option<&[Tool]>,
stream: bool,
max_tokens: Option<u32>,
_temperature: Option<f32>,
) -> serde_json::Value {
let mut body = json!({
"model": self.model,
"messages": convert_messages(messages),
"stream": stream,
});
if let Some(max_tokens) = max_tokens.or(self.max_tokens) {
body["max_completion_tokens"] = json!(max_tokens);
}
// OpenAI calls with temp setting seem to fail, so don't send one.
// if let Some(temperature) = temperature.or(self.temperature) {
// body["temperature"] = json!(temperature);
// }
if let Some(tools) = tools {
if !tools.is_empty() {
body["tools"] = json!(convert_tools(tools));
}
}
if stream {
body["stream_options"] = json!({
"include_usage": true,
});
}
body
}
async fn parse_streaming_response(
&self,
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
tx: mpsc::Sender<Result<CompletionChunk>>,
) -> Option<Usage> {
let mut buffer = String::new();
let mut accumulated_content = String::new();
let mut accumulated_usage: Option<Usage> = None;
let mut current_tool_calls: Vec<OpenAIStreamingToolCall> = Vec::new();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
let chunk_str = match std::str::from_utf8(&chunk) {
Ok(s) => s,
Err(e) => {
error!("Failed to parse chunk as UTF-8: {}", e);
continue;
}
};
buffer.push_str(chunk_str);
// Process complete lines
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].trim().to_string();
buffer.drain(..line_end + 1);
if line.is_empty() {
continue;
}
// Parse Server-Sent Events format
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
debug!("Received stream completion marker");
// Send final chunk with accumulated content and tool calls
if !accumulated_content.is_empty() || !current_tool_calls.is_empty() {
let tool_calls = if current_tool_calls.is_empty() {
None
} else {
Some(
current_tool_calls
.iter()
.filter_map(|tc| tc.to_tool_call())
.collect(),
)
};
let final_chunk = CompletionChunk {
content: accumulated_content.clone(),
finished: true,
tool_calls,
usage: accumulated_usage.clone(),
};
let _ = tx.send(Ok(final_chunk)).await;
}
return accumulated_usage;
}
// Parse the JSON data
match serde_json::from_str::<OpenAIStreamChunk>(data) {
Ok(chunk_data) => {
// Handle content
for choice in &chunk_data.choices {
if let Some(content) = &choice.delta.content {
accumulated_content.push_str(content);
let chunk = CompletionChunk {
content: content.clone(),
finished: false,
tool_calls: None,
usage: None,
};
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return accumulated_usage;
}
}
// Handle tool calls
if let Some(delta_tool_calls) = &choice.delta.tool_calls {
for delta_tool_call in delta_tool_calls {
if let Some(index) = delta_tool_call.index {
// Ensure we have enough tool calls in our vector
while current_tool_calls.len() <= index {
current_tool_calls
.push(OpenAIStreamingToolCall::default());
}
let tool_call = &mut current_tool_calls[index];
if let Some(id) = &delta_tool_call.id {
tool_call.id = Some(id.clone());
}
if let Some(function) = &delta_tool_call.function {
if let Some(name) = &function.name {
tool_call.name = Some(name.clone());
}
if let Some(arguments) = &function.arguments {
tool_call.arguments.push_str(arguments);
}
}
}
}
}
}
// Handle usage
if let Some(usage) = chunk_data.usage {
accumulated_usage = Some(Usage {
prompt_tokens: usage.prompt_tokens,
completion_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
});
}
}
Err(e) => {
debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
}
}
}
}
}
Err(e) => {
error!("Stream error: {}", e);
let _ = tx.send(Err(anyhow::anyhow!("Stream error: {}", e))).await;
return accumulated_usage;
}
}
}
// Send final chunk if we haven't already
let tool_calls = if current_tool_calls.is_empty() {
None
} else {
Some(
current_tool_calls
.iter()
.filter_map(|tc| tc.to_tool_call())
.collect(),
)
};
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
tool_calls,
usage: accumulated_usage.clone(),
};
let _ = tx.send(Ok(final_chunk)).await;
accumulated_usage
}
}
#[async_trait]
impl LLMProvider for OpenAIProvider {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
debug!(
"Processing OpenAI completion request with {} messages",
request.messages.len()
);
let body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
false,
request.max_tokens,
request.temperature,
);
debug!("Sending request to OpenAI API: model={}", self.model);
let response = self
.client
.post(format!("{}/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text));
}
let openai_response: OpenAIResponse = response.json().await?;
let content = openai_response
.choices
.first()
.and_then(|choice| choice.message.content.clone())
.unwrap_or_default();
let usage = Usage {
prompt_tokens: openai_response.usage.prompt_tokens,
completion_tokens: openai_response.usage.completion_tokens,
total_tokens: openai_response.usage.total_tokens,
};
debug!(
"OpenAI completion successful: {} tokens generated",
usage.completion_tokens
);
Ok(CompletionResponse {
content,
usage,
model: self.model.clone(),
})
}
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
debug!(
"Processing OpenAI streaming request with {} messages",
request.messages.len()
);
let body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
true,
request.max_tokens,
request.temperature,
);
debug!("Sending streaming request to OpenAI API: model={}", self.model);
let response = self
.client
.post(format!("{}/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text));
}
let stream = response.bytes_stream();
let (tx, rx) = mpsc::channel(100);
// Spawn task to process the stream
let provider = self.clone();
tokio::spawn(async move {
let usage = provider.parse_streaming_response(stream, tx).await;
// Log the final usage if available
if let Some(usage) = usage {
debug!(
"Stream completed with usage - prompt: {}, completion: {}, total: {}",
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
);
}
});
Ok(ReceiverStream::new(rx))
}
fn name(&self) -> &str {
"openai"
}
fn model(&self) -> &str {
&self.model
}
fn has_native_tool_calling(&self) -> bool {
// OpenAI models support native tool calling
true
}
}
fn convert_messages(messages: &[Message]) -> Vec<serde_json::Value> {
messages
.iter()
.map(|msg| {
json!({
"role": match msg.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
},
"content": msg.content,
})
})
.collect()
}
fn convert_tools(tools: &[Tool]) -> Vec<serde_json::Value> {
tools
.iter()
.map(|tool| {
json!({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.input_schema,
}
})
})
.collect()
}
// OpenAI API response structures
#[derive(Debug, Deserialize)]
struct OpenAIResponse {
choices: Vec<OpenAIChoice>,
usage: OpenAIUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAIChoice {
message: OpenAIMessage,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct OpenAIMessage {
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<OpenAIToolCall>>,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct OpenAIToolCall {
id: String,
function: OpenAIFunction,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct OpenAIFunction {
name: String,
arguments: String,
}
// Streaming tool call accumulator
#[derive(Debug, Default)]
struct OpenAIStreamingToolCall {
id: Option<String>,
name: Option<String>,
arguments: String,
}
impl OpenAIStreamingToolCall {
fn to_tool_call(&self) -> Option<ToolCall> {
let id = self.id.as_ref()?;
let name = self.name.as_ref()?;
let args = serde_json::from_str(&self.arguments).unwrap_or(serde_json::Value::Null);
Some(ToolCall {
id: id.clone(),
tool: name.clone(),
args,
})
}
}
#[derive(Debug, Deserialize)]
struct OpenAIUsage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
// Streaming response structures
#[derive(Debug, Deserialize)]
struct OpenAIStreamChunk {
choices: Vec<OpenAIStreamChoice>,
usage: Option<OpenAIUsage>,
}
#[derive(Debug, Deserialize)]
struct OpenAIStreamChoice {
delta: OpenAIDelta,
}
#[derive(Debug, Deserialize)]
struct OpenAIDelta {
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<OpenAIDeltaToolCall>>,
}
#[derive(Debug, Deserialize)]
struct OpenAIDeltaToolCall {
index: Option<usize>,
id: Option<String>,
function: Option<OpenAIDeltaFunction>,
}
#[derive(Debug, Deserialize)]
struct OpenAIDeltaFunction {
name: Option<String>,
arguments: Option<String>,
}

389
docs/ACCUMULATIVE_MODE.md Normal file
View File

@@ -0,0 +1,389 @@
# Accumulative Autonomous Mode
## Overview
Accumulative Autonomous Mode is the **new default interactive mode** for G3. It combines the ease of interactive chat with the power of autonomous implementation, allowing you to build projects iteratively by describing what you want, one requirement at a time.
## How It Works
### The Flow
1. **Start G3** in any directory (no arguments needed)
2. **Describe** what you want to build
3. **G3 automatically**:
- Adds your input to accumulated requirements
- Runs autonomous mode (coach-player feedback loop)
- Implements your requirements with quality checks
4. **Continue** adding more requirements or refinements
5. **Repeat** until your project is complete
### Example Session
```bash
$ cd ~/projects/my-new-app
$ g3
🪿 G3 AI Coding Agent - Accumulative Mode
>> describe what you want, I'll build it iteratively
📁 Workspace: /Users/you/projects/my-new-app
💡 Each input you provide will be added to requirements
and I'll automatically work on implementing them.
Type 'exit' or 'quit' to stop, Ctrl+D to finish
============================================================
📝 What would you like me to build? (describe your requirements)
============================================================
requirement> create a simple web server in Python with Flask that serves a homepage
📋 Current instructions and requirements (Turn 1):
create a simple web server in Python with Flask that serves a homepage
🚀 Starting autonomous implementation...
🤖 G3 AI Coding Agent - Autonomous Mode
📁 Using workspace: /Users/you/projects/my-new-app
📋 Requirements loaded from --requirements flag
🔄 Starting coach-player feedback loop...
📂 No existing implementation files detected
🎯 Starting with player implementation
=== TURN 1/5 - PLAYER MODE ===
🎯 Starting player implementation...
📋 Player starting initial implementation (no prior coach feedback)
[Player creates files, writes code...]
=== TURN 1/5 - COACH MODE ===
🎓 Starting coach review...
🎓 Coach review completed
Coach feedback:
The Flask server is implemented correctly with a homepage route.
The code follows best practices and meets the requirements.
IMPLEMENTATION_APPROVED
=== SESSION COMPLETED - IMPLEMENTATION APPROVED ===
✅ Coach approved the implementation!
============================================================
📊 AUTONOMOUS MODE SESSION REPORT
============================================================
⏱️ Total Duration: 12.34s
🔄 Turns Taken: 1/5
📝 Final Status: ✅ APPROVED
...
============================================================
✅ Autonomous run completed
============================================================
📝 Turn 2 - What's next? (add more requirements or refinements)
============================================================
requirement> add a /api/users endpoint that returns a list of users as JSON
📋 Current instructions and requirements (Turn 2):
add a /api/users endpoint that returns a list of users as JSON
🚀 Starting autonomous implementation...
[Autonomous mode runs again with BOTH requirements...]
============================================================
📝 Turn 3 - What's next? (add more requirements or refinements)
============================================================
requirement> exit
👋 Goodbye!
```
## Key Features
### 1. Requirement Accumulation
Each input you provide is:
- **Numbered sequentially** (1, 2, 3, ...)
- **Stored in memory** for the session
- **Included in all subsequent runs**
This means the agent always has the full context of what you've asked for.
### 2. Automatic Requirements Document
G3 automatically generates a structured requirements document:
```markdown
# Project Requirements
## Current Instructions and Requirements:
1. create a simple web server in Python with Flask that serves a homepage
2. add a /api/users endpoint that returns a list of users as JSON
3. add error handling for 404 and 500 errors
## Latest Requirement (Turn 3):
add error handling for 404 and 500 errors
```
This document is passed to autonomous mode, ensuring the agent:
- Knows all previous requirements
- Focuses on the latest addition
- Maintains consistency across iterations
### 3. Full Autonomous Quality
Each requirement triggers a complete autonomous run with:
- **Coach-Player Feedback Loop**: Quality assurance built-in
- **Multiple Turns**: Up to 5 iterations per requirement (configurable with `--max-turns`)
- **Compilation Checks**: Ensures code actually works
- **Testing**: Coach can run tests to verify functionality
### 4. Error Recovery
If an autonomous run fails:
- You're notified of the error
- You can provide additional requirements to fix issues
- The session continues (doesn't crash)
### 5. Workspace Management
- Uses **current directory** as workspace
- All files created in current directory
- No need to specify workspace path
- Works with existing projects or empty directories
## Command-Line Options
### Default (Accumulative Mode)
```bash
g3
```
Starts accumulative autonomous mode in the current directory.
### With Options
```bash
# Use a specific workspace
g3 --workspace ~/projects/my-app
# Limit autonomous turns per requirement
g3 --max-turns 3
# Enable macOS Accessibility tools
g3 --macax
# Enable WebDriver browser automation
g3 --webdriver
# Use a specific provider/model
g3 --provider anthropic --model claude-3-5-sonnet-20241022
# Show prompts and code during execution
g3 --show-prompt --show-code
# Disable log files
g3 --quiet
```
### Disable Accumulative Mode
To use the traditional chat mode (without automatic autonomous runs):
```bash
g3 --chat
# Alternative: legacy flag also works
g3 --accumulative
```
This gives you the old behavior where you chat with the agent without automatic autonomous runs.
## Use Cases
### 1. Rapid Prototyping
```bash
requirement> create a REST API for a todo app
requirement> add SQLite database storage
requirement> add authentication with JWT
requirement> add rate limiting
```
### 2. Iterative Refinement
```bash
requirement> create a data visualization dashboard
requirement> make the charts interactive
requirement> add dark mode support
requirement> optimize for mobile devices
```
### 3. Bug Fixing
```bash
requirement> fix the login form validation
requirement> handle edge case when username is empty
requirement> add better error messages
```
### 4. Feature Addition
```bash
requirement> add export to CSV functionality
requirement> add email notifications
requirement> add admin dashboard
```
## Tips and Best Practices
### 1. Start Simple
Begin with a basic requirement, let it be implemented, then add complexity:
```bash
✅ Good:
requirement> create a basic Flask web server
requirement> add a homepage with a form
requirement> add form validation
❌ Too Complex:
requirement> create a full-stack web app with authentication, database, API, and frontend
```
### 2. Be Specific
The more specific you are, the better the results:
```bash
✅ Good:
requirement> add a /api/users endpoint that returns JSON with id, name, and email fields
❌ Vague:
requirement> add users
```
### 3. One Thing at a Time
Focus each requirement on a single feature or fix:
```bash
✅ Good:
requirement> add error handling for database connections
requirement> add logging for all API requests
❌ Multiple Things:
requirement> add error handling and logging and monitoring and alerts
```
### 4. Review Between Turns
After each autonomous run completes:
- Check the generated files
- Test the functionality
- Decide what to add or fix next
### 5. Use Exit Commands
When done:
- Type `exit` or `quit`
- Press `Ctrl+D` (EOF)
- Press `Ctrl+C` to cancel current input
## Comparison with Other Modes
| Feature | Accumulative (Default) | Traditional Interactive | Autonomous | Single-Shot |
|---------|----------------------|------------------------|------------|-------------|
| **Command** | `g3` | `g3 --accumulative` | `g3 --autonomous` | `g3 "task"` |
| **Input Style** | Iterative prompts | Chat messages | requirements.md file | Command-line arg |
| **Auto-Autonomous** | ✅ Yes | ❌ No | ✅ Yes | ❌ No |
| **Coach-Player Loop** | ✅ Yes | ❌ No | ✅ Yes | ❌ No |
| **Accumulates Requirements** | ✅ Yes | ❌ No | ❌ No | ❌ No |
| **Multiple Iterations** | ✅ Yes | ✅ Yes | ✅ Yes | ❌ No |
| **Best For** | Iterative development | Quick questions | Pre-planned projects | One-off tasks |
## Technical Details
### Requirements Storage
- Stored in memory (not persisted to disk)
- Numbered sequentially starting from 1
- Formatted as markdown list
- Passed to autonomous mode as `--requirements` override
### History
- Saved to `~/.g3_accumulative_history`
- Separate from traditional interactive history
- Persists across sessions
- Uses rustyline for readline support
### Workspace
- Defaults to current directory
- Can be overridden with `--workspace`
- All files created in workspace
- Logs saved to `workspace/logs/`
### Autonomous Execution
- Full coach-player feedback loop
- Configurable max turns (default: 5)
- Respects all CLI flags (--macax, --webdriver, etc.)
- Error handling allows continuation
## Troubleshooting
### "No requirements provided"
This shouldn't happen in accumulative mode, but if it does:
- Check that you entered a requirement
- Ensure the requirement isn't empty
- Try restarting G3
### "Autonomous run failed"
If an autonomous run fails:
- Read the error message
- Provide a new requirement to fix the issue
- Or type `exit` and investigate manually
### "Context window full"
If you hit token limits:
- The agent will auto-summarize
- Or you can start a new session
- Consider using `--max-turns` to limit iterations
### "Coach never approves"
If the coach keeps rejecting:
- Check the coach feedback for specific issues
- Provide more specific requirements
- Consider increasing `--max-turns`
## Future Enhancements
Planned improvements:
1. **Persistence**: Save accumulated requirements to disk
2. **Editing**: Edit or remove previous requirements
3. **Branching**: Try different approaches
4. **Templates**: Pre-defined requirement sets
5. **Review**: Show all accumulated requirements
6. **Export**: Save to requirements.md
7. **Undo**: Remove last requirement
8. **Replay**: Re-run with same requirements
## Feedback
This is a new feature! Please provide feedback:
- What works well?
- What's confusing?
- What features would you like?
- Any bugs or issues?
Open an issue on GitHub or contribute improvements!

39
test-ai-requirements.sh Executable file
View File

@@ -0,0 +1,39 @@
#!/bin/bash
# Test script for AI-enhanced interactive requirements mode
echo "Testing AI-enhanced interactive requirements mode..."
echo ""
# Create a test workspace
TEST_WORKSPACE="/tmp/g3-test-interactive-$(date +%s)"
mkdir -p "$TEST_WORKSPACE"
echo "Test workspace: $TEST_WORKSPACE"
echo ""
# Create sample brief input
BRIEF_INPUT="build a calculator cli in rust with basic operations"
echo "Brief input:"
echo "---"
echo "$BRIEF_INPUT"
echo "---"
echo ""
echo "This will:"
echo "1. Send brief input to AI"
echo "2. AI generates structured requirements.md"
echo "3. Show enhanced requirements"
echo "4. Prompt for confirmation (y/e/n)"
echo ""
echo "To test manually, run:"
echo "cargo run -- --autonomous --interactive-requirements --workspace $TEST_WORKSPACE"
echo ""
echo "Then type: $BRIEF_INPUT"
echo "Press Ctrl+D"
echo "Review the AI-generated requirements"
echo "Choose 'y' to proceed, 'e' to edit, or 'n' to cancel"
echo ""
echo "Test workspace will be at: $TEST_WORKSPACE"