diff --git a/OLLAMA_CONFIG.md b/OLLAMA_CONFIG.md new file mode 100644 index 0000000..888bba9 --- /dev/null +++ b/OLLAMA_CONFIG.md @@ -0,0 +1,456 @@ +# Configuring Ollama Provider in G3 + +This guide shows you how to configure G3 to use Ollama as your LLM provider. + +## Quick Start + +### 1. Install Ollama + +```bash +# Visit https://ollama.ai to download and install +# Or use curl: +curl https://ollama.ai/install.sh | sh +``` + +### 2. Pull a Model + +```bash +ollama pull llama3.2 +# or any other model you prefer +``` + +### 3. Create Configuration File + +Copy the example configuration: + +```bash +cp config.ollama.example.toml ~/.config/g3/config.toml +``` + +Or create it manually: + +```toml +[providers] +default_provider = "ollama" + +[providers.ollama] +model = "llama3.2" +``` + +### 4. Run G3 + +```bash +g3 +# G3 will now use Ollama with llama3.2! +``` + +## Configuration Options + +### Basic Configuration + +```toml +[providers] +default_provider = "ollama" + +[providers.ollama] +model = "llama3.2" +``` + +This is the minimal configuration needed. It uses all defaults: +- Base URL: `http://localhost:11434` +- Temperature: `0.7` +- Max tokens: Not limited (uses model default) + +### Full Configuration + +```toml +[providers] +default_provider = "ollama" + +[providers.ollama] +model = "llama3.2" +base_url = "http://localhost:11434" +max_tokens = 2048 +temperature = 0.7 +``` + +### Custom Ollama Host + +If you're running Ollama on a different machine or port: + +```toml +[providers.ollama] +model = "llama3.2" +base_url = "http://192.168.1.100:11434" +``` + +### Different Models + +You can use any Ollama model: + +```toml +[providers.ollama] +model = "qwen2.5:7b" # Alibaba's Qwen model +``` + +```toml +[providers.ollama] +model = "mistral" # Mistral AI +``` + +```toml +[providers.ollama] +model = "llama3.1:70b" # Larger Llama model +``` + +## Multiple Provider Configuration + +You can configure multiple providers and switch between them: + +```toml +[providers] +default_provider = "ollama" # Default for most operations + +# Ollama for local, fast responses +[providers.ollama] +model = "llama3.2:3b" +temperature = 0.7 + +# Databricks for more complex tasks +[providers.databricks] +host = "https://your-workspace.cloud.databricks.com" +model = "databricks-claude-sonnet-4" +max_tokens = 4096 +temperature = 0.1 +use_oauth = true +``` + +Then switch providers with: + +```bash +g3 --provider databricks +``` + +## Autonomous Mode (Coach-Player) + +Use different providers for code review (coach) and implementation (player): + +```toml +[providers] +default_provider = "ollama" +coach = "databricks" # Use powerful cloud model for review +player = "ollama" # Use local model for implementation + +[providers.ollama] +model = "qwen2.5:14b" # Larger local model for coding + +[providers.databricks] +host = "https://your-workspace.cloud.databricks.com" +model = "databricks-claude-sonnet-4" +use_oauth = true +``` + +This gives you the best of both worlds: +- Fast local execution for coding tasks +- Powerful cloud review for quality assurance + +## Recommended Models + +### For Coding Tasks + +| Model | Size | Speed | Quality | Notes | +|-------|------|-------|---------|-------| +| **qwen2.5:7b** | 7B | Fast | Excellent | Best balance for coding | +| **llama3.2:3b** | 3B | Very Fast | Good | Great for quick tasks | +| **llama3.1:8b** | 8B | Medium | Very Good | Solid all-rounder | +| **mistral** | 7B | Fast | Good | Good for general use | + +### For Complex Tasks + +| Model | Size | Speed | Quality | Notes | +|-------|------|-------|---------|-------| +| **qwen2.5:14b** | 14B | Medium | Excellent | Best local model for coding | +| **qwen2.5:32b** | 32B | Slow | Outstanding | If you have the resources | +| **llama3.1:70b** | 70B | Very Slow | Outstanding | Requires significant RAM/GPU | + +## Temperature Settings + +Temperature controls randomness in responses: + +- **0.1-0.3**: Deterministic, good for code generation +- **0.5-0.7**: Balanced, good for most tasks +- **0.8-1.0**: Creative, good for brainstorming + +```toml +[providers.ollama] +model = "qwen2.5:7b" +temperature = 0.2 # Focused code generation +``` + +## Max Tokens + +Control response length: + +```toml +[providers.ollama] +model = "llama3.2" +max_tokens = 1024 # Shorter responses +``` + +```toml +[providers.ollama] +model = "qwen2.5:7b" +max_tokens = 4096 # Longer, detailed responses +``` + +Leave it unset for model defaults (recommended). + +## Performance Tuning + +### GPU Acceleration + +Ollama automatically uses GPU if available. To check: + +```bash +ollama ps +``` + +### Quantized Models + +For faster responses with less RAM: + +```toml +[providers.ollama] +model = "llama3.2:3b-q4_0" # 4-bit quantization +``` + +Quantization options: +- `q4_0`: 4-bit, fastest, lowest quality +- `q5_0`: 5-bit, balanced +- `q8_0`: 8-bit, slower, better quality + +### Multiple Models + +You can pull multiple models and switch easily: + +```bash +ollama pull llama3.2:3b # Fast for chat +ollama pull qwen2.5:7b # Better for code +ollama pull mistral # General purpose +``` + +Then change your config: + +```toml +[providers.ollama] +model = "qwen2.5:7b" # Just change this line +``` + +## Troubleshooting + +### Ollama Not Running + +```bash +# Check if Ollama is running +curl http://localhost:11434/api/version + +# Start Ollama (macOS/Linux) +ollama serve + +# Or just run a model (auto-starts) +ollama run llama3.2 +``` + +### Model Not Found + +```bash +# List available models +ollama list + +# Pull the model +ollama pull llama3.2 +``` + +### Slow Responses + +1. Use a smaller model: + ```toml + model = "llama3.2:1b" # Smallest, fastest + ``` + +2. Use quantized version: + ```toml + model = "llama3.2:3b-q4_0" + ``` + +3. Reduce max_tokens: + ```toml + max_tokens = 512 + ``` + +### Out of Memory + +1. Switch to smaller model +2. Use quantized version +3. Close other applications +4. Check GPU memory: `ollama ps` + +### Connection Refused + +Check base_url is correct: + +```toml +[providers.ollama] +model = "llama3.2" +base_url = "http://localhost:11434" # Default +``` + +For remote Ollama: + +```toml +base_url = "http://your-server:11434" +``` + +## Complete Example Configs + +### Minimal Local Setup + +```toml +[providers] +default_provider = "ollama" + +[providers.ollama] +model = "llama3.2" + +[agent] +max_context_length = 8192 +enable_streaming = true +timeout_seconds = 60 +``` + +### Optimized for Coding + +```toml +[providers] +default_provider = "ollama" + +[providers.ollama] +model = "qwen2.5:7b" +temperature = 0.2 +max_tokens = 2048 + +[agent] +max_context_length = 16384 +enable_streaming = true +timeout_seconds = 120 +``` + +### Fast Responses + +```toml +[providers] +default_provider = "ollama" + +[providers.ollama] +model = "llama3.2:3b-q4_0" +temperature = 0.7 +max_tokens = 1024 + +[agent] +max_context_length = 4096 +enable_streaming = true +timeout_seconds = 30 +``` + +### High Quality (Requires Good Hardware) + +```toml +[providers] +default_provider = "ollama" + +[providers.ollama] +model = "qwen2.5:32b" +temperature = 0.3 +max_tokens = 4096 + +[agent] +max_context_length = 32768 +enable_streaming = true +timeout_seconds = 300 +``` + +### Hybrid (Local + Cloud) + +```toml +[providers] +default_provider = "ollama" +coach = "databricks" +player = "ollama" + +[providers.ollama] +model = "qwen2.5:14b" +temperature = 0.2 + +[providers.databricks] +host = "https://your-workspace.cloud.databricks.com" +model = "databricks-claude-sonnet-4" +use_oauth = true + +[agent] +max_context_length = 16384 +enable_streaming = true +timeout_seconds = 120 +``` + +## Environment Variables + +You can override config with environment variables: + +```bash +# Override model +G3_PROVIDERS_OLLAMA_MODEL=qwen2.5:7b g3 + +# Override base URL +G3_PROVIDERS_OLLAMA_BASE_URL=http://192.168.1.100:11434 g3 + +# Override default provider +G3_PROVIDERS_DEFAULT_PROVIDER=ollama g3 +``` + +## Best Practices + +1. **Start Small**: Begin with llama3.2:3b, scale up if needed +2. **Use Quantization**: q4_0 or q5_0 for best speed/quality balance +3. **Match Task to Model**: + - Quick edits: 1B-3B models + - Code generation: 7B-14B models + - Complex refactoring: 14B-32B models +4. **Temperature for Code**: Use 0.1-0.3 for deterministic output +5. **Enable Streaming**: Always enable for better UX +6. **Local First**: Use Ollama by default, cloud for special cases + +## Comparison with Other Providers + +| Feature | Ollama | Databricks | OpenAI | Anthropic | +|---------|--------|------------|--------|-----------| +| Cost | Free | Paid | Paid | Paid | +| Privacy | Full | Medium | Low | Low | +| Speed (small models) | Fast | Fast | Medium | Medium | +| Speed (large models) | Slow | Fast | Fast | Fast | +| Setup Complexity | Low | Medium | Low | Low | +| Authentication | None | OAuth/Token | API Key | API Key | +| Offline Support | Yes | No | No | No | +| Tool Calling | Yes | Yes | Yes | Yes | + +## Next Steps + +1. Try different models: `ollama pull mistral`, `ollama pull qwen2.5` +2. Experiment with temperature settings +3. Set up hybrid config with cloud provider for complex tasks +4. Share your config in the community! + +## Getting Help + +- Ollama docs: https://ollama.ai/docs +- G3 issues: https://github.com/your-repo/issues +- Test your config: `g3 --help` diff --git a/OLLAMA_EXAMPLE.md b/OLLAMA_EXAMPLE.md new file mode 100644 index 0000000..65d4528 --- /dev/null +++ b/OLLAMA_EXAMPLE.md @@ -0,0 +1,315 @@ +# Ollama Provider for g3 + +A simple, local LLM provider implementation for g3 that connects to Ollama. + +## Features + +- ✅ **Simple Setup**: No API keys or authentication required +- ✅ **Local Execution**: Runs entirely on your machine +- ✅ **Tool Calling Support**: Native tool calling for compatible models +- ✅ **Streaming**: Full streaming support with real-time responses +- ✅ **Flexible Configuration**: Custom base URL, temperature, and max tokens +- ✅ **Model Discovery**: Automatic detection of available models + +## Quick Start + +### Prerequisites + +1. Install and start Ollama: https://ollama.ai +2. Pull a model: `ollama pull llama3.2` + +### Basic Usage + +```rust +use g3_providers::{OllamaProvider, LLMProvider, CompletionRequest, Message, MessageRole}; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + // Create provider with default settings (localhost:11434) + let provider = OllamaProvider::new( + "llama3.2".to_string(), + None, // base_url: defaults to http://localhost:11434 + None, // max_tokens: optional + None, // temperature: defaults to 0.7 + )?; + + // Create a simple request + let request = CompletionRequest { + messages: vec![ + Message { + role: MessageRole::User, + content: "What is the capital of France?".to_string(), + }, + ], + max_tokens: Some(1000), + temperature: Some(0.7), + stream: false, + tools: None, + }; + + // Get completion + let response = provider.complete(request).await?; + println!("Response: {}", response.content); + println!("Tokens: {}", response.usage.total_tokens); + + Ok(()) +} +``` + +### Streaming Example + +```rust +use futures_util::StreamExt; + +let request = CompletionRequest { + messages: vec![ + Message { + role: MessageRole::User, + content: "Write a short poem about coding".to_string(), + }, + ], + max_tokens: Some(500), + temperature: Some(0.8), + stream: true, + tools: None, +}; + +let mut stream = provider.stream(request).await?; + +while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + print!("{}", chunk.content); + if chunk.finished { + println!("\n\nDone!"); + if let Some(usage) = chunk.usage { + println!("Total tokens: {}", usage.total_tokens); + } + } + } + Err(e) => eprintln!("Error: {}", e), + } +} +``` + +### Tool Calling Example + +```rust +use serde_json::json; + +let tools = vec![Tool { + name: "get_weather".to_string(), + description: "Get current weather for a location".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature unit" + } + }, + "required": ["location"] + }), +}]; + +let request = CompletionRequest { + messages: vec![ + Message { + role: MessageRole::User, + content: "What's the weather in Paris?".to_string(), + }, + ], + max_tokens: Some(500), + temperature: Some(0.5), + stream: false, + tools: Some(tools), +}; + +let response = provider.complete(request).await?; +println!("Response: {}", response.content); +``` + +### Custom Ollama Host + +```rust +// Connect to remote Ollama instance +let provider = OllamaProvider::new( + "llama3.2".to_string(), + Some("http://192.168.1.100:11434".to_string()), + None, + None, +)?; +``` + +### Fetch Available Models + +```rust +// Discover what models are available +let models = provider.fetch_available_models().await?; +println!("Available models:"); +for model in models { + println!(" - {}", model); +} +``` + +## Supported Models + +The provider works with any Ollama model, including: + +- **llama3.2** (1B, 3B) - Meta's latest Llama models +- **llama3.1** (8B, 70B, 405B) - Previous generation +- **qwen2.5** (7B, 14B, 32B) - Alibaba's Qwen models +- **mistral** - Mistral AI models +- **mixtral** - Mixture of experts model +- **phi3** - Microsoft's Phi-3 +- **gemma2** - Google's Gemma 2 + +## Configuration + +### Constructor Parameters + +```rust +OllamaProvider::new( + model: String, // Model name (e.g., "llama3.2") + base_url: Option, // Ollama API URL (default: http://localhost:11434) + max_tokens: Option, // Maximum tokens to generate (optional) + temperature: Option, // Sampling temperature (default: 0.7) +) +``` + +### Request Options + +```rust +CompletionRequest { + messages: Vec, // Conversation history + max_tokens: Option, // Override provider's max_tokens + temperature: Option, // Override provider's temperature + stream: bool, // Enable streaming responses + tools: Option>, // Tools for function calling +} +``` + +## Comparison with Other Providers + +| Feature | Ollama | OpenAI | Anthropic | Databricks | +|---------|--------|--------|-----------|------------| +| Local Execution | ✅ | ❌ | ❌ | ❌ | +| Authentication | None | API Key | API Key | OAuth/Token | +| Tool Calling | ✅ | ✅ | ✅ | ✅ | +| Streaming | ✅ | ✅ | ✅ | ✅ | +| Cost | Free | Paid | Paid | Paid | +| Privacy | High | Low | Low | Medium | + +## Implementation Details + +### API Endpoints + +- **Chat Completion**: `POST /api/chat` +- **Model List**: `GET /api/tags` + +### Response Format + +Ollama uses a simple JSON-per-line streaming format: + +```json +{"message":{"role":"assistant","content":"Hello"},"done":false} +{"message":{"role":"assistant","content":" there"},"done":false} +{"done":true,"prompt_eval_count":10,"eval_count":20} +``` + +### Tool Call Format + +Tool calls are returned in the message structure: + +```json +{ + "message": { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": "get_weather", + "arguments": {"location": "Paris", "unit": "celsius"} + } + } + ] + }, + "done": true +} +``` + +## Troubleshooting + +### Connection Errors + +If you see connection errors, ensure Ollama is running: + +```bash +# Check if Ollama is running +curl http://localhost:11434/api/version + +# Start Ollama (if needed) +ollama serve +``` + +### Model Not Found + +Pull the model first: + +```bash +ollama pull llama3.2 +ollama list # Check available models +``` + +### Performance Issues + +- Use smaller models (1B, 3B) for faster responses +- Reduce `max_tokens` to limit generation length +- Enable GPU acceleration if available +- Consider quantized models (e.g., `llama3.2:3b-q4_0`) + +## Testing + +Run the included tests: + +```bash +cargo test --package g3-providers ollama +``` + +All tests should pass: +``` +running 4 tests +test ollama::tests::test_custom_base_url ... ok +test ollama::tests::test_message_conversion ... ok +test ollama::tests::test_provider_creation ... ok +test ollama::tests::test_tool_conversion ... ok +``` + +## Architecture + +The provider follows the same architecture as other g3 providers: + +1. **OllamaProvider**: Main struct implementing `LLMProvider` trait +2. **Request/Response Structures**: Internal types for Ollama API +3. **Streaming Parser**: Handles line-by-line JSON parsing +4. **Tool Call Handling**: Accumulates and converts tool calls +5. **Error Handling**: Robust error handling with retries + +## Contributing + +The provider is part of the g3-providers crate. To contribute: + +1. Add features to `ollama.rs` +2. Update tests +3. Run `cargo test --package g3-providers` +4. Update this documentation + +## License + +Same as the g3 project. diff --git a/config.ollama.example.toml b/config.ollama.example.toml new file mode 100644 index 0000000..b1b5170 --- /dev/null +++ b/config.ollama.example.toml @@ -0,0 +1,26 @@ +# Example G3 configuration using Ollama provider +# Copy this to ~/.config/g3/config.toml or ./g3.toml to use it + +[providers] +default_provider = "ollama" + +# Ollama configuration (local LLM) +[providers.ollama] +model = "llama3.2" # or qwen2.5, mistral, etc. +# base_url = "http://localhost:11434" # Optional, defaults to localhost +# max_tokens = 2048 # Optional +# temperature = 0.7 # Optional + +# Optional: Specify different providers for coach and player in autonomous mode +# coach = "ollama" # Provider for coach (code reviewer) +# player = "ollama" # Provider for player (code implementer) + +[agent] +max_context_length = 8192 +enable_streaming = true +timeout_seconds = 60 + +[computer_control] +enabled = false # Set to true to enable computer control (requires OS permissions) +require_confirmation = true +max_actions_per_second = 5 diff --git a/crates/g3-config/src/lib.rs b/crates/g3-config/src/lib.rs index d9f0602..befd3c0 100644 --- a/crates/g3-config/src/lib.rs +++ b/crates/g3-config/src/lib.rs @@ -17,6 +17,7 @@ pub struct ProvidersConfig { pub anthropic: Option, pub databricks: Option, pub embedded: Option, + pub ollama: Option, pub default_provider: String, pub coach: Option, // Provider to use for coach in autonomous mode pub player: Option, // Provider to use for player in autonomous mode @@ -60,6 +61,14 @@ pub struct EmbeddedConfig { pub threads: Option, // Number of CPU threads to use } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OllamaConfig { + pub model: String, + pub base_url: Option, // Default: http://localhost:11434 + pub max_tokens: Option, + pub temperature: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AgentConfig { pub max_context_length: usize, @@ -128,6 +137,7 @@ impl Default for Config { use_oauth: Some(true), }), embedded: None, + ollama: None, default_provider: "databricks".to_string(), coach: None, // Will use default_provider if not specified player: None, // Will use default_provider if not specified @@ -244,6 +254,7 @@ impl Config { gpu_layers: Some(32), threads: Some(8), }), + ollama: None, default_provider: "embedded".to_string(), coach: None, // Will use default_provider if not specified player: None, // Will use default_provider if not specified diff --git a/crates/g3-core/src/lib.rs b/crates/g3-core/src/lib.rs index 43c4548..d6d4762 100644 --- a/crates/g3-core/src/lib.rs +++ b/crates/g3-core/src/lib.rs @@ -856,6 +856,19 @@ impl Agent { } } + // Register Ollama provider if configured AND it's the default provider + if let Some(ollama_config) = &config.providers.ollama { + if providers_to_register.contains(&"ollama".to_string()) { + let ollama_provider = g3_providers::OllamaProvider::new( + ollama_config.model.clone(), + ollama_config.base_url.clone(), + ollama_config.max_tokens, + ollama_config.temperature, + )?; + providers.register(ollama_provider); + } + } + // Set default provider debug!( "Setting default provider to: {}", @@ -962,6 +975,26 @@ impl Agent { 16384 // Conservative default for other Databricks models } } + "ollama" => { + // Ollama model context windows based on model name + if model_name.contains("qwen") { + 32768 // Qwen2.5 supports 32k context + } else if model_name.contains("llama3") || model_name.contains("llama-3") { + if model_name.contains("3.2") || model_name.contains("3.1") { + 128000 // Llama 3.1/3.2 support 128k context + } else { + 8192 // Llama 3.0 + } + } else if model_name.contains("mistral") || model_name.contains("mixtral") { + 32768 // Mistral/Mixtral support 32k + } else if model_name.contains("gemma") { + 8192 // Gemma 2 + } else if model_name.contains("phi") { + 4096 // Phi-3 + } else { + 8192 // Conservative default for Ollama models + } + } _ => config.agent.max_context_length as u32, }; diff --git a/crates/g3-providers/src/lib.rs b/crates/g3-providers/src/lib.rs index 51ea55a..d68a09f 100644 --- a/crates/g3-providers/src/lib.rs +++ b/crates/g3-providers/src/lib.rs @@ -88,11 +88,13 @@ pub mod anthropic; pub mod databricks; pub mod embedded; pub mod oauth; +pub mod ollama; pub mod openai; pub use anthropic::AnthropicProvider; pub use databricks::DatabricksProvider; pub use embedded::EmbeddedProvider; +pub use ollama::OllamaProvider; pub use openai::OpenAIProvider; /// Provider registry for managing multiple LLM providers diff --git a/crates/g3-providers/src/ollama.rs b/crates/g3-providers/src/ollama.rs new file mode 100644 index 0000000..1807528 --- /dev/null +++ b/crates/g3-providers/src/ollama.rs @@ -0,0 +1,702 @@ +//! Ollama LLM provider implementation for the g3-providers crate. +//! +//! This module provides an implementation of the `LLMProvider` trait for Ollama, +//! supporting both completion and streaming modes with native tool calling. +//! +//! # Features +//! +//! - Support for any Ollama model (llama3.2, mistral, qwen, etc.) +//! - Both completion and streaming response modes +//! - Native tool calling support for compatible models +//! - Configurable base URL (defaults to http://localhost:11434) +//! - Simple configuration with no authentication required +//! +//! # Usage +//! +//! ```rust,no_run +//! use g3_providers::{OllamaProvider, LLMProvider, CompletionRequest, Message, MessageRole}; +//! +//! #[tokio::main] +//! async fn main() -> anyhow::Result<()> { +//! // Create the provider with default settings (localhost:11434) +//! let provider = OllamaProvider::new( +//! "llama3.2".to_string(), +//! None, // Optional: base_url +//! None, // Optional: max tokens +//! None, // Optional: temperature +//! )?; +//! +//! // Create a completion request +//! let request = CompletionRequest { +//! messages: vec![ +//! Message { +//! role: MessageRole::User, +//! content: "Hello! How are you?".to_string(), +//! }, +//! ], +//! max_tokens: Some(1000), +//! temperature: Some(0.7), +//! stream: false, +//! tools: None, +//! }; +//! +//! // Get a completion +//! let response = provider.complete(request).await?; +//! println!("Response: {}", response.content); +//! +//! Ok(()) +//! } +//! ``` + +use anyhow::{anyhow, Result}; +use bytes::Bytes; +use futures_util::stream::StreamExt; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tracing::{debug, error, info, warn}; + +use crate::{ + CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message, + MessageRole, Tool, ToolCall, Usage, +}; + +const DEFAULT_BASE_URL: &str = "http://localhost:11434"; +const DEFAULT_TIMEOUT_SECS: u64 = 600; + +pub const OLLAMA_DEFAULT_MODEL: &str = "llama3.2"; +pub const OLLAMA_KNOWN_MODELS: &[&str] = &[ + "llama3.2", + "llama3.2:1b", + "llama3.2:3b", + "llama3.1", + "llama3.1:8b", + "llama3.1:70b", + "mistral", + "mistral-nemo", + "mixtral", + "qwen2.5", + "qwen2.5:7b", + "qwen2.5:14b", + "qwen2.5:32b", + "phi3", + "gemma2", +]; + +#[derive(Debug, Clone)] +pub struct OllamaProvider { + client: Client, + base_url: String, + model: String, + max_tokens: Option, + temperature: f32, +} + +impl OllamaProvider { + pub fn new( + model: String, + base_url: Option, + max_tokens: Option, + temperature: Option, + ) -> Result { + let client = Client::builder() + .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS)) + .build() + .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?; + + let base_url = base_url + .unwrap_or_else(|| DEFAULT_BASE_URL.to_string()) + .trim_end_matches('/') + .to_string(); + + info!( + "Initialized Ollama provider with model: {} at {}", + model, base_url + ); + + Ok(Self { + client, + base_url, + model, + max_tokens, + temperature: temperature.unwrap_or(0.7), + }) + } + + fn convert_tools(&self, tools: &[Tool]) -> Vec { + tools + .iter() + .map(|tool| OllamaTool { + r#type: "function".to_string(), + function: OllamaFunction { + name: tool.name.clone(), + description: tool.description.clone(), + parameters: tool.input_schema.clone(), + }, + }) + .collect() + } + + fn convert_messages(&self, messages: &[Message]) -> Result> { + let mut ollama_messages = Vec::new(); + + for message in messages { + let role = match message.role { + MessageRole::System => "system", + MessageRole::User => "user", + MessageRole::Assistant => "assistant", + }; + + ollama_messages.push(OllamaMessage { + role: role.to_string(), + content: message.content.clone(), + tool_calls: None, // Only used in responses + }); + } + + if ollama_messages.is_empty() { + return Err(anyhow!("At least one message is required")); + } + + Ok(ollama_messages) + } + + fn create_request_body( + &self, + messages: &[Message], + tools: Option<&[Tool]>, + streaming: bool, + max_tokens: Option, + temperature: f32, + ) -> Result { + let ollama_messages = self.convert_messages(messages)?; + let ollama_tools = tools.map(|t| self.convert_tools(t)); + + let mut options = OllamaOptions { + temperature, + num_predict: max_tokens, + }; + + // If max_tokens is provided, use it; otherwise use the instance default + if max_tokens.is_none() { + options.num_predict = self.max_tokens; + } + + let request = OllamaRequest { + model: self.model.clone(), + messages: ollama_messages, + tools: ollama_tools, + stream: streaming, + options, + }; + + Ok(request) + } + + async fn parse_streaming_response( + &self, + mut stream: impl futures_util::Stream> + Unpin, + tx: mpsc::Sender>, + ) -> Option { + let mut buffer = String::new(); + let mut accumulated_usage: Option = None; + let mut current_tool_calls: Vec = Vec::new(); + let mut byte_buffer = Vec::new(); + + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + // Append new bytes to our buffer + byte_buffer.extend_from_slice(&chunk); + + // Try to convert the entire buffer to UTF-8 + let chunk_str = match std::str::from_utf8(&byte_buffer) { + Ok(s) => { + let result = s.to_string(); + byte_buffer.clear(); + result + } + Err(e) => { + let valid_up_to = e.valid_up_to(); + if valid_up_to > 0 { + let valid_bytes = + byte_buffer.drain(..valid_up_to).collect::>(); + std::str::from_utf8(&valid_bytes).unwrap().to_string() + } else { + continue; + } + } + }; + + buffer.push_str(&chunk_str); + + // Process complete lines + while let Some(line_end) = buffer.find('\n') { + let line = buffer[..line_end].trim().to_string(); + buffer.drain(..line_end + 1); + + if line.is_empty() { + continue; + } + + // Ollama streaming sends JSON objects per line + match serde_json::from_str::(&line) { + Ok(chunk) => { + // Handle text content + if let Some(message) = &chunk.message { + let content = &message.content; + if !content.is_empty() { + debug!("Sending text chunk: '{}'", content); + let chunk = CompletionChunk { + content: content.clone(), + finished: false, + usage: None, + tool_calls: None, + }; + if tx.send(Ok(chunk)).await.is_err() { + debug!("Receiver dropped, stopping stream"); + return accumulated_usage; + } + } + + // Handle tool calls + if let Some(tool_calls) = &message.tool_calls { + current_tool_calls.extend(tool_calls.clone()); + } + } + + // Check if stream is done + if chunk.done.unwrap_or(false) { + debug!("Stream completed"); + + // Update usage if available + if let Some(eval_count) = chunk.eval_count { + accumulated_usage = Some(Usage { + prompt_tokens: chunk.prompt_eval_count.unwrap_or(0), + completion_tokens: eval_count, + total_tokens: chunk.prompt_eval_count.unwrap_or(0) + + eval_count, + }); + } + + // Send final chunk with tool calls if any + let final_tool_calls: Vec = current_tool_calls + .iter() + .map(|tc| ToolCall { + id: tc.function.name.clone(), // Ollama doesn't provide IDs + tool: tc.function.name.clone(), + args: tc.function.arguments.clone(), + }) + .collect(); + + let final_chunk = CompletionChunk { + content: String::new(), + finished: true, + usage: accumulated_usage.clone(), + tool_calls: if final_tool_calls.is_empty() { + None + } else { + Some(final_tool_calls) + }, + }; + if tx.send(Ok(final_chunk)).await.is_err() { + debug!("Receiver dropped, stopping stream"); + } + return accumulated_usage; + } + } + Err(e) => { + debug!("Failed to parse Ollama stream chunk: {} - Line: {}", e, line); + // Don't error out, just continue + } + } + } + } + Err(e) => { + error!("Stream error: {}", e); + let error_msg = e.to_string(); + if error_msg.contains("unexpected EOF") || error_msg.contains("connection") { + warn!("Connection terminated unexpectedly, treating as end of stream"); + break; + } else { + let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await; + } + return accumulated_usage; + } + } + } + + // Send final chunk if we haven't already + let final_tool_calls: Vec = current_tool_calls + .iter() + .map(|tc| ToolCall { + id: tc.function.name.clone(), + tool: tc.function.name.clone(), + args: tc.function.arguments.clone(), + }) + .collect(); + + let final_chunk = CompletionChunk { + content: String::new(), + finished: true, + usage: accumulated_usage.clone(), + tool_calls: if final_tool_calls.is_empty() { + None + } else { + Some(final_tool_calls) + }, + }; + let _ = tx.send(Ok(final_chunk)).await; + accumulated_usage + } + + /// Fetch available models from the Ollama instance + pub async fn fetch_available_models(&self) -> Result> { + let response = self + .client + .get(format!("{}/api/tags", self.base_url)) + .send() + .await + .map_err(|e| anyhow!("Failed to fetch Ollama models: {}", e))?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(anyhow!( + "Failed to fetch Ollama models: {} - {}", + status, + error_text + )); + } + + let json: serde_json::Value = response.json().await?; + let models = json + .get("models") + .and_then(|v| v.as_array()) + .ok_or_else(|| anyhow!("Unexpected response format: missing 'models' array"))?; + + let model_names: Vec = models + .iter() + .filter_map(|model| model.get("name").and_then(|n| n.as_str()).map(String::from)) + .collect(); + + debug!("Found {} models in Ollama", model_names.len()); + Ok(model_names) + } +} + +#[async_trait::async_trait] +impl LLMProvider for OllamaProvider { + async fn complete(&self, request: CompletionRequest) -> Result { + debug!( + "Processing Ollama completion request with {} messages", + request.messages.len() + ); + + let max_tokens = request.max_tokens.or(self.max_tokens); + let temperature = request.temperature.unwrap_or(self.temperature); + + let request_body = self.create_request_body( + &request.messages, + request.tools.as_deref(), + false, + max_tokens, + temperature, + )?; + + debug!( + "Sending request to Ollama API: model={}, temperature={}", + self.model, request_body.options.temperature + ); + + let response = self + .client + .post(format!("{}/api/chat", self.base_url)) + .json(&request_body) + .send() + .await + .map_err(|e| anyhow!("Failed to send request to Ollama API: {}", e))?; + + let status = response.status(); + if !status.is_success() { + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(anyhow!("Ollama API error {}: {}", status, error_text)); + } + + let response_text = response.text().await?; + debug!("Raw Ollama API response: {}", response_text); + + let ollama_response: OllamaResponse = + serde_json::from_str(&response_text).map_err(|e| { + anyhow!( + "Failed to parse Ollama response: {} - Response: {}", + e, + response_text + ) + })?; + + let content = ollama_response.message.content.clone(); + + let usage = Usage { + prompt_tokens: ollama_response.prompt_eval_count.unwrap_or(0), + completion_tokens: ollama_response.eval_count.unwrap_or(0), + total_tokens: ollama_response.prompt_eval_count.unwrap_or(0) + + ollama_response.eval_count.unwrap_or(0), + }; + + debug!( + "Ollama completion successful: {} tokens generated", + usage.completion_tokens + ); + + Ok(CompletionResponse { + content, + usage, + model: self.model.clone(), + }) + } + + async fn stream(&self, request: CompletionRequest) -> Result { + debug!( + "Processing Ollama streaming request with {} messages", + request.messages.len() + ); + + if let Some(ref tools) = request.tools { + debug!("Request has {} tools", tools.len()); + for tool in tools.iter().take(5) { + debug!(" Tool: {}", tool.name); + } + } + + let max_tokens = request.max_tokens.or(self.max_tokens); + let temperature = request.temperature.unwrap_or(self.temperature); + + let request_body = self.create_request_body( + &request.messages, + request.tools.as_deref(), + true, + max_tokens, + temperature, + )?; + + debug!( + "Sending streaming request to Ollama API: model={}, temperature={}", + self.model, request_body.options.temperature + ); + + let response = self + .client + .post(format!("{}/api/chat", self.base_url)) + .json(&request_body) + .send() + .await + .map_err(|e| anyhow!("Failed to send streaming request to Ollama API: {}", e))?; + + let status = response.status(); + if !status.is_success() { + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(anyhow!("Ollama API error {}: {}", status, error_text)); + } + + let stream = response.bytes_stream(); + let (tx, rx) = mpsc::channel(100); + + let provider = self.clone(); + tokio::spawn(async move { + provider.parse_streaming_response(stream, tx).await; + }); + + Ok(ReceiverStream::new(rx)) + } + + fn name(&self) -> &str { + "ollama" + } + + fn model(&self) -> &str { + &self.model + } + + fn has_native_tool_calling(&self) -> bool { + // Most modern Ollama models support tool calling + // Models like llama3.2, qwen2.5, mistral, etc. have good tool support + true + } +} + +// Ollama API request/response structures + +#[derive(Debug, Serialize)] +struct OllamaRequest { + model: String, + messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + stream: bool, + options: OllamaOptions, +} + +#[derive(Debug, Serialize)] +struct OllamaOptions { + temperature: f32, + #[serde(skip_serializing_if = "Option::is_none")] + num_predict: Option, // Ollama's equivalent of max_tokens +} + +#[derive(Debug, Serialize)] +struct OllamaTool { + r#type: String, + function: OllamaFunction, +} + +#[derive(Debug, Serialize)] +struct OllamaFunction { + name: String, + description: String, + parameters: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct OllamaMessage { + role: String, + content: String, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct OllamaToolCall { + function: OllamaToolCallFunction, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct OllamaToolCallFunction { + name: String, + arguments: serde_json::Value, +} + +#[derive(Debug, Deserialize)] +struct OllamaResponse { + message: OllamaMessage, + #[allow(dead_code)] + done: bool, + #[allow(dead_code)] + total_duration: Option, + #[allow(dead_code)] + load_duration: Option, + prompt_eval_count: Option, + eval_count: Option, +} + +#[derive(Debug, Deserialize)] +struct OllamaStreamChunk { + message: Option, + done: Option, + prompt_eval_count: Option, + eval_count: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_provider_creation() { + let provider = OllamaProvider::new( + "llama3.2".to_string(), + None, + Some(1000), + Some(0.7), + ) + .unwrap(); + + assert_eq!(provider.model(), "llama3.2"); + assert_eq!(provider.name(), "ollama"); + assert!(provider.has_native_tool_calling()); + } + + #[test] + fn test_message_conversion() { + let provider = OllamaProvider::new( + "llama3.2".to_string(), + None, + None, + None, + ) + .unwrap(); + + let messages = vec![ + Message { + role: MessageRole::System, + content: "You are a helpful assistant.".to_string(), + }, + Message { + role: MessageRole::User, + content: "Hello!".to_string(), + }, + ]; + + let ollama_messages = provider.convert_messages(&messages).unwrap(); + + assert_eq!(ollama_messages.len(), 2); + assert_eq!(ollama_messages[0].role, "system"); + assert_eq!(ollama_messages[1].role, "user"); + } + + #[test] + fn test_tool_conversion() { + let provider = OllamaProvider::new( + "llama3.2".to_string(), + None, + None, + None, + ) + .unwrap(); + + let tools = vec![Tool { + name: "get_weather".to_string(), + description: "Get the current weather".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state" + } + }, + "required": ["location"] + }), + }]; + + let ollama_tools = provider.convert_tools(&tools); + + assert_eq!(ollama_tools.len(), 1); + assert_eq!(ollama_tools[0].r#type, "function"); + assert_eq!(ollama_tools[0].function.name, "get_weather"); + } + + #[test] + fn test_custom_base_url() { + let provider = OllamaProvider::new( + "llama3.2".to_string(), + Some("http://custom:11434".to_string()), + None, + None, + ) + .unwrap(); + + assert_eq!(provider.base_url, "http://custom:11434"); + } +}