embedded model support

This commit is contained in:
Dhanji Prasanna
2025-09-06 13:32:37 +10:00
parent 80e5178a1f
commit 1834b8946c
8 changed files with 793 additions and 14 deletions

270
Cargo.lock generated
View File

@@ -144,6 +144,29 @@ version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
[[package]]
name = "bindgen"
version = "0.69.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
dependencies = [
"bitflags 2.9.4",
"cexpr",
"clang-sys",
"itertools",
"lazy_static",
"lazycell",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash",
"shlex",
"syn",
"which",
]
[[package]]
name = "bitflags"
version = "1.3.2"
@@ -187,15 +210,37 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "590f9024a68a8c40351881787f1934dc11afd69090f5edb6831464694d836ea3"
dependencies = [
"find-msvc-tools",
"jobserver",
"libc",
"shlex",
]
[[package]]
name = "cexpr"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
dependencies = [
"nom",
]
[[package]]
name = "cfg-if"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9"
[[package]]
name = "clang-sys"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
dependencies = [
"glob",
"libc",
"libloading",
]
[[package]]
name = "clap"
version = "4.5.47"
@@ -249,7 +294,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68578f196d2a33ff61b27fae256c3164f65e36382648e30666dde05b8cc9dfdf"
dependencies = [
"async-trait",
"convert_case",
"convert_case 0.6.0",
"json5",
"nom",
"pathdiff",
@@ -281,6 +326,12 @@ dependencies = [
"tiny-keccak",
]
[[package]]
name = "convert_case"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e"
[[package]]
name = "convert_case"
version = "0.6.0"
@@ -331,6 +382,19 @@ dependencies = [
"typenum",
]
[[package]]
name = "derive_more"
version = "0.99.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6edb4b64a43d977b8e99788fe3a04d483834fba1215a7e02caa415b626497f7f"
dependencies = [
"convert_case 0.4.0",
"proc-macro2",
"quote",
"rustc_version",
"syn",
]
[[package]]
name = "digest"
version = "0.10.7"
@@ -382,6 +446,12 @@ dependencies = [
"const-random",
]
[[package]]
name = "either"
version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
[[package]]
name = "encoding_rs"
version = "0.8.35"
@@ -449,6 +519,21 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "futures"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.31"
@@ -456,6 +541,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10"
dependencies = [
"futures-core",
"futures-sink",
]
[[package]]
@@ -464,6 +550,17 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
[[package]]
name = "futures-executor"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.31"
@@ -499,6 +596,7 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
@@ -555,9 +653,11 @@ dependencies = [
"g3-config",
"g3-execution",
"g3-providers",
"llama_cpp",
"reqwest",
"serde",
"serde_json",
"shellexpand",
"thiserror 1.0.69",
"tokio",
"tokio-stream",
@@ -631,6 +731,12 @@ version = "0.31.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "glob"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280"
[[package]]
name = "h2"
version = "0.3.27"
@@ -681,6 +787,21 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
[[package]]
name = "home"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "http"
version = "0.2.12"
@@ -892,12 +1013,31 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "itertools"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
[[package]]
name = "jobserver"
version = "0.1.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33"
dependencies = [
"getrandom 0.3.3",
"libc",
]
[[package]]
name = "js-sys"
version = "0.3.78"
@@ -925,12 +1065,28 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]]
name = "libc"
version = "0.2.175"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
[[package]]
name = "libloading"
version = "0.8.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667"
dependencies = [
"cfg-if",
"windows-targets 0.53.3",
]
[[package]]
name = "libredox"
version = "0.1.9"
@@ -941,6 +1097,21 @@ dependencies = [
"libc",
]
[[package]]
name = "link-cplusplus"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c349c75e1ab4a03bd6b33fe6cbd3c479c5dd443e44ad732664d72cb0e755475"
dependencies = [
"cc",
]
[[package]]
name = "linux-raw-sys"
version = "0.4.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
[[package]]
name = "linux-raw-sys"
version = "0.9.4"
@@ -953,6 +1124,33 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956"
[[package]]
name = "llama_cpp"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f126770a2ed5e0e4596119479dc56f56b99037246bf0e36c544f7581a9458fd"
dependencies = [
"derive_more",
"futures",
"llama_cpp_sys",
"num_cpus",
"thiserror 1.0.69",
"tokio",
"tracing",
]
[[package]]
name = "llama_cpp_sys"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "037a1881ada3592c6a922224d5177b4b4f452e6b2979eb97393b71989e48357f"
dependencies = [
"bindgen",
"cc",
"link-cplusplus",
"once_cell",
]
[[package]]
name = "lock_api"
version = "0.4.13"
@@ -1043,6 +1241,16 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "num_cpus"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b"
dependencies = [
"hermit-abi",
"libc",
]
[[package]]
name = "object"
version = "0.36.7"
@@ -1230,6 +1438,16 @@ dependencies = [
"zerovec",
]
[[package]]
name = "prettyplease"
version = "0.2.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b"
dependencies = [
"proc-macro2",
"syn",
]
[[package]]
name = "proc-macro2"
version = "1.0.101"
@@ -1373,6 +1591,34 @@ version = "0.1.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace"
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustc_version"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92"
dependencies = [
"semver",
]
[[package]]
name = "rustix"
version = "0.38.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
dependencies = [
"bitflags 2.9.4",
"errno",
"libc",
"linux-raw-sys 0.4.15",
"windows-sys 0.59.0",
]
[[package]]
name = "rustix"
version = "1.0.8"
@@ -1382,7 +1628,7 @@ dependencies = [
"bitflags 2.9.4",
"errno",
"libc",
"linux-raw-sys",
"linux-raw-sys 0.9.4",
"windows-sys 0.60.2",
]
@@ -1445,6 +1691,12 @@ dependencies = [
"libc",
]
[[package]]
name = "semver"
version = "1.0.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0"
[[package]]
name = "serde"
version = "1.0.219"
@@ -1644,7 +1896,7 @@ dependencies = [
"fastrand",
"getrandom 0.3.3",
"once_cell",
"rustix",
"rustix 1.0.8",
"windows-sys 0.60.2",
]
@@ -2087,6 +2339,18 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "which"
version = "4.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
dependencies = [
"either",
"home",
"once_cell",
"rustix 0.38.44",
]
[[package]]
name = "windows-link"
version = "0.1.3"

View File

@@ -21,9 +21,9 @@ G3 is a **code-first AI agent** that helps you complete tasks by writing and exe
│ │ │ │ │ │
│ - Task commands │◄──►│ - Task │◄──►│ - OpenAI │
│ - Interactive │ │ interpretation│ │ - Anthropic │
│ mode │ │ - Code │ │ - Local models
│ - Code exec │ │ generation │ │ - Custom APIs
│ approval │ │ - Script │ │
│ mode │ │ - Code │ │ - Embedded
│ - Code exec │ │ generation │ │ (llama.cpp)
│ approval │ │ - Script │ │ - Custom APIs
│ │ │ execution │ │ │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │ │
@@ -58,11 +58,25 @@ G3 is a **code-first AI agent** that helps you complete tasks by writing and exe
- Autonomous execution of generated code
#### 3. LLM Providers (`g3-providers`)
- **Responsibility**: LLM communication (unchanged)
- **Responsibility**: LLM communication and model abstraction
- **Supported Providers**:
- **OpenAI**: GPT-4, GPT-3.5-turbo via API
- **Anthropic**: Claude models via API
- **Embedded**: Local open-weights models via llama.cpp
- **Enhanced Prompts**:
- Code-first system prompts
- Language-specific generation instructions
#### 5. Embedded Provider (`g3-core/providers/embedded`) - NEW
- **Responsibility**: Local model inference using llama.cpp
- **Features**:
- GGUF model support (Llama, CodeLlama, Mistral, etc.)
- GPU acceleration via CUDA/Metal
- Configurable context length and generation parameters
- Async-compatible inference without blocking
- Thread-safe model access
- Stop sequence detection
#### 4. Execution Engine (`g3-execution`) - NEW
- **Responsibility**: Safe code execution
- **Features**:
@@ -86,8 +100,73 @@ G3 is a **code-first AI agent** that helps you complete tasks by writing and exe
## Implementation Plan
### Phase 1: Core Refactoring
1. Update CLI commands for task-oriented interface
2. Enhance system prompts for code-first approach
3. Add basic code execution capabilities
4. Update interactive mode messaging
### Phase 1: Core Refactoring
1. Update CLI commands for task-oriented interface
2. Enhance system prompts for code-first approach
3. Add basic code execution capabilities
4. Update interactive mode messaging
### Phase 2: Enhanced Provider Support ✅
1. ✅ Implement embedded model provider using llama.cpp
2. ✅ Add GGUF model support for local inference
3. ✅ Configure GPU acceleration and performance optimization
4. ✅ Add comprehensive logging and debugging support
### Phase 3: Advanced Features (Future)
1. Model quantization and optimization
2. Multi-model ensemble support
3. Advanced code execution sandboxing
4. Plugin system for custom providers
5. Web interface for remote access
## Provider Comparison
| Feature | OpenAI | Anthropic | Embedded |
|---------|--------|-----------|----------|
| **Cost** | Pay per token | Pay per token | Free after download |
| **Privacy** | Data sent to API | Data sent to API | Completely local |
| **Performance** | Very fast | Very fast | Depends on hardware |
| **Model Quality** | Excellent | Excellent | Good (varies by model) |
| **Offline Support** | No | No | Yes |
| **Setup Complexity** | API key only | API key only | Model download required |
| **Hardware Requirements** | None | None | 4-16GB RAM, optional GPU |
## Configuration Examples
### Cloud-First Setup
```toml
[providers]
default_provider = "openai"
[providers.openai]
api_key = "sk-..."
model = "gpt-4"
```
### Privacy-First Setup
```toml
[providers]
default_provider = "embedded"
[providers.embedded]
model_path = "~/.cache/g3/models/codellama-7b-instruct.Q4_K_M.gguf"
model_type = "codellama"
gpu_layers = 32
```
### Hybrid Setup
```toml
[providers]
default_provider = "embedded"
# Use embedded for most tasks
[providers.embedded]
model_path = "~/.cache/g3/models/codellama-7b-instruct.Q4_K_M.gguf"
model_type = "codellama"
gpu_layers = 32
# Fallback to cloud for complex tasks
[providers.openai]
api_key = "sk-..."
model = "gpt-4"
```

View File

@@ -20,7 +20,8 @@ G3 will write the appropriate scripts (Python, Bash, JavaScript, etc.) and can e
- **Code-First Approach**: Always tries to solve problems with executable code
- **Multi-Language Support**: Generates Python, Bash, JavaScript, Rust, and more
- **Modular Architecture**: Clean separation between CLI, core engine, and LLM providers
- **Multiple LLM Providers**: Support for OpenAI, Anthropic, and extensible to other providers
- **Multiple LLM Providers**: Support for OpenAI, Anthropic, and embedded open-weights models
- **Local Model Support**: Run completely offline with embedded GGUF models via llama.cpp
- **Interactive Mode**: Chat with the AI and watch it solve problems in real-time
- **Task Automation**: Create reusable automation scripts
@@ -34,6 +35,8 @@ cargo install --path .
Create a configuration file at `~/.config/g3/config.toml`:
### Cloud Providers
```toml
[providers]
default_provider = "openai"
@@ -49,7 +52,37 @@ api_key = "your-anthropic-api-key"
model = "claude-3-sonnet-20240229"
max_tokens = 2048
temperature = 0.1
```
### Local Embedded Models
For completely offline operation with open-weights models:
```toml
[providers]
default_provider = "embedded"
[providers.embedded]
# Path to your GGUF model file
model_path = "~/.cache/g3/models/codellama-7b-instruct.Q4_K_M.gguf"
model_type = "codellama"
context_length = 4096
max_tokens = 2048
temperature = 0.1
# Number of layers to offload to GPU (0 for CPU only)
gpu_layers = 32
# Number of CPU threads to use
threads = 8
```
**Getting Models**: Download GGUF models from [Hugging Face](https://huggingface.co/models?library=gguf) (search for "GGUF"). Popular options:
- [CodeLlama 7B Instruct](https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GGUF)
- [Llama 2 7B Chat](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF)
- [Mistral 7B Instruct](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF)
### Agent Settings
```toml
[agent]
max_context_length = 8192
enable_streaming = true

View File

@@ -2,7 +2,7 @@
# Copy to ~/.config/g3/config.toml and customize
[providers]
default_provider = "openai"
default_provider = "embedded"
[providers.openai]
# Get your API key from https://platform.openai.com/api-keys
@@ -20,6 +20,18 @@ model = "claude-3-sonnet-20240229"
max_tokens = 2048
temperature = 0.1
[providers.embedded]
# Path to your GGUF model file
model_path = "~/.cache/g3/models/codellama-7b-instruct.Q4_K_M.gguf"
model_type = "codellama"
context_length = 16384 # Use CodeLlama's full context capability
max_tokens = 2048 # Default fallback, but will be calculated dynamically
temperature = 0.1
# Number of layers to offload to GPU (0 for CPU only)
gpu_layers = 32
# Number of CPU threads to use
threads = 8
[agent]
max_context_length = 8192
enable_streaming = true

View File

@@ -12,6 +12,7 @@ pub struct Config {
pub struct ProvidersConfig {
pub openai: Option<OpenAIConfig>,
pub anthropic: Option<AnthropicConfig>,
pub embedded: Option<EmbeddedConfig>,
pub default_provider: String,
}
@@ -32,6 +33,17 @@ pub struct AnthropicConfig {
pub temperature: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddedConfig {
pub model_path: String,
pub model_type: String, // e.g., "llama", "mistral", "codellama"
pub context_length: Option<u32>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub gpu_layers: Option<u32>, // Number of layers to offload to GPU
pub threads: Option<u32>, // Number of CPU threads to use
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
pub max_context_length: usize,
@@ -45,6 +57,7 @@ impl Default for Config {
providers: ProvidersConfig {
openai: None,
anthropic: None,
embedded: None,
default_provider: "openai".to_string(),
},
agent: AgentConfig {

View File

@@ -18,3 +18,5 @@ serde_json = { workspace = true }
uuid = { workspace = true }
async-trait = "0.1"
tokio-stream = "0.1"
llama_cpp = { version = "0.3.2", features = ["metal"] }
shellexpand = "3.1"

View File

@@ -72,6 +72,19 @@ impl Agent {
providers.register(anthropic_provider);
}
if let Some(embedded_config) = &config.providers.embedded {
let embedded_provider = crate::providers::embedded::EmbeddedProvider::new(
embedded_config.model_path.clone(),
embedded_config.model_type.clone(),
embedded_config.context_length,
embedded_config.max_tokens,
embedded_config.temperature,
embedded_config.gpu_layers,
embedded_config.threads,
)?;
providers.register(embedded_provider);
}
// Set default provider
providers.set_default(&config.providers.default_provider)?;
@@ -522,4 +535,5 @@ impl std::fmt::Display for AnalysisResult {
pub mod providers {
pub mod anthropic;
pub mod openai;
pub mod embedded;
}

View File

@@ -0,0 +1,362 @@
use g3_providers::{LLMProvider, CompletionRequest, CompletionResponse, CompletionStream, CompletionChunk, Usage, Message, MessageRole};
use anyhow::Result;
use llama_cpp::{LlamaModel, LlamaSession, LlamaParams, SessionParams, standard_sampler::{StandardSampler, SamplerStage}};
use std::path::Path;
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{debug, info, error, warn};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use std::sync::atomic::{AtomicBool, Ordering};
pub struct EmbeddedProvider {
model: Arc<LlamaModel>,
session: Arc<Mutex<LlamaSession>>,
model_name: String,
max_tokens: u32,
temperature: f32,
context_length: u32,
generation_active: Arc<AtomicBool>,
}
impl EmbeddedProvider {
pub fn new(
model_path: String,
model_type: String,
context_length: Option<u32>,
max_tokens: Option<u32>,
temperature: Option<f32>,
gpu_layers: Option<u32>,
threads: Option<u32>,
) -> Result<Self> {
info!("Loading embedded model from: {}", model_path);
// Expand tilde in path
let expanded_path = shellexpand::tilde(&model_path);
let model_path = Path::new(expanded_path.as_ref());
if !model_path.exists() {
anyhow::bail!("Model file not found: {}", model_path.display());
}
// Set up model parameters
let mut params = LlamaParams::default();
if let Some(gpu_layers) = gpu_layers {
params.n_gpu_layers = gpu_layers;
info!("Using {} GPU layers", gpu_layers);
}
let context_size = context_length.unwrap_or(4096);
info!("Using context length: {}", context_size);
// Load the model
info!("Loading model...");
let model = LlamaModel::load_from_file(model_path, params)
.map_err(|e| anyhow::anyhow!("Failed to load model: {}", e))?;
// Create session with parameters
let mut session_params = SessionParams::default();
session_params.n_ctx = context_size;
if let Some(threads) = threads {
session_params.n_threads = threads;
}
let session = model.create_session(session_params)
.map_err(|e| anyhow::anyhow!("Failed to create session: {}", e))?;
info!("Successfully loaded {} model", model_type);
Ok(Self {
model: Arc::new(model),
session: Arc::new(Mutex::new(session)),
model_name: format!("embedded-{}", model_type),
max_tokens: max_tokens.unwrap_or(2048),
temperature: temperature.unwrap_or(0.1),
context_length: context_size,
generation_active: Arc::new(AtomicBool::new(false)),
})
}
fn format_messages(&self, messages: &[Message]) -> String {
// Use proper prompt format for CodeLlama
let mut formatted = String::new();
for message in messages {
match message.role {
MessageRole::System => {
formatted.push_str(&format!("[INST] <<SYS>>\n{}\n<</SYS>>\n\n", message.content));
}
MessageRole::User => {
formatted.push_str(&format!("{} [/INST] ", message.content));
}
MessageRole::Assistant => {
formatted.push_str(&format!("{} </s><s>[INST] ", message.content));
}
}
}
formatted
}
async fn generate_completion(&self, prompt: &str, max_tokens: u32, temperature: f32) -> Result<String> {
let session = self.session.clone();
let prompt = prompt.to_string();
// Calculate dynamic max tokens based on available context headroom
let prompt_tokens = self.estimate_tokens(&prompt);
let available_tokens = self.context_length.saturating_sub(prompt_tokens).saturating_sub(50); // Reserve 50 tokens for safety
let dynamic_max_tokens = std::cmp::min(max_tokens as usize, available_tokens as usize);
debug!("Context calculation: prompt_tokens={}, context_length={}, available_tokens={}, dynamic_max_tokens={}",
prompt_tokens, self.context_length, available_tokens, dynamic_max_tokens);
// Add timeout to the entire operation
let timeout_duration = std::time::Duration::from_secs(30); // Increased timeout for larger contexts
let result = tokio::time::timeout(timeout_duration, tokio::task::spawn_blocking(move || {
let mut session = match session.try_lock() {
Ok(ctx) => ctx,
Err(_) => return Err(anyhow::anyhow!("Model is busy, please try again")),
};
debug!("Starting inference with prompt length: {} chars, estimated {} tokens", prompt.len(), prompt_tokens);
// Set context to the prompt
debug!("About to call set_context...");
session.set_context(&prompt)
.map_err(|e| anyhow::anyhow!("Failed to set context: {}", e))?;
debug!("set_context completed successfully");
// Create sampler with temperature
debug!("Creating sampler...");
let stages = vec![
SamplerStage::Temperature(temperature),
SamplerStage::TopK(40),
SamplerStage::TopP(0.9),
];
let sampler = StandardSampler::new_softmax(stages, 1);
debug!("Sampler created successfully");
// Start completion with dynamic max tokens
debug!("About to call start_completing_with with {} max tokens...", dynamic_max_tokens);
let mut completion_handle = session.start_completing_with(sampler, dynamic_max_tokens)
.map_err(|e| anyhow::anyhow!("Failed to start completion: {}", e))?;
debug!("start_completing_with completed successfully");
let mut generated_text = String::new();
let mut token_count = 0;
let start_time = std::time::Instant::now();
debug!("Starting token generation loop...");
// Generate tokens with dynamic limits
while let Some(token) = completion_handle.next_token() {
// Check for timeout on each token
if start_time.elapsed() > std::time::Duration::from_secs(25) {
debug!("Token generation timeout after {} tokens", token_count);
break;
}
let token_string = session.model().token_to_piece(token);
generated_text.push_str(&token_string);
token_count += 1;
if token_count <= 10 || token_count % 50 == 0 {
debug!("Generated token {}: '{}'", token_count, token_string);
}
// Use dynamic token limit
if token_count >= dynamic_max_tokens {
debug!("Reached dynamic token limit: {}", dynamic_max_tokens);
break;
}
// Stop on completion markers
if generated_text.contains("</s>") || generated_text.contains("[/INST]") {
debug!("Hit CodeLlama stop sequence at {} tokens", token_count);
break;
}
// Stop on natural completion points after reasonable generation
if token_count >= 20 && (
generated_text.trim().ends_with("```") ||
(generated_text.contains("```") && generated_text.matches("```").count() % 2 == 0) // Complete code blocks
) {
debug!("Hit code block completion at {} tokens", token_count);
break;
}
}
debug!("Token generation loop completed. Generated {} tokens in {:?}", token_count, start_time.elapsed());
Ok((generated_text.trim().to_string(), token_count))
})).await;
match result {
Ok(inner_result) => match inner_result {
Ok(task_result) => match task_result {
Ok((text, token_count)) => {
info!("Completed generation: {} tokens (dynamic limit was {})", token_count, dynamic_max_tokens);
Ok(text)
}
Err(e) => Err(e),
},
Err(e) => Err(e.into()),
},
Err(_) => {
error!("Generation timed out after 30 seconds");
Err(anyhow::anyhow!("Generation timed out"))
}
}
}
// Helper function to estimate token count from text
fn estimate_tokens(&self, text: &str) -> u32 {
// Rough estimation: average 4 characters per token
// This is conservative - actual tokenization might be different
(text.len() as f32 / 4.0).ceil() as u32
}
}
#[async_trait::async_trait]
impl LLMProvider for EmbeddedProvider {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
debug!("Processing completion request with {} messages", request.messages.len());
let prompt = self.format_messages(&request.messages);
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
debug!("Formatted prompt length: {} chars", prompt.len());
let content = self.generate_completion(&prompt, max_tokens, temperature).await?;
// Estimate token usage (rough approximation)
let prompt_tokens = (prompt.len() / 4) as u32; // Rough estimate: 4 chars per token
let completion_tokens = (content.len() / 4) as u32;
Ok(CompletionResponse {
content,
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
},
model: self.model_name.clone(),
})
}
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
debug!("Processing streaming request with {} messages", request.messages.len());
let prompt = self.format_messages(&request.messages);
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
let (tx, rx) = mpsc::channel(100);
let session = self.session.clone();
let prompt = prompt.to_string();
// Spawn streaming task
tokio::task::spawn_blocking(move || {
let mut session = match session.try_lock() {
Ok(ctx) => ctx,
Err(_) => {
let _ = tx.blocking_send(Err(anyhow::anyhow!("Model is busy, please try again")));
return;
}
};
// Set context to the prompt
if let Err(e) = session.set_context(&prompt) {
let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to set context: {}", e)));
return;
}
// Create sampler with temperature
let stages = vec![
SamplerStage::Temperature(temperature),
SamplerStage::TopK(40),
SamplerStage::TopP(0.9),
];
let sampler = StandardSampler::new_softmax(stages, 1);
// Start completion
let mut completion_handle = match session.start_completing_with(sampler, max_tokens as usize) {
Ok(handle) => handle,
Err(e) => {
let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to start completion: {}", e)));
return;
}
};
let mut accumulated_text = String::new();
let mut token_count = 0;
// Stream tokens with proper limits
while let Some(token) = completion_handle.next_token() {
let token_string = session.model().token_to_piece(token);
accumulated_text.push_str(&token_string);
token_count += 1;
let chunk = CompletionChunk {
content: token_string.clone(),
finished: false,
};
if tx.blocking_send(Ok(chunk)).is_err() {
break; // Receiver dropped
}
// Enforce token limit
if token_count >= max_tokens as usize {
debug!("Reached max token limit in streaming: {}", max_tokens);
break;
}
// Stop if we hit common stop sequences
if accumulated_text.contains("### Human") ||
accumulated_text.contains("### System") ||
accumulated_text.contains("<|end|>") ||
accumulated_text.contains("</s>") ||
accumulated_text.trim().ends_with("```") {
debug!("Hit stop sequence in streaming, stopping generation");
break;
}
// Emergency brake for streaming too
if token_count > 0 && token_count % 100 == 0 {
debug!("Streaming: Generated {} tokens so far", token_count);
if accumulated_text.trim().len() > 50 &&
(accumulated_text.contains('\n') || accumulated_text.len() > 200) {
if accumulated_text.trim().ends_with('.') ||
accumulated_text.trim().ends_with('!') ||
accumulated_text.trim().ends_with('?') ||
accumulated_text.trim().ends_with('\n') {
debug!("Found natural stopping point in streaming at {} tokens", token_count);
break;
}
}
}
}
// Send final chunk
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
};
let _ = tx.blocking_send(Ok(final_chunk));
});
Ok(ReceiverStream::new(rx))
}
fn name(&self) -> &str {
"embedded"
}
fn model(&self) -> &str {
&self.model_name
}
}