From bb630507795a237ed188e618dcd68eabbb5b71e1 Mon Sep 17 00:00:00 2001 From: "Dhanji R. Prasanna" Date: Wed, 7 Jan 2026 12:39:05 +1100 Subject: [PATCH] refactor: improve readability of streaming and file ops code Agent: carmack databricks.rs: - Extract ToolCallAccumulator struct to replace opaque (String, String, String) tuple - Add decode_utf8_streaming() helper for cleaner UTF-8 handling - Add is_incomplete_json_error() helper for JSON parse error detection - Add make_final_chunk() helper to reduce duplication - Add finalize_tool_calls() to convert accumulators to final format - Refactor parse_streaming_response from ~270 lines to ~100 lines - Reduce nesting depth from 8+ levels to 4 levels - Use early returns and let-else for cleaner control flow file_ops.rs: - Replace repetitive if-let chains with declarative PATH_CONTENT_KEYS table - Use match expression instead of nested if-else - Reduce extract_path_and_content from 44 lines to 20 lines All tests pass. Behavior unchanged. --- crates/g3-core/src/tools/file_ops.rs | 61 ++- crates/g3-providers/src/databricks.rs | 578 ++++++++++---------------- 2 files changed, 233 insertions(+), 406 deletions(-) diff --git a/crates/g3-core/src/tools/file_ops.rs b/crates/g3-core/src/tools/file_ops.rs index 6ed3e73..81164d9 100644 --- a/crates/g3-core/src/tools/file_ops.rs +++ b/crates/g3-core/src/tools/file_ops.rs @@ -345,48 +345,35 @@ pub async fn execute_str_replace( // Helper functions +/// Known argument key pairs for path and content. +const PATH_CONTENT_KEYS: &[(&str, &str)] = &[ + ("file_path", "content"), // Standard format + ("path", "content"), // Anthropic-style + ("filename", "text"), // Alternative naming + ("file", "data"), // Alternative naming +]; + /// Extract path and content from various argument formats. fn extract_path_and_content(args: &serde_json::Value) -> (Option<&str>, Option<&str>) { - if let Some(args_obj) = args.as_object() { - // Format 1: Standard format with file_path and content - if let (Some(path_val), Some(content_val)) = - (args_obj.get("file_path"), args_obj.get("content")) - { - if let (Some(path), Some(content)) = (path_val.as_str(), content_val.as_str()) { - return (Some(path), Some(content)); - } - } - // Format 2: Anthropic-style with path and content - if let (Some(path_val), Some(content_val)) = - (args_obj.get("path"), args_obj.get("content")) - { - if let (Some(path), Some(content)) = (path_val.as_str(), content_val.as_str()) { - return (Some(path), Some(content)); - } - } - // Format 3: Alternative naming with filename and text - if let (Some(path_val), Some(content_val)) = - (args_obj.get("filename"), args_obj.get("text")) - { - if let (Some(path), Some(content)) = (path_val.as_str(), content_val.as_str()) { - return (Some(path), Some(content)); - } - } - // Format 4: Alternative naming with file and data - if let (Some(path_val), Some(content_val)) = (args_obj.get("file"), args_obj.get("data")) { - if let (Some(path), Some(content)) = (path_val.as_str(), content_val.as_str()) { - return (Some(path), Some(content)); - } - } - } else if let Some(args_array) = args.as_array() { - // Format 5: Args might be an array [path, content] - if args_array.len() >= 2 { - if let (Some(path), Some(content)) = (args_array[0].as_str(), args_array[1].as_str()) { - return (Some(path), Some(content)); + match args { + serde_json::Value::Object(obj) => { + for &(path_key, content_key) in PATH_CONTENT_KEYS { + if let (Some(p), Some(c)) = (obj.get(path_key), obj.get(content_key)) { + if let (Some(path), Some(content)) = (p.as_str(), c.as_str()) { + return (Some(path), Some(content)); + } + } + } + (None, None) + } + serde_json::Value::Array(arr) if arr.len() >= 2 => { + match (arr[0].as_str(), arr[1].as_str()) { + (Some(path), Some(content)) => (Some(path), Some(content)), + _ => (None, None), } } + _ => (None, None), } - (None, None) } /// Get image dimensions from raw bytes. diff --git a/crates/g3-providers/src/databricks.rs b/crates/g3-providers/src/databricks.rs index 95fda57..ec70be2 100644 --- a/crates/g3-providers/src/databricks.rs +++ b/crates/g3-providers/src/databricks.rs @@ -58,6 +58,7 @@ use anyhow::{anyhow, Result}; use bytes::Bytes; +use std::collections::HashMap; use futures_util::stream::StreamExt; use reqwest::{Client, RequestBuilder}; use serde::{Deserialize, Serialize}; @@ -71,6 +72,101 @@ use crate::{ MessageRole, Tool, ToolCall, Usage, }; +// ───────────────────────────────────────────────────────────────────────────── +// Streaming helpers +// ───────────────────────────────────────────────────────────────────────────── + +/// Accumulated state for a single tool call being streamed in chunks. +#[derive(Default)] +struct ToolCallAccumulator { + id: String, + name: String, + args: String, +} + +impl ToolCallAccumulator { + /// Update accumulator with a streaming delta. + fn apply_delta(&mut self, delta: &DatabricksStreamToolCall) { + if let Some(ref id) = delta.id { + self.id = id.clone(); + } + if !delta.function.name.is_empty() { + self.name = delta.function.name.clone(); + } + self.args.push_str(&delta.function.arguments); + } + + /// Convert to final ToolCall if valid (has a name). + fn into_tool_call(self) -> Option { + if self.name.is_empty() { + return None; + } + let id = if self.id.is_empty() { + format!("tool_{}", self.name) + } else { + self.id + }; + let args = serde_json::from_str(&self.args) + .unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new())); + Some(ToolCall { id, tool: self.name, args }) + } +} + +/// Convert accumulated tool calls map to final Vec. +fn finalize_tool_calls(accumulators: HashMap) -> Vec { + accumulators + .into_values() + .filter_map(|acc| acc.into_tool_call()) + .collect() +} + +/// Try to decode bytes as UTF-8, handling incomplete sequences at the end. +/// Returns the decoded string and leaves any incomplete bytes in the buffer. +fn decode_utf8_streaming(byte_buffer: &mut Vec) -> Option { + match std::str::from_utf8(byte_buffer) { + Ok(s) => { + let result = s.to_string(); + byte_buffer.clear(); + Some(result) + } + Err(e) => { + let valid_up_to = e.valid_up_to(); + if valid_up_to > 0 { + let valid_bytes: Vec = byte_buffer.drain(..valid_up_to).collect(); + // Safe: we just validated these bytes + Some(String::from_utf8(valid_bytes).unwrap()) + } else { + None // No valid UTF-8 yet, wait for more bytes + } + } + } +} + +/// Check if a JSON parse error indicates incomplete data (vs. malformed JSON). +fn is_incomplete_json_error(error: &serde_json::Error, data: &str) -> bool { + let msg = error.to_string().to_lowercase(); + let looks_incomplete = msg.contains("eof") + || msg.contains("unterminated") + || msg.contains("unexpected end") + || msg.contains("trailing"); + let missing_terminator = !data.trim_end().ends_with('}') && !data.trim_end().ends_with(']'); + looks_incomplete || missing_terminator +} + +/// Create a final completion chunk with tool calls and usage. +fn make_final_chunk(tool_calls: Vec, usage: Option) -> CompletionChunk { + CompletionChunk { + content: String::new(), + finished: true, + usage, + tool_calls: if tool_calls.is_empty() { + None + } else { + Some(tool_calls) + }, + } +} + const DEFAULT_CLIENT_ID: &str = "databricks-cli"; const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020"; const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"]; @@ -352,387 +448,131 @@ impl DatabricksProvider { tx: mpsc::Sender>, ) -> Option { let mut buffer = String::new(); - let mut current_tool_calls: std::collections::HashMap = - std::collections::HashMap::new(); // index -> (id, name, args) - let mut incomplete_data_line = String::new(); // Buffer for incomplete data: lines + let mut tool_calls: HashMap = HashMap::new(); + let mut incomplete_data_line = String::new(); let mut chunk_count = 0; - let accumulated_usage: Option = None; - let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences + let mut byte_buffer = Vec::new(); while let Some(chunk_result) = stream.next().await { - match chunk_result { - Ok(chunk) => { - // Debug: Log raw bytes received - chunk_count += 1; - debug!("Processing chunk #{}", chunk_count); - debug!("Raw SSE bytes received: {} bytes", chunk.len()); - - // Append new bytes to our buffer - byte_buffer.extend_from_slice(&chunk); - - // Try to convert the entire buffer to UTF-8 - let chunk_str = match std::str::from_utf8(&byte_buffer) { - Ok(s) => { - // Successfully converted entire buffer, clear it and use the string - let result = s.to_string(); - byte_buffer.clear(); - result - } - Err(e) => { - // Check if this is an incomplete sequence at the end - let valid_up_to = e.valid_up_to(); - if valid_up_to > 0 { - // We have some valid UTF-8, extract it and keep the rest for next iteration - let valid_bytes = - byte_buffer.drain(..valid_up_to).collect::>(); - std::str::from_utf8(&valid_bytes).unwrap().to_string() - } else { - // No valid UTF-8 at all, skip this chunk and continue - continue; - } - } - }; - - // Debug: Log raw string content (truncated for large chunks) - if chunk_str.len() > 1000 { - debug!( - "Raw SSE string content (first 500 chars): {:?}...", - &chunk_str[..500] - ); - } else { - debug!("Raw SSE string content: {:?}", chunk_str); - } - - buffer.push_str(&chunk_str); - - // Process complete lines, but handle incomplete data: lines specially - while let Some(line_end) = buffer.find('\n') { - let line = buffer[..line_end].trim().to_string(); - buffer.drain(..line_end + 1); - - if line.is_empty() { - continue; - } - - // Check if we have an incomplete data line from previous chunk - let line = if !incomplete_data_line.is_empty() { - // We had an incomplete data: line, append this line to it - let complete_line = format!("{}{}", incomplete_data_line, line); - incomplete_data_line.clear(); - complete_line - } else { - line - }; - - // Check if this is a data: line that might be incomplete - // SSE format requires double newline after data, so if we don't see another newline - // after this one in the buffer, and it's a data: line, it might be incomplete - if line.starts_with("data: ") { - // Check if there's a complete SSE event (should have double newline after data) - // But for streaming, single newline is often used, so we need to be careful - // The safest approach is to try parsing and if it fails due to incomplete JSON, - // we'll handle it below - } - - // Debug: Log each SSE line (truncated for large lines) - if line.len() > 1000 { - debug!("SSE line (first 500 chars): {:?}...", &line[..500]); - } else { - debug!("SSE line: {:?}", line); - } - - // Parse Server-Sent Events format - if let Some(data) = line.strip_prefix("data: ") { - if data == "[DONE]" { - debug!("Received stream completion marker"); - let final_tool_calls: Vec = current_tool_calls - .values() - .map(|(id, name, args)| ToolCall { - id: id.clone(), - tool: name.clone(), - args: serde_json::from_str(args).unwrap_or( - serde_json::Value::Object(serde_json::Map::new()), - ), - }) - .collect(); - let final_chunk = CompletionChunk { - content: String::new(), - finished: true, - usage: accumulated_usage.clone(), - tool_calls: if final_tool_calls.is_empty() { - None - } else { - Some(final_tool_calls) - }, - }; - if tx.send(Ok(final_chunk)).await.is_err() { - debug!("Receiver dropped, stopping stream"); - } - return accumulated_usage; - } - - // Debug: Log every raw JSON payload from Databricks API (truncated for large payloads) - if data.len() > 1000 { - debug!( - "Raw Databricks SSE JSON payload (first 500 chars): {}...", - &data[..500] - ); - } else { - debug!("Raw Databricks SSE JSON payload: {}", data); - } - - match serde_json::from_str::(data) { - Ok(chunk) => { - debug!("Successfully parsed Databricks stream chunk"); - - // Handle different types of chunks - if let Some(choices) = chunk.choices { - for choice in choices { - if let Some(delta) = choice.delta { - // Handle text content - if let Some(content) = delta.content { - debug!("Sending text chunk: '{}'", content); - let chunk = CompletionChunk { - content, - finished: false, - usage: None, - tool_calls: None, - }; - if tx.send(Ok(chunk)).await.is_err() { - debug!("Receiver dropped, stopping stream"); - return accumulated_usage; - } - } - - // Handle tool calls - accumulate across chunks - if let Some(tool_calls) = delta.tool_calls { - debug!( - "Processing {} tool call deltas", - tool_calls.len() - ); - for tool_call in tool_calls { - let index = tool_call.index.unwrap_or(0); - debug!("Tool call delta for index {}: id={:?}, name='{}', args_len={}", - index, tool_call.id, tool_call.function.name, tool_call.function.arguments.len()); - - let entry = current_tool_calls - .entry(index) - .or_insert_with(|| { - ( - String::new(), - String::new(), - String::new(), - ) - }); - - // Update ID if provided - if let Some(id) = tool_call.id { - debug!("Updating tool call {} ID from '{}' to '{}'", index, entry.0, id); - entry.0 = id; - } - - // Update name if provided and not empty - if !tool_call.function.name.is_empty() { - debug!("Updating tool call {} name from '{}' to '{}'", index, entry.1, tool_call.function.name); - entry.1 = tool_call.function.name; - } - - // Append arguments - debug!("Appending {} chars to tool call {} args (current len: {})", - tool_call.function.arguments.len(), index, entry.2.len()); - entry.2.push_str( - &tool_call.function.arguments, - ); - - debug!("Accumulated tool call {}: id='{}', name='{}', args_len={}", - index, entry.0, entry.1, entry.2.len()); - - // Debug: Show a sample of the accumulated args if they're getting long - if entry.2.len() > 100 { - debug!("Tool call {} args sample (first 100 chars): {}", index, &entry.2[..100]); - } else if !entry.2.is_empty() { - debug!( - "Tool call {} full args: {}", - index, entry.2 - ); - } - } - } - } - - // Check if this choice is finished - if choice.finish_reason.is_some() { - debug!( - "Choice finished with reason: {:?}", - choice.finish_reason - ); - - // Convert accumulated tool calls to final format - let final_tool_calls: Vec = current_tool_calls.values() - .filter(|(_, name, _)| !name.is_empty()) // Only include tool calls with names - .map(|(id, name, args)| { - debug!("Converting tool call: id='{}', name='{}', args_len={}", id, name, args.len()); - ToolCall { - id: if id.is_empty() { format!("tool_{}", name) } else { id.clone() }, - tool: name.clone(), - args: serde_json::from_str(args).unwrap_or_else(|e| { - debug!("Failed to parse tool args (len={}): {}", args.len(), e); - // For debugging, log a sample of the args if they're very long - if args.len() > 1000 { - debug!("Tool args sample (first 500 chars): {}", &args[..500]); - } else { - debug!("Full tool args: {}", args); - } - serde_json::Value::Object(serde_json::Map::new()) - }), - } - }) - .collect(); - - debug!( - "Final tool calls count: {}", - final_tool_calls.len() - ); - - let final_chunk = CompletionChunk { - content: String::new(), - finished: true, - usage: accumulated_usage.clone(), - tool_calls: if final_tool_calls.is_empty() { - None - } else { - Some(final_tool_calls) - }, - }; - if tx.send(Ok(final_chunk)).await.is_err() { - debug!("Receiver dropped, stopping stream"); - } - return accumulated_usage; - } - } - } - } - Err(e) => { - // Check if this is likely an incomplete JSON due to line splitting - // Common indicators: unexpected EOF, unterminated string, etc. - let error_str = e.to_string().to_lowercase(); - if line.starts_with("data: ") - && (error_str.contains("eof") || - error_str.contains("unterminated") || - error_str.contains("unexpected end") || - error_str.contains("trailing") || - // Also check if the data doesn't end with a proper JSON terminator - (!data.trim_end().ends_with('}') && !data.trim_end().ends_with(']'))) - { - // This looks like an incomplete data line, save it for the next chunk - debug!("Detected incomplete data line (len={}), buffering for next chunk", line.len()); - incomplete_data_line = line.clone(); - // Continue to next iteration without processing - continue; - } else { - // This is a real parse error, not due to line splitting - debug!("Failed to parse Databricks stream chunk JSON: {} - Data length: {}", e, data.len()); - // For debugging large payloads, log a sample - if data.len() > 1000 { - debug!( - "JSON parse error - data sample: {}", - &data[..std::cmp::min(500, data.len())] - ); - } - } - // Don't error out on parse failures, just continue - } - } - } else if line.starts_with("event: ") || line.starts_with("id: ") { - // Debug: Log non-data SSE lines (like event: or id:) - debug!("Non-data SSE line: {}", line); - } - } - } + // Handle stream errors + let chunk = match chunk_result { + Ok(c) => c, Err(e) => { error!("Stream error at chunk {}: {}", chunk_count, e); - - // Check if this is a connection error that might be recoverable - let error_msg = e.to_string(); - if error_msg.contains("unexpected EOF") || error_msg.contains("connection") { - warn!("Connection terminated unexpectedly at chunk {}, treating as end of stream", chunk_count); - // Don't send error, just break and finalize + let is_connection_error = e.to_string().contains("unexpected EOF") + || e.to_string().contains("connection"); + if is_connection_error { + warn!("Connection terminated unexpectedly, treating as end of stream"); break; - } else { - let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await; } - return accumulated_usage; + let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await; + return None; + } + }; + + chunk_count += 1; + byte_buffer.extend_from_slice(&chunk); + + // Decode UTF-8, handling incomplete sequences + let Some(chunk_str) = decode_utf8_streaming(&mut byte_buffer) else { + continue; + }; + buffer.push_str(&chunk_str); + + // Process complete lines + while let Some(line_end) = buffer.find('\n') { + let line = buffer[..line_end].trim().to_string(); + buffer.drain(..line_end + 1); + + if line.is_empty() { + continue; + } + + // Reassemble lines split across chunks + let line = if !incomplete_data_line.is_empty() { + let complete = format!("{}{}", incomplete_data_line, line); + incomplete_data_line.clear(); + complete + } else { + line + }; + + // Parse SSE data lines + let Some(data) = line.strip_prefix("data: ") else { + if line.starts_with("event: ") || line.starts_with("id: ") { + debug!("SSE control line: {}", line); + } + continue; + }; + + // Stream completion marker + if data == "[DONE]" { + debug!("Received stream completion marker"); + let final_calls = finalize_tool_calls(tool_calls); + let _ = tx.send(Ok(make_final_chunk(final_calls, None))).await; + return None; + } + + // Parse JSON payload + let parsed = match serde_json::from_str::(data) { + Ok(c) => c, + Err(e) => { + if is_incomplete_json_error(&e, data) { + debug!("Incomplete JSON, buffering for next chunk"); + incomplete_data_line = line; + } else { + debug!("JSON parse error: {}", e); + } + continue; + } + }; + + // Process choices from the chunk + let Some(choices) = parsed.choices else { continue }; + for choice in choices { + // Handle delta content + if let Some(delta) = &choice.delta { + // Text content + if let Some(ref content) = delta.content { + let text_chunk = CompletionChunk { + content: content.clone(), + finished: false, + usage: None, + tool_calls: None, + }; + if tx.send(Ok(text_chunk)).await.is_err() { + debug!("Receiver dropped"); + return None; + } + } + + // Tool call deltas + if let Some(ref deltas) = delta.tool_calls { + for tc_delta in deltas { + let idx = tc_delta.index.unwrap_or(0); + tool_calls + .entry(idx) + .or_default() + .apply_delta(tc_delta); + } + } + } + + // Choice finished + if choice.finish_reason.is_some() { + debug!("Choice finished: {:?}", choice.finish_reason); + let final_calls = finalize_tool_calls(std::mem::take(&mut tool_calls)); + let _ = tx.send(Ok(make_final_chunk(final_calls, None))).await; + return None; + } } } } - // Log final state debug!("Stream ended after {} chunks", chunk_count); - debug!( - "Final state: buffer_len={}, incomplete_data_line_len={}, byte_buffer_len={}", - buffer.len(), - incomplete_data_line.len(), - byte_buffer.len() - ); - debug!("Accumulated tool calls: {}", current_tool_calls.len()); - - // If we have any remaining data in buffers, log it for debugging - if !buffer.is_empty() { - debug!("Remaining buffer content: {:?}", buffer); - } - if !byte_buffer.is_empty() { - debug!("Remaining byte buffer: {} bytes", byte_buffer.len()); - } - if !incomplete_data_line.is_empty() { - debug!("Remaining incomplete data line: {:?}", incomplete_data_line); - } - - // If we have any incomplete data line at the end, try to process it - if !incomplete_data_line.is_empty() { - debug!( - "Processing final incomplete data line (len={})", - incomplete_data_line.len() - ); - if let Some(data) = incomplete_data_line.strip_prefix("data: ") { - // Try to parse it as-is, it might be complete - if let Ok(_chunk) = serde_json::from_str::(data) { - // Process the chunk (code would be duplicated from above, so in practice - // we'd extract this to a helper function) - debug!("Successfully parsed final incomplete data line"); - } else { - warn!("Failed to parse final incomplete data line"); - } - } - } - - // Send final chunk if we haven't already - let final_tool_calls: Vec = current_tool_calls - .values() - .filter(|(_, name, _)| !name.is_empty()) - .map(|(id, name, args)| ToolCall { - id: if id.is_empty() { - format!("tool_{}", name) - } else { - id.clone() - }, - tool: name.clone(), - args: serde_json::from_str(args) - .unwrap_or(serde_json::Value::Object(serde_json::Map::new())), - }) - .collect(); - - let final_chunk = CompletionChunk { - content: String::new(), - finished: true, - usage: accumulated_usage.clone(), - tool_calls: if final_tool_calls.is_empty() { - None - } else { - Some(final_tool_calls) - }, - }; - let _ = tx.send(Ok(final_chunk)).await; - accumulated_usage + let final_calls = finalize_tool_calls(tool_calls); + let _ = tx.send(Ok(make_final_chunk(final_calls, None))).await; + None } pub async fn fetch_supported_models(&mut self) -> Result>> {