From c4902288244670740b3719640ebbf0ebdc2f9b92 Mon Sep 17 00:00:00 2001 From: Dhanji Prasanna Date: Sat, 27 Sep 2025 17:28:02 +1000 Subject: [PATCH] databricks support --- .gitignore | 2 + Cargo.lock | 459 ++++++++++++- config.example.toml | 35 +- crates/g3-config/src/lib.rs | 32 +- crates/g3-core/src/lib.rs | 43 +- crates/g3-providers/Cargo.toml | 11 + crates/g3-providers/src/databricks.rs | 907 ++++++++++++++++++++++++++ crates/g3-providers/src/lib.rs | 3 + crates/g3-providers/src/oauth.rs | 457 +++++++++++++ 9 files changed, 1899 insertions(+), 50 deletions(-) create mode 100644 crates/g3-providers/src/databricks.rs create mode 100644 crates/g3-providers/src/oauth.rs diff --git a/.gitignore b/.gitignore index ad67955..b8eecf8 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ target # These are backup files generated by rustfmt **/*.rs.bk +**/.DS_Store + # MSVC Windows builds of rustc generate these, which store debugging information *.pdb diff --git a/Cargo.lock b/Cargo.lock index e5072cd..8af1cd2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -126,12 +126,73 @@ dependencies = [ "syn", ] +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "autocfg" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "axum" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +dependencies = [ + "async-trait", + "axum-core", + "bytes", + "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "hyper 1.7.0", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper 1.0.2", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper 1.0.2", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "backtrace" version = "0.3.75" @@ -153,6 +214,12 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bindgen" version = "0.69.5" @@ -224,6 +291,12 @@ dependencies = [ "shlex", ] +[[package]] +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + [[package]] name = "cexpr" version = "0.6.0" @@ -325,6 +398,16 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + [[package]] name = "config" version = "0.14.1" @@ -402,6 +485,16 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -792,15 +885,25 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "axum", + "base64 0.22.1", "bytes", + "chrono", + "dirs 5.0.1", "futures-util", + "nanoid", "reqwest", "serde", "serde_json", + "serde_urlencoded", + "sha2", "thiserror 1.0.69", "tokio", "tokio-stream", + "tokio-util", "tracing", + "url", + "webbrowser", ] [[package]] @@ -859,7 +962,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.12", "indexmap", "slab", "tokio", @@ -924,6 +1027,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.6" @@ -931,7 +1045,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", - "http", + "http 0.2.12", + "pin-project-lite", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http 1.3.1", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http 1.3.1", + "http-body 1.0.1", "pin-project-lite", ] @@ -958,8 +1095,8 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", - "http-body", + "http 0.2.12", + "http-body 0.4.6", "httparse", "httpdate", "itoa", @@ -971,6 +1108,27 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "http 1.3.1", + "http-body 1.0.1", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "pin-utils", + "smallvec", + "tokio", +] + [[package]] name = "hyper-tls" version = "0.5.0" @@ -978,12 +1136,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes", - "hyper", + "hyper 0.14.32", "native-tls", "tokio", "tokio-native-tls", ] +[[package]] +name = "hyper-util" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8" +dependencies = [ + "bytes", + "futures-core", + "http 1.3.1", + "http-body 1.0.1", + "hyper 1.7.0", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "iana-time-zone" version = "0.1.64" @@ -1176,6 +1350,28 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +[[package]] +name = "jni" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" +dependencies = [ + "cesu8", + "cfg-if", + "combine", + "jni-sys", + "log", + "thiserror 1.0.69", + "walkdir", + "windows-sys 0.45.0", +] + +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + [[package]] name = "jobserver" version = "0.1.34" @@ -1324,6 +1520,12 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + [[package]] name = "memchr" version = "2.7.5" @@ -1362,6 +1564,15 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "nanoid" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ffa00dec017b5b1a8b7cf5e2c008bfda1aa7e0697ac1508b491fdf2622fb4d8" +dependencies = [ + "rand", +] + [[package]] name = "native-tls" version = "0.2.14" @@ -1379,6 +1590,12 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndk-context" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27b02d87554356db9e9a873add8782d4ea6e3e58ea071a9adb9a2e8ddb884a8b" + [[package]] name = "nibble_vec" version = "0.1.0" @@ -1444,6 +1661,31 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +[[package]] +name = "objc2" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "561f357ba7f3a2a61563a186a163d0a3a5247e1089524a3981d49adb775078bc" +dependencies = [ + "objc2-encode", +] + +[[package]] +name = "objc2-encode" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" + +[[package]] +name = "objc2-foundation" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "900831247d2fe1a09a683278e5384cfb8c80c79fe6b166f9d14bfdde0ea1b03c" +dependencies = [ + "bitflags 2.9.4", + "objc2", +] + [[package]] name = "object" version = "0.36.7" @@ -1637,6 +1879,15 @@ dependencies = [ "zerovec", ] +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + [[package]] name = "prettyplease" version = "0.2.37" @@ -1681,6 +1932,36 @@ dependencies = [ "nibble_vec", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.16", +] + [[package]] name = "redox_syscall" version = "0.5.17" @@ -1747,15 +2028,15 @@ version = "0.11.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" dependencies = [ - "base64", + "base64 0.21.7", "bytes", "encoding_rs", "futures-core", "futures-util", "h2", - "http", - "http-body", - "hyper", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.32", "hyper-tls", "ipnet", "js-sys", @@ -1769,7 +2050,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "sync_wrapper", + "sync_wrapper 0.1.2", "system-configuration", "tokio", "tokio-native-tls", @@ -1789,7 +2070,7 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b91f7eff05f748767f183df4320a63d6936e9c6107d97c9e6bdd9784f4289c94" dependencies = [ - "base64", + "base64 0.21.7", "bitflags 2.9.4", "serde", "serde_derive", @@ -1858,7 +2139,7 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" dependencies = [ - "base64", + "base64 0.21.7", ] [[package]] @@ -1895,6 +2176,15 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.27" @@ -1917,7 +2207,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.9.4", - "core-foundation", + "core-foundation 0.9.4", "core-foundation-sys", "libc", "security-framework-sys", @@ -1981,6 +2271,17 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + [[package]] name = "serde_spanned" version = "0.6.9" @@ -2107,6 +2408,12 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" + [[package]] name = "synstructure" version = "0.13.2" @@ -2125,7 +2432,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ "bitflags 1.3.2", - "core-foundation", + "core-foundation 0.9.4", "system-configuration-sys", ] @@ -2326,6 +2633,28 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" +[[package]] +name = "tower" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper 1.0.2", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + [[package]] name = "tower-service" version = "0.3.3" @@ -2338,6 +2667,7 @@ version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -2482,6 +2812,16 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.1" @@ -2611,6 +2951,22 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webbrowser" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aaf4f3c0ba838e82b4e5ccc4157003fb8c324ee24c058470ffb82820becbde98" +dependencies = [ + "core-foundation 0.10.1", + "jni", + "log", + "ndk-context", + "objc2", + "objc2-foundation", + "url", + "web-sys", +] + [[package]] name = "which" version = "4.4.2" @@ -2623,6 +2979,15 @@ dependencies = [ "rustix 0.38.44", ] +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys 0.61.0", +] + [[package]] name = "windows-core" version = "0.62.1" @@ -2688,6 +3053,15 @@ dependencies = [ "windows-link 0.2.0", ] +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -2733,6 +3107,21 @@ dependencies = [ "windows-link 0.2.0", ] +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -2781,6 +3170,12 @@ dependencies = [ "windows_x86_64_msvc 0.53.0", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -2799,6 +3194,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -2817,6 +3218,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -2847,6 +3254,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -2865,6 +3278,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -2883,6 +3302,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -2901,6 +3326,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" diff --git a/config.example.toml b/config.example.toml index b4fdcb5..6104d26 100644 --- a/config.example.toml +++ b/config.example.toml @@ -1,36 +1,13 @@ -# Example configuration file for G3 -# Copy to ~/.config/g3/config.toml and customize - [providers] -default_provider = "embedded" +default_provider = "databricks" -[providers.openai] -# Get your API key from https://platform.openai.com/api-keys -api_key = "sk-your-openai-api-key-here" -model = "gpt-4" -# Optional: custom base URL for OpenAI-compatible APIs -# base_url = "https://api.openai.com/v1" -max_tokens = 2048 -temperature = 0.1 - -[providers.anthropic] -# Get your API key from https://console.anthropic.com/ -api_key = "your-anthropic-api-key-here" -model = "claude-3-5-sonnet-20241022" +[providers.databricks] +host = "https://your-workspace.cloud.databricks.com" +# token = "your-databricks-token" # Optional - will use OAuth if not provided +model = "databricks-claude-sonnet-4" max_tokens = 4096 temperature = 0.1 - -[providers.embedded] -# Path to your GGUF model file -model_path = "~/.cache/g3/models/codellama-7b-instruct.Q4_K_M.gguf" -model_type = "codellama" -context_length = 16384 # Use CodeLlama's full context capability -max_tokens = 2048 # Default fallback, but will be calculated dynamically -temperature = 0.1 -# Number of layers to offload to GPU (0 for CPU only) -gpu_layers = 32 -# Number of CPU threads to use -threads = 8 +use_oauth = true [agent] max_context_length = 8192 diff --git a/crates/g3-config/src/lib.rs b/crates/g3-config/src/lib.rs index 46856f9..6d02005 100644 --- a/crates/g3-config/src/lib.rs +++ b/crates/g3-config/src/lib.rs @@ -12,6 +12,7 @@ pub struct Config { pub struct ProvidersConfig { pub openai: Option, pub anthropic: Option, + pub databricks: Option, pub embedded: Option, pub default_provider: String, } @@ -33,6 +34,16 @@ pub struct AnthropicConfig { pub temperature: Option, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DatabricksConfig { + pub host: String, + pub token: Option, // Optional - will use OAuth if not provided + pub model: String, + pub max_tokens: Option, + pub temperature: Option, + pub use_oauth: Option, // Default to true if token not provided +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct EmbeddedConfig { pub model_path: String, @@ -57,8 +68,16 @@ impl Default for Config { providers: ProvidersConfig { openai: None, anthropic: None, + databricks: Some(DatabricksConfig { + host: "https://your-workspace.cloud.databricks.com".to_string(), + token: None, // Will use OAuth by default + model: "databricks-claude-sonnet-4".to_string(), + max_tokens: Some(4096), + temperature: Some(0.1), + use_oauth: Some(true), + }), embedded: None, - default_provider: "anthropic".to_string(), + default_provider: "databricks".to_string(), }, agent: AgentConfig { max_context_length: 8192, @@ -88,9 +107,9 @@ impl Config { }) }; - // If no config exists, create and save a default Qwen config + // If no config exists, create and save a default Databricks config if !config_exists { - let qwen_config = Self::default_qwen_config(); + let databricks_config = Self::default(); // Save to default location let config_dir = dirs::home_dir() @@ -105,13 +124,13 @@ impl Config { std::fs::create_dir_all(&config_dir).ok(); let config_file = config_dir.join("config.toml"); - if let Err(e) = qwen_config.save(config_file.to_str().unwrap()) { + if let Err(e) = databricks_config.save(config_file.to_str().unwrap()) { eprintln!("Warning: Could not save default config: {}", e); } else { - println!("Created default Qwen configuration at: {}", config_file.display()); + println!("Created default Databricks configuration at: {}", config_file.display()); } - return Ok(qwen_config); + return Ok(databricks_config); } // Existing config loading logic @@ -157,6 +176,7 @@ impl Config { providers: ProvidersConfig { openai: None, anthropic: None, + databricks: None, embedded: Some(EmbeddedConfig { model_path: "~/.cache/g3/models/qwen2.5-7b-instruct-q3_k_m.gguf".to_string(), model_type: "qwen".to_string(), diff --git a/crates/g3-core/src/lib.rs b/crates/g3-core/src/lib.rs index d5b0a30..cbc3b3e 100644 --- a/crates/g3-core/src/lib.rs +++ b/crates/g3-core/src/lib.rs @@ -303,6 +303,36 @@ impl Agent { } } + // Register Databricks provider if configured AND it's the default provider + if let Some(databricks_config) = &config.providers.databricks { + if config.providers.default_provider == "databricks" { + info!("Initializing Databricks provider (selected as default)"); + + let databricks_provider = if let Some(token) = &databricks_config.token { + // Use token-based authentication + g3_providers::DatabricksProvider::from_token( + databricks_config.host.clone(), + token.clone(), + databricks_config.model.clone(), + databricks_config.max_tokens, + databricks_config.temperature, + )? + } else { + // Use OAuth authentication + g3_providers::DatabricksProvider::from_oauth( + databricks_config.host.clone(), + databricks_config.model.clone(), + databricks_config.max_tokens, + databricks_config.temperature, + ).await? + }; + + providers.register(databricks_provider); + } else { + info!("Databricks provider configured but not selected as default, skipping initialization"); + } + } + // Set default provider debug!( "Setting default provider to: {}", @@ -352,7 +382,18 @@ impl Agent { // Claude models have large context windows 200000 // Default for Claude models } - + "databricks" => { + // Databricks models have varying context windows depending on the model + if model_name.contains("claude") { + 200000 // Claude models on Databricks have large context windows + } else if model_name.contains("llama") { + 32768 // Llama models typically support 32k context + } else if model_name.contains("dbrx") { + 32768 // DBRX supports 32k context + } else { + 16384 // Conservative default for other Databricks models + } + } _ => config.agent.max_context_length as u32, }; diff --git a/crates/g3-providers/Cargo.toml b/crates/g3-providers/Cargo.toml index 83422cf..50225cc 100644 --- a/crates/g3-providers/Cargo.toml +++ b/crates/g3-providers/Cargo.toml @@ -16,3 +16,14 @@ async-trait = "0.1" tokio-stream = "0.1" futures-util = "0.3" bytes = "1.0" +# OAuth dependencies +axum = "0.7" +base64 = "0.22" +chrono = { version = "0.4", features = ["serde"] } +sha2 = "0.10" +url = "2.5" +webbrowser = "1.0" +nanoid = "0.4" +serde_urlencoded = "0.7" +tokio-util = "0.7" +dirs = "5.0" diff --git a/crates/g3-providers/src/databricks.rs b/crates/g3-providers/src/databricks.rs new file mode 100644 index 0000000..7da1ebc --- /dev/null +++ b/crates/g3-providers/src/databricks.rs @@ -0,0 +1,907 @@ +//! Databricks LLM provider implementation for the g3-providers crate. +//! +//! This module provides an implementation of the `LLMProvider` trait for Databricks Foundation Model APIs, +//! supporting both completion and streaming modes with OAuth authentication. +//! +//! # Features +//! +//! - Support for Databricks Foundation Models (databricks-claude-sonnet-4, databricks-meta-llama-3-3-70b-instruct, etc.) +//! - Both completion and streaming response modes +//! - OAuth authentication with automatic token refresh +//! - Token-based authentication as fallback +//! - Native tool calling support for compatible models +//! - Automatic model discovery from Databricks workspace +//! +//! # Usage +//! +//! ```rust,no_run +//! use g3_providers::{DatabricksProvider, LLMProvider, CompletionRequest, Message, MessageRole}; +//! +//! #[tokio::main] +//! async fn main() -> anyhow::Result<()> { +//! // Create the provider with OAuth (recommended) +//! let provider = DatabricksProvider::from_oauth( +//! "https://your-workspace.cloud.databricks.com".to_string(), +//! "databricks-claude-sonnet-4".to_string(), +//! None, // Optional: max tokens +//! None, // Optional: temperature +//! ).await?; +//! +//! // Or create with token +//! let provider = DatabricksProvider::from_token( +//! "https://your-workspace.cloud.databricks.com".to_string(), +//! "your-databricks-token".to_string(), +//! "databricks-claude-sonnet-4".to_string(), +//! None, +//! None, +//! )?; +//! +//! // Create a completion request +//! let request = CompletionRequest { +//! messages: vec![ +//! Message { +//! role: MessageRole::User, +//! content: "Hello! How are you?".to_string(), +//! }, +//! ], +//! max_tokens: Some(1000), +//! temperature: Some(0.7), +//! stream: false, +//! tools: None, +//! }; +//! +//! // Get a completion +//! let response = provider.complete(request).await?; +//! println!("Response: {}", response.content); +//! +//! Ok(()) +//! } +//! ``` + +use anyhow::{anyhow, Result}; +use bytes::Bytes; +use futures_util::stream::StreamExt; +use reqwest::{Client, RequestBuilder}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tracing::{debug, error, info, warn}; + +use crate::{ + CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message, + MessageRole, Tool, ToolCall, Usage, +}; + +const DEFAULT_CLIENT_ID: &str = "databricks-cli"; +const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020"; +const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"]; +const DEFAULT_TIMEOUT_SECS: u64 = 600; + +pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-sonnet-4"; +const DATABRICKS_DEFAULT_FAST_MODEL: &str = "gemini-1-5-flash"; +pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[ + "databricks-claude-3-7-sonnet", + "databricks-meta-llama-3-3-70b-instruct", + "databricks-meta-llama-3-1-405b-instruct", + "databricks-dbrx-instruct", + "databricks-mixtral-8x7b-instruct", +]; + +#[derive(Debug, Clone)] +pub enum DatabricksAuth { + Token(String), + OAuth { + host: String, + client_id: String, + redirect_url: String, + scopes: Vec, + cached_token: Option, + }, +} + +impl DatabricksAuth { + pub fn oauth(host: String) -> Self { + Self::OAuth { + host, + client_id: DEFAULT_CLIENT_ID.to_string(), + redirect_url: DEFAULT_REDIRECT_URL.to_string(), + scopes: DEFAULT_SCOPES.iter().map(|s| s.to_string()).collect(), + cached_token: None, + } + } + + pub fn token(token: String) -> Self { + Self::Token(token) + } + + async fn get_token(&mut self) -> Result { + match self { + DatabricksAuth::Token(token) => Ok(token.clone()), + DatabricksAuth::OAuth { + host, + client_id, + redirect_url, + scopes, + cached_token: _, + } => { + // Use the OAuth implementation + crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await + } + } + } +} + +#[derive(Debug, Clone)] +pub struct DatabricksProvider { + client: Client, + host: String, + auth: DatabricksAuth, + model: String, + max_tokens: u32, + temperature: f32, +} + +impl DatabricksProvider { + pub fn from_token( + host: String, + token: String, + model: String, + max_tokens: Option, + temperature: Option, + ) -> Result { + let client = Client::builder() + .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS)) + .build() + .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?; + + info!("Initialized Databricks provider with model: {} on host: {}", model, host); + + Ok(Self { + client, + host: host.trim_end_matches('/').to_string(), + auth: DatabricksAuth::token(token), + model, + max_tokens: max_tokens.unwrap_or(4096), + temperature: temperature.unwrap_or(0.1), + }) + } + + pub async fn from_oauth( + host: String, + model: String, + max_tokens: Option, + temperature: Option, + ) -> Result { + let client = Client::builder() + .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS)) + .build() + .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?; + + info!("Initialized Databricks provider with OAuth for model: {} on host: {}", model, host); + + Ok(Self { + client, + host: host.trim_end_matches('/').to_string(), + auth: DatabricksAuth::oauth(host.clone()), + model, + max_tokens: max_tokens.unwrap_or(4096), + temperature: temperature.unwrap_or(0.1), + }) + } + + async fn create_request_builder(&mut self, streaming: bool) -> Result { + let token = self.auth.get_token().await?; + + let mut builder = self + .client + .post(&format!("{}/serving-endpoints/{}/invocations", self.host, self.model)) + .header("Authorization", format!("Bearer {}", token)) + .header("Content-Type", "application/json"); + + if streaming { + builder = builder.header("Accept", "text/event-stream"); + } + + Ok(builder) + } + + fn convert_tools(&self, tools: &[Tool]) -> Vec { + tools + .iter() + .map(|tool| DatabricksTool { + r#type: "function".to_string(), + function: DatabricksFunction { + name: tool.name.clone(), + description: tool.description.clone(), + parameters: tool.input_schema.clone(), + }, + }) + .collect() + } + + fn convert_messages(&self, messages: &[Message]) -> Result> { + let mut databricks_messages = Vec::new(); + + for message in messages { + let role = match message.role { + MessageRole::System => "system", + MessageRole::User => "user", + MessageRole::Assistant => "assistant", + }; + + databricks_messages.push(DatabricksMessage { + role: role.to_string(), + content: Some(message.content.clone()), + tool_calls: None, // Only used in responses, not requests + }); + } + + if databricks_messages.is_empty() { + return Err(anyhow!("At least one message is required")); + } + + Ok(databricks_messages) + } + + fn create_request_body( + &self, + messages: &[Message], + tools: Option<&[Tool]>, + streaming: bool, + max_tokens: u32, + temperature: f32, + ) -> Result { + let databricks_messages = self.convert_messages(messages)?; + + // Convert tools if provided + let databricks_tools = tools.map(|t| self.convert_tools(t)); + + let request = DatabricksRequest { + messages: databricks_messages, + max_tokens, + temperature, + tools: databricks_tools, + stream: streaming, + }; + + Ok(request) + } + + async fn parse_streaming_response( + &self, + mut stream: impl futures_util::Stream> + Unpin, + tx: mpsc::Sender>, + ) { + let mut buffer = String::new(); + let mut current_tool_calls: std::collections::HashMap = std::collections::HashMap::new(); // index -> (id, name, args) + + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + let chunk_str = match std::str::from_utf8(&chunk) { + Ok(s) => s, + Err(e) => { + error!("Invalid UTF-8 in stream chunk: {}", e); + let _ = tx + .send(Err(anyhow!("Invalid UTF-8 in stream chunk: {}", e))) + .await; + return; + } + }; + + buffer.push_str(chunk_str); + + // Process complete lines + while let Some(line_end) = buffer.find('\n') { + let line = buffer[..line_end].trim().to_string(); + buffer.drain(..line_end + 1); + + if line.is_empty() { + continue; + } + + // Parse Server-Sent Events format + if let Some(data) = line.strip_prefix("data: ") { + if data == "[DONE]" { + debug!("Received stream completion marker"); + let final_tool_calls: Vec = current_tool_calls.values() + .map(|(id, name, args)| ToolCall { + id: id.clone(), + tool: name.clone(), + args: serde_json::from_str(args).unwrap_or(serde_json::Value::Object(serde_json::Map::new())), + }) + .collect(); + let final_chunk = CompletionChunk { + content: String::new(), + finished: true, + tool_calls: if final_tool_calls.is_empty() { None } else { Some(final_tool_calls) }, + }; + if tx.send(Ok(final_chunk)).await.is_err() { + debug!("Receiver dropped, stopping stream"); + } + return; + } + + debug!("Raw Databricks API JSON: {}", data); + + match serde_json::from_str::(data) { + Ok(chunk) => { + debug!("Parsed stream chunk: {:?}", chunk); + + // Handle different types of chunks + if let Some(choices) = chunk.choices { + for choice in choices { + if let Some(delta) = choice.delta { + // Handle text content + if let Some(content) = delta.content { + debug!("Sending text chunk: '{}'", content); + let chunk = CompletionChunk { + content, + finished: false, + tool_calls: None, + }; + if tx.send(Ok(chunk)).await.is_err() { + debug!("Receiver dropped, stopping stream"); + return; + } + } + + // Handle tool calls - accumulate across chunks + if let Some(tool_calls) = delta.tool_calls { + for tool_call in tool_calls { + let index = tool_call.index.unwrap_or(0); + let entry = current_tool_calls.entry(index).or_insert_with(|| { + (String::new(), String::new(), String::new()) + }); + + // Update ID if provided + if let Some(id) = tool_call.id { + entry.0 = id; + } + + // Update name if provided and not empty + if !tool_call.function.name.is_empty() { + entry.1 = tool_call.function.name; + } + + // Append arguments + entry.2.push_str(&tool_call.function.arguments); + + debug!("Accumulated tool call {}: id='{}', name='{}', args='{}'", + index, entry.0, entry.1, entry.2); + } + } + } + + // Check if this choice is finished + if choice.finish_reason.is_some() { + debug!("Choice finished with reason: {:?}", choice.finish_reason); + + // Convert accumulated tool calls to final format + let final_tool_calls: Vec = current_tool_calls.values() + .filter(|(_, name, _)| !name.is_empty()) // Only include tool calls with names + .map(|(id, name, args)| { + debug!("Converting tool call: id='{}', name='{}', args='{}'", id, name, args); + ToolCall { + id: if id.is_empty() { format!("tool_{}", name) } else { id.clone() }, + tool: name.clone(), + args: serde_json::from_str(args).unwrap_or_else(|e| { + debug!("Failed to parse tool args '{}': {}", args, e); + serde_json::Value::Object(serde_json::Map::new()) + }), + } + }) + .collect(); + + debug!("Final tool calls: {:?}", final_tool_calls); + + let final_chunk = CompletionChunk { + content: String::new(), + finished: true, + tool_calls: if final_tool_calls.is_empty() { None } else { Some(final_tool_calls) }, + }; + if tx.send(Ok(final_chunk)).await.is_err() { + debug!("Receiver dropped, stopping stream"); + } + return; + } + } + } + } + Err(e) => { + debug!("Failed to parse stream chunk: {} - Data: {}", e, data); + // Don't error out on parse failures, just continue + } + } + } + } + } + Err(e) => { + error!("Stream error: {}", e); + let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await; + return; + } + } + } + + // Send final chunk if we haven't already + let final_tool_calls: Vec = current_tool_calls.values() + .filter(|(_, name, _)| !name.is_empty()) + .map(|(id, name, args)| ToolCall { + id: if id.is_empty() { format!("tool_{}", name) } else { id.clone() }, + tool: name.clone(), + args: serde_json::from_str(args).unwrap_or(serde_json::Value::Object(serde_json::Map::new())), + }) + .collect(); + + let final_chunk = CompletionChunk { + content: String::new(), + finished: true, + tool_calls: if final_tool_calls.is_empty() { None } else { Some(final_tool_calls) }, + }; + let _ = tx.send(Ok(final_chunk)).await; + } + + pub async fn fetch_supported_models(&mut self) -> Result>> { + let token = self.auth.get_token().await?; + + let response = match self + .client + .get(&format!("{}/api/2.0/serving-endpoints", self.host)) + .header("Authorization", format!("Bearer {}", token)) + .send() + .await + { + Ok(resp) => resp, + Err(e) => { + warn!("Failed to fetch Databricks models: {}", e); + return Ok(None); + } + }; + + if !response.status().is_success() { + let status = response.status(); + if let Ok(error_text) = response.text().await { + warn!( + "Failed to fetch Databricks models: {} - {}", + status, + error_text + ); + } else { + warn!("Failed to fetch Databricks models: {}", status); + } + return Ok(None); + } + + let json: serde_json::Value = match response.json().await { + Ok(json) => json, + Err(e) => { + warn!("Failed to parse Databricks API response: {}", e); + return Ok(None); + } + }; + + let endpoints = match json.get("endpoints").and_then(|v| v.as_array()) { + Some(endpoints) => endpoints, + None => { + warn!( + "Unexpected response format from Databricks API: missing 'endpoints' array" + ); + return Ok(None); + } + }; + + let models: Vec = endpoints + .iter() + .filter_map(|endpoint| { + endpoint + .get("name") + .and_then(|v| v.as_str()) + .map(|name| name.to_string()) + }) + .collect(); + + if models.is_empty() { + debug!("No serving endpoints found in Databricks workspace"); + Ok(None) + } else { + debug!( + "Found {} serving endpoints in Databricks workspace", + models.len() + ); + Ok(Some(models)) + } + } +} + +#[async_trait::async_trait] +impl LLMProvider for DatabricksProvider { + async fn complete(&self, request: CompletionRequest) -> Result { + debug!( + "Processing Databricks completion request with {} messages", + request.messages.len() + ); + + let max_tokens = request.max_tokens.unwrap_or(self.max_tokens); + let temperature = request.temperature.unwrap_or(self.temperature); + + let request_body = self.create_request_body( + &request.messages, + request.tools.as_deref(), + false, + max_tokens, + temperature + )?; + + debug!("Sending request to Databricks API: model={}, max_tokens={}, temperature={}", + self.model, request_body.max_tokens, request_body.temperature); + + // Debug: Log the full request body when tools are present + if request.tools.is_some() { + debug!("Full request body with tools: {}", serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "Failed to serialize".to_string())); + } + + let mut provider_clone = self.clone(); + let response = provider_clone + .create_request_builder(false) + .await? + .json(&request_body) + .send() + .await + .map_err(|e| anyhow!("Failed to send request to Databricks API: {}", e))?; + + let status = response.status(); + if !status.is_success() { + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(anyhow!("Databricks API error {}: {}", status, error_text)); + } + + let response_text = response.text().await?; + debug!("Raw Databricks API response: {}", response_text); + + let databricks_response: DatabricksResponse = serde_json::from_str(&response_text) + .map_err(|e| anyhow!("Failed to parse Databricks response: {} - Response: {}", e, response_text))?; + + // Debug: Log the parsed response structure + debug!("Parsed Databricks response: {:#?}", databricks_response); + + // Extract content from the first choice + let content = databricks_response + .choices + .first() + .and_then(|choice| choice.message.content.as_ref()) + .cloned() + .unwrap_or_default(); + + // Check if there are tool calls in the response + if let Some(first_choice) = databricks_response.choices.first() { + if let Some(tool_calls) = &first_choice.message.tool_calls { + debug!("Found {} tool calls in Databricks response", tool_calls.len()); + for (i, tool_call) in tool_calls.iter().enumerate() { + debug!("Tool call {}: {} with args: {}", i, tool_call.function.name, tool_call.function.arguments); + } + + // For now, we'll return the content as-is since g3 handles tool calls via streaming + // In the future, we might need to convert these to the internal format + } + } + + let usage = Usage { + prompt_tokens: databricks_response.usage.prompt_tokens, + completion_tokens: databricks_response.usage.completion_tokens, + total_tokens: databricks_response.usage.total_tokens, + }; + + debug!( + "Databricks completion successful: {} tokens generated", + usage.completion_tokens + ); + + Ok(CompletionResponse { + content, + usage, + model: self.model.clone(), + }) + } + + async fn stream(&self, request: CompletionRequest) -> Result { + debug!( + "Processing Databricks streaming request with {} messages", + request.messages.len() + ); + + let max_tokens = request.max_tokens.unwrap_or(self.max_tokens); + let temperature = request.temperature.unwrap_or(self.temperature); + + let request_body = self.create_request_body( + &request.messages, + request.tools.as_deref(), + true, + max_tokens, + temperature + )?; + + debug!("Sending streaming request to Databricks API: model={}, max_tokens={}, temperature={}", + self.model, request_body.max_tokens, request_body.temperature); + + // Debug: Log the full request body + debug!("Full request body: {}", serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "Failed to serialize".to_string())); + + let mut provider_clone = self.clone(); + let response = provider_clone + .create_request_builder(true) + .await? + .json(&request_body) + .send() + .await + .map_err(|e| anyhow!("Failed to send streaming request to Databricks API: {}", e))?; + + let status = response.status(); + if !status.is_success() { + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(anyhow!("Databricks API error {}: {}", status, error_text)); + } + + let stream = response.bytes_stream(); + let (tx, rx) = mpsc::channel(100); + + // Spawn task to process the stream + let provider = self.clone(); + tokio::spawn(async move { + provider.parse_streaming_response(stream, tx).await; + }); + + Ok(ReceiverStream::new(rx)) + } + + fn name(&self) -> &str { + "databricks" + } + + fn model(&self) -> &str { + &self.model + } + + fn has_native_tool_calling(&self) -> bool { + // Databricks Foundation Models support native tool calling + // This includes Claude, Llama, DBRX, and most other models on the platform + true + } +} + +// Databricks API request/response structures + +#[derive(Debug, Serialize)] +struct DatabricksRequest { + messages: Vec, + max_tokens: u32, + temperature: f32, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + stream: bool, +} + +#[derive(Debug, Serialize)] +struct DatabricksTool { + r#type: String, + function: DatabricksFunction, +} + +#[derive(Debug, Serialize)] +struct DatabricksFunction { + name: String, + description: String, + parameters: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] +struct DatabricksMessage { + role: String, + content: Option, // Make content optional since tool calls might not have content + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, // Add tool_calls field for responses +} + +#[derive(Debug, Serialize, Deserialize)] +struct DatabricksToolCall { + id: String, + r#type: String, + function: DatabricksToolCallFunction, +} + +#[derive(Debug, Serialize, Deserialize)] +struct DatabricksToolCallFunction { + name: String, + arguments: String, // This will be a JSON string that needs parsing +} + +#[derive(Debug, Deserialize)] +struct DatabricksResponse { + choices: Vec, + usage: DatabricksUsage, +} + +#[derive(Debug, Deserialize)] +struct DatabricksChoice { + message: DatabricksMessage, + finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +struct DatabricksUsage { + prompt_tokens: u32, + completion_tokens: u32, + total_tokens: u32, +} + +// Streaming response structures + +#[derive(Debug, Deserialize)] +struct DatabricksStreamChunk { + choices: Option>, +} + +#[derive(Debug, Deserialize)] +struct DatabricksStreamChoice { + delta: Option, + finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +struct DatabricksStreamDelta { + content: Option, + tool_calls: Option>, +} + +#[derive(Debug, Deserialize)] +struct DatabricksStreamToolCall { + index: Option, + id: Option, + function: DatabricksStreamFunction, +} + +#[derive(Debug, Deserialize)] +struct DatabricksStreamFunction { + #[serde(default)] + name: String, + arguments: String, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_message_conversion() { + let provider = DatabricksProvider::from_token( + "https://test.databricks.com".to_string(), + "test-token".to_string(), + "test-model".to_string(), + None, + None, + ).unwrap(); + + let messages = vec![ + Message { + role: MessageRole::System, + content: "You are a helpful assistant.".to_string(), + }, + Message { + role: MessageRole::User, + content: "Hello!".to_string(), + }, + Message { + role: MessageRole::Assistant, + content: "Hi there!".to_string(), + }, + ]; + + let databricks_messages = provider.convert_messages(&messages).unwrap(); + + assert_eq!(databricks_messages.len(), 3); + assert_eq!(databricks_messages[0].role, "system"); + assert_eq!(databricks_messages[1].role, "user"); + assert_eq!(databricks_messages[2].role, "assistant"); + } + + #[test] + fn test_request_body_creation() { + let provider = DatabricksProvider::from_token( + "https://test.databricks.com".to_string(), + "test-token".to_string(), + "databricks-claude-sonnet-4".to_string(), + Some(1000), + Some(0.5), + ).unwrap(); + + let messages = vec![ + Message { + role: MessageRole::User, + content: "Test message".to_string(), + }, + ]; + + let request_body = provider + .create_request_body(&messages, None, false, 1000, 0.5) + .unwrap(); + + assert_eq!(request_body.max_tokens, 1000); + assert_eq!(request_body.temperature, 0.5); + assert!(!request_body.stream); + assert_eq!(request_body.messages.len(), 1); + assert!(request_body.tools.is_none()); + } + + #[test] + fn test_tool_conversion() { + let provider = DatabricksProvider::from_token( + "https://test.databricks.com".to_string(), + "test-token".to_string(), + "test-model".to_string(), + None, + None, + ).unwrap(); + + let tools = vec![ + Tool { + name: "get_weather".to_string(), + description: "Get the current weather".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state" + } + }, + "required": ["location"] + }), + }, + ]; + + let databricks_tools = provider.convert_tools(&tools); + + assert_eq!(databricks_tools.len(), 1); + assert_eq!(databricks_tools[0].r#type, "function"); + assert_eq!(databricks_tools[0].function.name, "get_weather"); + assert_eq!(databricks_tools[0].function.description, "Get the current weather"); + } + + #[test] + fn test_has_native_tool_calling() { + let claude_provider = DatabricksProvider::from_token( + "https://test.databricks.com".to_string(), + "test-token".to_string(), + "databricks-claude-sonnet-4".to_string(), + None, + None, + ).unwrap(); + + let llama_provider = DatabricksProvider::from_token( + "https://test.databricks.com".to_string(), + "test-token".to_string(), + "databricks-meta-llama-3-3-70b-instruct".to_string(), + None, + None, + ).unwrap(); + + let dbrx_provider = DatabricksProvider::from_token( + "https://test.databricks.com".to_string(), + "test-token".to_string(), + "databricks-dbrx-instruct".to_string(), + None, + None, + ).unwrap(); + + assert!(claude_provider.has_native_tool_calling()); + assert!(llama_provider.has_native_tool_calling()); + assert!(dbrx_provider.has_native_tool_calling()); + } +} diff --git a/crates/g3-providers/src/lib.rs b/crates/g3-providers/src/lib.rs index 18dc8ce..a778f6d 100644 --- a/crates/g3-providers/src/lib.rs +++ b/crates/g3-providers/src/lib.rs @@ -84,8 +84,11 @@ pub struct Tool { } pub mod anthropic; +pub mod databricks; +pub mod oauth; pub use anthropic::AnthropicProvider; +pub use databricks::DatabricksProvider; /// Provider registry for managing multiple LLM providers pub struct ProviderRegistry { diff --git a/crates/g3-providers/src/oauth.rs b/crates/g3-providers/src/oauth.rs new file mode 100644 index 0000000..3508a6e --- /dev/null +++ b/crates/g3-providers/src/oauth.rs @@ -0,0 +1,457 @@ +use anyhow::Result; +use axum::{extract::Query, response::Html, routing::get, Router}; +use base64::Engine; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use sha2::Digest; +use std::{collections::HashMap, fs, net::SocketAddr, path::PathBuf, sync::Arc}; +use tokio::sync::{oneshot, Mutex as TokioMutex}; +use url::Url; + +#[derive(Debug, Clone)] +struct OidcEndpoints { + authorization_endpoint: String, + token_endpoint: String, +} + +#[derive(Serialize, Deserialize)] +struct TokenData { + /// The access token used to authenticate API requests + access_token: String, + + /// Optional refresh token that can be used to obtain a new access token + /// when the current one expires, enabling offline access without user interaction + refresh_token: Option, + + /// When the access token expires (if known) + /// Used to determine when a token needs to be refreshed + expires_at: Option>, +} + +struct TokenCache { + cache_path: PathBuf, +} + +fn get_base_path() -> PathBuf { + // Use a similar pattern to Goose but for g3 + // macOS/Linux: ~/.config/g3/databricks/oauth + // Windows: ~\AppData\Roaming\g3\config\databricks\oauth\ + let mut path = dirs::config_dir().unwrap_or_else(|| PathBuf::from(".")); + path.push("g3"); + path.push("databricks"); + path.push("oauth"); + path +} + +impl TokenCache { + fn new(host: &str, client_id: &str, scopes: &[String]) -> Self { + let mut hasher = sha2::Sha256::new(); + hasher.update(host.as_bytes()); + hasher.update(client_id.as_bytes()); + hasher.update(scopes.join(",").as_bytes()); + let hash = format!("{:x}", hasher.finalize()); + + fs::create_dir_all(get_base_path()).unwrap_or_else(|_| {}); + let cache_path = get_base_path().join(format!("{}.json", hash)); + + Self { cache_path } + } + + fn load_token(&self) -> Option { + if let Ok(contents) = fs::read_to_string(&self.cache_path) { + if let Ok(token_data) = serde_json::from_str::(&contents) { + // Only return tokens that have a refresh token + if token_data.refresh_token.is_some() { + // If token is not expired, return it for immediate use + if let Some(expires_at) = token_data.expires_at { + if expires_at > Utc::now() { + return Some(token_data); + } + // If token is expired but has refresh token, return it so we can refresh + return Some(token_data); + } + // No expiration time but has refresh token, return it + return Some(token_data); + } + // Token doesn't have a refresh token, ignore it to force a new OAuth flow + } + } + None + } + + fn save_token(&self, token_data: &TokenData) -> Result<()> { + if let Some(parent) = self.cache_path.parent() { + fs::create_dir_all(parent)?; + } + let contents = serde_json::to_string(token_data)?; + fs::write(&self.cache_path, contents)?; + Ok(()) + } +} + +async fn get_workspace_endpoints(host: &str) -> Result { + let base_url = Url::parse(host).expect("Invalid host URL"); + let oidc_url = base_url + .join("oidc/.well-known/oauth-authorization-server") + .expect("Invalid OIDC URL"); + + let client = reqwest::Client::new(); + let resp = client.get(oidc_url.clone()).send().await?; + + if !resp.status().is_success() { + return Err(anyhow::anyhow!( + "Failed to get OIDC configuration from {}", + oidc_url.to_string() + )); + } + + let oidc_config: Value = resp.json().await?; + + let authorization_endpoint = oidc_config + .get("authorization_endpoint") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("authorization_endpoint not found in OIDC configuration"))? + .to_string(); + + let token_endpoint = oidc_config + .get("token_endpoint") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("token_endpoint not found in OIDC configuration"))? + .to_string(); + + Ok(OidcEndpoints { + authorization_endpoint, + token_endpoint, + }) +} + +struct OAuthFlow { + endpoints: OidcEndpoints, + client_id: String, + redirect_url: String, + scopes: Vec, + state: String, + verifier: String, +} + +impl OAuthFlow { + fn new( + endpoints: OidcEndpoints, + client_id: String, + redirect_url: String, + scopes: Vec, + ) -> Self { + Self { + endpoints, + client_id, + redirect_url, + scopes, + state: nanoid::nanoid!(16), + verifier: nanoid::nanoid!(64), + } + } + + /// Extracts token data from an OAuth 2.0 token response. + fn extract_token_data( + &self, + token_response: &Value, + old_refresh_token: Option<&str>, + ) -> Result { + // Extract access token (required) + let access_token = token_response + .get("access_token") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("access_token not found in token response"))? + .to_string(); + + // Extract refresh token if available + let refresh_token = token_response + .get("refresh_token") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .or_else(|| old_refresh_token.map(|s| s.to_string())); + + // Handle token expiration + let expires_at = + if let Some(expires_in) = token_response.get("expires_in").and_then(|v| v.as_u64()) { + // Traditional OAuth flow with expires_in seconds + Some(Utc::now() + chrono::Duration::seconds(expires_in as i64)) + } else { + // If the server doesn't provide any expiration info, log it but don't set an expiration + tracing::debug!( + "No expiration information provided by server, token expiration unknown." + ); + None + }; + + Ok(TokenData { + access_token, + refresh_token, + expires_at, + }) + } + + fn get_authorization_url(&self) -> String { + let challenge = { + let digest = sha2::Sha256::digest(self.verifier.as_bytes()); + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest) + }; + + let params = [ + ("response_type", "code"), + ("client_id", &self.client_id), + ("redirect_uri", &self.redirect_url), + ("scope", &self.scopes.join(" ")), + ("state", &self.state), + ("code_challenge", &challenge), + ("code_challenge_method", "S256"), + ]; + + format!( + "{}?{}", + self.endpoints.authorization_endpoint, + serde_urlencoded::to_string(params).unwrap() + ) + } + + async fn exchange_code_for_token(&self, code: &str) -> Result { + let params = [ + ("grant_type", "authorization_code"), + ("code", code), + ("redirect_uri", &self.redirect_url), + ("code_verifier", &self.verifier), + ("client_id", &self.client_id), + ]; + + let client = reqwest::Client::new(); + let resp = client + .post(&self.endpoints.token_endpoint) + .header("Content-Type", "application/x-www-form-urlencoded") + .form(¶ms) + .send() + .await?; + + if !resp.status().is_success() { + let err_text = resp.text().await?; + return Err(anyhow::anyhow!( + "Failed to exchange code for token: {}", + err_text + )); + } + + let token_response: Value = resp.json().await?; + self.extract_token_data(&token_response, None) + } + + async fn refresh_token(&self, refresh_token: &str) -> Result { + let params = [ + ("grant_type", "refresh_token"), + ("refresh_token", refresh_token), + ("client_id", &self.client_id), + ]; + + tracing::debug!("Refreshing token using refresh_token"); + + let client = reqwest::Client::new(); + let resp = client + .post(&self.endpoints.token_endpoint) + .header("Content-Type", "application/x-www-form-urlencoded") + .form(¶ms) + .send() + .await?; + + if !resp.status().is_success() { + let err_text = resp.text().await?; + return Err(anyhow::anyhow!("Failed to refresh token: {}", err_text)); + } + + let token_response: Value = resp.json().await?; + self.extract_token_data(&token_response, Some(refresh_token)) + } + + async fn execute(&self) -> Result { + // Create a channel that will send the auth code from the app process + let (tx, rx) = oneshot::channel(); + let state = self.state.clone(); + let tx = Arc::new(TokioMutex::new(Some(tx))); + + // Setup a server that will receive the redirect, capture the code, and display success/failure + let app = Router::new().route( + "/", + get(move |Query(params): Query>| { + let tx = Arc::clone(&tx); + let state = state.clone(); + async move { + let code = params.get("code").cloned(); + let received_state = params.get("state").cloned(); + + if let (Some(code), Some(received_state)) = (code, received_state) { + if received_state == state { + if let Some(sender) = tx.lock().await.take() { + if sender.send(code).is_ok() { + return Html( + "

G3 Authentication Success

You can close this window and return to your terminal.

", + ); + } + } + Html("

Error

Authentication already completed.

") + } else { + Html("

Error

State mismatch.

") + } + } else { + Html("

Error

Authentication failed.

") + } + } + }), + ); + + // Start the server to accept the oauth code + let redirect_url = Url::parse(&self.redirect_url)?; + let port = redirect_url.port().unwrap_or(80); + let addr = SocketAddr::from(([127, 0, 0, 1], port)); + + let listener = tokio::net::TcpListener::bind(addr).await?; + + let server_handle = tokio::spawn(async move { + let server = axum::serve(listener, app); + server.await.unwrap(); + }); + + // Open the browser which will redirect with the code to the server + let authorization_url = self.get_authorization_url(); + println!("🔐 Opening browser for Databricks authentication..."); + if webbrowser::open(&authorization_url).is_err() { + println!( + "Please open this URL in your browser:\n{}", + authorization_url + ); + } + + // Wait for the authorization code with a timeout + let code = tokio::time::timeout( + std::time::Duration::from_secs(120), // 2 minute timeout + rx, + ) + .await + .map_err(|_| anyhow::anyhow!("Authentication timed out after 2 minutes"))??; + + // Stop the server + server_handle.abort(); + + println!("✅ Authentication successful! Exchanging code for token..."); + + // Exchange the code for a token + self.exchange_code_for_token(&code).await + } +} + +pub async fn get_oauth_token_async( + host: &str, + client_id: &str, + redirect_url: &str, + scopes: &[String], +) -> Result { + let token_cache = TokenCache::new(host, client_id, scopes); + + // Try cache first + if let Some(token) = token_cache.load_token() { + // If token has an expiration time, check if it's expired + if let Some(expires_at) = token.expires_at { + if expires_at > Utc::now() { + tracing::debug!("Using cached token"); + return Ok(token.access_token); + } + // Token is expired, will try to refresh below + tracing::debug!("Token is expired, attempting to refresh"); + } else { + // No expiration time was provided by the server + tracing::debug!("Token has no expiration time, using cached token"); + return Ok(token.access_token); + } + + // Token is expired or has no expiration, try to refresh if we have a refresh token + if let Some(refresh_token) = token.refresh_token { + // Get endpoints for token refresh + match get_workspace_endpoints(host).await { + Ok(endpoints) => { + let flow = OAuthFlow::new( + endpoints, + client_id.to_string(), + redirect_url.to_string(), + scopes.to_vec(), + ); + + // Try to refresh the token + match flow.refresh_token(&refresh_token).await { + Ok(new_token) => { + if let Err(e) = token_cache.save_token(&new_token) { + tracing::warn!("Failed to save refreshed token: {}", e); + } + tracing::info!("Successfully refreshed token"); + return Ok(new_token.access_token); + } + Err(e) => { + tracing::warn!( + "Failed to refresh token, will try new auth flow: {}", + e + ); + // Continue to new auth flow + } + } + } + Err(e) => { + tracing::warn!("Failed to get endpoints for token refresh: {}", e); + // Continue to new auth flow + } + } + } + } + + // Get endpoints and execute flow for a new token + let endpoints = get_workspace_endpoints(host).await?; + let flow = OAuthFlow::new( + endpoints, + client_id.to_string(), + redirect_url.to_string(), + scopes.to_vec(), + ); + + // Execute the OAuth flow and get token + let token = flow.execute().await?; + + // Cache and return + token_cache.save_token(&token)?; + println!("🎉 Databricks authentication complete!"); + Ok(token.access_token) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_token_cache() -> Result<()> { + let cache = TokenCache::new( + "https://example.com", + "test-client", + &["scope1".to_string()], + ); + + // Test with expiration time + let token_data = TokenData { + access_token: "test-token".to_string(), + refresh_token: Some("test-refresh-token".to_string()), + expires_at: Some(Utc::now() + chrono::Duration::hours(1)), + }; + + cache.save_token(&token_data)?; + + let loaded_token = cache.load_token().unwrap(); + assert_eq!(loaded_token.access_token, token_data.access_token); + assert_eq!(loaded_token.refresh_token, token_data.refresh_token); + assert!(loaded_token.expires_at.is_some()); + + Ok(()) + } +}