Compare commits

...

9 Commits

Author SHA1 Message Date
Michael Neale
79b375519b use proper context for qwen3-coder 2025-11-05 13:55:45 +11:00
Michael Neale
88c3cc23fe works better without streaming 2025-11-05 12:55:21 +11:00
Michael Neale
4622507f37 gpt context aware 2025-11-05 12:25:02 +11:00
Michael Neale
217df2f2af ollama support 2025-11-05 12:17:01 +11:00
Dhanji Prasanna
22a0090cdc fix unexpected EOF on streams 2025-11-04 16:28:41 +11:00
Dhanji Prasanna
631f3c16ca compact on tool call if > 90% 2025-11-04 14:35:11 +11:00
Dhanji Prasanna
1f9fef5f18 more json filtering 2025-11-03 11:56:16 +11:00
Dhanji Prasanna
57d473c19d mild json filtering improvement 2025-11-03 11:54:27 +11:00
Jochen
e59ce2f93f Merge pull request #16 from dhanji/jochen-ast-tool
adds ast-grep tool for faster code exploration
2025-11-02 21:04:11 +11:00
12 changed files with 2014 additions and 66 deletions

1
.gitignore vendored
View File

@@ -3,6 +3,7 @@
debug
target
.build
appy/
# These are backup files generated by rustfmt
**/*.rs.bk

456
OLLAMA_CONFIG.md Normal file
View File

@@ -0,0 +1,456 @@
# Configuring Ollama Provider in G3
This guide shows you how to configure G3 to use Ollama as your LLM provider.
## Quick Start
### 1. Install Ollama
```bash
# Visit https://ollama.ai to download and install
# Or use curl:
curl https://ollama.ai/install.sh | sh
```
### 2. Pull a Model
```bash
ollama pull llama3.2
# or any other model you prefer
```
### 3. Create Configuration File
Copy the example configuration:
```bash
cp config.ollama.example.toml ~/.config/g3/config.toml
```
Or create it manually:
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "llama3.2"
```
### 4. Run G3
```bash
g3
# G3 will now use Ollama with llama3.2!
```
## Configuration Options
### Basic Configuration
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "llama3.2"
```
This is the minimal configuration needed. It uses all defaults:
- Base URL: `http://localhost:11434`
- Temperature: `0.7`
- Max tokens: Not limited (uses model default)
### Full Configuration
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "llama3.2"
base_url = "http://localhost:11434"
max_tokens = 2048
temperature = 0.7
```
### Custom Ollama Host
If you're running Ollama on a different machine or port:
```toml
[providers.ollama]
model = "llama3.2"
base_url = "http://192.168.1.100:11434"
```
### Different Models
You can use any Ollama model:
```toml
[providers.ollama]
model = "qwen2.5:7b" # Alibaba's Qwen model
```
```toml
[providers.ollama]
model = "mistral" # Mistral AI
```
```toml
[providers.ollama]
model = "llama3.1:70b" # Larger Llama model
```
## Multiple Provider Configuration
You can configure multiple providers and switch between them:
```toml
[providers]
default_provider = "ollama" # Default for most operations
# Ollama for local, fast responses
[providers.ollama]
model = "llama3.2:3b"
temperature = 0.7
# Databricks for more complex tasks
[providers.databricks]
host = "https://your-workspace.cloud.databricks.com"
model = "databricks-claude-sonnet-4"
max_tokens = 4096
temperature = 0.1
use_oauth = true
```
Then switch providers with:
```bash
g3 --provider databricks
```
## Autonomous Mode (Coach-Player)
Use different providers for code review (coach) and implementation (player):
```toml
[providers]
default_provider = "ollama"
coach = "databricks" # Use powerful cloud model for review
player = "ollama" # Use local model for implementation
[providers.ollama]
model = "qwen2.5:14b" # Larger local model for coding
[providers.databricks]
host = "https://your-workspace.cloud.databricks.com"
model = "databricks-claude-sonnet-4"
use_oauth = true
```
This gives you the best of both worlds:
- Fast local execution for coding tasks
- Powerful cloud review for quality assurance
## Recommended Models
### For Coding Tasks
| Model | Size | Speed | Quality | Notes |
|-------|------|-------|---------|-------|
| **qwen2.5:7b** | 7B | Fast | Excellent | Best balance for coding |
| **llama3.2:3b** | 3B | Very Fast | Good | Great for quick tasks |
| **llama3.1:8b** | 8B | Medium | Very Good | Solid all-rounder |
| **mistral** | 7B | Fast | Good | Good for general use |
### For Complex Tasks
| Model | Size | Speed | Quality | Notes |
|-------|------|-------|---------|-------|
| **qwen2.5:14b** | 14B | Medium | Excellent | Best local model for coding |
| **qwen2.5:32b** | 32B | Slow | Outstanding | If you have the resources |
| **llama3.1:70b** | 70B | Very Slow | Outstanding | Requires significant RAM/GPU |
## Temperature Settings
Temperature controls randomness in responses:
- **0.1-0.3**: Deterministic, good for code generation
- **0.5-0.7**: Balanced, good for most tasks
- **0.8-1.0**: Creative, good for brainstorming
```toml
[providers.ollama]
model = "qwen2.5:7b"
temperature = 0.2 # Focused code generation
```
## Max Tokens
Control response length:
```toml
[providers.ollama]
model = "llama3.2"
max_tokens = 1024 # Shorter responses
```
```toml
[providers.ollama]
model = "qwen2.5:7b"
max_tokens = 4096 # Longer, detailed responses
```
Leave it unset for model defaults (recommended).
## Performance Tuning
### GPU Acceleration
Ollama automatically uses GPU if available. To check:
```bash
ollama ps
```
### Quantized Models
For faster responses with less RAM:
```toml
[providers.ollama]
model = "llama3.2:3b-q4_0" # 4-bit quantization
```
Quantization options:
- `q4_0`: 4-bit, fastest, lowest quality
- `q5_0`: 5-bit, balanced
- `q8_0`: 8-bit, slower, better quality
### Multiple Models
You can pull multiple models and switch easily:
```bash
ollama pull llama3.2:3b # Fast for chat
ollama pull qwen2.5:7b # Better for code
ollama pull mistral # General purpose
```
Then change your config:
```toml
[providers.ollama]
model = "qwen2.5:7b" # Just change this line
```
## Troubleshooting
### Ollama Not Running
```bash
# Check if Ollama is running
curl http://localhost:11434/api/version
# Start Ollama (macOS/Linux)
ollama serve
# Or just run a model (auto-starts)
ollama run llama3.2
```
### Model Not Found
```bash
# List available models
ollama list
# Pull the model
ollama pull llama3.2
```
### Slow Responses
1. Use a smaller model:
```toml
model = "llama3.2:1b" # Smallest, fastest
```
2. Use quantized version:
```toml
model = "llama3.2:3b-q4_0"
```
3. Reduce max_tokens:
```toml
max_tokens = 512
```
### Out of Memory
1. Switch to smaller model
2. Use quantized version
3. Close other applications
4. Check GPU memory: `ollama ps`
### Connection Refused
Check base_url is correct:
```toml
[providers.ollama]
model = "llama3.2"
base_url = "http://localhost:11434" # Default
```
For remote Ollama:
```toml
base_url = "http://your-server:11434"
```
## Complete Example Configs
### Minimal Local Setup
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "llama3.2"
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
```
### Optimized for Coding
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "qwen2.5:7b"
temperature = 0.2
max_tokens = 2048
[agent]
max_context_length = 16384
enable_streaming = true
timeout_seconds = 120
```
### Fast Responses
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "llama3.2:3b-q4_0"
temperature = 0.7
max_tokens = 1024
[agent]
max_context_length = 4096
enable_streaming = true
timeout_seconds = 30
```
### High Quality (Requires Good Hardware)
```toml
[providers]
default_provider = "ollama"
[providers.ollama]
model = "qwen2.5:32b"
temperature = 0.3
max_tokens = 4096
[agent]
max_context_length = 32768
enable_streaming = true
timeout_seconds = 300
```
### Hybrid (Local + Cloud)
```toml
[providers]
default_provider = "ollama"
coach = "databricks"
player = "ollama"
[providers.ollama]
model = "qwen2.5:14b"
temperature = 0.2
[providers.databricks]
host = "https://your-workspace.cloud.databricks.com"
model = "databricks-claude-sonnet-4"
use_oauth = true
[agent]
max_context_length = 16384
enable_streaming = true
timeout_seconds = 120
```
## Environment Variables
You can override config with environment variables:
```bash
# Override model
G3_PROVIDERS_OLLAMA_MODEL=qwen2.5:7b g3
# Override base URL
G3_PROVIDERS_OLLAMA_BASE_URL=http://192.168.1.100:11434 g3
# Override default provider
G3_PROVIDERS_DEFAULT_PROVIDER=ollama g3
```
## Best Practices
1. **Start Small**: Begin with llama3.2:3b, scale up if needed
2. **Use Quantization**: q4_0 or q5_0 for best speed/quality balance
3. **Match Task to Model**:
- Quick edits: 1B-3B models
- Code generation: 7B-14B models
- Complex refactoring: 14B-32B models
4. **Temperature for Code**: Use 0.1-0.3 for deterministic output
5. **Enable Streaming**: Always enable for better UX
6. **Local First**: Use Ollama by default, cloud for special cases
## Comparison with Other Providers
| Feature | Ollama | Databricks | OpenAI | Anthropic |
|---------|--------|------------|--------|-----------|
| Cost | Free | Paid | Paid | Paid |
| Privacy | Full | Medium | Low | Low |
| Speed (small models) | Fast | Fast | Medium | Medium |
| Speed (large models) | Slow | Fast | Fast | Fast |
| Setup Complexity | Low | Medium | Low | Low |
| Authentication | None | OAuth/Token | API Key | API Key |
| Offline Support | Yes | No | No | No |
| Tool Calling | Yes | Yes | Yes | Yes |
## Next Steps
1. Try different models: `ollama pull mistral`, `ollama pull qwen2.5`
2. Experiment with temperature settings
3. Set up hybrid config with cloud provider for complex tasks
4. Share your config in the community!
## Getting Help
- Ollama docs: https://ollama.ai/docs
- G3 issues: https://github.com/your-repo/issues
- Test your config: `g3 --help`

315
OLLAMA_EXAMPLE.md Normal file
View File

@@ -0,0 +1,315 @@
# Ollama Provider for g3
A simple, local LLM provider implementation for g3 that connects to Ollama.
## Features
-**Simple Setup**: No API keys or authentication required
-**Local Execution**: Runs entirely on your machine
-**Tool Calling Support**: Native tool calling for compatible models
-**Streaming**: Full streaming support with real-time responses
-**Flexible Configuration**: Custom base URL, temperature, and max tokens
-**Model Discovery**: Automatic detection of available models
## Quick Start
### Prerequisites
1. Install and start Ollama: https://ollama.ai
2. Pull a model: `ollama pull llama3.2`
### Basic Usage
```rust
use g3_providers::{OllamaProvider, LLMProvider, CompletionRequest, Message, MessageRole};
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Create provider with default settings (localhost:11434)
let provider = OllamaProvider::new(
"llama3.2".to_string(),
None, // base_url: defaults to http://localhost:11434
None, // max_tokens: optional
None, // temperature: defaults to 0.7
)?;
// Create a simple request
let request = CompletionRequest {
messages: vec![
Message {
role: MessageRole::User,
content: "What is the capital of France?".to_string(),
},
],
max_tokens: Some(1000),
temperature: Some(0.7),
stream: false,
tools: None,
};
// Get completion
let response = provider.complete(request).await?;
println!("Response: {}", response.content);
println!("Tokens: {}", response.usage.total_tokens);
Ok(())
}
```
### Streaming Example
```rust
use futures_util::StreamExt;
let request = CompletionRequest {
messages: vec![
Message {
role: MessageRole::User,
content: "Write a short poem about coding".to_string(),
},
],
max_tokens: Some(500),
temperature: Some(0.8),
stream: true,
tools: None,
};
let mut stream = provider.stream(request).await?;
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
print!("{}", chunk.content);
if chunk.finished {
println!("\n\nDone!");
if let Some(usage) = chunk.usage {
println!("Total tokens: {}", usage.total_tokens);
}
}
}
Err(e) => eprintln!("Error: {}", e),
}
}
```
### Tool Calling Example
```rust
use serde_json::json;
let tools = vec![Tool {
name: "get_weather".to_string(),
description: "Get current weather for a location".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "City name"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "Temperature unit"
}
},
"required": ["location"]
}),
}];
let request = CompletionRequest {
messages: vec![
Message {
role: MessageRole::User,
content: "What's the weather in Paris?".to_string(),
},
],
max_tokens: Some(500),
temperature: Some(0.5),
stream: false,
tools: Some(tools),
};
let response = provider.complete(request).await?;
println!("Response: {}", response.content);
```
### Custom Ollama Host
```rust
// Connect to remote Ollama instance
let provider = OllamaProvider::new(
"llama3.2".to_string(),
Some("http://192.168.1.100:11434".to_string()),
None,
None,
)?;
```
### Fetch Available Models
```rust
// Discover what models are available
let models = provider.fetch_available_models().await?;
println!("Available models:");
for model in models {
println!(" - {}", model);
}
```
## Supported Models
The provider works with any Ollama model, including:
- **llama3.2** (1B, 3B) - Meta's latest Llama models
- **llama3.1** (8B, 70B, 405B) - Previous generation
- **qwen2.5** (7B, 14B, 32B) - Alibaba's Qwen models
- **mistral** - Mistral AI models
- **mixtral** - Mixture of experts model
- **phi3** - Microsoft's Phi-3
- **gemma2** - Google's Gemma 2
## Configuration
### Constructor Parameters
```rust
OllamaProvider::new(
model: String, // Model name (e.g., "llama3.2")
base_url: Option<String>, // Ollama API URL (default: http://localhost:11434)
max_tokens: Option<u32>, // Maximum tokens to generate (optional)
temperature: Option<f32>, // Sampling temperature (default: 0.7)
)
```
### Request Options
```rust
CompletionRequest {
messages: Vec<Message>, // Conversation history
max_tokens: Option<u32>, // Override provider's max_tokens
temperature: Option<f32>, // Override provider's temperature
stream: bool, // Enable streaming responses
tools: Option<Vec<Tool>>, // Tools for function calling
}
```
## Comparison with Other Providers
| Feature | Ollama | OpenAI | Anthropic | Databricks |
|---------|--------|--------|-----------|------------|
| Local Execution | ✅ | ❌ | ❌ | ❌ |
| Authentication | None | API Key | API Key | OAuth/Token |
| Tool Calling | ✅ | ✅ | ✅ | ✅ |
| Streaming | ✅ | ✅ | ✅ | ✅ |
| Cost | Free | Paid | Paid | Paid |
| Privacy | High | Low | Low | Medium |
## Implementation Details
### API Endpoints
- **Chat Completion**: `POST /api/chat`
- **Model List**: `GET /api/tags`
### Response Format
Ollama uses a simple JSON-per-line streaming format:
```json
{"message":{"role":"assistant","content":"Hello"},"done":false}
{"message":{"role":"assistant","content":" there"},"done":false}
{"done":true,"prompt_eval_count":10,"eval_count":20}
```
### Tool Call Format
Tool calls are returned in the message structure:
```json
{
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": {"location": "Paris", "unit": "celsius"}
}
}
]
},
"done": true
}
```
## Troubleshooting
### Connection Errors
If you see connection errors, ensure Ollama is running:
```bash
# Check if Ollama is running
curl http://localhost:11434/api/version
# Start Ollama (if needed)
ollama serve
```
### Model Not Found
Pull the model first:
```bash
ollama pull llama3.2
ollama list # Check available models
```
### Performance Issues
- Use smaller models (1B, 3B) for faster responses
- Reduce `max_tokens` to limit generation length
- Enable GPU acceleration if available
- Consider quantized models (e.g., `llama3.2:3b-q4_0`)
## Testing
Run the included tests:
```bash
cargo test --package g3-providers ollama
```
All tests should pass:
```
running 4 tests
test ollama::tests::test_custom_base_url ... ok
test ollama::tests::test_message_conversion ... ok
test ollama::tests::test_provider_creation ... ok
test ollama::tests::test_tool_conversion ... ok
```
## Architecture
The provider follows the same architecture as other g3 providers:
1. **OllamaProvider**: Main struct implementing `LLMProvider` trait
2. **Request/Response Structures**: Internal types for Ollama API
3. **Streaming Parser**: Handles line-by-line JSON parsing
4. **Tool Call Handling**: Accumulates and converts tool calls
5. **Error Handling**: Robust error handling with retries
## Contributing
The provider is part of the g3-providers crate. To contribute:
1. Add features to `ollama.rs`
2. Update tests
3. Run `cargo test --package g3-providers`
4. Update this documentation
## License
Same as the g3 project.

View File

@@ -0,0 +1,26 @@
# Example G3 configuration using Ollama provider
# Copy this to ~/.config/g3/config.toml or ./g3.toml to use it
[providers]
default_provider = "ollama"
# Ollama configuration (local LLM)
[providers.ollama]
model = "llama3.2" # or qwen2.5, mistral, etc.
# base_url = "http://localhost:11434" # Optional, defaults to localhost
# max_tokens = 2048 # Optional
# temperature = 0.7 # Optional
# Optional: Specify different providers for coach and player in autonomous mode
# coach = "ollama" # Provider for coach (code reviewer)
# player = "ollama" # Provider for player (code implementer)
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
[computer_control]
enabled = false # Set to true to enable computer control (requires OS permissions)
require_confirmation = true
max_actions_per_second = 5

View File

@@ -184,6 +184,10 @@ pub struct Cli {
#[arg(short, long)]
pub verbose: bool,
/// Enable manual control of context compaction (disables auto-compact at 90%)
#[arg(long = "manual-compact")]
pub manual_compact: bool,
/// Show the system prompt being sent to the LLM
#[arg(long)]
pub show_prompt: bool,
@@ -286,10 +290,6 @@ pub async fn run() -> Result<()> {
tracing_subscriber::registry().with(filter).init();
}
if !cli.machine {
info!("Starting G3 AI Coding Agent");
}
// Set up workspace directory
let workspace_dir = if let Some(ws) = &cli.workspace {
ws.clone()
@@ -325,10 +325,6 @@ pub async fn run() -> Result<()> {
project.ensure_workspace_exists()?;
project.enter_workspace()?;
if !cli.machine {
info!("Using workspace: {}", project.workspace().display());
}
// Load configuration with CLI overrides
let mut config = Config::load_with_overrides(
cli.config.as_deref(),
@@ -339,9 +335,6 @@ pub async fn run() -> Result<()> {
// Apply macax flag override
if cli.macax {
config.macax.enabled = true;
if !cli.machine {
info!("macOS Accessibility API tools enabled");
}
}
// Apply webdriver flag override
@@ -349,6 +342,11 @@ pub async fn run() -> Result<()> {
config.webdriver.enabled = true;
}
// Apply no-auto-compact flag override
if cli.manual_compact {
config.agent.auto_compact = false;
}
// Validate provider if specified
if let Some(ref provider) = cli.provider {
let valid_providers = ["anthropic", "databricks", "embedded", "openai"];
@@ -569,6 +567,11 @@ async fn run_accumulative_mode(
config.webdriver.enabled = true;
}
// Apply no-auto-compact flag override
if cli.manual_compact {
config.agent.auto_compact = false;
}
// Create agent for interactive mode with requirements context
let ui_writer = ConsoleUiWriter::new();
let agent = Agent::new_with_readme_and_quiet(
@@ -646,6 +649,11 @@ async fn run_accumulative_mode(
config.webdriver.enabled = true;
}
// Apply no-auto-compact flag override
if cli.manual_compact {
config.agent.auto_compact = false;
}
// Create agent for this autonomous run
let ui_writer = ConsoleUiWriter::new();
let agent = Agent::new_autonomous_with_readme_and_quiet(
@@ -766,9 +774,6 @@ async fn run_with_console_mode(
// Execute task, autonomous mode, or start interactive mode
if cli.autonomous {
// Autonomous mode with coach-player feedback loop
if !cli.machine {
info!("Starting autonomous mode");
}
run_autonomous(
agent,
project,
@@ -780,9 +785,6 @@ async fn run_with_console_mode(
.await?;
} else if let Some(task) = cli.task {
// Single-shot mode
if !cli.machine {
info!("Executing task: {}", task);
}
let output = SimpleOutput::new();
let result = agent
.execute_task_with_timing(&task, None, false, cli.show_prompt, cli.show_code, true)
@@ -790,9 +792,6 @@ async fn run_with_console_mode(
output.print_smart(&result.response);
} else {
// Interactive mode (default)
if !cli.machine {
info!("Starting interactive mode");
}
println!("📁 Workspace: {}", project.workspace().display());
run_interactive(agent, cli.show_prompt, cli.show_code, combined_content).await?;
}
@@ -841,7 +840,6 @@ fn read_agents_config(workspace_dir: &Path) -> Option<String> {
match std::fs::read_to_string(&agents_path) {
Ok(content) => {
// Return the content with a note about which file was read
info!("Loaded AGENTS.md from {}", agents_path.display());
Some(format!(
"🤖 Agent Configuration (from AGENTS.md):\n\n{}",
content
@@ -859,7 +857,6 @@ fn read_agents_config(workspace_dir: &Path) -> Option<String> {
if alt_path.exists() {
match std::fs::read_to_string(&alt_path) {
Ok(content) => {
info!("Loaded agents.md from {}", alt_path.display());
Some(format!("🤖 Agent Configuration (from agents.md):\n\n{}", content))
}
Err(e) => {

View File

@@ -17,6 +17,7 @@ pub struct ProvidersConfig {
pub anthropic: Option<AnthropicConfig>,
pub databricks: Option<DatabricksConfig>,
pub embedded: Option<EmbeddedConfig>,
pub ollama: Option<OllamaConfig>,
pub default_provider: String,
pub coach: Option<String>, // Provider to use for coach in autonomous mode
pub player: Option<String>, // Provider to use for player in autonomous mode
@@ -60,11 +61,20 @@ pub struct EmbeddedConfig {
pub threads: Option<u32>, // Number of CPU threads to use
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaConfig {
pub model: String,
pub base_url: Option<String>, // Default: http://localhost:11434
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
pub max_context_length: usize,
pub enable_streaming: bool,
pub timeout_seconds: u64,
pub auto_compact: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -127,6 +137,7 @@ impl Default for Config {
use_oauth: Some(true),
}),
embedded: None,
ollama: None,
default_provider: "databricks".to_string(),
coach: None, // Will use default_provider if not specified
player: None, // Will use default_provider if not specified
@@ -135,6 +146,7 @@ impl Default for Config {
max_context_length: 8192,
enable_streaming: true,
timeout_seconds: 60,
auto_compact: true,
},
computer_control: ComputerControlConfig::default(),
webdriver: WebDriverConfig::default(),
@@ -242,6 +254,7 @@ impl Config {
gpu_layers: Some(32),
threads: Some(8),
}),
ollama: None,
default_provider: "embedded".to_string(),
coach: None, // Will use default_provider if not specified
player: None, // Will use default_provider if not specified
@@ -250,6 +263,7 @@ impl Config {
max_context_length: 8192,
enable_streaming: true,
timeout_seconds: 60,
auto_compact: true,
},
computer_control: ComputerControlConfig::default(),
webdriver: WebDriverConfig::default(),

View File

@@ -4,6 +4,11 @@
// 3. Only elide JSON content between first '{' and last '}' (inclusive)
// 4. Return everything else as the final filtered string
//! JSON tool call filtering for streaming LLM responses.
//!
//! This module filters out JSON tool calls from LLM output streams while preserving
//! regular text content. It uses a state machine to handle streaming chunks.
use regex::Regex;
use std::cell::RefCell;
use tracing::debug;
@@ -13,37 +18,51 @@ thread_local! {
static FIXED_JSON_TOOL_STATE: RefCell<FixedJsonToolState> = RefCell::new(FixedJsonToolState::new());
}
/// Internal state for tracking JSON tool call filtering across streaming chunks.
#[derive(Debug, Clone)]
struct FixedJsonToolState {
/// True when actively suppressing a confirmed tool call
suppression_mode: bool,
/// True when buffering potential JSON (saw { but not yet confirmed as tool call)
potential_json_mode: bool,
/// Tracks nesting depth of braces within JSON
brace_depth: i32,
buffer: String,
json_start_in_buffer: Option<usize>,
json_start_in_buffer: Option<usize>, // Position where confirmed JSON tool call starts
content_returned_up_to: usize, // Track how much content we've already returned
potential_json_start: Option<usize>, // Where the potential JSON started
}
impl FixedJsonToolState {
fn new() -> Self {
Self {
suppression_mode: false,
potential_json_mode: false,
brace_depth: 0,
buffer: String::new(),
json_start_in_buffer: None,
content_returned_up_to: 0,
potential_json_start: None,
}
}
fn reset(&mut self) {
self.suppression_mode = false;
self.potential_json_mode = false;
self.brace_depth = 0;
self.buffer.clear();
self.json_start_in_buffer = None;
self.content_returned_up_to = 0;
self.potential_json_start = None;
}
}
// FINAL CORRECTED implementation according to specification
/// Filters JSON tool calls from streaming LLM content.
///
/// Processes content chunks and removes JSON tool calls while preserving regular text.
/// Maintains state across calls to handle tool calls spanning multiple chunks.
pub fn fixed_filter_json_tool_calls(content: &str) -> String {
if content.is_empty() {
return String::new();
@@ -87,13 +106,225 @@ pub fn fixed_filter_json_tool_calls(content: &str) -> String {
_ => {}
}
}
// CRITICAL FIX: After counting braces, if still in suppression mode,
// check if a new tool call pattern appears. This handles truncated JSON
// followed by complete JSON.
if state.suppression_mode {
let current_json_start = state.json_start_in_buffer.unwrap();
// Don't require newline - the new JSON might be concatenated directly
let tool_call_regex = Regex::new(r#"\{\s*"tool"\s*:\s*""#).unwrap();
// Look for new tool call patterns after the current one
if let Some(captures) = tool_call_regex.find(&state.buffer[current_json_start + 1..]) {
let new_json_start = current_json_start + 1 + captures.start() + captures.as_str().find('{').unwrap();
debug!("Detected new tool call at position {} while processing incomplete one at {} - discarding old", new_json_start, current_json_start);
// The previous JSON was incomplete/malformed
// Return content before the old JSON (if any)
let content_before_old_json = if current_json_start > state.content_returned_up_to {
state.buffer[state.content_returned_up_to..current_json_start].to_string()
} else {
String::new()
};
// Update state to skip the incomplete JSON and position at the new one
// We'll process the new JSON on the next call
state.content_returned_up_to = new_json_start;
state.suppression_mode = false;
state.json_start_in_buffer = None;
state.brace_depth = 0;
return content_before_old_json;
}
}
// Still in suppression mode, return empty string (content is being accumulated)
return String::new();
}
// Check if we're in potential JSON mode (saw { but waiting to confirm it's a tool call)
if state.potential_json_mode {
// Check if the buffer contains a confirmed tool call pattern
let tool_call_regex = Regex::new(r#"(?m)^\s*\{\s*"tool"\s*:\s*""#).unwrap();
if let Some(captures) = tool_call_regex.find(&state.buffer) {
// Confirmed! This is a tool call - enter suppression mode
let match_text = captures.as_str();
if let Some(brace_offset) = match_text.find('{') {
let json_start = captures.start() + brace_offset;
debug!("Confirmed JSON tool call at position {} - entering suppression mode", json_start);
state.potential_json_mode = false;
state.suppression_mode = true;
state.brace_depth = 0;
state.json_start_in_buffer = Some(json_start);
// Count braces from json_start to see if JSON is complete
let buffer_slice = state.buffer[json_start..].to_string();
for ch in buffer_slice.chars() {
match ch {
'{' => state.brace_depth += 1,
'}' => {
state.brace_depth -= 1;
if state.brace_depth <= 0 {
debug!("JSON tool call completed immediately");
let result = extract_fixed_content(&state.buffer, json_start);
let new_content = if result.len() > state.content_returned_up_to {
result[state.content_returned_up_to..].to_string()
} else {
String::new()
};
state.reset();
return new_content;
}
}
_ => {}
}
}
// JSON incomplete, stay in suppression mode, return nothing
return String::new();
}
}
// Check if we can rule out this being a tool call
// If we have enough content after the { and it doesn't match the pattern, release it
if let Some(potential_start) = state.potential_json_start {
let content_after_brace = &state.buffer[potential_start..];
// Rule out as a tool call if:
// 1. Closing } appears before we see the full pattern
// 2. Content clearly doesn't match the tool call pattern
// 3. Newline appears after the opening brace (tool calls should be compact)
let has_closing_brace = content_after_brace.contains('}');
let has_newline = content_after_brace[1..].contains('\n'); // Skip first char which is {
let long_enough = content_after_brace.len() >= 10;
// Detect non-tool JSON patterns:
// - { followed by " and a key that doesn't start with "tool"
// - { followed by "t" but not "to"
// - { followed by "to" but not "too", etc.
let not_tool_pattern = Regex::new(r#"^\{\s*"(?:[^t]|t(?:[^o]|o(?:[^o]|o(?:[^l]|l[^"\s:]))))"#).unwrap();
let definitely_not_tool = not_tool_pattern.is_match(content_after_brace);
if has_closing_brace || has_newline || (long_enough && definitely_not_tool) {
debug!("Potential JSON ruled out - not a tool call");
state.potential_json_mode = false;
state.potential_json_start = None;
// Return the buffered content we've been holding
let new_content = if state.buffer.len() > state.content_returned_up_to {
state.buffer[state.content_returned_up_to..].to_string()
} else {
String::new()
};
state.content_returned_up_to = state.buffer.len();
return new_content;
}
}
// Still in potential mode, keep buffering
return String::new();
}
// Detect potential JSON start: { at the beginning of a line
let potential_json_regex = Regex::new(r"(?m)^\s*\{\s*").unwrap();
if let Some(captures) = potential_json_regex.find(&state.buffer[state.content_returned_up_to..]) {
let match_start = state.content_returned_up_to + captures.start();
let brace_pos = match_start + captures.as_str().find('{').unwrap();
debug!("Potential JSON detected at position {} - entering buffering mode", brace_pos);
// Fast path: check if this is already a confirmed tool call
let tool_call_regex = Regex::new(r#"(?m)^\s*\{\s*"tool"\s*:\s*""#).unwrap();
if tool_call_regex.is_match(&state.buffer[brace_pos..]) {
// This is a confirmed tool call! Process it immediately
let json_start = brace_pos;
debug!("Immediately confirmed tool call at position {}", json_start);
// Return content before JSON
let content_before = if json_start > state.content_returned_up_to {
state.buffer[state.content_returned_up_to..json_start].to_string()
} else {
String::new()
};
state.content_returned_up_to = json_start;
state.suppression_mode = true;
state.brace_depth = 0;
state.json_start_in_buffer = Some(json_start);
// Count braces to see if JSON is complete
let buffer_slice = state.buffer[json_start..].to_string();
for ch in buffer_slice.chars() {
match ch {
'{' => state.brace_depth += 1,
'}' => {
state.brace_depth -= 1;
if state.brace_depth <= 0 {
debug!("JSON tool call completed in same chunk");
let result = extract_fixed_content(&state.buffer, json_start);
let content_after = if result.len() > json_start {
&result[json_start..]
} else {
""
};
let final_result = format!("{}{}", content_before, content_after);
state.reset();
return final_result;
}
}
_ => {}
}
}
// JSON incomplete, return content before and stay in suppression mode
return content_before;
}
// Return content before the potential JSON
let content_before = if brace_pos > state.content_returned_up_to {
state.buffer[state.content_returned_up_to..brace_pos].to_string()
} else {
String::new()
};
state.content_returned_up_to = brace_pos;
state.potential_json_mode = true;
state.potential_json_start = Some(brace_pos);
// Optimization: immediately check if we can rule this out for single-chunk processing
let content_after_brace = &state.buffer[brace_pos..];
let has_closing_brace = content_after_brace.contains('}');
let has_newline = content_after_brace.len() > 1 && content_after_brace[1..].contains('\n');
let long_enough = content_after_brace.len() >= 10;
let not_tool_pattern = Regex::new(r#"^\{\s*"(?:[^t]|t(?:[^o]|o(?:[^o]|o(?:[^l]|l[^"\s:]))))"#).unwrap();
let definitely_not_tool = not_tool_pattern.is_match(content_after_brace);
if has_closing_brace || has_newline || (long_enough && definitely_not_tool) {
debug!("Immediately ruled out as not a tool call");
state.potential_json_mode = false;
state.potential_json_start = None;
// Return all the buffered content
let new_content = if state.buffer.len() > state.content_returned_up_to {
state.buffer[state.content_returned_up_to..].to_string()
} else {
String::new()
};
state.content_returned_up_to = state.buffer.len();
return format!("{}{}", content_before, new_content);
}
return content_before;
}
// Check for tool call pattern using corrected regex
// More flexible than the strict specification to handle real-world JSON
let tool_call_regex = Regex::new(r#"(?m)^\s*\{\s*"tool"\s*:\s*""#).unwrap();
let tool_call_regex = Regex::new(r#"(?m)^\s*\{\s*"tool"\s*:\s*"[^"]*""#).unwrap();
if let Some(captures) = tool_call_regex.find(&state.buffer) {
let match_text = captures.as_str();
@@ -168,9 +399,17 @@ pub fn fixed_filter_json_tool_calls(content: &str) -> String {
})
}
// Helper function to extract content with JSON tool call filtered out
// Returns everything except the JSON between the first '{' and last '}' (inclusive)
/// Extracts content from buffer, removing the JSON tool call.
///
/// Given a buffer and the start position of a JSON tool call, this function:
/// 1. Extracts all content before the JSON
/// 2. Finds the end of the JSON (matching closing brace)
/// 3. Extracts all content after the JSON
/// 4. Returns the concatenation of before + after (JSON removed)
///
/// # Arguments
/// * `full_content` - The full content buffer
/// * `json_start` - Position where the JSON tool call begins
fn extract_fixed_content(full_content: &str, json_start: usize) -> String {
// Find the end of the JSON using proper brace counting with string handling
let mut brace_depth = 0;
@@ -212,8 +451,10 @@ fn extract_fixed_content(full_content: &str, json_start: usize) -> String {
format!("{}{}", before, after)
}
// Reset function for testing
/// Resets the global JSON filtering state.
///
/// Call this between independent filtering sessions to ensure clean state.
/// This is particularly important in tests and when starting new conversations.
pub fn reset_fixed_json_tool_state() {
FIXED_JSON_TOOL_STATE.with(|state| {
let mut state = state.borrow_mut();

View File

@@ -1,8 +1,14 @@
//! Tests for JSON tool call filtering.
//!
//! These tests verify that the filter correctly identifies and removes JSON tool calls
//! from LLM output streams while preserving all other content.
#[cfg(test)]
mod fixed_filter_tests {
use crate::fixed_filter_json::{fixed_filter_json_tool_calls, reset_fixed_json_tool_state};
use regex::Regex;
/// Test that regular text without tool calls passes through unchanged.
#[test]
fn test_no_tool_call_passthrough() {
reset_fixed_json_tool_state();
@@ -11,6 +17,7 @@ mod fixed_filter_tests {
assert_eq!(result, input);
}
/// Test detection and removal of a complete tool call in a single chunk.
#[test]
fn test_simple_tool_call_detection() {
reset_fixed_json_tool_state();
@@ -23,6 +30,7 @@ Some text after"#;
assert_eq!(result, expected);
}
/// Test handling of tool calls that arrive across multiple streaming chunks.
#[test]
fn test_streaming_chunks() {
reset_fixed_json_tool_state();
@@ -48,6 +56,7 @@ Some text after"#;
assert_eq!(final_result, expected);
}
/// Test correct handling of nested braces within JSON strings.
#[test]
fn test_nested_braces_in_tool_call() {
reset_fixed_json_tool_state();
@@ -61,6 +70,7 @@ Text after"#;
assert_eq!(result, expected);
}
/// Verify the regex pattern matches the specification with flexible whitespace.
#[test]
fn test_regex_pattern_specification() {
// Test the corrected regex pattern that's more flexible with whitespace
@@ -84,11 +94,6 @@ Text after"#;
), // Space after { DOES match with \s*
(
r#"line
abc{"tool":"#,
true,
),
(
r#"line
{"tool123":"#,
false,
), // "tool123" is not exactly "tool"
@@ -109,6 +114,7 @@ abc{"tool":"#,
}
}
/// Test that tool calls must appear at the start of a line (after newline).
#[test]
fn test_newline_requirement() {
reset_fixed_json_tool_state();
@@ -122,13 +128,14 @@ abc{"tool":"#,
reset_fixed_json_tool_state();
let result2 = fixed_filter_json_tool_calls(input_without_newline);
// Both cases currently trigger suppression due to regex pattern
// TODO: Fix regex to only match after actual newlines
// With the new aggressive filtering, only the newline case should trigger suppression
// The pattern requires { to be at the start of a line (after ^)
assert_eq!(result1, "Text\n");
// This currently fails because our regex matches both cases
assert_eq!(result2, "Text ");
// Without newline before {, it should pass through unchanged
assert_eq!(result2, input_without_newline);
}
/// Test handling of escaped quotes within JSON strings.
#[test]
fn test_json_with_escaped_quotes() {
reset_fixed_json_tool_state();
@@ -142,6 +149,7 @@ More text"#;
assert_eq!(result, expected);
}
/// Test graceful handling of incomplete/malformed JSON.
#[test]
fn test_edge_case_malformed_json() {
reset_fixed_json_tool_state();
@@ -157,6 +165,7 @@ More text"#;
assert_eq!(result, expected);
}
/// Test processing multiple independent tool calls sequentially.
#[test]
fn test_multiple_tool_calls_sequential() {
reset_fixed_json_tool_state();
@@ -179,6 +188,7 @@ Final text"#;
assert_eq!(result2, expected2);
}
/// Test tool calls with complex multi-line arguments.
#[test]
fn test_tool_call_with_complex_args() {
reset_fixed_json_tool_state();
@@ -192,6 +202,7 @@ After"#;
assert_eq!(result, expected);
}
/// Test input containing only a tool call with no surrounding text.
#[test]
fn test_tool_call_only() {
reset_fixed_json_tool_state();
@@ -204,6 +215,7 @@ After"#;
assert_eq!(result, expected);
}
/// Test accurate brace counting with deeply nested structures.
#[test]
fn test_brace_counting_accuracy() {
reset_fixed_json_tool_state();
@@ -218,6 +230,7 @@ End"#;
assert_eq!(result, expected);
}
/// Test that braces within strings don't affect brace counting.
#[test]
fn test_string_escaping_in_json() {
reset_fixed_json_tool_state();
@@ -232,6 +245,7 @@ More"#;
assert_eq!(result, expected);
}
/// Verify compliance with the exact specification requirements.
#[test]
fn test_specification_compliance() {
reset_fixed_json_tool_state();
@@ -248,6 +262,7 @@ More"#;
assert_eq!(result, expected);
}
/// Test that non-tool JSON objects are not filtered.
#[test]
fn test_no_false_positives() {
reset_fixed_json_tool_state();
@@ -261,6 +276,7 @@ More text"#;
assert_eq!(result, input);
}
/// Test patterns that look similar to tool calls but aren't exact matches.
#[test]
fn test_partial_tool_patterns() {
reset_fixed_json_tool_state();
@@ -280,6 +296,7 @@ More text"#;
}
}
/// Test streaming with very small chunks (character-by-character).
#[test]
fn test_streaming_edge_cases() {
reset_fixed_json_tool_state();
@@ -296,12 +313,13 @@ More text"#;
}
let final_result: String = results.join("");
// This test currently fails because the JSON is incomplete across chunks
// The function doesn't handle this edge case properly yet
let expected = "Text\n{\"tool\": \nAfter";
// With the new aggressive filtering, the JSON should be completely filtered out
// even when it arrives in very small chunks
let expected = "Text\n\nAfter";
assert_eq!(final_result, expected);
}
/// Debug test with detailed logging for streaming behavior.
#[test]
fn test_streaming_debug() {
reset_fixed_json_tool_state();
@@ -329,4 +347,38 @@ More text"#;
let expected = "Some text before\n\nText after";
assert_eq!(final_result, expected);
}
/// Test handling of truncated JSON followed by complete JSON (the json_err pattern)
#[test]
fn test_truncated_then_complete_json() {
reset_fixed_json_tool_state();
// Simulate the pattern from json_err trace:
// 1. Incomplete/truncated JSON appears
// 2. Then the same complete JSON appears
let chunks = vec![
"Some text\n",
r#"{"tool": "str_replace", "args": {"diff":"...","file_path":"./crates/g3-cli"#, // Truncated
r#"{"tool": "str_replace", "args": {"diff":"...","file_path":"./crates/g3-cli/src/lib.rs"}}"#, // Complete
"\nMore text",
];
let mut results = Vec::new();
for (i, chunk) in chunks.iter().enumerate() {
let result = fixed_filter_json_tool_calls(chunk);
println!("Chunk {}: {:?} -> {:?}", i, chunk, result);
results.push(result);
}
let final_result: String = results.join("");
println!("Final result: {:?}", final_result);
// The truncated JSON should be discarded when the complete one appears
// Both JSONs should be filtered out, leaving only the text
let expected = "Some text\n\nMore text";
assert_eq!(
final_result, expected,
"Failed to handle truncated JSON followed by complete JSON"
);
}
}

View File

@@ -681,6 +681,8 @@ pub struct Agent<W: UiWriter> {
providers: ProviderRegistry,
context_window: ContextWindow,
thinning_events: Vec<usize>, // chars saved per thinning event
pending_90_summarization: bool, // flag to trigger summarization at 90%
auto_compact: bool, // whether to auto-compact at 90% before tool calls
summarization_events: Vec<usize>, // chars saved per summarization event
first_token_times: Vec<Duration>, // time to first token for each completion
config: Config,
@@ -786,7 +788,6 @@ impl<W: UiWriter> Agent<W> {
// Register embedded provider if configured AND it's the default provider
if let Some(embedded_config) = &config.providers.embedded {
if providers_to_register.contains(&"embedded".to_string()) {
info!("Initializing embedded provider");
let embedded_provider = g3_providers::EmbeddedProvider::new(
embedded_config.model_path.clone(),
embedded_config.model_type.clone(),
@@ -797,15 +798,12 @@ impl<W: UiWriter> Agent<W> {
embedded_config.threads,
)?;
providers.register(embedded_provider);
} else {
info!("Embedded provider configured but not needed, skipping initialization");
}
}
// Register OpenAI provider if configured AND it's the default provider
if let Some(openai_config) = &config.providers.openai {
if providers_to_register.contains(&"openai".to_string()) {
info!("Initializing OpenAI provider");
let openai_provider = g3_providers::OpenAIProvider::new(
openai_config.api_key.clone(),
Some(openai_config.model.clone()),
@@ -814,15 +812,12 @@ impl<W: UiWriter> Agent<W> {
openai_config.temperature,
)?;
providers.register(openai_provider);
} else {
info!("OpenAI provider configured but not needed, skipping initialization");
}
}
// Register Anthropic provider if configured AND it's the default provider
if let Some(anthropic_config) = &config.providers.anthropic {
if providers_to_register.contains(&"anthropic".to_string()) {
info!("Initializing Anthropic provider");
let anthropic_provider = g3_providers::AnthropicProvider::new(
anthropic_config.api_key.clone(),
Some(anthropic_config.model.clone()),
@@ -830,15 +825,12 @@ impl<W: UiWriter> Agent<W> {
anthropic_config.temperature,
)?;
providers.register(anthropic_provider);
} else {
info!("Anthropic provider configured but not needed, skipping initialization");
}
}
// Register Databricks provider if configured AND it's the default provider
if let Some(databricks_config) = &config.providers.databricks {
if providers_to_register.contains(&"databricks".to_string()) {
info!("Initializing Databricks provider");
let databricks_provider = if let Some(token) = &databricks_config.token {
// Use token-based authentication
@@ -861,8 +853,19 @@ impl<W: UiWriter> Agent<W> {
};
providers.register(databricks_provider);
} else {
info!("Databricks provider configured but not needed, skipping initialization");
}
}
// Register Ollama provider if configured AND it's the default provider
if let Some(ollama_config) = &config.providers.ollama {
if providers_to_register.contains(&"ollama".to_string()) {
let ollama_provider = g3_providers::OllamaProvider::new(
ollama_config.model.clone(),
ollama_config.base_url.clone(),
ollama_config.max_tokens,
ollama_config.temperature,
)?;
providers.register(ollama_provider);
}
}
@@ -885,16 +888,12 @@ impl<W: UiWriter> Agent<W> {
content: readme,
};
context_window.add_message(readme_message);
info!("Added project README to context window");
}
// Initialize computer controller if enabled
let computer_controller = if config.computer_control.enabled {
match g3_computer_control::create_controller() {
Ok(controller) => {
info!("Computer control enabled");
Some(controller)
}
Ok(controller) => Some(controller),
Err(e) => {
warn!("Failed to initialize computer control: {}", e);
None
@@ -910,6 +909,8 @@ impl<W: UiWriter> Agent<W> {
Ok(Self {
providers,
context_window,
auto_compact: config.agent.auto_compact,
pending_90_summarization: false,
thinning_events: Vec::new(),
summarization_events: Vec::new(),
first_token_times: Vec::new(),
@@ -974,6 +975,30 @@ impl<W: UiWriter> Agent<W> {
16384 // Conservative default for other Databricks models
}
}
"ollama" => {
// Ollama model context windows based on model name
if model_name.contains("qwen3-coder") {
262144 // Qwen3-coder supports 256k context
} else if model_name.contains("qwen") {
32768 // Qwen2.5 supports 32k context
} else if model_name.contains("gpt-oss") {
131072 // GPT-OSS supports 128k context
} else if model_name.contains("llama3") || model_name.contains("llama-3") {
if model_name.contains("3.2") || model_name.contains("3.1") {
128000 // Llama 3.1/3.2 support 128k context
} else {
8192 // Llama 3.0
}
} else if model_name.contains("mistral") || model_name.contains("mixtral") {
32768 // Mistral/Mixtral support 32k
} else if model_name.contains("gemma") {
8192 // Gemma 2
} else if model_name.contains("phi") {
4096 // Phi-3
} else {
8192 // Conservative default for Ollama models
}
}
_ => config.agent.max_context_length as u32,
};
@@ -1355,6 +1380,19 @@ Template:
// Save context window at the end of successful interaction
self.save_context_window("completed");
// Check if we need to do 90% auto-compaction
if self.pending_90_summarization {
self.ui_writer.print_context_status(
"\n⚡ Context window reached 90% - auto-compacting...\n"
);
if let Err(e) = self.force_summarize().await {
warn!("Failed to auto-compact at 90%: {}", e);
} else {
self.ui_writer.println("");
}
self.pending_90_summarization = false;
}
// Return the task result which already includes timing if needed
Ok(task_result)
}
@@ -2651,6 +2689,14 @@ Template:
if let Some(tool_call) = completed_tools.into_iter().next() {
debug!("Processing completed tool call: {:?}", tool_call);
// Check if we should auto-compact at 90% BEFORE executing the tool
// We need to do this before any borrows of self
if self.auto_compact && self.context_window.percentage_used() >= 90.0 {
// Set flag to trigger summarization after this turn completes
// We can't do it now due to borrow checker constraints
self.pending_90_summarization = true;
}
// Check if we should thin the context BEFORE executing the tool
if self.context_window.should_thin() {
let (thin_summary, chars_saved) = self.context_window.thin_context();
@@ -2659,6 +2705,7 @@ Template:
self.ui_writer.print_context_thinning(&thin_summary);
}
// Track what we've already displayed before getting new text
// This prevents re-displaying old content after tool execution
let already_displayed_chars = current_response.chars().count();
@@ -2991,7 +3038,8 @@ Template:
"Using filtered parser text as last resort: {} chars",
filtered_text.len()
);
current_response = filtered_text;
// Note: This assignment is currently unused but kept for potential future use
let _ = filtered_text;
}
}
@@ -3124,17 +3172,33 @@ Template:
}
Err(e) => {
// Capture detailed streaming error information
let error_details =
format!("Streaming error at chunk {}: {}", chunks_received + 1, e);
error!("{}", error_details);
let error_msg = e.to_string();
let error_details = format!("Streaming error at chunk {}: {}", chunks_received + 1, error_msg);
error!("Error type: {}", std::any::type_name_of_val(&e));
error!("Parser state at error: text_buffer_len={}, native_tool_calls={}, message_stopped={}",
parser.text_buffer_len(), parser.native_tool_calls.len(), parser.is_message_stopped());
// Store the error for potential logging later
_last_error = Some(error_details);
_last_error = Some(error_details.clone());
// Check if this is a recoverable connection error
let is_connection_error = error_msg.contains("unexpected EOF")
|| error_msg.contains("connection")
|| error_msg.contains("chunk size line")
|| error_msg.contains("body error");
if is_connection_error {
warn!("Connection error at chunk {}, treating as end of stream", chunks_received + 1);
// If we have any content or tool calls, treat this as a graceful end
if chunks_received > 0 && (!parser.get_text_content().is_empty() || parser.native_tool_calls.len() > 0) {
warn!("Stream terminated unexpectedly but we have content, continuing");
break; // Break to process what we have
}
}
if tool_executed {
error!("{}", error_details);
warn!("Stream error after tool execution, attempting to continue");
break; // Break to outer loop to start new stream
} else {

View File

@@ -298,6 +298,7 @@ impl DatabricksProvider {
let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> =
std::collections::HashMap::new(); // index -> (id, name, args)
let mut incomplete_data_line = String::new(); // Buffer for incomplete data: lines
let mut chunk_count = 0;
let accumulated_usage: Option<Usage> = None;
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
@@ -305,6 +306,8 @@ impl DatabricksProvider {
match chunk_result {
Ok(chunk) => {
// Debug: Log raw bytes received
chunk_count += 1;
debug!("Processing chunk #{}", chunk_count);
debug!("Raw SSE bytes received: {} bytes", chunk.len());
// Append new bytes to our buffer
@@ -589,13 +592,39 @@ impl DatabricksProvider {
}
}
Err(e) => {
error!("Stream error: {}", e);
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
error!("Stream error at chunk {}: {}", chunk_count, e);
// Check if this is a connection error that might be recoverable
let error_msg = e.to_string();
if error_msg.contains("unexpected EOF") || error_msg.contains("connection") {
warn!("Connection terminated unexpectedly at chunk {}, treating as end of stream", chunk_count);
// Don't send error, just break and finalize
break;
} else {
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
}
return accumulated_usage;
}
}
}
// Log final state
debug!("Stream ended after {} chunks", chunk_count);
debug!("Final state: buffer_len={}, incomplete_data_line_len={}, byte_buffer_len={}",
buffer.len(), incomplete_data_line.len(), byte_buffer.len());
debug!("Accumulated tool calls: {}", current_tool_calls.len());
// If we have any remaining data in buffers, log it for debugging
if !buffer.is_empty() {
debug!("Remaining buffer content: {:?}", buffer);
}
if !byte_buffer.is_empty() {
debug!("Remaining byte buffer: {} bytes", byte_buffer.len());
}
if !incomplete_data_line.is_empty() {
debug!("Remaining incomplete data line: {:?}", incomplete_data_line);
}
// If we have any incomplete data line at the end, try to process it
if !incomplete_data_line.is_empty() {
debug!(

View File

@@ -88,11 +88,13 @@ pub mod anthropic;
pub mod databricks;
pub mod embedded;
pub mod oauth;
pub mod ollama;
pub mod openai;
pub use anthropic::AnthropicProvider;
pub use databricks::DatabricksProvider;
pub use embedded::EmbeddedProvider;
pub use ollama::OllamaProvider;
pub use openai::OpenAIProvider;
/// Provider registry for managing multiple LLM providers

View File

@@ -0,0 +1,751 @@
//! Ollama LLM provider implementation for the g3-providers crate.
//!
//! This module provides an implementation of the `LLMProvider` trait for Ollama,
//! supporting both completion and streaming modes with native tool calling.
//!
//! # Features
//!
//! - Support for any Ollama model (llama3.2, mistral, qwen, etc.)
//! - Both completion and streaming response modes
//! - Native tool calling support for compatible models
//! - Configurable base URL (defaults to http://localhost:11434)
//! - Simple configuration with no authentication required
//!
//! # Usage
//!
//! ```rust,no_run
//! use g3_providers::{OllamaProvider, LLMProvider, CompletionRequest, Message, MessageRole};
//!
//! #[tokio::main]
//! async fn main() -> anyhow::Result<()> {
//! // Create the provider with default settings (localhost:11434)
//! let provider = OllamaProvider::new(
//! "llama3.2".to_string(),
//! None, // Optional: base_url
//! None, // Optional: max tokens
//! None, // Optional: temperature
//! )?;
//!
//! // Create a completion request
//! let request = CompletionRequest {
//! messages: vec![
//! Message {
//! role: MessageRole::User,
//! content: "Hello! How are you?".to_string(),
//! },
//! ],
//! max_tokens: Some(1000),
//! temperature: Some(0.7),
//! stream: false,
//! tools: None,
//! };
//!
//! // Get a completion
//! let response = provider.complete(request).await?;
//! println!("Response: {}", response.content);
//!
//! Ok(())
//! }
//! ```
use anyhow::{anyhow, Result};
use bytes::Bytes;
use futures_util::stream::StreamExt;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error, info, warn};
use crate::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
MessageRole, Tool, ToolCall, Usage,
};
const DEFAULT_BASE_URL: &str = "http://localhost:11434";
const DEFAULT_TIMEOUT_SECS: u64 = 600;
pub const OLLAMA_DEFAULT_MODEL: &str = "llama3.2";
pub const OLLAMA_KNOWN_MODELS: &[&str] = &[
"llama3.2",
"llama3.2:1b",
"llama3.2:3b",
"llama3.1",
"llama3.1:8b",
"llama3.1:70b",
"mistral",
"mistral-nemo",
"mixtral",
"qwen2.5",
"qwen2.5:7b",
"qwen2.5:14b",
"qwen2.5:32b",
"qwen2.5-coder",
"qwen2.5-coder:7b",
"qwen3-coder",
"phi3",
"gemma2",
];
#[derive(Debug, Clone)]
pub struct OllamaProvider {
client: Client,
base_url: String,
model: String,
max_tokens: Option<u32>,
temperature: f32,
}
impl OllamaProvider {
pub fn new(
model: String,
base_url: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
) -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.build()
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
let base_url = base_url
.unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
.trim_end_matches('/')
.to_string();
info!(
"Initialized Ollama provider with model: {} at {}",
model, base_url
);
Ok(Self {
client,
base_url,
model,
max_tokens,
temperature: temperature.unwrap_or(0.7),
})
}
fn convert_tools(&self, tools: &[Tool]) -> Vec<OllamaTool> {
tools
.iter()
.map(|tool| OllamaTool {
r#type: "function".to_string(),
function: OllamaFunction {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.input_schema.clone(),
},
})
.collect()
}
fn convert_messages(&self, messages: &[Message]) -> Result<Vec<OllamaMessage>> {
let mut ollama_messages = Vec::new();
for message in messages {
let role = match message.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
};
ollama_messages.push(OllamaMessage {
role: role.to_string(),
content: message.content.clone(),
tool_calls: None, // Only used in responses
});
}
if ollama_messages.is_empty() {
return Err(anyhow!("At least one message is required"));
}
Ok(ollama_messages)
}
fn create_request_body(
&self,
messages: &[Message],
tools: Option<&[Tool]>,
streaming: bool,
max_tokens: Option<u32>,
temperature: f32,
) -> Result<OllamaRequest> {
let ollama_messages = self.convert_messages(messages)?;
let ollama_tools = tools.map(|t| self.convert_tools(t));
let mut options = OllamaOptions {
temperature,
num_predict: max_tokens,
};
// If max_tokens is provided, use it; otherwise use the instance default
if max_tokens.is_none() {
options.num_predict = self.max_tokens;
}
let request = OllamaRequest {
model: self.model.clone(),
messages: ollama_messages,
tools: ollama_tools,
stream: streaming,
options,
};
Ok(request)
}
async fn parse_streaming_response(
&self,
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
tx: mpsc::Sender<Result<CompletionChunk>>,
) -> Option<Usage> {
let mut buffer = String::new();
let mut accumulated_usage: Option<Usage> = None;
let mut current_tool_calls: Vec<OllamaToolCall> = Vec::new();
let mut byte_buffer = Vec::new();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
// Append new bytes to our buffer
byte_buffer.extend_from_slice(&chunk);
// Try to convert the entire buffer to UTF-8
let chunk_str = match std::str::from_utf8(&byte_buffer) {
Ok(s) => {
let result = s.to_string();
byte_buffer.clear();
result
}
Err(e) => {
let valid_up_to = e.valid_up_to();
if valid_up_to > 0 {
let valid_bytes =
byte_buffer.drain(..valid_up_to).collect::<Vec<_>>();
std::str::from_utf8(&valid_bytes).unwrap().to_string()
} else {
continue;
}
}
};
buffer.push_str(&chunk_str);
// Process complete lines
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].trim().to_string();
buffer.drain(..line_end + 1);
if line.is_empty() {
continue;
}
// Ollama streaming sends JSON objects per line
match serde_json::from_str::<OllamaStreamChunk>(&line) {
Ok(chunk) => {
// Handle text content
if let Some(message) = &chunk.message {
let content = &message.content;
if !content.is_empty() {
debug!("Sending text chunk: '{}'", content);
let chunk = CompletionChunk {
content: content.clone(),
finished: false,
usage: None,
tool_calls: None,
};
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return accumulated_usage;
}
}
// Handle tool calls
if let Some(tool_calls) = &message.tool_calls {
current_tool_calls.extend(tool_calls.clone());
}
}
// Check if stream is done
if chunk.done.unwrap_or(false) {
debug!("Stream completed");
// Update usage if available
if let Some(eval_count) = chunk.eval_count {
accumulated_usage = Some(Usage {
prompt_tokens: chunk.prompt_eval_count.unwrap_or(0),
completion_tokens: eval_count,
total_tokens: chunk.prompt_eval_count.unwrap_or(0)
+ eval_count,
});
}
// Send final chunk with tool calls if any
let final_tool_calls: Vec<ToolCall> = current_tool_calls
.iter()
.map(|tc| ToolCall {
id: tc.function.name.clone(), // Ollama doesn't provide IDs
tool: tc.function.name.clone(),
args: tc.function.arguments.clone(),
})
.collect();
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if final_tool_calls.is_empty() {
None
} else {
Some(final_tool_calls)
},
};
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
}
return accumulated_usage;
}
}
Err(e) => {
debug!("Failed to parse Ollama stream chunk: {} - Line: {}", e, line);
// Don't error out, just continue
}
}
}
}
Err(e) => {
error!("Stream error: {}", e);
let error_msg = e.to_string();
if error_msg.contains("unexpected EOF") || error_msg.contains("connection") {
warn!("Connection terminated unexpectedly, treating as end of stream");
break;
} else {
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
}
return accumulated_usage;
}
}
}
// Send final chunk if we haven't already
let final_tool_calls: Vec<ToolCall> = current_tool_calls
.iter()
.map(|tc| ToolCall {
id: tc.function.name.clone(),
tool: tc.function.name.clone(),
args: tc.function.arguments.clone(),
})
.collect();
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if final_tool_calls.is_empty() {
None
} else {
Some(final_tool_calls)
},
};
let _ = tx.send(Ok(final_chunk)).await;
accumulated_usage
}
/// Fetch available models from the Ollama instance
pub async fn fetch_available_models(&self) -> Result<Vec<String>> {
let response = self
.client
.get(format!("{}/api/tags", self.base_url))
.send()
.await
.map_err(|e| anyhow!("Failed to fetch Ollama models: {}", e))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!(
"Failed to fetch Ollama models: {} - {}",
status,
error_text
));
}
let json: serde_json::Value = response.json().await?;
let models = json
.get("models")
.and_then(|v| v.as_array())
.ok_or_else(|| anyhow!("Unexpected response format: missing 'models' array"))?;
let model_names: Vec<String> = models
.iter()
.filter_map(|model| model.get("name").and_then(|n| n.as_str()).map(String::from))
.collect();
debug!("Found {} models in Ollama", model_names.len());
Ok(model_names)
}
}
#[async_trait::async_trait]
impl LLMProvider for OllamaProvider {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
debug!(
"Processing Ollama completion request with {} messages",
request.messages.len()
);
let max_tokens = request.max_tokens.or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
let request_body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
false,
max_tokens,
temperature,
)?;
debug!(
"Sending request to Ollama API: model={}, temperature={}",
self.model, request_body.options.temperature
);
let response = self
.client
.post(format!("{}/api/chat", self.base_url))
.json(&request_body)
.send()
.await
.map_err(|e| anyhow!("Failed to send request to Ollama API: {}", e))?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!("Ollama API error {}: {}", status, error_text));
}
let response_text = response.text().await?;
debug!("Raw Ollama API response: {}", response_text);
let ollama_response: OllamaResponse =
serde_json::from_str(&response_text).map_err(|e| {
anyhow!(
"Failed to parse Ollama response: {} - Response: {}",
e,
response_text
)
})?;
let content = ollama_response.message.content.clone();
let usage = Usage {
prompt_tokens: ollama_response.prompt_eval_count.unwrap_or(0),
completion_tokens: ollama_response.eval_count.unwrap_or(0),
total_tokens: ollama_response.prompt_eval_count.unwrap_or(0)
+ ollama_response.eval_count.unwrap_or(0),
};
debug!(
"Ollama completion successful: {} tokens generated",
usage.completion_tokens
);
Ok(CompletionResponse {
content,
usage,
model: self.model.clone(),
})
}
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
debug!(
"Processing Ollama request (non-streaming) with {} messages",
request.messages.len()
);
if let Some(ref tools) = request.tools {
debug!("Request has {} tools", tools.len());
for tool in tools.iter().take(5) {
debug!(" Tool: {}", tool.name);
}
}
let max_tokens = request.max_tokens.or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
let request_body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
false, // Use non-streaming mode to avoid streaming bugs
max_tokens,
temperature,
)?;
debug!(
"Sending request to Ollama API (stream=false): model={}, temperature={}",
self.model, request_body.options.temperature
);
let response = self
.client
.post(format!("{}/api/chat", self.base_url))
.json(&request_body)
.send()
.await
.map_err(|e| anyhow!("Failed to send request to Ollama API: {}", e))?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!("Ollama API error {}: {}", status, error_text));
}
// For non-streaming, parse the complete JSON response
let response_text = response.text().await?;
debug!("Raw Ollama API response: {}", response_text);
let ollama_response: OllamaResponse =
serde_json::from_str(&response_text).map_err(|e| {
anyhow!(
"Failed to parse Ollama response: {} - Response: {}",
e,
response_text
)
})?;
let (tx, rx) = mpsc::channel(100);
tokio::spawn(async move {
let content = ollama_response.message.content;
let usage = Usage {
prompt_tokens: ollama_response.prompt_eval_count.unwrap_or(0),
completion_tokens: ollama_response.eval_count.unwrap_or(0),
total_tokens: ollama_response.prompt_eval_count.unwrap_or(0)
+ ollama_response.eval_count.unwrap_or(0),
};
// Extract tool calls if present
let tool_calls: Option<Vec<ToolCall>> = ollama_response.message.tool_calls.map(|tcs| {
tcs.iter()
.map(|tc| ToolCall {
id: tc.function.name.clone(),
tool: tc.function.name.clone(),
args: tc.function.arguments.clone(),
})
.collect()
});
// Send content if any
if !content.is_empty() {
let _ = tx.send(Ok(CompletionChunk {
content,
finished: false,
usage: None,
tool_calls: None,
})).await;
}
// Send final chunk with usage and tool calls
let _ = tx.send(Ok(CompletionChunk {
content: String::new(),
finished: true,
usage: Some(usage),
tool_calls,
})).await;
});
Ok(ReceiverStream::new(rx))
}
fn name(&self) -> &str {
"ollama"
}
fn model(&self) -> &str {
&self.model
}
fn has_native_tool_calling(&self) -> bool {
// Most modern Ollama models support tool calling
// Models like llama3.2, qwen2.5, mistral, etc. have good tool support
true
}
}
// Ollama API request/response structures
#[derive(Debug, Serialize)]
struct OllamaRequest {
model: String,
messages: Vec<OllamaMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<OllamaTool>>,
stream: bool,
options: OllamaOptions,
}
#[derive(Debug, Serialize)]
struct OllamaOptions {
temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
num_predict: Option<u32>, // Ollama's equivalent of max_tokens
}
#[derive(Debug, Serialize)]
struct OllamaTool {
r#type: String,
function: OllamaFunction,
}
#[derive(Debug, Serialize)]
struct OllamaFunction {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct OllamaMessage {
role: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<OllamaToolCall>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct OllamaToolCall {
function: OllamaToolCallFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct OllamaToolCallFunction {
name: String,
arguments: serde_json::Value,
}
#[derive(Debug, Deserialize)]
struct OllamaResponse {
message: OllamaMessage,
#[allow(dead_code)]
done: bool,
#[allow(dead_code)]
total_duration: Option<u64>,
#[allow(dead_code)]
load_duration: Option<u64>,
prompt_eval_count: Option<u32>,
eval_count: Option<u32>,
}
#[derive(Debug, Deserialize)]
struct OllamaStreamChunk {
message: Option<OllamaMessage>,
done: Option<bool>,
prompt_eval_count: Option<u32>,
eval_count: Option<u32>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_creation() {
let provider = OllamaProvider::new(
"llama3.2".to_string(),
None,
Some(1000),
Some(0.7),
)
.unwrap();
assert_eq!(provider.model(), "llama3.2");
assert_eq!(provider.name(), "ollama");
assert!(provider.has_native_tool_calling());
}
#[test]
fn test_message_conversion() {
let provider = OllamaProvider::new(
"llama3.2".to_string(),
None,
None,
None,
)
.unwrap();
let messages = vec![
Message {
role: MessageRole::System,
content: "You are a helpful assistant.".to_string(),
},
Message {
role: MessageRole::User,
content: "Hello!".to_string(),
},
];
let ollama_messages = provider.convert_messages(&messages).unwrap();
assert_eq!(ollama_messages.len(), 2);
assert_eq!(ollama_messages[0].role, "system");
assert_eq!(ollama_messages[1].role, "user");
}
#[test]
fn test_tool_conversion() {
let provider = OllamaProvider::new(
"llama3.2".to_string(),
None,
None,
None,
)
.unwrap();
let tools = vec![Tool {
name: "get_weather".to_string(),
description: "Get the current weather".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}),
}];
let ollama_tools = provider.convert_tools(&tools);
assert_eq!(ollama_tools.len(), 1);
assert_eq!(ollama_tools[0].r#type, "function");
assert_eq!(ollama_tools[0].function.name, "get_weather");
}
#[test]
fn test_custom_base_url() {
let provider = OllamaProvider::new(
"llama3.2".to_string(),
Some("http://custom:11434".to_string()),
None,
None,
)
.unwrap();
assert_eq!(provider.base_url, "http://custom:11434");
}
}