Files
g3/crates/g3-providers/src/databricks.rs
Jochen 52f78653b4 add context window monitor
Writes the current context window to logs/current_context_window (uses a symlink to a session ID).

This PR was unfortunately generated by a different LLM and did a ton of superficial reformating, it's actually a fairly small and benign change, but I don't want to roll back everything. Hope that's ok.
2025-11-27 21:00:02 +11:00

1368 lines
55 KiB
Rust

//! Databricks LLM provider implementation for the g3-providers crate.
//!
//! This module provides an implementation of the `LLMProvider` trait for Databricks Foundation Model APIs,
//! supporting both completion and streaming modes with OAuth authentication.
//!
//! # Features
//!
//! - Support for Databricks Foundation Models (databricks-claude-sonnet-4, databricks-meta-llama-3-3-70b-instruct, etc.)
//! - Both completion and streaming response modes
//! - OAuth authentication with automatic token refresh
//! - Token-based authentication as fallback
//! - Native tool calling support for compatible models
//! - Automatic model discovery from Databricks workspace
//!
//! # Usage
//!
//! ```rust,no_run
//! use g3_providers::{DatabricksProvider, LLMProvider, CompletionRequest, Message, MessageRole};
//!
//! #[tokio::main]
//! async fn main() -> anyhow::Result<()> {
//! // Create the provider with OAuth (recommended)
//! let provider = DatabricksProvider::from_oauth(
//! "https://your-workspace.cloud.databricks.com".to_string(),
//! "databricks-claude-sonnet-4".to_string(),
//! None, // Optional: max tokens
//! None, // Optional: temperature
//! ).await?;
//!
//! // Or create with token
//! let provider = DatabricksProvider::from_token(
//! "https://your-workspace.cloud.databricks.com".to_string(),
//! "your-databricks-token".to_string(),
//! "databricks-claude-sonnet-4".to_string(),
//! None,
//! None,
//! )?;
//!
//! // Create a completion request
//! let request = CompletionRequest {
//! messages: vec![
//! Message::new(MessageRole::User, "Hello! How are you?".to_string()),
//! ],
//! max_tokens: Some(1000),
//! temperature: Some(0.7),
//! stream: false,
//! tools: None,
//! };
//!
//! // Get a completion
//! let response = provider.complete(request).await?;
//! println!("Response: {}", response.content);
//!
//! Ok(())
//! }
//! ```
use anyhow::{anyhow, Result};
use bytes::Bytes;
use futures_util::stream::StreamExt;
use reqwest::{Client, RequestBuilder};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error, info, warn};
use crate::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
MessageRole, Tool, ToolCall, Usage,
};
const DEFAULT_CLIENT_ID: &str = "databricks-cli";
const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020";
const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"];
const DEFAULT_TIMEOUT_SECS: u64 = 600;
pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-sonnet-4";
pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[
"databricks-claude-3-7-sonnet",
"databricks-meta-llama-3-3-70b-instruct",
"databricks-meta-llama-3-1-405b-instruct",
"databricks-dbrx-instruct",
"databricks-mixtral-8x7b-instruct",
];
#[derive(Debug, Clone)]
pub enum DatabricksAuth {
Token(String),
OAuth {
host: String,
client_id: String,
redirect_url: String,
scopes: Vec<String>,
cached_token: Option<String>,
},
}
impl DatabricksAuth {
pub fn oauth(host: String) -> Self {
Self::OAuth {
host,
client_id: DEFAULT_CLIENT_ID.to_string(),
redirect_url: DEFAULT_REDIRECT_URL.to_string(),
scopes: DEFAULT_SCOPES.iter().map(|s| s.to_string()).collect(),
cached_token: None,
}
}
pub fn token(token: String) -> Self {
Self::Token(token)
}
async fn get_token(&mut self) -> Result<String> {
match self {
DatabricksAuth::Token(token) => Ok(token.clone()),
DatabricksAuth::OAuth {
host,
client_id,
redirect_url,
scopes,
cached_token,
} => {
// Use the OAuth implementation with automatic refresh
let token =
crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes)
.await?;
// Cache the token for potential reuse within the same session
*cached_token = Some(token.clone());
Ok(token)
}
}
}
/// Force a token refresh by clearing any cached token
/// This is useful when we get a 403 Invalid Token error
pub fn clear_cached_token(&mut self) {
if let DatabricksAuth::OAuth { cached_token, .. } = self {
*cached_token = None;
}
}
}
#[derive(Debug, Clone)]
pub struct DatabricksProvider {
client: Client,
host: String,
auth: DatabricksAuth,
model: String,
max_tokens: u32,
temperature: f32,
}
impl DatabricksProvider {
pub fn from_token(
host: String,
token: String,
model: String,
max_tokens: Option<u32>,
temperature: Option<f32>,
) -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.build()
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
info!(
"Initialized Databricks provider with model: {} on host: {}",
model, host
);
Ok(Self {
client,
host: host.trim_end_matches('/').to_string(),
auth: DatabricksAuth::token(token),
model,
max_tokens: max_tokens.unwrap_or(50000),
temperature: temperature.unwrap_or(0.1),
})
}
pub async fn from_oauth(
host: String,
model: String,
max_tokens: Option<u32>,
temperature: Option<f32>,
) -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.build()
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
info!(
"Initialized Databricks provider with OAuth for model: {} on host: {}",
model, host
);
Ok(Self {
client,
host: host.trim_end_matches('/').to_string(),
auth: DatabricksAuth::oauth(host.clone()),
model,
max_tokens: max_tokens.unwrap_or(50000),
temperature: temperature.unwrap_or(0.1),
})
}
async fn create_request_builder(&mut self, streaming: bool) -> Result<RequestBuilder> {
let token = self.auth.get_token().await?;
let mut builder = self
.client
.post(format!(
"{}/serving-endpoints/{}/invocations",
self.host, self.model
))
.header("Authorization", format!("Bearer {}", token))
.header("Content-Type", "application/json");
if streaming {
builder = builder.header("Accept", "text/event-stream");
}
Ok(builder)
}
fn convert_tools(&self, tools: &[Tool]) -> Vec<DatabricksTool> {
tools
.iter()
.map(|tool| DatabricksTool {
r#type: "function".to_string(),
function: DatabricksFunction {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.input_schema.clone(),
},
})
.collect()
}
fn convert_messages(&self, messages: &[Message]) -> Result<Vec<DatabricksMessage>> {
let mut databricks_messages = Vec::new();
for message in messages {
let role = match message.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
};
// Always use simple string format (Databricks doesn't support cache_control)
let content = serde_json::Value::String(message.content.clone());
databricks_messages.push(DatabricksMessage {
role: role.to_string(),
content: Some(content),
tool_calls: None, // Only used in responses, not requests
});
}
if databricks_messages.is_empty() {
return Err(anyhow!("At least one message is required"));
}
Ok(databricks_messages)
}
fn create_request_body(
&self,
messages: &[Message],
tools: Option<&[Tool]>,
streaming: bool,
max_tokens: u32,
temperature: f32,
) -> Result<DatabricksRequest> {
let databricks_messages = self.convert_messages(messages)?;
// Convert tools if provided
let databricks_tools = tools.map(|t| self.convert_tools(t));
let request = DatabricksRequest {
messages: databricks_messages,
max_tokens,
temperature,
tools: databricks_tools,
stream: streaming,
};
Ok(request)
}
async fn parse_streaming_response(
&self,
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
tx: mpsc::Sender<Result<CompletionChunk>>,
) -> Option<Usage> {
let mut buffer = String::new();
let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> =
std::collections::HashMap::new(); // index -> (id, name, args)
let mut incomplete_data_line = String::new(); // Buffer for incomplete data: lines
let mut chunk_count = 0;
let accumulated_usage: Option<Usage> = None;
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
// Debug: Log raw bytes received
chunk_count += 1;
debug!("Processing chunk #{}", chunk_count);
debug!("Raw SSE bytes received: {} bytes", chunk.len());
// Append new bytes to our buffer
byte_buffer.extend_from_slice(&chunk);
// Try to convert the entire buffer to UTF-8
let chunk_str = match std::str::from_utf8(&byte_buffer) {
Ok(s) => {
// Successfully converted entire buffer, clear it and use the string
let result = s.to_string();
byte_buffer.clear();
result
}
Err(e) => {
// Check if this is an incomplete sequence at the end
let valid_up_to = e.valid_up_to();
if valid_up_to > 0 {
// We have some valid UTF-8, extract it and keep the rest for next iteration
let valid_bytes =
byte_buffer.drain(..valid_up_to).collect::<Vec<_>>();
std::str::from_utf8(&valid_bytes).unwrap().to_string()
} else {
// No valid UTF-8 at all, skip this chunk and continue
continue;
}
}
};
// Debug: Log raw string content (truncated for large chunks)
if chunk_str.len() > 1000 {
debug!(
"Raw SSE string content (first 500 chars): {:?}...",
&chunk_str[..500]
);
} else {
debug!("Raw SSE string content: {:?}", chunk_str);
}
buffer.push_str(&chunk_str);
// Process complete lines, but handle incomplete data: lines specially
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].trim().to_string();
buffer.drain(..line_end + 1);
if line.is_empty() {
continue;
}
// Check if we have an incomplete data line from previous chunk
let line = if !incomplete_data_line.is_empty() {
// We had an incomplete data: line, append this line to it
let complete_line = format!("{}{}", incomplete_data_line, line);
incomplete_data_line.clear();
complete_line
} else {
line
};
// Check if this is a data: line that might be incomplete
// SSE format requires double newline after data, so if we don't see another newline
// after this one in the buffer, and it's a data: line, it might be incomplete
if line.starts_with("data: ") {
// Check if there's a complete SSE event (should have double newline after data)
// But for streaming, single newline is often used, so we need to be careful
// The safest approach is to try parsing and if it fails due to incomplete JSON,
// we'll handle it below
}
// Debug: Log each SSE line (truncated for large lines)
if line.len() > 1000 {
debug!("SSE line (first 500 chars): {:?}...", &line[..500]);
} else {
debug!("SSE line: {:?}", line);
}
// Parse Server-Sent Events format
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
debug!("Received stream completion marker");
let final_tool_calls: Vec<ToolCall> = current_tool_calls
.values()
.map(|(id, name, args)| ToolCall {
id: id.clone(),
tool: name.clone(),
args: serde_json::from_str(args).unwrap_or(
serde_json::Value::Object(serde_json::Map::new()),
),
})
.collect();
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if final_tool_calls.is_empty() {
None
} else {
Some(final_tool_calls)
},
};
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
}
return accumulated_usage;
}
// Debug: Log every raw JSON payload from Databricks API (truncated for large payloads)
if data.len() > 1000 {
debug!(
"Raw Databricks SSE JSON payload (first 500 chars): {}...",
&data[..500]
);
} else {
debug!("Raw Databricks SSE JSON payload: {}", data);
}
match serde_json::from_str::<DatabricksStreamChunk>(data) {
Ok(chunk) => {
debug!("Successfully parsed Databricks stream chunk");
// Handle different types of chunks
if let Some(choices) = chunk.choices {
for choice in choices {
if let Some(delta) = choice.delta {
// Handle text content
if let Some(content) = delta.content {
debug!("Sending text chunk: '{}'", content);
let chunk = CompletionChunk {
content,
finished: false,
usage: None,
tool_calls: None,
};
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return accumulated_usage;
}
}
// Handle tool calls - accumulate across chunks
if let Some(tool_calls) = delta.tool_calls {
debug!(
"Processing {} tool call deltas",
tool_calls.len()
);
for tool_call in tool_calls {
let index = tool_call.index.unwrap_or(0);
debug!("Tool call delta for index {}: id={:?}, name='{}', args_len={}",
index, tool_call.id, tool_call.function.name, tool_call.function.arguments.len());
let entry = current_tool_calls
.entry(index)
.or_insert_with(|| {
(
String::new(),
String::new(),
String::new(),
)
});
// Update ID if provided
if let Some(id) = tool_call.id {
debug!("Updating tool call {} ID from '{}' to '{}'", index, entry.0, id);
entry.0 = id;
}
// Update name if provided and not empty
if !tool_call.function.name.is_empty() {
debug!("Updating tool call {} name from '{}' to '{}'", index, entry.1, tool_call.function.name);
entry.1 = tool_call.function.name;
}
// Append arguments
debug!("Appending {} chars to tool call {} args (current len: {})",
tool_call.function.arguments.len(), index, entry.2.len());
entry.2.push_str(
&tool_call.function.arguments,
);
debug!("Accumulated tool call {}: id='{}', name='{}', args_len={}",
index, entry.0, entry.1, entry.2.len());
// Debug: Show a sample of the accumulated args if they're getting long
if entry.2.len() > 100 {
debug!("Tool call {} args sample (first 100 chars): {}", index, &entry.2[..100]);
} else if !entry.2.is_empty() {
debug!(
"Tool call {} full args: {}",
index, entry.2
);
}
}
}
}
// Check if this choice is finished
if choice.finish_reason.is_some() {
debug!(
"Choice finished with reason: {:?}",
choice.finish_reason
);
// Convert accumulated tool calls to final format
let final_tool_calls: Vec<ToolCall> = current_tool_calls.values()
.filter(|(_, name, _)| !name.is_empty()) // Only include tool calls with names
.map(|(id, name, args)| {
debug!("Converting tool call: id='{}', name='{}', args_len={}", id, name, args.len());
ToolCall {
id: if id.is_empty() { format!("tool_{}", name) } else { id.clone() },
tool: name.clone(),
args: serde_json::from_str(args).unwrap_or_else(|e| {
debug!("Failed to parse tool args (len={}): {}", args.len(), e);
// For debugging, log a sample of the args if they're very long
if args.len() > 1000 {
debug!("Tool args sample (first 500 chars): {}", &args[..500]);
} else {
debug!("Full tool args: {}", args);
}
serde_json::Value::Object(serde_json::Map::new())
}),
}
})
.collect();
debug!(
"Final tool calls count: {}",
final_tool_calls.len()
);
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if final_tool_calls.is_empty() {
None
} else {
Some(final_tool_calls)
},
};
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
}
return accumulated_usage;
}
}
}
}
Err(e) => {
// Check if this is likely an incomplete JSON due to line splitting
// Common indicators: unexpected EOF, unterminated string, etc.
let error_str = e.to_string().to_lowercase();
if line.starts_with("data: ")
&& (error_str.contains("eof") ||
error_str.contains("unterminated") ||
error_str.contains("unexpected end") ||
error_str.contains("trailing") ||
// Also check if the data doesn't end with a proper JSON terminator
(!data.trim_end().ends_with('}') && !data.trim_end().ends_with(']')))
{
// This looks like an incomplete data line, save it for the next chunk
debug!("Detected incomplete data line (len={}), buffering for next chunk", line.len());
incomplete_data_line = line.clone();
// Continue to next iteration without processing
continue;
} else {
// This is a real parse error, not due to line splitting
debug!("Failed to parse Databricks stream chunk JSON: {} - Data length: {}", e, data.len());
// For debugging large payloads, log a sample
if data.len() > 1000 {
debug!(
"JSON parse error - data sample: {}",
&data[..std::cmp::min(500, data.len())]
);
}
}
// Don't error out on parse failures, just continue
}
}
} else if line.starts_with("event: ") || line.starts_with("id: ") {
// Debug: Log non-data SSE lines (like event: or id:)
debug!("Non-data SSE line: {}", line);
}
}
}
Err(e) => {
error!("Stream error at chunk {}: {}", chunk_count, e);
// Check if this is a connection error that might be recoverable
let error_msg = e.to_string();
if error_msg.contains("unexpected EOF") || error_msg.contains("connection") {
warn!("Connection terminated unexpectedly at chunk {}, treating as end of stream", chunk_count);
// Don't send error, just break and finalize
break;
} else {
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
}
return accumulated_usage;
}
}
}
// Log final state
debug!("Stream ended after {} chunks", chunk_count);
debug!(
"Final state: buffer_len={}, incomplete_data_line_len={}, byte_buffer_len={}",
buffer.len(),
incomplete_data_line.len(),
byte_buffer.len()
);
debug!("Accumulated tool calls: {}", current_tool_calls.len());
// If we have any remaining data in buffers, log it for debugging
if !buffer.is_empty() {
debug!("Remaining buffer content: {:?}", buffer);
}
if !byte_buffer.is_empty() {
debug!("Remaining byte buffer: {} bytes", byte_buffer.len());
}
if !incomplete_data_line.is_empty() {
debug!("Remaining incomplete data line: {:?}", incomplete_data_line);
}
// If we have any incomplete data line at the end, try to process it
if !incomplete_data_line.is_empty() {
debug!(
"Processing final incomplete data line (len={})",
incomplete_data_line.len()
);
if let Some(data) = incomplete_data_line.strip_prefix("data: ") {
// Try to parse it as-is, it might be complete
if let Ok(_chunk) = serde_json::from_str::<DatabricksStreamChunk>(data) {
// Process the chunk (code would be duplicated from above, so in practice
// we'd extract this to a helper function)
debug!("Successfully parsed final incomplete data line");
} else {
warn!("Failed to parse final incomplete data line");
}
}
}
// Send final chunk if we haven't already
let final_tool_calls: Vec<ToolCall> = current_tool_calls
.values()
.filter(|(_, name, _)| !name.is_empty())
.map(|(id, name, args)| ToolCall {
id: if id.is_empty() {
format!("tool_{}", name)
} else {
id.clone()
},
tool: name.clone(),
args: serde_json::from_str(args)
.unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
})
.collect();
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if final_tool_calls.is_empty() {
None
} else {
Some(final_tool_calls)
},
};
let _ = tx.send(Ok(final_chunk)).await;
accumulated_usage
}
pub async fn fetch_supported_models(&mut self) -> Result<Option<Vec<String>>> {
let token = self.auth.get_token().await?;
let response = match self
.client
.get(format!("{}/api/2.0/serving-endpoints", self.host))
.header("Authorization", format!("Bearer {}", token))
.send()
.await
{
Ok(resp) => resp,
Err(e) => {
warn!("Failed to fetch Databricks models: {}", e);
return Ok(None);
}
};
if !response.status().is_success() {
let status = response.status();
if let Ok(error_text) = response.text().await {
warn!(
"Failed to fetch Databricks models: {} - {}",
status, error_text
);
} else {
warn!("Failed to fetch Databricks models: {}", status);
}
return Ok(None);
}
let json: serde_json::Value = match response.json().await {
Ok(json) => json,
Err(e) => {
warn!("Failed to parse Databricks API response: {}", e);
return Ok(None);
}
};
let endpoints = match json.get("endpoints").and_then(|v| v.as_array()) {
Some(endpoints) => endpoints,
None => {
warn!("Unexpected response format from Databricks API: missing 'endpoints' array");
return Ok(None);
}
};
let models: Vec<String> = endpoints
.iter()
.filter_map(|endpoint| {
endpoint
.get("name")
.and_then(|v| v.as_str())
.map(|name| name.to_string())
})
.collect();
if models.is_empty() {
debug!("No serving endpoints found in Databricks workspace");
Ok(None)
} else {
debug!(
"Found {} serving endpoints in Databricks workspace",
models.len()
);
Ok(Some(models))
}
}
}
#[async_trait::async_trait]
impl LLMProvider for DatabricksProvider {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
debug!(
"Processing Databricks completion request with {} messages",
request.messages.len()
);
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
let request_body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
false,
max_tokens,
temperature,
)?;
debug!(
"Sending request to Databricks API: model={}, max_tokens={}, temperature={}",
self.model, request_body.max_tokens, request_body.temperature
);
// Debug: Log the full request body when tools are present
if request.tools.is_some() {
debug!(
"Full request body with tools: {}",
serde_json::to_string_pretty(&request_body)
.unwrap_or_else(|_| "Failed to serialize".to_string())
);
}
let mut provider_clone = self.clone();
let mut response = provider_clone
.create_request_builder(false)
.await?
.json(&request_body)
.send()
.await
.map_err(|e| anyhow!("Failed to send request to Databricks API: {}", e))?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
// Check if this is a 403 Invalid Token error that we can retry with token refresh
if status == reqwest::StatusCode::FORBIDDEN
&& (error_text.contains("Invalid Token") || error_text.contains("invalid_token"))
{
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
// Try to refresh the token if we're using OAuth
if let DatabricksAuth::OAuth { .. } = &provider_clone.auth {
// Clear any cached token to force a refresh
provider_clone.auth.clear_cached_token();
// Try to get a new token (will attempt refresh or new OAuth flow)
match provider_clone.auth.get_token().await {
Ok(_new_token) => {
info!("Successfully refreshed OAuth token, retrying request");
// Retry the request with the new token
response = provider_clone
.create_request_builder(false)
.await?
.json(&request_body)
.send()
.await
.map_err(|e| anyhow!("Failed to send request to Databricks API after token refresh: {}", e))?;
let retry_status = response.status();
if !retry_status.is_success() {
let retry_error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!(
"Databricks API error {} after token refresh: {}",
retry_status,
retry_error_text
));
}
}
Err(e) => {
return Err(anyhow!(
"Failed to refresh OAuth token: {}. Original error: {}",
e,
error_text
));
}
}
} else {
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
}
} else {
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
}
}
let response_text = response.text().await?;
debug!("Raw Databricks API response: {}", response_text);
let databricks_response: DatabricksResponse = serde_json::from_str(&response_text)
.map_err(|e| {
anyhow!(
"Failed to parse Databricks response: {} - Response: {}",
e,
response_text
)
})?;
// Debug: Log the parsed response structure
debug!("Parsed Databricks response: {:#?}", databricks_response);
// Extract content from the first choice
let content = databricks_response
.choices
.first()
.and_then(|choice| {
choice.message.content.as_ref().map(|c| {
// Handle both string and array formats
if let Some(s) = c.as_str() {
s.to_string()
} else if let Some(arr) = c.as_array() {
// Extract text from content blocks
arr.iter()
.filter_map(|block| block.get("text").and_then(|t| t.as_str()))
.collect::<Vec<_>>()
.join("")
} else {
String::new()
}
})
})
.unwrap_or_default();
// Check if there are tool calls in the response
if let Some(first_choice) = databricks_response.choices.first() {
if let Some(tool_calls) = &first_choice.message.tool_calls {
debug!(
"Found {} tool calls in Databricks response",
tool_calls.len()
);
for (i, tool_call) in tool_calls.iter().enumerate() {
debug!(
"Tool call {}: {} with args: {}",
i, tool_call.function.name, tool_call.function.arguments
);
}
// For now, we'll return the content as-is since g3 handles tool calls via streaming
// In the future, we might need to convert these to the internal format
}
}
let usage = Usage {
prompt_tokens: databricks_response.usage.prompt_tokens,
completion_tokens: databricks_response.usage.completion_tokens,
total_tokens: databricks_response.usage.total_tokens,
};
debug!(
"Databricks completion successful: {} tokens generated",
usage.completion_tokens
);
Ok(CompletionResponse {
content,
usage,
model: self.model.clone(),
})
}
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
debug!(
"Processing Databricks streaming request with {} messages",
request.messages.len()
);
// Debug: Log tool count
if let Some(ref tools) = request.tools {
debug!("Request has {} tools", tools.len());
for tool in tools.iter().take(5) {
debug!(" Tool: {}", tool.name);
}
}
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
let request_body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
true,
max_tokens,
temperature,
)?;
debug!(
"Sending streaming request to Databricks API: model={}, max_tokens={}, temperature={}",
self.model, request_body.max_tokens, request_body.temperature
);
// Debug: Log the full request body
debug!(
"Full request body: {}",
serde_json::to_string_pretty(&request_body)
.unwrap_or_else(|_| "Failed to serialize".to_string())
);
let mut provider_clone = self.clone();
let mut response = provider_clone
.create_request_builder(true)
.await?
.json(&request_body)
.send()
.await
.map_err(|e| anyhow!("Failed to send streaming request to Databricks API: {}", e))?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
// Check if this is a 403 Invalid Token error that we can retry with token refresh
if status == reqwest::StatusCode::FORBIDDEN
&& (error_text.contains("Invalid Token") || error_text.contains("invalid_token"))
{
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
// Try to refresh the token if we're using OAuth
if let DatabricksAuth::OAuth { .. } = &provider_clone.auth {
// Clear any cached token to force a refresh
provider_clone.auth.clear_cached_token();
// Try to get a new token (will attempt refresh or new OAuth flow)
match provider_clone.auth.get_token().await {
Ok(_new_token) => {
info!("Successfully refreshed OAuth token, retrying streaming request");
// Retry the request with the new token
response = provider_clone
.create_request_builder(true)
.await?
.json(&request_body)
.send()
.await
.map_err(|e| anyhow!("Failed to send streaming request to Databricks API after token refresh: {}", e))?;
let retry_status = response.status();
if !retry_status.is_success() {
let retry_error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!(
"Databricks API error {} after token refresh: {}",
retry_status,
retry_error_text
));
}
}
Err(e) => {
return Err(anyhow!(
"Failed to refresh OAuth token: {}. Original error: {}",
e,
error_text
));
}
}
} else {
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
}
} else {
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
}
}
let stream = response.bytes_stream();
let (tx, rx) = mpsc::channel(100);
// Spawn task to process the stream
let provider = self.clone();
tokio::spawn(async move {
provider.parse_streaming_response(stream, tx).await;
});
Ok(ReceiverStream::new(rx))
}
fn name(&self) -> &str {
"databricks"
}
fn model(&self) -> &str {
&self.model
}
fn has_native_tool_calling(&self) -> bool {
// Databricks Foundation Models support native tool calling
// This includes Claude, Llama, DBRX, and most other models on the platform
true
}
fn supports_cache_control(&self) -> bool {
false
}
fn max_tokens(&self) -> u32 {
self.max_tokens
}
fn temperature(&self) -> f32 {
self.temperature
}
}
// Databricks API request/response structures
#[derive(Debug, Serialize)]
struct DatabricksRequest {
messages: Vec<DatabricksMessage>,
max_tokens: u32,
temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<DatabricksTool>>,
stream: bool,
}
#[derive(Debug, Serialize)]
struct DatabricksTool {
r#type: String,
function: DatabricksFunction,
}
#[derive(Debug, Serialize)]
struct DatabricksFunction {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
struct DatabricksMessage {
role: String,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<serde_json::Value>, // Can be string or array of content blocks
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<DatabricksToolCall>>, // Add tool_calls field for responses
}
#[derive(Debug, Serialize, Deserialize)]
struct DatabricksToolCall {
id: String,
r#type: String,
function: DatabricksToolCallFunction,
}
#[derive(Debug, Serialize, Deserialize)]
struct DatabricksToolCallFunction {
name: String,
arguments: String, // This will be a JSON string that needs parsing
}
#[derive(Debug, Deserialize)]
struct DatabricksResponse {
choices: Vec<DatabricksChoice>,
usage: DatabricksUsage,
}
#[derive(Debug, Deserialize)]
struct DatabricksChoice {
message: DatabricksMessage,
#[allow(dead_code)]
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct DatabricksUsage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
// Streaming response structures
#[derive(Debug, Deserialize)]
struct DatabricksStreamChunk {
choices: Option<Vec<DatabricksStreamChoice>>,
}
#[derive(Debug, Deserialize)]
struct DatabricksStreamChoice {
delta: Option<DatabricksStreamDelta>,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct DatabricksStreamDelta {
content: Option<String>,
tool_calls: Option<Vec<DatabricksStreamToolCall>>,
}
#[derive(Debug, Deserialize)]
struct DatabricksStreamToolCall {
index: Option<usize>,
id: Option<String>,
function: DatabricksStreamFunction,
}
#[derive(Debug, Deserialize)]
struct DatabricksStreamFunction {
#[serde(default)]
name: String,
arguments: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_conversion() {
let provider = DatabricksProvider::from_token(
"https://test.databricks.com".to_string(),
"test-token".to_string(),
"test-model".to_string(),
None,
None,
)
.unwrap();
let messages = vec![
Message::new(
MessageRole::System,
"You are a helpful assistant.".to_string(),
),
Message::new(MessageRole::User, "Hello!".to_string()),
Message::new(MessageRole::Assistant, "Hi there!".to_string()),
];
let databricks_messages = provider.convert_messages(&messages).unwrap();
assert_eq!(databricks_messages.len(), 3);
assert_eq!(databricks_messages[0].role, "system");
assert_eq!(databricks_messages[1].role, "user");
assert_eq!(databricks_messages[2].role, "assistant");
}
#[test]
fn test_request_body_creation() {
let provider = DatabricksProvider::from_token(
"https://test.databricks.com".to_string(),
"test-token".to_string(),
"databricks-claude-sonnet-4".to_string(),
Some(1000),
Some(0.5),
)
.unwrap();
let messages = vec![Message::new(MessageRole::User, "Test message".to_string())];
let request_body = provider
.create_request_body(&messages, None, false, 1000, 0.5)
.unwrap();
assert_eq!(request_body.max_tokens, 1000);
assert_eq!(request_body.temperature, 0.5);
assert!(!request_body.stream);
assert_eq!(request_body.messages.len(), 1);
assert!(request_body.tools.is_none());
}
#[test]
fn test_tool_conversion() {
let provider = DatabricksProvider::from_token(
"https://test.databricks.com".to_string(),
"test-token".to_string(),
"test-model".to_string(),
None,
None,
)
.unwrap();
let tools = vec![Tool {
name: "get_weather".to_string(),
description: "Get the current weather".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}),
}];
let databricks_tools = provider.convert_tools(&tools);
assert_eq!(databricks_tools.len(), 1);
assert_eq!(databricks_tools[0].r#type, "function");
assert_eq!(databricks_tools[0].function.name, "get_weather");
assert_eq!(
databricks_tools[0].function.description,
"Get the current weather"
);
}
#[test]
fn test_has_native_tool_calling() {
let claude_provider = DatabricksProvider::from_token(
"https://test.databricks.com".to_string(),
"test-token".to_string(),
"databricks-claude-sonnet-4".to_string(),
None,
None,
)
.unwrap();
let llama_provider = DatabricksProvider::from_token(
"https://test.databricks.com".to_string(),
"test-token".to_string(),
"databricks-meta-llama-3-3-70b-instruct".to_string(),
None,
None,
)
.unwrap();
let dbrx_provider = DatabricksProvider::from_token(
"https://test.databricks.com".to_string(),
"test-token".to_string(),
"databricks-dbrx-instruct".to_string(),
None,
None,
)
.unwrap();
assert!(claude_provider.has_native_tool_calling());
assert!(llama_provider.has_native_tool_calling());
assert!(dbrx_provider.has_native_tool_calling());
}
#[test]
fn test_cache_control_serialization() {
let provider = DatabricksProvider::from_token(
"https://test.databricks.com".to_string(),
"test-token".to_string(),
"databricks-claude-sonnet-4".to_string(),
None,
None,
)
.unwrap();
// Test message WITHOUT cache_control
let messages_without = vec![Message::new(MessageRole::User, "Hello".to_string())];
let databricks_messages_without = provider.convert_messages(&messages_without).unwrap();
let json_without = serde_json::to_string(&databricks_messages_without).unwrap();
println!("JSON without cache_control: {}", json_without);
assert!(
!json_without.contains("cache_control"),
"JSON should not contain 'cache_control' field when not configured"
);
// Test message WITH cache_control - should still NOT include it (Databricks doesn't support it)
let messages_with = vec![Message::with_cache_control(
MessageRole::User,
"Hello".to_string(),
crate::CacheControl::ephemeral(),
)];
let databricks_messages_with = provider.convert_messages(&messages_with).unwrap();
let json_with = serde_json::to_string(&databricks_messages_with).unwrap();
println!("JSON with cache_control: {}", json_with);
assert!(
!json_with.contains("cache_control"),
"JSON should NOT contain 'cache_control' field - Databricks doesn't support it"
);
}
#[test]
fn test_databricks_does_not_support_cache_control() {
let claude_provider = DatabricksProvider::from_token(
"https://test.databricks.com".to_string(),
"test-token".to_string(),
"databricks-claude-sonnet-4".to_string(),
None,
None,
)
.unwrap();
let llama_provider = DatabricksProvider::from_token(
"https://test.databricks.com".to_string(),
"test-token".to_string(),
"databricks-meta-llama-3-3-70b-instruct".to_string(),
None,
None,
)
.unwrap();
assert!(
!claude_provider.supports_cache_control(),
"Databricks should not support cache_control even for Claude models"
);
assert!(
!llama_provider.supports_cache_control(),
"Databricks should not support cache_control for Llama models"
);
}
}