Compare commits
44 Commits
micn/auton
...
micn/ollam
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79b375519b | ||
|
|
88c3cc23fe | ||
|
|
4622507f37 | ||
|
|
217df2f2af | ||
|
|
22a0090cdc | ||
|
|
631f3c16ca | ||
|
|
1f9fef5f18 | ||
|
|
57d473c19d | ||
|
|
e59ce2f93f | ||
|
|
a1ad94ed75 | ||
|
|
982c0bbfb3 | ||
|
|
ad9ba5e5d8 | ||
|
|
f89bbfc89a | ||
|
|
11eb01e04d | ||
|
|
bdaacfd051 | ||
|
|
92ae776510 | ||
|
|
c42e0bce54 | ||
|
|
b529d7f814 | ||
|
|
9752e81489 | ||
|
|
63c2aff7ba | ||
|
|
aa4a0267ea | ||
|
|
6cfa1e225c | ||
|
|
f53cd8e8f3 | ||
|
|
45bffc40da | ||
|
|
4bf0f71bbd | ||
|
|
c1ce3038d8 | ||
|
|
4b1694b308 | ||
|
|
5e08d6bbba | ||
|
|
c3f3f79dc5 | ||
|
|
834153ea69 | ||
|
|
65f25f840e | ||
|
|
a8af5d7cc1 | ||
|
|
61d748034d | ||
|
|
d0ac222e2e | ||
|
|
e1e732150a | ||
|
|
0be4829ca9 | ||
|
|
efd4eca755 | ||
|
|
3ec65e38ee | ||
|
|
c5d6fbef08 | ||
|
|
f93844d378 | ||
|
|
af6d37a8e2 | ||
|
|
c1c6680e03 | ||
|
|
f2d8e744bb | ||
|
|
010a43d203 |
5
.cargo/config.toml
Normal file
5
.cargo/config.toml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
[target.aarch64-apple-darwin]
|
||||||
|
rustflags = ["-C", "link-args=-Wl,-rpath,@executable_path"]
|
||||||
|
|
||||||
|
[target.x86_64-apple-darwin]
|
||||||
|
rustflags = ["-C", "link-args=-Wl,-rpath,@executable_path"]
|
||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -2,6 +2,8 @@
|
|||||||
# will have compiled files and executables
|
# will have compiled files and executables
|
||||||
debug
|
debug
|
||||||
target
|
target
|
||||||
|
.build
|
||||||
|
appy/
|
||||||
|
|
||||||
# These are backup files generated by rustfmt
|
# These are backup files generated by rustfmt
|
||||||
**/*.rs.bk
|
**/*.rs.bk
|
||||||
|
|||||||
334
Cargo.lock
generated
334
Cargo.lock
generated
@@ -2,6 +2,28 @@
|
|||||||
# It is not intended for manual editing.
|
# It is not intended for manual editing.
|
||||||
version = 4
|
version = 4
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "accessibility"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1ac9f33ffc1ef16eddb2451c03c983e56a5182ac760c3f2733da55ba8f48eac4"
|
||||||
|
dependencies = [
|
||||||
|
"accessibility-sys",
|
||||||
|
"cocoa 0.26.1",
|
||||||
|
"core-foundation 0.10.1",
|
||||||
|
"objc",
|
||||||
|
"thiserror 1.0.69",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "accessibility-sys"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "46a6a8e90a1d8b96a48249e7c8f5b4058447bea8847280db7bfccb6dcab6b8e1"
|
||||||
|
dependencies = [
|
||||||
|
"core-foundation-sys",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "adler2"
|
name = "adler2"
|
||||||
version = "2.0.1"
|
version = "2.0.1"
|
||||||
@@ -114,7 +136,7 @@ checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -196,28 +218,6 @@ version = "0.22.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
|
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "bindgen"
|
|
||||||
version = "0.64.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "c4243e6031260db77ede97ad86c27e501d646a27ab57b59a574f725d98ab1fb4"
|
|
||||||
dependencies = [
|
|
||||||
"bitflags 1.3.2",
|
|
||||||
"cexpr",
|
|
||||||
"clang-sys",
|
|
||||||
"lazy_static",
|
|
||||||
"lazycell",
|
|
||||||
"log",
|
|
||||||
"peeking_take_while",
|
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"regex",
|
|
||||||
"rustc-hash",
|
|
||||||
"shlex",
|
|
||||||
"syn 1.0.109",
|
|
||||||
"which",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bindgen"
|
name = "bindgen"
|
||||||
version = "0.69.5"
|
version = "0.69.5"
|
||||||
@@ -237,7 +237,7 @@ dependencies = [
|
|||||||
"regex",
|
"regex",
|
||||||
"rustc-hash",
|
"rustc-hash",
|
||||||
"shlex",
|
"shlex",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
"which",
|
"which",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -318,9 +318,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cc"
|
name = "cc"
|
||||||
version = "1.2.41"
|
version = "1.2.43"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ac9fe6cdbb24b6ade63616c0a0688e45bb56732262c158df3c0c4bea4ca47cb7"
|
checksum = "739eb0f94557554b3ca9a86d2d37bebd49c5e6d0c1d2bda35ba5bdac830befc2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"find-msvc-tools",
|
"find-msvc-tools",
|
||||||
"jobserver",
|
"jobserver",
|
||||||
@@ -411,7 +411,7 @@ dependencies = [
|
|||||||
"heck",
|
"heck",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -437,9 +437,25 @@ checksum = "f6140449f97a6e97f9511815c5632d84c8aacf8ac271ad77c559218161a1373c"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags 1.3.2",
|
"bitflags 1.3.2",
|
||||||
"block",
|
"block",
|
||||||
"cocoa-foundation",
|
"cocoa-foundation 0.1.2",
|
||||||
"core-foundation 0.9.4",
|
"core-foundation 0.9.4",
|
||||||
"core-graphics",
|
"core-graphics 0.23.2",
|
||||||
|
"foreign-types 0.5.0",
|
||||||
|
"libc",
|
||||||
|
"objc",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cocoa"
|
||||||
|
version = "0.26.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ad36507aeb7e16159dfe68db81ccc27571c3ccd4b76fb2fb72fc59e7a4b1b64c"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.10.0",
|
||||||
|
"block",
|
||||||
|
"cocoa-foundation 0.2.1",
|
||||||
|
"core-foundation 0.10.1",
|
||||||
|
"core-graphics 0.24.0",
|
||||||
"foreign-types 0.5.0",
|
"foreign-types 0.5.0",
|
||||||
"libc",
|
"libc",
|
||||||
"objc",
|
"objc",
|
||||||
@@ -454,11 +470,24 @@ dependencies = [
|
|||||||
"bitflags 1.3.2",
|
"bitflags 1.3.2",
|
||||||
"block",
|
"block",
|
||||||
"core-foundation 0.9.4",
|
"core-foundation 0.9.4",
|
||||||
"core-graphics-types",
|
"core-graphics-types 0.1.3",
|
||||||
"libc",
|
"libc",
|
||||||
"objc",
|
"objc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cocoa-foundation"
|
||||||
|
version = "0.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "81411967c50ee9a1fc11365f8c585f863a22a9697c89239c452292c40ba79b0d"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.10.0",
|
||||||
|
"block",
|
||||||
|
"core-foundation 0.10.1",
|
||||||
|
"core-graphics-types 0.2.0",
|
||||||
|
"objc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "color_quant"
|
name = "color_quant"
|
||||||
version = "1.1.0"
|
version = "1.1.0"
|
||||||
@@ -635,7 +664,20 @@ checksum = "c07782be35f9e1140080c6b96f0d44b739e2278479f64e02fdab4e32dfd8b081"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags 1.3.2",
|
"bitflags 1.3.2",
|
||||||
"core-foundation 0.9.4",
|
"core-foundation 0.9.4",
|
||||||
"core-graphics-types",
|
"core-graphics-types 0.1.3",
|
||||||
|
"foreign-types 0.5.0",
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "core-graphics"
|
||||||
|
version = "0.24.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "fa95a34622365fa5bbf40b20b75dba8dfa8c94c734aea8ac9a5ca38af14316f1"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.10.0",
|
||||||
|
"core-foundation 0.10.1",
|
||||||
|
"core-graphics-types 0.2.0",
|
||||||
"foreign-types 0.5.0",
|
"foreign-types 0.5.0",
|
||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
@@ -651,6 +693,17 @@ dependencies = [
|
|||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "core-graphics-types"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3d44a101f213f6c4cdc1853d4b78aef6db6bdfa3468798cc1d9912f4735013eb"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.10.0",
|
||||||
|
"core-foundation 0.10.1",
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cpufeatures"
|
name = "cpufeatures"
|
||||||
version = "0.2.17"
|
version = "0.2.17"
|
||||||
@@ -692,7 +745,7 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"strict",
|
"strict",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -831,7 +884,7 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"strsim",
|
"strsim",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -842,14 +895,14 @@ checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"darling_core",
|
"darling_core",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "deranged"
|
name = "deranged"
|
||||||
version = "0.5.4"
|
version = "0.5.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a41953f86f8a05768a6cda24def994fd2f424b04ec5c719cf89989779f199071"
|
checksum = "ececcb659e7ba858fb4f10388c250a7252eb0a27373f1a72b8748afdd248e587"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"powerfmt",
|
"powerfmt",
|
||||||
]
|
]
|
||||||
@@ -864,7 +917,7 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"rustc_version",
|
"rustc_version",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -885,7 +938,7 @@ dependencies = [
|
|||||||
"convert_case 0.7.1",
|
"convert_case 0.7.1",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -948,7 +1001,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -962,9 +1015,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "document-features"
|
name = "document-features"
|
||||||
version = "0.2.11"
|
version = "0.2.12"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "95249b50c6c185bee49034bcb378a49dc2b5dff0be90ff6616d31d64febab05d"
|
checksum = "d4b8a88685455ed29a21542a33abd9cb6510b6b129abadabdcef0f4c55bc8f61"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"litrs",
|
"litrs",
|
||||||
]
|
]
|
||||||
@@ -1091,9 +1144,9 @@ checksum = "52051878f80a721bb68ebfbc930e07b65ba72f2da88968ea5c06fd6ca3d3a127"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "flate2"
|
name = "flate2"
|
||||||
version = "1.1.4"
|
version = "1.1.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "dc5a4e564e38c699f2880d3fda590bedc2e69f3f84cd48b457bd892ce61d0aa9"
|
checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"crc32fast",
|
"crc32fast",
|
||||||
"miniz_oxide",
|
"miniz_oxide",
|
||||||
@@ -1138,7 +1191,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1218,7 +1271,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1287,18 +1340,18 @@ dependencies = [
|
|||||||
name = "g3-computer-control"
|
name = "g3-computer-control"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"accessibility",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"cocoa",
|
"cocoa 0.25.0",
|
||||||
"core-foundation 0.9.4",
|
"core-foundation 0.10.1",
|
||||||
"core-graphics",
|
"core-graphics 0.23.2",
|
||||||
"fantoccini",
|
"fantoccini",
|
||||||
"image",
|
"image",
|
||||||
"objc",
|
"objc",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"shellexpand",
|
"shellexpand",
|
||||||
"tesseract",
|
|
||||||
"thiserror 1.0.69",
|
"thiserror 1.0.69",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
@@ -1316,6 +1369,7 @@ dependencies = [
|
|||||||
"dirs 5.0.1",
|
"dirs 5.0.1",
|
||||||
"serde",
|
"serde",
|
||||||
"shellexpand",
|
"shellexpand",
|
||||||
|
"tempfile",
|
||||||
"thiserror 1.0.69",
|
"thiserror 1.0.69",
|
||||||
"toml",
|
"toml",
|
||||||
]
|
]
|
||||||
@@ -1337,6 +1391,7 @@ dependencies = [
|
|||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"serde_yaml",
|
||||||
"shellexpand",
|
"shellexpand",
|
||||||
"thiserror 1.0.69",
|
"thiserror 1.0.69",
|
||||||
"tokio",
|
"tokio",
|
||||||
@@ -1517,11 +1572,11 @@ checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "home"
|
name = "home"
|
||||||
version = "0.5.11"
|
version = "0.5.9"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
|
checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows-sys 0.59.0",
|
"windows-sys 0.52.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1868,9 +1923,12 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "indoc"
|
name = "indoc"
|
||||||
version = "2.0.6"
|
version = "2.0.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd"
|
checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706"
|
||||||
|
dependencies = [
|
||||||
|
"rustversion",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "instability"
|
name = "instability"
|
||||||
@@ -1882,7 +1940,7 @@ dependencies = [
|
|||||||
"indoc",
|
"indoc",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1893,9 +1951,9 @@ checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "is_terminal_polyfill"
|
name = "is_terminal_polyfill"
|
||||||
version = "1.70.1"
|
version = "1.70.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
|
checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "itertools"
|
name = "itertools"
|
||||||
@@ -2003,7 +2061,7 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"regex",
|
"regex",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2024,28 +2082,6 @@ version = "0.5.3"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7a79a3332a6609480d7d0c9eab957bca6b455b91bb84e66d19f5ff66294b85b8"
|
checksum = "7a79a3332a6609480d7d0c9eab957bca6b455b91bb84e66d19f5ff66294b85b8"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "leptonica-plumbing"
|
|
||||||
version = "1.4.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "cc7a74c43d6f090d39158d233f326f47cd8bba545217595c93662b4e31156f42"
|
|
||||||
dependencies = [
|
|
||||||
"leptonica-sys",
|
|
||||||
"libc",
|
|
||||||
"thiserror 1.0.69",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "leptonica-sys"
|
|
||||||
version = "0.4.9"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "da627c72b2499a8106f4dd33143843015e4a631f445d561f3481f7fba35b6151"
|
|
||||||
dependencies = [
|
|
||||||
"bindgen 0.64.0",
|
|
||||||
"pkg-config",
|
|
||||||
"vcpkg",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libc"
|
name = "libc"
|
||||||
version = "0.2.177"
|
version = "0.2.177"
|
||||||
@@ -2101,9 +2137,9 @@ checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "litrs"
|
name = "litrs"
|
||||||
version = "0.4.2"
|
version = "1.0.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f5e54036fe321fd421e10d732f155734c4e4afd610dd556d9a82833ab3ee0bed"
|
checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "llama_cpp"
|
name = "llama_cpp"
|
||||||
@@ -2126,7 +2162,7 @@ version = "0.3.2"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "037a1881ada3592c6a922224d5177b4b4f452e6b2979eb97393b71989e48357f"
|
checksum = "037a1881ada3592c6a922224d5177b4b4f452e6b2979eb97393b71989e48357f"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bindgen 0.69.5",
|
"bindgen",
|
||||||
"cc",
|
"cc",
|
||||||
"link-cplusplus",
|
"link-cplusplus",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
@@ -2219,14 +2255,14 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mio"
|
name = "mio"
|
||||||
version = "1.0.4"
|
version = "1.1.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c"
|
checksum = "69d83b0086dc8ecf3ce9ae2874b2d1290252e2a30720bea58a5c6639b0092873"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"log",
|
"log",
|
||||||
"wasi",
|
"wasi",
|
||||||
"windows-sys 0.59.0",
|
"windows-sys 0.61.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2374,9 +2410,9 @@ checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "once_cell_polyfill"
|
name = "once_cell_polyfill"
|
||||||
version = "1.70.1"
|
version = "1.70.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad"
|
checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "openssl"
|
name = "openssl"
|
||||||
@@ -2401,7 +2437,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2473,12 +2509,6 @@ version = "0.2.3"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3"
|
checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "peeking_take_while"
|
|
||||||
version = "0.1.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "percent-encoding"
|
name = "percent-encoding"
|
||||||
version = "2.3.2"
|
version = "2.3.2"
|
||||||
@@ -2515,7 +2545,7 @@ dependencies = [
|
|||||||
"pest_meta",
|
"pest_meta",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2596,14 +2626,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b"
|
checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "proc-macro2"
|
name = "proc-macro2"
|
||||||
version = "1.0.101"
|
version = "1.0.103"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de"
|
checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"unicode-ident",
|
"unicode-ident",
|
||||||
]
|
]
|
||||||
@@ -3001,7 +3031,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3049,6 +3079,19 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_yaml"
|
||||||
|
version = "0.9.34+deprecated"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47"
|
||||||
|
dependencies = [
|
||||||
|
"indexmap",
|
||||||
|
"itoa",
|
||||||
|
"ryu",
|
||||||
|
"serde",
|
||||||
|
"unsafe-libyaml",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sha2"
|
name = "sha2"
|
||||||
version = "0.10.9"
|
version = "0.10.9"
|
||||||
@@ -3096,9 +3139,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "signal-hook-mio"
|
name = "signal-hook-mio"
|
||||||
version = "0.2.4"
|
version = "0.2.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd"
|
checksum = "b75a19a7a740b25bc7944bdee6172368f988763b744e3d4dfe753f6b4ece40cc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"mio",
|
"mio",
|
||||||
@@ -3195,25 +3238,14 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"rustversion",
|
"rustversion",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syn"
|
name = "syn"
|
||||||
version = "1.0.109"
|
version = "2.0.108"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
|
checksum = "da58917d35242480a05c2897064da0a80589a2a0476c9a3f2fdc83b53502e917"
|
||||||
dependencies = [
|
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"unicode-ident",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "syn"
|
|
||||||
version = "2.0.107"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "2a26dbd934e5451d21ef060c018dae56fc073894c5a7896f882928a76e6d081b"
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
@@ -3240,7 +3272,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3293,40 +3325,6 @@ dependencies = [
|
|||||||
"unicode-width 0.1.14",
|
"unicode-width 0.1.14",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tesseract"
|
|
||||||
version = "0.14.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "2ee0c2c608b63817b095f7fded5c50add36a29e2be2b2fc4901357163329290a"
|
|
||||||
dependencies = [
|
|
||||||
"tesseract-plumbing",
|
|
||||||
"tesseract-sys",
|
|
||||||
"thiserror 1.0.69",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tesseract-plumbing"
|
|
||||||
version = "0.10.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "3e496d3e29eba540a276975394b85dccb5fd344b3eefb743d9286c8150f766d5"
|
|
||||||
dependencies = [
|
|
||||||
"leptonica-plumbing",
|
|
||||||
"tesseract-sys",
|
|
||||||
"thiserror 1.0.69",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tesseract-sys"
|
|
||||||
version = "0.5.15"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "bd33f6f216124cfaf0fa86c2c0cdf04da39b6257bd78c5e44fa4fa98c3a5857b"
|
|
||||||
dependencies = [
|
|
||||||
"bindgen 0.64.0",
|
|
||||||
"leptonica-sys",
|
|
||||||
"pkg-config",
|
|
||||||
"vcpkg",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror"
|
name = "thiserror"
|
||||||
version = "1.0.69"
|
version = "1.0.69"
|
||||||
@@ -3353,7 +3351,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3364,7 +3362,7 @@ checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3462,7 +3460,7 @@ checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3588,7 +3586,7 @@ checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3650,9 +3648,9 @@ checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-ident"
|
name = "unicode-ident"
|
||||||
version = "1.0.19"
|
version = "1.0.20"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f63a545481291138910575129486daeaf8ac54aee4387fe7906919f7830c7d9d"
|
checksum = "462eeb75aeb73aea900253ce739c8e18a67423fadf006037cd3ff27e82748a06"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-segmentation"
|
name = "unicode-segmentation"
|
||||||
@@ -3683,6 +3681,12 @@ version = "0.2.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd"
|
checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unsafe-libyaml"
|
||||||
|
version = "0.2.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "url"
|
name = "url"
|
||||||
version = "2.5.7"
|
version = "2.5.7"
|
||||||
@@ -3793,7 +3797,7 @@ dependencies = [
|
|||||||
"log",
|
"log",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
"wasm-bindgen-shared",
|
"wasm-bindgen-shared",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -3828,7 +3832,7 @@ checksum = "9f07d2f20d4da7b26400c9f4a0511e6e0345b040694e8a75bd41d578fa4421d7"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
"wasm-bindgen-backend",
|
"wasm-bindgen-backend",
|
||||||
"wasm-bindgen-shared",
|
"wasm-bindgen-shared",
|
||||||
]
|
]
|
||||||
@@ -4000,7 +4004,7 @@ checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -4011,7 +4015,7 @@ checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -4407,7 +4411,7 @@ checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
"synstructure",
|
"synstructure",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -4428,7 +4432,7 @@ checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -4448,7 +4452,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
"synstructure",
|
"synstructure",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -4482,7 +4486,7 @@ checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.107",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|||||||
456
OLLAMA_CONFIG.md
Normal file
456
OLLAMA_CONFIG.md
Normal file
@@ -0,0 +1,456 @@
|
|||||||
|
# Configuring Ollama Provider in G3
|
||||||
|
|
||||||
|
This guide shows you how to configure G3 to use Ollama as your LLM provider.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Install Ollama
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Visit https://ollama.ai to download and install
|
||||||
|
# Or use curl:
|
||||||
|
curl https://ollama.ai/install.sh | sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Pull a Model
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ollama pull llama3.2
|
||||||
|
# or any other model you prefer
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Create Configuration File
|
||||||
|
|
||||||
|
Copy the example configuration:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp config.ollama.example.toml ~/.config/g3/config.toml
|
||||||
|
```
|
||||||
|
|
||||||
|
Or create it manually:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama"
|
||||||
|
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Run G3
|
||||||
|
|
||||||
|
```bash
|
||||||
|
g3
|
||||||
|
# G3 will now use Ollama with llama3.2!
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Options
|
||||||
|
|
||||||
|
### Basic Configuration
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama"
|
||||||
|
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2"
|
||||||
|
```
|
||||||
|
|
||||||
|
This is the minimal configuration needed. It uses all defaults:
|
||||||
|
- Base URL: `http://localhost:11434`
|
||||||
|
- Temperature: `0.7`
|
||||||
|
- Max tokens: Not limited (uses model default)
|
||||||
|
|
||||||
|
### Full Configuration
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama"
|
||||||
|
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2"
|
||||||
|
base_url = "http://localhost:11434"
|
||||||
|
max_tokens = 2048
|
||||||
|
temperature = 0.7
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Ollama Host
|
||||||
|
|
||||||
|
If you're running Ollama on a different machine or port:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2"
|
||||||
|
base_url = "http://192.168.1.100:11434"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Different Models
|
||||||
|
|
||||||
|
You can use any Ollama model:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
model = "qwen2.5:7b" # Alibaba's Qwen model
|
||||||
|
```
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
model = "mistral" # Mistral AI
|
||||||
|
```
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.1:70b" # Larger Llama model
|
||||||
|
```
|
||||||
|
|
||||||
|
## Multiple Provider Configuration
|
||||||
|
|
||||||
|
You can configure multiple providers and switch between them:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama" # Default for most operations
|
||||||
|
|
||||||
|
# Ollama for local, fast responses
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2:3b"
|
||||||
|
temperature = 0.7
|
||||||
|
|
||||||
|
# Databricks for more complex tasks
|
||||||
|
[providers.databricks]
|
||||||
|
host = "https://your-workspace.cloud.databricks.com"
|
||||||
|
model = "databricks-claude-sonnet-4"
|
||||||
|
max_tokens = 4096
|
||||||
|
temperature = 0.1
|
||||||
|
use_oauth = true
|
||||||
|
```
|
||||||
|
|
||||||
|
Then switch providers with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
g3 --provider databricks
|
||||||
|
```
|
||||||
|
|
||||||
|
## Autonomous Mode (Coach-Player)
|
||||||
|
|
||||||
|
Use different providers for code review (coach) and implementation (player):
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama"
|
||||||
|
coach = "databricks" # Use powerful cloud model for review
|
||||||
|
player = "ollama" # Use local model for implementation
|
||||||
|
|
||||||
|
[providers.ollama]
|
||||||
|
model = "qwen2.5:14b" # Larger local model for coding
|
||||||
|
|
||||||
|
[providers.databricks]
|
||||||
|
host = "https://your-workspace.cloud.databricks.com"
|
||||||
|
model = "databricks-claude-sonnet-4"
|
||||||
|
use_oauth = true
|
||||||
|
```
|
||||||
|
|
||||||
|
This gives you the best of both worlds:
|
||||||
|
- Fast local execution for coding tasks
|
||||||
|
- Powerful cloud review for quality assurance
|
||||||
|
|
||||||
|
## Recommended Models
|
||||||
|
|
||||||
|
### For Coding Tasks
|
||||||
|
|
||||||
|
| Model | Size | Speed | Quality | Notes |
|
||||||
|
|-------|------|-------|---------|-------|
|
||||||
|
| **qwen2.5:7b** | 7B | Fast | Excellent | Best balance for coding |
|
||||||
|
| **llama3.2:3b** | 3B | Very Fast | Good | Great for quick tasks |
|
||||||
|
| **llama3.1:8b** | 8B | Medium | Very Good | Solid all-rounder |
|
||||||
|
| **mistral** | 7B | Fast | Good | Good for general use |
|
||||||
|
|
||||||
|
### For Complex Tasks
|
||||||
|
|
||||||
|
| Model | Size | Speed | Quality | Notes |
|
||||||
|
|-------|------|-------|---------|-------|
|
||||||
|
| **qwen2.5:14b** | 14B | Medium | Excellent | Best local model for coding |
|
||||||
|
| **qwen2.5:32b** | 32B | Slow | Outstanding | If you have the resources |
|
||||||
|
| **llama3.1:70b** | 70B | Very Slow | Outstanding | Requires significant RAM/GPU |
|
||||||
|
|
||||||
|
## Temperature Settings
|
||||||
|
|
||||||
|
Temperature controls randomness in responses:
|
||||||
|
|
||||||
|
- **0.1-0.3**: Deterministic, good for code generation
|
||||||
|
- **0.5-0.7**: Balanced, good for most tasks
|
||||||
|
- **0.8-1.0**: Creative, good for brainstorming
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
model = "qwen2.5:7b"
|
||||||
|
temperature = 0.2 # Focused code generation
|
||||||
|
```
|
||||||
|
|
||||||
|
## Max Tokens
|
||||||
|
|
||||||
|
Control response length:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2"
|
||||||
|
max_tokens = 1024 # Shorter responses
|
||||||
|
```
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
model = "qwen2.5:7b"
|
||||||
|
max_tokens = 4096 # Longer, detailed responses
|
||||||
|
```
|
||||||
|
|
||||||
|
Leave it unset for model defaults (recommended).
|
||||||
|
|
||||||
|
## Performance Tuning
|
||||||
|
|
||||||
|
### GPU Acceleration
|
||||||
|
|
||||||
|
Ollama automatically uses GPU if available. To check:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ollama ps
|
||||||
|
```
|
||||||
|
|
||||||
|
### Quantized Models
|
||||||
|
|
||||||
|
For faster responses with less RAM:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2:3b-q4_0" # 4-bit quantization
|
||||||
|
```
|
||||||
|
|
||||||
|
Quantization options:
|
||||||
|
- `q4_0`: 4-bit, fastest, lowest quality
|
||||||
|
- `q5_0`: 5-bit, balanced
|
||||||
|
- `q8_0`: 8-bit, slower, better quality
|
||||||
|
|
||||||
|
### Multiple Models
|
||||||
|
|
||||||
|
You can pull multiple models and switch easily:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ollama pull llama3.2:3b # Fast for chat
|
||||||
|
ollama pull qwen2.5:7b # Better for code
|
||||||
|
ollama pull mistral # General purpose
|
||||||
|
```
|
||||||
|
|
||||||
|
Then change your config:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
model = "qwen2.5:7b" # Just change this line
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Ollama Not Running
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Check if Ollama is running
|
||||||
|
curl http://localhost:11434/api/version
|
||||||
|
|
||||||
|
# Start Ollama (macOS/Linux)
|
||||||
|
ollama serve
|
||||||
|
|
||||||
|
# Or just run a model (auto-starts)
|
||||||
|
ollama run llama3.2
|
||||||
|
```
|
||||||
|
|
||||||
|
### Model Not Found
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# List available models
|
||||||
|
ollama list
|
||||||
|
|
||||||
|
# Pull the model
|
||||||
|
ollama pull llama3.2
|
||||||
|
```
|
||||||
|
|
||||||
|
### Slow Responses
|
||||||
|
|
||||||
|
1. Use a smaller model:
|
||||||
|
```toml
|
||||||
|
model = "llama3.2:1b" # Smallest, fastest
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Use quantized version:
|
||||||
|
```toml
|
||||||
|
model = "llama3.2:3b-q4_0"
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Reduce max_tokens:
|
||||||
|
```toml
|
||||||
|
max_tokens = 512
|
||||||
|
```
|
||||||
|
|
||||||
|
### Out of Memory
|
||||||
|
|
||||||
|
1. Switch to smaller model
|
||||||
|
2. Use quantized version
|
||||||
|
3. Close other applications
|
||||||
|
4. Check GPU memory: `ollama ps`
|
||||||
|
|
||||||
|
### Connection Refused
|
||||||
|
|
||||||
|
Check base_url is correct:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2"
|
||||||
|
base_url = "http://localhost:11434" # Default
|
||||||
|
```
|
||||||
|
|
||||||
|
For remote Ollama:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
base_url = "http://your-server:11434"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Complete Example Configs
|
||||||
|
|
||||||
|
### Minimal Local Setup
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama"
|
||||||
|
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2"
|
||||||
|
|
||||||
|
[agent]
|
||||||
|
max_context_length = 8192
|
||||||
|
enable_streaming = true
|
||||||
|
timeout_seconds = 60
|
||||||
|
```
|
||||||
|
|
||||||
|
### Optimized for Coding
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama"
|
||||||
|
|
||||||
|
[providers.ollama]
|
||||||
|
model = "qwen2.5:7b"
|
||||||
|
temperature = 0.2
|
||||||
|
max_tokens = 2048
|
||||||
|
|
||||||
|
[agent]
|
||||||
|
max_context_length = 16384
|
||||||
|
enable_streaming = true
|
||||||
|
timeout_seconds = 120
|
||||||
|
```
|
||||||
|
|
||||||
|
### Fast Responses
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama"
|
||||||
|
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2:3b-q4_0"
|
||||||
|
temperature = 0.7
|
||||||
|
max_tokens = 1024
|
||||||
|
|
||||||
|
[agent]
|
||||||
|
max_context_length = 4096
|
||||||
|
enable_streaming = true
|
||||||
|
timeout_seconds = 30
|
||||||
|
```
|
||||||
|
|
||||||
|
### High Quality (Requires Good Hardware)
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama"
|
||||||
|
|
||||||
|
[providers.ollama]
|
||||||
|
model = "qwen2.5:32b"
|
||||||
|
temperature = 0.3
|
||||||
|
max_tokens = 4096
|
||||||
|
|
||||||
|
[agent]
|
||||||
|
max_context_length = 32768
|
||||||
|
enable_streaming = true
|
||||||
|
timeout_seconds = 300
|
||||||
|
```
|
||||||
|
|
||||||
|
### Hybrid (Local + Cloud)
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama"
|
||||||
|
coach = "databricks"
|
||||||
|
player = "ollama"
|
||||||
|
|
||||||
|
[providers.ollama]
|
||||||
|
model = "qwen2.5:14b"
|
||||||
|
temperature = 0.2
|
||||||
|
|
||||||
|
[providers.databricks]
|
||||||
|
host = "https://your-workspace.cloud.databricks.com"
|
||||||
|
model = "databricks-claude-sonnet-4"
|
||||||
|
use_oauth = true
|
||||||
|
|
||||||
|
[agent]
|
||||||
|
max_context_length = 16384
|
||||||
|
enable_streaming = true
|
||||||
|
timeout_seconds = 120
|
||||||
|
```
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
|
You can override config with environment variables:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Override model
|
||||||
|
G3_PROVIDERS_OLLAMA_MODEL=qwen2.5:7b g3
|
||||||
|
|
||||||
|
# Override base URL
|
||||||
|
G3_PROVIDERS_OLLAMA_BASE_URL=http://192.168.1.100:11434 g3
|
||||||
|
|
||||||
|
# Override default provider
|
||||||
|
G3_PROVIDERS_DEFAULT_PROVIDER=ollama g3
|
||||||
|
```
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Start Small**: Begin with llama3.2:3b, scale up if needed
|
||||||
|
2. **Use Quantization**: q4_0 or q5_0 for best speed/quality balance
|
||||||
|
3. **Match Task to Model**:
|
||||||
|
- Quick edits: 1B-3B models
|
||||||
|
- Code generation: 7B-14B models
|
||||||
|
- Complex refactoring: 14B-32B models
|
||||||
|
4. **Temperature for Code**: Use 0.1-0.3 for deterministic output
|
||||||
|
5. **Enable Streaming**: Always enable for better UX
|
||||||
|
6. **Local First**: Use Ollama by default, cloud for special cases
|
||||||
|
|
||||||
|
## Comparison with Other Providers
|
||||||
|
|
||||||
|
| Feature | Ollama | Databricks | OpenAI | Anthropic |
|
||||||
|
|---------|--------|------------|--------|-----------|
|
||||||
|
| Cost | Free | Paid | Paid | Paid |
|
||||||
|
| Privacy | Full | Medium | Low | Low |
|
||||||
|
| Speed (small models) | Fast | Fast | Medium | Medium |
|
||||||
|
| Speed (large models) | Slow | Fast | Fast | Fast |
|
||||||
|
| Setup Complexity | Low | Medium | Low | Low |
|
||||||
|
| Authentication | None | OAuth/Token | API Key | API Key |
|
||||||
|
| Offline Support | Yes | No | No | No |
|
||||||
|
| Tool Calling | Yes | Yes | Yes | Yes |
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
1. Try different models: `ollama pull mistral`, `ollama pull qwen2.5`
|
||||||
|
2. Experiment with temperature settings
|
||||||
|
3. Set up hybrid config with cloud provider for complex tasks
|
||||||
|
4. Share your config in the community!
|
||||||
|
|
||||||
|
## Getting Help
|
||||||
|
|
||||||
|
- Ollama docs: https://ollama.ai/docs
|
||||||
|
- G3 issues: https://github.com/your-repo/issues
|
||||||
|
- Test your config: `g3 --help`
|
||||||
315
OLLAMA_EXAMPLE.md
Normal file
315
OLLAMA_EXAMPLE.md
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
# Ollama Provider for g3
|
||||||
|
|
||||||
|
A simple, local LLM provider implementation for g3 that connects to Ollama.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- ✅ **Simple Setup**: No API keys or authentication required
|
||||||
|
- ✅ **Local Execution**: Runs entirely on your machine
|
||||||
|
- ✅ **Tool Calling Support**: Native tool calling for compatible models
|
||||||
|
- ✅ **Streaming**: Full streaming support with real-time responses
|
||||||
|
- ✅ **Flexible Configuration**: Custom base URL, temperature, and max tokens
|
||||||
|
- ✅ **Model Discovery**: Automatic detection of available models
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
1. Install and start Ollama: https://ollama.ai
|
||||||
|
2. Pull a model: `ollama pull llama3.2`
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use g3_providers::{OllamaProvider, LLMProvider, CompletionRequest, Message, MessageRole};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
// Create provider with default settings (localhost:11434)
|
||||||
|
let provider = OllamaProvider::new(
|
||||||
|
"llama3.2".to_string(),
|
||||||
|
None, // base_url: defaults to http://localhost:11434
|
||||||
|
None, // max_tokens: optional
|
||||||
|
None, // temperature: defaults to 0.7
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Create a simple request
|
||||||
|
let request = CompletionRequest {
|
||||||
|
messages: vec![
|
||||||
|
Message {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: "What is the capital of France?".to_string(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
max_tokens: Some(1000),
|
||||||
|
temperature: Some(0.7),
|
||||||
|
stream: false,
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get completion
|
||||||
|
let response = provider.complete(request).await?;
|
||||||
|
println!("Response: {}", response.content);
|
||||||
|
println!("Tokens: {}", response.usage.total_tokens);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Streaming Example
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
|
||||||
|
let request = CompletionRequest {
|
||||||
|
messages: vec![
|
||||||
|
Message {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: "Write a short poem about coding".to_string(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
max_tokens: Some(500),
|
||||||
|
temperature: Some(0.8),
|
||||||
|
stream: true,
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut stream = provider.stream(request).await?;
|
||||||
|
|
||||||
|
while let Some(chunk_result) = stream.next().await {
|
||||||
|
match chunk_result {
|
||||||
|
Ok(chunk) => {
|
||||||
|
print!("{}", chunk.content);
|
||||||
|
if chunk.finished {
|
||||||
|
println!("\n\nDone!");
|
||||||
|
if let Some(usage) = chunk.usage {
|
||||||
|
println!("Total tokens: {}", usage.total_tokens);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => eprintln!("Error: {}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Tool Calling Example
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
let tools = vec![Tool {
|
||||||
|
name: "get_weather".to_string(),
|
||||||
|
description: "Get current weather for a location".to_string(),
|
||||||
|
input_schema: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "City name"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "Temperature unit"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}),
|
||||||
|
}];
|
||||||
|
|
||||||
|
let request = CompletionRequest {
|
||||||
|
messages: vec![
|
||||||
|
Message {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: "What's the weather in Paris?".to_string(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
max_tokens: Some(500),
|
||||||
|
temperature: Some(0.5),
|
||||||
|
stream: false,
|
||||||
|
tools: Some(tools),
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = provider.complete(request).await?;
|
||||||
|
println!("Response: {}", response.content);
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Ollama Host
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// Connect to remote Ollama instance
|
||||||
|
let provider = OllamaProvider::new(
|
||||||
|
"llama3.2".to_string(),
|
||||||
|
Some("http://192.168.1.100:11434".to_string()),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
```
|
||||||
|
|
||||||
|
### Fetch Available Models
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// Discover what models are available
|
||||||
|
let models = provider.fetch_available_models().await?;
|
||||||
|
println!("Available models:");
|
||||||
|
for model in models {
|
||||||
|
println!(" - {}", model);
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
The provider works with any Ollama model, including:
|
||||||
|
|
||||||
|
- **llama3.2** (1B, 3B) - Meta's latest Llama models
|
||||||
|
- **llama3.1** (8B, 70B, 405B) - Previous generation
|
||||||
|
- **qwen2.5** (7B, 14B, 32B) - Alibaba's Qwen models
|
||||||
|
- **mistral** - Mistral AI models
|
||||||
|
- **mixtral** - Mixture of experts model
|
||||||
|
- **phi3** - Microsoft's Phi-3
|
||||||
|
- **gemma2** - Google's Gemma 2
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Constructor Parameters
|
||||||
|
|
||||||
|
```rust
|
||||||
|
OllamaProvider::new(
|
||||||
|
model: String, // Model name (e.g., "llama3.2")
|
||||||
|
base_url: Option<String>, // Ollama API URL (default: http://localhost:11434)
|
||||||
|
max_tokens: Option<u32>, // Maximum tokens to generate (optional)
|
||||||
|
temperature: Option<f32>, // Sampling temperature (default: 0.7)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Request Options
|
||||||
|
|
||||||
|
```rust
|
||||||
|
CompletionRequest {
|
||||||
|
messages: Vec<Message>, // Conversation history
|
||||||
|
max_tokens: Option<u32>, // Override provider's max_tokens
|
||||||
|
temperature: Option<f32>, // Override provider's temperature
|
||||||
|
stream: bool, // Enable streaming responses
|
||||||
|
tools: Option<Vec<Tool>>, // Tools for function calling
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Comparison with Other Providers
|
||||||
|
|
||||||
|
| Feature | Ollama | OpenAI | Anthropic | Databricks |
|
||||||
|
|---------|--------|--------|-----------|------------|
|
||||||
|
| Local Execution | ✅ | ❌ | ❌ | ❌ |
|
||||||
|
| Authentication | None | API Key | API Key | OAuth/Token |
|
||||||
|
| Tool Calling | ✅ | ✅ | ✅ | ✅ |
|
||||||
|
| Streaming | ✅ | ✅ | ✅ | ✅ |
|
||||||
|
| Cost | Free | Paid | Paid | Paid |
|
||||||
|
| Privacy | High | Low | Low | Medium |
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
### API Endpoints
|
||||||
|
|
||||||
|
- **Chat Completion**: `POST /api/chat`
|
||||||
|
- **Model List**: `GET /api/tags`
|
||||||
|
|
||||||
|
### Response Format
|
||||||
|
|
||||||
|
Ollama uses a simple JSON-per-line streaming format:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"message":{"role":"assistant","content":"Hello"},"done":false}
|
||||||
|
{"message":{"role":"assistant","content":" there"},"done":false}
|
||||||
|
{"done":true,"prompt_eval_count":10,"eval_count":20}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Tool Call Format
|
||||||
|
|
||||||
|
Tool calls are returned in the message structure:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": {"location": "Paris", "unit": "celsius"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"done": true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Connection Errors
|
||||||
|
|
||||||
|
If you see connection errors, ensure Ollama is running:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Check if Ollama is running
|
||||||
|
curl http://localhost:11434/api/version
|
||||||
|
|
||||||
|
# Start Ollama (if needed)
|
||||||
|
ollama serve
|
||||||
|
```
|
||||||
|
|
||||||
|
### Model Not Found
|
||||||
|
|
||||||
|
Pull the model first:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ollama pull llama3.2
|
||||||
|
ollama list # Check available models
|
||||||
|
```
|
||||||
|
|
||||||
|
### Performance Issues
|
||||||
|
|
||||||
|
- Use smaller models (1B, 3B) for faster responses
|
||||||
|
- Reduce `max_tokens` to limit generation length
|
||||||
|
- Enable GPU acceleration if available
|
||||||
|
- Consider quantized models (e.g., `llama3.2:3b-q4_0`)
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
Run the included tests:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo test --package g3-providers ollama
|
||||||
|
```
|
||||||
|
|
||||||
|
All tests should pass:
|
||||||
|
```
|
||||||
|
running 4 tests
|
||||||
|
test ollama::tests::test_custom_base_url ... ok
|
||||||
|
test ollama::tests::test_message_conversion ... ok
|
||||||
|
test ollama::tests::test_provider_creation ... ok
|
||||||
|
test ollama::tests::test_tool_conversion ... ok
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
The provider follows the same architecture as other g3 providers:
|
||||||
|
|
||||||
|
1. **OllamaProvider**: Main struct implementing `LLMProvider` trait
|
||||||
|
2. **Request/Response Structures**: Internal types for Ollama API
|
||||||
|
3. **Streaming Parser**: Handles line-by-line JSON parsing
|
||||||
|
4. **Tool Call Handling**: Accumulates and converts tool calls
|
||||||
|
5. **Error Handling**: Robust error handling with retries
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
The provider is part of the g3-providers crate. To contribute:
|
||||||
|
|
||||||
|
1. Add features to `ollama.rs`
|
||||||
|
2. Update tests
|
||||||
|
3. Run `cargo test --package g3-providers`
|
||||||
|
4. Update this documentation
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
Same as the g3 project.
|
||||||
66
README.md
66
README.md
@@ -72,6 +72,16 @@ G3 includes robust error handling with automatic retry logic:
|
|||||||
- Conversation history preservation through summaries
|
- Conversation history preservation through summaries
|
||||||
- Dynamic token allocation for different providers (4k to 200k+ tokens)
|
- Dynamic token allocation for different providers (4k to 200k+ tokens)
|
||||||
|
|
||||||
|
### Interactive Control Commands
|
||||||
|
G3's interactive CLI includes control commands for manual context management:
|
||||||
|
- **`/compact`**: Manually trigger summarization to compact conversation history
|
||||||
|
- **`/thinnify`**: Manually trigger context thinning to replace large tool results with file references
|
||||||
|
- **`/readme`**: Reload README.md and AGENTS.md from disk without restarting
|
||||||
|
- **`/stats`**: Show detailed context and performance statistics
|
||||||
|
- **`/help`**: Display all available control commands
|
||||||
|
|
||||||
|
These commands give you fine-grained control over context management, allowing you to proactively optimize token usage and refresh project documentation. See [Control Commands Documentation](docs/CONTROL_COMMANDS.md) for detailed usage.
|
||||||
|
|
||||||
### Tool Ecosystem
|
### Tool Ecosystem
|
||||||
- **File Operations**: Read, write, and edit files with line-range precision
|
- **File Operations**: Read, write, and edit files with line-range precision
|
||||||
- **Shell Integration**: Execute system commands with output capture
|
- **Shell Integration**: Execute system commands with output capture
|
||||||
@@ -79,6 +89,7 @@ G3 includes robust error handling with automatic retry logic:
|
|||||||
- **TODO Management**: Read and write TODO lists with markdown checkbox format
|
- **TODO Management**: Read and write TODO lists with markdown checkbox format
|
||||||
- **Computer Control** (Experimental): Automate desktop applications
|
- **Computer Control** (Experimental): Automate desktop applications
|
||||||
- Mouse and keyboard control
|
- Mouse and keyboard control
|
||||||
|
- macOS Accessibility API for native app automation (via `--macax` flag)
|
||||||
- UI element inspection
|
- UI element inspection
|
||||||
- Screenshot capture and window management
|
- Screenshot capture and window management
|
||||||
- OCR text extraction from images and screen regions
|
- OCR text extraction from images and screen regions
|
||||||
@@ -121,12 +132,50 @@ G3 is designed for:
|
|||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
|
### Default Mode: Accumulative Autonomous
|
||||||
|
|
||||||
|
The default interactive mode now uses **accumulative autonomous mode**, which combines the best of interactive and autonomous workflows:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Simply run g3 in any directory
|
||||||
|
g3
|
||||||
|
|
||||||
|
# You'll be prompted to describe what you want to build
|
||||||
|
# Each input you provide:
|
||||||
|
# 1. Gets added to accumulated requirements
|
||||||
|
# 2. Automatically triggers autonomous mode (coach-player loop)
|
||||||
|
# 3. Implements your requirements iteratively
|
||||||
|
|
||||||
|
# Example session:
|
||||||
|
requirement> create a simple web server in Python with Flask
|
||||||
|
# ... autonomous mode runs and implements it ...
|
||||||
|
requirement> add a /health endpoint that returns JSON
|
||||||
|
# ... autonomous mode runs again with both requirements ...
|
||||||
|
```
|
||||||
|
|
||||||
|
### Other Modes
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Single-shot mode (one task, then exit)
|
||||||
|
g3 "implement a function to calculate fibonacci numbers"
|
||||||
|
|
||||||
|
# Traditional autonomous mode (reads requirements.md)
|
||||||
|
g3 --autonomous
|
||||||
|
|
||||||
|
# Traditional chat mode (simple interactive chat without autonomous runs)
|
||||||
|
g3 --chat
|
||||||
|
```
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Build the project
|
# Build the project
|
||||||
cargo build --release
|
cargo build --release
|
||||||
|
|
||||||
# Run G3
|
# Run from the build directory
|
||||||
cargo run
|
./target/release/g3
|
||||||
|
|
||||||
|
# Or copy both files to somewhere in your PATH (macOS only needs both files)
|
||||||
|
cp target/release/g3 ~/.local/bin/
|
||||||
|
cp target/release/libVisionBridge.dylib ~/.local/bin/ # macOS only
|
||||||
|
|
||||||
# Execute a task
|
# Execute a task
|
||||||
g3 "implement a function to calculate fibonacci numbers"
|
g3 "implement a function to calculate fibonacci numbers"
|
||||||
@@ -156,6 +205,19 @@ safaridriver --enable # Requires password
|
|||||||
|
|
||||||
**Usage**: Run G3 with the `--webdriver` flag to enable browser automation tools.
|
**Usage**: Run G3 with the `--webdriver` flag to enable browser automation tools.
|
||||||
|
|
||||||
|
## macOS Accessibility API Tools
|
||||||
|
|
||||||
|
G3 includes support for controlling macOS applications via the Accessibility API, allowing you to automate native macOS apps.
|
||||||
|
|
||||||
|
**Available Tools**: `macax_list_apps`, `macax_get_frontmost_app`, `macax_activate_app`, `macax_get_ui_tree`, `macax_find_elements`, `macax_click`, `macax_set_value`, `macax_get_value`, `macax_press_key`
|
||||||
|
|
||||||
|
**Setup**: Enable with the `--macax` flag or in config with `macax.enabled = true`. Grant accessibility permissions:
|
||||||
|
- **macOS**: System Preferences → Security & Privacy → Privacy → Accessibility → Add your terminal app
|
||||||
|
|
||||||
|
**For detailed documentation**, see [macOS Accessibility Tools Guide](docs/macax-tools.md).
|
||||||
|
|
||||||
|
**Note**: This is particularly useful for testing and automating apps you're building with G3, as you can add accessibility identifiers to your UI elements.
|
||||||
|
|
||||||
## Computer Control (Experimental)
|
## Computer Control (Experimental)
|
||||||
|
|
||||||
G3 can interact with your computer's GUI for automation tasks:
|
G3 can interact with your computer's GUI for automation tasks:
|
||||||
|
|||||||
24
config.coach-player.example.toml
Normal file
24
config.coach-player.example.toml
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
[providers]
|
||||||
|
default_provider = "databricks"
|
||||||
|
# Specify different providers for coach and player in autonomous mode
|
||||||
|
coach = "databricks" # Provider for coach (code reviewer) - can be more powerful/expensive
|
||||||
|
player = "anthropic" # Provider for player (code implementer) - can be faster/cheaper
|
||||||
|
|
||||||
|
[providers.databricks]
|
||||||
|
host = "https://your-workspace.cloud.databricks.com"
|
||||||
|
# token = "your-databricks-token" # Optional - will use OAuth if not provided
|
||||||
|
model = "databricks-claude-sonnet-4"
|
||||||
|
max_tokens = 4096
|
||||||
|
temperature = 0.1
|
||||||
|
use_oauth = true
|
||||||
|
|
||||||
|
[providers.anthropic]
|
||||||
|
api_key = "your-anthropic-api-key"
|
||||||
|
model = "claude-3-haiku-20240307" # Using a faster model for player
|
||||||
|
max_tokens = 4096
|
||||||
|
temperature = 0.3 # Slightly higher temperature for more creative implementations
|
||||||
|
|
||||||
|
[agent]
|
||||||
|
max_context_length = 8192
|
||||||
|
enable_streaming = true
|
||||||
|
timeout_seconds = 60
|
||||||
@@ -1,5 +1,10 @@
|
|||||||
[providers]
|
[providers]
|
||||||
default_provider = "databricks"
|
default_provider = "databricks"
|
||||||
|
# Optional: Specify different providers for coach and player in autonomous mode
|
||||||
|
# If not specified, will use default_provider for both
|
||||||
|
# coach = "databricks" # Provider for coach (code reviewer)
|
||||||
|
# player = "anthropic" # Provider for player (code implementer)
|
||||||
|
# Note: Make sure the specified providers are configured below
|
||||||
|
|
||||||
[providers.databricks]
|
[providers.databricks]
|
||||||
host = "https://your-workspace.cloud.databricks.com"
|
host = "https://your-workspace.cloud.databricks.com"
|
||||||
|
|||||||
26
config.ollama.example.toml
Normal file
26
config.ollama.example.toml
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# Example G3 configuration using Ollama provider
|
||||||
|
# Copy this to ~/.config/g3/config.toml or ./g3.toml to use it
|
||||||
|
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama"
|
||||||
|
|
||||||
|
# Ollama configuration (local LLM)
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2" # or qwen2.5, mistral, etc.
|
||||||
|
# base_url = "http://localhost:11434" # Optional, defaults to localhost
|
||||||
|
# max_tokens = 2048 # Optional
|
||||||
|
# temperature = 0.7 # Optional
|
||||||
|
|
||||||
|
# Optional: Specify different providers for coach and player in autonomous mode
|
||||||
|
# coach = "ollama" # Provider for coach (code reviewer)
|
||||||
|
# player = "ollama" # Provider for player (code implementer)
|
||||||
|
|
||||||
|
[agent]
|
||||||
|
max_context_length = 8192
|
||||||
|
enable_streaming = true
|
||||||
|
timeout_seconds = 60
|
||||||
|
|
||||||
|
[computer_control]
|
||||||
|
enabled = false # Set to true to enable computer control (requires OS permissions)
|
||||||
|
require_confirmation = true
|
||||||
|
max_actions_per_second = 5
|
||||||
File diff suppressed because it is too large
Load Diff
94
crates/g3-cli/src/machine_ui_writer.rs
Normal file
94
crates/g3-cli/src/machine_ui_writer.rs
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
use g3_core::ui_writer::UiWriter;
|
||||||
|
use std::io::{self, Write};
|
||||||
|
|
||||||
|
/// Machine-mode implementation of UiWriter that prints plain, unformatted output
|
||||||
|
/// This is designed for programmatic consumption and outputs everything verbatim
|
||||||
|
pub struct MachineUiWriter;
|
||||||
|
|
||||||
|
impl MachineUiWriter {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UiWriter for MachineUiWriter {
|
||||||
|
fn print(&self, message: &str) {
|
||||||
|
print!("{}", message);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn println(&self, message: &str) {
|
||||||
|
println!("{}", message);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_inline(&self, message: &str) {
|
||||||
|
print!("{}", message);
|
||||||
|
let _ = io::stdout().flush();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_system_prompt(&self, prompt: &str) {
|
||||||
|
println!("SYSTEM_PROMPT:");
|
||||||
|
println!("{}", prompt);
|
||||||
|
println!("END_SYSTEM_PROMPT");
|
||||||
|
println!();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_context_status(&self, message: &str) {
|
||||||
|
println!("CONTEXT_STATUS: {}", message);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_context_thinning(&self, message: &str) {
|
||||||
|
println!("CONTEXT_THINNING: {}", message);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_tool_header(&self, tool_name: &str) {
|
||||||
|
println!("TOOL_CALL: {}", tool_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_tool_arg(&self, key: &str, value: &str) {
|
||||||
|
println!("TOOL_ARG: {} = {}", key, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_tool_output_header(&self) {
|
||||||
|
println!("TOOL_OUTPUT:");
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_tool_output_line(&self, line: &str) {
|
||||||
|
println!("{}", line);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_tool_output_line(&self, line: &str) {
|
||||||
|
println!("{}", line);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_tool_output_summary(&self, count: usize) {
|
||||||
|
println!("TOOL_OUTPUT_LINES: {}", count);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_tool_timing(&self, duration_str: &str) {
|
||||||
|
println!("TOOL_DURATION: {}", duration_str);
|
||||||
|
println!("END_TOOL_OUTPUT");
|
||||||
|
println!();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_agent_prompt(&self) {
|
||||||
|
println!("AGENT_RESPONSE:");
|
||||||
|
let _ = io::stdout().flush();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_agent_response(&self, content: &str) {
|
||||||
|
print!("{}", content);
|
||||||
|
let _ = io::stdout().flush();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn notify_sse_received(&self) {
|
||||||
|
// No-op for machine mode
|
||||||
|
}
|
||||||
|
|
||||||
|
fn flush(&self) {
|
||||||
|
let _ = io::stdout().flush();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn wants_full_output(&self) -> bool {
|
||||||
|
true // Machine mode wants complete, untruncated output
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -267,23 +267,23 @@ impl TerminalState {
|
|||||||
let mut current_text = String::new();
|
let mut current_text = String::new();
|
||||||
|
|
||||||
// Check for headers first
|
// Check for headers first
|
||||||
if line.starts_with("### ") {
|
if let Some(stripped) = line.strip_prefix("### ") {
|
||||||
return Line::from(Span::styled(
|
return Line::from(Span::styled(
|
||||||
format!(" {}", &line[4..]),
|
format!(" {}", stripped),
|
||||||
Style::default()
|
Style::default()
|
||||||
.fg(self.theme.terminal_cyan.to_color())
|
.fg(self.theme.terminal_cyan.to_color())
|
||||||
.add_modifier(Modifier::BOLD | Modifier::UNDERLINED),
|
.add_modifier(Modifier::BOLD | Modifier::UNDERLINED),
|
||||||
));
|
));
|
||||||
} else if line.starts_with("## ") {
|
} else if let Some(stripped) = line.strip_prefix("## ") {
|
||||||
return Line::from(Span::styled(
|
return Line::from(Span::styled(
|
||||||
format!(" {}", &line[3..]),
|
format!(" {}", stripped),
|
||||||
Style::default()
|
Style::default()
|
||||||
.fg(self.theme.terminal_amber.to_color())
|
.fg(self.theme.terminal_amber.to_color())
|
||||||
.add_modifier(Modifier::BOLD),
|
.add_modifier(Modifier::BOLD),
|
||||||
));
|
));
|
||||||
} else if line.starts_with("# ") {
|
} else if let Some(stripped) = line.strip_prefix("# ") {
|
||||||
return Line::from(Span::styled(
|
return Line::from(Span::styled(
|
||||||
format!(" {}", &line[2..]),
|
format!(" {}", stripped),
|
||||||
Style::default()
|
Style::default()
|
||||||
.fg(self.theme.terminal_green.to_color())
|
.fg(self.theme.terminal_green.to_color())
|
||||||
.add_modifier(Modifier::BOLD),
|
.add_modifier(Modifier::BOLD),
|
||||||
@@ -343,7 +343,7 @@ impl TerminalState {
|
|||||||
}
|
}
|
||||||
// Find closing *
|
// Find closing *
|
||||||
let mut italic_text = String::new();
|
let mut italic_text = String::new();
|
||||||
while let Some(ch) = chars.next() {
|
for ch in chars.by_ref() {
|
||||||
if ch == '*' {
|
if ch == '*' {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -367,7 +367,7 @@ impl TerminalState {
|
|||||||
}
|
}
|
||||||
// Find closing `
|
// Find closing `
|
||||||
let mut code_text = String::new();
|
let mut code_text = String::new();
|
||||||
while let Some(ch) = chars.next() {
|
for ch in chars.by_ref() {
|
||||||
if ch == '`' {
|
if ch == '`' {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -612,11 +612,9 @@ impl RetroTui {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update status blink only if status is "PROCESSING"
|
// Update status blink only if status is "PROCESSING"
|
||||||
if state.status_line == "PROCESSING" {
|
if state.status_line == "PROCESSING" && state.last_status_blink.elapsed() > Duration::from_millis(500) {
|
||||||
if state.last_status_blink.elapsed() > Duration::from_millis(500) {
|
state.status_blink = !state.status_blink;
|
||||||
state.status_blink = !state.status_blink;
|
state.last_status_blink = Instant::now();
|
||||||
state.last_status_blink = Instant::now();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update activity area animation
|
// Update activity area animation
|
||||||
@@ -771,12 +769,7 @@ impl RetroTui {
|
|||||||
let total_cursor_pos = cursor_position;
|
let total_cursor_pos = cursor_position;
|
||||||
|
|
||||||
// Determine the window into the buffer we should show
|
// Determine the window into the buffer we should show
|
||||||
let window_start = if total_cursor_pos > available_width - 1 {
|
let window_start = total_cursor_pos.saturating_sub(available_width - 1);
|
||||||
// Cursor is beyond the visible area, scroll the view
|
|
||||||
total_cursor_pos - (available_width - 1)
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
};
|
|
||||||
|
|
||||||
// Get the visible portion of the buffer
|
// Get the visible portion of the buffer
|
||||||
let visible_buffer: String = input_buffer
|
let visible_buffer: String = input_buffer
|
||||||
@@ -1013,9 +1006,9 @@ impl RetroTui {
|
|||||||
let fade_color = |color: Color| -> Color {
|
let fade_color = |color: Color| -> Color {
|
||||||
match color {
|
match color {
|
||||||
Color::Rgb(r, g, b) => {
|
Color::Rgb(r, g, b) => {
|
||||||
let faded_r = ((r as f32 * opacity) as u8).max(0);
|
let faded_r = (r as f32 * opacity) as u8;
|
||||||
let faded_g = ((g as f32 * opacity) as u8).max(0);
|
let faded_g = (g as f32 * opacity) as u8;
|
||||||
let faded_b = ((b as f32 * opacity) as u8).max(0);
|
let faded_b = (b as f32 * opacity) as u8;
|
||||||
Color::Rgb(faded_r, faded_g, faded_b)
|
Color::Rgb(faded_r, faded_g, faded_b)
|
||||||
}
|
}
|
||||||
_ => color,
|
_ => color,
|
||||||
@@ -1098,9 +1091,9 @@ impl RetroTui {
|
|||||||
let fade_color = |color: Color| -> Color {
|
let fade_color = |color: Color| -> Color {
|
||||||
match color {
|
match color {
|
||||||
Color::Rgb(r, g, b) => {
|
Color::Rgb(r, g, b) => {
|
||||||
let faded_r = ((r as f32 * opacity) as u8).max(0);
|
let faded_r = (r as f32 * opacity) as u8;
|
||||||
let faded_g = ((g as f32 * opacity) as u8).max(0);
|
let faded_g = (g as f32 * opacity) as u8;
|
||||||
let faded_b = ((b as f32 * opacity) as u8).max(0);
|
let faded_b = (b as f32 * opacity) as u8;
|
||||||
Color::Rgb(faded_r, faded_g, faded_b)
|
Color::Rgb(faded_r, faded_g, faded_b)
|
||||||
}
|
}
|
||||||
_ => color,
|
_ => color,
|
||||||
@@ -1176,7 +1169,7 @@ impl RetroTui {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Wave characters for smooth animation
|
// Wave characters for smooth animation
|
||||||
let wave_chars = vec!['▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'];
|
let wave_chars = ['▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'];
|
||||||
|
|
||||||
// Build the wave line
|
// Build the wave line
|
||||||
let mut wave_line = String::new();
|
let mut wave_line = String::new();
|
||||||
@@ -1190,7 +1183,7 @@ impl RetroTui {
|
|||||||
let idx = wave_data.len().saturating_sub(display_width) + i;
|
let idx = wave_data.len().saturating_sub(display_width) + i;
|
||||||
|
|
||||||
if idx < wave_data.len() {
|
if idx < wave_data.len() {
|
||||||
let value = wave_data[idx].min(1.0).max(0.0);
|
let value = wave_data[idx].clamp(0.0, 1.0);
|
||||||
let char_idx = ((value * 7.0) as usize).min(7);
|
let char_idx = ((value * 7.0) as usize).min(7);
|
||||||
wave_line.push(wave_chars[char_idx]);
|
wave_line.push(wave_chars[char_idx]);
|
||||||
} else {
|
} else {
|
||||||
@@ -1206,8 +1199,6 @@ impl RetroTui {
|
|||||||
f.render_widget(wave_paragraph, area);
|
f.render_widget(wave_paragraph, area);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Draw the status bar
|
|
||||||
|
|
||||||
/// Draw the status bar
|
/// Draw the status bar
|
||||||
fn draw_status_bar(
|
fn draw_status_bar(
|
||||||
f: &mut Frame,
|
f: &mut Frame,
|
||||||
|
|||||||
32
crates/g3-cli/src/simple_output.rs
Normal file
32
crates/g3-cli/src/simple_output.rs
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
/// Simple output helper for printing messages
|
||||||
|
pub struct SimpleOutput {
|
||||||
|
machine_mode: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SimpleOutput {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
SimpleOutput { machine_mode: false }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_with_mode(machine_mode: bool) -> Self {
|
||||||
|
SimpleOutput { machine_mode }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn print(&self, message: &str) {
|
||||||
|
if !self.machine_mode {
|
||||||
|
println!("{}", message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn print_smart(&self, message: &str) {
|
||||||
|
if !self.machine_mode {
|
||||||
|
println!("{}", message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for SimpleOutput {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
use crossterm::style::Color;
|
use crossterm::style::Color;
|
||||||
use crossterm::style::{SetForegroundColor, ResetColor};
|
use crossterm::style::{SetForegroundColor, ResetColor};
|
||||||
|
use std::io::{self, Write};
|
||||||
use termimad::MadSkin;
|
use termimad::MadSkin;
|
||||||
|
|
||||||
/// Simple output handler with markdown support
|
/// Simple output handler with markdown support
|
||||||
@@ -40,7 +41,7 @@ impl SimpleOutput {
|
|||||||
trimmed.starts_with("* ") ||
|
trimmed.starts_with("* ") ||
|
||||||
trimmed.starts_with("+ ") ||
|
trimmed.starts_with("+ ") ||
|
||||||
(trimmed.len() > 2 &&
|
(trimmed.len() > 2 &&
|
||||||
trimmed.chars().next().map_or(false, |c| c.is_ascii_digit()) &&
|
trimmed.chars().next().is_some_and(|c| c.is_ascii_digit()) &&
|
||||||
trimmed.chars().nth(1) == Some('.') &&
|
trimmed.chars().nth(1) == Some('.') &&
|
||||||
trimmed.chars().nth(2) == Some(' ')) ||
|
trimmed.chars().nth(2) == Some(' ')) ||
|
||||||
(trimmed.contains('[') && trimmed.contains("]("))
|
(trimmed.contains('[') && trimmed.contains("]("))
|
||||||
@@ -70,18 +71,20 @@ impl SimpleOutput {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn print_context(&self, used: u32, total: u32, percentage: f32) {
|
pub fn print_context(&self, used: u32, total: u32, percentage: f32) {
|
||||||
let bar_width: usize = 10;
|
let total_dots = 10;
|
||||||
let filled_width = ((percentage / 100.0) * bar_width as f32) as usize;
|
let filled_dots = ((percentage / 100.0) * total_dots as f32) as usize;
|
||||||
let empty_width = bar_width.saturating_sub(filled_width);
|
let empty_dots = total_dots.saturating_sub(filled_dots);
|
||||||
|
|
||||||
let filled_chars = "●".repeat(filled_width);
|
let filled_str = "●".repeat(filled_dots);
|
||||||
let empty_chars = "○".repeat(empty_width);
|
let empty_str = "○".repeat(empty_dots);
|
||||||
|
|
||||||
// Determine color based on percentage
|
// Determine color based on percentage
|
||||||
let color = if percentage < 60.0 {
|
let color = if percentage < 40.0 {
|
||||||
crossterm::style::Color::Green
|
crossterm::style::Color::Green
|
||||||
} else if percentage < 80.0 {
|
} else if percentage < 60.0 {
|
||||||
crossterm::style::Color::Yellow
|
crossterm::style::Color::Yellow
|
||||||
|
} else if percentage < 80.0 {
|
||||||
|
crossterm::style::Color::Rgb { r: 255, g: 165, b: 0 } // Orange
|
||||||
} else {
|
} else {
|
||||||
crossterm::style::Color::Red
|
crossterm::style::Color::Red
|
||||||
};
|
};
|
||||||
@@ -89,9 +92,40 @@ impl SimpleOutput {
|
|||||||
// Print with colored progress bar
|
// Print with colored progress bar
|
||||||
print!("Context: ");
|
print!("Context: ");
|
||||||
print!("{}", SetForegroundColor(color));
|
print!("{}", SetForegroundColor(color));
|
||||||
print!("{}{}", filled_chars, empty_chars);
|
print!("{}{}", filled_str, empty_str);
|
||||||
print!("{}", ResetColor);
|
print!("{}", ResetColor);
|
||||||
println!(" {:.1}% | {}/{} tokens", percentage, used, total);
|
println!(" {:.0}% ({}/{} tokens)", percentage, used, total);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn print_context_thinning(&self, message: &str) {
|
||||||
|
// Animated highlight for context thinning
|
||||||
|
// Use bright cyan/green with a quick flash animation
|
||||||
|
|
||||||
|
// Flash animation: print with bright background, then normal
|
||||||
|
let frames = vec![
|
||||||
|
"\x1b[1;97;46m", // Frame 1: Bold white on cyan background
|
||||||
|
"\x1b[1;97;42m", // Frame 2: Bold white on green background
|
||||||
|
"\x1b[1;96;40m", // Frame 3: Bold cyan on black background
|
||||||
|
];
|
||||||
|
|
||||||
|
println!();
|
||||||
|
|
||||||
|
// Quick flash animation
|
||||||
|
for frame in &frames {
|
||||||
|
print!("\r{} ✨ {} ✨\x1b[0m", frame, message);
|
||||||
|
let _ = io::stdout().flush();
|
||||||
|
std::thread::sleep(std::time::Duration::from_millis(80));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final display with bright cyan and sparkle emojis
|
||||||
|
print!("\r\x1b[1;96m✨ {} ✨\x1b[0m", message);
|
||||||
|
println!();
|
||||||
|
|
||||||
|
// Add a subtle "success" indicator line
|
||||||
|
println!("\x1b[2;36m └─ Context optimized successfully\x1b[0m");
|
||||||
|
println!();
|
||||||
|
|
||||||
|
let _ = io::stdout().flush();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
use crate::retro_tui::RetroTui;
|
|
||||||
use g3_core::ui_writer::UiWriter;
|
use g3_core::ui_writer::UiWriter;
|
||||||
use std::io::{self, Write};
|
use std::io::{self, Write};
|
||||||
use std::sync::Mutex;
|
use std::sync::Mutex;
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
/// Console implementation of UiWriter that prints to stdout
|
/// Console implementation of UiWriter that prints to stdout
|
||||||
pub struct ConsoleUiWriter {
|
pub struct ConsoleUiWriter {
|
||||||
@@ -104,6 +102,37 @@ impl UiWriter for ConsoleUiWriter {
|
|||||||
println!("{}", message);
|
println!("{}", message);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn print_context_thinning(&self, message: &str) {
|
||||||
|
// Animated highlight for context thinning
|
||||||
|
// Use bright cyan/green with a quick flash animation
|
||||||
|
|
||||||
|
// Flash animation: print with bright background, then normal
|
||||||
|
let frames = vec![
|
||||||
|
"\x1b[1;97;46m", // Frame 1: Bold white on cyan background
|
||||||
|
"\x1b[1;97;42m", // Frame 2: Bold white on green background
|
||||||
|
"\x1b[1;96;40m", // Frame 3: Bold cyan on black background
|
||||||
|
];
|
||||||
|
|
||||||
|
println!();
|
||||||
|
|
||||||
|
// Quick flash animation
|
||||||
|
for frame in &frames {
|
||||||
|
print!("\r{} ✨ {} ✨\x1b[0m", frame, message);
|
||||||
|
let _ = io::stdout().flush();
|
||||||
|
std::thread::sleep(std::time::Duration::from_millis(80));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final display with bright cyan and sparkle emojis
|
||||||
|
print!("\r\x1b[1;96m✨ {} ✨\x1b[0m", message);
|
||||||
|
println!();
|
||||||
|
|
||||||
|
// Add a subtle "success" indicator line
|
||||||
|
println!("\x1b[2;36m └─ Context optimized successfully\x1b[0m");
|
||||||
|
println!();
|
||||||
|
|
||||||
|
let _ = io::stdout().flush();
|
||||||
|
}
|
||||||
|
|
||||||
fn print_tool_header(&self, tool_name: &str) {
|
fn print_tool_header(&self, tool_name: &str) {
|
||||||
// Store the tool name and clear args for collection
|
// Store the tool name and clear args for collection
|
||||||
*self.current_tool_name.lock().unwrap() = Some(tool_name.to_string());
|
*self.current_tool_name.lock().unwrap() = Some(tool_name.to_string());
|
||||||
@@ -115,7 +144,6 @@ impl UiWriter for ConsoleUiWriter {
|
|||||||
|
|
||||||
// For todo tools, we'll skip the normal header and print a custom one later
|
// For todo tools, we'll skip the normal header and print a custom one later
|
||||||
if is_todo {
|
if is_todo {
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,7 +191,12 @@ impl UiWriter for ConsoleUiWriter {
|
|||||||
|
|
||||||
// Truncate long values for display
|
// Truncate long values for display
|
||||||
let display_value = if first_line.len() > 80 {
|
let display_value = if first_line.len() > 80 {
|
||||||
format!("{}...", &first_line[..77])
|
// Use char_indices to safely truncate at character boundary
|
||||||
|
let truncate_at = first_line.char_indices()
|
||||||
|
.nth(77)
|
||||||
|
.map(|(i, _)| i)
|
||||||
|
.unwrap_or(first_line.len());
|
||||||
|
format!("{}...", &first_line[..truncate_at])
|
||||||
} else {
|
} else {
|
||||||
first_line.to_string()
|
first_line.to_string()
|
||||||
};
|
};
|
||||||
@@ -312,223 +345,3 @@ impl UiWriter for ConsoleUiWriter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// RetroTui implementation of UiWriter that sends output to the TUI
|
|
||||||
pub struct RetroTuiWriter {
|
|
||||||
tui: RetroTui,
|
|
||||||
current_tool_name: Mutex<Option<String>>,
|
|
||||||
current_tool_output: Mutex<Vec<String>>,
|
|
||||||
current_tool_start: Mutex<Option<Instant>>,
|
|
||||||
current_tool_caption: Mutex<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RetroTuiWriter {
|
|
||||||
pub fn new(tui: RetroTui) -> Self {
|
|
||||||
Self {
|
|
||||||
tui,
|
|
||||||
current_tool_name: Mutex::new(None),
|
|
||||||
current_tool_output: Mutex::new(Vec::new()),
|
|
||||||
current_tool_start: Mutex::new(None),
|
|
||||||
current_tool_caption: Mutex::new(String::new()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UiWriter for RetroTuiWriter {
|
|
||||||
fn print(&self, message: &str) {
|
|
||||||
self.tui.output(message);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn println(&self, message: &str) {
|
|
||||||
self.tui.output(message);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn print_inline(&self, message: &str) {
|
|
||||||
// For inline printing, we'll just append to the output
|
|
||||||
self.tui.output(message);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn print_system_prompt(&self, prompt: &str) {
|
|
||||||
self.tui.output("🔍 System Prompt:");
|
|
||||||
self.tui.output("================");
|
|
||||||
for line in prompt.lines() {
|
|
||||||
self.tui.output(line);
|
|
||||||
}
|
|
||||||
self.tui.output("================");
|
|
||||||
self.tui.output("");
|
|
||||||
}
|
|
||||||
|
|
||||||
fn print_context_status(&self, message: &str) {
|
|
||||||
self.tui.output(message);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn print_tool_header(&self, tool_name: &str) {
|
|
||||||
// Start collecting tool output
|
|
||||||
*self.current_tool_start.lock().unwrap() = Some(Instant::now());
|
|
||||||
*self.current_tool_name.lock().unwrap() = Some(tool_name.to_string());
|
|
||||||
self.current_tool_output.lock().unwrap().clear();
|
|
||||||
self.current_tool_output
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.push(format!("Tool: {}", tool_name));
|
|
||||||
|
|
||||||
// Initialize caption
|
|
||||||
*self.current_tool_caption.lock().unwrap() = String::new();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn print_tool_arg(&self, key: &str, value: &str) {
|
|
||||||
// Filter out any keys that look like they might be agent message content
|
|
||||||
// (e.g., keys that are suspiciously long or contain message-like content)
|
|
||||||
let is_valid_arg_key = key.len() < 50
|
|
||||||
&& !key.contains('\n')
|
|
||||||
&& !key.contains("I'll")
|
|
||||||
&& !key.contains("Let me")
|
|
||||||
&& !key.contains("Here's")
|
|
||||||
&& !key.contains("I can");
|
|
||||||
|
|
||||||
if is_valid_arg_key {
|
|
||||||
self.current_tool_output
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.push(format!("{}: {}", key, value));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build caption from first argument (usually the most important one)
|
|
||||||
let mut caption = self.current_tool_caption.lock().unwrap();
|
|
||||||
if caption.is_empty() && (key == "file_path" || key == "command" || key == "path") {
|
|
||||||
// Truncate long values for the caption
|
|
||||||
let truncated = if value.len() > 50 {
|
|
||||||
format!("{}...", &value[..47])
|
|
||||||
} else {
|
|
||||||
value.to_string()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Add range information for read_file tool calls
|
|
||||||
let tool_name = self.current_tool_name.lock().unwrap();
|
|
||||||
let range_suffix = if tool_name.as_ref().map_or(false, |name| name == "read_file") {
|
|
||||||
// We need to check if start/end args will be provided - for now just check if this is a partial read
|
|
||||||
// This is a simplified approach since we're building the caption incrementally
|
|
||||||
String::new() // We'll handle this in print_tool_output_header instead
|
|
||||||
} else {
|
|
||||||
String::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
*caption = format!("{}{}", truncated, range_suffix);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn print_tool_output_header(&self) {
|
|
||||||
// This is called right before tool execution starts
|
|
||||||
// Send the initial tool header to the TUI now
|
|
||||||
if let Some(tool_name) = self.current_tool_name.lock().unwrap().as_ref() {
|
|
||||||
let mut caption = self.current_tool_caption.lock().unwrap().clone();
|
|
||||||
|
|
||||||
// Add range information for read_file tool calls
|
|
||||||
if tool_name == "read_file" {
|
|
||||||
// Check the tool output for start/end parameters
|
|
||||||
let output = self.current_tool_output.lock().unwrap();
|
|
||||||
let has_start = output.iter().any(|line| line.starts_with("start:"));
|
|
||||||
let has_end = output.iter().any(|line| line.starts_with("end:"));
|
|
||||||
|
|
||||||
if has_start || has_end {
|
|
||||||
let start_val = output.iter().find(|line| line.starts_with("start:")).map(|line| line.split(':').nth(1).unwrap_or("0").trim()).unwrap_or("0");
|
|
||||||
let end_val = output.iter().find(|line| line.starts_with("end:")).map(|line| line.split(':').nth(1).unwrap_or("end").trim()).unwrap_or("end");
|
|
||||||
caption = format!("{} [{}..{}]", caption, start_val, end_val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send the tool output with initial header
|
|
||||||
self.tui.tool_output(tool_name, &caption, "");
|
|
||||||
}
|
|
||||||
|
|
||||||
self.current_tool_output.lock().unwrap().push(String::new());
|
|
||||||
self.current_tool_output
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.push("Output:".to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
fn update_tool_output_line(&self, line: &str) {
|
|
||||||
// For retro mode, we'll just add to the output buffer
|
|
||||||
self.current_tool_output
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.push(line.to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
fn print_tool_output_line(&self, line: &str) {
|
|
||||||
self.current_tool_output
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.push(line.to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
fn print_tool_output_summary(&self, hidden_count: usize) {
|
|
||||||
self.current_tool_output.lock().unwrap().push(format!(
|
|
||||||
"... ({} more line{})",
|
|
||||||
hidden_count,
|
|
||||||
if hidden_count == 1 { "" } else { "s" }
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
fn print_tool_timing(&self, duration_str: &str) {
|
|
||||||
self.current_tool_output
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.push(format!("⚡️ {}", duration_str));
|
|
||||||
|
|
||||||
// Calculate the actual duration
|
|
||||||
let duration_ms = if let Some(start) = *self.current_tool_start.lock().unwrap() {
|
|
||||||
start.elapsed().as_millis()
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
};
|
|
||||||
|
|
||||||
// Get the tool name and caption
|
|
||||||
if let Some(tool_name) = self.current_tool_name.lock().unwrap().as_ref() {
|
|
||||||
let content = self.current_tool_output.lock().unwrap().join("\n");
|
|
||||||
let caption = self.current_tool_caption.lock().unwrap().clone();
|
|
||||||
let caption = if caption.is_empty() {
|
|
||||||
"Completed".to_string()
|
|
||||||
} else {
|
|
||||||
caption
|
|
||||||
};
|
|
||||||
|
|
||||||
// Update the tool detail panel with the complete output without adding a new header
|
|
||||||
// This keeps the original header in place to be updated by tool_complete
|
|
||||||
self.tui.update_tool_detail(tool_name, &content);
|
|
||||||
|
|
||||||
// Determine success based on whether there's an error in the output
|
|
||||||
// This is a simple heuristic - you might want to make this more sophisticated
|
|
||||||
let success = !content.contains("error")
|
|
||||||
&& !content.contains("Error")
|
|
||||||
&& !content.contains("ERROR");
|
|
||||||
|
|
||||||
// Send the completion status to update the header
|
|
||||||
self.tui
|
|
||||||
.tool_complete(tool_name, success, duration_ms, &caption);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear the buffers
|
|
||||||
*self.current_tool_name.lock().unwrap() = None;
|
|
||||||
self.current_tool_output.lock().unwrap().clear();
|
|
||||||
*self.current_tool_start.lock().unwrap() = None;
|
|
||||||
*self.current_tool_caption.lock().unwrap() = String::new();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn print_agent_prompt(&self) {
|
|
||||||
self.tui.output("\n💬 ");
|
|
||||||
}
|
|
||||||
|
|
||||||
fn print_agent_response(&self, content: &str) {
|
|
||||||
self.tui.output(content);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn notify_sse_received(&self) {
|
|
||||||
// Notify the TUI that an SSE was received
|
|
||||||
self.tui.sse_received();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn flush(&self) {
|
|
||||||
// No-op for TUI since it handles its own rendering
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,6 +3,9 @@ name = "g3-computer-control"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
# Only needed for building Swift bridge on macOS
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
# Workspace dependencies
|
# Workspace dependencies
|
||||||
tokio = { workspace = true }
|
tokio = { workspace = true }
|
||||||
@@ -20,15 +23,13 @@ async-trait = "0.1"
|
|||||||
# WebDriver support
|
# WebDriver support
|
||||||
fantoccini = "0.21"
|
fantoccini = "0.21"
|
||||||
|
|
||||||
# OCR dependencies
|
|
||||||
tesseract = "0.14"
|
|
||||||
|
|
||||||
# macOS dependencies
|
# macOS dependencies
|
||||||
[target.'cfg(target_os = "macos")'.dependencies]
|
[target.'cfg(target_os = "macos")'.dependencies]
|
||||||
core-graphics = "0.23"
|
core-graphics = "0.23"
|
||||||
core-foundation = "0.9"
|
core-foundation = "0.10"
|
||||||
cocoa = "0.25"
|
cocoa = "0.25"
|
||||||
objc = "0.2"
|
objc = "0.2"
|
||||||
|
accessibility = "0.2"
|
||||||
image = "0.24"
|
image = "0.24"
|
||||||
|
|
||||||
# Linux dependencies
|
# Linux dependencies
|
||||||
|
|||||||
63
crates/g3-computer-control/build.rs
Normal file
63
crates/g3-computer-control/build.rs
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
use std::env;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::process::Command;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
// Only build Vision bridge on macOS
|
||||||
|
if env::var("CARGO_CFG_TARGET_OS").unwrap() != "macos" {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("cargo:rerun-if-changed=vision-bridge/Sources/VisionBridge/VisionOCR.swift");
|
||||||
|
println!("cargo:rerun-if-changed=vision-bridge/Sources/VisionBridge/VisionBridge.h");
|
||||||
|
println!("cargo:rerun-if-changed=vision-bridge/Package.swift");
|
||||||
|
|
||||||
|
let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap());
|
||||||
|
let vision_bridge_dir = manifest_dir.join("vision-bridge");
|
||||||
|
|
||||||
|
// Build Swift package
|
||||||
|
println!("cargo:warning=Building VisionBridge Swift package...");
|
||||||
|
let build_status = Command::new("swift")
|
||||||
|
.args(&["build", "-c", "release"])
|
||||||
|
.current_dir(&vision_bridge_dir)
|
||||||
|
.status()
|
||||||
|
.expect("Failed to build Swift package");
|
||||||
|
|
||||||
|
if !build_status.success() {
|
||||||
|
panic!("Swift build failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the built library
|
||||||
|
let lib_path = vision_bridge_dir
|
||||||
|
.join(".build/release")
|
||||||
|
.canonicalize()
|
||||||
|
.expect("Failed to find .build/release directory");
|
||||||
|
|
||||||
|
// Copy the dylib to the output directory so it can be found at runtime
|
||||||
|
let target_dir = manifest_dir.parent().unwrap().parent().unwrap().join("target");
|
||||||
|
let profile = env::var("PROFILE").unwrap_or_else(|_| "debug".to_string());
|
||||||
|
let output_dir = target_dir.join(&profile);
|
||||||
|
|
||||||
|
let dylib_src = lib_path.join("libVisionBridge.dylib");
|
||||||
|
let dylib_dst = output_dir.join("libVisionBridge.dylib");
|
||||||
|
|
||||||
|
std::fs::copy(&dylib_src, &dylib_dst)
|
||||||
|
.expect(&format!("Failed to copy dylib from {} to {}", dylib_src.display(), dylib_dst.display()));
|
||||||
|
|
||||||
|
println!("cargo:warning=Copied libVisionBridge.dylib to {}", dylib_dst.display());
|
||||||
|
|
||||||
|
// Add rpath so the dylib can be found at runtime
|
||||||
|
println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path");
|
||||||
|
println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path");
|
||||||
|
println!("cargo:rustc-link-search=native={}", lib_path.display());
|
||||||
|
println!("cargo:rustc-link-lib=dylib=VisionBridge");
|
||||||
|
|
||||||
|
// Link required frameworks
|
||||||
|
println!("cargo:rustc-link-lib=framework=Vision");
|
||||||
|
println!("cargo:rustc-link-lib=framework=AppKit");
|
||||||
|
println!("cargo:rustc-link-lib=framework=Foundation");
|
||||||
|
println!("cargo:rustc-link-lib=framework=CoreGraphics");
|
||||||
|
println!("cargo:rustc-link-lib=framework=CoreImage");
|
||||||
|
|
||||||
|
println!("cargo:warning=VisionBridge built successfully at {}", lib_path.display());
|
||||||
|
}
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
use core_graphics::window::{kCGWindowListOptionOnScreenOnly, kCGNullWindowID, CGWindowListCopyWindowInfo};
|
use core_graphics::window::{kCGWindowListOptionOnScreenOnly, kCGNullWindowID, CGWindowListCopyWindowInfo};
|
||||||
use core_foundation::dictionary::CFDictionary;
|
use core_foundation::dictionary::CFDictionary;
|
||||||
use core_foundation::string::CFString;
|
use core_foundation::string::CFString;
|
||||||
use core_foundation::base::TCFType;
|
use core_foundation::base::{TCFType, ToVoid};
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
println!("Listing all on-screen windows...");
|
println!("Listing all on-screen windows...");
|
||||||
@@ -22,7 +22,7 @@ fn main() {
|
|||||||
|
|
||||||
// Get window ID
|
// Get window ID
|
||||||
let window_id_key = CFString::from_static_string("kCGWindowNumber");
|
let window_id_key = CFString::from_static_string("kCGWindowNumber");
|
||||||
let window_id: i64 = if let Some(value) = dict.find(window_id_key.as_concrete_TypeRef()) {
|
let window_id: i64 = if let Some(value) = dict.find(window_id_key.to_void()) {
|
||||||
let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _);
|
let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _);
|
||||||
num.to_i64().unwrap_or(0)
|
num.to_i64().unwrap_or(0)
|
||||||
} else {
|
} else {
|
||||||
@@ -31,7 +31,7 @@ fn main() {
|
|||||||
|
|
||||||
// Get owner name
|
// Get owner name
|
||||||
let owner_key = CFString::from_static_string("kCGWindowOwnerName");
|
let owner_key = CFString::from_static_string("kCGWindowOwnerName");
|
||||||
let owner: String = if let Some(value) = dict.find(owner_key.as_concrete_TypeRef()) {
|
let owner: String = if let Some(value) = dict.find(owner_key.to_void()) {
|
||||||
let s: CFString = TCFType::wrap_under_get_rule(*value as *const _);
|
let s: CFString = TCFType::wrap_under_get_rule(*value as *const _);
|
||||||
s.to_string()
|
s.to_string()
|
||||||
} else {
|
} else {
|
||||||
@@ -40,15 +40,15 @@ fn main() {
|
|||||||
|
|
||||||
// Get window name/title
|
// Get window name/title
|
||||||
let name_key = CFString::from_static_string("kCGWindowName");
|
let name_key = CFString::from_static_string("kCGWindowName");
|
||||||
let title: String = if let Some(value) = dict.find(name_key.as_concrete_TypeRef()) {
|
let title: String = if let Some(value) = dict.find(name_key.to_void()) {
|
||||||
let s: CFString = TCFType::wrap_under_get_rule(*value as *const _);
|
let s: CFString = TCFType::wrap_under_get_rule(*value as *const _);
|
||||||
s.to_string()
|
s.to_string()
|
||||||
} else {
|
} else {
|
||||||
"".to_string()
|
"".to_string()
|
||||||
};
|
};
|
||||||
|
|
||||||
// Filter for iTerm or show all
|
// Show all windows
|
||||||
if owner.contains("iTerm") || owner.contains("Terminal") {
|
if !owner.is_empty() {
|
||||||
println!("{:<10} {:<25} {}", window_id, owner, title);
|
println!("{:<10} {:<25} {}", window_id, owner, title);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
74
crates/g3-computer-control/examples/macax_demo.rs
Normal file
74
crates/g3-computer-control/examples/macax_demo.rs
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
//! Example demonstrating macOS Accessibility API tools
|
||||||
|
//!
|
||||||
|
//! This example shows how to use the macax tools to control macOS applications.
|
||||||
|
//!
|
||||||
|
//! Run with: cargo run --example macax_demo
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use g3_computer_control::MacAxController;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
println!("🍎 macOS Accessibility API Demo\n");
|
||||||
|
println!("This demo shows how to control macOS applications using the Accessibility API.\n");
|
||||||
|
|
||||||
|
// Create controller
|
||||||
|
let controller = MacAxController::new()?;
|
||||||
|
println!("✅ MacAxController initialized\n");
|
||||||
|
|
||||||
|
// List running applications
|
||||||
|
println!("📱 Listing running applications:");
|
||||||
|
match controller.list_applications() {
|
||||||
|
Ok(apps) => {
|
||||||
|
for app in apps.iter().take(10) {
|
||||||
|
println!(" - {}", app.name);
|
||||||
|
}
|
||||||
|
if apps.len() > 10 {
|
||||||
|
println!(" ... and {} more", apps.len() - 10);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => println!(" ❌ Error: {}", e),
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
|
||||||
|
// Get frontmost app
|
||||||
|
println!("🎯 Getting frontmost application:");
|
||||||
|
match controller.get_frontmost_app() {
|
||||||
|
Ok(app) => println!(" Current: {}", app.name),
|
||||||
|
Err(e) => println!(" ❌ Error: {}", e),
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
|
||||||
|
// Example: Activate Finder and get its UI tree
|
||||||
|
println!("📂 Activating Finder and inspecting UI:");
|
||||||
|
match controller.activate_app("Finder") {
|
||||||
|
Ok(_) => {
|
||||||
|
println!(" ✅ Finder activated");
|
||||||
|
|
||||||
|
// Wait a moment for activation
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
||||||
|
|
||||||
|
// Get UI tree
|
||||||
|
match controller.get_ui_tree("Finder", 2) {
|
||||||
|
Ok(tree) => {
|
||||||
|
println!("\n UI Tree:");
|
||||||
|
for line in tree.lines().take(10) {
|
||||||
|
println!(" {}", line);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => println!(" ❌ Error getting UI tree: {}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => println!(" ❌ Error: {}", e),
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
|
||||||
|
println!("✨ Demo complete!\n");
|
||||||
|
println!("💡 Tips:");
|
||||||
|
println!(" - Use --macax flag with g3 to enable these tools");
|
||||||
|
println!(" - Grant accessibility permissions in System Preferences");
|
||||||
|
println!(" - Add accessibility identifiers to your apps for easier automation");
|
||||||
|
println!(" - See docs/macax-tools.md for full documentation\n");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -31,7 +31,7 @@ async fn main() -> Result<()> {
|
|||||||
|
|
||||||
// Find an element
|
// Find an element
|
||||||
println!("Finding h1 element...");
|
println!("Finding h1 element...");
|
||||||
let mut h1 = driver.find_element("h1").await?;
|
let h1 = driver.find_element("h1").await?;
|
||||||
let h1_text = h1.text().await?;
|
let h1_text = h1.text().await?;
|
||||||
println!("H1 text: {}\n", h1_text);
|
println!("H1 text: {}\n", h1_text);
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use g3_computer_control::{create_controller, ComputerController};
|
use g3_computer_control::create_controller;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
use core_graphics::display::CGDisplay;
|
use core_graphics::display::CGDisplay;
|
||||||
use image::{ImageBuffer, RgbaImage};
|
use image::{ImageBuffer, RgbaImage};
|
||||||
use std::path::Path;
|
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let display = CGDisplay::main();
|
let display = CGDisplay::main();
|
||||||
|
|||||||
48
crates/g3-computer-control/examples/test_type_text.rs
Normal file
48
crates/g3-computer-control/examples/test_type_text.rs
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
//! Test the new type_text functionality
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use g3_computer_control::MacAxController;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
println!("🧪 Testing macax type_text functionality\n");
|
||||||
|
|
||||||
|
let controller = MacAxController::new()?;
|
||||||
|
println!("✅ Controller initialized\n");
|
||||||
|
|
||||||
|
// Test 1: Type simple text
|
||||||
|
println!("Test 1: Typing simple text into TextEdit");
|
||||||
|
println!(" Please open TextEdit and create a new document...");
|
||||||
|
std::thread::sleep(std::time::Duration::from_secs(3));
|
||||||
|
|
||||||
|
match controller.type_text("TextEdit", "Hello, World!") {
|
||||||
|
Ok(_) => println!(" ✅ Successfully typed simple text\n"),
|
||||||
|
Err(e) => println!(" ❌ Failed: {}\n", e),
|
||||||
|
}
|
||||||
|
|
||||||
|
std::thread::sleep(std::time::Duration::from_secs(1));
|
||||||
|
|
||||||
|
// Test 2: Type unicode and emojis
|
||||||
|
println!("Test 2: Typing unicode and emojis");
|
||||||
|
match controller.type_text("TextEdit", "\n🌟 Unicode test: café, naïve, 日本語 🎉") {
|
||||||
|
Ok(_) => println!(" ✅ Successfully typed unicode text\n"),
|
||||||
|
Err(e) => println!(" ❌ Failed: {}\n", e),
|
||||||
|
}
|
||||||
|
|
||||||
|
std::thread::sleep(std::time::Duration::from_secs(1));
|
||||||
|
|
||||||
|
// Test 3: Type special characters
|
||||||
|
println!("Test 3: Typing special characters");
|
||||||
|
match controller.type_text("TextEdit", "\nSpecial: @#$%^&*()_+-=[]{}|;':,.<>?/") {
|
||||||
|
Ok(_) => println!(" ✅ Successfully typed special characters\n"),
|
||||||
|
Err(e) => println!(" ❌ Failed: {}\n", e),
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("\n✨ Tests complete!");
|
||||||
|
println!("\n💡 Now try with Things3:");
|
||||||
|
println!(" 1. Open Things3");
|
||||||
|
println!(" 2. Press Cmd+N to create a new task");
|
||||||
|
println!(" 3. Run: g3 --macax 'type \"🌟 My awesome task\" into Things'");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
85
crates/g3-computer-control/examples/test_vision.rs
Normal file
85
crates/g3-computer-control/examples/test_vision.rs
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
use g3_computer_control::ocr::{OCREngine, DefaultOCR};
|
||||||
|
use anyhow::Result;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
println!("🧪 Testing Apple Vision OCR");
|
||||||
|
println!("===========================\n");
|
||||||
|
|
||||||
|
// Initialize OCR engine
|
||||||
|
println!("📦 Initializing OCR engine...");
|
||||||
|
let ocr = DefaultOCR::new()?;
|
||||||
|
println!("✅ OCR engine: {}\n", ocr.name());
|
||||||
|
|
||||||
|
// Check if test image exists
|
||||||
|
let test_image = "/tmp/safari_test.png";
|
||||||
|
if !std::path::Path::new(test_image).exists() {
|
||||||
|
println!("⚠️ Test image not found: {}", test_image);
|
||||||
|
println!(" Creating a screenshot...");
|
||||||
|
|
||||||
|
let status = std::process::Command::new("screencapture")
|
||||||
|
.arg("-x")
|
||||||
|
.arg("-R")
|
||||||
|
.arg("0,0,1200,800")
|
||||||
|
.arg(test_image)
|
||||||
|
.status()?;
|
||||||
|
|
||||||
|
if !status.success() {
|
||||||
|
anyhow::bail!("Failed to create screenshot");
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("✅ Screenshot created\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run OCR
|
||||||
|
println!("🔍 Running Apple Vision OCR on {}...", test_image);
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let locations = ocr.extract_text_with_locations(test_image).await?;
|
||||||
|
let duration = start.elapsed();
|
||||||
|
|
||||||
|
println!("✅ OCR completed in {:.3}s\n", duration.as_secs_f64());
|
||||||
|
|
||||||
|
// Display results
|
||||||
|
println!("📊 Results:");
|
||||||
|
println!(" Found {} text elements\n", locations.len());
|
||||||
|
|
||||||
|
if locations.is_empty() {
|
||||||
|
println!("⚠️ No text found in image");
|
||||||
|
} else {
|
||||||
|
println!(" Top 20 results:");
|
||||||
|
println!(" {:<4} {:<40} {:<15} {:<12} {:<8}", "#", "Text", "Position", "Size", "Conf");
|
||||||
|
println!(" {}", "-".repeat(85));
|
||||||
|
|
||||||
|
for (i, loc) in locations.iter().take(20).enumerate() {
|
||||||
|
let text = if loc.text.len() > 37 {
|
||||||
|
format!("{}...", &loc.text[..37])
|
||||||
|
} else {
|
||||||
|
loc.text.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
println!(" {:<4} {:<40} ({:>4},{:>4}) {:>4}x{:<4} {:.2}",
|
||||||
|
i + 1,
|
||||||
|
text,
|
||||||
|
loc.x,
|
||||||
|
loc.y,
|
||||||
|
loc.width,
|
||||||
|
loc.height,
|
||||||
|
loc.confidence
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if locations.len() > 20 {
|
||||||
|
println!("\n ... and {} more", locations.len() - 20);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Performance comparison
|
||||||
|
println!("\n📈 Performance:");
|
||||||
|
println!(" OCR Speed: {:.3}s", duration.as_secs_f64());
|
||||||
|
println!(" Text elements: {}", locations.len());
|
||||||
|
println!(" Avg per element: {:.1}ms", duration.as_millis() as f64 / locations.len() as f64);
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("\n✅ Test complete!");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -1,10 +1,18 @@
|
|||||||
|
// Suppress warnings from objc crate macros
|
||||||
|
#![allow(unexpected_cfgs)]
|
||||||
|
|
||||||
pub mod types;
|
pub mod types;
|
||||||
pub mod platform;
|
pub mod platform;
|
||||||
|
pub mod ocr;
|
||||||
pub mod webdriver;
|
pub mod webdriver;
|
||||||
|
pub mod macax;
|
||||||
|
|
||||||
// Re-export webdriver types for convenience
|
// Re-export webdriver types for convenience
|
||||||
pub use webdriver::{WebDriverController, WebElement, safari::SafariDriver};
|
pub use webdriver::{WebDriverController, WebElement, safari::SafariDriver};
|
||||||
|
|
||||||
|
// Re-export macax types for convenience
|
||||||
|
pub use macax::{MacAxController, AXElement, AXApplication};
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use types::*;
|
use types::*;
|
||||||
@@ -15,8 +23,14 @@ pub trait ComputerController: Send + Sync {
|
|||||||
async fn take_screenshot(&self, path: &str, region: Option<Rect>, window_id: Option<&str>) -> Result<()>;
|
async fn take_screenshot(&self, path: &str, region: Option<Rect>, window_id: Option<&str>) -> Result<()>;
|
||||||
|
|
||||||
// OCR operations
|
// OCR operations
|
||||||
async fn extract_text_from_screen(&self, region: Rect) -> Result<String>;
|
async fn extract_text_from_screen(&self, region: Rect, window_id: &str) -> Result<String>;
|
||||||
async fn extract_text_from_image(&self, path: &str) -> Result<String>;
|
async fn extract_text_from_image(&self, path: &str) -> Result<String>;
|
||||||
|
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>>;
|
||||||
|
async fn find_text_in_app(&self, app_name: &str, search_text: &str) -> Result<Option<TextLocation>>;
|
||||||
|
|
||||||
|
// Mouse operations
|
||||||
|
fn move_mouse(&self, x: i32, y: i32) -> Result<()>;
|
||||||
|
fn click_at(&self, x: i32, y: i32, app_name: Option<&str>) -> Result<()>;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Platform-specific constructor
|
// Platform-specific constructor
|
||||||
|
|||||||
822
crates/g3-computer-control/src/macax/controller.rs
Normal file
822
crates/g3-computer-control/src/macax/controller.rs
Normal file
@@ -0,0 +1,822 @@
|
|||||||
|
use super::{AXApplication, AXElement};
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
use accessibility::{AXUIElement, AXUIElementAttributes, ElementFinder, TreeVisitor, TreeWalker, TreeWalkerFlow};
|
||||||
|
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
use core_foundation::base::TCFType;
|
||||||
|
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
use core_foundation::string::CFString;
|
||||||
|
|
||||||
|
/// macOS Accessibility API controller using native APIs
|
||||||
|
pub struct MacAxController {
|
||||||
|
// Cache for application elements
|
||||||
|
app_cache: std::sync::Mutex<HashMap<String, AXUIElement>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MacAxController {
|
||||||
|
pub fn new() -> Result<Self> {
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
{
|
||||||
|
// Check if we have accessibility permissions by trying to get system-wide element
|
||||||
|
let _system = AXUIElement::system_wide();
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
app_cache: std::sync::Mutex::new(HashMap::new()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "macos"))]
|
||||||
|
{
|
||||||
|
anyhow::bail!("macOS Accessibility API is only available on macOS")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List all running applications
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
pub fn list_applications(&self) -> Result<Vec<AXApplication>> {
|
||||||
|
let apps = Self::get_running_applications()?;
|
||||||
|
Ok(apps)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "macos"))]
|
||||||
|
pub fn list_applications(&self) -> Result<Vec<AXApplication>> {
|
||||||
|
anyhow::bail!("Not supported on this platform")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
fn get_running_applications() -> Result<Vec<AXApplication>> {
|
||||||
|
use cocoa::appkit::NSApplicationActivationPolicy;
|
||||||
|
use cocoa::base::{id, nil};
|
||||||
|
use objc::{class, msg_send, sel, sel_impl};
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
let workspace: id = msg_send![class!(NSWorkspace), sharedWorkspace];
|
||||||
|
let running_apps: id = msg_send![workspace, runningApplications];
|
||||||
|
let count: usize = msg_send![running_apps, count];
|
||||||
|
|
||||||
|
let mut apps = Vec::new();
|
||||||
|
|
||||||
|
for i in 0..count {
|
||||||
|
let app: id = msg_send![running_apps, objectAtIndex: i];
|
||||||
|
|
||||||
|
// Get app name
|
||||||
|
let localized_name: id = msg_send![app, localizedName];
|
||||||
|
if localized_name == nil {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let name_ptr: *const i8 = msg_send![localized_name, UTF8String];
|
||||||
|
let name = if !name_ptr.is_null() {
|
||||||
|
std::ffi::CStr::from_ptr(name_ptr)
|
||||||
|
.to_string_lossy()
|
||||||
|
.to_string()
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get bundle ID
|
||||||
|
let bundle_id_obj: id = msg_send![app, bundleIdentifier];
|
||||||
|
let bundle_id = if bundle_id_obj != nil {
|
||||||
|
let bundle_id_ptr: *const i8 = msg_send![bundle_id_obj, UTF8String];
|
||||||
|
if !bundle_id_ptr.is_null() {
|
||||||
|
Some(
|
||||||
|
std::ffi::CStr::from_ptr(bundle_id_ptr)
|
||||||
|
.to_string_lossy()
|
||||||
|
.to_string(),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get PID
|
||||||
|
let pid: i32 = msg_send![app, processIdentifier];
|
||||||
|
|
||||||
|
// Skip background-only apps
|
||||||
|
let activation_policy: i64 = msg_send![app, activationPolicy];
|
||||||
|
if activation_policy == NSApplicationActivationPolicy::NSApplicationActivationPolicyRegular as i64 {
|
||||||
|
apps.push(AXApplication {
|
||||||
|
name,
|
||||||
|
bundle_id,
|
||||||
|
pid,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(apps)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the frontmost (active) application
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
pub fn get_frontmost_app(&self) -> Result<AXApplication> {
|
||||||
|
use cocoa::base::{id, nil};
|
||||||
|
use objc::{class, msg_send, sel, sel_impl};
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
let workspace: id = msg_send![class!(NSWorkspace), sharedWorkspace];
|
||||||
|
let frontmost_app: id = msg_send![workspace, frontmostApplication];
|
||||||
|
|
||||||
|
if frontmost_app == nil {
|
||||||
|
anyhow::bail!("No frontmost application");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get app name
|
||||||
|
let localized_name: id = msg_send![frontmost_app, localizedName];
|
||||||
|
let name_ptr: *const i8 = msg_send![localized_name, UTF8String];
|
||||||
|
let name = std::ffi::CStr::from_ptr(name_ptr)
|
||||||
|
.to_string_lossy()
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
// Get bundle ID
|
||||||
|
let bundle_id_obj: id = msg_send![frontmost_app, bundleIdentifier];
|
||||||
|
let bundle_id = if bundle_id_obj != nil {
|
||||||
|
let bundle_id_ptr: *const i8 = msg_send![bundle_id_obj, UTF8String];
|
||||||
|
if !bundle_id_ptr.is_null() {
|
||||||
|
Some(
|
||||||
|
std::ffi::CStr::from_ptr(bundle_id_ptr)
|
||||||
|
.to_string_lossy()
|
||||||
|
.to_string(),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get PID
|
||||||
|
let pid: i32 = msg_send![frontmost_app, processIdentifier];
|
||||||
|
|
||||||
|
Ok(AXApplication {
|
||||||
|
name,
|
||||||
|
bundle_id,
|
||||||
|
pid,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "macos"))]
|
||||||
|
pub fn get_frontmost_app(&self) -> Result<AXApplication> {
|
||||||
|
anyhow::bail!("Not supported on this platform")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get AXUIElement for an application by name or PID
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
fn get_app_element(&self, app_name: &str) -> Result<AXUIElement> {
|
||||||
|
// Check cache first
|
||||||
|
{
|
||||||
|
let cache = self.app_cache.lock().unwrap();
|
||||||
|
if let Some(element) = cache.get(app_name) {
|
||||||
|
return Ok(element.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the app by name
|
||||||
|
let apps = Self::get_running_applications()?;
|
||||||
|
let app = apps
|
||||||
|
.iter()
|
||||||
|
.find(|a| a.name == app_name)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Application '{}' not found", app_name))?;
|
||||||
|
|
||||||
|
// Create AXUIElement for the app
|
||||||
|
let element = AXUIElement::application(app.pid);
|
||||||
|
|
||||||
|
// Cache it
|
||||||
|
{
|
||||||
|
let mut cache = self.app_cache.lock().unwrap();
|
||||||
|
cache.insert(app_name.to_string(), element.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(element)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Activate (bring to front) an application
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
pub fn activate_app(&self, app_name: &str) -> Result<()> {
|
||||||
|
use cocoa::base::id;
|
||||||
|
use objc::{class, msg_send, sel, sel_impl};
|
||||||
|
|
||||||
|
// Find the app
|
||||||
|
let apps = Self::get_running_applications()?;
|
||||||
|
let app = apps
|
||||||
|
.iter()
|
||||||
|
.find(|a| a.name == app_name)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Application '{}' not found", app_name))?;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
let workspace: id = msg_send![class!(NSWorkspace), sharedWorkspace];
|
||||||
|
let running_apps: id = msg_send![workspace, runningApplications];
|
||||||
|
let count: usize = msg_send![running_apps, count];
|
||||||
|
|
||||||
|
for i in 0..count {
|
||||||
|
let running_app: id = msg_send![running_apps, objectAtIndex: i];
|
||||||
|
let pid: i32 = msg_send![running_app, processIdentifier];
|
||||||
|
|
||||||
|
if pid == app.pid {
|
||||||
|
let _: bool = msg_send![running_app, activateWithOptions: 0];
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::bail!("Failed to activate application")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "macos"))]
|
||||||
|
pub fn activate_app(&self, _app_name: &str) -> Result<()> {
|
||||||
|
anyhow::bail!("Not supported on this platform")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the UI hierarchy of an application
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
pub fn get_ui_tree(&self, app_name: &str, max_depth: usize) -> Result<String> {
|
||||||
|
let app_element = self.get_app_element(app_name)?;
|
||||||
|
let mut output = format!("Application: {}\n", app_name);
|
||||||
|
|
||||||
|
Self::build_ui_tree(&app_element, &mut output, 0, max_depth)?;
|
||||||
|
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "macos"))]
|
||||||
|
pub fn get_ui_tree(&self, _app_name: &str, _max_depth: usize) -> Result<String> {
|
||||||
|
anyhow::bail!("Not supported on this platform")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
fn build_ui_tree(
|
||||||
|
element: &AXUIElement,
|
||||||
|
output: &mut String,
|
||||||
|
depth: usize,
|
||||||
|
max_depth: usize,
|
||||||
|
) -> Result<()> {
|
||||||
|
if depth >= max_depth {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let indent = " ".repeat(depth);
|
||||||
|
|
||||||
|
// Get role
|
||||||
|
let role = element.role().ok().map(|s| s.to_string())
|
||||||
|
.unwrap_or_else(|| "Unknown".to_string());
|
||||||
|
|
||||||
|
// Get title
|
||||||
|
let title = element.title().ok()
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
|
// Get identifier
|
||||||
|
let identifier = element.identifier().ok()
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
|
// Format output
|
||||||
|
output.push_str(&format!("{}Role: {}", indent, role));
|
||||||
|
if let Some(t) = title {
|
||||||
|
output.push_str(&format!(", Title: {}", t));
|
||||||
|
}
|
||||||
|
if let Some(id) = identifier {
|
||||||
|
output.push_str(&format!(", ID: {}", id));
|
||||||
|
}
|
||||||
|
output.push('\n');
|
||||||
|
|
||||||
|
// Get children
|
||||||
|
if let Ok(children) = element.children() {
|
||||||
|
for i in 0..children.len() {
|
||||||
|
if let Some(child) = children.get(i) {
|
||||||
|
let _ = Self::build_ui_tree(&child, output, depth + 1, max_depth);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find UI elements in an application
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
pub fn find_elements(
|
||||||
|
&self,
|
||||||
|
app_name: &str,
|
||||||
|
role: Option<&str>,
|
||||||
|
title: Option<&str>,
|
||||||
|
identifier: Option<&str>,
|
||||||
|
) -> Result<Vec<AXElement>> {
|
||||||
|
let app_element = self.get_app_element(app_name)?;
|
||||||
|
let mut found_elements = Vec::new();
|
||||||
|
|
||||||
|
let visitor = ElementCollector {
|
||||||
|
role_filter: role.map(|s| s.to_string()),
|
||||||
|
title_filter: title.map(|s| s.to_string()),
|
||||||
|
identifier_filter: identifier.map(|s| s.to_string()),
|
||||||
|
results: std::cell::RefCell::new(&mut found_elements),
|
||||||
|
depth: std::cell::Cell::new(0),
|
||||||
|
};
|
||||||
|
|
||||||
|
let walker = TreeWalker::new();
|
||||||
|
walker.walk(&app_element, &visitor);
|
||||||
|
|
||||||
|
Ok(found_elements)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "macos"))]
|
||||||
|
pub fn find_elements(
|
||||||
|
&self,
|
||||||
|
_app_name: &str,
|
||||||
|
_role: Option<&str>,
|
||||||
|
_title: Option<&str>,
|
||||||
|
_identifier: Option<&str>,
|
||||||
|
) -> Result<Vec<AXElement>> {
|
||||||
|
anyhow::bail!("Not supported on this platform")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find a single element (helper for click, set_value, etc.)
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
fn find_element(
|
||||||
|
&self,
|
||||||
|
app_name: &str,
|
||||||
|
role: &str,
|
||||||
|
title: Option<&str>,
|
||||||
|
identifier: Option<&str>,
|
||||||
|
) -> Result<AXUIElement> {
|
||||||
|
let app_element = self.get_app_element(app_name)?;
|
||||||
|
|
||||||
|
let role_str = role.to_string();
|
||||||
|
let title_str = title.map(|s| s.to_string());
|
||||||
|
let identifier_str = identifier.map(|s| s.to_string());
|
||||||
|
|
||||||
|
let finder = ElementFinder::new(
|
||||||
|
&app_element,
|
||||||
|
move |element| {
|
||||||
|
// Check role
|
||||||
|
let elem_role = element.role()
|
||||||
|
.ok()
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
|
if let Some(r) = elem_role {
|
||||||
|
if !r.contains(&role_str) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check title if specified
|
||||||
|
if let Some(ref title_filter) = title_str {
|
||||||
|
let elem_title = element.title()
|
||||||
|
.ok()
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
|
if let Some(t) = elem_title {
|
||||||
|
if !t.contains(title_filter) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check identifier if specified
|
||||||
|
if let Some(ref id_filter) = identifier_str {
|
||||||
|
let elem_id = element.identifier()
|
||||||
|
.ok()
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
|
if let Some(id) = elem_id {
|
||||||
|
if !id.contains(id_filter) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
true
|
||||||
|
},
|
||||||
|
Some(std::time::Duration::from_secs(2)),
|
||||||
|
);
|
||||||
|
|
||||||
|
finder.find().context("Element not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Click on a UI element
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
pub fn click_element(
|
||||||
|
&self,
|
||||||
|
app_name: &str,
|
||||||
|
role: &str,
|
||||||
|
title: Option<&str>,
|
||||||
|
identifier: Option<&str>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let element = self.find_element(app_name, role, title, identifier)?;
|
||||||
|
|
||||||
|
// Perform the press action
|
||||||
|
let action_name = CFString::new("AXPress");
|
||||||
|
element
|
||||||
|
.perform_action(&action_name)
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to perform press action: {:?}", e))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "macos"))]
|
||||||
|
pub fn click_element(
|
||||||
|
&self,
|
||||||
|
_app_name: &str,
|
||||||
|
_role: &str,
|
||||||
|
_title: Option<&str>,
|
||||||
|
_identifier: Option<&str>,
|
||||||
|
) -> Result<()> {
|
||||||
|
anyhow::bail!("Not supported on this platform")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the value of a UI element
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
pub fn set_value(
|
||||||
|
&self,
|
||||||
|
app_name: &str,
|
||||||
|
role: &str,
|
||||||
|
value: &str,
|
||||||
|
title: Option<&str>,
|
||||||
|
identifier: Option<&str>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let element = self.find_element(app_name, role, title, identifier)?;
|
||||||
|
|
||||||
|
// Set the value - convert CFString to CFType
|
||||||
|
let cf_value = CFString::new(value);
|
||||||
|
|
||||||
|
element.set_value(cf_value.as_CFType())
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to set value: {:?}", e))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "macos"))]
|
||||||
|
pub fn set_value(
|
||||||
|
&self,
|
||||||
|
_app_name: &str,
|
||||||
|
_role: &str,
|
||||||
|
_value: &str,
|
||||||
|
_title: Option<&str>,
|
||||||
|
_identifier: Option<&str>,
|
||||||
|
) -> Result<()> {
|
||||||
|
anyhow::bail!("Not supported on this platform")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the value of a UI element
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
pub fn get_value(
|
||||||
|
&self,
|
||||||
|
app_name: &str,
|
||||||
|
role: &str,
|
||||||
|
title: Option<&str>,
|
||||||
|
identifier: Option<&str>,
|
||||||
|
) -> Result<String> {
|
||||||
|
let element = self.find_element(app_name, role, title, identifier)?;
|
||||||
|
|
||||||
|
// Get the value
|
||||||
|
let value_type = element.value()
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to get value: {:?}", e))?;
|
||||||
|
|
||||||
|
// Try to downcast to CFString
|
||||||
|
if let Some(cf_string) = value_type.downcast::<CFString>() {
|
||||||
|
Ok(cf_string.to_string())
|
||||||
|
} else {
|
||||||
|
// For non-string values, try to get a description
|
||||||
|
Ok(format!("<non-string value>"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "macos"))]
|
||||||
|
pub fn get_value(
|
||||||
|
&self,
|
||||||
|
_app_name: &str,
|
||||||
|
_role: &str,
|
||||||
|
_title: Option<&str>,
|
||||||
|
_identifier: Option<&str>,
|
||||||
|
) -> Result<String> {
|
||||||
|
anyhow::bail!("Not supported on this platform")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Type text into the currently focused element (uses system text input)
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
pub fn type_text(&self, app_name: &str, text: &str) -> Result<()> {
|
||||||
|
use cocoa::base::{id, nil};
|
||||||
|
use cocoa::foundation::NSString;
|
||||||
|
use objc::{class, msg_send, sel, sel_impl};
|
||||||
|
|
||||||
|
// First, make sure the app is active
|
||||||
|
self.activate_app(app_name)?;
|
||||||
|
|
||||||
|
// Wait for app to fully activate
|
||||||
|
std::thread::sleep(std::time::Duration::from_millis(500));
|
||||||
|
|
||||||
|
// Send a Tab key to try to focus on a text field
|
||||||
|
// This helps ensure something is focused before we paste
|
||||||
|
let _ = self.press_key(app_name, "tab", vec![]);
|
||||||
|
std::thread::sleep(std::time::Duration::from_millis(800));
|
||||||
|
|
||||||
|
// Save old clipboard, set new content, paste, then restore
|
||||||
|
let old_content: id;
|
||||||
|
unsafe {
|
||||||
|
// Get the general pasteboard
|
||||||
|
let pasteboard: id = msg_send![class!(NSPasteboard), generalPasteboard];
|
||||||
|
|
||||||
|
// Save current clipboard content
|
||||||
|
let ns_string_type = NSString::alloc(nil).init_str("public.utf8-plain-text");
|
||||||
|
old_content = msg_send![pasteboard, stringForType: ns_string_type];
|
||||||
|
|
||||||
|
// Clear and set new content
|
||||||
|
let _: () = msg_send![pasteboard, clearContents];
|
||||||
|
|
||||||
|
let ns_string = NSString::alloc(nil).init_str(text);
|
||||||
|
let ns_type = NSString::alloc(nil).init_str("public.utf8-plain-text");
|
||||||
|
let _: bool = msg_send![pasteboard, setString:ns_string forType:ns_type];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait a moment for clipboard to update
|
||||||
|
std::thread::sleep(std::time::Duration::from_millis(200));
|
||||||
|
|
||||||
|
// Paste using Cmd+V (outside unsafe block)
|
||||||
|
self.press_key(app_name, "v", vec!["command"])?;
|
||||||
|
|
||||||
|
// Wait for paste to complete
|
||||||
|
std::thread::sleep(std::time::Duration::from_millis(300));
|
||||||
|
|
||||||
|
// Restore old clipboard content if it existed
|
||||||
|
unsafe {
|
||||||
|
if old_content != nil {
|
||||||
|
let pasteboard: id = msg_send![class!(NSPasteboard), generalPasteboard];
|
||||||
|
let _: () = msg_send![pasteboard, clearContents];
|
||||||
|
let ns_type = NSString::alloc(nil).init_str("public.utf8-plain-text");
|
||||||
|
let _: bool = msg_send![pasteboard, setString:old_content forType:ns_type];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "macos"))]
|
||||||
|
pub fn type_text(&self, _app_name: &str, _text: &str) -> Result<()> {
|
||||||
|
anyhow::bail!("Not supported on this platform")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Focus on a text field or text area element
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
pub fn focus_element(
|
||||||
|
&self,
|
||||||
|
app_name: &str,
|
||||||
|
role: &str,
|
||||||
|
title: Option<&str>,
|
||||||
|
identifier: Option<&str>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let element = self.find_element(app_name, role, title, identifier)?;
|
||||||
|
|
||||||
|
// Set focused attribute to true
|
||||||
|
use core_foundation::boolean::CFBoolean;
|
||||||
|
let cf_true = CFBoolean::true_value();
|
||||||
|
|
||||||
|
element.set_attribute(&accessibility::AXAttribute::focused(), cf_true)
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to focus element: {:?}", e))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Press a keyboard shortcut
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
pub fn press_key(
|
||||||
|
&self,
|
||||||
|
app_name: &str,
|
||||||
|
key: &str,
|
||||||
|
modifiers: Vec<&str>,
|
||||||
|
) -> Result<()> {
|
||||||
|
use core_graphics::event::{
|
||||||
|
CGEvent, CGEventFlags, CGEventTapLocation,
|
||||||
|
};
|
||||||
|
use core_graphics::event_source::{CGEventSource, CGEventSourceStateID};
|
||||||
|
|
||||||
|
// First, make sure the app is active
|
||||||
|
self.activate_app(app_name)?;
|
||||||
|
|
||||||
|
// Wait a bit for activation
|
||||||
|
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||||
|
|
||||||
|
// Map key string to key code
|
||||||
|
let key_code = Self::key_to_keycode(key)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Unknown key: {}", key))?;
|
||||||
|
|
||||||
|
// Map modifiers to flags
|
||||||
|
let mut flags = CGEventFlags::CGEventFlagNull;
|
||||||
|
for modifier in modifiers {
|
||||||
|
match modifier.to_lowercase().as_str() {
|
||||||
|
"command" | "cmd" => flags |= CGEventFlags::CGEventFlagCommand,
|
||||||
|
"option" | "alt" => flags |= CGEventFlags::CGEventFlagAlternate,
|
||||||
|
"control" | "ctrl" => flags |= CGEventFlags::CGEventFlagControl,
|
||||||
|
"shift" => flags |= CGEventFlags::CGEventFlagShift,
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create event source
|
||||||
|
let source = CGEventSource::new(CGEventSourceStateID::HIDSystemState)
|
||||||
|
.ok().context("Failed to create event source")?;
|
||||||
|
|
||||||
|
// Create key down event
|
||||||
|
let key_down = CGEvent::new_keyboard_event(source.clone(), key_code, true)
|
||||||
|
.ok().context("Failed to create key down event")?;
|
||||||
|
key_down.set_flags(flags);
|
||||||
|
|
||||||
|
// Create key up event
|
||||||
|
let key_up = CGEvent::new_keyboard_event(source, key_code, false)
|
||||||
|
.ok().context("Failed to create key up event")?;
|
||||||
|
key_up.set_flags(flags);
|
||||||
|
|
||||||
|
// Post events
|
||||||
|
key_down.post(CGEventTapLocation::HID);
|
||||||
|
std::thread::sleep(std::time::Duration::from_millis(50));
|
||||||
|
key_up.post(CGEventTapLocation::HID);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "macos"))]
|
||||||
|
pub fn press_key(
|
||||||
|
&self,
|
||||||
|
_app_name: &str,
|
||||||
|
_key: &str,
|
||||||
|
_modifiers: Vec<&str>,
|
||||||
|
) -> Result<()> {
|
||||||
|
anyhow::bail!("Not supported on this platform")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
fn key_to_keycode(key: &str) -> Option<u16> {
|
||||||
|
// Map common keys to keycodes
|
||||||
|
// See: https://eastmanreference.com/complete-list-of-applescript-key-codes
|
||||||
|
match key.to_lowercase().as_str() {
|
||||||
|
"a" => Some(0x00),
|
||||||
|
"s" => Some(0x01),
|
||||||
|
"d" => Some(0x02),
|
||||||
|
"f" => Some(0x03),
|
||||||
|
"h" => Some(0x04),
|
||||||
|
"g" => Some(0x05),
|
||||||
|
"z" => Some(0x06),
|
||||||
|
"x" => Some(0x07),
|
||||||
|
"c" => Some(0x08),
|
||||||
|
"v" => Some(0x09),
|
||||||
|
"b" => Some(0x0B),
|
||||||
|
"q" => Some(0x0C),
|
||||||
|
"w" => Some(0x0D),
|
||||||
|
"e" => Some(0x0E),
|
||||||
|
"r" => Some(0x0F),
|
||||||
|
"y" => Some(0x10),
|
||||||
|
"t" => Some(0x11),
|
||||||
|
"1" => Some(0x12),
|
||||||
|
"2" => Some(0x13),
|
||||||
|
"3" => Some(0x14),
|
||||||
|
"4" => Some(0x15),
|
||||||
|
"6" => Some(0x16),
|
||||||
|
"5" => Some(0x17),
|
||||||
|
"=" => Some(0x18),
|
||||||
|
"9" => Some(0x19),
|
||||||
|
"7" => Some(0x1A),
|
||||||
|
"-" => Some(0x1B),
|
||||||
|
"8" => Some(0x1C),
|
||||||
|
"0" => Some(0x1D),
|
||||||
|
"]" => Some(0x1E),
|
||||||
|
"o" => Some(0x1F),
|
||||||
|
"u" => Some(0x20),
|
||||||
|
"[" => Some(0x21),
|
||||||
|
"i" => Some(0x22),
|
||||||
|
"p" => Some(0x23),
|
||||||
|
"return" | "enter" => Some(0x24),
|
||||||
|
"l" => Some(0x25),
|
||||||
|
"j" => Some(0x26),
|
||||||
|
"'" => Some(0x27),
|
||||||
|
"k" => Some(0x28),
|
||||||
|
";" => Some(0x29),
|
||||||
|
"\\" => Some(0x2A),
|
||||||
|
"," => Some(0x2B),
|
||||||
|
"/" => Some(0x2C),
|
||||||
|
"n" => Some(0x2D),
|
||||||
|
"m" => Some(0x2E),
|
||||||
|
"." => Some(0x2F),
|
||||||
|
"tab" => Some(0x30),
|
||||||
|
"space" => Some(0x31),
|
||||||
|
"`" => Some(0x32),
|
||||||
|
"delete" | "backspace" => Some(0x33),
|
||||||
|
"escape" | "esc" => Some(0x35),
|
||||||
|
"f1" => Some(0x7A),
|
||||||
|
"f2" => Some(0x78),
|
||||||
|
"f3" => Some(0x63),
|
||||||
|
"f4" => Some(0x76),
|
||||||
|
"f5" => Some(0x60),
|
||||||
|
"f6" => Some(0x61),
|
||||||
|
"f7" => Some(0x62),
|
||||||
|
"f8" => Some(0x64),
|
||||||
|
"f9" => Some(0x65),
|
||||||
|
"f10" => Some(0x6D),
|
||||||
|
"f11" => Some(0x67),
|
||||||
|
"f12" => Some(0x6F),
|
||||||
|
"left" => Some(0x7B),
|
||||||
|
"right" => Some(0x7C),
|
||||||
|
"down" => Some(0x7D),
|
||||||
|
"up" => Some(0x7E),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
struct ElementCollector<'a> {
|
||||||
|
role_filter: Option<String>,
|
||||||
|
title_filter: Option<String>,
|
||||||
|
identifier_filter: Option<String>,
|
||||||
|
results: std::cell::RefCell<&'a mut Vec<AXElement>>,
|
||||||
|
depth: std::cell::Cell<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
impl<'a> TreeVisitor for ElementCollector<'a> {
|
||||||
|
fn enter_element(&self, element: &AXUIElement) -> TreeWalkerFlow {
|
||||||
|
self.depth.set(self.depth.get() + 1);
|
||||||
|
|
||||||
|
if self.depth.get() > 20 {
|
||||||
|
return TreeWalkerFlow::SkipSubtree;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get element properties
|
||||||
|
let role = element.role()
|
||||||
|
.ok()
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.unwrap_or_else(|| "Unknown".to_string());
|
||||||
|
|
||||||
|
let title = element.title()
|
||||||
|
.ok()
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
|
let identifier = element.identifier()
|
||||||
|
.ok()
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
|
// Check if this element matches the filters
|
||||||
|
let role_matches = self.role_filter.as_ref().map_or(true, |r| role.contains(r));
|
||||||
|
let title_matches = self.title_filter.as_ref().map_or(true, |t| {
|
||||||
|
title.as_ref().map_or(false, |title_str| title_str.contains(t))
|
||||||
|
});
|
||||||
|
let identifier_matches = self.identifier_filter.as_ref().map_or(true, |id| {
|
||||||
|
identifier.as_ref().map_or(false, |id_str| id_str.contains(id))
|
||||||
|
});
|
||||||
|
|
||||||
|
if role_matches && title_matches && identifier_matches {
|
||||||
|
// Get additional properties
|
||||||
|
let value = element.value()
|
||||||
|
.ok()
|
||||||
|
.and_then(|v| {
|
||||||
|
v.downcast::<CFString>().map(|s| s.to_string())
|
||||||
|
});
|
||||||
|
|
||||||
|
let label = element.description()
|
||||||
|
.ok()
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
|
let enabled = element.enabled()
|
||||||
|
.ok()
|
||||||
|
.map(|b| b.into())
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
let focused = element.focused()
|
||||||
|
.ok()
|
||||||
|
.map(|b| b.into())
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
// Count children
|
||||||
|
let children_count = element.children()
|
||||||
|
.ok()
|
||||||
|
.map(|arr| arr.len() as usize)
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
self.results.borrow_mut().push(AXElement {
|
||||||
|
role,
|
||||||
|
title,
|
||||||
|
value,
|
||||||
|
label,
|
||||||
|
identifier,
|
||||||
|
enabled,
|
||||||
|
focused,
|
||||||
|
position: None,
|
||||||
|
size: None,
|
||||||
|
children_count,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
TreeWalkerFlow::Continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fn exit_element(&self, _element: &AXUIElement) {
|
||||||
|
self.depth.set(self.depth.get() - 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
65
crates/g3-computer-control/src/macax/mod.rs
Normal file
65
crates/g3-computer-control/src/macax/mod.rs
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
pub mod controller;
|
||||||
|
|
||||||
|
pub use controller::MacAxController;
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests;
|
||||||
|
|
||||||
|
/// Represents an accessibility element in the UI hierarchy
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct AXElement {
|
||||||
|
pub role: String,
|
||||||
|
pub title: Option<String>,
|
||||||
|
pub value: Option<String>,
|
||||||
|
pub label: Option<String>,
|
||||||
|
pub identifier: Option<String>,
|
||||||
|
pub enabled: bool,
|
||||||
|
pub focused: bool,
|
||||||
|
pub position: Option<(f64, f64)>,
|
||||||
|
pub size: Option<(f64, f64)>,
|
||||||
|
pub children_count: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Represents a macOS application
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct AXApplication {
|
||||||
|
pub name: String,
|
||||||
|
pub bundle_id: Option<String>,
|
||||||
|
pub pid: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AXElement {
|
||||||
|
/// Convert to a human-readable string representation
|
||||||
|
pub fn to_string(&self) -> String {
|
||||||
|
let mut parts = vec![format!("Role: {}", self.role)];
|
||||||
|
|
||||||
|
if let Some(ref title) = self.title {
|
||||||
|
parts.push(format!("Title: {}", title));
|
||||||
|
}
|
||||||
|
if let Some(ref value) = self.value {
|
||||||
|
parts.push(format!("Value: {}", value));
|
||||||
|
}
|
||||||
|
if let Some(ref label) = self.label {
|
||||||
|
parts.push(format!("Label: {}", label));
|
||||||
|
}
|
||||||
|
if let Some(ref id) = self.identifier {
|
||||||
|
parts.push(format!("ID: {}", id));
|
||||||
|
}
|
||||||
|
|
||||||
|
parts.push(format!("Enabled: {}", self.enabled));
|
||||||
|
parts.push(format!("Focused: {}", self.focused));
|
||||||
|
|
||||||
|
if let Some((x, y)) = self.position {
|
||||||
|
parts.push(format!("Position: ({:.0}, {:.0})", x, y));
|
||||||
|
}
|
||||||
|
if let Some((w, h)) = self.size {
|
||||||
|
parts.push(format!("Size: ({:.0}, {:.0})", w, h));
|
||||||
|
}
|
||||||
|
|
||||||
|
parts.push(format!("Children: {}", self.children_count));
|
||||||
|
|
||||||
|
parts.join(", ")
|
||||||
|
}
|
||||||
|
}
|
||||||
37
crates/g3-computer-control/src/macax/tests.rs
Normal file
37
crates/g3-computer-control/src/macax/tests.rs
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::{AXElement, MacAxController};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ax_element_to_string() {
|
||||||
|
let element = AXElement {
|
||||||
|
role: "button".to_string(),
|
||||||
|
title: Some("Click Me".to_string()),
|
||||||
|
value: None,
|
||||||
|
label: Some("Submit Button".to_string()),
|
||||||
|
identifier: Some("submitBtn".to_string()),
|
||||||
|
enabled: true,
|
||||||
|
focused: false,
|
||||||
|
position: Some((100.0, 200.0)),
|
||||||
|
size: Some((80.0, 30.0)),
|
||||||
|
children_count: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let string_repr = element.to_string();
|
||||||
|
assert!(string_repr.contains("Role: button"));
|
||||||
|
assert!(string_repr.contains("Title: Click Me"));
|
||||||
|
assert!(string_repr.contains("Label: Submit Button"));
|
||||||
|
assert!(string_repr.contains("ID: submitBtn"));
|
||||||
|
assert!(string_repr.contains("Enabled: true"));
|
||||||
|
assert!(string_repr.contains("Position: (100, 200)"));
|
||||||
|
assert!(string_repr.contains("Size: (80, 30)"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_controller_creation() {
|
||||||
|
// Just test that we can create a controller
|
||||||
|
// Actual functionality requires macOS and permissions
|
||||||
|
let result = MacAxController::new();
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
}
|
||||||
26
crates/g3-computer-control/src/ocr/mod.rs
Normal file
26
crates/g3-computer-control/src/ocr/mod.rs
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
use crate::types::TextLocation;
|
||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
/// OCR engine trait for text recognition with bounding boxes
|
||||||
|
#[async_trait]
|
||||||
|
pub trait OCREngine: Send + Sync {
|
||||||
|
/// Extract text with locations from an image file
|
||||||
|
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>>;
|
||||||
|
|
||||||
|
/// Get the name of the OCR engine
|
||||||
|
fn name(&self) -> &str;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Platform-specific modules
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
pub mod vision;
|
||||||
|
|
||||||
|
pub mod tesseract;
|
||||||
|
|
||||||
|
// Re-export the default OCR engine for the platform
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
pub use vision::AppleVisionOCR as DefaultOCR;
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "macos"))]
|
||||||
|
pub use tesseract::TesseractOCR as DefaultOCR;
|
||||||
84
crates/g3-computer-control/src/ocr/tesseract.rs
Normal file
84
crates/g3-computer-control/src/ocr/tesseract.rs
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
use super::OCREngine;
|
||||||
|
use crate::types::TextLocation;
|
||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
/// Tesseract OCR engine (fallback/cross-platform)
|
||||||
|
pub struct TesseractOCR;
|
||||||
|
|
||||||
|
impl TesseractOCR {
|
||||||
|
pub fn new() -> Result<Self> {
|
||||||
|
// Check if tesseract is available
|
||||||
|
let tesseract_check = std::process::Command::new("which")
|
||||||
|
.arg("tesseract")
|
||||||
|
.output();
|
||||||
|
|
||||||
|
if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() {
|
||||||
|
anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\
|
||||||
|
To install tesseract:\n macOS: brew install tesseract\n \
|
||||||
|
Linux: sudo apt-get install tesseract-ocr (Ubuntu/Debian)\n \
|
||||||
|
sudo yum install tesseract (RHEL/CentOS)\n \
|
||||||
|
Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki\n\n\
|
||||||
|
After installation, restart your terminal and try again.");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl OCREngine for TesseractOCR {
|
||||||
|
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>> {
|
||||||
|
// Use tesseract CLI with TSV output to get bounding boxes
|
||||||
|
let output = std::process::Command::new("tesseract")
|
||||||
|
.arg(path)
|
||||||
|
.arg("stdout")
|
||||||
|
.arg("tsv")
|
||||||
|
.output()
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to run tesseract: {}", e))?;
|
||||||
|
|
||||||
|
if !output.status.success() {
|
||||||
|
anyhow::bail!("Tesseract failed: {}", String::from_utf8_lossy(&output.stderr));
|
||||||
|
}
|
||||||
|
|
||||||
|
let tsv_text = String::from_utf8_lossy(&output.stdout);
|
||||||
|
let mut locations = Vec::new();
|
||||||
|
|
||||||
|
// Parse TSV output (skip header line)
|
||||||
|
for (i, line) in tsv_text.lines().enumerate() {
|
||||||
|
if i == 0 { continue; } // Skip header
|
||||||
|
|
||||||
|
let parts: Vec<&str> = line.split('\t').collect();
|
||||||
|
if parts.len() >= 12 {
|
||||||
|
// TSV format: level, page_num, block_num, par_num, line_num, word_num,
|
||||||
|
// left, top, width, height, conf, text
|
||||||
|
if let (Ok(x), Ok(y), Ok(w), Ok(h), Ok(conf), text) = (
|
||||||
|
parts[6].parse::<i32>(),
|
||||||
|
parts[7].parse::<i32>(),
|
||||||
|
parts[8].parse::<i32>(),
|
||||||
|
parts[9].parse::<i32>(),
|
||||||
|
parts[10].parse::<f32>(),
|
||||||
|
parts[11],
|
||||||
|
) {
|
||||||
|
let trimmed = text.trim();
|
||||||
|
if !trimmed.is_empty() && conf > 0.0 {
|
||||||
|
locations.push(TextLocation {
|
||||||
|
text: trimmed.to_string(),
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
width: w,
|
||||||
|
height: h,
|
||||||
|
confidence: conf / 100.0, // Convert from 0-100 to 0-1
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(locations)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"Tesseract OCR"
|
||||||
|
}
|
||||||
|
}
|
||||||
103
crates/g3-computer-control/src/ocr/vision.rs
Normal file
103
crates/g3-computer-control/src/ocr/vision.rs
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
use super::OCREngine;
|
||||||
|
use crate::types::TextLocation;
|
||||||
|
use anyhow::{Result, Context};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use std::ffi::{CStr, CString};
|
||||||
|
use std::os::raw::{c_char, c_float, c_uint};
|
||||||
|
|
||||||
|
// FFI bindings to Swift VisionBridge
|
||||||
|
#[repr(C)]
|
||||||
|
struct VisionTextBox {
|
||||||
|
text: *const c_char,
|
||||||
|
text_len: c_uint,
|
||||||
|
x: i32,
|
||||||
|
y: i32,
|
||||||
|
width: i32,
|
||||||
|
height: i32,
|
||||||
|
confidence: c_float,
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
fn vision_recognize_text(
|
||||||
|
image_path: *const c_char,
|
||||||
|
image_path_len: c_uint,
|
||||||
|
out_boxes: *mut *mut std::ffi::c_void,
|
||||||
|
out_count: *mut c_uint,
|
||||||
|
) -> bool;
|
||||||
|
|
||||||
|
fn vision_free_boxes(boxes: *mut std::ffi::c_void, count: c_uint);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apple Vision Framework OCR engine
|
||||||
|
pub struct AppleVisionOCR;
|
||||||
|
|
||||||
|
impl AppleVisionOCR {
|
||||||
|
pub fn new() -> Result<Self> {
|
||||||
|
Ok(Self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl OCREngine for AppleVisionOCR {
|
||||||
|
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>> {
|
||||||
|
// Convert path to C string
|
||||||
|
let c_path = CString::new(path)
|
||||||
|
.context("Failed to convert path to C string")?;
|
||||||
|
|
||||||
|
let mut boxes_ptr: *mut std::ffi::c_void = std::ptr::null_mut();
|
||||||
|
let mut count: c_uint = 0;
|
||||||
|
|
||||||
|
// Call Swift Vision API
|
||||||
|
let success = unsafe {
|
||||||
|
vision_recognize_text(
|
||||||
|
c_path.as_ptr(),
|
||||||
|
path.len() as c_uint,
|
||||||
|
&mut boxes_ptr,
|
||||||
|
&mut count,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
if !success || boxes_ptr.is_null() {
|
||||||
|
anyhow::bail!("Apple Vision OCR failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert C array to Rust Vec
|
||||||
|
let mut locations = Vec::new();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
let typed_boxes = boxes_ptr as *const VisionTextBox;
|
||||||
|
let boxes_slice = std::slice::from_raw_parts(typed_boxes, count as usize);
|
||||||
|
|
||||||
|
for box_data in boxes_slice {
|
||||||
|
// Convert C string to Rust String
|
||||||
|
let text = if !box_data.text.is_null() {
|
||||||
|
CStr::from_ptr(box_data.text)
|
||||||
|
.to_string_lossy()
|
||||||
|
.into_owned()
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
if !text.is_empty() {
|
||||||
|
locations.push(TextLocation {
|
||||||
|
text,
|
||||||
|
x: box_data.x,
|
||||||
|
y: box_data.y,
|
||||||
|
width: box_data.width,
|
||||||
|
height: box_data.height,
|
||||||
|
confidence: box_data.confidence,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Free the C array
|
||||||
|
vision_free_boxes(boxes_ptr, count);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(locations)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"Apple Vision Framework"
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -63,10 +63,15 @@ impl ComputerController for LinuxController {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn take_screenshot(&self, _path: &str, _region: Option<Rect>, _window_id: Option<&str>) -> Result<()> {
|
async fn take_screenshot(&self, _path: &str, _region: Option<Rect>, _window_id: Option<&str>) -> Result<()> {
|
||||||
|
// Enforce that window_id must be provided
|
||||||
|
if _window_id.is_none() {
|
||||||
|
anyhow::bail!("window_id is required. You must specify which window to capture (e.g., 'Firefox', 'Terminal', 'gedit'). Use list_windows to see available windows.");
|
||||||
|
}
|
||||||
|
|
||||||
anyhow::bail!("Linux implementation not yet available")
|
anyhow::bail!("Linux implementation not yet available")
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn extract_text_from_screen(&self, _region: Rect) -> Result<OCRResult> {
|
async fn extract_text_from_screen(&self, _region: Rect, _window_id: &str) -> Result<String> {
|
||||||
anyhow::bail!("Linux implementation not yet available")
|
anyhow::bail!("Linux implementation not yet available")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,22 +1,37 @@
|
|||||||
use crate::{ComputerController, types::Rect};
|
use crate::{ComputerController, types::{Rect, TextLocation}};
|
||||||
use anyhow::Result;
|
use crate::ocr::{OCREngine, DefaultOCR};
|
||||||
|
use anyhow::{Result, Context};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use tesseract::Tesseract;
|
use core_graphics::window::{kCGWindowListOptionOnScreenOnly, kCGNullWindowID, CGWindowListCopyWindowInfo};
|
||||||
|
use core_foundation::dictionary::CFDictionary;
|
||||||
|
use core_foundation::string::CFString;
|
||||||
|
use core_foundation::base::{TCFType, ToVoid};
|
||||||
|
use core_foundation::array::CFArray;
|
||||||
|
|
||||||
pub struct MacOSController {
|
pub struct MacOSController {
|
||||||
// Empty struct for now
|
ocr_engine: Box<dyn OCREngine>,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
ocr_name: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MacOSController {
|
impl MacOSController {
|
||||||
pub fn new() -> Result<Self> {
|
pub fn new() -> Result<Self> {
|
||||||
Ok(Self {})
|
let ocr = Box::new(DefaultOCR::new()?);
|
||||||
|
let ocr_name = ocr.name().to_string();
|
||||||
|
tracing::info!("Initialized macOS controller with OCR engine: {}", ocr_name);
|
||||||
|
Ok(Self { ocr_engine: ocr, ocr_name })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl ComputerController for MacOSController {
|
impl ComputerController for MacOSController {
|
||||||
async fn take_screenshot(&self, path: &str, region: Option<Rect>, window_id: Option<&str>) -> Result<()> {
|
async fn take_screenshot(&self, path: &str, region: Option<Rect>, window_id: Option<&str>) -> Result<()> {
|
||||||
|
// Enforce that window_id must be provided
|
||||||
|
if window_id.is_none() {
|
||||||
|
return Err(anyhow::anyhow!("window_id is required. You must specify which window to capture (e.g., 'Safari', 'Terminal', 'Google Chrome'). Use list_windows to see available windows."));
|
||||||
|
}
|
||||||
|
|
||||||
// Determine the temporary directory for screenshots
|
// Determine the temporary directory for screenshots
|
||||||
let temp_dir = std::env::var("TMPDIR")
|
let temp_dir = std::env::var("TMPDIR")
|
||||||
.or_else(|_| std::env::var("HOME").map(|h| format!("{}/tmp", h)))
|
.or_else(|_| std::env::var("HOME").map(|h| format!("{}/tmp", h)))
|
||||||
@@ -37,48 +52,134 @@ impl ComputerController for MacOSController {
|
|||||||
std::fs::create_dir_all(parent)?;
|
std::fs::create_dir_all(parent)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut cmd = std::process::Command::new("screencapture");
|
let app_name = window_id.unwrap(); // Safe because we checked is_none() above
|
||||||
|
|
||||||
// Add flags
|
// Get the window ID for the specified application
|
||||||
|
let cg_window_id = unsafe {
|
||||||
|
let window_list = CGWindowListCopyWindowInfo(
|
||||||
|
kCGWindowListOptionOnScreenOnly,
|
||||||
|
kCGNullWindowID
|
||||||
|
);
|
||||||
|
|
||||||
|
let array = CFArray::<CFDictionary>::wrap_under_create_rule(window_list);
|
||||||
|
let count = array.len();
|
||||||
|
|
||||||
|
let mut found_window_id: Option<(u32, String)> = None; // (id, owner)
|
||||||
|
let app_name_lower = app_name.to_lowercase();
|
||||||
|
|
||||||
|
for i in 0..count {
|
||||||
|
let dict = array.get(i).unwrap();
|
||||||
|
|
||||||
|
// Get owner name
|
||||||
|
let owner_key = CFString::from_static_string("kCGWindowOwnerName");
|
||||||
|
let owner: String = if let Some(value) = dict.find(owner_key.to_void()) {
|
||||||
|
let s: CFString = TCFType::wrap_under_get_rule(*value as *const _);
|
||||||
|
s.to_string()
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::debug!("Checking window: owner='{}', looking for '{}'", owner, app_name);
|
||||||
|
let owner_lower = owner.to_lowercase();
|
||||||
|
|
||||||
|
// Normalize by removing spaces for exact matching
|
||||||
|
let app_name_normalized = app_name_lower.replace(" ", "");
|
||||||
|
let owner_normalized = owner_lower.replace(" ", "");
|
||||||
|
|
||||||
|
// ONLY accept exact matches (case-insensitive, with or without spaces)
|
||||||
|
// This prevents "Goose" from matching "GooseStudio"
|
||||||
|
let is_match = owner_lower == app_name_lower || owner_normalized == app_name_normalized;
|
||||||
|
|
||||||
|
if is_match {
|
||||||
|
// Get window ID
|
||||||
|
let window_id_key = CFString::from_static_string("kCGWindowNumber");
|
||||||
|
if let Some(value) = dict.find(window_id_key.to_void()) {
|
||||||
|
let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _);
|
||||||
|
if let Some(id) = num.to_i64() {
|
||||||
|
// Get window layer to filter out menu bar windows
|
||||||
|
let layer_key = CFString::from_static_string("kCGWindowLayer");
|
||||||
|
let layer: i32 = if let Some(value) = dict.find(layer_key.to_void()) {
|
||||||
|
let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _);
|
||||||
|
num.to_i32().unwrap_or(0)
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get window bounds to verify it's a real window
|
||||||
|
let bounds_key = CFString::from_static_string("kCGWindowBounds");
|
||||||
|
let has_real_bounds = if let Some(value) = dict.find(bounds_key.to_void()) {
|
||||||
|
let bounds_dict: CFDictionary = TCFType::wrap_under_get_rule(*value as *const _);
|
||||||
|
let width_key = CFString::from_static_string("Width");
|
||||||
|
let height_key = CFString::from_static_string("Height");
|
||||||
|
|
||||||
|
if let (Some(w_val), Some(h_val)) = (
|
||||||
|
bounds_dict.find(width_key.to_void()),
|
||||||
|
bounds_dict.find(height_key.to_void()),
|
||||||
|
) {
|
||||||
|
let w_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*w_val as *const _);
|
||||||
|
let h_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*h_val as *const _);
|
||||||
|
let width = w_num.to_f64().unwrap_or(0.0);
|
||||||
|
let height = h_num.to_f64().unwrap_or(0.0);
|
||||||
|
// Real windows should be at least 100x100 pixels
|
||||||
|
width >= 100.0 && height >= 100.0
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
};
|
||||||
|
|
||||||
|
// Only accept windows that are:
|
||||||
|
// 1. At layer 0 (normal windows, not menu bar)
|
||||||
|
// 2. Have real bounds (width and height >= 100)
|
||||||
|
if layer == 0 && has_real_bounds {
|
||||||
|
tracing::info!("Found valid window: ID {} for app '{}' (layer={}, bounds valid)", id, owner, layer);
|
||||||
|
found_window_id = Some((id as u32, owner.clone()));
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
tracing::debug!("Skipping window ID {} for '{}': layer={}, has_real_bounds={}", id, owner, layer, has_real_bounds);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
found_window_id
|
||||||
|
};
|
||||||
|
|
||||||
|
let (cg_window_id, matched_owner) = cg_window_id.ok_or_else(|| {
|
||||||
|
anyhow::anyhow!("Could not find window for application '{}'. Use list_windows to see available windows.", app_name)
|
||||||
|
})?;
|
||||||
|
tracing::info!("Taking screenshot of window ID {} for app '{}'", cg_window_id, matched_owner);
|
||||||
|
|
||||||
|
// Use screencapture with the window ID for now
|
||||||
|
// TODO: Implement direct CGWindowListCreateImage approach with proper image saving
|
||||||
|
let mut cmd = std::process::Command::new("screencapture");
|
||||||
cmd.arg("-x"); // No sound
|
cmd.arg("-x"); // No sound
|
||||||
|
cmd.arg("-l");
|
||||||
|
cmd.arg(cg_window_id.to_string());
|
||||||
|
|
||||||
if let Some(region) = region {
|
if let Some(region) = region {
|
||||||
// Capture specific region: -R x,y,width,height
|
|
||||||
cmd.arg("-R");
|
cmd.arg("-R");
|
||||||
cmd.arg(format!("{},{},{},{}", region.x, region.y, region.width, region.height));
|
cmd.arg(format!("{},{},{},{}", region.x, region.y, region.width, region.height));
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(app_name) = window_id {
|
|
||||||
// Capture specific window by app name
|
|
||||||
// Use AppleScript to get window ID
|
|
||||||
let script = format!(r#"tell application "{}" to id of window 1"#, app_name);
|
|
||||||
let output = std::process::Command::new("osascript")
|
|
||||||
.arg("-e")
|
|
||||||
.arg(&script)
|
|
||||||
.output()?;
|
|
||||||
|
|
||||||
if output.status.success() {
|
|
||||||
let window_id_str = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
|
||||||
cmd.arg(format!("-l{}", window_id_str));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd.arg(&final_path);
|
cmd.arg(&final_path);
|
||||||
|
|
||||||
let screenshot_result = cmd.output()?;
|
let screenshot_result = cmd.output()?;
|
||||||
|
|
||||||
if !screenshot_result.status.success() {
|
if !screenshot_result.status.success() {
|
||||||
let stderr = String::from_utf8_lossy(&screenshot_result.stderr);
|
let stderr = String::from_utf8_lossy(&screenshot_result.stderr);
|
||||||
return Err(anyhow::anyhow!("screencapture failed: {}", stderr));
|
return Err(anyhow::anyhow!("screencapture failed for window {}: {}", cg_window_id, stderr));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn extract_text_from_screen(&self, region: Rect) -> Result<String> {
|
async fn extract_text_from_screen(&self, region: Rect, window_id: &str) -> Result<String> {
|
||||||
// Take screenshot of region first
|
// Take screenshot of region first
|
||||||
let temp_path = format!("/tmp/g3_ocr_{}.png", uuid::Uuid::new_v4());
|
let temp_path = format!("/tmp/g3_ocr_{}.png", uuid::Uuid::new_v4());
|
||||||
self.take_screenshot(&temp_path, Some(region), None).await?;
|
self.take_screenshot(&temp_path, Some(region), Some(window_id)).await?;
|
||||||
|
|
||||||
// Extract text from the screenshot
|
// Extract text from the screenshot
|
||||||
let result = self.extract_text_from_image(&temp_path).await?;
|
let result = self.extract_text_from_image(&temp_path).await?;
|
||||||
@@ -90,36 +191,317 @@ impl ComputerController for MacOSController {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn extract_text_from_image(&self, path: &str) -> Result<String> {
|
async fn extract_text_from_image(&self, path: &str) -> Result<String> {
|
||||||
// Check if tesseract is available on the system
|
// Extract all text and concatenate
|
||||||
let tesseract_check = std::process::Command::new("which")
|
let locations = self.ocr_engine.extract_text_with_locations(path).await?;
|
||||||
.arg("tesseract")
|
Ok(locations.iter().map(|loc| loc.text.as_str()).collect::<Vec<_>>().join(" "))
|
||||||
.output();
|
}
|
||||||
|
|
||||||
|
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>> {
|
||||||
|
// Use the OCR engine
|
||||||
|
self.ocr_engine.extract_text_with_locations(path).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn find_text_in_app(&self, app_name: &str, search_text: &str) -> Result<Option<TextLocation>> {
|
||||||
|
// Take screenshot of specific app window
|
||||||
|
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
|
||||||
|
let temp_path = format!("{}/tmp/g3_find_text_{}_{}.png", home, app_name, uuid::Uuid::new_v4());
|
||||||
|
self.take_screenshot(&temp_path, None, Some(app_name)).await?;
|
||||||
|
|
||||||
if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() {
|
// Get screenshot dimensions before we delete it
|
||||||
anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\
|
let screenshot_dims = get_image_dimensions(&temp_path)?;
|
||||||
To install tesseract:\n macOS: brew install tesseract\n \
|
|
||||||
Linux: sudo apt-get install tesseract-ocr (Ubuntu/Debian)\n \
|
// Extract all text with locations
|
||||||
sudo yum install tesseract (RHEL/CentOS)\n \
|
let locations = self.extract_text_with_locations(&temp_path).await?;
|
||||||
Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki\n\n\
|
|
||||||
After installation, restart your terminal and try again.");
|
// Get window bounds to calculate coordinate transformation
|
||||||
|
let window_bounds = self.get_window_bounds(app_name)?;
|
||||||
|
|
||||||
|
// Clean up temp file
|
||||||
|
let _ = std::fs::remove_file(&temp_path);
|
||||||
|
|
||||||
|
// Find matching text (case-insensitive)
|
||||||
|
let search_lower = search_text.to_lowercase();
|
||||||
|
for location in locations {
|
||||||
|
if location.text.to_lowercase().contains(&search_lower) {
|
||||||
|
// Transform coordinates from screenshot space to screen space
|
||||||
|
let transformed = transform_screenshot_to_screen_coords(
|
||||||
|
location,
|
||||||
|
window_bounds,
|
||||||
|
screenshot_dims,
|
||||||
|
);
|
||||||
|
return Ok(Some(transformed));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize Tesseract
|
Ok(None)
|
||||||
let tess = Tesseract::new(None, Some("eng"))
|
|
||||||
.map_err(|e| {
|
|
||||||
anyhow::anyhow!("Failed to initialize Tesseract: {}\n\n\
|
|
||||||
This usually means:\n1. Tesseract is not properly installed\n\
|
|
||||||
2. Language data files are missing\n\nTo fix:\n \
|
|
||||||
macOS: brew reinstall tesseract\n \
|
|
||||||
Linux: sudo apt-get install tesseract-ocr-eng\n \
|
|
||||||
Windows: Reinstall tesseract and ensure language files are included", e)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let text = tess.set_image(path)
|
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to load image '{}': {}", path, e))?
|
|
||||||
.get_text()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to extract text from image: {}", e))?;
|
|
||||||
|
|
||||||
Ok(text)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
fn move_mouse(&self, x: i32, y: i32) -> Result<()> {
|
||||||
|
use core_graphics::event::{
|
||||||
|
CGEvent, CGEventTapLocation, CGEventType, CGMouseButton,
|
||||||
|
};
|
||||||
|
use core_graphics::event_source::{
|
||||||
|
CGEventSource, CGEventSourceStateID,
|
||||||
|
};
|
||||||
|
use core_graphics::geometry::CGPoint;
|
||||||
|
|
||||||
|
let source = CGEventSource::new(CGEventSourceStateID::HIDSystemState)
|
||||||
|
.ok().context("Failed to create event source")?;
|
||||||
|
|
||||||
|
let event = CGEvent::new_mouse_event(
|
||||||
|
source,
|
||||||
|
CGEventType::MouseMoved,
|
||||||
|
CGPoint::new(x as f64, y as f64),
|
||||||
|
CGMouseButton::Left,
|
||||||
|
).ok().context("Failed to create mouse event")?;
|
||||||
|
|
||||||
|
event.post(CGEventTapLocation::HID);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn click_at(&self, x: i32, y: i32, _app_name: Option<&str>) -> Result<()> {
|
||||||
|
use core_graphics::event::{
|
||||||
|
CGEvent, CGEventTapLocation, CGEventType, CGMouseButton,
|
||||||
|
};
|
||||||
|
use core_graphics::event_source::{
|
||||||
|
CGEventSource, CGEventSourceStateID,
|
||||||
|
};
|
||||||
|
use core_graphics::geometry::CGPoint;
|
||||||
|
use core_graphics::display::CGDisplay;
|
||||||
|
|
||||||
|
// IMPORTANT: Coordinates passed here are in NSScreen/CGWindowListCopyWindowInfo space
|
||||||
|
// (Y=0 at BOTTOM, increases UPWARD)
|
||||||
|
// But CGEvent uses a different coordinate system (Y=0 at TOP, increases DOWNWARD)
|
||||||
|
// We need to convert: CGEvent.y = screenHeight - NSScreen.y
|
||||||
|
|
||||||
|
let screen_height = CGDisplay::main().pixels_high() as i32;
|
||||||
|
let cgevent_x = x;
|
||||||
|
let cgevent_y = screen_height - y;
|
||||||
|
|
||||||
|
tracing::debug!("click_at: NSScreen coords ({}, {}) -> CGEvent coords ({}, {}) [screen_height={}]",
|
||||||
|
x, y, cgevent_x, cgevent_y, screen_height);
|
||||||
|
|
||||||
|
let (global_x, global_y) = (cgevent_x, cgevent_y);
|
||||||
|
|
||||||
|
let point = CGPoint::new(global_x as f64, global_y as f64);
|
||||||
|
|
||||||
|
let source = CGEventSource::new(CGEventSourceStateID::HIDSystemState)
|
||||||
|
.ok().context("Failed to create event source")?;
|
||||||
|
|
||||||
|
// Move mouse to position first
|
||||||
|
let move_event = CGEvent::new_mouse_event(
|
||||||
|
source.clone(),
|
||||||
|
CGEventType::MouseMoved,
|
||||||
|
point,
|
||||||
|
CGMouseButton::Left,
|
||||||
|
).ok().context("Failed to create mouse move event")?;
|
||||||
|
move_event.post(CGEventTapLocation::HID);
|
||||||
|
|
||||||
|
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||||
|
|
||||||
|
// Mouse down
|
||||||
|
let mouse_down = CGEvent::new_mouse_event(
|
||||||
|
source.clone(),
|
||||||
|
CGEventType::LeftMouseDown,
|
||||||
|
point,
|
||||||
|
CGMouseButton::Left,
|
||||||
|
).ok().context("Failed to create mouse down event")?;
|
||||||
|
mouse_down.post(CGEventTapLocation::HID);
|
||||||
|
|
||||||
|
std::thread::sleep(std::time::Duration::from_millis(50));
|
||||||
|
|
||||||
|
// Mouse up
|
||||||
|
let mouse_up = CGEvent::new_mouse_event(
|
||||||
|
source,
|
||||||
|
CGEventType::LeftMouseUp,
|
||||||
|
point,
|
||||||
|
CGMouseButton::Left,
|
||||||
|
).ok().context("Failed to create mouse up event")?;
|
||||||
|
mouse_up.post(CGEventTapLocation::HID);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MacOSController {
|
||||||
|
/// Get window bounds for an application (helper method)
|
||||||
|
fn get_window_bounds(&self, app_name: &str) -> Result<(i32, i32, i32, i32)> {
|
||||||
|
unsafe {
|
||||||
|
let window_list = CGWindowListCopyWindowInfo(
|
||||||
|
kCGWindowListOptionOnScreenOnly,
|
||||||
|
kCGNullWindowID
|
||||||
|
);
|
||||||
|
|
||||||
|
let array = CFArray::<CFDictionary>::wrap_under_create_rule(window_list);
|
||||||
|
let count = array.len();
|
||||||
|
|
||||||
|
let app_name_lower = app_name.to_lowercase();
|
||||||
|
|
||||||
|
for i in 0..count {
|
||||||
|
let dict = array.get(i).unwrap();
|
||||||
|
|
||||||
|
// Get owner name
|
||||||
|
let owner_key = CFString::from_static_string("kCGWindowOwnerName");
|
||||||
|
let owner: String = if let Some(value) = dict.find(owner_key.to_void()) {
|
||||||
|
let s: CFString = TCFType::wrap_under_get_rule(*value as *const _);
|
||||||
|
s.to_string()
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
let owner_lower = owner.to_lowercase();
|
||||||
|
|
||||||
|
// Normalize by removing spaces for exact matching
|
||||||
|
let app_name_normalized = app_name_lower.replace(" ", "");
|
||||||
|
let owner_normalized = owner_lower.replace(" ", "");
|
||||||
|
|
||||||
|
// ONLY accept exact matches (case-insensitive, with or without spaces)
|
||||||
|
// This prevents "Goose" from matching "GooseStudio"
|
||||||
|
let is_match = owner_lower == app_name_lower || owner_normalized == app_name_normalized;
|
||||||
|
|
||||||
|
if is_match {
|
||||||
|
// Get window layer to filter out menu bar windows
|
||||||
|
let layer_key = CFString::from_static_string("kCGWindowLayer");
|
||||||
|
let layer: i32 = if let Some(value) = dict.find(layer_key.to_void()) {
|
||||||
|
let num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*value as *const _);
|
||||||
|
num.to_i32().unwrap_or(0)
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
|
||||||
|
// Skip menu bar windows (layer >= 20)
|
||||||
|
if layer >= 20 {
|
||||||
|
tracing::debug!("Skipping window for '{}' at layer {} (menu bar)", owner, layer);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get window bounds to verify it's a real window
|
||||||
|
let bounds_key = CFString::from_static_string("kCGWindowBounds");
|
||||||
|
if let Some(value) = dict.find(bounds_key.to_void()) {
|
||||||
|
let bounds_dict: CFDictionary = TCFType::wrap_under_get_rule(*value as *const _);
|
||||||
|
|
||||||
|
let x_key = CFString::from_static_string("X");
|
||||||
|
let y_key = CFString::from_static_string("Y");
|
||||||
|
let width_key = CFString::from_static_string("Width");
|
||||||
|
let height_key = CFString::from_static_string("Height");
|
||||||
|
|
||||||
|
if let (Some(x_val), Some(y_val), Some(w_val), Some(h_val)) = (
|
||||||
|
bounds_dict.find(x_key.to_void()),
|
||||||
|
bounds_dict.find(y_key.to_void()),
|
||||||
|
bounds_dict.find(width_key.to_void()),
|
||||||
|
bounds_dict.find(height_key.to_void()),
|
||||||
|
) {
|
||||||
|
let x_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*x_val as *const _);
|
||||||
|
let y_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*y_val as *const _);
|
||||||
|
let w_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*w_val as *const _);
|
||||||
|
let h_num: core_foundation::number::CFNumber = TCFType::wrap_under_get_rule(*h_val as *const _);
|
||||||
|
|
||||||
|
let x: i32 = x_num.to_i64().unwrap_or(0) as i32;
|
||||||
|
let y: i32 = y_num.to_i64().unwrap_or(0) as i32;
|
||||||
|
let w: i32 = w_num.to_i64().unwrap_or(0) as i32;
|
||||||
|
let h: i32 = h_num.to_i64().unwrap_or(0) as i32;
|
||||||
|
|
||||||
|
// Only accept windows with real bounds (>= 100x100 pixels)
|
||||||
|
if w >= 100 && h >= 100 {
|
||||||
|
tracing::info!("Found valid window bounds for '{}': x={}, y={}, w={}, h={} (layer={})", owner, x, y, w, h, layer);
|
||||||
|
return Ok((x, y, w, h));
|
||||||
|
} else {
|
||||||
|
tracing::debug!("Skipping window for '{}': too small ({}x{})", owner, w, h);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(anyhow::anyhow!("Could not find window bounds for '{}'", app_name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get image dimensions from a PNG file
|
||||||
|
fn get_image_dimensions(path: &str) -> Result<(i32, i32)> {
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::Read;
|
||||||
|
|
||||||
|
let mut file = File::open(path)?;
|
||||||
|
let mut buffer = vec![0u8; 24];
|
||||||
|
file.read_exact(&mut buffer)?;
|
||||||
|
|
||||||
|
// PNG signature check
|
||||||
|
if &buffer[0..8] != b"\x89PNG\r\n\x1a\n" {
|
||||||
|
anyhow::bail!("Not a valid PNG file");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read IHDR chunk (width and height are at bytes 16-23)
|
||||||
|
let width = u32::from_be_bytes([buffer[16], buffer[17], buffer[18], buffer[19]]) as i32;
|
||||||
|
let height = u32::from_be_bytes([buffer[20], buffer[21], buffer[22], buffer[23]]) as i32;
|
||||||
|
|
||||||
|
Ok((width, height))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Transform coordinates from screenshot space to screen space
|
||||||
|
///
|
||||||
|
/// The screenshot is taken of a window, and Vision OCR returns coordinates
|
||||||
|
/// relative to the screenshot image. We need to transform these to actual
|
||||||
|
/// screen coordinates for clicking.
|
||||||
|
///
|
||||||
|
/// On Retina displays, screenshots are taken at 2x resolution, so we need
|
||||||
|
/// to account for this scaling factor.
|
||||||
|
fn transform_screenshot_to_screen_coords(
|
||||||
|
location: TextLocation,
|
||||||
|
window_bounds: (i32, i32, i32, i32), // (x, y, width, height) in screen space
|
||||||
|
screenshot_dims: (i32, i32), // (width, height) in pixels
|
||||||
|
) -> TextLocation {
|
||||||
|
let (win_x, win_y, win_width, win_height) = window_bounds;
|
||||||
|
let (screenshot_width, screenshot_height) = screenshot_dims;
|
||||||
|
|
||||||
|
// Calculate scale factors
|
||||||
|
// On Retina displays, screenshot is typically 2x the window size
|
||||||
|
let scale_x = win_width as f64 / screenshot_width as f64;
|
||||||
|
let scale_y = win_height as f64 / screenshot_height as f64;
|
||||||
|
|
||||||
|
tracing::debug!("Transform: screenshot={}x{}, window={}x{} at ({},{}), scale=({:.2},{:.2})",
|
||||||
|
screenshot_width, screenshot_height, win_width, win_height, win_x, win_y, scale_x, scale_y);
|
||||||
|
|
||||||
|
// Transform coordinates from image space to screen space
|
||||||
|
// IMPORTANT: macOS screen coordinates have origin at BOTTOM-LEFT (Y increases upward)
|
||||||
|
// Image coordinates have origin at TOP-LEFT (Y increases downward)
|
||||||
|
// win_y is the BOTTOM of the window in screen coordinates
|
||||||
|
// So we need to: (win_y + win_height) to get window TOP, then subtract screenshot_y
|
||||||
|
let window_top_y = win_y + win_height;
|
||||||
|
|
||||||
|
tracing::debug!("[transform] Input location in image space: x={}, y={}, width={}, height={}",
|
||||||
|
location.x, location.y, location.width, location.height);
|
||||||
|
tracing::debug!("[transform] Scale factors: scale_x={:.4}, scale_y={:.4}", scale_x, scale_y);
|
||||||
|
|
||||||
|
let transformed_x = win_x + (location.x as f64 * scale_x) as i32;
|
||||||
|
let transformed_y = window_top_y - (location.y as f64 * scale_y) as i32;
|
||||||
|
let transformed_width = (location.width as f64 * scale_x) as i32;
|
||||||
|
let transformed_height = (location.height as f64 * scale_y) as i32;
|
||||||
|
|
||||||
|
tracing::debug!("[transform] Calculation details:");
|
||||||
|
tracing::debug!(" - transformed_x = {} + ({} * {:.4}) = {} + {:.2} = {}", win_x, location.x, scale_x, win_x, location.x as f64 * scale_x, transformed_x);
|
||||||
|
tracing::debug!(" - transformed_width = ({} * {:.4}) = {:.2} -> {}", location.width, scale_x, location.width as f64 * scale_x, transformed_width);
|
||||||
|
tracing::debug!(" - transformed_height = ({} * {:.4}) = {:.2} -> {}", location.height, scale_y, location.height as f64 * scale_y, transformed_height);
|
||||||
|
|
||||||
|
tracing::debug!("Transformed location: screenshot=({},{}) {}x{} -> screen=({},{}) {}x{}",
|
||||||
|
location.x, location.y, location.width, location.height,
|
||||||
|
transformed_x, transformed_y, transformed_width, transformed_height);
|
||||||
|
|
||||||
|
TextLocation {
|
||||||
|
text: location.text,
|
||||||
|
x: transformed_x,
|
||||||
|
y: transformed_y,
|
||||||
|
width: transformed_width,
|
||||||
|
height: transformed_height,
|
||||||
|
confidence: location.confidence,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[path = "macos_window_matching_test.rs"]
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests;
|
||||||
@@ -1,425 +0,0 @@
|
|||||||
use crate::{ComputerController, types::*};
|
|
||||||
use anyhow::Result;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use core_graphics::display::CGPoint;
|
|
||||||
use core_graphics::event::{CGEvent, CGEventType, CGMouseButton, CGEventTapLocation};
|
|
||||||
use core_graphics::event_source::{CGEventSource, CGEventSourceStateID};
|
|
||||||
use std::path::Path;
|
|
||||||
use tesseract::Tesseract;
|
|
||||||
|
|
||||||
// MacOSController doesn't store CGEventSource to avoid Send/Sync issues
|
|
||||||
// We create it fresh for each operation
|
|
||||||
pub struct MacOSController {
|
|
||||||
// Empty struct - event source created per operation
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MacOSController {
|
|
||||||
pub fn new() -> Result<Self> {
|
|
||||||
// Test that we can create an event source
|
|
||||||
let _event_source = CGEventSource::new(CGEventSourceStateID::CombinedSessionState)
|
|
||||||
.map_err(|_| anyhow::anyhow!("Failed to create event source. Make sure Accessibility permissions are granted."))?;
|
|
||||||
Ok(Self {})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn key_to_keycode(&self, key: &str) -> Result<u16> {
|
|
||||||
// Map key names to macOS keycodes
|
|
||||||
let keycode = match key.to_lowercase().as_str() {
|
|
||||||
"return" | "enter" => 36,
|
|
||||||
"tab" => 48,
|
|
||||||
"space" => 49,
|
|
||||||
"delete" | "backspace" => 51,
|
|
||||||
"escape" | "esc" => 53,
|
|
||||||
"command" | "cmd" => 55,
|
|
||||||
"shift" => 56,
|
|
||||||
"capslock" => 57,
|
|
||||||
"option" | "alt" => 58,
|
|
||||||
"control" | "ctrl" => 59,
|
|
||||||
"left" => 123,
|
|
||||||
"right" => 124,
|
|
||||||
"down" => 125,
|
|
||||||
"up" => 126,
|
|
||||||
_ => anyhow::bail!("Unknown key: {}", key),
|
|
||||||
};
|
|
||||||
Ok(keycode)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl ComputerController for MacOSController {
|
|
||||||
async fn move_mouse(&self, x: i32, y: i32) -> Result<()> {
|
|
||||||
let event_source = CGEventSource::new(CGEventSourceStateID::CombinedSessionState)
|
|
||||||
.map_err(|_| anyhow::anyhow!("Failed to create event source"))?;
|
|
||||||
let point = CGPoint::new(x as f64, y as f64);
|
|
||||||
let event = CGEvent::new_mouse_event(
|
|
||||||
event_source,
|
|
||||||
CGEventType::MouseMoved,
|
|
||||||
point,
|
|
||||||
CGMouseButton::Left,
|
|
||||||
).map_err(|_| anyhow::anyhow!("Failed to create mouse move event"))?;
|
|
||||||
|
|
||||||
event.post(CGEventTapLocation::HID);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn click(&self, button: MouseButton) -> Result<()> {
|
|
||||||
let (cg_button, down_type, up_type) = match button {
|
|
||||||
MouseButton::Left => (CGMouseButton::Left, CGEventType::LeftMouseDown, CGEventType::LeftMouseUp),
|
|
||||||
MouseButton::Right => (CGMouseButton::Right, CGEventType::RightMouseDown, CGEventType::RightMouseUp),
|
|
||||||
MouseButton::Middle => (CGMouseButton::Center, CGEventType::OtherMouseDown, CGEventType::OtherMouseUp),
|
|
||||||
};
|
|
||||||
|
|
||||||
let point = {
|
|
||||||
// Get current mouse position
|
|
||||||
let temp_source = CGEventSource::new(CGEventSourceStateID::CombinedSessionState)
|
|
||||||
.map_err(|_| anyhow::anyhow!("Failed to create event source"))?;
|
|
||||||
let event = CGEvent::new(temp_source)
|
|
||||||
.map_err(|_| anyhow::anyhow!("Failed to get mouse position"))?;
|
|
||||||
let p = event.location();
|
|
||||||
p
|
|
||||||
};
|
|
||||||
|
|
||||||
{
|
|
||||||
let event_source = CGEventSource::new(CGEventSourceStateID::CombinedSessionState)
|
|
||||||
.map_err(|_| anyhow::anyhow!("Failed to create event source"))?;
|
|
||||||
|
|
||||||
// Mouse down
|
|
||||||
let down_event = CGEvent::new_mouse_event(
|
|
||||||
event_source,
|
|
||||||
down_type,
|
|
||||||
point,
|
|
||||||
cg_button,
|
|
||||||
).map_err(|_| anyhow::anyhow!("Failed to create mouse down event"))?;
|
|
||||||
down_event.post(CGEventTapLocation::HID);
|
|
||||||
} // event_source and down_event dropped here
|
|
||||||
|
|
||||||
// Small delay
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
|
||||||
|
|
||||||
{
|
|
||||||
let event_source = CGEventSource::new(CGEventSourceStateID::CombinedSessionState)
|
|
||||||
.map_err(|_| anyhow::anyhow!("Failed to create event source"))?;
|
|
||||||
|
|
||||||
let up_event = CGEvent::new_mouse_event(
|
|
||||||
event_source,
|
|
||||||
up_type,
|
|
||||||
point,
|
|
||||||
cg_button,
|
|
||||||
).map_err(|_| anyhow::anyhow!("Failed to create mouse up event"))?;
|
|
||||||
up_event.post(CGEventTapLocation::HID);
|
|
||||||
} // event_source and up_event dropped here
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn double_click(&self, button: MouseButton) -> Result<()> {
|
|
||||||
self.click(button).await?;
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
|
||||||
self.click(button).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn type_text(&self, text: &str) -> Result<()> {
|
|
||||||
for ch in text.chars() {
|
|
||||||
{
|
|
||||||
let event_source = CGEventSource::new(CGEventSourceStateID::CombinedSessionState)
|
|
||||||
.map_err(|_| anyhow::anyhow!("Failed to create event source"))?;
|
|
||||||
|
|
||||||
// Create keyboard event for character
|
|
||||||
let event = CGEvent::new_keyboard_event(
|
|
||||||
event_source,
|
|
||||||
0, // keycode (0 for unicode)
|
|
||||||
true,
|
|
||||||
).map_err(|_| anyhow::anyhow!("Failed to create keyboard event"))?;
|
|
||||||
|
|
||||||
// Set unicode string
|
|
||||||
let mut utf16_buf = [0u16; 2];
|
|
||||||
let utf16_slice = ch.encode_utf16(&mut utf16_buf);
|
|
||||||
let utf16_chars: Vec<u16> = utf16_slice.iter().copied().collect();
|
|
||||||
|
|
||||||
event.set_string_from_utf16_unchecked(utf16_chars.as_slice());
|
|
||||||
event.post(CGEventTapLocation::HID);
|
|
||||||
} // event_source and event dropped here
|
|
||||||
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn press_key(&self, key: &str) -> Result<()> {
|
|
||||||
let keycode = self.key_to_keycode(key)?;
|
|
||||||
|
|
||||||
{
|
|
||||||
let event_source = CGEventSource::new(CGEventSourceStateID::CombinedSessionState)
|
|
||||||
.map_err(|_| anyhow::anyhow!("Failed to create event source"))?;
|
|
||||||
|
|
||||||
// Key down
|
|
||||||
let down_event = CGEvent::new_keyboard_event(
|
|
||||||
event_source,
|
|
||||||
keycode,
|
|
||||||
true,
|
|
||||||
).map_err(|_| anyhow::anyhow!("Failed to create key down event"))?;
|
|
||||||
down_event.post(CGEventTapLocation::HID);
|
|
||||||
} // event_source and down_event dropped here
|
|
||||||
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
|
||||||
|
|
||||||
{
|
|
||||||
let event_source = CGEventSource::new(CGEventSourceStateID::CombinedSessionState)
|
|
||||||
.map_err(|_| anyhow::anyhow!("Failed to create event source"))?;
|
|
||||||
|
|
||||||
// Key up
|
|
||||||
let up_event = CGEvent::new_keyboard_event(
|
|
||||||
event_source,
|
|
||||||
keycode,
|
|
||||||
false,
|
|
||||||
).map_err(|_| anyhow::anyhow!("Failed to create key up event"))?;
|
|
||||||
up_event.post(CGEventTapLocation::HID);
|
|
||||||
} // event_source and up_event dropped here
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn list_windows(&self) -> Result<Vec<Window>> {
|
|
||||||
// Note: Full implementation would use CGWindowListCopyWindowInfo
|
|
||||||
// For now, return empty list as this requires more complex FFI
|
|
||||||
tracing::warn!("list_windows not fully implemented on macOS");
|
|
||||||
Ok(vec![])
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn focus_window(&self, _window_id: &str) -> Result<()> {
|
|
||||||
// Note: Full implementation would use NSWorkspace to activate application
|
|
||||||
tracing::warn!("focus_window not fully implemented on macOS");
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_window_bounds(&self, _window_id: &str) -> Result<Rect> {
|
|
||||||
// Note: Full implementation would use Accessibility API
|
|
||||||
tracing::warn!("get_window_bounds not fully implemented on macOS");
|
|
||||||
Ok(Rect { x: 0, y: 0, width: 800, height: 600 })
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn find_element(&self, _selector: &ElementSelector) -> Result<Option<UIElement>> {
|
|
||||||
// Note: Full implementation would use macOS Accessibility API
|
|
||||||
tracing::warn!("find_element not fully implemented on macOS");
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_element_text(&self, _element_id: &str) -> Result<String> {
|
|
||||||
// Note: Full implementation would use Accessibility API
|
|
||||||
tracing::warn!("get_element_text not fully implemented on macOS");
|
|
||||||
Ok(String::new())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_element_bounds(&self, _element_id: &str) -> Result<Rect> {
|
|
||||||
// Note: Full implementation would use Accessibility API
|
|
||||||
tracing::warn!("get_element_bounds not fully implemented on macOS");
|
|
||||||
Ok(Rect { x: 0, y: 0, width: 100, height: 30 })
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn take_screenshot(&self, path: &str, _region: Option<Rect>, window_id: Option<&str>) -> Result<()> {
|
|
||||||
// Use native macOS screencapture command which handles all the format complexities
|
|
||||||
|
|
||||||
// Check if we have Screen Recording permission by attempting a test capture
|
|
||||||
// If we only get wallpaper/menubar but no windows, we need permission
|
|
||||||
let needs_permission_check = std::env::var("G3_SKIP_PERMISSION_CHECK").is_err();
|
|
||||||
|
|
||||||
if needs_permission_check {
|
|
||||||
// Try to open Screen Recording settings if this is the first screenshot
|
|
||||||
static PERMISSION_PROMPTED: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
|
|
||||||
|
|
||||||
if !PERMISSION_PROMPTED.swap(true, std::sync::atomic::Ordering::Relaxed) {
|
|
||||||
tracing::warn!("\n=== Screen Recording Permission Required ===\n\
|
|
||||||
macOS requires explicit permission to capture window content.\n\
|
|
||||||
If screenshots only show wallpaper/menubar (no windows):\n\n\
|
|
||||||
1. Open System Settings > Privacy & Security > Screen Recording\n\
|
|
||||||
2. Enable permission for your terminal (iTerm/Terminal) or g3\n\
|
|
||||||
3. Restart your terminal if needed\n\n\
|
|
||||||
Opening Screen Recording settings now...\n");
|
|
||||||
|
|
||||||
// Try to open the settings (non-blocking)
|
|
||||||
let _ = std::process::Command::new("open")
|
|
||||||
.arg("x-apple.systempreferences:com.apple.preference.security?Privacy_ScreenCapture")
|
|
||||||
.spawn();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let path_obj = Path::new(path);
|
|
||||||
if let Some(parent) = path_obj.parent() {
|
|
||||||
std::fs::create_dir_all(parent)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut cmd = std::process::Command::new("screencapture");
|
|
||||||
|
|
||||||
// Add flags
|
|
||||||
cmd.arg("-x"); // No sound
|
|
||||||
|
|
||||||
if let Some(window_id) = window_id {
|
|
||||||
// Capture specific window by getting its bounds and using region capture
|
|
||||||
// window_id format: "AppName" or "AppName:WindowTitle"
|
|
||||||
let app_name = window_id.split(':').next().unwrap_or(window_id);
|
|
||||||
|
|
||||||
// Use AppleScript to get window bounds
|
|
||||||
let script = format!(
|
|
||||||
r#"tell application "{}"
|
|
||||||
tell current window
|
|
||||||
get bounds
|
|
||||||
end tell
|
|
||||||
end tell"#,
|
|
||||||
app_name
|
|
||||||
);
|
|
||||||
|
|
||||||
let output = std::process::Command::new("osascript")
|
|
||||||
.arg("-e")
|
|
||||||
.arg(&script)
|
|
||||||
.output()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to get window bounds: {}", e))?;
|
|
||||||
|
|
||||||
if output.status.success() {
|
|
||||||
let bounds_str = String::from_utf8_lossy(&output.stdout);
|
|
||||||
let bounds: Vec<i32> = bounds_str
|
|
||||||
.trim()
|
|
||||||
.split(',')
|
|
||||||
.filter_map(|s| s.trim().parse().ok())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
if bounds.len() == 4 {
|
|
||||||
let (left, top, right, bottom) = (bounds[0], bounds[1], bounds[2], bounds[3]);
|
|
||||||
let width = right - left;
|
|
||||||
let height = bottom - top;
|
|
||||||
|
|
||||||
cmd.arg("-R");
|
|
||||||
cmd.arg(format!("{},{},{},{}", left, top, width, height));
|
|
||||||
|
|
||||||
tracing::debug!("Capturing window '{}' at region: {},{} {}x{}", app_name, left, top, width, height);
|
|
||||||
} else {
|
|
||||||
tracing::warn!("Failed to parse window bounds, capturing full screen");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tracing::warn!("Failed to get window bounds for '{}', capturing full screen", app_name);
|
|
||||||
}
|
|
||||||
} else if let Some(region) = _region {
|
|
||||||
// Capture specific region: -R x,y,width,height
|
|
||||||
cmd.arg("-R");
|
|
||||||
cmd.arg(format!("{},{},{},{}", region.x, region.y, region.width, region.height));
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd.arg(path);
|
|
||||||
|
|
||||||
let output = cmd.output()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to execute screencapture: {}", e))?;
|
|
||||||
|
|
||||||
if !output.status.success() {
|
|
||||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
|
||||||
anyhow::bail!("screencapture failed: {}", stderr);
|
|
||||||
}
|
|
||||||
|
|
||||||
tracing::debug!("Screenshot saved using screencapture: {}", path);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn extract_text_from_screen(&self, region: Rect) -> Result<OCRResult> {
|
|
||||||
// Take screenshot of region first
|
|
||||||
let temp_path = format!("/tmp/g3_ocr_{}.png", uuid::Uuid::new_v4());
|
|
||||||
self.take_screenshot(&temp_path, Some(region), None).await?;
|
|
||||||
|
|
||||||
// Extract text from the screenshot
|
|
||||||
let result = self.extract_text_from_image(&temp_path).await?;
|
|
||||||
|
|
||||||
// Clean up temp file
|
|
||||||
let _ = std::fs::remove_file(&temp_path);
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn extract_text_from_image(&self, _path: &str) -> Result<OCRResult> {
|
|
||||||
// Check if tesseract is available on the system
|
|
||||||
let tesseract_check = std::process::Command::new("which")
|
|
||||||
.arg("tesseract")
|
|
||||||
.output();
|
|
||||||
|
|
||||||
if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() {
|
|
||||||
anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\
|
|
||||||
To install tesseract:\n macOS: brew install tesseract\n \
|
|
||||||
Linux: sudo apt-get install tesseract-ocr (Ubuntu/Debian)\n \
|
|
||||||
sudo yum install tesseract (RHEL/CentOS)\n \
|
|
||||||
Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki\n\n\
|
|
||||||
After installation, restart your terminal and try again.");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize Tesseract
|
|
||||||
let tess = Tesseract::new(None, Some("eng"))
|
|
||||||
.map_err(|e| {
|
|
||||||
anyhow::anyhow!("Failed to initialize Tesseract: {}\n\n\
|
|
||||||
This usually means:\n1. Tesseract is not properly installed\n\
|
|
||||||
2. Language data files are missing\n\nTo fix:\n \
|
|
||||||
macOS: brew reinstall tesseract\n \
|
|
||||||
Linux: sudo apt-get install tesseract-ocr-eng\n \
|
|
||||||
Windows: Reinstall tesseract and ensure language files are included", e)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let text = tess.set_image(_path)
|
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to load image '{}': {}", _path, e))?
|
|
||||||
.get_text()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to extract text from image: {}", e))?;
|
|
||||||
|
|
||||||
// Get confidence (simplified - would need more complex API calls for per-word confidence)
|
|
||||||
let confidence = 0.85; // Placeholder
|
|
||||||
|
|
||||||
Ok(OCRResult {
|
|
||||||
text,
|
|
||||||
confidence,
|
|
||||||
bounds: Rect { x: 0, y: 0, width: 0, height: 0 }, // Would need image dimensions
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn find_text_on_screen(&self, _text: &str) -> Result<Option<Point>> {
|
|
||||||
// Check if tesseract is available on the system
|
|
||||||
let tesseract_check = std::process::Command::new("which")
|
|
||||||
.arg("tesseract")
|
|
||||||
.output();
|
|
||||||
|
|
||||||
if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() {
|
|
||||||
anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\
|
|
||||||
To install tesseract:\n macOS: brew install tesseract\n \
|
|
||||||
Linux: sudo apt-get install tesseract-ocr (Ubuntu/Debian)\n \
|
|
||||||
sudo yum install tesseract (RHEL/CentOS)\n \
|
|
||||||
Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki\n\n\
|
|
||||||
After installation, restart your terminal and try again.");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Take full screen screenshot
|
|
||||||
let temp_path = format!("/tmp/g3_ocr_search_{}.png", uuid::Uuid::new_v4());
|
|
||||||
self.take_screenshot(&temp_path, None, None).await?;
|
|
||||||
|
|
||||||
// Use Tesseract to find text with bounding boxes
|
|
||||||
let tess = Tesseract::new(None, Some("eng"))
|
|
||||||
.map_err(|e| {
|
|
||||||
anyhow::anyhow!("Failed to initialize Tesseract: {}\n\n\
|
|
||||||
This usually means:\n1. Tesseract is not properly installed\n\
|
|
||||||
2. Language data files are missing\n\nTo fix:\n \
|
|
||||||
macOS: brew reinstall tesseract\n \
|
|
||||||
Linux: sudo apt-get install tesseract-ocr-eng\n \
|
|
||||||
Windows: Reinstall tesseract and ensure language files are included", e)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let full_text = tess.set_image(temp_path.as_str())
|
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to load screenshot: {}", e))?
|
|
||||||
.get_text()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to extract text from screen: {}", e))?;
|
|
||||||
|
|
||||||
// Clean up temp file
|
|
||||||
let _ = std::fs::remove_file(&temp_path);
|
|
||||||
|
|
||||||
// Simple text search - full implementation would use get_component_images
|
|
||||||
// to get bounding boxes for each word
|
|
||||||
if full_text.contains(_text) {
|
|
||||||
tracing::warn!("Text found but precise coordinates not available in simplified implementation");
|
|
||||||
Ok(Some(Point { x: 0, y: 0 }))
|
|
||||||
} else {
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
#[cfg(test)]
|
||||||
|
mod window_matching_tests {
|
||||||
|
/// Test that window name matching handles spaces correctly
|
||||||
|
///
|
||||||
|
/// Issue: When a user requests a screenshot of "Goose Studio" but the actual
|
||||||
|
/// application name is "GooseStudio" (no space), the fuzzy matching should
|
||||||
|
/// still find the window.
|
||||||
|
///
|
||||||
|
/// The fix normalizes both names by removing spaces before comparing.
|
||||||
|
#[test]
|
||||||
|
fn test_space_normalization() {
|
||||||
|
let test_cases = vec![
|
||||||
|
// (user_input, actual_app_name, should_match)
|
||||||
|
("Goose Studio", "GooseStudio", true),
|
||||||
|
("GooseStudio", "Goose Studio", true),
|
||||||
|
("Visual Studio Code", "VisualStudioCode", true),
|
||||||
|
("Google Chrome", "Google Chrome", true),
|
||||||
|
("Safari", "Safari", true),
|
||||||
|
("iTerm", "iTerm2", true), // fuzzy match
|
||||||
|
("Code", "Visual Studio Code", true), // fuzzy match
|
||||||
|
];
|
||||||
|
|
||||||
|
for (user_input, app_name, should_match) in test_cases {
|
||||||
|
let user_lower = user_input.to_lowercase();
|
||||||
|
let app_lower = app_name.to_lowercase();
|
||||||
|
|
||||||
|
let user_normalized = user_lower.replace(" ", "");
|
||||||
|
let app_normalized = app_lower.replace(" ", "");
|
||||||
|
|
||||||
|
let is_exact = app_lower == user_lower || app_normalized == user_normalized;
|
||||||
|
let is_fuzzy = app_lower.contains(&user_lower)
|
||||||
|
|| user_lower.contains(&app_lower)
|
||||||
|
|| app_normalized.contains(&user_normalized)
|
||||||
|
|| user_normalized.contains(&app_normalized);
|
||||||
|
|
||||||
|
let matches = is_exact || is_fuzzy;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
matches, should_match,
|
||||||
|
"Expected '{}' vs '{}' to match={}, but got match={}",
|
||||||
|
user_input, app_name, should_match, matches
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -62,10 +62,15 @@ impl ComputerController for WindowsController {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn take_screenshot(&self, _path: &str, _region: Option<Rect>, _window_id: Option<&str>) -> Result<()> {
|
async fn take_screenshot(&self, _path: &str, _region: Option<Rect>, _window_id: Option<&str>) -> Result<()> {
|
||||||
|
// Enforce that window_id must be provided
|
||||||
|
if _window_id.is_none() {
|
||||||
|
anyhow::bail!("window_id is required. You must specify which window to capture (e.g., 'Chrome', 'Terminal', 'Notepad'). Use list_windows to see available windows.");
|
||||||
|
}
|
||||||
|
|
||||||
anyhow::bail!("Windows implementation not yet available")
|
anyhow::bail!("Windows implementation not yet available")
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn extract_text_from_screen(&self, _region: Rect) -> Result<OCRResult> {
|
async fn extract_text_from_screen(&self, _region: Rect, _window_id: &str) -> Result<String> {
|
||||||
anyhow::bail!("Windows implementation not yet available")
|
anyhow::bail!("Windows implementation not yet available")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,3 +7,13 @@ pub struct Rect {
|
|||||||
pub width: i32,
|
pub width: i32,
|
||||||
pub height: i32,
|
pub height: i32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct TextLocation {
|
||||||
|
pub text: String,
|
||||||
|
pub x: i32,
|
||||||
|
pub y: i32,
|
||||||
|
pub width: i32,
|
||||||
|
pub height: i32,
|
||||||
|
pub confidence: f32,
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,23 +1,5 @@
|
|||||||
use g3_computer_control::*;
|
use g3_computer_control::*;
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_mouse_movement() {
|
|
||||||
let controller = create_controller().expect("Failed to create controller");
|
|
||||||
|
|
||||||
// Move mouse to center of screen (assuming 1920x1080)
|
|
||||||
let result = controller.move_mouse(960, 540).await;
|
|
||||||
assert!(result.is_ok(), "Failed to move mouse: {:?}", result.err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_typing() {
|
|
||||||
let controller = create_controller().expect("Failed to create controller");
|
|
||||||
|
|
||||||
// Type some text
|
|
||||||
let result = controller.type_text("Hello, World!").await;
|
|
||||||
assert!(result.is_ok(), "Failed to type text: {:?}", result.err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_screenshot() {
|
async fn test_screenshot() {
|
||||||
let controller = create_controller().expect("Failed to create controller");
|
let controller = create_controller().expect("Failed to create controller");
|
||||||
@@ -33,30 +15,3 @@ async fn test_screenshot() {
|
|||||||
// Clean up
|
// Clean up
|
||||||
let _ = std::fs::remove_file(path);
|
let _ = std::fs::remove_file(path);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_click() {
|
|
||||||
let controller = create_controller().expect("Failed to create controller");
|
|
||||||
|
|
||||||
// Click at a safe location
|
|
||||||
let result = controller.click(types::MouseButton::Left).await;
|
|
||||||
assert!(result.is_ok(), "Failed to click: {:?}", result.err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_double_click() {
|
|
||||||
let controller = create_controller().expect("Failed to create controller");
|
|
||||||
|
|
||||||
// Double click
|
|
||||||
let result = controller.double_click(types::MouseButton::Left).await;
|
|
||||||
assert!(result.is_ok(), "Failed to double click: {:?}", result.err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_press_key() {
|
|
||||||
let controller = create_controller().expect("Failed to create controller");
|
|
||||||
|
|
||||||
// Press escape key
|
|
||||||
let result = controller.press_key("escape").await;
|
|
||||||
assert!(result.is_ok(), "Failed to press key: {:?}", result.err());
|
|
||||||
}
|
|
||||||
|
|||||||
24
crates/g3-computer-control/vision-bridge/Package.swift
Normal file
24
crates/g3-computer-control/vision-bridge/Package.swift
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
// swift-tools-version:5.9
|
||||||
|
import PackageDescription
|
||||||
|
|
||||||
|
let package = Package(
|
||||||
|
name: "VisionBridge",
|
||||||
|
platforms: [
|
||||||
|
.macOS(.v11)
|
||||||
|
],
|
||||||
|
products: [
|
||||||
|
.library(
|
||||||
|
name: "VisionBridge",
|
||||||
|
type: .dynamic,
|
||||||
|
targets: ["VisionBridge"]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
targets: [
|
||||||
|
.target(
|
||||||
|
name: "VisionBridge",
|
||||||
|
dependencies: [],
|
||||||
|
path: "Sources/VisionBridge",
|
||||||
|
publicHeadersPath: "."
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
#ifndef VisionBridge_h
|
||||||
|
#define VisionBridge_h
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Text box structure for FFI
|
||||||
|
typedef struct {
|
||||||
|
const char* text;
|
||||||
|
uint32_t text_len;
|
||||||
|
int32_t x;
|
||||||
|
int32_t y;
|
||||||
|
int32_t width;
|
||||||
|
int32_t height;
|
||||||
|
float confidence;
|
||||||
|
} VisionTextBox;
|
||||||
|
|
||||||
|
// Recognize text in an image and return bounding boxes
|
||||||
|
// Returns true on success, false on failure
|
||||||
|
// Caller must free the returned boxes using vision_free_boxes
|
||||||
|
bool vision_recognize_text(
|
||||||
|
const char* image_path,
|
||||||
|
uint32_t image_path_len,
|
||||||
|
VisionTextBox** out_boxes,
|
||||||
|
uint32_t* out_count
|
||||||
|
);
|
||||||
|
|
||||||
|
// Free memory allocated by vision_recognize_text
|
||||||
|
void vision_free_boxes(VisionTextBox* boxes, uint32_t count);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif /* VisionBridge_h */
|
||||||
@@ -0,0 +1,145 @@
|
|||||||
|
import Foundation
|
||||||
|
import Vision
|
||||||
|
import AppKit
|
||||||
|
import CoreGraphics
|
||||||
|
|
||||||
|
// MARK: - C Bridge Functions
|
||||||
|
|
||||||
|
@_cdecl("vision_recognize_text")
|
||||||
|
public func vision_recognize_text(
|
||||||
|
_ imagePath: UnsafePointer<CChar>,
|
||||||
|
_ imagePathLen: UInt32,
|
||||||
|
_ outBoxes: UnsafeMutablePointer<UnsafeMutableRawPointer?>,
|
||||||
|
_ outCount: UnsafeMutablePointer<UInt32>
|
||||||
|
) -> Bool {
|
||||||
|
// Convert C string to Swift String
|
||||||
|
guard let pathData = Data(bytes: imagePath, count: Int(imagePathLen)).withUnsafeBytes({
|
||||||
|
String(bytes: $0, encoding: .utf8)
|
||||||
|
}) else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
let path = pathData.trimmingCharacters(in: .whitespaces)
|
||||||
|
|
||||||
|
// Load image
|
||||||
|
guard let image = NSImage(contentsOfFile: path),
|
||||||
|
let cgImage = image.cgImage(forProposedRect: nil, context: nil, hints: nil) else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform OCR
|
||||||
|
var textBoxes: [CTextBox] = []
|
||||||
|
let semaphore = DispatchSemaphore(value: 0)
|
||||||
|
var success = false
|
||||||
|
|
||||||
|
let request = VNRecognizeTextRequest { request, error in
|
||||||
|
defer { semaphore.signal() }
|
||||||
|
|
||||||
|
if let error = error {
|
||||||
|
print("Vision OCR error: \(error.localizedDescription)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
guard let observations = request.results as? [VNRecognizedTextObservation] else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
let imageSize = CGSize(width: cgImage.width, height: cgImage.height)
|
||||||
|
|
||||||
|
for observation in observations {
|
||||||
|
guard let candidate = observation.topCandidates(1).first else { continue }
|
||||||
|
|
||||||
|
let text = candidate.string
|
||||||
|
let boundingBox = observation.boundingBox
|
||||||
|
|
||||||
|
// Convert normalized coordinates (bottom-left origin) to pixel coordinates (top-left origin)
|
||||||
|
let x = Int32(boundingBox.origin.x * imageSize.width)
|
||||||
|
let y = Int32((1.0 - boundingBox.origin.y - boundingBox.height) * imageSize.height)
|
||||||
|
let width = Int32(boundingBox.width * imageSize.width)
|
||||||
|
let height = Int32(boundingBox.height * imageSize.height)
|
||||||
|
|
||||||
|
// Allocate C string for text
|
||||||
|
let cString = strdup(text)
|
||||||
|
|
||||||
|
textBoxes.append(CTextBox(
|
||||||
|
text: cString,
|
||||||
|
text_len: UInt32(text.utf8.count),
|
||||||
|
x: x,
|
||||||
|
y: y,
|
||||||
|
width: width,
|
||||||
|
height: height,
|
||||||
|
confidence: observation.confidence
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
success = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure request for best accuracy
|
||||||
|
request.recognitionLevel = .accurate
|
||||||
|
request.usesLanguageCorrection = true
|
||||||
|
request.recognitionLanguages = ["en-US"]
|
||||||
|
|
||||||
|
// Perform request
|
||||||
|
let handler = VNImageRequestHandler(cgImage: cgImage, options: [:])
|
||||||
|
do {
|
||||||
|
try handler.perform([request])
|
||||||
|
} catch {
|
||||||
|
print("Vision request failed: \(error.localizedDescription)")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for completion
|
||||||
|
semaphore.wait()
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate array for results
|
||||||
|
let boxesPtr = UnsafeMutablePointer<CTextBox>.allocate(capacity: textBoxes.count)
|
||||||
|
for (index, box) in textBoxes.enumerated() {
|
||||||
|
boxesPtr[index] = box
|
||||||
|
}
|
||||||
|
|
||||||
|
outBoxes.pointee = UnsafeMutableRawPointer(boxesPtr)
|
||||||
|
outCount.pointee = UInt32(textBoxes.count)
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
@_cdecl("vision_free_boxes")
|
||||||
|
public func vision_free_boxes(
|
||||||
|
_ boxes: UnsafeMutableRawPointer,
|
||||||
|
_ count: UInt32
|
||||||
|
) {
|
||||||
|
let typedBoxes = boxes.assumingMemoryBound(to: CTextBox.self)
|
||||||
|
for i in 0..<Int(count) {
|
||||||
|
if let text = typedBoxes[i].text {
|
||||||
|
free(UnsafeMutableRawPointer(mutating: text))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
typedBoxes.deallocate()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - C-Compatible Structure
|
||||||
|
|
||||||
|
public struct CTextBox {
|
||||||
|
public let text: UnsafePointer<CChar>?
|
||||||
|
public let text_len: UInt32
|
||||||
|
public let x: Int32
|
||||||
|
public let y: Int32
|
||||||
|
public let width: Int32
|
||||||
|
public let height: Int32
|
||||||
|
public let confidence: Float
|
||||||
|
|
||||||
|
public init(text: UnsafePointer<CChar>?, text_len: UInt32, x: Int32, y: Int32, width: Int32, height: Int32, confidence: Float) {
|
||||||
|
self.text = text
|
||||||
|
self.text_len = text_len
|
||||||
|
self.x = x
|
||||||
|
self.y = y
|
||||||
|
self.width = width
|
||||||
|
self.height = height
|
||||||
|
self.confidence = confidence
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,3 +12,6 @@ thiserror = { workspace = true }
|
|||||||
toml = "0.8"
|
toml = "0.8"
|
||||||
shellexpand = "3.0"
|
shellexpand = "3.0"
|
||||||
dirs = "5.0"
|
dirs = "5.0"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
tempfile = "3.8"
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ pub struct Config {
|
|||||||
pub agent: AgentConfig,
|
pub agent: AgentConfig,
|
||||||
pub computer_control: ComputerControlConfig,
|
pub computer_control: ComputerControlConfig,
|
||||||
pub webdriver: WebDriverConfig,
|
pub webdriver: WebDriverConfig,
|
||||||
|
pub macax: MacAxConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -16,7 +17,10 @@ pub struct ProvidersConfig {
|
|||||||
pub anthropic: Option<AnthropicConfig>,
|
pub anthropic: Option<AnthropicConfig>,
|
||||||
pub databricks: Option<DatabricksConfig>,
|
pub databricks: Option<DatabricksConfig>,
|
||||||
pub embedded: Option<EmbeddedConfig>,
|
pub embedded: Option<EmbeddedConfig>,
|
||||||
|
pub ollama: Option<OllamaConfig>,
|
||||||
pub default_provider: String,
|
pub default_provider: String,
|
||||||
|
pub coach: Option<String>, // Provider to use for coach in autonomous mode
|
||||||
|
pub player: Option<String>, // Provider to use for player in autonomous mode
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -57,11 +61,20 @@ pub struct EmbeddedConfig {
|
|||||||
pub threads: Option<u32>, // Number of CPU threads to use
|
pub threads: Option<u32>, // Number of CPU threads to use
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct OllamaConfig {
|
||||||
|
pub model: String,
|
||||||
|
pub base_url: Option<String>, // Default: http://localhost:11434
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct AgentConfig {
|
pub struct AgentConfig {
|
||||||
pub max_context_length: usize,
|
pub max_context_length: usize,
|
||||||
pub enable_streaming: bool,
|
pub enable_streaming: bool,
|
||||||
pub timeout_seconds: u64,
|
pub timeout_seconds: u64,
|
||||||
|
pub auto_compact: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -77,6 +90,19 @@ pub struct WebDriverConfig {
|
|||||||
pub safari_port: u16,
|
pub safari_port: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct MacAxConfig {
|
||||||
|
pub enabled: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for MacAxConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
enabled: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Default for WebDriverConfig {
|
impl Default for WebDriverConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -111,15 +137,20 @@ impl Default for Config {
|
|||||||
use_oauth: Some(true),
|
use_oauth: Some(true),
|
||||||
}),
|
}),
|
||||||
embedded: None,
|
embedded: None,
|
||||||
|
ollama: None,
|
||||||
default_provider: "databricks".to_string(),
|
default_provider: "databricks".to_string(),
|
||||||
|
coach: None, // Will use default_provider if not specified
|
||||||
|
player: None, // Will use default_provider if not specified
|
||||||
},
|
},
|
||||||
agent: AgentConfig {
|
agent: AgentConfig {
|
||||||
max_context_length: 8192,
|
max_context_length: 8192,
|
||||||
enable_streaming: true,
|
enable_streaming: true,
|
||||||
timeout_seconds: 60,
|
timeout_seconds: 60,
|
||||||
|
auto_compact: true,
|
||||||
},
|
},
|
||||||
computer_control: ComputerControlConfig::default(),
|
computer_control: ComputerControlConfig::default(),
|
||||||
webdriver: WebDriverConfig::default(),
|
webdriver: WebDriverConfig::default(),
|
||||||
|
macax: MacAxConfig::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -223,15 +254,20 @@ impl Config {
|
|||||||
gpu_layers: Some(32),
|
gpu_layers: Some(32),
|
||||||
threads: Some(8),
|
threads: Some(8),
|
||||||
}),
|
}),
|
||||||
|
ollama: None,
|
||||||
default_provider: "embedded".to_string(),
|
default_provider: "embedded".to_string(),
|
||||||
|
coach: None, // Will use default_provider if not specified
|
||||||
|
player: None, // Will use default_provider if not specified
|
||||||
},
|
},
|
||||||
agent: AgentConfig {
|
agent: AgentConfig {
|
||||||
max_context_length: 8192,
|
max_context_length: 8192,
|
||||||
enable_streaming: true,
|
enable_streaming: true,
|
||||||
timeout_seconds: 60,
|
timeout_seconds: 60,
|
||||||
|
auto_compact: true,
|
||||||
},
|
},
|
||||||
computer_control: ComputerControlConfig::default(),
|
computer_control: ComputerControlConfig::default(),
|
||||||
webdriver: WebDriverConfig::default(),
|
webdriver: WebDriverConfig::default(),
|
||||||
|
macax: MacAxConfig::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -300,4 +336,67 @@ impl Config {
|
|||||||
|
|
||||||
Ok(config)
|
Ok(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the provider to use for coach mode in autonomous execution
|
||||||
|
pub fn get_coach_provider(&self) -> &str {
|
||||||
|
self.providers.coach
|
||||||
|
.as_deref()
|
||||||
|
.unwrap_or(&self.providers.default_provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the provider to use for player mode in autonomous execution
|
||||||
|
pub fn get_player_provider(&self) -> &str {
|
||||||
|
self.providers.player
|
||||||
|
.as_deref()
|
||||||
|
.unwrap_or(&self.providers.default_provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a copy of the config with a different default provider
|
||||||
|
pub fn with_provider_override(&self, provider: &str) -> Result<Self> {
|
||||||
|
// Validate that the provider is configured
|
||||||
|
match provider {
|
||||||
|
"anthropic" if self.providers.anthropic.is_none() => {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
|
||||||
|
provider, provider
|
||||||
|
));
|
||||||
|
}
|
||||||
|
"databricks" if self.providers.databricks.is_none() => {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
|
||||||
|
provider, provider
|
||||||
|
));
|
||||||
|
}
|
||||||
|
"embedded" if self.providers.embedded.is_none() => {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
|
||||||
|
provider, provider
|
||||||
|
));
|
||||||
|
}
|
||||||
|
"openai" if self.providers.openai.is_none() => {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
|
||||||
|
provider, provider
|
||||||
|
));
|
||||||
|
}
|
||||||
|
_ => {} // Provider is configured or unknown (will be caught later)
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut config = self.clone();
|
||||||
|
config.providers.default_provider = provider.to_string();
|
||||||
|
Ok(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a copy of the config for coach mode in autonomous execution
|
||||||
|
pub fn for_coach(&self) -> Result<Self> {
|
||||||
|
self.with_provider_override(self.get_coach_provider())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a copy of the config for player mode in autonomous execution
|
||||||
|
pub fn for_player(&self) -> Result<Self> {
|
||||||
|
self.with_provider_override(self.get_player_provider())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests;
|
||||||
|
|||||||
131
crates/g3-config/src/tests.rs
Normal file
131
crates/g3-config/src/tests.rs
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::Config;
|
||||||
|
use std::fs;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_coach_player_providers() {
|
||||||
|
// Create a temporary directory for the test config
|
||||||
|
let temp_dir = TempDir::new().unwrap();
|
||||||
|
let config_path = temp_dir.path().join("test_config.toml");
|
||||||
|
|
||||||
|
// Write a test configuration with coach and player providers
|
||||||
|
let config_content = r#"
|
||||||
|
[providers]
|
||||||
|
default_provider = "databricks"
|
||||||
|
coach = "anthropic"
|
||||||
|
player = "embedded"
|
||||||
|
|
||||||
|
[providers.databricks]
|
||||||
|
host = "https://test.databricks.com"
|
||||||
|
token = "test-token"
|
||||||
|
model = "test-model"
|
||||||
|
|
||||||
|
[providers.anthropic]
|
||||||
|
api_key = "test-key"
|
||||||
|
model = "claude-3"
|
||||||
|
|
||||||
|
[providers.embedded]
|
||||||
|
model_path = "test.gguf"
|
||||||
|
model_type = "llama"
|
||||||
|
|
||||||
|
[agent]
|
||||||
|
max_context_length = 8192
|
||||||
|
enable_streaming = true
|
||||||
|
timeout_seconds = 60
|
||||||
|
"#;
|
||||||
|
|
||||||
|
fs::write(&config_path, config_content).unwrap();
|
||||||
|
|
||||||
|
// Load the configuration
|
||||||
|
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
|
||||||
|
|
||||||
|
// Test that the providers are correctly identified
|
||||||
|
assert_eq!(config.providers.default_provider, "databricks");
|
||||||
|
assert_eq!(config.get_coach_provider(), "anthropic");
|
||||||
|
assert_eq!(config.get_player_provider(), "embedded");
|
||||||
|
|
||||||
|
// Test creating coach config
|
||||||
|
let coach_config = config.for_coach().unwrap();
|
||||||
|
assert_eq!(coach_config.providers.default_provider, "anthropic");
|
||||||
|
|
||||||
|
// Test creating player config
|
||||||
|
let player_config = config.for_player().unwrap();
|
||||||
|
assert_eq!(player_config.providers.default_provider, "embedded");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_coach_player_fallback_to_default() {
|
||||||
|
// Create a temporary directory for the test config
|
||||||
|
let temp_dir = TempDir::new().unwrap();
|
||||||
|
let config_path = temp_dir.path().join("test_config.toml");
|
||||||
|
|
||||||
|
// Write a test configuration WITHOUT coach and player providers
|
||||||
|
let config_content = r#"
|
||||||
|
[providers]
|
||||||
|
default_provider = "databricks"
|
||||||
|
|
||||||
|
[providers.databricks]
|
||||||
|
host = "https://test.databricks.com"
|
||||||
|
token = "test-token"
|
||||||
|
model = "test-model"
|
||||||
|
|
||||||
|
[agent]
|
||||||
|
max_context_length = 8192
|
||||||
|
enable_streaming = true
|
||||||
|
timeout_seconds = 60
|
||||||
|
"#;
|
||||||
|
|
||||||
|
fs::write(&config_path, config_content).unwrap();
|
||||||
|
|
||||||
|
// Load the configuration
|
||||||
|
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
|
||||||
|
|
||||||
|
// Test that coach and player fall back to default provider
|
||||||
|
assert_eq!(config.get_coach_provider(), "databricks");
|
||||||
|
assert_eq!(config.get_player_provider(), "databricks");
|
||||||
|
|
||||||
|
// Test creating coach config (should use default)
|
||||||
|
let coach_config = config.for_coach().unwrap();
|
||||||
|
assert_eq!(coach_config.providers.default_provider, "databricks");
|
||||||
|
|
||||||
|
// Test creating player config (should use default)
|
||||||
|
let player_config = config.for_player().unwrap();
|
||||||
|
assert_eq!(player_config.providers.default_provider, "databricks");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_invalid_provider_error() {
|
||||||
|
// Create a temporary directory for the test config
|
||||||
|
let temp_dir = TempDir::new().unwrap();
|
||||||
|
let config_path = temp_dir.path().join("test_config.toml");
|
||||||
|
|
||||||
|
// Write a test configuration with an unconfigured provider
|
||||||
|
let config_content = r#"
|
||||||
|
[providers]
|
||||||
|
default_provider = "databricks"
|
||||||
|
coach = "openai" # OpenAI is not configured
|
||||||
|
|
||||||
|
[providers.databricks]
|
||||||
|
host = "https://test.databricks.com"
|
||||||
|
token = "test-token"
|
||||||
|
model = "test-model"
|
||||||
|
|
||||||
|
[agent]
|
||||||
|
max_context_length = 8192
|
||||||
|
enable_streaming = true
|
||||||
|
timeout_seconds = 60
|
||||||
|
"#;
|
||||||
|
|
||||||
|
fs::write(&config_path, config_content).unwrap();
|
||||||
|
|
||||||
|
// Load the configuration
|
||||||
|
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
|
||||||
|
|
||||||
|
// Test that trying to create a coach config with unconfigured provider fails
|
||||||
|
let result = config.for_coach();
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().to_string().contains("not configured"));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -25,3 +25,4 @@ chrono = { version = "0.4", features = ["serde"] }
|
|||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
regex = "1.0"
|
regex = "1.0"
|
||||||
shellexpand = "3.1"
|
shellexpand = "3.1"
|
||||||
|
serde_yaml = "0.9"
|
||||||
|
|||||||
787
crates/g3-core/src/code_search.rs
Normal file
787
crates/g3-core/src/code_search.rs
Normal file
@@ -0,0 +1,787 @@
|
|||||||
|
//! Code search functionality using ast-grep for syntax-aware semantic searches
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::process::Stdio;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||||
|
use tokio::process::Command;
|
||||||
|
use tokio::sync::Semaphore;
|
||||||
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
|
/// Maximum number of searches allowed per request
|
||||||
|
const MAX_SEARCHES: usize = 20;
|
||||||
|
|
||||||
|
/// Default timeout for individual searches in seconds
|
||||||
|
const DEFAULT_TIMEOUT_SECS: u64 = 60;
|
||||||
|
|
||||||
|
/// Default maximum concurrency
|
||||||
|
const DEFAULT_MAX_CONCURRENCY: usize = 4;
|
||||||
|
|
||||||
|
/// Default maximum matches per search
|
||||||
|
const DEFAULT_MAX_MATCHES: usize = 500;
|
||||||
|
|
||||||
|
/// Search specification for a single ast-grep search
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct SearchSpec {
|
||||||
|
pub name: String,
|
||||||
|
pub mode: SearchMode,
|
||||||
|
|
||||||
|
// Pattern mode fields
|
||||||
|
pub pattern: Option<String>,
|
||||||
|
pub language: Option<String>,
|
||||||
|
|
||||||
|
// YAML mode fields
|
||||||
|
pub rule_yaml: Option<String>,
|
||||||
|
|
||||||
|
// Common fields
|
||||||
|
pub paths: Option<Vec<String>>,
|
||||||
|
pub globs: Option<Vec<String>>,
|
||||||
|
pub json_style: Option<JsonStyle>,
|
||||||
|
pub context: Option<u32>,
|
||||||
|
pub threads: Option<u32>,
|
||||||
|
pub include_metadata: Option<bool>,
|
||||||
|
pub no_ignore: Option<Vec<NoIgnoreType>>,
|
||||||
|
pub severity: Option<HashMap<String, SeverityLevel>>,
|
||||||
|
pub timeout_secs: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Search mode: pattern or yaml
|
||||||
|
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum SearchMode {
|
||||||
|
Pattern,
|
||||||
|
Yaml,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// JSON output style
|
||||||
|
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum JsonStyle {
|
||||||
|
Pretty,
|
||||||
|
Stream,
|
||||||
|
Compact,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for JsonStyle {
|
||||||
|
fn default() -> Self {
|
||||||
|
JsonStyle::Stream
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// No-ignore types
|
||||||
|
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum NoIgnoreType {
|
||||||
|
Hidden,
|
||||||
|
Dot,
|
||||||
|
Exclude,
|
||||||
|
Global,
|
||||||
|
Parent,
|
||||||
|
Vcs,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Severity levels for YAML rules
|
||||||
|
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum SeverityLevel {
|
||||||
|
Error,
|
||||||
|
Warning,
|
||||||
|
Info,
|
||||||
|
Hint,
|
||||||
|
Off,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Request structure for code search
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct CodeSearchRequest {
|
||||||
|
pub searches: Vec<SearchSpec>,
|
||||||
|
pub max_concurrency: Option<usize>,
|
||||||
|
pub max_matches_per_search: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of a single search
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct SearchResult {
|
||||||
|
pub name: String,
|
||||||
|
pub mode: String,
|
||||||
|
pub status: String,
|
||||||
|
pub cmd: Vec<String>,
|
||||||
|
pub match_count: Option<usize>,
|
||||||
|
pub truncated: Option<bool>,
|
||||||
|
pub matches: Option<Vec<Value>>,
|
||||||
|
pub stderr: Option<String>,
|
||||||
|
pub exit_code: Option<i32>,
|
||||||
|
pub duration_ms: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Summary of all searches
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct SearchSummary {
|
||||||
|
pub completed: usize,
|
||||||
|
pub total: usize,
|
||||||
|
pub total_matches: usize,
|
||||||
|
pub duration_ms: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Complete response structure
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct CodeSearchResponse {
|
||||||
|
pub summary: SearchSummary,
|
||||||
|
pub searches: Vec<SearchResult>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// YAML rule structure for validation
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct YamlRule {
|
||||||
|
pub id: String,
|
||||||
|
pub language: String,
|
||||||
|
pub rule: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute a batch of code searches using ast-grep
|
||||||
|
pub async fn execute_code_search(request: CodeSearchRequest) -> Result<CodeSearchResponse> {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
// Validate request
|
||||||
|
if request.searches.is_empty() {
|
||||||
|
return Err(anyhow!("No searches specified"));
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.searches.len() > MAX_SEARCHES {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"Too many searches: {} (max: {})",
|
||||||
|
request.searches.len(),
|
||||||
|
MAX_SEARCHES
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if ast-grep is available
|
||||||
|
check_ast_grep_available().await?;
|
||||||
|
|
||||||
|
let max_concurrency = request.max_concurrency.unwrap_or(DEFAULT_MAX_CONCURRENCY);
|
||||||
|
let max_matches = request.max_matches_per_search.unwrap_or(DEFAULT_MAX_MATCHES);
|
||||||
|
|
||||||
|
// Create semaphore for concurrency control
|
||||||
|
let semaphore = std::sync::Arc::new(Semaphore::new(max_concurrency));
|
||||||
|
|
||||||
|
// Execute searches concurrently
|
||||||
|
let mut tasks = Vec::new();
|
||||||
|
|
||||||
|
for search in request.searches {
|
||||||
|
let sem = semaphore.clone();
|
||||||
|
let task = tokio::spawn(async move {
|
||||||
|
let _permit = sem.acquire().await.unwrap();
|
||||||
|
execute_single_search(search, max_matches).await
|
||||||
|
});
|
||||||
|
tasks.push(task);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all searches to complete
|
||||||
|
let mut results = Vec::new();
|
||||||
|
let mut total_matches = 0;
|
||||||
|
let mut completed = 0;
|
||||||
|
|
||||||
|
for task in tasks {
|
||||||
|
match task.await {
|
||||||
|
Ok(result) => {
|
||||||
|
if result.status == "ok" {
|
||||||
|
completed += 1;
|
||||||
|
if let Some(count) = result.match_count {
|
||||||
|
total_matches += count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results.push(result);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Task join error: {}", e);
|
||||||
|
// Create an error result
|
||||||
|
results.push(SearchResult {
|
||||||
|
name: "unknown".to_string(),
|
||||||
|
mode: "unknown".to_string(),
|
||||||
|
status: "error".to_string(),
|
||||||
|
cmd: vec![],
|
||||||
|
match_count: None,
|
||||||
|
truncated: None,
|
||||||
|
matches: None,
|
||||||
|
stderr: Some(format!("Task execution error: {}", e)),
|
||||||
|
exit_code: None,
|
||||||
|
duration_ms: 0,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let total_duration = start_time.elapsed();
|
||||||
|
|
||||||
|
Ok(CodeSearchResponse {
|
||||||
|
summary: SearchSummary {
|
||||||
|
completed,
|
||||||
|
total: results.len(),
|
||||||
|
total_matches,
|
||||||
|
duration_ms: total_duration.as_millis() as u64,
|
||||||
|
},
|
||||||
|
searches: results,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute a single search
|
||||||
|
async fn execute_single_search(search: SearchSpec, max_matches: usize) -> SearchResult {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
let timeout_secs = search.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
|
||||||
|
|
||||||
|
// Validate the search specification
|
||||||
|
if let Err(e) = validate_search_spec(&search) {
|
||||||
|
return SearchResult {
|
||||||
|
name: search.name,
|
||||||
|
mode: format!("{:?}", search.mode).to_lowercase(),
|
||||||
|
status: "error".to_string(),
|
||||||
|
cmd: vec![],
|
||||||
|
match_count: None,
|
||||||
|
truncated: None,
|
||||||
|
matches: None,
|
||||||
|
stderr: Some(format!("Validation error: {}", e)),
|
||||||
|
exit_code: None,
|
||||||
|
duration_ms: start_time.elapsed().as_millis() as u64,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build command
|
||||||
|
let cmd_args = match build_ast_grep_command(&search) {
|
||||||
|
Ok(args) => args,
|
||||||
|
Err(e) => {
|
||||||
|
return SearchResult {
|
||||||
|
name: search.name,
|
||||||
|
mode: format!("{:?}", search.mode).to_lowercase(),
|
||||||
|
status: "error".to_string(),
|
||||||
|
cmd: vec![],
|
||||||
|
match_count: None,
|
||||||
|
truncated: None,
|
||||||
|
matches: None,
|
||||||
|
stderr: Some(format!("Command build error: {}", e)),
|
||||||
|
exit_code: None,
|
||||||
|
duration_ms: start_time.elapsed().as_millis() as u64,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!("Executing ast-grep command: {:?}", cmd_args);
|
||||||
|
|
||||||
|
// Execute with timeout
|
||||||
|
let timeout_duration = Duration::from_secs(timeout_secs);
|
||||||
|
|
||||||
|
match tokio::time::timeout(timeout_duration, run_ast_grep_command(&cmd_args)).await {
|
||||||
|
Ok(Ok((stdout, stderr, exit_code))) => {
|
||||||
|
let duration_ms = start_time.elapsed().as_millis() as u64;
|
||||||
|
|
||||||
|
if exit_code == 0 {
|
||||||
|
// Parse JSON output
|
||||||
|
match parse_ast_grep_output(&stdout, max_matches) {
|
||||||
|
Ok((matches, truncated)) => {
|
||||||
|
SearchResult {
|
||||||
|
name: search.name,
|
||||||
|
mode: format!("{:?}", search.mode).to_lowercase(),
|
||||||
|
status: "ok".to_string(),
|
||||||
|
cmd: cmd_args,
|
||||||
|
match_count: Some(matches.len()),
|
||||||
|
truncated: Some(truncated),
|
||||||
|
matches: Some(matches),
|
||||||
|
stderr: if stderr.is_empty() { None } else { Some(stderr) },
|
||||||
|
exit_code: None,
|
||||||
|
duration_ms,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
SearchResult {
|
||||||
|
name: search.name,
|
||||||
|
mode: format!("{:?}", search.mode).to_lowercase(),
|
||||||
|
status: "error".to_string(),
|
||||||
|
cmd: cmd_args,
|
||||||
|
match_count: None,
|
||||||
|
truncated: None,
|
||||||
|
matches: None,
|
||||||
|
stderr: Some(format!("JSON parse error: {}\nRaw output: {}", e, stdout)),
|
||||||
|
exit_code: Some(exit_code),
|
||||||
|
duration_ms,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
SearchResult {
|
||||||
|
name: search.name,
|
||||||
|
mode: format!("{:?}", search.mode).to_lowercase(),
|
||||||
|
status: "error".to_string(),
|
||||||
|
cmd: cmd_args,
|
||||||
|
match_count: None,
|
||||||
|
truncated: None,
|
||||||
|
matches: None,
|
||||||
|
stderr: Some(stderr),
|
||||||
|
exit_code: Some(exit_code),
|
||||||
|
duration_ms,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
SearchResult {
|
||||||
|
name: search.name,
|
||||||
|
mode: format!("{:?}", search.mode).to_lowercase(),
|
||||||
|
status: "error".to_string(),
|
||||||
|
cmd: cmd_args,
|
||||||
|
match_count: None,
|
||||||
|
truncated: None,
|
||||||
|
matches: None,
|
||||||
|
stderr: Some(format!("Execution error: {}", e)),
|
||||||
|
exit_code: None,
|
||||||
|
duration_ms: start_time.elapsed().as_millis() as u64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
SearchResult {
|
||||||
|
name: search.name,
|
||||||
|
mode: format!("{:?}", search.mode).to_lowercase(),
|
||||||
|
status: "timeout".to_string(),
|
||||||
|
cmd: cmd_args,
|
||||||
|
match_count: None,
|
||||||
|
truncated: None,
|
||||||
|
matches: None,
|
||||||
|
stderr: Some(format!("Search timed out after {} seconds", timeout_secs)),
|
||||||
|
exit_code: None,
|
||||||
|
duration_ms: start_time.elapsed().as_millis() as u64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate a search specification
|
||||||
|
fn validate_search_spec(search: &SearchSpec) -> Result<()> {
|
||||||
|
match search.mode {
|
||||||
|
SearchMode::Pattern => {
|
||||||
|
if search.pattern.is_none() || search.pattern.as_ref().unwrap().is_empty() {
|
||||||
|
return Err(anyhow!("Pattern mode requires non-empty 'pattern' field"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SearchMode::Yaml => {
|
||||||
|
let rule_yaml = search.rule_yaml.as_ref()
|
||||||
|
.ok_or_else(|| anyhow!("YAML mode requires 'rule_yaml' field"))?;
|
||||||
|
|
||||||
|
if rule_yaml.is_empty() {
|
||||||
|
return Err(anyhow!("YAML mode requires non-empty 'rule_yaml' field"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse and validate YAML structure
|
||||||
|
let parsed: YamlRule = serde_yaml::from_str(rule_yaml)
|
||||||
|
.map_err(|e| anyhow!("Invalid YAML rule: {}", e))?;
|
||||||
|
|
||||||
|
if parsed.id.is_empty() {
|
||||||
|
return Err(anyhow!("YAML rule must have non-empty 'id' field"));
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsed.language.is_empty() {
|
||||||
|
return Err(anyhow!("YAML rule must have non-empty 'language' field"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate language is supported (basic check)
|
||||||
|
validate_language(&parsed.language)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate context range
|
||||||
|
if let Some(context) = search.context {
|
||||||
|
if context > 20 {
|
||||||
|
return Err(anyhow!("Context lines cannot exceed 20"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate that a language is supported by ast-grep
|
||||||
|
fn validate_language(language: &str) -> Result<()> {
|
||||||
|
let supported_languages = [
|
||||||
|
"rust", "javascript", "typescript", "python", "java", "c", "cpp", "csharp",
|
||||||
|
"go", "html", "css", "json", "yaml", "xml", "bash", "kotlin", "swift",
|
||||||
|
"php", "ruby", "scala", "dart", "lua", "r", "sql", "dockerfile",
|
||||||
|
"Rust", "JavaScript", "TypeScript", "Python", "Java", "C", "Cpp", "CSharp",
|
||||||
|
"Go", "Html", "Css", "Json", "Yaml", "Xml", "Bash", "Kotlin", "Swift",
|
||||||
|
"Php", "Ruby", "Scala", "Dart", "Lua", "R", "Sql", "Dockerfile"
|
||||||
|
];
|
||||||
|
|
||||||
|
if !supported_languages.contains(&language) {
|
||||||
|
warn!("Language '{}' may not be supported by ast-grep", language);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build ast-grep command arguments
|
||||||
|
fn build_ast_grep_command(search: &SearchSpec) -> Result<Vec<String>> {
|
||||||
|
let mut args = vec!["ast-grep".to_string()];
|
||||||
|
|
||||||
|
match search.mode {
|
||||||
|
SearchMode::Pattern => {
|
||||||
|
args.push("run".to_string());
|
||||||
|
|
||||||
|
// Add pattern
|
||||||
|
args.push("-p".to_string());
|
||||||
|
args.push(search.pattern.as_ref().unwrap().clone());
|
||||||
|
|
||||||
|
// Add language if specified
|
||||||
|
if let Some(ref lang) = search.language {
|
||||||
|
args.push("-l".to_string());
|
||||||
|
args.push(lang.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SearchMode::Yaml => {
|
||||||
|
args.push("scan".to_string());
|
||||||
|
|
||||||
|
// Add inline rules
|
||||||
|
args.push("--inline-rules".to_string());
|
||||||
|
args.push(search.rule_yaml.as_ref().unwrap().clone());
|
||||||
|
|
||||||
|
// Add include-metadata if requested
|
||||||
|
if search.include_metadata.unwrap_or(false) {
|
||||||
|
args.push("--include-metadata".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add severity overrides
|
||||||
|
if let Some(ref severity_map) = search.severity {
|
||||||
|
for (rule_id, severity) in severity_map {
|
||||||
|
match severity {
|
||||||
|
SeverityLevel::Error => {
|
||||||
|
args.push("--error".to_string());
|
||||||
|
args.push(rule_id.clone());
|
||||||
|
}
|
||||||
|
SeverityLevel::Warning => {
|
||||||
|
args.push("--warning".to_string());
|
||||||
|
args.push(rule_id.clone());
|
||||||
|
}
|
||||||
|
SeverityLevel::Info => {
|
||||||
|
args.push("--info".to_string());
|
||||||
|
args.push(rule_id.clone());
|
||||||
|
}
|
||||||
|
SeverityLevel::Hint => {
|
||||||
|
args.push("--hint".to_string());
|
||||||
|
args.push(rule_id.clone());
|
||||||
|
}
|
||||||
|
SeverityLevel::Off => {
|
||||||
|
args.push("--off".to_string());
|
||||||
|
args.push(rule_id.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add common arguments
|
||||||
|
|
||||||
|
// Add globs if specified
|
||||||
|
if let Some(ref globs) = search.globs {
|
||||||
|
if !globs.is_empty() {
|
||||||
|
args.push("--globs".to_string());
|
||||||
|
args.push(globs.join(","));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add context
|
||||||
|
if let Some(context) = search.context {
|
||||||
|
args.push("-C".to_string());
|
||||||
|
args.push(context.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add threads
|
||||||
|
if let Some(threads) = search.threads {
|
||||||
|
args.push("-j".to_string());
|
||||||
|
args.push(threads.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add JSON output style
|
||||||
|
let json_style = search.json_style.as_ref().unwrap_or(&JsonStyle::Stream);
|
||||||
|
let json_arg = match json_style {
|
||||||
|
JsonStyle::Pretty => "--json=pretty",
|
||||||
|
JsonStyle::Stream => "--json=stream",
|
||||||
|
JsonStyle::Compact => "--json=compact",
|
||||||
|
};
|
||||||
|
args.push(json_arg.to_string());
|
||||||
|
|
||||||
|
// Add no-ignore options
|
||||||
|
if let Some(ref no_ignore_list) = search.no_ignore {
|
||||||
|
for no_ignore_type in no_ignore_list {
|
||||||
|
let flag = match no_ignore_type {
|
||||||
|
NoIgnoreType::Hidden => "--no-ignore=hidden",
|
||||||
|
NoIgnoreType::Dot => "--no-ignore=dot",
|
||||||
|
NoIgnoreType::Exclude => "--no-ignore=exclude",
|
||||||
|
NoIgnoreType::Global => "--no-ignore=global",
|
||||||
|
NoIgnoreType::Parent => "--no-ignore=parent",
|
||||||
|
NoIgnoreType::Vcs => "--no-ignore=vcs",
|
||||||
|
};
|
||||||
|
args.push(flag.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add paths (default to current directory if none specified)
|
||||||
|
if let Some(ref paths) = search.paths {
|
||||||
|
if !paths.is_empty() {
|
||||||
|
args.extend(paths.clone());
|
||||||
|
} else {
|
||||||
|
args.push(".".to_string());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
args.push(".".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(args)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run ast-grep command and capture output
|
||||||
|
async fn run_ast_grep_command(args: &[String]) -> Result<(String, String, i32)> {
|
||||||
|
let mut cmd = Command::new(&args[0]);
|
||||||
|
cmd.args(&args[1..]);
|
||||||
|
cmd.stdout(Stdio::piped());
|
||||||
|
cmd.stderr(Stdio::piped());
|
||||||
|
|
||||||
|
debug!("Running command: {:?}", args);
|
||||||
|
|
||||||
|
let mut child = cmd.spawn()
|
||||||
|
.map_err(|e| anyhow!("Failed to spawn ast-grep process: {}", e))?;
|
||||||
|
|
||||||
|
let stdout = child.stdout.take().unwrap();
|
||||||
|
let stderr = child.stderr.take().unwrap();
|
||||||
|
|
||||||
|
let stdout_reader = BufReader::new(stdout);
|
||||||
|
let stderr_reader = BufReader::new(stderr);
|
||||||
|
|
||||||
|
let stdout_task = tokio::spawn(async move {
|
||||||
|
let mut lines = stdout_reader.lines();
|
||||||
|
let mut output = String::new();
|
||||||
|
while let Ok(Some(line)) = lines.next_line().await {
|
||||||
|
if !output.is_empty() {
|
||||||
|
output.push('\n');
|
||||||
|
}
|
||||||
|
output.push_str(&line);
|
||||||
|
}
|
||||||
|
output
|
||||||
|
});
|
||||||
|
|
||||||
|
let stderr_task = tokio::spawn(async move {
|
||||||
|
let mut lines = stderr_reader.lines();
|
||||||
|
let mut output = String::new();
|
||||||
|
while let Ok(Some(line)) = lines.next_line().await {
|
||||||
|
if !output.is_empty() {
|
||||||
|
output.push('\n');
|
||||||
|
}
|
||||||
|
output.push_str(&line);
|
||||||
|
}
|
||||||
|
output
|
||||||
|
});
|
||||||
|
|
||||||
|
let status = child.wait().await
|
||||||
|
.map_err(|e| anyhow!("Failed to wait for ast-grep process: {}", e))?;
|
||||||
|
|
||||||
|
let stdout_output = stdout_task.await
|
||||||
|
.map_err(|e| anyhow!("Failed to read stdout: {}", e))?;
|
||||||
|
let stderr_output = stderr_task.await
|
||||||
|
.map_err(|e| anyhow!("Failed to read stderr: {}", e))?;
|
||||||
|
|
||||||
|
let exit_code = status.code().unwrap_or(-1);
|
||||||
|
|
||||||
|
Ok((stdout_output, stderr_output, exit_code))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse ast-grep JSON output
|
||||||
|
fn parse_ast_grep_output(output: &str, max_matches: usize) -> Result<(Vec<Value>, bool)> {
|
||||||
|
if output.trim().is_empty() {
|
||||||
|
return Ok((vec![], false));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut matches = Vec::new();
|
||||||
|
let mut truncated = false;
|
||||||
|
|
||||||
|
// Handle stream format (line-delimited JSON)
|
||||||
|
for line in output.lines() {
|
||||||
|
let line = line.trim();
|
||||||
|
if line.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match serde_json::from_str::<Value>(line) {
|
||||||
|
Ok(match_obj) => {
|
||||||
|
if matches.len() >= max_matches {
|
||||||
|
truncated = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
matches.push(match_obj);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
debug!("Failed to parse JSON line '{}': {}", line, e);
|
||||||
|
// Try to parse the entire output as a single JSON array
|
||||||
|
match serde_json::from_str::<Vec<Value>>(output) {
|
||||||
|
Ok(array_matches) => {
|
||||||
|
let take_count = array_matches.len().min(max_matches);
|
||||||
|
let total_count = array_matches.len();
|
||||||
|
matches = array_matches.into_iter().take(take_count).collect();
|
||||||
|
truncated = take_count < total_count;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Err(e2) => {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"Failed to parse ast-grep output as line-delimited JSON or JSON array. Line error: {}, Array error: {}",
|
||||||
|
e, e2
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((matches, truncated))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if ast-grep is available and provide installation hints if not
|
||||||
|
async fn check_ast_grep_available() -> Result<()> {
|
||||||
|
match Command::new("ast-grep")
|
||||||
|
.arg("--version")
|
||||||
|
.output()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(output) => {
|
||||||
|
if output.status.success() {
|
||||||
|
let version = String::from_utf8_lossy(&output.stdout);
|
||||||
|
info!("Found ast-grep: {}", version.trim());
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("ast-grep command failed: {}", String::from_utf8_lossy(&output.stderr)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
Err(anyhow!(
|
||||||
|
"ast-grep not found. Please install it using one of these methods:\n\n\
|
||||||
|
• Homebrew (macOS): brew install ast-grep\n\
|
||||||
|
• MacPorts (macOS): sudo port install ast-grep\n\
|
||||||
|
• Nix: nix-env -iA nixpkgs.ast-grep\n\
|
||||||
|
• Cargo: cargo install ast-grep\n\
|
||||||
|
• npm: npm install -g @ast-grep/cli\n\
|
||||||
|
• pip: pip install ast-grep\n\n\
|
||||||
|
For more installation options, visit: https://ast-grep.github.io/guide/quick-start.html"
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_pattern_search() {
|
||||||
|
let search = SearchSpec {
|
||||||
|
name: "test".to_string(),
|
||||||
|
mode: SearchMode::Pattern,
|
||||||
|
pattern: Some("fn $NAME() {}".to_string()),
|
||||||
|
language: Some("rust".to_string()),
|
||||||
|
rule_yaml: None,
|
||||||
|
paths: None,
|
||||||
|
globs: None,
|
||||||
|
json_style: None,
|
||||||
|
context: None,
|
||||||
|
threads: None,
|
||||||
|
include_metadata: None,
|
||||||
|
no_ignore: None,
|
||||||
|
severity: None,
|
||||||
|
timeout_secs: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(validate_search_spec(&search).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_yaml_search() {
|
||||||
|
let yaml_rule = r#"
|
||||||
|
id: test-rule
|
||||||
|
language: Rust
|
||||||
|
rule:
|
||||||
|
pattern: "fn $NAME() {}"
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let search = SearchSpec {
|
||||||
|
name: "test".to_string(),
|
||||||
|
mode: SearchMode::Yaml,
|
||||||
|
pattern: None,
|
||||||
|
language: None,
|
||||||
|
rule_yaml: Some(yaml_rule.to_string()),
|
||||||
|
paths: None,
|
||||||
|
globs: None,
|
||||||
|
json_style: None,
|
||||||
|
context: None,
|
||||||
|
threads: None,
|
||||||
|
include_metadata: None,
|
||||||
|
no_ignore: None,
|
||||||
|
severity: None,
|
||||||
|
timeout_secs: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(validate_search_spec(&search).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_pattern_command() {
|
||||||
|
let search = SearchSpec {
|
||||||
|
name: "test".to_string(),
|
||||||
|
mode: SearchMode::Pattern,
|
||||||
|
pattern: Some("fn $NAME() {}".to_string()),
|
||||||
|
language: Some("rust".to_string()),
|
||||||
|
rule_yaml: None,
|
||||||
|
paths: Some(vec!["src/".to_string()]),
|
||||||
|
globs: None,
|
||||||
|
json_style: Some(JsonStyle::Stream),
|
||||||
|
context: Some(2),
|
||||||
|
threads: Some(4),
|
||||||
|
include_metadata: None,
|
||||||
|
no_ignore: None,
|
||||||
|
severity: None,
|
||||||
|
timeout_secs: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let cmd = build_ast_grep_command(&search).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(cmd[0], "ast-grep");
|
||||||
|
assert_eq!(cmd[1], "run");
|
||||||
|
assert!(cmd.contains(&"-p".to_string()));
|
||||||
|
assert!(cmd.contains(&"fn $NAME() {}".to_string()));
|
||||||
|
assert!(cmd.contains(&"-l".to_string()));
|
||||||
|
assert!(cmd.contains(&"rust".to_string()));
|
||||||
|
assert!(cmd.contains(&"--json=stream".to_string()));
|
||||||
|
assert!(cmd.contains(&"-C".to_string()));
|
||||||
|
assert!(cmd.contains(&"2".to_string()));
|
||||||
|
assert!(cmd.contains(&"-j".to_string()));
|
||||||
|
assert!(cmd.contains(&"4".to_string()));
|
||||||
|
assert!(cmd.contains(&"src/".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_stream_json() {
|
||||||
|
let output = r#"{"file":"test.rs","text":"fn hello() {}"}
|
||||||
|
{"file":"test2.rs","text":"fn world() {}"}"#;
|
||||||
|
|
||||||
|
let (matches, truncated) = parse_ast_grep_output(output, 10).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(matches.len(), 2);
|
||||||
|
assert!(!truncated);
|
||||||
|
assert_eq!(matches[0]["file"], "test.rs");
|
||||||
|
assert_eq!(matches[1]["file"], "test2.rs");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_truncated_output() {
|
||||||
|
let output = r#"{"file":"test1.rs","text":"fn a() {}"}
|
||||||
|
{"file":"test2.rs","text":"fn b() {}"}
|
||||||
|
{"file":"test3.rs","text":"fn c() {}"}"#;
|
||||||
|
|
||||||
|
let (matches, truncated) = parse_ast_grep_output(output, 2).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(matches.len(), 2);
|
||||||
|
assert!(truncated);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,11 @@
|
|||||||
// 3. Only elide JSON content between first '{' and last '}' (inclusive)
|
// 3. Only elide JSON content between first '{' and last '}' (inclusive)
|
||||||
// 4. Return everything else as the final filtered string
|
// 4. Return everything else as the final filtered string
|
||||||
|
|
||||||
|
//! JSON tool call filtering for streaming LLM responses.
|
||||||
|
//!
|
||||||
|
//! This module filters out JSON tool calls from LLM output streams while preserving
|
||||||
|
//! regular text content. It uses a state machine to handle streaming chunks.
|
||||||
|
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use std::cell::RefCell;
|
use std::cell::RefCell;
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
@@ -13,37 +18,51 @@ thread_local! {
|
|||||||
static FIXED_JSON_TOOL_STATE: RefCell<FixedJsonToolState> = RefCell::new(FixedJsonToolState::new());
|
static FIXED_JSON_TOOL_STATE: RefCell<FixedJsonToolState> = RefCell::new(FixedJsonToolState::new());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Internal state for tracking JSON tool call filtering across streaming chunks.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct FixedJsonToolState {
|
struct FixedJsonToolState {
|
||||||
|
/// True when actively suppressing a confirmed tool call
|
||||||
suppression_mode: bool,
|
suppression_mode: bool,
|
||||||
|
/// True when buffering potential JSON (saw { but not yet confirmed as tool call)
|
||||||
|
potential_json_mode: bool,
|
||||||
|
/// Tracks nesting depth of braces within JSON
|
||||||
brace_depth: i32,
|
brace_depth: i32,
|
||||||
buffer: String,
|
buffer: String,
|
||||||
json_start_in_buffer: Option<usize>,
|
json_start_in_buffer: Option<usize>, // Position where confirmed JSON tool call starts
|
||||||
content_returned_up_to: usize, // Track how much content we've already returned
|
content_returned_up_to: usize, // Track how much content we've already returned
|
||||||
|
potential_json_start: Option<usize>, // Where the potential JSON started
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FixedJsonToolState {
|
impl FixedJsonToolState {
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
suppression_mode: false,
|
suppression_mode: false,
|
||||||
|
potential_json_mode: false,
|
||||||
brace_depth: 0,
|
brace_depth: 0,
|
||||||
buffer: String::new(),
|
buffer: String::new(),
|
||||||
json_start_in_buffer: None,
|
json_start_in_buffer: None,
|
||||||
content_returned_up_to: 0,
|
content_returned_up_to: 0,
|
||||||
|
potential_json_start: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset(&mut self) {
|
fn reset(&mut self) {
|
||||||
self.suppression_mode = false;
|
self.suppression_mode = false;
|
||||||
|
self.potential_json_mode = false;
|
||||||
self.brace_depth = 0;
|
self.brace_depth = 0;
|
||||||
self.buffer.clear();
|
self.buffer.clear();
|
||||||
self.json_start_in_buffer = None;
|
self.json_start_in_buffer = None;
|
||||||
self.content_returned_up_to = 0;
|
self.content_returned_up_to = 0;
|
||||||
|
self.potential_json_start = None;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// FINAL CORRECTED implementation according to specification
|
// FINAL CORRECTED implementation according to specification
|
||||||
|
|
||||||
|
/// Filters JSON tool calls from streaming LLM content.
|
||||||
|
///
|
||||||
|
/// Processes content chunks and removes JSON tool calls while preserving regular text.
|
||||||
|
/// Maintains state across calls to handle tool calls spanning multiple chunks.
|
||||||
pub fn fixed_filter_json_tool_calls(content: &str) -> String {
|
pub fn fixed_filter_json_tool_calls(content: &str) -> String {
|
||||||
if content.is_empty() {
|
if content.is_empty() {
|
||||||
return String::new();
|
return String::new();
|
||||||
@@ -87,13 +106,225 @@ pub fn fixed_filter_json_tool_calls(content: &str) -> String {
|
|||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CRITICAL FIX: After counting braces, if still in suppression mode,
|
||||||
|
// check if a new tool call pattern appears. This handles truncated JSON
|
||||||
|
// followed by complete JSON.
|
||||||
|
if state.suppression_mode {
|
||||||
|
let current_json_start = state.json_start_in_buffer.unwrap();
|
||||||
|
// Don't require newline - the new JSON might be concatenated directly
|
||||||
|
let tool_call_regex = Regex::new(r#"\{\s*"tool"\s*:\s*""#).unwrap();
|
||||||
|
|
||||||
|
// Look for new tool call patterns after the current one
|
||||||
|
if let Some(captures) = tool_call_regex.find(&state.buffer[current_json_start + 1..]) {
|
||||||
|
let new_json_start = current_json_start + 1 + captures.start() + captures.as_str().find('{').unwrap();
|
||||||
|
|
||||||
|
debug!("Detected new tool call at position {} while processing incomplete one at {} - discarding old", new_json_start, current_json_start);
|
||||||
|
|
||||||
|
// The previous JSON was incomplete/malformed
|
||||||
|
// Return content before the old JSON (if any)
|
||||||
|
let content_before_old_json = if current_json_start > state.content_returned_up_to {
|
||||||
|
state.buffer[state.content_returned_up_to..current_json_start].to_string()
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Update state to skip the incomplete JSON and position at the new one
|
||||||
|
// We'll process the new JSON on the next call
|
||||||
|
state.content_returned_up_to = new_json_start;
|
||||||
|
state.suppression_mode = false;
|
||||||
|
state.json_start_in_buffer = None;
|
||||||
|
state.brace_depth = 0;
|
||||||
|
|
||||||
|
return content_before_old_json;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Still in suppression mode, return empty string (content is being accumulated)
|
// Still in suppression mode, return empty string (content is being accumulated)
|
||||||
return String::new();
|
return String::new();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if we're in potential JSON mode (saw { but waiting to confirm it's a tool call)
|
||||||
|
if state.potential_json_mode {
|
||||||
|
// Check if the buffer contains a confirmed tool call pattern
|
||||||
|
let tool_call_regex = Regex::new(r#"(?m)^\s*\{\s*"tool"\s*:\s*""#).unwrap();
|
||||||
|
|
||||||
|
if let Some(captures) = tool_call_regex.find(&state.buffer) {
|
||||||
|
// Confirmed! This is a tool call - enter suppression mode
|
||||||
|
let match_text = captures.as_str();
|
||||||
|
if let Some(brace_offset) = match_text.find('{') {
|
||||||
|
let json_start = captures.start() + brace_offset;
|
||||||
|
|
||||||
|
debug!("Confirmed JSON tool call at position {} - entering suppression mode", json_start);
|
||||||
|
|
||||||
|
state.potential_json_mode = false;
|
||||||
|
state.suppression_mode = true;
|
||||||
|
state.brace_depth = 0;
|
||||||
|
state.json_start_in_buffer = Some(json_start);
|
||||||
|
|
||||||
|
// Count braces from json_start to see if JSON is complete
|
||||||
|
let buffer_slice = state.buffer[json_start..].to_string();
|
||||||
|
for ch in buffer_slice.chars() {
|
||||||
|
match ch {
|
||||||
|
'{' => state.brace_depth += 1,
|
||||||
|
'}' => {
|
||||||
|
state.brace_depth -= 1;
|
||||||
|
if state.brace_depth <= 0 {
|
||||||
|
debug!("JSON tool call completed immediately");
|
||||||
|
let result = extract_fixed_content(&state.buffer, json_start);
|
||||||
|
let new_content = if result.len() > state.content_returned_up_to {
|
||||||
|
result[state.content_returned_up_to..].to_string()
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
};
|
||||||
|
state.reset();
|
||||||
|
return new_content;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// JSON incomplete, stay in suppression mode, return nothing
|
||||||
|
return String::new();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we can rule out this being a tool call
|
||||||
|
// If we have enough content after the { and it doesn't match the pattern, release it
|
||||||
|
if let Some(potential_start) = state.potential_json_start {
|
||||||
|
let content_after_brace = &state.buffer[potential_start..];
|
||||||
|
|
||||||
|
// Rule out as a tool call if:
|
||||||
|
// 1. Closing } appears before we see the full pattern
|
||||||
|
// 2. Content clearly doesn't match the tool call pattern
|
||||||
|
// 3. Newline appears after the opening brace (tool calls should be compact)
|
||||||
|
|
||||||
|
let has_closing_brace = content_after_brace.contains('}');
|
||||||
|
let has_newline = content_after_brace[1..].contains('\n'); // Skip first char which is {
|
||||||
|
let long_enough = content_after_brace.len() >= 10;
|
||||||
|
|
||||||
|
// Detect non-tool JSON patterns:
|
||||||
|
// - { followed by " and a key that doesn't start with "tool"
|
||||||
|
// - { followed by "t" but not "to"
|
||||||
|
// - { followed by "to" but not "too", etc.
|
||||||
|
let not_tool_pattern = Regex::new(r#"^\{\s*"(?:[^t]|t(?:[^o]|o(?:[^o]|o(?:[^l]|l[^"\s:]))))"#).unwrap();
|
||||||
|
let definitely_not_tool = not_tool_pattern.is_match(content_after_brace);
|
||||||
|
|
||||||
|
if has_closing_brace || has_newline || (long_enough && definitely_not_tool) {
|
||||||
|
debug!("Potential JSON ruled out - not a tool call");
|
||||||
|
state.potential_json_mode = false;
|
||||||
|
state.potential_json_start = None;
|
||||||
|
|
||||||
|
// Return the buffered content we've been holding
|
||||||
|
let new_content = if state.buffer.len() > state.content_returned_up_to {
|
||||||
|
state.buffer[state.content_returned_up_to..].to_string()
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
};
|
||||||
|
state.content_returned_up_to = state.buffer.len();
|
||||||
|
return new_content;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Still in potential mode, keep buffering
|
||||||
|
return String::new();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect potential JSON start: { at the beginning of a line
|
||||||
|
let potential_json_regex = Regex::new(r"(?m)^\s*\{\s*").unwrap();
|
||||||
|
|
||||||
|
if let Some(captures) = potential_json_regex.find(&state.buffer[state.content_returned_up_to..]) {
|
||||||
|
let match_start = state.content_returned_up_to + captures.start();
|
||||||
|
let brace_pos = match_start + captures.as_str().find('{').unwrap();
|
||||||
|
|
||||||
|
debug!("Potential JSON detected at position {} - entering buffering mode", brace_pos);
|
||||||
|
|
||||||
|
// Fast path: check if this is already a confirmed tool call
|
||||||
|
let tool_call_regex = Regex::new(r#"(?m)^\s*\{\s*"tool"\s*:\s*""#).unwrap();
|
||||||
|
if tool_call_regex.is_match(&state.buffer[brace_pos..]) {
|
||||||
|
// This is a confirmed tool call! Process it immediately
|
||||||
|
let json_start = brace_pos;
|
||||||
|
debug!("Immediately confirmed tool call at position {}", json_start);
|
||||||
|
|
||||||
|
// Return content before JSON
|
||||||
|
let content_before = if json_start > state.content_returned_up_to {
|
||||||
|
state.buffer[state.content_returned_up_to..json_start].to_string()
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
state.content_returned_up_to = json_start;
|
||||||
|
state.suppression_mode = true;
|
||||||
|
state.brace_depth = 0;
|
||||||
|
state.json_start_in_buffer = Some(json_start);
|
||||||
|
|
||||||
|
// Count braces to see if JSON is complete
|
||||||
|
let buffer_slice = state.buffer[json_start..].to_string();
|
||||||
|
for ch in buffer_slice.chars() {
|
||||||
|
match ch {
|
||||||
|
'{' => state.brace_depth += 1,
|
||||||
|
'}' => {
|
||||||
|
state.brace_depth -= 1;
|
||||||
|
if state.brace_depth <= 0 {
|
||||||
|
debug!("JSON tool call completed in same chunk");
|
||||||
|
let result = extract_fixed_content(&state.buffer, json_start);
|
||||||
|
let content_after = if result.len() > json_start {
|
||||||
|
&result[json_start..]
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
};
|
||||||
|
let final_result = format!("{}{}", content_before, content_after);
|
||||||
|
state.reset();
|
||||||
|
return final_result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// JSON incomplete, return content before and stay in suppression mode
|
||||||
|
return content_before;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return content before the potential JSON
|
||||||
|
let content_before = if brace_pos > state.content_returned_up_to {
|
||||||
|
state.buffer[state.content_returned_up_to..brace_pos].to_string()
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
state.content_returned_up_to = brace_pos;
|
||||||
|
state.potential_json_mode = true;
|
||||||
|
state.potential_json_start = Some(brace_pos);
|
||||||
|
|
||||||
|
// Optimization: immediately check if we can rule this out for single-chunk processing
|
||||||
|
let content_after_brace = &state.buffer[brace_pos..];
|
||||||
|
let has_closing_brace = content_after_brace.contains('}');
|
||||||
|
let has_newline = content_after_brace.len() > 1 && content_after_brace[1..].contains('\n');
|
||||||
|
let long_enough = content_after_brace.len() >= 10;
|
||||||
|
|
||||||
|
let not_tool_pattern = Regex::new(r#"^\{\s*"(?:[^t]|t(?:[^o]|o(?:[^o]|o(?:[^l]|l[^"\s:]))))"#).unwrap();
|
||||||
|
let definitely_not_tool = not_tool_pattern.is_match(content_after_brace);
|
||||||
|
|
||||||
|
if has_closing_brace || has_newline || (long_enough && definitely_not_tool) {
|
||||||
|
debug!("Immediately ruled out as not a tool call");
|
||||||
|
state.potential_json_mode = false;
|
||||||
|
state.potential_json_start = None;
|
||||||
|
|
||||||
|
// Return all the buffered content
|
||||||
|
let new_content = if state.buffer.len() > state.content_returned_up_to {
|
||||||
|
state.buffer[state.content_returned_up_to..].to_string()
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
};
|
||||||
|
state.content_returned_up_to = state.buffer.len();
|
||||||
|
return format!("{}{}", content_before, new_content);
|
||||||
|
}
|
||||||
|
|
||||||
|
return content_before;
|
||||||
|
}
|
||||||
|
|
||||||
// Check for tool call pattern using corrected regex
|
// Check for tool call pattern using corrected regex
|
||||||
// More flexible than the strict specification to handle real-world JSON
|
let tool_call_regex = Regex::new(r#"(?m)^\s*\{\s*"tool"\s*:\s*"[^"]*""#).unwrap();
|
||||||
let tool_call_regex = Regex::new(r#"(?m)^\s*\{\s*"tool"\s*:\s*""#).unwrap();
|
|
||||||
|
|
||||||
if let Some(captures) = tool_call_regex.find(&state.buffer) {
|
if let Some(captures) = tool_call_regex.find(&state.buffer) {
|
||||||
let match_text = captures.as_str();
|
let match_text = captures.as_str();
|
||||||
@@ -156,21 +387,29 @@ pub fn fixed_filter_json_tool_calls(content: &str) -> String {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// No JSON tool call detected, return only the new content we haven't returned yet
|
// No JSON tool call detected, return only the new content we haven't returned yet
|
||||||
let new_content = if state.buffer.len() > state.content_returned_up_to {
|
|
||||||
|
|
||||||
|
if state.buffer.len() > state.content_returned_up_to {
|
||||||
let result = state.buffer[state.content_returned_up_to..].to_string();
|
let result = state.buffer[state.content_returned_up_to..].to_string();
|
||||||
state.content_returned_up_to = state.buffer.len();
|
state.content_returned_up_to = state.buffer.len();
|
||||||
result
|
result
|
||||||
} else {
|
} else {
|
||||||
String::new()
|
String::new()
|
||||||
};
|
}
|
||||||
|
|
||||||
new_content
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to extract content with JSON tool call filtered out
|
/// Extracts content from buffer, removing the JSON tool call.
|
||||||
// Returns everything except the JSON between the first '{' and last '}' (inclusive)
|
///
|
||||||
|
/// Given a buffer and the start position of a JSON tool call, this function:
|
||||||
|
/// 1. Extracts all content before the JSON
|
||||||
|
/// 2. Finds the end of the JSON (matching closing brace)
|
||||||
|
/// 3. Extracts all content after the JSON
|
||||||
|
/// 4. Returns the concatenation of before + after (JSON removed)
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `full_content` - The full content buffer
|
||||||
|
/// * `json_start` - Position where the JSON tool call begins
|
||||||
fn extract_fixed_content(full_content: &str, json_start: usize) -> String {
|
fn extract_fixed_content(full_content: &str, json_start: usize) -> String {
|
||||||
// Find the end of the JSON using proper brace counting with string handling
|
// Find the end of the JSON using proper brace counting with string handling
|
||||||
let mut brace_depth = 0;
|
let mut brace_depth = 0;
|
||||||
@@ -212,8 +451,10 @@ fn extract_fixed_content(full_content: &str, json_start: usize) -> String {
|
|||||||
format!("{}{}", before, after)
|
format!("{}{}", before, after)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset function for testing
|
/// Resets the global JSON filtering state.
|
||||||
|
///
|
||||||
|
/// Call this between independent filtering sessions to ensure clean state.
|
||||||
|
/// This is particularly important in tests and when starting new conversations.
|
||||||
pub fn reset_fixed_json_tool_state() {
|
pub fn reset_fixed_json_tool_state() {
|
||||||
FIXED_JSON_TOOL_STATE.with(|state| {
|
FIXED_JSON_TOOL_STATE.with(|state| {
|
||||||
let mut state = state.borrow_mut();
|
let mut state = state.borrow_mut();
|
||||||
|
|||||||
@@ -1,8 +1,14 @@
|
|||||||
|
//! Tests for JSON tool call filtering.
|
||||||
|
//!
|
||||||
|
//! These tests verify that the filter correctly identifies and removes JSON tool calls
|
||||||
|
//! from LLM output streams while preserving all other content.
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod fixed_filter_tests {
|
mod fixed_filter_tests {
|
||||||
use crate::fixed_filter_json::{fixed_filter_json_tool_calls, reset_fixed_json_tool_state};
|
use crate::fixed_filter_json::{fixed_filter_json_tool_calls, reset_fixed_json_tool_state};
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
|
|
||||||
|
/// Test that regular text without tool calls passes through unchanged.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_no_tool_call_passthrough() {
|
fn test_no_tool_call_passthrough() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -11,6 +17,7 @@ mod fixed_filter_tests {
|
|||||||
assert_eq!(result, input);
|
assert_eq!(result, input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test detection and removal of a complete tool call in a single chunk.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_simple_tool_call_detection() {
|
fn test_simple_tool_call_detection() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -23,6 +30,7 @@ Some text after"#;
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test handling of tool calls that arrive across multiple streaming chunks.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_streaming_chunks() {
|
fn test_streaming_chunks() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -48,6 +56,7 @@ Some text after"#;
|
|||||||
assert_eq!(final_result, expected);
|
assert_eq!(final_result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test correct handling of nested braces within JSON strings.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_nested_braces_in_tool_call() {
|
fn test_nested_braces_in_tool_call() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -61,6 +70,7 @@ Text after"#;
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Verify the regex pattern matches the specification with flexible whitespace.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_regex_pattern_specification() {
|
fn test_regex_pattern_specification() {
|
||||||
// Test the corrected regex pattern that's more flexible with whitespace
|
// Test the corrected regex pattern that's more flexible with whitespace
|
||||||
@@ -84,11 +94,6 @@ Text after"#;
|
|||||||
), // Space after { DOES match with \s*
|
), // Space after { DOES match with \s*
|
||||||
(
|
(
|
||||||
r#"line
|
r#"line
|
||||||
abc{"tool":"#,
|
|
||||||
true,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r#"line
|
|
||||||
{"tool123":"#,
|
{"tool123":"#,
|
||||||
false,
|
false,
|
||||||
), // "tool123" is not exactly "tool"
|
), // "tool123" is not exactly "tool"
|
||||||
@@ -109,6 +114,7 @@ abc{"tool":"#,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test that tool calls must appear at the start of a line (after newline).
|
||||||
#[test]
|
#[test]
|
||||||
fn test_newline_requirement() {
|
fn test_newline_requirement() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -122,13 +128,14 @@ abc{"tool":"#,
|
|||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
let result2 = fixed_filter_json_tool_calls(input_without_newline);
|
let result2 = fixed_filter_json_tool_calls(input_without_newline);
|
||||||
|
|
||||||
// Both cases currently trigger suppression due to regex pattern
|
// With the new aggressive filtering, only the newline case should trigger suppression
|
||||||
// TODO: Fix regex to only match after actual newlines
|
// The pattern requires { to be at the start of a line (after ^)
|
||||||
assert_eq!(result1, "Text\n");
|
assert_eq!(result1, "Text\n");
|
||||||
// This currently fails because our regex matches both cases
|
// Without newline before {, it should pass through unchanged
|
||||||
assert_eq!(result2, "Text ");
|
assert_eq!(result2, input_without_newline);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test handling of escaped quotes within JSON strings.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_json_with_escaped_quotes() {
|
fn test_json_with_escaped_quotes() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -142,6 +149,7 @@ More text"#;
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test graceful handling of incomplete/malformed JSON.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_edge_case_malformed_json() {
|
fn test_edge_case_malformed_json() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -157,6 +165,7 @@ More text"#;
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test processing multiple independent tool calls sequentially.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_multiple_tool_calls_sequential() {
|
fn test_multiple_tool_calls_sequential() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -179,6 +188,7 @@ Final text"#;
|
|||||||
assert_eq!(result2, expected2);
|
assert_eq!(result2, expected2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test tool calls with complex multi-line arguments.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_tool_call_with_complex_args() {
|
fn test_tool_call_with_complex_args() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -192,6 +202,7 @@ After"#;
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test input containing only a tool call with no surrounding text.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_tool_call_only() {
|
fn test_tool_call_only() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -204,6 +215,7 @@ After"#;
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test accurate brace counting with deeply nested structures.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_brace_counting_accuracy() {
|
fn test_brace_counting_accuracy() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -218,6 +230,7 @@ End"#;
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test that braces within strings don't affect brace counting.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_string_escaping_in_json() {
|
fn test_string_escaping_in_json() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -232,6 +245,7 @@ More"#;
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Verify compliance with the exact specification requirements.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_specification_compliance() {
|
fn test_specification_compliance() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -248,6 +262,7 @@ More"#;
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test that non-tool JSON objects are not filtered.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_no_false_positives() {
|
fn test_no_false_positives() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -261,6 +276,7 @@ More text"#;
|
|||||||
assert_eq!(result, input);
|
assert_eq!(result, input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test patterns that look similar to tool calls but aren't exact matches.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_partial_tool_patterns() {
|
fn test_partial_tool_patterns() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -280,6 +296,7 @@ More text"#;
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test streaming with very small chunks (character-by-character).
|
||||||
#[test]
|
#[test]
|
||||||
fn test_streaming_edge_cases() {
|
fn test_streaming_edge_cases() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -296,12 +313,13 @@ More text"#;
|
|||||||
}
|
}
|
||||||
|
|
||||||
let final_result: String = results.join("");
|
let final_result: String = results.join("");
|
||||||
// This test currently fails because the JSON is incomplete across chunks
|
// With the new aggressive filtering, the JSON should be completely filtered out
|
||||||
// The function doesn't handle this edge case properly yet
|
// even when it arrives in very small chunks
|
||||||
let expected = "Text\n{\"tool\": \nAfter";
|
let expected = "Text\n\nAfter";
|
||||||
assert_eq!(final_result, expected);
|
assert_eq!(final_result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Debug test with detailed logging for streaming behavior.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_streaming_debug() {
|
fn test_streaming_debug() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -329,4 +347,38 @@ More text"#;
|
|||||||
let expected = "Some text before\n\nText after";
|
let expected = "Some text before\n\nText after";
|
||||||
assert_eq!(final_result, expected);
|
assert_eq!(final_result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test handling of truncated JSON followed by complete JSON (the json_err pattern)
|
||||||
|
#[test]
|
||||||
|
fn test_truncated_then_complete_json() {
|
||||||
|
reset_fixed_json_tool_state();
|
||||||
|
|
||||||
|
// Simulate the pattern from json_err trace:
|
||||||
|
// 1. Incomplete/truncated JSON appears
|
||||||
|
// 2. Then the same complete JSON appears
|
||||||
|
let chunks = vec![
|
||||||
|
"Some text\n",
|
||||||
|
r#"{"tool": "str_replace", "args": {"diff":"...","file_path":"./crates/g3-cli"#, // Truncated
|
||||||
|
r#"{"tool": "str_replace", "args": {"diff":"...","file_path":"./crates/g3-cli/src/lib.rs"}}"#, // Complete
|
||||||
|
"\nMore text",
|
||||||
|
];
|
||||||
|
|
||||||
|
let mut results = Vec::new();
|
||||||
|
for (i, chunk) in chunks.iter().enumerate() {
|
||||||
|
let result = fixed_filter_json_tool_calls(chunk);
|
||||||
|
println!("Chunk {}: {:?} -> {:?}", i, chunk, result);
|
||||||
|
results.push(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
let final_result: String = results.join("");
|
||||||
|
println!("Final result: {:?}", final_result);
|
||||||
|
|
||||||
|
// The truncated JSON should be discarded when the complete one appears
|
||||||
|
// Both JSONs should be filtered out, leaving only the text
|
||||||
|
let expected = "Some text\n\nMore text";
|
||||||
|
assert_eq!(
|
||||||
|
final_result, expected,
|
||||||
|
"Failed to handle truncated JSON followed by complete JSON"
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -104,6 +104,7 @@ impl Project {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Recursively check a directory for implementation files
|
/// Recursively check a directory for implementation files
|
||||||
|
#[allow(clippy::only_used_in_recursion)]
|
||||||
fn check_dir_for_implementation_files(&self, dir: &Path) -> bool {
|
fn check_dir_for_implementation_files(&self, dir: &Path) -> bool {
|
||||||
// Common source file extensions
|
// Common source file extensions
|
||||||
let extensions = vec![
|
let extensions = vec![
|
||||||
|
|||||||
37
crates/g3-core/src/take_screenshot_test.rs
Normal file
37
crates/g3-core/src/take_screenshot_test.rs
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
// Test to verify take_screenshot requires window_id
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod take_screenshot_tests {
|
||||||
|
use super::*;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_take_screenshot_requires_window_id() {
|
||||||
|
// Create a tool call without window_id
|
||||||
|
let tool_call = ToolCall {
|
||||||
|
tool: "take_screenshot".to_string(),
|
||||||
|
args: json!({
|
||||||
|
"path": "test.png"
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Verify that window_id is missing
|
||||||
|
assert!(tool_call.args.get("window_id").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_take_screenshot_with_window_id() {
|
||||||
|
// Create a tool call with window_id
|
||||||
|
let tool_call = ToolCall {
|
||||||
|
tool: "take_screenshot".to_string(),
|
||||||
|
args: json!({
|
||||||
|
"path": "test.png",
|
||||||
|
"window_id": "Safari"
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Verify that window_id is present
|
||||||
|
assert!(tool_call.args.get("window_id").is_some());
|
||||||
|
assert_eq!(tool_call.args.get("window_id").unwrap().as_str().unwrap(), "Safari");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -17,6 +17,9 @@ pub trait UiWriter: Send + Sync {
|
|||||||
/// Print a context window status message
|
/// Print a context window status message
|
||||||
fn print_context_status(&self, message: &str);
|
fn print_context_status(&self, message: &str);
|
||||||
|
|
||||||
|
/// Print a context thinning success message with highlight and animation
|
||||||
|
fn print_context_thinning(&self, message: &str);
|
||||||
|
|
||||||
/// Print a tool execution header
|
/// Print a tool execution header
|
||||||
fn print_tool_header(&self, tool_name: &str);
|
fn print_tool_header(&self, tool_name: &str);
|
||||||
|
|
||||||
@@ -49,6 +52,10 @@ pub trait UiWriter: Send + Sync {
|
|||||||
|
|
||||||
/// Flush any buffered output
|
/// Flush any buffered output
|
||||||
fn flush(&self);
|
fn flush(&self);
|
||||||
|
|
||||||
|
/// Returns true if this UI writer wants full, untruncated output
|
||||||
|
/// Default is false (truncate for human readability)
|
||||||
|
fn wants_full_output(&self) -> bool { false }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A no-op implementation for when UI output is not needed
|
/// A no-op implementation for when UI output is not needed
|
||||||
@@ -60,6 +67,7 @@ impl UiWriter for NullUiWriter {
|
|||||||
fn print_inline(&self, _message: &str) {}
|
fn print_inline(&self, _message: &str) {}
|
||||||
fn print_system_prompt(&self, _prompt: &str) {}
|
fn print_system_prompt(&self, _prompt: &str) {}
|
||||||
fn print_context_status(&self, _message: &str) {}
|
fn print_context_status(&self, _message: &str) {}
|
||||||
|
fn print_context_thinning(&self, _message: &str) {}
|
||||||
fn print_tool_header(&self, _tool_name: &str) {}
|
fn print_tool_header(&self, _tool_name: &str) {}
|
||||||
fn print_tool_arg(&self, _key: &str, _value: &str) {}
|
fn print_tool_arg(&self, _key: &str, _value: &str) {}
|
||||||
fn print_tool_output_header(&self) {}
|
fn print_tool_output_header(&self) {}
|
||||||
@@ -71,4 +79,5 @@ impl UiWriter for NullUiWriter {
|
|||||||
fn print_agent_response(&self, _content: &str) {}
|
fn print_agent_response(&self, _content: &str) {}
|
||||||
fn notify_sse_received(&self) {}
|
fn notify_sse_received(&self) {}
|
||||||
fn flush(&self) {}
|
fn flush(&self) {}
|
||||||
|
fn wants_full_output(&self) -> bool { false }
|
||||||
}
|
}
|
||||||
@@ -72,7 +72,7 @@ fn test_thin_context_basic() {
|
|||||||
|
|
||||||
// Trigger thinning at 50%
|
// Trigger thinning at 50%
|
||||||
context.used_tokens = 5000;
|
context.used_tokens = 5000;
|
||||||
let summary = context.thin_context();
|
let (summary, _chars_saved) = context.thin_context();
|
||||||
|
|
||||||
println!("Thinning summary: {}", summary);
|
println!("Thinning summary: {}", summary);
|
||||||
|
|
||||||
@@ -93,6 +93,119 @@ fn test_thin_context_basic() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_thin_write_file_tool_calls() {
|
||||||
|
let mut context = ContextWindow::new(10000);
|
||||||
|
|
||||||
|
// Add some messages including a write_file tool call with large content
|
||||||
|
context.add_message(Message {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: "Please create a large file".to_string(),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add an assistant message with a write_file tool call containing large content
|
||||||
|
let large_content = "x".repeat(1500);
|
||||||
|
let tool_call_json = format!(
|
||||||
|
r#"{{"tool": "write_file", "args": {{"file_path": "test.txt", "content": "{}"}}}}"#,
|
||||||
|
large_content
|
||||||
|
);
|
||||||
|
context.add_message(Message {
|
||||||
|
role: MessageRole::Assistant,
|
||||||
|
content: format!("I'll create that file.\n\n{}", tool_call_json),
|
||||||
|
});
|
||||||
|
|
||||||
|
context.add_message(Message {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: "Tool result: ✅ Successfully wrote 1500 lines".to_string(),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add more messages to ensure we have enough for "first third" logic
|
||||||
|
for i in 0..6 {
|
||||||
|
context.add_message(Message {
|
||||||
|
role: MessageRole::Assistant,
|
||||||
|
content: format!("Response {}", i),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trigger thinning at 50%
|
||||||
|
context.used_tokens = 5000;
|
||||||
|
let (summary, _chars_saved) = context.thin_context();
|
||||||
|
|
||||||
|
println!("Thinning summary: {}", summary);
|
||||||
|
|
||||||
|
// Should have thinned the write_file tool call
|
||||||
|
assert!(summary.contains("tool call") || summary.contains("chars saved"));
|
||||||
|
|
||||||
|
// Check that the large content was replaced with a file reference
|
||||||
|
let first_third_end = context.conversation_history.len() / 3;
|
||||||
|
for i in 0..first_third_end {
|
||||||
|
if let Some(msg) = context.conversation_history.get(i) {
|
||||||
|
if matches!(msg.role, MessageRole::Assistant) && msg.content.contains("write_file") {
|
||||||
|
// The content should now reference an external file
|
||||||
|
assert!(msg.content.contains("<content saved to"));
|
||||||
|
assert!(!msg.content.contains(&large_content));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_thin_str_replace_tool_calls() {
|
||||||
|
let mut context = ContextWindow::new(10000);
|
||||||
|
|
||||||
|
// Add some messages including a str_replace tool call with large diff
|
||||||
|
context.add_message(Message {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: "Please update the file".to_string(),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add an assistant message with a str_replace tool call containing large diff
|
||||||
|
let large_diff = format!("--- old\n{}\n+++ new\n{}", "-old line\n".repeat(100), "+new line\n".repeat(100));
|
||||||
|
let tool_call_json = format!(
|
||||||
|
r#"{{"tool": "str_replace", "args": {{"file_path": "test.txt", "diff": "{}"}}}}"#,
|
||||||
|
large_diff.replace('\n', "\\n")
|
||||||
|
);
|
||||||
|
context.add_message(Message {
|
||||||
|
role: MessageRole::Assistant,
|
||||||
|
content: format!("I'll update that file.\n\n{}", tool_call_json),
|
||||||
|
});
|
||||||
|
|
||||||
|
context.add_message(Message {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: "Tool result: ✅ applied unified diff".to_string(),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add more messages to ensure we have enough for "first third" logic
|
||||||
|
for i in 0..6 {
|
||||||
|
context.add_message(Message {
|
||||||
|
role: MessageRole::Assistant,
|
||||||
|
content: format!("Response {}", i),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trigger thinning at 50%
|
||||||
|
context.used_tokens = 5000;
|
||||||
|
let (summary, _chars_saved) = context.thin_context();
|
||||||
|
|
||||||
|
println!("Thinning summary: {}", summary);
|
||||||
|
|
||||||
|
// Should have thinned the str_replace tool call
|
||||||
|
assert!(summary.contains("tool call") || summary.contains("chars saved"));
|
||||||
|
|
||||||
|
// Check that the large diff was replaced with a file reference
|
||||||
|
let first_third_end = context.conversation_history.len() / 3;
|
||||||
|
for i in 0..first_third_end {
|
||||||
|
if let Some(msg) = context.conversation_history.get(i) {
|
||||||
|
if matches!(msg.role, MessageRole::Assistant) && msg.content.contains("str_replace") {
|
||||||
|
// The diff should now reference an external file
|
||||||
|
assert!(msg.content.contains("<diff saved to"));
|
||||||
|
// Should not contain the large diff content
|
||||||
|
assert!(!msg.content.contains("old line"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_thin_context_no_large_results() {
|
fn test_thin_context_no_large_results() {
|
||||||
let mut context = ContextWindow::new(10000);
|
let mut context = ContextWindow::new(10000);
|
||||||
@@ -106,10 +219,10 @@ fn test_thin_context_no_large_results() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
context.used_tokens = 5000;
|
context.used_tokens = 5000;
|
||||||
let summary = context.thin_context();
|
let (summary, _chars_saved) = context.thin_context();
|
||||||
|
|
||||||
// Should report no large results found
|
// Should report no large results found
|
||||||
assert!(summary.contains("no large tool results found"));
|
assert!(summary.contains("no large tool results or tool calls found"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -135,7 +248,7 @@ fn test_thin_context_only_affects_first_third() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
context.used_tokens = 5000;
|
context.used_tokens = 5000;
|
||||||
let summary = context.thin_context();
|
let (summary, _chars_saved) = context.thin_context();
|
||||||
|
|
||||||
// First third is 4 messages (indices 0-3), so only indices 1 and 3 should be thinned
|
// First third is 4 messages (indices 0-3), so only indices 1 and 3 should be thinned
|
||||||
// That's 2 tool results
|
// That's 2 tool results
|
||||||
|
|||||||
@@ -166,6 +166,31 @@ impl CodeExecutor {
|
|||||||
|
|
||||||
/// Execute Bash code
|
/// Execute Bash code
|
||||||
async fn execute_bash(&self, code: &str) -> Result<ExecutionResult> {
|
async fn execute_bash(&self, code: &str) -> Result<ExecutionResult> {
|
||||||
|
// Check if this is a detached/daemon command that should run independently
|
||||||
|
let is_detached = code.trim_start().starts_with("setsid ")
|
||||||
|
|| code.trim_start().starts_with("nohup ")
|
||||||
|
|| code.contains(" disown")
|
||||||
|
|| (code.contains(" &") && (code.contains("nohup") || code.contains("setsid")));
|
||||||
|
|
||||||
|
if is_detached {
|
||||||
|
// For detached commands, just spawn and return immediately
|
||||||
|
use std::process::Stdio;
|
||||||
|
Command::new("bash")
|
||||||
|
.arg("-c")
|
||||||
|
.arg(code)
|
||||||
|
.stdin(Stdio::null())
|
||||||
|
.stdout(Stdio::null())
|
||||||
|
.stderr(Stdio::null())
|
||||||
|
.spawn()?;
|
||||||
|
|
||||||
|
return Ok(ExecutionResult {
|
||||||
|
stdout: "✅ Command launched in background (detached process)".to_string(),
|
||||||
|
stderr: String::new(),
|
||||||
|
exit_code: 0,
|
||||||
|
success: true,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
let output = Command::new("bash")
|
let output = Command::new("bash")
|
||||||
.arg("-c")
|
.arg("-c")
|
||||||
.arg(code)
|
.arg(code)
|
||||||
@@ -221,6 +246,29 @@ impl CodeExecutor {
|
|||||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||||
use tokio::process::Command as TokioCommand;
|
use tokio::process::Command as TokioCommand;
|
||||||
|
|
||||||
|
// Check if this is a detached/daemon command that should run independently
|
||||||
|
// Look for patterns like: setsid, nohup with &, or explicit backgrounding with disown
|
||||||
|
let is_detached = code.trim_start().starts_with("setsid ")
|
||||||
|
|| code.trim_start().starts_with("nohup ")
|
||||||
|
|| code.contains(" disown")
|
||||||
|
|| (code.contains(" &") && (code.contains("nohup") || code.contains("setsid")));
|
||||||
|
|
||||||
|
if is_detached {
|
||||||
|
// For detached commands, just spawn and return immediately
|
||||||
|
TokioCommand::new("bash")
|
||||||
|
.arg("-c")
|
||||||
|
.arg(code)
|
||||||
|
.spawn()?;
|
||||||
|
|
||||||
|
// Don't wait for the process - it's meant to run independently
|
||||||
|
return Ok(ExecutionResult {
|
||||||
|
stdout: "✅ Command launched in background (detached process)".to_string(),
|
||||||
|
stderr: String::new(),
|
||||||
|
exit_code: 0,
|
||||||
|
success: true,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
let mut child = TokioCommand::new("bash")
|
let mut child = TokioCommand::new("bash")
|
||||||
.arg("-c")
|
.arg("-c")
|
||||||
.arg(code)
|
.arg(code)
|
||||||
@@ -259,7 +307,7 @@ impl CodeExecutor {
|
|||||||
line = stderr_lines.next_line() => {
|
line = stderr_lines.next_line() => {
|
||||||
match line {
|
match line {
|
||||||
Ok(Some(line)) => {
|
Ok(Some(line)) => {
|
||||||
receiver.on_output_line(&format!("{}", line));
|
receiver.on_output_line(&line.to_string());
|
||||||
stderr_output.push(line);
|
stderr_output.push(line);
|
||||||
}
|
}
|
||||||
Ok(None) => {}, // stderr EOF, continue
|
Ok(None) => {}, // stderr EOF, continue
|
||||||
|
|||||||
@@ -156,8 +156,9 @@ impl AnthropicProvider {
|
|||||||
.post(ANTHROPIC_API_URL)
|
.post(ANTHROPIC_API_URL)
|
||||||
.header("x-api-key", &self.api_key)
|
.header("x-api-key", &self.api_key)
|
||||||
.header("anthropic-version", ANTHROPIC_VERSION)
|
.header("anthropic-version", ANTHROPIC_VERSION)
|
||||||
|
// Anthropic beta 1m context window. Enable if needed. It costs extra, so check first.
|
||||||
|
// .header("anthropic-beta", "context-1m-2025-08-07")
|
||||||
.header("content-type", "application/json");
|
.header("content-type", "application/json");
|
||||||
|
|
||||||
if streaming {
|
if streaming {
|
||||||
builder = builder.header("accept", "text/event-stream");
|
builder = builder.header("accept", "text/event-stream");
|
||||||
}
|
}
|
||||||
@@ -275,6 +276,7 @@ impl AnthropicProvider {
|
|||||||
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
|
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
|
||||||
let mut accumulated_usage: Option<Usage> = None;
|
let mut accumulated_usage: Option<Usage> = None;
|
||||||
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
|
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
|
||||||
|
let mut message_stopped = false; // Track if we've received message_stop
|
||||||
|
|
||||||
while let Some(chunk_result) = stream.next().await {
|
while let Some(chunk_result) = stream.next().await {
|
||||||
match chunk_result {
|
match chunk_result {
|
||||||
@@ -315,6 +317,12 @@ impl AnthropicProvider {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we've already sent the final chunk, skip processing more events
|
||||||
|
if message_stopped {
|
||||||
|
debug!("Skipping event after message_stop: {}", line);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// Parse Server-Sent Events format
|
// Parse Server-Sent Events format
|
||||||
if let Some(data) = line.strip_prefix("data: ") {
|
if let Some(data) = line.strip_prefix("data: ") {
|
||||||
if data == "[DONE]" {
|
if data == "[DONE]" {
|
||||||
@@ -450,6 +458,7 @@ impl AnthropicProvider {
|
|||||||
}
|
}
|
||||||
"message_stop" => {
|
"message_stop" => {
|
||||||
debug!("Received message stop event");
|
debug!("Received message stop event");
|
||||||
|
message_stopped = true;
|
||||||
let final_chunk = CompletionChunk {
|
let final_chunk = CompletionChunk {
|
||||||
content: String::new(),
|
content: String::new(),
|
||||||
finished: true,
|
finished: true,
|
||||||
@@ -459,7 +468,8 @@ impl AnthropicProvider {
|
|||||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||||
debug!("Receiver dropped, stopping stream");
|
debug!("Receiver dropped, stopping stream");
|
||||||
}
|
}
|
||||||
return accumulated_usage;
|
// Don't return here - let the stream naturally exhaust
|
||||||
|
// This prevents dropping the sender prematurely
|
||||||
}
|
}
|
||||||
"error" => {
|
"error" => {
|
||||||
if let Some(error) = event.error {
|
if let Some(error) = event.error {
|
||||||
@@ -467,7 +477,7 @@ impl AnthropicProvider {
|
|||||||
let _ = tx
|
let _ = tx
|
||||||
.send(Err(anyhow!("Anthropic API error: {:?}", error)))
|
.send(Err(anyhow!("Anthropic API error: {:?}", error)))
|
||||||
.await;
|
.await;
|
||||||
return accumulated_usage;
|
break; // Break to let stream exhaust naturally
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
@@ -486,7 +496,10 @@ impl AnthropicProvider {
|
|||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Stream error: {}", e);
|
error!("Stream error: {}", e);
|
||||||
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
||||||
return accumulated_usage;
|
// Don't return here either - let the stream exhaust naturally
|
||||||
|
// The error has been sent to the receiver, so it will handle it
|
||||||
|
// Breaking here ensures we clean up properly
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -213,7 +213,7 @@ impl DatabricksProvider {
|
|||||||
|
|
||||||
let mut builder = self
|
let mut builder = self
|
||||||
.client
|
.client
|
||||||
.post(&format!(
|
.post(format!(
|
||||||
"{}/serving-endpoints/{}/invocations",
|
"{}/serving-endpoints/{}/invocations",
|
||||||
self.host, self.model
|
self.host, self.model
|
||||||
))
|
))
|
||||||
@@ -298,6 +298,7 @@ impl DatabricksProvider {
|
|||||||
let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> =
|
let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> =
|
||||||
std::collections::HashMap::new(); // index -> (id, name, args)
|
std::collections::HashMap::new(); // index -> (id, name, args)
|
||||||
let mut incomplete_data_line = String::new(); // Buffer for incomplete data: lines
|
let mut incomplete_data_line = String::new(); // Buffer for incomplete data: lines
|
||||||
|
let mut chunk_count = 0;
|
||||||
let accumulated_usage: Option<Usage> = None;
|
let accumulated_usage: Option<Usage> = None;
|
||||||
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
|
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
|
||||||
|
|
||||||
@@ -305,6 +306,8 @@ impl DatabricksProvider {
|
|||||||
match chunk_result {
|
match chunk_result {
|
||||||
Ok(chunk) => {
|
Ok(chunk) => {
|
||||||
// Debug: Log raw bytes received
|
// Debug: Log raw bytes received
|
||||||
|
chunk_count += 1;
|
||||||
|
debug!("Processing chunk #{}", chunk_count);
|
||||||
debug!("Raw SSE bytes received: {} bytes", chunk.len());
|
debug!("Raw SSE bytes received: {} bytes", chunk.len());
|
||||||
|
|
||||||
// Append new bytes to our buffer
|
// Append new bytes to our buffer
|
||||||
@@ -589,13 +592,39 @@ impl DatabricksProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Stream error: {}", e);
|
error!("Stream error at chunk {}: {}", chunk_count, e);
|
||||||
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
|
||||||
|
// Check if this is a connection error that might be recoverable
|
||||||
|
let error_msg = e.to_string();
|
||||||
|
if error_msg.contains("unexpected EOF") || error_msg.contains("connection") {
|
||||||
|
warn!("Connection terminated unexpectedly at chunk {}, treating as end of stream", chunk_count);
|
||||||
|
// Don't send error, just break and finalize
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
||||||
|
}
|
||||||
return accumulated_usage;
|
return accumulated_usage;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Log final state
|
||||||
|
debug!("Stream ended after {} chunks", chunk_count);
|
||||||
|
debug!("Final state: buffer_len={}, incomplete_data_line_len={}, byte_buffer_len={}",
|
||||||
|
buffer.len(), incomplete_data_line.len(), byte_buffer.len());
|
||||||
|
debug!("Accumulated tool calls: {}", current_tool_calls.len());
|
||||||
|
|
||||||
|
// If we have any remaining data in buffers, log it for debugging
|
||||||
|
if !buffer.is_empty() {
|
||||||
|
debug!("Remaining buffer content: {:?}", buffer);
|
||||||
|
}
|
||||||
|
if !byte_buffer.is_empty() {
|
||||||
|
debug!("Remaining byte buffer: {} bytes", byte_buffer.len());
|
||||||
|
}
|
||||||
|
if !incomplete_data_line.is_empty() {
|
||||||
|
debug!("Remaining incomplete data line: {:?}", incomplete_data_line);
|
||||||
|
}
|
||||||
|
|
||||||
// If we have any incomplete data line at the end, try to process it
|
// If we have any incomplete data line at the end, try to process it
|
||||||
if !incomplete_data_line.is_empty() {
|
if !incomplete_data_line.is_empty() {
|
||||||
debug!(
|
debug!(
|
||||||
@@ -881,6 +910,14 @@ impl LLMProvider for DatabricksProvider {
|
|||||||
"Processing Databricks streaming request with {} messages",
|
"Processing Databricks streaming request with {} messages",
|
||||||
request.messages.len()
|
request.messages.len()
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Debug: Log tool count
|
||||||
|
if let Some(ref tools) = request.tools {
|
||||||
|
debug!("Request has {} tools", tools.len());
|
||||||
|
for tool in tools.iter().take(5) {
|
||||||
|
debug!(" Tool: {}", tool.name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
|
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
|
||||||
let temperature = request.temperature.unwrap_or(self.temperature);
|
let temperature = request.temperature.unwrap_or(self.temperature);
|
||||||
|
|||||||
@@ -88,10 +88,14 @@ pub mod anthropic;
|
|||||||
pub mod databricks;
|
pub mod databricks;
|
||||||
pub mod embedded;
|
pub mod embedded;
|
||||||
pub mod oauth;
|
pub mod oauth;
|
||||||
|
pub mod ollama;
|
||||||
|
pub mod openai;
|
||||||
|
|
||||||
pub use anthropic::AnthropicProvider;
|
pub use anthropic::AnthropicProvider;
|
||||||
pub use databricks::DatabricksProvider;
|
pub use databricks::DatabricksProvider;
|
||||||
pub use embedded::EmbeddedProvider;
|
pub use embedded::EmbeddedProvider;
|
||||||
|
pub use ollama::OllamaProvider;
|
||||||
|
pub use openai::OpenAIProvider;
|
||||||
|
|
||||||
/// Provider registry for managing multiple LLM providers
|
/// Provider registry for managing multiple LLM providers
|
||||||
pub struct ProviderRegistry {
|
pub struct ProviderRegistry {
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ async fn get_workspace_endpoints(host: &str) -> Result<OidcEndpoints> {
|
|||||||
if !resp.status().is_success() {
|
if !resp.status().is_success() {
|
||||||
return Err(anyhow::anyhow!(
|
return Err(anyhow::anyhow!(
|
||||||
"Failed to get OIDC configuration from {}",
|
"Failed to get OIDC configuration from {}",
|
||||||
oidc_url.to_string()
|
oidc_url
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
751
crates/g3-providers/src/ollama.rs
Normal file
751
crates/g3-providers/src/ollama.rs
Normal file
@@ -0,0 +1,751 @@
|
|||||||
|
//! Ollama LLM provider implementation for the g3-providers crate.
|
||||||
|
//!
|
||||||
|
//! This module provides an implementation of the `LLMProvider` trait for Ollama,
|
||||||
|
//! supporting both completion and streaming modes with native tool calling.
|
||||||
|
//!
|
||||||
|
//! # Features
|
||||||
|
//!
|
||||||
|
//! - Support for any Ollama model (llama3.2, mistral, qwen, etc.)
|
||||||
|
//! - Both completion and streaming response modes
|
||||||
|
//! - Native tool calling support for compatible models
|
||||||
|
//! - Configurable base URL (defaults to http://localhost:11434)
|
||||||
|
//! - Simple configuration with no authentication required
|
||||||
|
//!
|
||||||
|
//! # Usage
|
||||||
|
//!
|
||||||
|
//! ```rust,no_run
|
||||||
|
//! use g3_providers::{OllamaProvider, LLMProvider, CompletionRequest, Message, MessageRole};
|
||||||
|
//!
|
||||||
|
//! #[tokio::main]
|
||||||
|
//! async fn main() -> anyhow::Result<()> {
|
||||||
|
//! // Create the provider with default settings (localhost:11434)
|
||||||
|
//! let provider = OllamaProvider::new(
|
||||||
|
//! "llama3.2".to_string(),
|
||||||
|
//! None, // Optional: base_url
|
||||||
|
//! None, // Optional: max tokens
|
||||||
|
//! None, // Optional: temperature
|
||||||
|
//! )?;
|
||||||
|
//!
|
||||||
|
//! // Create a completion request
|
||||||
|
//! let request = CompletionRequest {
|
||||||
|
//! messages: vec![
|
||||||
|
//! Message {
|
||||||
|
//! role: MessageRole::User,
|
||||||
|
//! content: "Hello! How are you?".to_string(),
|
||||||
|
//! },
|
||||||
|
//! ],
|
||||||
|
//! max_tokens: Some(1000),
|
||||||
|
//! temperature: Some(0.7),
|
||||||
|
//! stream: false,
|
||||||
|
//! tools: None,
|
||||||
|
//! };
|
||||||
|
//!
|
||||||
|
//! // Get a completion
|
||||||
|
//! let response = provider.complete(request).await?;
|
||||||
|
//! println!("Response: {}", response.content);
|
||||||
|
//!
|
||||||
|
//! Ok(())
|
||||||
|
//! }
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use bytes::Bytes;
|
||||||
|
use futures_util::stream::StreamExt;
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
|
||||||
|
MessageRole, Tool, ToolCall, Usage,
|
||||||
|
};
|
||||||
|
|
||||||
|
const DEFAULT_BASE_URL: &str = "http://localhost:11434";
|
||||||
|
const DEFAULT_TIMEOUT_SECS: u64 = 600;
|
||||||
|
|
||||||
|
pub const OLLAMA_DEFAULT_MODEL: &str = "llama3.2";
|
||||||
|
pub const OLLAMA_KNOWN_MODELS: &[&str] = &[
|
||||||
|
"llama3.2",
|
||||||
|
"llama3.2:1b",
|
||||||
|
"llama3.2:3b",
|
||||||
|
"llama3.1",
|
||||||
|
"llama3.1:8b",
|
||||||
|
"llama3.1:70b",
|
||||||
|
"mistral",
|
||||||
|
"mistral-nemo",
|
||||||
|
"mixtral",
|
||||||
|
"qwen2.5",
|
||||||
|
"qwen2.5:7b",
|
||||||
|
"qwen2.5:14b",
|
||||||
|
"qwen2.5:32b",
|
||||||
|
"qwen2.5-coder",
|
||||||
|
"qwen2.5-coder:7b",
|
||||||
|
"qwen3-coder",
|
||||||
|
"phi3",
|
||||||
|
"gemma2",
|
||||||
|
];
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct OllamaProvider {
|
||||||
|
client: Client,
|
||||||
|
base_url: String,
|
||||||
|
model: String,
|
||||||
|
max_tokens: Option<u32>,
|
||||||
|
temperature: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OllamaProvider {
|
||||||
|
pub fn new(
|
||||||
|
model: String,
|
||||||
|
base_url: Option<String>,
|
||||||
|
max_tokens: Option<u32>,
|
||||||
|
temperature: Option<f32>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let client = Client::builder()
|
||||||
|
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
|
||||||
|
.build()
|
||||||
|
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
|
||||||
|
|
||||||
|
let base_url = base_url
|
||||||
|
.unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
|
||||||
|
.trim_end_matches('/')
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Initialized Ollama provider with model: {} at {}",
|
||||||
|
model, base_url
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
client,
|
||||||
|
base_url,
|
||||||
|
model,
|
||||||
|
max_tokens,
|
||||||
|
temperature: temperature.unwrap_or(0.7),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_tools(&self, tools: &[Tool]) -> Vec<OllamaTool> {
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.map(|tool| OllamaTool {
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: OllamaFunction {
|
||||||
|
name: tool.name.clone(),
|
||||||
|
description: tool.description.clone(),
|
||||||
|
parameters: tool.input_schema.clone(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_messages(&self, messages: &[Message]) -> Result<Vec<OllamaMessage>> {
|
||||||
|
let mut ollama_messages = Vec::new();
|
||||||
|
|
||||||
|
for message in messages {
|
||||||
|
let role = match message.role {
|
||||||
|
MessageRole::System => "system",
|
||||||
|
MessageRole::User => "user",
|
||||||
|
MessageRole::Assistant => "assistant",
|
||||||
|
};
|
||||||
|
|
||||||
|
ollama_messages.push(OllamaMessage {
|
||||||
|
role: role.to_string(),
|
||||||
|
content: message.content.clone(),
|
||||||
|
tool_calls: None, // Only used in responses
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if ollama_messages.is_empty() {
|
||||||
|
return Err(anyhow!("At least one message is required"));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ollama_messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_request_body(
|
||||||
|
&self,
|
||||||
|
messages: &[Message],
|
||||||
|
tools: Option<&[Tool]>,
|
||||||
|
streaming: bool,
|
||||||
|
max_tokens: Option<u32>,
|
||||||
|
temperature: f32,
|
||||||
|
) -> Result<OllamaRequest> {
|
||||||
|
let ollama_messages = self.convert_messages(messages)?;
|
||||||
|
let ollama_tools = tools.map(|t| self.convert_tools(t));
|
||||||
|
|
||||||
|
let mut options = OllamaOptions {
|
||||||
|
temperature,
|
||||||
|
num_predict: max_tokens,
|
||||||
|
};
|
||||||
|
|
||||||
|
// If max_tokens is provided, use it; otherwise use the instance default
|
||||||
|
if max_tokens.is_none() {
|
||||||
|
options.num_predict = self.max_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
let request = OllamaRequest {
|
||||||
|
model: self.model.clone(),
|
||||||
|
messages: ollama_messages,
|
||||||
|
tools: ollama_tools,
|
||||||
|
stream: streaming,
|
||||||
|
options,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn parse_streaming_response(
|
||||||
|
&self,
|
||||||
|
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
|
||||||
|
tx: mpsc::Sender<Result<CompletionChunk>>,
|
||||||
|
) -> Option<Usage> {
|
||||||
|
let mut buffer = String::new();
|
||||||
|
let mut accumulated_usage: Option<Usage> = None;
|
||||||
|
let mut current_tool_calls: Vec<OllamaToolCall> = Vec::new();
|
||||||
|
let mut byte_buffer = Vec::new();
|
||||||
|
|
||||||
|
while let Some(chunk_result) = stream.next().await {
|
||||||
|
match chunk_result {
|
||||||
|
Ok(chunk) => {
|
||||||
|
// Append new bytes to our buffer
|
||||||
|
byte_buffer.extend_from_slice(&chunk);
|
||||||
|
|
||||||
|
// Try to convert the entire buffer to UTF-8
|
||||||
|
let chunk_str = match std::str::from_utf8(&byte_buffer) {
|
||||||
|
Ok(s) => {
|
||||||
|
let result = s.to_string();
|
||||||
|
byte_buffer.clear();
|
||||||
|
result
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let valid_up_to = e.valid_up_to();
|
||||||
|
if valid_up_to > 0 {
|
||||||
|
let valid_bytes =
|
||||||
|
byte_buffer.drain(..valid_up_to).collect::<Vec<_>>();
|
||||||
|
std::str::from_utf8(&valid_bytes).unwrap().to_string()
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
buffer.push_str(&chunk_str);
|
||||||
|
|
||||||
|
// Process complete lines
|
||||||
|
while let Some(line_end) = buffer.find('\n') {
|
||||||
|
let line = buffer[..line_end].trim().to_string();
|
||||||
|
buffer.drain(..line_end + 1);
|
||||||
|
|
||||||
|
if line.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ollama streaming sends JSON objects per line
|
||||||
|
match serde_json::from_str::<OllamaStreamChunk>(&line) {
|
||||||
|
Ok(chunk) => {
|
||||||
|
// Handle text content
|
||||||
|
if let Some(message) = &chunk.message {
|
||||||
|
let content = &message.content;
|
||||||
|
if !content.is_empty() {
|
||||||
|
debug!("Sending text chunk: '{}'", content);
|
||||||
|
let chunk = CompletionChunk {
|
||||||
|
content: content.clone(),
|
||||||
|
finished: false,
|
||||||
|
usage: None,
|
||||||
|
tool_calls: None,
|
||||||
|
};
|
||||||
|
if tx.send(Ok(chunk)).await.is_err() {
|
||||||
|
debug!("Receiver dropped, stopping stream");
|
||||||
|
return accumulated_usage;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle tool calls
|
||||||
|
if let Some(tool_calls) = &message.tool_calls {
|
||||||
|
current_tool_calls.extend(tool_calls.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if stream is done
|
||||||
|
if chunk.done.unwrap_or(false) {
|
||||||
|
debug!("Stream completed");
|
||||||
|
|
||||||
|
// Update usage if available
|
||||||
|
if let Some(eval_count) = chunk.eval_count {
|
||||||
|
accumulated_usage = Some(Usage {
|
||||||
|
prompt_tokens: chunk.prompt_eval_count.unwrap_or(0),
|
||||||
|
completion_tokens: eval_count,
|
||||||
|
total_tokens: chunk.prompt_eval_count.unwrap_or(0)
|
||||||
|
+ eval_count,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send final chunk with tool calls if any
|
||||||
|
let final_tool_calls: Vec<ToolCall> = current_tool_calls
|
||||||
|
.iter()
|
||||||
|
.map(|tc| ToolCall {
|
||||||
|
id: tc.function.name.clone(), // Ollama doesn't provide IDs
|
||||||
|
tool: tc.function.name.clone(),
|
||||||
|
args: tc.function.arguments.clone(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let final_chunk = CompletionChunk {
|
||||||
|
content: String::new(),
|
||||||
|
finished: true,
|
||||||
|
usage: accumulated_usage.clone(),
|
||||||
|
tool_calls: if final_tool_calls.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(final_tool_calls)
|
||||||
|
},
|
||||||
|
};
|
||||||
|
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||||
|
debug!("Receiver dropped, stopping stream");
|
||||||
|
}
|
||||||
|
return accumulated_usage;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
debug!("Failed to parse Ollama stream chunk: {} - Line: {}", e, line);
|
||||||
|
// Don't error out, just continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Stream error: {}", e);
|
||||||
|
let error_msg = e.to_string();
|
||||||
|
if error_msg.contains("unexpected EOF") || error_msg.contains("connection") {
|
||||||
|
warn!("Connection terminated unexpectedly, treating as end of stream");
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
||||||
|
}
|
||||||
|
return accumulated_usage;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send final chunk if we haven't already
|
||||||
|
let final_tool_calls: Vec<ToolCall> = current_tool_calls
|
||||||
|
.iter()
|
||||||
|
.map(|tc| ToolCall {
|
||||||
|
id: tc.function.name.clone(),
|
||||||
|
tool: tc.function.name.clone(),
|
||||||
|
args: tc.function.arguments.clone(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let final_chunk = CompletionChunk {
|
||||||
|
content: String::new(),
|
||||||
|
finished: true,
|
||||||
|
usage: accumulated_usage.clone(),
|
||||||
|
tool_calls: if final_tool_calls.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(final_tool_calls)
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let _ = tx.send(Ok(final_chunk)).await;
|
||||||
|
accumulated_usage
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fetch available models from the Ollama instance
|
||||||
|
pub async fn fetch_available_models(&self) -> Result<Vec<String>> {
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.get(format!("{}/api/tags", self.base_url))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Failed to fetch Ollama models: {}", e))?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
let status = response.status();
|
||||||
|
let error_text = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||||
|
return Err(anyhow!(
|
||||||
|
"Failed to fetch Ollama models: {} - {}",
|
||||||
|
status,
|
||||||
|
error_text
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let json: serde_json::Value = response.json().await?;
|
||||||
|
let models = json
|
||||||
|
.get("models")
|
||||||
|
.and_then(|v| v.as_array())
|
||||||
|
.ok_or_else(|| anyhow!("Unexpected response format: missing 'models' array"))?;
|
||||||
|
|
||||||
|
let model_names: Vec<String> = models
|
||||||
|
.iter()
|
||||||
|
.filter_map(|model| model.get("name").and_then(|n| n.as_str()).map(String::from))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
debug!("Found {} models in Ollama", model_names.len());
|
||||||
|
Ok(model_names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl LLMProvider for OllamaProvider {
|
||||||
|
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
|
||||||
|
debug!(
|
||||||
|
"Processing Ollama completion request with {} messages",
|
||||||
|
request.messages.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
let max_tokens = request.max_tokens.or(self.max_tokens);
|
||||||
|
let temperature = request.temperature.unwrap_or(self.temperature);
|
||||||
|
|
||||||
|
let request_body = self.create_request_body(
|
||||||
|
&request.messages,
|
||||||
|
request.tools.as_deref(),
|
||||||
|
false,
|
||||||
|
max_tokens,
|
||||||
|
temperature,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Sending request to Ollama API: model={}, temperature={}",
|
||||||
|
self.model, request_body.options.temperature
|
||||||
|
);
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(format!("{}/api/chat", self.base_url))
|
||||||
|
.json(&request_body)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Failed to send request to Ollama API: {}", e))?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
if !status.is_success() {
|
||||||
|
let error_text = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||||
|
return Err(anyhow!("Ollama API error {}: {}", status, error_text));
|
||||||
|
}
|
||||||
|
|
||||||
|
let response_text = response.text().await?;
|
||||||
|
debug!("Raw Ollama API response: {}", response_text);
|
||||||
|
|
||||||
|
let ollama_response: OllamaResponse =
|
||||||
|
serde_json::from_str(&response_text).map_err(|e| {
|
||||||
|
anyhow!(
|
||||||
|
"Failed to parse Ollama response: {} - Response: {}",
|
||||||
|
e,
|
||||||
|
response_text
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let content = ollama_response.message.content.clone();
|
||||||
|
|
||||||
|
let usage = Usage {
|
||||||
|
prompt_tokens: ollama_response.prompt_eval_count.unwrap_or(0),
|
||||||
|
completion_tokens: ollama_response.eval_count.unwrap_or(0),
|
||||||
|
total_tokens: ollama_response.prompt_eval_count.unwrap_or(0)
|
||||||
|
+ ollama_response.eval_count.unwrap_or(0),
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Ollama completion successful: {} tokens generated",
|
||||||
|
usage.completion_tokens
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(CompletionResponse {
|
||||||
|
content,
|
||||||
|
usage,
|
||||||
|
model: self.model.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
|
||||||
|
debug!(
|
||||||
|
"Processing Ollama request (non-streaming) with {} messages",
|
||||||
|
request.messages.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(ref tools) = request.tools {
|
||||||
|
debug!("Request has {} tools", tools.len());
|
||||||
|
for tool in tools.iter().take(5) {
|
||||||
|
debug!(" Tool: {}", tool.name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_tokens = request.max_tokens.or(self.max_tokens);
|
||||||
|
let temperature = request.temperature.unwrap_or(self.temperature);
|
||||||
|
|
||||||
|
let request_body = self.create_request_body(
|
||||||
|
&request.messages,
|
||||||
|
request.tools.as_deref(),
|
||||||
|
false, // Use non-streaming mode to avoid streaming bugs
|
||||||
|
max_tokens,
|
||||||
|
temperature,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Sending request to Ollama API (stream=false): model={}, temperature={}",
|
||||||
|
self.model, request_body.options.temperature
|
||||||
|
);
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(format!("{}/api/chat", self.base_url))
|
||||||
|
.json(&request_body)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Failed to send request to Ollama API: {}", e))?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
if !status.is_success() {
|
||||||
|
let error_text = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||||
|
return Err(anyhow!("Ollama API error {}: {}", status, error_text));
|
||||||
|
}
|
||||||
|
|
||||||
|
// For non-streaming, parse the complete JSON response
|
||||||
|
let response_text = response.text().await?;
|
||||||
|
debug!("Raw Ollama API response: {}", response_text);
|
||||||
|
|
||||||
|
let ollama_response: OllamaResponse =
|
||||||
|
serde_json::from_str(&response_text).map_err(|e| {
|
||||||
|
anyhow!(
|
||||||
|
"Failed to parse Ollama response: {} - Response: {}",
|
||||||
|
e,
|
||||||
|
response_text
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let (tx, rx) = mpsc::channel(100);
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let content = ollama_response.message.content;
|
||||||
|
let usage = Usage {
|
||||||
|
prompt_tokens: ollama_response.prompt_eval_count.unwrap_or(0),
|
||||||
|
completion_tokens: ollama_response.eval_count.unwrap_or(0),
|
||||||
|
total_tokens: ollama_response.prompt_eval_count.unwrap_or(0)
|
||||||
|
+ ollama_response.eval_count.unwrap_or(0),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Extract tool calls if present
|
||||||
|
let tool_calls: Option<Vec<ToolCall>> = ollama_response.message.tool_calls.map(|tcs| {
|
||||||
|
tcs.iter()
|
||||||
|
.map(|tc| ToolCall {
|
||||||
|
id: tc.function.name.clone(),
|
||||||
|
tool: tc.function.name.clone(),
|
||||||
|
args: tc.function.arguments.clone(),
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
});
|
||||||
|
|
||||||
|
// Send content if any
|
||||||
|
if !content.is_empty() {
|
||||||
|
let _ = tx.send(Ok(CompletionChunk {
|
||||||
|
content,
|
||||||
|
finished: false,
|
||||||
|
usage: None,
|
||||||
|
tool_calls: None,
|
||||||
|
})).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send final chunk with usage and tool calls
|
||||||
|
let _ = tx.send(Ok(CompletionChunk {
|
||||||
|
content: String::new(),
|
||||||
|
finished: true,
|
||||||
|
usage: Some(usage),
|
||||||
|
tool_calls,
|
||||||
|
})).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(ReceiverStream::new(rx))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"ollama"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn model(&self) -> &str {
|
||||||
|
&self.model
|
||||||
|
}
|
||||||
|
|
||||||
|
fn has_native_tool_calling(&self) -> bool {
|
||||||
|
// Most modern Ollama models support tool calling
|
||||||
|
// Models like llama3.2, qwen2.5, mistral, etc. have good tool support
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ollama API request/response structures
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct OllamaRequest {
|
||||||
|
model: String,
|
||||||
|
messages: Vec<OllamaMessage>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tools: Option<Vec<OllamaTool>>,
|
||||||
|
stream: bool,
|
||||||
|
options: OllamaOptions,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct OllamaOptions {
|
||||||
|
temperature: f32,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
num_predict: Option<u32>, // Ollama's equivalent of max_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct OllamaTool {
|
||||||
|
r#type: String,
|
||||||
|
function: OllamaFunction,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct OllamaFunction {
|
||||||
|
name: String,
|
||||||
|
description: String,
|
||||||
|
parameters: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
struct OllamaMessage {
|
||||||
|
role: String,
|
||||||
|
content: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tool_calls: Option<Vec<OllamaToolCall>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
struct OllamaToolCall {
|
||||||
|
function: OllamaToolCallFunction,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
struct OllamaToolCallFunction {
|
||||||
|
name: String,
|
||||||
|
arguments: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OllamaResponse {
|
||||||
|
message: OllamaMessage,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
done: bool,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
total_duration: Option<u64>,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
load_duration: Option<u64>,
|
||||||
|
prompt_eval_count: Option<u32>,
|
||||||
|
eval_count: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OllamaStreamChunk {
|
||||||
|
message: Option<OllamaMessage>,
|
||||||
|
done: Option<bool>,
|
||||||
|
prompt_eval_count: Option<u32>,
|
||||||
|
eval_count: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_provider_creation() {
|
||||||
|
let provider = OllamaProvider::new(
|
||||||
|
"llama3.2".to_string(),
|
||||||
|
None,
|
||||||
|
Some(1000),
|
||||||
|
Some(0.7),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(provider.model(), "llama3.2");
|
||||||
|
assert_eq!(provider.name(), "ollama");
|
||||||
|
assert!(provider.has_native_tool_calling());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_message_conversion() {
|
||||||
|
let provider = OllamaProvider::new(
|
||||||
|
"llama3.2".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let messages = vec![
|
||||||
|
Message {
|
||||||
|
role: MessageRole::System,
|
||||||
|
content: "You are a helpful assistant.".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: "Hello!".to_string(),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let ollama_messages = provider.convert_messages(&messages).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(ollama_messages.len(), 2);
|
||||||
|
assert_eq!(ollama_messages[0].role, "system");
|
||||||
|
assert_eq!(ollama_messages[1].role, "user");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tool_conversion() {
|
||||||
|
let provider = OllamaProvider::new(
|
||||||
|
"llama3.2".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let tools = vec![Tool {
|
||||||
|
name: "get_weather".to_string(),
|
||||||
|
description: "Get the current weather".to_string(),
|
||||||
|
input_schema: serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}),
|
||||||
|
}];
|
||||||
|
|
||||||
|
let ollama_tools = provider.convert_tools(&tools);
|
||||||
|
|
||||||
|
assert_eq!(ollama_tools.len(), 1);
|
||||||
|
assert_eq!(ollama_tools[0].r#type, "function");
|
||||||
|
assert_eq!(ollama_tools[0].function.name, "get_weather");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_custom_base_url() {
|
||||||
|
let provider = OllamaProvider::new(
|
||||||
|
"llama3.2".to_string(),
|
||||||
|
Some("http://custom:11434".to_string()),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(provider.base_url, "http://custom:11434");
|
||||||
|
}
|
||||||
|
}
|
||||||
495
crates/g3-providers/src/openai.rs
Normal file
495
crates/g3-providers/src/openai.rs
Normal file
@@ -0,0 +1,495 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use bytes::Bytes;
|
||||||
|
use futures_util::stream::StreamExt;
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde_json::json;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
|
use tracing::{debug, error};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider,
|
||||||
|
Message, MessageRole, Tool, ToolCall, Usage,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct OpenAIProvider {
|
||||||
|
client: Client,
|
||||||
|
api_key: String,
|
||||||
|
model: String,
|
||||||
|
base_url: String,
|
||||||
|
max_tokens: Option<u32>,
|
||||||
|
_temperature: Option<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAIProvider {
|
||||||
|
pub fn new(
|
||||||
|
api_key: String,
|
||||||
|
model: Option<String>,
|
||||||
|
base_url: Option<String>,
|
||||||
|
max_tokens: Option<u32>,
|
||||||
|
temperature: Option<f32>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Ok(Self {
|
||||||
|
client: Client::new(),
|
||||||
|
api_key,
|
||||||
|
model: model.unwrap_or_else(|| "gpt-4o".to_string()),
|
||||||
|
base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
|
||||||
|
max_tokens,
|
||||||
|
_temperature: temperature,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_request_body(
|
||||||
|
&self,
|
||||||
|
messages: &[Message],
|
||||||
|
tools: Option<&[Tool]>,
|
||||||
|
stream: bool,
|
||||||
|
max_tokens: Option<u32>,
|
||||||
|
_temperature: Option<f32>,
|
||||||
|
) -> serde_json::Value {
|
||||||
|
let mut body = json!({
|
||||||
|
"model": self.model,
|
||||||
|
"messages": convert_messages(messages),
|
||||||
|
"stream": stream,
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Some(max_tokens) = max_tokens.or(self.max_tokens) {
|
||||||
|
body["max_completion_tokens"] = json!(max_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAI calls with temp setting seem to fail, so don't send one.
|
||||||
|
// if let Some(temperature) = temperature.or(self.temperature) {
|
||||||
|
// body["temperature"] = json!(temperature);
|
||||||
|
// }
|
||||||
|
|
||||||
|
if let Some(tools) = tools {
|
||||||
|
if !tools.is_empty() {
|
||||||
|
body["tools"] = json!(convert_tools(tools));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if stream {
|
||||||
|
body["stream_options"] = json!({
|
||||||
|
"include_usage": true,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
body
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn parse_streaming_response(
|
||||||
|
&self,
|
||||||
|
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
|
||||||
|
tx: mpsc::Sender<Result<CompletionChunk>>,
|
||||||
|
) -> Option<Usage> {
|
||||||
|
let mut buffer = String::new();
|
||||||
|
let mut accumulated_content = String::new();
|
||||||
|
let mut accumulated_usage: Option<Usage> = None;
|
||||||
|
let mut current_tool_calls: Vec<OpenAIStreamingToolCall> = Vec::new();
|
||||||
|
|
||||||
|
while let Some(chunk_result) = stream.next().await {
|
||||||
|
match chunk_result {
|
||||||
|
Ok(chunk) => {
|
||||||
|
let chunk_str = match std::str::from_utf8(&chunk) {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to parse chunk as UTF-8: {}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
buffer.push_str(chunk_str);
|
||||||
|
|
||||||
|
// Process complete lines
|
||||||
|
while let Some(line_end) = buffer.find('\n') {
|
||||||
|
let line = buffer[..line_end].trim().to_string();
|
||||||
|
buffer.drain(..line_end + 1);
|
||||||
|
|
||||||
|
if line.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse Server-Sent Events format
|
||||||
|
if let Some(data) = line.strip_prefix("data: ") {
|
||||||
|
if data == "[DONE]" {
|
||||||
|
debug!("Received stream completion marker");
|
||||||
|
|
||||||
|
// Send final chunk with accumulated content and tool calls
|
||||||
|
if !accumulated_content.is_empty() || !current_tool_calls.is_empty() {
|
||||||
|
let tool_calls = if current_tool_calls.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(
|
||||||
|
current_tool_calls
|
||||||
|
.iter()
|
||||||
|
.filter_map(|tc| tc.to_tool_call())
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let final_chunk = CompletionChunk {
|
||||||
|
content: accumulated_content.clone(),
|
||||||
|
finished: true,
|
||||||
|
tool_calls,
|
||||||
|
usage: accumulated_usage.clone(),
|
||||||
|
};
|
||||||
|
let _ = tx.send(Ok(final_chunk)).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
return accumulated_usage;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the JSON data
|
||||||
|
match serde_json::from_str::<OpenAIStreamChunk>(data) {
|
||||||
|
Ok(chunk_data) => {
|
||||||
|
// Handle content
|
||||||
|
for choice in &chunk_data.choices {
|
||||||
|
if let Some(content) = &choice.delta.content {
|
||||||
|
accumulated_content.push_str(content);
|
||||||
|
|
||||||
|
let chunk = CompletionChunk {
|
||||||
|
content: content.clone(),
|
||||||
|
finished: false,
|
||||||
|
tool_calls: None,
|
||||||
|
usage: None,
|
||||||
|
};
|
||||||
|
if tx.send(Ok(chunk)).await.is_err() {
|
||||||
|
debug!("Receiver dropped, stopping stream");
|
||||||
|
return accumulated_usage;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle tool calls
|
||||||
|
if let Some(delta_tool_calls) = &choice.delta.tool_calls {
|
||||||
|
for delta_tool_call in delta_tool_calls {
|
||||||
|
if let Some(index) = delta_tool_call.index {
|
||||||
|
// Ensure we have enough tool calls in our vector
|
||||||
|
while current_tool_calls.len() <= index {
|
||||||
|
current_tool_calls
|
||||||
|
.push(OpenAIStreamingToolCall::default());
|
||||||
|
}
|
||||||
|
|
||||||
|
let tool_call = &mut current_tool_calls[index];
|
||||||
|
|
||||||
|
if let Some(id) = &delta_tool_call.id {
|
||||||
|
tool_call.id = Some(id.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(function) = &delta_tool_call.function {
|
||||||
|
if let Some(name) = &function.name {
|
||||||
|
tool_call.name = Some(name.clone());
|
||||||
|
}
|
||||||
|
if let Some(arguments) = &function.arguments {
|
||||||
|
tool_call.arguments.push_str(arguments);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle usage
|
||||||
|
if let Some(usage) = chunk_data.usage {
|
||||||
|
accumulated_usage = Some(Usage {
|
||||||
|
prompt_tokens: usage.prompt_tokens,
|
||||||
|
completion_tokens: usage.completion_tokens,
|
||||||
|
total_tokens: usage.total_tokens,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Stream error: {}", e);
|
||||||
|
let _ = tx.send(Err(anyhow::anyhow!("Stream error: {}", e))).await;
|
||||||
|
return accumulated_usage;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send final chunk if we haven't already
|
||||||
|
let tool_calls = if current_tool_calls.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(
|
||||||
|
current_tool_calls
|
||||||
|
.iter()
|
||||||
|
.filter_map(|tc| tc.to_tool_call())
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let final_chunk = CompletionChunk {
|
||||||
|
content: String::new(),
|
||||||
|
finished: true,
|
||||||
|
tool_calls,
|
||||||
|
usage: accumulated_usage.clone(),
|
||||||
|
};
|
||||||
|
let _ = tx.send(Ok(final_chunk)).await;
|
||||||
|
|
||||||
|
accumulated_usage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl LLMProvider for OpenAIProvider {
|
||||||
|
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
|
||||||
|
debug!(
|
||||||
|
"Processing OpenAI completion request with {} messages",
|
||||||
|
request.messages.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
let body = self.create_request_body(
|
||||||
|
&request.messages,
|
||||||
|
request.tools.as_deref(),
|
||||||
|
false,
|
||||||
|
request.max_tokens,
|
||||||
|
request.temperature,
|
||||||
|
);
|
||||||
|
|
||||||
|
debug!("Sending request to OpenAI API: model={}", self.model);
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(format!("{}/chat/completions", self.base_url))
|
||||||
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
|
.json(&body)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
if !status.is_success() {
|
||||||
|
let error_text = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||||
|
return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text));
|
||||||
|
}
|
||||||
|
|
||||||
|
let openai_response: OpenAIResponse = response.json().await?;
|
||||||
|
|
||||||
|
let content = openai_response
|
||||||
|
.choices
|
||||||
|
.first()
|
||||||
|
.and_then(|choice| choice.message.content.clone())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let usage = Usage {
|
||||||
|
prompt_tokens: openai_response.usage.prompt_tokens,
|
||||||
|
completion_tokens: openai_response.usage.completion_tokens,
|
||||||
|
total_tokens: openai_response.usage.total_tokens,
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"OpenAI completion successful: {} tokens generated",
|
||||||
|
usage.completion_tokens
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(CompletionResponse {
|
||||||
|
content,
|
||||||
|
usage,
|
||||||
|
model: self.model.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
|
||||||
|
debug!(
|
||||||
|
"Processing OpenAI streaming request with {} messages",
|
||||||
|
request.messages.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
let body = self.create_request_body(
|
||||||
|
&request.messages,
|
||||||
|
request.tools.as_deref(),
|
||||||
|
true,
|
||||||
|
request.max_tokens,
|
||||||
|
request.temperature,
|
||||||
|
);
|
||||||
|
|
||||||
|
debug!("Sending streaming request to OpenAI API: model={}", self.model);
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(format!("{}/chat/completions", self.base_url))
|
||||||
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
|
.json(&body)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
if !status.is_success() {
|
||||||
|
let error_text = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||||
|
return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text));
|
||||||
|
}
|
||||||
|
|
||||||
|
let stream = response.bytes_stream();
|
||||||
|
let (tx, rx) = mpsc::channel(100);
|
||||||
|
|
||||||
|
// Spawn task to process the stream
|
||||||
|
let provider = self.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let usage = provider.parse_streaming_response(stream, tx).await;
|
||||||
|
// Log the final usage if available
|
||||||
|
if let Some(usage) = usage {
|
||||||
|
debug!(
|
||||||
|
"Stream completed with usage - prompt: {}, completion: {}, total: {}",
|
||||||
|
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(ReceiverStream::new(rx))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"openai"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn model(&self) -> &str {
|
||||||
|
&self.model
|
||||||
|
}
|
||||||
|
|
||||||
|
fn has_native_tool_calling(&self) -> bool {
|
||||||
|
// OpenAI models support native tool calling
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_messages(messages: &[Message]) -> Vec<serde_json::Value> {
|
||||||
|
messages
|
||||||
|
.iter()
|
||||||
|
.map(|msg| {
|
||||||
|
json!({
|
||||||
|
"role": match msg.role {
|
||||||
|
MessageRole::System => "system",
|
||||||
|
MessageRole::User => "user",
|
||||||
|
MessageRole::Assistant => "assistant",
|
||||||
|
},
|
||||||
|
"content": msg.content,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_tools(tools: &[Tool]) -> Vec<serde_json::Value> {
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.map(|tool| {
|
||||||
|
json!({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": tool.input_schema,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAI API response structures
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OpenAIResponse {
|
||||||
|
choices: Vec<OpenAIChoice>,
|
||||||
|
usage: OpenAIUsage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OpenAIChoice {
|
||||||
|
message: OpenAIMessage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OpenAIMessage {
|
||||||
|
content: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
tool_calls: Option<Vec<OpenAIToolCall>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OpenAIToolCall {
|
||||||
|
id: String,
|
||||||
|
function: OpenAIFunction,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OpenAIFunction {
|
||||||
|
name: String,
|
||||||
|
arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Streaming tool call accumulator
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
struct OpenAIStreamingToolCall {
|
||||||
|
id: Option<String>,
|
||||||
|
name: Option<String>,
|
||||||
|
arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAIStreamingToolCall {
|
||||||
|
fn to_tool_call(&self) -> Option<ToolCall> {
|
||||||
|
let id = self.id.as_ref()?;
|
||||||
|
let name = self.name.as_ref()?;
|
||||||
|
|
||||||
|
let args = serde_json::from_str(&self.arguments).unwrap_or(serde_json::Value::Null);
|
||||||
|
|
||||||
|
Some(ToolCall {
|
||||||
|
id: id.clone(),
|
||||||
|
tool: name.clone(),
|
||||||
|
args,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OpenAIUsage {
|
||||||
|
prompt_tokens: u32,
|
||||||
|
completion_tokens: u32,
|
||||||
|
total_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Streaming response structures
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OpenAIStreamChunk {
|
||||||
|
choices: Vec<OpenAIStreamChoice>,
|
||||||
|
usage: Option<OpenAIUsage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OpenAIStreamChoice {
|
||||||
|
delta: OpenAIDelta,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OpenAIDelta {
|
||||||
|
content: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
tool_calls: Option<Vec<OpenAIDeltaToolCall>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OpenAIDeltaToolCall {
|
||||||
|
index: Option<usize>,
|
||||||
|
id: Option<String>,
|
||||||
|
function: Option<OpenAIDeltaFunction>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OpenAIDeltaFunction {
|
||||||
|
name: Option<String>,
|
||||||
|
arguments: Option<String>,
|
||||||
|
}
|
||||||
389
docs/ACCUMULATIVE_MODE.md
Normal file
389
docs/ACCUMULATIVE_MODE.md
Normal file
@@ -0,0 +1,389 @@
|
|||||||
|
# Accumulative Autonomous Mode
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Accumulative Autonomous Mode is the **new default interactive mode** for G3. It combines the ease of interactive chat with the power of autonomous implementation, allowing you to build projects iteratively by describing what you want, one requirement at a time.
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
### The Flow
|
||||||
|
|
||||||
|
1. **Start G3** in any directory (no arguments needed)
|
||||||
|
2. **Describe** what you want to build
|
||||||
|
3. **G3 automatically**:
|
||||||
|
- Adds your input to accumulated requirements
|
||||||
|
- Runs autonomous mode (coach-player feedback loop)
|
||||||
|
- Implements your requirements with quality checks
|
||||||
|
4. **Continue** adding more requirements or refinements
|
||||||
|
5. **Repeat** until your project is complete
|
||||||
|
|
||||||
|
### Example Session
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cd ~/projects/my-new-app
|
||||||
|
$ g3
|
||||||
|
|
||||||
|
🪿 G3 AI Coding Agent - Accumulative Mode
|
||||||
|
>> describe what you want, I'll build it iteratively
|
||||||
|
|
||||||
|
📁 Workspace: /Users/you/projects/my-new-app
|
||||||
|
|
||||||
|
💡 Each input you provide will be added to requirements
|
||||||
|
and I'll automatically work on implementing them.
|
||||||
|
|
||||||
|
Type 'exit' or 'quit' to stop, Ctrl+D to finish
|
||||||
|
|
||||||
|
============================================================
|
||||||
|
📝 What would you like me to build? (describe your requirements)
|
||||||
|
============================================================
|
||||||
|
requirement> create a simple web server in Python with Flask that serves a homepage
|
||||||
|
|
||||||
|
📋 Current instructions and requirements (Turn 1):
|
||||||
|
create a simple web server in Python with Flask that serves a homepage
|
||||||
|
|
||||||
|
🚀 Starting autonomous implementation...
|
||||||
|
|
||||||
|
🤖 G3 AI Coding Agent - Autonomous Mode
|
||||||
|
📁 Using workspace: /Users/you/projects/my-new-app
|
||||||
|
📋 Requirements loaded from --requirements flag
|
||||||
|
🔄 Starting coach-player feedback loop...
|
||||||
|
📂 No existing implementation files detected
|
||||||
|
🎯 Starting with player implementation
|
||||||
|
|
||||||
|
=== TURN 1/5 - PLAYER MODE ===
|
||||||
|
🎯 Starting player implementation...
|
||||||
|
📋 Player starting initial implementation (no prior coach feedback)
|
||||||
|
|
||||||
|
[Player creates files, writes code...]
|
||||||
|
|
||||||
|
=== TURN 1/5 - COACH MODE ===
|
||||||
|
🎓 Starting coach review...
|
||||||
|
🎓 Coach review completed
|
||||||
|
Coach feedback:
|
||||||
|
The Flask server is implemented correctly with a homepage route.
|
||||||
|
The code follows best practices and meets the requirements.
|
||||||
|
IMPLEMENTATION_APPROVED
|
||||||
|
|
||||||
|
=== SESSION COMPLETED - IMPLEMENTATION APPROVED ===
|
||||||
|
✅ Coach approved the implementation!
|
||||||
|
|
||||||
|
============================================================
|
||||||
|
📊 AUTONOMOUS MODE SESSION REPORT
|
||||||
|
============================================================
|
||||||
|
⏱️ Total Duration: 12.34s
|
||||||
|
🔄 Turns Taken: 1/5
|
||||||
|
📝 Final Status: ✅ APPROVED
|
||||||
|
...
|
||||||
|
============================================================
|
||||||
|
|
||||||
|
✅ Autonomous run completed
|
||||||
|
|
||||||
|
============================================================
|
||||||
|
📝 Turn 2 - What's next? (add more requirements or refinements)
|
||||||
|
============================================================
|
||||||
|
requirement> add a /api/users endpoint that returns a list of users as JSON
|
||||||
|
|
||||||
|
📋 Current instructions and requirements (Turn 2):
|
||||||
|
add a /api/users endpoint that returns a list of users as JSON
|
||||||
|
|
||||||
|
🚀 Starting autonomous implementation...
|
||||||
|
|
||||||
|
[Autonomous mode runs again with BOTH requirements...]
|
||||||
|
|
||||||
|
============================================================
|
||||||
|
📝 Turn 3 - What's next? (add more requirements or refinements)
|
||||||
|
============================================================
|
||||||
|
requirement> exit
|
||||||
|
|
||||||
|
👋 Goodbye!
|
||||||
|
```
|
||||||
|
|
||||||
|
## Key Features
|
||||||
|
|
||||||
|
### 1. Requirement Accumulation
|
||||||
|
|
||||||
|
Each input you provide is:
|
||||||
|
- **Numbered sequentially** (1, 2, 3, ...)
|
||||||
|
- **Stored in memory** for the session
|
||||||
|
- **Included in all subsequent runs**
|
||||||
|
|
||||||
|
This means the agent always has the full context of what you've asked for.
|
||||||
|
|
||||||
|
### 2. Automatic Requirements Document
|
||||||
|
|
||||||
|
G3 automatically generates a structured requirements document:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
# Project Requirements
|
||||||
|
|
||||||
|
## Current Instructions and Requirements:
|
||||||
|
|
||||||
|
1. create a simple web server in Python with Flask that serves a homepage
|
||||||
|
2. add a /api/users endpoint that returns a list of users as JSON
|
||||||
|
3. add error handling for 404 and 500 errors
|
||||||
|
|
||||||
|
## Latest Requirement (Turn 3):
|
||||||
|
|
||||||
|
add error handling for 404 and 500 errors
|
||||||
|
```
|
||||||
|
|
||||||
|
This document is passed to autonomous mode, ensuring the agent:
|
||||||
|
- Knows all previous requirements
|
||||||
|
- Focuses on the latest addition
|
||||||
|
- Maintains consistency across iterations
|
||||||
|
|
||||||
|
### 3. Full Autonomous Quality
|
||||||
|
|
||||||
|
Each requirement triggers a complete autonomous run with:
|
||||||
|
- **Coach-Player Feedback Loop**: Quality assurance built-in
|
||||||
|
- **Multiple Turns**: Up to 5 iterations per requirement (configurable with `--max-turns`)
|
||||||
|
- **Compilation Checks**: Ensures code actually works
|
||||||
|
- **Testing**: Coach can run tests to verify functionality
|
||||||
|
|
||||||
|
### 4. Error Recovery
|
||||||
|
|
||||||
|
If an autonomous run fails:
|
||||||
|
- You're notified of the error
|
||||||
|
- You can provide additional requirements to fix issues
|
||||||
|
- The session continues (doesn't crash)
|
||||||
|
|
||||||
|
### 5. Workspace Management
|
||||||
|
|
||||||
|
- Uses **current directory** as workspace
|
||||||
|
- All files created in current directory
|
||||||
|
- No need to specify workspace path
|
||||||
|
- Works with existing projects or empty directories
|
||||||
|
|
||||||
|
## Command-Line Options
|
||||||
|
|
||||||
|
### Default (Accumulative Mode)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
g3
|
||||||
|
```
|
||||||
|
|
||||||
|
Starts accumulative autonomous mode in the current directory.
|
||||||
|
|
||||||
|
### With Options
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Use a specific workspace
|
||||||
|
g3 --workspace ~/projects/my-app
|
||||||
|
|
||||||
|
# Limit autonomous turns per requirement
|
||||||
|
g3 --max-turns 3
|
||||||
|
|
||||||
|
# Enable macOS Accessibility tools
|
||||||
|
g3 --macax
|
||||||
|
|
||||||
|
# Enable WebDriver browser automation
|
||||||
|
g3 --webdriver
|
||||||
|
|
||||||
|
# Use a specific provider/model
|
||||||
|
g3 --provider anthropic --model claude-3-5-sonnet-20241022
|
||||||
|
|
||||||
|
# Show prompts and code during execution
|
||||||
|
g3 --show-prompt --show-code
|
||||||
|
|
||||||
|
# Disable log files
|
||||||
|
g3 --quiet
|
||||||
|
```
|
||||||
|
|
||||||
|
### Disable Accumulative Mode
|
||||||
|
|
||||||
|
To use the traditional chat mode (without automatic autonomous runs):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
g3 --chat
|
||||||
|
|
||||||
|
# Alternative: legacy flag also works
|
||||||
|
g3 --accumulative
|
||||||
|
```
|
||||||
|
|
||||||
|
This gives you the old behavior where you chat with the agent without automatic autonomous runs.
|
||||||
|
|
||||||
|
## Use Cases
|
||||||
|
|
||||||
|
### 1. Rapid Prototyping
|
||||||
|
|
||||||
|
```bash
|
||||||
|
requirement> create a REST API for a todo app
|
||||||
|
requirement> add SQLite database storage
|
||||||
|
requirement> add authentication with JWT
|
||||||
|
requirement> add rate limiting
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Iterative Refinement
|
||||||
|
|
||||||
|
```bash
|
||||||
|
requirement> create a data visualization dashboard
|
||||||
|
requirement> make the charts interactive
|
||||||
|
requirement> add dark mode support
|
||||||
|
requirement> optimize for mobile devices
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Bug Fixing
|
||||||
|
|
||||||
|
```bash
|
||||||
|
requirement> fix the login form validation
|
||||||
|
requirement> handle edge case when username is empty
|
||||||
|
requirement> add better error messages
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Feature Addition
|
||||||
|
|
||||||
|
```bash
|
||||||
|
requirement> add export to CSV functionality
|
||||||
|
requirement> add email notifications
|
||||||
|
requirement> add admin dashboard
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tips and Best Practices
|
||||||
|
|
||||||
|
### 1. Start Simple
|
||||||
|
|
||||||
|
Begin with a basic requirement, let it be implemented, then add complexity:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
✅ Good:
|
||||||
|
requirement> create a basic Flask web server
|
||||||
|
requirement> add a homepage with a form
|
||||||
|
requirement> add form validation
|
||||||
|
|
||||||
|
❌ Too Complex:
|
||||||
|
requirement> create a full-stack web app with authentication, database, API, and frontend
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Be Specific
|
||||||
|
|
||||||
|
The more specific you are, the better the results:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
✅ Good:
|
||||||
|
requirement> add a /api/users endpoint that returns JSON with id, name, and email fields
|
||||||
|
|
||||||
|
❌ Vague:
|
||||||
|
requirement> add users
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. One Thing at a Time
|
||||||
|
|
||||||
|
Focus each requirement on a single feature or fix:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
✅ Good:
|
||||||
|
requirement> add error handling for database connections
|
||||||
|
requirement> add logging for all API requests
|
||||||
|
|
||||||
|
❌ Multiple Things:
|
||||||
|
requirement> add error handling and logging and monitoring and alerts
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Review Between Turns
|
||||||
|
|
||||||
|
After each autonomous run completes:
|
||||||
|
- Check the generated files
|
||||||
|
- Test the functionality
|
||||||
|
- Decide what to add or fix next
|
||||||
|
|
||||||
|
### 5. Use Exit Commands
|
||||||
|
|
||||||
|
When done:
|
||||||
|
- Type `exit` or `quit`
|
||||||
|
- Press `Ctrl+D` (EOF)
|
||||||
|
- Press `Ctrl+C` to cancel current input
|
||||||
|
|
||||||
|
## Comparison with Other Modes
|
||||||
|
|
||||||
|
| Feature | Accumulative (Default) | Traditional Interactive | Autonomous | Single-Shot |
|
||||||
|
|---------|----------------------|------------------------|------------|-------------|
|
||||||
|
| **Command** | `g3` | `g3 --accumulative` | `g3 --autonomous` | `g3 "task"` |
|
||||||
|
| **Input Style** | Iterative prompts | Chat messages | requirements.md file | Command-line arg |
|
||||||
|
| **Auto-Autonomous** | ✅ Yes | ❌ No | ✅ Yes | ❌ No |
|
||||||
|
| **Coach-Player Loop** | ✅ Yes | ❌ No | ✅ Yes | ❌ No |
|
||||||
|
| **Accumulates Requirements** | ✅ Yes | ❌ No | ❌ No | ❌ No |
|
||||||
|
| **Multiple Iterations** | ✅ Yes | ✅ Yes | ✅ Yes | ❌ No |
|
||||||
|
| **Best For** | Iterative development | Quick questions | Pre-planned projects | One-off tasks |
|
||||||
|
|
||||||
|
## Technical Details
|
||||||
|
|
||||||
|
### Requirements Storage
|
||||||
|
|
||||||
|
- Stored in memory (not persisted to disk)
|
||||||
|
- Numbered sequentially starting from 1
|
||||||
|
- Formatted as markdown list
|
||||||
|
- Passed to autonomous mode as `--requirements` override
|
||||||
|
|
||||||
|
### History
|
||||||
|
|
||||||
|
- Saved to `~/.g3_accumulative_history`
|
||||||
|
- Separate from traditional interactive history
|
||||||
|
- Persists across sessions
|
||||||
|
- Uses rustyline for readline support
|
||||||
|
|
||||||
|
### Workspace
|
||||||
|
|
||||||
|
- Defaults to current directory
|
||||||
|
- Can be overridden with `--workspace`
|
||||||
|
- All files created in workspace
|
||||||
|
- Logs saved to `workspace/logs/`
|
||||||
|
|
||||||
|
### Autonomous Execution
|
||||||
|
|
||||||
|
- Full coach-player feedback loop
|
||||||
|
- Configurable max turns (default: 5)
|
||||||
|
- Respects all CLI flags (--macax, --webdriver, etc.)
|
||||||
|
- Error handling allows continuation
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### "No requirements provided"
|
||||||
|
|
||||||
|
This shouldn't happen in accumulative mode, but if it does:
|
||||||
|
- Check that you entered a requirement
|
||||||
|
- Ensure the requirement isn't empty
|
||||||
|
- Try restarting G3
|
||||||
|
|
||||||
|
### "Autonomous run failed"
|
||||||
|
|
||||||
|
If an autonomous run fails:
|
||||||
|
- Read the error message
|
||||||
|
- Provide a new requirement to fix the issue
|
||||||
|
- Or type `exit` and investigate manually
|
||||||
|
|
||||||
|
### "Context window full"
|
||||||
|
|
||||||
|
If you hit token limits:
|
||||||
|
- The agent will auto-summarize
|
||||||
|
- Or you can start a new session
|
||||||
|
- Consider using `--max-turns` to limit iterations
|
||||||
|
|
||||||
|
### "Coach never approves"
|
||||||
|
|
||||||
|
If the coach keeps rejecting:
|
||||||
|
- Check the coach feedback for specific issues
|
||||||
|
- Provide more specific requirements
|
||||||
|
- Consider increasing `--max-turns`
|
||||||
|
|
||||||
|
## Future Enhancements
|
||||||
|
|
||||||
|
Planned improvements:
|
||||||
|
|
||||||
|
1. **Persistence**: Save accumulated requirements to disk
|
||||||
|
2. **Editing**: Edit or remove previous requirements
|
||||||
|
3. **Branching**: Try different approaches
|
||||||
|
4. **Templates**: Pre-defined requirement sets
|
||||||
|
5. **Review**: Show all accumulated requirements
|
||||||
|
6. **Export**: Save to requirements.md
|
||||||
|
7. **Undo**: Remove last requirement
|
||||||
|
8. **Replay**: Re-run with same requirements
|
||||||
|
|
||||||
|
## Feedback
|
||||||
|
|
||||||
|
This is a new feature! Please provide feedback:
|
||||||
|
- What works well?
|
||||||
|
- What's confusing?
|
||||||
|
- What features would you like?
|
||||||
|
- Any bugs or issues?
|
||||||
|
|
||||||
|
Open an issue on GitHub or contribute improvements!
|
||||||
39
test-ai-requirements.sh
Executable file
39
test-ai-requirements.sh
Executable file
@@ -0,0 +1,39 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Test script for AI-enhanced interactive requirements mode
|
||||||
|
|
||||||
|
echo "Testing AI-enhanced interactive requirements mode..."
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Create a test workspace
|
||||||
|
TEST_WORKSPACE="/tmp/g3-test-interactive-$(date +%s)"
|
||||||
|
mkdir -p "$TEST_WORKSPACE"
|
||||||
|
|
||||||
|
echo "Test workspace: $TEST_WORKSPACE"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Create sample brief input
|
||||||
|
BRIEF_INPUT="build a calculator cli in rust with basic operations"
|
||||||
|
|
||||||
|
echo "Brief input:"
|
||||||
|
echo "---"
|
||||||
|
echo "$BRIEF_INPUT"
|
||||||
|
echo "---"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
echo "This will:"
|
||||||
|
echo "1. Send brief input to AI"
|
||||||
|
echo "2. AI generates structured requirements.md"
|
||||||
|
echo "3. Show enhanced requirements"
|
||||||
|
echo "4. Prompt for confirmation (y/e/n)"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
echo "To test manually, run:"
|
||||||
|
echo "cargo run -- --autonomous --interactive-requirements --workspace $TEST_WORKSPACE"
|
||||||
|
echo ""
|
||||||
|
echo "Then type: $BRIEF_INPUT"
|
||||||
|
echo "Press Ctrl+D"
|
||||||
|
echo "Review the AI-generated requirements"
|
||||||
|
echo "Choose 'y' to proceed, 'e' to edit, or 'n' to cancel"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
echo "Test workspace will be at: $TEST_WORKSPACE"
|
||||||
Reference in New Issue
Block a user