max tokens fix for databricks
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -850,6 +850,7 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
"chrono",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"g3-config",
|
"g3-config",
|
||||||
"g3-execution",
|
"g3-execution",
|
||||||
|
|||||||
@@ -171,6 +171,7 @@ impl Config {
|
|||||||
Ok(config)
|
Ok(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
fn default_qwen_config() -> Self {
|
fn default_qwen_config() -> Self {
|
||||||
Self {
|
Self {
|
||||||
providers: ProvidersConfig {
|
providers: ProvidersConfig {
|
||||||
|
|||||||
@@ -22,3 +22,4 @@ llama_cpp = { version = "0.3.2", features = ["metal"] }
|
|||||||
shellexpand = "3.1"
|
shellexpand = "3.1"
|
||||||
tokio-util = "0.7"
|
tokio-util = "0.7"
|
||||||
futures-util = "0.3"
|
futures-util = "0.3"
|
||||||
|
chrono = { version = "0.4", features = ["serde"] }
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -79,7 +79,6 @@ const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"];
|
|||||||
const DEFAULT_TIMEOUT_SECS: u64 = 600;
|
const DEFAULT_TIMEOUT_SECS: u64 = 600;
|
||||||
|
|
||||||
pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-sonnet-4";
|
pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-sonnet-4";
|
||||||
const DATABRICKS_DEFAULT_FAST_MODEL: &str = "gemini-1-5-flash";
|
|
||||||
pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[
|
pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[
|
||||||
"databricks-claude-3-7-sonnet",
|
"databricks-claude-3-7-sonnet",
|
||||||
"databricks-meta-llama-3-3-70b-instruct",
|
"databricks-meta-llama-3-3-70b-instruct",
|
||||||
@@ -155,14 +154,17 @@ impl DatabricksProvider {
|
|||||||
.build()
|
.build()
|
||||||
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
|
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
|
||||||
|
|
||||||
info!("Initialized Databricks provider with model: {} on host: {}", model, host);
|
info!(
|
||||||
|
"Initialized Databricks provider with model: {} on host: {}",
|
||||||
|
model, host
|
||||||
|
);
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
client,
|
client,
|
||||||
host: host.trim_end_matches('/').to_string(),
|
host: host.trim_end_matches('/').to_string(),
|
||||||
auth: DatabricksAuth::token(token),
|
auth: DatabricksAuth::token(token),
|
||||||
model,
|
model,
|
||||||
max_tokens: max_tokens.unwrap_or(4096),
|
max_tokens: max_tokens.unwrap_or(50000),
|
||||||
temperature: temperature.unwrap_or(0.1),
|
temperature: temperature.unwrap_or(0.1),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -178,14 +180,17 @@ impl DatabricksProvider {
|
|||||||
.build()
|
.build()
|
||||||
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
|
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
|
||||||
|
|
||||||
info!("Initialized Databricks provider with OAuth for model: {} on host: {}", model, host);
|
info!(
|
||||||
|
"Initialized Databricks provider with OAuth for model: {} on host: {}",
|
||||||
|
model, host
|
||||||
|
);
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
client,
|
client,
|
||||||
host: host.trim_end_matches('/').to_string(),
|
host: host.trim_end_matches('/').to_string(),
|
||||||
auth: DatabricksAuth::oauth(host.clone()),
|
auth: DatabricksAuth::oauth(host.clone()),
|
||||||
model,
|
model,
|
||||||
max_tokens: max_tokens.unwrap_or(4096),
|
max_tokens: max_tokens.unwrap_or(50000),
|
||||||
temperature: temperature.unwrap_or(0.1),
|
temperature: temperature.unwrap_or(0.1),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -195,7 +200,10 @@ impl DatabricksProvider {
|
|||||||
|
|
||||||
let mut builder = self
|
let mut builder = self
|
||||||
.client
|
.client
|
||||||
.post(&format!("{}/serving-endpoints/{}/invocations", self.host, self.model))
|
.post(&format!(
|
||||||
|
"{}/serving-endpoints/{}/invocations",
|
||||||
|
self.host, self.model
|
||||||
|
))
|
||||||
.header("Authorization", format!("Bearer {}", token))
|
.header("Authorization", format!("Bearer {}", token))
|
||||||
.header("Content-Type", "application/json");
|
.header("Content-Type", "application/json");
|
||||||
|
|
||||||
@@ -274,13 +282,26 @@ impl DatabricksProvider {
|
|||||||
tx: mpsc::Sender<Result<CompletionChunk>>,
|
tx: mpsc::Sender<Result<CompletionChunk>>,
|
||||||
) {
|
) {
|
||||||
let mut buffer = String::new();
|
let mut buffer = String::new();
|
||||||
let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> = std::collections::HashMap::new(); // index -> (id, name, args)
|
let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> =
|
||||||
|
std::collections::HashMap::new(); // index -> (id, name, args)
|
||||||
|
let mut incomplete_data_line = String::new(); // Buffer for incomplete data: lines
|
||||||
|
|
||||||
while let Some(chunk_result) = stream.next().await {
|
while let Some(chunk_result) = stream.next().await {
|
||||||
match chunk_result {
|
match chunk_result {
|
||||||
Ok(chunk) => {
|
Ok(chunk) => {
|
||||||
|
// Debug: Log raw bytes received
|
||||||
|
debug!("Raw SSE bytes received: {} bytes", chunk.len());
|
||||||
|
|
||||||
let chunk_str = match std::str::from_utf8(&chunk) {
|
let chunk_str = match std::str::from_utf8(&chunk) {
|
||||||
Ok(s) => s,
|
Ok(s) => {
|
||||||
|
// Debug: Log raw string content (truncated for large chunks)
|
||||||
|
if s.len() > 1000 {
|
||||||
|
debug!("Raw SSE string content (first 500 chars): {:?}...", &s[..500]);
|
||||||
|
} else {
|
||||||
|
debug!("Raw SSE string content: {:?}", s);
|
||||||
|
}
|
||||||
|
s
|
||||||
|
},
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Invalid UTF-8 in stream chunk: {}", e);
|
error!("Invalid UTF-8 in stream chunk: {}", e);
|
||||||
let _ = tx
|
let _ = tx
|
||||||
@@ -292,7 +313,7 @@ impl DatabricksProvider {
|
|||||||
|
|
||||||
buffer.push_str(chunk_str);
|
buffer.push_str(chunk_str);
|
||||||
|
|
||||||
// Process complete lines
|
// Process complete lines, but handle incomplete data: lines specially
|
||||||
while let Some(line_end) = buffer.find('\n') {
|
while let Some(line_end) = buffer.find('\n') {
|
||||||
let line = buffer[..line_end].trim().to_string();
|
let line = buffer[..line_end].trim().to_string();
|
||||||
buffer.drain(..line_end + 1);
|
buffer.drain(..line_end + 1);
|
||||||
@@ -301,21 +322,55 @@ impl DatabricksProvider {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if we have an incomplete data line from previous chunk
|
||||||
|
let line = if !incomplete_data_line.is_empty() {
|
||||||
|
// We had an incomplete data: line, append this line to it
|
||||||
|
let complete_line = format!("{}{}", incomplete_data_line, line);
|
||||||
|
incomplete_data_line.clear();
|
||||||
|
complete_line
|
||||||
|
} else {
|
||||||
|
line
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check if this is a data: line that might be incomplete
|
||||||
|
// SSE format requires double newline after data, so if we don't see another newline
|
||||||
|
// after this one in the buffer, and it's a data: line, it might be incomplete
|
||||||
|
if line.starts_with("data: ") {
|
||||||
|
// Check if there's a complete SSE event (should have double newline after data)
|
||||||
|
// But for streaming, single newline is often used, so we need to be careful
|
||||||
|
// The safest approach is to try parsing and if it fails due to incomplete JSON,
|
||||||
|
// we'll handle it below
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug: Log each SSE line (truncated for large lines)
|
||||||
|
if line.len() > 1000 {
|
||||||
|
debug!("SSE line (first 500 chars): {:?}...", &line[..500]);
|
||||||
|
} else {
|
||||||
|
debug!("SSE line: {:?}", line);
|
||||||
|
}
|
||||||
|
|
||||||
// Parse Server-Sent Events format
|
// Parse Server-Sent Events format
|
||||||
if let Some(data) = line.strip_prefix("data: ") {
|
if let Some(data) = line.strip_prefix("data: ") {
|
||||||
if data == "[DONE]" {
|
if data == "[DONE]" {
|
||||||
debug!("Received stream completion marker");
|
debug!("Received stream completion marker");
|
||||||
let final_tool_calls: Vec<ToolCall> = current_tool_calls.values()
|
let final_tool_calls: Vec<ToolCall> = current_tool_calls
|
||||||
|
.values()
|
||||||
.map(|(id, name, args)| ToolCall {
|
.map(|(id, name, args)| ToolCall {
|
||||||
id: id.clone(),
|
id: id.clone(),
|
||||||
tool: name.clone(),
|
tool: name.clone(),
|
||||||
args: serde_json::from_str(args).unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
|
args: serde_json::from_str(args).unwrap_or(
|
||||||
|
serde_json::Value::Object(serde_json::Map::new()),
|
||||||
|
),
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
let final_chunk = CompletionChunk {
|
let final_chunk = CompletionChunk {
|
||||||
content: String::new(),
|
content: String::new(),
|
||||||
finished: true,
|
finished: true,
|
||||||
tool_calls: if final_tool_calls.is_empty() { None } else { Some(final_tool_calls) },
|
tool_calls: if final_tool_calls.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(final_tool_calls)
|
||||||
|
},
|
||||||
};
|
};
|
||||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||||
debug!("Receiver dropped, stopping stream");
|
debug!("Receiver dropped, stopping stream");
|
||||||
@@ -323,11 +378,16 @@ impl DatabricksProvider {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
debug!("Raw Databricks API JSON: {}", data);
|
// Debug: Log every raw JSON payload from Databricks API (truncated for large payloads)
|
||||||
|
if data.len() > 1000 {
|
||||||
|
debug!("Raw Databricks SSE JSON payload (first 500 chars): {}...", &data[..500]);
|
||||||
|
} else {
|
||||||
|
debug!("Raw Databricks SSE JSON payload: {}", data);
|
||||||
|
}
|
||||||
|
|
||||||
match serde_json::from_str::<DatabricksStreamChunk>(data) {
|
match serde_json::from_str::<DatabricksStreamChunk>(data) {
|
||||||
Ok(chunk) => {
|
Ok(chunk) => {
|
||||||
debug!("Parsed stream chunk: {:?}", chunk);
|
debug!("Successfully parsed Databricks stream chunk");
|
||||||
|
|
||||||
// Handle different types of chunks
|
// Handle different types of chunks
|
||||||
if let Some(choices) = chunk.choices {
|
if let Some(choices) = chunk.choices {
|
||||||
@@ -349,57 +409,93 @@ impl DatabricksProvider {
|
|||||||
|
|
||||||
// Handle tool calls - accumulate across chunks
|
// Handle tool calls - accumulate across chunks
|
||||||
if let Some(tool_calls) = delta.tool_calls {
|
if let Some(tool_calls) = delta.tool_calls {
|
||||||
|
debug!("Processing {} tool call deltas", tool_calls.len());
|
||||||
for tool_call in tool_calls {
|
for tool_call in tool_calls {
|
||||||
let index = tool_call.index.unwrap_or(0);
|
let index = tool_call.index.unwrap_or(0);
|
||||||
let entry = current_tool_calls.entry(index).or_insert_with(|| {
|
debug!("Tool call delta for index {}: id={:?}, name='{}', args_len={}",
|
||||||
(String::new(), String::new(), String::new())
|
index, tool_call.id, tool_call.function.name, tool_call.function.arguments.len());
|
||||||
});
|
|
||||||
|
let entry = current_tool_calls
|
||||||
|
.entry(index)
|
||||||
|
.or_insert_with(|| {
|
||||||
|
(
|
||||||
|
String::new(),
|
||||||
|
String::new(),
|
||||||
|
String::new(),
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
// Update ID if provided
|
// Update ID if provided
|
||||||
if let Some(id) = tool_call.id {
|
if let Some(id) = tool_call.id {
|
||||||
|
debug!("Updating tool call {} ID from '{}' to '{}'", index, entry.0, id);
|
||||||
entry.0 = id;
|
entry.0 = id;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update name if provided and not empty
|
// Update name if provided and not empty
|
||||||
if !tool_call.function.name.is_empty() {
|
if !tool_call.function.name.is_empty() {
|
||||||
|
debug!("Updating tool call {} name from '{}' to '{}'", index, entry.1, tool_call.function.name);
|
||||||
entry.1 = tool_call.function.name;
|
entry.1 = tool_call.function.name;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Append arguments
|
// Append arguments
|
||||||
entry.2.push_str(&tool_call.function.arguments);
|
debug!("Appending {} chars to tool call {} args (current len: {})",
|
||||||
|
tool_call.function.arguments.len(), index, entry.2.len());
|
||||||
|
entry.2.push_str(
|
||||||
|
&tool_call.function.arguments,
|
||||||
|
);
|
||||||
|
|
||||||
debug!("Accumulated tool call {}: id='{}', name='{}', args='{}'",
|
debug!("Accumulated tool call {}: id='{}', name='{}', args_len={}",
|
||||||
index, entry.0, entry.1, entry.2);
|
index, entry.0, entry.1, entry.2.len());
|
||||||
|
|
||||||
|
// Debug: Show a sample of the accumulated args if they're getting long
|
||||||
|
if entry.2.len() > 100 {
|
||||||
|
debug!("Tool call {} args sample (first 100 chars): {}", index, &entry.2[..100]);
|
||||||
|
} else if !entry.2.is_empty() {
|
||||||
|
debug!("Tool call {} full args: {}", index, entry.2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if this choice is finished
|
// Check if this choice is finished
|
||||||
if choice.finish_reason.is_some() {
|
if choice.finish_reason.is_some() {
|
||||||
debug!("Choice finished with reason: {:?}", choice.finish_reason);
|
debug!(
|
||||||
|
"Choice finished with reason: {:?}",
|
||||||
|
choice.finish_reason
|
||||||
|
);
|
||||||
|
|
||||||
// Convert accumulated tool calls to final format
|
// Convert accumulated tool calls to final format
|
||||||
let final_tool_calls: Vec<ToolCall> = current_tool_calls.values()
|
let final_tool_calls: Vec<ToolCall> = current_tool_calls.values()
|
||||||
.filter(|(_, name, _)| !name.is_empty()) // Only include tool calls with names
|
.filter(|(_, name, _)| !name.is_empty()) // Only include tool calls with names
|
||||||
.map(|(id, name, args)| {
|
.map(|(id, name, args)| {
|
||||||
debug!("Converting tool call: id='{}', name='{}', args='{}'", id, name, args);
|
debug!("Converting tool call: id='{}', name='{}', args_len={}", id, name, args.len());
|
||||||
ToolCall {
|
ToolCall {
|
||||||
id: if id.is_empty() { format!("tool_{}", name) } else { id.clone() },
|
id: if id.is_empty() { format!("tool_{}", name) } else { id.clone() },
|
||||||
tool: name.clone(),
|
tool: name.clone(),
|
||||||
args: serde_json::from_str(args).unwrap_or_else(|e| {
|
args: serde_json::from_str(args).unwrap_or_else(|e| {
|
||||||
debug!("Failed to parse tool args '{}': {}", args, e);
|
debug!("Failed to parse tool args (len={}): {}", args.len(), e);
|
||||||
|
// For debugging, log a sample of the args if they're very long
|
||||||
|
if args.len() > 1000 {
|
||||||
|
debug!("Tool args sample (first 500 chars): {}", &args[..500]);
|
||||||
|
} else {
|
||||||
|
debug!("Full tool args: {}", args);
|
||||||
|
}
|
||||||
serde_json::Value::Object(serde_json::Map::new())
|
serde_json::Value::Object(serde_json::Map::new())
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
debug!("Final tool calls: {:?}", final_tool_calls);
|
debug!("Final tool calls count: {}", final_tool_calls.len());
|
||||||
|
|
||||||
let final_chunk = CompletionChunk {
|
let final_chunk = CompletionChunk {
|
||||||
content: String::new(),
|
content: String::new(),
|
||||||
finished: true,
|
finished: true,
|
||||||
tool_calls: if final_tool_calls.is_empty() { None } else { Some(final_tool_calls) },
|
tool_calls: if final_tool_calls.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(final_tool_calls)
|
||||||
|
},
|
||||||
};
|
};
|
||||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||||
debug!("Receiver dropped, stopping stream");
|
debug!("Receiver dropped, stopping stream");
|
||||||
@@ -410,10 +506,36 @@ impl DatabricksProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
|
// Check if this is likely an incomplete JSON due to line splitting
|
||||||
|
// Common indicators: unexpected EOF, unterminated string, etc.
|
||||||
|
let error_str = e.to_string().to_lowercase();
|
||||||
|
if line.starts_with("data: ") && (
|
||||||
|
error_str.contains("eof") ||
|
||||||
|
error_str.contains("unterminated") ||
|
||||||
|
error_str.contains("unexpected end") ||
|
||||||
|
error_str.contains("trailing") ||
|
||||||
|
// Also check if the data doesn't end with a proper JSON terminator
|
||||||
|
(!data.trim_end().ends_with('}') && !data.trim_end().ends_with(']'))
|
||||||
|
) {
|
||||||
|
// This looks like an incomplete data line, save it for the next chunk
|
||||||
|
debug!("Detected incomplete data line (len={}), buffering for next chunk", line.len());
|
||||||
|
incomplete_data_line = line.clone();
|
||||||
|
// Continue to next iteration without processing
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
// This is a real parse error, not due to line splitting
|
||||||
|
debug!("Failed to parse Databricks stream chunk JSON: {} - Data length: {}", e, data.len());
|
||||||
|
// For debugging large payloads, log a sample
|
||||||
|
if data.len() > 1000 {
|
||||||
|
debug!("JSON parse error - data sample: {}", &data[..std::cmp::min(500, data.len())]);
|
||||||
|
}
|
||||||
|
}
|
||||||
// Don't error out on parse failures, just continue
|
// Don't error out on parse failures, just continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if line.starts_with("event: ") || line.starts_with("id: ") {
|
||||||
|
// Debug: Log non-data SSE lines (like event: or id:)
|
||||||
|
debug!("Non-data SSE line: {}", line);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -425,20 +547,45 @@ impl DatabricksProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we have any incomplete data line at the end, try to process it
|
||||||
|
if !incomplete_data_line.is_empty() {
|
||||||
|
debug!("Processing final incomplete data line (len={})", incomplete_data_line.len());
|
||||||
|
if let Some(data) = incomplete_data_line.strip_prefix("data: ") {
|
||||||
|
// Try to parse it as-is, it might be complete
|
||||||
|
if let Ok(_chunk) = serde_json::from_str::<DatabricksStreamChunk>(data) {
|
||||||
|
// Process the chunk (code would be duplicated from above, so in practice
|
||||||
|
// we'd extract this to a helper function)
|
||||||
|
debug!("Successfully parsed final incomplete data line");
|
||||||
|
} else {
|
||||||
|
warn!("Failed to parse final incomplete data line");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Send final chunk if we haven't already
|
// Send final chunk if we haven't already
|
||||||
let final_tool_calls: Vec<ToolCall> = current_tool_calls.values()
|
let final_tool_calls: Vec<ToolCall> = current_tool_calls
|
||||||
|
.values()
|
||||||
.filter(|(_, name, _)| !name.is_empty())
|
.filter(|(_, name, _)| !name.is_empty())
|
||||||
.map(|(id, name, args)| ToolCall {
|
.map(|(id, name, args)| ToolCall {
|
||||||
id: if id.is_empty() { format!("tool_{}", name) } else { id.clone() },
|
id: if id.is_empty() {
|
||||||
|
format!("tool_{}", name)
|
||||||
|
} else {
|
||||||
|
id.clone()
|
||||||
|
},
|
||||||
tool: name.clone(),
|
tool: name.clone(),
|
||||||
args: serde_json::from_str(args).unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
|
args: serde_json::from_str(args)
|
||||||
|
.unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let final_chunk = CompletionChunk {
|
let final_chunk = CompletionChunk {
|
||||||
content: String::new(),
|
content: String::new(),
|
||||||
finished: true,
|
finished: true,
|
||||||
tool_calls: if final_tool_calls.is_empty() { None } else { Some(final_tool_calls) },
|
tool_calls: if final_tool_calls.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(final_tool_calls)
|
||||||
|
},
|
||||||
};
|
};
|
||||||
let _ = tx.send(Ok(final_chunk)).await;
|
let _ = tx.send(Ok(final_chunk)).await;
|
||||||
}
|
}
|
||||||
@@ -465,8 +612,7 @@ impl DatabricksProvider {
|
|||||||
if let Ok(error_text) = response.text().await {
|
if let Ok(error_text) = response.text().await {
|
||||||
warn!(
|
warn!(
|
||||||
"Failed to fetch Databricks models: {} - {}",
|
"Failed to fetch Databricks models: {} - {}",
|
||||||
status,
|
status, error_text
|
||||||
error_text
|
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
warn!("Failed to fetch Databricks models: {}", status);
|
warn!("Failed to fetch Databricks models: {}", status);
|
||||||
@@ -485,9 +631,7 @@ impl DatabricksProvider {
|
|||||||
let endpoints = match json.get("endpoints").and_then(|v| v.as_array()) {
|
let endpoints = match json.get("endpoints").and_then(|v| v.as_array()) {
|
||||||
Some(endpoints) => endpoints,
|
Some(endpoints) => endpoints,
|
||||||
None => {
|
None => {
|
||||||
warn!(
|
warn!("Unexpected response format from Databricks API: missing 'endpoints' array");
|
||||||
"Unexpected response format from Databricks API: missing 'endpoints' array"
|
|
||||||
);
|
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -531,15 +675,21 @@ impl LLMProvider for DatabricksProvider {
|
|||||||
request.tools.as_deref(),
|
request.tools.as_deref(),
|
||||||
false,
|
false,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
temperature
|
temperature,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
debug!("Sending request to Databricks API: model={}, max_tokens={}, temperature={}",
|
debug!(
|
||||||
self.model, request_body.max_tokens, request_body.temperature);
|
"Sending request to Databricks API: model={}, max_tokens={}, temperature={}",
|
||||||
|
self.model, request_body.max_tokens, request_body.temperature
|
||||||
|
);
|
||||||
|
|
||||||
// Debug: Log the full request body when tools are present
|
// Debug: Log the full request body when tools are present
|
||||||
if request.tools.is_some() {
|
if request.tools.is_some() {
|
||||||
debug!("Full request body with tools: {}", serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "Failed to serialize".to_string()));
|
debug!(
|
||||||
|
"Full request body with tools: {}",
|
||||||
|
serde_json::to_string_pretty(&request_body)
|
||||||
|
.unwrap_or_else(|_| "Failed to serialize".to_string())
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut provider_clone = self.clone();
|
let mut provider_clone = self.clone();
|
||||||
@@ -564,7 +714,13 @@ impl LLMProvider for DatabricksProvider {
|
|||||||
debug!("Raw Databricks API response: {}", response_text);
|
debug!("Raw Databricks API response: {}", response_text);
|
||||||
|
|
||||||
let databricks_response: DatabricksResponse = serde_json::from_str(&response_text)
|
let databricks_response: DatabricksResponse = serde_json::from_str(&response_text)
|
||||||
.map_err(|e| anyhow!("Failed to parse Databricks response: {} - Response: {}", e, response_text))?;
|
.map_err(|e| {
|
||||||
|
anyhow!(
|
||||||
|
"Failed to parse Databricks response: {} - Response: {}",
|
||||||
|
e,
|
||||||
|
response_text
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
// Debug: Log the parsed response structure
|
// Debug: Log the parsed response structure
|
||||||
debug!("Parsed Databricks response: {:#?}", databricks_response);
|
debug!("Parsed Databricks response: {:#?}", databricks_response);
|
||||||
@@ -580,9 +736,15 @@ impl LLMProvider for DatabricksProvider {
|
|||||||
// Check if there are tool calls in the response
|
// Check if there are tool calls in the response
|
||||||
if let Some(first_choice) = databricks_response.choices.first() {
|
if let Some(first_choice) = databricks_response.choices.first() {
|
||||||
if let Some(tool_calls) = &first_choice.message.tool_calls {
|
if let Some(tool_calls) = &first_choice.message.tool_calls {
|
||||||
debug!("Found {} tool calls in Databricks response", tool_calls.len());
|
debug!(
|
||||||
|
"Found {} tool calls in Databricks response",
|
||||||
|
tool_calls.len()
|
||||||
|
);
|
||||||
for (i, tool_call) in tool_calls.iter().enumerate() {
|
for (i, tool_call) in tool_calls.iter().enumerate() {
|
||||||
debug!("Tool call {}: {} with args: {}", i, tool_call.function.name, tool_call.function.arguments);
|
debug!(
|
||||||
|
"Tool call {}: {} with args: {}",
|
||||||
|
i, tool_call.function.name, tool_call.function.arguments
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// For now, we'll return the content as-is since g3 handles tool calls via streaming
|
// For now, we'll return the content as-is since g3 handles tool calls via streaming
|
||||||
@@ -622,14 +784,20 @@ impl LLMProvider for DatabricksProvider {
|
|||||||
request.tools.as_deref(),
|
request.tools.as_deref(),
|
||||||
true,
|
true,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
temperature
|
temperature,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
debug!("Sending streaming request to Databricks API: model={}, max_tokens={}, temperature={}",
|
debug!(
|
||||||
self.model, request_body.max_tokens, request_body.temperature);
|
"Sending streaming request to Databricks API: model={}, max_tokens={}, temperature={}",
|
||||||
|
self.model, request_body.max_tokens, request_body.temperature
|
||||||
|
);
|
||||||
|
|
||||||
// Debug: Log the full request body
|
// Debug: Log the full request body
|
||||||
debug!("Full request body: {}", serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "Failed to serialize".to_string()));
|
debug!(
|
||||||
|
"Full request body: {}",
|
||||||
|
serde_json::to_string_pretty(&request_body)
|
||||||
|
.unwrap_or_else(|_| "Failed to serialize".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
let mut provider_clone = self.clone();
|
let mut provider_clone = self.clone();
|
||||||
let response = provider_clone
|
let response = provider_clone
|
||||||
@@ -731,6 +899,7 @@ struct DatabricksResponse {
|
|||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct DatabricksChoice {
|
struct DatabricksChoice {
|
||||||
message: DatabricksMessage,
|
message: DatabricksMessage,
|
||||||
|
#[allow(dead_code)]
|
||||||
finish_reason: Option<String>,
|
finish_reason: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -786,7 +955,8 @@ mod tests {
|
|||||||
"test-model".to_string(),
|
"test-model".to_string(),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
Message {
|
Message {
|
||||||
@@ -819,14 +989,13 @@ mod tests {
|
|||||||
"databricks-claude-sonnet-4".to_string(),
|
"databricks-claude-sonnet-4".to_string(),
|
||||||
Some(1000),
|
Some(1000),
|
||||||
Some(0.5),
|
Some(0.5),
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let messages = vec![
|
let messages = vec![Message {
|
||||||
Message {
|
role: MessageRole::User,
|
||||||
role: MessageRole::User,
|
content: "Test message".to_string(),
|
||||||
content: "Test message".to_string(),
|
}];
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
let request_body = provider
|
let request_body = provider
|
||||||
.create_request_body(&messages, None, false, 1000, 0.5)
|
.create_request_body(&messages, None, false, 1000, 0.5)
|
||||||
@@ -847,31 +1016,33 @@ mod tests {
|
|||||||
"test-model".to_string(),
|
"test-model".to_string(),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let tools = vec![
|
let tools = vec![Tool {
|
||||||
Tool {
|
name: "get_weather".to_string(),
|
||||||
name: "get_weather".to_string(),
|
description: "Get the current weather".to_string(),
|
||||||
description: "Get the current weather".to_string(),
|
input_schema: serde_json::json!({
|
||||||
input_schema: serde_json::json!({
|
"type": "object",
|
||||||
"type": "object",
|
"properties": {
|
||||||
"properties": {
|
"location": {
|
||||||
"location": {
|
"type": "string",
|
||||||
"type": "string",
|
"description": "The city and state"
|
||||||
"description": "The city and state"
|
}
|
||||||
}
|
},
|
||||||
},
|
"required": ["location"]
|
||||||
"required": ["location"]
|
}),
|
||||||
}),
|
}];
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
let databricks_tools = provider.convert_tools(&tools);
|
let databricks_tools = provider.convert_tools(&tools);
|
||||||
|
|
||||||
assert_eq!(databricks_tools.len(), 1);
|
assert_eq!(databricks_tools.len(), 1);
|
||||||
assert_eq!(databricks_tools[0].r#type, "function");
|
assert_eq!(databricks_tools[0].r#type, "function");
|
||||||
assert_eq!(databricks_tools[0].function.name, "get_weather");
|
assert_eq!(databricks_tools[0].function.name, "get_weather");
|
||||||
assert_eq!(databricks_tools[0].function.description, "Get the current weather");
|
assert_eq!(
|
||||||
|
databricks_tools[0].function.description,
|
||||||
|
"Get the current weather"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -882,7 +1053,8 @@ mod tests {
|
|||||||
"databricks-claude-sonnet-4".to_string(),
|
"databricks-claude-sonnet-4".to_string(),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let llama_provider = DatabricksProvider::from_token(
|
let llama_provider = DatabricksProvider::from_token(
|
||||||
"https://test.databricks.com".to_string(),
|
"https://test.databricks.com".to_string(),
|
||||||
@@ -890,7 +1062,8 @@ mod tests {
|
|||||||
"databricks-meta-llama-3-3-70b-instruct".to_string(),
|
"databricks-meta-llama-3-3-70b-instruct".to_string(),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let dbrx_provider = DatabricksProvider::from_token(
|
let dbrx_provider = DatabricksProvider::from_token(
|
||||||
"https://test.databricks.com".to_string(),
|
"https://test.databricks.com".to_string(),
|
||||||
@@ -898,7 +1071,8 @@ mod tests {
|
|||||||
"databricks-dbrx-instruct".to_string(),
|
"databricks-dbrx-instruct".to_string(),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert!(claude_provider.has_native_tool_calling());
|
assert!(claude_provider.has_native_tool_calling());
|
||||||
assert!(llama_provider.has_native_tool_calling());
|
assert!(llama_provider.has_native_tool_calling());
|
||||||
|
|||||||
Reference in New Issue
Block a user