diff --git a/Cargo.lock b/Cargo.lock index 437a81b..a684345 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -144,6 +144,29 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "bindgen" +version = "0.69.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +dependencies = [ + "bitflags 2.9.4", + "cexpr", + "clang-sys", + "itertools", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn", + "which", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -187,15 +210,37 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "590f9024a68a8c40351881787f1934dc11afd69090f5edb6831464694d836ea3" dependencies = [ "find-msvc-tools", + "jobserver", + "libc", "shlex", ] +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "clap" version = "4.5.47" @@ -249,7 +294,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68578f196d2a33ff61b27fae256c3164f65e36382648e30666dde05b8cc9dfdf" dependencies = [ "async-trait", - "convert_case", + "convert_case 0.6.0", "json5", "nom", "pathdiff", @@ -281,6 +326,12 @@ dependencies = [ "tiny-keccak", ] +[[package]] +name = "convert_case" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" + [[package]] name = "convert_case" version = "0.6.0" @@ -331,6 +382,19 @@ dependencies = [ "typenum", ] +[[package]] +name = "derive_more" +version = "0.99.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6edb4b64a43d977b8e99788fe3a04d483834fba1215a7e02caa415b626497f7f" +dependencies = [ + "convert_case 0.4.0", + "proc-macro2", + "quote", + "rustc_version", + "syn", +] + [[package]] name = "digest" version = "0.10.7" @@ -382,6 +446,12 @@ dependencies = [ "const-random", ] +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + [[package]] name = "encoding_rs" version = "0.8.35" @@ -449,6 +519,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -456,6 +541,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -464,6 +550,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + [[package]] name = "futures-io" version = "0.3.31" @@ -499,6 +596,7 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", "futures-io", "futures-macro", @@ -555,9 +653,11 @@ dependencies = [ "g3-config", "g3-execution", "g3-providers", + "llama_cpp", "reqwest", "serde", "serde_json", + "shellexpand", "thiserror 1.0.69", "tokio", "tokio-stream", @@ -631,6 +731,12 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + [[package]] name = "h2" version = "0.3.27" @@ -681,6 +787,21 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + +[[package]] +name = "home" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "http" version = "0.2.12" @@ -892,12 +1013,31 @@ version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.3", + "libc", +] + [[package]] name = "js-sys" version = "0.3.78" @@ -925,12 +1065,28 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "libc" version = "0.2.175" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" +[[package]] +name = "libloading" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" +dependencies = [ + "cfg-if", + "windows-targets 0.53.3", +] + [[package]] name = "libredox" version = "0.1.9" @@ -941,6 +1097,21 @@ dependencies = [ "libc", ] +[[package]] +name = "link-cplusplus" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c349c75e1ab4a03bd6b33fe6cbd3c479c5dd443e44ad732664d72cb0e755475" +dependencies = [ + "cc", +] + +[[package]] +name = "linux-raw-sys" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" + [[package]] name = "linux-raw-sys" version = "0.9.4" @@ -953,6 +1124,33 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" +[[package]] +name = "llama_cpp" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f126770a2ed5e0e4596119479dc56f56b99037246bf0e36c544f7581a9458fd" +dependencies = [ + "derive_more", + "futures", + "llama_cpp_sys", + "num_cpus", + "thiserror 1.0.69", + "tokio", + "tracing", +] + +[[package]] +name = "llama_cpp_sys" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037a1881ada3592c6a922224d5177b4b4f452e6b2979eb97393b71989e48357f" +dependencies = [ + "bindgen", + "cc", + "link-cplusplus", + "once_cell", +] + [[package]] name = "lock_api" version = "0.4.13" @@ -1043,6 +1241,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "num_cpus" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "object" version = "0.36.7" @@ -1230,6 +1438,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + [[package]] name = "proc-macro2" version = "1.0.101" @@ -1373,6 +1591,34 @@ version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + +[[package]] +name = "rustix" +version = "0.38.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" +dependencies = [ + "bitflags 2.9.4", + "errno", + "libc", + "linux-raw-sys 0.4.15", + "windows-sys 0.59.0", +] + [[package]] name = "rustix" version = "1.0.8" @@ -1382,7 +1628,7 @@ dependencies = [ "bitflags 2.9.4", "errno", "libc", - "linux-raw-sys", + "linux-raw-sys 0.9.4", "windows-sys 0.60.2", ] @@ -1445,6 +1691,12 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" + [[package]] name = "serde" version = "1.0.219" @@ -1644,7 +1896,7 @@ dependencies = [ "fastrand", "getrandom 0.3.3", "once_cell", - "rustix", + "rustix 1.0.8", "windows-sys 0.60.2", ] @@ -2087,6 +2339,18 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix 0.38.44", +] + [[package]] name = "windows-link" version = "0.1.3" diff --git a/DESIGN.md b/DESIGN.md index 47b5896..45696ce 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -21,9 +21,9 @@ G3 is a **code-first AI agent** that helps you complete tasks by writing and exe │ │ │ │ │ │ │ - Task commands │◄──►│ - Task │◄──►│ - OpenAI │ │ - Interactive │ │ interpretation│ │ - Anthropic │ -│ mode │ │ - Code │ │ - Local models │ -│ - Code exec │ │ generation │ │ - Custom APIs │ -│ approval │ │ - Script │ │ │ +│ mode │ │ - Code │ │ - Embedded │ +│ - Code exec │ │ generation │ │ (llama.cpp) │ +│ approval │ │ - Script │ │ - Custom APIs │ │ │ │ execution │ │ │ └─────────────────┘ └─────────────────┘ └─────────────────┘ │ │ │ @@ -58,11 +58,25 @@ G3 is a **code-first AI agent** that helps you complete tasks by writing and exe - Autonomous execution of generated code #### 3. LLM Providers (`g3-providers`) -- **Responsibility**: LLM communication (unchanged) +- **Responsibility**: LLM communication and model abstraction +- **Supported Providers**: + - **OpenAI**: GPT-4, GPT-3.5-turbo via API + - **Anthropic**: Claude models via API + - **Embedded**: Local open-weights models via llama.cpp - **Enhanced Prompts**: - Code-first system prompts - Language-specific generation instructions +#### 5. Embedded Provider (`g3-core/providers/embedded`) - NEW +- **Responsibility**: Local model inference using llama.cpp +- **Features**: + - GGUF model support (Llama, CodeLlama, Mistral, etc.) + - GPU acceleration via CUDA/Metal + - Configurable context length and generation parameters + - Async-compatible inference without blocking + - Thread-safe model access + - Stop sequence detection + #### 4. Execution Engine (`g3-execution`) - NEW - **Responsibility**: Safe code execution - **Features**: @@ -86,8 +100,73 @@ G3 is a **code-first AI agent** that helps you complete tasks by writing and exe ## Implementation Plan -### Phase 1: Core Refactoring -1. Update CLI commands for task-oriented interface -2. Enhance system prompts for code-first approach -3. Add basic code execution capabilities -4. Update interactive mode messaging +### Phase 1: Core Refactoring ✅ +1. ✅ Update CLI commands for task-oriented interface +2. ✅ Enhance system prompts for code-first approach +3. ✅ Add basic code execution capabilities +4. ✅ Update interactive mode messaging + +### Phase 2: Enhanced Provider Support ✅ +1. ✅ Implement embedded model provider using llama.cpp +2. ✅ Add GGUF model support for local inference +3. ✅ Configure GPU acceleration and performance optimization +4. ✅ Add comprehensive logging and debugging support + +### Phase 3: Advanced Features (Future) +1. Model quantization and optimization +2. Multi-model ensemble support +3. Advanced code execution sandboxing +4. Plugin system for custom providers +5. Web interface for remote access + +## Provider Comparison + +| Feature | OpenAI | Anthropic | Embedded | +|---------|--------|-----------|----------| +| **Cost** | Pay per token | Pay per token | Free after download | +| **Privacy** | Data sent to API | Data sent to API | Completely local | +| **Performance** | Very fast | Very fast | Depends on hardware | +| **Model Quality** | Excellent | Excellent | Good (varies by model) | +| **Offline Support** | No | No | Yes | +| **Setup Complexity** | API key only | API key only | Model download required | +| **Hardware Requirements** | None | None | 4-16GB RAM, optional GPU | + +## Configuration Examples + +### Cloud-First Setup +```toml +[providers] +default_provider = "openai" + +[providers.openai] +api_key = "sk-..." +model = "gpt-4" +``` + +### Privacy-First Setup +```toml +[providers] +default_provider = "embedded" + +[providers.embedded] +model_path = "~/.cache/g3/models/codellama-7b-instruct.Q4_K_M.gguf" +model_type = "codellama" +gpu_layers = 32 +``` + +### Hybrid Setup +```toml +[providers] +default_provider = "embedded" + +# Use embedded for most tasks +[providers.embedded] +model_path = "~/.cache/g3/models/codellama-7b-instruct.Q4_K_M.gguf" +model_type = "codellama" +gpu_layers = 32 + +# Fallback to cloud for complex tasks +[providers.openai] +api_key = "sk-..." +model = "gpt-4" +``` diff --git a/README.md b/README.md index 7314828..aef8af3 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,8 @@ G3 will write the appropriate scripts (Python, Bash, JavaScript, etc.) and can e - **Code-First Approach**: Always tries to solve problems with executable code - **Multi-Language Support**: Generates Python, Bash, JavaScript, Rust, and more - **Modular Architecture**: Clean separation between CLI, core engine, and LLM providers -- **Multiple LLM Providers**: Support for OpenAI, Anthropic, and extensible to other providers +- **Multiple LLM Providers**: Support for OpenAI, Anthropic, and embedded open-weights models +- **Local Model Support**: Run completely offline with embedded GGUF models via llama.cpp - **Interactive Mode**: Chat with the AI and watch it solve problems in real-time - **Task Automation**: Create reusable automation scripts @@ -34,6 +35,8 @@ cargo install --path . Create a configuration file at `~/.config/g3/config.toml`: +### Cloud Providers + ```toml [providers] default_provider = "openai" @@ -49,7 +52,37 @@ api_key = "your-anthropic-api-key" model = "claude-3-sonnet-20240229" max_tokens = 2048 temperature = 0.1 +``` +### Local Embedded Models + +For completely offline operation with open-weights models: + +```toml +[providers] +default_provider = "embedded" + +[providers.embedded] +# Path to your GGUF model file +model_path = "~/.cache/g3/models/codellama-7b-instruct.Q4_K_M.gguf" +model_type = "codellama" +context_length = 4096 +max_tokens = 2048 +temperature = 0.1 +# Number of layers to offload to GPU (0 for CPU only) +gpu_layers = 32 +# Number of CPU threads to use +threads = 8 +``` + +**Getting Models**: Download GGUF models from [Hugging Face](https://huggingface.co/models?library=gguf) (search for "GGUF"). Popular options: +- [CodeLlama 7B Instruct](https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GGUF) +- [Llama 2 7B Chat](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF) +- [Mistral 7B Instruct](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF) + +### Agent Settings + +```toml [agent] max_context_length = 8192 enable_streaming = true diff --git a/config.example.toml b/config.example.toml index 2acb512..cbc9409 100644 --- a/config.example.toml +++ b/config.example.toml @@ -2,7 +2,7 @@ # Copy to ~/.config/g3/config.toml and customize [providers] -default_provider = "openai" +default_provider = "embedded" [providers.openai] # Get your API key from https://platform.openai.com/api-keys @@ -20,6 +20,18 @@ model = "claude-3-sonnet-20240229" max_tokens = 2048 temperature = 0.1 +[providers.embedded] +# Path to your GGUF model file +model_path = "~/.cache/g3/models/codellama-7b-instruct.Q4_K_M.gguf" +model_type = "codellama" +context_length = 16384 # Use CodeLlama's full context capability +max_tokens = 2048 # Default fallback, but will be calculated dynamically +temperature = 0.1 +# Number of layers to offload to GPU (0 for CPU only) +gpu_layers = 32 +# Number of CPU threads to use +threads = 8 + [agent] max_context_length = 8192 enable_streaming = true diff --git a/crates/g3-config/src/lib.rs b/crates/g3-config/src/lib.rs index ad16093..5951985 100644 --- a/crates/g3-config/src/lib.rs +++ b/crates/g3-config/src/lib.rs @@ -12,6 +12,7 @@ pub struct Config { pub struct ProvidersConfig { pub openai: Option, pub anthropic: Option, + pub embedded: Option, pub default_provider: String, } @@ -32,6 +33,17 @@ pub struct AnthropicConfig { pub temperature: Option, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EmbeddedConfig { + pub model_path: String, + pub model_type: String, // e.g., "llama", "mistral", "codellama" + pub context_length: Option, + pub max_tokens: Option, + pub temperature: Option, + pub gpu_layers: Option, // Number of layers to offload to GPU + pub threads: Option, // Number of CPU threads to use +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AgentConfig { pub max_context_length: usize, @@ -45,6 +57,7 @@ impl Default for Config { providers: ProvidersConfig { openai: None, anthropic: None, + embedded: None, default_provider: "openai".to_string(), }, agent: AgentConfig { diff --git a/crates/g3-core/Cargo.toml b/crates/g3-core/Cargo.toml index fa9fe3a..eb6aabf 100644 --- a/crates/g3-core/Cargo.toml +++ b/crates/g3-core/Cargo.toml @@ -18,3 +18,5 @@ serde_json = { workspace = true } uuid = { workspace = true } async-trait = "0.1" tokio-stream = "0.1" +llama_cpp = { version = "0.3.2", features = ["metal"] } +shellexpand = "3.1" diff --git a/crates/g3-core/src/lib.rs b/crates/g3-core/src/lib.rs index 8acd303..03b1b33 100644 --- a/crates/g3-core/src/lib.rs +++ b/crates/g3-core/src/lib.rs @@ -72,6 +72,19 @@ impl Agent { providers.register(anthropic_provider); } + if let Some(embedded_config) = &config.providers.embedded { + let embedded_provider = crate::providers::embedded::EmbeddedProvider::new( + embedded_config.model_path.clone(), + embedded_config.model_type.clone(), + embedded_config.context_length, + embedded_config.max_tokens, + embedded_config.temperature, + embedded_config.gpu_layers, + embedded_config.threads, + )?; + providers.register(embedded_provider); + } + // Set default provider providers.set_default(&config.providers.default_provider)?; @@ -522,4 +535,5 @@ impl std::fmt::Display for AnalysisResult { pub mod providers { pub mod anthropic; pub mod openai; + pub mod embedded; } diff --git a/crates/g3-core/src/providers/embedded.rs b/crates/g3-core/src/providers/embedded.rs new file mode 100644 index 0000000..3fa9f1f --- /dev/null +++ b/crates/g3-core/src/providers/embedded.rs @@ -0,0 +1,362 @@ +use g3_providers::{LLMProvider, CompletionRequest, CompletionResponse, CompletionStream, CompletionChunk, Usage, Message, MessageRole}; +use anyhow::Result; +use llama_cpp::{LlamaModel, LlamaSession, LlamaParams, SessionParams, standard_sampler::{StandardSampler, SamplerStage}}; +use std::path::Path; +use std::sync::Arc; +use tokio::sync::Mutex; +use tracing::{debug, info, error, warn}; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use std::sync::atomic::{AtomicBool, Ordering}; + +pub struct EmbeddedProvider { + model: Arc, + session: Arc>, + model_name: String, + max_tokens: u32, + temperature: f32, + context_length: u32, + generation_active: Arc, +} + +impl EmbeddedProvider { + pub fn new( + model_path: String, + model_type: String, + context_length: Option, + max_tokens: Option, + temperature: Option, + gpu_layers: Option, + threads: Option, + ) -> Result { + info!("Loading embedded model from: {}", model_path); + + // Expand tilde in path + let expanded_path = shellexpand::tilde(&model_path); + let model_path = Path::new(expanded_path.as_ref()); + + if !model_path.exists() { + anyhow::bail!("Model file not found: {}", model_path.display()); + } + + // Set up model parameters + let mut params = LlamaParams::default(); + + if let Some(gpu_layers) = gpu_layers { + params.n_gpu_layers = gpu_layers; + info!("Using {} GPU layers", gpu_layers); + } + + let context_size = context_length.unwrap_or(4096); + info!("Using context length: {}", context_size); + + // Load the model + info!("Loading model..."); + let model = LlamaModel::load_from_file(model_path, params) + .map_err(|e| anyhow::anyhow!("Failed to load model: {}", e))?; + + // Create session with parameters + let mut session_params = SessionParams::default(); + session_params.n_ctx = context_size; + if let Some(threads) = threads { + session_params.n_threads = threads; + } + + let session = model.create_session(session_params) + .map_err(|e| anyhow::anyhow!("Failed to create session: {}", e))?; + + info!("Successfully loaded {} model", model_type); + + Ok(Self { + model: Arc::new(model), + session: Arc::new(Mutex::new(session)), + model_name: format!("embedded-{}", model_type), + max_tokens: max_tokens.unwrap_or(2048), + temperature: temperature.unwrap_or(0.1), + context_length: context_size, + generation_active: Arc::new(AtomicBool::new(false)), + }) + } + + fn format_messages(&self, messages: &[Message]) -> String { + // Use proper prompt format for CodeLlama + let mut formatted = String::new(); + + for message in messages { + match message.role { + MessageRole::System => { + formatted.push_str(&format!("[INST] <>\n{}\n<>\n\n", message.content)); + } + MessageRole::User => { + formatted.push_str(&format!("{} [/INST] ", message.content)); + } + MessageRole::Assistant => { + formatted.push_str(&format!("{} [INST] ", message.content)); + } + } + } + + formatted + } + + async fn generate_completion(&self, prompt: &str, max_tokens: u32, temperature: f32) -> Result { + let session = self.session.clone(); + let prompt = prompt.to_string(); + + // Calculate dynamic max tokens based on available context headroom + let prompt_tokens = self.estimate_tokens(&prompt); + let available_tokens = self.context_length.saturating_sub(prompt_tokens).saturating_sub(50); // Reserve 50 tokens for safety + let dynamic_max_tokens = std::cmp::min(max_tokens as usize, available_tokens as usize); + + debug!("Context calculation: prompt_tokens={}, context_length={}, available_tokens={}, dynamic_max_tokens={}", + prompt_tokens, self.context_length, available_tokens, dynamic_max_tokens); + + // Add timeout to the entire operation + let timeout_duration = std::time::Duration::from_secs(30); // Increased timeout for larger contexts + + let result = tokio::time::timeout(timeout_duration, tokio::task::spawn_blocking(move || { + let mut session = match session.try_lock() { + Ok(ctx) => ctx, + Err(_) => return Err(anyhow::anyhow!("Model is busy, please try again")), + }; + + debug!("Starting inference with prompt length: {} chars, estimated {} tokens", prompt.len(), prompt_tokens); + + // Set context to the prompt + debug!("About to call set_context..."); + session.set_context(&prompt) + .map_err(|e| anyhow::anyhow!("Failed to set context: {}", e))?; + debug!("set_context completed successfully"); + + // Create sampler with temperature + debug!("Creating sampler..."); + let stages = vec![ + SamplerStage::Temperature(temperature), + SamplerStage::TopK(40), + SamplerStage::TopP(0.9), + ]; + let sampler = StandardSampler::new_softmax(stages, 1); + debug!("Sampler created successfully"); + + // Start completion with dynamic max tokens + debug!("About to call start_completing_with with {} max tokens...", dynamic_max_tokens); + let mut completion_handle = session.start_completing_with(sampler, dynamic_max_tokens) + .map_err(|e| anyhow::anyhow!("Failed to start completion: {}", e))?; + debug!("start_completing_with completed successfully"); + + let mut generated_text = String::new(); + let mut token_count = 0; + let start_time = std::time::Instant::now(); + + debug!("Starting token generation loop..."); + + // Generate tokens with dynamic limits + while let Some(token) = completion_handle.next_token() { + // Check for timeout on each token + if start_time.elapsed() > std::time::Duration::from_secs(25) { + debug!("Token generation timeout after {} tokens", token_count); + break; + } + + let token_string = session.model().token_to_piece(token); + generated_text.push_str(&token_string); + token_count += 1; + + if token_count <= 10 || token_count % 50 == 0 { + debug!("Generated token {}: '{}'", token_count, token_string); + } + + // Use dynamic token limit + if token_count >= dynamic_max_tokens { + debug!("Reached dynamic token limit: {}", dynamic_max_tokens); + break; + } + + // Stop on completion markers + if generated_text.contains("") || generated_text.contains("[/INST]") { + debug!("Hit CodeLlama stop sequence at {} tokens", token_count); + break; + } + + // Stop on natural completion points after reasonable generation + if token_count >= 20 && ( + generated_text.trim().ends_with("```") || + (generated_text.contains("```") && generated_text.matches("```").count() % 2 == 0) // Complete code blocks + ) { + debug!("Hit code block completion at {} tokens", token_count); + break; + } + } + + debug!("Token generation loop completed. Generated {} tokens in {:?}", token_count, start_time.elapsed()); + Ok((generated_text.trim().to_string(), token_count)) + })).await; + + match result { + Ok(inner_result) => match inner_result { + Ok(task_result) => match task_result { + Ok((text, token_count)) => { + info!("Completed generation: {} tokens (dynamic limit was {})", token_count, dynamic_max_tokens); + Ok(text) + } + Err(e) => Err(e), + }, + Err(e) => Err(e.into()), + }, + Err(_) => { + error!("Generation timed out after 30 seconds"); + Err(anyhow::anyhow!("Generation timed out")) + } + } + } + + // Helper function to estimate token count from text + fn estimate_tokens(&self, text: &str) -> u32 { + // Rough estimation: average 4 characters per token + // This is conservative - actual tokenization might be different + (text.len() as f32 / 4.0).ceil() as u32 + } +} + +#[async_trait::async_trait] +impl LLMProvider for EmbeddedProvider { + async fn complete(&self, request: CompletionRequest) -> Result { + debug!("Processing completion request with {} messages", request.messages.len()); + + let prompt = self.format_messages(&request.messages); + let max_tokens = request.max_tokens.unwrap_or(self.max_tokens); + let temperature = request.temperature.unwrap_or(self.temperature); + + debug!("Formatted prompt length: {} chars", prompt.len()); + + let content = self.generate_completion(&prompt, max_tokens, temperature).await?; + + // Estimate token usage (rough approximation) + let prompt_tokens = (prompt.len() / 4) as u32; // Rough estimate: 4 chars per token + let completion_tokens = (content.len() / 4) as u32; + + Ok(CompletionResponse { + content, + usage: Usage { + prompt_tokens, + completion_tokens, + total_tokens: prompt_tokens + completion_tokens, + }, + model: self.model_name.clone(), + }) + } + + async fn stream(&self, request: CompletionRequest) -> Result { + debug!("Processing streaming request with {} messages", request.messages.len()); + + let prompt = self.format_messages(&request.messages); + let max_tokens = request.max_tokens.unwrap_or(self.max_tokens); + let temperature = request.temperature.unwrap_or(self.temperature); + + let (tx, rx) = mpsc::channel(100); + let session = self.session.clone(); + let prompt = prompt.to_string(); + + // Spawn streaming task + tokio::task::spawn_blocking(move || { + let mut session = match session.try_lock() { + Ok(ctx) => ctx, + Err(_) => { + let _ = tx.blocking_send(Err(anyhow::anyhow!("Model is busy, please try again"))); + return; + } + }; + + // Set context to the prompt + if let Err(e) = session.set_context(&prompt) { + let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to set context: {}", e))); + return; + } + + // Create sampler with temperature + let stages = vec![ + SamplerStage::Temperature(temperature), + SamplerStage::TopK(40), + SamplerStage::TopP(0.9), + ]; + let sampler = StandardSampler::new_softmax(stages, 1); + + // Start completion + let mut completion_handle = match session.start_completing_with(sampler, max_tokens as usize) { + Ok(handle) => handle, + Err(e) => { + let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to start completion: {}", e))); + return; + } + }; + + let mut accumulated_text = String::new(); + let mut token_count = 0; + + // Stream tokens with proper limits + while let Some(token) = completion_handle.next_token() { + let token_string = session.model().token_to_piece(token); + + accumulated_text.push_str(&token_string); + token_count += 1; + + let chunk = CompletionChunk { + content: token_string.clone(), + finished: false, + }; + + if tx.blocking_send(Ok(chunk)).is_err() { + break; // Receiver dropped + } + + // Enforce token limit + if token_count >= max_tokens as usize { + debug!("Reached max token limit in streaming: {}", max_tokens); + break; + } + + // Stop if we hit common stop sequences + if accumulated_text.contains("### Human") || + accumulated_text.contains("### System") || + accumulated_text.contains("<|end|>") || + accumulated_text.contains("") || + accumulated_text.trim().ends_with("```") { + debug!("Hit stop sequence in streaming, stopping generation"); + break; + } + + // Emergency brake for streaming too + if token_count > 0 && token_count % 100 == 0 { + debug!("Streaming: Generated {} tokens so far", token_count); + if accumulated_text.trim().len() > 50 && + (accumulated_text.contains('\n') || accumulated_text.len() > 200) { + if accumulated_text.trim().ends_with('.') || + accumulated_text.trim().ends_with('!') || + accumulated_text.trim().ends_with('?') || + accumulated_text.trim().ends_with('\n') { + debug!("Found natural stopping point in streaming at {} tokens", token_count); + break; + } + } + } + } + + // Send final chunk + let final_chunk = CompletionChunk { + content: String::new(), + finished: true, + }; + let _ = tx.blocking_send(Ok(final_chunk)); + }); + + Ok(ReceiverStream::new(rx)) + } + + fn name(&self) -> &str { + "embedded" + } + + fn model(&self) -> &str { + &self.model_name + } +}