Compare commits

..

4 Commits

Author SHA1 Message Date
Michael Neale
f2ed303550 Revert "don't need this"
This reverts commit 93121c18e0.
2025-10-22 14:53:25 +11:00
Michael Neale
93121c18e0 don't need this 2025-10-22 14:30:13 +11:00
Michael Neale
ed84a940f9 tweak auto mode 2025-10-22 14:27:17 +11:00
Michael Neale
3128b5d8b9 can choose per mode models for auto mode 2025-10-22 14:19:00 +11:00
56 changed files with 2028 additions and 5497 deletions

1
.gitignore vendored
View File

@@ -2,7 +2,6 @@
# will have compiled files and executables # will have compiled files and executables
debug debug
target target
.build
# These are backup files generated by rustfmt # These are backup files generated by rustfmt
**/*.rs.bk **/*.rs.bk

33
CHANGELOG.md Normal file
View File

@@ -0,0 +1,33 @@
# Changelog
## [Unreleased]
### Added
**Interactive Requirements Mode**
- **AI-Enhanced Interactive Requirements**: New `--interactive-requirements` flag for autonomous mode
- User enters brief description of what they want to build
- AI automatically enhances input into structured requirements.md document
- Generates professional markdown with:
- Project title and overview
- Organized requirements (functional, technical, quality)
- Acceptance criteria
- User can review, accept, edit manually, or cancel before proceeding
- Seamlessly transitions to autonomous mode
**Autonomous Mode Configuration**
- **Autonomous Mode Configuration**: Added ability to specify different models for coach and player agents in autonomous mode
- New `[autonomous]` configuration section in `g3.toml`
- `coach_provider` and `coach_model` options for coach agent
- `player_provider` and `player_model` options for player agent
- `Config::for_coach()` and `Config::for_player()` methods to generate role-specific configurations
- Comprehensive test suite for autonomous configuration
### Changed
- Autonomous mode now uses `config.for_player()` for the player agent
- Coach agent creation now uses `config.for_coach()` for the coach agent
### Benefits
- **Cost Optimization**: Use cheaper models for execution, expensive models for review
- **Speed Optimization**: Use faster models for iteration, thorough models for validation
- **Specialization**: Leverage different providers' strengths for different roles

326
Cargo.lock generated
View File

@@ -2,28 +2,6 @@
# It is not intended for manual editing. # It is not intended for manual editing.
version = 4 version = 4
[[package]]
name = "accessibility"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ac9f33ffc1ef16eddb2451c03c983e56a5182ac760c3f2733da55ba8f48eac4"
dependencies = [
"accessibility-sys",
"cocoa 0.26.1",
"core-foundation 0.10.1",
"objc",
"thiserror 1.0.69",
]
[[package]]
name = "accessibility-sys"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46a6a8e90a1d8b96a48249e7c8f5b4058447bea8847280db7bfccb6dcab6b8e1"
dependencies = [
"core-foundation-sys",
]
[[package]] [[package]]
name = "adler2" name = "adler2"
version = "2.0.1" version = "2.0.1"
@@ -136,7 +114,7 @@ checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -218,6 +196,28 @@ version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "bindgen"
version = "0.64.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4243e6031260db77ede97ad86c27e501d646a27ab57b59a574f725d98ab1fb4"
dependencies = [
"bitflags 1.3.2",
"cexpr",
"clang-sys",
"lazy_static",
"lazycell",
"log",
"peeking_take_while",
"proc-macro2",
"quote",
"regex",
"rustc-hash",
"shlex",
"syn 1.0.109",
"which",
]
[[package]] [[package]]
name = "bindgen" name = "bindgen"
version = "0.69.5" version = "0.69.5"
@@ -237,7 +237,7 @@ dependencies = [
"regex", "regex",
"rustc-hash", "rustc-hash",
"shlex", "shlex",
"syn", "syn 2.0.107",
"which", "which",
] ]
@@ -318,9 +318,9 @@ dependencies = [
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.2.43" version = "1.2.41"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "739eb0f94557554b3ca9a86d2d37bebd49c5e6d0c1d2bda35ba5bdac830befc2" checksum = "ac9fe6cdbb24b6ade63616c0a0688e45bb56732262c158df3c0c4bea4ca47cb7"
dependencies = [ dependencies = [
"find-msvc-tools", "find-msvc-tools",
"jobserver", "jobserver",
@@ -411,7 +411,7 @@ dependencies = [
"heck", "heck",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -437,25 +437,9 @@ checksum = "f6140449f97a6e97f9511815c5632d84c8aacf8ac271ad77c559218161a1373c"
dependencies = [ dependencies = [
"bitflags 1.3.2", "bitflags 1.3.2",
"block", "block",
"cocoa-foundation 0.1.2", "cocoa-foundation",
"core-foundation 0.9.4", "core-foundation 0.9.4",
"core-graphics 0.23.2", "core-graphics",
"foreign-types 0.5.0",
"libc",
"objc",
]
[[package]]
name = "cocoa"
version = "0.26.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad36507aeb7e16159dfe68db81ccc27571c3ccd4b76fb2fb72fc59e7a4b1b64c"
dependencies = [
"bitflags 2.10.0",
"block",
"cocoa-foundation 0.2.1",
"core-foundation 0.10.1",
"core-graphics 0.24.0",
"foreign-types 0.5.0", "foreign-types 0.5.0",
"libc", "libc",
"objc", "objc",
@@ -470,24 +454,11 @@ dependencies = [
"bitflags 1.3.2", "bitflags 1.3.2",
"block", "block",
"core-foundation 0.9.4", "core-foundation 0.9.4",
"core-graphics-types 0.1.3", "core-graphics-types",
"libc", "libc",
"objc", "objc",
] ]
[[package]]
name = "cocoa-foundation"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81411967c50ee9a1fc11365f8c585f863a22a9697c89239c452292c40ba79b0d"
dependencies = [
"bitflags 2.10.0",
"block",
"core-foundation 0.10.1",
"core-graphics-types 0.2.0",
"objc",
]
[[package]] [[package]]
name = "color_quant" name = "color_quant"
version = "1.1.0" version = "1.1.0"
@@ -664,20 +635,7 @@ checksum = "c07782be35f9e1140080c6b96f0d44b739e2278479f64e02fdab4e32dfd8b081"
dependencies = [ dependencies = [
"bitflags 1.3.2", "bitflags 1.3.2",
"core-foundation 0.9.4", "core-foundation 0.9.4",
"core-graphics-types 0.1.3", "core-graphics-types",
"foreign-types 0.5.0",
"libc",
]
[[package]]
name = "core-graphics"
version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa95a34622365fa5bbf40b20b75dba8dfa8c94c734aea8ac9a5ca38af14316f1"
dependencies = [
"bitflags 2.10.0",
"core-foundation 0.10.1",
"core-graphics-types 0.2.0",
"foreign-types 0.5.0", "foreign-types 0.5.0",
"libc", "libc",
] ]
@@ -693,17 +651,6 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "core-graphics-types"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d44a101f213f6c4cdc1853d4b78aef6db6bdfa3468798cc1d9912f4735013eb"
dependencies = [
"bitflags 2.10.0",
"core-foundation 0.10.1",
"libc",
]
[[package]] [[package]]
name = "cpufeatures" name = "cpufeatures"
version = "0.2.17" version = "0.2.17"
@@ -745,7 +692,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"strict", "strict",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -884,7 +831,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"strsim", "strsim",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -895,14 +842,14 @@ checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead"
dependencies = [ dependencies = [
"darling_core", "darling_core",
"quote", "quote",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
name = "deranged" name = "deranged"
version = "0.5.5" version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ececcb659e7ba858fb4f10388c250a7252eb0a27373f1a72b8748afdd248e587" checksum = "a41953f86f8a05768a6cda24def994fd2f424b04ec5c719cf89989779f199071"
dependencies = [ dependencies = [
"powerfmt", "powerfmt",
] ]
@@ -917,7 +864,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"rustc_version", "rustc_version",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -938,7 +885,7 @@ dependencies = [
"convert_case 0.7.1", "convert_case 0.7.1",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -990,7 +937,7 @@ dependencies = [
"libc", "libc",
"option-ext", "option-ext",
"redox_users 0.5.2", "redox_users 0.5.2",
"windows-sys 0.59.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@@ -1001,7 +948,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -1015,9 +962,9 @@ dependencies = [
[[package]] [[package]]
name = "document-features" name = "document-features"
version = "0.2.12" version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d4b8a88685455ed29a21542a33abd9cb6510b6b129abadabdcef0f4c55bc8f61" checksum = "95249b50c6c185bee49034bcb378a49dc2b5dff0be90ff6616d31d64febab05d"
dependencies = [ dependencies = [
"litrs", "litrs",
] ]
@@ -1062,7 +1009,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys 0.52.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@@ -1144,9 +1091,9 @@ checksum = "52051878f80a721bb68ebfbc930e07b65ba72f2da88968ea5c06fd6ca3d3a127"
[[package]] [[package]]
name = "flate2" name = "flate2"
version = "1.1.5" version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb" checksum = "dc5a4e564e38c699f2880d3fda590bedc2e69f3f84cd48b457bd892ce61d0aa9"
dependencies = [ dependencies = [
"crc32fast", "crc32fast",
"miniz_oxide", "miniz_oxide",
@@ -1191,7 +1138,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -1271,7 +1218,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -1340,18 +1287,18 @@ dependencies = [
name = "g3-computer-control" name = "g3-computer-control"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"accessibility",
"anyhow", "anyhow",
"async-trait", "async-trait",
"cocoa 0.25.0", "cocoa",
"core-foundation 0.10.1", "core-foundation 0.9.4",
"core-graphics 0.23.2", "core-graphics",
"fantoccini", "fantoccini",
"image", "image",
"objc", "objc",
"serde", "serde",
"serde_json", "serde_json",
"shellexpand", "shellexpand",
"tesseract",
"thiserror 1.0.69", "thiserror 1.0.69",
"tokio", "tokio",
"tracing", "tracing",
@@ -1369,7 +1316,6 @@ dependencies = [
"dirs 5.0.1", "dirs 5.0.1",
"serde", "serde",
"shellexpand", "shellexpand",
"tempfile",
"thiserror 1.0.69", "thiserror 1.0.69",
"toml", "toml",
] ]
@@ -1571,11 +1517,11 @@ checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
[[package]] [[package]]
name = "home" name = "home"
version = "0.5.9" version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
dependencies = [ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@@ -1922,12 +1868,9 @@ dependencies = [
[[package]] [[package]]
name = "indoc" name = "indoc"
version = "2.0.7" version = "2.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706" checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd"
dependencies = [
"rustversion",
]
[[package]] [[package]]
name = "instability" name = "instability"
@@ -1939,7 +1882,7 @@ dependencies = [
"indoc", "indoc",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -1950,9 +1893,9 @@ checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
[[package]] [[package]]
name = "is_terminal_polyfill" name = "is_terminal_polyfill"
version = "1.70.2" version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]] [[package]]
name = "itertools" name = "itertools"
@@ -2060,7 +2003,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"regex", "regex",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -2081,6 +2024,28 @@ version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a79a3332a6609480d7d0c9eab957bca6b455b91bb84e66d19f5ff66294b85b8" checksum = "7a79a3332a6609480d7d0c9eab957bca6b455b91bb84e66d19f5ff66294b85b8"
[[package]]
name = "leptonica-plumbing"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc7a74c43d6f090d39158d233f326f47cd8bba545217595c93662b4e31156f42"
dependencies = [
"leptonica-sys",
"libc",
"thiserror 1.0.69",
]
[[package]]
name = "leptonica-sys"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da627c72b2499a8106f4dd33143843015e4a631f445d561f3481f7fba35b6151"
dependencies = [
"bindgen 0.64.0",
"pkg-config",
"vcpkg",
]
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.177" version = "0.2.177"
@@ -2136,9 +2101,9 @@ checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956"
[[package]] [[package]]
name = "litrs" name = "litrs"
version = "1.0.0" version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092" checksum = "f5e54036fe321fd421e10d732f155734c4e4afd610dd556d9a82833ab3ee0bed"
[[package]] [[package]]
name = "llama_cpp" name = "llama_cpp"
@@ -2161,7 +2126,7 @@ version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "037a1881ada3592c6a922224d5177b4b4f452e6b2979eb97393b71989e48357f" checksum = "037a1881ada3592c6a922224d5177b4b4f452e6b2979eb97393b71989e48357f"
dependencies = [ dependencies = [
"bindgen", "bindgen 0.69.5",
"cc", "cc",
"link-cplusplus", "link-cplusplus",
"once_cell", "once_cell",
@@ -2254,14 +2219,14 @@ dependencies = [
[[package]] [[package]]
name = "mio" name = "mio"
version = "1.1.0" version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69d83b0086dc8ecf3ce9ae2874b2d1290252e2a30720bea58a5c6639b0092873" checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c"
dependencies = [ dependencies = [
"libc", "libc",
"log", "log",
"wasi", "wasi",
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@@ -2333,7 +2298,7 @@ version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [ dependencies = [
"windows-sys 0.59.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@@ -2409,9 +2374,9 @@ checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]] [[package]]
name = "once_cell_polyfill" name = "once_cell_polyfill"
version = "1.70.2" version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad"
[[package]] [[package]]
name = "openssl" name = "openssl"
@@ -2436,7 +2401,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -2508,6 +2473,12 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3"
[[package]]
name = "peeking_take_while"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
[[package]] [[package]]
name = "percent-encoding" name = "percent-encoding"
version = "2.3.2" version = "2.3.2"
@@ -2544,7 +2515,7 @@ dependencies = [
"pest_meta", "pest_meta",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -2625,14 +2596,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.103" version = "1.0.101"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8" checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de"
dependencies = [ dependencies = [
"unicode-ident", "unicode-ident",
] ]
@@ -2904,7 +2875,7 @@ dependencies = [
"errno", "errno",
"libc", "libc",
"linux-raw-sys 0.11.0", "linux-raw-sys 0.11.0",
"windows-sys 0.52.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@@ -3030,7 +3001,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -3125,9 +3096,9 @@ dependencies = [
[[package]] [[package]]
name = "signal-hook-mio" name = "signal-hook-mio"
version = "0.2.5" version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b75a19a7a740b25bc7944bdee6172368f988763b744e3d4dfe753f6b4ece40cc" checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd"
dependencies = [ dependencies = [
"libc", "libc",
"mio", "mio",
@@ -3224,14 +3195,25 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"rustversion", "rustversion",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.108" version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da58917d35242480a05c2897064da0a80589a2a0476c9a3f2fdc83b53502e917" checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "syn"
version = "2.0.107"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a26dbd934e5451d21ef060c018dae56fc073894c5a7896f882928a76e6d081b"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@@ -3258,7 +3240,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -3292,7 +3274,7 @@ dependencies = [
"getrandom 0.3.4", "getrandom 0.3.4",
"once_cell", "once_cell",
"rustix 1.1.2", "rustix 1.1.2",
"windows-sys 0.52.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@@ -3311,6 +3293,40 @@ dependencies = [
"unicode-width 0.1.14", "unicode-width 0.1.14",
] ]
[[package]]
name = "tesseract"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2ee0c2c608b63817b095f7fded5c50add36a29e2be2b2fc4901357163329290a"
dependencies = [
"tesseract-plumbing",
"tesseract-sys",
"thiserror 1.0.69",
]
[[package]]
name = "tesseract-plumbing"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e496d3e29eba540a276975394b85dccb5fd344b3eefb743d9286c8150f766d5"
dependencies = [
"leptonica-plumbing",
"tesseract-sys",
"thiserror 1.0.69",
]
[[package]]
name = "tesseract-sys"
version = "0.5.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd33f6f216124cfaf0fa86c2c0cdf04da39b6257bd78c5e44fa4fa98c3a5857b"
dependencies = [
"bindgen 0.64.0",
"leptonica-sys",
"pkg-config",
"vcpkg",
]
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.69" version = "1.0.69"
@@ -3337,7 +3353,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -3348,7 +3364,7 @@ checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -3446,7 +3462,7 @@ checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -3572,7 +3588,7 @@ checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -3634,9 +3650,9 @@ checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971"
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"
version = "1.0.20" version = "1.0.19"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "462eeb75aeb73aea900253ce739c8e18a67423fadf006037cd3ff27e82748a06" checksum = "f63a545481291138910575129486daeaf8ac54aee4387fe7906919f7830c7d9d"
[[package]] [[package]]
name = "unicode-segmentation" name = "unicode-segmentation"
@@ -3777,7 +3793,7 @@ dependencies = [
"log", "log",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
@@ -3812,7 +3828,7 @@ checksum = "9f07d2f20d4da7b26400c9f4a0511e6e0345b040694e8a75bd41d578fa4421d7"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
"wasm-bindgen-backend", "wasm-bindgen-backend",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
@@ -3935,7 +3951,7 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [ dependencies = [
"windows-sys 0.48.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@@ -3984,7 +4000,7 @@ checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -3995,7 +4011,7 @@ checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -4391,7 +4407,7 @@ checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
"synstructure", "synstructure",
] ]
@@ -4412,7 +4428,7 @@ checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]
@@ -4432,7 +4448,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
"synstructure", "synstructure",
] ]
@@ -4466,7 +4482,7 @@ checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.107",
] ]
[[package]] [[package]]

368
README.md
View File

@@ -2,133 +2,14 @@
G3 is a coding AI agent designed to help you complete tasks by writing code and executing commands. Built in Rust, it provides a flexible architecture for interacting with various Large Language Model (LLM) providers while offering powerful code generation and task automation capabilities. G3 is a coding AI agent designed to help you complete tasks by writing code and executing commands. Built in Rust, it provides a flexible architecture for interacting with various Large Language Model (LLM) providers while offering powerful code generation and task automation capabilities.
## Architecture Overview
G3 follows a modular architecture organized as a Rust workspace with multiple crates, each responsible for specific functionality:
### Core Components
#### **g3-core**
The heart of the agent system, containing:
- **Agent Engine**: Main orchestration logic for handling conversations, tool execution, and task management
- **Context Window Management**: Intelligent tracking of token usage with context thinning (50-80%) and auto-summarization at 80% capacity
- **Tool System**: Built-in tools for file operations, shell commands, computer control, TODO management, and structured output
- **Streaming Response Parser**: Real-time parsing of LLM responses with tool call detection and execution
- **Task Execution**: Support for single and iterative task execution with automatic retry logic
#### **g3-providers**
Abstraction layer for LLM providers:
- **Provider Interface**: Common trait-based API for different LLM backends
- **Multiple Provider Support**:
- Anthropic (Claude models)
- Databricks (DBRX and other models)
- Local/embedded models via llama.cpp with Metal acceleration on macOS
- **OAuth Authentication**: Built-in OAuth flow support for secure provider authentication
- **Provider Registry**: Dynamic provider management and selection
#### **g3-config**
Configuration management system:
- Environment-based configuration
- Provider credentials and settings
- Model selection and parameters
- Runtime configuration options
#### **g3-execution**
Task execution framework:
- Task planning and decomposition
- Execution strategies (sequential, parallel)
- Error handling and retry mechanisms
- Progress tracking and reporting
#### **g3-computer-control**
Computer control capabilities:
- Mouse and keyboard automation
- UI element inspection and interaction
- Screenshot capture and window management
- OCR text extraction via Tesseract
#### **g3-cli**
Command-line interface:
- Interactive terminal interface
- Task submission and monitoring
- Configuration management commands
- Session management
### Error Handling & Resilience
G3 includes robust error handling with automatic retry logic:
- **Recoverable Error Detection**: Automatically identifies recoverable errors (rate limits, network issues, server errors, timeouts)
- **Exponential Backoff with Jitter**: Implements intelligent retry delays to avoid overwhelming services
- **Detailed Error Logging**: Captures comprehensive error context including stack traces, request/response data, and session information
- **Error Persistence**: Saves detailed error logs to `logs/errors/` for post-mortem analysis
- **Graceful Degradation**: Non-recoverable errors are logged with full context before terminating
## Key Features ## Key Features
### Intelligent Context Management - **Multiple LLM Providers**: Anthropic (Claude), Databricks, OpenAI, and local models via llama.cpp
- Automatic context window monitoring with percentage-based tracking - **Autonomous Mode**: Coach-player feedback loop for complex tasks
- Smart auto-summarization when approaching token limits - **Intelligent Context Management**: Auto-summarization and context thinning at 50-80% thresholds
- **Context thinning** at 50%, 60%, 70%, 80% thresholds - automatically replaces large tool results with file references - **Rich Tool Ecosystem**: File operations, shell commands, computer control, browser automation
- Conversation history preservation through summaries - **Streaming Responses**: Real-time output with tool call detection
- Dynamic token allocation for different providers (4k to 200k+ tokens) - **Error Recovery**: Automatic retry logic with exponential backoff
### Interactive Control Commands
G3's interactive CLI includes control commands for manual context management:
- **`/compact`**: Manually trigger summarization to compact conversation history
- **`/thinnify`**: Manually trigger context thinning to replace large tool results with file references
- **`/readme`**: Reload README.md and AGENTS.md from disk without restarting
- **`/stats`**: Show detailed context and performance statistics
- **`/help`**: Display all available control commands
These commands give you fine-grained control over context management, allowing you to proactively optimize token usage and refresh project documentation. See [Control Commands Documentation](docs/CONTROL_COMMANDS.md) for detailed usage.
### Tool Ecosystem
- **File Operations**: Read, write, and edit files with line-range precision
- **Shell Integration**: Execute system commands with output capture
- **Code Generation**: Structured code generation with syntax awareness
- **TODO Management**: Read and write TODO lists with markdown checkbox format
- **Computer Control** (Experimental): Automate desktop applications
- Mouse and keyboard control
- macOS Accessibility API for native app automation (via `--macax` flag)
- UI element inspection
- Screenshot capture and window management
- OCR text extraction from images and screen regions
- Window listing and identification
- **Final Output**: Formatted result presentation
### Provider Flexibility
- Support for multiple LLM providers through a unified interface
- Hot-swappable providers without code changes
- Provider-specific optimizations and feature support
- Local model support for offline operation
### Task Automation
- Single-shot task execution for quick operations
- Iterative task mode for complex, multi-step workflows
- Automatic error recovery and retry logic
- Progress tracking and intermediate result handling
## Language & Technology Stack
- **Language**: Rust (2021 edition)
- **Async Runtime**: Tokio for concurrent operations
- **HTTP Client**: Reqwest for API communications
- **Serialization**: Serde for JSON handling
- **CLI Framework**: Clap for command-line parsing
- **Logging**: Tracing for structured logging
- **Local Models**: llama.cpp with Metal acceleration support
## Use Cases
G3 is designed for:
- Automated code generation and refactoring
- File manipulation and project scaffolding
- System administration tasks
- Data processing and transformation
- API integration and testing
- Documentation generation
- Complex multi-step workflows
- Desktop application automation and testing
## Getting Started ## Getting Started
@@ -136,69 +17,234 @@ G3 is designed for:
# Build the project # Build the project
cargo build --release cargo build --release
# Run G3 # Execute a single task
cargo run
# Execute a task
g3 "implement a function to calculate fibonacci numbers" g3 "implement a function to calculate fibonacci numbers"
# Start autonomous mode with interactive requirements
g3 --autonomous --interactive-requirements
``` ```
## Configuration
Create `~/.config/g3/config.toml`:
```toml
[providers]
default_provider = "databricks"
[providers.anthropic]
api_key = "sk-ant-..."
model = "claude-3-5-sonnet-20241022"
max_tokens = 4096
[providers.databricks]
host = "https://your-workspace.cloud.databricks.com"
model = "databricks-meta-llama-3-1-70b-instruct"
max_tokens = 4096
use_oauth = true
[agent]
max_context_length = 8192
enable_streaming = true
# Optional: Use different models for coach and player in autonomous mode
[autonomous]
coach_provider = "anthropic"
coach_model = "claude-3-5-sonnet-20241022" # Thorough review
player_provider = "databricks"
player_model = "databricks-meta-llama-3-1-70b-instruct" # Fast execution
```
## Autonomous Mode (Coach-Player Loop)
G3 features an autonomous mode where two agents collaborate:
- **Player Agent**: Executes tasks and implements solutions
- **Coach Agent**: Reviews work and provides feedback
### Option 1: Interactive Requirements with AI Enhancement (Recommended)
```bash
g3 --autonomous --interactive-requirements
```
**How it works:**
1. Describe what you want to build (can be brief)
2. Press **Ctrl+D** (Unix/Mac) or **Ctrl+Z** (Windows)
3. AI enhances your input into a structured requirements document
4. Review the enhanced requirements
5. Choose to proceed, edit manually, or cancel
6. If accepted, autonomous mode starts automatically
**Example:**
```
You type: "build a todo app with cli in python"
AI generates:
# Todo List CLI Application
## Overview
A command-line todo list application built in Python...
## Functional Requirements
1. Add tasks with descriptions
2. Mark tasks as complete
3. Delete tasks
...
```
### Option 2: Direct Requirements
```bash
g3 --autonomous --requirements "Build a REST API with CRUD operations for user management"
```
### Option 3: Requirements File
Create `requirements.md` in your workspace:
```markdown
# Project Requirements
1. Create a REST API with user endpoints
2. Use SQLite for storage
3. Include input validation
4. Write unit tests
```
Then run:
```bash
g3 --autonomous
```
### Why Different Models for Coach and Player?
Configure different models in the `[autonomous]` section to:
- **Optimize Cost**: Use cheaper model for execution, expensive for review
- **Optimize Speed**: Use fast model for iteration, thorough for validation
- **Specialize**: Leverage provider strengths (e.g., Claude for analysis, Llama for code)
If not configured, both agents use the `default_provider` and its model.
## Command-Line Options
```bash
# Autonomous mode
g3 --autonomous --interactive-requirements
g3 --autonomous --requirements "Your requirements"
g3 --autonomous --max-turns 10
# Single-shot mode
g3 "your task here"
# Options
--workspace <DIR> # Set workspace directory
--provider <NAME> # Override provider (anthropic, databricks, openai)
--model <NAME> # Override model
--quiet # Disable log files
--webdriver # Enable browser automation
--show-prompt # Show system prompt
--show-code # Show generated code
```
## Architecture Overview
G3 is organized as a Rust workspace with multiple crates:
- **g3-core**: Agent engine, context management, tool system, streaming parser
- **g3-providers**: LLM provider abstraction (Anthropic, Databricks, OpenAI, local models)
- **g3-config**: Configuration management
- **g3-execution**: Task execution framework
- **g3-computer-control**: Mouse/keyboard automation, OCR, screenshots
- **g3-cli**: Command-line interface
### Key Capabilities
**Intelligent Context Management**
- Automatic context window monitoring with percentage-based tracking
- Smart auto-summarization when approaching token limits
- Context thinning at 50%, 60%, 70%, 80% thresholds
- Dynamic token allocation (4k to 200k+ tokens)
**Tool Ecosystem**
- File operations (read, write, edit with line-range precision)
- Shell command execution
- TODO management
- Computer control (experimental): mouse, keyboard, OCR, screenshots
- Browser automation via WebDriver (Safari)
**Error Handling**
- Automatic retry logic with exponential backoff
- Recoverable error detection (rate limits, network issues, timeouts)
- Detailed error logging to `logs/errors/`
## WebDriver Browser Automation ## WebDriver Browser Automation
G3 includes WebDriver support for browser automation tasks using Safari. **One-Time Setup** (macOS):
**One-Time Setup** (macOS only):
Safari Remote Automation must be enabled before using WebDriver tools. Run this once:
```bash ```bash
# Option 1: Use the provided script # Enable Safari Remote Automation
./scripts/enable-safari-automation.sh
# Option 2: Enable manually
safaridriver --enable # Requires password safaridriver --enable # Requires password
# Option 3: Enable via Safari UI # Or via Safari UI:
# Safari → Preferences → Advanced → Show Develop menu # Safari → Preferences → Advanced → Show Develop menu
# Then: Develop → Allow Remote Automation # Then: Develop → Allow Remote Automation
``` ```
**For detailed setup instructions and troubleshooting**, see [WebDriver Setup Guide](docs/webdriver-setup.md). **Usage**:
**Usage**: Run G3 with the `--webdriver` flag to enable browser automation tools. ```bash
g3 --webdriver "scrape the top stories from Hacker News"
```
## macOS Accessibility API Tools See [docs/webdriver-setup.md](docs/webdriver-setup.md) for detailed setup.
G3 includes support for controlling macOS applications via the Accessibility API, allowing you to automate native macOS apps.
**Available Tools**: `macax_list_apps`, `macax_get_frontmost_app`, `macax_activate_app`, `macax_get_ui_tree`, `macax_find_elements`, `macax_click`, `macax_set_value`, `macax_get_value`, `macax_press_key`
**Setup**: Enable with the `--macax` flag or in config with `macax.enabled = true`. Grant accessibility permissions:
- **macOS**: System Preferences → Security & Privacy → Privacy → Accessibility → Add your terminal app
**For detailed documentation**, see [macOS Accessibility Tools Guide](docs/macax-tools.md).
**Note**: This is particularly useful for testing and automating apps you're building with G3, as you can add accessibility identifiers to your UI elements.
## Computer Control (Experimental) ## Computer Control (Experimental)
G3 can interact with your computer's GUI for automation tasks: Enable in config:
```toml
[computer_control]
enabled = true
require_confirmation = true
```
Grant accessibility permissions:
- **macOS**: System Preferences → Security & Privacy → Accessibility
- **Linux**: Ensure X11 or Wayland access
- **Windows**: Run as administrator (first time)
**Available Tools**: `mouse_click`, `type_text`, `find_element`, `take_screenshot`, `extract_text`, `find_text_on_screen`, `list_windows` **Available Tools**: `mouse_click`, `type_text`, `find_element`, `take_screenshot`, `extract_text`, `find_text_on_screen`, `list_windows`
**Setup**: Enable in config with `computer_control.enabled = true` and grant OS accessibility permissions: ## Use Cases
- **macOS**: System Preferences → Security & Privacy → Accessibility
- **Linux**: Ensure X11 or Wayland access - Automated code generation and refactoring
- **Windows**: Run as administrator (first time only) - File manipulation and project scaffolding
- System administration tasks
- Data processing and transformation
- API integration and testing
- Documentation generation
- Complex multi-step workflows
- Desktop application automation
## Session Logs ## Session Logs
G3 automatically saves session logs for each interaction in the `logs/` directory. These logs contain: G3 automatically saves session logs to `logs/` directory:
- Complete conversation history - Complete conversation history
- Token usage statistics - Token usage statistics
- Timestamps and session status - Timestamps and session status
The `logs/` directory is created automatically on first use and is excluded from version control. Disable with `--quiet` flag.
## Technology Stack
- **Language**: Rust (2021 edition)
- **Async Runtime**: Tokio
- **HTTP Client**: Reqwest
- **Serialization**: Serde
- **CLI Framework**: Clap
- **Logging**: Tracing
- **Local Models**: llama.cpp with Metal acceleration
## License ## License
@@ -206,4 +252,4 @@ MIT License - see LICENSE file for details
## Contributing ## Contributing
G3 is an open-source project. Contributions are welcome! Please see CONTRIBUTING.md for guidelines. Contributions welcome! Please see CONTRIBUTING.md for guidelines.

View File

@@ -1,24 +0,0 @@
[providers]
default_provider = "databricks"
# Specify different providers for coach and player in autonomous mode
coach = "databricks" # Provider for coach (code reviewer) - can be more powerful/expensive
player = "anthropic" # Provider for player (code implementer) - can be faster/cheaper
[providers.databricks]
host = "https://your-workspace.cloud.databricks.com"
# token = "your-databricks-token" # Optional - will use OAuth if not provided
model = "databricks-claude-sonnet-4"
max_tokens = 4096
temperature = 0.1
use_oauth = true
[providers.anthropic]
api_key = "your-anthropic-api-key"
model = "claude-3-haiku-20240307" # Using a faster model for player
max_tokens = 4096
temperature = 0.3 # Slightly higher temperature for more creative implementations
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60

View File

@@ -1,10 +1,5 @@
[providers] [providers]
default_provider = "databricks" default_provider = "databricks"
# Optional: Specify different providers for coach and player in autonomous mode
# If not specified, will use default_provider for both
# coach = "databricks" # Provider for coach (code reviewer)
# player = "anthropic" # Provider for player (code implementer)
# Note: Make sure the specified providers are configured below
[providers.databricks] [providers.databricks]
host = "https://your-workspace.cloud.databricks.com" host = "https://your-workspace.cloud.databricks.com"

File diff suppressed because it is too large Load Diff

View File

@@ -1,94 +0,0 @@
use g3_core::ui_writer::UiWriter;
use std::io::{self, Write};
/// Machine-mode implementation of UiWriter that prints plain, unformatted output
/// This is designed for programmatic consumption and outputs everything verbatim
pub struct MachineUiWriter;
impl MachineUiWriter {
pub fn new() -> Self {
Self
}
}
impl UiWriter for MachineUiWriter {
fn print(&self, message: &str) {
print!("{}", message);
}
fn println(&self, message: &str) {
println!("{}", message);
}
fn print_inline(&self, message: &str) {
print!("{}", message);
let _ = io::stdout().flush();
}
fn print_system_prompt(&self, prompt: &str) {
println!("SYSTEM_PROMPT:");
println!("{}", prompt);
println!("END_SYSTEM_PROMPT");
println!();
}
fn print_context_status(&self, message: &str) {
println!("CONTEXT_STATUS: {}", message);
}
fn print_context_thinning(&self, message: &str) {
println!("CONTEXT_THINNING: {}", message);
}
fn print_tool_header(&self, tool_name: &str) {
println!("TOOL_CALL: {}", tool_name);
}
fn print_tool_arg(&self, key: &str, value: &str) {
println!("TOOL_ARG: {} = {}", key, value);
}
fn print_tool_output_header(&self) {
println!("TOOL_OUTPUT:");
}
fn update_tool_output_line(&self, line: &str) {
println!("{}", line);
}
fn print_tool_output_line(&self, line: &str) {
println!("{}", line);
}
fn print_tool_output_summary(&self, count: usize) {
println!("TOOL_OUTPUT_LINES: {}", count);
}
fn print_tool_timing(&self, duration_str: &str) {
println!("TOOL_DURATION: {}", duration_str);
println!("END_TOOL_OUTPUT");
println!();
}
fn print_agent_prompt(&self) {
println!("AGENT_RESPONSE:");
let _ = io::stdout().flush();
}
fn print_agent_response(&self, content: &str) {
print!("{}", content);
let _ = io::stdout().flush();
}
fn notify_sse_received(&self) {
// No-op for machine mode
}
fn flush(&self) {
let _ = io::stdout().flush();
}
fn wants_full_output(&self) -> bool {
true // Machine mode wants complete, untruncated output
}
}

View File

@@ -267,23 +267,23 @@ impl TerminalState {
let mut current_text = String::new(); let mut current_text = String::new();
// Check for headers first // Check for headers first
if let Some(stripped) = line.strip_prefix("### ") { if line.starts_with("### ") {
return Line::from(Span::styled( return Line::from(Span::styled(
format!(" {}", stripped), format!(" {}", &line[4..]),
Style::default() Style::default()
.fg(self.theme.terminal_cyan.to_color()) .fg(self.theme.terminal_cyan.to_color())
.add_modifier(Modifier::BOLD | Modifier::UNDERLINED), .add_modifier(Modifier::BOLD | Modifier::UNDERLINED),
)); ));
} else if let Some(stripped) = line.strip_prefix("## ") { } else if line.starts_with("## ") {
return Line::from(Span::styled( return Line::from(Span::styled(
format!(" {}", stripped), format!(" {}", &line[3..]),
Style::default() Style::default()
.fg(self.theme.terminal_amber.to_color()) .fg(self.theme.terminal_amber.to_color())
.add_modifier(Modifier::BOLD), .add_modifier(Modifier::BOLD),
)); ));
} else if let Some(stripped) = line.strip_prefix("# ") { } else if line.starts_with("# ") {
return Line::from(Span::styled( return Line::from(Span::styled(
format!(" {}", stripped), format!(" {}", &line[2..]),
Style::default() Style::default()
.fg(self.theme.terminal_green.to_color()) .fg(self.theme.terminal_green.to_color())
.add_modifier(Modifier::BOLD), .add_modifier(Modifier::BOLD),
@@ -343,7 +343,7 @@ impl TerminalState {
} }
// Find closing * // Find closing *
let mut italic_text = String::new(); let mut italic_text = String::new();
for ch in chars.by_ref() { while let Some(ch) = chars.next() {
if ch == '*' { if ch == '*' {
break; break;
} }
@@ -367,7 +367,7 @@ impl TerminalState {
} }
// Find closing ` // Find closing `
let mut code_text = String::new(); let mut code_text = String::new();
for ch in chars.by_ref() { while let Some(ch) = chars.next() {
if ch == '`' { if ch == '`' {
break; break;
} }
@@ -612,9 +612,11 @@ impl RetroTui {
} }
// Update status blink only if status is "PROCESSING" // Update status blink only if status is "PROCESSING"
if state.status_line == "PROCESSING" && state.last_status_blink.elapsed() > Duration::from_millis(500) { if state.status_line == "PROCESSING" {
state.status_blink = !state.status_blink; if state.last_status_blink.elapsed() > Duration::from_millis(500) {
state.last_status_blink = Instant::now(); state.status_blink = !state.status_blink;
state.last_status_blink = Instant::now();
}
} }
// Update activity area animation // Update activity area animation
@@ -769,7 +771,12 @@ impl RetroTui {
let total_cursor_pos = cursor_position; let total_cursor_pos = cursor_position;
// Determine the window into the buffer we should show // Determine the window into the buffer we should show
let window_start = total_cursor_pos.saturating_sub(available_width - 1); let window_start = if total_cursor_pos > available_width - 1 {
// Cursor is beyond the visible area, scroll the view
total_cursor_pos - (available_width - 1)
} else {
0
};
// Get the visible portion of the buffer // Get the visible portion of the buffer
let visible_buffer: String = input_buffer let visible_buffer: String = input_buffer
@@ -1006,9 +1013,9 @@ impl RetroTui {
let fade_color = |color: Color| -> Color { let fade_color = |color: Color| -> Color {
match color { match color {
Color::Rgb(r, g, b) => { Color::Rgb(r, g, b) => {
let faded_r = (r as f32 * opacity) as u8; let faded_r = ((r as f32 * opacity) as u8).max(0);
let faded_g = (g as f32 * opacity) as u8; let faded_g = ((g as f32 * opacity) as u8).max(0);
let faded_b = (b as f32 * opacity) as u8; let faded_b = ((b as f32 * opacity) as u8).max(0);
Color::Rgb(faded_r, faded_g, faded_b) Color::Rgb(faded_r, faded_g, faded_b)
} }
_ => color, _ => color,
@@ -1091,9 +1098,9 @@ impl RetroTui {
let fade_color = |color: Color| -> Color { let fade_color = |color: Color| -> Color {
match color { match color {
Color::Rgb(r, g, b) => { Color::Rgb(r, g, b) => {
let faded_r = (r as f32 * opacity) as u8; let faded_r = ((r as f32 * opacity) as u8).max(0);
let faded_g = (g as f32 * opacity) as u8; let faded_g = ((g as f32 * opacity) as u8).max(0);
let faded_b = (b as f32 * opacity) as u8; let faded_b = ((b as f32 * opacity) as u8).max(0);
Color::Rgb(faded_r, faded_g, faded_b) Color::Rgb(faded_r, faded_g, faded_b)
} }
_ => color, _ => color,
@@ -1169,7 +1176,7 @@ impl RetroTui {
} }
// Wave characters for smooth animation // Wave characters for smooth animation
let wave_chars = ['▁', '▂', '▃', '▄', '▅', '▆', '▇', '█']; let wave_chars = vec!['▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'];
// Build the wave line // Build the wave line
let mut wave_line = String::new(); let mut wave_line = String::new();
@@ -1183,7 +1190,7 @@ impl RetroTui {
let idx = wave_data.len().saturating_sub(display_width) + i; let idx = wave_data.len().saturating_sub(display_width) + i;
if idx < wave_data.len() { if idx < wave_data.len() {
let value = wave_data[idx].clamp(0.0, 1.0); let value = wave_data[idx].min(1.0).max(0.0);
let char_idx = ((value * 7.0) as usize).min(7); let char_idx = ((value * 7.0) as usize).min(7);
wave_line.push(wave_chars[char_idx]); wave_line.push(wave_chars[char_idx]);
} else { } else {
@@ -1199,6 +1206,8 @@ impl RetroTui {
f.render_widget(wave_paragraph, area); f.render_widget(wave_paragraph, area);
} }
/// Draw the status bar
/// Draw the status bar /// Draw the status bar
fn draw_status_bar( fn draw_status_bar(
f: &mut Frame, f: &mut Frame,

View File

@@ -1,32 +0,0 @@
/// Simple output helper for printing messages
pub struct SimpleOutput {
machine_mode: bool,
}
impl SimpleOutput {
pub fn new() -> Self {
SimpleOutput { machine_mode: false }
}
pub fn new_with_mode(machine_mode: bool) -> Self {
SimpleOutput { machine_mode }
}
pub fn print(&self, message: &str) {
if !self.machine_mode {
println!("{}", message);
}
}
pub fn print_smart(&self, message: &str) {
if !self.machine_mode {
println!("{}", message);
}
}
}
impl Default for SimpleOutput {
fn default() -> Self {
Self::new()
}
}

View File

@@ -1,6 +1,5 @@
use crossterm::style::Color; use crossterm::style::Color;
use crossterm::style::{SetForegroundColor, ResetColor}; use crossterm::style::{SetForegroundColor, ResetColor};
use std::io::{self, Write};
use termimad::MadSkin; use termimad::MadSkin;
/// Simple output handler with markdown support /// Simple output handler with markdown support
@@ -41,7 +40,7 @@ impl SimpleOutput {
trimmed.starts_with("* ") || trimmed.starts_with("* ") ||
trimmed.starts_with("+ ") || trimmed.starts_with("+ ") ||
(trimmed.len() > 2 && (trimmed.len() > 2 &&
trimmed.chars().next().is_some_and(|c| c.is_ascii_digit()) && trimmed.chars().next().map_or(false, |c| c.is_ascii_digit()) &&
trimmed.chars().nth(1) == Some('.') && trimmed.chars().nth(1) == Some('.') &&
trimmed.chars().nth(2) == Some(' ')) || trimmed.chars().nth(2) == Some(' ')) ||
(trimmed.contains('[') && trimmed.contains("](")) (trimmed.contains('[') && trimmed.contains("]("))
@@ -94,37 +93,6 @@ impl SimpleOutput {
print!("{}", ResetColor); print!("{}", ResetColor);
println!(" {:.1}% | {}/{} tokens", percentage, used, total); println!(" {:.1}% | {}/{} tokens", percentage, used, total);
} }
pub fn print_context_thinning(&self, message: &str) {
// Animated highlight for context thinning
// Use bright cyan/green with a quick flash animation
// Flash animation: print with bright background, then normal
let frames = vec![
"\x1b[1;97;46m", // Frame 1: Bold white on cyan background
"\x1b[1;97;42m", // Frame 2: Bold white on green background
"\x1b[1;96;40m", // Frame 3: Bold cyan on black background
];
println!();
// Quick flash animation
for frame in &frames {
print!("\r{}{}\x1b[0m", frame, message);
let _ = io::stdout().flush();
std::thread::sleep(std::time::Duration::from_millis(80));
}
// Final display with bright cyan and sparkle emojis
print!("\r\x1b[1;96m✨ {}\x1b[0m", message);
println!();
// Add a subtle "success" indicator line
println!("\x1b[2;36m └─ Context optimized successfully\x1b[0m");
println!();
let _ = io::stdout().flush();
}
} }
#[cfg(test)] #[cfg(test)]

View File

@@ -1,6 +1,8 @@
use crate::retro_tui::RetroTui;
use g3_core::ui_writer::UiWriter; use g3_core::ui_writer::UiWriter;
use std::io::{self, Write}; use std::io::{self, Write};
use std::sync::Mutex; use std::sync::Mutex;
use std::time::Instant;
/// Console implementation of UiWriter that prints to stdout /// Console implementation of UiWriter that prints to stdout
pub struct ConsoleUiWriter { pub struct ConsoleUiWriter {
@@ -102,37 +104,6 @@ impl UiWriter for ConsoleUiWriter {
println!("{}", message); println!("{}", message);
} }
fn print_context_thinning(&self, message: &str) {
// Animated highlight for context thinning
// Use bright cyan/green with a quick flash animation
// Flash animation: print with bright background, then normal
let frames = vec![
"\x1b[1;97;46m", // Frame 1: Bold white on cyan background
"\x1b[1;97;42m", // Frame 2: Bold white on green background
"\x1b[1;96;40m", // Frame 3: Bold cyan on black background
];
println!();
// Quick flash animation
for frame in &frames {
print!("\r{}{}\x1b[0m", frame, message);
let _ = io::stdout().flush();
std::thread::sleep(std::time::Duration::from_millis(80));
}
// Final display with bright cyan and sparkle emojis
print!("\r\x1b[1;96m✨ {}\x1b[0m", message);
println!();
// Add a subtle "success" indicator line
println!("\x1b[2;36m └─ Context optimized successfully\x1b[0m");
println!();
let _ = io::stdout().flush();
}
fn print_tool_header(&self, tool_name: &str) { fn print_tool_header(&self, tool_name: &str) {
// Store the tool name and clear args for collection // Store the tool name and clear args for collection
*self.current_tool_name.lock().unwrap() = Some(tool_name.to_string()); *self.current_tool_name.lock().unwrap() = Some(tool_name.to_string());
@@ -144,6 +115,7 @@ impl UiWriter for ConsoleUiWriter {
// For todo tools, we'll skip the normal header and print a custom one later // For todo tools, we'll skip the normal header and print a custom one later
if is_todo { if is_todo {
return;
} }
} }
@@ -191,12 +163,7 @@ impl UiWriter for ConsoleUiWriter {
// Truncate long values for display // Truncate long values for display
let display_value = if first_line.len() > 80 { let display_value = if first_line.len() > 80 {
// Use char_indices to safely truncate at character boundary format!("{}...", &first_line[..77])
let truncate_at = first_line.char_indices()
.nth(77)
.map(|(i, _)| i)
.unwrap_or(first_line.len());
format!("{}...", &first_line[..truncate_at])
} else { } else {
first_line.to_string() first_line.to_string()
}; };
@@ -345,3 +312,223 @@ impl UiWriter for ConsoleUiWriter {
} }
} }
/// RetroTui implementation of UiWriter that sends output to the TUI
pub struct RetroTuiWriter {
tui: RetroTui,
current_tool_name: Mutex<Option<String>>,
current_tool_output: Mutex<Vec<String>>,
current_tool_start: Mutex<Option<Instant>>,
current_tool_caption: Mutex<String>,
}
impl RetroTuiWriter {
pub fn new(tui: RetroTui) -> Self {
Self {
tui,
current_tool_name: Mutex::new(None),
current_tool_output: Mutex::new(Vec::new()),
current_tool_start: Mutex::new(None),
current_tool_caption: Mutex::new(String::new()),
}
}
}
impl UiWriter for RetroTuiWriter {
fn print(&self, message: &str) {
self.tui.output(message);
}
fn println(&self, message: &str) {
self.tui.output(message);
}
fn print_inline(&self, message: &str) {
// For inline printing, we'll just append to the output
self.tui.output(message);
}
fn print_system_prompt(&self, prompt: &str) {
self.tui.output("🔍 System Prompt:");
self.tui.output("================");
for line in prompt.lines() {
self.tui.output(line);
}
self.tui.output("================");
self.tui.output("");
}
fn print_context_status(&self, message: &str) {
self.tui.output(message);
}
fn print_tool_header(&self, tool_name: &str) {
// Start collecting tool output
*self.current_tool_start.lock().unwrap() = Some(Instant::now());
*self.current_tool_name.lock().unwrap() = Some(tool_name.to_string());
self.current_tool_output.lock().unwrap().clear();
self.current_tool_output
.lock()
.unwrap()
.push(format!("Tool: {}", tool_name));
// Initialize caption
*self.current_tool_caption.lock().unwrap() = String::new();
}
fn print_tool_arg(&self, key: &str, value: &str) {
// Filter out any keys that look like they might be agent message content
// (e.g., keys that are suspiciously long or contain message-like content)
let is_valid_arg_key = key.len() < 50
&& !key.contains('\n')
&& !key.contains("I'll")
&& !key.contains("Let me")
&& !key.contains("Here's")
&& !key.contains("I can");
if is_valid_arg_key {
self.current_tool_output
.lock()
.unwrap()
.push(format!("{}: {}", key, value));
}
// Build caption from first argument (usually the most important one)
let mut caption = self.current_tool_caption.lock().unwrap();
if caption.is_empty() && (key == "file_path" || key == "command" || key == "path") {
// Truncate long values for the caption
let truncated = if value.len() > 50 {
format!("{}...", &value[..47])
} else {
value.to_string()
};
// Add range information for read_file tool calls
let tool_name = self.current_tool_name.lock().unwrap();
let range_suffix = if tool_name.as_ref().map_or(false, |name| name == "read_file") {
// We need to check if start/end args will be provided - for now just check if this is a partial read
// This is a simplified approach since we're building the caption incrementally
String::new() // We'll handle this in print_tool_output_header instead
} else {
String::new()
};
*caption = format!("{}{}", truncated, range_suffix);
}
}
fn print_tool_output_header(&self) {
// This is called right before tool execution starts
// Send the initial tool header to the TUI now
if let Some(tool_name) = self.current_tool_name.lock().unwrap().as_ref() {
let mut caption = self.current_tool_caption.lock().unwrap().clone();
// Add range information for read_file tool calls
if tool_name == "read_file" {
// Check the tool output for start/end parameters
let output = self.current_tool_output.lock().unwrap();
let has_start = output.iter().any(|line| line.starts_with("start:"));
let has_end = output.iter().any(|line| line.starts_with("end:"));
if has_start || has_end {
let start_val = output.iter().find(|line| line.starts_with("start:")).map(|line| line.split(':').nth(1).unwrap_or("0").trim()).unwrap_or("0");
let end_val = output.iter().find(|line| line.starts_with("end:")).map(|line| line.split(':').nth(1).unwrap_or("end").trim()).unwrap_or("end");
caption = format!("{} [{}..{}]", caption, start_val, end_val);
}
}
// Send the tool output with initial header
self.tui.tool_output(tool_name, &caption, "");
}
self.current_tool_output.lock().unwrap().push(String::new());
self.current_tool_output
.lock()
.unwrap()
.push("Output:".to_string());
}
fn update_tool_output_line(&self, line: &str) {
// For retro mode, we'll just add to the output buffer
self.current_tool_output
.lock()
.unwrap()
.push(line.to_string());
}
fn print_tool_output_line(&self, line: &str) {
self.current_tool_output
.lock()
.unwrap()
.push(line.to_string());
}
fn print_tool_output_summary(&self, hidden_count: usize) {
self.current_tool_output.lock().unwrap().push(format!(
"... ({} more line{})",
hidden_count,
if hidden_count == 1 { "" } else { "s" }
));
}
fn print_tool_timing(&self, duration_str: &str) {
self.current_tool_output
.lock()
.unwrap()
.push(format!("⚡️ {}", duration_str));
// Calculate the actual duration
let duration_ms = if let Some(start) = *self.current_tool_start.lock().unwrap() {
start.elapsed().as_millis()
} else {
0
};
// Get the tool name and caption
if let Some(tool_name) = self.current_tool_name.lock().unwrap().as_ref() {
let content = self.current_tool_output.lock().unwrap().join("\n");
let caption = self.current_tool_caption.lock().unwrap().clone();
let caption = if caption.is_empty() {
"Completed".to_string()
} else {
caption
};
// Update the tool detail panel with the complete output without adding a new header
// This keeps the original header in place to be updated by tool_complete
self.tui.update_tool_detail(tool_name, &content);
// Determine success based on whether there's an error in the output
// This is a simple heuristic - you might want to make this more sophisticated
let success = !content.contains("error")
&& !content.contains("Error")
&& !content.contains("ERROR");
// Send the completion status to update the header
self.tui
.tool_complete(tool_name, success, duration_ms, &caption);
}
// Clear the buffers
*self.current_tool_name.lock().unwrap() = None;
self.current_tool_output.lock().unwrap().clear();
*self.current_tool_start.lock().unwrap() = None;
*self.current_tool_caption.lock().unwrap() = String::new();
}
fn print_agent_prompt(&self) {
self.tui.output("\n💬 ");
}
fn print_agent_response(&self, content: &str) {
self.tui.output(content);
}
fn notify_sse_received(&self) {
// Notify the TUI that an SSE was received
self.tui.sse_received();
}
fn flush(&self) {
// No-op for TUI since it handles its own rendering
}
}

View File

@@ -3,9 +3,6 @@ name = "g3-computer-control"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
[build-dependencies]
# Only needed for building Swift bridge on macOS
[dependencies] [dependencies]
# Workspace dependencies # Workspace dependencies
tokio = { workspace = true } tokio = { workspace = true }
@@ -23,13 +20,15 @@ async-trait = "0.1"
# WebDriver support # WebDriver support
fantoccini = "0.21" fantoccini = "0.21"
# OCR dependencies
tesseract = "0.14"
# macOS dependencies # macOS dependencies
[target.'cfg(target_os = "macos")'.dependencies] [target.'cfg(target_os = "macos")'.dependencies]
core-graphics = "0.23" core-graphics = "0.23"
core-foundation = "0.10" core-foundation = "0.9"
cocoa = "0.25" cocoa = "0.25"
objc = "0.2" objc = "0.2"
accessibility = "0.2"
image = "0.24" image = "0.24"
# Linux dependencies # Linux dependencies

View File

@@ -1,63 +0,0 @@
use std::env;
use std::path::PathBuf;
use std::process::Command;
fn main() {
// Only build Vision bridge on macOS
if env::var("CARGO_CFG_TARGET_OS").unwrap() != "macos" {
return;
}
println!("cargo:rerun-if-changed=vision-bridge/Sources/VisionBridge/VisionOCR.swift");
println!("cargo:rerun-if-changed=vision-bridge/Sources/VisionBridge/VisionBridge.h");
println!("cargo:rerun-if-changed=vision-bridge/Package.swift");
let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap());
let vision_bridge_dir = manifest_dir.join("vision-bridge");
// Build Swift package
println!("cargo:warning=Building VisionBridge Swift package...");
let build_status = Command::new("swift")
.args(&["build", "-c", "release"])
.current_dir(&vision_bridge_dir)
.status()
.expect("Failed to build Swift package");
if !build_status.success() {
panic!("Swift build failed");
}
// Find the built library
let lib_path = vision_bridge_dir
.join(".build/release")
.canonicalize()
.expect("Failed to find .build/release directory");
// Copy the dylib to the output directory so it can be found at runtime
let target_dir = manifest_dir.parent().unwrap().parent().unwrap().join("target");
let profile = env::var("PROFILE").unwrap_or_else(|_| "debug".to_string());
let output_dir = target_dir.join(&profile);
let dylib_src = lib_path.join("libVisionBridge.dylib");
let dylib_dst = output_dir.join("libVisionBridge.dylib");
std::fs::copy(&dylib_src, &dylib_dst)
.expect(&format!("Failed to copy dylib from {} to {}", dylib_src.display(), dylib_dst.display()));
println!("cargo:warning=Copied libVisionBridge.dylib to {}", dylib_dst.display());
// Add rpath so the dylib can be found at runtime
println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path");
println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path");
println!("cargo:rustc-link-search=native={}", lib_path.display());
println!("cargo:rustc-link-lib=dylib=VisionBridge");
// Link required frameworks
println!("cargo:rustc-link-lib=framework=Vision");
println!("cargo:rustc-link-lib=framework=AppKit");
println!("cargo:rustc-link-lib=framework=Foundation");
println!("cargo:rustc-link-lib=framework=CoreGraphics");
println!("cargo:rustc-link-lib=framework=CoreImage");
println!("cargo:warning=VisionBridge built successfully at {}", lib_path.display());
}

View File

@@ -1,7 +1,7 @@
use core_graphics::window::{kCGWindowListOptionOnScreenOnly, kCGNullWindowID, CGWindowListCopyWindowInfo}; use core_graphics::window::{kCGWindowListOptionOnScreenOnly, kCGNullWindowID, CGWindowListCopyWindowInfo};
use core_foundation::dictionary::CFDictionary; use core_foundation::dictionary::CFDictionary;
use core_foundation::string::CFString; use core_foundation::string::CFString;
use core_foundation::base::{TCFType, ToVoid}; use core_foundation::base::TCFType;
fn main() { fn main() {
println!("Listing all on-screen windows..."); println!("Listing all on-screen windows...");
@@ -22,7 +22,7 @@ fn main() {
// Get window ID // Get window ID
let window_id_key = CFString::from_static_string("kCGWindowNumber"); let window_id_key = CFString::from_static_string("kCGWindowNumber");
let window_id: i64 = if let Some(value) = dict.find(window_id_key.to_void()) { let window_id: i64 = if let Some(value) = dict.find(window_id_key.as_concrete_TypeRef()) {
let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _); let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _);
num.to_i64().unwrap_or(0) num.to_i64().unwrap_or(0)
} else { } else {
@@ -31,7 +31,7 @@ fn main() {
// Get owner name // Get owner name
let owner_key = CFString::from_static_string("kCGWindowOwnerName"); let owner_key = CFString::from_static_string("kCGWindowOwnerName");
let owner: String = if let Some(value) = dict.find(owner_key.to_void()) { let owner: String = if let Some(value) = dict.find(owner_key.as_concrete_TypeRef()) {
let s: CFString = TCFType::wrap_under_get_rule(*value as *const _); let s: CFString = TCFType::wrap_under_get_rule(*value as *const _);
s.to_string() s.to_string()
} else { } else {
@@ -40,15 +40,15 @@ fn main() {
// Get window name/title // Get window name/title
let name_key = CFString::from_static_string("kCGWindowName"); let name_key = CFString::from_static_string("kCGWindowName");
let title: String = if let Some(value) = dict.find(name_key.to_void()) { let title: String = if let Some(value) = dict.find(name_key.as_concrete_TypeRef()) {
let s: CFString = TCFType::wrap_under_get_rule(*value as *const _); let s: CFString = TCFType::wrap_under_get_rule(*value as *const _);
s.to_string() s.to_string()
} else { } else {
"".to_string() "".to_string()
}; };
// Show all windows // Filter for iTerm or show all
if !owner.is_empty() { if owner.contains("iTerm") || owner.contains("Terminal") {
println!("{:<10} {:<25} {}", window_id, owner, title); println!("{:<10} {:<25} {}", window_id, owner, title);
} }
} }

View File

@@ -1,74 +0,0 @@
//! Example demonstrating macOS Accessibility API tools
//!
//! This example shows how to use the macax tools to control macOS applications.
//!
//! Run with: cargo run --example macax_demo
use anyhow::Result;
use g3_computer_control::MacAxController;
#[tokio::main]
async fn main() -> Result<()> {
println!("🍎 macOS Accessibility API Demo\n");
println!("This demo shows how to control macOS applications using the Accessibility API.\n");
// Create controller
let controller = MacAxController::new()?;
println!("✅ MacAxController initialized\n");
// List running applications
println!("📱 Listing running applications:");
match controller.list_applications() {
Ok(apps) => {
for app in apps.iter().take(10) {
println!(" - {}", app.name);
}
if apps.len() > 10 {
println!(" ... and {} more", apps.len() - 10);
}
}
Err(e) => println!(" ❌ Error: {}", e),
}
println!();
// Get frontmost app
println!("🎯 Getting frontmost application:");
match controller.get_frontmost_app() {
Ok(app) => println!(" Current: {}", app.name),
Err(e) => println!(" ❌ Error: {}", e),
}
println!();
// Example: Activate Finder and get its UI tree
println!("📂 Activating Finder and inspecting UI:");
match controller.activate_app("Finder") {
Ok(_) => {
println!(" ✅ Finder activated");
// Wait a moment for activation
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
// Get UI tree
match controller.get_ui_tree("Finder", 2) {
Ok(tree) => {
println!("\n UI Tree:");
for line in tree.lines().take(10) {
println!(" {}", line);
}
}
Err(e) => println!(" ❌ Error getting UI tree: {}", e),
}
}
Err(e) => println!(" ❌ Error: {}", e),
}
println!();
println!("✨ Demo complete!\n");
println!("💡 Tips:");
println!(" - Use --macax flag with g3 to enable these tools");
println!(" - Grant accessibility permissions in System Preferences");
println!(" - Add accessibility identifiers to your apps for easier automation");
println!(" - See docs/macax-tools.md for full documentation\n");
Ok(())
}

View File

@@ -31,7 +31,7 @@ async fn main() -> Result<()> {
// Find an element // Find an element
println!("Finding h1 element..."); println!("Finding h1 element...");
let h1 = driver.find_element("h1").await?; let mut h1 = driver.find_element("h1").await?;
let h1_text = h1.text().await?; let h1_text = h1.text().await?;
println!("H1 text: {}\n", h1_text); println!("H1 text: {}\n", h1_text);

View File

@@ -1,4 +1,4 @@
use g3_computer_control::create_controller; use g3_computer_control::{create_controller, ComputerController};
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {

View File

@@ -1,5 +1,6 @@
use core_graphics::display::CGDisplay; use core_graphics::display::CGDisplay;
use image::{ImageBuffer, RgbaImage}; use image::{ImageBuffer, RgbaImage};
use std::path::Path;
fn main() { fn main() {
let display = CGDisplay::main(); let display = CGDisplay::main();

View File

@@ -1,48 +0,0 @@
//! Test the new type_text functionality
use anyhow::Result;
use g3_computer_control::MacAxController;
#[tokio::main]
async fn main() -> Result<()> {
println!("🧪 Testing macax type_text functionality\n");
let controller = MacAxController::new()?;
println!("✅ Controller initialized\n");
// Test 1: Type simple text
println!("Test 1: Typing simple text into TextEdit");
println!(" Please open TextEdit and create a new document...");
std::thread::sleep(std::time::Duration::from_secs(3));
match controller.type_text("TextEdit", "Hello, World!") {
Ok(_) => println!(" ✅ Successfully typed simple text\n"),
Err(e) => println!(" ❌ Failed: {}\n", e),
}
std::thread::sleep(std::time::Duration::from_secs(1));
// Test 2: Type unicode and emojis
println!("Test 2: Typing unicode and emojis");
match controller.type_text("TextEdit", "\n🌟 Unicode test: café, naïve, 日本語 🎉") {
Ok(_) => println!(" ✅ Successfully typed unicode text\n"),
Err(e) => println!(" ❌ Failed: {}\n", e),
}
std::thread::sleep(std::time::Duration::from_secs(1));
// Test 3: Type special characters
println!("Test 3: Typing special characters");
match controller.type_text("TextEdit", "\nSpecial: @#$%^&*()_+-=[]{}|;':,.<>?/") {
Ok(_) => println!(" ✅ Successfully typed special characters\n"),
Err(e) => println!(" ❌ Failed: {}\n", e),
}
println!("\n✨ Tests complete!");
println!("\n💡 Now try with Things3:");
println!(" 1. Open Things3");
println!(" 2. Press Cmd+N to create a new task");
println!(" 3. Run: g3 --macax 'type \"🌟 My awesome task\" into Things'");
Ok(())
}

View File

@@ -1,85 +0,0 @@
use g3_computer_control::ocr::{OCREngine, DefaultOCR};
use anyhow::Result;
#[tokio::main]
async fn main() -> Result<()> {
println!("🧪 Testing Apple Vision OCR");
println!("===========================\n");
// Initialize OCR engine
println!("📦 Initializing OCR engine...");
let ocr = DefaultOCR::new()?;
println!("✅ OCR engine: {}\n", ocr.name());
// Check if test image exists
let test_image = "/tmp/safari_test.png";
if !std::path::Path::new(test_image).exists() {
println!("⚠️ Test image not found: {}", test_image);
println!(" Creating a screenshot...");
let status = std::process::Command::new("screencapture")
.arg("-x")
.arg("-R")
.arg("0,0,1200,800")
.arg(test_image)
.status()?;
if !status.success() {
anyhow::bail!("Failed to create screenshot");
}
println!("✅ Screenshot created\n");
}
// Run OCR
println!("🔍 Running Apple Vision OCR on {}...", test_image);
let start = std::time::Instant::now();
let locations = ocr.extract_text_with_locations(test_image).await?;
let duration = start.elapsed();
println!("✅ OCR completed in {:.3}s\n", duration.as_secs_f64());
// Display results
println!("📊 Results:");
println!(" Found {} text elements\n", locations.len());
if locations.is_empty() {
println!("⚠️ No text found in image");
} else {
println!(" Top 20 results:");
println!(" {:<4} {:<40} {:<15} {:<12} {:<8}", "#", "Text", "Position", "Size", "Conf");
println!(" {}", "-".repeat(85));
for (i, loc) in locations.iter().take(20).enumerate() {
let text = if loc.text.len() > 37 {
format!("{}...", &loc.text[..37])
} else {
loc.text.clone()
};
println!(" {:<4} {:<40} ({:>4},{:>4}) {:>4}x{:<4} {:.2}",
i + 1,
text,
loc.x,
loc.y,
loc.width,
loc.height,
loc.confidence
);
}
if locations.len() > 20 {
println!("\n ... and {} more", locations.len() - 20);
}
// Performance comparison
println!("\n📈 Performance:");
println!(" OCR Speed: {:.3}s", duration.as_secs_f64());
println!(" Text elements: {}", locations.len());
println!(" Avg per element: {:.1}ms", duration.as_millis() as f64 / locations.len() as f64);
}
println!("\n✅ Test complete!");
Ok(())
}

View File

@@ -1,18 +1,10 @@
// Suppress warnings from objc crate macros
#![allow(unexpected_cfgs)]
pub mod types; pub mod types;
pub mod platform; pub mod platform;
pub mod ocr;
pub mod webdriver; pub mod webdriver;
pub mod macax;
// Re-export webdriver types for convenience // Re-export webdriver types for convenience
pub use webdriver::{WebDriverController, WebElement, safari::SafariDriver}; pub use webdriver::{WebDriverController, WebElement, safari::SafariDriver};
// Re-export macax types for convenience
pub use macax::{MacAxController, AXElement, AXApplication};
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use types::*; use types::*;
@@ -23,14 +15,8 @@ pub trait ComputerController: Send + Sync {
async fn take_screenshot(&self, path: &str, region: Option<Rect>, window_id: Option<&str>) -> Result<()>; async fn take_screenshot(&self, path: &str, region: Option<Rect>, window_id: Option<&str>) -> Result<()>;
// OCR operations // OCR operations
async fn extract_text_from_screen(&self, region: Rect, window_id: &str) -> Result<String>; async fn extract_text_from_screen(&self, region: Rect) -> Result<String>;
async fn extract_text_from_image(&self, path: &str) -> Result<String>; async fn extract_text_from_image(&self, path: &str) -> Result<String>;
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>>;
async fn find_text_in_app(&self, app_name: &str, search_text: &str) -> Result<Option<TextLocation>>;
// Mouse operations
fn move_mouse(&self, x: i32, y: i32) -> Result<()>;
fn click_at(&self, x: i32, y: i32, app_name: Option<&str>) -> Result<()>;
} }
// Platform-specific constructor // Platform-specific constructor

View File

@@ -1,822 +0,0 @@
use super::{AXApplication, AXElement};
use anyhow::{Context, Result};
use std::collections::HashMap;
#[cfg(target_os = "macos")]
use accessibility::{AXUIElement, AXUIElementAttributes, ElementFinder, TreeVisitor, TreeWalker, TreeWalkerFlow};
#[cfg(target_os = "macos")]
use core_foundation::base::TCFType;
#[cfg(target_os = "macos")]
use core_foundation::string::CFString;
/// macOS Accessibility API controller using native APIs
pub struct MacAxController {
// Cache for application elements
app_cache: std::sync::Mutex<HashMap<String, AXUIElement>>,
}
impl MacAxController {
pub fn new() -> Result<Self> {
#[cfg(target_os = "macos")]
{
// Check if we have accessibility permissions by trying to get system-wide element
let _system = AXUIElement::system_wide();
Ok(Self {
app_cache: std::sync::Mutex::new(HashMap::new()),
})
}
#[cfg(not(target_os = "macos"))]
{
anyhow::bail!("macOS Accessibility API is only available on macOS")
}
}
/// List all running applications
#[cfg(target_os = "macos")]
pub fn list_applications(&self) -> Result<Vec<AXApplication>> {
let apps = Self::get_running_applications()?;
Ok(apps)
}
#[cfg(not(target_os = "macos"))]
pub fn list_applications(&self) -> Result<Vec<AXApplication>> {
anyhow::bail!("Not supported on this platform")
}
#[cfg(target_os = "macos")]
fn get_running_applications() -> Result<Vec<AXApplication>> {
use cocoa::appkit::NSApplicationActivationPolicy;
use cocoa::base::{id, nil};
use objc::{class, msg_send, sel, sel_impl};
unsafe {
let workspace: id = msg_send![class!(NSWorkspace), sharedWorkspace];
let running_apps: id = msg_send![workspace, runningApplications];
let count: usize = msg_send![running_apps, count];
let mut apps = Vec::new();
for i in 0..count {
let app: id = msg_send![running_apps, objectAtIndex: i];
// Get app name
let localized_name: id = msg_send![app, localizedName];
if localized_name == nil {
continue;
}
let name_ptr: *const i8 = msg_send![localized_name, UTF8String];
let name = if !name_ptr.is_null() {
std::ffi::CStr::from_ptr(name_ptr)
.to_string_lossy()
.to_string()
} else {
continue;
};
// Get bundle ID
let bundle_id_obj: id = msg_send![app, bundleIdentifier];
let bundle_id = if bundle_id_obj != nil {
let bundle_id_ptr: *const i8 = msg_send![bundle_id_obj, UTF8String];
if !bundle_id_ptr.is_null() {
Some(
std::ffi::CStr::from_ptr(bundle_id_ptr)
.to_string_lossy()
.to_string(),
)
} else {
None
}
} else {
None
};
// Get PID
let pid: i32 = msg_send![app, processIdentifier];
// Skip background-only apps
let activation_policy: i64 = msg_send![app, activationPolicy];
if activation_policy == NSApplicationActivationPolicy::NSApplicationActivationPolicyRegular as i64 {
apps.push(AXApplication {
name,
bundle_id,
pid,
});
}
}
Ok(apps)
}
}
/// Get the frontmost (active) application
#[cfg(target_os = "macos")]
pub fn get_frontmost_app(&self) -> Result<AXApplication> {
use cocoa::base::{id, nil};
use objc::{class, msg_send, sel, sel_impl};
unsafe {
let workspace: id = msg_send![class!(NSWorkspace), sharedWorkspace];
let frontmost_app: id = msg_send![workspace, frontmostApplication];
if frontmost_app == nil {
anyhow::bail!("No frontmost application");
}
// Get app name
let localized_name: id = msg_send![frontmost_app, localizedName];
let name_ptr: *const i8 = msg_send![localized_name, UTF8String];
let name = std::ffi::CStr::from_ptr(name_ptr)
.to_string_lossy()
.to_string();
// Get bundle ID
let bundle_id_obj: id = msg_send![frontmost_app, bundleIdentifier];
let bundle_id = if bundle_id_obj != nil {
let bundle_id_ptr: *const i8 = msg_send![bundle_id_obj, UTF8String];
if !bundle_id_ptr.is_null() {
Some(
std::ffi::CStr::from_ptr(bundle_id_ptr)
.to_string_lossy()
.to_string(),
)
} else {
None
}
} else {
None
};
// Get PID
let pid: i32 = msg_send![frontmost_app, processIdentifier];
Ok(AXApplication {
name,
bundle_id,
pid,
})
}
}
#[cfg(not(target_os = "macos"))]
pub fn get_frontmost_app(&self) -> Result<AXApplication> {
anyhow::bail!("Not supported on this platform")
}
/// Get AXUIElement for an application by name or PID
#[cfg(target_os = "macos")]
fn get_app_element(&self, app_name: &str) -> Result<AXUIElement> {
// Check cache first
{
let cache = self.app_cache.lock().unwrap();
if let Some(element) = cache.get(app_name) {
return Ok(element.clone());
}
}
// Find the app by name
let apps = Self::get_running_applications()?;
let app = apps
.iter()
.find(|a| a.name == app_name)
.ok_or_else(|| anyhow::anyhow!("Application '{}' not found", app_name))?;
// Create AXUIElement for the app
let element = AXUIElement::application(app.pid);
// Cache it
{
let mut cache = self.app_cache.lock().unwrap();
cache.insert(app_name.to_string(), element.clone());
}
Ok(element)
}
/// Activate (bring to front) an application
#[cfg(target_os = "macos")]
pub fn activate_app(&self, app_name: &str) -> Result<()> {
use cocoa::base::id;
use objc::{class, msg_send, sel, sel_impl};
// Find the app
let apps = Self::get_running_applications()?;
let app = apps
.iter()
.find(|a| a.name == app_name)
.ok_or_else(|| anyhow::anyhow!("Application '{}' not found", app_name))?;
unsafe {
let workspace: id = msg_send![class!(NSWorkspace), sharedWorkspace];
let running_apps: id = msg_send![workspace, runningApplications];
let count: usize = msg_send![running_apps, count];
for i in 0..count {
let running_app: id = msg_send![running_apps, objectAtIndex: i];
let pid: i32 = msg_send![running_app, processIdentifier];
if pid == app.pid {
let _: bool = msg_send![running_app, activateWithOptions: 0];
return Ok(());
}
}
}
anyhow::bail!("Failed to activate application")
}
#[cfg(not(target_os = "macos"))]
pub fn activate_app(&self, _app_name: &str) -> Result<()> {
anyhow::bail!("Not supported on this platform")
}
/// Get the UI hierarchy of an application
#[cfg(target_os = "macos")]
pub fn get_ui_tree(&self, app_name: &str, max_depth: usize) -> Result<String> {
let app_element = self.get_app_element(app_name)?;
let mut output = format!("Application: {}\n", app_name);
Self::build_ui_tree(&app_element, &mut output, 0, max_depth)?;
Ok(output)
}
#[cfg(not(target_os = "macos"))]
pub fn get_ui_tree(&self, _app_name: &str, _max_depth: usize) -> Result<String> {
anyhow::bail!("Not supported on this platform")
}
#[cfg(target_os = "macos")]
fn build_ui_tree(
element: &AXUIElement,
output: &mut String,
depth: usize,
max_depth: usize,
) -> Result<()> {
if depth >= max_depth {
return Ok(());
}
let indent = " ".repeat(depth);
// Get role
let role = element.role().ok().map(|s| s.to_string())
.unwrap_or_else(|| "Unknown".to_string());
// Get title
let title = element.title().ok()
.map(|s| s.to_string());
// Get identifier
let identifier = element.identifier().ok()
.map(|s| s.to_string());
// Format output
output.push_str(&format!("{}Role: {}", indent, role));
if let Some(t) = title {
output.push_str(&format!(", Title: {}", t));
}
if let Some(id) = identifier {
output.push_str(&format!(", ID: {}", id));
}
output.push('\n');
// Get children
if let Ok(children) = element.children() {
for i in 0..children.len() {
if let Some(child) = children.get(i) {
let _ = Self::build_ui_tree(&child, output, depth + 1, max_depth);
}
}
}
Ok(())
}
/// Find UI elements in an application
#[cfg(target_os = "macos")]
pub fn find_elements(
&self,
app_name: &str,
role: Option<&str>,
title: Option<&str>,
identifier: Option<&str>,
) -> Result<Vec<AXElement>> {
let app_element = self.get_app_element(app_name)?;
let mut found_elements = Vec::new();
let visitor = ElementCollector {
role_filter: role.map(|s| s.to_string()),
title_filter: title.map(|s| s.to_string()),
identifier_filter: identifier.map(|s| s.to_string()),
results: std::cell::RefCell::new(&mut found_elements),
depth: std::cell::Cell::new(0),
};
let walker = TreeWalker::new();
walker.walk(&app_element, &visitor);
Ok(found_elements)
}
#[cfg(not(target_os = "macos"))]
pub fn find_elements(
&self,
_app_name: &str,
_role: Option<&str>,
_title: Option<&str>,
_identifier: Option<&str>,
) -> Result<Vec<AXElement>> {
anyhow::bail!("Not supported on this platform")
}
/// Find a single element (helper for click, set_value, etc.)
#[cfg(target_os = "macos")]
fn find_element(
&self,
app_name: &str,
role: &str,
title: Option<&str>,
identifier: Option<&str>,
) -> Result<AXUIElement> {
let app_element = self.get_app_element(app_name)?;
let role_str = role.to_string();
let title_str = title.map(|s| s.to_string());
let identifier_str = identifier.map(|s| s.to_string());
let finder = ElementFinder::new(
&app_element,
move |element| {
// Check role
let elem_role = element.role()
.ok()
.map(|s| s.to_string());
if let Some(r) = elem_role {
if !r.contains(&role_str) {
return false;
}
} else {
return false;
}
// Check title if specified
if let Some(ref title_filter) = title_str {
let elem_title = element.title()
.ok()
.map(|s| s.to_string());
if let Some(t) = elem_title {
if !t.contains(title_filter) {
return false;
}
} else {
return false;
}
}
// Check identifier if specified
if let Some(ref id_filter) = identifier_str {
let elem_id = element.identifier()
.ok()
.map(|s| s.to_string());
if let Some(id) = elem_id {
if !id.contains(id_filter) {
return false;
}
} else {
return false;
}
}
true
},
Some(std::time::Duration::from_secs(2)),
);
finder.find().context("Element not found")
}
/// Click on a UI element
#[cfg(target_os = "macos")]
pub fn click_element(
&self,
app_name: &str,
role: &str,
title: Option<&str>,
identifier: Option<&str>,
) -> Result<()> {
let element = self.find_element(app_name, role, title, identifier)?;
// Perform the press action
let action_name = CFString::new("AXPress");
element
.perform_action(&action_name)
.map_err(|e| anyhow::anyhow!("Failed to perform press action: {:?}", e))?;
Ok(())
}
#[cfg(not(target_os = "macos"))]
pub fn click_element(
&self,
_app_name: &str,
_role: &str,
_title: Option<&str>,
_identifier: Option<&str>,
) -> Result<()> {
anyhow::bail!("Not supported on this platform")
}
/// Set the value of a UI element
#[cfg(target_os = "macos")]
pub fn set_value(
&self,
app_name: &str,
role: &str,
value: &str,
title: Option<&str>,
identifier: Option<&str>,
) -> Result<()> {
let element = self.find_element(app_name, role, title, identifier)?;
// Set the value - convert CFString to CFType
let cf_value = CFString::new(value);
element.set_value(cf_value.as_CFType())
.map_err(|e| anyhow::anyhow!("Failed to set value: {:?}", e))?;
Ok(())
}
#[cfg(not(target_os = "macos"))]
pub fn set_value(
&self,
_app_name: &str,
_role: &str,
_value: &str,
_title: Option<&str>,
_identifier: Option<&str>,
) -> Result<()> {
anyhow::bail!("Not supported on this platform")
}
/// Get the value of a UI element
#[cfg(target_os = "macos")]
pub fn get_value(
&self,
app_name: &str,
role: &str,
title: Option<&str>,
identifier: Option<&str>,
) -> Result<String> {
let element = self.find_element(app_name, role, title, identifier)?;
// Get the value
let value_type = element.value()
.map_err(|e| anyhow::anyhow!("Failed to get value: {:?}", e))?;
// Try to downcast to CFString
if let Some(cf_string) = value_type.downcast::<CFString>() {
Ok(cf_string.to_string())
} else {
// For non-string values, try to get a description
Ok(format!("<non-string value>"))
}
}
#[cfg(not(target_os = "macos"))]
pub fn get_value(
&self,
_app_name: &str,
_role: &str,
_title: Option<&str>,
_identifier: Option<&str>,
) -> Result<String> {
anyhow::bail!("Not supported on this platform")
}
/// Type text into the currently focused element (uses system text input)
#[cfg(target_os = "macos")]
pub fn type_text(&self, app_name: &str, text: &str) -> Result<()> {
use cocoa::base::{id, nil};
use cocoa::foundation::NSString;
use objc::{class, msg_send, sel, sel_impl};
// First, make sure the app is active
self.activate_app(app_name)?;
// Wait for app to fully activate
std::thread::sleep(std::time::Duration::from_millis(500));
// Send a Tab key to try to focus on a text field
// This helps ensure something is focused before we paste
let _ = self.press_key(app_name, "tab", vec![]);
std::thread::sleep(std::time::Duration::from_millis(800));
// Save old clipboard, set new content, paste, then restore
let old_content: id;
unsafe {
// Get the general pasteboard
let pasteboard: id = msg_send![class!(NSPasteboard), generalPasteboard];
// Save current clipboard content
let ns_string_type = NSString::alloc(nil).init_str("public.utf8-plain-text");
old_content = msg_send![pasteboard, stringForType: ns_string_type];
// Clear and set new content
let _: () = msg_send![pasteboard, clearContents];
let ns_string = NSString::alloc(nil).init_str(text);
let ns_type = NSString::alloc(nil).init_str("public.utf8-plain-text");
let _: bool = msg_send![pasteboard, setString:ns_string forType:ns_type];
}
// Wait a moment for clipboard to update
std::thread::sleep(std::time::Duration::from_millis(200));
// Paste using Cmd+V (outside unsafe block)
self.press_key(app_name, "v", vec!["command"])?;
// Wait for paste to complete
std::thread::sleep(std::time::Duration::from_millis(300));
// Restore old clipboard content if it existed
unsafe {
if old_content != nil {
let pasteboard: id = msg_send![class!(NSPasteboard), generalPasteboard];
let _: () = msg_send![pasteboard, clearContents];
let ns_type = NSString::alloc(nil).init_str("public.utf8-plain-text");
let _: bool = msg_send![pasteboard, setString:old_content forType:ns_type];
}
}
Ok(())
}
#[cfg(not(target_os = "macos"))]
pub fn type_text(&self, _app_name: &str, _text: &str) -> Result<()> {
anyhow::bail!("Not supported on this platform")
}
/// Focus on a text field or text area element
#[cfg(target_os = "macos")]
pub fn focus_element(
&self,
app_name: &str,
role: &str,
title: Option<&str>,
identifier: Option<&str>,
) -> Result<()> {
let element = self.find_element(app_name, role, title, identifier)?;
// Set focused attribute to true
use core_foundation::boolean::CFBoolean;
let cf_true = CFBoolean::true_value();
element.set_attribute(&accessibility::AXAttribute::focused(), cf_true)
.map_err(|e| anyhow::anyhow!("Failed to focus element: {:?}", e))?;
Ok(())
}
/// Press a keyboard shortcut
#[cfg(target_os = "macos")]
pub fn press_key(
&self,
app_name: &str,
key: &str,
modifiers: Vec<&str>,
) -> Result<()> {
use core_graphics::event::{
CGEvent, CGEventFlags, CGEventTapLocation,
};
use core_graphics::event_source::{CGEventSource, CGEventSourceStateID};
// First, make sure the app is active
self.activate_app(app_name)?;
// Wait a bit for activation
std::thread::sleep(std::time::Duration::from_millis(100));
// Map key string to key code
let key_code = Self::key_to_keycode(key)
.ok_or_else(|| anyhow::anyhow!("Unknown key: {}", key))?;
// Map modifiers to flags
let mut flags = CGEventFlags::CGEventFlagNull;
for modifier in modifiers {
match modifier.to_lowercase().as_str() {
"command" | "cmd" => flags |= CGEventFlags::CGEventFlagCommand,
"option" | "alt" => flags |= CGEventFlags::CGEventFlagAlternate,
"control" | "ctrl" => flags |= CGEventFlags::CGEventFlagControl,
"shift" => flags |= CGEventFlags::CGEventFlagShift,
_ => {}
}
}
// Create event source
let source = CGEventSource::new(CGEventSourceStateID::HIDSystemState)
.ok().context("Failed to create event source")?;
// Create key down event
let key_down = CGEvent::new_keyboard_event(source.clone(), key_code, true)
.ok().context("Failed to create key down event")?;
key_down.set_flags(flags);
// Create key up event
let key_up = CGEvent::new_keyboard_event(source, key_code, false)
.ok().context("Failed to create key up event")?;
key_up.set_flags(flags);
// Post events
key_down.post(CGEventTapLocation::HID);
std::thread::sleep(std::time::Duration::from_millis(50));
key_up.post(CGEventTapLocation::HID);
Ok(())
}
#[cfg(not(target_os = "macos"))]
pub fn press_key(
&self,
_app_name: &str,
_key: &str,
_modifiers: Vec<&str>,
) -> Result<()> {
anyhow::bail!("Not supported on this platform")
}
#[cfg(target_os = "macos")]
fn key_to_keycode(key: &str) -> Option<u16> {
// Map common keys to keycodes
// See: https://eastmanreference.com/complete-list-of-applescript-key-codes
match key.to_lowercase().as_str() {
"a" => Some(0x00),
"s" => Some(0x01),
"d" => Some(0x02),
"f" => Some(0x03),
"h" => Some(0x04),
"g" => Some(0x05),
"z" => Some(0x06),
"x" => Some(0x07),
"c" => Some(0x08),
"v" => Some(0x09),
"b" => Some(0x0B),
"q" => Some(0x0C),
"w" => Some(0x0D),
"e" => Some(0x0E),
"r" => Some(0x0F),
"y" => Some(0x10),
"t" => Some(0x11),
"1" => Some(0x12),
"2" => Some(0x13),
"3" => Some(0x14),
"4" => Some(0x15),
"6" => Some(0x16),
"5" => Some(0x17),
"=" => Some(0x18),
"9" => Some(0x19),
"7" => Some(0x1A),
"-" => Some(0x1B),
"8" => Some(0x1C),
"0" => Some(0x1D),
"]" => Some(0x1E),
"o" => Some(0x1F),
"u" => Some(0x20),
"[" => Some(0x21),
"i" => Some(0x22),
"p" => Some(0x23),
"return" | "enter" => Some(0x24),
"l" => Some(0x25),
"j" => Some(0x26),
"'" => Some(0x27),
"k" => Some(0x28),
";" => Some(0x29),
"\\" => Some(0x2A),
"," => Some(0x2B),
"/" => Some(0x2C),
"n" => Some(0x2D),
"m" => Some(0x2E),
"." => Some(0x2F),
"tab" => Some(0x30),
"space" => Some(0x31),
"`" => Some(0x32),
"delete" | "backspace" => Some(0x33),
"escape" | "esc" => Some(0x35),
"f1" => Some(0x7A),
"f2" => Some(0x78),
"f3" => Some(0x63),
"f4" => Some(0x76),
"f5" => Some(0x60),
"f6" => Some(0x61),
"f7" => Some(0x62),
"f8" => Some(0x64),
"f9" => Some(0x65),
"f10" => Some(0x6D),
"f11" => Some(0x67),
"f12" => Some(0x6F),
"left" => Some(0x7B),
"right" => Some(0x7C),
"down" => Some(0x7D),
"up" => Some(0x7E),
_ => None,
}
}
}
#[cfg(target_os = "macos")]
struct ElementCollector<'a> {
role_filter: Option<String>,
title_filter: Option<String>,
identifier_filter: Option<String>,
results: std::cell::RefCell<&'a mut Vec<AXElement>>,
depth: std::cell::Cell<usize>,
}
#[cfg(target_os = "macos")]
impl<'a> TreeVisitor for ElementCollector<'a> {
fn enter_element(&self, element: &AXUIElement) -> TreeWalkerFlow {
self.depth.set(self.depth.get() + 1);
if self.depth.get() > 20 {
return TreeWalkerFlow::SkipSubtree;
}
// Get element properties
let role = element.role()
.ok()
.map(|s| s.to_string())
.unwrap_or_else(|| "Unknown".to_string());
let title = element.title()
.ok()
.map(|s| s.to_string());
let identifier = element.identifier()
.ok()
.map(|s| s.to_string());
// Check if this element matches the filters
let role_matches = self.role_filter.as_ref().map_or(true, |r| role.contains(r));
let title_matches = self.title_filter.as_ref().map_or(true, |t| {
title.as_ref().map_or(false, |title_str| title_str.contains(t))
});
let identifier_matches = self.identifier_filter.as_ref().map_or(true, |id| {
identifier.as_ref().map_or(false, |id_str| id_str.contains(id))
});
if role_matches && title_matches && identifier_matches {
// Get additional properties
let value = element.value()
.ok()
.and_then(|v| {
v.downcast::<CFString>().map(|s| s.to_string())
});
let label = element.description()
.ok()
.map(|s| s.to_string());
let enabled = element.enabled()
.ok()
.map(|b| b.into())
.unwrap_or(false);
let focused = element.focused()
.ok()
.map(|b| b.into())
.unwrap_or(false);
// Count children
let children_count = element.children()
.ok()
.map(|arr| arr.len() as usize)
.unwrap_or(0);
self.results.borrow_mut().push(AXElement {
role,
title,
value,
label,
identifier,
enabled,
focused,
position: None,
size: None,
children_count,
});
}
TreeWalkerFlow::Continue
}
fn exit_element(&self, _element: &AXUIElement) {
self.depth.set(self.depth.get() - 1);
}
}

View File

@@ -1,65 +0,0 @@
pub mod controller;
pub use controller::MacAxController;
use serde::{Deserialize, Serialize};
#[cfg(test)]
mod tests;
/// Represents an accessibility element in the UI hierarchy
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AXElement {
pub role: String,
pub title: Option<String>,
pub value: Option<String>,
pub label: Option<String>,
pub identifier: Option<String>,
pub enabled: bool,
pub focused: bool,
pub position: Option<(f64, f64)>,
pub size: Option<(f64, f64)>,
pub children_count: usize,
}
/// Represents a macOS application
#[derive(Debug, Clone)]
pub struct AXApplication {
pub name: String,
pub bundle_id: Option<String>,
pub pid: i32,
}
impl AXElement {
/// Convert to a human-readable string representation
pub fn to_string(&self) -> String {
let mut parts = vec![format!("Role: {}", self.role)];
if let Some(ref title) = self.title {
parts.push(format!("Title: {}", title));
}
if let Some(ref value) = self.value {
parts.push(format!("Value: {}", value));
}
if let Some(ref label) = self.label {
parts.push(format!("Label: {}", label));
}
if let Some(ref id) = self.identifier {
parts.push(format!("ID: {}", id));
}
parts.push(format!("Enabled: {}", self.enabled));
parts.push(format!("Focused: {}", self.focused));
if let Some((x, y)) = self.position {
parts.push(format!("Position: ({:.0}, {:.0})", x, y));
}
if let Some((w, h)) = self.size {
parts.push(format!("Size: ({:.0}, {:.0})", w, h));
}
parts.push(format!("Children: {}", self.children_count));
parts.join(", ")
}
}

View File

@@ -1,37 +0,0 @@
#[cfg(test)]
mod tests {
use crate::{AXElement, MacAxController};
#[test]
fn test_ax_element_to_string() {
let element = AXElement {
role: "button".to_string(),
title: Some("Click Me".to_string()),
value: None,
label: Some("Submit Button".to_string()),
identifier: Some("submitBtn".to_string()),
enabled: true,
focused: false,
position: Some((100.0, 200.0)),
size: Some((80.0, 30.0)),
children_count: 0,
};
let string_repr = element.to_string();
assert!(string_repr.contains("Role: button"));
assert!(string_repr.contains("Title: Click Me"));
assert!(string_repr.contains("Label: Submit Button"));
assert!(string_repr.contains("ID: submitBtn"));
assert!(string_repr.contains("Enabled: true"));
assert!(string_repr.contains("Position: (100, 200)"));
assert!(string_repr.contains("Size: (80, 30)"));
}
#[test]
fn test_controller_creation() {
// Just test that we can create a controller
// Actual functionality requires macOS and permissions
let result = MacAxController::new();
assert!(result.is_ok());
}
}

View File

@@ -1,26 +0,0 @@
use crate::types::TextLocation;
use anyhow::Result;
use async_trait::async_trait;
/// OCR engine trait for text recognition with bounding boxes
#[async_trait]
pub trait OCREngine: Send + Sync {
/// Extract text with locations from an image file
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>>;
/// Get the name of the OCR engine
fn name(&self) -> &str;
}
// Platform-specific modules
#[cfg(target_os = "macos")]
pub mod vision;
pub mod tesseract;
// Re-export the default OCR engine for the platform
#[cfg(target_os = "macos")]
pub use vision::AppleVisionOCR as DefaultOCR;
#[cfg(not(target_os = "macos"))]
pub use tesseract::TesseractOCR as DefaultOCR;

View File

@@ -1,84 +0,0 @@
use super::OCREngine;
use crate::types::TextLocation;
use anyhow::Result;
use async_trait::async_trait;
/// Tesseract OCR engine (fallback/cross-platform)
pub struct TesseractOCR;
impl TesseractOCR {
pub fn new() -> Result<Self> {
// Check if tesseract is available
let tesseract_check = std::process::Command::new("which")
.arg("tesseract")
.output();
if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() {
anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\
To install tesseract:\n macOS: brew install tesseract\n \
Linux: sudo apt-get install tesseract-ocr (Ubuntu/Debian)\n \
sudo yum install tesseract (RHEL/CentOS)\n \
Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki\n\n\
After installation, restart your terminal and try again.");
}
Ok(Self)
}
}
#[async_trait]
impl OCREngine for TesseractOCR {
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>> {
// Use tesseract CLI with TSV output to get bounding boxes
let output = std::process::Command::new("tesseract")
.arg(path)
.arg("stdout")
.arg("tsv")
.output()
.map_err(|e| anyhow::anyhow!("Failed to run tesseract: {}", e))?;
if !output.status.success() {
anyhow::bail!("Tesseract failed: {}", String::from_utf8_lossy(&output.stderr));
}
let tsv_text = String::from_utf8_lossy(&output.stdout);
let mut locations = Vec::new();
// Parse TSV output (skip header line)
for (i, line) in tsv_text.lines().enumerate() {
if i == 0 { continue; } // Skip header
let parts: Vec<&str> = line.split('\t').collect();
if parts.len() >= 12 {
// TSV format: level, page_num, block_num, par_num, line_num, word_num,
// left, top, width, height, conf, text
if let (Ok(x), Ok(y), Ok(w), Ok(h), Ok(conf), text) = (
parts[6].parse::<i32>(),
parts[7].parse::<i32>(),
parts[8].parse::<i32>(),
parts[9].parse::<i32>(),
parts[10].parse::<f32>(),
parts[11],
) {
let trimmed = text.trim();
if !trimmed.is_empty() && conf > 0.0 {
locations.push(TextLocation {
text: trimmed.to_string(),
x,
y,
width: w,
height: h,
confidence: conf / 100.0, // Convert from 0-100 to 0-1
});
}
}
}
}
Ok(locations)
}
fn name(&self) -> &str {
"Tesseract OCR"
}
}

View File

@@ -1,103 +0,0 @@
use super::OCREngine;
use crate::types::TextLocation;
use anyhow::{Result, Context};
use async_trait::async_trait;
use std::ffi::{CStr, CString};
use std::os::raw::{c_char, c_float, c_uint};
// FFI bindings to Swift VisionBridge
#[repr(C)]
struct VisionTextBox {
text: *const c_char,
text_len: c_uint,
x: i32,
y: i32,
width: i32,
height: i32,
confidence: c_float,
}
extern "C" {
fn vision_recognize_text(
image_path: *const c_char,
image_path_len: c_uint,
out_boxes: *mut *mut std::ffi::c_void,
out_count: *mut c_uint,
) -> bool;
fn vision_free_boxes(boxes: *mut std::ffi::c_void, count: c_uint);
}
/// Apple Vision Framework OCR engine
pub struct AppleVisionOCR;
impl AppleVisionOCR {
pub fn new() -> Result<Self> {
Ok(Self)
}
}
#[async_trait]
impl OCREngine for AppleVisionOCR {
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>> {
// Convert path to C string
let c_path = CString::new(path)
.context("Failed to convert path to C string")?;
let mut boxes_ptr: *mut std::ffi::c_void = std::ptr::null_mut();
let mut count: c_uint = 0;
// Call Swift Vision API
let success = unsafe {
vision_recognize_text(
c_path.as_ptr(),
path.len() as c_uint,
&mut boxes_ptr,
&mut count,
)
};
if !success || boxes_ptr.is_null() {
anyhow::bail!("Apple Vision OCR failed");
}
// Convert C array to Rust Vec
let mut locations = Vec::new();
unsafe {
let typed_boxes = boxes_ptr as *const VisionTextBox;
let boxes_slice = std::slice::from_raw_parts(typed_boxes, count as usize);
for box_data in boxes_slice {
// Convert C string to Rust String
let text = if !box_data.text.is_null() {
CStr::from_ptr(box_data.text)
.to_string_lossy()
.into_owned()
} else {
String::new()
};
if !text.is_empty() {
locations.push(TextLocation {
text,
x: box_data.x,
y: box_data.y,
width: box_data.width,
height: box_data.height,
confidence: box_data.confidence,
});
}
}
// Free the C array
vision_free_boxes(boxes_ptr, count);
}
Ok(locations)
}
fn name(&self) -> &str {
"Apple Vision Framework"
}
}

View File

@@ -63,15 +63,10 @@ impl ComputerController for LinuxController {
} }
async fn take_screenshot(&self, _path: &str, _region: Option<Rect>, _window_id: Option<&str>) -> Result<()> { async fn take_screenshot(&self, _path: &str, _region: Option<Rect>, _window_id: Option<&str>) -> Result<()> {
// Enforce that window_id must be provided
if _window_id.is_none() {
anyhow::bail!("window_id is required. You must specify which window to capture (e.g., 'Firefox', 'Terminal', 'gedit'). Use list_windows to see available windows.");
}
anyhow::bail!("Linux implementation not yet available") anyhow::bail!("Linux implementation not yet available")
} }
async fn extract_text_from_screen(&self, _region: Rect, _window_id: &str) -> Result<String> { async fn extract_text_from_screen(&self, _region: Rect) -> Result<OCRResult> {
anyhow::bail!("Linux implementation not yet available") anyhow::bail!("Linux implementation not yet available")
} }

View File

@@ -1,37 +1,22 @@
use crate::{ComputerController, types::{Rect, TextLocation}}; use crate::{ComputerController, types::Rect};
use crate::ocr::{OCREngine, DefaultOCR}; use anyhow::Result;
use anyhow::{Result, Context};
use async_trait::async_trait; use async_trait::async_trait;
use std::path::Path; use std::path::Path;
use core_graphics::window::{kCGWindowListOptionOnScreenOnly, kCGNullWindowID, CGWindowListCopyWindowInfo}; use tesseract::Tesseract;
use core_foundation::dictionary::CFDictionary;
use core_foundation::string::CFString;
use core_foundation::base::{TCFType, ToVoid};
use core_foundation::array::CFArray;
pub struct MacOSController { pub struct MacOSController {
ocr_engine: Box<dyn OCREngine>, // Empty struct for now
#[allow(dead_code)]
ocr_name: String,
} }
impl MacOSController { impl MacOSController {
pub fn new() -> Result<Self> { pub fn new() -> Result<Self> {
let ocr = Box::new(DefaultOCR::new()?); Ok(Self {})
let ocr_name = ocr.name().to_string();
tracing::info!("Initialized macOS controller with OCR engine: {}", ocr_name);
Ok(Self { ocr_engine: ocr, ocr_name })
} }
} }
#[async_trait] #[async_trait]
impl ComputerController for MacOSController { impl ComputerController for MacOSController {
async fn take_screenshot(&self, path: &str, region: Option<Rect>, window_id: Option<&str>) -> Result<()> { async fn take_screenshot(&self, path: &str, region: Option<Rect>, window_id: Option<&str>) -> Result<()> {
// Enforce that window_id must be provided
if window_id.is_none() {
return Err(anyhow::anyhow!("window_id is required. You must specify which window to capture (e.g., 'Safari', 'Terminal', 'Google Chrome'). Use list_windows to see available windows."));
}
// Determine the temporary directory for screenshots // Determine the temporary directory for screenshots
let temp_dir = std::env::var("TMPDIR") let temp_dir = std::env::var("TMPDIR")
.or_else(|_| std::env::var("HOME").map(|h| format!("{}/tmp", h))) .or_else(|_| std::env::var("HOME").map(|h| format!("{}/tmp", h)))
@@ -52,134 +37,48 @@ impl ComputerController for MacOSController {
std::fs::create_dir_all(parent)?; std::fs::create_dir_all(parent)?;
} }
let app_name = window_id.unwrap(); // Safe because we checked is_none() above
// Get the window ID for the specified application
let cg_window_id = unsafe {
let window_list = CGWindowListCopyWindowInfo(
kCGWindowListOptionOnScreenOnly,
kCGNullWindowID
);
let array = CFArray::<CFDictionary>::wrap_under_create_rule(window_list);
let count = array.len();
let mut found_window_id: Option<(u32, String)> = None; // (id, owner)
let app_name_lower = app_name.to_lowercase();
for i in 0..count {
let dict = array.get(i).unwrap();
// Get owner name
let owner_key = CFString::from_static_string("kCGWindowOwnerName");
let owner: String = if let Some(value) = dict.find(owner_key.to_void()) {
let s: CFString = TCFType::wrap_under_get_rule(*value as *const _);
s.to_string()
} else {
continue;
};
tracing::debug!("Checking window: owner='{}', looking for '{}'", owner, app_name);
let owner_lower = owner.to_lowercase();
// Normalize by removing spaces for exact matching
let app_name_normalized = app_name_lower.replace(" ", "");
let owner_normalized = owner_lower.replace(" ", "");
// ONLY accept exact matches (case-insensitive, with or without spaces)
// This prevents "Goose" from matching "GooseStudio"
let is_match = owner_lower == app_name_lower || owner_normalized == app_name_normalized;
if is_match {
// Get window ID
let window_id_key = CFString::from_static_string("kCGWindowNumber");
if let Some(value) = dict.find(window_id_key.to_void()) {
let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _);
if let Some(id) = num.to_i64() {
// Get window layer to filter out menu bar windows
let layer_key = CFString::from_static_string("kCGWindowLayer");
let layer: i32 = if let Some(value) = dict.find(layer_key.to_void()) {
let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _);
num.to_i32().unwrap_or(0)
} else {
0
};
// Get window bounds to verify it's a real window
let bounds_key = CFString::from_static_string("kCGWindowBounds");
let has_real_bounds = if let Some(value) = dict.find(bounds_key.to_void()) {
let bounds_dict: CFDictionary = TCFType::wrap_under_get_rule(*value as *const _);
let width_key = CFString::from_static_string("Width");
let height_key = CFString::from_static_string("Height");
if let (Some(w_val), Some(h_val)) = (
bounds_dict.find(width_key.to_void()),
bounds_dict.find(height_key.to_void()),
) {
let w_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*w_val as *const _);
let h_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*h_val as *const _);
let width = w_num.to_f64().unwrap_or(0.0);
let height = h_num.to_f64().unwrap_or(0.0);
// Real windows should be at least 100x100 pixels
width >= 100.0 && height >= 100.0
} else {
false
}
} else {
false
};
// Only accept windows that are:
// 1. At layer 0 (normal windows, not menu bar)
// 2. Have real bounds (width and height >= 100)
if layer == 0 && has_real_bounds {
tracing::info!("Found valid window: ID {} for app '{}' (layer={}, bounds valid)", id, owner, layer);
found_window_id = Some((id as u32, owner.clone()));
break;
} else {
tracing::debug!("Skipping window ID {} for '{}': layer={}, has_real_bounds={}", id, owner, layer, has_real_bounds);
}
}
}
}
}
found_window_id
};
let (cg_window_id, matched_owner) = cg_window_id.ok_or_else(|| {
anyhow::anyhow!("Could not find window for application '{}'. Use list_windows to see available windows.", app_name)
})?;
tracing::info!("Taking screenshot of window ID {} for app '{}'", cg_window_id, matched_owner);
// Use screencapture with the window ID for now
// TODO: Implement direct CGWindowListCreateImage approach with proper image saving
let mut cmd = std::process::Command::new("screencapture"); let mut cmd = std::process::Command::new("screencapture");
// Add flags
cmd.arg("-x"); // No sound cmd.arg("-x"); // No sound
cmd.arg("-l");
cmd.arg(cg_window_id.to_string());
if let Some(region) = region { if let Some(region) = region {
// Capture specific region: -R x,y,width,height
cmd.arg("-R"); cmd.arg("-R");
cmd.arg(format!("{},{},{},{}", region.x, region.y, region.width, region.height)); cmd.arg(format!("{},{},{},{}", region.x, region.y, region.width, region.height));
} }
if let Some(app_name) = window_id {
// Capture specific window by app name
// Use AppleScript to get window ID
let script = format!(r#"tell application "{}" to id of window 1"#, app_name);
let output = std::process::Command::new("osascript")
.arg("-e")
.arg(&script)
.output()?;
if output.status.success() {
let window_id_str = String::from_utf8_lossy(&output.stdout).trim().to_string();
cmd.arg(format!("-l{}", window_id_str));
}
}
cmd.arg(&final_path); cmd.arg(&final_path);
let screenshot_result = cmd.output()?; let screenshot_result = cmd.output()?;
if !screenshot_result.status.success() { if !screenshot_result.status.success() {
let stderr = String::from_utf8_lossy(&screenshot_result.stderr); let stderr = String::from_utf8_lossy(&screenshot_result.stderr);
return Err(anyhow::anyhow!("screencapture failed for window {}: {}", cg_window_id, stderr)); return Err(anyhow::anyhow!("screencapture failed: {}", stderr));
} }
Ok(()) Ok(())
} }
async fn extract_text_from_screen(&self, region: Rect, window_id: &str) -> Result<String> { async fn extract_text_from_screen(&self, region: Rect) -> Result<String> {
// Take screenshot of region first // Take screenshot of region first
let temp_path = format!("/tmp/g3_ocr_{}.png", uuid::Uuid::new_v4()); let temp_path = format!("/tmp/g3_ocr_{}.png", uuid::Uuid::new_v4());
self.take_screenshot(&temp_path, Some(region), Some(window_id)).await?; self.take_screenshot(&temp_path, Some(region), None).await?;
// Extract text from the screenshot // Extract text from the screenshot
let result = self.extract_text_from_image(&temp_path).await?; let result = self.extract_text_from_image(&temp_path).await?;
@@ -191,317 +90,36 @@ impl ComputerController for MacOSController {
} }
async fn extract_text_from_image(&self, path: &str) -> Result<String> { async fn extract_text_from_image(&self, path: &str) -> Result<String> {
// Extract all text and concatenate // Check if tesseract is available on the system
let locations = self.ocr_engine.extract_text_with_locations(path).await?; let tesseract_check = std::process::Command::new("which")
Ok(locations.iter().map(|loc| loc.text.as_str()).collect::<Vec<_>>().join(" ")) .arg("tesseract")
} .output();
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>> {
// Use the OCR engine
self.ocr_engine.extract_text_with_locations(path).await
}
async fn find_text_in_app(&self, app_name: &str, search_text: &str) -> Result<Option<TextLocation>> {
// Take screenshot of specific app window
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
let temp_path = format!("{}/tmp/g3_find_text_{}_{}.png", home, app_name, uuid::Uuid::new_v4());
self.take_screenshot(&temp_path, None, Some(app_name)).await?;
// Get screenshot dimensions before we delete it if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() {
let screenshot_dims = get_image_dimensions(&temp_path)?; anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\
To install tesseract:\n macOS: brew install tesseract\n \
// Extract all text with locations Linux: sudo apt-get install tesseract-ocr (Ubuntu/Debian)\n \
let locations = self.extract_text_with_locations(&temp_path).await?; sudo yum install tesseract (RHEL/CentOS)\n \
Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki\n\n\
// Get window bounds to calculate coordinate transformation After installation, restart your terminal and try again.");
let window_bounds = self.get_window_bounds(app_name)?;
// Clean up temp file
let _ = std::fs::remove_file(&temp_path);
// Find matching text (case-insensitive)
let search_lower = search_text.to_lowercase();
for location in locations {
if location.text.to_lowercase().contains(&search_lower) {
// Transform coordinates from screenshot space to screen space
let transformed = transform_screenshot_to_screen_coords(
location,
window_bounds,
screenshot_dims,
);
return Ok(Some(transformed));
}
} }
Ok(None) // Initialize Tesseract
let tess = Tesseract::new(None, Some("eng"))
.map_err(|e| {
anyhow::anyhow!("Failed to initialize Tesseract: {}\n\n\
This usually means:\n1. Tesseract is not properly installed\n\
2. Language data files are missing\n\nTo fix:\n \
macOS: brew reinstall tesseract\n \
Linux: sudo apt-get install tesseract-ocr-eng\n \
Windows: Reinstall tesseract and ensure language files are included", e)
})?;
let text = tess.set_image(path)
.map_err(|e| anyhow::anyhow!("Failed to load image '{}': {}", path, e))?
.get_text()
.map_err(|e| anyhow::anyhow!("Failed to extract text from image: {}", e))?;
Ok(text)
} }
}
fn move_mouse(&self, x: i32, y: i32) -> Result<()> {
use core_graphics::event::{
CGEvent, CGEventTapLocation, CGEventType, CGMouseButton,
};
use core_graphics::event_source::{
CGEventSource, CGEventSourceStateID,
};
use core_graphics::geometry::CGPoint;
let source = CGEventSource::new(CGEventSourceStateID::HIDSystemState)
.ok().context("Failed to create event source")?;
let event = CGEvent::new_mouse_event(
source,
CGEventType::MouseMoved,
CGPoint::new(x as f64, y as f64),
CGMouseButton::Left,
).ok().context("Failed to create mouse event")?;
event.post(CGEventTapLocation::HID);
Ok(())
}
fn click_at(&self, x: i32, y: i32, _app_name: Option<&str>) -> Result<()> {
use core_graphics::event::{
CGEvent, CGEventTapLocation, CGEventType, CGMouseButton,
};
use core_graphics::event_source::{
CGEventSource, CGEventSourceStateID,
};
use core_graphics::geometry::CGPoint;
use core_graphics::display::CGDisplay;
// IMPORTANT: Coordinates passed here are in NSScreen/CGWindowListCopyWindowInfo space
// (Y=0 at BOTTOM, increases UPWARD)
// But CGEvent uses a different coordinate system (Y=0 at TOP, increases DOWNWARD)
// We need to convert: CGEvent.y = screenHeight - NSScreen.y
let screen_height = CGDisplay::main().pixels_high() as i32;
let cgevent_x = x;
let cgevent_y = screen_height - y;
tracing::debug!("click_at: NSScreen coords ({}, {}) -> CGEvent coords ({}, {}) [screen_height={}]",
x, y, cgevent_x, cgevent_y, screen_height);
let (global_x, global_y) = (cgevent_x, cgevent_y);
let point = CGPoint::new(global_x as f64, global_y as f64);
let source = CGEventSource::new(CGEventSourceStateID::HIDSystemState)
.ok().context("Failed to create event source")?;
// Move mouse to position first
let move_event = CGEvent::new_mouse_event(
source.clone(),
CGEventType::MouseMoved,
point,
CGMouseButton::Left,
).ok().context("Failed to create mouse move event")?;
move_event.post(CGEventTapLocation::HID);
std::thread::sleep(std::time::Duration::from_millis(100));
// Mouse down
let mouse_down = CGEvent::new_mouse_event(
source.clone(),
CGEventType::LeftMouseDown,
point,
CGMouseButton::Left,
).ok().context("Failed to create mouse down event")?;
mouse_down.post(CGEventTapLocation::HID);
std::thread::sleep(std::time::Duration::from_millis(50));
// Mouse up
let mouse_up = CGEvent::new_mouse_event(
source,
CGEventType::LeftMouseUp,
point,
CGMouseButton::Left,
).ok().context("Failed to create mouse up event")?;
mouse_up.post(CGEventTapLocation::HID);
Ok(())
}
}
impl MacOSController {
/// Get window bounds for an application (helper method)
fn get_window_bounds(&self, app_name: &str) -> Result<(i32, i32, i32, i32)> {
unsafe {
let window_list = CGWindowListCopyWindowInfo(
kCGWindowListOptionOnScreenOnly,
kCGNullWindowID
);
let array = CFArray::<CFDictionary>::wrap_under_create_rule(window_list);
let count = array.len();
let app_name_lower = app_name.to_lowercase();
for i in 0..count {
let dict = array.get(i).unwrap();
// Get owner name
let owner_key = CFString::from_static_string("kCGWindowOwnerName");
let owner: String = if let Some(value) = dict.find(owner_key.to_void()) {
let s: CFString = TCFType::wrap_under_get_rule(*value as *const _);
s.to_string()
} else {
continue;
};
let owner_lower = owner.to_lowercase();
// Normalize by removing spaces for exact matching
let app_name_normalized = app_name_lower.replace(" ", "");
let owner_normalized = owner_lower.replace(" ", "");
// ONLY accept exact matches (case-insensitive, with or without spaces)
// This prevents "Goose" from matching "GooseStudio"
let is_match = owner_lower == app_name_lower || owner_normalized == app_name_normalized;
if is_match {
// Get window layer to filter out menu bar windows
let layer_key = CFString::from_static_string("kCGWindowLayer");
let layer: i32 = if let Some(value) = dict.find(layer_key.to_void()) {
let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _);
num.to_i32().unwrap_or(0)
} else {
0
};
// Skip menu bar windows (layer >= 20)
if layer >= 20 {
tracing::debug!("Skipping window for '{}' at layer {} (menu bar)", owner, layer);
continue;
}
// Get window bounds to verify it's a real window
let bounds_key = CFString::from_static_string("kCGWindowBounds");
if let Some(value) = dict.find(bounds_key.to_void()) {
let bounds_dict: CFDictionary = TCFType::wrap_under_get_rule(*value as *const _);
let x_key = CFString::from_static_string("X");
let y_key = CFString::from_static_string("Y");
let width_key = CFString::from_static_string("Width");
let height_key = CFString::from_static_string("Height");
if let (Some(x_val), Some(y_val), Some(w_val), Some(h_val)) = (
bounds_dict.find(x_key.to_void()),
bounds_dict.find(y_key.to_void()),
bounds_dict.find(width_key.to_void()),
bounds_dict.find(height_key.to_void()),
) {
let x_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*x_val as *const _);
let y_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*y_val as *const _);
let w_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*w_val as *const _);
let h_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*h_val as *const _);
let x: i32 = x_num.to_i64().unwrap_or(0) as i32;
let y: i32 = y_num.to_i64().unwrap_or(0) as i32;
let w: i32 = w_num.to_i64().unwrap_or(0) as i32;
let h: i32 = h_num.to_i64().unwrap_or(0) as i32;
// Only accept windows with real bounds (>= 100x100 pixels)
if w >= 100 && h >= 100 {
tracing::info!("Found valid window bounds for '{}': x={}, y={}, w={}, h={} (layer={})", owner, x, y, w, h, layer);
return Ok((x, y, w, h));
} else {
tracing::debug!("Skipping window for '{}': too small ({}x{})", owner, w, h);
continue;
}
} else {
continue;
}
}
}
}
}
Err(anyhow::anyhow!("Could not find window bounds for '{}'", app_name))
}
}
/// Get image dimensions from a PNG file
fn get_image_dimensions(path: &str) -> Result<(i32, i32)> {
use std::fs::File;
use std::io::Read;
let mut file = File::open(path)?;
let mut buffer = vec![0u8; 24];
file.read_exact(&mut buffer)?;
// PNG signature check
if &buffer[0..8] != b"\x89PNG\r\n\x1a\n" {
anyhow::bail!("Not a valid PNG file");
}
// Read IHDR chunk (width and height are at bytes 16-23)
let width = u32::from_be_bytes([buffer[16], buffer[17], buffer[18], buffer[19]]) as i32;
let height = u32::from_be_bytes([buffer[20], buffer[21], buffer[22], buffer[23]]) as i32;
Ok((width, height))
}
/// Transform coordinates from screenshot space to screen space
///
/// The screenshot is taken of a window, and Vision OCR returns coordinates
/// relative to the screenshot image. We need to transform these to actual
/// screen coordinates for clicking.
///
/// On Retina displays, screenshots are taken at 2x resolution, so we need
/// to account for this scaling factor.
fn transform_screenshot_to_screen_coords(
location: TextLocation,
window_bounds: (i32, i32, i32, i32), // (x, y, width, height) in screen space
screenshot_dims: (i32, i32), // (width, height) in pixels
) -> TextLocation {
let (win_x, win_y, win_width, win_height) = window_bounds;
let (screenshot_width, screenshot_height) = screenshot_dims;
// Calculate scale factors
// On Retina displays, screenshot is typically 2x the window size
let scale_x = win_width as f64 / screenshot_width as f64;
let scale_y = win_height as f64 / screenshot_height as f64;
tracing::debug!("Transform: screenshot={}x{}, window={}x{} at ({},{}), scale=({:.2},{:.2})",
screenshot_width, screenshot_height, win_width, win_height, win_x, win_y, scale_x, scale_y);
// Transform coordinates from image space to screen space
// IMPORTANT: macOS screen coordinates have origin at BOTTOM-LEFT (Y increases upward)
// Image coordinates have origin at TOP-LEFT (Y increases downward)
// win_y is the BOTTOM of the window in screen coordinates
// So we need to: (win_y + win_height) to get window TOP, then subtract screenshot_y
let window_top_y = win_y + win_height;
tracing::debug!("[transform] Input location in image space: x={}, y={}, width={}, height={}",
location.x, location.y, location.width, location.height);
tracing::debug!("[transform] Scale factors: scale_x={:.4}, scale_y={:.4}", scale_x, scale_y);
let transformed_x = win_x + (location.x as f64 * scale_x) as i32;
let transformed_y = window_top_y - (location.y as f64 * scale_y) as i32;
let transformed_width = (location.width as f64 * scale_x) as i32;
let transformed_height = (location.height as f64 * scale_y) as i32;
tracing::debug!("[transform] Calculation details:");
tracing::debug!(" - transformed_x = {} + ({} * {:.4}) = {} + {:.2} = {}", win_x, location.x, scale_x, win_x, location.x as f64 * scale_x, transformed_x);
tracing::debug!(" - transformed_width = ({} * {:.4}) = {:.2} -> {}", location.width, scale_x, location.width as f64 * scale_x, transformed_width);
tracing::debug!(" - transformed_height = ({} * {:.4}) = {:.2} -> {}", location.height, scale_y, location.height as f64 * scale_y, transformed_height);
tracing::debug!("Transformed location: screenshot=({},{}) {}x{} -> screen=({},{}) {}x{}",
location.x, location.y, location.width, location.height,
transformed_x, transformed_y, transformed_width, transformed_height);
TextLocation {
text: location.text,
x: transformed_x,
y: transformed_y,
width: transformed_width,
height: transformed_height,
confidence: location.confidence,
}
}
#[path = "macos_window_matching_test.rs"]
#[cfg(test)]
mod tests;

View File

@@ -0,0 +1,425 @@
use crate::{ComputerController, types::*};
use anyhow::Result;
use async_trait::async_trait;
use core_graphics::display::CGPoint;
use core_graphics::event::{CGEvent, CGEventType, CGMouseButton, CGEventTapLocation};
use core_graphics::event_source::{CGEventSource, CGEventSourceStateID};
use std::path::Path;
use tesseract::Tesseract;
// MacOSController doesn't store CGEventSource to avoid Send/Sync issues
// We create it fresh for each operation
pub struct MacOSController {
// Empty struct - event source created per operation
}
impl MacOSController {
pub fn new() -> Result<Self> {
// Test that we can create an event source
let _event_source = CGEventSource::new(CGEventSourceStateID::CombinedSessionState)
.map_err(|_| anyhow::anyhow!("Failed to create event source. Make sure Accessibility permissions are granted."))?;
Ok(Self {})
}
fn key_to_keycode(&self, key: &str) -> Result<u16> {
// Map key names to macOS keycodes
let keycode = match key.to_lowercase().as_str() {
"return" | "enter" => 36,
"tab" => 48,
"space" => 49,
"delete" | "backspace" => 51,
"escape" | "esc" => 53,
"command" | "cmd" => 55,
"shift" => 56,
"capslock" => 57,
"option" | "alt" => 58,
"control" | "ctrl" => 59,
"left" => 123,
"right" => 124,
"down" => 125,
"up" => 126,
_ => anyhow::bail!("Unknown key: {}", key),
};
Ok(keycode)
}
}
#[async_trait]
impl ComputerController for MacOSController {
async fn move_mouse(&self, x: i32, y: i32) -> Result<()> {
let event_source = CGEventSource::new(CGEventSourceStateID::CombinedSessionState)
.map_err(|_| anyhow::anyhow!("Failed to create event source"))?;
let point = CGPoint::new(x as f64, y as f64);
let event = CGEvent::new_mouse_event(
event_source,
CGEventType::MouseMoved,
point,
CGMouseButton::Left,
).map_err(|_| anyhow::anyhow!("Failed to create mouse move event"))?;
event.post(CGEventTapLocation::HID);
Ok(())
}
async fn click(&self, button: MouseButton) -> Result<()> {
let (cg_button, down_type, up_type) = match button {
MouseButton::Left => (CGMouseButton::Left, CGEventType::LeftMouseDown, CGEventType::LeftMouseUp),
MouseButton::Right => (CGMouseButton::Right, CGEventType::RightMouseDown, CGEventType::RightMouseUp),
MouseButton::Middle => (CGMouseButton::Center, CGEventType::OtherMouseDown, CGEventType::OtherMouseUp),
};
let point = {
// Get current mouse position
let temp_source = CGEventSource::new(CGEventSourceStateID::CombinedSessionState)
.map_err(|_| anyhow::anyhow!("Failed to create event source"))?;
let event = CGEvent::new(temp_source)
.map_err(|_| anyhow::anyhow!("Failed to get mouse position"))?;
let p = event.location();
p
};
{
let event_source = CGEventSource::new(CGEventSourceStateID::CombinedSessionState)
.map_err(|_| anyhow::anyhow!("Failed to create event source"))?;
// Mouse down
let down_event = CGEvent::new_mouse_event(
event_source,
down_type,
point,
cg_button,
).map_err(|_| anyhow::anyhow!("Failed to create mouse down event"))?;
down_event.post(CGEventTapLocation::HID);
} // event_source and down_event dropped here
// Small delay
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
{
let event_source = CGEventSource::new(CGEventSourceStateID::CombinedSessionState)
.map_err(|_| anyhow::anyhow!("Failed to create event source"))?;
let up_event = CGEvent::new_mouse_event(
event_source,
up_type,
point,
cg_button,
).map_err(|_| anyhow::anyhow!("Failed to create mouse up event"))?;
up_event.post(CGEventTapLocation::HID);
} // event_source and up_event dropped here
Ok(())
}
async fn double_click(&self, button: MouseButton) -> Result<()> {
self.click(button).await?;
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
self.click(button).await?;
Ok(())
}
async fn type_text(&self, text: &str) -> Result<()> {
for ch in text.chars() {
{
let event_source = CGEventSource::new(CGEventSourceStateID::CombinedSessionState)
.map_err(|_| anyhow::anyhow!("Failed to create event source"))?;
// Create keyboard event for character
let event = CGEvent::new_keyboard_event(
event_source,
0, // keycode (0 for unicode)
true,
).map_err(|_| anyhow::anyhow!("Failed to create keyboard event"))?;
// Set unicode string
let mut utf16_buf = [0u16; 2];
let utf16_slice = ch.encode_utf16(&mut utf16_buf);
let utf16_chars: Vec<u16> = utf16_slice.iter().copied().collect();
event.set_string_from_utf16_unchecked(utf16_chars.as_slice());
event.post(CGEventTapLocation::HID);
} // event_source and event dropped here
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
Ok(())
}
async fn press_key(&self, key: &str) -> Result<()> {
let keycode = self.key_to_keycode(key)?;
{
let event_source = CGEventSource::new(CGEventSourceStateID::CombinedSessionState)
.map_err(|_| anyhow::anyhow!("Failed to create event source"))?;
// Key down
let down_event = CGEvent::new_keyboard_event(
event_source,
keycode,
true,
).map_err(|_| anyhow::anyhow!("Failed to create key down event"))?;
down_event.post(CGEventTapLocation::HID);
} // event_source and down_event dropped here
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
{
let event_source = CGEventSource::new(CGEventSourceStateID::CombinedSessionState)
.map_err(|_| anyhow::anyhow!("Failed to create event source"))?;
// Key up
let up_event = CGEvent::new_keyboard_event(
event_source,
keycode,
false,
).map_err(|_| anyhow::anyhow!("Failed to create key up event"))?;
up_event.post(CGEventTapLocation::HID);
} // event_source and up_event dropped here
Ok(())
}
async fn list_windows(&self) -> Result<Vec<Window>> {
// Note: Full implementation would use CGWindowListCopyWindowInfo
// For now, return empty list as this requires more complex FFI
tracing::warn!("list_windows not fully implemented on macOS");
Ok(vec![])
}
async fn focus_window(&self, _window_id: &str) -> Result<()> {
// Note: Full implementation would use NSWorkspace to activate application
tracing::warn!("focus_window not fully implemented on macOS");
Ok(())
}
async fn get_window_bounds(&self, _window_id: &str) -> Result<Rect> {
// Note: Full implementation would use Accessibility API
tracing::warn!("get_window_bounds not fully implemented on macOS");
Ok(Rect { x: 0, y: 0, width: 800, height: 600 })
}
async fn find_element(&self, _selector: &ElementSelector) -> Result<Option<UIElement>> {
// Note: Full implementation would use macOS Accessibility API
tracing::warn!("find_element not fully implemented on macOS");
Ok(None)
}
async fn get_element_text(&self, _element_id: &str) -> Result<String> {
// Note: Full implementation would use Accessibility API
tracing::warn!("get_element_text not fully implemented on macOS");
Ok(String::new())
}
async fn get_element_bounds(&self, _element_id: &str) -> Result<Rect> {
// Note: Full implementation would use Accessibility API
tracing::warn!("get_element_bounds not fully implemented on macOS");
Ok(Rect { x: 0, y: 0, width: 100, height: 30 })
}
async fn take_screenshot(&self, path: &str, _region: Option<Rect>, window_id: Option<&str>) -> Result<()> {
// Use native macOS screencapture command which handles all the format complexities
// Check if we have Screen Recording permission by attempting a test capture
// If we only get wallpaper/menubar but no windows, we need permission
let needs_permission_check = std::env::var("G3_SKIP_PERMISSION_CHECK").is_err();
if needs_permission_check {
// Try to open Screen Recording settings if this is the first screenshot
static PERMISSION_PROMPTED: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
if !PERMISSION_PROMPTED.swap(true, std::sync::atomic::Ordering::Relaxed) {
tracing::warn!("\n=== Screen Recording Permission Required ===\n\
macOS requires explicit permission to capture window content.\n\
If screenshots only show wallpaper/menubar (no windows):\n\n\
1. Open System Settings > Privacy & Security > Screen Recording\n\
2. Enable permission for your terminal (iTerm/Terminal) or g3\n\
3. Restart your terminal if needed\n\n\
Opening Screen Recording settings now...\n");
// Try to open the settings (non-blocking)
let _ = std::process::Command::new("open")
.arg("x-apple.systempreferences:com.apple.preference.security?Privacy_ScreenCapture")
.spawn();
}
}
let path_obj = Path::new(path);
if let Some(parent) = path_obj.parent() {
std::fs::create_dir_all(parent)?;
}
let mut cmd = std::process::Command::new("screencapture");
// Add flags
cmd.arg("-x"); // No sound
if let Some(window_id) = window_id {
// Capture specific window by getting its bounds and using region capture
// window_id format: "AppName" or "AppName:WindowTitle"
let app_name = window_id.split(':').next().unwrap_or(window_id);
// Use AppleScript to get window bounds
let script = format!(
r#"tell application "{}"
tell current window
get bounds
end tell
end tell"#,
app_name
);
let output = std::process::Command::new("osascript")
.arg("-e")
.arg(&script)
.output()
.map_err(|e| anyhow::anyhow!("Failed to get window bounds: {}", e))?;
if output.status.success() {
let bounds_str = String::from_utf8_lossy(&output.stdout);
let bounds: Vec<i32> = bounds_str
.trim()
.split(',')
.filter_map(|s| s.trim().parse().ok())
.collect();
if bounds.len() == 4 {
let (left, top, right, bottom) = (bounds[0], bounds[1], bounds[2], bounds[3]);
let width = right - left;
let height = bottom - top;
cmd.arg("-R");
cmd.arg(format!("{},{},{},{}", left, top, width, height));
tracing::debug!("Capturing window '{}' at region: {},{} {}x{}", app_name, left, top, width, height);
} else {
tracing::warn!("Failed to parse window bounds, capturing full screen");
}
} else {
tracing::warn!("Failed to get window bounds for '{}', capturing full screen", app_name);
}
} else if let Some(region) = _region {
// Capture specific region: -R x,y,width,height
cmd.arg("-R");
cmd.arg(format!("{},{},{},{}", region.x, region.y, region.width, region.height));
}
cmd.arg(path);
let output = cmd.output()
.map_err(|e| anyhow::anyhow!("Failed to execute screencapture: {}", e))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("screencapture failed: {}", stderr);
}
tracing::debug!("Screenshot saved using screencapture: {}", path);
Ok(())
}
}
async fn extract_text_from_screen(&self, region: Rect) -> Result<OCRResult> {
// Take screenshot of region first
let temp_path = format!("/tmp/g3_ocr_{}.png", uuid::Uuid::new_v4());
self.take_screenshot(&temp_path, Some(region), None).await?;
// Extract text from the screenshot
let result = self.extract_text_from_image(&temp_path).await?;
// Clean up temp file
let _ = std::fs::remove_file(&temp_path);
Ok(result)
}
async fn extract_text_from_image(&self, _path: &str) -> Result<OCRResult> {
// Check if tesseract is available on the system
let tesseract_check = std::process::Command::new("which")
.arg("tesseract")
.output();
if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() {
anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\
To install tesseract:\n macOS: brew install tesseract\n \
Linux: sudo apt-get install tesseract-ocr (Ubuntu/Debian)\n \
sudo yum install tesseract (RHEL/CentOS)\n \
Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki\n\n\
After installation, restart your terminal and try again.");
}
// Initialize Tesseract
let tess = Tesseract::new(None, Some("eng"))
.map_err(|e| {
anyhow::anyhow!("Failed to initialize Tesseract: {}\n\n\
This usually means:\n1. Tesseract is not properly installed\n\
2. Language data files are missing\n\nTo fix:\n \
macOS: brew reinstall tesseract\n \
Linux: sudo apt-get install tesseract-ocr-eng\n \
Windows: Reinstall tesseract and ensure language files are included", e)
})?;
let text = tess.set_image(_path)
.map_err(|e| anyhow::anyhow!("Failed to load image '{}': {}", _path, e))?
.get_text()
.map_err(|e| anyhow::anyhow!("Failed to extract text from image: {}", e))?;
// Get confidence (simplified - would need more complex API calls for per-word confidence)
let confidence = 0.85; // Placeholder
Ok(OCRResult {
text,
confidence,
bounds: Rect { x: 0, y: 0, width: 0, height: 0 }, // Would need image dimensions
})
}
async fn find_text_on_screen(&self, _text: &str) -> Result<Option<Point>> {
// Check if tesseract is available on the system
let tesseract_check = std::process::Command::new("which")
.arg("tesseract")
.output();
if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() {
anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\
To install tesseract:\n macOS: brew install tesseract\n \
Linux: sudo apt-get install tesseract-ocr (Ubuntu/Debian)\n \
sudo yum install tesseract (RHEL/CentOS)\n \
Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki\n\n\
After installation, restart your terminal and try again.");
}
// Take full screen screenshot
let temp_path = format!("/tmp/g3_ocr_search_{}.png", uuid::Uuid::new_v4());
self.take_screenshot(&temp_path, None, None).await?;
// Use Tesseract to find text with bounding boxes
let tess = Tesseract::new(None, Some("eng"))
.map_err(|e| {
anyhow::anyhow!("Failed to initialize Tesseract: {}\n\n\
This usually means:\n1. Tesseract is not properly installed\n\
2. Language data files are missing\n\nTo fix:\n \
macOS: brew reinstall tesseract\n \
Linux: sudo apt-get install tesseract-ocr-eng\n \
Windows: Reinstall tesseract and ensure language files are included", e)
})?;
let full_text = tess.set_image(temp_path.as_str())
.map_err(|e| anyhow::anyhow!("Failed to load screenshot: {}", e))?
.get_text()
.map_err(|e| anyhow::anyhow!("Failed to extract text from screen: {}", e))?;
// Clean up temp file
let _ = std::fs::remove_file(&temp_path);
// Simple text search - full implementation would use get_component_images
// to get bounding boxes for each word
if full_text.contains(_text) {
tracing::warn!("Text found but precise coordinates not available in simplified implementation");
Ok(Some(Point { x: 0, y: 0 }))
} else {
Ok(None)
}
}
}

View File

@@ -1,45 +0,0 @@
#[cfg(test)]
mod window_matching_tests {
/// Test that window name matching handles spaces correctly
///
/// Issue: When a user requests a screenshot of "Goose Studio" but the actual
/// application name is "GooseStudio" (no space), the fuzzy matching should
/// still find the window.
///
/// The fix normalizes both names by removing spaces before comparing.
#[test]
fn test_space_normalization() {
let test_cases = vec![
// (user_input, actual_app_name, should_match)
("Goose Studio", "GooseStudio", true),
("GooseStudio", "Goose Studio", true),
("Visual Studio Code", "VisualStudioCode", true),
("Google Chrome", "Google Chrome", true),
("Safari", "Safari", true),
("iTerm", "iTerm2", true), // fuzzy match
("Code", "Visual Studio Code", true), // fuzzy match
];
for (user_input, app_name, should_match) in test_cases {
let user_lower = user_input.to_lowercase();
let app_lower = app_name.to_lowercase();
let user_normalized = user_lower.replace(" ", "");
let app_normalized = app_lower.replace(" ", "");
let is_exact = app_lower == user_lower || app_normalized == user_normalized;
let is_fuzzy = app_lower.contains(&user_lower)
|| user_lower.contains(&app_lower)
|| app_normalized.contains(&user_normalized)
|| user_normalized.contains(&app_normalized);
let matches = is_exact || is_fuzzy;
assert_eq!(
matches, should_match,
"Expected '{}' vs '{}' to match={}, but got match={}",
user_input, app_name, should_match, matches
);
}
}
}

View File

@@ -62,15 +62,10 @@ impl ComputerController for WindowsController {
} }
async fn take_screenshot(&self, _path: &str, _region: Option<Rect>, _window_id: Option<&str>) -> Result<()> { async fn take_screenshot(&self, _path: &str, _region: Option<Rect>, _window_id: Option<&str>) -> Result<()> {
// Enforce that window_id must be provided
if _window_id.is_none() {
anyhow::bail!("window_id is required. You must specify which window to capture (e.g., 'Chrome', 'Terminal', 'Notepad'). Use list_windows to see available windows.");
}
anyhow::bail!("Windows implementation not yet available") anyhow::bail!("Windows implementation not yet available")
} }
async fn extract_text_from_screen(&self, _region: Rect, _window_id: &str) -> Result<String> { async fn extract_text_from_screen(&self, _region: Rect) -> Result<OCRResult> {
anyhow::bail!("Windows implementation not yet available") anyhow::bail!("Windows implementation not yet available")
} }

View File

@@ -7,13 +7,3 @@ pub struct Rect {
pub width: i32, pub width: i32,
pub height: i32, pub height: i32,
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextLocation {
pub text: String,
pub x: i32,
pub y: i32,
pub width: i32,
pub height: i32,
pub confidence: f32,
}

View File

@@ -1,5 +1,23 @@
use g3_computer_control::*; use g3_computer_control::*;
#[tokio::test]
async fn test_mouse_movement() {
let controller = create_controller().expect("Failed to create controller");
// Move mouse to center of screen (assuming 1920x1080)
let result = controller.move_mouse(960, 540).await;
assert!(result.is_ok(), "Failed to move mouse: {:?}", result.err());
}
#[tokio::test]
async fn test_typing() {
let controller = create_controller().expect("Failed to create controller");
// Type some text
let result = controller.type_text("Hello, World!").await;
assert!(result.is_ok(), "Failed to type text: {:?}", result.err());
}
#[tokio::test] #[tokio::test]
async fn test_screenshot() { async fn test_screenshot() {
let controller = create_controller().expect("Failed to create controller"); let controller = create_controller().expect("Failed to create controller");
@@ -15,3 +33,30 @@ async fn test_screenshot() {
// Clean up // Clean up
let _ = std::fs::remove_file(path); let _ = std::fs::remove_file(path);
} }
#[tokio::test]
async fn test_click() {
let controller = create_controller().expect("Failed to create controller");
// Click at a safe location
let result = controller.click(types::MouseButton::Left).await;
assert!(result.is_ok(), "Failed to click: {:?}", result.err());
}
#[tokio::test]
async fn test_double_click() {
let controller = create_controller().expect("Failed to create controller");
// Double click
let result = controller.double_click(types::MouseButton::Left).await;
assert!(result.is_ok(), "Failed to double click: {:?}", result.err());
}
#[tokio::test]
async fn test_press_key() {
let controller = create_controller().expect("Failed to create controller");
// Press escape key
let result = controller.press_key("escape").await;
assert!(result.is_ok(), "Failed to press key: {:?}", result.err());
}

View File

@@ -1,24 +0,0 @@
// swift-tools-version:5.9
import PackageDescription
let package = Package(
name: "VisionBridge",
platforms: [
.macOS(.v11)
],
products: [
.library(
name: "VisionBridge",
type: .dynamic,
targets: ["VisionBridge"]
),
],
targets: [
.target(
name: "VisionBridge",
dependencies: [],
path: "Sources/VisionBridge",
publicHeadersPath: "."
),
]
)

View File

@@ -1,39 +0,0 @@
#ifndef VisionBridge_h
#define VisionBridge_h
#include <stdint.h>
#include <stdbool.h>
#ifdef __cplusplus
extern "C" {
#endif
// Text box structure for FFI
typedef struct {
const char* text;
uint32_t text_len;
int32_t x;
int32_t y;
int32_t width;
int32_t height;
float confidence;
} VisionTextBox;
// Recognize text in an image and return bounding boxes
// Returns true on success, false on failure
// Caller must free the returned boxes using vision_free_boxes
bool vision_recognize_text(
const char* image_path,
uint32_t image_path_len,
VisionTextBox** out_boxes,
uint32_t* out_count
);
// Free memory allocated by vision_recognize_text
void vision_free_boxes(VisionTextBox* boxes, uint32_t count);
#ifdef __cplusplus
}
#endif
#endif /* VisionBridge_h */

View File

@@ -1,145 +0,0 @@
import Foundation
import Vision
import AppKit
import CoreGraphics
// MARK: - C Bridge Functions
@_cdecl("vision_recognize_text")
public func vision_recognize_text(
_ imagePath: UnsafePointer<CChar>,
_ imagePathLen: UInt32,
_ outBoxes: UnsafeMutablePointer<UnsafeMutableRawPointer?>,
_ outCount: UnsafeMutablePointer<UInt32>
) -> Bool {
// Convert C string to Swift String
guard let pathData = Data(bytes: imagePath, count: Int(imagePathLen)).withUnsafeBytes({
String(bytes: $0, encoding: .utf8)
}) else {
return false
}
let path = pathData.trimmingCharacters(in: .whitespaces)
// Load image
guard let image = NSImage(contentsOfFile: path),
let cgImage = image.cgImage(forProposedRect: nil, context: nil, hints: nil) else {
return false
}
// Perform OCR
var textBoxes: [CTextBox] = []
let semaphore = DispatchSemaphore(value: 0)
var success = false
let request = VNRecognizeTextRequest { request, error in
defer { semaphore.signal() }
if let error = error {
print("Vision OCR error: \(error.localizedDescription)")
return
}
guard let observations = request.results as? [VNRecognizedTextObservation] else {
return
}
let imageSize = CGSize(width: cgImage.width, height: cgImage.height)
for observation in observations {
guard let candidate = observation.topCandidates(1).first else { continue }
let text = candidate.string
let boundingBox = observation.boundingBox
// Convert normalized coordinates (bottom-left origin) to pixel coordinates (top-left origin)
let x = Int32(boundingBox.origin.x * imageSize.width)
let y = Int32((1.0 - boundingBox.origin.y - boundingBox.height) * imageSize.height)
let width = Int32(boundingBox.width * imageSize.width)
let height = Int32(boundingBox.height * imageSize.height)
// Allocate C string for text
let cString = strdup(text)
textBoxes.append(CTextBox(
text: cString,
text_len: UInt32(text.utf8.count),
x: x,
y: y,
width: width,
height: height,
confidence: observation.confidence
))
}
success = true
}
// Configure request for best accuracy
request.recognitionLevel = .accurate
request.usesLanguageCorrection = true
request.recognitionLanguages = ["en-US"]
// Perform request
let handler = VNImageRequestHandler(cgImage: cgImage, options: [:])
do {
try handler.perform([request])
} catch {
print("Vision request failed: \(error.localizedDescription)")
return false
}
// Wait for completion
semaphore.wait()
if !success {
return false
}
// Allocate array for results
let boxesPtr = UnsafeMutablePointer<CTextBox>.allocate(capacity: textBoxes.count)
for (index, box) in textBoxes.enumerated() {
boxesPtr[index] = box
}
outBoxes.pointee = UnsafeMutableRawPointer(boxesPtr)
outCount.pointee = UInt32(textBoxes.count)
return true
}
@_cdecl("vision_free_boxes")
public func vision_free_boxes(
_ boxes: UnsafeMutableRawPointer,
_ count: UInt32
) {
let typedBoxes = boxes.assumingMemoryBound(to: CTextBox.self)
for i in 0..<Int(count) {
if let text = typedBoxes[i].text {
free(UnsafeMutableRawPointer(mutating: text))
}
}
typedBoxes.deallocate()
}
// MARK: - C-Compatible Structure
public struct CTextBox {
public let text: UnsafePointer<CChar>?
public let text_len: UInt32
public let x: Int32
public let y: Int32
public let width: Int32
public let height: Int32
public let confidence: Float
public init(text: UnsafePointer<CChar>?, text_len: UInt32, x: Int32, y: Int32, width: Int32, height: Int32, confidence: Float) {
self.text = text
self.text_len = text_len
self.x = x
self.y = y
self.width = width
self.height = height
self.confidence = confidence
}
}

View File

@@ -12,6 +12,3 @@ thiserror = { workspace = true }
toml = "0.8" toml = "0.8"
shellexpand = "3.0" shellexpand = "3.0"
dirs = "5.0" dirs = "5.0"
[dev-dependencies]
tempfile = "3.8"

View File

@@ -0,0 +1,131 @@
#[cfg(test)]
mod autonomous_config_tests {
use crate::{Config, AnthropicConfig, DatabricksConfig};
#[test]
fn test_default_autonomous_config() {
let config = Config::default();
assert!(config.autonomous.coach_provider.is_none());
assert!(config.autonomous.coach_model.is_none());
assert!(config.autonomous.player_provider.is_none());
assert!(config.autonomous.player_model.is_none());
}
#[test]
fn test_for_coach_with_overrides() {
let mut config = Config::default();
// Set up base config with anthropic
config.providers.anthropic = Some(AnthropicConfig {
api_key: "test-key".to_string(),
model: "claude-3-5-sonnet-20241022".to_string(),
max_tokens: Some(4096),
temperature: Some(0.1),
});
// Set coach overrides
config.autonomous.coach_provider = Some("anthropic".to_string());
config.autonomous.coach_model = Some("claude-3-opus-20240229".to_string());
let coach_config = config.for_coach().unwrap();
// Verify coach uses overridden provider and model
assert_eq!(coach_config.providers.default_provider, "anthropic");
assert_eq!(
coach_config.providers.anthropic.as_ref().unwrap().model,
"claude-3-opus-20240229"
);
}
#[test]
fn test_for_player_with_overrides() {
let mut config = Config::default();
// Set up base config with databricks
config.providers.databricks = Some(DatabricksConfig {
host: "https://test.databricks.com".to_string(),
token: Some("test-token".to_string()),
model: "databricks-meta-llama-3-1-70b-instruct".to_string(),
max_tokens: Some(4096),
temperature: Some(0.1),
use_oauth: Some(false),
});
// Set player overrides
config.autonomous.player_provider = Some("databricks".to_string());
config.autonomous.player_model = Some("databricks-dbrx-instruct".to_string());
let player_config = config.for_player().unwrap();
// Verify player uses overridden provider and model
assert_eq!(player_config.providers.default_provider, "databricks");
assert_eq!(
player_config.providers.databricks.as_ref().unwrap().model,
"databricks-dbrx-instruct"
);
}
#[test]
fn test_no_overrides_uses_defaults() {
let mut config = Config::default();
config.providers.default_provider = "databricks".to_string();
let coach_config = config.for_coach().unwrap();
let player_config = config.for_player().unwrap();
// Both should use the default provider when no overrides
assert_eq!(coach_config.providers.default_provider, "databricks");
assert_eq!(player_config.providers.default_provider, "databricks");
}
#[test]
fn test_provider_override_only() {
let mut config = Config::default();
config.providers.anthropic = Some(AnthropicConfig {
api_key: "test-key".to_string(),
model: "claude-3-5-sonnet-20241022".to_string(),
max_tokens: Some(4096),
temperature: Some(0.1),
});
// Only override provider, not model
config.autonomous.coach_provider = Some("anthropic".to_string());
let coach_config = config.for_coach().unwrap();
// Should use overridden provider with its default model
assert_eq!(coach_config.providers.default_provider, "anthropic");
assert_eq!(
coach_config.providers.anthropic.as_ref().unwrap().model,
"claude-3-5-sonnet-20241022"
);
}
#[test]
fn test_model_override_only() {
let mut config = Config::default();
config.providers.default_provider = "databricks".to_string();
config.providers.databricks = Some(DatabricksConfig {
host: "https://test.databricks.com".to_string(),
token: Some("test-token".to_string()),
model: "databricks-meta-llama-3-1-70b-instruct".to_string(),
max_tokens: Some(4096),
temperature: Some(0.1),
use_oauth: Some(false),
});
// Only override model, not provider
config.autonomous.player_model = Some("databricks-dbrx-instruct".to_string());
let player_config = config.for_player().unwrap();
// Should use default provider with overridden model
assert_eq!(player_config.providers.default_provider, "databricks");
assert_eq!(
player_config.providers.databricks.as_ref().unwrap().model,
"databricks-dbrx-instruct"
);
}
}

View File

@@ -2,13 +2,16 @@ use serde::{Deserialize, Serialize};
use anyhow::Result; use anyhow::Result;
use std::path::Path; use std::path::Path;
#[cfg(test)]
mod autonomous_config_tests;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config { pub struct Config {
pub providers: ProvidersConfig, pub providers: ProvidersConfig,
pub agent: AgentConfig, pub agent: AgentConfig,
pub computer_control: ComputerControlConfig, pub computer_control: ComputerControlConfig,
pub webdriver: WebDriverConfig, pub webdriver: WebDriverConfig,
pub macax: MacAxConfig, pub autonomous: AutonomousConfig,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -18,8 +21,6 @@ pub struct ProvidersConfig {
pub databricks: Option<DatabricksConfig>, pub databricks: Option<DatabricksConfig>,
pub embedded: Option<EmbeddedConfig>, pub embedded: Option<EmbeddedConfig>,
pub default_provider: String, pub default_provider: String,
pub coach: Option<String>, // Provider to use for coach in autonomous mode
pub player: Option<String>, // Provider to use for player in autonomous mode
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -80,19 +81,6 @@ pub struct WebDriverConfig {
pub safari_port: u16, pub safari_port: u16,
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MacAxConfig {
pub enabled: bool,
}
impl Default for MacAxConfig {
fn default() -> Self {
Self {
enabled: false,
}
}
}
impl Default for WebDriverConfig { impl Default for WebDriverConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
@@ -102,6 +90,20 @@ impl Default for WebDriverConfig {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AutonomousConfig {
pub coach_provider: Option<String>,
pub coach_model: Option<String>,
pub player_provider: Option<String>,
pub player_model: Option<String>,
}
impl Default for AutonomousConfig {
fn default() -> Self {
Self { coach_provider: None, coach_model: None, player_provider: None, player_model: None }
}
}
impl Default for ComputerControlConfig { impl Default for ComputerControlConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
@@ -128,8 +130,6 @@ impl Default for Config {
}), }),
embedded: None, embedded: None,
default_provider: "databricks".to_string(), default_provider: "databricks".to_string(),
coach: None, // Will use default_provider if not specified
player: None, // Will use default_provider if not specified
}, },
agent: AgentConfig { agent: AgentConfig {
max_context_length: 8192, max_context_length: 8192,
@@ -138,7 +138,7 @@ impl Default for Config {
}, },
computer_control: ComputerControlConfig::default(), computer_control: ComputerControlConfig::default(),
webdriver: WebDriverConfig::default(), webdriver: WebDriverConfig::default(),
macax: MacAxConfig::default(), autonomous: AutonomousConfig::default(),
} }
} }
} }
@@ -243,8 +243,6 @@ impl Config {
threads: Some(8), threads: Some(8),
}), }),
default_provider: "embedded".to_string(), default_provider: "embedded".to_string(),
coach: None, // Will use default_provider if not specified
player: None, // Will use default_provider if not specified
}, },
agent: AgentConfig { agent: AgentConfig {
max_context_length: 8192, max_context_length: 8192,
@@ -253,7 +251,7 @@ impl Config {
}, },
computer_control: ComputerControlConfig::default(), computer_control: ComputerControlConfig::default(),
webdriver: WebDriverConfig::default(), webdriver: WebDriverConfig::default(),
macax: MacAxConfig::default(), autonomous: AutonomousConfig::default(),
} }
} }
@@ -323,66 +321,77 @@ impl Config {
Ok(config) Ok(config)
} }
/// Get the provider to use for coach mode in autonomous execution /// Create a config for the coach agent in autonomous mode
pub fn get_coach_provider(&self) -> &str { pub fn for_coach(&self) -> Result<Self> {
self.providers.coach let mut config = self.clone();
.as_deref()
.unwrap_or(&self.providers.default_provider) // Apply coach-specific overrides if configured
} if let Some(ref coach_provider) = self.autonomous.coach_provider {
config.providers.default_provider = coach_provider.clone();
/// Get the provider to use for player mode in autonomous execution }
pub fn get_player_provider(&self) -> &str {
self.providers.player if let Some(ref coach_model) = self.autonomous.coach_model {
.as_deref() // Apply model override to the coach's provider
.unwrap_or(&self.providers.default_provider) match config.providers.default_provider.as_str() {
} "anthropic" => {
if let Some(ref mut anthropic) = config.providers.anthropic {
/// Create a copy of the config with a different default provider anthropic.model = coach_model.clone();
pub fn with_provider_override(&self, provider: &str) -> Result<Self> { } else {
// Validate that the provider is configured return Err(anyhow::anyhow!(
match provider { "Coach provider 'anthropic' is not configured. Please add anthropic configuration to your config file."
"anthropic" if self.providers.anthropic.is_none() => { ));
return Err(anyhow::anyhow!( }
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.", }
provider, provider "databricks" => {
)); if let Some(ref mut databricks) = config.providers.databricks {
} databricks.model = coach_model.clone();
"databricks" if self.providers.databricks.is_none() => { } else {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.", "Coach provider 'databricks' is not configured. Please add databricks configuration to your config file."
provider, provider ));
)); }
} }
"embedded" if self.providers.embedded.is_none() => { _ => {}
return Err(anyhow::anyhow!( }
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
provider, provider
));
}
"openai" if self.providers.openai.is_none() => {
return Err(anyhow::anyhow!(
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
provider, provider
));
}
_ => {} // Provider is configured or unknown (will be caught later)
} }
let mut config = self.clone();
config.providers.default_provider = provider.to_string();
Ok(config) Ok(config)
} }
/// Create a copy of the config for coach mode in autonomous execution /// Create a config for the player agent in autonomous mode
pub fn for_coach(&self) -> Result<Self> {
self.with_provider_override(self.get_coach_provider())
}
/// Create a copy of the config for player mode in autonomous execution
pub fn for_player(&self) -> Result<Self> { pub fn for_player(&self) -> Result<Self> {
self.with_provider_override(self.get_player_provider()) let mut config = self.clone();
// Apply player-specific overrides if configured
if let Some(ref player_provider) = self.autonomous.player_provider {
config.providers.default_provider = player_provider.clone();
}
if let Some(ref player_model) = self.autonomous.player_model {
// Apply model override to the player's provider
match config.providers.default_provider.as_str() {
"anthropic" => {
if let Some(ref mut anthropic) = config.providers.anthropic {
anthropic.model = player_model.clone();
} else {
return Err(anyhow::anyhow!(
"Player provider 'anthropic' is not configured. Please add anthropic configuration to your config file."
));
}
}
"databricks" => {
if let Some(ref mut databricks) = config.providers.databricks {
databricks.model = player_model.clone();
} else {
return Err(anyhow::anyhow!(
"Player provider 'databricks' is not configured. Please add databricks configuration to your config file."
));
}
}
_ => {}
}
}
Ok(config)
} }
} }
#[cfg(test)]
mod tests;

View File

@@ -1,131 +0,0 @@
#[cfg(test)]
mod tests {
use crate::Config;
use std::fs;
use tempfile::TempDir;
#[test]
fn test_coach_player_providers() {
// Create a temporary directory for the test config
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("test_config.toml");
// Write a test configuration with coach and player providers
let config_content = r#"
[providers]
default_provider = "databricks"
coach = "anthropic"
player = "embedded"
[providers.databricks]
host = "https://test.databricks.com"
token = "test-token"
model = "test-model"
[providers.anthropic]
api_key = "test-key"
model = "claude-3"
[providers.embedded]
model_path = "test.gguf"
model_type = "llama"
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
"#;
fs::write(&config_path, config_content).unwrap();
// Load the configuration
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
// Test that the providers are correctly identified
assert_eq!(config.providers.default_provider, "databricks");
assert_eq!(config.get_coach_provider(), "anthropic");
assert_eq!(config.get_player_provider(), "embedded");
// Test creating coach config
let coach_config = config.for_coach().unwrap();
assert_eq!(coach_config.providers.default_provider, "anthropic");
// Test creating player config
let player_config = config.for_player().unwrap();
assert_eq!(player_config.providers.default_provider, "embedded");
}
#[test]
fn test_coach_player_fallback_to_default() {
// Create a temporary directory for the test config
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("test_config.toml");
// Write a test configuration WITHOUT coach and player providers
let config_content = r#"
[providers]
default_provider = "databricks"
[providers.databricks]
host = "https://test.databricks.com"
token = "test-token"
model = "test-model"
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
"#;
fs::write(&config_path, config_content).unwrap();
// Load the configuration
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
// Test that coach and player fall back to default provider
assert_eq!(config.get_coach_provider(), "databricks");
assert_eq!(config.get_player_provider(), "databricks");
// Test creating coach config (should use default)
let coach_config = config.for_coach().unwrap();
assert_eq!(coach_config.providers.default_provider, "databricks");
// Test creating player config (should use default)
let player_config = config.for_player().unwrap();
assert_eq!(player_config.providers.default_provider, "databricks");
}
#[test]
fn test_invalid_provider_error() {
// Create a temporary directory for the test config
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("test_config.toml");
// Write a test configuration with an unconfigured provider
let config_content = r#"
[providers]
default_provider = "databricks"
coach = "openai" # OpenAI is not configured
[providers.databricks]
host = "https://test.databricks.com"
token = "test-token"
model = "test-model"
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
"#;
fs::write(&config_path, config_content).unwrap();
// Load the configuration
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
// Test that trying to create a coach config with unconfigured provider fails
let result = config.for_coach();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not configured"));
}
}

View File

@@ -156,15 +156,15 @@ pub fn fixed_filter_json_tool_calls(content: &str) -> String {
} }
// No JSON tool call detected, return only the new content we haven't returned yet // No JSON tool call detected, return only the new content we haven't returned yet
let new_content = if state.buffer.len() > state.content_returned_up_to {
if state.buffer.len() > state.content_returned_up_to {
let result = state.buffer[state.content_returned_up_to..].to_string(); let result = state.buffer[state.content_returned_up_to..].to_string();
state.content_returned_up_to = state.buffer.len(); state.content_returned_up_to = state.buffer.len();
result result
} else { } else {
String::new() String::new()
} };
new_content
}) })
} }

File diff suppressed because it is too large Load Diff

View File

@@ -104,7 +104,6 @@ impl Project {
} }
/// Recursively check a directory for implementation files /// Recursively check a directory for implementation files
#[allow(clippy::only_used_in_recursion)]
fn check_dir_for_implementation_files(&self, dir: &Path) -> bool { fn check_dir_for_implementation_files(&self, dir: &Path) -> bool {
// Common source file extensions // Common source file extensions
let extensions = vec![ let extensions = vec![

View File

@@ -1,37 +0,0 @@
// Test to verify take_screenshot requires window_id
#[cfg(test)]
mod take_screenshot_tests {
use super::*;
use serde_json::json;
#[test]
fn test_take_screenshot_requires_window_id() {
// Create a tool call without window_id
let tool_call = ToolCall {
tool: "take_screenshot".to_string(),
args: json!({
"path": "test.png"
}),
};
// Verify that window_id is missing
assert!(tool_call.args.get("window_id").is_none());
}
#[test]
fn test_take_screenshot_with_window_id() {
// Create a tool call with window_id
let tool_call = ToolCall {
tool: "take_screenshot".to_string(),
args: json!({
"path": "test.png",
"window_id": "Safari"
}),
};
// Verify that window_id is present
assert!(tool_call.args.get("window_id").is_some());
assert_eq!(tool_call.args.get("window_id").unwrap().as_str().unwrap(), "Safari");
}
}

View File

@@ -17,9 +17,6 @@ pub trait UiWriter: Send + Sync {
/// Print a context window status message /// Print a context window status message
fn print_context_status(&self, message: &str); fn print_context_status(&self, message: &str);
/// Print a context thinning success message with highlight and animation
fn print_context_thinning(&self, message: &str);
/// Print a tool execution header /// Print a tool execution header
fn print_tool_header(&self, tool_name: &str); fn print_tool_header(&self, tool_name: &str);
@@ -52,10 +49,6 @@ pub trait UiWriter: Send + Sync {
/// Flush any buffered output /// Flush any buffered output
fn flush(&self); fn flush(&self);
/// Returns true if this UI writer wants full, untruncated output
/// Default is false (truncate for human readability)
fn wants_full_output(&self) -> bool { false }
} }
/// A no-op implementation for when UI output is not needed /// A no-op implementation for when UI output is not needed
@@ -67,7 +60,6 @@ impl UiWriter for NullUiWriter {
fn print_inline(&self, _message: &str) {} fn print_inline(&self, _message: &str) {}
fn print_system_prompt(&self, _prompt: &str) {} fn print_system_prompt(&self, _prompt: &str) {}
fn print_context_status(&self, _message: &str) {} fn print_context_status(&self, _message: &str) {}
fn print_context_thinning(&self, _message: &str) {}
fn print_tool_header(&self, _tool_name: &str) {} fn print_tool_header(&self, _tool_name: &str) {}
fn print_tool_arg(&self, _key: &str, _value: &str) {} fn print_tool_arg(&self, _key: &str, _value: &str) {}
fn print_tool_output_header(&self) {} fn print_tool_output_header(&self) {}
@@ -79,5 +71,4 @@ impl UiWriter for NullUiWriter {
fn print_agent_response(&self, _content: &str) {} fn print_agent_response(&self, _content: &str) {}
fn notify_sse_received(&self) {} fn notify_sse_received(&self) {}
fn flush(&self) {} fn flush(&self) {}
fn wants_full_output(&self) -> bool { false }
} }

View File

@@ -72,7 +72,7 @@ fn test_thin_context_basic() {
// Trigger thinning at 50% // Trigger thinning at 50%
context.used_tokens = 5000; context.used_tokens = 5000;
let (summary, _chars_saved) = context.thin_context(); let summary = context.thin_context();
println!("Thinning summary: {}", summary); println!("Thinning summary: {}", summary);
@@ -93,119 +93,6 @@ fn test_thin_context_basic() {
} }
} }
#[test]
fn test_thin_write_file_tool_calls() {
let mut context = ContextWindow::new(10000);
// Add some messages including a write_file tool call with large content
context.add_message(Message {
role: MessageRole::User,
content: "Please create a large file".to_string(),
});
// Add an assistant message with a write_file tool call containing large content
let large_content = "x".repeat(1500);
let tool_call_json = format!(
r#"{{"tool": "write_file", "args": {{"file_path": "test.txt", "content": "{}"}}}}"#,
large_content
);
context.add_message(Message {
role: MessageRole::Assistant,
content: format!("I'll create that file.\n\n{}", tool_call_json),
});
context.add_message(Message {
role: MessageRole::User,
content: "Tool result: ✅ Successfully wrote 1500 lines".to_string(),
});
// Add more messages to ensure we have enough for "first third" logic
for i in 0..6 {
context.add_message(Message {
role: MessageRole::Assistant,
content: format!("Response {}", i),
});
}
// Trigger thinning at 50%
context.used_tokens = 5000;
let (summary, _chars_saved) = context.thin_context();
println!("Thinning summary: {}", summary);
// Should have thinned the write_file tool call
assert!(summary.contains("tool call") || summary.contains("chars saved"));
// Check that the large content was replaced with a file reference
let first_third_end = context.conversation_history.len() / 3;
for i in 0..first_third_end {
if let Some(msg) = context.conversation_history.get(i) {
if matches!(msg.role, MessageRole::Assistant) && msg.content.contains("write_file") {
// The content should now reference an external file
assert!(msg.content.contains("<content saved to"));
assert!(!msg.content.contains(&large_content));
}
}
}
}
#[test]
fn test_thin_str_replace_tool_calls() {
let mut context = ContextWindow::new(10000);
// Add some messages including a str_replace tool call with large diff
context.add_message(Message {
role: MessageRole::User,
content: "Please update the file".to_string(),
});
// Add an assistant message with a str_replace tool call containing large diff
let large_diff = format!("--- old\n{}\n+++ new\n{}", "-old line\n".repeat(100), "+new line\n".repeat(100));
let tool_call_json = format!(
r#"{{"tool": "str_replace", "args": {{"file_path": "test.txt", "diff": "{}"}}}}"#,
large_diff.replace('\n', "\\n")
);
context.add_message(Message {
role: MessageRole::Assistant,
content: format!("I'll update that file.\n\n{}", tool_call_json),
});
context.add_message(Message {
role: MessageRole::User,
content: "Tool result: ✅ applied unified diff".to_string(),
});
// Add more messages to ensure we have enough for "first third" logic
for i in 0..6 {
context.add_message(Message {
role: MessageRole::Assistant,
content: format!("Response {}", i),
});
}
// Trigger thinning at 50%
context.used_tokens = 5000;
let (summary, _chars_saved) = context.thin_context();
println!("Thinning summary: {}", summary);
// Should have thinned the str_replace tool call
assert!(summary.contains("tool call") || summary.contains("chars saved"));
// Check that the large diff was replaced with a file reference
let first_third_end = context.conversation_history.len() / 3;
for i in 0..first_third_end {
if let Some(msg) = context.conversation_history.get(i) {
if matches!(msg.role, MessageRole::Assistant) && msg.content.contains("str_replace") {
// The diff should now reference an external file
assert!(msg.content.contains("<diff saved to"));
// Should not contain the large diff content
assert!(!msg.content.contains("old line"));
}
}
}
}
#[test] #[test]
fn test_thin_context_no_large_results() { fn test_thin_context_no_large_results() {
let mut context = ContextWindow::new(10000); let mut context = ContextWindow::new(10000);
@@ -219,10 +106,10 @@ fn test_thin_context_no_large_results() {
} }
context.used_tokens = 5000; context.used_tokens = 5000;
let (summary, _chars_saved) = context.thin_context(); let summary = context.thin_context();
// Should report no large results found // Should report no large results found
assert!(summary.contains("no large tool results or tool calls found")); assert!(summary.contains("no large tool results found"));
} }
#[test] #[test]
@@ -248,7 +135,7 @@ fn test_thin_context_only_affects_first_third() {
} }
context.used_tokens = 5000; context.used_tokens = 5000;
let (summary, _chars_saved) = context.thin_context(); let summary = context.thin_context();
// First third is 4 messages (indices 0-3), so only indices 1 and 3 should be thinned // First third is 4 messages (indices 0-3), so only indices 1 and 3 should be thinned
// That's 2 tool results // That's 2 tool results

View File

@@ -166,31 +166,6 @@ impl CodeExecutor {
/// Execute Bash code /// Execute Bash code
async fn execute_bash(&self, code: &str) -> Result<ExecutionResult> { async fn execute_bash(&self, code: &str) -> Result<ExecutionResult> {
// Check if this is a detached/daemon command that should run independently
let is_detached = code.trim_start().starts_with("setsid ")
|| code.trim_start().starts_with("nohup ")
|| code.contains(" disown")
|| (code.contains(" &") && (code.contains("nohup") || code.contains("setsid")));
if is_detached {
// For detached commands, just spawn and return immediately
use std::process::Stdio;
Command::new("bash")
.arg("-c")
.arg(code)
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()?;
return Ok(ExecutionResult {
stdout: "✅ Command launched in background (detached process)".to_string(),
stderr: String::new(),
exit_code: 0,
success: true,
});
}
let output = Command::new("bash") let output = Command::new("bash")
.arg("-c") .arg("-c")
.arg(code) .arg(code)
@@ -246,29 +221,6 @@ impl CodeExecutor {
use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command as TokioCommand; use tokio::process::Command as TokioCommand;
// Check if this is a detached/daemon command that should run independently
// Look for patterns like: setsid, nohup with &, or explicit backgrounding with disown
let is_detached = code.trim_start().starts_with("setsid ")
|| code.trim_start().starts_with("nohup ")
|| code.contains(" disown")
|| (code.contains(" &") && (code.contains("nohup") || code.contains("setsid")));
if is_detached {
// For detached commands, just spawn and return immediately
TokioCommand::new("bash")
.arg("-c")
.arg(code)
.spawn()?;
// Don't wait for the process - it's meant to run independently
return Ok(ExecutionResult {
stdout: "✅ Command launched in background (detached process)".to_string(),
stderr: String::new(),
exit_code: 0,
success: true,
});
}
let mut child = TokioCommand::new("bash") let mut child = TokioCommand::new("bash")
.arg("-c") .arg("-c")
.arg(code) .arg(code)
@@ -307,7 +259,7 @@ impl CodeExecutor {
line = stderr_lines.next_line() => { line = stderr_lines.next_line() => {
match line { match line {
Ok(Some(line)) => { Ok(Some(line)) => {
receiver.on_output_line(&line.to_string()); receiver.on_output_line(&format!("{}", line));
stderr_output.push(line); stderr_output.push(line);
} }
Ok(None) => {}, // stderr EOF, continue Ok(None) => {}, // stderr EOF, continue

View File

@@ -156,9 +156,8 @@ impl AnthropicProvider {
.post(ANTHROPIC_API_URL) .post(ANTHROPIC_API_URL)
.header("x-api-key", &self.api_key) .header("x-api-key", &self.api_key)
.header("anthropic-version", ANTHROPIC_VERSION) .header("anthropic-version", ANTHROPIC_VERSION)
// Anthropic beta 1m context window. Enable if needed. It costs extra, so check first.
// .header("anthropic-beta", "context-1m-2025-08-07")
.header("content-type", "application/json"); .header("content-type", "application/json");
if streaming { if streaming {
builder = builder.header("accept", "text/event-stream"); builder = builder.header("accept", "text/event-stream");
} }
@@ -276,7 +275,6 @@ impl AnthropicProvider {
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
let mut accumulated_usage: Option<Usage> = None; let mut accumulated_usage: Option<Usage> = None;
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
let mut actual_completion_tokens: u32 = 0; // Track actual completion tokens
while let Some(chunk_result) = stream.next().await { while let Some(chunk_result) = stream.next().await {
match chunk_result { match chunk_result {
@@ -324,12 +322,7 @@ impl AnthropicProvider {
let final_chunk = CompletionChunk { let final_chunk = CompletionChunk {
content: String::new(), content: String::new(),
finished: true, finished: true,
usage: accumulated_usage.as_ref().map(|u| Usage { usage: accumulated_usage.clone(),
prompt_tokens: u.prompt_tokens,
// Use actual completion tokens if we tracked them, otherwise use the estimate
completion_tokens: if actual_completion_tokens > 0 { actual_completion_tokens } else { u.completion_tokens },
total_tokens: u.prompt_tokens + if actual_completion_tokens > 0 { actual_completion_tokens } else { u.completion_tokens },
}),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) }, tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
}; };
if tx.send(Ok(final_chunk)).await.is_err() { if tx.send(Ok(final_chunk)).await.is_err() {
@@ -343,7 +336,6 @@ impl AnthropicProvider {
match serde_json::from_str::<AnthropicStreamEvent>(data) { match serde_json::from_str::<AnthropicStreamEvent>(data) {
Ok(event) => { Ok(event) => {
debug!("Parsed event type: {}, event: {:?}", event.event_type, event); debug!("Parsed event type: {}, event: {:?}", event.event_type, event);
match event.event_type.as_str() { match event.event_type.as_str() {
"message_start" => { "message_start" => {
// Extract usage data from message_start event // Extract usage data from message_start event
@@ -354,10 +346,7 @@ impl AnthropicProvider {
completion_tokens: usage.output_tokens, completion_tokens: usage.output_tokens,
total_tokens: usage.input_tokens + usage.output_tokens, total_tokens: usage.input_tokens + usage.output_tokens,
}); });
debug!("Captured initial usage from message_start - prompt: {}, completion: {} (estimated), total: {}", debug!("Captured usage from message_start: {:?}", accumulated_usage);
usage.input_tokens,
usage.output_tokens,
usage.input_tokens + usage.output_tokens);
} }
} }
} }
@@ -406,9 +395,6 @@ impl AnthropicProvider {
"content_block_delta" => { "content_block_delta" => {
if let Some(delta) = event.delta { if let Some(delta) = event.delta {
if let Some(text) = delta.text { if let Some(text) = delta.text {
// Track actual completion tokens (rough estimate: 4 chars per token)
actual_completion_tokens += (text.len() as f32 / 4.0).ceil() as u32;
debug!("Sending text chunk of length {}: '{}'", text.len(), text); debug!("Sending text chunk of length {}: '{}'", text.len(), text);
let chunk = CompletionChunk { let chunk = CompletionChunk {
content: text, content: text,
@@ -429,19 +415,6 @@ impl AnthropicProvider {
} }
} }
} }
"message_delta" => {
// Check if message_delta contains updated usage data
if let Some(delta) = event.delta {
if let Some(usage) = delta.usage {
accumulated_usage = Some(Usage {
prompt_tokens: usage.input_tokens,
completion_tokens: usage.output_tokens,
total_tokens: usage.input_tokens + usage.output_tokens,
});
debug!("Updated usage from message_delta - prompt: {}, completion: {}, total: {}", usage.input_tokens, usage.output_tokens, usage.input_tokens + usage.output_tokens);
}
}
}
"content_block_stop" => { "content_block_stop" => {
// Tool call block is complete - now parse the accumulated JSON // Tool call block is complete - now parse the accumulated JSON
if !current_tool_calls.is_empty() && !partial_tool_json.is_empty() { if !current_tool_calls.is_empty() && !partial_tool_json.is_empty() {
@@ -476,44 +449,11 @@ impl AnthropicProvider {
} }
} }
"message_stop" => { "message_stop" => {
debug!("Received message_stop event: {:?}", event); debug!("Received message stop event");
// Check if message_stop contains final usage data
if let Some(message) = event.message {
if let Some(usage) = message.usage {
// Update with final accurate usage data from message_stop
// This should have the actual completion token count
accumulated_usage = Some(Usage {
prompt_tokens: usage.input_tokens,
// Prefer the actual output_tokens from message_stop if available
// Otherwise use our tracked count, and as last resort the initial estimate
completion_tokens: if usage.output_tokens > 0 {
usage.output_tokens
} else if actual_completion_tokens > 0 {
actual_completion_tokens
} else { usage.output_tokens },
total_tokens: usage.input_tokens + usage.output_tokens,
});
debug!("Updated with final usage from message_stop - prompt: {}, completion: {}, total: {}",
usage.input_tokens,
usage.output_tokens,
usage.input_tokens + usage.output_tokens);
}
}
let final_chunk = CompletionChunk { let final_chunk = CompletionChunk {
content: String::new(), content: String::new(),
finished: true, finished: true,
usage: accumulated_usage.as_ref().map(|u| Usage { usage: accumulated_usage.clone(),
prompt_tokens: u.prompt_tokens,
// Use actual completion tokens if we tracked them and they're higher
completion_tokens: if actual_completion_tokens > u.completion_tokens {
actual_completion_tokens
} else {
u.completion_tokens
},
total_tokens: u.prompt_tokens + u32::max(actual_completion_tokens, u.completion_tokens),
}),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) }, tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
}; };
if tx.send(Ok(final_chunk)).await.is_err() { if tx.send(Ok(final_chunk)).await.is_err() {
@@ -555,27 +495,10 @@ impl AnthropicProvider {
let final_chunk = CompletionChunk { let final_chunk = CompletionChunk {
content: String::new(), content: String::new(),
finished: true, finished: true,
usage: accumulated_usage.as_ref().map(|u| Usage { usage: accumulated_usage.clone(),
prompt_tokens: u.prompt_tokens,
completion_tokens: if actual_completion_tokens > u.completion_tokens {
actual_completion_tokens
} else {
u.completion_tokens
},
total_tokens: u.prompt_tokens + u32::max(actual_completion_tokens, u.completion_tokens),
}),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) }, tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) },
}; };
let _ = tx.send(Ok(final_chunk)).await; let _ = tx.send(Ok(final_chunk)).await;
// Log final usage for debugging
if let Some(ref usage) = accumulated_usage {
info!("Anthropic stream completed with final usage - prompt: {}, completion: {}, total: {}",
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens);
} else {
warn!("Anthropic stream completed without usage data - token accounting will fall back to estimation");
}
accumulated_usage accumulated_usage
} }
} }
@@ -813,8 +736,6 @@ struct AnthropicStreamMessage {
struct AnthropicDelta { struct AnthropicDelta {
text: Option<String>, text: Option<String>,
partial_json: Option<String>, partial_json: Option<String>,
#[serde(default)]
usage: Option<AnthropicUsage>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]

View File

@@ -213,7 +213,7 @@ impl DatabricksProvider {
let mut builder = self let mut builder = self
.client .client
.post(format!( .post(&format!(
"{}/serving-endpoints/{}/invocations", "{}/serving-endpoints/{}/invocations",
self.host, self.model self.host, self.model
)) ))
@@ -881,14 +881,6 @@ impl LLMProvider for DatabricksProvider {
"Processing Databricks streaming request with {} messages", "Processing Databricks streaming request with {} messages",
request.messages.len() request.messages.len()
); );
// Debug: Log tool count
if let Some(ref tools) = request.tools {
debug!("Request has {} tools", tools.len());
for tool in tools.iter().take(5) {
debug!(" Tool: {}", tool.name);
}
}
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens); let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature); let temperature = request.temperature.unwrap_or(self.temperature);

View File

@@ -88,12 +88,10 @@ pub mod anthropic;
pub mod databricks; pub mod databricks;
pub mod embedded; pub mod embedded;
pub mod oauth; pub mod oauth;
pub mod openai;
pub use anthropic::AnthropicProvider; pub use anthropic::AnthropicProvider;
pub use databricks::DatabricksProvider; pub use databricks::DatabricksProvider;
pub use embedded::EmbeddedProvider; pub use embedded::EmbeddedProvider;
pub use openai::OpenAIProvider;
/// Provider registry for managing multiple LLM providers /// Provider registry for managing multiple LLM providers
pub struct ProviderRegistry { pub struct ProviderRegistry {

View File

@@ -102,7 +102,7 @@ async fn get_workspace_endpoints(host: &str) -> Result<OidcEndpoints> {
if !resp.status().is_success() { if !resp.status().is_success() {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
"Failed to get OIDC configuration from {}", "Failed to get OIDC configuration from {}",
oidc_url oidc_url.to_string()
)); ));
} }

View File

@@ -1,495 +0,0 @@
use anyhow::Result;
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::stream::StreamExt;
use reqwest::Client;
use serde::Deserialize;
use serde_json::json;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error};
use crate::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider,
Message, MessageRole, Tool, ToolCall, Usage,
};
#[derive(Clone)]
pub struct OpenAIProvider {
client: Client,
api_key: String,
model: String,
base_url: String,
max_tokens: Option<u32>,
_temperature: Option<f32>,
}
impl OpenAIProvider {
pub fn new(
api_key: String,
model: Option<String>,
base_url: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
) -> Result<Self> {
Ok(Self {
client: Client::new(),
api_key,
model: model.unwrap_or_else(|| "gpt-4o".to_string()),
base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
max_tokens,
_temperature: temperature,
})
}
fn create_request_body(
&self,
messages: &[Message],
tools: Option<&[Tool]>,
stream: bool,
max_tokens: Option<u32>,
_temperature: Option<f32>,
) -> serde_json::Value {
let mut body = json!({
"model": self.model,
"messages": convert_messages(messages),
"stream": stream,
});
if let Some(max_tokens) = max_tokens.or(self.max_tokens) {
body["max_completion_tokens"] = json!(max_tokens);
}
// OpenAI calls with temp setting seem to fail, so don't send one.
// if let Some(temperature) = temperature.or(self.temperature) {
// body["temperature"] = json!(temperature);
// }
if let Some(tools) = tools {
if !tools.is_empty() {
body["tools"] = json!(convert_tools(tools));
}
}
if stream {
body["stream_options"] = json!({
"include_usage": true,
});
}
body
}
async fn parse_streaming_response(
&self,
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
tx: mpsc::Sender<Result<CompletionChunk>>,
) -> Option<Usage> {
let mut buffer = String::new();
let mut accumulated_content = String::new();
let mut accumulated_usage: Option<Usage> = None;
let mut current_tool_calls: Vec<OpenAIStreamingToolCall> = Vec::new();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
let chunk_str = match std::str::from_utf8(&chunk) {
Ok(s) => s,
Err(e) => {
error!("Failed to parse chunk as UTF-8: {}", e);
continue;
}
};
buffer.push_str(chunk_str);
// Process complete lines
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].trim().to_string();
buffer.drain(..line_end + 1);
if line.is_empty() {
continue;
}
// Parse Server-Sent Events format
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
debug!("Received stream completion marker");
// Send final chunk with accumulated content and tool calls
if !accumulated_content.is_empty() || !current_tool_calls.is_empty() {
let tool_calls = if current_tool_calls.is_empty() {
None
} else {
Some(
current_tool_calls
.iter()
.filter_map(|tc| tc.to_tool_call())
.collect(),
)
};
let final_chunk = CompletionChunk {
content: accumulated_content.clone(),
finished: true,
tool_calls,
usage: accumulated_usage.clone(),
};
let _ = tx.send(Ok(final_chunk)).await;
}
return accumulated_usage;
}
// Parse the JSON data
match serde_json::from_str::<OpenAIStreamChunk>(data) {
Ok(chunk_data) => {
// Handle content
for choice in &chunk_data.choices {
if let Some(content) = &choice.delta.content {
accumulated_content.push_str(content);
let chunk = CompletionChunk {
content: content.clone(),
finished: false,
tool_calls: None,
usage: None,
};
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return accumulated_usage;
}
}
// Handle tool calls
if let Some(delta_tool_calls) = &choice.delta.tool_calls {
for delta_tool_call in delta_tool_calls {
if let Some(index) = delta_tool_call.index {
// Ensure we have enough tool calls in our vector
while current_tool_calls.len() <= index {
current_tool_calls
.push(OpenAIStreamingToolCall::default());
}
let tool_call = &mut current_tool_calls[index];
if let Some(id) = &delta_tool_call.id {
tool_call.id = Some(id.clone());
}
if let Some(function) = &delta_tool_call.function {
if let Some(name) = &function.name {
tool_call.name = Some(name.clone());
}
if let Some(arguments) = &function.arguments {
tool_call.arguments.push_str(arguments);
}
}
}
}
}
}
// Handle usage
if let Some(usage) = chunk_data.usage {
accumulated_usage = Some(Usage {
prompt_tokens: usage.prompt_tokens,
completion_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
});
}
}
Err(e) => {
debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
}
}
}
}
}
Err(e) => {
error!("Stream error: {}", e);
let _ = tx.send(Err(anyhow::anyhow!("Stream error: {}", e))).await;
return accumulated_usage;
}
}
}
// Send final chunk if we haven't already
let tool_calls = if current_tool_calls.is_empty() {
None
} else {
Some(
current_tool_calls
.iter()
.filter_map(|tc| tc.to_tool_call())
.collect(),
)
};
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
tool_calls,
usage: accumulated_usage.clone(),
};
let _ = tx.send(Ok(final_chunk)).await;
accumulated_usage
}
}
#[async_trait]
impl LLMProvider for OpenAIProvider {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
debug!(
"Processing OpenAI completion request with {} messages",
request.messages.len()
);
let body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
false,
request.max_tokens,
request.temperature,
);
debug!("Sending request to OpenAI API: model={}", self.model);
let response = self
.client
.post(format!("{}/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text));
}
let openai_response: OpenAIResponse = response.json().await?;
let content = openai_response
.choices
.first()
.and_then(|choice| choice.message.content.clone())
.unwrap_or_default();
let usage = Usage {
prompt_tokens: openai_response.usage.prompt_tokens,
completion_tokens: openai_response.usage.completion_tokens,
total_tokens: openai_response.usage.total_tokens,
};
debug!(
"OpenAI completion successful: {} tokens generated",
usage.completion_tokens
);
Ok(CompletionResponse {
content,
usage,
model: self.model.clone(),
})
}
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
debug!(
"Processing OpenAI streaming request with {} messages",
request.messages.len()
);
let body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
true,
request.max_tokens,
request.temperature,
);
debug!("Sending streaming request to OpenAI API: model={}", self.model);
let response = self
.client
.post(format!("{}/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text));
}
let stream = response.bytes_stream();
let (tx, rx) = mpsc::channel(100);
// Spawn task to process the stream
let provider = self.clone();
tokio::spawn(async move {
let usage = provider.parse_streaming_response(stream, tx).await;
// Log the final usage if available
if let Some(usage) = usage {
debug!(
"Stream completed with usage - prompt: {}, completion: {}, total: {}",
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
);
}
});
Ok(ReceiverStream::new(rx))
}
fn name(&self) -> &str {
"openai"
}
fn model(&self) -> &str {
&self.model
}
fn has_native_tool_calling(&self) -> bool {
// OpenAI models support native tool calling
true
}
}
fn convert_messages(messages: &[Message]) -> Vec<serde_json::Value> {
messages
.iter()
.map(|msg| {
json!({
"role": match msg.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
},
"content": msg.content,
})
})
.collect()
}
fn convert_tools(tools: &[Tool]) -> Vec<serde_json::Value> {
tools
.iter()
.map(|tool| {
json!({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.input_schema,
}
})
})
.collect()
}
// OpenAI API response structures
#[derive(Debug, Deserialize)]
struct OpenAIResponse {
choices: Vec<OpenAIChoice>,
usage: OpenAIUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAIChoice {
message: OpenAIMessage,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct OpenAIMessage {
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<OpenAIToolCall>>,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct OpenAIToolCall {
id: String,
function: OpenAIFunction,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct OpenAIFunction {
name: String,
arguments: String,
}
// Streaming tool call accumulator
#[derive(Debug, Default)]
struct OpenAIStreamingToolCall {
id: Option<String>,
name: Option<String>,
arguments: String,
}
impl OpenAIStreamingToolCall {
fn to_tool_call(&self) -> Option<ToolCall> {
let id = self.id.as_ref()?;
let name = self.name.as_ref()?;
let args = serde_json::from_str(&self.arguments).unwrap_or(serde_json::Value::Null);
Some(ToolCall {
id: id.clone(),
tool: name.clone(),
args,
})
}
}
#[derive(Debug, Deserialize)]
struct OpenAIUsage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
// Streaming response structures
#[derive(Debug, Deserialize)]
struct OpenAIStreamChunk {
choices: Vec<OpenAIStreamChoice>,
usage: Option<OpenAIUsage>,
}
#[derive(Debug, Deserialize)]
struct OpenAIStreamChoice {
delta: OpenAIDelta,
}
#[derive(Debug, Deserialize)]
struct OpenAIDelta {
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<OpenAIDeltaToolCall>>,
}
#[derive(Debug, Deserialize)]
struct OpenAIDeltaToolCall {
index: Option<usize>,
id: Option<String>,
function: Option<OpenAIDeltaFunction>,
}
#[derive(Debug, Deserialize)]
struct OpenAIDeltaFunction {
name: Option<String>,
arguments: Option<String>,
}

View File

@@ -1,164 +0,0 @@
#!/usr/bin/env python3
"""
Test script to verify token accounting is working correctly with the Anthropic provider.
This script will send multiple messages and verify that token counts accumulate properly.
"""
import subprocess
import json
import re
import sys
import time
def run_g3_command(prompt, provider="anthropic"):
"""Run a g3 command and capture the output."""
cmd = [
"cargo", "run", "--release", "--",
"--provider", provider,
prompt
]
env = {
"RUST_LOG": "g3_providers=debug,g3_core=info",
"RUST_BACKTRACE": "1"
}
result = subprocess.run(
cmd,
capture_output=True,
text=True,
env={**subprocess.os.environ, **env}
)
return result.stdout + result.stderr
def extract_token_info(output):
"""Extract token usage information from the output."""
token_info = {}
# Look for token usage updates
usage_pattern = r"Updated token usage.*was: (\d+), now: (\d+).*prompt=(\d+), completion=(\d+), total=(\d+)"
matches = re.findall(usage_pattern, output)
if matches:
last_match = matches[-1]
token_info['was'] = int(last_match[0])
token_info['now'] = int(last_match[1])
token_info['prompt'] = int(last_match[2])
token_info['completion'] = int(last_match[3])
token_info['total'] = int(last_match[4])
# Look for context percentage
context_pattern = r"Context usage at (\d+)%.*\((\d+)/(\d+) tokens\)"
matches = re.findall(context_pattern, output)
if matches:
last_match = matches[-1]
token_info['percentage'] = int(last_match[0])
token_info['used'] = int(last_match[1])
token_info['total_context'] = int(last_match[2])
# Look for thinning triggers
thinning_pattern = r"Context thinning triggered.*usage: (\d+)%.*\((\d+)/(\d+) tokens\)"
matches = re.findall(thinning_pattern, output)
if matches:
token_info['thinning_triggered'] = True
token_info['thinning_percentage'] = int(matches[-1][0])
# Look for final usage from Anthropic
final_usage_pattern = r"Anthropic stream completed with final usage.*prompt: (\d+), completion: (\d+), total: (\d+)"
matches = re.findall(final_usage_pattern, output)
if matches:
last_match = matches[-1]
token_info['final_prompt'] = int(last_match[0])
token_info['final_completion'] = int(last_match[1])
token_info['final_total'] = int(last_match[2])
return token_info
def main():
print("Testing Anthropic Provider Token Accounting")
print("="*50)
# Build the project first
print("Building project...")
subprocess.run(["cargo", "build", "--release"], capture_output=True)
# Test 1: Simple prompt
print("\nTest 1: Simple prompt")
print("-"*30)
output = run_g3_command("Say 'Hello, World!' and nothing else.")
tokens = extract_token_info(output)
if tokens:
print(f"Token usage: {tokens.get('now', 'N/A')} tokens")
print(f" Prompt tokens: {tokens.get('prompt', 'N/A')}")
print(f" Completion tokens: {tokens.get('completion', 'N/A')}")
print(f" Total from provider: {tokens.get('total', 'N/A')}")
if 'final_total' in tokens:
print(f" Final total from stream: {tokens['final_total']}")
if tokens.get('now') != tokens['final_total']:
print(f" ⚠️ WARNING: Mismatch between tracked ({tokens.get('now')}) and final ({tokens['final_total']})")
# Check if the completion tokens are reasonable (should be small for "Hello, World!")
if tokens.get('completion', 0) > 50:
print(f" ⚠️ WARNING: Completion tokens seem high for a simple response: {tokens.get('completion')}")
else:
print(" ❌ No token information found in output")
# Test 2: Longer response
print("\nTest 2: Longer response")
print("-"*30)
output = run_g3_command("Write a 3-paragraph essay about the importance of accurate token counting in LLM applications.")
tokens = extract_token_info(output)
if tokens:
print(f"Token usage: {tokens.get('now', 'N/A')} tokens")
print(f" Prompt tokens: {tokens.get('prompt', 'N/A')}")
print(f" Completion tokens: {tokens.get('completion', 'N/A')}")
print(f" Total from provider: {tokens.get('total', 'N/A')}")
if 'final_total' in tokens:
print(f" Final total from stream: {tokens['final_total']}")
if tokens.get('now') != tokens['final_total']:
print(f" ⚠️ WARNING: Mismatch between tracked ({tokens.get('now')}) and final ({tokens['final_total']})")
# Check if completion tokens are reasonable for a longer response
if tokens.get('completion', 0) < 100:
print(f" ⚠️ WARNING: Completion tokens seem low for a 3-paragraph essay: {tokens.get('completion')}")
else:
print(" ❌ No token information found in output")
# Test 3: Check for proper accumulation
print("\nTest 3: Token accumulation (multiple messages)")
print("-"*30)
# First message
output1 = run_g3_command("Count from 1 to 5.")
tokens1 = extract_token_info(output1)
# Second message (this would need to be in a conversation, but for now we test separately)
output2 = run_g3_command("Now count from 6 to 10.")
tokens2 = extract_token_info(output2)
if tokens1 and tokens2:
print(f"First message: {tokens1.get('now', 'N/A')} tokens")
print(f"Second message: {tokens2.get('now', 'N/A')} tokens")
# In a real conversation, tokens2['now'] should be greater than tokens1['now']
# But since these are separate invocations, we just check they're both reasonable
if tokens1.get('now', 0) > 0 and tokens2.get('now', 0) > 0:
print(" ✅ Both messages have token counts")
else:
print(" ❌ Missing token counts")
print("\n" + "="*50)
print("Test Summary:")
print("Check the output above for any warnings or errors.")
print("Key things to verify:")
print(" 1. Token counts are being captured from the provider")
print(" 2. Completion tokens are reasonable for the response length")
print(" 3. No mismatch between tracked and final token counts")
print(" 4. Context thinning triggers at appropriate thresholds")
if __name__ == "__main__":
main()

View File

@@ -1,46 +0,0 @@
#!/bin/bash
# Test script to verify token accounting with Anthropic provider
echo "Testing token accounting with Anthropic provider..."
echo "This test will send a few messages and check if token counts are properly tracked."
echo ""
# Set up environment for testing
export RUST_LOG=g3_providers=debug,g3_core=info
export RUST_BACKTRACE=1
# Build the project first
echo "Building project..."
cargo build --release 2>&1 | grep -E "(Compiling|Finished)" || true
echo ""
echo "Running test with Anthropic provider..."
echo "Watch for these log messages:"
echo " - 'Captured initial usage from message_start'"
echo " - 'Updated usage from message_delta' (if available)"
echo " - 'Updated with final usage from message_stop' (if available)"
echo " - 'Anthropic stream completed with final usage'"
echo " - 'Updated token usage from provider'"
echo " - 'Context thinning triggered' (when reaching thresholds)"
echo ""
# Create a simple test that will generate some tokens
cat << 'EOF' > /tmp/test_prompt.txt
Please write a short paragraph about the importance of accurate token counting in LLM applications. Then list 3 reasons why token accounting might fail.
EOF
# Run the test
echo "Sending test prompt..."
cargo run --release -- --provider anthropic "$(cat /tmp/test_prompt.txt)" 2>&1 | tee /tmp/token_test.log
echo ""
echo "Analyzing results..."
echo ""
# Check for token accounting messages
echo "Token accounting messages found:"
grep -E "(usage from|token usage|Context thinning|Context usage)" /tmp/token_test.log | head -20
echo ""
echo "Test complete. Check /tmp/token_test.log for full output."