diff --git a/Cargo.lock b/Cargo.lock index 8cec021..6e2b33a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -732,6 +732,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "futures-util", "g3-config", "g3-execution", "g3-providers", diff --git a/README.md b/README.md index aef8af3..a9b5ef2 100644 --- a/README.md +++ b/README.md @@ -39,19 +39,19 @@ Create a configuration file at `~/.config/g3/config.toml`: ```toml [providers] -default_provider = "openai" +default_provider = "anthropic" + +[providers.anthropic] +api_key = "your-anthropic-api-key" +model = "claude-3-5-sonnet-20241022" +max_tokens = 4096 +temperature = 0.1 [providers.openai] api_key = "your-openai-api-key" model = "gpt-4" max_tokens = 2048 temperature = 0.1 - -[providers.anthropic] -api_key = "your-anthropic-api-key" -model = "claude-3-sonnet-20240229" -max_tokens = 2048 -temperature = 0.1 ``` ### Local Embedded Models diff --git a/crates/g3-config/src/lib.rs b/crates/g3-config/src/lib.rs index 5951985..5b3596f 100644 --- a/crates/g3-config/src/lib.rs +++ b/crates/g3-config/src/lib.rs @@ -58,7 +58,7 @@ impl Default for Config { openai: None, anthropic: None, embedded: None, - default_provider: "openai".to_string(), + default_provider: "anthropic".to_string(), }, agent: AgentConfig { max_context_length: 8192, diff --git a/crates/g3-core/Cargo.toml b/crates/g3-core/Cargo.toml index fa2af3f..26788b8 100644 --- a/crates/g3-core/Cargo.toml +++ b/crates/g3-core/Cargo.toml @@ -21,3 +21,4 @@ tokio-stream = "0.1" llama_cpp = { version = "0.3.2", features = ["metal"] } shellexpand = "3.1" tokio-util = "0.7" +futures-util = "0.3" diff --git a/crates/g3-core/src/lib.rs b/crates/g3-core/src/lib.rs index 5114a21..164582e 100644 --- a/crates/g3-core/src/lib.rs +++ b/crates/g3-core/src/lib.rs @@ -7,7 +7,7 @@ use std::fs; use std::path::Path; use std::time::{Duration, Instant}; use tokio_util::sync::CancellationToken; -use tracing::{error, info, warn}; +use tracing::{error, info, warn, debug}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolCall { @@ -229,7 +229,9 @@ impl Agent { } // Set default provider + debug!("Setting default provider to: {}", config.providers.default_provider); providers.set_default(&config.providers.default_provider)?; + debug!("Default provider set successfully"); // Determine context window size based on active provider let context_length = Self::determine_context_length(&config, &providers)?; @@ -364,8 +366,10 @@ impl Agent { let _provider = self.providers.get(None)?; - let system_prompt = format!( - "You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems step by step. + // Only add system message if this is the first interaction (empty conversation history) + if self.context_window.conversation_history.is_empty() { + let system_prompt = format!( + "You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems by writing code. # Tool Call Format @@ -381,48 +385,53 @@ The tool will execute immediately and you'll receive the result to continue with - Format: {{\"tool\": \"shell\", \"args\": {{\"command\": \"your_command_here\"}}}} - Example: {{\"tool\": \"shell\", \"args\": {{\"command\": \"ls ~/Downloads\"}}}} -- **final_output**: Signal task completion +- **final_output**: Signal task completion with a summary of work done in markdown format - Format: {{\"tool\": \"final_output\", \"args\": {{\"summary\": \"what_was_accomplished\"}}}} # Instructions -1. Break down tasks into small steps +1. Analyze the request and break down into smaller tasks if appropriate 2. Execute ONE tool at a time -3. Wait for the result before proceeding -4. Use the actual file paths on the system -5. End with final_output when done +3. STOP when the original request was satisfied +4. End with final_output when done + +# Response Guidelines + +- Use Markdown formatting for all responses except tool calls. +- Whenever calling tools, use the pronoun 'I' -Let's start with the first step of your task. "); - if show_prompt { - println!("🔍 System Prompt:"); - println!("================"); - println!("{}", system_prompt); - println!("================"); - println!(); - } + if show_prompt { + println!("🔍 System Prompt:"); + println!("================"); + println!("{}", system_prompt); + println!("================"); + println!(); + } - // Add system message to context window - let system_message = Message { - role: MessageRole::System, - content: system_prompt.clone(), - }; - self.context_window.add_message(system_message.clone()); + // Add system message to context window + let system_message = Message { + role: MessageRole::System, + content: system_prompt, + }; + self.context_window.add_message(system_message); + } // Add user message to context window let user_message = Message { role: MessageRole::User, content: format!("Task: {}", description), }; - self.context_window.add_message(user_message.clone()); + self.context_window.add_message(user_message); - let messages = vec![system_message, user_message]; + // Use the complete conversation history for the request + let messages = self.context_window.conversation_history.clone(); let request = CompletionRequest { messages, max_tokens: Some(2048), - temperature: Some(0.2), + temperature: Some(0.1), stream: true, // Enable streaming }; @@ -520,12 +529,15 @@ Let's start with the first step of your task. &self.context_window } - async fn stream_completion(&self, request: CompletionRequest) -> Result<(String, Duration)> { + async fn stream_completion( + &mut self, + request: CompletionRequest, + ) -> Result<(String, Duration)> { self.stream_completion_with_tools(request).await } async fn stream_completion_with_tools( - &self, + &mut self, mut request: CompletionRequest, ) -> Result<(String, Duration)> { use std::io::{self, Write}; @@ -587,8 +599,34 @@ Let's start with the first step of your task. first_token_time = Some(stream_start.elapsed()); } - // Check for tool calls in the streaming content - if let Some((tool_call, tool_end_pos)) = parser.add_chunk(&chunk.content) { + // Check for tool calls - either from JSON parsing (embedded models) + // or from native tool calls (Anthropic, OpenAI, etc.) + let mut detected_tool_call = None; + + // First check for native tool calls in the chunk + if let Some(ref tool_calls) = chunk.tool_calls { + debug!("Found native tool calls in chunk: {:?}", tool_calls); + if let Some(first_tool) = tool_calls.first() { + // Convert native tool call to our internal format + detected_tool_call = Some(( + crate::ToolCall { + tool: first_tool.tool.clone(), + args: first_tool.args.clone(), + }, + current_response.len(), // Position doesn't matter for native calls + )); + debug!("Converted native tool call: {:?}", detected_tool_call); + } + } else { + debug!("No native tool calls in chunk, chunk.tool_calls is None"); + } + + // If no native tool calls, check for JSON tool calls in text (embedded models) + if detected_tool_call.is_none() { + detected_tool_call = parser.add_chunk(&chunk.content); + } + + if let Some((tool_call, tool_end_pos)) = detected_tool_call { // Found a complete tool call! Stop streaming and execute it let content_before_tool = parser.get_content_before_tool(tool_end_pos); @@ -621,7 +659,7 @@ Let's start with the first step of your task. // Tool call header println!("┌─ {}", tool_call.tool); if let Some(args_obj) = tool_call.args.as_object() { - for (key, value) in args_obj { + for (_key, value) in args_obj { let value_str = match value { serde_json::Value::String(s) => s.clone(), _ => value.to_string(), @@ -664,7 +702,7 @@ Let's start with the first step of your task. print!("🤖 "); // Continue response indicator io::stdout().flush()?; - // Update the conversation with the tool call and result + // Add the tool call and result to the context window immediately let tool_message = Message { role: MessageRole::Assistant, content: format!( @@ -679,8 +717,12 @@ Let's start with the first step of your task. content: format!("Tool result: {}", tool_result), }; - //request.messages.push(tool_message); - request.messages.push(result_message); + // Add to context window for persistence + self.context_window.add_message(tool_message); + self.context_window.add_message(result_message); + + // Update the request with the new context for next iteration + request.messages = self.context_window.conversation_history.clone(); full_response.push_str(display_content); full_response.push_str(&format!( diff --git a/crates/g3-core/src/providers/anthropic.rs b/crates/g3-core/src/providers/anthropic.rs index 70f3bd3..a563da1 100644 --- a/crates/g3-core/src/providers/anthropic.rs +++ b/crates/g3-core/src/providers/anthropic.rs @@ -1,10 +1,14 @@ -use g3_providers::{LLMProvider, CompletionRequest, CompletionResponse, CompletionStream, CompletionChunk, Usage, Message, MessageRole}; +use g3_providers::{LLMProvider, CompletionRequest, CompletionResponse, CompletionStream, CompletionChunk, Usage, Message, MessageRole, ToolCall}; use anyhow::Result; use reqwest::Client; use serde::{Deserialize, Serialize}; -use tracing::{debug, error}; +use serde_json::Value; +use tracing::{debug, error, info}; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; +use tokio_stream::StreamExt; +use futures_util::stream::Stream; +use std::pin::Pin; pub struct AnthropicProvider { client: Client, @@ -22,26 +26,68 @@ struct AnthropicRequest { max_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, +} + +#[derive(Debug, Serialize)] +struct AnthropicTool { + name: String, + description: String, + input_schema: Value, } #[derive(Debug, Serialize)] struct AnthropicMessage { role: String, - content: String, + content: AnthropicMessageContent, +} + +#[derive(Debug, Serialize)] +#[serde(untagged)] +enum AnthropicMessageContent { + Text(String), + Blocks(Vec), +} + +#[derive(Debug, Serialize)] +#[serde(tag = "type")] +enum AnthropicContentBlock { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: Value, + }, + #[serde(rename = "tool_result")] + ToolResult { + tool_use_id: String, + content: String, + }, } #[derive(Debug, Deserialize)] struct AnthropicResponse { - content: Vec, + content: Vec, usage: AnthropicUsage, model: String, + #[serde(default)] + stop_reason: Option, } #[derive(Debug, Deserialize)] -struct AnthropicContent { - #[serde(rename = "type")] - content_type: String, - text: String, +#[serde(tag = "type")] +enum AnthropicResponseContent { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: Value, + }, } #[derive(Debug, Deserialize)] @@ -50,6 +96,24 @@ struct AnthropicUsage { output_tokens: u32, } +// Streaming response structures +#[derive(Debug, Deserialize)] +struct AnthropicStreamEvent { + #[serde(rename = "type")] + event_type: String, + #[serde(flatten)] + data: Value, +} + +#[derive(Debug, Deserialize)] +struct AnthropicStreamDelta { + #[serde(rename = "type")] + delta_type: String, + text: Option, + #[serde(flatten)] + other: Value, +} + impl AnthropicProvider { pub fn new(api_key: String, model: String) -> Result { let client = Client::new(); @@ -68,15 +132,209 @@ impl AnthropicProvider { MessageRole::User => "user".to_string(), MessageRole::Assistant => "assistant".to_string(), }, - content: message.content.clone(), + content: AnthropicMessageContent::Text(message.content.clone()), } } + + fn create_tools() -> Vec { + vec![ + AnthropicTool { + name: "shell".to_string(), + description: "Execute a shell command and return the output".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The shell command to execute" + } + }, + "required": ["command"] + }), + }, + AnthropicTool { + name: "final_output".to_string(), + description: "Provide a final summary or output for the task".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "summary": { + "type": "string", + "description": "A summary of what was accomplished" + } + }, + "required": ["summary"] + }), + }, + ] + } + + fn extract_content_and_tools(&self, response: &AnthropicResponse) -> (String, Vec<(String, String, Value)>) { + let mut text_content = String::new(); + let mut tool_calls = Vec::new(); + + for content in &response.content { + match content { + AnthropicResponseContent::Text { text } => { + if !text_content.is_empty() { + text_content.push('\n'); + } + text_content.push_str(text); + } + AnthropicResponseContent::ToolUse { id, name, input } => { + tool_calls.push((id.clone(), name.clone(), input.clone())); + } + } + } + + (text_content, tool_calls) + } + + async fn execute_tool(&self, tool_name: &str, input: &Value) -> Result { + match tool_name { + "shell" => { + if let Some(command) = input.get("command").and_then(|v| v.as_str()) { + info!("Executing shell command via Anthropic tool: {}", command); + + // Import the CodeExecutor from g3-execution + use g3_execution::CodeExecutor; + + let executor = CodeExecutor::new(); + match executor.execute_code("bash", command).await { + Ok(result) => { + if result.success { + Ok(if result.stdout.is_empty() { + "✅ Command executed successfully".to_string() + } else { + result.stdout + }) + } else { + Ok(format!("❌ Command failed: {}", result.stderr)) + } + } + Err(e) => { + error!("Shell execution error: {}", e); + Ok(format!("❌ Execution error: {}", e)) + } + } + } else { + Ok("❌ Missing command argument".to_string()) + } + } + "final_output" => { + if let Some(summary) = input.get("summary").and_then(|v| v.as_str()) { + Ok(format!("📋 Final Output: {}", summary)) + } else { + Ok("📋 Task completed".to_string()) + } + } + _ => { + error!("Unknown tool: {}", tool_name); + Ok(format!("❓ Unknown tool: {}", tool_name)) + } + } + } + + async fn complete_with_tools(&self, request: CompletionRequest) -> Result { + // Separate system messages from other messages + let mut system_content: Option = None; + let mut non_system_messages = Vec::new(); + + for message in &request.messages { + match message.role { + MessageRole::System => { + // Combine multiple system messages if present + if let Some(existing) = &system_content { + system_content = Some(format!("{}\n\n{}", existing, message.content)); + } else { + system_content = Some(message.content.clone()); + } + } + _ => { + non_system_messages.push(self.convert_message(message)); + } + } + } + + let anthropic_request = AnthropicRequest { + model: self.model.clone(), + system: system_content, + messages: non_system_messages, + max_tokens: request.max_tokens, + temperature: request.temperature, + tools: Some(Self::create_tools()), + }; + + let response = self + .client + .post("https://api.anthropic.com/v1/messages") + .header("x-api-key", &self.api_key) + .header("Content-Type", "application/json") + .header("anthropic-version", "2023-06-01") + .json(&anthropic_request) + .send() + .await?; + + if !response.status().is_success() { + let error_text = response.text().await?; + error!("Anthropic API error: {}", error_text); + anyhow::bail!("Anthropic API error: {}", error_text); + } + + let anthropic_response: AnthropicResponse = response.json().await?; + debug!("Anthropic response: {:?}", anthropic_response); + + let (text_content, tool_calls) = self.extract_content_and_tools(&anthropic_response); + + // For the completion API, we'll execute tools and return the combined result + let final_content = if !tool_calls.is_empty() { + info!("Anthropic response contains {} tool calls", tool_calls.len()); + + let mut content_with_tools = text_content.clone(); + for (_id, name, input) in tool_calls { + // Execute the tool call + let tool_result = match self.execute_tool(&name, &input).await { + Ok(result) => result, + Err(e) => format!("Error executing tool {}: {}", name, e), + }; + + // Append tool execution info to content + content_with_tools.push_str(&format!( + "\n\nTool executed: {} -> {}\n", + name, tool_result + )); + } + content_with_tools + } else { + text_content + }; + + Ok(CompletionResponse { + content: final_content, + usage: Usage { + prompt_tokens: anthropic_response.usage.input_tokens, + completion_tokens: anthropic_response.usage.output_tokens, + total_tokens: anthropic_response.usage.input_tokens + anthropic_response.usage.output_tokens, + }, + model: anthropic_response.model, + }) + } } #[async_trait::async_trait] impl LLMProvider for AnthropicProvider { async fn complete(&self, request: CompletionRequest) -> Result { - debug!("Making Anthropic completion request"); + debug!("Making Anthropic completion request with tools"); + + // This is a simplified implementation - for full tool support, + // we should use the streaming method with proper tool handling + self.complete_with_tools(request).await + } + + async fn stream(&self, request: CompletionRequest) -> Result { + debug!("Making Anthropic streaming request with tools"); + + let (tx, rx) = mpsc::channel(100); // Separate system messages from other messages let mut system_content: Option = None; @@ -104,58 +362,196 @@ impl LLMProvider for AnthropicProvider { messages: non_system_messages, max_tokens: request.max_tokens, temperature: request.temperature, + tools: Some(Self::create_tools()), }; - let response = self - .client - .post("https://api.anthropic.com/v1/messages") - .header("x-api-key", &self.api_key) - .header("Content-Type", "application/json") - .header("anthropic-version", "2023-06-01") - .json(&anthropic_request) - .send() - .await?; + // Add stream parameter + let mut request_json = serde_json::to_value(&anthropic_request)?; + request_json["stream"] = serde_json::Value::Bool(true); - if !response.status().is_success() { - let error_text = response.text().await?; - error!("Anthropic API error: {}", error_text); - anyhow::bail!("Anthropic API error: {}", error_text); - } + let client = self.client.clone(); + let api_key = self.api_key.clone(); - let anthropic_response: AnthropicResponse = response.json().await?; - - let content = anthropic_response - .content - .first() - .map(|content| content.text.clone()) - .unwrap_or_default(); - - Ok(CompletionResponse { - content, - usage: Usage { - prompt_tokens: anthropic_response.usage.input_tokens, - completion_tokens: anthropic_response.usage.output_tokens, - total_tokens: anthropic_response.usage.input_tokens + anthropic_response.usage.output_tokens, - }, - model: anthropic_response.model, - }) - } - - async fn stream(&self, request: CompletionRequest) -> Result { - debug!("Making Anthropic streaming request"); - - let (tx, rx) = mpsc::channel(100); - - // For now, just send the complete response as a single chunk - // In a real implementation, we'd handle Server-Sent Events - let completion = self.complete(request).await?; - - let chunk = CompletionChunk { - content: completion.content, - finished: true, - }; - - tx.send(Ok(chunk)).await.map_err(|_| anyhow::anyhow!("Failed to send chunk"))?; + tokio::spawn(async move { + debug!("Sending Anthropic streaming request with tools: {:?}", request_json); + let response = client + .post("https://api.anthropic.com/v1/messages") + .header("x-api-key", &api_key) + .header("Content-Type", "application/json") + .header("anthropic-version", "2023-06-01") + .json(&request_json) + .send() + .await; + + let response = match response { + Ok(resp) => { + if !resp.status().is_success() { + let error_text = resp.text().await.unwrap_or_default(); + let _ = tx.send(Err(anyhow::anyhow!("Anthropic API error: {}", error_text))).await; + return; + } + resp + } + Err(e) => { + let _ = tx.send(Err(e.into())).await; + return; + } + }; + + // Handle Server-Sent Events + let mut stream = response.bytes_stream(); + let mut buffer = String::new(); + let mut pending_tool_calls = Vec::new(); + + while let Some(chunk_result) = stream.next().await { + let chunk = match chunk_result { + Ok(bytes) => bytes, + Err(e) => { + let _ = tx.send(Err(e.into())).await; + break; + } + }; + + let chunk_str = match std::str::from_utf8(&chunk) { + Ok(s) => s, + Err(_) => continue, + }; + + buffer.push_str(chunk_str); + + // Process complete lines + while let Some(line_end) = buffer.find('\n') { + let line = buffer[..line_end].trim().to_string(); + buffer.drain(..line_end + 1); + + if line.is_empty() { + continue; + } + + // Parse SSE format: "data: {...}" + if let Some(data) = line.strip_prefix("data: ") { + debug!("Raw SSE data: {}", data); + if data == "[DONE]" { + // Send any pending tool calls first + if !pending_tool_calls.is_empty() { + let tool_chunk = CompletionChunk { + content: String::new(), + finished: false, + tool_calls: Some(pending_tool_calls.clone()), + }; + let _ = tx.send(Ok(tool_chunk)).await; + pending_tool_calls.clear(); + } + + // Send final chunk + let final_chunk = CompletionChunk { + content: String::new(), + finished: true, + tool_calls: None, + }; + let _ = tx.send(Ok(final_chunk)).await; + break; + } + + // Parse the JSON event + match serde_json::from_str::(data) { + Ok(event) => { + debug!("Received Anthropic event: type={}, data={:?}", event.event_type, event.data); + match event.event_type.as_str() { + "content_block_start" => { + // Check if this is a tool use block + if let Some(content_block) = event.data.get("content_block") { + if let Some(block_type) = content_block.get("type").and_then(|t| t.as_str()) { + if block_type == "tool_use" { + // Extract tool call information immediately + if let (Some(id), Some(name), Some(input)) = ( + content_block.get("id").and_then(|v| v.as_str()), + content_block.get("name").and_then(|v| v.as_str()), + content_block.get("input") + ) { + let tool_call = ToolCall { + id: id.to_string(), + tool: name.to_string(), + args: input.clone(), + }; + debug!("Added tool call from content_block_start: {:?}", tool_call); + pending_tool_calls.push(tool_call); + } + } + } + } + } + "content_block_delta" => { + // Extract text from delta + if let Some(delta) = event.data.get("delta") { + if let Some(text) = delta.get("text").and_then(|t| t.as_str()) { + let chunk = CompletionChunk { + content: text.to_string(), + finished: false, + tool_calls: None, + }; + if tx.send(Ok(chunk)).await.is_err() { + break; + } + } + } + } + "content_block_stop" => { + // Check if we have a complete tool use block + if let Some(content_block) = event.data.get("content_block") { + if let Some(block_type) = content_block.get("type").and_then(|t| t.as_str()) { + if block_type == "tool_use" { + // Extract tool call information + if let (Some(id), Some(name), Some(input)) = ( + content_block.get("id").and_then(|v| v.as_str()), + content_block.get("name").and_then(|v| v.as_str()), + content_block.get("input") + ) { + let tool_call = ToolCall { + id: id.to_string(), + tool: name.to_string(), + args: input.clone(), + }; + pending_tool_calls.push(tool_call); + } + } + } + } + debug!("Content block finished"); + } + "message_stop" => { + // Send any pending tool calls first + if !pending_tool_calls.is_empty() { + let tool_chunk = CompletionChunk { + content: String::new(), + finished: false, + tool_calls: Some(pending_tool_calls.clone()), + }; + let _ = tx.send(Ok(tool_chunk)).await; + } + + // Message finished + let final_chunk = CompletionChunk { + content: String::new(), + finished: true, + tool_calls: None, + }; + let _ = tx.send(Ok(final_chunk)).await; + break; + } + _ => { + debug!("Unhandled event type: {}", event.event_type); + } + } + } + Err(e) => { + debug!("Failed to parse streaming event: {} - Data: {}", e, data); + } + } + } + } + } + }); Ok(ReceiverStream::new(rx)) } @@ -167,4 +563,8 @@ impl LLMProvider for AnthropicProvider { fn model(&self) -> &str { &self.model } + + fn has_native_tool_calling(&self) -> bool { + true + } } diff --git a/crates/g3-core/src/providers/embedded.rs b/crates/g3-core/src/providers/embedded.rs index 8a8874a..c10cdec 100644 --- a/crates/g3-core/src/providers/embedded.rs +++ b/crates/g3-core/src/providers/embedded.rs @@ -8,12 +8,12 @@ use llama_cpp::{ LlamaModel, LlamaParams, LlamaSession, SessionParams, }; use std::path::Path; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::atomic::AtomicBool; use std::sync::Arc; use tokio::sync::mpsc; use tokio::sync::Mutex; use tokio_stream::wrappers::ReceiverStream; -use tracing::{debug, error, info, warn}; +use tracing::{debug, error, info}; pub struct EmbeddedProvider { model: Arc, @@ -129,6 +129,9 @@ impl EmbeddedProvider { debug!("Context calculation: prompt_tokens={}, context_length={}, available_tokens={}, dynamic_max_tokens={}", prompt_tokens, self.context_length, available_tokens, dynamic_max_tokens); + // Get stop sequences before entering the closure + let stop_sequences = self.get_stop_sequences(); + // Add timeout to the entire operation let timeout_duration = std::time::Duration::from_secs(30); // Increased timeout for larger contexts @@ -202,8 +205,16 @@ impl EmbeddedProvider { } // Stop on completion markers - if generated_text.contains("") || generated_text.contains("[/INST]") { - debug!("Hit CodeLlama stop sequence at {} tokens", token_count); + let mut hit_stop = false; + for stop_seq in &stop_sequences { + if generated_text.contains(stop_seq) { + debug!("Hit stop sequence '{}' at {} tokens", stop_seq, token_count); + hit_stop = true; + break; + } + } + + if hit_stop { break; } } @@ -213,7 +224,8 @@ impl EmbeddedProvider { token_count, start_time.elapsed() ); - Ok((generated_text.trim().to_string(), token_count)) + + Ok((generated_text, token_count)) }), ) .await; @@ -226,7 +238,8 @@ impl EmbeddedProvider { "Completed generation: {} tokens (dynamic limit was {})", token_count, dynamic_max_tokens ); - Ok(text) + // Clean stop sequences from the generated text after the closure + Ok(self.clean_stop_sequences(&text)) } Err(e) => Err(e), }, @@ -245,6 +258,78 @@ impl EmbeddedProvider { // This is conservative - actual tokenization might be different (text.len() as f32 / 4.0).ceil() as u32 } + + // Helper function to get stop sequences based on model type + fn get_stop_sequences(&self) -> Vec<&'static str> { + // Determine model type from model_name + let model_name_lower = self.model_name.to_lowercase(); + + if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama") { + vec![ + "", // End of sequence + "[/INST]", // End of instruction + "<>", // End of system message + "[INST]", // Start of new instruction (shouldn't appear in response) + "<>", // Start of system (shouldn't appear in response) + ] + } else if model_name_lower.contains("llama") { + vec![ + "", // End of sequence + "[/INST]", // End of instruction + "<>", // End of system message + "### Human:", // Conversation format + "### Assistant:", // Conversation format + "[INST]", // Start of new instruction + ] + } else if model_name_lower.contains("mistral") { + vec![ + "", // End of sequence + "[/INST]", // End of instruction + "<|im_end|>", // ChatML format + ] + } else if model_name_lower.contains("vicuna") || model_name_lower.contains("wizard") { + vec![ + "### Human:", // Conversation format + "### Assistant:", // Conversation format + "USER:", // Alternative format + "ASSISTANT:", // Alternative format + "", // End of sequence + ] + } else if model_name_lower.contains("alpaca") { + vec![ + "### Instruction:", // Alpaca format + "### Response:", // Alpaca format + "### Input:", // Alpaca format + "", // End of sequence + ] + } else { + // Generic/unknown model - use common stop sequences + vec![ + "", // Most common end sequence + "<|endoftext|>", // GPT-style + "<|im_end|>", // ChatML + "### Human:", // Common conversation format + "### Assistant:", // Common conversation format + "[/INST]", // Instruction format + "<>", // System format + ] + } + } + + // Helper function to clean up stop sequences from generated text + fn clean_stop_sequences(&self, text: &str) -> String { + let mut cleaned = text.to_string(); + let stop_sequences = self.get_stop_sequences(); + + for stop_seq in &stop_sequences { + if let Some(pos) = cleaned.find(stop_seq) { + cleaned.truncate(pos); + break; // Only remove the first occurrence to avoid over-truncation + } + } + + cleaned.trim().to_string() + } } #[async_trait::async_trait] @@ -333,6 +418,17 @@ impl LLMProvider for EmbeddedProvider { let mut accumulated_text = String::new(); let mut token_count = 0; + + // Get stop sequences dynamically based on model type + // We need to create a temporary EmbeddedProvider instance to access the method + // Since we can't access self in the spawned task, we'll use a static approach + let stop_sequences = if prompt.contains("[INST]") || prompt.contains("<>") { + // Llama/CodeLlama format detected + vec!["", "[/INST]", "<>", "[INST]", "<>", "### Human:", "### Assistant:"] + } else { + // Generic format + vec!["", "<|endoftext|>", "<|im_end|>", "### Human:", "### Assistant:", "[/INST]", "<>"] + }; // Stream tokens with proper limits while let Some(token) = completion_handle.next_token() { @@ -341,13 +437,52 @@ impl LLMProvider for EmbeddedProvider { accumulated_text.push_str(&token_string); token_count += 1; - let chunk = CompletionChunk { - content: token_string.clone(), - finished: false, - }; + // Check if we've hit a stop sequence + let mut hit_stop = false; + for stop_seq in &stop_sequences { + if accumulated_text.contains(stop_seq) { + debug!("Hit stop sequence in streaming: {}", stop_seq); + hit_stop = true; + break; + } + } - if tx.blocking_send(Ok(chunk)).is_err() { - break; // Receiver dropped + if hit_stop { + // Don't send the token that contains the stop sequence + // Instead, send only the part before the stop sequence + let mut clean_accumulated = accumulated_text.clone(); + for stop_seq in &stop_sequences { + if let Some(pos) = clean_accumulated.find(stop_seq) { + clean_accumulated.truncate(pos); + break; + } + } + + // Calculate what part we haven't sent yet + let already_sent_len = accumulated_text.len() - token_string.len(); + if clean_accumulated.len() > already_sent_len { + let remaining_to_send = &clean_accumulated[already_sent_len..]; + if !remaining_to_send.is_empty() { + let chunk = CompletionChunk { + content: remaining_to_send.to_string(), + finished: false, + tool_calls: None, + }; + let _ = tx.blocking_send(Ok(chunk)); + } + } + break; + } else { + // Normal token, send it + let chunk = CompletionChunk { + content: token_string.clone(), + finished: false, + tool_calls: None, + }; + + if tx.blocking_send(Ok(chunk)).is_err() { + break; // Receiver dropped + } } // Enforce token limit @@ -355,22 +490,13 @@ impl LLMProvider for EmbeddedProvider { debug!("Reached max token limit in streaming: {}", max_tokens); break; } - - // Stop if we hit common stop sequences - if accumulated_text.contains("### Human") - || accumulated_text.contains("### System") - || accumulated_text.contains("<|end|>") - || accumulated_text.contains("") - { - debug!("Hit stop sequence in streaming, stopping generation"); - break; - } } // Send final chunk let final_chunk = CompletionChunk { content: String::new(), finished: true, + tool_calls: None, }; let _ = tx.blocking_send(Ok(final_chunk)); }); diff --git a/crates/g3-core/src/providers/openai.rs b/crates/g3-core/src/providers/openai.rs index 265e946..f18eadb 100644 --- a/crates/g3-core/src/providers/openai.rs +++ b/crates/g3-core/src/providers/openai.rs @@ -140,6 +140,7 @@ impl LLMProvider for OpenAIProvider { let chunk = CompletionChunk { content: completion.content, finished: true, + tool_calls: None, }; tx.send(Ok(chunk)).await.map_err(|_| anyhow::anyhow!("Failed to send chunk"))?; diff --git a/crates/g3-providers/src/lib.rs b/crates/g3-providers/src/lib.rs index 8a3bafd..6824ac4 100644 --- a/crates/g3-providers/src/lib.rs +++ b/crates/g3-providers/src/lib.rs @@ -16,6 +16,11 @@ pub trait LLMProvider: Send + Sync { /// Get the model name fn model(&self) -> &str; + + /// Check if the provider supports native tool calling + fn has_native_tool_calling(&self) -> bool { + false + } } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -60,6 +65,14 @@ pub type CompletionStream = tokio_stream::wrappers::ReceiverStream>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + pub id: String, + pub tool: String, + pub args: serde_json::Value, } /// Provider registry for managing multiple LLM providers