Compare commits

..

16 Commits

Author SHA1 Message Date
Michael Neale
79b375519b use proper context for qwen3-coder 2025-11-05 13:55:45 +11:00
Michael Neale
88c3cc23fe works better without streaming 2025-11-05 12:55:21 +11:00
Michael Neale
4622507f37 gpt context aware 2025-11-05 12:25:02 +11:00
Michael Neale
217df2f2af ollama support 2025-11-05 12:17:01 +11:00
Dhanji Prasanna
22a0090cdc fix unexpected EOF on streams 2025-11-04 16:28:41 +11:00
Dhanji Prasanna
631f3c16ca compact on tool call if > 90% 2025-11-04 14:35:11 +11:00
Dhanji Prasanna
1f9fef5f18 more json filtering 2025-11-03 11:56:16 +11:00
Dhanji Prasanna
57d473c19d mild json filtering improvement 2025-11-03 11:54:27 +11:00
Jochen
e59ce2f93f Merge pull request #16 from dhanji/jochen-ast-tool
adds ast-grep tool for faster code exploration
2025-11-02 21:04:11 +11:00
Jochen
a1ad94ed75 Added comment & example for native flow
detailed examples for using code_search tool for native tool use.
2025-11-02 21:02:43 +11:00
Jochen
982c0bbfb3 amend instructions for tool use 2025-11-01 15:52:08 +11:00
Jochen
ad9ba5e5d8 added ast-grep use
g3 tool use of ast-grep command with batching for faster code exploration.
2025-11-01 14:59:55 +11:00
Dhanji Prasanna
f89bbfc89a fix final_output bug 2025-10-31 14:48:36 +11:00
Dhanji Prasanna
11eb01e04d add back progress bar to cli 2025-10-31 14:28:35 +11:00
Dhanji Prasanna
bdaacfd051 fix for duplicate messages at end 2025-10-31 13:34:36 +11:00
Dhanji R. Prasanna
92ae776510 Merge pull request #14 from dhanji/micn/always-coach-player
always coach player
2025-10-31 12:52:17 +11:00
17 changed files with 3038 additions and 103 deletions

1
.gitignore vendored
View File

@@ -3,6 +3,7 @@
debug
target
.build
appy/
# These are backup files generated by rustfmt
**/*.rs.bk

20
Cargo.lock generated
View File

@@ -1391,6 +1391,7 @@ dependencies = [
"reqwest",
"serde",
"serde_json",
"serde_yaml",
"shellexpand",
"thiserror 1.0.69",
"tokio",
@@ -3078,6 +3079,19 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_yaml"
version = "0.9.34+deprecated"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47"
dependencies = [
"indexmap",
"itoa",
"ryu",
"serde",
"unsafe-libyaml",
]
[[package]]
name = "sha2"
version = "0.10.9"
@@ -3667,6 +3681,12 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd"
[[package]]
name = "unsafe-libyaml"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
[[package]]
name = "url"
version = "2.5.7"

456
OLLAMA_CONFIG.md Normal file
View File

@@ -0,0 +1,456 @@
# Configuring Ollama Provider in G3
This guide shows you how to configure G3 to use Ollama as your LLM provider.
## Quick Start
### 1. Install Ollama
```bash
# Visit https://ollama.ai to download and install
# Or use curl:
curl https://ollama.ai/install.sh | sh
```
### 2. Pull a Model
```bash
ollama pull llama3.2
# or any other model you prefer
```
### 3. Create Configuration File
Copy the example configuration:
```bash
cp config.ollama.example.toml ~/.config/g3/config.toml
```
Or create it manually:
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "llama3.2"
```
### 4. Run G3
```bash
g3
# G3 will now use Ollama with llama3.2!
```
## Configuration Options
### Basic Configuration
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "llama3.2"
```
This is the minimal configuration needed. It uses all defaults:
- Base URL: `http://localhost:11434`
- Temperature: `0.7`
- Max tokens: Not limited (uses model default)
### Full Configuration
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "llama3.2"
base_url = "http://localhost:11434"
max_tokens = 2048
temperature = 0.7
```
### Custom Ollama Host
If you're running Ollama on a different machine or port:
```toml
[providers.ollama]
model = "llama3.2"
base_url = "http://192.168.1.100:11434"
```
### Different Models
You can use any Ollama model:
```toml
[providers.ollama]
model = "qwen2.5:7b" # Alibaba's Qwen model
```
```toml
[providers.ollama]
model = "mistral" # Mistral AI
```
```toml
[providers.ollama]
model = "llama3.1:70b" # Larger Llama model
```
## Multiple Provider Configuration
You can configure multiple providers and switch between them:
```toml
[providers]
default_provider = "ollama" # Default for most operations
# Ollama for local, fast responses
[providers.ollama]
model = "llama3.2:3b"
temperature = 0.7
# Databricks for more complex tasks
[providers.databricks]
host = "https://your-workspace.cloud.databricks.com"
model = "databricks-claude-sonnet-4"
max_tokens = 4096
temperature = 0.1
use_oauth = true
```
Then switch providers with:
```bash
g3 --provider databricks
```
## Autonomous Mode (Coach-Player)
Use different providers for code review (coach) and implementation (player):
```toml
[providers]
default_provider = "ollama"
coach = "databricks" # Use powerful cloud model for review
player = "ollama" # Use local model for implementation
[providers.ollama]
model = "qwen2.5:14b" # Larger local model for coding
[providers.databricks]
host = "https://your-workspace.cloud.databricks.com"
model = "databricks-claude-sonnet-4"
use_oauth = true
```
This gives you the best of both worlds:
- Fast local execution for coding tasks
- Powerful cloud review for quality assurance
## Recommended Models
### For Coding Tasks
| Model | Size | Speed | Quality | Notes |
|-------|------|-------|---------|-------|
| **qwen2.5:7b** | 7B | Fast | Excellent | Best balance for coding |
| **llama3.2:3b** | 3B | Very Fast | Good | Great for quick tasks |
| **llama3.1:8b** | 8B | Medium | Very Good | Solid all-rounder |
| **mistral** | 7B | Fast | Good | Good for general use |
### For Complex Tasks
| Model | Size | Speed | Quality | Notes |
|-------|------|-------|---------|-------|
| **qwen2.5:14b** | 14B | Medium | Excellent | Best local model for coding |
| **qwen2.5:32b** | 32B | Slow | Outstanding | If you have the resources |
| **llama3.1:70b** | 70B | Very Slow | Outstanding | Requires significant RAM/GPU |
## Temperature Settings
Temperature controls randomness in responses:
- **0.1-0.3**: Deterministic, good for code generation
- **0.5-0.7**: Balanced, good for most tasks
- **0.8-1.0**: Creative, good for brainstorming
```toml
[providers.ollama]
model = "qwen2.5:7b"
temperature = 0.2 # Focused code generation
```
## Max Tokens
Control response length:
```toml
[providers.ollama]
model = "llama3.2"
max_tokens = 1024 # Shorter responses
```
```toml
[providers.ollama]
model = "qwen2.5:7b"
max_tokens = 4096 # Longer, detailed responses
```
Leave it unset for model defaults (recommended).
## Performance Tuning
### GPU Acceleration
Ollama automatically uses GPU if available. To check:
```bash
ollama ps
```
### Quantized Models
For faster responses with less RAM:
```toml
[providers.ollama]
model = "llama3.2:3b-q4_0" # 4-bit quantization
```
Quantization options:
- `q4_0`: 4-bit, fastest, lowest quality
- `q5_0`: 5-bit, balanced
- `q8_0`: 8-bit, slower, better quality
### Multiple Models
You can pull multiple models and switch easily:
```bash
ollama pull llama3.2:3b # Fast for chat
ollama pull qwen2.5:7b # Better for code
ollama pull mistral # General purpose
```
Then change your config:
```toml
[providers.ollama]
model = "qwen2.5:7b" # Just change this line
```
## Troubleshooting
### Ollama Not Running
```bash
# Check if Ollama is running
curl http://localhost:11434/api/version
# Start Ollama (macOS/Linux)
ollama serve
# Or just run a model (auto-starts)
ollama run llama3.2
```
### Model Not Found
```bash
# List available models
ollama list
# Pull the model
ollama pull llama3.2
```
### Slow Responses
1. Use a smaller model:
```toml
model = "llama3.2:1b" # Smallest, fastest
```
2. Use quantized version:
```toml
model = "llama3.2:3b-q4_0"
```
3. Reduce max_tokens:
```toml
max_tokens = 512
```
### Out of Memory
1. Switch to smaller model
2. Use quantized version
3. Close other applications
4. Check GPU memory: `ollama ps`
### Connection Refused
Check base_url is correct:
```toml
[providers.ollama]
model = "llama3.2"
base_url = "http://localhost:11434" # Default
```
For remote Ollama:
```toml
base_url = "http://your-server:11434"
```
## Complete Example Configs
### Minimal Local Setup
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "llama3.2"
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
```
### Optimized for Coding
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "qwen2.5:7b"
temperature = 0.2
max_tokens = 2048
[agent]
max_context_length = 16384
enable_streaming = true
timeout_seconds = 120
```
### Fast Responses
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "llama3.2:3b-q4_0"
temperature = 0.7
max_tokens = 1024
[agent]
max_context_length = 4096
enable_streaming = true
timeout_seconds = 30
```
### High Quality (Requires Good Hardware)
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "qwen2.5:32b"
temperature = 0.3
max_tokens = 4096
[agent]
max_context_length = 32768
enable_streaming = true
timeout_seconds = 300
```
### Hybrid (Local + Cloud)
```toml
[providers]
default_provider = "ollama"
coach = "databricks"
player = "ollama"
[providers.ollama]
model = "qwen2.5:14b"
temperature = 0.2
[providers.databricks]
host = "https://your-workspace.cloud.databricks.com"
model = "databricks-claude-sonnet-4"
use_oauth = true
[agent]
max_context_length = 16384
enable_streaming = true
timeout_seconds = 120
```
## Environment Variables
You can override config with environment variables:
```bash
# Override model
G3_PROVIDERS_OLLAMA_MODEL=qwen2.5:7b g3
# Override base URL
G3_PROVIDERS_OLLAMA_BASE_URL=http://192.168.1.100:11434 g3
# Override default provider
G3_PROVIDERS_DEFAULT_PROVIDER=ollama g3
```
## Best Practices
1. **Start Small**: Begin with llama3.2:3b, scale up if needed
2. **Use Quantization**: q4_0 or q5_0 for best speed/quality balance
3. **Match Task to Model**:
- Quick edits: 1B-3B models
- Code generation: 7B-14B models
- Complex refactoring: 14B-32B models
4. **Temperature for Code**: Use 0.1-0.3 for deterministic output
5. **Enable Streaming**: Always enable for better UX
6. **Local First**: Use Ollama by default, cloud for special cases
## Comparison with Other Providers
| Feature | Ollama | Databricks | OpenAI | Anthropic |
|---------|--------|------------|--------|-----------|
| Cost | Free | Paid | Paid | Paid |
| Privacy | Full | Medium | Low | Low |
| Speed (small models) | Fast | Fast | Medium | Medium |
| Speed (large models) | Slow | Fast | Fast | Fast |
| Setup Complexity | Low | Medium | Low | Low |
| Authentication | None | OAuth/Token | API Key | API Key |
| Offline Support | Yes | No | No | No |
| Tool Calling | Yes | Yes | Yes | Yes |
## Next Steps
1. Try different models: `ollama pull mistral`, `ollama pull qwen2.5`
2. Experiment with temperature settings
3. Set up hybrid config with cloud provider for complex tasks
4. Share your config in the community!
## Getting Help
- Ollama docs: https://ollama.ai/docs
- G3 issues: https://github.com/your-repo/issues
- Test your config: `g3 --help`

315
OLLAMA_EXAMPLE.md Normal file
View File

@@ -0,0 +1,315 @@
# Ollama Provider for g3
A simple, local LLM provider implementation for g3 that connects to Ollama.
## Features
-**Simple Setup**: No API keys or authentication required
-**Local Execution**: Runs entirely on your machine
-**Tool Calling Support**: Native tool calling for compatible models
-**Streaming**: Full streaming support with real-time responses
-**Flexible Configuration**: Custom base URL, temperature, and max tokens
-**Model Discovery**: Automatic detection of available models
## Quick Start
### Prerequisites
1. Install and start Ollama: https://ollama.ai
2. Pull a model: `ollama pull llama3.2`
### Basic Usage
```rust
use g3_providers::{OllamaProvider, LLMProvider, CompletionRequest, Message, MessageRole};
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Create provider with default settings (localhost:11434)
let provider = OllamaProvider::new(
"llama3.2".to_string(),
None, // base_url: defaults to http://localhost:11434
None, // max_tokens: optional
None, // temperature: defaults to 0.7
)?;
// Create a simple request
let request = CompletionRequest {
messages: vec![
Message {
role: MessageRole::User,
content: "What is the capital of France?".to_string(),
},
],
max_tokens: Some(1000),
temperature: Some(0.7),
stream: false,
tools: None,
};
// Get completion
let response = provider.complete(request).await?;
println!("Response: {}", response.content);
println!("Tokens: {}", response.usage.total_tokens);
Ok(())
}
```
### Streaming Example
```rust
use futures_util::StreamExt;
let request = CompletionRequest {
messages: vec![
Message {
role: MessageRole::User,
content: "Write a short poem about coding".to_string(),
},
],
max_tokens: Some(500),
temperature: Some(0.8),
stream: true,
tools: None,
};
let mut stream = provider.stream(request).await?;
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
print!("{}", chunk.content);
if chunk.finished {
println!("\n\nDone!");
if let Some(usage) = chunk.usage {
println!("Total tokens: {}", usage.total_tokens);
}
}
}
Err(e) => eprintln!("Error: {}", e),
}
}
```
### Tool Calling Example
```rust
use serde_json::json;
let tools = vec![Tool {
name: "get_weather".to_string(),
description: "Get current weather for a location".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "City name"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "Temperature unit"
}
},
"required": ["location"]
}),
}];
let request = CompletionRequest {
messages: vec![
Message {
role: MessageRole::User,
content: "What's the weather in Paris?".to_string(),
},
],
max_tokens: Some(500),
temperature: Some(0.5),
stream: false,
tools: Some(tools),
};
let response = provider.complete(request).await?;
println!("Response: {}", response.content);
```
### Custom Ollama Host
```rust
// Connect to remote Ollama instance
let provider = OllamaProvider::new(
"llama3.2".to_string(),
Some("http://192.168.1.100:11434".to_string()),
None,
None,
)?;
```
### Fetch Available Models
```rust
// Discover what models are available
let models = provider.fetch_available_models().await?;
println!("Available models:");
for model in models {
println!(" - {}", model);
}
```
## Supported Models
The provider works with any Ollama model, including:
- **llama3.2** (1B, 3B) - Meta's latest Llama models
- **llama3.1** (8B, 70B, 405B) - Previous generation
- **qwen2.5** (7B, 14B, 32B) - Alibaba's Qwen models
- **mistral** - Mistral AI models
- **mixtral** - Mixture of experts model
- **phi3** - Microsoft's Phi-3
- **gemma2** - Google's Gemma 2
## Configuration
### Constructor Parameters
```rust
OllamaProvider::new(
model: String, // Model name (e.g., "llama3.2")
base_url: Option<String>, // Ollama API URL (default: http://localhost:11434)
max_tokens: Option<u32>, // Maximum tokens to generate (optional)
temperature: Option<f32>, // Sampling temperature (default: 0.7)
)
```
### Request Options
```rust
CompletionRequest {
messages: Vec<Message>, // Conversation history
max_tokens: Option<u32>, // Override provider's max_tokens
temperature: Option<f32>, // Override provider's temperature
stream: bool, // Enable streaming responses
tools: Option<Vec<Tool>>, // Tools for function calling
}
```
## Comparison with Other Providers
| Feature | Ollama | OpenAI | Anthropic | Databricks |
|---------|--------|--------|-----------|------------|
| Local Execution | ✅ | ❌ | ❌ | ❌ |
| Authentication | None | API Key | API Key | OAuth/Token |
| Tool Calling | ✅ | ✅ | ✅ | ✅ |
| Streaming | ✅ | ✅ | ✅ | ✅ |
| Cost | Free | Paid | Paid | Paid |
| Privacy | High | Low | Low | Medium |
## Implementation Details
### API Endpoints
- **Chat Completion**: `POST /api/chat`
- **Model List**: `GET /api/tags`
### Response Format
Ollama uses a simple JSON-per-line streaming format:
```json
{"message":{"role":"assistant","content":"Hello"},"done":false}
{"message":{"role":"assistant","content":" there"},"done":false}
{"done":true,"prompt_eval_count":10,"eval_count":20}
```
### Tool Call Format
Tool calls are returned in the message structure:
```json
{
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": {"location": "Paris", "unit": "celsius"}
}
}
]
},
"done": true
}
```
## Troubleshooting
### Connection Errors
If you see connection errors, ensure Ollama is running:
```bash
# Check if Ollama is running
curl http://localhost:11434/api/version
# Start Ollama (if needed)
ollama serve
```
### Model Not Found
Pull the model first:
```bash
ollama pull llama3.2
ollama list # Check available models
```
### Performance Issues
- Use smaller models (1B, 3B) for faster responses
- Reduce `max_tokens` to limit generation length
- Enable GPU acceleration if available
- Consider quantized models (e.g., `llama3.2:3b-q4_0`)
## Testing
Run the included tests:
```bash
cargo test --package g3-providers ollama
```
All tests should pass:
```
running 4 tests
test ollama::tests::test_custom_base_url ... ok
test ollama::tests::test_message_conversion ... ok
test ollama::tests::test_provider_creation ... ok
test ollama::tests::test_tool_conversion ... ok
```
## Architecture
The provider follows the same architecture as other g3 providers:
1. **OllamaProvider**: Main struct implementing `LLMProvider` trait
2. **Request/Response Structures**: Internal types for Ollama API
3. **Streaming Parser**: Handles line-by-line JSON parsing
4. **Tool Call Handling**: Accumulates and converts tool calls
5. **Error Handling**: Robust error handling with retries
## Contributing
The provider is part of the g3-providers crate. To contribute:
1. Add features to `ollama.rs`
2. Update tests
3. Run `cargo test --package g3-providers`
4. Update this documentation
## License
Same as the g3 project.

View File

@@ -0,0 +1,26 @@
# Example G3 configuration using Ollama provider
# Copy this to ~/.config/g3/config.toml or ./g3.toml to use it
[providers]
default_provider = "ollama"
# Ollama configuration (local LLM)
[providers.ollama]
model = "llama3.2" # or qwen2.5, mistral, etc.
# base_url = "http://localhost:11434" # Optional, defaults to localhost
# max_tokens = 2048 # Optional
# temperature = 0.7 # Optional
# Optional: Specify different providers for coach and player in autonomous mode
# coach = "ollama" # Provider for coach (code reviewer)
# player = "ollama" # Provider for player (code implementer)
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
[computer_control]
enabled = false # Set to true to enable computer control (requires OS permissions)
require_confirmation = true
max_actions_per_second = 5

View File

@@ -1,4 +1,5 @@
use anyhow::Result;
use crossterm::style::{Color, SetForegroundColor, ResetColor};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
@@ -183,6 +184,10 @@ pub struct Cli {
#[arg(short, long)]
pub verbose: bool,
/// Enable manual control of context compaction (disables auto-compact at 90%)
#[arg(long = "manual-compact")]
pub manual_compact: bool,
/// Show the system prompt being sent to the LLM
#[arg(long)]
pub show_prompt: bool,
@@ -285,10 +290,6 @@ pub async fn run() -> Result<()> {
tracing_subscriber::registry().with(filter).init();
}
if !cli.machine {
info!("Starting G3 AI Coding Agent");
}
// Set up workspace directory
let workspace_dir = if let Some(ws) = &cli.workspace {
ws.clone()
@@ -324,10 +325,6 @@ pub async fn run() -> Result<()> {
project.ensure_workspace_exists()?;
project.enter_workspace()?;
if !cli.machine {
info!("Using workspace: {}", project.workspace().display());
}
// Load configuration with CLI overrides
let mut config = Config::load_with_overrides(
cli.config.as_deref(),
@@ -338,9 +335,6 @@ pub async fn run() -> Result<()> {
// Apply macax flag override
if cli.macax {
config.macax.enabled = true;
if !cli.machine {
info!("macOS Accessibility API tools enabled");
}
}
// Apply webdriver flag override
@@ -348,6 +342,11 @@ pub async fn run() -> Result<()> {
config.webdriver.enabled = true;
}
// Apply no-auto-compact flag override
if cli.manual_compact {
config.agent.auto_compact = false;
}
// Validate provider if specified
if let Some(ref provider) = cli.provider {
let valid_providers = ["anthropic", "databricks", "embedded", "openai"];
@@ -568,6 +567,11 @@ async fn run_accumulative_mode(
config.webdriver.enabled = true;
}
// Apply no-auto-compact flag override
if cli.manual_compact {
config.agent.auto_compact = false;
}
// Create agent for interactive mode with requirements context
let ui_writer = ConsoleUiWriter::new();
let agent = Agent::new_with_readme_and_quiet(
@@ -645,6 +649,11 @@ async fn run_accumulative_mode(
config.webdriver.enabled = true;
}
// Apply no-auto-compact flag override
if cli.manual_compact {
config.agent.auto_compact = false;
}
// Create agent for this autonomous run
let ui_writer = ConsoleUiWriter::new();
let agent = Agent::new_autonomous_with_readme_and_quiet(
@@ -765,9 +774,6 @@ async fn run_with_console_mode(
// Execute task, autonomous mode, or start interactive mode
if cli.autonomous {
// Autonomous mode with coach-player feedback loop
if !cli.machine {
info!("Starting autonomous mode");
}
run_autonomous(
agent,
project,
@@ -779,9 +785,6 @@ async fn run_with_console_mode(
.await?;
} else if let Some(task) = cli.task {
// Single-shot mode
if !cli.machine {
info!("Executing task: {}", task);
}
let output = SimpleOutput::new();
let result = agent
.execute_task_with_timing(&task, None, false, cli.show_prompt, cli.show_code, true)
@@ -789,9 +792,6 @@ async fn run_with_console_mode(
output.print_smart(&result.response);
} else {
// Interactive mode (default)
if !cli.machine {
info!("Starting interactive mode");
}
println!("📁 Workspace: {}", project.workspace().display());
run_interactive(agent, cli.show_prompt, cli.show_code, combined_content).await?;
}
@@ -840,7 +840,6 @@ fn read_agents_config(workspace_dir: &Path) -> Option<String> {
match std::fs::read_to_string(&agents_path) {
Ok(content) => {
// Return the content with a note about which file was read
info!("Loaded AGENTS.md from {}", agents_path.display());
Some(format!(
"🤖 Agent Configuration (from AGENTS.md):\n\n{}",
content
@@ -858,7 +857,6 @@ fn read_agents_config(workspace_dir: &Path) -> Option<String> {
if alt_path.exists() {
match std::fs::read_to_string(&alt_path) {
Ok(content) => {
info!("Loaded agents.md from {}", alt_path.display());
Some(format!("🤖 Agent Configuration (from agents.md):\n\n{}", content))
}
Err(e) => {
@@ -1479,10 +1477,32 @@ fn handle_execution_error(e: &anyhow::Error, input: &str, output: &SimpleOutput,
}
}
fn display_context_progress<W: UiWriter>(agent: &Agent<W>, output: &SimpleOutput) {
fn display_context_progress<W: UiWriter>(agent: &Agent<W>, _output: &SimpleOutput) {
let context = agent.get_context_window();
output.print(&format!("Context: {}/{} tokens ({:.1}%)",
context.used_tokens, context.total_tokens, context.percentage_used()));
let percentage = context.percentage_used();
// Create 10 dots representing context fullness
let total_dots: usize = 10;
let filled_dots = ((percentage / 100.0) * total_dots as f32).round() as usize;
let empty_dots = total_dots.saturating_sub(filled_dots);
let filled_str = "".repeat(filled_dots);
let empty_str = "".repeat(empty_dots);
// Determine color based on percentage
let color = if percentage < 40.0 {
Color::Green
} else if percentage < 60.0 {
Color::Yellow
} else if percentage < 80.0 {
Color::Rgb { r: 255, g: 165, b: 0 } // Orange
} else {
Color::Red
};
// Print with colored dots (using print! directly to handle color codes)
print!("Context: {}{}{}{} {:.0}% ({}/{} tokens)\n",
SetForegroundColor(color), filled_str, empty_str, ResetColor, percentage, context.used_tokens, context.total_tokens);
}
/// Set up the workspace directory for autonomous mode

View File

@@ -71,18 +71,20 @@ impl SimpleOutput {
}
pub fn print_context(&self, used: u32, total: u32, percentage: f32) {
let bar_width: usize = 10;
let filled_width = ((percentage / 100.0) * bar_width as f32) as usize;
let empty_width = bar_width.saturating_sub(filled_width);
let total_dots = 10;
let filled_dots = ((percentage / 100.0) * total_dots as f32) as usize;
let empty_dots = total_dots.saturating_sub(filled_dots);
let filled_chars = "".repeat(filled_width);
let empty_chars = "".repeat(empty_width);
let filled_str = "".repeat(filled_dots);
let empty_str = "".repeat(empty_dots);
// Determine color based on percentage
let color = if percentage < 60.0 {
let color = if percentage < 40.0 {
crossterm::style::Color::Green
} else if percentage < 80.0 {
} else if percentage < 60.0 {
crossterm::style::Color::Yellow
} else if percentage < 80.0 {
crossterm::style::Color::Rgb { r: 255, g: 165, b: 0 } // Orange
} else {
crossterm::style::Color::Red
};
@@ -90,9 +92,9 @@ impl SimpleOutput {
// Print with colored progress bar
print!("Context: ");
print!("{}", SetForegroundColor(color));
print!("{}{}", filled_chars, empty_chars);
print!("{}{}", filled_str, empty_str);
print!("{}", ResetColor);
println!(" {:.1}% | {}/{} tokens", percentage, used, total);
println!(" {:.0}% ({}/{} tokens)", percentage, used, total);
}
pub fn print_context_thinning(&self, message: &str) {

View File

@@ -17,6 +17,7 @@ pub struct ProvidersConfig {
pub anthropic: Option<AnthropicConfig>,
pub databricks: Option<DatabricksConfig>,
pub embedded: Option<EmbeddedConfig>,
pub ollama: Option<OllamaConfig>,
pub default_provider: String,
pub coach: Option<String>, // Provider to use for coach in autonomous mode
pub player: Option<String>, // Provider to use for player in autonomous mode
@@ -60,11 +61,20 @@ pub struct EmbeddedConfig {
pub threads: Option<u32>, // Number of CPU threads to use
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaConfig {
pub model: String,
pub base_url: Option<String>, // Default: http://localhost:11434
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
pub max_context_length: usize,
pub enable_streaming: bool,
pub timeout_seconds: u64,
pub auto_compact: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -127,6 +137,7 @@ impl Default for Config {
use_oauth: Some(true),
}),
embedded: None,
ollama: None,
default_provider: "databricks".to_string(),
coach: None, // Will use default_provider if not specified
player: None, // Will use default_provider if not specified
@@ -135,6 +146,7 @@ impl Default for Config {
max_context_length: 8192,
enable_streaming: true,
timeout_seconds: 60,
auto_compact: true,
},
computer_control: ComputerControlConfig::default(),
webdriver: WebDriverConfig::default(),
@@ -242,6 +254,7 @@ impl Config {
gpu_layers: Some(32),
threads: Some(8),
}),
ollama: None,
default_provider: "embedded".to_string(),
coach: None, // Will use default_provider if not specified
player: None, // Will use default_provider if not specified
@@ -250,6 +263,7 @@ impl Config {
max_context_length: 8192,
enable_streaming: true,
timeout_seconds: 60,
auto_compact: true,
},
computer_control: ComputerControlConfig::default(),
webdriver: WebDriverConfig::default(),

View File

@@ -25,3 +25,4 @@ chrono = { version = "0.4", features = ["serde"] }
rand = "0.8"
regex = "1.0"
shellexpand = "3.1"
serde_yaml = "0.9"

View File

@@ -0,0 +1,787 @@
//! Code search functionality using ast-grep for syntax-aware semantic searches
use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::process::Stdio;
use std::time::{Duration, Instant};
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command;
use tokio::sync::Semaphore;
use tracing::{debug, error, info, warn};
/// Maximum number of searches allowed per request
const MAX_SEARCHES: usize = 20;
/// Default timeout for individual searches in seconds
const DEFAULT_TIMEOUT_SECS: u64 = 60;
/// Default maximum concurrency
const DEFAULT_MAX_CONCURRENCY: usize = 4;
/// Default maximum matches per search
const DEFAULT_MAX_MATCHES: usize = 500;
/// Search specification for a single ast-grep search
#[derive(Debug, Clone, Deserialize)]
pub struct SearchSpec {
pub name: String,
pub mode: SearchMode,
// Pattern mode fields
pub pattern: Option<String>,
pub language: Option<String>,
// YAML mode fields
pub rule_yaml: Option<String>,
// Common fields
pub paths: Option<Vec<String>>,
pub globs: Option<Vec<String>>,
pub json_style: Option<JsonStyle>,
pub context: Option<u32>,
pub threads: Option<u32>,
pub include_metadata: Option<bool>,
pub no_ignore: Option<Vec<NoIgnoreType>>,
pub severity: Option<HashMap<String, SeverityLevel>>,
pub timeout_secs: Option<u64>,
}
/// Search mode: pattern or yaml
#[derive(Debug, Clone, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum SearchMode {
Pattern,
Yaml,
}
/// JSON output style
#[derive(Debug, Clone, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum JsonStyle {
Pretty,
Stream,
Compact,
}
impl Default for JsonStyle {
fn default() -> Self {
JsonStyle::Stream
}
}
/// No-ignore types
#[derive(Debug, Clone, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum NoIgnoreType {
Hidden,
Dot,
Exclude,
Global,
Parent,
Vcs,
}
/// Severity levels for YAML rules
#[derive(Debug, Clone, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum SeverityLevel {
Error,
Warning,
Info,
Hint,
Off,
}
/// Request structure for code search
#[derive(Debug, Deserialize)]
pub struct CodeSearchRequest {
pub searches: Vec<SearchSpec>,
pub max_concurrency: Option<usize>,
pub max_matches_per_search: Option<usize>,
}
/// Result of a single search
#[derive(Debug, Serialize)]
pub struct SearchResult {
pub name: String,
pub mode: String,
pub status: String,
pub cmd: Vec<String>,
pub match_count: Option<usize>,
pub truncated: Option<bool>,
pub matches: Option<Vec<Value>>,
pub stderr: Option<String>,
pub exit_code: Option<i32>,
pub duration_ms: u64,
}
/// Summary of all searches
#[derive(Debug, Serialize)]
pub struct SearchSummary {
pub completed: usize,
pub total: usize,
pub total_matches: usize,
pub duration_ms: u64,
}
/// Complete response structure
#[derive(Debug, Serialize)]
pub struct CodeSearchResponse {
pub summary: SearchSummary,
pub searches: Vec<SearchResult>,
}
/// YAML rule structure for validation
#[derive(Debug, Deserialize)]
struct YamlRule {
pub id: String,
pub language: String,
pub rule: Value,
}
/// Execute a batch of code searches using ast-grep
pub async fn execute_code_search(request: CodeSearchRequest) -> Result<CodeSearchResponse> {
let start_time = Instant::now();
// Validate request
if request.searches.is_empty() {
return Err(anyhow!("No searches specified"));
}
if request.searches.len() > MAX_SEARCHES {
return Err(anyhow!(
"Too many searches: {} (max: {})",
request.searches.len(),
MAX_SEARCHES
));
}
// Check if ast-grep is available
check_ast_grep_available().await?;
let max_concurrency = request.max_concurrency.unwrap_or(DEFAULT_MAX_CONCURRENCY);
let max_matches = request.max_matches_per_search.unwrap_or(DEFAULT_MAX_MATCHES);
// Create semaphore for concurrency control
let semaphore = std::sync::Arc::new(Semaphore::new(max_concurrency));
// Execute searches concurrently
let mut tasks = Vec::new();
for search in request.searches {
let sem = semaphore.clone();
let task = tokio::spawn(async move {
let _permit = sem.acquire().await.unwrap();
execute_single_search(search, max_matches).await
});
tasks.push(task);
}
// Wait for all searches to complete
let mut results = Vec::new();
let mut total_matches = 0;
let mut completed = 0;
for task in tasks {
match task.await {
Ok(result) => {
if result.status == "ok" {
completed += 1;
if let Some(count) = result.match_count {
total_matches += count;
}
}
results.push(result);
}
Err(e) => {
error!("Task join error: {}", e);
// Create an error result
results.push(SearchResult {
name: "unknown".to_string(),
mode: "unknown".to_string(),
status: "error".to_string(),
cmd: vec![],
match_count: None,
truncated: None,
matches: None,
stderr: Some(format!("Task execution error: {}", e)),
exit_code: None,
duration_ms: 0,
});
}
}
}
let total_duration = start_time.elapsed();
Ok(CodeSearchResponse {
summary: SearchSummary {
completed,
total: results.len(),
total_matches,
duration_ms: total_duration.as_millis() as u64,
},
searches: results,
})
}
/// Execute a single search
async fn execute_single_search(search: SearchSpec, max_matches: usize) -> SearchResult {
let start_time = Instant::now();
let timeout_secs = search.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
// Validate the search specification
if let Err(e) = validate_search_spec(&search) {
return SearchResult {
name: search.name,
mode: format!("{:?}", search.mode).to_lowercase(),
status: "error".to_string(),
cmd: vec![],
match_count: None,
truncated: None,
matches: None,
stderr: Some(format!("Validation error: {}", e)),
exit_code: None,
duration_ms: start_time.elapsed().as_millis() as u64,
};
}
// Build command
let cmd_args = match build_ast_grep_command(&search) {
Ok(args) => args,
Err(e) => {
return SearchResult {
name: search.name,
mode: format!("{:?}", search.mode).to_lowercase(),
status: "error".to_string(),
cmd: vec![],
match_count: None,
truncated: None,
matches: None,
stderr: Some(format!("Command build error: {}", e)),
exit_code: None,
duration_ms: start_time.elapsed().as_millis() as u64,
};
}
};
debug!("Executing ast-grep command: {:?}", cmd_args);
// Execute with timeout
let timeout_duration = Duration::from_secs(timeout_secs);
match tokio::time::timeout(timeout_duration, run_ast_grep_command(&cmd_args)).await {
Ok(Ok((stdout, stderr, exit_code))) => {
let duration_ms = start_time.elapsed().as_millis() as u64;
if exit_code == 0 {
// Parse JSON output
match parse_ast_grep_output(&stdout, max_matches) {
Ok((matches, truncated)) => {
SearchResult {
name: search.name,
mode: format!("{:?}", search.mode).to_lowercase(),
status: "ok".to_string(),
cmd: cmd_args,
match_count: Some(matches.len()),
truncated: Some(truncated),
matches: Some(matches),
stderr: if stderr.is_empty() { None } else { Some(stderr) },
exit_code: None,
duration_ms,
}
}
Err(e) => {
SearchResult {
name: search.name,
mode: format!("{:?}", search.mode).to_lowercase(),
status: "error".to_string(),
cmd: cmd_args,
match_count: None,
truncated: None,
matches: None,
stderr: Some(format!("JSON parse error: {}\nRaw output: {}", e, stdout)),
exit_code: Some(exit_code),
duration_ms,
}
}
}
} else {
SearchResult {
name: search.name,
mode: format!("{:?}", search.mode).to_lowercase(),
status: "error".to_string(),
cmd: cmd_args,
match_count: None,
truncated: None,
matches: None,
stderr: Some(stderr),
exit_code: Some(exit_code),
duration_ms,
}
}
}
Ok(Err(e)) => {
SearchResult {
name: search.name,
mode: format!("{:?}", search.mode).to_lowercase(),
status: "error".to_string(),
cmd: cmd_args,
match_count: None,
truncated: None,
matches: None,
stderr: Some(format!("Execution error: {}", e)),
exit_code: None,
duration_ms: start_time.elapsed().as_millis() as u64,
}
}
Err(_) => {
SearchResult {
name: search.name,
mode: format!("{:?}", search.mode).to_lowercase(),
status: "timeout".to_string(),
cmd: cmd_args,
match_count: None,
truncated: None,
matches: None,
stderr: Some(format!("Search timed out after {} seconds", timeout_secs)),
exit_code: None,
duration_ms: start_time.elapsed().as_millis() as u64,
}
}
}
}
/// Validate a search specification
fn validate_search_spec(search: &SearchSpec) -> Result<()> {
match search.mode {
SearchMode::Pattern => {
if search.pattern.is_none() || search.pattern.as_ref().unwrap().is_empty() {
return Err(anyhow!("Pattern mode requires non-empty 'pattern' field"));
}
}
SearchMode::Yaml => {
let rule_yaml = search.rule_yaml.as_ref()
.ok_or_else(|| anyhow!("YAML mode requires 'rule_yaml' field"))?;
if rule_yaml.is_empty() {
return Err(anyhow!("YAML mode requires non-empty 'rule_yaml' field"));
}
// Parse and validate YAML structure
let parsed: YamlRule = serde_yaml::from_str(rule_yaml)
.map_err(|e| anyhow!("Invalid YAML rule: {}", e))?;
if parsed.id.is_empty() {
return Err(anyhow!("YAML rule must have non-empty 'id' field"));
}
if parsed.language.is_empty() {
return Err(anyhow!("YAML rule must have non-empty 'language' field"));
}
// Validate language is supported (basic check)
validate_language(&parsed.language)?;
}
}
// Validate context range
if let Some(context) = search.context {
if context > 20 {
return Err(anyhow!("Context lines cannot exceed 20"));
}
}
Ok(())
}
/// Validate that a language is supported by ast-grep
fn validate_language(language: &str) -> Result<()> {
let supported_languages = [
"rust", "javascript", "typescript", "python", "java", "c", "cpp", "csharp",
"go", "html", "css", "json", "yaml", "xml", "bash", "kotlin", "swift",
"php", "ruby", "scala", "dart", "lua", "r", "sql", "dockerfile",
"Rust", "JavaScript", "TypeScript", "Python", "Java", "C", "Cpp", "CSharp",
"Go", "Html", "Css", "Json", "Yaml", "Xml", "Bash", "Kotlin", "Swift",
"Php", "Ruby", "Scala", "Dart", "Lua", "R", "Sql", "Dockerfile"
];
if !supported_languages.contains(&language) {
warn!("Language '{}' may not be supported by ast-grep", language);
}
Ok(())
}
/// Build ast-grep command arguments
fn build_ast_grep_command(search: &SearchSpec) -> Result<Vec<String>> {
let mut args = vec!["ast-grep".to_string()];
match search.mode {
SearchMode::Pattern => {
args.push("run".to_string());
// Add pattern
args.push("-p".to_string());
args.push(search.pattern.as_ref().unwrap().clone());
// Add language if specified
if let Some(ref lang) = search.language {
args.push("-l".to_string());
args.push(lang.clone());
}
}
SearchMode::Yaml => {
args.push("scan".to_string());
// Add inline rules
args.push("--inline-rules".to_string());
args.push(search.rule_yaml.as_ref().unwrap().clone());
// Add include-metadata if requested
if search.include_metadata.unwrap_or(false) {
args.push("--include-metadata".to_string());
}
// Add severity overrides
if let Some(ref severity_map) = search.severity {
for (rule_id, severity) in severity_map {
match severity {
SeverityLevel::Error => {
args.push("--error".to_string());
args.push(rule_id.clone());
}
SeverityLevel::Warning => {
args.push("--warning".to_string());
args.push(rule_id.clone());
}
SeverityLevel::Info => {
args.push("--info".to_string());
args.push(rule_id.clone());
}
SeverityLevel::Hint => {
args.push("--hint".to_string());
args.push(rule_id.clone());
}
SeverityLevel::Off => {
args.push("--off".to_string());
args.push(rule_id.clone());
}
}
}
}
}
}
// Add common arguments
// Add globs if specified
if let Some(ref globs) = search.globs {
if !globs.is_empty() {
args.push("--globs".to_string());
args.push(globs.join(","));
}
}
// Add context
if let Some(context) = search.context {
args.push("-C".to_string());
args.push(context.to_string());
}
// Add threads
if let Some(threads) = search.threads {
args.push("-j".to_string());
args.push(threads.to_string());
}
// Add JSON output style
let json_style = search.json_style.as_ref().unwrap_or(&JsonStyle::Stream);
let json_arg = match json_style {
JsonStyle::Pretty => "--json=pretty",
JsonStyle::Stream => "--json=stream",
JsonStyle::Compact => "--json=compact",
};
args.push(json_arg.to_string());
// Add no-ignore options
if let Some(ref no_ignore_list) = search.no_ignore {
for no_ignore_type in no_ignore_list {
let flag = match no_ignore_type {
NoIgnoreType::Hidden => "--no-ignore=hidden",
NoIgnoreType::Dot => "--no-ignore=dot",
NoIgnoreType::Exclude => "--no-ignore=exclude",
NoIgnoreType::Global => "--no-ignore=global",
NoIgnoreType::Parent => "--no-ignore=parent",
NoIgnoreType::Vcs => "--no-ignore=vcs",
};
args.push(flag.to_string());
}
}
// Add paths (default to current directory if none specified)
if let Some(ref paths) = search.paths {
if !paths.is_empty() {
args.extend(paths.clone());
} else {
args.push(".".to_string());
}
} else {
args.push(".".to_string());
}
Ok(args)
}
/// Run ast-grep command and capture output
async fn run_ast_grep_command(args: &[String]) -> Result<(String, String, i32)> {
let mut cmd = Command::new(&args[0]);
cmd.args(&args[1..]);
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
debug!("Running command: {:?}", args);
let mut child = cmd.spawn()
.map_err(|e| anyhow!("Failed to spawn ast-grep process: {}", e))?;
let stdout = child.stdout.take().unwrap();
let stderr = child.stderr.take().unwrap();
let stdout_reader = BufReader::new(stdout);
let stderr_reader = BufReader::new(stderr);
let stdout_task = tokio::spawn(async move {
let mut lines = stdout_reader.lines();
let mut output = String::new();
while let Ok(Some(line)) = lines.next_line().await {
if !output.is_empty() {
output.push('\n');
}
output.push_str(&line);
}
output
});
let stderr_task = tokio::spawn(async move {
let mut lines = stderr_reader.lines();
let mut output = String::new();
while let Ok(Some(line)) = lines.next_line().await {
if !output.is_empty() {
output.push('\n');
}
output.push_str(&line);
}
output
});
let status = child.wait().await
.map_err(|e| anyhow!("Failed to wait for ast-grep process: {}", e))?;
let stdout_output = stdout_task.await
.map_err(|e| anyhow!("Failed to read stdout: {}", e))?;
let stderr_output = stderr_task.await
.map_err(|e| anyhow!("Failed to read stderr: {}", e))?;
let exit_code = status.code().unwrap_or(-1);
Ok((stdout_output, stderr_output, exit_code))
}
/// Parse ast-grep JSON output
fn parse_ast_grep_output(output: &str, max_matches: usize) -> Result<(Vec<Value>, bool)> {
if output.trim().is_empty() {
return Ok((vec![], false));
}
let mut matches = Vec::new();
let mut truncated = false;
// Handle stream format (line-delimited JSON)
for line in output.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
match serde_json::from_str::<Value>(line) {
Ok(match_obj) => {
if matches.len() >= max_matches {
truncated = true;
break;
}
matches.push(match_obj);
}
Err(e) => {
debug!("Failed to parse JSON line '{}': {}", line, e);
// Try to parse the entire output as a single JSON array
match serde_json::from_str::<Vec<Value>>(output) {
Ok(array_matches) => {
let take_count = array_matches.len().min(max_matches);
let total_count = array_matches.len();
matches = array_matches.into_iter().take(take_count).collect();
truncated = take_count < total_count;
break;
}
Err(e2) => {
return Err(anyhow!(
"Failed to parse ast-grep output as line-delimited JSON or JSON array. Line error: {}, Array error: {}",
e, e2
));
}
}
}
}
}
Ok((matches, truncated))
}
/// Check if ast-grep is available and provide installation hints if not
async fn check_ast_grep_available() -> Result<()> {
match Command::new("ast-grep")
.arg("--version")
.output()
.await
{
Ok(output) => {
if output.status.success() {
let version = String::from_utf8_lossy(&output.stdout);
info!("Found ast-grep: {}", version.trim());
Ok(())
} else {
Err(anyhow!("ast-grep command failed: {}", String::from_utf8_lossy(&output.stderr)))
}
}
Err(_) => {
Err(anyhow!(
"ast-grep not found. Please install it using one of these methods:\n\n\
• Homebrew (macOS): brew install ast-grep\n\
• MacPorts (macOS): sudo port install ast-grep\n\
• Nix: nix-env -iA nixpkgs.ast-grep\n\
• Cargo: cargo install ast-grep\n\
• npm: npm install -g @ast-grep/cli\n\
• pip: pip install ast-grep\n\n\
For more installation options, visit: https://ast-grep.github.io/guide/quick-start.html"
))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_pattern_search() {
let search = SearchSpec {
name: "test".to_string(),
mode: SearchMode::Pattern,
pattern: Some("fn $NAME() {}".to_string()),
language: Some("rust".to_string()),
rule_yaml: None,
paths: None,
globs: None,
json_style: None,
context: None,
threads: None,
include_metadata: None,
no_ignore: None,
severity: None,
timeout_secs: None,
};
assert!(validate_search_spec(&search).is_ok());
}
#[test]
fn test_validate_yaml_search() {
let yaml_rule = r#"
id: test-rule
language: Rust
rule:
pattern: "fn $NAME() {}"
"#;
let search = SearchSpec {
name: "test".to_string(),
mode: SearchMode::Yaml,
pattern: None,
language: None,
rule_yaml: Some(yaml_rule.to_string()),
paths: None,
globs: None,
json_style: None,
context: None,
threads: None,
include_metadata: None,
no_ignore: None,
severity: None,
timeout_secs: None,
};
assert!(validate_search_spec(&search).is_ok());
}
#[test]
fn test_build_pattern_command() {
let search = SearchSpec {
name: "test".to_string(),
mode: SearchMode::Pattern,
pattern: Some("fn $NAME() {}".to_string()),
language: Some("rust".to_string()),
rule_yaml: None,
paths: Some(vec!["src/".to_string()]),
globs: None,
json_style: Some(JsonStyle::Stream),
context: Some(2),
threads: Some(4),
include_metadata: None,
no_ignore: None,
severity: None,
timeout_secs: None,
};
let cmd = build_ast_grep_command(&search).unwrap();
assert_eq!(cmd[0], "ast-grep");
assert_eq!(cmd[1], "run");
assert!(cmd.contains(&"-p".to_string()));
assert!(cmd.contains(&"fn $NAME() {}".to_string()));
assert!(cmd.contains(&"-l".to_string()));
assert!(cmd.contains(&"rust".to_string()));
assert!(cmd.contains(&"--json=stream".to_string()));
assert!(cmd.contains(&"-C".to_string()));
assert!(cmd.contains(&"2".to_string()));
assert!(cmd.contains(&"-j".to_string()));
assert!(cmd.contains(&"4".to_string()));
assert!(cmd.contains(&"src/".to_string()));
}
#[test]
fn test_parse_stream_json() {
let output = r#"{"file":"test.rs","text":"fn hello() {}"}
{"file":"test2.rs","text":"fn world() {}"}"#;
let (matches, truncated) = parse_ast_grep_output(output, 10).unwrap();
assert_eq!(matches.len(), 2);
assert!(!truncated);
assert_eq!(matches[0]["file"], "test.rs");
assert_eq!(matches[1]["file"], "test2.rs");
}
#[test]
fn test_parse_truncated_output() {
let output = r#"{"file":"test1.rs","text":"fn a() {}"}
{"file":"test2.rs","text":"fn b() {}"}
{"file":"test3.rs","text":"fn c() {}"}"#;
let (matches, truncated) = parse_ast_grep_output(output, 2).unwrap();
assert_eq!(matches.len(), 2);
assert!(truncated);
}
}

View File

@@ -4,6 +4,11 @@
// 3. Only elide JSON content between first '{' and last '}' (inclusive)
// 4. Return everything else as the final filtered string
//! JSON tool call filtering for streaming LLM responses.
//!
//! This module filters out JSON tool calls from LLM output streams while preserving
//! regular text content. It uses a state machine to handle streaming chunks.
use regex::Regex;
use std::cell::RefCell;
use tracing::debug;
@@ -13,37 +18,51 @@ thread_local! {
static FIXED_JSON_TOOL_STATE: RefCell<FixedJsonToolState> = RefCell::new(FixedJsonToolState::new());
}
/// Internal state for tracking JSON tool call filtering across streaming chunks.
#[derive(Debug, Clone)]
struct FixedJsonToolState {
/// True when actively suppressing a confirmed tool call
suppression_mode: bool,
/// True when buffering potential JSON (saw { but not yet confirmed as tool call)
potential_json_mode: bool,
/// Tracks nesting depth of braces within JSON
brace_depth: i32,
buffer: String,
json_start_in_buffer: Option<usize>,
json_start_in_buffer: Option<usize>, // Position where confirmed JSON tool call starts
content_returned_up_to: usize, // Track how much content we've already returned
potential_json_start: Option<usize>, // Where the potential JSON started
}
impl FixedJsonToolState {
fn new() -> Self {
Self {
suppression_mode: false,
potential_json_mode: false,
brace_depth: 0,
buffer: String::new(),
json_start_in_buffer: None,
content_returned_up_to: 0,
potential_json_start: None,
}
}
fn reset(&mut self) {
self.suppression_mode = false;
self.potential_json_mode = false;
self.brace_depth = 0;
self.buffer.clear();
self.json_start_in_buffer = None;
self.content_returned_up_to = 0;
self.potential_json_start = None;
}
}
// FINAL CORRECTED implementation according to specification
/// Filters JSON tool calls from streaming LLM content.
///
/// Processes content chunks and removes JSON tool calls while preserving regular text.
/// Maintains state across calls to handle tool calls spanning multiple chunks.
pub fn fixed_filter_json_tool_calls(content: &str) -> String {
if content.is_empty() {
return String::new();
@@ -87,13 +106,225 @@ pub fn fixed_filter_json_tool_calls(content: &str) -> String {
_ => {}
}
}
// CRITICAL FIX: After counting braces, if still in suppression mode,
// check if a new tool call pattern appears. This handles truncated JSON
// followed by complete JSON.
if state.suppression_mode {
let current_json_start = state.json_start_in_buffer.unwrap();
// Don't require newline - the new JSON might be concatenated directly
let tool_call_regex = Regex::new(r#"\{\s*"tool"\s*:\s*""#).unwrap();
// Look for new tool call patterns after the current one
if let Some(captures) = tool_call_regex.find(&state.buffer[current_json_start + 1..]) {
let new_json_start = current_json_start + 1 + captures.start() + captures.as_str().find('{').unwrap();
debug!("Detected new tool call at position {} while processing incomplete one at {} - discarding old", new_json_start, current_json_start);
// The previous JSON was incomplete/malformed
// Return content before the old JSON (if any)
let content_before_old_json = if current_json_start > state.content_returned_up_to {
state.buffer[state.content_returned_up_to..current_json_start].to_string()
} else {
String::new()
};
// Update state to skip the incomplete JSON and position at the new one
// We'll process the new JSON on the next call
state.content_returned_up_to = new_json_start;
state.suppression_mode = false;
state.json_start_in_buffer = None;
state.brace_depth = 0;
return content_before_old_json;
}
}
// Still in suppression mode, return empty string (content is being accumulated)
return String::new();
}
// Check if we're in potential JSON mode (saw { but waiting to confirm it's a tool call)
if state.potential_json_mode {
// Check if the buffer contains a confirmed tool call pattern
let tool_call_regex = Regex::new(r#"(?m)^\s*\{\s*"tool"\s*:\s*""#).unwrap();
if let Some(captures) = tool_call_regex.find(&state.buffer) {
// Confirmed! This is a tool call - enter suppression mode
let match_text = captures.as_str();
if let Some(brace_offset) = match_text.find('{') {
let json_start = captures.start() + brace_offset;
debug!("Confirmed JSON tool call at position {} - entering suppression mode", json_start);
state.potential_json_mode = false;
state.suppression_mode = true;
state.brace_depth = 0;
state.json_start_in_buffer = Some(json_start);
// Count braces from json_start to see if JSON is complete
let buffer_slice = state.buffer[json_start..].to_string();
for ch in buffer_slice.chars() {
match ch {
'{' => state.brace_depth += 1,
'}' => {
state.brace_depth -= 1;
if state.brace_depth <= 0 {
debug!("JSON tool call completed immediately");
let result = extract_fixed_content(&state.buffer, json_start);
let new_content = if result.len() > state.content_returned_up_to {
result[state.content_returned_up_to..].to_string()
} else {
String::new()
};
state.reset();
return new_content;
}
}
_ => {}
}
}
// JSON incomplete, stay in suppression mode, return nothing
return String::new();
}
}
// Check if we can rule out this being a tool call
// If we have enough content after the { and it doesn't match the pattern, release it
if let Some(potential_start) = state.potential_json_start {
let content_after_brace = &state.buffer[potential_start..];
// Rule out as a tool call if:
// 1. Closing } appears before we see the full pattern
// 2. Content clearly doesn't match the tool call pattern
// 3. Newline appears after the opening brace (tool calls should be compact)
let has_closing_brace = content_after_brace.contains('}');
let has_newline = content_after_brace[1..].contains('\n'); // Skip first char which is {
let long_enough = content_after_brace.len() >= 10;
// Detect non-tool JSON patterns:
// - { followed by " and a key that doesn't start with "tool"
// - { followed by "t" but not "to"
// - { followed by "to" but not "too", etc.
let not_tool_pattern = Regex::new(r#"^\{\s*"(?:[^t]|t(?:[^o]|o(?:[^o]|o(?:[^l]|l[^"\s:]))))"#).unwrap();
let definitely_not_tool = not_tool_pattern.is_match(content_after_brace);
if has_closing_brace || has_newline || (long_enough && definitely_not_tool) {
debug!("Potential JSON ruled out - not a tool call");
state.potential_json_mode = false;
state.potential_json_start = None;
// Return the buffered content we've been holding
let new_content = if state.buffer.len() > state.content_returned_up_to {
state.buffer[state.content_returned_up_to..].to_string()
} else {
String::new()
};
state.content_returned_up_to = state.buffer.len();
return new_content;
}
}
// Still in potential mode, keep buffering
return String::new();
}
// Detect potential JSON start: { at the beginning of a line
let potential_json_regex = Regex::new(r"(?m)^\s*\{\s*").unwrap();
if let Some(captures) = potential_json_regex.find(&state.buffer[state.content_returned_up_to..]) {
let match_start = state.content_returned_up_to + captures.start();
let brace_pos = match_start + captures.as_str().find('{').unwrap();
debug!("Potential JSON detected at position {} - entering buffering mode", brace_pos);
// Fast path: check if this is already a confirmed tool call
let tool_call_regex = Regex::new(r#"(?m)^\s*\{\s*"tool"\s*:\s*""#).unwrap();
if tool_call_regex.is_match(&state.buffer[brace_pos..]) {
// This is a confirmed tool call! Process it immediately
let json_start = brace_pos;
debug!("Immediately confirmed tool call at position {}", json_start);
// Return content before JSON
let content_before = if json_start > state.content_returned_up_to {
state.buffer[state.content_returned_up_to..json_start].to_string()
} else {
String::new()
};
state.content_returned_up_to = json_start;
state.suppression_mode = true;
state.brace_depth = 0;
state.json_start_in_buffer = Some(json_start);
// Count braces to see if JSON is complete
let buffer_slice = state.buffer[json_start..].to_string();
for ch in buffer_slice.chars() {
match ch {
'{' => state.brace_depth += 1,
'}' => {
state.brace_depth -= 1;
if state.brace_depth <= 0 {
debug!("JSON tool call completed in same chunk");
let result = extract_fixed_content(&state.buffer, json_start);
let content_after = if result.len() > json_start {
&result[json_start..]
} else {
""
};
let final_result = format!("{}{}", content_before, content_after);
state.reset();
return final_result;
}
}
_ => {}
}
}
// JSON incomplete, return content before and stay in suppression mode
return content_before;
}
// Return content before the potential JSON
let content_before = if brace_pos > state.content_returned_up_to {
state.buffer[state.content_returned_up_to..brace_pos].to_string()
} else {
String::new()
};
state.content_returned_up_to = brace_pos;
state.potential_json_mode = true;
state.potential_json_start = Some(brace_pos);
// Optimization: immediately check if we can rule this out for single-chunk processing
let content_after_brace = &state.buffer[brace_pos..];
let has_closing_brace = content_after_brace.contains('}');
let has_newline = content_after_brace.len() > 1 && content_after_brace[1..].contains('\n');
let long_enough = content_after_brace.len() >= 10;
let not_tool_pattern = Regex::new(r#"^\{\s*"(?:[^t]|t(?:[^o]|o(?:[^o]|o(?:[^l]|l[^"\s:]))))"#).unwrap();
let definitely_not_tool = not_tool_pattern.is_match(content_after_brace);
if has_closing_brace || has_newline || (long_enough && definitely_not_tool) {
debug!("Immediately ruled out as not a tool call");
state.potential_json_mode = false;
state.potential_json_start = None;
// Return all the buffered content
let new_content = if state.buffer.len() > state.content_returned_up_to {
state.buffer[state.content_returned_up_to..].to_string()
} else {
String::new()
};
state.content_returned_up_to = state.buffer.len();
return format!("{}{}", content_before, new_content);
}
return content_before;
}
// Check for tool call pattern using corrected regex
// More flexible than the strict specification to handle real-world JSON
let tool_call_regex = Regex::new(r#"(?m)^\s*\{\s*"tool"\s*:\s*""#).unwrap();
let tool_call_regex = Regex::new(r#"(?m)^\s*\{\s*"tool"\s*:\s*"[^"]*""#).unwrap();
if let Some(captures) = tool_call_regex.find(&state.buffer) {
let match_text = captures.as_str();
@@ -168,9 +399,17 @@ pub fn fixed_filter_json_tool_calls(content: &str) -> String {
})
}
// Helper function to extract content with JSON tool call filtered out
// Returns everything except the JSON between the first '{' and last '}' (inclusive)
/// Extracts content from buffer, removing the JSON tool call.
///
/// Given a buffer and the start position of a JSON tool call, this function:
/// 1. Extracts all content before the JSON
/// 2. Finds the end of the JSON (matching closing brace)
/// 3. Extracts all content after the JSON
/// 4. Returns the concatenation of before + after (JSON removed)
///
/// # Arguments
/// * `full_content` - The full content buffer
/// * `json_start` - Position where the JSON tool call begins
fn extract_fixed_content(full_content: &str, json_start: usize) -> String {
// Find the end of the JSON using proper brace counting with string handling
let mut brace_depth = 0;
@@ -212,8 +451,10 @@ fn extract_fixed_content(full_content: &str, json_start: usize) -> String {
format!("{}{}", before, after)
}
// Reset function for testing
/// Resets the global JSON filtering state.
///
/// Call this between independent filtering sessions to ensure clean state.
/// This is particularly important in tests and when starting new conversations.
pub fn reset_fixed_json_tool_state() {
FIXED_JSON_TOOL_STATE.with(|state| {
let mut state = state.borrow_mut();

View File

@@ -1,8 +1,14 @@
//! Tests for JSON tool call filtering.
//!
//! These tests verify that the filter correctly identifies and removes JSON tool calls
//! from LLM output streams while preserving all other content.
#[cfg(test)]
mod fixed_filter_tests {
use crate::fixed_filter_json::{fixed_filter_json_tool_calls, reset_fixed_json_tool_state};
use regex::Regex;
/// Test that regular text without tool calls passes through unchanged.
#[test]
fn test_no_tool_call_passthrough() {
reset_fixed_json_tool_state();
@@ -11,6 +17,7 @@ mod fixed_filter_tests {
assert_eq!(result, input);
}
/// Test detection and removal of a complete tool call in a single chunk.
#[test]
fn test_simple_tool_call_detection() {
reset_fixed_json_tool_state();
@@ -23,6 +30,7 @@ Some text after"#;
assert_eq!(result, expected);
}
/// Test handling of tool calls that arrive across multiple streaming chunks.
#[test]
fn test_streaming_chunks() {
reset_fixed_json_tool_state();
@@ -48,6 +56,7 @@ Some text after"#;
assert_eq!(final_result, expected);
}
/// Test correct handling of nested braces within JSON strings.
#[test]
fn test_nested_braces_in_tool_call() {
reset_fixed_json_tool_state();
@@ -61,6 +70,7 @@ Text after"#;
assert_eq!(result, expected);
}
/// Verify the regex pattern matches the specification with flexible whitespace.
#[test]
fn test_regex_pattern_specification() {
// Test the corrected regex pattern that's more flexible with whitespace
@@ -84,11 +94,6 @@ Text after"#;
), // Space after { DOES match with \s*
(
r#"line
abc{"tool":"#,
true,
),
(
r#"line
{"tool123":"#,
false,
), // "tool123" is not exactly "tool"
@@ -109,6 +114,7 @@ abc{"tool":"#,
}
}
/// Test that tool calls must appear at the start of a line (after newline).
#[test]
fn test_newline_requirement() {
reset_fixed_json_tool_state();
@@ -122,13 +128,14 @@ abc{"tool":"#,
reset_fixed_json_tool_state();
let result2 = fixed_filter_json_tool_calls(input_without_newline);
// Both cases currently trigger suppression due to regex pattern
// TODO: Fix regex to only match after actual newlines
// With the new aggressive filtering, only the newline case should trigger suppression
// The pattern requires { to be at the start of a line (after ^)
assert_eq!(result1, "Text\n");
// This currently fails because our regex matches both cases
assert_eq!(result2, "Text ");
// Without newline before {, it should pass through unchanged
assert_eq!(result2, input_without_newline);
}
/// Test handling of escaped quotes within JSON strings.
#[test]
fn test_json_with_escaped_quotes() {
reset_fixed_json_tool_state();
@@ -142,6 +149,7 @@ More text"#;
assert_eq!(result, expected);
}
/// Test graceful handling of incomplete/malformed JSON.
#[test]
fn test_edge_case_malformed_json() {
reset_fixed_json_tool_state();
@@ -157,6 +165,7 @@ More text"#;
assert_eq!(result, expected);
}
/// Test processing multiple independent tool calls sequentially.
#[test]
fn test_multiple_tool_calls_sequential() {
reset_fixed_json_tool_state();
@@ -179,6 +188,7 @@ Final text"#;
assert_eq!(result2, expected2);
}
/// Test tool calls with complex multi-line arguments.
#[test]
fn test_tool_call_with_complex_args() {
reset_fixed_json_tool_state();
@@ -192,6 +202,7 @@ After"#;
assert_eq!(result, expected);
}
/// Test input containing only a tool call with no surrounding text.
#[test]
fn test_tool_call_only() {
reset_fixed_json_tool_state();
@@ -204,6 +215,7 @@ After"#;
assert_eq!(result, expected);
}
/// Test accurate brace counting with deeply nested structures.
#[test]
fn test_brace_counting_accuracy() {
reset_fixed_json_tool_state();
@@ -218,6 +230,7 @@ End"#;
assert_eq!(result, expected);
}
/// Test that braces within strings don't affect brace counting.
#[test]
fn test_string_escaping_in_json() {
reset_fixed_json_tool_state();
@@ -232,6 +245,7 @@ More"#;
assert_eq!(result, expected);
}
/// Verify compliance with the exact specification requirements.
#[test]
fn test_specification_compliance() {
reset_fixed_json_tool_state();
@@ -248,6 +262,7 @@ More"#;
assert_eq!(result, expected);
}
/// Test that non-tool JSON objects are not filtered.
#[test]
fn test_no_false_positives() {
reset_fixed_json_tool_state();
@@ -261,6 +276,7 @@ More text"#;
assert_eq!(result, input);
}
/// Test patterns that look similar to tool calls but aren't exact matches.
#[test]
fn test_partial_tool_patterns() {
reset_fixed_json_tool_state();
@@ -280,6 +296,7 @@ More text"#;
}
}
/// Test streaming with very small chunks (character-by-character).
#[test]
fn test_streaming_edge_cases() {
reset_fixed_json_tool_state();
@@ -296,12 +313,13 @@ More text"#;
}
let final_result: String = results.join("");
// This test currently fails because the JSON is incomplete across chunks
// The function doesn't handle this edge case properly yet
let expected = "Text\n{\"tool\": \nAfter";
// With the new aggressive filtering, the JSON should be completely filtered out
// even when it arrives in very small chunks
let expected = "Text\n\nAfter";
assert_eq!(final_result, expected);
}
/// Debug test with detailed logging for streaming behavior.
#[test]
fn test_streaming_debug() {
reset_fixed_json_tool_state();
@@ -329,4 +347,38 @@ More text"#;
let expected = "Some text before\n\nText after";
assert_eq!(final_result, expected);
}
/// Test handling of truncated JSON followed by complete JSON (the json_err pattern)
#[test]
fn test_truncated_then_complete_json() {
reset_fixed_json_tool_state();
// Simulate the pattern from json_err trace:
// 1. Incomplete/truncated JSON appears
// 2. Then the same complete JSON appears
let chunks = vec![
"Some text\n",
r#"{"tool": "str_replace", "args": {"diff":"...","file_path":"./crates/g3-cli"#, // Truncated
r#"{"tool": "str_replace", "args": {"diff":"...","file_path":"./crates/g3-cli/src/lib.rs"}}"#, // Complete
"\nMore text",
];
let mut results = Vec::new();
for (i, chunk) in chunks.iter().enumerate() {
let result = fixed_filter_json_tool_calls(chunk);
println!("Chunk {}: {:?} -> {:?}", i, chunk, result);
results.push(result);
}
let final_result: String = results.join("");
println!("Final result: {:?}", final_result);
// The truncated JSON should be discarded when the complete one appears
// Both JSONs should be filtered out, leaving only the text
let expected = "Some text\n\nMore text";
assert_eq!(
final_result, expected,
"Failed to handle truncated JSON followed by complete JSON"
);
}
}

View File

@@ -2,6 +2,7 @@ pub mod error_handling;
pub mod project;
pub mod task_result;
pub mod ui_writer;
pub mod code_search;
pub use task_result::TaskResult;
#[cfg(test)]
@@ -680,6 +681,8 @@ pub struct Agent<W: UiWriter> {
providers: ProviderRegistry,
context_window: ContextWindow,
thinning_events: Vec<usize>, // chars saved per thinning event
pending_90_summarization: bool, // flag to trigger summarization at 90%
auto_compact: bool, // whether to auto-compact at 90% before tool calls
summarization_events: Vec<usize>, // chars saved per summarization event
first_token_times: Vec<Duration>, // time to first token for each completion
config: Config,
@@ -785,7 +788,6 @@ impl<W: UiWriter> Agent<W> {
// Register embedded provider if configured AND it's the default provider
if let Some(embedded_config) = &config.providers.embedded {
if providers_to_register.contains(&"embedded".to_string()) {
info!("Initializing embedded provider");
let embedded_provider = g3_providers::EmbeddedProvider::new(
embedded_config.model_path.clone(),
embedded_config.model_type.clone(),
@@ -796,15 +798,12 @@ impl<W: UiWriter> Agent<W> {
embedded_config.threads,
)?;
providers.register(embedded_provider);
} else {
info!("Embedded provider configured but not needed, skipping initialization");
}
}
// Register OpenAI provider if configured AND it's the default provider
if let Some(openai_config) = &config.providers.openai {
if providers_to_register.contains(&"openai".to_string()) {
info!("Initializing OpenAI provider");
let openai_provider = g3_providers::OpenAIProvider::new(
openai_config.api_key.clone(),
Some(openai_config.model.clone()),
@@ -813,15 +812,12 @@ impl<W: UiWriter> Agent<W> {
openai_config.temperature,
)?;
providers.register(openai_provider);
} else {
info!("OpenAI provider configured but not needed, skipping initialization");
}
}
// Register Anthropic provider if configured AND it's the default provider
if let Some(anthropic_config) = &config.providers.anthropic {
if providers_to_register.contains(&"anthropic".to_string()) {
info!("Initializing Anthropic provider");
let anthropic_provider = g3_providers::AnthropicProvider::new(
anthropic_config.api_key.clone(),
Some(anthropic_config.model.clone()),
@@ -829,15 +825,12 @@ impl<W: UiWriter> Agent<W> {
anthropic_config.temperature,
)?;
providers.register(anthropic_provider);
} else {
info!("Anthropic provider configured but not needed, skipping initialization");
}
}
// Register Databricks provider if configured AND it's the default provider
if let Some(databricks_config) = &config.providers.databricks {
if providers_to_register.contains(&"databricks".to_string()) {
info!("Initializing Databricks provider");
let databricks_provider = if let Some(token) = &databricks_config.token {
// Use token-based authentication
@@ -860,8 +853,19 @@ impl<W: UiWriter> Agent<W> {
};
providers.register(databricks_provider);
} else {
info!("Databricks provider configured but not needed, skipping initialization");
}
}
// Register Ollama provider if configured AND it's the default provider
if let Some(ollama_config) = &config.providers.ollama {
if providers_to_register.contains(&"ollama".to_string()) {
let ollama_provider = g3_providers::OllamaProvider::new(
ollama_config.model.clone(),
ollama_config.base_url.clone(),
ollama_config.max_tokens,
ollama_config.temperature,
)?;
providers.register(ollama_provider);
}
}
@@ -884,16 +888,12 @@ impl<W: UiWriter> Agent<W> {
content: readme,
};
context_window.add_message(readme_message);
info!("Added project README to context window");
}
// Initialize computer controller if enabled
let computer_controller = if config.computer_control.enabled {
match g3_computer_control::create_controller() {
Ok(controller) => {
info!("Computer control enabled");
Some(controller)
}
Ok(controller) => Some(controller),
Err(e) => {
warn!("Failed to initialize computer control: {}", e);
None
@@ -909,6 +909,8 @@ impl<W: UiWriter> Agent<W> {
Ok(Self {
providers,
context_window,
auto_compact: config.agent.auto_compact,
pending_90_summarization: false,
thinning_events: Vec::new(),
summarization_events: Vec::new(),
first_token_times: Vec::new(),
@@ -973,6 +975,30 @@ impl<W: UiWriter> Agent<W> {
16384 // Conservative default for other Databricks models
}
}
"ollama" => {
// Ollama model context windows based on model name
if model_name.contains("qwen3-coder") {
262144 // Qwen3-coder supports 256k context
} else if model_name.contains("qwen") {
32768 // Qwen2.5 supports 32k context
} else if model_name.contains("gpt-oss") {
131072 // GPT-OSS supports 128k context
} else if model_name.contains("llama3") || model_name.contains("llama-3") {
if model_name.contains("3.2") || model_name.contains("3.1") {
128000 // Llama 3.1/3.2 support 128k context
} else {
8192 // Llama 3.0
}
} else if model_name.contains("mistral") || model_name.contains("mixtral") {
32768 // Mistral/Mixtral support 32k
} else if model_name.contains("gemma") {
8192 // Gemma 2
} else if model_name.contains("phi") {
4096 // Phi-3
} else {
8192 // Conservative default for Ollama models
}
}
_ => config.agent.max_context_length as u32,
};
@@ -1104,6 +1130,18 @@ IMPORTANT: You must call tools to achieve goals. When you receive a request:
For shell commands: Use the shell tool with the exact command needed. Avoid commands that produce a large amount of output, and consider piping those outputs to files. Example: If asked to list files, immediately call the shell tool with command parameter \"ls\".
If you create temporary files for verification, place these in a subdir named 'tmp'. Do NOT pollute the current dir.
For reading files, prioritize use of code_search tool use with multiple search requests per call instead of read_file, if it makes sense.
Additional examples for the 'code_search' tool:
- Example for pattern mode: {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"find_functions\", \"mode\": \"pattern\", \"pattern\": \"fn $NAME($$$ARGS) { $$$ }\", \"language\": \"rust\", \"paths\": [\"src/\"]}]}}
- Example for YAML mode: {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"find_async\", \"mode\": \"yaml\", \"rule_yaml\": \"id: async-fn\nlanguage: Rust\nrule:\n pattern: async fn $NAME($$$) { $$$ }\"}]}}
- Example for multiple searches: {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"funcs\", \"mode\": \"pattern\", \"pattern\": \"fn $NAME\", \"language\": \"rust\"}, {\"name\": \"structs\", \"mode\": \"pattern\", \"pattern\": \"struct $NAME\", \"language\": \"rust\"}]}}
- Example for passing optional args like \"context\": {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"funcs\", \"mode\": \"pattern\", \"context\": 3, \"pattern\": \"fn $NAME\", \"language\": \"rust\"}]}
- Common optional args for searches:
- \"context\": 3 (show surrounding lines),
- \"json_style\": \"stream\" (for large results)
IMPORTANT: If the user asks you to just respond with text (like \"just say hello\" or \"tell me about X\"), do NOT use tools. Simply respond with the requested text directly. Only use tools when you need to execute commands or complete tasks that require action.
When taking screenshots of specific windows (like \"my Safari window\" or \"my terminal\"), ALWAYS use list_windows first to identify the correct window ID, then use take_screenshot with the window_id parameter.
@@ -1153,6 +1191,8 @@ The tool will execute immediately and you'll receive the result (success or erro
# Available Tools
Short description for providers without native calling specs:
- **shell**: Execute shell commands
- Format: {\"tool\": \"shell\", \"args\": {\"command\": \"your_command_here\"}
- Example: {\"tool\": \"shell\", \"args\": {\"command\": \"ls ~/Downloads\"}
@@ -1181,13 +1221,41 @@ The tool will execute immediately and you'll receive the result (success or erro
- Format: {\"tool\": \"todo_write\", \"args\": {\"content\": \"- [ ] Task 1\\n- [ ] Task 2\"}}
- Example: {\"tool\": \"todo_write\", \"args\": {\"content\": \"- [ ] Implement feature\\n - [ ] Write tests\\n - [ ] Run tests\"}}
- **code_search**: Batch syntax-aware searches via ast-grep. Supports up to 20 pattern or YAML-rule searches in parallel.
- Format: {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"search_label\", \"mode\": \"pattern|yaml\", ...}], \"max_concurrency\": 4, \"max_matches_per_search\": 500}}
- Example for pattern mode: {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"find_functions\", \"mode\": \"pattern\", \"pattern\": \"fn $NAME($$$ARGS) { $$$ }\", \"language\": \"rust\", \"paths\": [\"src/\"]}]}}
- Example for YAML mode: {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"find_async\", \"mode\": \"yaml\", \"rule_yaml\": \"id: async-fn\nlanguage: Rust\nrule:\n pattern: async fn $NAME($$$) { $$$ }\"}]}}
- Example for multiple searches: {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"funcs\", \"mode\": \"pattern\", \"pattern\": \"fn $NAME\", \"language\": \"rust\"}, {\"name\": \"structs\", \"mode\": \"pattern\", \"pattern\": \"struct $NAME\", \"language\": \"rust\"}]}}
- Example for passing optional args like \"context\": {\"tool\": \"code_search\", \"args\": {\"searches\": [{\"name\": \"funcs\", \"mode\": \"pattern\", \"context\": 3, \"pattern\": \"fn $NAME\", \"language\": \"rust\"}]}
- Common optional args for searches:
- \"context\": 3 (show surrounding lines),
- \"json_style\": \"stream\" (for large results)
# Instructions
1. Analyze the request and break down into smaller tasks if appropriate
2. Execute ONE tool at a time
2. Execute ONE tool at a time. An exception exists for when you're writing files. See below.
3. STOP when the original request was satisfied
4. Call the final_output tool when done
For reading files, prioritize use of code_search tool use with multiple search requests per call instead of read_file, if it makes sense.
Exception to using ONE tool at a time:
If all youre doing is WRITING files, and you dont need to do anything else between each step.
You can issue MULTIPLE write_file tool calls in a request, however you may ONLY make a SINGLE write_file call for any file in that request.
For example you may call:
[START OF REQUEST]
write_file(\"helper.rs\", \"...\")
write_file(\"file2.txt\", \"...\")
[DONE]
But NOT:
[START OF REQUEST]
write_file(\"helper.rs\", \"...\")
write_file(\"file2.txt\", \"...\")
write_file(\"helper.rs\", \"...\")
[DONE]
# Task Management
Use todo_read and todo_write for tasks with 3+ steps, multiple files/components, or uncertain scope.
@@ -1312,6 +1380,19 @@ Template:
// Save context window at the end of successful interaction
self.save_context_window("completed");
// Check if we need to do 90% auto-compaction
if self.pending_90_summarization {
self.ui_writer.print_context_status(
"\n⚡ Context window reached 90% - auto-compacting...\n"
);
if let Err(e) = self.force_summarize().await {
warn!("Failed to auto-compact at 90%: {}", e);
} else {
self.ui_writer.println("");
}
self.pending_90_summarization = false;
}
// Return the task result which already includes timing if needed
Ok(task_result)
}
@@ -1860,6 +1941,64 @@ Template:
}),
},
];
// Add code_search tool
tools.push(Tool {
name: "code_search".to_string(),
description: "Batch syntax-aware searches via ast-grep. Supports up to 20 pattern or YAML-rule searches in parallel; returns JSON matches (stream-collated).".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"searches": {
"type": "array",
"maxItems": 20,
"items": {
"type": "object",
"properties": {
"name": { "type": "string", "description": "Label for this search." },
"mode": {
"type": "string",
"enum": ["pattern", "yaml"],
"description": "`pattern` uses `ast-grep run`; `yaml` uses `ast-grep scan --inline-rules`."
},
// pattern mode (fast one-off)
"pattern": { "type": "string", "description": "ast-grep pattern code (e.g., \"async fn $NAME($$$ARGS) { $$$ }\")"},
"language": { "type": "string", "description": "Optional language for pattern mode; ast-grep may infer from file extensions if omitted." },
// yaml mode (full rule object)
"rule_yaml": { "type": "string", "description": "A full YAML rule object text. Must include `id`, `language`, and `rule`." },
// targeting
"paths": { "type": "array", "items": { "type": "string" }, "description": "Paths/dirs to search. Defaults to current dir if empty." },
"globs": { "type": "array", "items": { "type": "string" }, "description": "Optional include/exclude globs for CLI --globs." },
// result formatting & performance knobs
"json_style": { "type": "string", "enum": ["pretty","stream","compact"], "default": "stream", "description": "Use stream for large codebases." },
"context": { "type": "integer", "minimum": 0, "maximum": 20, "default": 0, "description": "CLI -C context lines in text output; also affects JSON `lines` field." },
"threads": { "type": "integer", "minimum": 1, "description": "Optional override for ast-grep -j (per process)." },
"include_metadata": { "type": "boolean", "default": false, "description": "If yaml mode and rule has metadata, add --include-metadata." },
// robustness
"no_ignore": {
"type": "array",
"items": { "type": "string", "enum": ["hidden","dot","exclude","global","parent","vcs"] },
"description": "Forwarded to --no-ignore to bypass ignore files/hidden."
},
// severity overrides for yaml mode
"severity": {
"type": "object",
"additionalProperties": { "type": "string", "enum": ["error","warning","info","hint","off"] },
"description": "Optional map<ruleId, severity> -> passed via --error/--warning/--info/--hint/--off."
},
// per-search timeout seconds (default 60)
"timeout_secs": { "type": "integer", "minimum": 1, "default": 60 }
},
"required": ["name","mode"]
}
},
// global concurrency & truncation
"max_concurrency": { "type": "integer", "minimum": 1, "default": 4 },
"max_matches_per_search": { "type": "integer", "minimum": 1, "default": 500 }
},
"required": ["searches"]
}),
});
// Add WebDriver tools if enabled
if enable_webdriver {
@@ -2550,6 +2689,14 @@ Template:
if let Some(tool_call) = completed_tools.into_iter().next() {
debug!("Processing completed tool call: {:?}", tool_call);
// Check if we should auto-compact at 90% BEFORE executing the tool
// We need to do this before any borrows of self
if self.auto_compact && self.context_window.percentage_used() >= 90.0 {
// Set flag to trigger summarization after this turn completes
// We can't do it now due to borrow checker constraints
self.pending_90_summarization = true;
}
// Check if we should thin the context BEFORE executing the tool
if self.context_window.should_thin() {
let (thin_summary, chars_saved) = self.context_window.thin_context();
@@ -2558,6 +2705,7 @@ Template:
self.ui_writer.print_context_thinning(&thin_summary);
}
// Track what we've already displayed before getting new text
// This prevents re-displaying old content after tool execution
let already_displayed_chars = current_response.chars().count();
@@ -2596,7 +2744,8 @@ Template:
String::new()
};
if !new_content.trim().is_empty() {
// Don't display text before final_output - it will be in the summary
if !new_content.trim().is_empty() && tool_call.tool != "final_output" {
#[allow(unused_assignments)]
if !response_started {
self.ui_writer.print_agent_prompt();
@@ -2674,7 +2823,13 @@ Template:
));
// Display tool execution result with proper indentation
if tool_call.tool != "final_output" {
if tool_call.tool == "final_output" {
// For final_output, display the summary without truncation
for line in tool_result.lines() {
self.ui_writer.update_tool_output_line(line);
}
self.ui_writer.println("");
} else {
let output_lines: Vec<&str> = tool_result.lines().collect();
// Check if UI wants full output (machine mode) or truncated (human mode)
@@ -2722,13 +2877,9 @@ Template:
// Check if this was a final_output tool call
if tool_call.tool == "final_output" {
// Don't add final_display_content here - it was already added before tool execution
// Adding it again would duplicate the output
if let Some(summary) = tool_call.args.get("summary") {
if let Some(summary_str) = summary.as_str() {
full_response.push_str(&format!("\n\n{}", summary_str));
}
}
// The summary was displayed above when we printed the tool result
// Add it to full_response so it's included in the TaskResult
full_response.push_str(&tool_result);
self.ui_writer.println("");
let _ttft =
first_token_time.unwrap_or_else(|| stream_start.elapsed());
@@ -2760,7 +2911,7 @@ Template:
// Add the tool call and result to the context window using RAW unfiltered content
// This ensures the log file contains the true raw content including JSON tool calls
let tool_message = if !full_response.contains(final_display_content) && !raw_content_for_log.trim().is_empty() {
let tool_message = if !raw_content_for_log.trim().is_empty() {
Message {
role: MessageRole::Assistant,
content: format!(
@@ -2771,7 +2922,7 @@ Template:
),
}
} else {
// If we've already added the text or there's no text, just include the tool call
// No text content before tool call, just include the tool call
Message {
role: MessageRole::Assistant,
content: format!(
@@ -2796,18 +2947,22 @@ Template:
request.tools = Some(Self::create_tool_definitions(self.config.webdriver.enabled, self.config.macax.enabled, self.config.computer_control.enabled));
}
// Only add to full_response if we haven't already added it
if !full_response.contains(final_display_content) {
full_response.push_str(final_display_content);
}
// DO NOT add final_display_content to full_response here!
// The content was already displayed during streaming and added to current_response.
// Adding it again would cause duplication when the agent message is printed.
// The only time we should add to full_response is:
// 1. For final_output tool (handled separately)
// 2. At the end when no tools were executed (handled in the "no tool executed" branch)
tool_executed = true;
// Reset the JSON tool call filter state after each tool execution
// This ensures the filter doesn't stay in suppression mode for subsequent streaming content
fixed_filter_json::reset_fixed_json_tool_state();
// Reset parser for next iteration
// Reset parser for next iteration - this clears the text buffer
parser.reset();
// Clear current_response for next iteration to prevent buffered text
// from being incorrectly displayed after tool execution
current_response.clear();
@@ -2883,7 +3038,8 @@ Template:
"Using filtered parser text as last resort: {} chars",
filtered_text.len()
);
current_response = filtered_text;
// Note: This assignment is currently unused but kept for potential future use
let _ = filtered_text;
}
}
@@ -2985,11 +3141,10 @@ Template:
// Set full_response to current_response (don't append)
// current_response already contains everything that was displayed
// Appending would duplicate the output
if !current_response.is_empty() && full_response.is_empty() {
full_response = current_response.clone();
debug!("Set full_response from current_response (no tool): {} chars", full_response.len());
}
// Don't set full_response here - it would duplicate the output
// The text was already displayed during streaming
// Return empty string to avoid duplication
full_response = String::new();
self.ui_writer.println("");
let _ttft =
@@ -3017,17 +3172,33 @@ Template:
}
Err(e) => {
// Capture detailed streaming error information
let error_details =
format!("Streaming error at chunk {}: {}", chunks_received + 1, e);
error!("{}", error_details);
let error_msg = e.to_string();
let error_details = format!("Streaming error at chunk {}: {}", chunks_received + 1, error_msg);
error!("Error type: {}", std::any::type_name_of_val(&e));
error!("Parser state at error: text_buffer_len={}, native_tool_calls={}, message_stopped={}",
parser.text_buffer_len(), parser.native_tool_calls.len(), parser.is_message_stopped());
// Store the error for potential logging later
_last_error = Some(error_details);
_last_error = Some(error_details.clone());
// Check if this is a recoverable connection error
let is_connection_error = error_msg.contains("unexpected EOF")
|| error_msg.contains("connection")
|| error_msg.contains("chunk size line")
|| error_msg.contains("body error");
if is_connection_error {
warn!("Connection error at chunk {}, treating as end of stream", chunks_received + 1);
// If we have any content or tool calls, treat this as a graceful end
if chunks_received > 0 && (!parser.get_text_content().is_empty() || parser.native_tool_calls.len() > 0) {
warn!("Stream terminated unexpectedly but we have content, continuing");
break; // Break to process what we have
}
}
if tool_executed {
error!("{}", error_details);
warn!("Stream error after tool execution, attempting to continue");
break; // Break to outer loop to start new stream
} else {
@@ -4431,6 +4602,41 @@ Template:
Ok("❌ Computer control not enabled. Set computer_control.enabled = true in config.".to_string())
}
}
"code_search" => {
debug!("Processing code_search tool call");
// Parse the request
let request: crate::code_search::CodeSearchRequest = match serde_json::from_value(tool_call.args.clone()) {
Ok(req) => req,
Err(e) => {
return Ok(format!("❌ Invalid code_search arguments: {}", e));
}
};
// Execute the code search
match crate::code_search::execute_code_search(request).await {
Ok(response) => {
// Serialize the response to JSON
match serde_json::to_string_pretty(&response) {
Ok(json_output) => {
Ok(format!("✅ Code search completed\n{}", json_output))
}
Err(e) => {
Ok(format!("❌ Failed to serialize response: {}", e))
}
}
}
Err(e) => {
// Check if it's an ast-grep not found error and provide helpful message
let error_msg = e.to_string();
if error_msg.contains("ast-grep not found") {
Ok(format!("{}", error_msg))
} else {
Ok(format!("❌ Code search failed: {}", error_msg))
}
}
}
}
_ => {
warn!("Unknown tool: {}", tool_call.tool);
Ok(format!("❓ Unknown tool: {}", tool_call.tool))

View File

@@ -276,6 +276,7 @@ impl AnthropicProvider {
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
let mut accumulated_usage: Option<Usage> = None;
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
let mut message_stopped = false; // Track if we've received message_stop
while let Some(chunk_result) = stream.next().await {
match chunk_result {
@@ -316,6 +317,12 @@ impl AnthropicProvider {
continue;
}
// If we've already sent the final chunk, skip processing more events
if message_stopped {
debug!("Skipping event after message_stop: {}", line);
continue;
}
// Parse Server-Sent Events format
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
@@ -451,6 +458,7 @@ impl AnthropicProvider {
}
"message_stop" => {
debug!("Received message stop event");
message_stopped = true;
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
@@ -460,7 +468,8 @@ impl AnthropicProvider {
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
}
return accumulated_usage;
// Don't return here - let the stream naturally exhaust
// This prevents dropping the sender prematurely
}
"error" => {
if let Some(error) = event.error {
@@ -468,7 +477,7 @@ impl AnthropicProvider {
let _ = tx
.send(Err(anyhow!("Anthropic API error: {:?}", error)))
.await;
return accumulated_usage;
break; // Break to let stream exhaust naturally
}
}
_ => {
@@ -487,7 +496,10 @@ impl AnthropicProvider {
Err(e) => {
error!("Stream error: {}", e);
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
return accumulated_usage;
// Don't return here either - let the stream exhaust naturally
// The error has been sent to the receiver, so it will handle it
// Breaking here ensures we clean up properly
break;
}
}
}

View File

@@ -298,6 +298,7 @@ impl DatabricksProvider {
let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> =
std::collections::HashMap::new(); // index -> (id, name, args)
let mut incomplete_data_line = String::new(); // Buffer for incomplete data: lines
let mut chunk_count = 0;
let accumulated_usage: Option<Usage> = None;
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
@@ -305,6 +306,8 @@ impl DatabricksProvider {
match chunk_result {
Ok(chunk) => {
// Debug: Log raw bytes received
chunk_count += 1;
debug!("Processing chunk #{}", chunk_count);
debug!("Raw SSE bytes received: {} bytes", chunk.len());
// Append new bytes to our buffer
@@ -589,13 +592,39 @@ impl DatabricksProvider {
}
}
Err(e) => {
error!("Stream error: {}", e);
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
error!("Stream error at chunk {}: {}", chunk_count, e);
// Check if this is a connection error that might be recoverable
let error_msg = e.to_string();
if error_msg.contains("unexpected EOF") || error_msg.contains("connection") {
warn!("Connection terminated unexpectedly at chunk {}, treating as end of stream", chunk_count);
// Don't send error, just break and finalize
break;
} else {
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
}
return accumulated_usage;
}
}
}
// Log final state
debug!("Stream ended after {} chunks", chunk_count);
debug!("Final state: buffer_len={}, incomplete_data_line_len={}, byte_buffer_len={}",
buffer.len(), incomplete_data_line.len(), byte_buffer.len());
debug!("Accumulated tool calls: {}", current_tool_calls.len());
// If we have any remaining data in buffers, log it for debugging
if !buffer.is_empty() {
debug!("Remaining buffer content: {:?}", buffer);
}
if !byte_buffer.is_empty() {
debug!("Remaining byte buffer: {} bytes", byte_buffer.len());
}
if !incomplete_data_line.is_empty() {
debug!("Remaining incomplete data line: {:?}", incomplete_data_line);
}
// If we have any incomplete data line at the end, try to process it
if !incomplete_data_line.is_empty() {
debug!(

View File

@@ -88,11 +88,13 @@ pub mod anthropic;
pub mod databricks;
pub mod embedded;
pub mod oauth;
pub mod ollama;
pub mod openai;
pub use anthropic::AnthropicProvider;
pub use databricks::DatabricksProvider;
pub use embedded::EmbeddedProvider;
pub use ollama::OllamaProvider;
pub use openai::OpenAIProvider;
/// Provider registry for managing multiple LLM providers

View File

@@ -0,0 +1,751 @@
//! Ollama LLM provider implementation for the g3-providers crate.
//!
//! This module provides an implementation of the `LLMProvider` trait for Ollama,
//! supporting both completion and streaming modes with native tool calling.
//!
//! # Features
//!
//! - Support for any Ollama model (llama3.2, mistral, qwen, etc.)
//! - Both completion and streaming response modes
//! - Native tool calling support for compatible models
//! - Configurable base URL (defaults to http://localhost:11434)
//! - Simple configuration with no authentication required
//!
//! # Usage
//!
//! ```rust,no_run
//! use g3_providers::{OllamaProvider, LLMProvider, CompletionRequest, Message, MessageRole};
//!
//! #[tokio::main]
//! async fn main() -> anyhow::Result<()> {
//! // Create the provider with default settings (localhost:11434)
//! let provider = OllamaProvider::new(
//! "llama3.2".to_string(),
//! None, // Optional: base_url
//! None, // Optional: max tokens
//! None, // Optional: temperature
//! )?;
//!
//! // Create a completion request
//! let request = CompletionRequest {
//! messages: vec![
//! Message {
//! role: MessageRole::User,
//! content: "Hello! How are you?".to_string(),
//! },
//! ],
//! max_tokens: Some(1000),
//! temperature: Some(0.7),
//! stream: false,
//! tools: None,
//! };
//!
//! // Get a completion
//! let response = provider.complete(request).await?;
//! println!("Response: {}", response.content);
//!
//! Ok(())
//! }
//! ```
use anyhow::{anyhow, Result};
use bytes::Bytes;
use futures_util::stream::StreamExt;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error, info, warn};
use crate::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
MessageRole, Tool, ToolCall, Usage,
};
const DEFAULT_BASE_URL: &str = "http://localhost:11434";
const DEFAULT_TIMEOUT_SECS: u64 = 600;
pub const OLLAMA_DEFAULT_MODEL: &str = "llama3.2";
pub const OLLAMA_KNOWN_MODELS: &[&str] = &[
"llama3.2",
"llama3.2:1b",
"llama3.2:3b",
"llama3.1",
"llama3.1:8b",
"llama3.1:70b",
"mistral",
"mistral-nemo",
"mixtral",
"qwen2.5",
"qwen2.5:7b",
"qwen2.5:14b",
"qwen2.5:32b",
"qwen2.5-coder",
"qwen2.5-coder:7b",
"qwen3-coder",
"phi3",
"gemma2",
];
#[derive(Debug, Clone)]
pub struct OllamaProvider {
client: Client,
base_url: String,
model: String,
max_tokens: Option<u32>,
temperature: f32,
}
impl OllamaProvider {
pub fn new(
model: String,
base_url: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
) -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.build()
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
let base_url = base_url
.unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
.trim_end_matches('/')
.to_string();
info!(
"Initialized Ollama provider with model: {} at {}",
model, base_url
);
Ok(Self {
client,
base_url,
model,
max_tokens,
temperature: temperature.unwrap_or(0.7),
})
}
fn convert_tools(&self, tools: &[Tool]) -> Vec<OllamaTool> {
tools
.iter()
.map(|tool| OllamaTool {
r#type: "function".to_string(),
function: OllamaFunction {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.input_schema.clone(),
},
})
.collect()
}
fn convert_messages(&self, messages: &[Message]) -> Result<Vec<OllamaMessage>> {
let mut ollama_messages = Vec::new();
for message in messages {
let role = match message.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
};
ollama_messages.push(OllamaMessage {
role: role.to_string(),
content: message.content.clone(),
tool_calls: None, // Only used in responses
});
}
if ollama_messages.is_empty() {
return Err(anyhow!("At least one message is required"));
}
Ok(ollama_messages)
}
fn create_request_body(
&self,
messages: &[Message],
tools: Option<&[Tool]>,
streaming: bool,
max_tokens: Option<u32>,
temperature: f32,
) -> Result<OllamaRequest> {
let ollama_messages = self.convert_messages(messages)?;
let ollama_tools = tools.map(|t| self.convert_tools(t));
let mut options = OllamaOptions {
temperature,
num_predict: max_tokens,
};
// If max_tokens is provided, use it; otherwise use the instance default
if max_tokens.is_none() {
options.num_predict = self.max_tokens;
}
let request = OllamaRequest {
model: self.model.clone(),
messages: ollama_messages,
tools: ollama_tools,
stream: streaming,
options,
};
Ok(request)
}
async fn parse_streaming_response(
&self,
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
tx: mpsc::Sender<Result<CompletionChunk>>,
) -> Option<Usage> {
let mut buffer = String::new();
let mut accumulated_usage: Option<Usage> = None;
let mut current_tool_calls: Vec<OllamaToolCall> = Vec::new();
let mut byte_buffer = Vec::new();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
// Append new bytes to our buffer
byte_buffer.extend_from_slice(&chunk);
// Try to convert the entire buffer to UTF-8
let chunk_str = match std::str::from_utf8(&byte_buffer) {
Ok(s) => {
let result = s.to_string();
byte_buffer.clear();
result
}
Err(e) => {
let valid_up_to = e.valid_up_to();
if valid_up_to > 0 {
let valid_bytes =
byte_buffer.drain(..valid_up_to).collect::<Vec<_>>();
std::str::from_utf8(&valid_bytes).unwrap().to_string()
} else {
continue;
}
}
};
buffer.push_str(&chunk_str);
// Process complete lines
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].trim().to_string();
buffer.drain(..line_end + 1);
if line.is_empty() {
continue;
}
// Ollama streaming sends JSON objects per line
match serde_json::from_str::<OllamaStreamChunk>(&line) {
Ok(chunk) => {
// Handle text content
if let Some(message) = &chunk.message {
let content = &message.content;
if !content.is_empty() {
debug!("Sending text chunk: '{}'", content);
let chunk = CompletionChunk {
content: content.clone(),
finished: false,
usage: None,
tool_calls: None,
};
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return accumulated_usage;
}
}
// Handle tool calls
if let Some(tool_calls) = &message.tool_calls {
current_tool_calls.extend(tool_calls.clone());
}
}
// Check if stream is done
if chunk.done.unwrap_or(false) {
debug!("Stream completed");
// Update usage if available
if let Some(eval_count) = chunk.eval_count {
accumulated_usage = Some(Usage {
prompt_tokens: chunk.prompt_eval_count.unwrap_or(0),
completion_tokens: eval_count,
total_tokens: chunk.prompt_eval_count.unwrap_or(0)
+ eval_count,
});
}
// Send final chunk with tool calls if any
let final_tool_calls: Vec<ToolCall> = current_tool_calls
.iter()
.map(|tc| ToolCall {
id: tc.function.name.clone(), // Ollama doesn't provide IDs
tool: tc.function.name.clone(),
args: tc.function.arguments.clone(),
})
.collect();
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if final_tool_calls.is_empty() {
None
} else {
Some(final_tool_calls)
},
};
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
}
return accumulated_usage;
}
}
Err(e) => {
debug!("Failed to parse Ollama stream chunk: {} - Line: {}", e, line);
// Don't error out, just continue
}
}
}
}
Err(e) => {
error!("Stream error: {}", e);
let error_msg = e.to_string();
if error_msg.contains("unexpected EOF") || error_msg.contains("connection") {
warn!("Connection terminated unexpectedly, treating as end of stream");
break;
} else {
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
}
return accumulated_usage;
}
}
}
// Send final chunk if we haven't already
let final_tool_calls: Vec<ToolCall> = current_tool_calls
.iter()
.map(|tc| ToolCall {
id: tc.function.name.clone(),
tool: tc.function.name.clone(),
args: tc.function.arguments.clone(),
})
.collect();
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if final_tool_calls.is_empty() {
None
} else {
Some(final_tool_calls)
},
};
let _ = tx.send(Ok(final_chunk)).await;
accumulated_usage
}
/// Fetch available models from the Ollama instance
pub async fn fetch_available_models(&self) -> Result<Vec<String>> {
let response = self
.client
.get(format!("{}/api/tags", self.base_url))
.send()
.await
.map_err(|e| anyhow!("Failed to fetch Ollama models: {}", e))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!(
"Failed to fetch Ollama models: {} - {}",
status,
error_text
));
}
let json: serde_json::Value = response.json().await?;
let models = json
.get("models")
.and_then(|v| v.as_array())
.ok_or_else(|| anyhow!("Unexpected response format: missing 'models' array"))?;
let model_names: Vec<String> = models
.iter()
.filter_map(|model| model.get("name").and_then(|n| n.as_str()).map(String::from))
.collect();
debug!("Found {} models in Ollama", model_names.len());
Ok(model_names)
}
}
#[async_trait::async_trait]
impl LLMProvider for OllamaProvider {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
debug!(
"Processing Ollama completion request with {} messages",
request.messages.len()
);
let max_tokens = request.max_tokens.or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
let request_body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
false,
max_tokens,
temperature,
)?;
debug!(
"Sending request to Ollama API: model={}, temperature={}",
self.model, request_body.options.temperature
);
let response = self
.client
.post(format!("{}/api/chat", self.base_url))
.json(&request_body)
.send()
.await
.map_err(|e| anyhow!("Failed to send request to Ollama API: {}", e))?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!("Ollama API error {}: {}", status, error_text));
}
let response_text = response.text().await?;
debug!("Raw Ollama API response: {}", response_text);
let ollama_response: OllamaResponse =
serde_json::from_str(&response_text).map_err(|e| {
anyhow!(
"Failed to parse Ollama response: {} - Response: {}",
e,
response_text
)
})?;
let content = ollama_response.message.content.clone();
let usage = Usage {
prompt_tokens: ollama_response.prompt_eval_count.unwrap_or(0),
completion_tokens: ollama_response.eval_count.unwrap_or(0),
total_tokens: ollama_response.prompt_eval_count.unwrap_or(0)
+ ollama_response.eval_count.unwrap_or(0),
};
debug!(
"Ollama completion successful: {} tokens generated",
usage.completion_tokens
);
Ok(CompletionResponse {
content,
usage,
model: self.model.clone(),
})
}
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
debug!(
"Processing Ollama request (non-streaming) with {} messages",
request.messages.len()
);
if let Some(ref tools) = request.tools {
debug!("Request has {} tools", tools.len());
for tool in tools.iter().take(5) {
debug!(" Tool: {}", tool.name);
}
}
let max_tokens = request.max_tokens.or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
let request_body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
false, // Use non-streaming mode to avoid streaming bugs
max_tokens,
temperature,
)?;
debug!(
"Sending request to Ollama API (stream=false): model={}, temperature={}",
self.model, request_body.options.temperature
);
let response = self
.client
.post(format!("{}/api/chat", self.base_url))
.json(&request_body)
.send()
.await
.map_err(|e| anyhow!("Failed to send request to Ollama API: {}", e))?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!("Ollama API error {}: {}", status, error_text));
}
// For non-streaming, parse the complete JSON response
let response_text = response.text().await?;
debug!("Raw Ollama API response: {}", response_text);
let ollama_response: OllamaResponse =
serde_json::from_str(&response_text).map_err(|e| {
anyhow!(
"Failed to parse Ollama response: {} - Response: {}",
e,
response_text
)
})?;
let (tx, rx) = mpsc::channel(100);
tokio::spawn(async move {
let content = ollama_response.message.content;
let usage = Usage {
prompt_tokens: ollama_response.prompt_eval_count.unwrap_or(0),
completion_tokens: ollama_response.eval_count.unwrap_or(0),
total_tokens: ollama_response.prompt_eval_count.unwrap_or(0)
+ ollama_response.eval_count.unwrap_or(0),
};
// Extract tool calls if present
let tool_calls: Option<Vec<ToolCall>> = ollama_response.message.tool_calls.map(|tcs| {
tcs.iter()
.map(|tc| ToolCall {
id: tc.function.name.clone(),
tool: tc.function.name.clone(),
args: tc.function.arguments.clone(),
})
.collect()
});
// Send content if any
if !content.is_empty() {
let _ = tx.send(Ok(CompletionChunk {
content,
finished: false,
usage: None,
tool_calls: None,
})).await;
}
// Send final chunk with usage and tool calls
let _ = tx.send(Ok(CompletionChunk {
content: String::new(),
finished: true,
usage: Some(usage),
tool_calls,
})).await;
});
Ok(ReceiverStream::new(rx))
}
fn name(&self) -> &str {
"ollama"
}
fn model(&self) -> &str {
&self.model
}
fn has_native_tool_calling(&self) -> bool {
// Most modern Ollama models support tool calling
// Models like llama3.2, qwen2.5, mistral, etc. have good tool support
true
}
}
// Ollama API request/response structures
#[derive(Debug, Serialize)]
struct OllamaRequest {
model: String,
messages: Vec<OllamaMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<OllamaTool>>,
stream: bool,
options: OllamaOptions,
}
#[derive(Debug, Serialize)]
struct OllamaOptions {
temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
num_predict: Option<u32>, // Ollama's equivalent of max_tokens
}
#[derive(Debug, Serialize)]
struct OllamaTool {
r#type: String,
function: OllamaFunction,
}
#[derive(Debug, Serialize)]
struct OllamaFunction {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct OllamaMessage {
role: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<OllamaToolCall>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct OllamaToolCall {
function: OllamaToolCallFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct OllamaToolCallFunction {
name: String,
arguments: serde_json::Value,
}
#[derive(Debug, Deserialize)]
struct OllamaResponse {
message: OllamaMessage,
#[allow(dead_code)]
done: bool,
#[allow(dead_code)]
total_duration: Option<u64>,
#[allow(dead_code)]
load_duration: Option<u64>,
prompt_eval_count: Option<u32>,
eval_count: Option<u32>,
}
#[derive(Debug, Deserialize)]
struct OllamaStreamChunk {
message: Option<OllamaMessage>,
done: Option<bool>,
prompt_eval_count: Option<u32>,
eval_count: Option<u32>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_creation() {
let provider = OllamaProvider::new(
"llama3.2".to_string(),
None,
Some(1000),
Some(0.7),
)
.unwrap();
assert_eq!(provider.model(), "llama3.2");
assert_eq!(provider.name(), "ollama");
assert!(provider.has_native_tool_calling());
}
#[test]
fn test_message_conversion() {
let provider = OllamaProvider::new(
"llama3.2".to_string(),
None,
None,
None,
)
.unwrap();
let messages = vec![
Message {
role: MessageRole::System,
content: "You are a helpful assistant.".to_string(),
},
Message {
role: MessageRole::User,
content: "Hello!".to_string(),
},
];
let ollama_messages = provider.convert_messages(&messages).unwrap();
assert_eq!(ollama_messages.len(), 2);
assert_eq!(ollama_messages[0].role, "system");
assert_eq!(ollama_messages[1].role, "user");
}
#[test]
fn test_tool_conversion() {
let provider = OllamaProvider::new(
"llama3.2".to_string(),
None,
None,
None,
)
.unwrap();
let tools = vec![Tool {
name: "get_weather".to_string(),
description: "Get the current weather".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}),
}];
let ollama_tools = provider.convert_tools(&tools);
assert_eq!(ollama_tools.len(), 1);
assert_eq!(ollama_tools[0].r#type, "function");
assert_eq!(ollama_tools[0].function.name, "get_weather");
}
#[test]
fn test_custom_base_url() {
let provider = OllamaProvider::new(
"llama3.2".to_string(),
Some("http://custom:11434".to_string()),
None,
None,
)
.unwrap();
assert_eq!(provider.base_url, "http://custom:11434");
}
}