Compare commits
16 Commits
micn/alway
...
micn/ollam
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79b375519b | ||
|
|
88c3cc23fe | ||
|
|
4622507f37 | ||
|
|
217df2f2af | ||
|
|
22a0090cdc | ||
|
|
631f3c16ca | ||
|
|
1f9fef5f18 | ||
|
|
57d473c19d | ||
|
|
e59ce2f93f | ||
|
|
a1ad94ed75 | ||
|
|
982c0bbfb3 | ||
|
|
ad9ba5e5d8 | ||
|
|
f89bbfc89a | ||
|
|
11eb01e04d | ||
|
|
bdaacfd051 | ||
|
|
92ae776510 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -3,6 +3,7 @@
|
|||||||
debug
|
debug
|
||||||
target
|
target
|
||||||
.build
|
.build
|
||||||
|
appy/
|
||||||
|
|
||||||
# These are backup files generated by rustfmt
|
# These are backup files generated by rustfmt
|
||||||
**/*.rs.bk
|
**/*.rs.bk
|
||||||
|
|||||||
20
Cargo.lock
generated
20
Cargo.lock
generated
@@ -1391,6 +1391,7 @@ dependencies = [
|
|||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"serde_yaml",
|
||||||
"shellexpand",
|
"shellexpand",
|
||||||
"thiserror 1.0.69",
|
"thiserror 1.0.69",
|
||||||
"tokio",
|
"tokio",
|
||||||
@@ -3078,6 +3079,19 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_yaml"
|
||||||
|
version = "0.9.34+deprecated"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47"
|
||||||
|
dependencies = [
|
||||||
|
"indexmap",
|
||||||
|
"itoa",
|
||||||
|
"ryu",
|
||||||
|
"serde",
|
||||||
|
"unsafe-libyaml",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sha2"
|
name = "sha2"
|
||||||
version = "0.10.9"
|
version = "0.10.9"
|
||||||
@@ -3667,6 +3681,12 @@ version = "0.2.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd"
|
checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unsafe-libyaml"
|
||||||
|
version = "0.2.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "url"
|
name = "url"
|
||||||
version = "2.5.7"
|
version = "2.5.7"
|
||||||
|
|||||||
456
OLLAMA_CONFIG.md
Normal file
456
OLLAMA_CONFIG.md
Normal file
@@ -0,0 +1,456 @@
|
|||||||
|
# Configuring Ollama Provider in G3
|
||||||
|
|
||||||
|
This guide shows you how to configure G3 to use Ollama as your LLM provider.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Install Ollama
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Visit https://ollama.ai to download and install
|
||||||
|
# Or use curl:
|
||||||
|
curl https://ollama.ai/install.sh | sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Pull a Model
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ollama pull llama3.2
|
||||||
|
# or any other model you prefer
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Create Configuration File
|
||||||
|
|
||||||
|
Copy the example configuration:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp config.ollama.example.toml ~/.config/g3/config.toml
|
||||||
|
```
|
||||||
|
|
||||||
|
Or create it manually:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama"
|
||||||
|
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Run G3
|
||||||
|
|
||||||
|
```bash
|
||||||
|
g3
|
||||||
|
# G3 will now use Ollama with llama3.2!
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Options
|
||||||
|
|
||||||
|
### Basic Configuration
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama"
|
||||||
|
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2"
|
||||||
|
```
|
||||||
|
|
||||||
|
This is the minimal configuration needed. It uses all defaults:
|
||||||
|
- Base URL: `http://localhost:11434`
|
||||||
|
- Temperature: `0.7`
|
||||||
|
- Max tokens: Not limited (uses model default)
|
||||||
|
|
||||||
|
### Full Configuration
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama"
|
||||||
|
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2"
|
||||||
|
base_url = "http://localhost:11434"
|
||||||
|
max_tokens = 2048
|
||||||
|
temperature = 0.7
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Ollama Host
|
||||||
|
|
||||||
|
If you're running Ollama on a different machine or port:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2"
|
||||||
|
base_url = "http://192.168.1.100:11434"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Different Models
|
||||||
|
|
||||||
|
You can use any Ollama model:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
model = "qwen2.5:7b" # Alibaba's Qwen model
|
||||||
|
```
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
model = "mistral" # Mistral AI
|
||||||
|
```
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.1:70b" # Larger Llama model
|
||||||
|
```
|
||||||
|
|
||||||
|
## Multiple Provider Configuration
|
||||||
|
|
||||||
|
You can configure multiple providers and switch between them:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama" # Default for most operations
|
||||||
|
|
||||||
|
# Ollama for local, fast responses
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2:3b"
|
||||||
|
temperature = 0.7
|
||||||
|
|
||||||
|
# Databricks for more complex tasks
|
||||||
|
[providers.databricks]
|
||||||
|
host = "https://your-workspace.cloud.databricks.com"
|
||||||
|
model = "databricks-claude-sonnet-4"
|
||||||
|
max_tokens = 4096
|
||||||
|
temperature = 0.1
|
||||||
|
use_oauth = true
|
||||||
|
```
|
||||||
|
|
||||||
|
Then switch providers with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
g3 --provider databricks
|
||||||
|
```
|
||||||
|
|
||||||
|
## Autonomous Mode (Coach-Player)
|
||||||
|
|
||||||
|
Use different providers for code review (coach) and implementation (player):
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama"
|
||||||
|
coach = "databricks" # Use powerful cloud model for review
|
||||||
|
player = "ollama" # Use local model for implementation
|
||||||
|
|
||||||
|
[providers.ollama]
|
||||||
|
model = "qwen2.5:14b" # Larger local model for coding
|
||||||
|
|
||||||
|
[providers.databricks]
|
||||||
|
host = "https://your-workspace.cloud.databricks.com"
|
||||||
|
model = "databricks-claude-sonnet-4"
|
||||||
|
use_oauth = true
|
||||||
|
```
|
||||||
|
|
||||||
|
This gives you the best of both worlds:
|
||||||
|
- Fast local execution for coding tasks
|
||||||
|
- Powerful cloud review for quality assurance
|
||||||
|
|
||||||
|
## Recommended Models
|
||||||
|
|
||||||
|
### For Coding Tasks
|
||||||
|
|
||||||
|
| Model | Size | Speed | Quality | Notes |
|
||||||
|
|-------|------|-------|---------|-------|
|
||||||
|
| **qwen2.5:7b** | 7B | Fast | Excellent | Best balance for coding |
|
||||||
|
| **llama3.2:3b** | 3B | Very Fast | Good | Great for quick tasks |
|
||||||
|
| **llama3.1:8b** | 8B | Medium | Very Good | Solid all-rounder |
|
||||||
|
| **mistral** | 7B | Fast | Good | Good for general use |
|
||||||
|
|
||||||
|
### For Complex Tasks
|
||||||
|
|
||||||
|
| Model | Size | Speed | Quality | Notes |
|
||||||
|
|-------|------|-------|---------|-------|
|
||||||
|
| **qwen2.5:14b** | 14B | Medium | Excellent | Best local model for coding |
|
||||||
|
| **qwen2.5:32b** | 32B | Slow | Outstanding | If you have the resources |
|
||||||
|
| **llama3.1:70b** | 70B | Very Slow | Outstanding | Requires significant RAM/GPU |
|
||||||
|
|
||||||
|
## Temperature Settings
|
||||||
|
|
||||||
|
Temperature controls randomness in responses:
|
||||||
|
|
||||||
|
- **0.1-0.3**: Deterministic, good for code generation
|
||||||
|
- **0.5-0.7**: Balanced, good for most tasks
|
||||||
|
- **0.8-1.0**: Creative, good for brainstorming
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
model = "qwen2.5:7b"
|
||||||
|
temperature = 0.2 # Focused code generation
|
||||||
|
```
|
||||||
|
|
||||||
|
## Max Tokens
|
||||||
|
|
||||||
|
Control response length:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2"
|
||||||
|
max_tokens = 1024 # Shorter responses
|
||||||
|
```
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
model = "qwen2.5:7b"
|
||||||
|
max_tokens = 4096 # Longer, detailed responses
|
||||||
|
```
|
||||||
|
|
||||||
|
Leave it unset for model defaults (recommended).
|
||||||
|
|
||||||
|
## Performance Tuning
|
||||||
|
|
||||||
|
### GPU Acceleration
|
||||||
|
|
||||||
|
Ollama automatically uses GPU if available. To check:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ollama ps
|
||||||
|
```
|
||||||
|
|
||||||
|
### Quantized Models
|
||||||
|
|
||||||
|
For faster responses with less RAM:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2:3b-q4_0" # 4-bit quantization
|
||||||
|
```
|
||||||
|
|
||||||
|
Quantization options:
|
||||||
|
- `q4_0`: 4-bit, fastest, lowest quality
|
||||||
|
- `q5_0`: 5-bit, balanced
|
||||||
|
- `q8_0`: 8-bit, slower, better quality
|
||||||
|
|
||||||
|
### Multiple Models
|
||||||
|
|
||||||
|
You can pull multiple models and switch easily:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ollama pull llama3.2:3b # Fast for chat
|
||||||
|
ollama pull qwen2.5:7b # Better for code
|
||||||
|
ollama pull mistral # General purpose
|
||||||
|
```
|
||||||
|
|
||||||
|
Then change your config:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
model = "qwen2.5:7b" # Just change this line
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Ollama Not Running
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Check if Ollama is running
|
||||||
|
curl http://localhost:11434/api/version
|
||||||
|
|
||||||
|
# Start Ollama (macOS/Linux)
|
||||||
|
ollama serve
|
||||||
|
|
||||||
|
# Or just run a model (auto-starts)
|
||||||
|
ollama run llama3.2
|
||||||
|
```
|
||||||
|
|
||||||
|
### Model Not Found
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# List available models
|
||||||
|
ollama list
|
||||||
|
|
||||||
|
# Pull the model
|
||||||
|
ollama pull llama3.2
|
||||||
|
```
|
||||||
|
|
||||||
|
### Slow Responses
|
||||||
|
|
||||||
|
1. Use a smaller model:
|
||||||
|
```toml
|
||||||
|
model = "llama3.2:1b" # Smallest, fastest
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Use quantized version:
|
||||||
|
```toml
|
||||||
|
model = "llama3.2:3b-q4_0"
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Reduce max_tokens:
|
||||||
|
```toml
|
||||||
|
max_tokens = 512
|
||||||
|
```
|
||||||
|
|
||||||
|
### Out of Memory
|
||||||
|
|
||||||
|
1. Switch to smaller model
|
||||||
|
2. Use quantized version
|
||||||
|
3. Close other applications
|
||||||
|
4. Check GPU memory: `ollama ps`
|
||||||
|
|
||||||
|
### Connection Refused
|
||||||
|
|
||||||
|
Check base_url is correct:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2"
|
||||||
|
base_url = "http://localhost:11434" # Default
|
||||||
|
```
|
||||||
|
|
||||||
|
For remote Ollama:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
base_url = "http://your-server:11434"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Complete Example Configs
|
||||||
|
|
||||||
|
### Minimal Local Setup
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama"
|
||||||
|
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2"
|
||||||
|
|
||||||
|
[agent]
|
||||||
|
max_context_length = 8192
|
||||||
|
enable_streaming = true
|
||||||
|
timeout_seconds = 60
|
||||||
|
```
|
||||||
|
|
||||||
|
### Optimized for Coding
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama"
|
||||||
|
|
||||||
|
[providers.ollama]
|
||||||
|
model = "qwen2.5:7b"
|
||||||
|
temperature = 0.2
|
||||||
|
max_tokens = 2048
|
||||||
|
|
||||||
|
[agent]
|
||||||
|
max_context_length = 16384
|
||||||
|
enable_streaming = true
|
||||||
|
timeout_seconds = 120
|
||||||
|
```
|
||||||
|
|
||||||
|
### Fast Responses
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama"
|
||||||
|
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2:3b-q4_0"
|
||||||
|
temperature = 0.7
|
||||||
|
max_tokens = 1024
|
||||||
|
|
||||||
|
[agent]
|
||||||
|
max_context_length = 4096
|
||||||
|
enable_streaming = true
|
||||||
|
timeout_seconds = 30
|
||||||
|
```
|
||||||
|
|
||||||
|
### High Quality (Requires Good Hardware)
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama"
|
||||||
|
|
||||||
|
[providers.ollama]
|
||||||
|
model = "qwen2.5:32b"
|
||||||
|
temperature = 0.3
|
||||||
|
max_tokens = 4096
|
||||||
|
|
||||||
|
[agent]
|
||||||
|
max_context_length = 32768
|
||||||
|
enable_streaming = true
|
||||||
|
timeout_seconds = 300
|
||||||
|
```
|
||||||
|
|
||||||
|
### Hybrid (Local + Cloud)
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama"
|
||||||
|
coach = "databricks"
|
||||||
|
player = "ollama"
|
||||||
|
|
||||||
|
[providers.ollama]
|
||||||
|
model = "qwen2.5:14b"
|
||||||
|
temperature = 0.2
|
||||||
|
|
||||||
|
[providers.databricks]
|
||||||
|
host = "https://your-workspace.cloud.databricks.com"
|
||||||
|
model = "databricks-claude-sonnet-4"
|
||||||
|
use_oauth = true
|
||||||
|
|
||||||
|
[agent]
|
||||||
|
max_context_length = 16384
|
||||||
|
enable_streaming = true
|
||||||
|
timeout_seconds = 120
|
||||||
|
```
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
|
You can override config with environment variables:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Override model
|
||||||
|
G3_PROVIDERS_OLLAMA_MODEL=qwen2.5:7b g3
|
||||||
|
|
||||||
|
# Override base URL
|
||||||
|
G3_PROVIDERS_OLLAMA_BASE_URL=http://192.168.1.100:11434 g3
|
||||||
|
|
||||||
|
# Override default provider
|
||||||
|
G3_PROVIDERS_DEFAULT_PROVIDER=ollama g3
|
||||||
|
```
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Start Small**: Begin with llama3.2:3b, scale up if needed
|
||||||
|
2. **Use Quantization**: q4_0 or q5_0 for best speed/quality balance
|
||||||
|
3. **Match Task to Model**:
|
||||||
|
- Quick edits: 1B-3B models
|
||||||
|
- Code generation: 7B-14B models
|
||||||
|
- Complex refactoring: 14B-32B models
|
||||||
|
4. **Temperature for Code**: Use 0.1-0.3 for deterministic output
|
||||||
|
5. **Enable Streaming**: Always enable for better UX
|
||||||
|
6. **Local First**: Use Ollama by default, cloud for special cases
|
||||||
|
|
||||||
|
## Comparison with Other Providers
|
||||||
|
|
||||||
|
| Feature | Ollama | Databricks | OpenAI | Anthropic |
|
||||||
|
|---------|--------|------------|--------|-----------|
|
||||||
|
| Cost | Free | Paid | Paid | Paid |
|
||||||
|
| Privacy | Full | Medium | Low | Low |
|
||||||
|
| Speed (small models) | Fast | Fast | Medium | Medium |
|
||||||
|
| Speed (large models) | Slow | Fast | Fast | Fast |
|
||||||
|
| Setup Complexity | Low | Medium | Low | Low |
|
||||||
|
| Authentication | None | OAuth/Token | API Key | API Key |
|
||||||
|
| Offline Support | Yes | No | No | No |
|
||||||
|
| Tool Calling | Yes | Yes | Yes | Yes |
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
1. Try different models: `ollama pull mistral`, `ollama pull qwen2.5`
|
||||||
|
2. Experiment with temperature settings
|
||||||
|
3. Set up hybrid config with cloud provider for complex tasks
|
||||||
|
4. Share your config in the community!
|
||||||
|
|
||||||
|
## Getting Help
|
||||||
|
|
||||||
|
- Ollama docs: https://ollama.ai/docs
|
||||||
|
- G3 issues: https://github.com/your-repo/issues
|
||||||
|
- Test your config: `g3 --help`
|
||||||
315
OLLAMA_EXAMPLE.md
Normal file
315
OLLAMA_EXAMPLE.md
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
# Ollama Provider for g3
|
||||||
|
|
||||||
|
A simple, local LLM provider implementation for g3 that connects to Ollama.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- ✅ **Simple Setup**: No API keys or authentication required
|
||||||
|
- ✅ **Local Execution**: Runs entirely on your machine
|
||||||
|
- ✅ **Tool Calling Support**: Native tool calling for compatible models
|
||||||
|
- ✅ **Streaming**: Full streaming support with real-time responses
|
||||||
|
- ✅ **Flexible Configuration**: Custom base URL, temperature, and max tokens
|
||||||
|
- ✅ **Model Discovery**: Automatic detection of available models
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
1. Install and start Ollama: https://ollama.ai
|
||||||
|
2. Pull a model: `ollama pull llama3.2`
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use g3_providers::{OllamaProvider, LLMProvider, CompletionRequest, Message, MessageRole};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
// Create provider with default settings (localhost:11434)
|
||||||
|
let provider = OllamaProvider::new(
|
||||||
|
"llama3.2".to_string(),
|
||||||
|
None, // base_url: defaults to http://localhost:11434
|
||||||
|
None, // max_tokens: optional
|
||||||
|
None, // temperature: defaults to 0.7
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Create a simple request
|
||||||
|
let request = CompletionRequest {
|
||||||
|
messages: vec![
|
||||||
|
Message {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: "What is the capital of France?".to_string(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
max_tokens: Some(1000),
|
||||||
|
temperature: Some(0.7),
|
||||||
|
stream: false,
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get completion
|
||||||
|
let response = provider.complete(request).await?;
|
||||||
|
println!("Response: {}", response.content);
|
||||||
|
println!("Tokens: {}", response.usage.total_tokens);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Streaming Example
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
|
||||||
|
let request = CompletionRequest {
|
||||||
|
messages: vec![
|
||||||
|
Message {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: "Write a short poem about coding".to_string(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
max_tokens: Some(500),
|
||||||
|
temperature: Some(0.8),
|
||||||
|
stream: true,
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut stream = provider.stream(request).await?;
|
||||||
|
|
||||||
|
while let Some(chunk_result) = stream.next().await {
|
||||||
|
match chunk_result {
|
||||||
|
Ok(chunk) => {
|
||||||
|
print!("{}", chunk.content);
|
||||||
|
if chunk.finished {
|
||||||
|
println!("\n\nDone!");
|
||||||
|
if let Some(usage) = chunk.usage {
|
||||||
|
println!("Total tokens: {}", usage.total_tokens);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => eprintln!("Error: {}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Tool Calling Example
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
let tools = vec![Tool {
|
||||||
|
name: "get_weather".to_string(),
|
||||||
|
description: "Get current weather for a location".to_string(),
|
||||||
|
input_schema: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "City name"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "Temperature unit"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}),
|
||||||
|
}];
|
||||||
|
|
||||||
|
let request = CompletionRequest {
|
||||||
|
messages: vec![
|
||||||
|
Message {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: "What's the weather in Paris?".to_string(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
max_tokens: Some(500),
|
||||||
|
temperature: Some(0.5),
|
||||||
|
stream: false,
|
||||||
|
tools: Some(tools),
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = provider.complete(request).await?;
|
||||||
|
println!("Response: {}", response.content);
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Ollama Host
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// Connect to remote Ollama instance
|
||||||
|
let provider = OllamaProvider::new(
|
||||||
|
"llama3.2".to_string(),
|
||||||
|
Some("http://192.168.1.100:11434".to_string()),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
```
|
||||||
|
|
||||||
|
### Fetch Available Models
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// Discover what models are available
|
||||||
|
let models = provider.fetch_available_models().await?;
|
||||||
|
println!("Available models:");
|
||||||
|
for model in models {
|
||||||
|
println!(" - {}", model);
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
The provider works with any Ollama model, including:
|
||||||
|
|
||||||
|
- **llama3.2** (1B, 3B) - Meta's latest Llama models
|
||||||
|
- **llama3.1** (8B, 70B, 405B) - Previous generation
|
||||||
|
- **qwen2.5** (7B, 14B, 32B) - Alibaba's Qwen models
|
||||||
|
- **mistral** - Mistral AI models
|
||||||
|
- **mixtral** - Mixture of experts model
|
||||||
|
- **phi3** - Microsoft's Phi-3
|
||||||
|
- **gemma2** - Google's Gemma 2
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Constructor Parameters
|
||||||
|
|
||||||
|
```rust
|
||||||
|
OllamaProvider::new(
|
||||||
|
model: String, // Model name (e.g., "llama3.2")
|
||||||
|
base_url: Option<String>, // Ollama API URL (default: http://localhost:11434)
|
||||||
|
max_tokens: Option<u32>, // Maximum tokens to generate (optional)
|
||||||
|
temperature: Option<f32>, // Sampling temperature (default: 0.7)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Request Options
|
||||||
|
|
||||||
|
```rust
|
||||||
|
CompletionRequest {
|
||||||
|
messages: Vec<Message>, // Conversation history
|
||||||
|
max_tokens: Option<u32>, // Override provider's max_tokens
|
||||||
|
temperature: Option<f32>, // Override provider's temperature
|
||||||
|
stream: bool, // Enable streaming responses
|
||||||
|
tools: Option<Vec<Tool>>, // Tools for function calling
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Comparison with Other Providers
|
||||||
|
|
||||||
|
| Feature | Ollama | OpenAI | Anthropic | Databricks |
|
||||||
|
|---------|--------|--------|-----------|------------|
|
||||||
|
| Local Execution | ✅ | ❌ | ❌ | ❌ |
|
||||||
|
| Authentication | None | API Key | API Key | OAuth/Token |
|
||||||
|
| Tool Calling | ✅ | ✅ | ✅ | ✅ |
|
||||||
|
| Streaming | ✅ | ✅ | ✅ | ✅ |
|
||||||
|
| Cost | Free | Paid | Paid | Paid |
|
||||||
|
| Privacy | High | Low | Low | Medium |
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
### API Endpoints
|
||||||
|
|
||||||
|
- **Chat Completion**: `POST /api/chat`
|
||||||
|
- **Model List**: `GET /api/tags`
|
||||||
|
|
||||||
|
### Response Format
|
||||||
|
|
||||||
|
Ollama uses a simple JSON-per-line streaming format:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"message":{"role":"assistant","content":"Hello"},"done":false}
|
||||||
|
{"message":{"role":"assistant","content":" there"},"done":false}
|
||||||
|
{"done":true,"prompt_eval_count":10,"eval_count":20}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Tool Call Format
|
||||||
|
|
||||||
|
Tool calls are returned in the message structure:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": {"location": "Paris", "unit": "celsius"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"done": true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Connection Errors
|
||||||
|
|
||||||
|
If you see connection errors, ensure Ollama is running:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Check if Ollama is running
|
||||||
|
curl http://localhost:11434/api/version
|
||||||
|
|
||||||
|
# Start Ollama (if needed)
|
||||||
|
ollama serve
|
||||||
|
```
|
||||||
|
|
||||||
|
### Model Not Found
|
||||||
|
|
||||||
|
Pull the model first:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ollama pull llama3.2
|
||||||
|
ollama list # Check available models
|
||||||
|
```
|
||||||
|
|
||||||
|
### Performance Issues
|
||||||
|
|
||||||
|
- Use smaller models (1B, 3B) for faster responses
|
||||||
|
- Reduce `max_tokens` to limit generation length
|
||||||
|
- Enable GPU acceleration if available
|
||||||
|
- Consider quantized models (e.g., `llama3.2:3b-q4_0`)
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
Run the included tests:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo test --package g3-providers ollama
|
||||||
|
```
|
||||||
|
|
||||||
|
All tests should pass:
|
||||||
|
```
|
||||||
|
running 4 tests
|
||||||
|
test ollama::tests::test_custom_base_url ... ok
|
||||||
|
test ollama::tests::test_message_conversion ... ok
|
||||||
|
test ollama::tests::test_provider_creation ... ok
|
||||||
|
test ollama::tests::test_tool_conversion ... ok
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
The provider follows the same architecture as other g3 providers:
|
||||||
|
|
||||||
|
1. **OllamaProvider**: Main struct implementing `LLMProvider` trait
|
||||||
|
2. **Request/Response Structures**: Internal types for Ollama API
|
||||||
|
3. **Streaming Parser**: Handles line-by-line JSON parsing
|
||||||
|
4. **Tool Call Handling**: Accumulates and converts tool calls
|
||||||
|
5. **Error Handling**: Robust error handling with retries
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
The provider is part of the g3-providers crate. To contribute:
|
||||||
|
|
||||||
|
1. Add features to `ollama.rs`
|
||||||
|
2. Update tests
|
||||||
|
3. Run `cargo test --package g3-providers`
|
||||||
|
4. Update this documentation
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
Same as the g3 project.
|
||||||
26
config.ollama.example.toml
Normal file
26
config.ollama.example.toml
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# Example G3 configuration using Ollama provider
|
||||||
|
# Copy this to ~/.config/g3/config.toml or ./g3.toml to use it
|
||||||
|
|
||||||
|
[providers]
|
||||||
|
default_provider = "ollama"
|
||||||
|
|
||||||
|
# Ollama configuration (local LLM)
|
||||||
|
[providers.ollama]
|
||||||
|
model = "llama3.2" # or qwen2.5, mistral, etc.
|
||||||
|
# base_url = "http://localhost:11434" # Optional, defaults to localhost
|
||||||
|
# max_tokens = 2048 # Optional
|
||||||
|
# temperature = 0.7 # Optional
|
||||||
|
|
||||||
|
# Optional: Specify different providers for coach and player in autonomous mode
|
||||||
|
# coach = "ollama" # Provider for coach (code reviewer)
|
||||||
|
# player = "ollama" # Provider for player (code implementer)
|
||||||
|
|
||||||
|
[agent]
|
||||||
|
max_context_length = 8192
|
||||||
|
enable_streaming = true
|
||||||
|
timeout_seconds = 60
|
||||||
|
|
||||||
|
[computer_control]
|
||||||
|
enabled = false # Set to true to enable computer control (requires OS permissions)
|
||||||
|
require_confirmation = true
|
||||||
|
max_actions_per_second = 5
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use crossterm::style::{Color, SetForegroundColor, ResetColor};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@@ -183,6 +184,10 @@ pub struct Cli {
|
|||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
pub verbose: bool,
|
pub verbose: bool,
|
||||||
|
|
||||||
|
/// Enable manual control of context compaction (disables auto-compact at 90%)
|
||||||
|
#[arg(long = "manual-compact")]
|
||||||
|
pub manual_compact: bool,
|
||||||
|
|
||||||
/// Show the system prompt being sent to the LLM
|
/// Show the system prompt being sent to the LLM
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub show_prompt: bool,
|
pub show_prompt: bool,
|
||||||
@@ -285,10 +290,6 @@ pub async fn run() -> Result<()> {
|
|||||||
tracing_subscriber::registry().with(filter).init();
|
tracing_subscriber::registry().with(filter).init();
|
||||||
}
|
}
|
||||||
|
|
||||||
if !cli.machine {
|
|
||||||
info!("Starting G3 AI Coding Agent");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up workspace directory
|
// Set up workspace directory
|
||||||
let workspace_dir = if let Some(ws) = &cli.workspace {
|
let workspace_dir = if let Some(ws) = &cli.workspace {
|
||||||
ws.clone()
|
ws.clone()
|
||||||
@@ -324,10 +325,6 @@ pub async fn run() -> Result<()> {
|
|||||||
project.ensure_workspace_exists()?;
|
project.ensure_workspace_exists()?;
|
||||||
project.enter_workspace()?;
|
project.enter_workspace()?;
|
||||||
|
|
||||||
if !cli.machine {
|
|
||||||
info!("Using workspace: {}", project.workspace().display());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load configuration with CLI overrides
|
// Load configuration with CLI overrides
|
||||||
let mut config = Config::load_with_overrides(
|
let mut config = Config::load_with_overrides(
|
||||||
cli.config.as_deref(),
|
cli.config.as_deref(),
|
||||||
@@ -338,9 +335,6 @@ pub async fn run() -> Result<()> {
|
|||||||
// Apply macax flag override
|
// Apply macax flag override
|
||||||
if cli.macax {
|
if cli.macax {
|
||||||
config.macax.enabled = true;
|
config.macax.enabled = true;
|
||||||
if !cli.machine {
|
|
||||||
info!("macOS Accessibility API tools enabled");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply webdriver flag override
|
// Apply webdriver flag override
|
||||||
@@ -348,6 +342,11 @@ pub async fn run() -> Result<()> {
|
|||||||
config.webdriver.enabled = true;
|
config.webdriver.enabled = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply no-auto-compact flag override
|
||||||
|
if cli.manual_compact {
|
||||||
|
config.agent.auto_compact = false;
|
||||||
|
}
|
||||||
|
|
||||||
// Validate provider if specified
|
// Validate provider if specified
|
||||||
if let Some(ref provider) = cli.provider {
|
if let Some(ref provider) = cli.provider {
|
||||||
let valid_providers = ["anthropic", "databricks", "embedded", "openai"];
|
let valid_providers = ["anthropic", "databricks", "embedded", "openai"];
|
||||||
@@ -568,6 +567,11 @@ async fn run_accumulative_mode(
|
|||||||
config.webdriver.enabled = true;
|
config.webdriver.enabled = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply no-auto-compact flag override
|
||||||
|
if cli.manual_compact {
|
||||||
|
config.agent.auto_compact = false;
|
||||||
|
}
|
||||||
|
|
||||||
// Create agent for interactive mode with requirements context
|
// Create agent for interactive mode with requirements context
|
||||||
let ui_writer = ConsoleUiWriter::new();
|
let ui_writer = ConsoleUiWriter::new();
|
||||||
let agent = Agent::new_with_readme_and_quiet(
|
let agent = Agent::new_with_readme_and_quiet(
|
||||||
@@ -645,6 +649,11 @@ async fn run_accumulative_mode(
|
|||||||
config.webdriver.enabled = true;
|
config.webdriver.enabled = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply no-auto-compact flag override
|
||||||
|
if cli.manual_compact {
|
||||||
|
config.agent.auto_compact = false;
|
||||||
|
}
|
||||||
|
|
||||||
// Create agent for this autonomous run
|
// Create agent for this autonomous run
|
||||||
let ui_writer = ConsoleUiWriter::new();
|
let ui_writer = ConsoleUiWriter::new();
|
||||||
let agent = Agent::new_autonomous_with_readme_and_quiet(
|
let agent = Agent::new_autonomous_with_readme_and_quiet(
|
||||||
@@ -765,9 +774,6 @@ async fn run_with_console_mode(
|
|||||||
// Execute task, autonomous mode, or start interactive mode
|
// Execute task, autonomous mode, or start interactive mode
|
||||||
if cli.autonomous {
|
if cli.autonomous {
|
||||||
// Autonomous mode with coach-player feedback loop
|
// Autonomous mode with coach-player feedback loop
|
||||||
if !cli.machine {
|
|
||||||
info!("Starting autonomous mode");
|
|
||||||
}
|
|
||||||
run_autonomous(
|
run_autonomous(
|
||||||
agent,
|
agent,
|
||||||
project,
|
project,
|
||||||
@@ -779,9 +785,6 @@ async fn run_with_console_mode(
|
|||||||
.await?;
|
.await?;
|
||||||
} else if let Some(task) = cli.task {
|
} else if let Some(task) = cli.task {
|
||||||
// Single-shot mode
|
// Single-shot mode
|
||||||
if !cli.machine {
|
|
||||||
info!("Executing task: {}", task);
|
|
||||||
}
|
|
||||||
let output = SimpleOutput::new();
|
let output = SimpleOutput::new();
|
||||||
let result = agent
|
let result = agent
|
||||||
.execute_task_with_timing(&task, None, false, cli.show_prompt, cli.show_code, true)
|
.execute_task_with_timing(&task, None, false, cli.show_prompt, cli.show_code, true)
|
||||||
@@ -789,9 +792,6 @@ async fn run_with_console_mode(
|
|||||||
output.print_smart(&result.response);
|
output.print_smart(&result.response);
|
||||||
} else {
|
} else {
|
||||||
// Interactive mode (default)
|
// Interactive mode (default)
|
||||||
if !cli.machine {
|
|
||||||
info!("Starting interactive mode");
|
|
||||||
}
|
|
||||||
println!("📁 Workspace: {}", project.workspace().display());
|
println!("📁 Workspace: {}", project.workspace().display());
|
||||||
run_interactive(agent, cli.show_prompt, cli.show_code, combined_content).await?;
|
run_interactive(agent, cli.show_prompt, cli.show_code, combined_content).await?;
|
||||||
}
|
}
|
||||||
@@ -840,7 +840,6 @@ fn read_agents_config(workspace_dir: &Path) -> Option<String> {
|
|||||||
match std::fs::read_to_string(&agents_path) {
|
match std::fs::read_to_string(&agents_path) {
|
||||||
Ok(content) => {
|
Ok(content) => {
|
||||||
// Return the content with a note about which file was read
|
// Return the content with a note about which file was read
|
||||||
info!("Loaded AGENTS.md from {}", agents_path.display());
|
|
||||||
Some(format!(
|
Some(format!(
|
||||||
"🤖 Agent Configuration (from AGENTS.md):\n\n{}",
|
"🤖 Agent Configuration (from AGENTS.md):\n\n{}",
|
||||||
content
|
content
|
||||||
@@ -858,7 +857,6 @@ fn read_agents_config(workspace_dir: &Path) -> Option<String> {
|
|||||||
if alt_path.exists() {
|
if alt_path.exists() {
|
||||||
match std::fs::read_to_string(&alt_path) {
|
match std::fs::read_to_string(&alt_path) {
|
||||||
Ok(content) => {
|
Ok(content) => {
|
||||||
info!("Loaded agents.md from {}", alt_path.display());
|
|
||||||
Some(format!("🤖 Agent Configuration (from agents.md):\n\n{}", content))
|
Some(format!("🤖 Agent Configuration (from agents.md):\n\n{}", content))
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -1479,10 +1477,32 @@ fn handle_execution_error(e: &anyhow::Error, input: &str, output: &SimpleOutput,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn display_context_progress<W: UiWriter>(agent: &Agent<W>, output: &SimpleOutput) {
|
fn display_context_progress<W: UiWriter>(agent: &Agent<W>, _output: &SimpleOutput) {
|
||||||
let context = agent.get_context_window();
|
let context = agent.get_context_window();
|
||||||
output.print(&format!("Context: {}/{} tokens ({:.1}%)",
|
let percentage = context.percentage_used();
|
||||||
context.used_tokens, context.total_tokens, context.percentage_used()));
|
|
||||||
|
// Create 10 dots representing context fullness
|
||||||
|
let total_dots: usize = 10;
|
||||||
|
let filled_dots = ((percentage / 100.0) * total_dots as f32).round() as usize;
|
||||||
|
let empty_dots = total_dots.saturating_sub(filled_dots);
|
||||||
|
|
||||||
|
let filled_str = "●".repeat(filled_dots);
|
||||||
|
let empty_str = "○".repeat(empty_dots);
|
||||||
|
|
||||||
|
// Determine color based on percentage
|
||||||
|
let color = if percentage < 40.0 {
|
||||||
|
Color::Green
|
||||||
|
} else if percentage < 60.0 {
|
||||||
|
Color::Yellow
|
||||||
|
} else if percentage < 80.0 {
|
||||||
|
Color::Rgb { r: 255, g: 165, b: 0 } // Orange
|
||||||
|
} else {
|
||||||
|
Color::Red
|
||||||
|
};
|
||||||
|
|
||||||
|
// Print with colored dots (using print! directly to handle color codes)
|
||||||
|
print!("Context: {}{}{}{} {:.0}% ({}/{} tokens)\n",
|
||||||
|
SetForegroundColor(color), filled_str, empty_str, ResetColor, percentage, context.used_tokens, context.total_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set up the workspace directory for autonomous mode
|
/// Set up the workspace directory for autonomous mode
|
||||||
|
|||||||
@@ -71,18 +71,20 @@ impl SimpleOutput {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn print_context(&self, used: u32, total: u32, percentage: f32) {
|
pub fn print_context(&self, used: u32, total: u32, percentage: f32) {
|
||||||
let bar_width: usize = 10;
|
let total_dots = 10;
|
||||||
let filled_width = ((percentage / 100.0) * bar_width as f32) as usize;
|
let filled_dots = ((percentage / 100.0) * total_dots as f32) as usize;
|
||||||
let empty_width = bar_width.saturating_sub(filled_width);
|
let empty_dots = total_dots.saturating_sub(filled_dots);
|
||||||
|
|
||||||
let filled_chars = "●".repeat(filled_width);
|
let filled_str = "●".repeat(filled_dots);
|
||||||
let empty_chars = "○".repeat(empty_width);
|
let empty_str = "○".repeat(empty_dots);
|
||||||
|
|
||||||
// Determine color based on percentage
|
// Determine color based on percentage
|
||||||
let color = if percentage < 60.0 {
|
let color = if percentage < 40.0 {
|
||||||
crossterm::style::Color::Green
|
crossterm::style::Color::Green
|
||||||
} else if percentage < 80.0 {
|
} else if percentage < 60.0 {
|
||||||
crossterm::style::Color::Yellow
|
crossterm::style::Color::Yellow
|
||||||
|
} else if percentage < 80.0 {
|
||||||
|
crossterm::style::Color::Rgb { r: 255, g: 165, b: 0 } // Orange
|
||||||
} else {
|
} else {
|
||||||
crossterm::style::Color::Red
|
crossterm::style::Color::Red
|
||||||
};
|
};
|
||||||
@@ -90,9 +92,9 @@ impl SimpleOutput {
|
|||||||
// Print with colored progress bar
|
// Print with colored progress bar
|
||||||
print!("Context: ");
|
print!("Context: ");
|
||||||
print!("{}", SetForegroundColor(color));
|
print!("{}", SetForegroundColor(color));
|
||||||
print!("{}{}", filled_chars, empty_chars);
|
print!("{}{}", filled_str, empty_str);
|
||||||
print!("{}", ResetColor);
|
print!("{}", ResetColor);
|
||||||
println!(" {:.1}% | {}/{} tokens", percentage, used, total);
|
println!(" {:.0}% ({}/{} tokens)", percentage, used, total);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn print_context_thinning(&self, message: &str) {
|
pub fn print_context_thinning(&self, message: &str) {
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ pub struct ProvidersConfig {
|
|||||||
pub anthropic: Option<AnthropicConfig>,
|
pub anthropic: Option<AnthropicConfig>,
|
||||||
pub databricks: Option<DatabricksConfig>,
|
pub databricks: Option<DatabricksConfig>,
|
||||||
pub embedded: Option<EmbeddedConfig>,
|
pub embedded: Option<EmbeddedConfig>,
|
||||||
|
pub ollama: Option<OllamaConfig>,
|
||||||
pub default_provider: String,
|
pub default_provider: String,
|
||||||
pub coach: Option<String>, // Provider to use for coach in autonomous mode
|
pub coach: Option<String>, // Provider to use for coach in autonomous mode
|
||||||
pub player: Option<String>, // Provider to use for player in autonomous mode
|
pub player: Option<String>, // Provider to use for player in autonomous mode
|
||||||
@@ -60,11 +61,20 @@ pub struct EmbeddedConfig {
|
|||||||
pub threads: Option<u32>, // Number of CPU threads to use
|
pub threads: Option<u32>, // Number of CPU threads to use
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct OllamaConfig {
|
||||||
|
pub model: String,
|
||||||
|
pub base_url: Option<String>, // Default: http://localhost:11434
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct AgentConfig {
|
pub struct AgentConfig {
|
||||||
pub max_context_length: usize,
|
pub max_context_length: usize,
|
||||||
pub enable_streaming: bool,
|
pub enable_streaming: bool,
|
||||||
pub timeout_seconds: u64,
|
pub timeout_seconds: u64,
|
||||||
|
pub auto_compact: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -127,6 +137,7 @@ impl Default for Config {
|
|||||||
use_oauth: Some(true),
|
use_oauth: Some(true),
|
||||||
}),
|
}),
|
||||||
embedded: None,
|
embedded: None,
|
||||||
|
ollama: None,
|
||||||
default_provider: "databricks".to_string(),
|
default_provider: "databricks".to_string(),
|
||||||
coach: None, // Will use default_provider if not specified
|
coach: None, // Will use default_provider if not specified
|
||||||
player: None, // Will use default_provider if not specified
|
player: None, // Will use default_provider if not specified
|
||||||
@@ -135,6 +146,7 @@ impl Default for Config {
|
|||||||
max_context_length: 8192,
|
max_context_length: 8192,
|
||||||
enable_streaming: true,
|
enable_streaming: true,
|
||||||
timeout_seconds: 60,
|
timeout_seconds: 60,
|
||||||
|
auto_compact: true,
|
||||||
},
|
},
|
||||||
computer_control: ComputerControlConfig::default(),
|
computer_control: ComputerControlConfig::default(),
|
||||||
webdriver: WebDriverConfig::default(),
|
webdriver: WebDriverConfig::default(),
|
||||||
@@ -242,6 +254,7 @@ impl Config {
|
|||||||
gpu_layers: Some(32),
|
gpu_layers: Some(32),
|
||||||
threads: Some(8),
|
threads: Some(8),
|
||||||
}),
|
}),
|
||||||
|
ollama: None,
|
||||||
default_provider: "embedded".to_string(),
|
default_provider: "embedded".to_string(),
|
||||||
coach: None, // Will use default_provider if not specified
|
coach: None, // Will use default_provider if not specified
|
||||||
player: None, // Will use default_provider if not specified
|
player: None, // Will use default_provider if not specified
|
||||||
@@ -250,6 +263,7 @@ impl Config {
|
|||||||
max_context_length: 8192,
|
max_context_length: 8192,
|
||||||
enable_streaming: true,
|
enable_streaming: true,
|
||||||
timeout_seconds: 60,
|
timeout_seconds: 60,
|
||||||
|
auto_compact: true,
|
||||||
},
|
},
|
||||||
computer_control: ComputerControlConfig::default(),
|
computer_control: ComputerControlConfig::default(),
|
||||||
webdriver: WebDriverConfig::default(),
|
webdriver: WebDriverConfig::default(),
|
||||||
|
|||||||
@@ -25,3 +25,4 @@ chrono = { version = "0.4", features = ["serde"] }
|
|||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
regex = "1.0"
|
regex = "1.0"
|
||||||
shellexpand = "3.1"
|
shellexpand = "3.1"
|
||||||
|
serde_yaml = "0.9"
|
||||||
|
|||||||
787
crates/g3-core/src/code_search.rs
Normal file
787
crates/g3-core/src/code_search.rs
Normal file
@@ -0,0 +1,787 @@
|
|||||||
|
//! Code search functionality using ast-grep for syntax-aware semantic searches
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::process::Stdio;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||||
|
use tokio::process::Command;
|
||||||
|
use tokio::sync::Semaphore;
|
||||||
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
|
/// Maximum number of searches allowed per request
|
||||||
|
const MAX_SEARCHES: usize = 20;
|
||||||
|
|
||||||
|
/// Default timeout for individual searches in seconds
|
||||||
|
const DEFAULT_TIMEOUT_SECS: u64 = 60;
|
||||||
|
|
||||||
|
/// Default maximum concurrency
|
||||||
|
const DEFAULT_MAX_CONCURRENCY: usize = 4;
|
||||||
|
|
||||||
|
/// Default maximum matches per search
|
||||||
|
const DEFAULT_MAX_MATCHES: usize = 500;
|
||||||
|
|
||||||
|
/// Search specification for a single ast-grep search
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct SearchSpec {
|
||||||
|
pub name: String,
|
||||||
|
pub mode: SearchMode,
|
||||||
|
|
||||||
|
// Pattern mode fields
|
||||||
|
pub pattern: Option<String>,
|
||||||
|
pub language: Option<String>,
|
||||||
|
|
||||||
|
// YAML mode fields
|
||||||
|
pub rule_yaml: Option<String>,
|
||||||
|
|
||||||
|
// Common fields
|
||||||
|
pub paths: Option<Vec<String>>,
|
||||||
|
pub globs: Option<Vec<String>>,
|
||||||
|
pub json_style: Option<JsonStyle>,
|
||||||
|
pub context: Option<u32>,
|
||||||
|
pub threads: Option<u32>,
|
||||||
|
pub include_metadata: Option<bool>,
|
||||||
|
pub no_ignore: Option<Vec<NoIgnoreType>>,
|
||||||
|
pub severity: Option<HashMap<String, SeverityLevel>>,
|
||||||
|
pub timeout_secs: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Search mode: pattern or yaml
|
||||||
|
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum SearchMode {
|
||||||
|
Pattern,
|
||||||
|
Yaml,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// JSON output style
|
||||||
|
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum JsonStyle {
|
||||||
|
Pretty,
|
||||||
|
Stream,
|
||||||
|
Compact,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for JsonStyle {
|
||||||
|
fn default() -> Self {
|
||||||
|
JsonStyle::Stream
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// No-ignore types
|
||||||
|
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum NoIgnoreType {
|
||||||
|
Hidden,
|
||||||
|
Dot,
|
||||||
|
Exclude,
|
||||||
|
Global,
|
||||||
|
Parent,
|
||||||
|
Vcs,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Severity levels for YAML rules
|
||||||
|
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum SeverityLevel {
|
||||||
|
Error,
|
||||||
|
Warning,
|
||||||
|
Info,
|
||||||
|
Hint,
|
||||||
|
Off,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Request structure for code search
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct CodeSearchRequest {
|
||||||
|
pub searches: Vec<SearchSpec>,
|
||||||
|
pub max_concurrency: Option<usize>,
|
||||||
|
pub max_matches_per_search: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of a single search
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct SearchResult {
|
||||||
|
pub name: String,
|
||||||
|
pub mode: String,
|
||||||
|
pub status: String,
|
||||||
|
pub cmd: Vec<String>,
|
||||||
|
pub match_count: Option<usize>,
|
||||||
|
pub truncated: Option<bool>,
|
||||||
|
pub matches: Option<Vec<Value>>,
|
||||||
|
pub stderr: Option<String>,
|
||||||
|
pub exit_code: Option<i32>,
|
||||||
|
pub duration_ms: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Summary of all searches
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct SearchSummary {
|
||||||
|
pub completed: usize,
|
||||||
|
pub total: usize,
|
||||||
|
pub total_matches: usize,
|
||||||
|
pub duration_ms: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Complete response structure
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct CodeSearchResponse {
|
||||||
|
pub summary: SearchSummary,
|
||||||
|
pub searches: Vec<SearchResult>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// YAML rule structure for validation
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct YamlRule {
|
||||||
|
pub id: String,
|
||||||
|
pub language: String,
|
||||||
|
pub rule: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute a batch of code searches using ast-grep
|
||||||
|
pub async fn execute_code_search(request: CodeSearchRequest) -> Result<CodeSearchResponse> {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
// Validate request
|
||||||
|
if request.searches.is_empty() {
|
||||||
|
return Err(anyhow!("No searches specified"));
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.searches.len() > MAX_SEARCHES {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"Too many searches: {} (max: {})",
|
||||||
|
request.searches.len(),
|
||||||
|
MAX_SEARCHES
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if ast-grep is available
|
||||||
|
check_ast_grep_available().await?;
|
||||||
|
|
||||||
|
let max_concurrency = request.max_concurrency.unwrap_or(DEFAULT_MAX_CONCURRENCY);
|
||||||
|
let max_matches = request.max_matches_per_search.unwrap_or(DEFAULT_MAX_MATCHES);
|
||||||
|
|
||||||
|
// Create semaphore for concurrency control
|
||||||
|
let semaphore = std::sync::Arc::new(Semaphore::new(max_concurrency));
|
||||||
|
|
||||||
|
// Execute searches concurrently
|
||||||
|
let mut tasks = Vec::new();
|
||||||
|
|
||||||
|
for search in request.searches {
|
||||||
|
let sem = semaphore.clone();
|
||||||
|
let task = tokio::spawn(async move {
|
||||||
|
let _permit = sem.acquire().await.unwrap();
|
||||||
|
execute_single_search(search, max_matches).await
|
||||||
|
});
|
||||||
|
tasks.push(task);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all searches to complete
|
||||||
|
let mut results = Vec::new();
|
||||||
|
let mut total_matches = 0;
|
||||||
|
let mut completed = 0;
|
||||||
|
|
||||||
|
for task in tasks {
|
||||||
|
match task.await {
|
||||||
|
Ok(result) => {
|
||||||
|
if result.status == "ok" {
|
||||||
|
completed += 1;
|
||||||
|
if let Some(count) = result.match_count {
|
||||||
|
total_matches += count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results.push(result);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Task join error: {}", e);
|
||||||
|
// Create an error result
|
||||||
|
results.push(SearchResult {
|
||||||
|
name: "unknown".to_string(),
|
||||||
|
mode: "unknown".to_string(),
|
||||||
|
status: "error".to_string(),
|
||||||
|
cmd: vec![],
|
||||||
|
match_count: None,
|
||||||
|
truncated: None,
|
||||||
|
matches: None,
|
||||||
|
stderr: Some(format!("Task execution error: {}", e)),
|
||||||
|
exit_code: None,
|
||||||
|
duration_ms: 0,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let total_duration = start_time.elapsed();
|
||||||
|
|
||||||
|
Ok(CodeSearchResponse {
|
||||||
|
summary: SearchSummary {
|
||||||
|
completed,
|
||||||
|
total: results.len(),
|
||||||
|
total_matches,
|
||||||
|
duration_ms: total_duration.as_millis() as u64,
|
||||||
|
},
|
||||||
|
searches: results,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute a single search
|
||||||
|
async fn execute_single_search(search: SearchSpec, max_matches: usize) -> SearchResult {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
let timeout_secs = search.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
|
||||||
|
|
||||||
|
// Validate the search specification
|
||||||
|
if let Err(e) = validate_search_spec(&search) {
|
||||||
|
return SearchResult {
|
||||||
|
name: search.name,
|
||||||
|
mode: format!("{:?}", search.mode).to_lowercase(),
|
||||||
|
status: "error".to_string(),
|
||||||
|
cmd: vec![],
|
||||||
|
match_count: None,
|
||||||
|
truncated: None,
|
||||||
|
matches: None,
|
||||||
|
stderr: Some(format!("Validation error: {}", e)),
|
||||||
|
exit_code: None,
|
||||||
|
duration_ms: start_time.elapsed().as_millis() as u64,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build command
|
||||||
|
let cmd_args = match build_ast_grep_command(&search) {
|
||||||
|
Ok(args) => args,
|
||||||
|
Err(e) => {
|
||||||
|
return SearchResult {
|
||||||
|
name: search.name,
|
||||||
|
mode: format!("{:?}", search.mode).to_lowercase(),
|
||||||
|
status: "error".to_string(),
|
||||||
|
cmd: vec![],
|
||||||
|
match_count: None,
|
||||||
|
truncated: None,
|
||||||
|
matches: None,
|
||||||
|
stderr: Some(format!("Command build error: {}", e)),
|
||||||
|
exit_code: None,
|
||||||
|
duration_ms: start_time.elapsed().as_millis() as u64,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!("Executing ast-grep command: {:?}", cmd_args);
|
||||||
|
|
||||||
|
// Execute with timeout
|
||||||
|
let timeout_duration = Duration::from_secs(timeout_secs);
|
||||||
|
|
||||||
|
match tokio::time::timeout(timeout_duration, run_ast_grep_command(&cmd_args)).await {
|
||||||
|
Ok(Ok((stdout, stderr, exit_code))) => {
|
||||||
|
let duration_ms = start_time.elapsed().as_millis() as u64;
|
||||||
|
|
||||||
|
if exit_code == 0 {
|
||||||
|
// Parse JSON output
|
||||||
|
match parse_ast_grep_output(&stdout, max_matches) {
|
||||||
|
Ok((matches, truncated)) => {
|
||||||
|
SearchResult {
|
||||||
|
name: search.name,
|
||||||
|
mode: format!("{:?}", search.mode).to_lowercase(),
|
||||||
|
status: "ok".to_string(),
|
||||||
|
cmd: cmd_args,
|
||||||
|
match_count: Some(matches.len()),
|
||||||
|
truncated: Some(truncated),
|
||||||
|
matches: Some(matches),
|
||||||
|
stderr: if stderr.is_empty() { None } else { Some(stderr) },
|
||||||
|
exit_code: None,
|
||||||
|
duration_ms,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
SearchResult {
|
||||||
|
name: search.name,
|
||||||
|
mode: format!("{:?}", search.mode).to_lowercase(),
|
||||||
|
status: "error".to_string(),
|
||||||
|
cmd: cmd_args,
|
||||||
|
match_count: None,
|
||||||
|
truncated: None,
|
||||||
|
matches: None,
|
||||||
|
stderr: Some(format!("JSON parse error: {}\nRaw output: {}", e, stdout)),
|
||||||
|
exit_code: Some(exit_code),
|
||||||
|
duration_ms,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
SearchResult {
|
||||||
|
name: search.name,
|
||||||
|
mode: format!("{:?}", search.mode).to_lowercase(),
|
||||||
|
status: "error".to_string(),
|
||||||
|
cmd: cmd_args,
|
||||||
|
match_count: None,
|
||||||
|
truncated: None,
|
||||||
|
matches: None,
|
||||||
|
stderr: Some(stderr),
|
||||||
|
exit_code: Some(exit_code),
|
||||||
|
duration_ms,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
SearchResult {
|
||||||
|
name: search.name,
|
||||||
|
mode: format!("{:?}", search.mode).to_lowercase(),
|
||||||
|
status: "error".to_string(),
|
||||||
|
cmd: cmd_args,
|
||||||
|
match_count: None,
|
||||||
|
truncated: None,
|
||||||
|
matches: None,
|
||||||
|
stderr: Some(format!("Execution error: {}", e)),
|
||||||
|
exit_code: None,
|
||||||
|
duration_ms: start_time.elapsed().as_millis() as u64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
SearchResult {
|
||||||
|
name: search.name,
|
||||||
|
mode: format!("{:?}", search.mode).to_lowercase(),
|
||||||
|
status: "timeout".to_string(),
|
||||||
|
cmd: cmd_args,
|
||||||
|
match_count: None,
|
||||||
|
truncated: None,
|
||||||
|
matches: None,
|
||||||
|
stderr: Some(format!("Search timed out after {} seconds", timeout_secs)),
|
||||||
|
exit_code: None,
|
||||||
|
duration_ms: start_time.elapsed().as_millis() as u64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate a search specification
|
||||||
|
fn validate_search_spec(search: &SearchSpec) -> Result<()> {
|
||||||
|
match search.mode {
|
||||||
|
SearchMode::Pattern => {
|
||||||
|
if search.pattern.is_none() || search.pattern.as_ref().unwrap().is_empty() {
|
||||||
|
return Err(anyhow!("Pattern mode requires non-empty 'pattern' field"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SearchMode::Yaml => {
|
||||||
|
let rule_yaml = search.rule_yaml.as_ref()
|
||||||
|
.ok_or_else(|| anyhow!("YAML mode requires 'rule_yaml' field"))?;
|
||||||
|
|
||||||
|
if rule_yaml.is_empty() {
|
||||||
|
return Err(anyhow!("YAML mode requires non-empty 'rule_yaml' field"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse and validate YAML structure
|
||||||
|
let parsed: YamlRule = serde_yaml::from_str(rule_yaml)
|
||||||
|
.map_err(|e| anyhow!("Invalid YAML rule: {}", e))?;
|
||||||
|
|
||||||
|
if parsed.id.is_empty() {
|
||||||
|
return Err(anyhow!("YAML rule must have non-empty 'id' field"));
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsed.language.is_empty() {
|
||||||
|
return Err(anyhow!("YAML rule must have non-empty 'language' field"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate language is supported (basic check)
|
||||||
|
validate_language(&parsed.language)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate context range
|
||||||
|
if let Some(context) = search.context {
|
||||||
|
if context > 20 {
|
||||||
|
return Err(anyhow!("Context lines cannot exceed 20"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate that a language is supported by ast-grep
|
||||||
|
fn validate_language(language: &str) -> Result<()> {
|
||||||
|
let supported_languages = [
|
||||||
|
"rust", "javascript", "typescript", "python", "java", "c", "cpp", "csharp",
|
||||||
|
"go", "html", "css", "json", "yaml", "xml", "bash", "kotlin", "swift",
|
||||||
|
"php", "ruby", "scala", "dart", "lua", "r", "sql", "dockerfile",
|
||||||
|
"Rust", "JavaScript", "TypeScript", "Python", "Java", "C", "Cpp", "CSharp",
|
||||||
|
"Go", "Html", "Css", "Json", "Yaml", "Xml", "Bash", "Kotlin", "Swift",
|
||||||
|
"Php", "Ruby", "Scala", "Dart", "Lua", "R", "Sql", "Dockerfile"
|
||||||
|
];
|
||||||
|
|
||||||
|
if !supported_languages.contains(&language) {
|
||||||
|
warn!("Language '{}' may not be supported by ast-grep", language);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build ast-grep command arguments
|
||||||
|
fn build_ast_grep_command(search: &SearchSpec) -> Result<Vec<String>> {
|
||||||
|
let mut args = vec!["ast-grep".to_string()];
|
||||||
|
|
||||||
|
match search.mode {
|
||||||
|
SearchMode::Pattern => {
|
||||||
|
args.push("run".to_string());
|
||||||
|
|
||||||
|
// Add pattern
|
||||||
|
args.push("-p".to_string());
|
||||||
|
args.push(search.pattern.as_ref().unwrap().clone());
|
||||||
|
|
||||||
|
// Add language if specified
|
||||||
|
if let Some(ref lang) = search.language {
|
||||||
|
args.push("-l".to_string());
|
||||||
|
args.push(lang.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SearchMode::Yaml => {
|
||||||
|
args.push("scan".to_string());
|
||||||
|
|
||||||
|
// Add inline rules
|
||||||
|
args.push("--inline-rules".to_string());
|
||||||
|
args.push(search.rule_yaml.as_ref().unwrap().clone());
|
||||||
|
|
||||||
|
// Add include-metadata if requested
|
||||||
|
if search.include_metadata.unwrap_or(false) {
|
||||||
|
args.push("--include-metadata".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add severity overrides
|
||||||
|
if let Some(ref severity_map) = search.severity {
|
||||||
|
for (rule_id, severity) in severity_map {
|
||||||
|
match severity {
|
||||||
|
SeverityLevel::Error => {
|
||||||
|
args.push("--error".to_string());
|
||||||
|
args.push(rule_id.clone());
|
||||||
|
}
|
||||||
|
SeverityLevel::Warning => {
|
||||||
|
args.push("--warning".to_string());
|
||||||
|
args.push(rule_id.clone());
|
||||||
|
}
|
||||||
|
SeverityLevel::Info => {
|
||||||
|
args.push("--info".to_string());
|
||||||
|
args.push(rule_id.clone());
|
||||||
|
}
|
||||||
|
SeverityLevel::Hint => {
|
||||||
|
args.push("--hint".to_string());
|
||||||
|
args.push(rule_id.clone());
|
||||||
|
}
|
||||||
|
SeverityLevel::Off => {
|
||||||
|
args.push("--off".to_string());
|
||||||
|
args.push(rule_id.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add common arguments
|
||||||
|
|
||||||
|
// Add globs if specified
|
||||||
|
if let Some(ref globs) = search.globs {
|
||||||
|
if !globs.is_empty() {
|
||||||
|
args.push("--globs".to_string());
|
||||||
|
args.push(globs.join(","));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add context
|
||||||
|
if let Some(context) = search.context {
|
||||||
|
args.push("-C".to_string());
|
||||||
|
args.push(context.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add threads
|
||||||
|
if let Some(threads) = search.threads {
|
||||||
|
args.push("-j".to_string());
|
||||||
|
args.push(threads.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add JSON output style
|
||||||
|
let json_style = search.json_style.as_ref().unwrap_or(&JsonStyle::Stream);
|
||||||
|
let json_arg = match json_style {
|
||||||
|
JsonStyle::Pretty => "--json=pretty",
|
||||||
|
JsonStyle::Stream => "--json=stream",
|
||||||
|
JsonStyle::Compact => "--json=compact",
|
||||||
|
};
|
||||||
|
args.push(json_arg.to_string());
|
||||||
|
|
||||||
|
// Add no-ignore options
|
||||||
|
if let Some(ref no_ignore_list) = search.no_ignore {
|
||||||
|
for no_ignore_type in no_ignore_list {
|
||||||
|
let flag = match no_ignore_type {
|
||||||
|
NoIgnoreType::Hidden => "--no-ignore=hidden",
|
||||||
|
NoIgnoreType::Dot => "--no-ignore=dot",
|
||||||
|
NoIgnoreType::Exclude => "--no-ignore=exclude",
|
||||||
|
NoIgnoreType::Global => "--no-ignore=global",
|
||||||
|
NoIgnoreType::Parent => "--no-ignore=parent",
|
||||||
|
NoIgnoreType::Vcs => "--no-ignore=vcs",
|
||||||
|
};
|
||||||
|
args.push(flag.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add paths (default to current directory if none specified)
|
||||||
|
if let Some(ref paths) = search.paths {
|
||||||
|
if !paths.is_empty() {
|
||||||
|
args.extend(paths.clone());
|
||||||
|
} else {
|
||||||
|
args.push(".".to_string());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
args.push(".".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(args)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run ast-grep command and capture output
|
||||||
|
async fn run_ast_grep_command(args: &[String]) -> Result<(String, String, i32)> {
|
||||||
|
let mut cmd = Command::new(&args[0]);
|
||||||
|
cmd.args(&args[1..]);
|
||||||
|
cmd.stdout(Stdio::piped());
|
||||||
|
cmd.stderr(Stdio::piped());
|
||||||
|
|
||||||
|
debug!("Running command: {:?}", args);
|
||||||
|
|
||||||
|
let mut child = cmd.spawn()
|
||||||
|
.map_err(|e| anyhow!("Failed to spawn ast-grep process: {}", e))?;
|
||||||
|
|
||||||
|
let stdout = child.stdout.take().unwrap();
|
||||||
|
let stderr = child.stderr.take().unwrap();
|
||||||
|
|
||||||
|
let stdout_reader = BufReader::new(stdout);
|
||||||
|
let stderr_reader = BufReader::new(stderr);
|
||||||
|
|
||||||
|
let stdout_task = tokio::spawn(async move {
|
||||||
|
let mut lines = stdout_reader.lines();
|
||||||
|
let mut output = String::new();
|
||||||
|
while let Ok(Some(line)) = lines.next_line().await {
|
||||||
|
if !output.is_empty() {
|
||||||
|
output.push('\n');
|
||||||
|
}
|
||||||
|
output.push_str(&line);
|
||||||
|
}
|
||||||
|
output
|
||||||
|
});
|
||||||
|
|
||||||
|
let stderr_task = tokio::spawn(async move {
|
||||||
|
let mut lines = stderr_reader.lines();
|
||||||
|
let mut output = String::new();
|
||||||
|
while let Ok(Some(line)) = lines.next_line().await {
|
||||||
|
if !output.is_empty() {
|
||||||
|
output.push('\n');
|
||||||
|
}
|
||||||
|
output.push_str(&line);
|
||||||
|
}
|
||||||
|
output
|
||||||
|
});
|
||||||
|
|
||||||
|
let status = child.wait().await
|
||||||
|
.map_err(|e| anyhow!("Failed to wait for ast-grep process: {}", e))?;
|
||||||
|
|
||||||
|
let stdout_output = stdout_task.await
|
||||||
|
.map_err(|e| anyhow!("Failed to read stdout: {}", e))?;
|
||||||
|
let stderr_output = stderr_task.await
|
||||||
|
.map_err(|e| anyhow!("Failed to read stderr: {}", e))?;
|
||||||
|
|
||||||
|
let exit_code = status.code().unwrap_or(-1);
|
||||||
|
|
||||||
|
Ok((stdout_output, stderr_output, exit_code))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse ast-grep JSON output
|
||||||
|
fn parse_ast_grep_output(output: &str, max_matches: usize) -> Result<(Vec<Value>, bool)> {
|
||||||
|
if output.trim().is_empty() {
|
||||||
|
return Ok((vec![], false));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut matches = Vec::new();
|
||||||
|
let mut truncated = false;
|
||||||
|
|
||||||
|
// Handle stream format (line-delimited JSON)
|
||||||
|
for line in output.lines() {
|
||||||
|
let line = line.trim();
|
||||||
|
if line.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match serde_json::from_str::<Value>(line) {
|
||||||
|
Ok(match_obj) => {
|
||||||
|
if matches.len() >= max_matches {
|
||||||
|
truncated = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
matches.push(match_obj);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
debug!("Failed to parse JSON line '{}': {}", line, e);
|
||||||
|
// Try to parse the entire output as a single JSON array
|
||||||
|
match serde_json::from_str::<Vec<Value>>(output) {
|
||||||
|
Ok(array_matches) => {
|
||||||
|
let take_count = array_matches.len().min(max_matches);
|
||||||
|
let total_count = array_matches.len();
|
||||||
|
matches = array_matches.into_iter().take(take_count).collect();
|
||||||
|
truncated = take_count < total_count;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Err(e2) => {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"Failed to parse ast-grep output as line-delimited JSON or JSON array. Line error: {}, Array error: {}",
|
||||||
|
e, e2
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((matches, truncated))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if ast-grep is available and provide installation hints if not
|
||||||
|
async fn check_ast_grep_available() -> Result<()> {
|
||||||
|
match Command::new("ast-grep")
|
||||||
|
.arg("--version")
|
||||||
|
.output()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(output) => {
|
||||||
|
if output.status.success() {
|
||||||
|
let version = String::from_utf8_lossy(&output.stdout);
|
||||||
|
info!("Found ast-grep: {}", version.trim());
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("ast-grep command failed: {}", String::from_utf8_lossy(&output.stderr)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
Err(anyhow!(
|
||||||
|
"ast-grep not found. Please install it using one of these methods:\n\n\
|
||||||
|
• Homebrew (macOS): brew install ast-grep\n\
|
||||||
|
• MacPorts (macOS): sudo port install ast-grep\n\
|
||||||
|
• Nix: nix-env -iA nixpkgs.ast-grep\n\
|
||||||
|
• Cargo: cargo install ast-grep\n\
|
||||||
|
• npm: npm install -g @ast-grep/cli\n\
|
||||||
|
• pip: pip install ast-grep\n\n\
|
||||||
|
For more installation options, visit: https://ast-grep.github.io/guide/quick-start.html"
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_pattern_search() {
|
||||||
|
let search = SearchSpec {
|
||||||
|
name: "test".to_string(),
|
||||||
|
mode: SearchMode::Pattern,
|
||||||
|
pattern: Some("fn $NAME() {}".to_string()),
|
||||||
|
language: Some("rust".to_string()),
|
||||||
|
rule_yaml: None,
|
||||||
|
paths: None,
|
||||||
|
globs: None,
|
||||||
|
json_style: None,
|
||||||
|
context: None,
|
||||||
|
threads: None,
|
||||||
|
include_metadata: None,
|
||||||
|
no_ignore: None,
|
||||||
|
severity: None,
|
||||||
|
timeout_secs: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(validate_search_spec(&search).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_yaml_search() {
|
||||||
|
let yaml_rule = r#"
|
||||||
|
id: test-rule
|
||||||
|
language: Rust
|
||||||
|
rule:
|
||||||
|
pattern: "fn $NAME() {}"
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let search = SearchSpec {
|
||||||
|
name: "test".to_string(),
|
||||||
|
mode: SearchMode::Yaml,
|
||||||
|
pattern: None,
|
||||||
|
language: None,
|
||||||
|
rule_yaml: Some(yaml_rule.to_string()),
|
||||||
|
paths: None,
|
||||||
|
globs: None,
|
||||||
|
json_style: None,
|
||||||
|
context: None,
|
||||||
|
threads: None,
|
||||||
|
include_metadata: None,
|
||||||
|
no_ignore: None,
|
||||||
|
severity: None,
|
||||||
|
timeout_secs: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(validate_search_spec(&search).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_pattern_command() {
|
||||||
|
let search = SearchSpec {
|
||||||
|
name: "test".to_string(),
|
||||||
|
mode: SearchMode::Pattern,
|
||||||
|
pattern: Some("fn $NAME() {}".to_string()),
|
||||||
|
language: Some("rust".to_string()),
|
||||||
|
rule_yaml: None,
|
||||||
|
paths: Some(vec!["src/".to_string()]),
|
||||||
|
globs: None,
|
||||||
|
json_style: Some(JsonStyle::Stream),
|
||||||
|
context: Some(2),
|
||||||
|
threads: Some(4),
|
||||||
|
include_metadata: None,
|
||||||
|
no_ignore: None,
|
||||||
|
severity: None,
|
||||||
|
timeout_secs: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let cmd = build_ast_grep_command(&search).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(cmd[0], "ast-grep");
|
||||||
|
assert_eq!(cmd[1], "run");
|
||||||
|
assert!(cmd.contains(&"-p".to_string()));
|
||||||
|
assert!(cmd.contains(&"fn $NAME() {}".to_string()));
|
||||||
|
assert!(cmd.contains(&"-l".to_string()));
|
||||||
|
assert!(cmd.contains(&"rust".to_string()));
|
||||||
|
assert!(cmd.contains(&"--json=stream".to_string()));
|
||||||
|
assert!(cmd.contains(&"-C".to_string()));
|
||||||
|
assert!(cmd.contains(&"2".to_string()));
|
||||||
|
assert!(cmd.contains(&"-j".to_string()));
|
||||||
|
assert!(cmd.contains(&"4".to_string()));
|
||||||
|
assert!(cmd.contains(&"src/".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_stream_json() {
|
||||||
|
let output = r#"{"file":"test.rs","text":"fn hello() {}"}
|
||||||
|
{"file":"test2.rs","text":"fn world() {}"}"#;
|
||||||
|
|
||||||
|
let (matches, truncated) = parse_ast_grep_output(output, 10).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(matches.len(), 2);
|
||||||
|
assert!(!truncated);
|
||||||
|
assert_eq!(matches[0]["file"], "test.rs");
|
||||||
|
assert_eq!(matches[1]["file"], "test2.rs");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_truncated_output() {
|
||||||
|
let output = r#"{"file":"test1.rs","text":"fn a() {}"}
|
||||||
|
{"file":"test2.rs","text":"fn b() {}"}
|
||||||
|
{"file":"test3.rs","text":"fn c() {}"}"#;
|
||||||
|
|
||||||
|
let (matches, truncated) = parse_ast_grep_output(output, 2).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(matches.len(), 2);
|
||||||
|
assert!(truncated);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,11 @@
|
|||||||
// 3. Only elide JSON content between first '{' and last '}' (inclusive)
|
// 3. Only elide JSON content between first '{' and last '}' (inclusive)
|
||||||
// 4. Return everything else as the final filtered string
|
// 4. Return everything else as the final filtered string
|
||||||
|
|
||||||
|
//! JSON tool call filtering for streaming LLM responses.
|
||||||
|
//!
|
||||||
|
//! This module filters out JSON tool calls from LLM output streams while preserving
|
||||||
|
//! regular text content. It uses a state machine to handle streaming chunks.
|
||||||
|
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use std::cell::RefCell;
|
use std::cell::RefCell;
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
@@ -13,37 +18,51 @@ thread_local! {
|
|||||||
static FIXED_JSON_TOOL_STATE: RefCell<FixedJsonToolState> = RefCell::new(FixedJsonToolState::new());
|
static FIXED_JSON_TOOL_STATE: RefCell<FixedJsonToolState> = RefCell::new(FixedJsonToolState::new());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Internal state for tracking JSON tool call filtering across streaming chunks.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct FixedJsonToolState {
|
struct FixedJsonToolState {
|
||||||
|
/// True when actively suppressing a confirmed tool call
|
||||||
suppression_mode: bool,
|
suppression_mode: bool,
|
||||||
|
/// True when buffering potential JSON (saw { but not yet confirmed as tool call)
|
||||||
|
potential_json_mode: bool,
|
||||||
|
/// Tracks nesting depth of braces within JSON
|
||||||
brace_depth: i32,
|
brace_depth: i32,
|
||||||
buffer: String,
|
buffer: String,
|
||||||
json_start_in_buffer: Option<usize>,
|
json_start_in_buffer: Option<usize>, // Position where confirmed JSON tool call starts
|
||||||
content_returned_up_to: usize, // Track how much content we've already returned
|
content_returned_up_to: usize, // Track how much content we've already returned
|
||||||
|
potential_json_start: Option<usize>, // Where the potential JSON started
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FixedJsonToolState {
|
impl FixedJsonToolState {
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
suppression_mode: false,
|
suppression_mode: false,
|
||||||
|
potential_json_mode: false,
|
||||||
brace_depth: 0,
|
brace_depth: 0,
|
||||||
buffer: String::new(),
|
buffer: String::new(),
|
||||||
json_start_in_buffer: None,
|
json_start_in_buffer: None,
|
||||||
content_returned_up_to: 0,
|
content_returned_up_to: 0,
|
||||||
|
potential_json_start: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset(&mut self) {
|
fn reset(&mut self) {
|
||||||
self.suppression_mode = false;
|
self.suppression_mode = false;
|
||||||
|
self.potential_json_mode = false;
|
||||||
self.brace_depth = 0;
|
self.brace_depth = 0;
|
||||||
self.buffer.clear();
|
self.buffer.clear();
|
||||||
self.json_start_in_buffer = None;
|
self.json_start_in_buffer = None;
|
||||||
self.content_returned_up_to = 0;
|
self.content_returned_up_to = 0;
|
||||||
|
self.potential_json_start = None;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// FINAL CORRECTED implementation according to specification
|
// FINAL CORRECTED implementation according to specification
|
||||||
|
|
||||||
|
/// Filters JSON tool calls from streaming LLM content.
|
||||||
|
///
|
||||||
|
/// Processes content chunks and removes JSON tool calls while preserving regular text.
|
||||||
|
/// Maintains state across calls to handle tool calls spanning multiple chunks.
|
||||||
pub fn fixed_filter_json_tool_calls(content: &str) -> String {
|
pub fn fixed_filter_json_tool_calls(content: &str) -> String {
|
||||||
if content.is_empty() {
|
if content.is_empty() {
|
||||||
return String::new();
|
return String::new();
|
||||||
@@ -87,13 +106,225 @@ pub fn fixed_filter_json_tool_calls(content: &str) -> String {
|
|||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CRITICAL FIX: After counting braces, if still in suppression mode,
|
||||||
|
// check if a new tool call pattern appears. This handles truncated JSON
|
||||||
|
// followed by complete JSON.
|
||||||
|
if state.suppression_mode {
|
||||||
|
let current_json_start = state.json_start_in_buffer.unwrap();
|
||||||
|
// Don't require newline - the new JSON might be concatenated directly
|
||||||
|
let tool_call_regex = Regex::new(r#"\{\s*"tool"\s*:\s*""#).unwrap();
|
||||||
|
|
||||||
|
// Look for new tool call patterns after the current one
|
||||||
|
if let Some(captures) = tool_call_regex.find(&state.buffer[current_json_start + 1..]) {
|
||||||
|
let new_json_start = current_json_start + 1 + captures.start() + captures.as_str().find('{').unwrap();
|
||||||
|
|
||||||
|
debug!("Detected new tool call at position {} while processing incomplete one at {} - discarding old", new_json_start, current_json_start);
|
||||||
|
|
||||||
|
// The previous JSON was incomplete/malformed
|
||||||
|
// Return content before the old JSON (if any)
|
||||||
|
let content_before_old_json = if current_json_start > state.content_returned_up_to {
|
||||||
|
state.buffer[state.content_returned_up_to..current_json_start].to_string()
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Update state to skip the incomplete JSON and position at the new one
|
||||||
|
// We'll process the new JSON on the next call
|
||||||
|
state.content_returned_up_to = new_json_start;
|
||||||
|
state.suppression_mode = false;
|
||||||
|
state.json_start_in_buffer = None;
|
||||||
|
state.brace_depth = 0;
|
||||||
|
|
||||||
|
return content_before_old_json;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Still in suppression mode, return empty string (content is being accumulated)
|
// Still in suppression mode, return empty string (content is being accumulated)
|
||||||
return String::new();
|
return String::new();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if we're in potential JSON mode (saw { but waiting to confirm it's a tool call)
|
||||||
|
if state.potential_json_mode {
|
||||||
|
// Check if the buffer contains a confirmed tool call pattern
|
||||||
|
let tool_call_regex = Regex::new(r#"(?m)^\s*\{\s*"tool"\s*:\s*""#).unwrap();
|
||||||
|
|
||||||
|
if let Some(captures) = tool_call_regex.find(&state.buffer) {
|
||||||
|
// Confirmed! This is a tool call - enter suppression mode
|
||||||
|
let match_text = captures.as_str();
|
||||||
|
if let Some(brace_offset) = match_text.find('{') {
|
||||||
|
let json_start = captures.start() + brace_offset;
|
||||||
|
|
||||||
|
debug!("Confirmed JSON tool call at position {} - entering suppression mode", json_start);
|
||||||
|
|
||||||
|
state.potential_json_mode = false;
|
||||||
|
state.suppression_mode = true;
|
||||||
|
state.brace_depth = 0;
|
||||||
|
state.json_start_in_buffer = Some(json_start);
|
||||||
|
|
||||||
|
// Count braces from json_start to see if JSON is complete
|
||||||
|
let buffer_slice = state.buffer[json_start..].to_string();
|
||||||
|
for ch in buffer_slice.chars() {
|
||||||
|
match ch {
|
||||||
|
'{' => state.brace_depth += 1,
|
||||||
|
'}' => {
|
||||||
|
state.brace_depth -= 1;
|
||||||
|
if state.brace_depth <= 0 {
|
||||||
|
debug!("JSON tool call completed immediately");
|
||||||
|
let result = extract_fixed_content(&state.buffer, json_start);
|
||||||
|
let new_content = if result.len() > state.content_returned_up_to {
|
||||||
|
result[state.content_returned_up_to..].to_string()
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
};
|
||||||
|
state.reset();
|
||||||
|
return new_content;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// JSON incomplete, stay in suppression mode, return nothing
|
||||||
|
return String::new();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we can rule out this being a tool call
|
||||||
|
// If we have enough content after the { and it doesn't match the pattern, release it
|
||||||
|
if let Some(potential_start) = state.potential_json_start {
|
||||||
|
let content_after_brace = &state.buffer[potential_start..];
|
||||||
|
|
||||||
|
// Rule out as a tool call if:
|
||||||
|
// 1. Closing } appears before we see the full pattern
|
||||||
|
// 2. Content clearly doesn't match the tool call pattern
|
||||||
|
// 3. Newline appears after the opening brace (tool calls should be compact)
|
||||||
|
|
||||||
|
let has_closing_brace = content_after_brace.contains('}');
|
||||||
|
let has_newline = content_after_brace[1..].contains('\n'); // Skip first char which is {
|
||||||
|
let long_enough = content_after_brace.len() >= 10;
|
||||||
|
|
||||||
|
// Detect non-tool JSON patterns:
|
||||||
|
// - { followed by " and a key that doesn't start with "tool"
|
||||||
|
// - { followed by "t" but not "to"
|
||||||
|
// - { followed by "to" but not "too", etc.
|
||||||
|
let not_tool_pattern = Regex::new(r#"^\{\s*"(?:[^t]|t(?:[^o]|o(?:[^o]|o(?:[^l]|l[^"\s:]))))"#).unwrap();
|
||||||
|
let definitely_not_tool = not_tool_pattern.is_match(content_after_brace);
|
||||||
|
|
||||||
|
if has_closing_brace || has_newline || (long_enough && definitely_not_tool) {
|
||||||
|
debug!("Potential JSON ruled out - not a tool call");
|
||||||
|
state.potential_json_mode = false;
|
||||||
|
state.potential_json_start = None;
|
||||||
|
|
||||||
|
// Return the buffered content we've been holding
|
||||||
|
let new_content = if state.buffer.len() > state.content_returned_up_to {
|
||||||
|
state.buffer[state.content_returned_up_to..].to_string()
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
};
|
||||||
|
state.content_returned_up_to = state.buffer.len();
|
||||||
|
return new_content;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Still in potential mode, keep buffering
|
||||||
|
return String::new();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect potential JSON start: { at the beginning of a line
|
||||||
|
let potential_json_regex = Regex::new(r"(?m)^\s*\{\s*").unwrap();
|
||||||
|
|
||||||
|
if let Some(captures) = potential_json_regex.find(&state.buffer[state.content_returned_up_to..]) {
|
||||||
|
let match_start = state.content_returned_up_to + captures.start();
|
||||||
|
let brace_pos = match_start + captures.as_str().find('{').unwrap();
|
||||||
|
|
||||||
|
debug!("Potential JSON detected at position {} - entering buffering mode", brace_pos);
|
||||||
|
|
||||||
|
// Fast path: check if this is already a confirmed tool call
|
||||||
|
let tool_call_regex = Regex::new(r#"(?m)^\s*\{\s*"tool"\s*:\s*""#).unwrap();
|
||||||
|
if tool_call_regex.is_match(&state.buffer[brace_pos..]) {
|
||||||
|
// This is a confirmed tool call! Process it immediately
|
||||||
|
let json_start = brace_pos;
|
||||||
|
debug!("Immediately confirmed tool call at position {}", json_start);
|
||||||
|
|
||||||
|
// Return content before JSON
|
||||||
|
let content_before = if json_start > state.content_returned_up_to {
|
||||||
|
state.buffer[state.content_returned_up_to..json_start].to_string()
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
state.content_returned_up_to = json_start;
|
||||||
|
state.suppression_mode = true;
|
||||||
|
state.brace_depth = 0;
|
||||||
|
state.json_start_in_buffer = Some(json_start);
|
||||||
|
|
||||||
|
// Count braces to see if JSON is complete
|
||||||
|
let buffer_slice = state.buffer[json_start..].to_string();
|
||||||
|
for ch in buffer_slice.chars() {
|
||||||
|
match ch {
|
||||||
|
'{' => state.brace_depth += 1,
|
||||||
|
'}' => {
|
||||||
|
state.brace_depth -= 1;
|
||||||
|
if state.brace_depth <= 0 {
|
||||||
|
debug!("JSON tool call completed in same chunk");
|
||||||
|
let result = extract_fixed_content(&state.buffer, json_start);
|
||||||
|
let content_after = if result.len() > json_start {
|
||||||
|
&result[json_start..]
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
};
|
||||||
|
let final_result = format!("{}{}", content_before, content_after);
|
||||||
|
state.reset();
|
||||||
|
return final_result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// JSON incomplete, return content before and stay in suppression mode
|
||||||
|
return content_before;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return content before the potential JSON
|
||||||
|
let content_before = if brace_pos > state.content_returned_up_to {
|
||||||
|
state.buffer[state.content_returned_up_to..brace_pos].to_string()
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
state.content_returned_up_to = brace_pos;
|
||||||
|
state.potential_json_mode = true;
|
||||||
|
state.potential_json_start = Some(brace_pos);
|
||||||
|
|
||||||
|
// Optimization: immediately check if we can rule this out for single-chunk processing
|
||||||
|
let content_after_brace = &state.buffer[brace_pos..];
|
||||||
|
let has_closing_brace = content_after_brace.contains('}');
|
||||||
|
let has_newline = content_after_brace.len() > 1 && content_after_brace[1..].contains('\n');
|
||||||
|
let long_enough = content_after_brace.len() >= 10;
|
||||||
|
|
||||||
|
let not_tool_pattern = Regex::new(r#"^\{\s*"(?:[^t]|t(?:[^o]|o(?:[^o]|o(?:[^l]|l[^"\s:]))))"#).unwrap();
|
||||||
|
let definitely_not_tool = not_tool_pattern.is_match(content_after_brace);
|
||||||
|
|
||||||
|
if has_closing_brace || has_newline || (long_enough && definitely_not_tool) {
|
||||||
|
debug!("Immediately ruled out as not a tool call");
|
||||||
|
state.potential_json_mode = false;
|
||||||
|
state.potential_json_start = None;
|
||||||
|
|
||||||
|
// Return all the buffered content
|
||||||
|
let new_content = if state.buffer.len() > state.content_returned_up_to {
|
||||||
|
state.buffer[state.content_returned_up_to..].to_string()
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
};
|
||||||
|
state.content_returned_up_to = state.buffer.len();
|
||||||
|
return format!("{}{}", content_before, new_content);
|
||||||
|
}
|
||||||
|
|
||||||
|
return content_before;
|
||||||
|
}
|
||||||
|
|
||||||
// Check for tool call pattern using corrected regex
|
// Check for tool call pattern using corrected regex
|
||||||
// More flexible than the strict specification to handle real-world JSON
|
let tool_call_regex = Regex::new(r#"(?m)^\s*\{\s*"tool"\s*:\s*"[^"]*""#).unwrap();
|
||||||
let tool_call_regex = Regex::new(r#"(?m)^\s*\{\s*"tool"\s*:\s*""#).unwrap();
|
|
||||||
|
|
||||||
if let Some(captures) = tool_call_regex.find(&state.buffer) {
|
if let Some(captures) = tool_call_regex.find(&state.buffer) {
|
||||||
let match_text = captures.as_str();
|
let match_text = captures.as_str();
|
||||||
@@ -168,9 +399,17 @@ pub fn fixed_filter_json_tool_calls(content: &str) -> String {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to extract content with JSON tool call filtered out
|
/// Extracts content from buffer, removing the JSON tool call.
|
||||||
// Returns everything except the JSON between the first '{' and last '}' (inclusive)
|
///
|
||||||
|
/// Given a buffer and the start position of a JSON tool call, this function:
|
||||||
|
/// 1. Extracts all content before the JSON
|
||||||
|
/// 2. Finds the end of the JSON (matching closing brace)
|
||||||
|
/// 3. Extracts all content after the JSON
|
||||||
|
/// 4. Returns the concatenation of before + after (JSON removed)
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `full_content` - The full content buffer
|
||||||
|
/// * `json_start` - Position where the JSON tool call begins
|
||||||
fn extract_fixed_content(full_content: &str, json_start: usize) -> String {
|
fn extract_fixed_content(full_content: &str, json_start: usize) -> String {
|
||||||
// Find the end of the JSON using proper brace counting with string handling
|
// Find the end of the JSON using proper brace counting with string handling
|
||||||
let mut brace_depth = 0;
|
let mut brace_depth = 0;
|
||||||
@@ -212,8 +451,10 @@ fn extract_fixed_content(full_content: &str, json_start: usize) -> String {
|
|||||||
format!("{}{}", before, after)
|
format!("{}{}", before, after)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset function for testing
|
/// Resets the global JSON filtering state.
|
||||||
|
///
|
||||||
|
/// Call this between independent filtering sessions to ensure clean state.
|
||||||
|
/// This is particularly important in tests and when starting new conversations.
|
||||||
pub fn reset_fixed_json_tool_state() {
|
pub fn reset_fixed_json_tool_state() {
|
||||||
FIXED_JSON_TOOL_STATE.with(|state| {
|
FIXED_JSON_TOOL_STATE.with(|state| {
|
||||||
let mut state = state.borrow_mut();
|
let mut state = state.borrow_mut();
|
||||||
|
|||||||
@@ -1,8 +1,14 @@
|
|||||||
|
//! Tests for JSON tool call filtering.
|
||||||
|
//!
|
||||||
|
//! These tests verify that the filter correctly identifies and removes JSON tool calls
|
||||||
|
//! from LLM output streams while preserving all other content.
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod fixed_filter_tests {
|
mod fixed_filter_tests {
|
||||||
use crate::fixed_filter_json::{fixed_filter_json_tool_calls, reset_fixed_json_tool_state};
|
use crate::fixed_filter_json::{fixed_filter_json_tool_calls, reset_fixed_json_tool_state};
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
|
|
||||||
|
/// Test that regular text without tool calls passes through unchanged.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_no_tool_call_passthrough() {
|
fn test_no_tool_call_passthrough() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -11,6 +17,7 @@ mod fixed_filter_tests {
|
|||||||
assert_eq!(result, input);
|
assert_eq!(result, input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test detection and removal of a complete tool call in a single chunk.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_simple_tool_call_detection() {
|
fn test_simple_tool_call_detection() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -23,6 +30,7 @@ Some text after"#;
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test handling of tool calls that arrive across multiple streaming chunks.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_streaming_chunks() {
|
fn test_streaming_chunks() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -48,6 +56,7 @@ Some text after"#;
|
|||||||
assert_eq!(final_result, expected);
|
assert_eq!(final_result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test correct handling of nested braces within JSON strings.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_nested_braces_in_tool_call() {
|
fn test_nested_braces_in_tool_call() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -61,6 +70,7 @@ Text after"#;
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Verify the regex pattern matches the specification with flexible whitespace.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_regex_pattern_specification() {
|
fn test_regex_pattern_specification() {
|
||||||
// Test the corrected regex pattern that's more flexible with whitespace
|
// Test the corrected regex pattern that's more flexible with whitespace
|
||||||
@@ -84,11 +94,6 @@ Text after"#;
|
|||||||
), // Space after { DOES match with \s*
|
), // Space after { DOES match with \s*
|
||||||
(
|
(
|
||||||
r#"line
|
r#"line
|
||||||
abc{"tool":"#,
|
|
||||||
true,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r#"line
|
|
||||||
{"tool123":"#,
|
{"tool123":"#,
|
||||||
false,
|
false,
|
||||||
), // "tool123" is not exactly "tool"
|
), // "tool123" is not exactly "tool"
|
||||||
@@ -109,6 +114,7 @@ abc{"tool":"#,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test that tool calls must appear at the start of a line (after newline).
|
||||||
#[test]
|
#[test]
|
||||||
fn test_newline_requirement() {
|
fn test_newline_requirement() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -122,13 +128,14 @@ abc{"tool":"#,
|
|||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
let result2 = fixed_filter_json_tool_calls(input_without_newline);
|
let result2 = fixed_filter_json_tool_calls(input_without_newline);
|
||||||
|
|
||||||
// Both cases currently trigger suppression due to regex pattern
|
// With the new aggressive filtering, only the newline case should trigger suppression
|
||||||
// TODO: Fix regex to only match after actual newlines
|
// The pattern requires { to be at the start of a line (after ^)
|
||||||
assert_eq!(result1, "Text\n");
|
assert_eq!(result1, "Text\n");
|
||||||
// This currently fails because our regex matches both cases
|
// Without newline before {, it should pass through unchanged
|
||||||
assert_eq!(result2, "Text ");
|
assert_eq!(result2, input_without_newline);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test handling of escaped quotes within JSON strings.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_json_with_escaped_quotes() {
|
fn test_json_with_escaped_quotes() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -142,6 +149,7 @@ More text"#;
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test graceful handling of incomplete/malformed JSON.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_edge_case_malformed_json() {
|
fn test_edge_case_malformed_json() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -157,6 +165,7 @@ More text"#;
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test processing multiple independent tool calls sequentially.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_multiple_tool_calls_sequential() {
|
fn test_multiple_tool_calls_sequential() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -179,6 +188,7 @@ Final text"#;
|
|||||||
assert_eq!(result2, expected2);
|
assert_eq!(result2, expected2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test tool calls with complex multi-line arguments.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_tool_call_with_complex_args() {
|
fn test_tool_call_with_complex_args() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -192,6 +202,7 @@ After"#;
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test input containing only a tool call with no surrounding text.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_tool_call_only() {
|
fn test_tool_call_only() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -204,6 +215,7 @@ After"#;
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test accurate brace counting with deeply nested structures.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_brace_counting_accuracy() {
|
fn test_brace_counting_accuracy() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -218,6 +230,7 @@ End"#;
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test that braces within strings don't affect brace counting.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_string_escaping_in_json() {
|
fn test_string_escaping_in_json() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -232,6 +245,7 @@ More"#;
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Verify compliance with the exact specification requirements.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_specification_compliance() {
|
fn test_specification_compliance() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -248,6 +262,7 @@ More"#;
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test that non-tool JSON objects are not filtered.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_no_false_positives() {
|
fn test_no_false_positives() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -261,6 +276,7 @@ More text"#;
|
|||||||
assert_eq!(result, input);
|
assert_eq!(result, input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test patterns that look similar to tool calls but aren't exact matches.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_partial_tool_patterns() {
|
fn test_partial_tool_patterns() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -280,6 +296,7 @@ More text"#;
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test streaming with very small chunks (character-by-character).
|
||||||
#[test]
|
#[test]
|
||||||
fn test_streaming_edge_cases() {
|
fn test_streaming_edge_cases() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -296,12 +313,13 @@ More text"#;
|
|||||||
}
|
}
|
||||||
|
|
||||||
let final_result: String = results.join("");
|
let final_result: String = results.join("");
|
||||||
// This test currently fails because the JSON is incomplete across chunks
|
// With the new aggressive filtering, the JSON should be completely filtered out
|
||||||
// The function doesn't handle this edge case properly yet
|
// even when it arrives in very small chunks
|
||||||
let expected = "Text\n{\"tool\": \nAfter";
|
let expected = "Text\n\nAfter";
|
||||||
assert_eq!(final_result, expected);
|
assert_eq!(final_result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Debug test with detailed logging for streaming behavior.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_streaming_debug() {
|
fn test_streaming_debug() {
|
||||||
reset_fixed_json_tool_state();
|
reset_fixed_json_tool_state();
|
||||||
@@ -329,4 +347,38 @@ More text"#;
|
|||||||
let expected = "Some text before\n\nText after";
|
let expected = "Some text before\n\nText after";
|
||||||
assert_eq!(final_result, expected);
|
assert_eq!(final_result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test handling of truncated JSON followed by complete JSON (the json_err pattern)
|
||||||
|
#[test]
|
||||||
|
fn test_truncated_then_complete_json() {
|
||||||
|
reset_fixed_json_tool_state();
|
||||||
|
|
||||||
|
// Simulate the pattern from json_err trace:
|
||||||
|
// 1. Incomplete/truncated JSON appears
|
||||||
|
// 2. Then the same complete JSON appears
|
||||||
|
let chunks = vec![
|
||||||
|
"Some text\n",
|
||||||
|
r#"{"tool": "str_replace", "args": {"diff":"...","file_path":"./crates/g3-cli"#, // Truncated
|
||||||
|
r#"{"tool": "str_replace", "args": {"diff":"...","file_path":"./crates/g3-cli/src/lib.rs"}}"#, // Complete
|
||||||
|
"\nMore text",
|
||||||
|
];
|
||||||
|
|
||||||
|
let mut results = Vec::new();
|
||||||
|
for (i, chunk) in chunks.iter().enumerate() {
|
||||||
|
let result = fixed_filter_json_tool_calls(chunk);
|
||||||
|
println!("Chunk {}: {:?} -> {:?}", i, chunk, result);
|
||||||
|
results.push(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
let final_result: String = results.join("");
|
||||||
|
println!("Final result: {:?}", final_result);
|
||||||
|
|
||||||
|
// The truncated JSON should be discarded when the complete one appears
|
||||||
|
// Both JSONs should be filtered out, leaving only the text
|
||||||
|
let expected = "Some text\n\nMore text";
|
||||||
|
assert_eq!(
|
||||||
|
final_result, expected,
|
||||||
|
"Failed to handle truncated JSON followed by complete JSON"
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ pub mod error_handling;
|
|||||||
pub mod project;
|
pub mod project;
|
||||||
pub mod task_result;
|
pub mod task_result;
|
||||||
pub mod ui_writer;
|
pub mod ui_writer;
|
||||||
|
pub mod code_search;
|
||||||
pub use task_result::TaskResult;
|
pub use task_result::TaskResult;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -680,6 +681,8 @@ pub struct Agent<W: UiWriter> {
|
|||||||
providers: ProviderRegistry,
|
providers: ProviderRegistry,
|
||||||
context_window: ContextWindow,
|
context_window: ContextWindow,
|
||||||
thinning_events: Vec<usize>, // chars saved per thinning event
|
thinning_events: Vec<usize>, // chars saved per thinning event
|
||||||
|
pending_90_summarization: bool, // flag to trigger summarization at 90%
|
||||||
|
auto_compact: bool, // whether to auto-compact at 90% before tool calls
|
||||||
summarization_events: Vec<usize>, // chars saved per summarization event
|
summarization_events: Vec<usize>, // chars saved per summarization event
|
||||||
first_token_times: Vec<Duration>, // time to first token for each completion
|
first_token_times: Vec<Duration>, // time to first token for each completion
|
||||||
config: Config,
|
config: Config,
|
||||||
@@ -785,7 +788,6 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
// Register embedded provider if configured AND it's the default provider
|
// Register embedded provider if configured AND it's the default provider
|
||||||
if let Some(embedded_config) = &config.providers.embedded {
|
if let Some(embedded_config) = &config.providers.embedded {
|
||||||
if providers_to_register.contains(&"embedded".to_string()) {
|
if providers_to_register.contains(&"embedded".to_string()) {
|
||||||
info!("Initializing embedded provider");
|
|
||||||
let embedded_provider = g3_providers::EmbeddedProvider::new(
|
let embedded_provider = g3_providers::EmbeddedProvider::new(
|
||||||
embedded_config.model_path.clone(),
|
embedded_config.model_path.clone(),
|
||||||
embedded_config.model_type.clone(),
|
embedded_config.model_type.clone(),
|
||||||
@@ -796,15 +798,12 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
embedded_config.threads,
|
embedded_config.threads,
|
||||||
)?;
|
)?;
|
||||||
providers.register(embedded_provider);
|
providers.register(embedded_provider);
|
||||||
} else {
|
|
||||||
info!("Embedded provider configured but not needed, skipping initialization");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register OpenAI provider if configured AND it's the default provider
|
// Register OpenAI provider if configured AND it's the default provider
|
||||||
if let Some(openai_config) = &config.providers.openai {
|
if let Some(openai_config) = &config.providers.openai {
|
||||||
if providers_to_register.contains(&"openai".to_string()) {
|
if providers_to_register.contains(&"openai".to_string()) {
|
||||||
info!("Initializing OpenAI provider");
|
|
||||||
let openai_provider = g3_providers::OpenAIProvider::new(
|
let openai_provider = g3_providers::OpenAIProvider::new(
|
||||||
openai_config.api_key.clone(),
|
openai_config.api_key.clone(),
|
||||||
Some(openai_config.model.clone()),
|
Some(openai_config.model.clone()),
|
||||||
@@ -813,15 +812,12 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
openai_config.temperature,
|
openai_config.temperature,
|
||||||
)?;
|
)?;
|
||||||
providers.register(openai_provider);
|
providers.register(openai_provider);
|
||||||
} else {
|
|
||||||
info!("OpenAI provider configured but not needed, skipping initialization");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register Anthropic provider if configured AND it's the default provider
|
// Register Anthropic provider if configured AND it's the default provider
|
||||||
if let Some(anthropic_config) = &config.providers.anthropic {
|
if let Some(anthropic_config) = &config.providers.anthropic {
|
||||||
if providers_to_register.contains(&"anthropic".to_string()) {
|
if providers_to_register.contains(&"anthropic".to_string()) {
|
||||||
info!("Initializing Anthropic provider");
|
|
||||||
let anthropic_provider = g3_providers::AnthropicProvider::new(
|
let anthropic_provider = g3_providers::AnthropicProvider::new(
|
||||||
anthropic_config.api_key.clone(),
|
anthropic_config.api_key.clone(),
|
||||||
Some(anthropic_config.model.clone()),
|
Some(anthropic_config.model.clone()),
|
||||||
@@ -829,15 +825,12 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
anthropic_config.temperature,
|
anthropic_config.temperature,
|
||||||
)?;
|
)?;
|
||||||
providers.register(anthropic_provider);
|
providers.register(anthropic_provider);
|
||||||
} else {
|
|
||||||
info!("Anthropic provider configured but not needed, skipping initialization");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register Databricks provider if configured AND it's the default provider
|
// Register Databricks provider if configured AND it's the default provider
|
||||||
if let Some(databricks_config) = &config.providers.databricks {
|
if let Some(databricks_config) = &config.providers.databricks {
|
||||||
if providers_to_register.contains(&"databricks".to_string()) {
|
if providers_to_register.contains(&"databricks".to_string()) {
|
||||||
info!("Initializing Databricks provider");
|
|
||||||
|
|
||||||
let databricks_provider = if let Some(token) = &databricks_config.token {
|
let databricks_provider = if let Some(token) = &databricks_config.token {
|
||||||
// Use token-based authentication
|
// Use token-based authentication
|
||||||
@@ -860,8 +853,19 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
providers.register(databricks_provider);
|
providers.register(databricks_provider);
|
||||||
} else {
|
}
|
||||||
info!("Databricks provider configured but not needed, skipping initialization");
|
}
|
||||||
|
|
||||||
|
// Register Ollama provider if configured AND it's the default provider
|
||||||
|
if let Some(ollama_config) = &config.providers.ollama {
|
||||||
|
if providers_to_register.contains(&"ollama".to_string()) {
|
||||||
|
let ollama_provider = g3_providers::OllamaProvider::new(
|
||||||
|
ollama_config.model.clone(),
|
||||||
|
ollama_config.base_url.clone(),
|
||||||
|
ollama_config.max_tokens,
|
||||||
|
ollama_config.temperature,
|
||||||
|
)?;
|
||||||
|
providers.register(ollama_provider);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -884,16 +888,12 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
content: readme,
|
content: readme,
|
||||||
};
|
};
|
||||||
context_window.add_message(readme_message);
|
context_window.add_message(readme_message);
|
||||||
info!("Added project README to context window");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize computer controller if enabled
|
// Initialize computer controller if enabled
|
||||||
let computer_controller = if config.computer_control.enabled {
|
let computer_controller = if config.computer_control.enabled {
|
||||||
match g3_computer_control::create_controller() {
|
match g3_computer_control::create_controller() {
|
||||||
Ok(controller) => {
|
Ok(controller) => Some(controller),
|
||||||
info!("Computer control enabled");
|
|
||||||
Some(controller)
|
|
||||||
}
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Failed to initialize computer control: {}", e);
|
warn!("Failed to initialize computer control: {}", e);
|
||||||
None
|
None
|
||||||
@@ -909,6 +909,8 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
Ok(Self {
|
Ok(Self {
|
||||||
providers,
|
providers,
|
||||||
context_window,
|
context_window,
|
||||||
|
auto_compact: config.agent.auto_compact,
|
||||||
|
pending_90_summarization: false,
|
||||||
thinning_events: Vec::new(),
|
thinning_events: Vec::new(),
|
||||||
summarization_events: Vec::new(),
|
summarization_events: Vec::new(),
|
||||||
first_token_times: Vec::new(),
|
first_token_times: Vec::new(),
|
||||||
@@ -973,6 +975,30 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
16384 // Conservative default for other Databricks models
|
16384 // Conservative default for other Databricks models
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
"ollama" => {
|
||||||
|
// Ollama model context windows based on model name
|
||||||
|
if model_name.contains("qwen3-coder") {
|
||||||
|
262144 // Qwen3-coder supports 256k context
|
||||||
|
} else if model_name.contains("qwen") {
|
||||||
|
32768 // Qwen2.5 supports 32k context
|
||||||
|
} else if model_name.contains("gpt-oss") {
|
||||||
|
131072 // GPT-OSS supports 128k context
|
||||||
|
} else if model_name.contains("llama3") || model_name.contains("llama-3") {
|
||||||
|
if model_name.contains("3.2") || model_name.contains("3.1") {
|
||||||
|
128000 // Llama 3.1/3.2 support 128k context
|
||||||
|
} else {
|
||||||
|
8192 // Llama 3.0
|
||||||
|
}
|
||||||
|
} else if model_name.contains("mistral") || model_name.contains("mixtral") {
|
||||||
|
32768 // Mistral/Mixtral support 32k
|
||||||
|
} else if model_name.contains("gemma") {
|
||||||
|
8192 // Gemma 2
|
||||||
|
} else if model_name.contains("phi") {
|
||||||
|
4096 // Phi-3
|
||||||
|
} else {
|
||||||
|
8192 // Conservative default for Ollama models
|
||||||
|
}
|
||||||
|
}
|
||||||
_ => config.agent.max_context_length as u32,
|
_ => config.agent.max_context_length as u32,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -1104,6 +1130,18 @@ IMPORTANT: You must call tools to achieve goals. When you receive a request:
|
|||||||
For shell commands: Use the shell tool with the exact command needed. Avoid commands that produce a large amount of output, and consider piping those outputs to files. Example: If asked to list files, immediately call the shell tool with command parameter \"ls\".
|
For shell commands: Use the shell tool with the exact command needed. Avoid commands that produce a large amount of output, and consider piping those outputs to files. Example: If asked to list files, immediately call the shell tool with command parameter \"ls\".
|
||||||
If you create temporary files for verification, place these in a subdir named 'tmp'. Do NOT pollute the current dir.
|
If you create temporary files for verification, place these in a subdir named 'tmp'. Do NOT pollute the current dir.
|
||||||
|
|
||||||
|
For reading files, prioritize use of code_search tool use with multiple search requests per call instead of read_file, if it makes sense.
|
||||||
|
|
||||||
|
Additional examples for the 'code_search' tool:
|
||||||
|
- Example for pattern mode: {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"find_functions\", \"mode\": \"pattern\", \"pattern\": \"fn $NAME($$$ARGS) { $$$ }\", \"language\": \"rust\", \"paths\": [\"src/\"]}]}}
|
||||||
|
- Example for YAML mode: {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"find_async\", \"mode\": \"yaml\", \"rule_yaml\": \"id: async-fn\nlanguage: Rust\nrule:\n pattern: async fn $NAME($$$) { $$$ }\"}]}}
|
||||||
|
- Example for multiple searches: {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"funcs\", \"mode\": \"pattern\", \"pattern\": \"fn $NAME\", \"language\": \"rust\"}, {\"name\": \"structs\", \"mode\": \"pattern\", \"pattern\": \"struct $NAME\", \"language\": \"rust\"}]}}
|
||||||
|
- Example for passing optional args like \"context\": {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"funcs\", \"mode\": \"pattern\", \"context\": 3, \"pattern\": \"fn $NAME\", \"language\": \"rust\"}]}
|
||||||
|
- Common optional args for searches:
|
||||||
|
- \"context\": 3 (show surrounding lines),
|
||||||
|
- \"json_style\": \"stream\" (for large results)
|
||||||
|
|
||||||
|
|
||||||
IMPORTANT: If the user asks you to just respond with text (like \"just say hello\" or \"tell me about X\"), do NOT use tools. Simply respond with the requested text directly. Only use tools when you need to execute commands or complete tasks that require action.
|
IMPORTANT: If the user asks you to just respond with text (like \"just say hello\" or \"tell me about X\"), do NOT use tools. Simply respond with the requested text directly. Only use tools when you need to execute commands or complete tasks that require action.
|
||||||
|
|
||||||
When taking screenshots of specific windows (like \"my Safari window\" or \"my terminal\"), ALWAYS use list_windows first to identify the correct window ID, then use take_screenshot with the window_id parameter.
|
When taking screenshots of specific windows (like \"my Safari window\" or \"my terminal\"), ALWAYS use list_windows first to identify the correct window ID, then use take_screenshot with the window_id parameter.
|
||||||
@@ -1153,6 +1191,8 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
|
|
||||||
# Available Tools
|
# Available Tools
|
||||||
|
|
||||||
|
Short description for providers without native calling specs:
|
||||||
|
|
||||||
- **shell**: Execute shell commands
|
- **shell**: Execute shell commands
|
||||||
- Format: {\"tool\": \"shell\", \"args\": {\"command\": \"your_command_here\"}
|
- Format: {\"tool\": \"shell\", \"args\": {\"command\": \"your_command_here\"}
|
||||||
- Example: {\"tool\": \"shell\", \"args\": {\"command\": \"ls ~/Downloads\"}
|
- Example: {\"tool\": \"shell\", \"args\": {\"command\": \"ls ~/Downloads\"}
|
||||||
@@ -1181,13 +1221,41 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
- Format: {\"tool\": \"todo_write\", \"args\": {\"content\": \"- [ ] Task 1\\n- [ ] Task 2\"}}
|
- Format: {\"tool\": \"todo_write\", \"args\": {\"content\": \"- [ ] Task 1\\n- [ ] Task 2\"}}
|
||||||
- Example: {\"tool\": \"todo_write\", \"args\": {\"content\": \"- [ ] Implement feature\\n - [ ] Write tests\\n - [ ] Run tests\"}}
|
- Example: {\"tool\": \"todo_write\", \"args\": {\"content\": \"- [ ] Implement feature\\n - [ ] Write tests\\n - [ ] Run tests\"}}
|
||||||
|
|
||||||
|
- **code_search**: Batch syntax-aware searches via ast-grep. Supports up to 20 pattern or YAML-rule searches in parallel.
|
||||||
|
- Format: {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"search_label\", \"mode\": \"pattern|yaml\", ...}], \"max_concurrency\": 4, \"max_matches_per_search\": 500}}
|
||||||
|
- Example for pattern mode: {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"find_functions\", \"mode\": \"pattern\", \"pattern\": \"fn $NAME($$$ARGS) { $$$ }\", \"language\": \"rust\", \"paths\": [\"src/\"]}]}}
|
||||||
|
- Example for YAML mode: {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"find_async\", \"mode\": \"yaml\", \"rule_yaml\": \"id: async-fn\nlanguage: Rust\nrule:\n pattern: async fn $NAME($$$) { $$$ }\"}]}}
|
||||||
|
- Example for multiple searches: {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"funcs\", \"mode\": \"pattern\", \"pattern\": \"fn $NAME\", \"language\": \"rust\"}, {\"name\": \"structs\", \"mode\": \"pattern\", \"pattern\": \"struct $NAME\", \"language\": \"rust\"}]}}
|
||||||
|
- Example for passing optional args like \"context\": {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"funcs\", \"mode\": \"pattern\", \"context\": 3, \"pattern\": \"fn $NAME\", \"language\": \"rust\"}]}
|
||||||
|
- Common optional args for searches:
|
||||||
|
- \"context\": 3 (show surrounding lines),
|
||||||
|
- \"json_style\": \"stream\" (for large results)
|
||||||
|
|
||||||
# Instructions
|
# Instructions
|
||||||
|
|
||||||
1. Analyze the request and break down into smaller tasks if appropriate
|
1. Analyze the request and break down into smaller tasks if appropriate
|
||||||
2. Execute ONE tool at a time
|
2. Execute ONE tool at a time. An exception exists for when you're writing files. See below.
|
||||||
3. STOP when the original request was satisfied
|
3. STOP when the original request was satisfied
|
||||||
4. Call the final_output tool when done
|
4. Call the final_output tool when done
|
||||||
|
|
||||||
|
For reading files, prioritize use of code_search tool use with multiple search requests per call instead of read_file, if it makes sense.
|
||||||
|
|
||||||
|
Exception to using ONE tool at a time:
|
||||||
|
If all you’re doing is WRITING files, and you don’t need to do anything else between each step.
|
||||||
|
You can issue MULTIPLE write_file tool calls in a request, however you may ONLY make a SINGLE write_file call for any file in that request.
|
||||||
|
For example you may call:
|
||||||
|
[START OF REQUEST]
|
||||||
|
write_file(\"helper.rs\", \"...\")
|
||||||
|
write_file(\"file2.txt\", \"...\")
|
||||||
|
[DONE]
|
||||||
|
|
||||||
|
But NOT:
|
||||||
|
[START OF REQUEST]
|
||||||
|
write_file(\"helper.rs\", \"...\")
|
||||||
|
write_file(\"file2.txt\", \"...\")
|
||||||
|
write_file(\"helper.rs\", \"...\")
|
||||||
|
[DONE]
|
||||||
|
|
||||||
# Task Management
|
# Task Management
|
||||||
|
|
||||||
Use todo_read and todo_write for tasks with 3+ steps, multiple files/components, or uncertain scope.
|
Use todo_read and todo_write for tasks with 3+ steps, multiple files/components, or uncertain scope.
|
||||||
@@ -1312,6 +1380,19 @@ Template:
|
|||||||
// Save context window at the end of successful interaction
|
// Save context window at the end of successful interaction
|
||||||
self.save_context_window("completed");
|
self.save_context_window("completed");
|
||||||
|
|
||||||
|
// Check if we need to do 90% auto-compaction
|
||||||
|
if self.pending_90_summarization {
|
||||||
|
self.ui_writer.print_context_status(
|
||||||
|
"\n⚡ Context window reached 90% - auto-compacting...\n"
|
||||||
|
);
|
||||||
|
if let Err(e) = self.force_summarize().await {
|
||||||
|
warn!("Failed to auto-compact at 90%: {}", e);
|
||||||
|
} else {
|
||||||
|
self.ui_writer.println("");
|
||||||
|
}
|
||||||
|
self.pending_90_summarization = false;
|
||||||
|
}
|
||||||
|
|
||||||
// Return the task result which already includes timing if needed
|
// Return the task result which already includes timing if needed
|
||||||
Ok(task_result)
|
Ok(task_result)
|
||||||
}
|
}
|
||||||
@@ -1860,6 +1941,64 @@ Template:
|
|||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// Add code_search tool
|
||||||
|
tools.push(Tool {
|
||||||
|
name: "code_search".to_string(),
|
||||||
|
description: "Batch syntax-aware searches via ast-grep. Supports up to 20 pattern or YAML-rule searches in parallel; returns JSON matches (stream-collated).".to_string(),
|
||||||
|
input_schema: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"searches": {
|
||||||
|
"type": "array",
|
||||||
|
"maxItems": 20,
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": { "type": "string", "description": "Label for this search." },
|
||||||
|
"mode": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["pattern", "yaml"],
|
||||||
|
"description": "`pattern` uses `ast-grep run`; `yaml` uses `ast-grep scan --inline-rules`."
|
||||||
|
},
|
||||||
|
// pattern mode (fast one-off)
|
||||||
|
"pattern": { "type": "string", "description": "ast-grep pattern code (e.g., \"async fn $NAME($$$ARGS) { $$$ }\")"},
|
||||||
|
"language": { "type": "string", "description": "Optional language for pattern mode; ast-grep may infer from file extensions if omitted." },
|
||||||
|
// yaml mode (full rule object)
|
||||||
|
"rule_yaml": { "type": "string", "description": "A full YAML rule object text. Must include `id`, `language`, and `rule`." },
|
||||||
|
// targeting
|
||||||
|
"paths": { "type": "array", "items": { "type": "string" }, "description": "Paths/dirs to search. Defaults to current dir if empty." },
|
||||||
|
"globs": { "type": "array", "items": { "type": "string" }, "description": "Optional include/exclude globs for CLI --globs." },
|
||||||
|
// result formatting & performance knobs
|
||||||
|
"json_style": { "type": "string", "enum": ["pretty","stream","compact"], "default": "stream", "description": "Use stream for large codebases." },
|
||||||
|
"context": { "type": "integer", "minimum": 0, "maximum": 20, "default": 0, "description": "CLI -C context lines in text output; also affects JSON `lines` field." },
|
||||||
|
"threads": { "type": "integer", "minimum": 1, "description": "Optional override for ast-grep -j (per process)." },
|
||||||
|
"include_metadata": { "type": "boolean", "default": false, "description": "If yaml mode and rule has metadata, add --include-metadata." },
|
||||||
|
// robustness
|
||||||
|
"no_ignore": {
|
||||||
|
"type": "array",
|
||||||
|
"items": { "type": "string", "enum": ["hidden","dot","exclude","global","parent","vcs"] },
|
||||||
|
"description": "Forwarded to --no-ignore to bypass ignore files/hidden."
|
||||||
|
},
|
||||||
|
// severity overrides for yaml mode
|
||||||
|
"severity": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": { "type": "string", "enum": ["error","warning","info","hint","off"] },
|
||||||
|
"description": "Optional map<ruleId, severity> -> passed via --error/--warning/--info/--hint/--off."
|
||||||
|
},
|
||||||
|
// per-search timeout seconds (default 60)
|
||||||
|
"timeout_secs": { "type": "integer", "minimum": 1, "default": 60 }
|
||||||
|
},
|
||||||
|
"required": ["name","mode"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
// global concurrency & truncation
|
||||||
|
"max_concurrency": { "type": "integer", "minimum": 1, "default": 4 },
|
||||||
|
"max_matches_per_search": { "type": "integer", "minimum": 1, "default": 500 }
|
||||||
|
},
|
||||||
|
"required": ["searches"]
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
// Add WebDriver tools if enabled
|
// Add WebDriver tools if enabled
|
||||||
if enable_webdriver {
|
if enable_webdriver {
|
||||||
@@ -2550,6 +2689,14 @@ Template:
|
|||||||
if let Some(tool_call) = completed_tools.into_iter().next() {
|
if let Some(tool_call) = completed_tools.into_iter().next() {
|
||||||
debug!("Processing completed tool call: {:?}", tool_call);
|
debug!("Processing completed tool call: {:?}", tool_call);
|
||||||
|
|
||||||
|
// Check if we should auto-compact at 90% BEFORE executing the tool
|
||||||
|
// We need to do this before any borrows of self
|
||||||
|
if self.auto_compact && self.context_window.percentage_used() >= 90.0 {
|
||||||
|
// Set flag to trigger summarization after this turn completes
|
||||||
|
// We can't do it now due to borrow checker constraints
|
||||||
|
self.pending_90_summarization = true;
|
||||||
|
}
|
||||||
|
|
||||||
// Check if we should thin the context BEFORE executing the tool
|
// Check if we should thin the context BEFORE executing the tool
|
||||||
if self.context_window.should_thin() {
|
if self.context_window.should_thin() {
|
||||||
let (thin_summary, chars_saved) = self.context_window.thin_context();
|
let (thin_summary, chars_saved) = self.context_window.thin_context();
|
||||||
@@ -2558,6 +2705,7 @@ Template:
|
|||||||
self.ui_writer.print_context_thinning(&thin_summary);
|
self.ui_writer.print_context_thinning(&thin_summary);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Track what we've already displayed before getting new text
|
// Track what we've already displayed before getting new text
|
||||||
// This prevents re-displaying old content after tool execution
|
// This prevents re-displaying old content after tool execution
|
||||||
let already_displayed_chars = current_response.chars().count();
|
let already_displayed_chars = current_response.chars().count();
|
||||||
@@ -2596,7 +2744,8 @@ Template:
|
|||||||
String::new()
|
String::new()
|
||||||
};
|
};
|
||||||
|
|
||||||
if !new_content.trim().is_empty() {
|
// Don't display text before final_output - it will be in the summary
|
||||||
|
if !new_content.trim().is_empty() && tool_call.tool != "final_output" {
|
||||||
#[allow(unused_assignments)]
|
#[allow(unused_assignments)]
|
||||||
if !response_started {
|
if !response_started {
|
||||||
self.ui_writer.print_agent_prompt();
|
self.ui_writer.print_agent_prompt();
|
||||||
@@ -2674,7 +2823,13 @@ Template:
|
|||||||
));
|
));
|
||||||
|
|
||||||
// Display tool execution result with proper indentation
|
// Display tool execution result with proper indentation
|
||||||
if tool_call.tool != "final_output" {
|
if tool_call.tool == "final_output" {
|
||||||
|
// For final_output, display the summary without truncation
|
||||||
|
for line in tool_result.lines() {
|
||||||
|
self.ui_writer.update_tool_output_line(line);
|
||||||
|
}
|
||||||
|
self.ui_writer.println("");
|
||||||
|
} else {
|
||||||
let output_lines: Vec<&str> = tool_result.lines().collect();
|
let output_lines: Vec<&str> = tool_result.lines().collect();
|
||||||
|
|
||||||
// Check if UI wants full output (machine mode) or truncated (human mode)
|
// Check if UI wants full output (machine mode) or truncated (human mode)
|
||||||
@@ -2722,13 +2877,9 @@ Template:
|
|||||||
|
|
||||||
// Check if this was a final_output tool call
|
// Check if this was a final_output tool call
|
||||||
if tool_call.tool == "final_output" {
|
if tool_call.tool == "final_output" {
|
||||||
// Don't add final_display_content here - it was already added before tool execution
|
// The summary was displayed above when we printed the tool result
|
||||||
// Adding it again would duplicate the output
|
// Add it to full_response so it's included in the TaskResult
|
||||||
if let Some(summary) = tool_call.args.get("summary") {
|
full_response.push_str(&tool_result);
|
||||||
if let Some(summary_str) = summary.as_str() {
|
|
||||||
full_response.push_str(&format!("\n\n{}", summary_str));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.ui_writer.println("");
|
self.ui_writer.println("");
|
||||||
let _ttft =
|
let _ttft =
|
||||||
first_token_time.unwrap_or_else(|| stream_start.elapsed());
|
first_token_time.unwrap_or_else(|| stream_start.elapsed());
|
||||||
@@ -2760,7 +2911,7 @@ Template:
|
|||||||
|
|
||||||
// Add the tool call and result to the context window using RAW unfiltered content
|
// Add the tool call and result to the context window using RAW unfiltered content
|
||||||
// This ensures the log file contains the true raw content including JSON tool calls
|
// This ensures the log file contains the true raw content including JSON tool calls
|
||||||
let tool_message = if !full_response.contains(final_display_content) && !raw_content_for_log.trim().is_empty() {
|
let tool_message = if !raw_content_for_log.trim().is_empty() {
|
||||||
Message {
|
Message {
|
||||||
role: MessageRole::Assistant,
|
role: MessageRole::Assistant,
|
||||||
content: format!(
|
content: format!(
|
||||||
@@ -2771,7 +2922,7 @@ Template:
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// If we've already added the text or there's no text, just include the tool call
|
// No text content before tool call, just include the tool call
|
||||||
Message {
|
Message {
|
||||||
role: MessageRole::Assistant,
|
role: MessageRole::Assistant,
|
||||||
content: format!(
|
content: format!(
|
||||||
@@ -2796,18 +2947,22 @@ Template:
|
|||||||
request.tools = Some(Self::create_tool_definitions(self.config.webdriver.enabled, self.config.macax.enabled, self.config.computer_control.enabled));
|
request.tools = Some(Self::create_tool_definitions(self.config.webdriver.enabled, self.config.macax.enabled, self.config.computer_control.enabled));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only add to full_response if we haven't already added it
|
// DO NOT add final_display_content to full_response here!
|
||||||
if !full_response.contains(final_display_content) {
|
// The content was already displayed during streaming and added to current_response.
|
||||||
full_response.push_str(final_display_content);
|
// Adding it again would cause duplication when the agent message is printed.
|
||||||
}
|
// The only time we should add to full_response is:
|
||||||
|
// 1. For final_output tool (handled separately)
|
||||||
|
// 2. At the end when no tools were executed (handled in the "no tool executed" branch)
|
||||||
|
|
||||||
tool_executed = true;
|
tool_executed = true;
|
||||||
|
|
||||||
// Reset the JSON tool call filter state after each tool execution
|
// Reset the JSON tool call filter state after each tool execution
|
||||||
// This ensures the filter doesn't stay in suppression mode for subsequent streaming content
|
// This ensures the filter doesn't stay in suppression mode for subsequent streaming content
|
||||||
fixed_filter_json::reset_fixed_json_tool_state();
|
fixed_filter_json::reset_fixed_json_tool_state();
|
||||||
|
|
||||||
// Reset parser for next iteration
|
// Reset parser for next iteration - this clears the text buffer
|
||||||
parser.reset();
|
parser.reset();
|
||||||
|
|
||||||
// Clear current_response for next iteration to prevent buffered text
|
// Clear current_response for next iteration to prevent buffered text
|
||||||
// from being incorrectly displayed after tool execution
|
// from being incorrectly displayed after tool execution
|
||||||
current_response.clear();
|
current_response.clear();
|
||||||
@@ -2883,7 +3038,8 @@ Template:
|
|||||||
"Using filtered parser text as last resort: {} chars",
|
"Using filtered parser text as last resort: {} chars",
|
||||||
filtered_text.len()
|
filtered_text.len()
|
||||||
);
|
);
|
||||||
current_response = filtered_text;
|
// Note: This assignment is currently unused but kept for potential future use
|
||||||
|
let _ = filtered_text;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2985,11 +3141,10 @@ Template:
|
|||||||
|
|
||||||
// Set full_response to current_response (don't append)
|
// Set full_response to current_response (don't append)
|
||||||
// current_response already contains everything that was displayed
|
// current_response already contains everything that was displayed
|
||||||
// Appending would duplicate the output
|
// Don't set full_response here - it would duplicate the output
|
||||||
if !current_response.is_empty() && full_response.is_empty() {
|
// The text was already displayed during streaming
|
||||||
full_response = current_response.clone();
|
// Return empty string to avoid duplication
|
||||||
debug!("Set full_response from current_response (no tool): {} chars", full_response.len());
|
full_response = String::new();
|
||||||
}
|
|
||||||
|
|
||||||
self.ui_writer.println("");
|
self.ui_writer.println("");
|
||||||
let _ttft =
|
let _ttft =
|
||||||
@@ -3017,17 +3172,33 @@ Template:
|
|||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// Capture detailed streaming error information
|
// Capture detailed streaming error information
|
||||||
let error_details =
|
let error_msg = e.to_string();
|
||||||
format!("Streaming error at chunk {}: {}", chunks_received + 1, e);
|
let error_details = format!("Streaming error at chunk {}: {}", chunks_received + 1, error_msg);
|
||||||
error!("{}", error_details);
|
|
||||||
error!("Error type: {}", std::any::type_name_of_val(&e));
|
error!("Error type: {}", std::any::type_name_of_val(&e));
|
||||||
error!("Parser state at error: text_buffer_len={}, native_tool_calls={}, message_stopped={}",
|
error!("Parser state at error: text_buffer_len={}, native_tool_calls={}, message_stopped={}",
|
||||||
parser.text_buffer_len(), parser.native_tool_calls.len(), parser.is_message_stopped());
|
parser.text_buffer_len(), parser.native_tool_calls.len(), parser.is_message_stopped());
|
||||||
|
|
||||||
// Store the error for potential logging later
|
// Store the error for potential logging later
|
||||||
_last_error = Some(error_details);
|
_last_error = Some(error_details.clone());
|
||||||
|
|
||||||
|
// Check if this is a recoverable connection error
|
||||||
|
let is_connection_error = error_msg.contains("unexpected EOF")
|
||||||
|
|| error_msg.contains("connection")
|
||||||
|
|| error_msg.contains("chunk size line")
|
||||||
|
|| error_msg.contains("body error");
|
||||||
|
|
||||||
|
if is_connection_error {
|
||||||
|
warn!("Connection error at chunk {}, treating as end of stream", chunks_received + 1);
|
||||||
|
// If we have any content or tool calls, treat this as a graceful end
|
||||||
|
if chunks_received > 0 && (!parser.get_text_content().is_empty() || parser.native_tool_calls.len() > 0) {
|
||||||
|
warn!("Stream terminated unexpectedly but we have content, continuing");
|
||||||
|
break; // Break to process what we have
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if tool_executed {
|
if tool_executed {
|
||||||
|
error!("{}", error_details);
|
||||||
warn!("Stream error after tool execution, attempting to continue");
|
warn!("Stream error after tool execution, attempting to continue");
|
||||||
break; // Break to outer loop to start new stream
|
break; // Break to outer loop to start new stream
|
||||||
} else {
|
} else {
|
||||||
@@ -4431,6 +4602,41 @@ Template:
|
|||||||
Ok("❌ Computer control not enabled. Set computer_control.enabled = true in config.".to_string())
|
Ok("❌ Computer control not enabled. Set computer_control.enabled = true in config.".to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
"code_search" => {
|
||||||
|
debug!("Processing code_search tool call");
|
||||||
|
|
||||||
|
// Parse the request
|
||||||
|
let request: crate::code_search::CodeSearchRequest = match serde_json::from_value(tool_call.args.clone()) {
|
||||||
|
Ok(req) => req,
|
||||||
|
Err(e) => {
|
||||||
|
return Ok(format!("❌ Invalid code_search arguments: {}", e));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Execute the code search
|
||||||
|
match crate::code_search::execute_code_search(request).await {
|
||||||
|
Ok(response) => {
|
||||||
|
// Serialize the response to JSON
|
||||||
|
match serde_json::to_string_pretty(&response) {
|
||||||
|
Ok(json_output) => {
|
||||||
|
Ok(format!("✅ Code search completed\n{}", json_output))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
Ok(format!("❌ Failed to serialize response: {}", e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// Check if it's an ast-grep not found error and provide helpful message
|
||||||
|
let error_msg = e.to_string();
|
||||||
|
if error_msg.contains("ast-grep not found") {
|
||||||
|
Ok(format!("❌ {}", error_msg))
|
||||||
|
} else {
|
||||||
|
Ok(format!("❌ Code search failed: {}", error_msg))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
_ => {
|
_ => {
|
||||||
warn!("Unknown tool: {}", tool_call.tool);
|
warn!("Unknown tool: {}", tool_call.tool);
|
||||||
Ok(format!("❓ Unknown tool: {}", tool_call.tool))
|
Ok(format!("❓ Unknown tool: {}", tool_call.tool))
|
||||||
|
|||||||
@@ -276,6 +276,7 @@ impl AnthropicProvider {
|
|||||||
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
|
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
|
||||||
let mut accumulated_usage: Option<Usage> = None;
|
let mut accumulated_usage: Option<Usage> = None;
|
||||||
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
|
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
|
||||||
|
let mut message_stopped = false; // Track if we've received message_stop
|
||||||
|
|
||||||
while let Some(chunk_result) = stream.next().await {
|
while let Some(chunk_result) = stream.next().await {
|
||||||
match chunk_result {
|
match chunk_result {
|
||||||
@@ -316,6 +317,12 @@ impl AnthropicProvider {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we've already sent the final chunk, skip processing more events
|
||||||
|
if message_stopped {
|
||||||
|
debug!("Skipping event after message_stop: {}", line);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// Parse Server-Sent Events format
|
// Parse Server-Sent Events format
|
||||||
if let Some(data) = line.strip_prefix("data: ") {
|
if let Some(data) = line.strip_prefix("data: ") {
|
||||||
if data == "[DONE]" {
|
if data == "[DONE]" {
|
||||||
@@ -451,6 +458,7 @@ impl AnthropicProvider {
|
|||||||
}
|
}
|
||||||
"message_stop" => {
|
"message_stop" => {
|
||||||
debug!("Received message stop event");
|
debug!("Received message stop event");
|
||||||
|
message_stopped = true;
|
||||||
let final_chunk = CompletionChunk {
|
let final_chunk = CompletionChunk {
|
||||||
content: String::new(),
|
content: String::new(),
|
||||||
finished: true,
|
finished: true,
|
||||||
@@ -460,7 +468,8 @@ impl AnthropicProvider {
|
|||||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||||
debug!("Receiver dropped, stopping stream");
|
debug!("Receiver dropped, stopping stream");
|
||||||
}
|
}
|
||||||
return accumulated_usage;
|
// Don't return here - let the stream naturally exhaust
|
||||||
|
// This prevents dropping the sender prematurely
|
||||||
}
|
}
|
||||||
"error" => {
|
"error" => {
|
||||||
if let Some(error) = event.error {
|
if let Some(error) = event.error {
|
||||||
@@ -468,7 +477,7 @@ impl AnthropicProvider {
|
|||||||
let _ = tx
|
let _ = tx
|
||||||
.send(Err(anyhow!("Anthropic API error: {:?}", error)))
|
.send(Err(anyhow!("Anthropic API error: {:?}", error)))
|
||||||
.await;
|
.await;
|
||||||
return accumulated_usage;
|
break; // Break to let stream exhaust naturally
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
@@ -487,7 +496,10 @@ impl AnthropicProvider {
|
|||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Stream error: {}", e);
|
error!("Stream error: {}", e);
|
||||||
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
||||||
return accumulated_usage;
|
// Don't return here either - let the stream exhaust naturally
|
||||||
|
// The error has been sent to the receiver, so it will handle it
|
||||||
|
// Breaking here ensures we clean up properly
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -298,6 +298,7 @@ impl DatabricksProvider {
|
|||||||
let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> =
|
let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> =
|
||||||
std::collections::HashMap::new(); // index -> (id, name, args)
|
std::collections::HashMap::new(); // index -> (id, name, args)
|
||||||
let mut incomplete_data_line = String::new(); // Buffer for incomplete data: lines
|
let mut incomplete_data_line = String::new(); // Buffer for incomplete data: lines
|
||||||
|
let mut chunk_count = 0;
|
||||||
let accumulated_usage: Option<Usage> = None;
|
let accumulated_usage: Option<Usage> = None;
|
||||||
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
|
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
|
||||||
|
|
||||||
@@ -305,6 +306,8 @@ impl DatabricksProvider {
|
|||||||
match chunk_result {
|
match chunk_result {
|
||||||
Ok(chunk) => {
|
Ok(chunk) => {
|
||||||
// Debug: Log raw bytes received
|
// Debug: Log raw bytes received
|
||||||
|
chunk_count += 1;
|
||||||
|
debug!("Processing chunk #{}", chunk_count);
|
||||||
debug!("Raw SSE bytes received: {} bytes", chunk.len());
|
debug!("Raw SSE bytes received: {} bytes", chunk.len());
|
||||||
|
|
||||||
// Append new bytes to our buffer
|
// Append new bytes to our buffer
|
||||||
@@ -589,13 +592,39 @@ impl DatabricksProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Stream error: {}", e);
|
error!("Stream error at chunk {}: {}", chunk_count, e);
|
||||||
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
|
||||||
|
// Check if this is a connection error that might be recoverable
|
||||||
|
let error_msg = e.to_string();
|
||||||
|
if error_msg.contains("unexpected EOF") || error_msg.contains("connection") {
|
||||||
|
warn!("Connection terminated unexpectedly at chunk {}, treating as end of stream", chunk_count);
|
||||||
|
// Don't send error, just break and finalize
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
||||||
|
}
|
||||||
return accumulated_usage;
|
return accumulated_usage;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Log final state
|
||||||
|
debug!("Stream ended after {} chunks", chunk_count);
|
||||||
|
debug!("Final state: buffer_len={}, incomplete_data_line_len={}, byte_buffer_len={}",
|
||||||
|
buffer.len(), incomplete_data_line.len(), byte_buffer.len());
|
||||||
|
debug!("Accumulated tool calls: {}", current_tool_calls.len());
|
||||||
|
|
||||||
|
// If we have any remaining data in buffers, log it for debugging
|
||||||
|
if !buffer.is_empty() {
|
||||||
|
debug!("Remaining buffer content: {:?}", buffer);
|
||||||
|
}
|
||||||
|
if !byte_buffer.is_empty() {
|
||||||
|
debug!("Remaining byte buffer: {} bytes", byte_buffer.len());
|
||||||
|
}
|
||||||
|
if !incomplete_data_line.is_empty() {
|
||||||
|
debug!("Remaining incomplete data line: {:?}", incomplete_data_line);
|
||||||
|
}
|
||||||
|
|
||||||
// If we have any incomplete data line at the end, try to process it
|
// If we have any incomplete data line at the end, try to process it
|
||||||
if !incomplete_data_line.is_empty() {
|
if !incomplete_data_line.is_empty() {
|
||||||
debug!(
|
debug!(
|
||||||
|
|||||||
@@ -88,11 +88,13 @@ pub mod anthropic;
|
|||||||
pub mod databricks;
|
pub mod databricks;
|
||||||
pub mod embedded;
|
pub mod embedded;
|
||||||
pub mod oauth;
|
pub mod oauth;
|
||||||
|
pub mod ollama;
|
||||||
pub mod openai;
|
pub mod openai;
|
||||||
|
|
||||||
pub use anthropic::AnthropicProvider;
|
pub use anthropic::AnthropicProvider;
|
||||||
pub use databricks::DatabricksProvider;
|
pub use databricks::DatabricksProvider;
|
||||||
pub use embedded::EmbeddedProvider;
|
pub use embedded::EmbeddedProvider;
|
||||||
|
pub use ollama::OllamaProvider;
|
||||||
pub use openai::OpenAIProvider;
|
pub use openai::OpenAIProvider;
|
||||||
|
|
||||||
/// Provider registry for managing multiple LLM providers
|
/// Provider registry for managing multiple LLM providers
|
||||||
|
|||||||
751
crates/g3-providers/src/ollama.rs
Normal file
751
crates/g3-providers/src/ollama.rs
Normal file
@@ -0,0 +1,751 @@
|
|||||||
|
//! Ollama LLM provider implementation for the g3-providers crate.
|
||||||
|
//!
|
||||||
|
//! This module provides an implementation of the `LLMProvider` trait for Ollama,
|
||||||
|
//! supporting both completion and streaming modes with native tool calling.
|
||||||
|
//!
|
||||||
|
//! # Features
|
||||||
|
//!
|
||||||
|
//! - Support for any Ollama model (llama3.2, mistral, qwen, etc.)
|
||||||
|
//! - Both completion and streaming response modes
|
||||||
|
//! - Native tool calling support for compatible models
|
||||||
|
//! - Configurable base URL (defaults to http://localhost:11434)
|
||||||
|
//! - Simple configuration with no authentication required
|
||||||
|
//!
|
||||||
|
//! # Usage
|
||||||
|
//!
|
||||||
|
//! ```rust,no_run
|
||||||
|
//! use g3_providers::{OllamaProvider, LLMProvider, CompletionRequest, Message, MessageRole};
|
||||||
|
//!
|
||||||
|
//! #[tokio::main]
|
||||||
|
//! async fn main() -> anyhow::Result<()> {
|
||||||
|
//! // Create the provider with default settings (localhost:11434)
|
||||||
|
//! let provider = OllamaProvider::new(
|
||||||
|
//! "llama3.2".to_string(),
|
||||||
|
//! None, // Optional: base_url
|
||||||
|
//! None, // Optional: max tokens
|
||||||
|
//! None, // Optional: temperature
|
||||||
|
//! )?;
|
||||||
|
//!
|
||||||
|
//! // Create a completion request
|
||||||
|
//! let request = CompletionRequest {
|
||||||
|
//! messages: vec![
|
||||||
|
//! Message {
|
||||||
|
//! role: MessageRole::User,
|
||||||
|
//! content: "Hello! How are you?".to_string(),
|
||||||
|
//! },
|
||||||
|
//! ],
|
||||||
|
//! max_tokens: Some(1000),
|
||||||
|
//! temperature: Some(0.7),
|
||||||
|
//! stream: false,
|
||||||
|
//! tools: None,
|
||||||
|
//! };
|
||||||
|
//!
|
||||||
|
//! // Get a completion
|
||||||
|
//! let response = provider.complete(request).await?;
|
||||||
|
//! println!("Response: {}", response.content);
|
||||||
|
//!
|
||||||
|
//! Ok(())
|
||||||
|
//! }
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use bytes::Bytes;
|
||||||
|
use futures_util::stream::StreamExt;
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
|
||||||
|
MessageRole, Tool, ToolCall, Usage,
|
||||||
|
};
|
||||||
|
|
||||||
|
const DEFAULT_BASE_URL: &str = "http://localhost:11434";
|
||||||
|
const DEFAULT_TIMEOUT_SECS: u64 = 600;
|
||||||
|
|
||||||
|
pub const OLLAMA_DEFAULT_MODEL: &str = "llama3.2";
|
||||||
|
pub const OLLAMA_KNOWN_MODELS: &[&str] = &[
|
||||||
|
"llama3.2",
|
||||||
|
"llama3.2:1b",
|
||||||
|
"llama3.2:3b",
|
||||||
|
"llama3.1",
|
||||||
|
"llama3.1:8b",
|
||||||
|
"llama3.1:70b",
|
||||||
|
"mistral",
|
||||||
|
"mistral-nemo",
|
||||||
|
"mixtral",
|
||||||
|
"qwen2.5",
|
||||||
|
"qwen2.5:7b",
|
||||||
|
"qwen2.5:14b",
|
||||||
|
"qwen2.5:32b",
|
||||||
|
"qwen2.5-coder",
|
||||||
|
"qwen2.5-coder:7b",
|
||||||
|
"qwen3-coder",
|
||||||
|
"phi3",
|
||||||
|
"gemma2",
|
||||||
|
];
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct OllamaProvider {
|
||||||
|
client: Client,
|
||||||
|
base_url: String,
|
||||||
|
model: String,
|
||||||
|
max_tokens: Option<u32>,
|
||||||
|
temperature: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OllamaProvider {
|
||||||
|
pub fn new(
|
||||||
|
model: String,
|
||||||
|
base_url: Option<String>,
|
||||||
|
max_tokens: Option<u32>,
|
||||||
|
temperature: Option<f32>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let client = Client::builder()
|
||||||
|
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
|
||||||
|
.build()
|
||||||
|
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
|
||||||
|
|
||||||
|
let base_url = base_url
|
||||||
|
.unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
|
||||||
|
.trim_end_matches('/')
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Initialized Ollama provider with model: {} at {}",
|
||||||
|
model, base_url
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
client,
|
||||||
|
base_url,
|
||||||
|
model,
|
||||||
|
max_tokens,
|
||||||
|
temperature: temperature.unwrap_or(0.7),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_tools(&self, tools: &[Tool]) -> Vec<OllamaTool> {
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.map(|tool| OllamaTool {
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: OllamaFunction {
|
||||||
|
name: tool.name.clone(),
|
||||||
|
description: tool.description.clone(),
|
||||||
|
parameters: tool.input_schema.clone(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_messages(&self, messages: &[Message]) -> Result<Vec<OllamaMessage>> {
|
||||||
|
let mut ollama_messages = Vec::new();
|
||||||
|
|
||||||
|
for message in messages {
|
||||||
|
let role = match message.role {
|
||||||
|
MessageRole::System => "system",
|
||||||
|
MessageRole::User => "user",
|
||||||
|
MessageRole::Assistant => "assistant",
|
||||||
|
};
|
||||||
|
|
||||||
|
ollama_messages.push(OllamaMessage {
|
||||||
|
role: role.to_string(),
|
||||||
|
content: message.content.clone(),
|
||||||
|
tool_calls: None, // Only used in responses
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if ollama_messages.is_empty() {
|
||||||
|
return Err(anyhow!("At least one message is required"));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ollama_messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_request_body(
|
||||||
|
&self,
|
||||||
|
messages: &[Message],
|
||||||
|
tools: Option<&[Tool]>,
|
||||||
|
streaming: bool,
|
||||||
|
max_tokens: Option<u32>,
|
||||||
|
temperature: f32,
|
||||||
|
) -> Result<OllamaRequest> {
|
||||||
|
let ollama_messages = self.convert_messages(messages)?;
|
||||||
|
let ollama_tools = tools.map(|t| self.convert_tools(t));
|
||||||
|
|
||||||
|
let mut options = OllamaOptions {
|
||||||
|
temperature,
|
||||||
|
num_predict: max_tokens,
|
||||||
|
};
|
||||||
|
|
||||||
|
// If max_tokens is provided, use it; otherwise use the instance default
|
||||||
|
if max_tokens.is_none() {
|
||||||
|
options.num_predict = self.max_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
let request = OllamaRequest {
|
||||||
|
model: self.model.clone(),
|
||||||
|
messages: ollama_messages,
|
||||||
|
tools: ollama_tools,
|
||||||
|
stream: streaming,
|
||||||
|
options,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn parse_streaming_response(
|
||||||
|
&self,
|
||||||
|
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
|
||||||
|
tx: mpsc::Sender<Result<CompletionChunk>>,
|
||||||
|
) -> Option<Usage> {
|
||||||
|
let mut buffer = String::new();
|
||||||
|
let mut accumulated_usage: Option<Usage> = None;
|
||||||
|
let mut current_tool_calls: Vec<OllamaToolCall> = Vec::new();
|
||||||
|
let mut byte_buffer = Vec::new();
|
||||||
|
|
||||||
|
while let Some(chunk_result) = stream.next().await {
|
||||||
|
match chunk_result {
|
||||||
|
Ok(chunk) => {
|
||||||
|
// Append new bytes to our buffer
|
||||||
|
byte_buffer.extend_from_slice(&chunk);
|
||||||
|
|
||||||
|
// Try to convert the entire buffer to UTF-8
|
||||||
|
let chunk_str = match std::str::from_utf8(&byte_buffer) {
|
||||||
|
Ok(s) => {
|
||||||
|
let result = s.to_string();
|
||||||
|
byte_buffer.clear();
|
||||||
|
result
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let valid_up_to = e.valid_up_to();
|
||||||
|
if valid_up_to > 0 {
|
||||||
|
let valid_bytes =
|
||||||
|
byte_buffer.drain(..valid_up_to).collect::<Vec<_>>();
|
||||||
|
std::str::from_utf8(&valid_bytes).unwrap().to_string()
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
buffer.push_str(&chunk_str);
|
||||||
|
|
||||||
|
// Process complete lines
|
||||||
|
while let Some(line_end) = buffer.find('\n') {
|
||||||
|
let line = buffer[..line_end].trim().to_string();
|
||||||
|
buffer.drain(..line_end + 1);
|
||||||
|
|
||||||
|
if line.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ollama streaming sends JSON objects per line
|
||||||
|
match serde_json::from_str::<OllamaStreamChunk>(&line) {
|
||||||
|
Ok(chunk) => {
|
||||||
|
// Handle text content
|
||||||
|
if let Some(message) = &chunk.message {
|
||||||
|
let content = &message.content;
|
||||||
|
if !content.is_empty() {
|
||||||
|
debug!("Sending text chunk: '{}'", content);
|
||||||
|
let chunk = CompletionChunk {
|
||||||
|
content: content.clone(),
|
||||||
|
finished: false,
|
||||||
|
usage: None,
|
||||||
|
tool_calls: None,
|
||||||
|
};
|
||||||
|
if tx.send(Ok(chunk)).await.is_err() {
|
||||||
|
debug!("Receiver dropped, stopping stream");
|
||||||
|
return accumulated_usage;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle tool calls
|
||||||
|
if let Some(tool_calls) = &message.tool_calls {
|
||||||
|
current_tool_calls.extend(tool_calls.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if stream is done
|
||||||
|
if chunk.done.unwrap_or(false) {
|
||||||
|
debug!("Stream completed");
|
||||||
|
|
||||||
|
// Update usage if available
|
||||||
|
if let Some(eval_count) = chunk.eval_count {
|
||||||
|
accumulated_usage = Some(Usage {
|
||||||
|
prompt_tokens: chunk.prompt_eval_count.unwrap_or(0),
|
||||||
|
completion_tokens: eval_count,
|
||||||
|
total_tokens: chunk.prompt_eval_count.unwrap_or(0)
|
||||||
|
+ eval_count,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send final chunk with tool calls if any
|
||||||
|
let final_tool_calls: Vec<ToolCall> = current_tool_calls
|
||||||
|
.iter()
|
||||||
|
.map(|tc| ToolCall {
|
||||||
|
id: tc.function.name.clone(), // Ollama doesn't provide IDs
|
||||||
|
tool: tc.function.name.clone(),
|
||||||
|
args: tc.function.arguments.clone(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let final_chunk = CompletionChunk {
|
||||||
|
content: String::new(),
|
||||||
|
finished: true,
|
||||||
|
usage: accumulated_usage.clone(),
|
||||||
|
tool_calls: if final_tool_calls.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(final_tool_calls)
|
||||||
|
},
|
||||||
|
};
|
||||||
|
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||||
|
debug!("Receiver dropped, stopping stream");
|
||||||
|
}
|
||||||
|
return accumulated_usage;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
debug!("Failed to parse Ollama stream chunk: {} - Line: {}", e, line);
|
||||||
|
// Don't error out, just continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Stream error: {}", e);
|
||||||
|
let error_msg = e.to_string();
|
||||||
|
if error_msg.contains("unexpected EOF") || error_msg.contains("connection") {
|
||||||
|
warn!("Connection terminated unexpectedly, treating as end of stream");
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
||||||
|
}
|
||||||
|
return accumulated_usage;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send final chunk if we haven't already
|
||||||
|
let final_tool_calls: Vec<ToolCall> = current_tool_calls
|
||||||
|
.iter()
|
||||||
|
.map(|tc| ToolCall {
|
||||||
|
id: tc.function.name.clone(),
|
||||||
|
tool: tc.function.name.clone(),
|
||||||
|
args: tc.function.arguments.clone(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let final_chunk = CompletionChunk {
|
||||||
|
content: String::new(),
|
||||||
|
finished: true,
|
||||||
|
usage: accumulated_usage.clone(),
|
||||||
|
tool_calls: if final_tool_calls.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(final_tool_calls)
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let _ = tx.send(Ok(final_chunk)).await;
|
||||||
|
accumulated_usage
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fetch available models from the Ollama instance
|
||||||
|
pub async fn fetch_available_models(&self) -> Result<Vec<String>> {
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.get(format!("{}/api/tags", self.base_url))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Failed to fetch Ollama models: {}", e))?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
let status = response.status();
|
||||||
|
let error_text = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||||
|
return Err(anyhow!(
|
||||||
|
"Failed to fetch Ollama models: {} - {}",
|
||||||
|
status,
|
||||||
|
error_text
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let json: serde_json::Value = response.json().await?;
|
||||||
|
let models = json
|
||||||
|
.get("models")
|
||||||
|
.and_then(|v| v.as_array())
|
||||||
|
.ok_or_else(|| anyhow!("Unexpected response format: missing 'models' array"))?;
|
||||||
|
|
||||||
|
let model_names: Vec<String> = models
|
||||||
|
.iter()
|
||||||
|
.filter_map(|model| model.get("name").and_then(|n| n.as_str()).map(String::from))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
debug!("Found {} models in Ollama", model_names.len());
|
||||||
|
Ok(model_names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl LLMProvider for OllamaProvider {
|
||||||
|
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
|
||||||
|
debug!(
|
||||||
|
"Processing Ollama completion request with {} messages",
|
||||||
|
request.messages.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
let max_tokens = request.max_tokens.or(self.max_tokens);
|
||||||
|
let temperature = request.temperature.unwrap_or(self.temperature);
|
||||||
|
|
||||||
|
let request_body = self.create_request_body(
|
||||||
|
&request.messages,
|
||||||
|
request.tools.as_deref(),
|
||||||
|
false,
|
||||||
|
max_tokens,
|
||||||
|
temperature,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Sending request to Ollama API: model={}, temperature={}",
|
||||||
|
self.model, request_body.options.temperature
|
||||||
|
);
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(format!("{}/api/chat", self.base_url))
|
||||||
|
.json(&request_body)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Failed to send request to Ollama API: {}", e))?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
if !status.is_success() {
|
||||||
|
let error_text = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||||
|
return Err(anyhow!("Ollama API error {}: {}", status, error_text));
|
||||||
|
}
|
||||||
|
|
||||||
|
let response_text = response.text().await?;
|
||||||
|
debug!("Raw Ollama API response: {}", response_text);
|
||||||
|
|
||||||
|
let ollama_response: OllamaResponse =
|
||||||
|
serde_json::from_str(&response_text).map_err(|e| {
|
||||||
|
anyhow!(
|
||||||
|
"Failed to parse Ollama response: {} - Response: {}",
|
||||||
|
e,
|
||||||
|
response_text
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let content = ollama_response.message.content.clone();
|
||||||
|
|
||||||
|
let usage = Usage {
|
||||||
|
prompt_tokens: ollama_response.prompt_eval_count.unwrap_or(0),
|
||||||
|
completion_tokens: ollama_response.eval_count.unwrap_or(0),
|
||||||
|
total_tokens: ollama_response.prompt_eval_count.unwrap_or(0)
|
||||||
|
+ ollama_response.eval_count.unwrap_or(0),
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Ollama completion successful: {} tokens generated",
|
||||||
|
usage.completion_tokens
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(CompletionResponse {
|
||||||
|
content,
|
||||||
|
usage,
|
||||||
|
model: self.model.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
|
||||||
|
debug!(
|
||||||
|
"Processing Ollama request (non-streaming) with {} messages",
|
||||||
|
request.messages.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(ref tools) = request.tools {
|
||||||
|
debug!("Request has {} tools", tools.len());
|
||||||
|
for tool in tools.iter().take(5) {
|
||||||
|
debug!(" Tool: {}", tool.name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_tokens = request.max_tokens.or(self.max_tokens);
|
||||||
|
let temperature = request.temperature.unwrap_or(self.temperature);
|
||||||
|
|
||||||
|
let request_body = self.create_request_body(
|
||||||
|
&request.messages,
|
||||||
|
request.tools.as_deref(),
|
||||||
|
false, // Use non-streaming mode to avoid streaming bugs
|
||||||
|
max_tokens,
|
||||||
|
temperature,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Sending request to Ollama API (stream=false): model={}, temperature={}",
|
||||||
|
self.model, request_body.options.temperature
|
||||||
|
);
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(format!("{}/api/chat", self.base_url))
|
||||||
|
.json(&request_body)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Failed to send request to Ollama API: {}", e))?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
if !status.is_success() {
|
||||||
|
let error_text = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||||
|
return Err(anyhow!("Ollama API error {}: {}", status, error_text));
|
||||||
|
}
|
||||||
|
|
||||||
|
// For non-streaming, parse the complete JSON response
|
||||||
|
let response_text = response.text().await?;
|
||||||
|
debug!("Raw Ollama API response: {}", response_text);
|
||||||
|
|
||||||
|
let ollama_response: OllamaResponse =
|
||||||
|
serde_json::from_str(&response_text).map_err(|e| {
|
||||||
|
anyhow!(
|
||||||
|
"Failed to parse Ollama response: {} - Response: {}",
|
||||||
|
e,
|
||||||
|
response_text
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let (tx, rx) = mpsc::channel(100);
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let content = ollama_response.message.content;
|
||||||
|
let usage = Usage {
|
||||||
|
prompt_tokens: ollama_response.prompt_eval_count.unwrap_or(0),
|
||||||
|
completion_tokens: ollama_response.eval_count.unwrap_or(0),
|
||||||
|
total_tokens: ollama_response.prompt_eval_count.unwrap_or(0)
|
||||||
|
+ ollama_response.eval_count.unwrap_or(0),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Extract tool calls if present
|
||||||
|
let tool_calls: Option<Vec<ToolCall>> = ollama_response.message.tool_calls.map(|tcs| {
|
||||||
|
tcs.iter()
|
||||||
|
.map(|tc| ToolCall {
|
||||||
|
id: tc.function.name.clone(),
|
||||||
|
tool: tc.function.name.clone(),
|
||||||
|
args: tc.function.arguments.clone(),
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
});
|
||||||
|
|
||||||
|
// Send content if any
|
||||||
|
if !content.is_empty() {
|
||||||
|
let _ = tx.send(Ok(CompletionChunk {
|
||||||
|
content,
|
||||||
|
finished: false,
|
||||||
|
usage: None,
|
||||||
|
tool_calls: None,
|
||||||
|
})).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send final chunk with usage and tool calls
|
||||||
|
let _ = tx.send(Ok(CompletionChunk {
|
||||||
|
content: String::new(),
|
||||||
|
finished: true,
|
||||||
|
usage: Some(usage),
|
||||||
|
tool_calls,
|
||||||
|
})).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(ReceiverStream::new(rx))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"ollama"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn model(&self) -> &str {
|
||||||
|
&self.model
|
||||||
|
}
|
||||||
|
|
||||||
|
fn has_native_tool_calling(&self) -> bool {
|
||||||
|
// Most modern Ollama models support tool calling
|
||||||
|
// Models like llama3.2, qwen2.5, mistral, etc. have good tool support
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ollama API request/response structures
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct OllamaRequest {
|
||||||
|
model: String,
|
||||||
|
messages: Vec<OllamaMessage>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tools: Option<Vec<OllamaTool>>,
|
||||||
|
stream: bool,
|
||||||
|
options: OllamaOptions,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct OllamaOptions {
|
||||||
|
temperature: f32,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
num_predict: Option<u32>, // Ollama's equivalent of max_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct OllamaTool {
|
||||||
|
r#type: String,
|
||||||
|
function: OllamaFunction,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct OllamaFunction {
|
||||||
|
name: String,
|
||||||
|
description: String,
|
||||||
|
parameters: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
struct OllamaMessage {
|
||||||
|
role: String,
|
||||||
|
content: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tool_calls: Option<Vec<OllamaToolCall>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
struct OllamaToolCall {
|
||||||
|
function: OllamaToolCallFunction,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
struct OllamaToolCallFunction {
|
||||||
|
name: String,
|
||||||
|
arguments: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OllamaResponse {
|
||||||
|
message: OllamaMessage,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
done: bool,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
total_duration: Option<u64>,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
load_duration: Option<u64>,
|
||||||
|
prompt_eval_count: Option<u32>,
|
||||||
|
eval_count: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OllamaStreamChunk {
|
||||||
|
message: Option<OllamaMessage>,
|
||||||
|
done: Option<bool>,
|
||||||
|
prompt_eval_count: Option<u32>,
|
||||||
|
eval_count: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_provider_creation() {
|
||||||
|
let provider = OllamaProvider::new(
|
||||||
|
"llama3.2".to_string(),
|
||||||
|
None,
|
||||||
|
Some(1000),
|
||||||
|
Some(0.7),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(provider.model(), "llama3.2");
|
||||||
|
assert_eq!(provider.name(), "ollama");
|
||||||
|
assert!(provider.has_native_tool_calling());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_message_conversion() {
|
||||||
|
let provider = OllamaProvider::new(
|
||||||
|
"llama3.2".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let messages = vec![
|
||||||
|
Message {
|
||||||
|
role: MessageRole::System,
|
||||||
|
content: "You are a helpful assistant.".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: "Hello!".to_string(),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let ollama_messages = provider.convert_messages(&messages).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(ollama_messages.len(), 2);
|
||||||
|
assert_eq!(ollama_messages[0].role, "system");
|
||||||
|
assert_eq!(ollama_messages[1].role, "user");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tool_conversion() {
|
||||||
|
let provider = OllamaProvider::new(
|
||||||
|
"llama3.2".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let tools = vec![Tool {
|
||||||
|
name: "get_weather".to_string(),
|
||||||
|
description: "Get the current weather".to_string(),
|
||||||
|
input_schema: serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}),
|
||||||
|
}];
|
||||||
|
|
||||||
|
let ollama_tools = provider.convert_tools(&tools);
|
||||||
|
|
||||||
|
assert_eq!(ollama_tools.len(), 1);
|
||||||
|
assert_eq!(ollama_tools[0].r#type, "function");
|
||||||
|
assert_eq!(ollama_tools[0].function.name, "get_weather");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_custom_base_url() {
|
||||||
|
let provider = OllamaProvider::new(
|
||||||
|
"llama3.2".to_string(),
|
||||||
|
Some("http://custom:11434".to_string()),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(provider.base_url, "http://custom:11434");
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user