max tokens fix for databricks

This commit is contained in:
Dhanji Prasanna
2025-09-29 06:45:53 +10:00
parent f3cf9b688e
commit 4e64555008
5 changed files with 643 additions and 539 deletions

1
Cargo.lock generated
View File

@@ -850,6 +850,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-trait", "async-trait",
"chrono",
"futures-util", "futures-util",
"g3-config", "g3-config",
"g3-execution", "g3-execution",

View File

@@ -170,7 +170,8 @@ impl Config {
let config = settings.build()?.try_deserialize()?; let config = settings.build()?.try_deserialize()?;
Ok(config) Ok(config)
} }
#[allow(dead_code)]
fn default_qwen_config() -> Self { fn default_qwen_config() -> Self {
Self { Self {
providers: ProvidersConfig { providers: ProvidersConfig {

View File

@@ -22,3 +22,4 @@ llama_cpp = { version = "0.3.2", features = ["metal"] }
shellexpand = "3.1" shellexpand = "3.1"
tokio-util = "0.7" tokio-util = "0.7"
futures-util = "0.3" futures-util = "0.3"
chrono = { version = "0.4", features = ["serde"] }

File diff suppressed because it is too large Load Diff

View File

@@ -79,10 +79,9 @@ const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"];
const DEFAULT_TIMEOUT_SECS: u64 = 600; const DEFAULT_TIMEOUT_SECS: u64 = 600;
pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-sonnet-4"; pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-sonnet-4";
const DATABRICKS_DEFAULT_FAST_MODEL: &str = "gemini-1-5-flash";
pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[ pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[
"databricks-claude-3-7-sonnet", "databricks-claude-3-7-sonnet",
"databricks-meta-llama-3-3-70b-instruct", "databricks-meta-llama-3-3-70b-instruct",
"databricks-meta-llama-3-1-405b-instruct", "databricks-meta-llama-3-1-405b-instruct",
"databricks-dbrx-instruct", "databricks-dbrx-instruct",
"databricks-mixtral-8x7b-instruct", "databricks-mixtral-8x7b-instruct",
@@ -155,14 +154,17 @@ impl DatabricksProvider {
.build() .build()
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?; .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
info!("Initialized Databricks provider with model: {} on host: {}", model, host); info!(
"Initialized Databricks provider with model: {} on host: {}",
model, host
);
Ok(Self { Ok(Self {
client, client,
host: host.trim_end_matches('/').to_string(), host: host.trim_end_matches('/').to_string(),
auth: DatabricksAuth::token(token), auth: DatabricksAuth::token(token),
model, model,
max_tokens: max_tokens.unwrap_or(4096), max_tokens: max_tokens.unwrap_or(50000),
temperature: temperature.unwrap_or(0.1), temperature: temperature.unwrap_or(0.1),
}) })
} }
@@ -178,24 +180,30 @@ impl DatabricksProvider {
.build() .build()
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?; .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
info!("Initialized Databricks provider with OAuth for model: {} on host: {}", model, host); info!(
"Initialized Databricks provider with OAuth for model: {} on host: {}",
model, host
);
Ok(Self { Ok(Self {
client, client,
host: host.trim_end_matches('/').to_string(), host: host.trim_end_matches('/').to_string(),
auth: DatabricksAuth::oauth(host.clone()), auth: DatabricksAuth::oauth(host.clone()),
model, model,
max_tokens: max_tokens.unwrap_or(4096), max_tokens: max_tokens.unwrap_or(50000),
temperature: temperature.unwrap_or(0.1), temperature: temperature.unwrap_or(0.1),
}) })
} }
async fn create_request_builder(&mut self, streaming: bool) -> Result<RequestBuilder> { async fn create_request_builder(&mut self, streaming: bool) -> Result<RequestBuilder> {
let token = self.auth.get_token().await?; let token = self.auth.get_token().await?;
let mut builder = self let mut builder = self
.client .client
.post(&format!("{}/serving-endpoints/{}/invocations", self.host, self.model)) .post(&format!(
"{}/serving-endpoints/{}/invocations",
self.host, self.model
))
.header("Authorization", format!("Bearer {}", token)) .header("Authorization", format!("Bearer {}", token))
.header("Content-Type", "application/json"); .header("Content-Type", "application/json");
@@ -226,7 +234,7 @@ impl DatabricksProvider {
for message in messages { for message in messages {
let role = match message.role { let role = match message.role {
MessageRole::System => "system", MessageRole::System => "system",
MessageRole::User => "user", MessageRole::User => "user",
MessageRole::Assistant => "assistant", MessageRole::Assistant => "assistant",
}; };
@@ -274,13 +282,26 @@ impl DatabricksProvider {
tx: mpsc::Sender<Result<CompletionChunk>>, tx: mpsc::Sender<Result<CompletionChunk>>,
) { ) {
let mut buffer = String::new(); let mut buffer = String::new();
let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> = std::collections::HashMap::new(); // index -> (id, name, args) let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> =
std::collections::HashMap::new(); // index -> (id, name, args)
let mut incomplete_data_line = String::new(); // Buffer for incomplete data: lines
while let Some(chunk_result) = stream.next().await { while let Some(chunk_result) = stream.next().await {
match chunk_result { match chunk_result {
Ok(chunk) => { Ok(chunk) => {
// Debug: Log raw bytes received
debug!("Raw SSE bytes received: {} bytes", chunk.len());
let chunk_str = match std::str::from_utf8(&chunk) { let chunk_str = match std::str::from_utf8(&chunk) {
Ok(s) => s, Ok(s) => {
// Debug: Log raw string content (truncated for large chunks)
if s.len() > 1000 {
debug!("Raw SSE string content (first 500 chars): {:?}...", &s[..500]);
} else {
debug!("Raw SSE string content: {:?}", s);
}
s
},
Err(e) => { Err(e) => {
error!("Invalid UTF-8 in stream chunk: {}", e); error!("Invalid UTF-8 in stream chunk: {}", e);
let _ = tx let _ = tx
@@ -292,7 +313,7 @@ impl DatabricksProvider {
buffer.push_str(chunk_str); buffer.push_str(chunk_str);
// Process complete lines // Process complete lines, but handle incomplete data: lines specially
while let Some(line_end) = buffer.find('\n') { while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].trim().to_string(); let line = buffer[..line_end].trim().to_string();
buffer.drain(..line_end + 1); buffer.drain(..line_end + 1);
@@ -301,21 +322,55 @@ impl DatabricksProvider {
continue; continue;
} }
// Check if we have an incomplete data line from previous chunk
let line = if !incomplete_data_line.is_empty() {
// We had an incomplete data: line, append this line to it
let complete_line = format!("{}{}", incomplete_data_line, line);
incomplete_data_line.clear();
complete_line
} else {
line
};
// Check if this is a data: line that might be incomplete
// SSE format requires double newline after data, so if we don't see another newline
// after this one in the buffer, and it's a data: line, it might be incomplete
if line.starts_with("data: ") {
// Check if there's a complete SSE event (should have double newline after data)
// But for streaming, single newline is often used, so we need to be careful
// The safest approach is to try parsing and if it fails due to incomplete JSON,
// we'll handle it below
}
// Debug: Log each SSE line (truncated for large lines)
if line.len() > 1000 {
debug!("SSE line (first 500 chars): {:?}...", &line[..500]);
} else {
debug!("SSE line: {:?}", line);
}
// Parse Server-Sent Events format // Parse Server-Sent Events format
if let Some(data) = line.strip_prefix("data: ") { if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" { if data == "[DONE]" {
debug!("Received stream completion marker"); debug!("Received stream completion marker");
let final_tool_calls: Vec<ToolCall> = current_tool_calls.values() let final_tool_calls: Vec<ToolCall> = current_tool_calls
.values()
.map(|(id, name, args)| ToolCall { .map(|(id, name, args)| ToolCall {
id: id.clone(), id: id.clone(),
tool: name.clone(), tool: name.clone(),
args: serde_json::from_str(args).unwrap_or(serde_json::Value::Object(serde_json::Map::new())), args: serde_json::from_str(args).unwrap_or(
serde_json::Value::Object(serde_json::Map::new()),
),
}) })
.collect(); .collect();
let final_chunk = CompletionChunk { let final_chunk = CompletionChunk {
content: String::new(), content: String::new(),
finished: true, finished: true,
tool_calls: if final_tool_calls.is_empty() { None } else { Some(final_tool_calls) }, tool_calls: if final_tool_calls.is_empty() {
None
} else {
Some(final_tool_calls)
},
}; };
if tx.send(Ok(final_chunk)).await.is_err() { if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream"); debug!("Receiver dropped, stopping stream");
@@ -323,12 +378,17 @@ impl DatabricksProvider {
return; return;
} }
debug!("Raw Databricks API JSON: {}", data); // Debug: Log every raw JSON payload from Databricks API (truncated for large payloads)
if data.len() > 1000 {
debug!("Raw Databricks SSE JSON payload (first 500 chars): {}...", &data[..500]);
} else {
debug!("Raw Databricks SSE JSON payload: {}", data);
}
match serde_json::from_str::<DatabricksStreamChunk>(data) { match serde_json::from_str::<DatabricksStreamChunk>(data) {
Ok(chunk) => { Ok(chunk) => {
debug!("Parsed stream chunk: {:?}", chunk); debug!("Successfully parsed Databricks stream chunk");
// Handle different types of chunks // Handle different types of chunks
if let Some(choices) = chunk.choices { if let Some(choices) = chunk.choices {
for choice in choices { for choice in choices {
@@ -349,57 +409,93 @@ impl DatabricksProvider {
// Handle tool calls - accumulate across chunks // Handle tool calls - accumulate across chunks
if let Some(tool_calls) = delta.tool_calls { if let Some(tool_calls) = delta.tool_calls {
debug!("Processing {} tool call deltas", tool_calls.len());
for tool_call in tool_calls { for tool_call in tool_calls {
let index = tool_call.index.unwrap_or(0); let index = tool_call.index.unwrap_or(0);
let entry = current_tool_calls.entry(index).or_insert_with(|| { debug!("Tool call delta for index {}: id={:?}, name='{}', args_len={}",
(String::new(), String::new(), String::new()) index, tool_call.id, tool_call.function.name, tool_call.function.arguments.len());
});
let entry = current_tool_calls
.entry(index)
.or_insert_with(|| {
(
String::new(),
String::new(),
String::new(),
)
});
// Update ID if provided // Update ID if provided
if let Some(id) = tool_call.id { if let Some(id) = tool_call.id {
debug!("Updating tool call {} ID from '{}' to '{}'", index, entry.0, id);
entry.0 = id; entry.0 = id;
} }
// Update name if provided and not empty // Update name if provided and not empty
if !tool_call.function.name.is_empty() { if !tool_call.function.name.is_empty() {
debug!("Updating tool call {} name from '{}' to '{}'", index, entry.1, tool_call.function.name);
entry.1 = tool_call.function.name; entry.1 = tool_call.function.name;
} }
// Append arguments // Append arguments
entry.2.push_str(&tool_call.function.arguments); debug!("Appending {} chars to tool call {} args (current len: {})",
tool_call.function.arguments.len(), index, entry.2.len());
entry.2.push_str(
&tool_call.function.arguments,
);
debug!("Accumulated tool call {}: id='{}', name='{}', args='{}'", debug!("Accumulated tool call {}: id='{}', name='{}', args_len={}",
index, entry.0, entry.1, entry.2); index, entry.0, entry.1, entry.2.len());
// Debug: Show a sample of the accumulated args if they're getting long
if entry.2.len() > 100 {
debug!("Tool call {} args sample (first 100 chars): {}", index, &entry.2[..100]);
} else if !entry.2.is_empty() {
debug!("Tool call {} full args: {}", index, entry.2);
}
} }
} }
} }
// Check if this choice is finished // Check if this choice is finished
if choice.finish_reason.is_some() { if choice.finish_reason.is_some() {
debug!("Choice finished with reason: {:?}", choice.finish_reason); debug!(
"Choice finished with reason: {:?}",
choice.finish_reason
);
// Convert accumulated tool calls to final format // Convert accumulated tool calls to final format
let final_tool_calls: Vec<ToolCall> = current_tool_calls.values() let final_tool_calls: Vec<ToolCall> = current_tool_calls.values()
.filter(|(_, name, _)| !name.is_empty()) // Only include tool calls with names .filter(|(_, name, _)| !name.is_empty()) // Only include tool calls with names
.map(|(id, name, args)| { .map(|(id, name, args)| {
debug!("Converting tool call: id='{}', name='{}', args='{}'", id, name, args); debug!("Converting tool call: id='{}', name='{}', args_len={}", id, name, args.len());
ToolCall { ToolCall {
id: if id.is_empty() { format!("tool_{}", name) } else { id.clone() }, id: if id.is_empty() { format!("tool_{}", name) } else { id.clone() },
tool: name.clone(), tool: name.clone(),
args: serde_json::from_str(args).unwrap_or_else(|e| { args: serde_json::from_str(args).unwrap_or_else(|e| {
debug!("Failed to parse tool args '{}': {}", args, e); debug!("Failed to parse tool args (len={}): {}", args.len(), e);
// For debugging, log a sample of the args if they're very long
if args.len() > 1000 {
debug!("Tool args sample (first 500 chars): {}", &args[..500]);
} else {
debug!("Full tool args: {}", args);
}
serde_json::Value::Object(serde_json::Map::new()) serde_json::Value::Object(serde_json::Map::new())
}), }),
} }
}) })
.collect(); .collect();
debug!("Final tool calls: {:?}", final_tool_calls); debug!("Final tool calls count: {}", final_tool_calls.len());
let final_chunk = CompletionChunk { let final_chunk = CompletionChunk {
content: String::new(), content: String::new(),
finished: true, finished: true,
tool_calls: if final_tool_calls.is_empty() { None } else { Some(final_tool_calls) }, tool_calls: if final_tool_calls.is_empty() {
None
} else {
Some(final_tool_calls)
},
}; };
if tx.send(Ok(final_chunk)).await.is_err() { if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream"); debug!("Receiver dropped, stopping stream");
@@ -410,10 +506,36 @@ impl DatabricksProvider {
} }
} }
Err(e) => { Err(e) => {
debug!("Failed to parse stream chunk: {} - Data: {}", e, data); // Check if this is likely an incomplete JSON due to line splitting
// Common indicators: unexpected EOF, unterminated string, etc.
let error_str = e.to_string().to_lowercase();
if line.starts_with("data: ") && (
error_str.contains("eof") ||
error_str.contains("unterminated") ||
error_str.contains("unexpected end") ||
error_str.contains("trailing") ||
// Also check if the data doesn't end with a proper JSON terminator
(!data.trim_end().ends_with('}') && !data.trim_end().ends_with(']'))
) {
// This looks like an incomplete data line, save it for the next chunk
debug!("Detected incomplete data line (len={}), buffering for next chunk", line.len());
incomplete_data_line = line.clone();
// Continue to next iteration without processing
continue;
} else {
// This is a real parse error, not due to line splitting
debug!("Failed to parse Databricks stream chunk JSON: {} - Data length: {}", e, data.len());
// For debugging large payloads, log a sample
if data.len() > 1000 {
debug!("JSON parse error - data sample: {}", &data[..std::cmp::min(500, data.len())]);
}
}
// Don't error out on parse failures, just continue // Don't error out on parse failures, just continue
} }
} }
} else if line.starts_with("event: ") || line.starts_with("id: ") {
// Debug: Log non-data SSE lines (like event: or id:)
debug!("Non-data SSE line: {}", line);
} }
} }
} }
@@ -425,27 +547,52 @@ impl DatabricksProvider {
} }
} }
// If we have any incomplete data line at the end, try to process it
if !incomplete_data_line.is_empty() {
debug!("Processing final incomplete data line (len={})", incomplete_data_line.len());
if let Some(data) = incomplete_data_line.strip_prefix("data: ") {
// Try to parse it as-is, it might be complete
if let Ok(_chunk) = serde_json::from_str::<DatabricksStreamChunk>(data) {
// Process the chunk (code would be duplicated from above, so in practice
// we'd extract this to a helper function)
debug!("Successfully parsed final incomplete data line");
} else {
warn!("Failed to parse final incomplete data line");
}
}
}
// Send final chunk if we haven't already // Send final chunk if we haven't already
let final_tool_calls: Vec<ToolCall> = current_tool_calls.values() let final_tool_calls: Vec<ToolCall> = current_tool_calls
.values()
.filter(|(_, name, _)| !name.is_empty()) .filter(|(_, name, _)| !name.is_empty())
.map(|(id, name, args)| ToolCall { .map(|(id, name, args)| ToolCall {
id: if id.is_empty() { format!("tool_{}", name) } else { id.clone() }, id: if id.is_empty() {
format!("tool_{}", name)
} else {
id.clone()
},
tool: name.clone(), tool: name.clone(),
args: serde_json::from_str(args).unwrap_or(serde_json::Value::Object(serde_json::Map::new())), args: serde_json::from_str(args)
.unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
}) })
.collect(); .collect();
let final_chunk = CompletionChunk { let final_chunk = CompletionChunk {
content: String::new(), content: String::new(),
finished: true, finished: true,
tool_calls: if final_tool_calls.is_empty() { None } else { Some(final_tool_calls) }, tool_calls: if final_tool_calls.is_empty() {
None
} else {
Some(final_tool_calls)
},
}; };
let _ = tx.send(Ok(final_chunk)).await; let _ = tx.send(Ok(final_chunk)).await;
} }
pub async fn fetch_supported_models(&mut self) -> Result<Option<Vec<String>>> { pub async fn fetch_supported_models(&mut self) -> Result<Option<Vec<String>>> {
let token = self.auth.get_token().await?; let token = self.auth.get_token().await?;
let response = match self let response = match self
.client .client
.get(&format!("{}/api/2.0/serving-endpoints", self.host)) .get(&format!("{}/api/2.0/serving-endpoints", self.host))
@@ -465,8 +612,7 @@ impl DatabricksProvider {
if let Ok(error_text) = response.text().await { if let Ok(error_text) = response.text().await {
warn!( warn!(
"Failed to fetch Databricks models: {} - {}", "Failed to fetch Databricks models: {} - {}",
status, status, error_text
error_text
); );
} else { } else {
warn!("Failed to fetch Databricks models: {}", status); warn!("Failed to fetch Databricks models: {}", status);
@@ -485,9 +631,7 @@ impl DatabricksProvider {
let endpoints = match json.get("endpoints").and_then(|v| v.as_array()) { let endpoints = match json.get("endpoints").and_then(|v| v.as_array()) {
Some(endpoints) => endpoints, Some(endpoints) => endpoints,
None => { None => {
warn!( warn!("Unexpected response format from Databricks API: missing 'endpoints' array");
"Unexpected response format from Databricks API: missing 'endpoints' array"
);
return Ok(None); return Ok(None);
} }
}; };
@@ -527,19 +671,25 @@ impl LLMProvider for DatabricksProvider {
let temperature = request.temperature.unwrap_or(self.temperature); let temperature = request.temperature.unwrap_or(self.temperature);
let request_body = self.create_request_body( let request_body = self.create_request_body(
&request.messages, &request.messages,
request.tools.as_deref(), request.tools.as_deref(),
false, false,
max_tokens, max_tokens,
temperature temperature,
)?; )?;
debug!("Sending request to Databricks API: model={}, max_tokens={}, temperature={}", debug!(
self.model, request_body.max_tokens, request_body.temperature); "Sending request to Databricks API: model={}, max_tokens={}, temperature={}",
self.model, request_body.max_tokens, request_body.temperature
);
// Debug: Log the full request body when tools are present // Debug: Log the full request body when tools are present
if request.tools.is_some() { if request.tools.is_some() {
debug!("Full request body with tools: {}", serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "Failed to serialize".to_string())); debug!(
"Full request body with tools: {}",
serde_json::to_string_pretty(&request_body)
.unwrap_or_else(|_| "Failed to serialize".to_string())
);
} }
let mut provider_clone = self.clone(); let mut provider_clone = self.clone();
@@ -564,7 +714,13 @@ impl LLMProvider for DatabricksProvider {
debug!("Raw Databricks API response: {}", response_text); debug!("Raw Databricks API response: {}", response_text);
let databricks_response: DatabricksResponse = serde_json::from_str(&response_text) let databricks_response: DatabricksResponse = serde_json::from_str(&response_text)
.map_err(|e| anyhow!("Failed to parse Databricks response: {} - Response: {}", e, response_text))?; .map_err(|e| {
anyhow!(
"Failed to parse Databricks response: {} - Response: {}",
e,
response_text
)
})?;
// Debug: Log the parsed response structure // Debug: Log the parsed response structure
debug!("Parsed Databricks response: {:#?}", databricks_response); debug!("Parsed Databricks response: {:#?}", databricks_response);
@@ -580,11 +736,17 @@ impl LLMProvider for DatabricksProvider {
// Check if there are tool calls in the response // Check if there are tool calls in the response
if let Some(first_choice) = databricks_response.choices.first() { if let Some(first_choice) = databricks_response.choices.first() {
if let Some(tool_calls) = &first_choice.message.tool_calls { if let Some(tool_calls) = &first_choice.message.tool_calls {
debug!("Found {} tool calls in Databricks response", tool_calls.len()); debug!(
"Found {} tool calls in Databricks response",
tool_calls.len()
);
for (i, tool_call) in tool_calls.iter().enumerate() { for (i, tool_call) in tool_calls.iter().enumerate() {
debug!("Tool call {}: {} with args: {}", i, tool_call.function.name, tool_call.function.arguments); debug!(
"Tool call {}: {} with args: {}",
i, tool_call.function.name, tool_call.function.arguments
);
} }
// For now, we'll return the content as-is since g3 handles tool calls via streaming // For now, we'll return the content as-is since g3 handles tool calls via streaming
// In the future, we might need to convert these to the internal format // In the future, we might need to convert these to the internal format
} }
@@ -618,18 +780,24 @@ impl LLMProvider for DatabricksProvider {
let temperature = request.temperature.unwrap_or(self.temperature); let temperature = request.temperature.unwrap_or(self.temperature);
let request_body = self.create_request_body( let request_body = self.create_request_body(
&request.messages, &request.messages,
request.tools.as_deref(), request.tools.as_deref(),
true, true,
max_tokens, max_tokens,
temperature temperature,
)?; )?;
debug!("Sending streaming request to Databricks API: model={}, max_tokens={}, temperature={}", debug!(
self.model, request_body.max_tokens, request_body.temperature); "Sending streaming request to Databricks API: model={}, max_tokens={}, temperature={}",
self.model, request_body.max_tokens, request_body.temperature
);
// Debug: Log the full request body // Debug: Log the full request body
debug!("Full request body: {}", serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "Failed to serialize".to_string())); debug!(
"Full request body: {}",
serde_json::to_string_pretty(&request_body)
.unwrap_or_else(|_| "Failed to serialize".to_string())
);
let mut provider_clone = self.clone(); let mut provider_clone = self.clone();
let response = provider_clone let response = provider_clone
@@ -731,6 +899,7 @@ struct DatabricksResponse {
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct DatabricksChoice { struct DatabricksChoice {
message: DatabricksMessage, message: DatabricksMessage,
#[allow(dead_code)]
finish_reason: Option<String>, finish_reason: Option<String>,
} }
@@ -786,7 +955,8 @@ mod tests {
"test-model".to_string(), "test-model".to_string(),
None, None,
None, None,
).unwrap(); )
.unwrap();
let messages = vec![ let messages = vec![
Message { Message {
@@ -819,14 +989,13 @@ mod tests {
"databricks-claude-sonnet-4".to_string(), "databricks-claude-sonnet-4".to_string(),
Some(1000), Some(1000),
Some(0.5), Some(0.5),
).unwrap(); )
.unwrap();
let messages = vec![ let messages = vec![Message {
Message { role: MessageRole::User,
role: MessageRole::User, content: "Test message".to_string(),
content: "Test message".to_string(), }];
},
];
let request_body = provider let request_body = provider
.create_request_body(&messages, None, false, 1000, 0.5) .create_request_body(&messages, None, false, 1000, 0.5)
@@ -847,31 +1016,33 @@ mod tests {
"test-model".to_string(), "test-model".to_string(),
None, None,
None, None,
).unwrap(); )
.unwrap();
let tools = vec![ let tools = vec![Tool {
Tool { name: "get_weather".to_string(),
name: "get_weather".to_string(), description: "Get the current weather".to_string(),
description: "Get the current weather".to_string(), input_schema: serde_json::json!({
input_schema: serde_json::json!({ "type": "object",
"type": "object", "properties": {
"properties": { "location": {
"location": { "type": "string",
"type": "string", "description": "The city and state"
"description": "The city and state" }
} },
}, "required": ["location"]
"required": ["location"] }),
}), }];
},
];
let databricks_tools = provider.convert_tools(&tools); let databricks_tools = provider.convert_tools(&tools);
assert_eq!(databricks_tools.len(), 1); assert_eq!(databricks_tools.len(), 1);
assert_eq!(databricks_tools[0].r#type, "function"); assert_eq!(databricks_tools[0].r#type, "function");
assert_eq!(databricks_tools[0].function.name, "get_weather"); assert_eq!(databricks_tools[0].function.name, "get_weather");
assert_eq!(databricks_tools[0].function.description, "Get the current weather"); assert_eq!(
databricks_tools[0].function.description,
"Get the current weather"
);
} }
#[test] #[test]
@@ -882,7 +1053,8 @@ mod tests {
"databricks-claude-sonnet-4".to_string(), "databricks-claude-sonnet-4".to_string(),
None, None,
None, None,
).unwrap(); )
.unwrap();
let llama_provider = DatabricksProvider::from_token( let llama_provider = DatabricksProvider::from_token(
"https://test.databricks.com".to_string(), "https://test.databricks.com".to_string(),
@@ -890,7 +1062,8 @@ mod tests {
"databricks-meta-llama-3-3-70b-instruct".to_string(), "databricks-meta-llama-3-3-70b-instruct".to_string(),
None, None,
None, None,
).unwrap(); )
.unwrap();
let dbrx_provider = DatabricksProvider::from_token( let dbrx_provider = DatabricksProvider::from_token(
"https://test.databricks.com".to_string(), "https://test.databricks.com".to_string(),
@@ -898,7 +1071,8 @@ mod tests {
"databricks-dbrx-instruct".to_string(), "databricks-dbrx-instruct".to_string(),
None, None,
None, None,
).unwrap(); )
.unwrap();
assert!(claude_provider.has_native_tool_calling()); assert!(claude_provider.has_native_tool_calling());
assert!(llama_provider.has_native_tool_calling()); assert!(llama_provider.has_native_tool_calling());