diff --git a/Cargo.lock b/Cargo.lock index c628cfa..9c5220b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -229,16 +229,14 @@ dependencies = [ [[package]] name = "bindgen" -version = "0.69.5" +version = "0.72.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" dependencies = [ "bitflags 2.10.0", "cexpr", "clang-sys", - "itertools 0.12.1", - "lazy_static", - "lazycell", + "itertools", "log", "prettyplease", "proc-macro2", @@ -247,7 +245,6 @@ dependencies = [ "rustc-hash", "shlex", "syn", - "which", ] [[package]] @@ -438,6 +435,15 @@ dependencies = [ "error-code", ] +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + [[package]] name = "cocoa" version = "0.25.0" @@ -605,12 +611,6 @@ dependencies = [ "unicode-xid", ] -[[package]] -name = "convert_case" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" - [[package]] name = "convert_case" version = "0.6.0" @@ -857,7 +857,7 @@ checksum = "d8b9f2e4c67f833b660cdb0a3523065869fb35570177239812ed4c905aeff87b" dependencies = [ "bitflags 2.10.0", "crossterm_winapi", - "derive_more 2.0.1", + "derive_more", "document-features", "mio", "parking_lot", @@ -936,19 +936,6 @@ dependencies = [ "powerfmt", ] -[[package]] -name = "derive_more" -version = "0.99.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6edb4b64a43d977b8e99788fe3a04d483834fba1215a7e02caa415b626497f7f" -dependencies = [ - "convert_case 0.4.0", - "proc-macro2", - "quote", - "rustc_version", - "syn", -] - [[package]] name = "derive_more" version = "2.0.1" @@ -1078,6 +1065,26 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" +[[package]] +name = "enumflags2" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1027f7680c853e056ebcec683615fb6fbbc07dbaa13b4d5d9442b146ded4ecef" +dependencies = [ + "enumflags2_derive", +] + +[[package]] +name = "enumflags2_derive" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67c78a4d8fdf9953a5c9d458f9efe940fd97a0cab0941c075a813ac594733827" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -1171,6 +1178,15 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "52051878f80a721bb68ebfbc930e07b65ba72f2da88968ea5c06fd6ca3d3a127" +[[package]] +name = "find_cuda_helper" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9f9e65c593dd01ac77daad909ea4ad17f0d6d1776193fc8ea766356177abdad" +dependencies = [ + "glob", +] + [[package]] name = "flate2" version = "1.1.5" @@ -1504,7 +1520,7 @@ dependencies = [ "chrono", "dirs 5.0.1", "futures-util", - "llama_cpp", + "llama-cpp-2", "nanoid", "rand", "reqwest", @@ -1643,12 +1659,6 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" -[[package]] -name = "hermit-abi" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" - [[package]] name = "hex" version = "0.4.3" @@ -2040,15 +2050,6 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" -[[package]] -name = "itertools" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.13.0" @@ -2155,12 +2156,6 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" -[[package]] -name = "lazycell" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" - [[package]] name = "lebe" version = "0.5.3" @@ -2193,15 +2188,6 @@ dependencies = [ "libc", ] -[[package]] -name = "link-cplusplus" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f78c730aaa7d0b9336a299029ea49f9ee53b0ed06e9202e8cb7db9bae7b8c82" -dependencies = [ - "cc", -] - [[package]] name = "linked-hash-map" version = "0.5.6" @@ -2233,30 +2219,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092" [[package]] -name = "llama_cpp" -version = "0.3.2" +name = "llama-cpp-2" +version = "0.1.125" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f126770a2ed5e0e4596119479dc56f56b99037246bf0e36c544f7581a9458fd" +checksum = "14cc99d19a12f372957e1ad1cb33c5459e6080c7914389e52f2464d8fb043175" dependencies = [ - "derive_more 0.99.20", - "futures", - "llama_cpp_sys", - "num_cpus", + "enumflags2", + "llama-cpp-sys-2", "thiserror 1.0.69", - "tokio", "tracing", + "tracing-core", ] [[package]] -name = "llama_cpp_sys" -version = "0.3.2" +name = "llama-cpp-sys-2" +version = "0.1.125" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "037a1881ada3592c6a922224d5177b4b4f452e6b2979eb97393b71989e48357f" +checksum = "cc9443103277a9808b0e7055966a39fd2de14c7877fecdec4daf7b8770c46ec3" dependencies = [ "bindgen", "cc", - "link-cplusplus", - "once_cell", + "cmake", + "find_cuda_helper", + "glob", + "walkdir", ] [[package]] @@ -2443,16 +2429,6 @@ dependencies = [ "autocfg", ] -[[package]] -name = "num_cpus" -version = "1.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" -dependencies = [ - "hermit-abi", - "libc", -] - [[package]] name = "number_prefix" version = "0.4.0" @@ -2860,7 +2836,7 @@ dependencies = [ "crossterm 0.28.1", "indoc", "instability", - "itertools 0.13.0", + "itertools", "lru", "paste", "strum", @@ -3015,18 +2991,9 @@ dependencies = [ [[package]] name = "rustc-hash" -version = "1.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" - -[[package]] -name = "rustc_version" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" -dependencies = [ - "semver", -] +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rustix" @@ -3171,12 +3138,6 @@ dependencies = [ "libc", ] -[[package]] -name = "semver" -version = "1.0.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" - [[package]] name = "serde" version = "1.0.228" @@ -4048,7 +4009,7 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b3644627a5af5fa321c95b9b235a72fd24cd29c648c2c379431e6628655627bf" dependencies = [ - "itertools 0.13.0", + "itertools", "unicode-segmentation", "unicode-width 0.1.14", ] @@ -4311,18 +4272,6 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a751b3277700db47d3e574514de2eced5e54dc8a5436a3bf7a0b248b2cee16f3" -[[package]] -name = "which" -version = "4.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" -dependencies = [ - "either", - "home", - "once_cell", - "rustix 0.38.44", -] - [[package]] name = "winapi" version = "0.3.9" diff --git a/analysis/memory.md b/analysis/memory.md index b24b559..6463a36 100644 --- a/analysis/memory.md +++ b/analysis/memory.md @@ -1,5 +1,5 @@ # Workspace Memory -> Updated: 2026-01-27T00:12:18Z | Size: 19.5k chars +> Updated: 2026-01-27T02:55:20Z | Size: 21.9k chars ### Remember Tool Wiring - `crates/g3-core/src/tools/memory.rs` [0..5000] - `execute_remember()`, `get_memory_path()`, `merge_memory()` @@ -346,4 +346,58 @@ Tracks prompt/prefix caching efficacy across Anthropic and OpenAI providers. - `crates/g3-core/src/stats.rs` - `AgentStatsSnapshot.cache_stats` [20] - reference to cache stats for formatting - - `format_cache_stats()` [189..230] - formats cache statistics section with hit rate and efficiency metrics \ No newline at end of file + - `format_cache_stats()` [189..230] - formats cache statistics section with hit rate and efficiency metrics + +### Embedded Provider (Local LLM via llama.cpp) +Local model inference using llama-cpp-rs bindings with Metal acceleration on macOS. + +- `crates/g3-providers/src/embedded.rs` + - `EmbeddedProvider` [22..85] - struct with session, model_name, max_tokens, temperature, context_length + - `new()` [26..85] - constructor, handles tilde expansion, auto-downloads Qwen if missing + - `format_messages()` [87..175] - converts Message[] to prompt string, supports Qwen/Mistral/Llama templates + - `get_stop_sequences()` [280..340] - returns model-specific stop tokens + - `generate_completion()` [177..278] - non-streaming inference with timeout + - `stream()` [560..780] - streaming inference via spawn_blocking + mpsc channel + +**Known Issues (as of 2025-01):** +- Provider name hardcoded as `"embedded"` instead of `"embedded.{name}"` format +- Missing GLM-4 chat template (uses `[gMASK]<|role|>` NOT ChatML) +- Missing `has_native_tool_calling()` override (defaults to false) +- No streaming usage tracking (unlike Anthropic) +- No tool streaming hints (`make_tool_streaming_hint()` not used) + +### Chat Template Formats (Embedded Provider) +| Model | Format | Start Token | End Token | +|-------|--------|-------------|----------| +| Qwen | ChatML | `<\|im_start\|>role\n` | `<\|im_end\|>` | +| GLM-4 | ChatGLM4 | `[gMASK]<\|role\|>\n` | `<\|endoftext\|>` | +| Mistral | Instruct | `[INST]` | `[/INST]` | +| Llama | Llama2 | `<>` | `<>` | + +### GLM-4 Model Downloads +Recommended GGUF models for Mac M4 Max with 128GB unified memory. + +**Download commands:** +```bash +# GLM-4 9B Q8_0 (~10GB) - Very capable, fast +python3 -m huggingface_hub.commands.huggingface_cli download bartowski/THUDM_GLM-4-9B-0414-GGUF \ + --include "THUDM_GLM-4-9B-0414-Q8_0.gguf" --local-dir ~/.g3/models/ + +# GLM-4 32B Q6_K_L (~27GB) - TOP TIER for coding/reasoning +python3 -m huggingface_hub.commands.huggingface_cli download bartowski/THUDM_GLM-4-32B-0414-GGUF \ + --include "THUDM_GLM-4-32B-0414-Q6_K_L.gguf" --local-dir ~/.g3/models/ + +# Qwen3 4B Q4_K_M (~2.3GB) - Small but rivals 72B performance +python3 -m huggingface_hub.commands.huggingface_cli download Qwen/Qwen3-4B-GGUF \ + --include "qwen3-4b-q4_k_m.gguf" --local-dir ~/.g3/models/ +``` + +**Config example:** +```toml +[providers.embedded.glm4] +model_path = "~/.g3/models/THUDM_GLM-4-32B-0414-Q6_K_L.gguf" +model_type = "glm4" +context_length = 32768 +max_tokens = 4096 +gpu_layers = 99 +``` \ No newline at end of file diff --git a/config.example.toml b/config.example.toml index 49e3551..61aa000 100644 --- a/config.example.toml +++ b/config.example.toml @@ -52,6 +52,35 @@ model = "claude-sonnet-4-5" # model = "anthropic/claude-3.5-sonnet" # base_url = "https://openrouter.ai/api/v1" +# ============================================================================= +# Embedded providers (local models via llama.cpp with Metal acceleration) +# ============================================================================= +# Download models from Hugging Face: +# huggingface-cli download bartowski/THUDM_GLM-4-32B-0414-GGUF \ +# --include "THUDM_GLM-4-32B-0414-Q6_K_L.gguf" --local-dir ~/.g3/models/ +# +# GLM-4 32B - Top-tier local model for coding/reasoning (context_length auto-detected from GGUF) +# [providers.embedded.glm4] +# model_path = "~/.g3/models/THUDM_GLM-4-32B-0414-Q6_K_L.gguf" +# model_type = "glm4" # Required: glm4, qwen, mistral, llama, codellama +# context_length = 32768 # Optional: auto-detected from GGUF (GLM-4 = 32K) +# max_tokens = 4096 # Optional: defaults to min(4096, context/4) +# temperature = 0.1 +# gpu_layers = 99 # Use all GPU layers on Apple Silicon +# threads = 8 + +# GLM-4 9B - Smaller but very capable (minimal config - most settings auto-detected) +# [providers.embedded.glm4-9b] +# model_path = "~/.g3/models/THUDM_GLM-4-9B-0414-Q8_0.gguf" +# model_type = "glm4" +# gpu_layers = 99 # Optional but recommended for Apple Silicon + +# Qwen3 4B - Small but powerful, good for ensemble usage (minimal config) +# [providers.embedded.qwen3] +# model_path = "~/.g3/models/qwen3-4b-q4_k_m.gguf" +# model_type = "qwen" +# gpu_layers = 99 # Optional but recommended for Apple Silicon + # ============================================================================= # Agent settings (all optional - these are the defaults) # ============================================================================= diff --git a/crates/g3-core/src/lib.rs b/crates/g3-core/src/lib.rs index 1a77fa4..01bfa60 100644 --- a/crates/g3-core/src/lib.rs +++ b/crates/g3-core/src/lib.rs @@ -54,7 +54,7 @@ mod prompts; use anyhow::Result; use g3_config::Config; use g3_providers::{CacheControl, CompletionRequest, Message, MessageRole, ProviderRegistry}; -use prompts::{get_system_prompt_for_native, SYSTEM_PROMPT_FOR_NON_NATIVE_TOOL_USE}; +use prompts::{get_system_prompt_for_native, get_system_prompt_for_non_native}; use serde::{Deserialize, Serialize}; use std::time::{Duration, Instant}; use tokio_util::sync::CancellationToken; @@ -354,7 +354,7 @@ impl Agent { get_system_prompt_for_native() } else { // For non-native providers (embedded models), use JSON format instructions - SYSTEM_PROMPT_FOR_NON_NATIVE_TOOL_USE.to_string() + get_system_prompt_for_non_native() } }; @@ -426,12 +426,13 @@ impl Agent { } // Check for system prompt markers that are present in both standard and agent mode - // Agent mode replaces the identity line but keeps all other instructions + // Check for system prompt markers that are present in both native and non-native prompts + // Both prompts contain "You have access to tools" as a common marker let has_tool_instructions = first_message .content - .contains("IMPORTANT: You must call tools to achieve goals"); + .contains("You have access to tools"); if !has_tool_instructions { - panic!("FATAL: First system message does not contain the system prompt. This likely means the README was added before the system prompt."); + panic!("FATAL: First system message does not contain the system prompt marker 'You have access to tools'. This likely means the README was added before the system prompt."); } } diff --git a/crates/g3-core/src/prompts.rs b/crates/g3-core/src/prompts.rs index 15b2da6..48075c1 100644 --- a/crates/g3-core/src/prompts.rs +++ b/crates/g3-core/src/prompts.rs @@ -1,5 +1,10 @@ -const SYSTEM_NATIVE_TOOL_CALLS: &'static str = -"You are G3, an AI programming agent of the same skill level as a seasoned engineer at a major technology company. You analyze given tasks and write code to achieve goals. +// ============================================================================ +// SHARED PROMPT SECTIONS +// These are used by both native and non-native tool calling prompts +// ============================================================================ + +const SHARED_INTRO: &str = "\ +You are G3, an AI programming agent of the same skill level as a seasoned engineer at a major technology company. You analyze given tasks and write code to achieve goals. You have access to tools. When you need to accomplish a task, you MUST use the appropriate tool. Do not just describe what you would do - actually use the tools. @@ -11,8 +16,9 @@ IMPORTANT: You must call tools to achieve goals. When you receive a request: 5. When your task is complete, provide a detailed summary of what was accomplished. For shell commands: Use the shell tool with the exact command needed. Always use `rg` (ripgrep) instead of `grep` - it's faster, has better defaults, and respects .gitignore. Avoid commands that produce a large amount of output, and consider piping those outputs to files. Example: If asked to list files, immediately call the shell tool with command parameter \"ls\". -If you create temporary files for verification, place these in a subdir named 'tmp'. Do NOT pollute the current dir. +If you create temporary files for verification, place these in a subdir named 'tmp'. Do NOT pollute the current dir."; +const SHARED_TODO_SECTION: &str = "\ # Task Management with TODO Tools **REQUIRED for multi-step tasks.** Use TODO tools when your task involves ANY of: @@ -75,12 +81,14 @@ Keep items short, specific, and action-oriented. ✓ Helps recover from interruptions ✓ Creates better summaries -If you can complete it with 1-2 tool calls, skip TODO. +If you can complete it with 1-2 tool calls, skip TODO."; +const SHARED_TEMPORARY_FILES: &str = "\ # Temporary files -If you create temporary files for verification or investigation, place these in a subdir named 'tmp'. Do NOT pollute the current dir. +If you create temporary files for verification or investigation, place these in a subdir named 'tmp'. Do NOT pollute the current dir."; +const SHARED_WEB_RESEARCH: &str = "\ # Web Research When you need to look up documentation, search for resources, find data online, or research a topic to complete your task, use the `research` tool. @@ -95,13 +103,14 @@ Simply call `research` with a specific query describing what you need to know. T IMPORTANT: If the user asks you to just respond with text (like \"just say hello\" or \"tell me about X\"), do NOT use tools. Simply respond with the requested text directly. Only use tools when you need to execute commands or complete tasks that require action. -Do not explain what you're going to do - just do it by calling the tools. +Do not explain what you're going to do - just do it by calling the tools."; +const SHARED_WORKSPACE_MEMORY: &str = "\ # Workspace Memory Workspace memory is automatically loaded at startup alongside README.md and AGENTS.md. It contains an index of features -> code locations, patterns, and entry points. If you need to re-read memory from disk (e.g., after another agent updates it), use `read_file analysis/memory.md`. -**IMPORTANT**: After completing a task where you discovered code locations, you **MUST** call the `remember` tool to save them.. +**IMPORTANT**: After completing a task where you discovered code locations, you **MUST** call the `remember` tool to save them. ## Memory Format @@ -143,33 +152,27 @@ After discovering how session continuation works: After discovering a useful pattern: -{\"tool\": \"remember\", \"args\": {\"notes\": \"### UTF-8 Safe String Slicing\\nRust string slices use byte indices. Multi-byte chars (emoji, CJK) cause panics if sliced mid-character.\\n\\n1. Use `s.char_indices().nth(n)` to get byte index of Nth character\\n2. Use `s.chars().count()` for length, not `s.len()`\\n3. Danger zones: display truncation, user input, any non-ASCII text\"}} +{\"tool\": \"remember\", \"args\": {\"notes\": \"### UTF-8 Safe String Slicing\\nRust string slices use byte indices. Multi-byte chars (emoji, CJK) cause panics if sliced mid-character.\\n\\n1. Use `s.char_indices().nth(n)` to get byte index of Nth character\\n2. Use `s.chars().count()` for length, not `s.len()`\\n3. Danger zones: display truncation, user input, any non-ASCII text\"}}"; +const SHARED_RESPONSE_GUIDELINES: &str = "\ # Response Guidelines - Use Markdown formatting for all responses except tool calls. - Whenever taking actions, use the pronoun 'I' - When you discover features, patterns and code locations, call `remember` to save them. -- When showing example tool call JSON in prose or code blocks, use the fullwidth left curly bracket `{` (U+FF5B) instead of `{` to prevent parser confusion. -"; +- When showing example tool call JSON in prose or code blocks, use the fullwidth left curly bracket `{` (U+FF5B) instead of `{` to prevent parser confusion."; -pub const SYSTEM_PROMPT_FOR_NATIVE_TOOL_USE: &'static str = SYSTEM_NATIVE_TOOL_CALLS; - -/// Generate system prompt based on whether multiple tool calls are allowed -pub fn get_system_prompt_for_native() -> String { - SYSTEM_PROMPT_FOR_NATIVE_TOOL_USE.to_string() -} - -const SYSTEM_NON_NATIVE_TOOL_USE: &'static str = -"You are G3, a general-purpose AI agent. Your goal is to analyze and solve problems by writing code. - -You have access to tools. When you need to accomplish a task, you MUST use the appropriate tool. Do not just describe what you would do - actually use the tools. +// ============================================================================ +// NON-NATIVE SPECIFIC SECTIONS +// These are only used by providers without native tool calling +// ============================================================================ +const NON_NATIVE_TOOL_FORMAT: &str = "\ # Tool Call Format When you need to execute a tool, write ONLY the JSON tool call on a new line: -{\"tool\": \"tool_name\", \"args\": {\"param\": \"value\"} +{\"tool\": \"tool_name\", \"args\": {\"param\": \"value\"}} The tool will execute immediately and you'll receive the result (success or error) to continue with. @@ -178,8 +181,8 @@ The tool will execute immediately and you'll receive the result (success or erro Short description for providers without native calling specs: - **shell**: Execute shell commands - - Format: {\"tool\": \"shell\", \"args\": {\"command\": \"your_command_here\"} - - Example: {\"tool\": \"shell\", \"args\": {\"command\": \"ls ~/Downloads\"} + - Format: {\"tool\": \"shell\", \"args\": {\"command\": \"your_command_here\"}} + - Example: {\"tool\": \"shell\", \"args\": {\"command\": \"ls ~/Downloads\"}} - Always use `rg` (ripgrep) instead of `grep` - it's faster and respects .gitignore - **background_process**: Launch a long-running process in the background (e.g., game servers, dev servers) @@ -189,21 +192,21 @@ Short description for providers without native calling specs: - Note: Process runs independently; logs are captured to a file for later inspection - **read_file**: Read the contents of a file (supports partial reads via start/end) - - Format: {\"tool\": \"read_file\", \"args\": {\"file_path\": \"path/to/file\", \"start\": 0, \"end\": 100} - - Example: {\"tool\": \"read_file\", \"args\": {\"file_path\": \"src/main.rs\"} - - Example (partial): {\"tool\": \"read_file\", \"args\": {\"file_path\": \"large.log\", \"start\": 0, \"end\": 1000} + - Format: {\"tool\": \"read_file\", \"args\": {\"file_path\": \"path/to/file\", \"start\": 0, \"end\": 100}} + - Example: {\"tool\": \"read_file\", \"args\": {\"file_path\": \"src/main.rs\"}} + - Example (partial): {\"tool\": \"read_file\", \"args\": {\"file_path\": \"large.log\", \"start\": 0, \"end\": 1000}} - **read_image**: Read an image file for visual analysis (PNG, JPEG, GIF, WebP) - Format: {\"tool\": \"read_image\", \"args\": {\"file_paths\": [\"path/to/image.png\"]}} - Example: {\"tool\": \"read_image\", \"args\": {\"file_paths\": [\"sprites/fairy.png\"]}} - **write_file**: Write content to a file (creates or overwrites) - - Format: {\"tool\": \"write_file\", \"args\": {\"file_path\": \"path/to/file\", \"content\": \"file content\"} - - Example: {\"tool\": \"write_file\", \"args\": {\"file_path\": \"src/lib.rs\", \"content\": \"pub fn hello() {}\"} + - Format: {\"tool\": \"write_file\", \"args\": {\"file_path\": \"path/to/file\", \"content\": \"file content\"}} + - Example: {\"tool\": \"write_file\", \"args\": {\"file_path\": \"src/lib.rs\", \"content\": \"pub fn hello() {}\"}} - **str_replace**: Replace text in a file using a diff - - Format: {\"tool\": \"str_replace\", \"args\": {\"file_path\": \"path/to/file\", \"diff\": \"--- old\\n-old text\\n+++ new\\n+new text\"} - - Example: {\"tool\": \"str_replace\", \"args\": {\"file_path\": \"src/main.rs\", \"diff\": \"--- old\\n-old_code();\\n+++ new\\n+new_code();\"} + - Format: {\"tool\": \"str_replace\", \"args\": {\"file_path\": \"path/to/file\", \"diff\": \"--- old\\n-old text\\n+++ new\\n+new text\"}} + - Example: {\"tool\": \"str_replace\", \"args\": {\"file_path\": \"src/main.rs\", \"diff\": \"--- old\\n-old_code();\\n+++ new\\n+new_code();\"}} - **todo_read**: Read the current session's TODO list from todo.g3.md (session-scoped) - Format: {\"tool\": \"todo_read\", \"args\": {}} @@ -220,8 +223,6 @@ Short description for providers without native calling specs: - Find structs: {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"structs\", \"query\": \"(struct_item name: (type_identifier) @name)\", \"language\": \"rust\"}]}} - Multiple searches: {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"funcs\", \"query\": \"(function_item name: (identifier) @name)\", \"language\": \"rust\"}, {\"name\": \"structs\", \"query\": \"(struct_item name: (type_identifier) @name)\", \"language\": \"rust\"}]}} - With context lines: {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"funcs\", \"query\": \"(function_item name: (identifier) @name)\", \"language\": \"rust\", \"context_lines\": 3}]}} - - \"context\": 3 (show surrounding lines), - - \"json_style\": \"stream\" (for large results) - **research**: Perform web-based research and return a structured report - Format: {\"tool\": \"research\", \"args\": {\"query\": \"your research question\"}} @@ -230,9 +231,10 @@ Short description for providers without native calling specs: - **remember**: Save discovered code locations to workspace memory - Format: {\"tool\": \"remember\", \"args\": {\"notes\": \"markdown notes\"}} - - Example: {\"tool\": \"remember\", \"args\": {\"notes\": \"### Feature Name\\n- `file.rs` [0..100] - `function_name()`\"}} - - Use at the END of your turn after discovering code locations via search tools + - Example: {\"tool\": \"remember\", \"args\": {\"notes\": \"### Feature Name\\n- `file.rs` [0..100] - `function_name()\"}} + - Use at the END of your turn after discovering code locations via search tools"; +const NON_NATIVE_INSTRUCTIONS: &str = "\ # Instructions 1. Analyze the request and break down into smaller tasks if appropriate @@ -240,6 +242,10 @@ Short description for providers without native calling specs: 3. STOP when the original request was satisfied 4. When your task is complete, provide a detailed summary of what was accomplished +IMPORTANT: If the user asks you to just respond with text (like \"just say hello\" or \"tell me about X\"), do NOT use tools. Simply respond with the requested text directly. Only use tools when you need to execute commands or complete tasks that require action. + +Do not explain what you're going to do - just do it by calling the tools. + For reading files, prioritize use of code_search tool use with multiple search requests per call instead of read_file, if it makes sense. Exception to using ONE tool at a time: @@ -256,104 +262,53 @@ But NOT: write_file(\"helper.rs\", \"...\") write_file(\"file2.txt\", \"...\") write_file(\"helper.rs\", \"...\") -[DONE] +[DONE]"; -# Task Management with TODO Tools - -**REQUIRED for multi-step tasks.** Use TODO tools when your task involves ANY of: -- Multiple files to create/modify (2+) -- Multiple distinct steps (3+) -- Dependencies between steps -- Testing or verification needed -- Uncertainty about approach - -## Workflow - -Every multi-step task follows this pattern: -1. **Start**: Call todo_read, then todo_write to create your plan -2. **During**: Execute steps, then todo_read and todo_write to mark progress -3. **End**: Call todo_read to verify all items complete - -Note: todo_write replaces the entire list, so always read first to preserve content. +const NON_NATIVE_TODO_ADDENDUM: &str = " IMPORTANT: If you are provided with a SHA256 hash of the requirements file, you MUST include it as the very first line of the todo.g3.md file in the following format: `{{Based on the requirements file with SHA256: }}` -This ensures the TODO list is tracked against the specific version of requirements it was generated from. +This ensures the TODO list is tracked against the specific version of requirements it was generated from."; -## Examples +// ============================================================================ +// COMPOSED PROMPTS +// ============================================================================ -**Example 1: Feature Implementation** -User asks: \"Add user authentication with tests\" +/// System prompt for providers with native tool calling (Anthropic, OpenAI, etc.) +/// Note: This is kept for backwards compatibility but the function is preferred +pub const SYSTEM_PROMPT_FOR_NATIVE_TOOL_USE: &str = ""; -First action: -{\"tool\": \"todo_read\", \"args\": {}} +/// Generate system prompt for native tool calling providers +pub fn get_system_prompt_for_native() -> String { + format!( + "{}\n\n{}\n\n{}\n\n{}\n\n{}\n\n{}", + SHARED_INTRO, + SHARED_TODO_SECTION, + SHARED_TEMPORARY_FILES, + SHARED_WEB_RESEARCH, + SHARED_WORKSPACE_MEMORY, + SHARED_RESPONSE_GUIDELINES + ) +} -Then create plan: -{\"tool\": \"todo_write\", \"args\": {\"content\": \"- [ ] Add user authentication\\n - [ ] Create User struct\\n - [ ] Add login endpoint\\n - [ ] Add password hashing\\n - [ ] Write unit tests\\n - [ ] Write integration tests\"}} +/// System prompt for providers without native tool calling (embedded models) +/// Note: This is kept for backwards compatibility but the function is preferred +pub const SYSTEM_PROMPT_FOR_NON_NATIVE_TOOL_USE: &str = ""; -After completing User struct: -{\"tool\": \"todo_read\", \"args\": {}} -{\"tool\": \"todo_write\", \"args\": {\"content\": \"- [ ] Add user authentication\\n - [x] Create User struct\\n - [ ] Add login endpoint\\n - [ ] Add password hashing\\n - [ ] Write unit tests\\n - [ ] Write integration tests\"}} - -**Example 2: Bug Fix** -User asks: \"Fix the memory leak in cache module\" - -{\"tool\": \"todo_read\", \"args\": {}} -{\"tool\": \"todo_write\", \"args\": {\"content\": \"- [ ] Fix memory leak\\n - [ ] Review cache.rs\\n - [ ] Check for unclosed resources\\n - [ ] Add drop implementation\\n - [ ] Write test to verify fix\"}} - -**Example 3: Refactoring** -User asks: \"Refactor database layer to use async/await\" - -{\"tool\": \"todo_read\", \"args\": {}} -{\"tool\": \"todo_write\", \"args\": {\"content\": \"- [ ] Refactor to async\\n - [ ] Update function signatures\\n - [ ] Replace blocking calls\\n - [ ] Update all callers\\n - [ ] Update tests\"}} - -## Format - -Use markdown checkboxes: -- \"- [ ]\" for incomplete tasks -- \"- [x]\" for completed tasks -- Indent with 2 spaces for subtasks - -Keep items short, specific, and action-oriented. - -## Benefits - -✓ Prevents missed steps -✓ Makes progress visible -✓ Helps recover from interruptions -✓ Creates better summaries - -## When NOT to Use - -Skip TODO tools for simple single-step tasks: -- \"List files\" → just use shell -- \"Read config.json\" → just use read_file -- \"Search for functions\" → just use code_search - -If you can complete it with 1-2 tool calls, skip TODO. - -# Workspace Memory - -Workspace memory (if available) is automatically loaded at startup. It contains feature locations and patterns discovered in previous sessions. If you need to re-read memory from disk (e.g., after another agent updates it), use `read_file analysis/memory.md`. - -**ALWAYS** call `remember` at the END of your turn when you discovered: -- A feature's location (file + char range + function/struct names) -- A useful pattern or workflow -- An entry point for a subsystem - -This applies whenever you use search tools like `code_search`, `rg`, `grep`, `find`, or `read_file` to locate code. - -Do NOT save duplicates - check the Workspace Memory section (loaded at startup) to see what's already known. - -# Response Guidelines - -- Use Markdown formatting for all responses except tool calls. -- Whenever taking actions, use the pronoun 'I' -- After discovering code locations via search tools, call `remember` to save them. -- When showing example tool call JSON in prose or code blocks, use the fullwidth left curly bracket `{` (U+FF5B) instead of `{` to prevent parser confusion. -"; - -pub const SYSTEM_PROMPT_FOR_NON_NATIVE_TOOL_USE: &'static str = SYSTEM_NON_NATIVE_TOOL_USE; +/// Generate system prompt for non-native tool calling providers (embedded models) +pub fn get_system_prompt_for_non_native() -> String { + format!( + "{}\n\n{}\n\n{}\n\n{}{}\n\n{}\n\n{}\n\n{}", + SHARED_INTRO, + NON_NATIVE_TOOL_FORMAT, + NON_NATIVE_INSTRUCTIONS, + SHARED_TODO_SECTION, + NON_NATIVE_TODO_ADDENDUM, + SHARED_WEB_RESEARCH, + SHARED_WORKSPACE_MEMORY, + SHARED_RESPONSE_GUIDELINES + ) +} /// The G3 identity line that gets replaced in agent mode const G3_IDENTITY_LINE: &str = "You are G3, an AI programming agent of the same skill level as a seasoned engineer at a major technology company. You analyze given tasks and write code to achieve goals."; @@ -371,3 +326,80 @@ pub fn get_agent_system_prompt(agent_prompt: &str, allow_multiple_tool_calls: bo // Replace only the G3 identity line with the custom agent prompt full_prompt.replace(G3_IDENTITY_LINE, agent_prompt.trim()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_native_prompt_contains_validation_string() { + let prompt = get_system_prompt_for_native(); + assert!(prompt.contains("You have access to tools"), + "Native prompt must contain validation string"); + } + + #[test] + fn test_non_native_prompt_contains_validation_string() { + let prompt = get_system_prompt_for_non_native(); + assert!(prompt.contains("You have access to tools"), + "Non-native prompt must contain validation string"); + } + + #[test] + fn test_native_prompt_contains_important_directive() { + let prompt = get_system_prompt_for_native(); + assert!(prompt.contains("IMPORTANT: You must call tools to achieve goals"), + "Native prompt must contain IMPORTANT directive"); + } + + #[test] + fn test_non_native_prompt_contains_important_directive() { + let prompt = get_system_prompt_for_non_native(); + assert!(prompt.contains("IMPORTANT: You must call tools to achieve goals"), + "Non-native prompt must contain IMPORTANT directive"); + } + + #[test] + fn test_non_native_prompt_contains_tool_format() { + let prompt = get_system_prompt_for_non_native(); + assert!(prompt.contains("# Tool Call Format"), + "Non-native prompt must contain tool format section"); + assert!(prompt.contains("# Available Tools"), + "Non-native prompt must contain available tools section"); + } + + #[test] + fn test_agent_prompt_replaces_identity() { + let custom = "You are TestAgent, a specialized testing assistant."; + let prompt = get_agent_system_prompt(custom, true); + assert!(prompt.contains(custom), "Agent prompt should contain custom identity"); + assert!(!prompt.contains(G3_IDENTITY_LINE), "Agent prompt should not contain G3 identity"); + } + + #[test] + fn test_both_prompts_have_todo_section() { + let native = get_system_prompt_for_native(); + let non_native = get_system_prompt_for_non_native(); + + assert!(native.contains("# Task Management with TODO Tools")); + assert!(non_native.contains("# Task Management with TODO Tools")); + } + + #[test] + fn test_both_prompts_have_workspace_memory() { + let native = get_system_prompt_for_native(); + let non_native = get_system_prompt_for_non_native(); + + assert!(native.contains("# Workspace Memory")); + assert!(non_native.contains("# Workspace Memory")); + } + + #[test] + fn test_both_prompts_have_web_research() { + let native = get_system_prompt_for_native(); + let non_native = get_system_prompt_for_non_native(); + + assert!(native.contains("# Web Research")); + assert!(non_native.contains("# Web Research")); + } +} diff --git a/crates/g3-core/src/provider_registration.rs b/crates/g3-core/src/provider_registration.rs index 2a2841f..44baf39 100644 --- a/crates/g3-core/src/provider_registration.rs +++ b/crates/g3-core/src/provider_registration.rs @@ -78,7 +78,8 @@ fn register_embedded_providers( ) -> Result<()> { for (name, embedded_config) in &config.providers.embedded { if should_register(providers_to_register, "embedded", name) { - let embedded_provider = g3_providers::EmbeddedProvider::new( + let embedded_provider = g3_providers::EmbeddedProvider::new_with_name( + format!("embedded.{}", name), embedded_config.model_path.clone(), embedded_config.model_type.clone(), embedded_config.context_length, diff --git a/crates/g3-core/tests/project_context_test.rs b/crates/g3-core/tests/project_context_test.rs index 786c7ca..53951b6 100644 --- a/crates/g3-core/tests/project_context_test.rs +++ b/crates/g3-core/tests/project_context_test.rs @@ -40,7 +40,7 @@ async fn test_context_window_initial_structure() { // First message should be system prompt let system_msg = &context.conversation_history[0]; - assert!(system_msg.content.contains("IMPORTANT: You must call tools to achieve goals"), + assert!(system_msg.content.contains("You have access to tools"), "First message should be system prompt with tool instructions"); // Second message should be README content @@ -285,7 +285,7 @@ async fn test_full_context_order() { // Message 0: System prompt let system = &context.conversation_history[0].content; - assert!(system.contains("IMPORTANT: You must call tools"), + assert!(system.contains("You have access to tools"), "Message 0 should be system prompt"); // Message 1: Combined content with project appended diff --git a/crates/g3-providers/Cargo.toml b/crates/g3-providers/Cargo.toml index b4db302..017c2d4 100644 --- a/crates/g3-providers/Cargo.toml +++ b/crates/g3-providers/Cargo.toml @@ -27,6 +27,6 @@ nanoid = "0.4" serde_urlencoded = "0.7" tokio-util = "0.7" dirs = "5.0" -llama_cpp = { version = "0.3.2", features = ["metal"] } +llama-cpp-2 = { version = "0.1", features = ["metal"] } shellexpand = "3.1" rand = "0.8" diff --git a/crates/g3-providers/src/embedded.rs b/crates/g3-providers/src/embedded.rs index 9a196d2..8e47646 100644 --- a/crates/g3-providers/src/embedded.rs +++ b/crates/g3-providers/src/embedded.rs @@ -1,29 +1,58 @@ use crate::{ CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message, MessageRole, Usage, - streaming::{make_text_chunk, make_final_chunk}, + streaming::{make_text_chunk, make_final_chunk_with_reason}, }; use anyhow::Result; -use llama_cpp::{ - standard_sampler::{SamplerStage, StandardSampler}, - LlamaModel, LlamaParams, LlamaSession, SessionParams, +use llama_cpp_2::{ + context::params::LlamaContextParams, + llama_backend::LlamaBackend, + llama_batch::LlamaBatch, + model::{params::LlamaModelParams, AddBos, LlamaModel, Special}, + sampling::LlamaSampler, }; -use std::path::{Path, PathBuf}; +use std::num::NonZeroU32; +use std::path::PathBuf; use std::sync::Arc; use tokio::sync::mpsc; -use tokio::sync::Mutex; use tokio_stream::wrappers::ReceiverStream; use tracing::{debug, error}; +/// Embedded LLM provider using llama.cpp with Metal acceleration on macOS. +/// +/// Supports multiple model families with their native chat templates: +/// - Qwen (ChatML format) +/// - GLM-4 (ChatGLM4 format) +/// - Mistral (Instruct format) +/// - Llama/CodeLlama (Llama2 format) pub struct EmbeddedProvider { - session: Arc>, + /// Provider name in format "embedded.{config_name}" + name: String, + /// The loaded model + model: Arc, + /// The llama.cpp backend (must be kept alive) + backend: Arc, + /// Model type identifier (e.g., "qwen", "glm4", "mistral") + model_type: String, + /// Full model name for display model_name: String, - max_tokens: u32, + /// Maximum tokens to generate (None = auto-calculate) + max_tokens: Option, + /// Sampling temperature temperature: f32, + /// Context window size context_length: u32, + /// Number of GPU layers + gpu_layers: u32, + /// Number of threads + threads: Option, } impl EmbeddedProvider { + /// Create a new embedded provider with default naming. + /// + /// The provider will be registered as "embedded" (legacy behavior). + /// For proper multi-provider support, use `new_with_name()` instead. pub fn new( model_path: String, model_type: String, @@ -32,6 +61,39 @@ impl EmbeddedProvider { temperature: Option, gpu_layers: Option, threads: Option, + ) -> Result { + Self::new_with_name( + "embedded".to_string(), + model_path, + model_type, + context_length, + max_tokens, + temperature, + gpu_layers, + threads, + ) + } + + /// Create a new embedded provider with a custom name. + /// + /// # Arguments + /// * `name` - Provider name (e.g., "embedded.glm4", "embedded.qwen") + /// * `model_path` - Path to the GGUF model file (supports ~ expansion) + /// * `model_type` - Model family identifier ("qwen", "glm4", "glm", "mistral", "llama", etc.) + /// * `context_length` - Context window size (default: auto-detected from GGUF) + /// * `max_tokens` - Maximum tokens to generate (default: min(4096, context/4)) + /// * `temperature` - Sampling temperature (default: 0.1) + /// * `gpu_layers` - Number of layers to offload to GPU (default: 99 for Apple Silicon) + /// * `threads` - Number of CPU threads for inference + pub fn new_with_name( + name: String, + model_path: String, + model_type: String, + context_length: Option, + max_tokens: Option, + temperature: Option, + gpu_layers: Option, + threads: Option, ) -> Result { debug!("Loading embedded model from: {}", model_path); @@ -39,389 +101,247 @@ impl EmbeddedProvider { let expanded_path = shellexpand::tilde(&model_path); let model_path_buf = PathBuf::from(expanded_path.as_ref()); - // If model doesn't exist and it's the default Qwen model, offer to download it if !model_path_buf.exists() { - if model_path.contains("qwen2.5-7b-instruct-q3_k_m.gguf") { - debug!("Model file not found. Attempting to download Qwen 2.5 7B model..."); - Self::download_qwen_model(&model_path_buf)?; - } else { - anyhow::bail!("Model file not found: {}", model_path_buf.display()); - } + anyhow::bail!("Model file not found: {}", model_path_buf.display()); } - let model_path = model_path_buf.as_path(); + // Initialize the llama.cpp backend + let backend = LlamaBackend::init() + .map_err(|e| anyhow::anyhow!("Failed to initialize llama.cpp backend: {:?}", e))?; // Set up model parameters - let mut params = LlamaParams::default(); - - if let Some(gpu_layers) = gpu_layers { - params.n_gpu_layers = gpu_layers; - debug!("Using {} GPU layers", gpu_layers); - } - - let context_size = context_length.unwrap_or(4096); - debug!("Using context length: {}", context_size); + let n_gpu_layers = gpu_layers.unwrap_or(99); + let model_params = LlamaModelParams::default().with_n_gpu_layers(n_gpu_layers); + debug!("Using {} GPU layers", n_gpu_layers); // Load the model debug!("Loading model..."); - let model = LlamaModel::load_from_file(model_path, params) - .map_err(|e| anyhow::anyhow!("Failed to load model: {}", e))?; + let model = LlamaModel::load_from_file(&backend, &model_path_buf, &model_params) + .map_err(|e| anyhow::anyhow!("Failed to load model: {:?}", e))?; - // Create session with parameters - let mut session_params = SessionParams { - n_ctx: context_size, - ..Default::default() - }; - if let Some(threads) = threads { - session_params.n_threads = threads; - } + // Auto-detect context length from GGUF metadata, or use provided value + let model_ctx_train = model.n_ctx_train(); + let context_size = context_length.unwrap_or(model_ctx_train); + debug!( + "Context length: {} (model trained: {}, configured: {:?})", + context_size, model_ctx_train, context_length + ); - let session = model - .create_session(session_params) - .map_err(|e| anyhow::anyhow!("Failed to create session: {}", e))?; - - debug!("Successfully loaded {} model", model_type); + debug!("Successfully loaded {} model as '{}'", model_type, name); Ok(Self { - session: Arc::new(Mutex::new(session)), + name, + model: Arc::new(model), + backend: Arc::new(backend), + model_type: model_type.to_lowercase(), model_name: format!("embedded-{}", model_type), - max_tokens: max_tokens.unwrap_or(2048), + max_tokens, temperature: temperature.unwrap_or(0.1), context_length: context_size, + gpu_layers: n_gpu_layers, + threads, }) } + /// Format messages according to the model's native chat template. fn format_messages(&self, messages: &[Message]) -> String { - // Determine the appropriate format based on model type - let model_name_lower = self.model_name.to_lowercase(); + let model_type = &self.model_type; - if model_name_lower.contains("qwen") { - // Qwen format: <|im_start|>role\ncontent<|im_end|> - let mut formatted = String::new(); - - for message in messages { - let role = match message.role { - MessageRole::System => "system", - MessageRole::User => "user", - MessageRole::Assistant => "assistant", - }; - - formatted.push_str(&format!( - "<|im_start|>{}\n{}<|im_end|>\n", - role, message.content - )); - } - - // Add the start of assistant response - formatted.push_str("<|im_start|>assistant\n"); - formatted - } else if model_name_lower.contains("mistral") { - // Mistral Instruct format: [INST] ... [/INST] assistant_response - let mut formatted = String::new(); - let mut in_conversation = false; - - for (i, message) in messages.iter().enumerate() { - match message.role { - MessageRole::System => { - // Mistral doesn't have a special system token, include it at the start - if i == 0 { - formatted.push_str("[INST] "); - formatted.push_str(&message.content); - formatted.push_str("\n\n"); - in_conversation = true; - } - } - MessageRole::User => { - if !in_conversation { - formatted.push_str("[INST] "); - } - formatted.push_str(&message.content); - formatted.push_str(" [/INST]"); - in_conversation = false; - } - MessageRole::Assistant => { - formatted.push(' '); - formatted.push_str(&message.content); - formatted.push_str(" "); - in_conversation = false; - } - } - } - - // If the last message was from user, add a space for the assistant's response - if messages - .last() - .is_some_and(|m| matches!(m.role, MessageRole::User)) - { - formatted.push(' '); - } - - formatted + if model_type.contains("glm") { + self.format_glm4_messages(messages) + } else if model_type.contains("qwen") { + self.format_qwen_messages(messages) + } else if model_type.contains("mistral") { + self.format_mistral_messages(messages) } else { - // Use Llama/CodeLlama format for other models - let mut formatted = String::new(); - - for message in messages { - match message.role { - MessageRole::System => { - formatted.push_str(&format!( - "[INST] <>\n{}\n<>\n\n", - message.content - )); - } - MessageRole::User => { - formatted.push_str(&format!("{} [/INST] ", message.content)); - } - MessageRole::Assistant => { - formatted.push_str(&format!("{} [INST] ", message.content)); - } - } - } - formatted + // Default to Llama format + self.format_llama_messages(messages) } } - async fn generate_completion( - &self, - prompt: &str, - max_tokens: u32, - temperature: f32, - ) -> Result { - let session = self.session.clone(); - let prompt = prompt.to_string(); + /// GLM-4 ChatGLM4 format: [gMASK]<|role|>\ncontent + fn format_glm4_messages(&self, messages: &[Message]) -> String { + let mut formatted = String::from("[gMASK]"); - // Calculate dynamic max tokens based on available context headroom - let prompt_tokens = self.estimate_tokens(&prompt); - let available_tokens = self - .context_length - .saturating_sub(prompt_tokens) - .saturating_sub(50); // Reserve 50 tokens for safety - let dynamic_max_tokens = std::cmp::min(max_tokens as usize, available_tokens as usize); - - debug!("Context calculation: prompt_tokens={}, context_length={}, available_tokens={}, dynamic_max_tokens={}", - prompt_tokens, self.context_length, available_tokens, dynamic_max_tokens); - - // Get stop sequences before entering the closure - let stop_sequences = self.get_stop_sequences(); - - // Add timeout to the entire operation - let timeout_duration = std::time::Duration::from_secs(30); // Increased timeout for larger contexts - - let result = tokio::time::timeout( - timeout_duration, - tokio::task::spawn_blocking(move || { - // Retry logic for acquiring the session lock - let mut session_guard = None; - for attempt in 0..5 { - match session.try_lock() { - Ok(ctx) => { - session_guard = Some(ctx); - break; - } - Err(_) => { - if attempt < 4 { - debug!( - "Session busy, retrying in {}ms (attempt {}/5)", - 100 * (attempt + 1), - attempt + 1 - ); - std::thread::sleep(std::time::Duration::from_millis( - 100 * (attempt + 1) as u64, - )); - } else { - return Err(anyhow::anyhow!( - "Model is busy after 5 attempts, please try again" - )); - } - } - } - } - - let mut session = session_guard - .ok_or_else(|| anyhow::anyhow!("Failed to acquire session lock"))?; - - debug!( - "Starting inference with prompt length: {} chars, estimated {} tokens", - prompt.len(), - prompt_tokens - ); - - // Set context to the prompt - debug!("About to call set_context..."); - session - .set_context(&prompt) - .map_err(|e| anyhow::anyhow!("Failed to set context: {}", e))?; - debug!("set_context completed successfully"); - - // Create sampler with temperature - debug!("Creating sampler..."); - let stages = vec![ - SamplerStage::Temperature(temperature), - SamplerStage::TopK(40), - SamplerStage::TopP(0.9), - ]; - let sampler = StandardSampler::new_softmax(stages, 1); - debug!("Sampler created successfully"); - - // Start completion with dynamic max tokens - debug!( - "About to call start_completing_with with {} max tokens...", - dynamic_max_tokens - ); - let mut completion_handle = session - .start_completing_with(sampler, dynamic_max_tokens) - .map_err(|e| anyhow::anyhow!("Failed to start completion: {}", e))?; - debug!("start_completing_with completed successfully"); - - let mut generated_text = String::new(); - let mut token_count = 0; - let start_time = std::time::Instant::now(); - - debug!("Starting token generation loop..."); - - // Generate tokens with dynamic limits - while let Some(token) = completion_handle.next_token() { - // Check for timeout on each token - if start_time.elapsed() > std::time::Duration::from_secs(25) { - debug!("Token generation timeout after {} tokens", token_count); - break; - } - - let token_string = session.model().token_to_piece(token); - generated_text.push_str(&token_string); - token_count += 1; - - if token_count <= 10 || token_count % 50 == 0 { - debug!("Generated token {}: '{}'", token_count, token_string); - } - - // Use dynamic token limit - if token_count >= dynamic_max_tokens { - debug!("Reached dynamic token limit: {}", dynamic_max_tokens); - break; - } - - // Stop on completion markers - let mut hit_stop = false; - for stop_seq in &stop_sequences { - if generated_text.contains(stop_seq) { - debug!("Hit stop sequence '{}' at {} tokens", stop_seq, token_count); - hit_stop = true; - break; - } - } - - if hit_stop { - break; - } - } - - debug!( - "Token generation loop completed. Generated {} tokens in {:?}", - token_count, - start_time.elapsed() - ); - - Ok((generated_text, token_count)) - }), - ) - .await; - - match result { - Ok(inner_result) => match inner_result { - Ok(task_result) => match task_result { - Ok((text, token_count)) => { - debug!( - "Completed generation: {} tokens (dynamic limit was {})", - token_count, dynamic_max_tokens - ); - // Clean stop sequences from the generated text after the closure - Ok(self.clean_stop_sequences(&text)) - } - Err(e) => Err(e), - }, - Err(e) => Err(e.into()), - }, - Err(_) => { - error!("Generation timed out after 30 seconds"); - Err(anyhow::anyhow!("Generation timed out")) - } + for message in messages { + let role = match message.role { + MessageRole::System => "<|system|>", + MessageRole::User => "<|user|>", + MessageRole::Assistant => "<|assistant|>", + }; + formatted.push_str(&format!("{}\n{}", role, message.content)); } + + // Add the start of assistant response + formatted.push_str("<|assistant|>\n"); + formatted } - // Helper function to estimate token count from text + /// Qwen ChatML format: <|im_start|>role\ncontent<|im_end|> + fn format_qwen_messages(&self, messages: &[Message]) -> String { + let mut formatted = String::new(); + + for message in messages { + let role = match message.role { + MessageRole::System => "system", + MessageRole::User => "user", + MessageRole::Assistant => "assistant", + }; + + formatted.push_str(&format!( + "<|im_start|>{}\n{}<|im_end|>\n", + role, message.content + )); + } + + // Add the start of assistant response + formatted.push_str("<|im_start|>assistant\n"); + formatted + } + + /// Mistral Instruct format: [INST] ... [/INST] response + fn format_mistral_messages(&self, messages: &[Message]) -> String { + let mut formatted = String::new(); + let mut in_conversation = false; + + for (i, message) in messages.iter().enumerate() { + match message.role { + MessageRole::System => { + // Mistral doesn't have a special system token, include it at the start + if i == 0 { + formatted.push_str("[INST] "); + formatted.push_str(&message.content); + formatted.push_str("\n\n"); + in_conversation = true; + } + } + MessageRole::User => { + if !in_conversation { + formatted.push_str("[INST] "); + } + formatted.push_str(&message.content); + formatted.push_str(" [/INST]"); + in_conversation = false; + } + MessageRole::Assistant => { + formatted.push(' '); + formatted.push_str(&message.content); + formatted.push_str(" "); + in_conversation = false; + } + } + } + + // If the last message was from user, add a space for the assistant's response + if messages + .last() + .is_some_and(|m| matches!(m.role, MessageRole::User)) + { + formatted.push(' '); + } + + formatted + } + + /// Llama/CodeLlama format: [INST] <>\nsystem<>\n\nuser [/INST] + fn format_llama_messages(&self, messages: &[Message]) -> String { + let mut formatted = String::new(); + + for message in messages { + match message.role { + MessageRole::System => { + formatted.push_str(&format!( + "[INST] <>\n{}\n<>\n\n", + message.content + )); + } + MessageRole::User => { + formatted.push_str(&format!("{} [/INST] ", message.content)); + } + MessageRole::Assistant => { + formatted.push_str(&format!("{} [INST] ", message.content)); + } + } + } + formatted + } + + /// Estimate token count from text (rough approximation: ~4 chars per token) fn estimate_tokens(&self, text: &str) -> u32 { - // Rough estimation: average 4 characters per token - // This is conservative - actual tokenization might be different (text.len() as f32 / 4.0).ceil() as u32 } - // Helper function to get stop sequences based on model type + /// Get stop sequences based on model type. fn get_stop_sequences(&self) -> Vec<&'static str> { - // Determine model type from model_name - let model_name_lower = self.model_name.to_lowercase(); + let model_type = &self.model_type; - if model_name_lower.contains("qwen") { + if model_type.contains("glm") { vec![ - "<|im_end|>", // Qwen ChatML format end token - "<|endoftext|>", // Alternative end token - "", // Generic end of sequence - "<|im_start|>", // Start of new message (shouldn't appear in response) + "<|endoftext|>", // GLM end of text + "<|user|>", // Start of new user turn + "<|observation|>", // Tool observation (shouldn't appear in response) + "<|system|>", // System message (shouldn't appear in response) ] - } else if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama") - { + } else if model_type.contains("qwen") { vec![ - "", // End of sequence - "[/INST]", // End of instruction - "<>", // End of system message - "[INST]", // Start of new instruction (shouldn't appear in response) - "<>", // Start of system (shouldn't appear in response) + "<|im_end|>", // Qwen ChatML format end token + "<|endoftext|>", // Alternative end token + "", // Generic end of sequence + "<|im_start|>", // Start of new message (shouldn't appear in response) ] - } else if model_name_lower.contains("llama") { + } else if model_type.contains("codellama") || model_type.contains("code-llama") { vec![ - "", // End of sequence - "[/INST]", // End of instruction - "<>", // End of system message - "### Human:", // Conversation format - "### Assistant:", // Conversation format - "[INST]", // Start of new instruction + "", // End of sequence + "[/INST]", // End of instruction + "<>", // End of system message + "[INST]", // Start of new instruction + "<>", // Start of system ] - } else if model_name_lower.contains("mistral") { + } else if model_type.contains("llama") { vec![ - "", // End of sequence - "[/INST]", // End of instruction - "<|im_end|>", // ChatML format + "", // End of sequence + "[/INST]", // End of instruction + "<>", // End of system message + "### Human:", // Conversation format + "### Assistant:", // Conversation format + "[INST]", // Start of new instruction ] - } else if model_name_lower.contains("vicuna") || model_name_lower.contains("wizard") { + } else if model_type.contains("mistral") { vec![ - "### Human:", // Conversation format - "### Assistant:", // Conversation format - "USER:", // Alternative format - "ASSISTANT:", // Alternative format - "", // End of sequence + "", // End of sequence + "[/INST]", // End of instruction + "<|im_end|>", // ChatML format (some Mistral fine-tunes) ] - } else if model_name_lower.contains("alpaca") { + } else if model_type.contains("vicuna") || model_type.contains("wizard") { vec![ - "### Instruction:", // Alpaca format - "### Response:", // Alpaca format - "### Input:", // Alpaca format - "", // End of sequence + "### Human:", // Conversation format + "### Assistant:", // Conversation format + "USER:", // Alternative format + "ASSISTANT:", // Alternative format + "", // End of sequence + ] + } else if model_type.contains("alpaca") { + vec![ + "### Instruction:", // Alpaca format + "### Response:", // Alpaca format + "### Input:", // Alpaca format + "", // End of sequence ] } else { // Generic/unknown model - use common stop sequences vec![ - "", // Most common end sequence - "<|endoftext|>", // GPT-style - "<|im_end|>", // ChatML - "### Human:", // Common conversation format - "### Assistant:", // Common conversation format - "[/INST]", // Instruction format - "<>", // System format + "", // Most common end sequence + "<|endoftext|>", // GPT-style + "<|im_end|>", // ChatML + "### Human:", // Common conversation format + "### Assistant:", // Common conversation format + "[/INST]", // Instruction format + "<>", // System format ] } } - // Helper function to clean up stop sequences from generated text + /// Clean stop sequences from generated text. fn clean_stop_sequences(&self, text: &str) -> String { let mut cleaned = text.to_string(); let stop_sequences = self.get_stop_sequences(); @@ -436,70 +356,10 @@ impl EmbeddedProvider { cleaned.trim().to_string() } - // Download the Qwen 2.5 7B model if it doesn't exist - fn download_qwen_model(model_path: &Path) -> Result<()> { - use std::fs; - use std::process::Command; - - const MODEL_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-7B-Instruct-GGUF/resolve/main/qwen2.5-7b-instruct-q3_k_m.gguf"; - const MODEL_SIZE_MB: u64 = 3631; // Approximate size in MB - - // Create the parent directory if it doesn't exist - if let Some(parent) = model_path.parent() { - fs::create_dir_all(parent)?; - } - - debug!("Downloading Qwen 2.5 7B model (Q3_K_M quantization, ~3.5GB)..."); - debug!("This is a one-time download that may take several minutes depending on your connection."); - debug!("Downloading to: {}", model_path.display()); - - // Use curl with progress bar for download - let output = Command::new("curl") - .args([ - "-L", // Follow redirects - "-#", // Show progress bar - "-f", // Fail on HTTP errors - "-o", - model_path.to_str().unwrap(), - MODEL_URL, - ]) - .output()?; - - if !output.status.success() { - let stderr = String::from_utf8_lossy(&output.stderr); - - // If curl is not available, provide alternative instructions - if stderr.contains("command not found") || stderr.contains("not found") { - error!( - "curl is not installed. Please install curl or manually download the model." - ); - error!("Manual download instructions:"); - error!("1. Download from: {}", MODEL_URL); - error!("2. Save to: {}", model_path.display()); - anyhow::bail!( - "curl not found - please install curl or download the model manually" - ); - } - - anyhow::bail!("Failed to download model: {}", stderr); - } - - // Verify the file was created and has reasonable size - let metadata = fs::metadata(model_path)?; - let size_mb = metadata.len() / (1024 * 1024); - - if size_mb < MODEL_SIZE_MB - 100 { - // Allow some variance - fs::remove_file(model_path).ok(); // Clean up partial download - anyhow::bail!( - "Downloaded file appears incomplete ({}MB vs expected ~{}MB). Please try again.", - size_mb, - MODEL_SIZE_MB - ); - } - - debug!("Successfully downloaded Qwen 2.5 7B model ({}MB)", size_mb); - Ok(()) + /// Get the effective max tokens for generation + fn effective_max_tokens(&self) -> u32 { + self.max_tokens + .unwrap_or_else(|| std::cmp::min(4096, self.context_length / 4)) } } @@ -512,18 +372,119 @@ impl LLMProvider for EmbeddedProvider { ); let prompt = self.format_messages(&request.messages); - let max_tokens = request.max_tokens.unwrap_or(self.max_tokens); + let max_tokens = request.max_tokens.unwrap_or_else(|| self.effective_max_tokens()); let temperature = request.temperature.unwrap_or(self.temperature); debug!("Formatted prompt length: {} chars", prompt.len()); + + // Estimate prompt tokens before moving prompt into closure + let prompt_tokens = self.estimate_tokens(&prompt); - let content = self - .generate_completion(&prompt, max_tokens, temperature) - .await?; + // Clone what we need for the blocking task + let model = self.model.clone(); + let backend = self.backend.clone(); + let context_length = self.context_length; + let threads = self.threads; + let stop_sequences: Vec = self.get_stop_sequences().iter().map(|s| s.to_string()).collect(); - // Estimate token usage (rough approximation) - let prompt_tokens = (prompt.len() / 4) as u32; // Rough estimate: 4 chars per token - let completion_tokens = (content.len() / 4) as u32; + let (content, completion_tokens) = tokio::task::spawn_blocking(move || { + // Create context for this completion + let n_ctx = NonZeroU32::new(context_length).unwrap_or(NonZeroU32::new(4096).unwrap()); + let mut ctx_params = LlamaContextParams::default() + .with_n_ctx(Some(n_ctx)) + .with_n_batch(context_length); // Batch size must accommodate full prompt + if let Some(n_threads) = threads { + ctx_params = ctx_params.with_n_threads(n_threads as i32); + } + + let mut ctx = model + .new_context(&backend, ctx_params) + .map_err(|e| anyhow::anyhow!("Failed to create context: {:?}", e))?; + + // Tokenize the prompt + let tokens = model + .str_to_token(&prompt, AddBos::Always) + .map_err(|e| anyhow::anyhow!("Failed to tokenize: {:?}", e))?; + + debug!("Tokenized prompt: {} tokens", tokens.len()); + + // Create batch large enough for the prompt tokens + // The batch size must be at least as large as the number of tokens we're adding + let batch_size = std::cmp::max(512, tokens.len()); + let mut batch = LlamaBatch::new(batch_size, 1); + for (i, token) in tokens.iter().enumerate() { + batch + .add(*token, i as i32, &[0], i == tokens.len() - 1) + .map_err(|e| anyhow::anyhow!("Failed to add token to batch: {:?}", e))?; + } + + // Decode the prompt + ctx.decode(&mut batch) + .map_err(|e| anyhow::anyhow!("Failed to decode prompt: {:?}", e))?; + + // Set up sampler + let mut sampler = LlamaSampler::chain_simple([ + LlamaSampler::temp(temperature), + LlamaSampler::dist(1234), + ]); + + // Generate tokens + let mut generated_text = String::new(); + let mut n_cur = tokens.len() as i32; + let mut token_count = 0u32; + + for _ in 0..max_tokens { + let new_token = sampler.sample(&ctx, batch.n_tokens() - 1); + sampler.accept(new_token); + + // Check for end of generation + if model.is_eog_token(new_token) { + debug!("Hit end-of-generation token at {} tokens", token_count); + break; + } + + // Decode token to string + let token_str = model.token_to_str(new_token, Special::Tokenize) + .unwrap_or_default(); + generated_text.push_str(&token_str); + token_count += 1; + + // Check for stop sequences + let mut hit_stop = false; + for stop_seq in &stop_sequences { + if generated_text.contains(stop_seq) { + debug!("Hit stop sequence '{}' at {} tokens", stop_seq, token_count); + hit_stop = true; + break; + } + } + if hit_stop { + break; + } + + // Prepare next batch + batch.clear(); + batch + .add(new_token, n_cur, &[0], true) + .map_err(|e| anyhow::anyhow!("Failed to add token to batch: {:?}", e))?; + n_cur += 1; + + ctx.decode(&mut batch) + .map_err(|e| anyhow::anyhow!("Failed to decode: {:?}", e))?; + } + + // Clean stop sequences from output + for stop_seq in &stop_sequences { + if let Some(pos) = generated_text.find(stop_seq) { + generated_text.truncate(pos); + break; + } + } + + Ok::<_, anyhow::Error>((generated_text.trim().to_string(), token_count)) + }) + .await + .map_err(|e| anyhow::anyhow!("Task join error: {}", e))??; Ok(CompletionResponse { content, @@ -531,7 +492,7 @@ impl LLMProvider for EmbeddedProvider { prompt_tokens, completion_tokens, total_tokens: prompt_tokens + completion_tokens, - cache_creation_tokens: 0, // Embedded models don't support prompt caching + cache_creation_tokens: 0, cache_read_tokens: 0, }, model: self.model_name.clone(), @@ -545,243 +506,162 @@ impl LLMProvider for EmbeddedProvider { ); let prompt = self.format_messages(&request.messages); - let max_tokens = request.max_tokens.unwrap_or(self.max_tokens); + let max_tokens = request.max_tokens.unwrap_or_else(|| self.effective_max_tokens()); let temperature = request.temperature.unwrap_or(self.temperature); - let (tx, rx) = mpsc::channel(100); - let session = self.session.clone(); - let prompt = prompt.to_string(); + // Estimate prompt tokens for usage tracking + let prompt_tokens = self.estimate_tokens(&prompt); + + let (tx, rx) = mpsc::channel(100); + + // Clone what we need for the blocking task + let model = self.model.clone(); + let backend = self.backend.clone(); + let context_length = self.context_length; + let threads = self.threads; + let stop_sequences: Vec = self.get_stop_sequences().iter().map(|s| s.to_string()).collect(); - // Spawn streaming task tokio::task::spawn_blocking(move || { - // Retry logic for acquiring the session lock - let mut session_guard = None; - for attempt in 0..5 { - match session.try_lock() { - Ok(ctx) => { - session_guard = Some(ctx); - break; - } - Err(_) => { - if attempt < 4 { - debug!( - "Session busy, retrying in {}ms (attempt {}/5)", - 100 * (attempt + 1), - attempt + 1 - ); - std::thread::sleep(std::time::Duration::from_millis( - 100 * (attempt + 1) as u64, - )); - } else { - let _ = tx.blocking_send(Err(anyhow::anyhow!( - "Model is busy after 5 attempts, please try again" - ))); - return; - } - } - } + // Create context for this completion + let n_ctx = NonZeroU32::new(context_length).unwrap_or(NonZeroU32::new(4096).unwrap()); + let mut ctx_params = LlamaContextParams::default() + .with_n_ctx(Some(n_ctx)) + .with_n_batch(context_length); // Batch size must accommodate full prompt + if let Some(n_threads) = threads { + ctx_params = ctx_params.with_n_threads(n_threads as i32); } - let mut session = match session_guard { - Some(ctx) => ctx, - None => { - let _ = - tx.blocking_send(Err(anyhow::anyhow!("Failed to acquire session lock"))); + let mut ctx = match model.new_context(&backend, ctx_params) { + Ok(ctx) => ctx, + Err(e) => { + let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to create context: {:?}", e))); return; } }; - // Set context to the prompt - if let Err(e) = session.set_context(&prompt) { - let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to set context: {}", e))); + // Tokenize the prompt + let tokens = match model.str_to_token(&prompt, AddBos::Always) { + Ok(t) => t, + Err(e) => { + let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to tokenize: {:?}", e))); + return; + } + }; + + debug!("Tokenized prompt: {} tokens", tokens.len()); + + // Create batch large enough for the prompt tokens + // The batch size must be at least as large as the number of tokens we're adding + let batch_size = std::cmp::max(512, tokens.len()); + let mut batch = LlamaBatch::new(batch_size, 1); + for (i, token) in tokens.iter().enumerate() { + if let Err(e) = batch.add(*token, i as i32, &[0], i == tokens.len() - 1) { + let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to add token to batch: {:?}", e))); + return; + } + } + + // Decode the prompt + if let Err(e) = ctx.decode(&mut batch) { + let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to decode prompt: {:?}", e))); return; } - // Create sampler with temperature - let stages = vec![ - SamplerStage::Temperature(temperature), - SamplerStage::TopK(40), - SamplerStage::TopP(0.9), - ]; - let sampler = StandardSampler::new_softmax(stages, 1); - - // Start completion - let mut completion_handle = match session - .start_completing_with(sampler, max_tokens as usize) - { - Ok(handle) => handle, - Err(e) => { - let _ = - tx.blocking_send(Err(anyhow::anyhow!("Failed to start completion: {}", e))); - return; - } - }; + // Set up sampler + let mut sampler = LlamaSampler::chain_simple([ + LlamaSampler::temp(temperature), + LlamaSampler::dist(1234), + ]); + // Generate tokens let mut accumulated_text = String::new(); - let mut token_count = 0; - let mut unsent_tokens = String::new(); // Buffer for tokens we're holding back + let mut n_cur = tokens.len() as i32; + let mut token_count = 0u32; + let mut stop_reason: Option = None; - // Get stop sequences dynamically based on model type - let stop_sequences = if prompt.contains("<|im_start|>") { - // Qwen ChatML format detected - vec!["<|im_end|>", "<|endoftext|>", "", "<|im_start|>"] - } else if prompt.contains("[INST]") || prompt.contains("<>") { - // Llama/CodeLlama format detected - vec![ - "", - "[/INST]", - "<>", - "[INST]", - "<>", - "### Human:", - "### Assistant:", - ] - } else { - // Generic format - vec![ - "", - "<|endoftext|>", - "<|im_end|>", - "### Human:", - "### Assistant:", - "[/INST]", - "<>", - ] - }; + for _ in 0..max_tokens { + let new_token = sampler.sample(&ctx, batch.n_tokens() - 1); + sampler.accept(new_token); - // Stream tokens with proper limits - while let Some(token) = completion_handle.next_token() { - let token_string = session.model().token_to_piece(token); + // Check for end of generation + if model.is_eog_token(new_token) { + debug!("Hit end-of-generation token at {} tokens", token_count); + stop_reason = Some("end_turn".to_string()); + break; + } - accumulated_text.push_str(&token_string); - unsent_tokens.push_str(&token_string); + // Decode token to string + let token_str = model.token_to_str(new_token, Special::Tokenize) + .unwrap_or_default(); + + accumulated_text.push_str(&token_str); token_count += 1; - // Check if we've hit a complete stop sequence + // Check for stop sequences let mut hit_stop = false; for stop_seq in &stop_sequences { if accumulated_text.contains(stop_seq) { - debug!("Hit complete stop sequence in streaming: {}", stop_seq); + debug!("Hit stop sequence '{}' at {} tokens", stop_seq, token_count); hit_stop = true; + stop_reason = Some("stop_sequence".to_string()); break; } } if hit_stop { - // Before stopping, check if there might be an incomplete tool call - // Look for JSON tool call patterns that might be cut off by the stop sequence - let has_potential_tool_call = accumulated_text.contains(r#"{"tool":"#) - || accumulated_text.contains(r#"{"{""tool"":"#) - || accumulated_text.contains(r#"{{""tool"":"#); - - if has_potential_tool_call { - // Check if the tool call appears to be complete (has closing brace after the stop sequence) - let mut complete_tool_call = false; - for stop_seq in &stop_sequences { - if let Some(stop_pos) = accumulated_text.find(stop_seq) { - // Look for tool call pattern before the stop sequence - let before_stop = &accumulated_text[..stop_pos]; - if let Some(tool_start) = before_stop.rfind(r#"{"tool":"#) { - let tool_part = &before_stop[tool_start..]; - // Count braces to see if JSON is complete - let open_braces = tool_part.matches('{').count(); - let close_braces = tool_part.matches('}').count(); - if open_braces > 0 && open_braces == close_braces { - complete_tool_call = true; - break; - } - } - } - } - - // If tool call is incomplete, send the raw content including stop sequences - // so the main parser can handle it properly - if !complete_tool_call { - debug!("Found incomplete tool call, sending raw content with stop sequences"); - let already_sent_len = accumulated_text.len() - unsent_tokens.len(); - if accumulated_text.len() > already_sent_len { - let remaining_to_send = &accumulated_text[already_sent_len..]; - if !remaining_to_send.is_empty() { - let chunk = make_text_chunk(remaining_to_send.to_string()); - let _ = tx.blocking_send(Ok(chunk)); - } - } - break; - } - } - - // Send any remaining clean content before stopping (original behavior) - let mut clean_accumulated = accumulated_text.clone(); + // Send any remaining clean content + let mut clean_text = accumulated_text.clone(); for stop_seq in &stop_sequences { - if let Some(pos) = clean_accumulated.find(stop_seq) { - clean_accumulated.truncate(pos); + if let Some(pos) = clean_text.find(stop_seq) { + clean_text.truncate(pos); break; } } - - // Calculate what part we haven't sent yet - let already_sent_len = accumulated_text.len() - unsent_tokens.len(); - if clean_accumulated.len() > already_sent_len { - let remaining_to_send = &clean_accumulated[already_sent_len..]; - if !remaining_to_send.is_empty() { - let chunk = make_text_chunk(remaining_to_send.to_string()); - let _ = tx.blocking_send(Ok(chunk)); - } - } + // We've been sending incrementally, so just break break; } - // Check if we're building towards a stop sequence - let mut might_be_stop = false; - for stop_seq in &stop_sequences { - for i in 1..stop_seq.len() { - let partial = &stop_seq[..i]; - if accumulated_text.ends_with(partial) { - debug!("Detected potential partial stop sequence: '{}'", partial); - might_be_stop = true; - break; - } - } - if might_be_stop { - break; - } + // Send the token + let chunk = make_text_chunk(token_str); + if tx.blocking_send(Ok(chunk)).is_err() { + break; } - if might_be_stop { - // Hold back tokens, but only for a limited buffer size - if unsent_tokens.len() > 20 { - // Don't hold back more than 20 characters - // Send the oldest part and keep only the recent part that might be a stop sequence - let to_send = &unsent_tokens[..unsent_tokens.len() - 10]; - if !to_send.is_empty() { - let chunk = make_text_chunk(to_send.to_string()); - if tx.blocking_send(Ok(chunk)).is_err() { - break; - } - } - unsent_tokens = unsent_tokens[unsent_tokens.len() - 10..].to_string(); - } - // Continue to next token without sending - } else { - // No potential stop sequence, send all unsent tokens - if !unsent_tokens.is_empty() { - let chunk = make_text_chunk(unsent_tokens.clone()); - if tx.blocking_send(Ok(chunk)).is_err() { - break; - } - unsent_tokens.clear(); - } + // Check token limit + if token_count >= max_tokens { + debug!("Reached max token limit: {}", max_tokens); + stop_reason = Some("max_tokens".to_string()); + break; } - // Enforce token limit - if token_count >= max_tokens as usize { - debug!("Reached max token limit in streaming: {}", max_tokens); + // Prepare next batch + batch.clear(); + if let Err(e) = batch.add(new_token, n_cur, &[0], true) { + error!("Failed to add token to batch: {:?}", e); + break; + } + n_cur += 1; + + if let Err(e) = ctx.decode(&mut batch) { + error!("Failed to decode: {:?}", e); break; } } - // Send final chunk - let final_chunk = make_final_chunk(vec![], None); + // If no stop reason set, it was end_turn (natural completion) + if stop_reason.is_none() { + stop_reason = Some("end_turn".to_string()); + } + + // Send final chunk with usage information + let usage = Usage { + prompt_tokens, + completion_tokens: token_count, + total_tokens: prompt_tokens + token_count, + cache_creation_tokens: 0, + cache_read_tokens: 0, + }; + let final_chunk = make_final_chunk_with_reason(vec![], Some(usage), stop_reason); let _ = tx.blocking_send(Ok(final_chunk)); }); @@ -789,7 +669,7 @@ impl LLMProvider for EmbeddedProvider { } fn name(&self) -> &str { - "embedded" + &self.name } fn model(&self) -> &str { @@ -797,10 +677,240 @@ impl LLMProvider for EmbeddedProvider { } fn max_tokens(&self) -> u32 { - self.max_tokens + self.effective_max_tokens() } fn temperature(&self) -> f32 { self.temperature } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_format_glm4_messages() { + let messages = vec![ + Message::new(MessageRole::System, "You are a helpful assistant.".to_string()), + Message::new(MessageRole::User, "Hello!".to_string()), + ]; + + let formatted = format_glm4_messages_standalone(&messages); + + assert!(formatted.starts_with("[gMASK]")); + assert!(formatted.contains("<|system|>\nYou are a helpful assistant.")); + assert!(formatted.contains("<|user|>\nHello!")); + assert!(formatted.ends_with("<|assistant|>\n")); + } + + #[test] + fn test_format_qwen_messages() { + let messages = vec![ + Message::new(MessageRole::System, "You are a helpful assistant.".to_string()), + Message::new(MessageRole::User, "Hello!".to_string()), + ]; + + let formatted = format_qwen_messages_standalone(&messages); + + assert!(formatted.contains("<|im_start|>system\nYou are a helpful assistant.<|im_end|>")); + assert!(formatted.contains("<|im_start|>user\nHello!<|im_end|>")); + assert!(formatted.ends_with("<|im_start|>assistant\n")); + } + + #[test] + fn test_format_mistral_messages() { + let messages = vec![ + Message::new(MessageRole::System, "You are a helpful assistant.".to_string()), + Message::new(MessageRole::User, "Hello!".to_string()), + ]; + + let formatted = format_mistral_messages_standalone(&messages); + + assert!(formatted.starts_with("[INST] ")); + assert!(formatted.contains("You are a helpful assistant.")); + assert!(formatted.contains("Hello!")); + assert!(formatted.contains("[/INST]")); + } + + #[test] + fn test_format_llama_messages() { + let messages = vec![ + Message::new(MessageRole::System, "You are a helpful assistant.".to_string()), + Message::new(MessageRole::User, "Hello!".to_string()), + ]; + + let formatted = format_llama_messages_standalone(&messages); + + assert!(formatted.contains("<>")); + assert!(formatted.contains("You are a helpful assistant.")); + assert!(formatted.contains("<>")); + assert!(formatted.contains("Hello!")); + assert!(formatted.contains("[/INST]")); + } + + #[test] + fn test_glm4_stop_sequences() { + let stop_seqs = get_stop_sequences_for_model_type("glm4"); + + assert!(stop_seqs.contains(&"<|endoftext|>")); + assert!(stop_seqs.contains(&"<|user|>")); + assert!(stop_seqs.contains(&"<|observation|>")); + assert!(stop_seqs.contains(&"<|system|>")); + } + + #[test] + fn test_qwen_stop_sequences() { + let stop_seqs = get_stop_sequences_for_model_type("qwen"); + + assert!(stop_seqs.contains(&"<|im_end|>")); + assert!(stop_seqs.contains(&"<|endoftext|>")); + assert!(stop_seqs.contains(&"<|im_start|>")); + } + + #[test] + fn test_glm4_multi_turn_conversation() { + let messages = vec![ + Message::new(MessageRole::System, "You are a coding assistant.".to_string()), + Message::new(MessageRole::User, "Write a hello world in Python.".to_string()), + Message::new(MessageRole::Assistant, "print('Hello, World!')".to_string()), + Message::new(MessageRole::User, "Now in Rust.".to_string()), + ]; + + let formatted = format_glm4_messages_standalone(&messages); + + // Verify all parts are present in order + let system_pos = formatted.find("<|system|>").unwrap(); + let user1_pos = formatted.find("<|user|>\nWrite a hello world").unwrap(); + let assistant_pos = formatted.find("<|assistant|>\nprint").unwrap(); + let user2_pos = formatted.find("<|user|>\nNow in Rust").unwrap(); + let final_assistant_pos = formatted.rfind("<|assistant|>\n").unwrap(); + + assert!(system_pos < user1_pos); + assert!(user1_pos < assistant_pos); + assert!(assistant_pos < user2_pos); + assert!(user2_pos < final_assistant_pos); + } + + // Standalone formatting functions for testing without needing a full provider + fn format_glm4_messages_standalone(messages: &[Message]) -> String { + let mut formatted = String::from("[gMASK]"); + for message in messages { + let role = match message.role { + MessageRole::System => "<|system|>", + MessageRole::User => "<|user|>", + MessageRole::Assistant => "<|assistant|>", + }; + formatted.push_str(&format!("{}\n{}", role, message.content)); + } + formatted.push_str("<|assistant|>\n"); + formatted + } + + fn format_qwen_messages_standalone(messages: &[Message]) -> String { + let mut formatted = String::new(); + for message in messages { + let role = match message.role { + MessageRole::System => "system", + MessageRole::User => "user", + MessageRole::Assistant => "assistant", + }; + formatted.push_str(&format!( + "<|im_start|>{}\n{}<|im_end|>\n", + role, message.content + )); + } + formatted.push_str("<|im_start|>assistant\n"); + formatted + } + + fn format_mistral_messages_standalone(messages: &[Message]) -> String { + let mut formatted = String::new(); + let mut in_conversation = false; + for (i, message) in messages.iter().enumerate() { + match message.role { + MessageRole::System => { + if i == 0 { + formatted.push_str("[INST] "); + formatted.push_str(&message.content); + formatted.push_str("\n\n"); + in_conversation = true; + } + } + MessageRole::User => { + if !in_conversation { + formatted.push_str("[INST] "); + } + formatted.push_str(&message.content); + formatted.push_str(" [/INST]"); + in_conversation = false; + } + MessageRole::Assistant => { + formatted.push(' '); + formatted.push_str(&message.content); + formatted.push_str(" "); + in_conversation = false; + } + } + } + if messages.last().is_some_and(|m| matches!(m.role, MessageRole::User)) { + formatted.push(' '); + } + formatted + } + + fn format_llama_messages_standalone(messages: &[Message]) -> String { + let mut formatted = String::new(); + for message in messages { + match message.role { + MessageRole::System => { + formatted.push_str(&format!( + "[INST] <>\n{}\n<>\n\n", + message.content + )); + } + MessageRole::User => { + formatted.push_str(&format!("{} [/INST] ", message.content)); + } + MessageRole::Assistant => { + formatted.push_str(&format!("{} [INST] ", message.content)); + } + } + } + formatted + } + + fn get_stop_sequences_for_model_type(model_type: &str) -> Vec<&'static str> { + if model_type.contains("glm") { + vec![ + "<|endoftext|>", + "<|user|>", + "<|observation|>", + "<|system|>", + ] + } else if model_type.contains("qwen") { + vec![ + "<|im_end|>", + "<|endoftext|>", + "", + "<|im_start|>", + ] + } else if model_type.contains("mistral") { + vec![ + "", + "[/INST]", + "<|im_end|>", + ] + } else { + vec![ + "", + "<|endoftext|>", + "<|im_end|>", + "### Human:", + "### Assistant:", + "[/INST]", + "<>", + ] + } + } +}