databricks support
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -6,6 +6,8 @@ target
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
**/.DS_Store
|
||||
|
||||
# MSVC Windows builds of rustc generate these, which store debugging information
|
||||
*.pdb
|
||||
|
||||
|
||||
459
Cargo.lock
generated
459
Cargo.lock
generated
@@ -126,12 +126,73 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic-waker"
|
||||
version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
|
||||
|
||||
[[package]]
|
||||
name = "axum"
|
||||
version = "0.7.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 1.3.1",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"hyper 1.7.0",
|
||||
"hyper-util",
|
||||
"itoa",
|
||||
"matchit",
|
||||
"memchr",
|
||||
"mime",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"rustversion",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper 1.0.2",
|
||||
"tokio",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-core"
|
||||
version = "0.4.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 1.3.1",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"mime",
|
||||
"pin-project-lite",
|
||||
"rustversion",
|
||||
"sync_wrapper 1.0.2",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backtrace"
|
||||
version = "0.3.75"
|
||||
@@ -153,6 +214,12 @@ version = "0.21.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.22.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
|
||||
|
||||
[[package]]
|
||||
name = "bindgen"
|
||||
version = "0.69.5"
|
||||
@@ -224,6 +291,12 @@ dependencies = [
|
||||
"shlex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cesu8"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c"
|
||||
|
||||
[[package]]
|
||||
name = "cexpr"
|
||||
version = "0.6.0"
|
||||
@@ -325,6 +398,16 @@ version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
|
||||
|
||||
[[package]]
|
||||
name = "combine"
|
||||
version = "4.6.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "config"
|
||||
version = "0.14.1"
|
||||
@@ -402,6 +485,16 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation-sys"
|
||||
version = "0.8.7"
|
||||
@@ -792,15 +885,25 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"axum",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"dirs 5.0.1",
|
||||
"futures-util",
|
||||
"nanoid",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"sha2",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"url",
|
||||
"webbrowser",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -859,7 +962,7 @@ dependencies = [
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http 0.2.12",
|
||||
"indexmap",
|
||||
"slab",
|
||||
"tokio",
|
||||
@@ -924,6 +1027,17 @@ dependencies = [
|
||||
"itoa",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "1.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fnv",
|
||||
"itoa",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body"
|
||||
version = "0.4.6"
|
||||
@@ -931,7 +1045,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"http",
|
||||
"http 0.2.12",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"http 1.3.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body-util"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"http 1.3.1",
|
||||
"http-body 1.0.1",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
@@ -958,8 +1095,8 @@ dependencies = [
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2",
|
||||
"http",
|
||||
"http-body",
|
||||
"http 0.2.12",
|
||||
"http-body 0.4.6",
|
||||
"httparse",
|
||||
"httpdate",
|
||||
"itoa",
|
||||
@@ -971,6 +1108,27 @@ dependencies = [
|
||||
"want",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "1.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e"
|
||||
dependencies = [
|
||||
"atomic-waker",
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"http 1.3.1",
|
||||
"http-body 1.0.1",
|
||||
"httparse",
|
||||
"httpdate",
|
||||
"itoa",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"smallvec",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-tls"
|
||||
version = "0.5.0"
|
||||
@@ -978,12 +1136,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"hyper",
|
||||
"hyper 0.14.32",
|
||||
"native-tls",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-util"
|
||||
version = "0.1.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"http 1.3.1",
|
||||
"http-body 1.0.1",
|
||||
"hyper 1.7.0",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone"
|
||||
version = "0.1.64"
|
||||
@@ -1176,6 +1350,28 @@ version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
|
||||
|
||||
[[package]]
|
||||
name = "jni"
|
||||
version = "0.21.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97"
|
||||
dependencies = [
|
||||
"cesu8",
|
||||
"cfg-if",
|
||||
"combine",
|
||||
"jni-sys",
|
||||
"log",
|
||||
"thiserror 1.0.69",
|
||||
"walkdir",
|
||||
"windows-sys 0.45.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jni-sys"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
|
||||
|
||||
[[package]]
|
||||
name = "jobserver"
|
||||
version = "0.1.34"
|
||||
@@ -1324,6 +1520,12 @@ dependencies = [
|
||||
"regex-automata",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "matchit"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.7.5"
|
||||
@@ -1362,6 +1564,15 @@ dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nanoid"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3ffa00dec017b5b1a8b7cf5e2c008bfda1aa7e0697ac1508b491fdf2622fb4d8"
|
||||
dependencies = [
|
||||
"rand",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "native-tls"
|
||||
version = "0.2.14"
|
||||
@@ -1379,6 +1590,12 @@ dependencies = [
|
||||
"tempfile",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndk-context"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "27b02d87554356db9e9a873add8782d4ea6e3e58ea071a9adb9a2e8ddb884a8b"
|
||||
|
||||
[[package]]
|
||||
name = "nibble_vec"
|
||||
version = "0.1.0"
|
||||
@@ -1444,6 +1661,31 @@ version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
|
||||
|
||||
[[package]]
|
||||
name = "objc2"
|
||||
version = "0.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "561f357ba7f3a2a61563a186a163d0a3a5247e1089524a3981d49adb775078bc"
|
||||
dependencies = [
|
||||
"objc2-encode",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "objc2-encode"
|
||||
version = "4.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33"
|
||||
|
||||
[[package]]
|
||||
name = "objc2-foundation"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "900831247d2fe1a09a683278e5384cfb8c80c79fe6b166f9d14bfdde0ea1b03c"
|
||||
dependencies = [
|
||||
"bitflags 2.9.4",
|
||||
"objc2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "object"
|
||||
version = "0.36.7"
|
||||
@@ -1637,6 +1879,15 @@ dependencies = [
|
||||
"zerovec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9"
|
||||
dependencies = [
|
||||
"zerocopy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prettyplease"
|
||||
version = "0.2.37"
|
||||
@@ -1681,6 +1932,36 @@ dependencies = [
|
||||
"nibble_vec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"rand_chacha",
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
||||
dependencies = [
|
||||
"ppv-lite86",
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.6.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
||||
dependencies = [
|
||||
"getrandom 0.2.16",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.5.17"
|
||||
@@ -1747,15 +2028,15 @@ version = "0.11.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"base64 0.21.7",
|
||||
"bytes",
|
||||
"encoding_rs",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2",
|
||||
"http",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"http 0.2.12",
|
||||
"http-body 0.4.6",
|
||||
"hyper 0.14.32",
|
||||
"hyper-tls",
|
||||
"ipnet",
|
||||
"js-sys",
|
||||
@@ -1769,7 +2050,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper",
|
||||
"sync_wrapper 0.1.2",
|
||||
"system-configuration",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
@@ -1789,7 +2070,7 @@ version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b91f7eff05f748767f183df4320a63d6936e9c6107d97c9e6bdd9784f4289c94"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"base64 0.21.7",
|
||||
"bitflags 2.9.4",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
@@ -1858,7 +2139,7 @@ version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"base64 0.21.7",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1895,6 +2176,15 @@ version = "1.0.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
|
||||
|
||||
[[package]]
|
||||
name = "same-file"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
|
||||
dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "schannel"
|
||||
version = "0.1.27"
|
||||
@@ -1917,7 +2207,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
|
||||
dependencies = [
|
||||
"bitflags 2.9.4",
|
||||
"core-foundation",
|
||||
"core-foundation 0.9.4",
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
"security-framework-sys",
|
||||
@@ -1981,6 +2271,17 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_path_to_error"
|
||||
version = "0.1.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"serde",
|
||||
"serde_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_spanned"
|
||||
version = "0.6.9"
|
||||
@@ -2107,6 +2408,12 @@ version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
|
||||
|
||||
[[package]]
|
||||
name = "sync_wrapper"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
|
||||
|
||||
[[package]]
|
||||
name = "synstructure"
|
||||
version = "0.13.2"
|
||||
@@ -2125,7 +2432,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"core-foundation",
|
||||
"core-foundation 0.9.4",
|
||||
"system-configuration-sys",
|
||||
]
|
||||
|
||||
@@ -2326,6 +2633,28 @@ version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801"
|
||||
|
||||
[[package]]
|
||||
name = "tower"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"pin-project-lite",
|
||||
"sync_wrapper 1.0.2",
|
||||
"tokio",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-layer"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e"
|
||||
|
||||
[[package]]
|
||||
name = "tower-service"
|
||||
version = "0.3.3"
|
||||
@@ -2338,6 +2667,7 @@ version = "0.1.41"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0"
|
||||
dependencies = [
|
||||
"log",
|
||||
"pin-project-lite",
|
||||
"tracing-attributes",
|
||||
"tracing-core",
|
||||
@@ -2482,6 +2812,16 @@ version = "0.9.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
|
||||
|
||||
[[package]]
|
||||
name = "walkdir"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
|
||||
dependencies = [
|
||||
"same-file",
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "want"
|
||||
version = "0.3.1"
|
||||
@@ -2611,6 +2951,22 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webbrowser"
|
||||
version = "1.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aaf4f3c0ba838e82b4e5ccc4157003fb8c324ee24c058470ffb82820becbde98"
|
||||
dependencies = [
|
||||
"core-foundation 0.10.1",
|
||||
"jni",
|
||||
"log",
|
||||
"ndk-context",
|
||||
"objc2",
|
||||
"objc2-foundation",
|
||||
"url",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "which"
|
||||
version = "4.4.2"
|
||||
@@ -2623,6 +2979,15 @@ dependencies = [
|
||||
"rustix 0.38.44",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-util"
|
||||
version = "0.1.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
|
||||
dependencies = [
|
||||
"windows-sys 0.61.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-core"
|
||||
version = "0.62.1"
|
||||
@@ -2688,6 +3053,15 @@ dependencies = [
|
||||
"windows-link 0.2.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.45.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0"
|
||||
dependencies = [
|
||||
"windows-targets 0.42.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.48.0"
|
||||
@@ -2733,6 +3107,21 @@ dependencies = [
|
||||
"windows-link 0.2.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm 0.42.2",
|
||||
"windows_aarch64_msvc 0.42.2",
|
||||
"windows_i686_gnu 0.42.2",
|
||||
"windows_i686_msvc 0.42.2",
|
||||
"windows_x86_64_gnu 0.42.2",
|
||||
"windows_x86_64_gnullvm 0.42.2",
|
||||
"windows_x86_64_msvc 0.42.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.48.5"
|
||||
@@ -2781,6 +3170,12 @@ dependencies = [
|
||||
"windows_x86_64_msvc 0.53.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.48.5"
|
||||
@@ -2799,6 +3194,12 @@ version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.48.5"
|
||||
@@ -2817,6 +3218,12 @@ version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.48.5"
|
||||
@@ -2847,6 +3254,12 @@ version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.48.5"
|
||||
@@ -2865,6 +3278,12 @@ version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.48.5"
|
||||
@@ -2883,6 +3302,12 @@ version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.48.5"
|
||||
@@ -2901,6 +3326,12 @@ version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.48.5"
|
||||
|
||||
@@ -1,36 +1,13 @@
|
||||
# Example configuration file for G3
|
||||
# Copy to ~/.config/g3/config.toml and customize
|
||||
|
||||
[providers]
|
||||
default_provider = "embedded"
|
||||
default_provider = "databricks"
|
||||
|
||||
[providers.openai]
|
||||
# Get your API key from https://platform.openai.com/api-keys
|
||||
api_key = "sk-your-openai-api-key-here"
|
||||
model = "gpt-4"
|
||||
# Optional: custom base URL for OpenAI-compatible APIs
|
||||
# base_url = "https://api.openai.com/v1"
|
||||
max_tokens = 2048
|
||||
temperature = 0.1
|
||||
|
||||
[providers.anthropic]
|
||||
# Get your API key from https://console.anthropic.com/
|
||||
api_key = "your-anthropic-api-key-here"
|
||||
model = "claude-3-5-sonnet-20241022"
|
||||
[providers.databricks]
|
||||
host = "https://your-workspace.cloud.databricks.com"
|
||||
# token = "your-databricks-token" # Optional - will use OAuth if not provided
|
||||
model = "databricks-claude-sonnet-4"
|
||||
max_tokens = 4096
|
||||
temperature = 0.1
|
||||
|
||||
[providers.embedded]
|
||||
# Path to your GGUF model file
|
||||
model_path = "~/.cache/g3/models/codellama-7b-instruct.Q4_K_M.gguf"
|
||||
model_type = "codellama"
|
||||
context_length = 16384 # Use CodeLlama's full context capability
|
||||
max_tokens = 2048 # Default fallback, but will be calculated dynamically
|
||||
temperature = 0.1
|
||||
# Number of layers to offload to GPU (0 for CPU only)
|
||||
gpu_layers = 32
|
||||
# Number of CPU threads to use
|
||||
threads = 8
|
||||
use_oauth = true
|
||||
|
||||
[agent]
|
||||
max_context_length = 8192
|
||||
|
||||
@@ -12,6 +12,7 @@ pub struct Config {
|
||||
pub struct ProvidersConfig {
|
||||
pub openai: Option<OpenAIConfig>,
|
||||
pub anthropic: Option<AnthropicConfig>,
|
||||
pub databricks: Option<DatabricksConfig>,
|
||||
pub embedded: Option<EmbeddedConfig>,
|
||||
pub default_provider: String,
|
||||
}
|
||||
@@ -33,6 +34,16 @@ pub struct AnthropicConfig {
|
||||
pub temperature: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DatabricksConfig {
|
||||
pub host: String,
|
||||
pub token: Option<String>, // Optional - will use OAuth if not provided
|
||||
pub model: String,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub temperature: Option<f32>,
|
||||
pub use_oauth: Option<bool>, // Default to true if token not provided
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddedConfig {
|
||||
pub model_path: String,
|
||||
@@ -57,8 +68,16 @@ impl Default for Config {
|
||||
providers: ProvidersConfig {
|
||||
openai: None,
|
||||
anthropic: None,
|
||||
databricks: Some(DatabricksConfig {
|
||||
host: "https://your-workspace.cloud.databricks.com".to_string(),
|
||||
token: None, // Will use OAuth by default
|
||||
model: "databricks-claude-sonnet-4".to_string(),
|
||||
max_tokens: Some(4096),
|
||||
temperature: Some(0.1),
|
||||
use_oauth: Some(true),
|
||||
}),
|
||||
embedded: None,
|
||||
default_provider: "anthropic".to_string(),
|
||||
default_provider: "databricks".to_string(),
|
||||
},
|
||||
agent: AgentConfig {
|
||||
max_context_length: 8192,
|
||||
@@ -88,9 +107,9 @@ impl Config {
|
||||
})
|
||||
};
|
||||
|
||||
// If no config exists, create and save a default Qwen config
|
||||
// If no config exists, create and save a default Databricks config
|
||||
if !config_exists {
|
||||
let qwen_config = Self::default_qwen_config();
|
||||
let databricks_config = Self::default();
|
||||
|
||||
// Save to default location
|
||||
let config_dir = dirs::home_dir()
|
||||
@@ -105,13 +124,13 @@ impl Config {
|
||||
std::fs::create_dir_all(&config_dir).ok();
|
||||
|
||||
let config_file = config_dir.join("config.toml");
|
||||
if let Err(e) = qwen_config.save(config_file.to_str().unwrap()) {
|
||||
if let Err(e) = databricks_config.save(config_file.to_str().unwrap()) {
|
||||
eprintln!("Warning: Could not save default config: {}", e);
|
||||
} else {
|
||||
println!("Created default Qwen configuration at: {}", config_file.display());
|
||||
println!("Created default Databricks configuration at: {}", config_file.display());
|
||||
}
|
||||
|
||||
return Ok(qwen_config);
|
||||
return Ok(databricks_config);
|
||||
}
|
||||
|
||||
// Existing config loading logic
|
||||
@@ -157,6 +176,7 @@ impl Config {
|
||||
providers: ProvidersConfig {
|
||||
openai: None,
|
||||
anthropic: None,
|
||||
databricks: None,
|
||||
embedded: Some(EmbeddedConfig {
|
||||
model_path: "~/.cache/g3/models/qwen2.5-7b-instruct-q3_k_m.gguf".to_string(),
|
||||
model_type: "qwen".to_string(),
|
||||
|
||||
@@ -303,6 +303,36 @@ impl Agent {
|
||||
}
|
||||
}
|
||||
|
||||
// Register Databricks provider if configured AND it's the default provider
|
||||
if let Some(databricks_config) = &config.providers.databricks {
|
||||
if config.providers.default_provider == "databricks" {
|
||||
info!("Initializing Databricks provider (selected as default)");
|
||||
|
||||
let databricks_provider = if let Some(token) = &databricks_config.token {
|
||||
// Use token-based authentication
|
||||
g3_providers::DatabricksProvider::from_token(
|
||||
databricks_config.host.clone(),
|
||||
token.clone(),
|
||||
databricks_config.model.clone(),
|
||||
databricks_config.max_tokens,
|
||||
databricks_config.temperature,
|
||||
)?
|
||||
} else {
|
||||
// Use OAuth authentication
|
||||
g3_providers::DatabricksProvider::from_oauth(
|
||||
databricks_config.host.clone(),
|
||||
databricks_config.model.clone(),
|
||||
databricks_config.max_tokens,
|
||||
databricks_config.temperature,
|
||||
).await?
|
||||
};
|
||||
|
||||
providers.register(databricks_provider);
|
||||
} else {
|
||||
info!("Databricks provider configured but not selected as default, skipping initialization");
|
||||
}
|
||||
}
|
||||
|
||||
// Set default provider
|
||||
debug!(
|
||||
"Setting default provider to: {}",
|
||||
@@ -352,7 +382,18 @@ impl Agent {
|
||||
// Claude models have large context windows
|
||||
200000 // Default for Claude models
|
||||
}
|
||||
|
||||
"databricks" => {
|
||||
// Databricks models have varying context windows depending on the model
|
||||
if model_name.contains("claude") {
|
||||
200000 // Claude models on Databricks have large context windows
|
||||
} else if model_name.contains("llama") {
|
||||
32768 // Llama models typically support 32k context
|
||||
} else if model_name.contains("dbrx") {
|
||||
32768 // DBRX supports 32k context
|
||||
} else {
|
||||
16384 // Conservative default for other Databricks models
|
||||
}
|
||||
}
|
||||
_ => config.agent.max_context_length as u32,
|
||||
};
|
||||
|
||||
|
||||
@@ -16,3 +16,14 @@ async-trait = "0.1"
|
||||
tokio-stream = "0.1"
|
||||
futures-util = "0.3"
|
||||
bytes = "1.0"
|
||||
# OAuth dependencies
|
||||
axum = "0.7"
|
||||
base64 = "0.22"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
sha2 = "0.10"
|
||||
url = "2.5"
|
||||
webbrowser = "1.0"
|
||||
nanoid = "0.4"
|
||||
serde_urlencoded = "0.7"
|
||||
tokio-util = "0.7"
|
||||
dirs = "5.0"
|
||||
|
||||
907
crates/g3-providers/src/databricks.rs
Normal file
907
crates/g3-providers/src/databricks.rs
Normal file
@@ -0,0 +1,907 @@
|
||||
//! Databricks LLM provider implementation for the g3-providers crate.
|
||||
//!
|
||||
//! This module provides an implementation of the `LLMProvider` trait for Databricks Foundation Model APIs,
|
||||
//! supporting both completion and streaming modes with OAuth authentication.
|
||||
//!
|
||||
//! # Features
|
||||
//!
|
||||
//! - Support for Databricks Foundation Models (databricks-claude-sonnet-4, databricks-meta-llama-3-3-70b-instruct, etc.)
|
||||
//! - Both completion and streaming response modes
|
||||
//! - OAuth authentication with automatic token refresh
|
||||
//! - Token-based authentication as fallback
|
||||
//! - Native tool calling support for compatible models
|
||||
//! - Automatic model discovery from Databricks workspace
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use g3_providers::{DatabricksProvider, LLMProvider, CompletionRequest, Message, MessageRole};
|
||||
//!
|
||||
//! #[tokio::main]
|
||||
//! async fn main() -> anyhow::Result<()> {
|
||||
//! // Create the provider with OAuth (recommended)
|
||||
//! let provider = DatabricksProvider::from_oauth(
|
||||
//! "https://your-workspace.cloud.databricks.com".to_string(),
|
||||
//! "databricks-claude-sonnet-4".to_string(),
|
||||
//! None, // Optional: max tokens
|
||||
//! None, // Optional: temperature
|
||||
//! ).await?;
|
||||
//!
|
||||
//! // Or create with token
|
||||
//! let provider = DatabricksProvider::from_token(
|
||||
//! "https://your-workspace.cloud.databricks.com".to_string(),
|
||||
//! "your-databricks-token".to_string(),
|
||||
//! "databricks-claude-sonnet-4".to_string(),
|
||||
//! None,
|
||||
//! None,
|
||||
//! )?;
|
||||
//!
|
||||
//! // Create a completion request
|
||||
//! let request = CompletionRequest {
|
||||
//! messages: vec![
|
||||
//! Message {
|
||||
//! role: MessageRole::User,
|
||||
//! content: "Hello! How are you?".to_string(),
|
||||
//! },
|
||||
//! ],
|
||||
//! max_tokens: Some(1000),
|
||||
//! temperature: Some(0.7),
|
||||
//! stream: false,
|
||||
//! tools: None,
|
||||
//! };
|
||||
//!
|
||||
//! // Get a completion
|
||||
//! let response = provider.complete(request).await?;
|
||||
//! println!("Response: {}", response.content);
|
||||
//!
|
||||
//! Ok(())
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use bytes::Bytes;
|
||||
use futures_util::stream::StreamExt;
|
||||
use reqwest::{Client, RequestBuilder};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::{
|
||||
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
|
||||
MessageRole, Tool, ToolCall, Usage,
|
||||
};
|
||||
|
||||
const DEFAULT_CLIENT_ID: &str = "databricks-cli";
|
||||
const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020";
|
||||
const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"];
|
||||
const DEFAULT_TIMEOUT_SECS: u64 = 600;
|
||||
|
||||
pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-sonnet-4";
|
||||
const DATABRICKS_DEFAULT_FAST_MODEL: &str = "gemini-1-5-flash";
|
||||
pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[
|
||||
"databricks-claude-3-7-sonnet",
|
||||
"databricks-meta-llama-3-3-70b-instruct",
|
||||
"databricks-meta-llama-3-1-405b-instruct",
|
||||
"databricks-dbrx-instruct",
|
||||
"databricks-mixtral-8x7b-instruct",
|
||||
];
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DatabricksAuth {
|
||||
Token(String),
|
||||
OAuth {
|
||||
host: String,
|
||||
client_id: String,
|
||||
redirect_url: String,
|
||||
scopes: Vec<String>,
|
||||
cached_token: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl DatabricksAuth {
|
||||
pub fn oauth(host: String) -> Self {
|
||||
Self::OAuth {
|
||||
host,
|
||||
client_id: DEFAULT_CLIENT_ID.to_string(),
|
||||
redirect_url: DEFAULT_REDIRECT_URL.to_string(),
|
||||
scopes: DEFAULT_SCOPES.iter().map(|s| s.to_string()).collect(),
|
||||
cached_token: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn token(token: String) -> Self {
|
||||
Self::Token(token)
|
||||
}
|
||||
|
||||
async fn get_token(&mut self) -> Result<String> {
|
||||
match self {
|
||||
DatabricksAuth::Token(token) => Ok(token.clone()),
|
||||
DatabricksAuth::OAuth {
|
||||
host,
|
||||
client_id,
|
||||
redirect_url,
|
||||
scopes,
|
||||
cached_token: _,
|
||||
} => {
|
||||
// Use the OAuth implementation
|
||||
crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DatabricksProvider {
|
||||
client: Client,
|
||||
host: String,
|
||||
auth: DatabricksAuth,
|
||||
model: String,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
}
|
||||
|
||||
impl DatabricksProvider {
|
||||
pub fn from_token(
|
||||
host: String,
|
||||
token: String,
|
||||
model: String,
|
||||
max_tokens: Option<u32>,
|
||||
temperature: Option<f32>,
|
||||
) -> Result<Self> {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
|
||||
.build()
|
||||
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
info!("Initialized Databricks provider with model: {} on host: {}", model, host);
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
host: host.trim_end_matches('/').to_string(),
|
||||
auth: DatabricksAuth::token(token),
|
||||
model,
|
||||
max_tokens: max_tokens.unwrap_or(4096),
|
||||
temperature: temperature.unwrap_or(0.1),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn from_oauth(
|
||||
host: String,
|
||||
model: String,
|
||||
max_tokens: Option<u32>,
|
||||
temperature: Option<f32>,
|
||||
) -> Result<Self> {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
|
||||
.build()
|
||||
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
info!("Initialized Databricks provider with OAuth for model: {} on host: {}", model, host);
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
host: host.trim_end_matches('/').to_string(),
|
||||
auth: DatabricksAuth::oauth(host.clone()),
|
||||
model,
|
||||
max_tokens: max_tokens.unwrap_or(4096),
|
||||
temperature: temperature.unwrap_or(0.1),
|
||||
})
|
||||
}
|
||||
|
||||
async fn create_request_builder(&mut self, streaming: bool) -> Result<RequestBuilder> {
|
||||
let token = self.auth.get_token().await?;
|
||||
|
||||
let mut builder = self
|
||||
.client
|
||||
.post(&format!("{}/serving-endpoints/{}/invocations", self.host, self.model))
|
||||
.header("Authorization", format!("Bearer {}", token))
|
||||
.header("Content-Type", "application/json");
|
||||
|
||||
if streaming {
|
||||
builder = builder.header("Accept", "text/event-stream");
|
||||
}
|
||||
|
||||
Ok(builder)
|
||||
}
|
||||
|
||||
fn convert_tools(&self, tools: &[Tool]) -> Vec<DatabricksTool> {
|
||||
tools
|
||||
.iter()
|
||||
.map(|tool| DatabricksTool {
|
||||
r#type: "function".to_string(),
|
||||
function: DatabricksFunction {
|
||||
name: tool.name.clone(),
|
||||
description: tool.description.clone(),
|
||||
parameters: tool.input_schema.clone(),
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn convert_messages(&self, messages: &[Message]) -> Result<Vec<DatabricksMessage>> {
|
||||
let mut databricks_messages = Vec::new();
|
||||
|
||||
for message in messages {
|
||||
let role = match message.role {
|
||||
MessageRole::System => "system",
|
||||
MessageRole::User => "user",
|
||||
MessageRole::Assistant => "assistant",
|
||||
};
|
||||
|
||||
databricks_messages.push(DatabricksMessage {
|
||||
role: role.to_string(),
|
||||
content: Some(message.content.clone()),
|
||||
tool_calls: None, // Only used in responses, not requests
|
||||
});
|
||||
}
|
||||
|
||||
if databricks_messages.is_empty() {
|
||||
return Err(anyhow!("At least one message is required"));
|
||||
}
|
||||
|
||||
Ok(databricks_messages)
|
||||
}
|
||||
|
||||
fn create_request_body(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
tools: Option<&[Tool]>,
|
||||
streaming: bool,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
) -> Result<DatabricksRequest> {
|
||||
let databricks_messages = self.convert_messages(messages)?;
|
||||
|
||||
// Convert tools if provided
|
||||
let databricks_tools = tools.map(|t| self.convert_tools(t));
|
||||
|
||||
let request = DatabricksRequest {
|
||||
messages: databricks_messages,
|
||||
max_tokens,
|
||||
temperature,
|
||||
tools: databricks_tools,
|
||||
stream: streaming,
|
||||
};
|
||||
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
async fn parse_streaming_response(
|
||||
&self,
|
||||
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
|
||||
tx: mpsc::Sender<Result<CompletionChunk>>,
|
||||
) {
|
||||
let mut buffer = String::new();
|
||||
let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> = std::collections::HashMap::new(); // index -> (id, name, args)
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
let chunk_str = match std::str::from_utf8(&chunk) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
error!("Invalid UTF-8 in stream chunk: {}", e);
|
||||
let _ = tx
|
||||
.send(Err(anyhow!("Invalid UTF-8 in stream chunk: {}", e)))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
buffer.push_str(chunk_str);
|
||||
|
||||
// Process complete lines
|
||||
while let Some(line_end) = buffer.find('\n') {
|
||||
let line = buffer[..line_end].trim().to_string();
|
||||
buffer.drain(..line_end + 1);
|
||||
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse Server-Sent Events format
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
if data == "[DONE]" {
|
||||
debug!("Received stream completion marker");
|
||||
let final_tool_calls: Vec<ToolCall> = current_tool_calls.values()
|
||||
.map(|(id, name, args)| ToolCall {
|
||||
id: id.clone(),
|
||||
tool: name.clone(),
|
||||
args: serde_json::from_str(args).unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
|
||||
})
|
||||
.collect();
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
tool_calls: if final_tool_calls.is_empty() { None } else { Some(final_tool_calls) },
|
||||
};
|
||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
debug!("Raw Databricks API JSON: {}", data);
|
||||
|
||||
match serde_json::from_str::<DatabricksStreamChunk>(data) {
|
||||
Ok(chunk) => {
|
||||
debug!("Parsed stream chunk: {:?}", chunk);
|
||||
|
||||
// Handle different types of chunks
|
||||
if let Some(choices) = chunk.choices {
|
||||
for choice in choices {
|
||||
if let Some(delta) = choice.delta {
|
||||
// Handle text content
|
||||
if let Some(content) = delta.content {
|
||||
debug!("Sending text chunk: '{}'", content);
|
||||
let chunk = CompletionChunk {
|
||||
content,
|
||||
finished: false,
|
||||
tool_calls: None,
|
||||
};
|
||||
if tx.send(Ok(chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool calls - accumulate across chunks
|
||||
if let Some(tool_calls) = delta.tool_calls {
|
||||
for tool_call in tool_calls {
|
||||
let index = tool_call.index.unwrap_or(0);
|
||||
let entry = current_tool_calls.entry(index).or_insert_with(|| {
|
||||
(String::new(), String::new(), String::new())
|
||||
});
|
||||
|
||||
// Update ID if provided
|
||||
if let Some(id) = tool_call.id {
|
||||
entry.0 = id;
|
||||
}
|
||||
|
||||
// Update name if provided and not empty
|
||||
if !tool_call.function.name.is_empty() {
|
||||
entry.1 = tool_call.function.name;
|
||||
}
|
||||
|
||||
// Append arguments
|
||||
entry.2.push_str(&tool_call.function.arguments);
|
||||
|
||||
debug!("Accumulated tool call {}: id='{}', name='{}', args='{}'",
|
||||
index, entry.0, entry.1, entry.2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this choice is finished
|
||||
if choice.finish_reason.is_some() {
|
||||
debug!("Choice finished with reason: {:?}", choice.finish_reason);
|
||||
|
||||
// Convert accumulated tool calls to final format
|
||||
let final_tool_calls: Vec<ToolCall> = current_tool_calls.values()
|
||||
.filter(|(_, name, _)| !name.is_empty()) // Only include tool calls with names
|
||||
.map(|(id, name, args)| {
|
||||
debug!("Converting tool call: id='{}', name='{}', args='{}'", id, name, args);
|
||||
ToolCall {
|
||||
id: if id.is_empty() { format!("tool_{}", name) } else { id.clone() },
|
||||
tool: name.clone(),
|
||||
args: serde_json::from_str(args).unwrap_or_else(|e| {
|
||||
debug!("Failed to parse tool args '{}': {}", args, e);
|
||||
serde_json::Value::Object(serde_json::Map::new())
|
||||
}),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
debug!("Final tool calls: {:?}", final_tool_calls);
|
||||
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
tool_calls: if final_tool_calls.is_empty() { None } else { Some(final_tool_calls) },
|
||||
};
|
||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
|
||||
// Don't error out on parse failures, just continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Stream error: {}", e);
|
||||
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send final chunk if we haven't already
|
||||
let final_tool_calls: Vec<ToolCall> = current_tool_calls.values()
|
||||
.filter(|(_, name, _)| !name.is_empty())
|
||||
.map(|(id, name, args)| ToolCall {
|
||||
id: if id.is_empty() { format!("tool_{}", name) } else { id.clone() },
|
||||
tool: name.clone(),
|
||||
args: serde_json::from_str(args).unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
tool_calls: if final_tool_calls.is_empty() { None } else { Some(final_tool_calls) },
|
||||
};
|
||||
let _ = tx.send(Ok(final_chunk)).await;
|
||||
}
|
||||
|
||||
pub async fn fetch_supported_models(&mut self) -> Result<Option<Vec<String>>> {
|
||||
let token = self.auth.get_token().await?;
|
||||
|
||||
let response = match self
|
||||
.client
|
||||
.get(&format!("{}/api/2.0/serving-endpoints", self.host))
|
||||
.header("Authorization", format!("Bearer {}", token))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch Databricks models: {}", e);
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
if let Ok(error_text) = response.text().await {
|
||||
warn!(
|
||||
"Failed to fetch Databricks models: {} - {}",
|
||||
status,
|
||||
error_text
|
||||
);
|
||||
} else {
|
||||
warn!("Failed to fetch Databricks models: {}", status);
|
||||
}
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let json: serde_json::Value = match response.json().await {
|
||||
Ok(json) => json,
|
||||
Err(e) => {
|
||||
warn!("Failed to parse Databricks API response: {}", e);
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
let endpoints = match json.get("endpoints").and_then(|v| v.as_array()) {
|
||||
Some(endpoints) => endpoints,
|
||||
None => {
|
||||
warn!(
|
||||
"Unexpected response format from Databricks API: missing 'endpoints' array"
|
||||
);
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
let models: Vec<String> = endpoints
|
||||
.iter()
|
||||
.filter_map(|endpoint| {
|
||||
endpoint
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|name| name.to_string())
|
||||
})
|
||||
.collect();
|
||||
|
||||
if models.is_empty() {
|
||||
debug!("No serving endpoints found in Databricks workspace");
|
||||
Ok(None)
|
||||
} else {
|
||||
debug!(
|
||||
"Found {} serving endpoints in Databricks workspace",
|
||||
models.len()
|
||||
);
|
||||
Ok(Some(models))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl LLMProvider for DatabricksProvider {
|
||||
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
|
||||
debug!(
|
||||
"Processing Databricks completion request with {} messages",
|
||||
request.messages.len()
|
||||
);
|
||||
|
||||
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
|
||||
let temperature = request.temperature.unwrap_or(self.temperature);
|
||||
|
||||
let request_body = self.create_request_body(
|
||||
&request.messages,
|
||||
request.tools.as_deref(),
|
||||
false,
|
||||
max_tokens,
|
||||
temperature
|
||||
)?;
|
||||
|
||||
debug!("Sending request to Databricks API: model={}, max_tokens={}, temperature={}",
|
||||
self.model, request_body.max_tokens, request_body.temperature);
|
||||
|
||||
// Debug: Log the full request body when tools are present
|
||||
if request.tools.is_some() {
|
||||
debug!("Full request body with tools: {}", serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "Failed to serialize".to_string()));
|
||||
}
|
||||
|
||||
let mut provider_clone = self.clone();
|
||||
let response = provider_clone
|
||||
.create_request_builder(false)
|
||||
.await?
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to send request to Databricks API: {}", e))?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
}
|
||||
|
||||
let response_text = response.text().await?;
|
||||
debug!("Raw Databricks API response: {}", response_text);
|
||||
|
||||
let databricks_response: DatabricksResponse = serde_json::from_str(&response_text)
|
||||
.map_err(|e| anyhow!("Failed to parse Databricks response: {} - Response: {}", e, response_text))?;
|
||||
|
||||
// Debug: Log the parsed response structure
|
||||
debug!("Parsed Databricks response: {:#?}", databricks_response);
|
||||
|
||||
// Extract content from the first choice
|
||||
let content = databricks_response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.content.as_ref())
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
// Check if there are tool calls in the response
|
||||
if let Some(first_choice) = databricks_response.choices.first() {
|
||||
if let Some(tool_calls) = &first_choice.message.tool_calls {
|
||||
debug!("Found {} tool calls in Databricks response", tool_calls.len());
|
||||
for (i, tool_call) in tool_calls.iter().enumerate() {
|
||||
debug!("Tool call {}: {} with args: {}", i, tool_call.function.name, tool_call.function.arguments);
|
||||
}
|
||||
|
||||
// For now, we'll return the content as-is since g3 handles tool calls via streaming
|
||||
// In the future, we might need to convert these to the internal format
|
||||
}
|
||||
}
|
||||
|
||||
let usage = Usage {
|
||||
prompt_tokens: databricks_response.usage.prompt_tokens,
|
||||
completion_tokens: databricks_response.usage.completion_tokens,
|
||||
total_tokens: databricks_response.usage.total_tokens,
|
||||
};
|
||||
|
||||
debug!(
|
||||
"Databricks completion successful: {} tokens generated",
|
||||
usage.completion_tokens
|
||||
);
|
||||
|
||||
Ok(CompletionResponse {
|
||||
content,
|
||||
usage,
|
||||
model: self.model.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
|
||||
debug!(
|
||||
"Processing Databricks streaming request with {} messages",
|
||||
request.messages.len()
|
||||
);
|
||||
|
||||
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
|
||||
let temperature = request.temperature.unwrap_or(self.temperature);
|
||||
|
||||
let request_body = self.create_request_body(
|
||||
&request.messages,
|
||||
request.tools.as_deref(),
|
||||
true,
|
||||
max_tokens,
|
||||
temperature
|
||||
)?;
|
||||
|
||||
debug!("Sending streaming request to Databricks API: model={}, max_tokens={}, temperature={}",
|
||||
self.model, request_body.max_tokens, request_body.temperature);
|
||||
|
||||
// Debug: Log the full request body
|
||||
debug!("Full request body: {}", serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "Failed to serialize".to_string()));
|
||||
|
||||
let mut provider_clone = self.clone();
|
||||
let response = provider_clone
|
||||
.create_request_builder(true)
|
||||
.await?
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to send streaming request to Databricks API: {}", e))?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
}
|
||||
|
||||
let stream = response.bytes_stream();
|
||||
let (tx, rx) = mpsc::channel(100);
|
||||
|
||||
// Spawn task to process the stream
|
||||
let provider = self.clone();
|
||||
tokio::spawn(async move {
|
||||
provider.parse_streaming_response(stream, tx).await;
|
||||
});
|
||||
|
||||
Ok(ReceiverStream::new(rx))
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"databricks"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
fn has_native_tool_calling(&self) -> bool {
|
||||
// Databricks Foundation Models support native tool calling
|
||||
// This includes Claude, Llama, DBRX, and most other models on the platform
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
// Databricks API request/response structures
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct DatabricksRequest {
|
||||
messages: Vec<DatabricksMessage>,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<DatabricksTool>>,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct DatabricksTool {
|
||||
r#type: String,
|
||||
function: DatabricksFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct DatabricksFunction {
|
||||
name: String,
|
||||
description: String,
|
||||
parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct DatabricksMessage {
|
||||
role: String,
|
||||
content: Option<String>, // Make content optional since tool calls might not have content
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<DatabricksToolCall>>, // Add tool_calls field for responses
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct DatabricksToolCall {
|
||||
id: String,
|
||||
r#type: String,
|
||||
function: DatabricksToolCallFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct DatabricksToolCallFunction {
|
||||
name: String,
|
||||
arguments: String, // This will be a JSON string that needs parsing
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DatabricksResponse {
|
||||
choices: Vec<DatabricksChoice>,
|
||||
usage: DatabricksUsage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DatabricksChoice {
|
||||
message: DatabricksMessage,
|
||||
finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DatabricksUsage {
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
total_tokens: u32,
|
||||
}
|
||||
|
||||
// Streaming response structures
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DatabricksStreamChunk {
|
||||
choices: Option<Vec<DatabricksStreamChoice>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DatabricksStreamChoice {
|
||||
delta: Option<DatabricksStreamDelta>,
|
||||
finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DatabricksStreamDelta {
|
||||
content: Option<String>,
|
||||
tool_calls: Option<Vec<DatabricksStreamToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DatabricksStreamToolCall {
|
||||
index: Option<usize>,
|
||||
id: Option<String>,
|
||||
function: DatabricksStreamFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DatabricksStreamFunction {
|
||||
#[serde(default)]
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_message_conversion() {
|
||||
let provider = DatabricksProvider::from_token(
|
||||
"https://test.databricks.com".to_string(),
|
||||
"test-token".to_string(),
|
||||
"test-model".to_string(),
|
||||
None,
|
||||
None,
|
||||
).unwrap();
|
||||
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: MessageRole::System,
|
||||
content: "You are a helpful assistant.".to_string(),
|
||||
},
|
||||
Message {
|
||||
role: MessageRole::User,
|
||||
content: "Hello!".to_string(),
|
||||
},
|
||||
Message {
|
||||
role: MessageRole::Assistant,
|
||||
content: "Hi there!".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
let databricks_messages = provider.convert_messages(&messages).unwrap();
|
||||
|
||||
assert_eq!(databricks_messages.len(), 3);
|
||||
assert_eq!(databricks_messages[0].role, "system");
|
||||
assert_eq!(databricks_messages[1].role, "user");
|
||||
assert_eq!(databricks_messages[2].role, "assistant");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_body_creation() {
|
||||
let provider = DatabricksProvider::from_token(
|
||||
"https://test.databricks.com".to_string(),
|
||||
"test-token".to_string(),
|
||||
"databricks-claude-sonnet-4".to_string(),
|
||||
Some(1000),
|
||||
Some(0.5),
|
||||
).unwrap();
|
||||
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: MessageRole::User,
|
||||
content: "Test message".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
let request_body = provider
|
||||
.create_request_body(&messages, None, false, 1000, 0.5)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(request_body.max_tokens, 1000);
|
||||
assert_eq!(request_body.temperature, 0.5);
|
||||
assert!(!request_body.stream);
|
||||
assert_eq!(request_body.messages.len(), 1);
|
||||
assert!(request_body.tools.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_conversion() {
|
||||
let provider = DatabricksProvider::from_token(
|
||||
"https://test.databricks.com".to_string(),
|
||||
"test-token".to_string(),
|
||||
"test-model".to_string(),
|
||||
None,
|
||||
None,
|
||||
).unwrap();
|
||||
|
||||
let tools = vec![
|
||||
Tool {
|
||||
name: "get_weather".to_string(),
|
||||
description: "Get the current weather".to_string(),
|
||||
input_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}),
|
||||
},
|
||||
];
|
||||
|
||||
let databricks_tools = provider.convert_tools(&tools);
|
||||
|
||||
assert_eq!(databricks_tools.len(), 1);
|
||||
assert_eq!(databricks_tools[0].r#type, "function");
|
||||
assert_eq!(databricks_tools[0].function.name, "get_weather");
|
||||
assert_eq!(databricks_tools[0].function.description, "Get the current weather");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_has_native_tool_calling() {
|
||||
let claude_provider = DatabricksProvider::from_token(
|
||||
"https://test.databricks.com".to_string(),
|
||||
"test-token".to_string(),
|
||||
"databricks-claude-sonnet-4".to_string(),
|
||||
None,
|
||||
None,
|
||||
).unwrap();
|
||||
|
||||
let llama_provider = DatabricksProvider::from_token(
|
||||
"https://test.databricks.com".to_string(),
|
||||
"test-token".to_string(),
|
||||
"databricks-meta-llama-3-3-70b-instruct".to_string(),
|
||||
None,
|
||||
None,
|
||||
).unwrap();
|
||||
|
||||
let dbrx_provider = DatabricksProvider::from_token(
|
||||
"https://test.databricks.com".to_string(),
|
||||
"test-token".to_string(),
|
||||
"databricks-dbrx-instruct".to_string(),
|
||||
None,
|
||||
None,
|
||||
).unwrap();
|
||||
|
||||
assert!(claude_provider.has_native_tool_calling());
|
||||
assert!(llama_provider.has_native_tool_calling());
|
||||
assert!(dbrx_provider.has_native_tool_calling());
|
||||
}
|
||||
}
|
||||
@@ -84,8 +84,11 @@ pub struct Tool {
|
||||
}
|
||||
|
||||
pub mod anthropic;
|
||||
pub mod databricks;
|
||||
pub mod oauth;
|
||||
|
||||
pub use anthropic::AnthropicProvider;
|
||||
pub use databricks::DatabricksProvider;
|
||||
|
||||
/// Provider registry for managing multiple LLM providers
|
||||
pub struct ProviderRegistry {
|
||||
|
||||
457
crates/g3-providers/src/oauth.rs
Normal file
457
crates/g3-providers/src/oauth.rs
Normal file
@@ -0,0 +1,457 @@
|
||||
use anyhow::Result;
|
||||
use axum::{extract::Query, response::Html, routing::get, Router};
|
||||
use base64::Engine;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use sha2::Digest;
|
||||
use std::{collections::HashMap, fs, net::SocketAddr, path::PathBuf, sync::Arc};
|
||||
use tokio::sync::{oneshot, Mutex as TokioMutex};
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct OidcEndpoints {
|
||||
authorization_endpoint: String,
|
||||
token_endpoint: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct TokenData {
|
||||
/// The access token used to authenticate API requests
|
||||
access_token: String,
|
||||
|
||||
/// Optional refresh token that can be used to obtain a new access token
|
||||
/// when the current one expires, enabling offline access without user interaction
|
||||
refresh_token: Option<String>,
|
||||
|
||||
/// When the access token expires (if known)
|
||||
/// Used to determine when a token needs to be refreshed
|
||||
expires_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
struct TokenCache {
|
||||
cache_path: PathBuf,
|
||||
}
|
||||
|
||||
fn get_base_path() -> PathBuf {
|
||||
// Use a similar pattern to Goose but for g3
|
||||
// macOS/Linux: ~/.config/g3/databricks/oauth
|
||||
// Windows: ~\AppData\Roaming\g3\config\databricks\oauth\
|
||||
let mut path = dirs::config_dir().unwrap_or_else(|| PathBuf::from("."));
|
||||
path.push("g3");
|
||||
path.push("databricks");
|
||||
path.push("oauth");
|
||||
path
|
||||
}
|
||||
|
||||
impl TokenCache {
|
||||
fn new(host: &str, client_id: &str, scopes: &[String]) -> Self {
|
||||
let mut hasher = sha2::Sha256::new();
|
||||
hasher.update(host.as_bytes());
|
||||
hasher.update(client_id.as_bytes());
|
||||
hasher.update(scopes.join(",").as_bytes());
|
||||
let hash = format!("{:x}", hasher.finalize());
|
||||
|
||||
fs::create_dir_all(get_base_path()).unwrap_or_else(|_| {});
|
||||
let cache_path = get_base_path().join(format!("{}.json", hash));
|
||||
|
||||
Self { cache_path }
|
||||
}
|
||||
|
||||
fn load_token(&self) -> Option<TokenData> {
|
||||
if let Ok(contents) = fs::read_to_string(&self.cache_path) {
|
||||
if let Ok(token_data) = serde_json::from_str::<TokenData>(&contents) {
|
||||
// Only return tokens that have a refresh token
|
||||
if token_data.refresh_token.is_some() {
|
||||
// If token is not expired, return it for immediate use
|
||||
if let Some(expires_at) = token_data.expires_at {
|
||||
if expires_at > Utc::now() {
|
||||
return Some(token_data);
|
||||
}
|
||||
// If token is expired but has refresh token, return it so we can refresh
|
||||
return Some(token_data);
|
||||
}
|
||||
// No expiration time but has refresh token, return it
|
||||
return Some(token_data);
|
||||
}
|
||||
// Token doesn't have a refresh token, ignore it to force a new OAuth flow
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn save_token(&self, token_data: &TokenData) -> Result<()> {
|
||||
if let Some(parent) = self.cache_path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
let contents = serde_json::to_string(token_data)?;
|
||||
fs::write(&self.cache_path, contents)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_workspace_endpoints(host: &str) -> Result<OidcEndpoints> {
|
||||
let base_url = Url::parse(host).expect("Invalid host URL");
|
||||
let oidc_url = base_url
|
||||
.join("oidc/.well-known/oauth-authorization-server")
|
||||
.expect("Invalid OIDC URL");
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client.get(oidc_url.clone()).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to get OIDC configuration from {}",
|
||||
oidc_url.to_string()
|
||||
));
|
||||
}
|
||||
|
||||
let oidc_config: Value = resp.json().await?;
|
||||
|
||||
let authorization_endpoint = oidc_config
|
||||
.get("authorization_endpoint")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("authorization_endpoint not found in OIDC configuration"))?
|
||||
.to_string();
|
||||
|
||||
let token_endpoint = oidc_config
|
||||
.get("token_endpoint")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("token_endpoint not found in OIDC configuration"))?
|
||||
.to_string();
|
||||
|
||||
Ok(OidcEndpoints {
|
||||
authorization_endpoint,
|
||||
token_endpoint,
|
||||
})
|
||||
}
|
||||
|
||||
struct OAuthFlow {
|
||||
endpoints: OidcEndpoints,
|
||||
client_id: String,
|
||||
redirect_url: String,
|
||||
scopes: Vec<String>,
|
||||
state: String,
|
||||
verifier: String,
|
||||
}
|
||||
|
||||
impl OAuthFlow {
|
||||
fn new(
|
||||
endpoints: OidcEndpoints,
|
||||
client_id: String,
|
||||
redirect_url: String,
|
||||
scopes: Vec<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
endpoints,
|
||||
client_id,
|
||||
redirect_url,
|
||||
scopes,
|
||||
state: nanoid::nanoid!(16),
|
||||
verifier: nanoid::nanoid!(64),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extracts token data from an OAuth 2.0 token response.
|
||||
fn extract_token_data(
|
||||
&self,
|
||||
token_response: &Value,
|
||||
old_refresh_token: Option<&str>,
|
||||
) -> Result<TokenData> {
|
||||
// Extract access token (required)
|
||||
let access_token = token_response
|
||||
.get("access_token")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("access_token not found in token response"))?
|
||||
.to_string();
|
||||
|
||||
// Extract refresh token if available
|
||||
let refresh_token = token_response
|
||||
.get("refresh_token")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| old_refresh_token.map(|s| s.to_string()));
|
||||
|
||||
// Handle token expiration
|
||||
let expires_at =
|
||||
if let Some(expires_in) = token_response.get("expires_in").and_then(|v| v.as_u64()) {
|
||||
// Traditional OAuth flow with expires_in seconds
|
||||
Some(Utc::now() + chrono::Duration::seconds(expires_in as i64))
|
||||
} else {
|
||||
// If the server doesn't provide any expiration info, log it but don't set an expiration
|
||||
tracing::debug!(
|
||||
"No expiration information provided by server, token expiration unknown."
|
||||
);
|
||||
None
|
||||
};
|
||||
|
||||
Ok(TokenData {
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_at,
|
||||
})
|
||||
}
|
||||
|
||||
fn get_authorization_url(&self) -> String {
|
||||
let challenge = {
|
||||
let digest = sha2::Sha256::digest(self.verifier.as_bytes());
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest)
|
||||
};
|
||||
|
||||
let params = [
|
||||
("response_type", "code"),
|
||||
("client_id", &self.client_id),
|
||||
("redirect_uri", &self.redirect_url),
|
||||
("scope", &self.scopes.join(" ")),
|
||||
("state", &self.state),
|
||||
("code_challenge", &challenge),
|
||||
("code_challenge_method", "S256"),
|
||||
];
|
||||
|
||||
format!(
|
||||
"{}?{}",
|
||||
self.endpoints.authorization_endpoint,
|
||||
serde_urlencoded::to_string(params).unwrap()
|
||||
)
|
||||
}
|
||||
|
||||
async fn exchange_code_for_token(&self, code: &str) -> Result<TokenData> {
|
||||
let params = [
|
||||
("grant_type", "authorization_code"),
|
||||
("code", code),
|
||||
("redirect_uri", &self.redirect_url),
|
||||
("code_verifier", &self.verifier),
|
||||
("client_id", &self.client_id),
|
||||
];
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(&self.endpoints.token_endpoint)
|
||||
.header("Content-Type", "application/x-www-form-urlencoded")
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let err_text = resp.text().await?;
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to exchange code for token: {}",
|
||||
err_text
|
||||
));
|
||||
}
|
||||
|
||||
let token_response: Value = resp.json().await?;
|
||||
self.extract_token_data(&token_response, None)
|
||||
}
|
||||
|
||||
async fn refresh_token(&self, refresh_token: &str) -> Result<TokenData> {
|
||||
let params = [
|
||||
("grant_type", "refresh_token"),
|
||||
("refresh_token", refresh_token),
|
||||
("client_id", &self.client_id),
|
||||
];
|
||||
|
||||
tracing::debug!("Refreshing token using refresh_token");
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(&self.endpoints.token_endpoint)
|
||||
.header("Content-Type", "application/x-www-form-urlencoded")
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let err_text = resp.text().await?;
|
||||
return Err(anyhow::anyhow!("Failed to refresh token: {}", err_text));
|
||||
}
|
||||
|
||||
let token_response: Value = resp.json().await?;
|
||||
self.extract_token_data(&token_response, Some(refresh_token))
|
||||
}
|
||||
|
||||
async fn execute(&self) -> Result<TokenData> {
|
||||
// Create a channel that will send the auth code from the app process
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let state = self.state.clone();
|
||||
let tx = Arc::new(TokioMutex::new(Some(tx)));
|
||||
|
||||
// Setup a server that will receive the redirect, capture the code, and display success/failure
|
||||
let app = Router::new().route(
|
||||
"/",
|
||||
get(move |Query(params): Query<HashMap<String, String>>| {
|
||||
let tx = Arc::clone(&tx);
|
||||
let state = state.clone();
|
||||
async move {
|
||||
let code = params.get("code").cloned();
|
||||
let received_state = params.get("state").cloned();
|
||||
|
||||
if let (Some(code), Some(received_state)) = (code, received_state) {
|
||||
if received_state == state {
|
||||
if let Some(sender) = tx.lock().await.take() {
|
||||
if sender.send(code).is_ok() {
|
||||
return Html(
|
||||
"<h2>G3 Authentication Success</h2><p>You can close this window and return to your terminal.</p>",
|
||||
);
|
||||
}
|
||||
}
|
||||
Html("<h2>Error</h2><p>Authentication already completed.</p>")
|
||||
} else {
|
||||
Html("<h2>Error</h2><p>State mismatch.</p>")
|
||||
}
|
||||
} else {
|
||||
Html("<h2>Error</h2><p>Authentication failed.</p>")
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
// Start the server to accept the oauth code
|
||||
let redirect_url = Url::parse(&self.redirect_url)?;
|
||||
let port = redirect_url.port().unwrap_or(80);
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], port));
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
|
||||
let server_handle = tokio::spawn(async move {
|
||||
let server = axum::serve(listener, app);
|
||||
server.await.unwrap();
|
||||
});
|
||||
|
||||
// Open the browser which will redirect with the code to the server
|
||||
let authorization_url = self.get_authorization_url();
|
||||
println!("🔐 Opening browser for Databricks authentication...");
|
||||
if webbrowser::open(&authorization_url).is_err() {
|
||||
println!(
|
||||
"Please open this URL in your browser:\n{}",
|
||||
authorization_url
|
||||
);
|
||||
}
|
||||
|
||||
// Wait for the authorization code with a timeout
|
||||
let code = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(120), // 2 minute timeout
|
||||
rx,
|
||||
)
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("Authentication timed out after 2 minutes"))??;
|
||||
|
||||
// Stop the server
|
||||
server_handle.abort();
|
||||
|
||||
println!("✅ Authentication successful! Exchanging code for token...");
|
||||
|
||||
// Exchange the code for a token
|
||||
self.exchange_code_for_token(&code).await
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_oauth_token_async(
|
||||
host: &str,
|
||||
client_id: &str,
|
||||
redirect_url: &str,
|
||||
scopes: &[String],
|
||||
) -> Result<String> {
|
||||
let token_cache = TokenCache::new(host, client_id, scopes);
|
||||
|
||||
// Try cache first
|
||||
if let Some(token) = token_cache.load_token() {
|
||||
// If token has an expiration time, check if it's expired
|
||||
if let Some(expires_at) = token.expires_at {
|
||||
if expires_at > Utc::now() {
|
||||
tracing::debug!("Using cached token");
|
||||
return Ok(token.access_token);
|
||||
}
|
||||
// Token is expired, will try to refresh below
|
||||
tracing::debug!("Token is expired, attempting to refresh");
|
||||
} else {
|
||||
// No expiration time was provided by the server
|
||||
tracing::debug!("Token has no expiration time, using cached token");
|
||||
return Ok(token.access_token);
|
||||
}
|
||||
|
||||
// Token is expired or has no expiration, try to refresh if we have a refresh token
|
||||
if let Some(refresh_token) = token.refresh_token {
|
||||
// Get endpoints for token refresh
|
||||
match get_workspace_endpoints(host).await {
|
||||
Ok(endpoints) => {
|
||||
let flow = OAuthFlow::new(
|
||||
endpoints,
|
||||
client_id.to_string(),
|
||||
redirect_url.to_string(),
|
||||
scopes.to_vec(),
|
||||
);
|
||||
|
||||
// Try to refresh the token
|
||||
match flow.refresh_token(&refresh_token).await {
|
||||
Ok(new_token) => {
|
||||
if let Err(e) = token_cache.save_token(&new_token) {
|
||||
tracing::warn!("Failed to save refreshed token: {}", e);
|
||||
}
|
||||
tracing::info!("Successfully refreshed token");
|
||||
return Ok(new_token.access_token);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Failed to refresh token, will try new auth flow: {}",
|
||||
e
|
||||
);
|
||||
// Continue to new auth flow
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to get endpoints for token refresh: {}", e);
|
||||
// Continue to new auth flow
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get endpoints and execute flow for a new token
|
||||
let endpoints = get_workspace_endpoints(host).await?;
|
||||
let flow = OAuthFlow::new(
|
||||
endpoints,
|
||||
client_id.to_string(),
|
||||
redirect_url.to_string(),
|
||||
scopes.to_vec(),
|
||||
);
|
||||
|
||||
// Execute the OAuth flow and get token
|
||||
let token = flow.execute().await?;
|
||||
|
||||
// Cache and return
|
||||
token_cache.save_token(&token)?;
|
||||
println!("🎉 Databricks authentication complete!");
|
||||
Ok(token.access_token)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_token_cache() -> Result<()> {
|
||||
let cache = TokenCache::new(
|
||||
"https://example.com",
|
||||
"test-client",
|
||||
&["scope1".to_string()],
|
||||
);
|
||||
|
||||
// Test with expiration time
|
||||
let token_data = TokenData {
|
||||
access_token: "test-token".to_string(),
|
||||
refresh_token: Some("test-refresh-token".to_string()),
|
||||
expires_at: Some(Utc::now() + chrono::Duration::hours(1)),
|
||||
};
|
||||
|
||||
cache.save_token(&token_data)?;
|
||||
|
||||
let loaded_token = cache.load_token().unwrap();
|
||||
assert_eq!(loaded_token.access_token, token_data.access_token);
|
||||
assert_eq!(loaded_token.refresh_token, token_data.refresh_token);
|
||||
assert!(loaded_token.expires_at.is_some());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user