Merge branch 'main' into micn/fix-anthropic-1p

* main:
  fix panic in CLI parser
  coach/player provider split + add OpenAI
This commit is contained in:
Michael Neale
2025-10-22 15:01:18 +11:00
12 changed files with 884 additions and 141 deletions

1
Cargo.lock generated
View File

@@ -1316,6 +1316,7 @@ dependencies = [
"dirs 5.0.1", "dirs 5.0.1",
"serde", "serde",
"shellexpand", "shellexpand",
"tempfile",
"thiserror 1.0.69", "thiserror 1.0.69",
"toml", "toml",
] ]

View File

@@ -0,0 +1,24 @@
[providers]
default_provider = "databricks"
# Specify different providers for coach and player in autonomous mode
coach = "databricks" # Provider for coach (code reviewer) - can be more powerful/expensive
player = "anthropic" # Provider for player (code implementer) - can be faster/cheaper
[providers.databricks]
host = "https://your-workspace.cloud.databricks.com"
# token = "your-databricks-token" # Optional - will use OAuth if not provided
model = "databricks-claude-sonnet-4"
max_tokens = 4096
temperature = 0.1
use_oauth = true
[providers.anthropic]
api_key = "your-anthropic-api-key"
model = "claude-3-haiku-20240307" # Using a faster model for player
max_tokens = 4096
temperature = 0.3 # Slightly higher temperature for more creative implementations
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60

View File

@@ -1,5 +1,10 @@
[providers] [providers]
default_provider = "databricks" default_provider = "databricks"
# Optional: Specify different providers for coach and player in autonomous mode
# If not specified, will use default_provider for both
# coach = "databricks" # Provider for coach (code reviewer)
# player = "anthropic" # Provider for player (code implementer)
# Note: Make sure the specified providers are configured below
[providers.databricks] [providers.databricks]
host = "https://your-workspace.cloud.databricks.com" host = "https://your-workspace.cloud.databricks.com"

View File

@@ -103,11 +103,14 @@ fn extract_coach_feedback_from_logs(
coach_result: &g3_core::TaskResult, coach_result: &g3_core::TaskResult,
coach_agent: &g3_core::Agent<ConsoleUiWriter>, coach_agent: &g3_core::Agent<ConsoleUiWriter>,
output: &SimpleOutput, output: &SimpleOutput,
) -> String { ) -> Result<String> {
// CORRECT APPROACH: Get the session ID from the current coach agent
// and read its specific log file directly
// Get the coach agent's session ID // Get the coach agent's session ID
let session_id = coach_agent let session_id = coach_agent
.get_session_id() .get_session_id()
.expect("Coach agent has no session ID"); .ok_or_else(|| anyhow::anyhow!("Coach agent has no session ID"))?;
// Construct the log file path for this specific coach session // Construct the log file path for this specific coach session
let logs_dir = std::path::Path::new("logs"); let logs_dir = std::path::Path::new("logs");
@@ -120,75 +123,15 @@ fn extract_coach_feedback_from_logs(
if let Some(context_window) = log_json.get("context_window") { if let Some(context_window) = log_json.get("context_window") {
if let Some(conversation_history) = context_window.get("conversation_history") { if let Some(conversation_history) = context_window.get("conversation_history") {
if let Some(messages) = conversation_history.as_array() { if let Some(messages) = conversation_history.as_array() {
// Look for the last assistant message (regardless of tool used) // Simply get the last message content - this is the coach's final feedback
for message in messages.iter().rev() { if let Some(last_message) = messages.last() {
if let Some(role) = message.get("role") { if let Some(content) = last_message.get("content") {
if role.as_str() == Some("assistant") { if let Some(content_str) = content.as_str() {
if let Some(content) = message.get("content") { output.print(&format!(
if let Some(content_str) = content.as_str() { "✅ Extracted coach feedback from session: {}",
// First, check if this is plain text feedback (no tool call) session_id
// This happens when the coach returns final feedback directly ));
if !content_str.contains("{\"tool\"") { return Ok(content_str.to_string());
let trimmed = content_str.trim();
if !trimmed.is_empty() {
output.print(&format!(
"✅ Extracted coach feedback from session: {} ({} chars) [plain text]",
session_id,
trimmed.len()
));
return trimmed.to_string();
}
}
// Look for ANY tool call in the message
// Pattern: {"tool": "...", "args": {...}}
if let Some(tool_start) = content_str.find("{\"tool\"") {
let json_part = &content_str[tool_start..];
// Find the end of the JSON object
if let Some(json_end) = find_json_end(json_part) {
let json_str = &json_part[..json_end];
if let Ok(tool_call) = serde_json::from_str::<serde_json::Value>(json_str) {
if let Some(args) = tool_call.get("args") {
// Try to extract feedback from different possible fields
let feedback = if let Some(summary) = args.get("summary") {
// final_output tool uses "summary"
summary.as_str().map(|s| s.to_string())
} else if let Some(content) = args.get("content") {
// todo_write and other tools might use "content"
content.as_str().map(|s| s.to_string())
} else {
// Fallback: use the entire args as JSON string
Some(serde_json::to_string_pretty(args).unwrap_or_default())
};
if let Some(feedback_str) = feedback {
if !feedback_str.trim().is_empty() {
output.print(&format!(
"✅ Extracted coach feedback from session: {} ({} chars)",
session_id,
feedback_str.len()
));
// Validate feedback length
if feedback_str.len() < 80 && !feedback_str.contains("IMPLEMENTATION_APPROVED") {
panic!(
"Coach feedback is too short ({} chars): '{}'",
feedback_str.len(),
feedback_str
);
}
return feedback_str;
}
}
}
}
}
}
}
}
} }
} }
} }
@@ -213,35 +156,6 @@ fn extract_coach_feedback_from_logs(
); );
} }
/// Helper function to find the end of a JSON object using brace counting
fn find_json_end(json_str: &str) -> Option<usize> {
let mut depth = 0;
let mut in_string = false;
let mut escape_next = false;
for (i, ch) in json_str.char_indices() {
if escape_next {
escape_next = false;
continue;
}
match ch {
'\\' if in_string => escape_next = true,
'"' => in_string = !in_string,
'{' if !in_string => depth += 1,
'}' if !in_string => {
depth -= 1;
if depth == 0 {
return Some(i + 1);
}
}
_ => {}
}
}
None
}
use clap::Parser; use clap::Parser;
use g3_config::Config; use g3_config::Config;
use g3_core::{project::Project, ui_writer::UiWriter, Agent}; use g3_core::{project::Project, ui_writer::UiWriter, Agent};
@@ -321,10 +235,6 @@ pub struct Cli {
/// Disable log file creation (no logs/ directory or session logs) /// Disable log file creation (no logs/ directory or session logs)
#[arg(long)] #[arg(long)]
pub quiet: bool, pub quiet: bool,
/// Enable WebDriver tools for browser automation (Safari)
#[arg(long)]
pub webdriver: bool,
} }
pub async fn run() -> Result<()> { pub async fn run() -> Result<()> {
@@ -413,17 +323,12 @@ pub async fn run() -> Result<()> {
} }
// Load configuration with CLI overrides // Load configuration with CLI overrides
let mut config = Config::load_with_overrides( let config = Config::load_with_overrides(
cli.config.as_deref(), cli.config.as_deref(),
cli.provider.clone(), cli.provider.clone(),
cli.model.clone(), cli.model.clone(),
)?; )?;
// Override webdriver setting from CLI flag
if cli.webdriver {
config.webdriver.enabled = true;
}
// Validate provider if specified // Validate provider if specified
if let Some(ref provider) = cli.provider { if let Some(ref provider) = cli.provider {
let valid_providers = ["anthropic", "databricks", "embedded", "openai"]; let valid_providers = ["anthropic", "databricks", "embedded", "openai"];
@@ -1358,10 +1263,6 @@ async fn run_autonomous(
loop { loop {
let turn_start_time = Instant::now(); let turn_start_time = Instant::now();
let turn_start_tokens = agent.get_context_window().used_tokens; let turn_start_tokens = agent.get_context_window().used_tokens;
// Reset filter suppression state at the start of each turn
g3_core::fixed_filter_json::reset_fixed_json_tool_state();
// Skip player turn if it's the first turn and implementation files exist // Skip player turn if it's the first turn and implementation files exist
if !(turn == 1 && skip_first_player) { if !(turn == 1 && skip_first_player) {
output.print(&format!( output.print(&format!(
@@ -1522,14 +1423,15 @@ async fn run_autonomous(
// Create a new agent instance for coach mode to ensure fresh context // Create a new agent instance for coach mode to ensure fresh context
// Use the same config with overrides that was passed to the player agent // Use the same config with overrides that was passed to the player agent
let config = agent.get_config().clone(); let base_config = agent.get_config().clone();
let coach_config = base_config.for_coach()?;
// Reset filter suppression state before creating coach agent // Reset filter suppression state before creating coach agent
g3_core::fixed_filter_json::reset_fixed_json_tool_state(); g3_core::fixed_filter_json::reset_fixed_json_tool_state();
let ui_writer = ConsoleUiWriter::new(); let ui_writer = ConsoleUiWriter::new();
let mut coach_agent = let mut coach_agent =
Agent::new_autonomous_with_readme_and_quiet(config, ui_writer, None, quiet).await?; Agent::new_autonomous_with_readme_and_quiet(coach_config, ui_writer, None, quiet).await?;
// Ensure coach agent is also in the workspace directory // Ensure coach agent is also in the workspace directory
project.enter_workspace()?; project.enter_workspace()?;
@@ -1677,7 +1579,7 @@ Remember: Be clear in your review and concise in your feedback. APPROVE if the i
// Extract the complete coach feedback from final_output // Extract the complete coach feedback from final_output
let coach_feedback_text = let coach_feedback_text =
extract_coach_feedback_from_logs(&coach_result, &coach_agent, &output); extract_coach_feedback_from_logs(&coach_result, &coach_agent, &output)?;
// Log the size of the feedback for debugging // Log the size of the feedback for debugging
info!( info!(
@@ -1704,15 +1606,6 @@ Remember: Be clear in your review and concise in your feedback. APPROVE if the i
output.print_smart(&format!("Coach feedback:\n{}", coach_feedback_text)); output.print_smart(&format!("Coach feedback:\n{}", coach_feedback_text));
// Record turn metrics before checking for approval or max turns
let turn_duration = turn_start_time.elapsed();
let turn_tokens = agent.get_context_window().used_tokens.saturating_sub(turn_start_tokens);
turn_metrics.push(TurnMetrics {
turn_number: turn,
tokens_used: turn_tokens,
wall_clock_time: turn_duration,
});
// Check if coach approved the implementation // Check if coach approved the implementation
if coach_result.is_approved() || coach_feedback_text.contains("IMPLEMENTATION_APPROVED") { if coach_result.is_approved() || coach_feedback_text.contains("IMPLEMENTATION_APPROVED") {
output.print("\n=== SESSION COMPLETED - IMPLEMENTATION APPROVED ==="); output.print("\n=== SESSION COMPLETED - IMPLEMENTATION APPROVED ===");
@@ -1721,7 +1614,6 @@ Remember: Be clear in your review and concise in your feedback. APPROVE if the i
break; break;
} }
// Increment turn counter after recording metrics but before checking max turns
// Check if we've reached max turns // Check if we've reached max turns
if turn >= max_turns { if turn >= max_turns {
output.print("\n=== SESSION COMPLETED - MAX TURNS REACHED ==="); output.print("\n=== SESSION COMPLETED - MAX TURNS REACHED ===");
@@ -1731,7 +1623,14 @@ Remember: Be clear in your review and concise in your feedback. APPROVE if the i
// Store coach feedback for next iteration // Store coach feedback for next iteration
coach_feedback = coach_feedback_text; coach_feedback = coach_feedback_text;
// Record turn metrics before incrementing
let turn_duration = turn_start_time.elapsed();
let turn_tokens = agent.get_context_window().used_tokens.saturating_sub(turn_start_tokens);
turn_metrics.push(TurnMetrics {
turn_number: turn,
tokens_used: turn_tokens,
wall_clock_time: turn_duration,
});
turn += 1; turn += 1;
output.print("🔄 Coach provided feedback for next iteration"); output.print("🔄 Coach provided feedback for next iteration");

View File

@@ -12,3 +12,6 @@ thiserror = { workspace = true }
toml = "0.8" toml = "0.8"
shellexpand = "3.0" shellexpand = "3.0"
dirs = "5.0" dirs = "5.0"
[dev-dependencies]
tempfile = "3.8"

View File

@@ -17,6 +17,8 @@ pub struct ProvidersConfig {
pub databricks: Option<DatabricksConfig>, pub databricks: Option<DatabricksConfig>,
pub embedded: Option<EmbeddedConfig>, pub embedded: Option<EmbeddedConfig>,
pub default_provider: String, pub default_provider: String,
pub coach: Option<String>, // Provider to use for coach in autonomous mode
pub player: Option<String>, // Provider to use for player in autonomous mode
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -112,6 +114,8 @@ impl Default for Config {
}), }),
embedded: None, embedded: None,
default_provider: "databricks".to_string(), default_provider: "databricks".to_string(),
coach: None, // Will use default_provider if not specified
player: None, // Will use default_provider if not specified
}, },
agent: AgentConfig { agent: AgentConfig {
max_context_length: 8192, max_context_length: 8192,
@@ -224,6 +228,8 @@ impl Config {
threads: Some(8), threads: Some(8),
}), }),
default_provider: "embedded".to_string(), default_provider: "embedded".to_string(),
coach: None, // Will use default_provider if not specified
player: None, // Will use default_provider if not specified
}, },
agent: AgentConfig { agent: AgentConfig {
max_context_length: 8192, max_context_length: 8192,
@@ -300,4 +306,67 @@ impl Config {
Ok(config) Ok(config)
} }
/// Get the provider to use for coach mode in autonomous execution
pub fn get_coach_provider(&self) -> &str {
self.providers.coach
.as_deref()
.unwrap_or(&self.providers.default_provider)
}
/// Get the provider to use for player mode in autonomous execution
pub fn get_player_provider(&self) -> &str {
self.providers.player
.as_deref()
.unwrap_or(&self.providers.default_provider)
}
/// Create a copy of the config with a different default provider
pub fn with_provider_override(&self, provider: &str) -> Result<Self> {
// Validate that the provider is configured
match provider {
"anthropic" if self.providers.anthropic.is_none() => {
return Err(anyhow::anyhow!(
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
provider, provider
));
}
"databricks" if self.providers.databricks.is_none() => {
return Err(anyhow::anyhow!(
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
provider, provider
));
}
"embedded" if self.providers.embedded.is_none() => {
return Err(anyhow::anyhow!(
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
provider, provider
));
}
"openai" if self.providers.openai.is_none() => {
return Err(anyhow::anyhow!(
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
provider, provider
));
}
_ => {} // Provider is configured or unknown (will be caught later)
}
let mut config = self.clone();
config.providers.default_provider = provider.to_string();
Ok(config)
}
/// Create a copy of the config for coach mode in autonomous execution
pub fn for_coach(&self) -> Result<Self> {
self.with_provider_override(self.get_coach_provider())
}
/// Create a copy of the config for player mode in autonomous execution
pub fn for_player(&self) -> Result<Self> {
self.with_provider_override(self.get_player_provider())
}
} }
#[cfg(test)]
mod tests;

View File

@@ -0,0 +1,131 @@
#[cfg(test)]
mod tests {
use crate::Config;
use std::fs;
use tempfile::TempDir;
#[test]
fn test_coach_player_providers() {
// Create a temporary directory for the test config
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("test_config.toml");
// Write a test configuration with coach and player providers
let config_content = r#"
[providers]
default_provider = "databricks"
coach = "anthropic"
player = "embedded"
[providers.databricks]
host = "https://test.databricks.com"
token = "test-token"
model = "test-model"
[providers.anthropic]
api_key = "test-key"
model = "claude-3"
[providers.embedded]
model_path = "test.gguf"
model_type = "llama"
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
"#;
fs::write(&config_path, config_content).unwrap();
// Load the configuration
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
// Test that the providers are correctly identified
assert_eq!(config.providers.default_provider, "databricks");
assert_eq!(config.get_coach_provider(), "anthropic");
assert_eq!(config.get_player_provider(), "embedded");
// Test creating coach config
let coach_config = config.for_coach().unwrap();
assert_eq!(coach_config.providers.default_provider, "anthropic");
// Test creating player config
let player_config = config.for_player().unwrap();
assert_eq!(player_config.providers.default_provider, "embedded");
}
#[test]
fn test_coach_player_fallback_to_default() {
// Create a temporary directory for the test config
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("test_config.toml");
// Write a test configuration WITHOUT coach and player providers
let config_content = r#"
[providers]
default_provider = "databricks"
[providers.databricks]
host = "https://test.databricks.com"
token = "test-token"
model = "test-model"
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
"#;
fs::write(&config_path, config_content).unwrap();
// Load the configuration
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
// Test that coach and player fall back to default provider
assert_eq!(config.get_coach_provider(), "databricks");
assert_eq!(config.get_player_provider(), "databricks");
// Test creating coach config (should use default)
let coach_config = config.for_coach().unwrap();
assert_eq!(coach_config.providers.default_provider, "databricks");
// Test creating player config (should use default)
let player_config = config.for_player().unwrap();
assert_eq!(player_config.providers.default_provider, "databricks");
}
#[test]
fn test_invalid_provider_error() {
// Create a temporary directory for the test config
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("test_config.toml");
// Write a test configuration with an unconfigured provider
let config_content = r#"
[providers]
default_provider = "databricks"
coach = "openai" # OpenAI is not configured
[providers.databricks]
host = "https://test.databricks.com"
token = "test-token"
model = "test-model"
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
"#;
fs::write(&config_path, config_content).unwrap();
// Load the configuration
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
// Test that trying to create a coach config with unconfigured provider fails
let result = config.for_coach();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not configured"));
}
}

View File

@@ -625,13 +625,32 @@ impl<W: UiWriter> Agent<W> {
) -> Result<Self> { ) -> Result<Self> {
let mut providers = ProviderRegistry::new(); let mut providers = ProviderRegistry::new();
// In autonomous mode, we need to register both coach and player providers
// Otherwise, only register the default provider
let providers_to_register: Vec<String> = if is_autonomous {
let mut providers = vec![config.providers.default_provider.clone()];
if let Some(coach) = &config.providers.coach {
if !providers.contains(coach) {
providers.push(coach.clone());
}
}
if let Some(player) = &config.providers.player {
if !providers.contains(player) {
providers.push(player.clone());
}
}
providers
} else {
vec![config.providers.default_provider.clone()]
};
// Only register providers that are configured AND selected as the default provider // Only register providers that are configured AND selected as the default provider
// This prevents unnecessary initialization of heavy providers like embedded models // This prevents unnecessary initialization of heavy providers like embedded models
// Register embedded provider if configured AND it's the default provider // Register embedded provider if configured AND it's the default provider
if let Some(embedded_config) = &config.providers.embedded { if let Some(embedded_config) = &config.providers.embedded {
if config.providers.default_provider == "embedded" { if providers_to_register.contains(&"embedded".to_string()) {
info!("Initializing embedded provider (selected as default)"); info!("Initializing embedded provider");
let embedded_provider = g3_providers::EmbeddedProvider::new( let embedded_provider = g3_providers::EmbeddedProvider::new(
embedded_config.model_path.clone(), embedded_config.model_path.clone(),
embedded_config.model_type.clone(), embedded_config.model_type.clone(),
@@ -643,14 +662,31 @@ impl<W: UiWriter> Agent<W> {
)?; )?;
providers.register(embedded_provider); providers.register(embedded_provider);
} else { } else {
info!("Embedded provider configured but not selected as default, skipping initialization"); info!("Embedded provider configured but not needed, skipping initialization");
}
}
// Register OpenAI provider if configured AND it's the default provider
if let Some(openai_config) = &config.providers.openai {
if providers_to_register.contains(&"openai".to_string()) {
info!("Initializing OpenAI provider");
let openai_provider = g3_providers::OpenAIProvider::new(
openai_config.api_key.clone(),
Some(openai_config.model.clone()),
openai_config.base_url.clone(),
openai_config.max_tokens,
openai_config.temperature,
)?;
providers.register(openai_provider);
} else {
info!("OpenAI provider configured but not needed, skipping initialization");
} }
} }
// Register Anthropic provider if configured AND it's the default provider // Register Anthropic provider if configured AND it's the default provider
if let Some(anthropic_config) = &config.providers.anthropic { if let Some(anthropic_config) = &config.providers.anthropic {
if config.providers.default_provider == "anthropic" { if providers_to_register.contains(&"anthropic".to_string()) {
info!("Initializing Anthropic provider (selected as default)"); info!("Initializing Anthropic provider");
let anthropic_provider = g3_providers::AnthropicProvider::new( let anthropic_provider = g3_providers::AnthropicProvider::new(
anthropic_config.api_key.clone(), anthropic_config.api_key.clone(),
Some(anthropic_config.model.clone()), Some(anthropic_config.model.clone()),
@@ -659,14 +695,14 @@ impl<W: UiWriter> Agent<W> {
)?; )?;
providers.register(anthropic_provider); providers.register(anthropic_provider);
} else { } else {
info!("Anthropic provider configured but not selected as default, skipping initialization"); info!("Anthropic provider configured but not needed, skipping initialization");
} }
} }
// Register Databricks provider if configured AND it's the default provider // Register Databricks provider if configured AND it's the default provider
if let Some(databricks_config) = &config.providers.databricks { if let Some(databricks_config) = &config.providers.databricks {
if config.providers.default_provider == "databricks" { if providers_to_register.contains(&"databricks".to_string()) {
info!("Initializing Databricks provider (selected as default)"); info!("Initializing Databricks provider");
let databricks_provider = if let Some(token) = &databricks_config.token { let databricks_provider = if let Some(token) = &databricks_config.token {
// Use token-based authentication // Use token-based authentication
@@ -690,7 +726,7 @@ impl<W: UiWriter> Agent<W> {
providers.register(databricks_provider); providers.register(databricks_provider);
} else { } else {
info!("Databricks provider configured but not selected as default, skipping initialization"); info!("Databricks provider configured but not needed, skipping initialization");
} }
} }
@@ -773,6 +809,9 @@ impl<W: UiWriter> Agent<W> {
config.agent.max_context_length as u32 config.agent.max_context_length as u32
} }
} }
"openai" => {
192000
}
"anthropic" => { "anthropic" => {
// Claude models have large context windows // Claude models have large context windows
200000 // Default for Claude models 200000 // Default for Claude models
@@ -1060,7 +1099,6 @@ Template:
}; };
// Get max_tokens from provider configuration // Get max_tokens from provider configuration
// For Databricks, this should be much higher to support large file generation
let max_tokens = match provider.name() { let max_tokens = match provider.name() {
"databricks" => { "databricks" => {
// Use the model's maximum limit for Databricks to allow large file generation // Use the model's maximum limit for Databricks to allow large file generation

View File

@@ -156,8 +156,9 @@ impl AnthropicProvider {
.post(ANTHROPIC_API_URL) .post(ANTHROPIC_API_URL)
.header("x-api-key", &self.api_key) .header("x-api-key", &self.api_key)
.header("anthropic-version", ANTHROPIC_VERSION) .header("anthropic-version", ANTHROPIC_VERSION)
// Anthropic beta 1m context window. Enable if needed. It costs extra, so check first.
// .header("anthropic-beta", "context-1m-2025-08-07")
.header("content-type", "application/json"); .header("content-type", "application/json");
if streaming { if streaming {
builder = builder.header("accept", "text/event-stream"); builder = builder.header("accept", "text/event-stream");
} }

View File

@@ -88,10 +88,12 @@ pub mod anthropic;
pub mod databricks; pub mod databricks;
pub mod embedded; pub mod embedded;
pub mod oauth; pub mod oauth;
pub mod openai;
pub use anthropic::AnthropicProvider; pub use anthropic::AnthropicProvider;
pub use databricks::DatabricksProvider; pub use databricks::DatabricksProvider;
pub use embedded::EmbeddedProvider; pub use embedded::EmbeddedProvider;
pub use openai::OpenAIProvider;
/// Provider registry for managing multiple LLM providers /// Provider registry for managing multiple LLM providers
pub struct ProviderRegistry { pub struct ProviderRegistry {

View File

@@ -0,0 +1,495 @@
use anyhow::Result;
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::stream::StreamExt;
use reqwest::Client;
use serde::Deserialize;
use serde_json::json;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error};
use crate::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider,
Message, MessageRole, Tool, ToolCall, Usage,
};
#[derive(Clone)]
pub struct OpenAIProvider {
client: Client,
api_key: String,
model: String,
base_url: String,
max_tokens: Option<u32>,
_temperature: Option<f32>,
}
impl OpenAIProvider {
pub fn new(
api_key: String,
model: Option<String>,
base_url: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
) -> Result<Self> {
Ok(Self {
client: Client::new(),
api_key,
model: model.unwrap_or_else(|| "gpt-4o".to_string()),
base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
max_tokens,
_temperature: temperature,
})
}
fn create_request_body(
&self,
messages: &[Message],
tools: Option<&[Tool]>,
stream: bool,
max_tokens: Option<u32>,
_temperature: Option<f32>,
) -> serde_json::Value {
let mut body = json!({
"model": self.model,
"messages": convert_messages(messages),
"stream": stream,
});
if let Some(max_tokens) = max_tokens.or(self.max_tokens) {
body["max_completion_tokens"] = json!(max_tokens);
}
// OpenAI calls with temp setting seem to fail, so don't send one.
// if let Some(temperature) = temperature.or(self.temperature) {
// body["temperature"] = json!(temperature);
// }
if let Some(tools) = tools {
if !tools.is_empty() {
body["tools"] = json!(convert_tools(tools));
}
}
if stream {
body["stream_options"] = json!({
"include_usage": true,
});
}
body
}
async fn parse_streaming_response(
&self,
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
tx: mpsc::Sender<Result<CompletionChunk>>,
) -> Option<Usage> {
let mut buffer = String::new();
let mut accumulated_content = String::new();
let mut accumulated_usage: Option<Usage> = None;
let mut current_tool_calls: Vec<OpenAIStreamingToolCall> = Vec::new();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
let chunk_str = match std::str::from_utf8(&chunk) {
Ok(s) => s,
Err(e) => {
error!("Failed to parse chunk as UTF-8: {}", e);
continue;
}
};
buffer.push_str(chunk_str);
// Process complete lines
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].trim().to_string();
buffer.drain(..line_end + 1);
if line.is_empty() {
continue;
}
// Parse Server-Sent Events format
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
debug!("Received stream completion marker");
// Send final chunk with accumulated content and tool calls
if !accumulated_content.is_empty() || !current_tool_calls.is_empty() {
let tool_calls = if current_tool_calls.is_empty() {
None
} else {
Some(
current_tool_calls
.iter()
.filter_map(|tc| tc.to_tool_call())
.collect(),
)
};
let final_chunk = CompletionChunk {
content: accumulated_content.clone(),
finished: true,
tool_calls,
usage: accumulated_usage.clone(),
};
let _ = tx.send(Ok(final_chunk)).await;
}
return accumulated_usage;
}
// Parse the JSON data
match serde_json::from_str::<OpenAIStreamChunk>(data) {
Ok(chunk_data) => {
// Handle content
for choice in &chunk_data.choices {
if let Some(content) = &choice.delta.content {
accumulated_content.push_str(content);
let chunk = CompletionChunk {
content: content.clone(),
finished: false,
tool_calls: None,
usage: None,
};
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return accumulated_usage;
}
}
// Handle tool calls
if let Some(delta_tool_calls) = &choice.delta.tool_calls {
for delta_tool_call in delta_tool_calls {
if let Some(index) = delta_tool_call.index {
// Ensure we have enough tool calls in our vector
while current_tool_calls.len() <= index {
current_tool_calls
.push(OpenAIStreamingToolCall::default());
}
let tool_call = &mut current_tool_calls[index];
if let Some(id) = &delta_tool_call.id {
tool_call.id = Some(id.clone());
}
if let Some(function) = &delta_tool_call.function {
if let Some(name) = &function.name {
tool_call.name = Some(name.clone());
}
if let Some(arguments) = &function.arguments {
tool_call.arguments.push_str(arguments);
}
}
}
}
}
}
// Handle usage
if let Some(usage) = chunk_data.usage {
accumulated_usage = Some(Usage {
prompt_tokens: usage.prompt_tokens,
completion_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
});
}
}
Err(e) => {
debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
}
}
}
}
}
Err(e) => {
error!("Stream error: {}", e);
let _ = tx.send(Err(anyhow::anyhow!("Stream error: {}", e))).await;
return accumulated_usage;
}
}
}
// Send final chunk if we haven't already
let tool_calls = if current_tool_calls.is_empty() {
None
} else {
Some(
current_tool_calls
.iter()
.filter_map(|tc| tc.to_tool_call())
.collect(),
)
};
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
tool_calls,
usage: accumulated_usage.clone(),
};
let _ = tx.send(Ok(final_chunk)).await;
accumulated_usage
}
}
#[async_trait]
impl LLMProvider for OpenAIProvider {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
debug!(
"Processing OpenAI completion request with {} messages",
request.messages.len()
);
let body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
false,
request.max_tokens,
request.temperature,
);
debug!("Sending request to OpenAI API: model={}", self.model);
let response = self
.client
.post(&format!("{}/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text));
}
let openai_response: OpenAIResponse = response.json().await?;
let content = openai_response
.choices
.first()
.and_then(|choice| choice.message.content.clone())
.unwrap_or_default();
let usage = Usage {
prompt_tokens: openai_response.usage.prompt_tokens,
completion_tokens: openai_response.usage.completion_tokens,
total_tokens: openai_response.usage.total_tokens,
};
debug!(
"OpenAI completion successful: {} tokens generated",
usage.completion_tokens
);
Ok(CompletionResponse {
content,
usage,
model: self.model.clone(),
})
}
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
debug!(
"Processing OpenAI streaming request with {} messages",
request.messages.len()
);
let body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
true,
request.max_tokens,
request.temperature,
);
debug!("Sending streaming request to OpenAI API: model={}", self.model);
let response = self
.client
.post(&format!("{}/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text));
}
let stream = response.bytes_stream();
let (tx, rx) = mpsc::channel(100);
// Spawn task to process the stream
let provider = self.clone();
tokio::spawn(async move {
let usage = provider.parse_streaming_response(stream, tx).await;
// Log the final usage if available
if let Some(usage) = usage {
debug!(
"Stream completed with usage - prompt: {}, completion: {}, total: {}",
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
);
}
});
Ok(ReceiverStream::new(rx))
}
fn name(&self) -> &str {
"openai"
}
fn model(&self) -> &str {
&self.model
}
fn has_native_tool_calling(&self) -> bool {
// OpenAI models support native tool calling
true
}
}
fn convert_messages(messages: &[Message]) -> Vec<serde_json::Value> {
messages
.iter()
.map(|msg| {
json!({
"role": match msg.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
},
"content": msg.content,
})
})
.collect()
}
fn convert_tools(tools: &[Tool]) -> Vec<serde_json::Value> {
tools
.iter()
.map(|tool| {
json!({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.input_schema,
}
})
})
.collect()
}
// OpenAI API response structures
#[derive(Debug, Deserialize)]
struct OpenAIResponse {
choices: Vec<OpenAIChoice>,
usage: OpenAIUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAIChoice {
message: OpenAIMessage,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct OpenAIMessage {
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<OpenAIToolCall>>,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct OpenAIToolCall {
id: String,
function: OpenAIFunction,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct OpenAIFunction {
name: String,
arguments: String,
}
// Streaming tool call accumulator
#[derive(Debug, Default)]
struct OpenAIStreamingToolCall {
id: Option<String>,
name: Option<String>,
arguments: String,
}
impl OpenAIStreamingToolCall {
fn to_tool_call(&self) -> Option<ToolCall> {
let id = self.id.as_ref()?;
let name = self.name.as_ref()?;
let args = serde_json::from_str(&self.arguments).unwrap_or(serde_json::Value::Null);
Some(ToolCall {
id: id.clone(),
tool: name.clone(),
args,
})
}
}
#[derive(Debug, Deserialize)]
struct OpenAIUsage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
// Streaming response structures
#[derive(Debug, Deserialize)]
struct OpenAIStreamChunk {
choices: Vec<OpenAIStreamChoice>,
usage: Option<OpenAIUsage>,
}
#[derive(Debug, Deserialize)]
struct OpenAIStreamChoice {
delta: OpenAIDelta,
}
#[derive(Debug, Deserialize)]
struct OpenAIDelta {
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<OpenAIDeltaToolCall>>,
}
#[derive(Debug, Deserialize)]
struct OpenAIDeltaToolCall {
index: Option<usize>,
id: Option<String>,
function: Option<OpenAIDeltaFunction>,
}
#[derive(Debug, Deserialize)]
struct OpenAIDeltaFunction {
name: Option<String>,
arguments: Option<String>,
}

View File

@@ -0,0 +1,75 @@
# Coach-Player Provider Configuration
G3 now supports specifying different LLM providers for the coach and player agents when running in autonomous mode. This allows you to optimize for different requirements:
- **Player**: The agent that implements code - might benefit from a faster, more cost-effective model
- **Coach**: The agent that reviews code - might benefit from a more powerful, analytical model
## Configuration
In your `config.toml` file, under the `[providers]` section, you can specify:
```toml
[providers]
default_provider = "databricks" # Used for normal operations
coach = "databricks" # Provider for coach (code reviewer)
player = "anthropic" # Provider for player (code implementer)
```
If `coach` or `player` are not specified, they will default to using the `default_provider`.
## Example Use Cases
### Cost Optimization
Use a cheaper, faster model for initial implementations (player) and a more powerful model for review (coach):
```toml
coach = "anthropic" # Claude Sonnet for thorough review
player = "anthropic" # Claude Haiku for quick implementation
```
### Speed vs Quality Trade-off
Use a local embedded model for fast iterations (player) and a cloud model for quality review (coach):
```toml
coach = "databricks" # Cloud model for quality review
player = "embedded" # Local model for fast implementation
```
### Specialized Models
Use different models optimized for different tasks:
```toml
coach = "databricks" # Model fine-tuned for code review
player = "openai" # Model optimized for code generation
```
## Requirements
- Both providers must be properly configured in your config file
- Each provider must have valid credentials
- The models specified for each provider must be accessible
## How It Works
When running in autonomous mode (`g3 --autonomous`), the system will:
1. Use the `player` provider (or default) for the initial implementation
2. Switch to the `coach` provider (or default) for code review
3. Return to the `player` provider for implementing feedback
4. Continue this cycle for the specified number of turns
The providers are logged at startup so you can verify which models are being used:
```
🎮 Player provider: anthropic
👨‍🏫 Coach provider: databricks
Using different providers for player and coach
```
## Benefits
- **Cost Efficiency**: Use expensive models only where they add the most value
- **Speed Optimization**: Use faster models for iterative development
- **Specialization**: Leverage models that excel at specific tasks
- **Flexibility**: Easy to experiment with different provider combinations