tool calling support for anthropic
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -732,6 +732,7 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
"futures-util",
|
||||||
"g3-config",
|
"g3-config",
|
||||||
"g3-execution",
|
"g3-execution",
|
||||||
"g3-providers",
|
"g3-providers",
|
||||||
|
|||||||
14
README.md
14
README.md
@@ -39,19 +39,19 @@ Create a configuration file at `~/.config/g3/config.toml`:
|
|||||||
|
|
||||||
```toml
|
```toml
|
||||||
[providers]
|
[providers]
|
||||||
default_provider = "openai"
|
default_provider = "anthropic"
|
||||||
|
|
||||||
|
[providers.anthropic]
|
||||||
|
api_key = "your-anthropic-api-key"
|
||||||
|
model = "claude-3-5-sonnet-20241022"
|
||||||
|
max_tokens = 4096
|
||||||
|
temperature = 0.1
|
||||||
|
|
||||||
[providers.openai]
|
[providers.openai]
|
||||||
api_key = "your-openai-api-key"
|
api_key = "your-openai-api-key"
|
||||||
model = "gpt-4"
|
model = "gpt-4"
|
||||||
max_tokens = 2048
|
max_tokens = 2048
|
||||||
temperature = 0.1
|
temperature = 0.1
|
||||||
|
|
||||||
[providers.anthropic]
|
|
||||||
api_key = "your-anthropic-api-key"
|
|
||||||
model = "claude-3-sonnet-20240229"
|
|
||||||
max_tokens = 2048
|
|
||||||
temperature = 0.1
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Local Embedded Models
|
### Local Embedded Models
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ impl Default for Config {
|
|||||||
openai: None,
|
openai: None,
|
||||||
anthropic: None,
|
anthropic: None,
|
||||||
embedded: None,
|
embedded: None,
|
||||||
default_provider: "openai".to_string(),
|
default_provider: "anthropic".to_string(),
|
||||||
},
|
},
|
||||||
agent: AgentConfig {
|
agent: AgentConfig {
|
||||||
max_context_length: 8192,
|
max_context_length: 8192,
|
||||||
|
|||||||
@@ -21,3 +21,4 @@ tokio-stream = "0.1"
|
|||||||
llama_cpp = { version = "0.3.2", features = ["metal"] }
|
llama_cpp = { version = "0.3.2", features = ["metal"] }
|
||||||
shellexpand = "3.1"
|
shellexpand = "3.1"
|
||||||
tokio-util = "0.7"
|
tokio-util = "0.7"
|
||||||
|
futures-util = "0.3"
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ use std::fs;
|
|||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
use tracing::{error, info, warn};
|
use tracing::{error, info, warn, debug};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ToolCall {
|
pub struct ToolCall {
|
||||||
@@ -229,7 +229,9 @@ impl Agent {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set default provider
|
// Set default provider
|
||||||
|
debug!("Setting default provider to: {}", config.providers.default_provider);
|
||||||
providers.set_default(&config.providers.default_provider)?;
|
providers.set_default(&config.providers.default_provider)?;
|
||||||
|
debug!("Default provider set successfully");
|
||||||
|
|
||||||
// Determine context window size based on active provider
|
// Determine context window size based on active provider
|
||||||
let context_length = Self::determine_context_length(&config, &providers)?;
|
let context_length = Self::determine_context_length(&config, &providers)?;
|
||||||
@@ -364,8 +366,10 @@ impl Agent {
|
|||||||
|
|
||||||
let _provider = self.providers.get(None)?;
|
let _provider = self.providers.get(None)?;
|
||||||
|
|
||||||
let system_prompt = format!(
|
// Only add system message if this is the first interaction (empty conversation history)
|
||||||
"You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems step by step.
|
if self.context_window.conversation_history.is_empty() {
|
||||||
|
let system_prompt = format!(
|
||||||
|
"You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems by writing code.
|
||||||
|
|
||||||
# Tool Call Format
|
# Tool Call Format
|
||||||
|
|
||||||
@@ -381,48 +385,53 @@ The tool will execute immediately and you'll receive the result to continue with
|
|||||||
- Format: {{\"tool\": \"shell\", \"args\": {{\"command\": \"your_command_here\"}}}}
|
- Format: {{\"tool\": \"shell\", \"args\": {{\"command\": \"your_command_here\"}}}}
|
||||||
- Example: {{\"tool\": \"shell\", \"args\": {{\"command\": \"ls ~/Downloads\"}}}}
|
- Example: {{\"tool\": \"shell\", \"args\": {{\"command\": \"ls ~/Downloads\"}}}}
|
||||||
|
|
||||||
- **final_output**: Signal task completion
|
- **final_output**: Signal task completion with a summary of work done in markdown format
|
||||||
- Format: {{\"tool\": \"final_output\", \"args\": {{\"summary\": \"what_was_accomplished\"}}}}
|
- Format: {{\"tool\": \"final_output\", \"args\": {{\"summary\": \"what_was_accomplished\"}}}}
|
||||||
|
|
||||||
# Instructions
|
# Instructions
|
||||||
|
|
||||||
1. Break down tasks into small steps
|
1. Analyze the request and break down into smaller tasks if appropriate
|
||||||
2. Execute ONE tool at a time
|
2. Execute ONE tool at a time
|
||||||
3. Wait for the result before proceeding
|
3. STOP when the original request was satisfied
|
||||||
4. Use the actual file paths on the system
|
4. End with final_output when done
|
||||||
5. End with final_output when done
|
|
||||||
|
# Response Guidelines
|
||||||
|
|
||||||
|
- Use Markdown formatting for all responses except tool calls.
|
||||||
|
- Whenever calling tools, use the pronoun 'I'
|
||||||
|
|
||||||
Let's start with the first step of your task.
|
|
||||||
");
|
");
|
||||||
|
|
||||||
if show_prompt {
|
if show_prompt {
|
||||||
println!("🔍 System Prompt:");
|
println!("🔍 System Prompt:");
|
||||||
println!("================");
|
println!("================");
|
||||||
println!("{}", system_prompt);
|
println!("{}", system_prompt);
|
||||||
println!("================");
|
println!("================");
|
||||||
println!();
|
println!();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add system message to context window
|
// Add system message to context window
|
||||||
let system_message = Message {
|
let system_message = Message {
|
||||||
role: MessageRole::System,
|
role: MessageRole::System,
|
||||||
content: system_prompt.clone(),
|
content: system_prompt,
|
||||||
};
|
};
|
||||||
self.context_window.add_message(system_message.clone());
|
self.context_window.add_message(system_message);
|
||||||
|
}
|
||||||
|
|
||||||
// Add user message to context window
|
// Add user message to context window
|
||||||
let user_message = Message {
|
let user_message = Message {
|
||||||
role: MessageRole::User,
|
role: MessageRole::User,
|
||||||
content: format!("Task: {}", description),
|
content: format!("Task: {}", description),
|
||||||
};
|
};
|
||||||
self.context_window.add_message(user_message.clone());
|
self.context_window.add_message(user_message);
|
||||||
|
|
||||||
let messages = vec![system_message, user_message];
|
// Use the complete conversation history for the request
|
||||||
|
let messages = self.context_window.conversation_history.clone();
|
||||||
|
|
||||||
let request = CompletionRequest {
|
let request = CompletionRequest {
|
||||||
messages,
|
messages,
|
||||||
max_tokens: Some(2048),
|
max_tokens: Some(2048),
|
||||||
temperature: Some(0.2),
|
temperature: Some(0.1),
|
||||||
stream: true, // Enable streaming
|
stream: true, // Enable streaming
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -520,12 +529,15 @@ Let's start with the first step of your task.
|
|||||||
&self.context_window
|
&self.context_window
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn stream_completion(&self, request: CompletionRequest) -> Result<(String, Duration)> {
|
async fn stream_completion(
|
||||||
|
&mut self,
|
||||||
|
request: CompletionRequest,
|
||||||
|
) -> Result<(String, Duration)> {
|
||||||
self.stream_completion_with_tools(request).await
|
self.stream_completion_with_tools(request).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn stream_completion_with_tools(
|
async fn stream_completion_with_tools(
|
||||||
&self,
|
&mut self,
|
||||||
mut request: CompletionRequest,
|
mut request: CompletionRequest,
|
||||||
) -> Result<(String, Duration)> {
|
) -> Result<(String, Duration)> {
|
||||||
use std::io::{self, Write};
|
use std::io::{self, Write};
|
||||||
@@ -587,8 +599,34 @@ Let's start with the first step of your task.
|
|||||||
first_token_time = Some(stream_start.elapsed());
|
first_token_time = Some(stream_start.elapsed());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for tool calls in the streaming content
|
// Check for tool calls - either from JSON parsing (embedded models)
|
||||||
if let Some((tool_call, tool_end_pos)) = parser.add_chunk(&chunk.content) {
|
// or from native tool calls (Anthropic, OpenAI, etc.)
|
||||||
|
let mut detected_tool_call = None;
|
||||||
|
|
||||||
|
// First check for native tool calls in the chunk
|
||||||
|
if let Some(ref tool_calls) = chunk.tool_calls {
|
||||||
|
debug!("Found native tool calls in chunk: {:?}", tool_calls);
|
||||||
|
if let Some(first_tool) = tool_calls.first() {
|
||||||
|
// Convert native tool call to our internal format
|
||||||
|
detected_tool_call = Some((
|
||||||
|
crate::ToolCall {
|
||||||
|
tool: first_tool.tool.clone(),
|
||||||
|
args: first_tool.args.clone(),
|
||||||
|
},
|
||||||
|
current_response.len(), // Position doesn't matter for native calls
|
||||||
|
));
|
||||||
|
debug!("Converted native tool call: {:?}", detected_tool_call);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
debug!("No native tool calls in chunk, chunk.tool_calls is None");
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no native tool calls, check for JSON tool calls in text (embedded models)
|
||||||
|
if detected_tool_call.is_none() {
|
||||||
|
detected_tool_call = parser.add_chunk(&chunk.content);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some((tool_call, tool_end_pos)) = detected_tool_call {
|
||||||
// Found a complete tool call! Stop streaming and execute it
|
// Found a complete tool call! Stop streaming and execute it
|
||||||
let content_before_tool = parser.get_content_before_tool(tool_end_pos);
|
let content_before_tool = parser.get_content_before_tool(tool_end_pos);
|
||||||
|
|
||||||
@@ -621,7 +659,7 @@ Let's start with the first step of your task.
|
|||||||
// Tool call header
|
// Tool call header
|
||||||
println!("┌─ {}", tool_call.tool);
|
println!("┌─ {}", tool_call.tool);
|
||||||
if let Some(args_obj) = tool_call.args.as_object() {
|
if let Some(args_obj) = tool_call.args.as_object() {
|
||||||
for (key, value) in args_obj {
|
for (_key, value) in args_obj {
|
||||||
let value_str = match value {
|
let value_str = match value {
|
||||||
serde_json::Value::String(s) => s.clone(),
|
serde_json::Value::String(s) => s.clone(),
|
||||||
_ => value.to_string(),
|
_ => value.to_string(),
|
||||||
@@ -664,7 +702,7 @@ Let's start with the first step of your task.
|
|||||||
print!("🤖 "); // Continue response indicator
|
print!("🤖 "); // Continue response indicator
|
||||||
io::stdout().flush()?;
|
io::stdout().flush()?;
|
||||||
|
|
||||||
// Update the conversation with the tool call and result
|
// Add the tool call and result to the context window immediately
|
||||||
let tool_message = Message {
|
let tool_message = Message {
|
||||||
role: MessageRole::Assistant,
|
role: MessageRole::Assistant,
|
||||||
content: format!(
|
content: format!(
|
||||||
@@ -679,8 +717,12 @@ Let's start with the first step of your task.
|
|||||||
content: format!("Tool result: {}", tool_result),
|
content: format!("Tool result: {}", tool_result),
|
||||||
};
|
};
|
||||||
|
|
||||||
//request.messages.push(tool_message);
|
// Add to context window for persistence
|
||||||
request.messages.push(result_message);
|
self.context_window.add_message(tool_message);
|
||||||
|
self.context_window.add_message(result_message);
|
||||||
|
|
||||||
|
// Update the request with the new context for next iteration
|
||||||
|
request.messages = self.context_window.conversation_history.clone();
|
||||||
|
|
||||||
full_response.push_str(display_content);
|
full_response.push_str(display_content);
|
||||||
full_response.push_str(&format!(
|
full_response.push_str(&format!(
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
use g3_providers::{LLMProvider, CompletionRequest, CompletionResponse, CompletionStream, CompletionChunk, Usage, Message, MessageRole};
|
use g3_providers::{LLMProvider, CompletionRequest, CompletionResponse, CompletionStream, CompletionChunk, Usage, Message, MessageRole, ToolCall};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tracing::{debug, error};
|
use serde_json::Value;
|
||||||
|
use tracing::{debug, error, info};
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tokio_stream::wrappers::ReceiverStream;
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
|
use tokio_stream::StreamExt;
|
||||||
|
use futures_util::stream::Stream;
|
||||||
|
use std::pin::Pin;
|
||||||
|
|
||||||
pub struct AnthropicProvider {
|
pub struct AnthropicProvider {
|
||||||
client: Client,
|
client: Client,
|
||||||
@@ -22,26 +26,68 @@ struct AnthropicRequest {
|
|||||||
max_tokens: Option<u32>,
|
max_tokens: Option<u32>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
temperature: Option<f32>,
|
temperature: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tools: Option<Vec<AnthropicTool>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct AnthropicTool {
|
||||||
|
name: String,
|
||||||
|
description: String,
|
||||||
|
input_schema: Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
struct AnthropicMessage {
|
struct AnthropicMessage {
|
||||||
role: String,
|
role: String,
|
||||||
content: String,
|
content: AnthropicMessageContent,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum AnthropicMessageContent {
|
||||||
|
Text(String),
|
||||||
|
Blocks(Vec<AnthropicContentBlock>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
enum AnthropicContentBlock {
|
||||||
|
#[serde(rename = "text")]
|
||||||
|
Text { text: String },
|
||||||
|
#[serde(rename = "tool_use")]
|
||||||
|
ToolUse {
|
||||||
|
id: String,
|
||||||
|
name: String,
|
||||||
|
input: Value,
|
||||||
|
},
|
||||||
|
#[serde(rename = "tool_result")]
|
||||||
|
ToolResult {
|
||||||
|
tool_use_id: String,
|
||||||
|
content: String,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct AnthropicResponse {
|
struct AnthropicResponse {
|
||||||
content: Vec<AnthropicContent>,
|
content: Vec<AnthropicResponseContent>,
|
||||||
usage: AnthropicUsage,
|
usage: AnthropicUsage,
|
||||||
model: String,
|
model: String,
|
||||||
|
#[serde(default)]
|
||||||
|
stop_reason: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct AnthropicContent {
|
#[serde(tag = "type")]
|
||||||
#[serde(rename = "type")]
|
enum AnthropicResponseContent {
|
||||||
content_type: String,
|
#[serde(rename = "text")]
|
||||||
text: String,
|
Text { text: String },
|
||||||
|
#[serde(rename = "tool_use")]
|
||||||
|
ToolUse {
|
||||||
|
id: String,
|
||||||
|
name: String,
|
||||||
|
input: Value,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
@@ -50,6 +96,24 @@ struct AnthropicUsage {
|
|||||||
output_tokens: u32,
|
output_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Streaming response structures
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct AnthropicStreamEvent {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
event_type: String,
|
||||||
|
#[serde(flatten)]
|
||||||
|
data: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct AnthropicStreamDelta {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
delta_type: String,
|
||||||
|
text: Option<String>,
|
||||||
|
#[serde(flatten)]
|
||||||
|
other: Value,
|
||||||
|
}
|
||||||
|
|
||||||
impl AnthropicProvider {
|
impl AnthropicProvider {
|
||||||
pub fn new(api_key: String, model: String) -> Result<Self> {
|
pub fn new(api_key: String, model: String) -> Result<Self> {
|
||||||
let client = Client::new();
|
let client = Client::new();
|
||||||
@@ -68,15 +132,209 @@ impl AnthropicProvider {
|
|||||||
MessageRole::User => "user".to_string(),
|
MessageRole::User => "user".to_string(),
|
||||||
MessageRole::Assistant => "assistant".to_string(),
|
MessageRole::Assistant => "assistant".to_string(),
|
||||||
},
|
},
|
||||||
content: message.content.clone(),
|
content: AnthropicMessageContent::Text(message.content.clone()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn create_tools() -> Vec<AnthropicTool> {
|
||||||
|
vec![
|
||||||
|
AnthropicTool {
|
||||||
|
name: "shell".to_string(),
|
||||||
|
description: "Execute a shell command and return the output".to_string(),
|
||||||
|
input_schema: serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The shell command to execute"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["command"]
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
AnthropicTool {
|
||||||
|
name: "final_output".to_string(),
|
||||||
|
description: "Provide a final summary or output for the task".to_string(),
|
||||||
|
input_schema: serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"summary": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "A summary of what was accomplished"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["summary"]
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_content_and_tools(&self, response: &AnthropicResponse) -> (String, Vec<(String, String, Value)>) {
|
||||||
|
let mut text_content = String::new();
|
||||||
|
let mut tool_calls = Vec::new();
|
||||||
|
|
||||||
|
for content in &response.content {
|
||||||
|
match content {
|
||||||
|
AnthropicResponseContent::Text { text } => {
|
||||||
|
if !text_content.is_empty() {
|
||||||
|
text_content.push('\n');
|
||||||
|
}
|
||||||
|
text_content.push_str(text);
|
||||||
|
}
|
||||||
|
AnthropicResponseContent::ToolUse { id, name, input } => {
|
||||||
|
tool_calls.push((id.clone(), name.clone(), input.clone()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(text_content, tool_calls)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute_tool(&self, tool_name: &str, input: &Value) -> Result<String> {
|
||||||
|
match tool_name {
|
||||||
|
"shell" => {
|
||||||
|
if let Some(command) = input.get("command").and_then(|v| v.as_str()) {
|
||||||
|
info!("Executing shell command via Anthropic tool: {}", command);
|
||||||
|
|
||||||
|
// Import the CodeExecutor from g3-execution
|
||||||
|
use g3_execution::CodeExecutor;
|
||||||
|
|
||||||
|
let executor = CodeExecutor::new();
|
||||||
|
match executor.execute_code("bash", command).await {
|
||||||
|
Ok(result) => {
|
||||||
|
if result.success {
|
||||||
|
Ok(if result.stdout.is_empty() {
|
||||||
|
"✅ Command executed successfully".to_string()
|
||||||
|
} else {
|
||||||
|
result.stdout
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Ok(format!("❌ Command failed: {}", result.stderr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Shell execution error: {}", e);
|
||||||
|
Ok(format!("❌ Execution error: {}", e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Ok("❌ Missing command argument".to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"final_output" => {
|
||||||
|
if let Some(summary) = input.get("summary").and_then(|v| v.as_str()) {
|
||||||
|
Ok(format!("📋 Final Output: {}", summary))
|
||||||
|
} else {
|
||||||
|
Ok("📋 Task completed".to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
error!("Unknown tool: {}", tool_name);
|
||||||
|
Ok(format!("❓ Unknown tool: {}", tool_name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn complete_with_tools(&self, request: CompletionRequest) -> Result<CompletionResponse> {
|
||||||
|
// Separate system messages from other messages
|
||||||
|
let mut system_content: Option<String> = None;
|
||||||
|
let mut non_system_messages = Vec::new();
|
||||||
|
|
||||||
|
for message in &request.messages {
|
||||||
|
match message.role {
|
||||||
|
MessageRole::System => {
|
||||||
|
// Combine multiple system messages if present
|
||||||
|
if let Some(existing) = &system_content {
|
||||||
|
system_content = Some(format!("{}\n\n{}", existing, message.content));
|
||||||
|
} else {
|
||||||
|
system_content = Some(message.content.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
non_system_messages.push(self.convert_message(message));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let anthropic_request = AnthropicRequest {
|
||||||
|
model: self.model.clone(),
|
||||||
|
system: system_content,
|
||||||
|
messages: non_system_messages,
|
||||||
|
max_tokens: request.max_tokens,
|
||||||
|
temperature: request.temperature,
|
||||||
|
tools: Some(Self::create_tools()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post("https://api.anthropic.com/v1/messages")
|
||||||
|
.header("x-api-key", &self.api_key)
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("anthropic-version", "2023-06-01")
|
||||||
|
.json(&anthropic_request)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
let error_text = response.text().await?;
|
||||||
|
error!("Anthropic API error: {}", error_text);
|
||||||
|
anyhow::bail!("Anthropic API error: {}", error_text);
|
||||||
|
}
|
||||||
|
|
||||||
|
let anthropic_response: AnthropicResponse = response.json().await?;
|
||||||
|
debug!("Anthropic response: {:?}", anthropic_response);
|
||||||
|
|
||||||
|
let (text_content, tool_calls) = self.extract_content_and_tools(&anthropic_response);
|
||||||
|
|
||||||
|
// For the completion API, we'll execute tools and return the combined result
|
||||||
|
let final_content = if !tool_calls.is_empty() {
|
||||||
|
info!("Anthropic response contains {} tool calls", tool_calls.len());
|
||||||
|
|
||||||
|
let mut content_with_tools = text_content.clone();
|
||||||
|
for (_id, name, input) in tool_calls {
|
||||||
|
// Execute the tool call
|
||||||
|
let tool_result = match self.execute_tool(&name, &input).await {
|
||||||
|
Ok(result) => result,
|
||||||
|
Err(e) => format!("Error executing tool {}: {}", name, e),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Append tool execution info to content
|
||||||
|
content_with_tools.push_str(&format!(
|
||||||
|
"\n\nTool executed: {} -> {}\n",
|
||||||
|
name, tool_result
|
||||||
|
));
|
||||||
|
}
|
||||||
|
content_with_tools
|
||||||
|
} else {
|
||||||
|
text_content
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(CompletionResponse {
|
||||||
|
content: final_content,
|
||||||
|
usage: Usage {
|
||||||
|
prompt_tokens: anthropic_response.usage.input_tokens,
|
||||||
|
completion_tokens: anthropic_response.usage.output_tokens,
|
||||||
|
total_tokens: anthropic_response.usage.input_tokens + anthropic_response.usage.output_tokens,
|
||||||
|
},
|
||||||
|
model: anthropic_response.model,
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl LLMProvider for AnthropicProvider {
|
impl LLMProvider for AnthropicProvider {
|
||||||
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
|
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
|
||||||
debug!("Making Anthropic completion request");
|
debug!("Making Anthropic completion request with tools");
|
||||||
|
|
||||||
|
// This is a simplified implementation - for full tool support,
|
||||||
|
// we should use the streaming method with proper tool handling
|
||||||
|
self.complete_with_tools(request).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
|
||||||
|
debug!("Making Anthropic streaming request with tools");
|
||||||
|
|
||||||
|
let (tx, rx) = mpsc::channel(100);
|
||||||
|
|
||||||
// Separate system messages from other messages
|
// Separate system messages from other messages
|
||||||
let mut system_content: Option<String> = None;
|
let mut system_content: Option<String> = None;
|
||||||
@@ -104,58 +362,196 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
messages: non_system_messages,
|
messages: non_system_messages,
|
||||||
max_tokens: request.max_tokens,
|
max_tokens: request.max_tokens,
|
||||||
temperature: request.temperature,
|
temperature: request.temperature,
|
||||||
|
tools: Some(Self::create_tools()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = self
|
// Add stream parameter
|
||||||
.client
|
let mut request_json = serde_json::to_value(&anthropic_request)?;
|
||||||
.post("https://api.anthropic.com/v1/messages")
|
request_json["stream"] = serde_json::Value::Bool(true);
|
||||||
.header("x-api-key", &self.api_key)
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header("anthropic-version", "2023-06-01")
|
|
||||||
.json(&anthropic_request)
|
|
||||||
.send()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if !response.status().is_success() {
|
let client = self.client.clone();
|
||||||
let error_text = response.text().await?;
|
let api_key = self.api_key.clone();
|
||||||
error!("Anthropic API error: {}", error_text);
|
|
||||||
anyhow::bail!("Anthropic API error: {}", error_text);
|
|
||||||
}
|
|
||||||
|
|
||||||
let anthropic_response: AnthropicResponse = response.json().await?;
|
tokio::spawn(async move {
|
||||||
|
debug!("Sending Anthropic streaming request with tools: {:?}", request_json);
|
||||||
|
let response = client
|
||||||
|
.post("https://api.anthropic.com/v1/messages")
|
||||||
|
.header("x-api-key", &api_key)
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("anthropic-version", "2023-06-01")
|
||||||
|
.json(&request_json)
|
||||||
|
.send()
|
||||||
|
.await;
|
||||||
|
|
||||||
let content = anthropic_response
|
let response = match response {
|
||||||
.content
|
Ok(resp) => {
|
||||||
.first()
|
if !resp.status().is_success() {
|
||||||
.map(|content| content.text.clone())
|
let error_text = resp.text().await.unwrap_or_default();
|
||||||
.unwrap_or_default();
|
let _ = tx.send(Err(anyhow::anyhow!("Anthropic API error: {}", error_text))).await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
resp
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx.send(Err(e.into())).await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
Ok(CompletionResponse {
|
// Handle Server-Sent Events
|
||||||
content,
|
let mut stream = response.bytes_stream();
|
||||||
usage: Usage {
|
let mut buffer = String::new();
|
||||||
prompt_tokens: anthropic_response.usage.input_tokens,
|
let mut pending_tool_calls = Vec::new();
|
||||||
completion_tokens: anthropic_response.usage.output_tokens,
|
|
||||||
total_tokens: anthropic_response.usage.input_tokens + anthropic_response.usage.output_tokens,
|
|
||||||
},
|
|
||||||
model: anthropic_response.model,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
|
while let Some(chunk_result) = stream.next().await {
|
||||||
debug!("Making Anthropic streaming request");
|
let chunk = match chunk_result {
|
||||||
|
Ok(bytes) => bytes,
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx.send(Err(e.into())).await;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let (tx, rx) = mpsc::channel(100);
|
let chunk_str = match std::str::from_utf8(&chunk) {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
|
||||||
// For now, just send the complete response as a single chunk
|
buffer.push_str(chunk_str);
|
||||||
// In a real implementation, we'd handle Server-Sent Events
|
|
||||||
let completion = self.complete(request).await?;
|
|
||||||
|
|
||||||
let chunk = CompletionChunk {
|
// Process complete lines
|
||||||
content: completion.content,
|
while let Some(line_end) = buffer.find('\n') {
|
||||||
finished: true,
|
let line = buffer[..line_end].trim().to_string();
|
||||||
};
|
buffer.drain(..line_end + 1);
|
||||||
|
|
||||||
tx.send(Ok(chunk)).await.map_err(|_| anyhow::anyhow!("Failed to send chunk"))?;
|
if line.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse SSE format: "data: {...}"
|
||||||
|
if let Some(data) = line.strip_prefix("data: ") {
|
||||||
|
debug!("Raw SSE data: {}", data);
|
||||||
|
if data == "[DONE]" {
|
||||||
|
// Send any pending tool calls first
|
||||||
|
if !pending_tool_calls.is_empty() {
|
||||||
|
let tool_chunk = CompletionChunk {
|
||||||
|
content: String::new(),
|
||||||
|
finished: false,
|
||||||
|
tool_calls: Some(pending_tool_calls.clone()),
|
||||||
|
};
|
||||||
|
let _ = tx.send(Ok(tool_chunk)).await;
|
||||||
|
pending_tool_calls.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send final chunk
|
||||||
|
let final_chunk = CompletionChunk {
|
||||||
|
content: String::new(),
|
||||||
|
finished: true,
|
||||||
|
tool_calls: None,
|
||||||
|
};
|
||||||
|
let _ = tx.send(Ok(final_chunk)).await;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the JSON event
|
||||||
|
match serde_json::from_str::<AnthropicStreamEvent>(data) {
|
||||||
|
Ok(event) => {
|
||||||
|
debug!("Received Anthropic event: type={}, data={:?}", event.event_type, event.data);
|
||||||
|
match event.event_type.as_str() {
|
||||||
|
"content_block_start" => {
|
||||||
|
// Check if this is a tool use block
|
||||||
|
if let Some(content_block) = event.data.get("content_block") {
|
||||||
|
if let Some(block_type) = content_block.get("type").and_then(|t| t.as_str()) {
|
||||||
|
if block_type == "tool_use" {
|
||||||
|
// Extract tool call information immediately
|
||||||
|
if let (Some(id), Some(name), Some(input)) = (
|
||||||
|
content_block.get("id").and_then(|v| v.as_str()),
|
||||||
|
content_block.get("name").and_then(|v| v.as_str()),
|
||||||
|
content_block.get("input")
|
||||||
|
) {
|
||||||
|
let tool_call = ToolCall {
|
||||||
|
id: id.to_string(),
|
||||||
|
tool: name.to_string(),
|
||||||
|
args: input.clone(),
|
||||||
|
};
|
||||||
|
debug!("Added tool call from content_block_start: {:?}", tool_call);
|
||||||
|
pending_tool_calls.push(tool_call);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"content_block_delta" => {
|
||||||
|
// Extract text from delta
|
||||||
|
if let Some(delta) = event.data.get("delta") {
|
||||||
|
if let Some(text) = delta.get("text").and_then(|t| t.as_str()) {
|
||||||
|
let chunk = CompletionChunk {
|
||||||
|
content: text.to_string(),
|
||||||
|
finished: false,
|
||||||
|
tool_calls: None,
|
||||||
|
};
|
||||||
|
if tx.send(Ok(chunk)).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"content_block_stop" => {
|
||||||
|
// Check if we have a complete tool use block
|
||||||
|
if let Some(content_block) = event.data.get("content_block") {
|
||||||
|
if let Some(block_type) = content_block.get("type").and_then(|t| t.as_str()) {
|
||||||
|
if block_type == "tool_use" {
|
||||||
|
// Extract tool call information
|
||||||
|
if let (Some(id), Some(name), Some(input)) = (
|
||||||
|
content_block.get("id").and_then(|v| v.as_str()),
|
||||||
|
content_block.get("name").and_then(|v| v.as_str()),
|
||||||
|
content_block.get("input")
|
||||||
|
) {
|
||||||
|
let tool_call = ToolCall {
|
||||||
|
id: id.to_string(),
|
||||||
|
tool: name.to_string(),
|
||||||
|
args: input.clone(),
|
||||||
|
};
|
||||||
|
pending_tool_calls.push(tool_call);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
debug!("Content block finished");
|
||||||
|
}
|
||||||
|
"message_stop" => {
|
||||||
|
// Send any pending tool calls first
|
||||||
|
if !pending_tool_calls.is_empty() {
|
||||||
|
let tool_chunk = CompletionChunk {
|
||||||
|
content: String::new(),
|
||||||
|
finished: false,
|
||||||
|
tool_calls: Some(pending_tool_calls.clone()),
|
||||||
|
};
|
||||||
|
let _ = tx.send(Ok(tool_chunk)).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message finished
|
||||||
|
let final_chunk = CompletionChunk {
|
||||||
|
content: String::new(),
|
||||||
|
finished: true,
|
||||||
|
tool_calls: None,
|
||||||
|
};
|
||||||
|
let _ = tx.send(Ok(final_chunk)).await;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
debug!("Unhandled event type: {}", event.event_type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
debug!("Failed to parse streaming event: {} - Data: {}", e, data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
Ok(ReceiverStream::new(rx))
|
Ok(ReceiverStream::new(rx))
|
||||||
}
|
}
|
||||||
@@ -167,4 +563,8 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
fn model(&self) -> &str {
|
fn model(&self) -> &str {
|
||||||
&self.model
|
&self.model
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn has_native_tool_calling(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,12 +8,12 @@ use llama_cpp::{
|
|||||||
LlamaModel, LlamaParams, LlamaSession, SessionParams,
|
LlamaModel, LlamaParams, LlamaSession, SessionParams,
|
||||||
};
|
};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::AtomicBool;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tokio_stream::wrappers::ReceiverStream;
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info};
|
||||||
|
|
||||||
pub struct EmbeddedProvider {
|
pub struct EmbeddedProvider {
|
||||||
model: Arc<LlamaModel>,
|
model: Arc<LlamaModel>,
|
||||||
@@ -129,6 +129,9 @@ impl EmbeddedProvider {
|
|||||||
debug!("Context calculation: prompt_tokens={}, context_length={}, available_tokens={}, dynamic_max_tokens={}",
|
debug!("Context calculation: prompt_tokens={}, context_length={}, available_tokens={}, dynamic_max_tokens={}",
|
||||||
prompt_tokens, self.context_length, available_tokens, dynamic_max_tokens);
|
prompt_tokens, self.context_length, available_tokens, dynamic_max_tokens);
|
||||||
|
|
||||||
|
// Get stop sequences before entering the closure
|
||||||
|
let stop_sequences = self.get_stop_sequences();
|
||||||
|
|
||||||
// Add timeout to the entire operation
|
// Add timeout to the entire operation
|
||||||
let timeout_duration = std::time::Duration::from_secs(30); // Increased timeout for larger contexts
|
let timeout_duration = std::time::Duration::from_secs(30); // Increased timeout for larger contexts
|
||||||
|
|
||||||
@@ -202,8 +205,16 @@ impl EmbeddedProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Stop on completion markers
|
// Stop on completion markers
|
||||||
if generated_text.contains("</s>") || generated_text.contains("[/INST]") {
|
let mut hit_stop = false;
|
||||||
debug!("Hit CodeLlama stop sequence at {} tokens", token_count);
|
for stop_seq in &stop_sequences {
|
||||||
|
if generated_text.contains(stop_seq) {
|
||||||
|
debug!("Hit stop sequence '{}' at {} tokens", stop_seq, token_count);
|
||||||
|
hit_stop = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hit_stop {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -213,7 +224,8 @@ impl EmbeddedProvider {
|
|||||||
token_count,
|
token_count,
|
||||||
start_time.elapsed()
|
start_time.elapsed()
|
||||||
);
|
);
|
||||||
Ok((generated_text.trim().to_string(), token_count))
|
|
||||||
|
Ok((generated_text, token_count))
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
@@ -226,7 +238,8 @@ impl EmbeddedProvider {
|
|||||||
"Completed generation: {} tokens (dynamic limit was {})",
|
"Completed generation: {} tokens (dynamic limit was {})",
|
||||||
token_count, dynamic_max_tokens
|
token_count, dynamic_max_tokens
|
||||||
);
|
);
|
||||||
Ok(text)
|
// Clean stop sequences from the generated text after the closure
|
||||||
|
Ok(self.clean_stop_sequences(&text))
|
||||||
}
|
}
|
||||||
Err(e) => Err(e),
|
Err(e) => Err(e),
|
||||||
},
|
},
|
||||||
@@ -245,6 +258,78 @@ impl EmbeddedProvider {
|
|||||||
// This is conservative - actual tokenization might be different
|
// This is conservative - actual tokenization might be different
|
||||||
(text.len() as f32 / 4.0).ceil() as u32
|
(text.len() as f32 / 4.0).ceil() as u32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper function to get stop sequences based on model type
|
||||||
|
fn get_stop_sequences(&self) -> Vec<&'static str> {
|
||||||
|
// Determine model type from model_name
|
||||||
|
let model_name_lower = self.model_name.to_lowercase();
|
||||||
|
|
||||||
|
if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama") {
|
||||||
|
vec![
|
||||||
|
"</s>", // End of sequence
|
||||||
|
"[/INST]", // End of instruction
|
||||||
|
"<</SYS>>", // End of system message
|
||||||
|
"[INST]", // Start of new instruction (shouldn't appear in response)
|
||||||
|
"<<SYS>>", // Start of system (shouldn't appear in response)
|
||||||
|
]
|
||||||
|
} else if model_name_lower.contains("llama") {
|
||||||
|
vec![
|
||||||
|
"</s>", // End of sequence
|
||||||
|
"[/INST]", // End of instruction
|
||||||
|
"<</SYS>>", // End of system message
|
||||||
|
"### Human:", // Conversation format
|
||||||
|
"### Assistant:", // Conversation format
|
||||||
|
"[INST]", // Start of new instruction
|
||||||
|
]
|
||||||
|
} else if model_name_lower.contains("mistral") {
|
||||||
|
vec![
|
||||||
|
"</s>", // End of sequence
|
||||||
|
"[/INST]", // End of instruction
|
||||||
|
"<|im_end|>", // ChatML format
|
||||||
|
]
|
||||||
|
} else if model_name_lower.contains("vicuna") || model_name_lower.contains("wizard") {
|
||||||
|
vec![
|
||||||
|
"### Human:", // Conversation format
|
||||||
|
"### Assistant:", // Conversation format
|
||||||
|
"USER:", // Alternative format
|
||||||
|
"ASSISTANT:", // Alternative format
|
||||||
|
"</s>", // End of sequence
|
||||||
|
]
|
||||||
|
} else if model_name_lower.contains("alpaca") {
|
||||||
|
vec![
|
||||||
|
"### Instruction:", // Alpaca format
|
||||||
|
"### Response:", // Alpaca format
|
||||||
|
"### Input:", // Alpaca format
|
||||||
|
"</s>", // End of sequence
|
||||||
|
]
|
||||||
|
} else {
|
||||||
|
// Generic/unknown model - use common stop sequences
|
||||||
|
vec![
|
||||||
|
"</s>", // Most common end sequence
|
||||||
|
"<|endoftext|>", // GPT-style
|
||||||
|
"<|im_end|>", // ChatML
|
||||||
|
"### Human:", // Common conversation format
|
||||||
|
"### Assistant:", // Common conversation format
|
||||||
|
"[/INST]", // Instruction format
|
||||||
|
"<</SYS>>", // System format
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to clean up stop sequences from generated text
|
||||||
|
fn clean_stop_sequences(&self, text: &str) -> String {
|
||||||
|
let mut cleaned = text.to_string();
|
||||||
|
let stop_sequences = self.get_stop_sequences();
|
||||||
|
|
||||||
|
for stop_seq in &stop_sequences {
|
||||||
|
if let Some(pos) = cleaned.find(stop_seq) {
|
||||||
|
cleaned.truncate(pos);
|
||||||
|
break; // Only remove the first occurrence to avoid over-truncation
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cleaned.trim().to_string()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
@@ -334,6 +419,17 @@ impl LLMProvider for EmbeddedProvider {
|
|||||||
let mut accumulated_text = String::new();
|
let mut accumulated_text = String::new();
|
||||||
let mut token_count = 0;
|
let mut token_count = 0;
|
||||||
|
|
||||||
|
// Get stop sequences dynamically based on model type
|
||||||
|
// We need to create a temporary EmbeddedProvider instance to access the method
|
||||||
|
// Since we can't access self in the spawned task, we'll use a static approach
|
||||||
|
let stop_sequences = if prompt.contains("[INST]") || prompt.contains("<<SYS>>") {
|
||||||
|
// Llama/CodeLlama format detected
|
||||||
|
vec!["</s>", "[/INST]", "<</SYS>>", "[INST]", "<<SYS>>", "### Human:", "### Assistant:"]
|
||||||
|
} else {
|
||||||
|
// Generic format
|
||||||
|
vec!["</s>", "<|endoftext|>", "<|im_end|>", "### Human:", "### Assistant:", "[/INST]", "<</SYS>>"]
|
||||||
|
};
|
||||||
|
|
||||||
// Stream tokens with proper limits
|
// Stream tokens with proper limits
|
||||||
while let Some(token) = completion_handle.next_token() {
|
while let Some(token) = completion_handle.next_token() {
|
||||||
let token_string = session.model().token_to_piece(token);
|
let token_string = session.model().token_to_piece(token);
|
||||||
@@ -341,13 +437,52 @@ impl LLMProvider for EmbeddedProvider {
|
|||||||
accumulated_text.push_str(&token_string);
|
accumulated_text.push_str(&token_string);
|
||||||
token_count += 1;
|
token_count += 1;
|
||||||
|
|
||||||
let chunk = CompletionChunk {
|
// Check if we've hit a stop sequence
|
||||||
content: token_string.clone(),
|
let mut hit_stop = false;
|
||||||
finished: false,
|
for stop_seq in &stop_sequences {
|
||||||
};
|
if accumulated_text.contains(stop_seq) {
|
||||||
|
debug!("Hit stop sequence in streaming: {}", stop_seq);
|
||||||
|
hit_stop = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if tx.blocking_send(Ok(chunk)).is_err() {
|
if hit_stop {
|
||||||
break; // Receiver dropped
|
// Don't send the token that contains the stop sequence
|
||||||
|
// Instead, send only the part before the stop sequence
|
||||||
|
let mut clean_accumulated = accumulated_text.clone();
|
||||||
|
for stop_seq in &stop_sequences {
|
||||||
|
if let Some(pos) = clean_accumulated.find(stop_seq) {
|
||||||
|
clean_accumulated.truncate(pos);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate what part we haven't sent yet
|
||||||
|
let already_sent_len = accumulated_text.len() - token_string.len();
|
||||||
|
if clean_accumulated.len() > already_sent_len {
|
||||||
|
let remaining_to_send = &clean_accumulated[already_sent_len..];
|
||||||
|
if !remaining_to_send.is_empty() {
|
||||||
|
let chunk = CompletionChunk {
|
||||||
|
content: remaining_to_send.to_string(),
|
||||||
|
finished: false,
|
||||||
|
tool_calls: None,
|
||||||
|
};
|
||||||
|
let _ = tx.blocking_send(Ok(chunk));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
// Normal token, send it
|
||||||
|
let chunk = CompletionChunk {
|
||||||
|
content: token_string.clone(),
|
||||||
|
finished: false,
|
||||||
|
tool_calls: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
if tx.blocking_send(Ok(chunk)).is_err() {
|
||||||
|
break; // Receiver dropped
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enforce token limit
|
// Enforce token limit
|
||||||
@@ -355,22 +490,13 @@ impl LLMProvider for EmbeddedProvider {
|
|||||||
debug!("Reached max token limit in streaming: {}", max_tokens);
|
debug!("Reached max token limit in streaming: {}", max_tokens);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop if we hit common stop sequences
|
|
||||||
if accumulated_text.contains("### Human")
|
|
||||||
|| accumulated_text.contains("### System")
|
|
||||||
|| accumulated_text.contains("<|end|>")
|
|
||||||
|| accumulated_text.contains("</s>")
|
|
||||||
{
|
|
||||||
debug!("Hit stop sequence in streaming, stopping generation");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send final chunk
|
// Send final chunk
|
||||||
let final_chunk = CompletionChunk {
|
let final_chunk = CompletionChunk {
|
||||||
content: String::new(),
|
content: String::new(),
|
||||||
finished: true,
|
finished: true,
|
||||||
|
tool_calls: None,
|
||||||
};
|
};
|
||||||
let _ = tx.blocking_send(Ok(final_chunk));
|
let _ = tx.blocking_send(Ok(final_chunk));
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -140,6 +140,7 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
let chunk = CompletionChunk {
|
let chunk = CompletionChunk {
|
||||||
content: completion.content,
|
content: completion.content,
|
||||||
finished: true,
|
finished: true,
|
||||||
|
tool_calls: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
tx.send(Ok(chunk)).await.map_err(|_| anyhow::anyhow!("Failed to send chunk"))?;
|
tx.send(Ok(chunk)).await.map_err(|_| anyhow::anyhow!("Failed to send chunk"))?;
|
||||||
|
|||||||
@@ -16,6 +16,11 @@ pub trait LLMProvider: Send + Sync {
|
|||||||
|
|
||||||
/// Get the model name
|
/// Get the model name
|
||||||
fn model(&self) -> &str;
|
fn model(&self) -> &str;
|
||||||
|
|
||||||
|
/// Check if the provider supports native tool calling
|
||||||
|
fn has_native_tool_calling(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -60,6 +65,14 @@ pub type CompletionStream = tokio_stream::wrappers::ReceiverStream<Result<Comple
|
|||||||
pub struct CompletionChunk {
|
pub struct CompletionChunk {
|
||||||
pub content: String,
|
pub content: String,
|
||||||
pub finished: bool,
|
pub finished: bool,
|
||||||
|
pub tool_calls: Option<Vec<ToolCall>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ToolCall {
|
||||||
|
pub id: String,
|
||||||
|
pub tool: String,
|
||||||
|
pub args: serde_json::Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Provider registry for managing multiple LLM providers
|
/// Provider registry for managing multiple LLM providers
|
||||||
|
|||||||
Reference in New Issue
Block a user