feat(embedded): add GLM tool format adapter for code fence stripping
GLM-4 models wrap tool calls in markdown code fences and inline backticks, which prevents the streaming parser from detecting them. This adapter: - Strips ```json and ``` code fence markers during streaming - Strips inline backticks from tool call JSON - Handles chunked streaming correctly (buffers potential fence lines) - Transforms GLM native format (<|assistant|>tool_name) to g3 JSON format Also refactors embedded provider into module structure: - embedded/mod.rs - module exports - embedded/provider.rs - main EmbeddedProvider (moved from embedded.rs) - embedded/adapters/mod.rs - ToolFormatAdapter trait - embedded/adapters/glm.rs - GLM-specific adapter Includes 22 unit tests covering edge cases like nested JSON in strings, chunk boundary handling, and false pattern detection. Updates README to show GLM-4 9B now works (⭐⭐) for agentic tasks.
This commit is contained in:
733
crates/g3-providers/src/embedded/adapters/glm.rs
Normal file
733
crates/g3-providers/src/embedded/adapters/glm.rs
Normal file
@@ -0,0 +1,733 @@
|
||||
//! GLM/Z-AI tool format adapter
|
||||
//!
|
||||
//! GLM models can use two tool calling formats:
|
||||
//!
|
||||
//! 1. Native format:
|
||||
//! ```text
|
||||
//! <|assistant|>tool_name
|
||||
//! {"arg": "value"}
|
||||
//! ```
|
||||
//!
|
||||
//! 2. Code-fenced JSON (when following system prompt instructions):
|
||||
//! ```text
|
||||
//! ```json
|
||||
//! {"tool": "shell", "args": {"command": "ls"}}
|
||||
//! ```
|
||||
//! ```
|
||||
//!
|
||||
//! This adapter handles both formats and strips code fences when present.
|
||||
|
||||
use super::{AdapterOutput, ToolFormatAdapter};
|
||||
|
||||
/// Safety limits to prevent unbounded buffering
|
||||
const MAX_PATTERN_BUFFER: usize = 20; // `<|assistant|>` is 13 chars
|
||||
const MAX_TOOL_NAME: usize = 64;
|
||||
const MAX_JSON_BUFFER: usize = 65536; // 64KB
|
||||
const MAX_NEWLINES_BEFORE_JSON: usize = 2;
|
||||
|
||||
/// The pattern that indicates a tool call in GLM format
|
||||
const ASSISTANT_PATTERN: &str = "<|assistant|>";
|
||||
|
||||
/// Parser state for the main state machine
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum ParseState {
|
||||
/// Normal prose, watching for `<|assistant|>`
|
||||
Prose,
|
||||
/// Saw start of potential pattern (e.g., "<|"), buffering to confirm
|
||||
MaybePattern,
|
||||
/// Confirmed `<|assistant|>`, now reading tool name until newline
|
||||
InToolName,
|
||||
/// Got tool name, waiting for `{` to start JSON (allowing whitespace/newlines)
|
||||
AwaitingJson { tool_name: String, newline_count: usize },
|
||||
/// Inside JSON body, tracking depth to find end
|
||||
InToolJson { tool_name: String },
|
||||
}
|
||||
|
||||
/// State for JSON parsing (to handle strings correctly)
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
enum JsonState {
|
||||
/// Normal JSON, counting braces
|
||||
Normal,
|
||||
/// Inside a string literal, ignore braces
|
||||
InString,
|
||||
/// Just saw backslash in string, next char is escaped
|
||||
InStringEscape,
|
||||
}
|
||||
|
||||
/// Adapter for GLM/Z-AI model tool calling format
|
||||
#[derive(Debug)]
|
||||
pub struct GlmToolAdapter {
|
||||
/// Buffer for accumulating content
|
||||
buffer: String,
|
||||
/// Buffer for current line (to detect code fences)
|
||||
line_buffer: String,
|
||||
/// Whether we're currently inside a code fence
|
||||
in_code_fence: bool,
|
||||
/// Current parse state
|
||||
state: ParseState,
|
||||
/// JSON parsing state (when in InToolJson)
|
||||
json_state: JsonState,
|
||||
/// Brace depth for JSON parsing
|
||||
json_depth: i32,
|
||||
/// Content to emit that's been confirmed as prose
|
||||
pending_emit: String,
|
||||
}
|
||||
|
||||
impl GlmToolAdapter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buffer: String::new(),
|
||||
line_buffer: String::new(),
|
||||
in_code_fence: false,
|
||||
state: ParseState::Prose,
|
||||
json_state: JsonState::Normal,
|
||||
json_depth: 0,
|
||||
pending_emit: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a character for code fence detection (streaming-safe)
|
||||
/// Returns the string to emit (empty if content should be suppressed)
|
||||
fn process_for_code_fence(&mut self, c: char) -> String {
|
||||
if c == '\n' {
|
||||
// End of line - check if it's a code fence
|
||||
let trimmed = self.line_buffer.trim();
|
||||
if trimmed.starts_with("```") {
|
||||
let after_fence = trimmed.trim_start_matches('`').trim();
|
||||
if after_fence.is_empty() || after_fence.chars().all(|c| c.is_ascii_alphanumeric()) {
|
||||
// This is a code fence marker line - suppress it
|
||||
self.line_buffer.clear();
|
||||
return String::new(); // Don't emit anything for fence lines
|
||||
}
|
||||
}
|
||||
// Not a fence line - just emit the newline
|
||||
// (buffered content was already emitted char-by-char)
|
||||
self.line_buffer.clear();
|
||||
c.to_string()
|
||||
} else {
|
||||
self.line_buffer.push(c);
|
||||
// Only suppress output if the line looks like it could be a code fence
|
||||
// A code fence line starts with optional whitespace then ```
|
||||
let trimmed = self.line_buffer.trim_start();
|
||||
if trimmed.starts_with('`') && trimmed.len() <= 10 {
|
||||
// Potentially a fence marker - buffer until we see newline
|
||||
String::new()
|
||||
} else {
|
||||
// Not a fence - emit the entire buffer (which includes current char)
|
||||
// and clear it since we've emitted everything
|
||||
let result = std::mem::take(&mut self.line_buffer);
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Strip markdown code fence markers from output
|
||||
///
|
||||
/// GLM models sometimes wrap tool calls in code fences like:
|
||||
/// ```json
|
||||
/// {"tool": "shell", ...}
|
||||
/// ```
|
||||
///
|
||||
/// This strips those markers so the JSON can be parsed as a tool call.
|
||||
fn strip_code_fences(text: &str) -> String {
|
||||
text.lines()
|
||||
.filter_map(|line| {
|
||||
let trimmed = line.trim();
|
||||
// Filter out lines that are just code fence markers (with optional language)
|
||||
if trimmed.starts_with("```") {
|
||||
// Check if there's content after the fence marker on the same line
|
||||
let after_fence = trimmed.trim_start_matches('`').trim();
|
||||
if after_fence.is_empty() || after_fence.chars().all(|c| c.is_ascii_alphanumeric()) {
|
||||
// Just a fence marker (possibly with language like "json"), skip it
|
||||
return None;
|
||||
}
|
||||
}
|
||||
Some(line)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
/// Strip inline code backticks from text
|
||||
///
|
||||
/// GLM models sometimes wrap tool calls in inline backticks like:
|
||||
/// `{"tool": "shell", ...}`
|
||||
fn strip_inline_backticks(text: &str) -> String {
|
||||
let trimmed = text.trim();
|
||||
if trimmed.starts_with('`') && trimmed.ends_with('`') && !trimmed.starts_with("```") {
|
||||
trimmed[1..trimmed.len()-1].to_string()
|
||||
} else {
|
||||
text.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a string is a valid tool name
|
||||
/// Pattern: starts with letter or underscore, followed by alphanumeric or underscore
|
||||
fn is_valid_tool_name(name: &str) -> bool {
|
||||
if name.is_empty() || name.len() > MAX_TOOL_NAME {
|
||||
return false;
|
||||
}
|
||||
let mut chars = name.chars();
|
||||
match chars.next() {
|
||||
Some(c) if c.is_ascii_alphabetic() || c == '_' => {}
|
||||
_ => return false,
|
||||
}
|
||||
chars.all(|c| c.is_ascii_alphanumeric() || c == '_')
|
||||
}
|
||||
|
||||
/// Process a single character in Prose state
|
||||
fn process_prose_char(&mut self, c: char) {
|
||||
// First, filter through code fence detection
|
||||
let filtered = self.process_for_code_fence(c);
|
||||
for filtered_c in filtered.chars() {
|
||||
if filtered_c == '<' {
|
||||
// Potential start of pattern
|
||||
self.buffer.push(filtered_c);
|
||||
self.state = ParseState::MaybePattern;
|
||||
} else {
|
||||
self.pending_emit.push(filtered_c);
|
||||
}
|
||||
}
|
||||
// If empty string, the character is being buffered for code fence detection
|
||||
}
|
||||
|
||||
/// Process a single character in MaybePattern state
|
||||
fn process_maybe_pattern_char(&mut self, c: char) {
|
||||
self.buffer.push(c);
|
||||
|
||||
// Check if buffer matches start of pattern
|
||||
if ASSISTANT_PATTERN.starts_with(&self.buffer) {
|
||||
// Still could be the pattern
|
||||
if self.buffer == ASSISTANT_PATTERN {
|
||||
// Complete pattern match!
|
||||
self.buffer.clear();
|
||||
self.state = ParseState::InToolName;
|
||||
}
|
||||
// else: keep buffering
|
||||
} else {
|
||||
// Not the pattern, emit buffer as prose
|
||||
self.pending_emit.push_str(&self.buffer);
|
||||
self.buffer.clear();
|
||||
self.state = ParseState::Prose;
|
||||
}
|
||||
|
||||
// Safety: if buffer gets too long, it's not our pattern
|
||||
if self.buffer.len() > MAX_PATTERN_BUFFER {
|
||||
self.pending_emit.push_str(&self.buffer);
|
||||
self.buffer.clear();
|
||||
self.state = ParseState::Prose;
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a single character in InToolName state
|
||||
fn process_tool_name_char(&mut self, c: char) {
|
||||
if c == '\n' {
|
||||
// End of tool name
|
||||
let tool_name = self.buffer.trim().to_string();
|
||||
self.buffer.clear();
|
||||
|
||||
if Self::is_valid_tool_name(&tool_name) {
|
||||
self.state = ParseState::AwaitingJson {
|
||||
tool_name,
|
||||
newline_count: 1,
|
||||
};
|
||||
} else {
|
||||
// Invalid tool name, emit as prose
|
||||
self.pending_emit.push_str(ASSISTANT_PATTERN);
|
||||
self.pending_emit.push_str(&tool_name);
|
||||
self.pending_emit.push(c);
|
||||
self.state = ParseState::Prose;
|
||||
}
|
||||
} else if c.is_whitespace() && self.buffer.is_empty() {
|
||||
// Skip leading whitespace after <|assistant|>
|
||||
} else {
|
||||
self.buffer.push(c);
|
||||
|
||||
// Safety: tool name too long
|
||||
if self.buffer.len() > MAX_TOOL_NAME {
|
||||
self.pending_emit.push_str(ASSISTANT_PATTERN);
|
||||
self.pending_emit.push_str(&self.buffer);
|
||||
self.buffer.clear();
|
||||
self.state = ParseState::Prose;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a single character in AwaitingJson state
|
||||
fn process_awaiting_json_char(&mut self, c: char, tool_name: String, newline_count: usize) {
|
||||
if c == '{' {
|
||||
// Start of JSON!
|
||||
self.buffer.push(c);
|
||||
self.json_depth = 1;
|
||||
self.json_state = JsonState::Normal;
|
||||
self.state = ParseState::InToolJson { tool_name };
|
||||
} else if c == '\n' {
|
||||
let new_count = newline_count + 1;
|
||||
if new_count > MAX_NEWLINES_BEFORE_JSON {
|
||||
// Too many newlines, not a tool call
|
||||
self.pending_emit.push_str(ASSISTANT_PATTERN);
|
||||
self.pending_emit.push_str(&tool_name);
|
||||
for _ in 0..new_count {
|
||||
self.pending_emit.push('\n');
|
||||
}
|
||||
self.state = ParseState::Prose;
|
||||
} else {
|
||||
self.state = ParseState::AwaitingJson {
|
||||
tool_name,
|
||||
newline_count: new_count,
|
||||
};
|
||||
}
|
||||
} else if c.is_whitespace() {
|
||||
// Skip whitespace while waiting for JSON
|
||||
self.state = ParseState::AwaitingJson {
|
||||
tool_name,
|
||||
newline_count,
|
||||
};
|
||||
} else {
|
||||
// Non-JSON character, not a tool call
|
||||
self.pending_emit.push_str(ASSISTANT_PATTERN);
|
||||
self.pending_emit.push_str(&tool_name);
|
||||
self.pending_emit.push('\n');
|
||||
self.pending_emit.push(c);
|
||||
self.state = ParseState::Prose;
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a single character in InToolJson state
|
||||
fn process_json_char(&mut self, c: char, tool_name: String) -> Option<String> {
|
||||
self.buffer.push(c);
|
||||
|
||||
// Update JSON state machine
|
||||
match self.json_state {
|
||||
JsonState::Normal => {
|
||||
match c {
|
||||
'{' => self.json_depth += 1,
|
||||
'}' => {
|
||||
self.json_depth -= 1;
|
||||
if self.json_depth == 0 {
|
||||
// JSON complete!
|
||||
let json_args = self.buffer.clone();
|
||||
self.buffer.clear();
|
||||
self.state = ParseState::Prose;
|
||||
|
||||
// Transform to g3 format
|
||||
let transformed = format!(
|
||||
"{{\"tool\": \"{}\", \"args\": {}}}",
|
||||
tool_name, json_args
|
||||
);
|
||||
return Some(transformed);
|
||||
}
|
||||
}
|
||||
'"' => self.json_state = JsonState::InString,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
JsonState::InString => {
|
||||
match c {
|
||||
'\\' => self.json_state = JsonState::InStringEscape,
|
||||
'"' => self.json_state = JsonState::Normal,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
JsonState::InStringEscape => {
|
||||
// Any character after backslash, return to InString
|
||||
self.json_state = JsonState::InString;
|
||||
}
|
||||
}
|
||||
|
||||
// Safety: JSON buffer too large
|
||||
if self.buffer.len() > MAX_JSON_BUFFER {
|
||||
// Emit as malformed - let downstream handle it
|
||||
self.pending_emit.push_str(ASSISTANT_PATTERN);
|
||||
self.pending_emit.push_str(&tool_name);
|
||||
self.pending_emit.push('\n');
|
||||
self.pending_emit.push_str(&self.buffer);
|
||||
self.buffer.clear();
|
||||
self.state = ParseState::Prose;
|
||||
self.json_state = JsonState::Normal;
|
||||
self.json_depth = 0;
|
||||
}
|
||||
|
||||
// Keep state for next iteration
|
||||
self.state = ParseState::InToolJson { tool_name };
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GlmToolAdapter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolFormatAdapter for GlmToolAdapter {
|
||||
fn handles(&self, model_type: &str) -> bool {
|
||||
model_type.contains("glm")
|
||||
}
|
||||
|
||||
fn process_chunk(&mut self, chunk: &str) -> AdapterOutput {
|
||||
let mut has_tool_call = false;
|
||||
|
||||
for c in chunk.chars() {
|
||||
match self.state.clone() {
|
||||
ParseState::Prose => {
|
||||
self.process_prose_char(c);
|
||||
}
|
||||
ParseState::MaybePattern => {
|
||||
self.process_maybe_pattern_char(c);
|
||||
}
|
||||
ParseState::InToolName => {
|
||||
self.process_tool_name_char(c);
|
||||
}
|
||||
ParseState::AwaitingJson { tool_name, newline_count } => {
|
||||
self.process_awaiting_json_char(c, tool_name, newline_count);
|
||||
}
|
||||
ParseState::InToolJson { tool_name } => {
|
||||
if let Some(transformed) = self.process_json_char(c, tool_name) {
|
||||
self.pending_emit.push('\n');
|
||||
self.pending_emit.push_str(&transformed);
|
||||
has_tool_call = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return accumulated emit content, stripping any code fence markers
|
||||
let raw_emit = std::mem::take(&mut self.pending_emit);
|
||||
let stripped_fences = Self::strip_code_fences(&raw_emit);
|
||||
let emit = Self::strip_inline_backticks(&stripped_fences);
|
||||
AdapterOutput {
|
||||
emit: emit.to_string(),
|
||||
has_tool_call,
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> AdapterOutput {
|
||||
let mut emit = std::mem::take(&mut self.pending_emit);
|
||||
|
||||
// Emit any buffered content as-is
|
||||
match &self.state {
|
||||
ParseState::Prose => {
|
||||
// Nothing extra to emit
|
||||
}
|
||||
ParseState::MaybePattern => {
|
||||
emit.push_str(&self.buffer);
|
||||
}
|
||||
ParseState::InToolName => {
|
||||
emit.push_str(ASSISTANT_PATTERN);
|
||||
emit.push_str(&self.buffer);
|
||||
}
|
||||
ParseState::AwaitingJson { tool_name, newline_count } => {
|
||||
emit.push_str(ASSISTANT_PATTERN);
|
||||
emit.push_str(tool_name);
|
||||
for _ in 0..*newline_count {
|
||||
emit.push('\n');
|
||||
}
|
||||
}
|
||||
ParseState::InToolJson { tool_name } => {
|
||||
emit.push_str(ASSISTANT_PATTERN);
|
||||
emit.push_str(tool_name);
|
||||
emit.push('\n');
|
||||
emit.push_str(&self.buffer);
|
||||
}
|
||||
}
|
||||
|
||||
// Flush any remaining line buffer content (if not a code fence)
|
||||
if !self.line_buffer.is_empty() {
|
||||
let trimmed = self.line_buffer.trim();
|
||||
let is_fence = trimmed.starts_with("```") &&
|
||||
(trimmed.trim_start_matches('`').trim().is_empty() ||
|
||||
trimmed.trim_start_matches('`').trim().chars().all(|c| c.is_ascii_alphanumeric()));
|
||||
if !is_fence {
|
||||
emit.push_str(&self.line_buffer);
|
||||
}
|
||||
}
|
||||
|
||||
self.reset();
|
||||
|
||||
// Strip code fences and inline backticks from final output
|
||||
let stripped_fences = Self::strip_code_fences(&emit);
|
||||
let stripped = Self::strip_inline_backticks(&stripped_fences);
|
||||
AdapterOutput {
|
||||
emit: stripped,
|
||||
has_tool_call: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.buffer.clear();
|
||||
self.line_buffer.clear();
|
||||
self.in_code_fence = false;
|
||||
self.state = ParseState::Prose;
|
||||
self.json_state = JsonState::Normal;
|
||||
self.json_depth = 0;
|
||||
self.pending_emit.clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_handles_glm_models() {
|
||||
let adapter = GlmToolAdapter::new();
|
||||
assert!(adapter.handles("glm4"));
|
||||
assert!(adapter.handles("glm"));
|
||||
assert!(adapter.handles("some-glm-variant"));
|
||||
assert!(!adapter.handles("qwen"));
|
||||
assert!(!adapter.handles("llama"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_tool_names() {
|
||||
assert!(GlmToolAdapter::is_valid_tool_name("shell"));
|
||||
assert!(GlmToolAdapter::is_valid_tool_name("read_file"));
|
||||
assert!(GlmToolAdapter::is_valid_tool_name("_private"));
|
||||
assert!(GlmToolAdapter::is_valid_tool_name("tool123"));
|
||||
assert!(!GlmToolAdapter::is_valid_tool_name(""));
|
||||
assert!(!GlmToolAdapter::is_valid_tool_name("123tool"));
|
||||
assert!(!GlmToolAdapter::is_valid_tool_name("tool-name"));
|
||||
assert!(!GlmToolAdapter::is_valid_tool_name("tool name"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_basic_tool_call() {
|
||||
let mut adapter = GlmToolAdapter::new();
|
||||
|
||||
let input = "Let me list files.<|assistant|>shell\n{\"command\": \"ls\"}";
|
||||
let output = adapter.process_chunk(input);
|
||||
|
||||
assert!(output.has_tool_call);
|
||||
assert!(output.emit.contains("Let me list files."));
|
||||
assert!(output.emit.contains(r#"{"tool": "shell", "args": {"command": "ls"}}"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_chunked() {
|
||||
let mut adapter = GlmToolAdapter::new();
|
||||
|
||||
// Simulate chunked input
|
||||
let chunks = vec![
|
||||
"Let me ",
|
||||
"list.<|assis",
|
||||
"tant|>shell\n{\"co",
|
||||
"mmand\": \"ls\"}",
|
||||
];
|
||||
|
||||
let mut full_output = String::new();
|
||||
let mut found_tool = false;
|
||||
|
||||
for chunk in chunks {
|
||||
let output = adapter.process_chunk(chunk);
|
||||
full_output.push_str(&output.emit);
|
||||
if output.has_tool_call {
|
||||
found_tool = true;
|
||||
}
|
||||
}
|
||||
|
||||
let final_output = adapter.flush();
|
||||
full_output.push_str(&final_output.emit);
|
||||
|
||||
assert!(found_tool);
|
||||
assert!(full_output.contains("Let me list."));
|
||||
assert!(full_output.contains(r#"{"tool": "shell", "args": {"command": "ls"}}"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nested_json_in_string() {
|
||||
let mut adapter = GlmToolAdapter::new();
|
||||
|
||||
let input = r#"<|assistant|>shell
|
||||
{"command": "echo '{\"nested\": true}'"}
|
||||
Done."#;
|
||||
|
||||
let output = adapter.process_chunk(input);
|
||||
let final_output = adapter.flush();
|
||||
|
||||
assert!(output.has_tool_call);
|
||||
let full = format!("{}{}", output.emit, final_output.emit);
|
||||
assert!(full.contains(r#""args": {"command": "echo '{\"nested\": true}'"}}"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_escaped_quotes_in_string() {
|
||||
let mut adapter = GlmToolAdapter::new();
|
||||
|
||||
let input = r#"<|assistant|>shell
|
||||
{"command": "echo \"hello\""}
|
||||
Done."#;
|
||||
|
||||
let output = adapter.process_chunk(input);
|
||||
|
||||
assert!(output.has_tool_call);
|
||||
assert!(output.emit.contains(r#""args": {"command": "echo \"hello\""}"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_false_pattern_in_prose() {
|
||||
let mut adapter = GlmToolAdapter::new();
|
||||
|
||||
let input = "The format is <|assistant|>tool_name for GLM models.";
|
||||
let output = adapter.process_chunk(input);
|
||||
let final_output = adapter.flush();
|
||||
|
||||
// Should not detect as tool call since no JSON follows
|
||||
assert!(!output.has_tool_call);
|
||||
let full = format!("{}{}", output.emit, final_output.emit);
|
||||
assert!(full.contains("<|assistant|>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_tool_calls() {
|
||||
let mut adapter = GlmToolAdapter::new();
|
||||
|
||||
let input = r#"First:<|assistant|>shell
|
||||
{"command": "ls"}
|
||||
Second:<|assistant|>read_file
|
||||
{"path": "test.txt"}"#;
|
||||
|
||||
let output = adapter.process_chunk(input);
|
||||
|
||||
assert!(output.has_tool_call);
|
||||
assert!(output.emit.contains(r#"{"tool": "shell"#));
|
||||
assert!(output.emit.contains(r#"{"tool": "read_file"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_whitespace_before_json() {
|
||||
let mut adapter = GlmToolAdapter::new();
|
||||
|
||||
let input = "<|assistant|>shell\n {\"command\": \"ls\"}";
|
||||
let output = adapter.process_chunk(input);
|
||||
|
||||
assert!(output.has_tool_call);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extra_newline_before_json() {
|
||||
let mut adapter = GlmToolAdapter::new();
|
||||
|
||||
let input = "<|assistant|>shell\n\n{\"command\": \"ls\"}";
|
||||
let output = adapter.process_chunk(input);
|
||||
|
||||
assert!(output.has_tool_call);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_too_many_newlines_before_json() {
|
||||
let mut adapter = GlmToolAdapter::new();
|
||||
|
||||
let input = "<|assistant|>shell\n\n\n{\"command\": \"ls\"}";
|
||||
let output = adapter.process_chunk(input);
|
||||
let final_output = adapter.flush();
|
||||
|
||||
// Should not detect as tool call - too many newlines
|
||||
assert!(!output.has_tool_call);
|
||||
let full = format!("{}{}", output.emit, final_output.emit);
|
||||
assert!(full.contains("<|assistant|>shell"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_tool_name() {
|
||||
let mut adapter = GlmToolAdapter::new();
|
||||
|
||||
let input = "<|assistant|>123invalid\n{\"command\": \"ls\"}";
|
||||
let output = adapter.process_chunk(input);
|
||||
let final_output = adapter.flush();
|
||||
|
||||
// Should not detect as tool call - invalid name
|
||||
assert!(!output.has_tool_call);
|
||||
let full = format!("{}{}", output.emit, final_output.emit);
|
||||
assert!(full.contains("<|assistant|>123invalid"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stream_ends_mid_pattern() {
|
||||
let mut adapter = GlmToolAdapter::new();
|
||||
|
||||
let output = adapter.process_chunk("text<|assis");
|
||||
let final_output = adapter.flush();
|
||||
|
||||
assert!(!output.has_tool_call);
|
||||
let full = format!("{}{}", output.emit, final_output.emit);
|
||||
assert_eq!(full, "text<|assis");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stream_ends_mid_json() {
|
||||
let mut adapter = GlmToolAdapter::new();
|
||||
|
||||
let output = adapter.process_chunk("<|assistant|>shell\n{\"command\": \"ls");
|
||||
let final_output = adapter.flush();
|
||||
|
||||
assert!(!output.has_tool_call);
|
||||
let full = format!("{}{}", output.emit, final_output.emit);
|
||||
// Should emit the incomplete content
|
||||
assert!(full.contains("<|assistant|>shell"));
|
||||
assert!(full.contains("{\"command\": \"ls"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prose_with_angle_brackets() {
|
||||
let mut adapter = GlmToolAdapter::new();
|
||||
|
||||
let input = "Use <html> tags and <|other|> patterns.";
|
||||
let output = adapter.process_chunk(input);
|
||||
let final_output = adapter.flush();
|
||||
|
||||
assert!(!output.has_tool_call);
|
||||
let full = format!("{}{}", output.emit, final_output.emit);
|
||||
assert_eq!(full, input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset() {
|
||||
let mut adapter = GlmToolAdapter::new();
|
||||
|
||||
// Start processing but don't finish
|
||||
adapter.process_chunk("<|assistant|>shell\n{\"cmd");
|
||||
|
||||
// Reset
|
||||
adapter.reset();
|
||||
|
||||
// Should be back to clean state
|
||||
let output = adapter.process_chunk("Normal text");
|
||||
assert_eq!(output.emit, "Normal text");
|
||||
assert!(!output.has_tool_call);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_code_fences() {
|
||||
assert_eq!(
|
||||
GlmToolAdapter::strip_code_fences("```json\n{\"tool\": \"shell\"}\n```"),
|
||||
"{\"tool\": \"shell\"}"
|
||||
);
|
||||
assert_eq!(
|
||||
GlmToolAdapter::strip_code_fences("```\n{\"tool\": \"shell\"}\n```"),
|
||||
"{\"tool\": \"shell\"}"
|
||||
);
|
||||
assert_eq!(
|
||||
GlmToolAdapter::strip_code_fences("normal text"),
|
||||
"normal text"
|
||||
);
|
||||
assert_eq!(
|
||||
GlmToolAdapter::strip_code_fences("```json\ncode\n```\nmore text"),
|
||||
"code\nmore text"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_code_fenced_tool_call() {
|
||||
let mut adapter = GlmToolAdapter::new();
|
||||
|
||||
let input = "```json\n{\"tool\": \"shell\", \"args\": {\"command\": \"ls\"}}\n```";
|
||||
let output = adapter.process_chunk(input);
|
||||
let final_output = adapter.flush();
|
||||
|
||||
let full = format!("{}{}", output.emit, final_output.emit);
|
||||
// Should strip the code fences
|
||||
assert!(!full.contains("```"));
|
||||
assert!(full.contains("{\"tool\": \"shell\""));
|
||||
}
|
||||
98
crates/g3-providers/src/embedded/adapters/mod.rs
Normal file
98
crates/g3-providers/src/embedded/adapters/mod.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
//! Tool format adapters for embedded models
|
||||
//!
|
||||
//! Different model families use different formats for tool calling.
|
||||
//! Adapters transform model-specific formats to g3's standard JSON format:
|
||||
//! `{"tool": "name", "args": {...}}`
|
||||
//!
|
||||
//! This module provides:
|
||||
//! - `ToolFormatAdapter` trait for implementing format transformations
|
||||
//! - `GlmToolAdapter` for GLM/Z-AI models that use `<|assistant|>tool_name` format
|
||||
|
||||
mod glm;
|
||||
|
||||
pub use glm::GlmToolAdapter;
|
||||
|
||||
/// Output from processing a chunk through an adapter
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct AdapterOutput {
|
||||
/// Text safe to emit downstream (prose and/or complete tool calls)
|
||||
pub emit: String,
|
||||
/// True if a complete tool call was detected and transformed
|
||||
pub has_tool_call: bool,
|
||||
}
|
||||
|
||||
impl AdapterOutput {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn with_emit(emit: String) -> Self {
|
||||
Self {
|
||||
emit,
|
||||
has_tool_call: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_tool_call(emit: String) -> Self {
|
||||
Self {
|
||||
emit,
|
||||
has_tool_call: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for adapting model-specific tool call formats to g3's standard format
|
||||
///
|
||||
/// Adapters are stateful to handle streaming - they buffer incomplete patterns
|
||||
/// and emit complete chunks as soon as they're ready.
|
||||
pub trait ToolFormatAdapter: Send + Sync {
|
||||
/// Check if this adapter handles the given model type
|
||||
fn handles(&self, model_type: &str) -> bool;
|
||||
|
||||
/// Process a chunk of model output
|
||||
///
|
||||
/// The adapter may buffer content if it's in the middle of a potential pattern.
|
||||
/// Returns content that's safe to emit downstream.
|
||||
fn process_chunk(&mut self, chunk: &str) -> AdapterOutput;
|
||||
|
||||
/// Flush any remaining buffered content (call at end of stream)
|
||||
///
|
||||
/// This should emit any buffered content, even if incomplete.
|
||||
fn flush(&mut self) -> AdapterOutput;
|
||||
|
||||
/// Reset the adapter state (call between conversations)
|
||||
fn reset(&mut self);
|
||||
}
|
||||
|
||||
/// Create an adapter for the given model type, if one exists
|
||||
pub fn create_adapter_for_model(model_type: &str) -> Option<Box<dyn ToolFormatAdapter>> {
|
||||
let glm_adapter = GlmToolAdapter::new();
|
||||
if glm_adapter.handles(model_type) {
|
||||
return Some(Box::new(glm_adapter));
|
||||
}
|
||||
|
||||
// Add other adapters here as needed:
|
||||
// let mistral_adapter = MistralToolAdapter::new();
|
||||
// if mistral_adapter.handles(model_type) { ... }
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_create_adapter_for_glm() {
|
||||
assert!(create_adapter_for_model("glm4").is_some());
|
||||
assert!(create_adapter_for_model("glm").is_some());
|
||||
assert!(create_adapter_for_model("some-glm-variant").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_adapter_for_unknown() {
|
||||
assert!(create_adapter_for_model("qwen").is_none());
|
||||
assert!(create_adapter_for_model("llama").is_none());
|
||||
assert!(create_adapter_for_model("mistral").is_none());
|
||||
}
|
||||
}
|
||||
12
crates/g3-providers/src/embedded/mod.rs
Normal file
12
crates/g3-providers/src/embedded/mod.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
//! Embedded LLM provider using llama.cpp
|
||||
//!
|
||||
//! This module provides local model inference via llama.cpp with Metal acceleration.
|
||||
|
||||
pub mod adapters;
|
||||
mod provider;
|
||||
|
||||
// Re-export adapter types
|
||||
pub use adapters::{create_adapter_for_model, AdapterOutput, ToolFormatAdapter};
|
||||
|
||||
// Re-export the main provider
|
||||
pub use provider::EmbeddedProvider;
|
||||
813
crates/g3-providers/src/embedded/provider.rs
Normal file
813
crates/g3-providers/src/embedded/provider.rs
Normal file
@@ -0,0 +1,813 @@
|
||||
//! Embedded LLM provider using llama.cpp with Metal acceleration on macOS.
|
||||
//!
|
||||
//! Supports multiple model families with their native chat templates:
|
||||
//! - Qwen (ChatML format)
|
||||
//! - GLM-4 (ChatGLM4 format)
|
||||
//! - Mistral (Instruct format)
|
||||
//! - Llama/CodeLlama (Llama2 format)
|
||||
|
||||
use crate::{
|
||||
CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message, MessageRole,
|
||||
Usage,
|
||||
streaming::{make_final_chunk_with_reason, make_text_chunk},
|
||||
};
|
||||
use anyhow::Result;
|
||||
use llama_cpp_2::{
|
||||
context::LlamaContext,
|
||||
context::params::LlamaContextParams,
|
||||
llama_backend::LlamaBackend,
|
||||
llama_batch::LlamaBatch,
|
||||
model::{AddBos, LlamaModel, Special, params::LlamaModelParams},
|
||||
sampling::LlamaSampler,
|
||||
};
|
||||
use std::num::NonZeroU32;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tracing::{debug, error};
|
||||
|
||||
// ============================================================================
|
||||
// Global Backend
|
||||
// ============================================================================
|
||||
|
||||
/// Global llama.cpp backend - can only be initialized once per process
|
||||
static LLAMA_BACKEND: OnceLock<Arc<LlamaBackend>> = OnceLock::new();
|
||||
|
||||
/// Get or initialize the global llama.cpp backend
|
||||
fn get_or_init_backend() -> Result<Arc<LlamaBackend>> {
|
||||
if let Some(backend) = LLAMA_BACKEND.get() {
|
||||
return Ok(Arc::clone(backend));
|
||||
}
|
||||
|
||||
// Suppress llama.cpp's verbose logging to stderr
|
||||
suppress_llama_logging();
|
||||
|
||||
debug!("Initializing llama.cpp backend...");
|
||||
let backend = LlamaBackend::init()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to initialize llama.cpp backend: {:?}", e))?;
|
||||
|
||||
// Store it (ignore if another thread beat us to it)
|
||||
let _ = LLAMA_BACKEND.set(Arc::new(backend));
|
||||
Ok(Arc::clone(LLAMA_BACKEND.get().expect("backend was just set")))
|
||||
}
|
||||
|
||||
fn suppress_llama_logging() {
|
||||
unsafe {
|
||||
unsafe extern "C" fn void_log(
|
||||
_level: std::ffi::c_int,
|
||||
_text: *const std::os::raw::c_char,
|
||||
_user_data: *mut std::os::raw::c_void,
|
||||
) {
|
||||
// Intentionally empty
|
||||
}
|
||||
extern "C" {
|
||||
fn llama_log_set(
|
||||
log_callback: Option<
|
||||
unsafe extern "C" fn(
|
||||
std::ffi::c_int,
|
||||
*const std::os::raw::c_char,
|
||||
*mut std::os::raw::c_void,
|
||||
),
|
||||
>,
|
||||
user_data: *mut std::os::raw::c_void,
|
||||
);
|
||||
}
|
||||
llama_log_set(Some(void_log), std::ptr::null_mut());
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Provider Struct
|
||||
// ============================================================================
|
||||
|
||||
use super::adapters::create_adapter_for_model;
|
||||
|
||||
pub struct EmbeddedProvider {
|
||||
name: String,
|
||||
model: Arc<LlamaModel>,
|
||||
backend: Arc<LlamaBackend>,
|
||||
model_type: String,
|
||||
model_name: String,
|
||||
max_tokens: Option<u32>,
|
||||
temperature: f32,
|
||||
context_length: u32,
|
||||
threads: Option<u32>,
|
||||
}
|
||||
|
||||
impl EmbeddedProvider {
|
||||
/// Create a new embedded provider with default naming ("embedded").
|
||||
pub fn new(
|
||||
model_path: String,
|
||||
model_type: String,
|
||||
context_length: Option<u32>,
|
||||
max_tokens: Option<u32>,
|
||||
temperature: Option<f32>,
|
||||
gpu_layers: Option<u32>,
|
||||
threads: Option<u32>,
|
||||
) -> Result<Self> {
|
||||
Self::new_with_name(
|
||||
"embedded".to_string(),
|
||||
model_path,
|
||||
model_type,
|
||||
context_length,
|
||||
max_tokens,
|
||||
temperature,
|
||||
gpu_layers,
|
||||
threads,
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a new embedded provider with a custom name.
|
||||
pub fn new_with_name(
|
||||
name: String,
|
||||
model_path: String,
|
||||
model_type: String,
|
||||
context_length: Option<u32>,
|
||||
max_tokens: Option<u32>,
|
||||
temperature: Option<f32>,
|
||||
gpu_layers: Option<u32>,
|
||||
threads: Option<u32>,
|
||||
) -> Result<Self> {
|
||||
debug!("Loading embedded model from: {}", model_path);
|
||||
|
||||
let expanded_path = shellexpand::tilde(&model_path);
|
||||
let model_path_buf = PathBuf::from(expanded_path.as_ref());
|
||||
|
||||
if !model_path_buf.exists() {
|
||||
anyhow::bail!("Model file not found: {}", model_path_buf.display());
|
||||
}
|
||||
|
||||
let backend = get_or_init_backend()?;
|
||||
|
||||
let n_gpu_layers = gpu_layers.unwrap_or(99);
|
||||
let model_params = LlamaModelParams::default().with_n_gpu_layers(n_gpu_layers);
|
||||
debug!("Using {} GPU layers", n_gpu_layers);
|
||||
|
||||
debug!("Loading model...");
|
||||
let model = LlamaModel::load_from_file(&backend, &model_path_buf, &model_params)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to load model: {:?}", e))?;
|
||||
|
||||
let model_ctx_train = model.n_ctx_train();
|
||||
let context_size = context_length.unwrap_or(model_ctx_train);
|
||||
debug!(
|
||||
"Context length: {} (model trained: {}, configured: {:?})",
|
||||
context_size, model_ctx_train, context_length
|
||||
);
|
||||
|
||||
debug!("Successfully loaded {} model as '{}'", model_type, name);
|
||||
|
||||
Ok(Self {
|
||||
name,
|
||||
model: Arc::new(model),
|
||||
backend,
|
||||
model_type: model_type.to_lowercase(),
|
||||
model_name: format!("embedded-{}", model_type),
|
||||
max_tokens,
|
||||
temperature: temperature.unwrap_or(0.1),
|
||||
context_length: context_size,
|
||||
threads,
|
||||
})
|
||||
}
|
||||
|
||||
fn effective_max_tokens(&self) -> u32 {
|
||||
self.max_tokens
|
||||
.unwrap_or_else(|| std::cmp::min(4096, self.context_length / 4))
|
||||
}
|
||||
|
||||
/// Estimate token count from text (~4 chars per token)
|
||||
fn estimate_tokens(&self, text: &str) -> u32 {
|
||||
(text.len() as f32 / 4.0).ceil() as u32
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Chat Template Formatting
|
||||
// ============================================================================
|
||||
|
||||
impl EmbeddedProvider {
|
||||
/// Format messages according to the model's native chat template.
|
||||
fn format_messages(&self, messages: &[Message]) -> String {
|
||||
match self.model_type.as_str() {
|
||||
t if t.contains("glm") => format_glm4(messages),
|
||||
t if t.contains("qwen") => format_qwen(messages),
|
||||
t if t.contains("mistral") => format_mistral(messages),
|
||||
_ => format_llama(messages),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get stop sequences based on model type.
|
||||
fn get_stop_sequences(&self) -> &'static [&'static str] {
|
||||
get_stop_sequences_for_model(&self.model_type)
|
||||
}
|
||||
}
|
||||
|
||||
/// GLM-4 ChatGLM4 format: [gMASK]<sop><|role|>\ncontent
|
||||
fn format_glm4(messages: &[Message]) -> String {
|
||||
let mut out = String::from("[gMASK]<sop>");
|
||||
for msg in messages {
|
||||
let role = match msg.role {
|
||||
MessageRole::System => "<|system|>",
|
||||
MessageRole::User => "<|user|>",
|
||||
MessageRole::Assistant => "<|assistant|>",
|
||||
};
|
||||
out.push_str(&format!("{}\n{}", role, msg.content));
|
||||
}
|
||||
out.push_str("<|assistant|>\n");
|
||||
out
|
||||
}
|
||||
|
||||
/// Qwen ChatML format: <|im_start|>role\ncontent<|im_end|>
|
||||
fn format_qwen(messages: &[Message]) -> String {
|
||||
let mut out = String::new();
|
||||
for msg in messages {
|
||||
let role = match msg.role {
|
||||
MessageRole::System => "system",
|
||||
MessageRole::User => "user",
|
||||
MessageRole::Assistant => "assistant",
|
||||
};
|
||||
out.push_str(&format!("<|im_start|>{}\n{}<|im_end|>\n", role, msg.content));
|
||||
}
|
||||
out.push_str("<|im_start|>assistant\n");
|
||||
out
|
||||
}
|
||||
|
||||
/// Mistral Instruct format: <s>[INST] ... [/INST] response</s>
|
||||
fn format_mistral(messages: &[Message]) -> String {
|
||||
let mut out = String::new();
|
||||
let mut in_inst = false;
|
||||
|
||||
for (i, msg) in messages.iter().enumerate() {
|
||||
match msg.role {
|
||||
MessageRole::System if i == 0 => {
|
||||
out.push_str("<s>[INST] ");
|
||||
out.push_str(&msg.content);
|
||||
out.push_str("\n\n");
|
||||
in_inst = true;
|
||||
}
|
||||
MessageRole::System => {} // Ignore non-first system messages
|
||||
MessageRole::User => {
|
||||
if !in_inst {
|
||||
out.push_str("<s>[INST] ");
|
||||
}
|
||||
out.push_str(&msg.content);
|
||||
out.push_str(" [/INST]");
|
||||
in_inst = false;
|
||||
}
|
||||
MessageRole::Assistant => {
|
||||
out.push(' ');
|
||||
out.push_str(&msg.content);
|
||||
out.push_str("</s> ");
|
||||
in_inst = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if messages.last().is_some_and(|m| matches!(m.role, MessageRole::User)) {
|
||||
out.push(' ');
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Llama/CodeLlama format: [INST] <<SYS>>\nsystem<</SYS>>\n\nuser [/INST]
|
||||
fn format_llama(messages: &[Message]) -> String {
|
||||
let mut out = String::new();
|
||||
for msg in messages {
|
||||
match msg.role {
|
||||
MessageRole::System => {
|
||||
out.push_str(&format!("[INST] <<SYS>>\n{}\n<</SYS>>\n\n", msg.content));
|
||||
}
|
||||
MessageRole::User => {
|
||||
out.push_str(&format!("{} [/INST] ", msg.content));
|
||||
}
|
||||
MessageRole::Assistant => {
|
||||
out.push_str(&format!("{} </s><s>[INST] ", msg.content));
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Get stop sequences for a model type.
|
||||
fn get_stop_sequences_for_model(model_type: &str) -> &'static [&'static str] {
|
||||
if model_type.contains("glm") {
|
||||
&["<|endoftext|>", "<|user|>", "<|observation|>", "<|system|>"]
|
||||
} else if model_type.contains("qwen") {
|
||||
&["<|im_end|>", "<|endoftext|>", "</s>", "<|im_start|>"]
|
||||
} else if model_type.contains("code-llama") || model_type.contains("codellama") {
|
||||
&["</s>", "[/INST]", "<</SYS>>", "[INST]", "<<SYS>>"]
|
||||
} else if model_type.contains("llama") {
|
||||
&[
|
||||
"</s>",
|
||||
"[/INST]",
|
||||
"<</SYS>>",
|
||||
"### Human:",
|
||||
"### Assistant:",
|
||||
"[INST]",
|
||||
]
|
||||
} else if model_type.contains("mistral") {
|
||||
&["</s>", "[/INST]", "<|im_end|>"]
|
||||
} else if model_type.contains("vicuna") || model_type.contains("wizard") {
|
||||
&[
|
||||
"### Human:",
|
||||
"### Assistant:",
|
||||
"USER:",
|
||||
"ASSISTANT:",
|
||||
"</s>",
|
||||
]
|
||||
} else if model_type.contains("alpaca") {
|
||||
&["### Instruction:", "### Response:", "### Input:", "</s>"]
|
||||
} else {
|
||||
// Generic fallback
|
||||
&[
|
||||
"</s>",
|
||||
"<|endoftext|>",
|
||||
"<|im_end|>",
|
||||
"### Human:",
|
||||
"### Assistant:",
|
||||
"[/INST]",
|
||||
"<</SYS>>",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Inference Helpers
|
||||
// ============================================================================
|
||||
|
||||
/// Parameters for inference, extracted from request and provider defaults.
|
||||
struct InferenceParams {
|
||||
prompt: String,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
stop_sequences: Vec<String>,
|
||||
}
|
||||
|
||||
/// Prepared inference context with tokenized prompt ready for generation.
|
||||
struct PreparedContext<'a> {
|
||||
ctx: LlamaContext<'a>,
|
||||
batch: LlamaBatch,
|
||||
sampler: LlamaSampler,
|
||||
token_count: i32,
|
||||
}
|
||||
|
||||
impl EmbeddedProvider {
|
||||
/// Extract inference parameters from a completion request.
|
||||
fn extract_params(&self, request: &CompletionRequest) -> InferenceParams {
|
||||
InferenceParams {
|
||||
prompt: self.format_messages(&request.messages),
|
||||
max_tokens: request.max_tokens.unwrap_or_else(|| self.effective_max_tokens()),
|
||||
temperature: request.temperature.unwrap_or(self.temperature),
|
||||
stop_sequences: self
|
||||
.get_stop_sequences()
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Prepare the inference context: create context, tokenize prompt, decode initial batch.
|
||||
fn prepare_context<'a>(
|
||||
model: &'a LlamaModel,
|
||||
backend: &'a LlamaBackend,
|
||||
prompt: &str,
|
||||
temperature: f32,
|
||||
context_length: u32,
|
||||
threads: Option<u32>,
|
||||
) -> Result<PreparedContext<'a>> {
|
||||
let n_ctx = NonZeroU32::new(context_length).unwrap_or(NonZeroU32::new(4096).unwrap());
|
||||
let mut ctx_params = LlamaContextParams::default()
|
||||
.with_n_ctx(Some(n_ctx))
|
||||
.with_n_batch(context_length);
|
||||
if let Some(n_threads) = threads {
|
||||
ctx_params = ctx_params.with_n_threads(n_threads as i32);
|
||||
}
|
||||
|
||||
let mut ctx = model
|
||||
.new_context(backend, ctx_params)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to create context: {:?}", e))?;
|
||||
|
||||
let tokens = model
|
||||
.str_to_token(prompt, AddBos::Always)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to tokenize: {:?}", e))?;
|
||||
|
||||
debug!("Tokenized prompt: {} tokens", tokens.len());
|
||||
|
||||
let batch_size = std::cmp::max(512, tokens.len());
|
||||
let mut batch = LlamaBatch::new(batch_size, 1);
|
||||
for (i, token) in tokens.iter().enumerate() {
|
||||
batch
|
||||
.add(*token, i as i32, &[0], i == tokens.len() - 1)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to add token to batch: {:?}", e))?;
|
||||
}
|
||||
|
||||
ctx.decode(&mut batch)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to decode prompt: {:?}", e))?;
|
||||
|
||||
let sampler = LlamaSampler::chain_simple([
|
||||
LlamaSampler::temp(temperature),
|
||||
LlamaSampler::dist(1234),
|
||||
]);
|
||||
|
||||
Ok(PreparedContext {
|
||||
ctx,
|
||||
batch,
|
||||
sampler,
|
||||
token_count: tokens.len() as i32,
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if text contains any stop sequence. Returns the truncation position if found.
|
||||
fn find_stop_sequence(text: &str, stop_sequences: &[String]) -> Option<usize> {
|
||||
for stop_seq in stop_sequences {
|
||||
if let Some(pos) = text.find(stop_seq) {
|
||||
return Some(pos);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Truncate text at the first stop sequence, if any.
|
||||
fn truncate_at_stop_sequence(text: &mut String, stop_sequences: &[String]) {
|
||||
if let Some(pos) = find_stop_sequence(text, stop_sequences) {
|
||||
text.truncate(pos);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// LLMProvider Implementation
|
||||
// ============================================================================
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl LLMProvider for EmbeddedProvider {
|
||||
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
|
||||
debug!(
|
||||
"Processing completion request with {} messages",
|
||||
request.messages.len()
|
||||
);
|
||||
|
||||
let params = self.extract_params(&request);
|
||||
let prompt_tokens = self.estimate_tokens(¶ms.prompt);
|
||||
|
||||
debug!("Formatted prompt length: {} chars", params.prompt.len());
|
||||
|
||||
// Clone what we need for the blocking task
|
||||
let model = self.model.clone();
|
||||
let backend = self.backend.clone();
|
||||
let context_length = self.context_length;
|
||||
let threads = self.threads;
|
||||
let model_name = self.model_name.clone();
|
||||
|
||||
let (content, completion_tokens) = tokio::task::spawn_blocking(move || {
|
||||
let mut prepared = prepare_context(
|
||||
&model,
|
||||
&backend,
|
||||
¶ms.prompt,
|
||||
params.temperature,
|
||||
context_length,
|
||||
threads,
|
||||
)?;
|
||||
|
||||
let mut generated_text = String::new();
|
||||
let mut token_count = 0u32;
|
||||
|
||||
for _ in 0..params.max_tokens {
|
||||
let new_token = prepared.sampler.sample(&prepared.ctx, prepared.batch.n_tokens() - 1);
|
||||
prepared.sampler.accept(new_token);
|
||||
|
||||
if model.is_eog_token(new_token) {
|
||||
debug!("Hit end-of-generation token at {} tokens", token_count);
|
||||
break;
|
||||
}
|
||||
|
||||
let token_str = model
|
||||
.token_to_str(new_token, Special::Tokenize)
|
||||
.unwrap_or_default();
|
||||
generated_text.push_str(&token_str);
|
||||
token_count += 1;
|
||||
|
||||
if find_stop_sequence(&generated_text, ¶ms.stop_sequences).is_some() {
|
||||
debug!("Hit stop sequence at {} tokens", token_count);
|
||||
break;
|
||||
}
|
||||
|
||||
// Prepare next batch
|
||||
prepared.batch.clear();
|
||||
prepared
|
||||
.batch
|
||||
.add(new_token, prepared.token_count, &[0], true)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to add token to batch: {:?}", e))?;
|
||||
prepared.token_count += 1;
|
||||
|
||||
prepared
|
||||
.ctx
|
||||
.decode(&mut prepared.batch)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to decode: {:?}", e))?;
|
||||
}
|
||||
|
||||
truncate_at_stop_sequence(&mut generated_text, ¶ms.stop_sequences);
|
||||
|
||||
Ok::<_, anyhow::Error>((generated_text.trim().to_string(), token_count))
|
||||
})
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Task join error: {}", e))??;
|
||||
|
||||
Ok(CompletionResponse {
|
||||
content,
|
||||
usage: Usage {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens: prompt_tokens + completion_tokens,
|
||||
cache_creation_tokens: 0,
|
||||
cache_read_tokens: 0,
|
||||
},
|
||||
model: model_name,
|
||||
})
|
||||
}
|
||||
|
||||
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
|
||||
debug!(
|
||||
"Processing streaming request with {} messages",
|
||||
request.messages.len()
|
||||
);
|
||||
|
||||
let params = self.extract_params(&request);
|
||||
let prompt_tokens = self.estimate_tokens(¶ms.prompt);
|
||||
|
||||
let (tx, rx) = mpsc::channel(100);
|
||||
|
||||
let model = self.model.clone();
|
||||
let backend = self.backend.clone();
|
||||
let context_length = self.context_length;
|
||||
let threads = self.threads;
|
||||
let model_type = self.model_type.clone();
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
// Create adapter for model-specific tool format transformation (e.g., GLM)
|
||||
let mut adapter = create_adapter_for_model(&model_type);
|
||||
|
||||
let mut prepared = match prepare_context(
|
||||
&model,
|
||||
&backend,
|
||||
¶ms.prompt,
|
||||
params.temperature,
|
||||
context_length,
|
||||
threads,
|
||||
) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
let _ = tx.blocking_send(Err(e));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut accumulated_text = String::new();
|
||||
let mut token_count = 0u32;
|
||||
let mut stop_reason: Option<String> = None;
|
||||
|
||||
for _ in 0..params.max_tokens {
|
||||
let new_token = prepared.sampler.sample(&prepared.ctx, prepared.batch.n_tokens() - 1);
|
||||
prepared.sampler.accept(new_token);
|
||||
|
||||
if model.is_eog_token(new_token) {
|
||||
debug!("Hit end-of-generation token at {} tokens", token_count);
|
||||
stop_reason = Some("end_turn".to_string());
|
||||
break;
|
||||
}
|
||||
|
||||
let token_str = model
|
||||
.token_to_str(new_token, Special::Tokenize)
|
||||
.unwrap_or_default();
|
||||
|
||||
accumulated_text.push_str(&token_str);
|
||||
token_count += 1;
|
||||
|
||||
if find_stop_sequence(&accumulated_text, ¶ms.stop_sequences).is_some() {
|
||||
debug!("Hit stop sequence at {} tokens", token_count);
|
||||
stop_reason = Some("stop_sequence".to_string());
|
||||
break;
|
||||
}
|
||||
|
||||
// Stream the token (through adapter if present)
|
||||
let output_text = if let Some(ref mut adapt) = adapter {
|
||||
let output = adapt.process_chunk(&token_str);
|
||||
output.emit
|
||||
} else {
|
||||
token_str
|
||||
};
|
||||
if !output_text.is_empty() {
|
||||
if tx.blocking_send(Ok(make_text_chunk(output_text))).is_err() {
|
||||
return; // Receiver dropped
|
||||
}
|
||||
}
|
||||
|
||||
if token_count >= params.max_tokens {
|
||||
debug!("Reached max token limit: {}", params.max_tokens);
|
||||
stop_reason = Some("max_tokens".to_string());
|
||||
break;
|
||||
}
|
||||
|
||||
// Prepare next batch
|
||||
prepared.batch.clear();
|
||||
if let Err(e) = prepared.batch.add(new_token, prepared.token_count, &[0], true) {
|
||||
error!("Failed to add token to batch: {:?}", e);
|
||||
break;
|
||||
}
|
||||
prepared.token_count += 1;
|
||||
|
||||
if let Err(e) = prepared.ctx.decode(&mut prepared.batch) {
|
||||
error!("Failed to decode: {:?}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Flush any remaining content from the adapter
|
||||
if let Some(ref mut adapt) = adapter {
|
||||
let final_output = adapt.flush();
|
||||
if !final_output.emit.is_empty() {
|
||||
if tx.blocking_send(Ok(make_text_chunk(final_output.emit))).is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let usage = Usage {
|
||||
prompt_tokens,
|
||||
completion_tokens: token_count,
|
||||
total_tokens: prompt_tokens + token_count,
|
||||
cache_creation_tokens: 0,
|
||||
cache_read_tokens: 0,
|
||||
};
|
||||
let final_chunk =
|
||||
make_final_chunk_with_reason(vec![], Some(usage), stop_reason.or(Some("end_turn".to_string())));
|
||||
let _ = tx.blocking_send(Ok(final_chunk));
|
||||
});
|
||||
|
||||
Ok(ReceiverStream::new(rx))
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.model_name
|
||||
}
|
||||
|
||||
fn max_tokens(&self) -> u32 {
|
||||
self.effective_max_tokens()
|
||||
}
|
||||
|
||||
fn temperature(&self) -> f32 {
|
||||
self.temperature
|
||||
}
|
||||
|
||||
fn context_window_size(&self) -> Option<u32> {
|
||||
Some(self.context_length)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_format_glm4_messages() {
|
||||
let messages = vec![
|
||||
Message::new(MessageRole::System, "You are a helpful assistant.".to_string()),
|
||||
Message::new(MessageRole::User, "Hello!".to_string()),
|
||||
];
|
||||
|
||||
let formatted = format_glm4(&messages);
|
||||
|
||||
assert!(formatted.starts_with("[gMASK]<sop>"));
|
||||
assert!(formatted.contains("<|system|>\nYou are a helpful assistant."));
|
||||
assert!(formatted.contains("<|user|>\nHello!"));
|
||||
assert!(formatted.ends_with("<|assistant|>\n"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_qwen_messages() {
|
||||
let messages = vec![
|
||||
Message::new(MessageRole::System, "You are a helpful assistant.".to_string()),
|
||||
Message::new(MessageRole::User, "Hello!".to_string()),
|
||||
];
|
||||
|
||||
let formatted = format_qwen(&messages);
|
||||
|
||||
assert!(formatted.contains("<|im_start|>system\nYou are a helpful assistant.<|im_end|>"));
|
||||
assert!(formatted.contains("<|im_start|>user\nHello!<|im_end|>"));
|
||||
assert!(formatted.ends_with("<|im_start|>assistant\n"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_mistral_messages() {
|
||||
let messages = vec![
|
||||
Message::new(MessageRole::System, "You are a helpful assistant.".to_string()),
|
||||
Message::new(MessageRole::User, "Hello!".to_string()),
|
||||
];
|
||||
|
||||
let formatted = format_mistral(&messages);
|
||||
|
||||
assert!(formatted.starts_with("<s>[INST] "));
|
||||
assert!(formatted.contains("You are a helpful assistant."));
|
||||
assert!(formatted.contains("Hello!"));
|
||||
assert!(formatted.contains("[/INST]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_llama_messages() {
|
||||
let messages = vec![
|
||||
Message::new(MessageRole::System, "You are a helpful assistant.".to_string()),
|
||||
Message::new(MessageRole::User, "Hello!".to_string()),
|
||||
];
|
||||
|
||||
let formatted = format_llama(&messages);
|
||||
|
||||
assert!(formatted.contains("<<SYS>>"));
|
||||
assert!(formatted.contains("You are a helpful assistant."));
|
||||
assert!(formatted.contains("<</SYS>>"));
|
||||
assert!(formatted.contains("Hello!"));
|
||||
assert!(formatted.contains("[/INST]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glm4_stop_sequences() {
|
||||
let stop_seqs = get_stop_sequences_for_model("glm4");
|
||||
|
||||
assert!(stop_seqs.contains(&"<|endoftext|>"));
|
||||
assert!(stop_seqs.contains(&"<|user|>"));
|
||||
assert!(stop_seqs.contains(&"<|observation|>"));
|
||||
assert!(stop_seqs.contains(&"<|system|>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qwen_stop_sequences() {
|
||||
let stop_seqs = get_stop_sequences_for_model("qwen");
|
||||
|
||||
assert!(stop_seqs.contains(&"<|im_end|>"));
|
||||
assert!(stop_seqs.contains(&"<|endoftext|>"));
|
||||
assert!(stop_seqs.contains(&"<|im_start|>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glm4_multi_turn_conversation() {
|
||||
let messages = vec![
|
||||
Message::new(MessageRole::System, "You are a coding assistant.".to_string()),
|
||||
Message::new(
|
||||
MessageRole::User,
|
||||
"Write a hello world in Python.".to_string(),
|
||||
),
|
||||
Message::new(
|
||||
MessageRole::Assistant,
|
||||
"print('Hello, World!')".to_string(),
|
||||
),
|
||||
Message::new(MessageRole::User, "Now in Rust.".to_string()),
|
||||
];
|
||||
|
||||
let formatted = format_glm4(&messages);
|
||||
|
||||
// Verify all parts are present in order
|
||||
let system_pos = formatted.find("<|system|>").unwrap();
|
||||
let user1_pos = formatted.find("<|user|>\nWrite a hello world").unwrap();
|
||||
let assistant_pos = formatted.find("<|assistant|>\nprint").unwrap();
|
||||
let user2_pos = formatted.find("<|user|>\nNow in Rust").unwrap();
|
||||
let final_assistant_pos = formatted.rfind("<|assistant|>\n").unwrap();
|
||||
|
||||
assert!(system_pos < user1_pos);
|
||||
assert!(user1_pos < assistant_pos);
|
||||
assert!(assistant_pos < user2_pos);
|
||||
assert!(user2_pos < final_assistant_pos);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_stop_sequence() {
|
||||
let stop_seqs = vec!["</s>".to_string(), "<|im_end|>".to_string()];
|
||||
|
||||
assert_eq!(find_stop_sequence("hello world", &stop_seqs), None);
|
||||
assert_eq!(find_stop_sequence("hello</s>world", &stop_seqs), Some(5));
|
||||
assert_eq!(
|
||||
find_stop_sequence("hello<|im_end|>world", &stop_seqs),
|
||||
Some(5)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncate_at_stop_sequence() {
|
||||
let stop_seqs = vec!["</s>".to_string()];
|
||||
|
||||
let mut text = "hello</s>world".to_string();
|
||||
truncate_at_stop_sequence(&mut text, &stop_seqs);
|
||||
assert_eq!(text, "hello");
|
||||
|
||||
let mut text2 = "no stop here".to_string();
|
||||
truncate_at_stop_sequence(&mut text2, &stop_seqs);
|
||||
assert_eq!(text2, "no stop here");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user