refactor: improve readability of streaming and file ops code
Agent: carmack databricks.rs: - Extract ToolCallAccumulator struct to replace opaque (String, String, String) tuple - Add decode_utf8_streaming() helper for cleaner UTF-8 handling - Add is_incomplete_json_error() helper for JSON parse error detection - Add make_final_chunk() helper to reduce duplication - Add finalize_tool_calls() to convert accumulators to final format - Refactor parse_streaming_response from ~270 lines to ~100 lines - Reduce nesting depth from 8+ levels to 4 levels - Use early returns and let-else for cleaner control flow file_ops.rs: - Replace repetitive if-let chains with declarative PATH_CONTENT_KEYS table - Use match expression instead of nested if-else - Reduce extract_path_and_content from 44 lines to 20 lines All tests pass. Behavior unchanged.
This commit is contained in:
@@ -345,48 +345,35 @@ pub async fn execute_str_replace<W: UiWriter>(
|
|||||||
|
|
||||||
// Helper functions
|
// Helper functions
|
||||||
|
|
||||||
|
/// Known argument key pairs for path and content.
|
||||||
|
const PATH_CONTENT_KEYS: &[(&str, &str)] = &[
|
||||||
|
("file_path", "content"), // Standard format
|
||||||
|
("path", "content"), // Anthropic-style
|
||||||
|
("filename", "text"), // Alternative naming
|
||||||
|
("file", "data"), // Alternative naming
|
||||||
|
];
|
||||||
|
|
||||||
/// Extract path and content from various argument formats.
|
/// Extract path and content from various argument formats.
|
||||||
fn extract_path_and_content(args: &serde_json::Value) -> (Option<&str>, Option<&str>) {
|
fn extract_path_and_content(args: &serde_json::Value) -> (Option<&str>, Option<&str>) {
|
||||||
if let Some(args_obj) = args.as_object() {
|
match args {
|
||||||
// Format 1: Standard format with file_path and content
|
serde_json::Value::Object(obj) => {
|
||||||
if let (Some(path_val), Some(content_val)) =
|
for &(path_key, content_key) in PATH_CONTENT_KEYS {
|
||||||
(args_obj.get("file_path"), args_obj.get("content"))
|
if let (Some(p), Some(c)) = (obj.get(path_key), obj.get(content_key)) {
|
||||||
{
|
if let (Some(path), Some(content)) = (p.as_str(), c.as_str()) {
|
||||||
if let (Some(path), Some(content)) = (path_val.as_str(), content_val.as_str()) {
|
return (Some(path), Some(content));
|
||||||
return (Some(path), Some(content));
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Format 2: Anthropic-style with path and content
|
(None, None)
|
||||||
if let (Some(path_val), Some(content_val)) =
|
}
|
||||||
(args_obj.get("path"), args_obj.get("content"))
|
serde_json::Value::Array(arr) if arr.len() >= 2 => {
|
||||||
{
|
match (arr[0].as_str(), arr[1].as_str()) {
|
||||||
if let (Some(path), Some(content)) = (path_val.as_str(), content_val.as_str()) {
|
(Some(path), Some(content)) => (Some(path), Some(content)),
|
||||||
return (Some(path), Some(content));
|
_ => (None, None),
|
||||||
}
|
|
||||||
}
|
|
||||||
// Format 3: Alternative naming with filename and text
|
|
||||||
if let (Some(path_val), Some(content_val)) =
|
|
||||||
(args_obj.get("filename"), args_obj.get("text"))
|
|
||||||
{
|
|
||||||
if let (Some(path), Some(content)) = (path_val.as_str(), content_val.as_str()) {
|
|
||||||
return (Some(path), Some(content));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Format 4: Alternative naming with file and data
|
|
||||||
if let (Some(path_val), Some(content_val)) = (args_obj.get("file"), args_obj.get("data")) {
|
|
||||||
if let (Some(path), Some(content)) = (path_val.as_str(), content_val.as_str()) {
|
|
||||||
return (Some(path), Some(content));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if let Some(args_array) = args.as_array() {
|
|
||||||
// Format 5: Args might be an array [path, content]
|
|
||||||
if args_array.len() >= 2 {
|
|
||||||
if let (Some(path), Some(content)) = (args_array[0].as_str(), args_array[1].as_str()) {
|
|
||||||
return (Some(path), Some(content));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
_ => (None, None),
|
||||||
}
|
}
|
||||||
(None, None)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get image dimensions from raw bytes.
|
/// Get image dimensions from raw bytes.
|
||||||
|
|||||||
@@ -58,6 +58,7 @@
|
|||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
|
use std::collections::HashMap;
|
||||||
use futures_util::stream::StreamExt;
|
use futures_util::stream::StreamExt;
|
||||||
use reqwest::{Client, RequestBuilder};
|
use reqwest::{Client, RequestBuilder};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@@ -71,6 +72,101 @@ use crate::{
|
|||||||
MessageRole, Tool, ToolCall, Usage,
|
MessageRole, Tool, ToolCall, Usage,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Streaming helpers
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Accumulated state for a single tool call being streamed in chunks.
|
||||||
|
#[derive(Default)]
|
||||||
|
struct ToolCallAccumulator {
|
||||||
|
id: String,
|
||||||
|
name: String,
|
||||||
|
args: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolCallAccumulator {
|
||||||
|
/// Update accumulator with a streaming delta.
|
||||||
|
fn apply_delta(&mut self, delta: &DatabricksStreamToolCall) {
|
||||||
|
if let Some(ref id) = delta.id {
|
||||||
|
self.id = id.clone();
|
||||||
|
}
|
||||||
|
if !delta.function.name.is_empty() {
|
||||||
|
self.name = delta.function.name.clone();
|
||||||
|
}
|
||||||
|
self.args.push_str(&delta.function.arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert to final ToolCall if valid (has a name).
|
||||||
|
fn into_tool_call(self) -> Option<ToolCall> {
|
||||||
|
if self.name.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let id = if self.id.is_empty() {
|
||||||
|
format!("tool_{}", self.name)
|
||||||
|
} else {
|
||||||
|
self.id
|
||||||
|
};
|
||||||
|
let args = serde_json::from_str(&self.args)
|
||||||
|
.unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new()));
|
||||||
|
Some(ToolCall { id, tool: self.name, args })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert accumulated tool calls map to final Vec<ToolCall>.
|
||||||
|
fn finalize_tool_calls(accumulators: HashMap<usize, ToolCallAccumulator>) -> Vec<ToolCall> {
|
||||||
|
accumulators
|
||||||
|
.into_values()
|
||||||
|
.filter_map(|acc| acc.into_tool_call())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to decode bytes as UTF-8, handling incomplete sequences at the end.
|
||||||
|
/// Returns the decoded string and leaves any incomplete bytes in the buffer.
|
||||||
|
fn decode_utf8_streaming(byte_buffer: &mut Vec<u8>) -> Option<String> {
|
||||||
|
match std::str::from_utf8(byte_buffer) {
|
||||||
|
Ok(s) => {
|
||||||
|
let result = s.to_string();
|
||||||
|
byte_buffer.clear();
|
||||||
|
Some(result)
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let valid_up_to = e.valid_up_to();
|
||||||
|
if valid_up_to > 0 {
|
||||||
|
let valid_bytes: Vec<u8> = byte_buffer.drain(..valid_up_to).collect();
|
||||||
|
// Safe: we just validated these bytes
|
||||||
|
Some(String::from_utf8(valid_bytes).unwrap())
|
||||||
|
} else {
|
||||||
|
None // No valid UTF-8 yet, wait for more bytes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a JSON parse error indicates incomplete data (vs. malformed JSON).
|
||||||
|
fn is_incomplete_json_error(error: &serde_json::Error, data: &str) -> bool {
|
||||||
|
let msg = error.to_string().to_lowercase();
|
||||||
|
let looks_incomplete = msg.contains("eof")
|
||||||
|
|| msg.contains("unterminated")
|
||||||
|
|| msg.contains("unexpected end")
|
||||||
|
|| msg.contains("trailing");
|
||||||
|
let missing_terminator = !data.trim_end().ends_with('}') && !data.trim_end().ends_with(']');
|
||||||
|
looks_incomplete || missing_terminator
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a final completion chunk with tool calls and usage.
|
||||||
|
fn make_final_chunk(tool_calls: Vec<ToolCall>, usage: Option<Usage>) -> CompletionChunk {
|
||||||
|
CompletionChunk {
|
||||||
|
content: String::new(),
|
||||||
|
finished: true,
|
||||||
|
usage,
|
||||||
|
tool_calls: if tool_calls.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(tool_calls)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const DEFAULT_CLIENT_ID: &str = "databricks-cli";
|
const DEFAULT_CLIENT_ID: &str = "databricks-cli";
|
||||||
const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020";
|
const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020";
|
||||||
const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"];
|
const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"];
|
||||||
@@ -352,387 +448,131 @@ impl DatabricksProvider {
|
|||||||
tx: mpsc::Sender<Result<CompletionChunk>>,
|
tx: mpsc::Sender<Result<CompletionChunk>>,
|
||||||
) -> Option<Usage> {
|
) -> Option<Usage> {
|
||||||
let mut buffer = String::new();
|
let mut buffer = String::new();
|
||||||
let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> =
|
let mut tool_calls: HashMap<usize, ToolCallAccumulator> = HashMap::new();
|
||||||
std::collections::HashMap::new(); // index -> (id, name, args)
|
let mut incomplete_data_line = String::new();
|
||||||
let mut incomplete_data_line = String::new(); // Buffer for incomplete data: lines
|
|
||||||
let mut chunk_count = 0;
|
let mut chunk_count = 0;
|
||||||
let accumulated_usage: Option<Usage> = None;
|
let mut byte_buffer = Vec::new();
|
||||||
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
|
|
||||||
|
|
||||||
while let Some(chunk_result) = stream.next().await {
|
while let Some(chunk_result) = stream.next().await {
|
||||||
match chunk_result {
|
// Handle stream errors
|
||||||
Ok(chunk) => {
|
let chunk = match chunk_result {
|
||||||
// Debug: Log raw bytes received
|
Ok(c) => c,
|
||||||
chunk_count += 1;
|
|
||||||
debug!("Processing chunk #{}", chunk_count);
|
|
||||||
debug!("Raw SSE bytes received: {} bytes", chunk.len());
|
|
||||||
|
|
||||||
// Append new bytes to our buffer
|
|
||||||
byte_buffer.extend_from_slice(&chunk);
|
|
||||||
|
|
||||||
// Try to convert the entire buffer to UTF-8
|
|
||||||
let chunk_str = match std::str::from_utf8(&byte_buffer) {
|
|
||||||
Ok(s) => {
|
|
||||||
// Successfully converted entire buffer, clear it and use the string
|
|
||||||
let result = s.to_string();
|
|
||||||
byte_buffer.clear();
|
|
||||||
result
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
// Check if this is an incomplete sequence at the end
|
|
||||||
let valid_up_to = e.valid_up_to();
|
|
||||||
if valid_up_to > 0 {
|
|
||||||
// We have some valid UTF-8, extract it and keep the rest for next iteration
|
|
||||||
let valid_bytes =
|
|
||||||
byte_buffer.drain(..valid_up_to).collect::<Vec<_>>();
|
|
||||||
std::str::from_utf8(&valid_bytes).unwrap().to_string()
|
|
||||||
} else {
|
|
||||||
// No valid UTF-8 at all, skip this chunk and continue
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Debug: Log raw string content (truncated for large chunks)
|
|
||||||
if chunk_str.len() > 1000 {
|
|
||||||
debug!(
|
|
||||||
"Raw SSE string content (first 500 chars): {:?}...",
|
|
||||||
&chunk_str[..500]
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
debug!("Raw SSE string content: {:?}", chunk_str);
|
|
||||||
}
|
|
||||||
|
|
||||||
buffer.push_str(&chunk_str);
|
|
||||||
|
|
||||||
// Process complete lines, but handle incomplete data: lines specially
|
|
||||||
while let Some(line_end) = buffer.find('\n') {
|
|
||||||
let line = buffer[..line_end].trim().to_string();
|
|
||||||
buffer.drain(..line_end + 1);
|
|
||||||
|
|
||||||
if line.is_empty() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if we have an incomplete data line from previous chunk
|
|
||||||
let line = if !incomplete_data_line.is_empty() {
|
|
||||||
// We had an incomplete data: line, append this line to it
|
|
||||||
let complete_line = format!("{}{}", incomplete_data_line, line);
|
|
||||||
incomplete_data_line.clear();
|
|
||||||
complete_line
|
|
||||||
} else {
|
|
||||||
line
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check if this is a data: line that might be incomplete
|
|
||||||
// SSE format requires double newline after data, so if we don't see another newline
|
|
||||||
// after this one in the buffer, and it's a data: line, it might be incomplete
|
|
||||||
if line.starts_with("data: ") {
|
|
||||||
// Check if there's a complete SSE event (should have double newline after data)
|
|
||||||
// But for streaming, single newline is often used, so we need to be careful
|
|
||||||
// The safest approach is to try parsing and if it fails due to incomplete JSON,
|
|
||||||
// we'll handle it below
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debug: Log each SSE line (truncated for large lines)
|
|
||||||
if line.len() > 1000 {
|
|
||||||
debug!("SSE line (first 500 chars): {:?}...", &line[..500]);
|
|
||||||
} else {
|
|
||||||
debug!("SSE line: {:?}", line);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse Server-Sent Events format
|
|
||||||
if let Some(data) = line.strip_prefix("data: ") {
|
|
||||||
if data == "[DONE]" {
|
|
||||||
debug!("Received stream completion marker");
|
|
||||||
let final_tool_calls: Vec<ToolCall> = current_tool_calls
|
|
||||||
.values()
|
|
||||||
.map(|(id, name, args)| ToolCall {
|
|
||||||
id: id.clone(),
|
|
||||||
tool: name.clone(),
|
|
||||||
args: serde_json::from_str(args).unwrap_or(
|
|
||||||
serde_json::Value::Object(serde_json::Map::new()),
|
|
||||||
),
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let final_chunk = CompletionChunk {
|
|
||||||
content: String::new(),
|
|
||||||
finished: true,
|
|
||||||
usage: accumulated_usage.clone(),
|
|
||||||
tool_calls: if final_tool_calls.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(final_tool_calls)
|
|
||||||
},
|
|
||||||
};
|
|
||||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
|
||||||
debug!("Receiver dropped, stopping stream");
|
|
||||||
}
|
|
||||||
return accumulated_usage;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debug: Log every raw JSON payload from Databricks API (truncated for large payloads)
|
|
||||||
if data.len() > 1000 {
|
|
||||||
debug!(
|
|
||||||
"Raw Databricks SSE JSON payload (first 500 chars): {}...",
|
|
||||||
&data[..500]
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
debug!("Raw Databricks SSE JSON payload: {}", data);
|
|
||||||
}
|
|
||||||
|
|
||||||
match serde_json::from_str::<DatabricksStreamChunk>(data) {
|
|
||||||
Ok(chunk) => {
|
|
||||||
debug!("Successfully parsed Databricks stream chunk");
|
|
||||||
|
|
||||||
// Handle different types of chunks
|
|
||||||
if let Some(choices) = chunk.choices {
|
|
||||||
for choice in choices {
|
|
||||||
if let Some(delta) = choice.delta {
|
|
||||||
// Handle text content
|
|
||||||
if let Some(content) = delta.content {
|
|
||||||
debug!("Sending text chunk: '{}'", content);
|
|
||||||
let chunk = CompletionChunk {
|
|
||||||
content,
|
|
||||||
finished: false,
|
|
||||||
usage: None,
|
|
||||||
tool_calls: None,
|
|
||||||
};
|
|
||||||
if tx.send(Ok(chunk)).await.is_err() {
|
|
||||||
debug!("Receiver dropped, stopping stream");
|
|
||||||
return accumulated_usage;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle tool calls - accumulate across chunks
|
|
||||||
if let Some(tool_calls) = delta.tool_calls {
|
|
||||||
debug!(
|
|
||||||
"Processing {} tool call deltas",
|
|
||||||
tool_calls.len()
|
|
||||||
);
|
|
||||||
for tool_call in tool_calls {
|
|
||||||
let index = tool_call.index.unwrap_or(0);
|
|
||||||
debug!("Tool call delta for index {}: id={:?}, name='{}', args_len={}",
|
|
||||||
index, tool_call.id, tool_call.function.name, tool_call.function.arguments.len());
|
|
||||||
|
|
||||||
let entry = current_tool_calls
|
|
||||||
.entry(index)
|
|
||||||
.or_insert_with(|| {
|
|
||||||
(
|
|
||||||
String::new(),
|
|
||||||
String::new(),
|
|
||||||
String::new(),
|
|
||||||
)
|
|
||||||
});
|
|
||||||
|
|
||||||
// Update ID if provided
|
|
||||||
if let Some(id) = tool_call.id {
|
|
||||||
debug!("Updating tool call {} ID from '{}' to '{}'", index, entry.0, id);
|
|
||||||
entry.0 = id;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update name if provided and not empty
|
|
||||||
if !tool_call.function.name.is_empty() {
|
|
||||||
debug!("Updating tool call {} name from '{}' to '{}'", index, entry.1, tool_call.function.name);
|
|
||||||
entry.1 = tool_call.function.name;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Append arguments
|
|
||||||
debug!("Appending {} chars to tool call {} args (current len: {})",
|
|
||||||
tool_call.function.arguments.len(), index, entry.2.len());
|
|
||||||
entry.2.push_str(
|
|
||||||
&tool_call.function.arguments,
|
|
||||||
);
|
|
||||||
|
|
||||||
debug!("Accumulated tool call {}: id='{}', name='{}', args_len={}",
|
|
||||||
index, entry.0, entry.1, entry.2.len());
|
|
||||||
|
|
||||||
// Debug: Show a sample of the accumulated args if they're getting long
|
|
||||||
if entry.2.len() > 100 {
|
|
||||||
debug!("Tool call {} args sample (first 100 chars): {}", index, &entry.2[..100]);
|
|
||||||
} else if !entry.2.is_empty() {
|
|
||||||
debug!(
|
|
||||||
"Tool call {} full args: {}",
|
|
||||||
index, entry.2
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if this choice is finished
|
|
||||||
if choice.finish_reason.is_some() {
|
|
||||||
debug!(
|
|
||||||
"Choice finished with reason: {:?}",
|
|
||||||
choice.finish_reason
|
|
||||||
);
|
|
||||||
|
|
||||||
// Convert accumulated tool calls to final format
|
|
||||||
let final_tool_calls: Vec<ToolCall> = current_tool_calls.values()
|
|
||||||
.filter(|(_, name, _)| !name.is_empty()) // Only include tool calls with names
|
|
||||||
.map(|(id, name, args)| {
|
|
||||||
debug!("Converting tool call: id='{}', name='{}', args_len={}", id, name, args.len());
|
|
||||||
ToolCall {
|
|
||||||
id: if id.is_empty() { format!("tool_{}", name) } else { id.clone() },
|
|
||||||
tool: name.clone(),
|
|
||||||
args: serde_json::from_str(args).unwrap_or_else(|e| {
|
|
||||||
debug!("Failed to parse tool args (len={}): {}", args.len(), e);
|
|
||||||
// For debugging, log a sample of the args if they're very long
|
|
||||||
if args.len() > 1000 {
|
|
||||||
debug!("Tool args sample (first 500 chars): {}", &args[..500]);
|
|
||||||
} else {
|
|
||||||
debug!("Full tool args: {}", args);
|
|
||||||
}
|
|
||||||
serde_json::Value::Object(serde_json::Map::new())
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
debug!(
|
|
||||||
"Final tool calls count: {}",
|
|
||||||
final_tool_calls.len()
|
|
||||||
);
|
|
||||||
|
|
||||||
let final_chunk = CompletionChunk {
|
|
||||||
content: String::new(),
|
|
||||||
finished: true,
|
|
||||||
usage: accumulated_usage.clone(),
|
|
||||||
tool_calls: if final_tool_calls.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(final_tool_calls)
|
|
||||||
},
|
|
||||||
};
|
|
||||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
|
||||||
debug!("Receiver dropped, stopping stream");
|
|
||||||
}
|
|
||||||
return accumulated_usage;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
// Check if this is likely an incomplete JSON due to line splitting
|
|
||||||
// Common indicators: unexpected EOF, unterminated string, etc.
|
|
||||||
let error_str = e.to_string().to_lowercase();
|
|
||||||
if line.starts_with("data: ")
|
|
||||||
&& (error_str.contains("eof") ||
|
|
||||||
error_str.contains("unterminated") ||
|
|
||||||
error_str.contains("unexpected end") ||
|
|
||||||
error_str.contains("trailing") ||
|
|
||||||
// Also check if the data doesn't end with a proper JSON terminator
|
|
||||||
(!data.trim_end().ends_with('}') && !data.trim_end().ends_with(']')))
|
|
||||||
{
|
|
||||||
// This looks like an incomplete data line, save it for the next chunk
|
|
||||||
debug!("Detected incomplete data line (len={}), buffering for next chunk", line.len());
|
|
||||||
incomplete_data_line = line.clone();
|
|
||||||
// Continue to next iteration without processing
|
|
||||||
continue;
|
|
||||||
} else {
|
|
||||||
// This is a real parse error, not due to line splitting
|
|
||||||
debug!("Failed to parse Databricks stream chunk JSON: {} - Data length: {}", e, data.len());
|
|
||||||
// For debugging large payloads, log a sample
|
|
||||||
if data.len() > 1000 {
|
|
||||||
debug!(
|
|
||||||
"JSON parse error - data sample: {}",
|
|
||||||
&data[..std::cmp::min(500, data.len())]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Don't error out on parse failures, just continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if line.starts_with("event: ") || line.starts_with("id: ") {
|
|
||||||
// Debug: Log non-data SSE lines (like event: or id:)
|
|
||||||
debug!("Non-data SSE line: {}", line);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Stream error at chunk {}: {}", chunk_count, e);
|
error!("Stream error at chunk {}: {}", chunk_count, e);
|
||||||
|
let is_connection_error = e.to_string().contains("unexpected EOF")
|
||||||
// Check if this is a connection error that might be recoverable
|
|| e.to_string().contains("connection");
|
||||||
let error_msg = e.to_string();
|
if is_connection_error {
|
||||||
if error_msg.contains("unexpected EOF") || error_msg.contains("connection") {
|
warn!("Connection terminated unexpectedly, treating as end of stream");
|
||||||
warn!("Connection terminated unexpectedly at chunk {}, treating as end of stream", chunk_count);
|
|
||||||
// Don't send error, just break and finalize
|
|
||||||
break;
|
break;
|
||||||
} else {
|
|
||||||
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
|
||||||
}
|
}
|
||||||
return accumulated_usage;
|
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
chunk_count += 1;
|
||||||
|
byte_buffer.extend_from_slice(&chunk);
|
||||||
|
|
||||||
|
// Decode UTF-8, handling incomplete sequences
|
||||||
|
let Some(chunk_str) = decode_utf8_streaming(&mut byte_buffer) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
buffer.push_str(&chunk_str);
|
||||||
|
|
||||||
|
// Process complete lines
|
||||||
|
while let Some(line_end) = buffer.find('\n') {
|
||||||
|
let line = buffer[..line_end].trim().to_string();
|
||||||
|
buffer.drain(..line_end + 1);
|
||||||
|
|
||||||
|
if line.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reassemble lines split across chunks
|
||||||
|
let line = if !incomplete_data_line.is_empty() {
|
||||||
|
let complete = format!("{}{}", incomplete_data_line, line);
|
||||||
|
incomplete_data_line.clear();
|
||||||
|
complete
|
||||||
|
} else {
|
||||||
|
line
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse SSE data lines
|
||||||
|
let Some(data) = line.strip_prefix("data: ") else {
|
||||||
|
if line.starts_with("event: ") || line.starts_with("id: ") {
|
||||||
|
debug!("SSE control line: {}", line);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Stream completion marker
|
||||||
|
if data == "[DONE]" {
|
||||||
|
debug!("Received stream completion marker");
|
||||||
|
let final_calls = finalize_tool_calls(tool_calls);
|
||||||
|
let _ = tx.send(Ok(make_final_chunk(final_calls, None))).await;
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse JSON payload
|
||||||
|
let parsed = match serde_json::from_str::<DatabricksStreamChunk>(data) {
|
||||||
|
Ok(c) => c,
|
||||||
|
Err(e) => {
|
||||||
|
if is_incomplete_json_error(&e, data) {
|
||||||
|
debug!("Incomplete JSON, buffering for next chunk");
|
||||||
|
incomplete_data_line = line;
|
||||||
|
} else {
|
||||||
|
debug!("JSON parse error: {}", e);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Process choices from the chunk
|
||||||
|
let Some(choices) = parsed.choices else { continue };
|
||||||
|
for choice in choices {
|
||||||
|
// Handle delta content
|
||||||
|
if let Some(delta) = &choice.delta {
|
||||||
|
// Text content
|
||||||
|
if let Some(ref content) = delta.content {
|
||||||
|
let text_chunk = CompletionChunk {
|
||||||
|
content: content.clone(),
|
||||||
|
finished: false,
|
||||||
|
usage: None,
|
||||||
|
tool_calls: None,
|
||||||
|
};
|
||||||
|
if tx.send(Ok(text_chunk)).await.is_err() {
|
||||||
|
debug!("Receiver dropped");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool call deltas
|
||||||
|
if let Some(ref deltas) = delta.tool_calls {
|
||||||
|
for tc_delta in deltas {
|
||||||
|
let idx = tc_delta.index.unwrap_or(0);
|
||||||
|
tool_calls
|
||||||
|
.entry(idx)
|
||||||
|
.or_default()
|
||||||
|
.apply_delta(tc_delta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Choice finished
|
||||||
|
if choice.finish_reason.is_some() {
|
||||||
|
debug!("Choice finished: {:?}", choice.finish_reason);
|
||||||
|
let final_calls = finalize_tool_calls(std::mem::take(&mut tool_calls));
|
||||||
|
let _ = tx.send(Ok(make_final_chunk(final_calls, None))).await;
|
||||||
|
return None;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Log final state
|
|
||||||
debug!("Stream ended after {} chunks", chunk_count);
|
debug!("Stream ended after {} chunks", chunk_count);
|
||||||
debug!(
|
let final_calls = finalize_tool_calls(tool_calls);
|
||||||
"Final state: buffer_len={}, incomplete_data_line_len={}, byte_buffer_len={}",
|
let _ = tx.send(Ok(make_final_chunk(final_calls, None))).await;
|
||||||
buffer.len(),
|
None
|
||||||
incomplete_data_line.len(),
|
|
||||||
byte_buffer.len()
|
|
||||||
);
|
|
||||||
debug!("Accumulated tool calls: {}", current_tool_calls.len());
|
|
||||||
|
|
||||||
// If we have any remaining data in buffers, log it for debugging
|
|
||||||
if !buffer.is_empty() {
|
|
||||||
debug!("Remaining buffer content: {:?}", buffer);
|
|
||||||
}
|
|
||||||
if !byte_buffer.is_empty() {
|
|
||||||
debug!("Remaining byte buffer: {} bytes", byte_buffer.len());
|
|
||||||
}
|
|
||||||
if !incomplete_data_line.is_empty() {
|
|
||||||
debug!("Remaining incomplete data line: {:?}", incomplete_data_line);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we have any incomplete data line at the end, try to process it
|
|
||||||
if !incomplete_data_line.is_empty() {
|
|
||||||
debug!(
|
|
||||||
"Processing final incomplete data line (len={})",
|
|
||||||
incomplete_data_line.len()
|
|
||||||
);
|
|
||||||
if let Some(data) = incomplete_data_line.strip_prefix("data: ") {
|
|
||||||
// Try to parse it as-is, it might be complete
|
|
||||||
if let Ok(_chunk) = serde_json::from_str::<DatabricksStreamChunk>(data) {
|
|
||||||
// Process the chunk (code would be duplicated from above, so in practice
|
|
||||||
// we'd extract this to a helper function)
|
|
||||||
debug!("Successfully parsed final incomplete data line");
|
|
||||||
} else {
|
|
||||||
warn!("Failed to parse final incomplete data line");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send final chunk if we haven't already
|
|
||||||
let final_tool_calls: Vec<ToolCall> = current_tool_calls
|
|
||||||
.values()
|
|
||||||
.filter(|(_, name, _)| !name.is_empty())
|
|
||||||
.map(|(id, name, args)| ToolCall {
|
|
||||||
id: if id.is_empty() {
|
|
||||||
format!("tool_{}", name)
|
|
||||||
} else {
|
|
||||||
id.clone()
|
|
||||||
},
|
|
||||||
tool: name.clone(),
|
|
||||||
args: serde_json::from_str(args)
|
|
||||||
.unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let final_chunk = CompletionChunk {
|
|
||||||
content: String::new(),
|
|
||||||
finished: true,
|
|
||||||
usage: accumulated_usage.clone(),
|
|
||||||
tool_calls: if final_tool_calls.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(final_tool_calls)
|
|
||||||
},
|
|
||||||
};
|
|
||||||
let _ = tx.send(Ok(final_chunk)).await;
|
|
||||||
accumulated_usage
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn fetch_supported_models(&mut self) -> Result<Option<Vec<String>>> {
|
pub async fn fetch_supported_models(&mut self) -> Result<Option<Vec<String>>> {
|
||||||
|
|||||||
Reference in New Issue
Block a user