diff --git a/Cargo.lock b/Cargo.lock index 8af1cd2..1db310c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -850,6 +850,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "chrono", "futures-util", "g3-config", "g3-execution", diff --git a/crates/g3-config/src/lib.rs b/crates/g3-config/src/lib.rs index 6d02005..2138250 100644 --- a/crates/g3-config/src/lib.rs +++ b/crates/g3-config/src/lib.rs @@ -170,7 +170,8 @@ impl Config { let config = settings.build()?.try_deserialize()?; Ok(config) } - + + #[allow(dead_code)] fn default_qwen_config() -> Self { Self { providers: ProvidersConfig { diff --git a/crates/g3-core/Cargo.toml b/crates/g3-core/Cargo.toml index 26788b8..0370647 100644 --- a/crates/g3-core/Cargo.toml +++ b/crates/g3-core/Cargo.toml @@ -22,3 +22,4 @@ llama_cpp = { version = "0.3.2", features = ["metal"] } shellexpand = "3.1" tokio-util = "0.7" futures-util = "0.3" +chrono = { version = "0.4", features = ["serde"] } diff --git a/crates/g3-core/src/lib.rs b/crates/g3-core/src/lib.rs index bf95b2a..6106bcf 100644 --- a/crates/g3-core/src/lib.rs +++ b/crates/g3-core/src/lib.rs @@ -4,7 +4,6 @@ use g3_execution::CodeExecutor; use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry, Tool}; use serde::{Deserialize, Serialize}; use serde_json::json; -use std::fs; use std::time::{Duration, Instant}; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, warn}; @@ -23,187 +22,195 @@ pub enum StreamState { Resuming, } +/// Modern streaming tool parser that properly handles native tool calls and SSE chunks #[derive(Debug)] pub struct StreamingToolParser { - buffer: String, - brace_count: i32, - in_tool_call: bool, - tool_start_pos: Option, + /// Buffer for accumulating text content + text_buffer: String, + /// Buffer for accumulating native tool calls + native_tool_calls: Vec, + /// Whether we've received a message_stop event + message_stopped: bool, + /// Whether we're currently in a JSON tool call (for fallback parsing) + in_json_tool_call: bool, + /// Start position of JSON tool call (for fallback parsing) + json_tool_start: Option, } impl StreamingToolParser { pub fn new() -> Self { Self { - buffer: String::new(), - brace_count: 0, - in_tool_call: false, - tool_start_pos: None, + text_buffer: String::new(), + native_tool_calls: Vec::new(), + message_stopped: false, + in_json_tool_call: false, + json_tool_start: None, } } - pub fn add_chunk(&mut self, chunk: &str) -> Option<(ToolCall, usize)> { - self.buffer.push_str(chunk); - //info!("Parser buffer now: {:?}", self.buffer); - self.detect_tool_call() + /// Process a streaming chunk and return completed tool calls if any + pub fn process_chunk(&mut self, chunk: &g3_providers::CompletionChunk) -> Vec { + let mut completed_tools = Vec::new(); + + // Add text content to buffer + if !chunk.content.is_empty() { + self.text_buffer.push_str(&chunk.content); + } + + // Handle native tool calls + if let Some(ref tool_calls) = chunk.tool_calls { + debug!("Received native tool calls: {:?}", tool_calls); + + // Accumulate native tool calls + for tool_call in tool_calls { + self.native_tool_calls.push(tool_call.clone()); + } + } + + // Check if message is finished/stopped + if chunk.finished { + self.message_stopped = true; + debug!("Message finished, processing accumulated tool calls"); + } + + // If we have native tool calls and the message is stopped, return them + if self.message_stopped && !self.native_tool_calls.is_empty() { + debug!( + "Converting {} native tool calls", + self.native_tool_calls.len() + ); + + for native_tool in &self.native_tool_calls { + let converted_tool = ToolCall { + tool: native_tool.tool.clone(), + args: native_tool.args.clone(), + }; + completed_tools.push(converted_tool); + } + + // Clear native tool calls after processing + self.native_tool_calls.clear(); + } + + // Fallback: Try to parse JSON tool calls from text if no native tool calls + if completed_tools.is_empty() && !chunk.content.is_empty() { + if let Some(json_tool) = self.try_parse_json_tool_call(&chunk.content) { + completed_tools.push(json_tool); + } + } + + completed_tools } - fn detect_tool_call(&mut self) -> Option<(ToolCall, usize)> { - //info!("Detecting tool call in buffer: {:?}", self.buffer); - - // Look for the start of a tool call pattern: {"tool": - if !self.in_tool_call { - // Look for JSON tool call pattern - check both raw JSON and inside code blocks - // Also handle malformed patterns and whitespace variations - let patterns = [ - r#"{"tool":"#, // Normal pattern - r#"{ "tool":"#, // Pattern with space after opening brace - r#"{"tool" :"#, // Pattern with space before colon - r#"{ "tool" :"#, // Pattern with spaces around tool - r#"{"{""tool"":"#, // Malformed pattern with extra brace and doubled quotes - r#"{{""tool"":"#, // Alternative malformed pattern - ]; + /// Fallback method to parse JSON tool calls from text content + fn try_parse_json_tool_call(&mut self, _content: &str) -> Option { + // Look for JSON tool call patterns + let patterns = [ + r#"{"tool":"#, + r#"{ "tool":"#, + r#"{"tool" :"#, + r#"{ "tool" :"#, + ]; + // If we're not currently in a JSON tool call, look for the start + if !self.in_json_tool_call { for pattern in &patterns { - if let Some(pos) = self.buffer.rfind(pattern) { - // Check if this is inside a code block - let before_pos = &self.buffer[..pos]; - let _code_block_count = before_pos.matches("```").count(); - - // Accept tool calls both inside and outside code blocks - // The LLM might use either format despite our instructions - //info!("Starting tool call parsing (code block status: {})", code_block_count % 2 == 1); - self.in_tool_call = true; - self.tool_start_pos = Some(pos); - self.brace_count = 0; // Start counting from 0, we'll count the opening brace in parsing - - // Continue parsing from after the opening brace - return self.parse_from_start_pos(pos); + if let Some(pos) = self.text_buffer.rfind(pattern) { + debug!( + "Found JSON tool call pattern '{}' at position {}", + pattern, pos + ); + self.in_json_tool_call = true; + self.json_tool_start = Some(pos); + break; } } - } else { - //info!("Already in tool call, continuing parsing"); - // We're already in a tool call, continue parsing - let start_pos = self.tool_start_pos.unwrap(); - return self.parse_from_start_pos(start_pos); } - None - } + // If we're in a JSON tool call, try to find the end and parse it + if self.in_json_tool_call { + if let Some(start_pos) = self.json_tool_start { + let json_text = &self.text_buffer[start_pos..]; - fn parse_from_start_pos(&mut self, start_pos: usize) -> Option<(ToolCall, usize)> { - let remaining = self.buffer[start_pos..].to_string(); - self.parse_from_position(&remaining, start_pos) - } + // Try to find a complete JSON object + let mut brace_count = 0; + let mut in_string = false; + let mut escape_next = false; - fn parse_from_position(&mut self, text: &str, start_pos: usize) -> Option<(ToolCall, usize)> { - let mut current_brace_count = 0; // Always start fresh for each parsing attempt + for (i, ch) in json_text.char_indices() { + if escape_next { + escape_next = false; + continue; + } - //info!("Parsing from position {} with text: {:?}", start_pos, text); - //info!("Starting brace count: {}", current_brace_count); + match ch { + '\\' => escape_next = true, + '"' if !escape_next => in_string = !in_string, + '{' if !in_string => brace_count += 1, + '}' if !in_string => { + brace_count -= 1; + if brace_count == 0 { + // Found complete JSON object + let json_str = &json_text[..=i]; + debug!("Attempting to parse JSON tool call: {}", json_str); - for (i, ch) in text.char_indices() { - match ch { - '{' => current_brace_count += 1, - '}' => { - current_brace_count -= 1; - //info!("Found '}}' at position {}, brace count now: {}", i, current_brace_count); - if current_brace_count == 0 { - // Found complete JSON object - let end_pos = start_pos + i + 1; - let mut json_str = self.buffer[start_pos..end_pos].to_string(); + if let Ok(tool_call) = serde_json::from_str::(json_str) { + debug!("Successfully parsed JSON tool call: {:?}", tool_call); - // Clean up malformed JSON patterns - json_str = json_str - .replace(r#"{"{""#, r#"{"#) // Fix {"{" -> {" - .replace(r#"""}"#, r#""}"#) // Fix ""} -> "} - .replace(r#"{{""#, r#"{"#) // Fix {{" -> {" - .replace(r#"""}"#, r#""}"#); // Fix ""} -> "} + // Reset JSON parsing state + self.in_json_tool_call = false; + self.json_tool_start = None; - // First, try to parse the JSON as-is - if let Ok(tool_call) = serde_json::from_str::(&json_str) { - info!("Successfully parsed tool call: {:?}", tool_call); - // Reset parser state - self.in_tool_call = false; - self.tool_start_pos = None; - self.brace_count = 0; - - return Some((tool_call, end_pos)); - } - - // If parsing failed and this is a shell command, try to fix nested quotes - if json_str.contains(r#""tool": "shell""#) { - let fixed_json = fix_nested_quotes_in_shell_command(&json_str); - if let Ok(tool_call) = serde_json::from_str::(&fixed_json) { - // Reset parser state - self.in_tool_call = false; - self.tool_start_pos = None; - self.brace_count = 0; - - return Some((tool_call, end_pos)); - } else { - info!( - "Failed to parse JSON even after fixing nested quotes: {}", - fixed_json - ); + return Some(tool_call); + } else { + debug!("Failed to parse JSON tool call: {}", json_str); + // Reset and continue looking + self.in_json_tool_call = false; + self.json_tool_start = None; + } + break; } } - - // Try to fix mixed quote issues (single quotes in JSON) - let fixed_mixed_quotes = fix_mixed_quotes_in_json(&json_str); - if fixed_mixed_quotes != json_str { - if let Ok(tool_call) = - serde_json::from_str::(&fixed_mixed_quotes) - { - info!( - "Successfully parsed tool call after fixing mixed quotes: {:?}", - tool_call - ); - // Reset parser state - self.in_tool_call = false; - self.tool_start_pos = None; - self.brace_count = 0; - - return Some((tool_call, end_pos)); - } else { - info!( - "Failed to parse JSON even after fixing mixed quotes: {}", - fixed_mixed_quotes - ); - } - } else { - info!("Failed to parse JSON (no fixes applied): {}", json_str); - } - - // Invalid JSON, reset and continue looking - self.in_tool_call = false; - self.tool_start_pos = None; - self.brace_count = 0; + _ => {} } } - _ => {} } } - // Update brace count for next iteration - self.brace_count = current_brace_count; - //info!("End of parsing, final brace count: {}", current_brace_count); None } - pub fn get_content_before_tool(&self, tool_end_pos: usize) -> String { - if tool_end_pos <= self.buffer.len() { - self.buffer[..tool_end_pos].to_string() + /// Get the accumulated text content (excluding tool calls) + pub fn get_text_content(&self) -> &str { + &self.text_buffer + } + + /// Get content before a specific position (for display purposes) + pub fn get_content_before_position(&self, pos: usize) -> String { + if pos <= self.text_buffer.len() { + self.text_buffer[..pos].to_string() } else { - self.buffer.clone() + self.text_buffer.clone() } } - pub fn get_remaining_content(&self, from_pos: usize) -> String { - if from_pos < self.buffer.len() { - self.buffer[from_pos..].to_string() - } else { - String::new() - } + /// Check if the message has been stopped/finished + pub fn is_message_stopped(&self) -> bool { + self.message_stopped + } + + /// Reset the parser state for a new message + pub fn reset(&mut self) { + self.text_buffer.clear(); + self.native_tool_calls.clear(); + self.message_stopped = false; + self.in_json_tool_call = false; + self.json_tool_start = None; + } + + /// Get the current text buffer length (for position tracking) + pub fn text_buffer_len(&self) -> usize { + self.text_buffer.len() } } @@ -606,9 +613,22 @@ The tool will execute immediately and you'll receive the result (success or erro None }; + // Get max_tokens from provider configuration + // For Databricks, this should be much higher to support large file generation + let max_tokens = match provider.name() { + "databricks" => { + // Use the model's maximum limit for Databricks to allow large file generation + Some(32000) + } + _ => { + // Default for other providers + Some(16000) + } + }; + let request = CompletionRequest { messages, - max_tokens: Some(2048), + max_tokens, temperature: Some(0.1), stream: true, // Enable streaming tools, @@ -723,7 +743,7 @@ The tool will execute immediately and you'll receive the result (success or erro match serde_json::to_string_pretty(&context_data) { Ok(json_content) => { - if let Err(e) = fs::write(&filename, json_content) { + if let Err(e) = std::fs::write(&filename, json_content) { error!("Failed to save context window to {}: {}", filename, e); } } @@ -797,32 +817,6 @@ The tool will execute immediately and you'll receive the result (success or erro "required": ["file_path", "content"] }), }, - // Tool { - // name: "edit_file".to_string(), - // description: "Edit a specific range of lines in a file".to_string(), - // input_schema: json!({ - // "type": "object", - // "properties": { - // "file_path": { - // "type": "string", - // "description": "The path to the file to edit" - // }, - // "start_line": { - // "type": "integer", - // "description": "The starting line number (1-based) of the range to replace" - // }, - // "end_line": { - // "type": "integer", - // "description": "The ending line number (1-based) of the range to replace" - // }, - // "new_text": { - // "type": "string", - // "description": "The new text to replace the specified range" - // } - // }, - // "required": ["file_path", "start_line", "end_line", "new_text"] - // }), - // }, Tool { name: "final_output".to_string(), description: "Signal task completion with a detailed summary".to_string(), @@ -847,6 +841,8 @@ The tool will execute immediately and you'll receive the result (success or erro use std::io::{self, Write}; use tokio_stream::StreamExt; + debug!("Starting stream_completion_with_tools"); + let mut full_response = String::new(); let mut first_token_time: Option = None; let stream_start = Instant::now(); @@ -857,6 +853,7 @@ The tool will execute immediately and you'll receive the result (success or erro loop { iteration_count += 1; + debug!("Starting iteration {}", iteration_count); if iteration_count > MAX_ITERATIONS { warn!("Maximum iterations reached, stopping stream"); break; @@ -868,6 +865,7 @@ The tool will execute immediately and you'll receive the result (success or erro } let provider = self.providers.get(None)?; + debug!("Got provider: {}", provider.name()); let mut stream = match provider.stream(request.clone()).await { Ok(s) => s, Err(e) => { @@ -889,6 +887,7 @@ The tool will execute immediately and you'll receive the result (success or erro } } }; + let mut parser = StreamingToolParser::new(); let mut current_response = String::new(); let mut tool_executed = false; @@ -901,66 +900,27 @@ The tool will execute immediately and you'll receive the result (success or erro first_token_time = Some(stream_start.elapsed()); } - // Check for tool calls - prioritize native tool calls over JSON parsing - let mut detected_tool_call = None; - let mut is_text_based_tool_call = false; + // Process chunk with the new parser + let completed_tools = parser.process_chunk(&chunk); - // First check for native tool calls in the chunk - if let Some(ref tool_calls) = chunk.tool_calls { - debug!("Found native tool calls in chunk: {:?}", tool_calls); - if let Some(first_tool) = tool_calls.first() { - // Convert native tool call to our internal format - detected_tool_call = Some(( - crate::ToolCall { - tool: first_tool.tool.clone(), - args: first_tool.args.clone(), - }, - current_response.len(), // Position doesn't matter for native calls - )); - debug!("Converted native tool call: {:?}", detected_tool_call); - } - } else { - debug!("No native tool calls in chunk, chunk.tool_calls is None"); - } + // Handle completed tool calls + for tool_call in completed_tools { + debug!("Processing completed tool call: {:?}", tool_call); - // Always try JSON parsing as fallback, even for native providers - // This handles cases where Anthropic returns tool calls as text instead of native format - if detected_tool_call.is_none() { - // Try to parse JSON tool calls from text content - detected_tool_call = parser.add_chunk(&chunk.content); - if detected_tool_call.is_some() { - debug!("Found JSON tool call in text content for native provider"); - is_text_based_tool_call = true; - } - } + // Get the text content accumulated so far + let text_content = parser.get_text_content(); - if let Some((tool_call, tool_end_pos)) = detected_tool_call { - // Found a complete tool call! Stop streaming and execute it - let content_before_tool = parser.get_content_before_tool(tool_end_pos); - - // Display content up to the tool call (excluding the JSON and any stop tokens) - let display_content = if let Some(json_start) = - content_before_tool.rfind(r#"{"tool":"#) - { - // Only show content before the JSON tool call - content_before_tool[..json_start].trim() - } else { - // Fallback: clean any stop tokens from the content - content_before_tool.trim() - }; - - // Clean stop tokens from display content - let clean_display_content = display_content + // Clean and prepare display content + let clean_display_content = text_content .replace("<|im_end|>", "") .replace("", "") .replace("[/INST]", "") .replace("<>", ""); let final_display_content = clean_display_content.trim(); - // Safely get the new content to display + // Display any new content before tool execution let new_content = if current_response.len() <= final_display_content.len() { - // Use char indices to avoid UTF-8 boundary issues let chars_already_shown = current_response.chars().count(); final_display_content .chars() @@ -970,43 +930,24 @@ The tool will execute immediately and you'll receive the result (success or erro String::new() }; - // Only print if there's actually new content to show if !new_content.trim().is_empty() { - // Replace thinking indicator with response indicator if not already done if !response_started { - print!("\ršŸ¤– "); // Clear thinking indicator and show response indicator + print!("\ršŸ¤– "); response_started = true; } print!("{}", new_content); io::stdout().flush()?; } - // Check if this was a JSON tool call detected from text (not native) - // If so, show a brief indicator that we detected a text-based tool call - let provider = self.providers.get(None)?; - if provider.has_native_tool_calling() && is_text_based_tool_call { - // This means we detected a JSON tool call in text for a native provider - // Show a brief indicator instead of the raw JSON - if !response_started { - print!("\ršŸ¤– "); // Clear thinking indicator and show response indicator - response_started = true; - } - print!("šŸ”§ "); // Brief tool call indicator - io::stdout().flush()?; - } - // Execute the tool with formatted output println!(); // New line before tool execution // Tool call header println!("ā”Œā”€ {}", tool_call.tool); - debug!("Tool call args object: {:?}", tool_call.args.as_object()); if let Some(args_obj) = tool_call.args.as_object() { for (key, value) in args_obj { - debug!("Processing arg: {} = {:?}", key, value); let value_str = match value { serde_json::Value::String(s) => { - // For shell commands, truncate at newlines to keep display clean if tool_call.tool == "shell" && key == "command" { if let Some(first_line) = s.lines().next() { if s.lines().count() > 1 { @@ -1018,7 +959,6 @@ The tool will execute immediately and you'll receive the result (success or erro s.clone() } } else { - // For other tools, show first 100 chars to avoid huge displays if s.len() > 100 { format!("{}...", &s[..100]) } else { @@ -1030,8 +970,6 @@ The tool will execute immediately and you'll receive the result (success or erro }; println!("│ {}: {}", key, value_str); } - } else { - debug!("No args object found in tool call"); } println!("ā”œā”€ output:"); @@ -1053,12 +991,10 @@ The tool will execute immediately and you'll receive the result (success or erro const MAX_LINES: usize = 5; if output_lines.len() <= MAX_LINES { - // Show all lines if within limit for line in output_lines { println!("│ {}", line); } } else { - // Show first MAX_LINES and add truncation note for line in output_lines.iter().take(MAX_LINES) { println!("│ {}", line); } @@ -1070,17 +1006,15 @@ The tool will execute immediately and you'll receive the result (success or erro ); } - // Check if this was a final_output tool call - if so, stop the conversation + // Check if this was a final_output tool call if tool_call.tool == "final_output" { - // For final_output, don't add the tool call and result to context - // Just add the display content and return immediately full_response.push_str(final_display_content); if let Some(summary) = tool_call.args.get("summary") { if let Some(summary_str) = summary.as_str() { full_response.push_str(&format!("\n\n=> {}", summary_str)); } } - println!(); // New line after final output + println!(); let ttft = first_token_time.unwrap_or_else(|| stream_start.elapsed()); return Ok((full_response, ttft)); @@ -1089,25 +1023,24 @@ The tool will execute immediately and you'll receive the result (success or erro // Closure marker with timing println!("└─ āš”ļø {}", Self::format_duration(exec_duration)); println!(); - print!("šŸ¤– "); // Continue response indicator + print!("šŸ¤– "); io::stdout().flush()?; - // Add the tool call and result to the context window immediately + // Add the tool call and result to the context window let tool_message = Message { role: MessageRole::Assistant, content: format!( "{}\n\n{{\"tool\": \"{}\", \"args\": {}}}", - display_content.trim(), + final_display_content.trim(), tool_call.tool, tool_call.args ), }; let result_message = Message { - role: MessageRole::User, // Tool results come back as user messages + role: MessageRole::User, content: format!("Tool result: {}", tool_result), }; - // Add to context window for persistence self.context_window.add_message(tool_message); self.context_window.add_message(result_message); @@ -1120,17 +1053,15 @@ The tool will execute immediately and you'll receive the result (success or erro } full_response.push_str(final_display_content); - // full_response.push_str(&format!( - // "\n\nTool executed: {} -> {}\n\n", - // tool_call.tool, tool_result - // )); - tool_executed = true; - // Break out of current stream to start a new one with updated context - break; - } else { - // No tool call detected, continue streaming normally - // But first check if we need to filter JSON tool calls from display + + // Reset parser for next iteration + parser.reset(); + break; // Break out of current stream to start a new one + } + + // If no tool calls were completed, continue streaming normally + if !tool_executed { let clean_content = chunk .content .replace("<|im_end|>", "") @@ -1139,24 +1070,16 @@ The tool will execute immediately and you'll receive the result (success or erro .replace("<>", ""); if !clean_content.is_empty() { - // Filter out JSON tool calls from display let filtered_content = filter_json_tool_calls(&clean_content); - // If we have any content to display if !filtered_content.is_empty() { - // Replace thinking indicator with response indicator on first content if !response_started { - print!("\ršŸ¤– "); // Clear thinking indicator and show response indicator + print!("\ršŸ¤– "); response_started = true; } - debug!("Printing filtered content: '{}'", filtered_content); print!("{}", filtered_content); - let _ = io::stdout().flush(); // Force immediate output - debug!( - "Flushed {} characters to stdout", - filtered_content.len() - ); + let _ = io::stdout().flush(); current_response.push_str(&filtered_content); } } @@ -1165,7 +1088,7 @@ The tool will execute immediately and you'll receive the result (success or erro if chunk.finished { // Stream finished naturally without tool calls full_response.push_str(¤t_response); - println!(); // New line after streaming completes + println!(); let ttft = first_token_time.unwrap_or_else(|| stream_start.elapsed()); return Ok((full_response, ttft)); } @@ -1173,7 +1096,6 @@ The tool will execute immediately and you'll receive the result (success or erro Err(e) => { error!("Streaming error: {}", e); - // If we executed a tool, try to continue with a new stream if tool_executed { warn!("Stream error after tool execution, attempting to continue"); break; // Break to outer loop to start new stream @@ -1187,7 +1109,7 @@ The tool will execute immediately and you'll receive the result (success or erro // If we get here and no tool was executed, we're done if !tool_executed { full_response.push_str(¤t_response); - println!(); // New line after streaming completes + println!(); let ttft = first_token_time.unwrap_or_else(|| stream_start.elapsed()); return Ok((full_response, ttft)); } @@ -1201,13 +1123,15 @@ The tool will execute immediately and you'll receive the result (success or erro } async fn execute_tool(&self, tool_call: &ToolCall) -> Result { - debug!("Executing tool: {}", tool_call.tool); - debug!("Tool call args: {:?}", tool_call.args); + debug!("=== EXECUTING TOOL ==="); + debug!("Tool name: {}", tool_call.tool); + debug!("Tool args (raw): {:?}", tool_call.args); debug!( - "Tool call args JSON: {}", + "Tool args (JSON): {}", serde_json::to_string(&tool_call.args) .unwrap_or_else(|_| "failed to serialize".to_string()) ); + debug!("======================"); match tool_call.tool.as_str() { "shell" => { @@ -1280,175 +1204,177 @@ The tool will execute immediately and you'll receive the result (success or erro serde_json::to_string(&tool_call.args) .unwrap_or_else(|_| "failed to serialize".to_string()) ); + debug!( + "Args type: {:?}", + std::any::type_name_of_val(&tool_call.args) + ); + debug!("Args is_object: {}", tool_call.args.is_object()); + debug!("Args is_array: {}", tool_call.args.is_array()); + debug!("Args is_null: {}", tool_call.args.is_null()); - let file_path = tool_call.args.get("file_path"); - let content = tool_call.args.get("content"); - - debug!("file_path: {:?}", file_path); - debug!("content: {:?}", content); - - if let (Some(path_val), Some(content_val)) = (file_path, content) { - if let (Some(path_str), Some(content_str)) = - (path_val.as_str(), content_val.as_str()) - { - debug!("Writing to file: {}", path_str); - - // Create parent directories if they don't exist - if let Some(parent) = std::path::Path::new(path_str).parent() { - if let Err(e) = std::fs::create_dir_all(parent) { - return Ok(format!( - "āŒ Failed to create parent directories for '{}': {}", - path_str, e - )); - } - } - - match std::fs::write(path_str, content_str) { - Ok(()) => { - let line_count = content_str.lines().count(); - Ok(format!( - "āœ… Successfully wrote {} lines to '{}'", - line_count, path_str - )) - } - Err(e) => { - Ok(format!("āŒ Failed to write to file '{}': {}", path_str, e)) - } - } - } else { - Ok("āŒ Invalid file_path or content argument".to_string()) - } - } else { - Ok("āŒ Missing file_path or content argument".to_string()) - } - } - "edit_file" => { - debug!("Processing edit_file tool call"); - debug!("Raw tool_call.args: {:?}", tool_call.args); - - let file_path = tool_call.args.get("file_path"); - let start_line = tool_call.args.get("start_line"); - let end_line = tool_call.args.get("end_line"); - let new_text = tool_call.args.get("new_text"); - - debug!("Extracted values - file_path: {:?}, start_line: {:?}, end_line: {:?}, new_text: {:?}", - file_path, start_line, end_line, new_text); - - if let (Some(path_val), Some(start_val), Some(end_val), Some(text_val)) = - (file_path, start_line, end_line, new_text) - { - debug!("All required arguments present"); + // Try multiple argument formats that different providers might use + let (path_str, content_str) = if let Some(args_obj) = tool_call.args.as_object() { debug!( - "path_val: {:?}, start_val: {:?}, end_val: {:?}, text_val: {:?}", - path_val, start_val, end_val, text_val + "Args object keys: {:?}", + args_obj.keys().collect::>() ); - if let (Some(path_str), Some(start_num), Some(end_num), Some(text_str)) = ( - path_val.as_str(), - start_val.as_i64(), - end_val.as_i64(), - text_val.as_str(), - ) { - debug!("Successfully converted types - path: {}, start: {}, end: {}, text_len: {}", - path_str, start_num, end_num, text_str.len()); - - // Validate line numbers - if start_num < 1 || end_num < 1 || start_num > end_num { - return Ok("āŒ Invalid line numbers: start_line and end_line must be >= 1 and start_line <= end_line".to_string()); - } - - // Read the current file content - let original_content = match std::fs::read_to_string(path_str) { - Ok(content) => content, - Err(e) => { - return Ok(format!("āŒ Failed to read file '{}': {}", path_str, e)) - } - }; - - let lines: Vec<&str> = original_content.lines().collect(); - let total_lines = lines.len(); - debug!("File has {} lines", total_lines); - - // Convert to 0-based indexing - let start_idx = (start_num - 1) as usize; - let end_idx = (end_num - 1) as usize; - debug!( - "Using 0-based indices: start_idx={}, end_idx={}", - start_idx, end_idx - ); - - // Validate line ranges - if start_idx >= total_lines { - return Ok(format!( - "āŒ start_line {} is beyond file length ({} lines)", - start_num, total_lines - )); - } - if end_idx >= total_lines { - return Ok(format!( - "āŒ end_line {} is beyond file length ({} lines)", - end_num, total_lines - )); - } - - // Split new_text into lines - let new_lines: Vec<&str> = if text_str.is_empty() { - vec![] + // Format 1: Standard format with file_path and content + if let (Some(path_val), Some(content_val)) = + (args_obj.get("file_path"), args_obj.get("content")) + { + debug!("Found file_path and content keys"); + if let (Some(path), Some(content)) = + (path_val.as_str(), content_val.as_str()) + { + debug!( + "Successfully extracted file_path='{}', content_len={}", + path, + content.len() + ); + (Some(path), Some(content)) } else { - text_str.lines().collect() - }; - - let new_lines_count = new_lines.len(); - debug!("New text has {} lines", new_lines_count); - - // Create the new content - let mut new_content_lines = Vec::new(); - - // Add lines before the edit range - new_content_lines.extend_from_slice(&lines[..start_idx]); - - // Add the new lines - new_content_lines.extend(new_lines); - - // Add lines after the edit range - if end_idx + 1 < lines.len() { - new_content_lines.extend_from_slice(&lines[end_idx + 1..]); + debug!("file_path or content values are not strings: path_val={:?}, content_val={:?}", path_val, content_val); + (None, None) } - - // Join the lines back together - let new_content = new_content_lines.join("\n"); - debug!("New content length: {} characters", new_content.len()); - - // Write the modified content back to the file - match std::fs::write(path_str, &new_content) { - Ok(()) => { - let old_range_size = end_idx - start_idx + 1; - Ok(format!("āœ… Successfully edited '{}': replaced {} lines ({}:{}) with {} lines", - path_str, old_range_size, start_num, end_num, new_lines_count)) - } - Err(e) => Ok(format!( - "āŒ Failed to write edited content to '{}': {}", - path_str, e - )), + } + // Format 2: Anthropic-style with path and content + else if let (Some(path_val), Some(content_val)) = + (args_obj.get("path"), args_obj.get("content")) + { + debug!("Found path and content keys (Anthropic style)"); + if let (Some(path), Some(content)) = + (path_val.as_str(), content_val.as_str()) + { + debug!( + "Successfully extracted path='{}', content_len={}", + path, + content.len() + ); + (Some(path), Some(content)) + } else { + debug!("path or content values are not strings: path_val={:?}, content_val={:?}", path_val, content_val); + (None, None) + } + } + // Format 3: Alternative naming with filename and text + else if let (Some(path_val), Some(content_val)) = + (args_obj.get("filename"), args_obj.get("text")) + { + debug!("Found filename and text keys"); + if let (Some(path), Some(content)) = + (path_val.as_str(), content_val.as_str()) + { + debug!( + "Successfully extracted filename='{}', text_len={}", + path, + content.len() + ); + (Some(path), Some(content)) + } else { + debug!("filename or text values are not strings: path_val={:?}, content_val={:?}", path_val, content_val); + (None, None) + } + } + // Format 4: Alternative naming with file and data + else if let (Some(path_val), Some(content_val)) = + (args_obj.get("file"), args_obj.get("data")) + { + debug!("Found file and data keys"); + if let (Some(path), Some(content)) = + (path_val.as_str(), content_val.as_str()) + { + debug!( + "Successfully extracted file='{}', data_len={}", + path, + content.len() + ); + (Some(path), Some(content)) + } else { + debug!("file or data values are not strings: path_val={:?}, content_val={:?}", path_val, content_val); + (None, None) } } else { - debug!("Type conversion failed:"); - debug!(" path_val.as_str(): {:?}", path_val.as_str()); - debug!(" start_val.as_i64(): {:?}", start_val.as_i64()); - debug!(" end_val.as_i64(): {:?}", end_val.as_i64()); - debug!(" text_val.as_str(): {:?}", text_val.as_str()); - Ok("āŒ Invalid argument types: file_path must be string, start_line and end_line must be integers, new_text must be string".to_string()) + debug!( + "No matching key patterns found. Available argument keys: {:?}", + args_obj.keys().collect::>() + ); + (None, None) } } else { - debug!("Missing required arguments:"); - debug!(" file_path present: {}", file_path.is_some()); - debug!(" start_line present: {}", start_line.is_some()); - debug!(" end_line present: {}", end_line.is_some()); - debug!(" new_text present: {}", new_text.is_some()); - Ok( - "āŒ Missing required arguments: file_path, start_line, end_line, new_text" - .to_string(), - ) + debug!("Args is not an object, checking if it's an array"); + // Format 5: Args might be an array [path, content] + if let Some(args_array) = tool_call.args.as_array() { + debug!("Args is an array with {} elements", args_array.len()); + if args_array.len() >= 2 { + if let (Some(path), Some(content)) = + (args_array[0].as_str(), args_array[1].as_str()) + { + debug!( + "Successfully extracted from array: path='{}', content_len={}", + path, + content.len() + ); + (Some(path), Some(content)) + } else { + debug!( + "Array elements are not strings: [0]={:?}, [1]={:?}", + args_array[0], args_array[1] + ); + (None, None) + } + } else { + debug!("Array has insufficient elements: {}", args_array.len()); + (None, None) + } + } else { + debug!("Args is neither object nor array"); + (None, None) + } + }; + + debug!( + "Final extracted values: path_str={:?}, content_str_len={:?}", + path_str, + content_str.map(|c| c.len()) + ); + + if let (Some(path), Some(content)) = (path_str, content_str) { + debug!("Writing to file: {}", path); + + // Create parent directories if they don't exist + if let Some(parent) = std::path::Path::new(path).parent() { + if let Err(e) = std::fs::create_dir_all(parent) { + return Ok(format!( + "āŒ Failed to create parent directories for '{}': {}", + path, e + )); + } + } + + match std::fs::write(path, content) { + Ok(()) => { + let line_count = content.lines().count(); + let char_count = content.len(); + Ok(format!( + "āœ… Successfully wrote {} lines ({} characters) to '{}'", + line_count, char_count, path + )) + } + Err(e) => Ok(format!("āŒ Failed to write to file '{}': {}", path, e)), + } + } else { + // Provide more detailed error information + let available_keys = if let Some(obj) = tool_call.args.as_object() { + obj.keys().collect::>() + } else { + vec![] + }; + + Ok(format!( + "āŒ Missing file_path or content argument. Available keys: {:?}. Expected formats: {{\"file_path\": \"...\", \"content\": \"...\"}}, {{\"path\": \"...\", \"content\": \"...\"}}, {{\"filename\": \"...\", \"text\": \"...\"}}, or {{\"file\": \"...\", \"data\": \"...\"}}", + available_keys + )) } } "final_output" => { @@ -1525,7 +1451,8 @@ fn filter_json_tool_calls(content: &str) -> String { && (trimmed.contains("tool") || trimmed.contains("args") || trimmed.contains(r#"""#))) // Catch malformed tool calls like: {"tool": "write_file", "path || (trimmed.contains(r#""tool":"#) || trimmed.contains(r#""tool": "#)) - || (trimmed.starts_with(r#"{"#) && trimmed.contains(r#"", ""#)) // JSON with quoted comma pattern + || (trimmed.starts_with(r#"{"#) && trimmed.contains(r#"", ""#)) + // JSON with quoted comma pattern { // This looks like part of a JSON tool call, suppress it "".to_string() @@ -1620,9 +1547,11 @@ fn shell_escape_command(command: &str) -> String { } } -// Helper function to fix nested quotes in shell command JSON +// Helper function to fix mixed quotes in JSON strings +#[allow(dead_code)] fn fix_nested_quotes_in_shell_command(json_str: &str) -> String { - // This handles cases where shell commands contain nested quotes that break JSON parsing + let mut _result = String::new(); + let _chars = json_str.chars().peekable(); // Example: {"tool": "shell", "args": {"command": "python -c 'import os; print("hello")'"}} // Look for the pattern: "command": " @@ -1680,10 +1609,8 @@ fn fix_nested_quotes_in_shell_command(json_str: &str) -> String { } // Helper function to fix mixed quotes in JSON (single quotes where double quotes should be) +#[allow(dead_code)] fn fix_mixed_quotes_in_json(json_str: &str) -> String { - // This handles cases where the LLM uses single quotes in JSON values - // Example: {"tool": "shell", "args": {"command": 'echo "hello"'}} - let mut result = String::new(); let mut chars = json_str.chars().peekable(); let mut in_string = false; diff --git a/crates/g3-providers/src/databricks.rs b/crates/g3-providers/src/databricks.rs index 7da1ebc..b38d755 100644 --- a/crates/g3-providers/src/databricks.rs +++ b/crates/g3-providers/src/databricks.rs @@ -79,10 +79,9 @@ const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"]; const DEFAULT_TIMEOUT_SECS: u64 = 600; pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-sonnet-4"; -const DATABRICKS_DEFAULT_FAST_MODEL: &str = "gemini-1-5-flash"; pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[ "databricks-claude-3-7-sonnet", - "databricks-meta-llama-3-3-70b-instruct", + "databricks-meta-llama-3-3-70b-instruct", "databricks-meta-llama-3-1-405b-instruct", "databricks-dbrx-instruct", "databricks-mixtral-8x7b-instruct", @@ -155,14 +154,17 @@ impl DatabricksProvider { .build() .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?; - info!("Initialized Databricks provider with model: {} on host: {}", model, host); + info!( + "Initialized Databricks provider with model: {} on host: {}", + model, host + ); Ok(Self { client, host: host.trim_end_matches('/').to_string(), auth: DatabricksAuth::token(token), model, - max_tokens: max_tokens.unwrap_or(4096), + max_tokens: max_tokens.unwrap_or(50000), temperature: temperature.unwrap_or(0.1), }) } @@ -178,24 +180,30 @@ impl DatabricksProvider { .build() .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?; - info!("Initialized Databricks provider with OAuth for model: {} on host: {}", model, host); + info!( + "Initialized Databricks provider with OAuth for model: {} on host: {}", + model, host + ); Ok(Self { client, host: host.trim_end_matches('/').to_string(), auth: DatabricksAuth::oauth(host.clone()), model, - max_tokens: max_tokens.unwrap_or(4096), + max_tokens: max_tokens.unwrap_or(50000), temperature: temperature.unwrap_or(0.1), }) } async fn create_request_builder(&mut self, streaming: bool) -> Result { let token = self.auth.get_token().await?; - + let mut builder = self .client - .post(&format!("{}/serving-endpoints/{}/invocations", self.host, self.model)) + .post(&format!( + "{}/serving-endpoints/{}/invocations", + self.host, self.model + )) .header("Authorization", format!("Bearer {}", token)) .header("Content-Type", "application/json"); @@ -226,7 +234,7 @@ impl DatabricksProvider { for message in messages { let role = match message.role { MessageRole::System => "system", - MessageRole::User => "user", + MessageRole::User => "user", MessageRole::Assistant => "assistant", }; @@ -274,13 +282,26 @@ impl DatabricksProvider { tx: mpsc::Sender>, ) { let mut buffer = String::new(); - let mut current_tool_calls: std::collections::HashMap = std::collections::HashMap::new(); // index -> (id, name, args) + let mut current_tool_calls: std::collections::HashMap = + std::collections::HashMap::new(); // index -> (id, name, args) + let mut incomplete_data_line = String::new(); // Buffer for incomplete data: lines while let Some(chunk_result) = stream.next().await { match chunk_result { Ok(chunk) => { + // Debug: Log raw bytes received + debug!("Raw SSE bytes received: {} bytes", chunk.len()); + let chunk_str = match std::str::from_utf8(&chunk) { - Ok(s) => s, + Ok(s) => { + // Debug: Log raw string content (truncated for large chunks) + if s.len() > 1000 { + debug!("Raw SSE string content (first 500 chars): {:?}...", &s[..500]); + } else { + debug!("Raw SSE string content: {:?}", s); + } + s + }, Err(e) => { error!("Invalid UTF-8 in stream chunk: {}", e); let _ = tx @@ -292,7 +313,7 @@ impl DatabricksProvider { buffer.push_str(chunk_str); - // Process complete lines + // Process complete lines, but handle incomplete data: lines specially while let Some(line_end) = buffer.find('\n') { let line = buffer[..line_end].trim().to_string(); buffer.drain(..line_end + 1); @@ -301,21 +322,55 @@ impl DatabricksProvider { continue; } + // Check if we have an incomplete data line from previous chunk + let line = if !incomplete_data_line.is_empty() { + // We had an incomplete data: line, append this line to it + let complete_line = format!("{}{}", incomplete_data_line, line); + incomplete_data_line.clear(); + complete_line + } else { + line + }; + + // Check if this is a data: line that might be incomplete + // SSE format requires double newline after data, so if we don't see another newline + // after this one in the buffer, and it's a data: line, it might be incomplete + if line.starts_with("data: ") { + // Check if there's a complete SSE event (should have double newline after data) + // But for streaming, single newline is often used, so we need to be careful + // The safest approach is to try parsing and if it fails due to incomplete JSON, + // we'll handle it below + } + + // Debug: Log each SSE line (truncated for large lines) + if line.len() > 1000 { + debug!("SSE line (first 500 chars): {:?}...", &line[..500]); + } else { + debug!("SSE line: {:?}", line); + } + // Parse Server-Sent Events format if let Some(data) = line.strip_prefix("data: ") { if data == "[DONE]" { debug!("Received stream completion marker"); - let final_tool_calls: Vec = current_tool_calls.values() + let final_tool_calls: Vec = current_tool_calls + .values() .map(|(id, name, args)| ToolCall { id: id.clone(), tool: name.clone(), - args: serde_json::from_str(args).unwrap_or(serde_json::Value::Object(serde_json::Map::new())), + args: serde_json::from_str(args).unwrap_or( + serde_json::Value::Object(serde_json::Map::new()), + ), }) .collect(); let final_chunk = CompletionChunk { content: String::new(), finished: true, - tool_calls: if final_tool_calls.is_empty() { None } else { Some(final_tool_calls) }, + tool_calls: if final_tool_calls.is_empty() { + None + } else { + Some(final_tool_calls) + }, }; if tx.send(Ok(final_chunk)).await.is_err() { debug!("Receiver dropped, stopping stream"); @@ -323,12 +378,17 @@ impl DatabricksProvider { return; } - debug!("Raw Databricks API JSON: {}", data); + // Debug: Log every raw JSON payload from Databricks API (truncated for large payloads) + if data.len() > 1000 { + debug!("Raw Databricks SSE JSON payload (first 500 chars): {}...", &data[..500]); + } else { + debug!("Raw Databricks SSE JSON payload: {}", data); + } match serde_json::from_str::(data) { Ok(chunk) => { - debug!("Parsed stream chunk: {:?}", chunk); - + debug!("Successfully parsed Databricks stream chunk"); + // Handle different types of chunks if let Some(choices) = chunk.choices { for choice in choices { @@ -349,57 +409,93 @@ impl DatabricksProvider { // Handle tool calls - accumulate across chunks if let Some(tool_calls) = delta.tool_calls { + debug!("Processing {} tool call deltas", tool_calls.len()); for tool_call in tool_calls { let index = tool_call.index.unwrap_or(0); - let entry = current_tool_calls.entry(index).or_insert_with(|| { - (String::new(), String::new(), String::new()) - }); + debug!("Tool call delta for index {}: id={:?}, name='{}', args_len={}", + index, tool_call.id, tool_call.function.name, tool_call.function.arguments.len()); + + let entry = current_tool_calls + .entry(index) + .or_insert_with(|| { + ( + String::new(), + String::new(), + String::new(), + ) + }); // Update ID if provided if let Some(id) = tool_call.id { + debug!("Updating tool call {} ID from '{}' to '{}'", index, entry.0, id); entry.0 = id; } // Update name if provided and not empty if !tool_call.function.name.is_empty() { + debug!("Updating tool call {} name from '{}' to '{}'", index, entry.1, tool_call.function.name); entry.1 = tool_call.function.name; } // Append arguments - entry.2.push_str(&tool_call.function.arguments); + debug!("Appending {} chars to tool call {} args (current len: {})", + tool_call.function.arguments.len(), index, entry.2.len()); + entry.2.push_str( + &tool_call.function.arguments, + ); - debug!("Accumulated tool call {}: id='{}', name='{}', args='{}'", - index, entry.0, entry.1, entry.2); + debug!("Accumulated tool call {}: id='{}', name='{}', args_len={}", + index, entry.0, entry.1, entry.2.len()); + + // Debug: Show a sample of the accumulated args if they're getting long + if entry.2.len() > 100 { + debug!("Tool call {} args sample (first 100 chars): {}", index, &entry.2[..100]); + } else if !entry.2.is_empty() { + debug!("Tool call {} full args: {}", index, entry.2); + } } } } // Check if this choice is finished if choice.finish_reason.is_some() { - debug!("Choice finished with reason: {:?}", choice.finish_reason); - + debug!( + "Choice finished with reason: {:?}", + choice.finish_reason + ); + // Convert accumulated tool calls to final format let final_tool_calls: Vec = current_tool_calls.values() .filter(|(_, name, _)| !name.is_empty()) // Only include tool calls with names .map(|(id, name, args)| { - debug!("Converting tool call: id='{}', name='{}', args='{}'", id, name, args); + debug!("Converting tool call: id='{}', name='{}', args_len={}", id, name, args.len()); ToolCall { id: if id.is_empty() { format!("tool_{}", name) } else { id.clone() }, tool: name.clone(), args: serde_json::from_str(args).unwrap_or_else(|e| { - debug!("Failed to parse tool args '{}': {}", args, e); + debug!("Failed to parse tool args (len={}): {}", args.len(), e); + // For debugging, log a sample of the args if they're very long + if args.len() > 1000 { + debug!("Tool args sample (first 500 chars): {}", &args[..500]); + } else { + debug!("Full tool args: {}", args); + } serde_json::Value::Object(serde_json::Map::new()) }), } }) .collect(); - debug!("Final tool calls: {:?}", final_tool_calls); + debug!("Final tool calls count: {}", final_tool_calls.len()); let final_chunk = CompletionChunk { content: String::new(), finished: true, - tool_calls: if final_tool_calls.is_empty() { None } else { Some(final_tool_calls) }, + tool_calls: if final_tool_calls.is_empty() { + None + } else { + Some(final_tool_calls) + }, }; if tx.send(Ok(final_chunk)).await.is_err() { debug!("Receiver dropped, stopping stream"); @@ -410,10 +506,36 @@ impl DatabricksProvider { } } Err(e) => { - debug!("Failed to parse stream chunk: {} - Data: {}", e, data); + // Check if this is likely an incomplete JSON due to line splitting + // Common indicators: unexpected EOF, unterminated string, etc. + let error_str = e.to_string().to_lowercase(); + if line.starts_with("data: ") && ( + error_str.contains("eof") || + error_str.contains("unterminated") || + error_str.contains("unexpected end") || + error_str.contains("trailing") || + // Also check if the data doesn't end with a proper JSON terminator + (!data.trim_end().ends_with('}') && !data.trim_end().ends_with(']')) + ) { + // This looks like an incomplete data line, save it for the next chunk + debug!("Detected incomplete data line (len={}), buffering for next chunk", line.len()); + incomplete_data_line = line.clone(); + // Continue to next iteration without processing + continue; + } else { + // This is a real parse error, not due to line splitting + debug!("Failed to parse Databricks stream chunk JSON: {} - Data length: {}", e, data.len()); + // For debugging large payloads, log a sample + if data.len() > 1000 { + debug!("JSON parse error - data sample: {}", &data[..std::cmp::min(500, data.len())]); + } + } // Don't error out on parse failures, just continue } } + } else if line.starts_with("event: ") || line.starts_with("id: ") { + // Debug: Log non-data SSE lines (like event: or id:) + debug!("Non-data SSE line: {}", line); } } } @@ -425,27 +547,52 @@ impl DatabricksProvider { } } + // If we have any incomplete data line at the end, try to process it + if !incomplete_data_line.is_empty() { + debug!("Processing final incomplete data line (len={})", incomplete_data_line.len()); + if let Some(data) = incomplete_data_line.strip_prefix("data: ") { + // Try to parse it as-is, it might be complete + if let Ok(_chunk) = serde_json::from_str::(data) { + // Process the chunk (code would be duplicated from above, so in practice + // we'd extract this to a helper function) + debug!("Successfully parsed final incomplete data line"); + } else { + warn!("Failed to parse final incomplete data line"); + } + } + } + // Send final chunk if we haven't already - let final_tool_calls: Vec = current_tool_calls.values() + let final_tool_calls: Vec = current_tool_calls + .values() .filter(|(_, name, _)| !name.is_empty()) .map(|(id, name, args)| ToolCall { - id: if id.is_empty() { format!("tool_{}", name) } else { id.clone() }, + id: if id.is_empty() { + format!("tool_{}", name) + } else { + id.clone() + }, tool: name.clone(), - args: serde_json::from_str(args).unwrap_or(serde_json::Value::Object(serde_json::Map::new())), + args: serde_json::from_str(args) + .unwrap_or(serde_json::Value::Object(serde_json::Map::new())), }) .collect(); let final_chunk = CompletionChunk { content: String::new(), finished: true, - tool_calls: if final_tool_calls.is_empty() { None } else { Some(final_tool_calls) }, + tool_calls: if final_tool_calls.is_empty() { + None + } else { + Some(final_tool_calls) + }, }; let _ = tx.send(Ok(final_chunk)).await; } pub async fn fetch_supported_models(&mut self) -> Result>> { let token = self.auth.get_token().await?; - + let response = match self .client .get(&format!("{}/api/2.0/serving-endpoints", self.host)) @@ -465,8 +612,7 @@ impl DatabricksProvider { if let Ok(error_text) = response.text().await { warn!( "Failed to fetch Databricks models: {} - {}", - status, - error_text + status, error_text ); } else { warn!("Failed to fetch Databricks models: {}", status); @@ -485,9 +631,7 @@ impl DatabricksProvider { let endpoints = match json.get("endpoints").and_then(|v| v.as_array()) { Some(endpoints) => endpoints, None => { - warn!( - "Unexpected response format from Databricks API: missing 'endpoints' array" - ); + warn!("Unexpected response format from Databricks API: missing 'endpoints' array"); return Ok(None); } }; @@ -527,19 +671,25 @@ impl LLMProvider for DatabricksProvider { let temperature = request.temperature.unwrap_or(self.temperature); let request_body = self.create_request_body( - &request.messages, - request.tools.as_deref(), - false, - max_tokens, - temperature + &request.messages, + request.tools.as_deref(), + false, + max_tokens, + temperature, )?; - debug!("Sending request to Databricks API: model={}, max_tokens={}, temperature={}", - self.model, request_body.max_tokens, request_body.temperature); - + debug!( + "Sending request to Databricks API: model={}, max_tokens={}, temperature={}", + self.model, request_body.max_tokens, request_body.temperature + ); + // Debug: Log the full request body when tools are present if request.tools.is_some() { - debug!("Full request body with tools: {}", serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "Failed to serialize".to_string())); + debug!( + "Full request body with tools: {}", + serde_json::to_string_pretty(&request_body) + .unwrap_or_else(|_| "Failed to serialize".to_string()) + ); } let mut provider_clone = self.clone(); @@ -564,7 +714,13 @@ impl LLMProvider for DatabricksProvider { debug!("Raw Databricks API response: {}", response_text); let databricks_response: DatabricksResponse = serde_json::from_str(&response_text) - .map_err(|e| anyhow!("Failed to parse Databricks response: {} - Response: {}", e, response_text))?; + .map_err(|e| { + anyhow!( + "Failed to parse Databricks response: {} - Response: {}", + e, + response_text + ) + })?; // Debug: Log the parsed response structure debug!("Parsed Databricks response: {:#?}", databricks_response); @@ -580,11 +736,17 @@ impl LLMProvider for DatabricksProvider { // Check if there are tool calls in the response if let Some(first_choice) = databricks_response.choices.first() { if let Some(tool_calls) = &first_choice.message.tool_calls { - debug!("Found {} tool calls in Databricks response", tool_calls.len()); + debug!( + "Found {} tool calls in Databricks response", + tool_calls.len() + ); for (i, tool_call) in tool_calls.iter().enumerate() { - debug!("Tool call {}: {} with args: {}", i, tool_call.function.name, tool_call.function.arguments); + debug!( + "Tool call {}: {} with args: {}", + i, tool_call.function.name, tool_call.function.arguments + ); } - + // For now, we'll return the content as-is since g3 handles tool calls via streaming // In the future, we might need to convert these to the internal format } @@ -618,18 +780,24 @@ impl LLMProvider for DatabricksProvider { let temperature = request.temperature.unwrap_or(self.temperature); let request_body = self.create_request_body( - &request.messages, - request.tools.as_deref(), - true, - max_tokens, - temperature + &request.messages, + request.tools.as_deref(), + true, + max_tokens, + temperature, )?; - debug!("Sending streaming request to Databricks API: model={}, max_tokens={}, temperature={}", - self.model, request_body.max_tokens, request_body.temperature); - + debug!( + "Sending streaming request to Databricks API: model={}, max_tokens={}, temperature={}", + self.model, request_body.max_tokens, request_body.temperature + ); + // Debug: Log the full request body - debug!("Full request body: {}", serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "Failed to serialize".to_string())); + debug!( + "Full request body: {}", + serde_json::to_string_pretty(&request_body) + .unwrap_or_else(|_| "Failed to serialize".to_string()) + ); let mut provider_clone = self.clone(); let response = provider_clone @@ -731,6 +899,7 @@ struct DatabricksResponse { #[derive(Debug, Deserialize)] struct DatabricksChoice { message: DatabricksMessage, + #[allow(dead_code)] finish_reason: Option, } @@ -786,7 +955,8 @@ mod tests { "test-model".to_string(), None, None, - ).unwrap(); + ) + .unwrap(); let messages = vec![ Message { @@ -819,14 +989,13 @@ mod tests { "databricks-claude-sonnet-4".to_string(), Some(1000), Some(0.5), - ).unwrap(); + ) + .unwrap(); - let messages = vec![ - Message { - role: MessageRole::User, - content: "Test message".to_string(), - }, - ]; + let messages = vec![Message { + role: MessageRole::User, + content: "Test message".to_string(), + }]; let request_body = provider .create_request_body(&messages, None, false, 1000, 0.5) @@ -847,31 +1016,33 @@ mod tests { "test-model".to_string(), None, None, - ).unwrap(); + ) + .unwrap(); - let tools = vec![ - Tool { - name: "get_weather".to_string(), - description: "Get the current weather".to_string(), - input_schema: serde_json::json!({ - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state" - } - }, - "required": ["location"] - }), - }, - ]; + let tools = vec![Tool { + name: "get_weather".to_string(), + description: "Get the current weather".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state" + } + }, + "required": ["location"] + }), + }]; let databricks_tools = provider.convert_tools(&tools); assert_eq!(databricks_tools.len(), 1); assert_eq!(databricks_tools[0].r#type, "function"); assert_eq!(databricks_tools[0].function.name, "get_weather"); - assert_eq!(databricks_tools[0].function.description, "Get the current weather"); + assert_eq!( + databricks_tools[0].function.description, + "Get the current weather" + ); } #[test] @@ -882,7 +1053,8 @@ mod tests { "databricks-claude-sonnet-4".to_string(), None, None, - ).unwrap(); + ) + .unwrap(); let llama_provider = DatabricksProvider::from_token( "https://test.databricks.com".to_string(), @@ -890,7 +1062,8 @@ mod tests { "databricks-meta-llama-3-3-70b-instruct".to_string(), None, None, - ).unwrap(); + ) + .unwrap(); let dbrx_provider = DatabricksProvider::from_token( "https://test.databricks.com".to_string(), @@ -898,7 +1071,8 @@ mod tests { "databricks-dbrx-instruct".to_string(), None, None, - ).unwrap(); + ) + .unwrap(); assert!(claude_provider.has_native_tool_calling()); assert!(llama_provider.has_native_tool_calling());