databricks support

This commit is contained in:
Dhanji Prasanna
2025-09-27 17:28:02 +10:00
parent 258eb4fd54
commit c490228824
9 changed files with 1899 additions and 50 deletions

View File

@@ -16,3 +16,14 @@ async-trait = "0.1"
tokio-stream = "0.1"
futures-util = "0.3"
bytes = "1.0"
# OAuth dependencies
axum = "0.7"
base64 = "0.22"
chrono = { version = "0.4", features = ["serde"] }
sha2 = "0.10"
url = "2.5"
webbrowser = "1.0"
nanoid = "0.4"
serde_urlencoded = "0.7"
tokio-util = "0.7"
dirs = "5.0"

View File

@@ -0,0 +1,907 @@
//! Databricks LLM provider implementation for the g3-providers crate.
//!
//! This module provides an implementation of the `LLMProvider` trait for Databricks Foundation Model APIs,
//! supporting both completion and streaming modes with OAuth authentication.
//!
//! # Features
//!
//! - Support for Databricks Foundation Models (databricks-claude-sonnet-4, databricks-meta-llama-3-3-70b-instruct, etc.)
//! - Both completion and streaming response modes
//! - OAuth authentication with automatic token refresh
//! - Token-based authentication as fallback
//! - Native tool calling support for compatible models
//! - Automatic model discovery from Databricks workspace
//!
//! # Usage
//!
//! ```rust,no_run
//! use g3_providers::{DatabricksProvider, LLMProvider, CompletionRequest, Message, MessageRole};
//!
//! #[tokio::main]
//! async fn main() -> anyhow::Result<()> {
//! // Create the provider with OAuth (recommended)
//! let provider = DatabricksProvider::from_oauth(
//! "https://your-workspace.cloud.databricks.com".to_string(),
//! "databricks-claude-sonnet-4".to_string(),
//! None, // Optional: max tokens
//! None, // Optional: temperature
//! ).await?;
//!
//! // Or create with token
//! let provider = DatabricksProvider::from_token(
//! "https://your-workspace.cloud.databricks.com".to_string(),
//! "your-databricks-token".to_string(),
//! "databricks-claude-sonnet-4".to_string(),
//! None,
//! None,
//! )?;
//!
//! // Create a completion request
//! let request = CompletionRequest {
//! messages: vec![
//! Message {
//! role: MessageRole::User,
//! content: "Hello! How are you?".to_string(),
//! },
//! ],
//! max_tokens: Some(1000),
//! temperature: Some(0.7),
//! stream: false,
//! tools: None,
//! };
//!
//! // Get a completion
//! let response = provider.complete(request).await?;
//! println!("Response: {}", response.content);
//!
//! Ok(())
//! }
//! ```
use anyhow::{anyhow, Result};
use bytes::Bytes;
use futures_util::stream::StreamExt;
use reqwest::{Client, RequestBuilder};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error, info, warn};
use crate::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
MessageRole, Tool, ToolCall, Usage,
};
const DEFAULT_CLIENT_ID: &str = "databricks-cli";
const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020";
const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"];
const DEFAULT_TIMEOUT_SECS: u64 = 600;
pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-sonnet-4";
const DATABRICKS_DEFAULT_FAST_MODEL: &str = "gemini-1-5-flash";
pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[
"databricks-claude-3-7-sonnet",
"databricks-meta-llama-3-3-70b-instruct",
"databricks-meta-llama-3-1-405b-instruct",
"databricks-dbrx-instruct",
"databricks-mixtral-8x7b-instruct",
];
#[derive(Debug, Clone)]
pub enum DatabricksAuth {
Token(String),
OAuth {
host: String,
client_id: String,
redirect_url: String,
scopes: Vec<String>,
cached_token: Option<String>,
},
}
impl DatabricksAuth {
pub fn oauth(host: String) -> Self {
Self::OAuth {
host,
client_id: DEFAULT_CLIENT_ID.to_string(),
redirect_url: DEFAULT_REDIRECT_URL.to_string(),
scopes: DEFAULT_SCOPES.iter().map(|s| s.to_string()).collect(),
cached_token: None,
}
}
pub fn token(token: String) -> Self {
Self::Token(token)
}
async fn get_token(&mut self) -> Result<String> {
match self {
DatabricksAuth::Token(token) => Ok(token.clone()),
DatabricksAuth::OAuth {
host,
client_id,
redirect_url,
scopes,
cached_token: _,
} => {
// Use the OAuth implementation
crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await
}
}
}
}
#[derive(Debug, Clone)]
pub struct DatabricksProvider {
client: Client,
host: String,
auth: DatabricksAuth,
model: String,
max_tokens: u32,
temperature: f32,
}
impl DatabricksProvider {
pub fn from_token(
host: String,
token: String,
model: String,
max_tokens: Option<u32>,
temperature: Option<f32>,
) -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.build()
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
info!("Initialized Databricks provider with model: {} on host: {}", model, host);
Ok(Self {
client,
host: host.trim_end_matches('/').to_string(),
auth: DatabricksAuth::token(token),
model,
max_tokens: max_tokens.unwrap_or(4096),
temperature: temperature.unwrap_or(0.1),
})
}
pub async fn from_oauth(
host: String,
model: String,
max_tokens: Option<u32>,
temperature: Option<f32>,
) -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.build()
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
info!("Initialized Databricks provider with OAuth for model: {} on host: {}", model, host);
Ok(Self {
client,
host: host.trim_end_matches('/').to_string(),
auth: DatabricksAuth::oauth(host.clone()),
model,
max_tokens: max_tokens.unwrap_or(4096),
temperature: temperature.unwrap_or(0.1),
})
}
async fn create_request_builder(&mut self, streaming: bool) -> Result<RequestBuilder> {
let token = self.auth.get_token().await?;
let mut builder = self
.client
.post(&format!("{}/serving-endpoints/{}/invocations", self.host, self.model))
.header("Authorization", format!("Bearer {}", token))
.header("Content-Type", "application/json");
if streaming {
builder = builder.header("Accept", "text/event-stream");
}
Ok(builder)
}
fn convert_tools(&self, tools: &[Tool]) -> Vec<DatabricksTool> {
tools
.iter()
.map(|tool| DatabricksTool {
r#type: "function".to_string(),
function: DatabricksFunction {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.input_schema.clone(),
},
})
.collect()
}
fn convert_messages(&self, messages: &[Message]) -> Result<Vec<DatabricksMessage>> {
let mut databricks_messages = Vec::new();
for message in messages {
let role = match message.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
};
databricks_messages.push(DatabricksMessage {
role: role.to_string(),
content: Some(message.content.clone()),
tool_calls: None, // Only used in responses, not requests
});
}
if databricks_messages.is_empty() {
return Err(anyhow!("At least one message is required"));
}
Ok(databricks_messages)
}
fn create_request_body(
&self,
messages: &[Message],
tools: Option<&[Tool]>,
streaming: bool,
max_tokens: u32,
temperature: f32,
) -> Result<DatabricksRequest> {
let databricks_messages = self.convert_messages(messages)?;
// Convert tools if provided
let databricks_tools = tools.map(|t| self.convert_tools(t));
let request = DatabricksRequest {
messages: databricks_messages,
max_tokens,
temperature,
tools: databricks_tools,
stream: streaming,
};
Ok(request)
}
async fn parse_streaming_response(
&self,
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
tx: mpsc::Sender<Result<CompletionChunk>>,
) {
let mut buffer = String::new();
let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> = std::collections::HashMap::new(); // index -> (id, name, args)
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
let chunk_str = match std::str::from_utf8(&chunk) {
Ok(s) => s,
Err(e) => {
error!("Invalid UTF-8 in stream chunk: {}", e);
let _ = tx
.send(Err(anyhow!("Invalid UTF-8 in stream chunk: {}", e)))
.await;
return;
}
};
buffer.push_str(chunk_str);
// Process complete lines
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].trim().to_string();
buffer.drain(..line_end + 1);
if line.is_empty() {
continue;
}
// Parse Server-Sent Events format
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
debug!("Received stream completion marker");
let final_tool_calls: Vec<ToolCall> = current_tool_calls.values()
.map(|(id, name, args)| ToolCall {
id: id.clone(),
tool: name.clone(),
args: serde_json::from_str(args).unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
})
.collect();
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
tool_calls: if final_tool_calls.is_empty() { None } else { Some(final_tool_calls) },
};
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
}
return;
}
debug!("Raw Databricks API JSON: {}", data);
match serde_json::from_str::<DatabricksStreamChunk>(data) {
Ok(chunk) => {
debug!("Parsed stream chunk: {:?}", chunk);
// Handle different types of chunks
if let Some(choices) = chunk.choices {
for choice in choices {
if let Some(delta) = choice.delta {
// Handle text content
if let Some(content) = delta.content {
debug!("Sending text chunk: '{}'", content);
let chunk = CompletionChunk {
content,
finished: false,
tool_calls: None,
};
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return;
}
}
// Handle tool calls - accumulate across chunks
if let Some(tool_calls) = delta.tool_calls {
for tool_call in tool_calls {
let index = tool_call.index.unwrap_or(0);
let entry = current_tool_calls.entry(index).or_insert_with(|| {
(String::new(), String::new(), String::new())
});
// Update ID if provided
if let Some(id) = tool_call.id {
entry.0 = id;
}
// Update name if provided and not empty
if !tool_call.function.name.is_empty() {
entry.1 = tool_call.function.name;
}
// Append arguments
entry.2.push_str(&tool_call.function.arguments);
debug!("Accumulated tool call {}: id='{}', name='{}', args='{}'",
index, entry.0, entry.1, entry.2);
}
}
}
// Check if this choice is finished
if choice.finish_reason.is_some() {
debug!("Choice finished with reason: {:?}", choice.finish_reason);
// Convert accumulated tool calls to final format
let final_tool_calls: Vec<ToolCall> = current_tool_calls.values()
.filter(|(_, name, _)| !name.is_empty()) // Only include tool calls with names
.map(|(id, name, args)| {
debug!("Converting tool call: id='{}', name='{}', args='{}'", id, name, args);
ToolCall {
id: if id.is_empty() { format!("tool_{}", name) } else { id.clone() },
tool: name.clone(),
args: serde_json::from_str(args).unwrap_or_else(|e| {
debug!("Failed to parse tool args '{}': {}", args, e);
serde_json::Value::Object(serde_json::Map::new())
}),
}
})
.collect();
debug!("Final tool calls: {:?}", final_tool_calls);
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
tool_calls: if final_tool_calls.is_empty() { None } else { Some(final_tool_calls) },
};
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
}
return;
}
}
}
}
Err(e) => {
debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
// Don't error out on parse failures, just continue
}
}
}
}
}
Err(e) => {
error!("Stream error: {}", e);
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
return;
}
}
}
// Send final chunk if we haven't already
let final_tool_calls: Vec<ToolCall> = current_tool_calls.values()
.filter(|(_, name, _)| !name.is_empty())
.map(|(id, name, args)| ToolCall {
id: if id.is_empty() { format!("tool_{}", name) } else { id.clone() },
tool: name.clone(),
args: serde_json::from_str(args).unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
})
.collect();
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
tool_calls: if final_tool_calls.is_empty() { None } else { Some(final_tool_calls) },
};
let _ = tx.send(Ok(final_chunk)).await;
}
pub async fn fetch_supported_models(&mut self) -> Result<Option<Vec<String>>> {
let token = self.auth.get_token().await?;
let response = match self
.client
.get(&format!("{}/api/2.0/serving-endpoints", self.host))
.header("Authorization", format!("Bearer {}", token))
.send()
.await
{
Ok(resp) => resp,
Err(e) => {
warn!("Failed to fetch Databricks models: {}", e);
return Ok(None);
}
};
if !response.status().is_success() {
let status = response.status();
if let Ok(error_text) = response.text().await {
warn!(
"Failed to fetch Databricks models: {} - {}",
status,
error_text
);
} else {
warn!("Failed to fetch Databricks models: {}", status);
}
return Ok(None);
}
let json: serde_json::Value = match response.json().await {
Ok(json) => json,
Err(e) => {
warn!("Failed to parse Databricks API response: {}", e);
return Ok(None);
}
};
let endpoints = match json.get("endpoints").and_then(|v| v.as_array()) {
Some(endpoints) => endpoints,
None => {
warn!(
"Unexpected response format from Databricks API: missing 'endpoints' array"
);
return Ok(None);
}
};
let models: Vec<String> = endpoints
.iter()
.filter_map(|endpoint| {
endpoint
.get("name")
.and_then(|v| v.as_str())
.map(|name| name.to_string())
})
.collect();
if models.is_empty() {
debug!("No serving endpoints found in Databricks workspace");
Ok(None)
} else {
debug!(
"Found {} serving endpoints in Databricks workspace",
models.len()
);
Ok(Some(models))
}
}
}
#[async_trait::async_trait]
impl LLMProvider for DatabricksProvider {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
debug!(
"Processing Databricks completion request with {} messages",
request.messages.len()
);
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
let request_body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
false,
max_tokens,
temperature
)?;
debug!("Sending request to Databricks API: model={}, max_tokens={}, temperature={}",
self.model, request_body.max_tokens, request_body.temperature);
// Debug: Log the full request body when tools are present
if request.tools.is_some() {
debug!("Full request body with tools: {}", serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "Failed to serialize".to_string()));
}
let mut provider_clone = self.clone();
let response = provider_clone
.create_request_builder(false)
.await?
.json(&request_body)
.send()
.await
.map_err(|e| anyhow!("Failed to send request to Databricks API: {}", e))?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
}
let response_text = response.text().await?;
debug!("Raw Databricks API response: {}", response_text);
let databricks_response: DatabricksResponse = serde_json::from_str(&response_text)
.map_err(|e| anyhow!("Failed to parse Databricks response: {} - Response: {}", e, response_text))?;
// Debug: Log the parsed response structure
debug!("Parsed Databricks response: {:#?}", databricks_response);
// Extract content from the first choice
let content = databricks_response
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.cloned()
.unwrap_or_default();
// Check if there are tool calls in the response
if let Some(first_choice) = databricks_response.choices.first() {
if let Some(tool_calls) = &first_choice.message.tool_calls {
debug!("Found {} tool calls in Databricks response", tool_calls.len());
for (i, tool_call) in tool_calls.iter().enumerate() {
debug!("Tool call {}: {} with args: {}", i, tool_call.function.name, tool_call.function.arguments);
}
// For now, we'll return the content as-is since g3 handles tool calls via streaming
// In the future, we might need to convert these to the internal format
}
}
let usage = Usage {
prompt_tokens: databricks_response.usage.prompt_tokens,
completion_tokens: databricks_response.usage.completion_tokens,
total_tokens: databricks_response.usage.total_tokens,
};
debug!(
"Databricks completion successful: {} tokens generated",
usage.completion_tokens
);
Ok(CompletionResponse {
content,
usage,
model: self.model.clone(),
})
}
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
debug!(
"Processing Databricks streaming request with {} messages",
request.messages.len()
);
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
let request_body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
true,
max_tokens,
temperature
)?;
debug!("Sending streaming request to Databricks API: model={}, max_tokens={}, temperature={}",
self.model, request_body.max_tokens, request_body.temperature);
// Debug: Log the full request body
debug!("Full request body: {}", serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "Failed to serialize".to_string()));
let mut provider_clone = self.clone();
let response = provider_clone
.create_request_builder(true)
.await?
.json(&request_body)
.send()
.await
.map_err(|e| anyhow!("Failed to send streaming request to Databricks API: {}", e))?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
}
let stream = response.bytes_stream();
let (tx, rx) = mpsc::channel(100);
// Spawn task to process the stream
let provider = self.clone();
tokio::spawn(async move {
provider.parse_streaming_response(stream, tx).await;
});
Ok(ReceiverStream::new(rx))
}
fn name(&self) -> &str {
"databricks"
}
fn model(&self) -> &str {
&self.model
}
fn has_native_tool_calling(&self) -> bool {
// Databricks Foundation Models support native tool calling
// This includes Claude, Llama, DBRX, and most other models on the platform
true
}
}
// Databricks API request/response structures
#[derive(Debug, Serialize)]
struct DatabricksRequest {
messages: Vec<DatabricksMessage>,
max_tokens: u32,
temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<DatabricksTool>>,
stream: bool,
}
#[derive(Debug, Serialize)]
struct DatabricksTool {
r#type: String,
function: DatabricksFunction,
}
#[derive(Debug, Serialize)]
struct DatabricksFunction {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
struct DatabricksMessage {
role: String,
content: Option<String>, // Make content optional since tool calls might not have content
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<DatabricksToolCall>>, // Add tool_calls field for responses
}
#[derive(Debug, Serialize, Deserialize)]
struct DatabricksToolCall {
id: String,
r#type: String,
function: DatabricksToolCallFunction,
}
#[derive(Debug, Serialize, Deserialize)]
struct DatabricksToolCallFunction {
name: String,
arguments: String, // This will be a JSON string that needs parsing
}
#[derive(Debug, Deserialize)]
struct DatabricksResponse {
choices: Vec<DatabricksChoice>,
usage: DatabricksUsage,
}
#[derive(Debug, Deserialize)]
struct DatabricksChoice {
message: DatabricksMessage,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct DatabricksUsage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
// Streaming response structures
#[derive(Debug, Deserialize)]
struct DatabricksStreamChunk {
choices: Option<Vec<DatabricksStreamChoice>>,
}
#[derive(Debug, Deserialize)]
struct DatabricksStreamChoice {
delta: Option<DatabricksStreamDelta>,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct DatabricksStreamDelta {
content: Option<String>,
tool_calls: Option<Vec<DatabricksStreamToolCall>>,
}
#[derive(Debug, Deserialize)]
struct DatabricksStreamToolCall {
index: Option<usize>,
id: Option<String>,
function: DatabricksStreamFunction,
}
#[derive(Debug, Deserialize)]
struct DatabricksStreamFunction {
#[serde(default)]
name: String,
arguments: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_conversion() {
let provider = DatabricksProvider::from_token(
"https://test.databricks.com".to_string(),
"test-token".to_string(),
"test-model".to_string(),
None,
None,
).unwrap();
let messages = vec![
Message {
role: MessageRole::System,
content: "You are a helpful assistant.".to_string(),
},
Message {
role: MessageRole::User,
content: "Hello!".to_string(),
},
Message {
role: MessageRole::Assistant,
content: "Hi there!".to_string(),
},
];
let databricks_messages = provider.convert_messages(&messages).unwrap();
assert_eq!(databricks_messages.len(), 3);
assert_eq!(databricks_messages[0].role, "system");
assert_eq!(databricks_messages[1].role, "user");
assert_eq!(databricks_messages[2].role, "assistant");
}
#[test]
fn test_request_body_creation() {
let provider = DatabricksProvider::from_token(
"https://test.databricks.com".to_string(),
"test-token".to_string(),
"databricks-claude-sonnet-4".to_string(),
Some(1000),
Some(0.5),
).unwrap();
let messages = vec![
Message {
role: MessageRole::User,
content: "Test message".to_string(),
},
];
let request_body = provider
.create_request_body(&messages, None, false, 1000, 0.5)
.unwrap();
assert_eq!(request_body.max_tokens, 1000);
assert_eq!(request_body.temperature, 0.5);
assert!(!request_body.stream);
assert_eq!(request_body.messages.len(), 1);
assert!(request_body.tools.is_none());
}
#[test]
fn test_tool_conversion() {
let provider = DatabricksProvider::from_token(
"https://test.databricks.com".to_string(),
"test-token".to_string(),
"test-model".to_string(),
None,
None,
).unwrap();
let tools = vec![
Tool {
name: "get_weather".to_string(),
description: "Get the current weather".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}),
},
];
let databricks_tools = provider.convert_tools(&tools);
assert_eq!(databricks_tools.len(), 1);
assert_eq!(databricks_tools[0].r#type, "function");
assert_eq!(databricks_tools[0].function.name, "get_weather");
assert_eq!(databricks_tools[0].function.description, "Get the current weather");
}
#[test]
fn test_has_native_tool_calling() {
let claude_provider = DatabricksProvider::from_token(
"https://test.databricks.com".to_string(),
"test-token".to_string(),
"databricks-claude-sonnet-4".to_string(),
None,
None,
).unwrap();
let llama_provider = DatabricksProvider::from_token(
"https://test.databricks.com".to_string(),
"test-token".to_string(),
"databricks-meta-llama-3-3-70b-instruct".to_string(),
None,
None,
).unwrap();
let dbrx_provider = DatabricksProvider::from_token(
"https://test.databricks.com".to_string(),
"test-token".to_string(),
"databricks-dbrx-instruct".to_string(),
None,
None,
).unwrap();
assert!(claude_provider.has_native_tool_calling());
assert!(llama_provider.has_native_tool_calling());
assert!(dbrx_provider.has_native_tool_calling());
}
}

View File

@@ -84,8 +84,11 @@ pub struct Tool {
}
pub mod anthropic;
pub mod databricks;
pub mod oauth;
pub use anthropic::AnthropicProvider;
pub use databricks::DatabricksProvider;
/// Provider registry for managing multiple LLM providers
pub struct ProviderRegistry {

View File

@@ -0,0 +1,457 @@
use anyhow::Result;
use axum::{extract::Query, response::Html, routing::get, Router};
use base64::Engine;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sha2::Digest;
use std::{collections::HashMap, fs, net::SocketAddr, path::PathBuf, sync::Arc};
use tokio::sync::{oneshot, Mutex as TokioMutex};
use url::Url;
#[derive(Debug, Clone)]
struct OidcEndpoints {
authorization_endpoint: String,
token_endpoint: String,
}
#[derive(Serialize, Deserialize)]
struct TokenData {
/// The access token used to authenticate API requests
access_token: String,
/// Optional refresh token that can be used to obtain a new access token
/// when the current one expires, enabling offline access without user interaction
refresh_token: Option<String>,
/// When the access token expires (if known)
/// Used to determine when a token needs to be refreshed
expires_at: Option<DateTime<Utc>>,
}
struct TokenCache {
cache_path: PathBuf,
}
fn get_base_path() -> PathBuf {
// Use a similar pattern to Goose but for g3
// macOS/Linux: ~/.config/g3/databricks/oauth
// Windows: ~\AppData\Roaming\g3\config\databricks\oauth\
let mut path = dirs::config_dir().unwrap_or_else(|| PathBuf::from("."));
path.push("g3");
path.push("databricks");
path.push("oauth");
path
}
impl TokenCache {
fn new(host: &str, client_id: &str, scopes: &[String]) -> Self {
let mut hasher = sha2::Sha256::new();
hasher.update(host.as_bytes());
hasher.update(client_id.as_bytes());
hasher.update(scopes.join(",").as_bytes());
let hash = format!("{:x}", hasher.finalize());
fs::create_dir_all(get_base_path()).unwrap_or_else(|_| {});
let cache_path = get_base_path().join(format!("{}.json", hash));
Self { cache_path }
}
fn load_token(&self) -> Option<TokenData> {
if let Ok(contents) = fs::read_to_string(&self.cache_path) {
if let Ok(token_data) = serde_json::from_str::<TokenData>(&contents) {
// Only return tokens that have a refresh token
if token_data.refresh_token.is_some() {
// If token is not expired, return it for immediate use
if let Some(expires_at) = token_data.expires_at {
if expires_at > Utc::now() {
return Some(token_data);
}
// If token is expired but has refresh token, return it so we can refresh
return Some(token_data);
}
// No expiration time but has refresh token, return it
return Some(token_data);
}
// Token doesn't have a refresh token, ignore it to force a new OAuth flow
}
}
None
}
fn save_token(&self, token_data: &TokenData) -> Result<()> {
if let Some(parent) = self.cache_path.parent() {
fs::create_dir_all(parent)?;
}
let contents = serde_json::to_string(token_data)?;
fs::write(&self.cache_path, contents)?;
Ok(())
}
}
async fn get_workspace_endpoints(host: &str) -> Result<OidcEndpoints> {
let base_url = Url::parse(host).expect("Invalid host URL");
let oidc_url = base_url
.join("oidc/.well-known/oauth-authorization-server")
.expect("Invalid OIDC URL");
let client = reqwest::Client::new();
let resp = client.get(oidc_url.clone()).send().await?;
if !resp.status().is_success() {
return Err(anyhow::anyhow!(
"Failed to get OIDC configuration from {}",
oidc_url.to_string()
));
}
let oidc_config: Value = resp.json().await?;
let authorization_endpoint = oidc_config
.get("authorization_endpoint")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("authorization_endpoint not found in OIDC configuration"))?
.to_string();
let token_endpoint = oidc_config
.get("token_endpoint")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("token_endpoint not found in OIDC configuration"))?
.to_string();
Ok(OidcEndpoints {
authorization_endpoint,
token_endpoint,
})
}
struct OAuthFlow {
endpoints: OidcEndpoints,
client_id: String,
redirect_url: String,
scopes: Vec<String>,
state: String,
verifier: String,
}
impl OAuthFlow {
fn new(
endpoints: OidcEndpoints,
client_id: String,
redirect_url: String,
scopes: Vec<String>,
) -> Self {
Self {
endpoints,
client_id,
redirect_url,
scopes,
state: nanoid::nanoid!(16),
verifier: nanoid::nanoid!(64),
}
}
/// Extracts token data from an OAuth 2.0 token response.
fn extract_token_data(
&self,
token_response: &Value,
old_refresh_token: Option<&str>,
) -> Result<TokenData> {
// Extract access token (required)
let access_token = token_response
.get("access_token")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("access_token not found in token response"))?
.to_string();
// Extract refresh token if available
let refresh_token = token_response
.get("refresh_token")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| old_refresh_token.map(|s| s.to_string()));
// Handle token expiration
let expires_at =
if let Some(expires_in) = token_response.get("expires_in").and_then(|v| v.as_u64()) {
// Traditional OAuth flow with expires_in seconds
Some(Utc::now() + chrono::Duration::seconds(expires_in as i64))
} else {
// If the server doesn't provide any expiration info, log it but don't set an expiration
tracing::debug!(
"No expiration information provided by server, token expiration unknown."
);
None
};
Ok(TokenData {
access_token,
refresh_token,
expires_at,
})
}
fn get_authorization_url(&self) -> String {
let challenge = {
let digest = sha2::Sha256::digest(self.verifier.as_bytes());
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest)
};
let params = [
("response_type", "code"),
("client_id", &self.client_id),
("redirect_uri", &self.redirect_url),
("scope", &self.scopes.join(" ")),
("state", &self.state),
("code_challenge", &challenge),
("code_challenge_method", "S256"),
];
format!(
"{}?{}",
self.endpoints.authorization_endpoint,
serde_urlencoded::to_string(params).unwrap()
)
}
async fn exchange_code_for_token(&self, code: &str) -> Result<TokenData> {
let params = [
("grant_type", "authorization_code"),
("code", code),
("redirect_uri", &self.redirect_url),
("code_verifier", &self.verifier),
("client_id", &self.client_id),
];
let client = reqwest::Client::new();
let resp = client
.post(&self.endpoints.token_endpoint)
.header("Content-Type", "application/x-www-form-urlencoded")
.form(&params)
.send()
.await?;
if !resp.status().is_success() {
let err_text = resp.text().await?;
return Err(anyhow::anyhow!(
"Failed to exchange code for token: {}",
err_text
));
}
let token_response: Value = resp.json().await?;
self.extract_token_data(&token_response, None)
}
async fn refresh_token(&self, refresh_token: &str) -> Result<TokenData> {
let params = [
("grant_type", "refresh_token"),
("refresh_token", refresh_token),
("client_id", &self.client_id),
];
tracing::debug!("Refreshing token using refresh_token");
let client = reqwest::Client::new();
let resp = client
.post(&self.endpoints.token_endpoint)
.header("Content-Type", "application/x-www-form-urlencoded")
.form(&params)
.send()
.await?;
if !resp.status().is_success() {
let err_text = resp.text().await?;
return Err(anyhow::anyhow!("Failed to refresh token: {}", err_text));
}
let token_response: Value = resp.json().await?;
self.extract_token_data(&token_response, Some(refresh_token))
}
async fn execute(&self) -> Result<TokenData> {
// Create a channel that will send the auth code from the app process
let (tx, rx) = oneshot::channel();
let state = self.state.clone();
let tx = Arc::new(TokioMutex::new(Some(tx)));
// Setup a server that will receive the redirect, capture the code, and display success/failure
let app = Router::new().route(
"/",
get(move |Query(params): Query<HashMap<String, String>>| {
let tx = Arc::clone(&tx);
let state = state.clone();
async move {
let code = params.get("code").cloned();
let received_state = params.get("state").cloned();
if let (Some(code), Some(received_state)) = (code, received_state) {
if received_state == state {
if let Some(sender) = tx.lock().await.take() {
if sender.send(code).is_ok() {
return Html(
"<h2>G3 Authentication Success</h2><p>You can close this window and return to your terminal.</p>",
);
}
}
Html("<h2>Error</h2><p>Authentication already completed.</p>")
} else {
Html("<h2>Error</h2><p>State mismatch.</p>")
}
} else {
Html("<h2>Error</h2><p>Authentication failed.</p>")
}
}
}),
);
// Start the server to accept the oauth code
let redirect_url = Url::parse(&self.redirect_url)?;
let port = redirect_url.port().unwrap_or(80);
let addr = SocketAddr::from(([127, 0, 0, 1], port));
let listener = tokio::net::TcpListener::bind(addr).await?;
let server_handle = tokio::spawn(async move {
let server = axum::serve(listener, app);
server.await.unwrap();
});
// Open the browser which will redirect with the code to the server
let authorization_url = self.get_authorization_url();
println!("🔐 Opening browser for Databricks authentication...");
if webbrowser::open(&authorization_url).is_err() {
println!(
"Please open this URL in your browser:\n{}",
authorization_url
);
}
// Wait for the authorization code with a timeout
let code = tokio::time::timeout(
std::time::Duration::from_secs(120), // 2 minute timeout
rx,
)
.await
.map_err(|_| anyhow::anyhow!("Authentication timed out after 2 minutes"))??;
// Stop the server
server_handle.abort();
println!("✅ Authentication successful! Exchanging code for token...");
// Exchange the code for a token
self.exchange_code_for_token(&code).await
}
}
pub async fn get_oauth_token_async(
host: &str,
client_id: &str,
redirect_url: &str,
scopes: &[String],
) -> Result<String> {
let token_cache = TokenCache::new(host, client_id, scopes);
// Try cache first
if let Some(token) = token_cache.load_token() {
// If token has an expiration time, check if it's expired
if let Some(expires_at) = token.expires_at {
if expires_at > Utc::now() {
tracing::debug!("Using cached token");
return Ok(token.access_token);
}
// Token is expired, will try to refresh below
tracing::debug!("Token is expired, attempting to refresh");
} else {
// No expiration time was provided by the server
tracing::debug!("Token has no expiration time, using cached token");
return Ok(token.access_token);
}
// Token is expired or has no expiration, try to refresh if we have a refresh token
if let Some(refresh_token) = token.refresh_token {
// Get endpoints for token refresh
match get_workspace_endpoints(host).await {
Ok(endpoints) => {
let flow = OAuthFlow::new(
endpoints,
client_id.to_string(),
redirect_url.to_string(),
scopes.to_vec(),
);
// Try to refresh the token
match flow.refresh_token(&refresh_token).await {
Ok(new_token) => {
if let Err(e) = token_cache.save_token(&new_token) {
tracing::warn!("Failed to save refreshed token: {}", e);
}
tracing::info!("Successfully refreshed token");
return Ok(new_token.access_token);
}
Err(e) => {
tracing::warn!(
"Failed to refresh token, will try new auth flow: {}",
e
);
// Continue to new auth flow
}
}
}
Err(e) => {
tracing::warn!("Failed to get endpoints for token refresh: {}", e);
// Continue to new auth flow
}
}
}
}
// Get endpoints and execute flow for a new token
let endpoints = get_workspace_endpoints(host).await?;
let flow = OAuthFlow::new(
endpoints,
client_id.to_string(),
redirect_url.to_string(),
scopes.to_vec(),
);
// Execute the OAuth flow and get token
let token = flow.execute().await?;
// Cache and return
token_cache.save_token(&token)?;
println!("🎉 Databricks authentication complete!");
Ok(token.access_token)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_cache() -> Result<()> {
let cache = TokenCache::new(
"https://example.com",
"test-client",
&["scope1".to_string()],
);
// Test with expiration time
let token_data = TokenData {
access_token: "test-token".to_string(),
refresh_token: Some("test-refresh-token".to_string()),
expires_at: Some(Utc::now() + chrono::Duration::hours(1)),
};
cache.save_token(&token_data)?;
let loaded_token = cache.load_token().unwrap();
assert_eq!(loaded_token.access_token, token_data.access_token);
assert_eq!(loaded_token.refresh_token, token_data.refresh_token);
assert!(loaded_token.expires_at.is_some());
Ok(())
}
}