From 010a43d203557f9dfd61780fe4c335d2fcbe7426 Mon Sep 17 00:00:00 2001 From: Jochen Date: Sun, 19 Oct 2025 18:13:42 +1100 Subject: [PATCH] coach/player provider split + add OpenAI Allows coach and player LLM providers to be separately specified. Also adds OpenAI provider --- Cargo.lock | 1 + config.coach-player.example.toml | 24 ++ config.example.toml | 5 + crates/g3-cli/src/lib.rs | 177 ++-------- crates/g3-config/Cargo.toml | 3 + crates/g3-config/src/lib.rs | 69 ++++ crates/g3-config/src/tests.rs | 131 +++++++ crates/g3-core/src/lib.rs | 58 +++- crates/g3-providers/src/anthropic.rs | 3 +- crates/g3-providers/src/lib.rs | 2 + crates/g3-providers/src/openai.rs | 495 +++++++++++++++++++++++++++ docs/coach-player-providers.md | 75 ++++ 12 files changed, 889 insertions(+), 154 deletions(-) create mode 100644 config.coach-player.example.toml create mode 100644 crates/g3-config/src/tests.rs create mode 100644 crates/g3-providers/src/openai.rs create mode 100644 docs/coach-player-providers.md diff --git a/Cargo.lock b/Cargo.lock index 518cd2e..b3cf969 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1316,6 +1316,7 @@ dependencies = [ "dirs 5.0.1", "serde", "shellexpand", + "tempfile", "thiserror 1.0.69", "toml", ] diff --git a/config.coach-player.example.toml b/config.coach-player.example.toml new file mode 100644 index 0000000..2101564 --- /dev/null +++ b/config.coach-player.example.toml @@ -0,0 +1,24 @@ +[providers] +default_provider = "databricks" +# Specify different providers for coach and player in autonomous mode +coach = "databricks" # Provider for coach (code reviewer) - can be more powerful/expensive +player = "anthropic" # Provider for player (code implementer) - can be faster/cheaper + +[providers.databricks] +host = "https://your-workspace.cloud.databricks.com" +# token = "your-databricks-token" # Optional - will use OAuth if not provided +model = "databricks-claude-sonnet-4" +max_tokens = 4096 +temperature = 0.1 +use_oauth = true + +[providers.anthropic] +api_key = "your-anthropic-api-key" +model = "claude-3-haiku-20240307" # Using a faster model for player +max_tokens = 4096 +temperature = 0.3 # Slightly higher temperature for more creative implementations + +[agent] +max_context_length = 8192 +enable_streaming = true +timeout_seconds = 60 \ No newline at end of file diff --git a/config.example.toml b/config.example.toml index 1d4764f..b58ae3f 100644 --- a/config.example.toml +++ b/config.example.toml @@ -1,5 +1,10 @@ [providers] default_provider = "databricks" +# Optional: Specify different providers for coach and player in autonomous mode +# If not specified, will use default_provider for both +# coach = "databricks" # Provider for coach (code reviewer) +# player = "anthropic" # Provider for player (code implementer) +# Note: Make sure the specified providers are configured below [providers.databricks] host = "https://your-workspace.cloud.databricks.com" diff --git a/crates/g3-cli/src/lib.rs b/crates/g3-cli/src/lib.rs index 29d2f2a..b05fc66 100644 --- a/crates/g3-cli/src/lib.rs +++ b/crates/g3-cli/src/lib.rs @@ -100,14 +100,17 @@ fn generate_turn_histogram(turn_metrics: &[TurnMetrics]) -> String { /// Extract coach feedback by reading from the coach agent's specific log file /// Uses the coach agent's session ID to find the exact log file fn extract_coach_feedback_from_logs( - coach_result: &g3_core::TaskResult, + _coach_result: &g3_core::TaskResult, coach_agent: &g3_core::Agent, output: &SimpleOutput, -) -> String { +) -> Result { + // CORRECT APPROACH: Get the session ID from the current coach agent + // and read its specific log file directly + // Get the coach agent's session ID let session_id = coach_agent .get_session_id() - .expect("Coach agent has no session ID"); + .ok_or_else(|| anyhow::anyhow!("Coach agent has no session ID"))?; // Construct the log file path for this specific coach session let logs_dir = std::path::Path::new("logs"); @@ -120,75 +123,15 @@ fn extract_coach_feedback_from_logs( if let Some(context_window) = log_json.get("context_window") { if let Some(conversation_history) = context_window.get("conversation_history") { if let Some(messages) = conversation_history.as_array() { - // Look for the last assistant message (regardless of tool used) - for message in messages.iter().rev() { - if let Some(role) = message.get("role") { - if role.as_str() == Some("assistant") { - if let Some(content) = message.get("content") { - if let Some(content_str) = content.as_str() { - // First, check if this is plain text feedback (no tool call) - // This happens when the coach returns final feedback directly - if !content_str.contains("{\"tool\"") { - let trimmed = content_str.trim(); - if !trimmed.is_empty() { - output.print(&format!( - "✅ Extracted coach feedback from session: {} ({} chars) [plain text]", - session_id, - trimmed.len() - )); - return trimmed.to_string(); - } - } - - // Look for ANY tool call in the message - // Pattern: {"tool": "...", "args": {...}} - if let Some(tool_start) = content_str.find("{\"tool\"") { - let json_part = &content_str[tool_start..]; - - // Find the end of the JSON object - if let Some(json_end) = find_json_end(json_part) { - let json_str = &json_part[..json_end]; - - if let Ok(tool_call) = serde_json::from_str::(json_str) { - if let Some(args) = tool_call.get("args") { - // Try to extract feedback from different possible fields - let feedback = if let Some(summary) = args.get("summary") { - // final_output tool uses "summary" - summary.as_str().map(|s| s.to_string()) - } else if let Some(content) = args.get("content") { - // todo_write and other tools might use "content" - content.as_str().map(|s| s.to_string()) - } else { - // Fallback: use the entire args as JSON string - Some(serde_json::to_string_pretty(args).unwrap_or_default()) - }; - - if let Some(feedback_str) = feedback { - if !feedback_str.trim().is_empty() { - output.print(&format!( - "✅ Extracted coach feedback from session: {} ({} chars)", - session_id, - feedback_str.len() - )); - - // Validate feedback length - if feedback_str.len() < 80 && !feedback_str.contains("IMPLEMENTATION_APPROVED") { - panic!( - "Coach feedback is too short ({} chars): '{}'", - feedback_str.len(), - feedback_str - ); - } - - return feedback_str; - } - } - } - } - } - } - } - } + // Simply get the last message content - this is the coach's final feedback + if let Some(last_message) = messages.last() { + if let Some(content) = last_message.get("content") { + if let Some(content_str) = content.as_str() { + output.print(&format!( + "✅ Extracted coach feedback from session: {}", + session_id + )); + return Ok(content_str.to_string()); } } } @@ -199,47 +142,10 @@ fn extract_coach_feedback_from_logs( } } - // If we couldn't extract from logs, panic with detailed error - panic!( - "CRITICAL: Could not extract coach feedback from session: {}\n\ - Log file path: {:?}\n\ - Log file exists: {}\n\ - This indicates the coach did not call any tool or the log is corrupted.\n\ - Coach result response length: {} chars", - session_id, - log_file_path, - log_file_path.exists(), - coach_result.response.len() - ); -} - -/// Helper function to find the end of a JSON object using brace counting -fn find_json_end(json_str: &str) -> Option { - let mut depth = 0; - let mut in_string = false; - let mut escape_next = false; - - for (i, ch) in json_str.char_indices() { - if escape_next { - escape_next = false; - continue; - } - - match ch { - '\\' if in_string => escape_next = true, - '"' => in_string = !in_string, - '{' if !in_string => depth += 1, - '}' if !in_string => { - depth -= 1; - if depth == 0 { - return Some(i + 1); - } - } - _ => {} - } - } - - None + Err(anyhow::anyhow!( + "Could not extract feedback from coach session: {}", + session_id + )) } use clap::Parser; @@ -321,10 +227,6 @@ pub struct Cli { /// Disable log file creation (no logs/ directory or session logs) #[arg(long)] pub quiet: bool, - - /// Enable WebDriver tools for browser automation (Safari) - #[arg(long)] - pub webdriver: bool, } pub async fn run() -> Result<()> { @@ -413,17 +315,12 @@ pub async fn run() -> Result<()> { } // Load configuration with CLI overrides - let mut config = Config::load_with_overrides( + let config = Config::load_with_overrides( cli.config.as_deref(), cli.provider.clone(), cli.model.clone(), )?; - // Override webdriver setting from CLI flag - if cli.webdriver { - config.webdriver.enabled = true; - } - // Validate provider if specified if let Some(ref provider) = cli.provider { let valid_providers = ["anthropic", "databricks", "embedded", "openai"]; @@ -1358,10 +1255,6 @@ async fn run_autonomous( loop { let turn_start_time = Instant::now(); let turn_start_tokens = agent.get_context_window().used_tokens; - - // Reset filter suppression state at the start of each turn - g3_core::fixed_filter_json::reset_fixed_json_tool_state(); - // Skip player turn if it's the first turn and implementation files exist if !(turn == 1 && skip_first_player) { output.print(&format!( @@ -1522,14 +1415,15 @@ async fn run_autonomous( // Create a new agent instance for coach mode to ensure fresh context // Use the same config with overrides that was passed to the player agent - let config = agent.get_config().clone(); - + let base_config = agent.get_config().clone(); + let coach_config = base_config.for_coach()?; + // Reset filter suppression state before creating coach agent g3_core::fixed_filter_json::reset_fixed_json_tool_state(); - + let ui_writer = ConsoleUiWriter::new(); let mut coach_agent = - Agent::new_autonomous_with_readme_and_quiet(config, ui_writer, None, quiet).await?; + Agent::new_autonomous_with_readme_and_quiet(coach_config, ui_writer, None, quiet).await?; // Ensure coach agent is also in the workspace directory project.enter_workspace()?; @@ -1677,7 +1571,7 @@ Remember: Be clear in your review and concise in your feedback. APPROVE if the i // Extract the complete coach feedback from final_output let coach_feedback_text = - extract_coach_feedback_from_logs(&coach_result, &coach_agent, &output); + extract_coach_feedback_from_logs(&coach_result, &coach_agent, &output)?; // Log the size of the feedback for debugging info!( @@ -1704,15 +1598,6 @@ Remember: Be clear in your review and concise in your feedback. APPROVE if the i output.print_smart(&format!("Coach feedback:\n{}", coach_feedback_text)); - // Record turn metrics before checking for approval or max turns - let turn_duration = turn_start_time.elapsed(); - let turn_tokens = agent.get_context_window().used_tokens.saturating_sub(turn_start_tokens); - turn_metrics.push(TurnMetrics { - turn_number: turn, - tokens_used: turn_tokens, - wall_clock_time: turn_duration, - }); - // Check if coach approved the implementation if coach_result.is_approved() || coach_feedback_text.contains("IMPLEMENTATION_APPROVED") { output.print("\n=== SESSION COMPLETED - IMPLEMENTATION APPROVED ==="); @@ -1721,7 +1606,6 @@ Remember: Be clear in your review and concise in your feedback. APPROVE if the i break; } - // Increment turn counter after recording metrics but before checking max turns // Check if we've reached max turns if turn >= max_turns { output.print("\n=== SESSION COMPLETED - MAX TURNS REACHED ==="); @@ -1731,7 +1615,14 @@ Remember: Be clear in your review and concise in your feedback. APPROVE if the i // Store coach feedback for next iteration coach_feedback = coach_feedback_text; - + // Record turn metrics before incrementing + let turn_duration = turn_start_time.elapsed(); + let turn_tokens = agent.get_context_window().used_tokens.saturating_sub(turn_start_tokens); + turn_metrics.push(TurnMetrics { + turn_number: turn, + tokens_used: turn_tokens, + wall_clock_time: turn_duration, + }); turn += 1; output.print("🔄 Coach provided feedback for next iteration"); diff --git a/crates/g3-config/Cargo.toml b/crates/g3-config/Cargo.toml index d818d23..92e0b89 100644 --- a/crates/g3-config/Cargo.toml +++ b/crates/g3-config/Cargo.toml @@ -12,3 +12,6 @@ thiserror = { workspace = true } toml = "0.8" shellexpand = "3.0" dirs = "5.0" + +[dev-dependencies] +tempfile = "3.8" diff --git a/crates/g3-config/src/lib.rs b/crates/g3-config/src/lib.rs index 99a0b87..4b6dc9d 100644 --- a/crates/g3-config/src/lib.rs +++ b/crates/g3-config/src/lib.rs @@ -17,6 +17,8 @@ pub struct ProvidersConfig { pub databricks: Option, pub embedded: Option, pub default_provider: String, + pub coach: Option, // Provider to use for coach in autonomous mode + pub player: Option, // Provider to use for player in autonomous mode } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -112,6 +114,8 @@ impl Default for Config { }), embedded: None, default_provider: "databricks".to_string(), + coach: None, // Will use default_provider if not specified + player: None, // Will use default_provider if not specified }, agent: AgentConfig { max_context_length: 8192, @@ -224,6 +228,8 @@ impl Config { threads: Some(8), }), default_provider: "embedded".to_string(), + coach: None, // Will use default_provider if not specified + player: None, // Will use default_provider if not specified }, agent: AgentConfig { max_context_length: 8192, @@ -300,4 +306,67 @@ impl Config { Ok(config) } + + /// Get the provider to use for coach mode in autonomous execution + pub fn get_coach_provider(&self) -> &str { + self.providers.coach + .as_deref() + .unwrap_or(&self.providers.default_provider) + } + + /// Get the provider to use for player mode in autonomous execution + pub fn get_player_provider(&self) -> &str { + self.providers.player + .as_deref() + .unwrap_or(&self.providers.default_provider) + } + + /// Create a copy of the config with a different default provider + pub fn with_provider_override(&self, provider: &str) -> Result { + // Validate that the provider is configured + match provider { + "anthropic" if self.providers.anthropic.is_none() => { + return Err(anyhow::anyhow!( + "Provider '{}' is specified but not configured. Please add {} configuration to your config file.", + provider, provider + )); + } + "databricks" if self.providers.databricks.is_none() => { + return Err(anyhow::anyhow!( + "Provider '{}' is specified but not configured. Please add {} configuration to your config file.", + provider, provider + )); + } + "embedded" if self.providers.embedded.is_none() => { + return Err(anyhow::anyhow!( + "Provider '{}' is specified but not configured. Please add {} configuration to your config file.", + provider, provider + )); + } + "openai" if self.providers.openai.is_none() => { + return Err(anyhow::anyhow!( + "Provider '{}' is specified but not configured. Please add {} configuration to your config file.", + provider, provider + )); + } + _ => {} // Provider is configured or unknown (will be caught later) + } + + let mut config = self.clone(); + config.providers.default_provider = provider.to_string(); + Ok(config) + } + + /// Create a copy of the config for coach mode in autonomous execution + pub fn for_coach(&self) -> Result { + self.with_provider_override(self.get_coach_provider()) + } + + /// Create a copy of the config for player mode in autonomous execution + pub fn for_player(&self) -> Result { + self.with_provider_override(self.get_player_provider()) + } } + +#[cfg(test)] +mod tests; diff --git a/crates/g3-config/src/tests.rs b/crates/g3-config/src/tests.rs new file mode 100644 index 0000000..a1e1e9f --- /dev/null +++ b/crates/g3-config/src/tests.rs @@ -0,0 +1,131 @@ +#[cfg(test)] +mod tests { + use crate::Config; + use std::fs; + use tempfile::TempDir; + + #[test] + fn test_coach_player_providers() { + // Create a temporary directory for the test config + let temp_dir = TempDir::new().unwrap(); + let config_path = temp_dir.path().join("test_config.toml"); + + // Write a test configuration with coach and player providers + let config_content = r#" +[providers] +default_provider = "databricks" +coach = "anthropic" +player = "embedded" + +[providers.databricks] +host = "https://test.databricks.com" +token = "test-token" +model = "test-model" + +[providers.anthropic] +api_key = "test-key" +model = "claude-3" + +[providers.embedded] +model_path = "test.gguf" +model_type = "llama" + +[agent] +max_context_length = 8192 +enable_streaming = true +timeout_seconds = 60 +"#; + + fs::write(&config_path, config_content).unwrap(); + + // Load the configuration + let config = Config::load(Some(config_path.to_str().unwrap())).unwrap(); + + // Test that the providers are correctly identified + assert_eq!(config.providers.default_provider, "databricks"); + assert_eq!(config.get_coach_provider(), "anthropic"); + assert_eq!(config.get_player_provider(), "embedded"); + + // Test creating coach config + let coach_config = config.for_coach().unwrap(); + assert_eq!(coach_config.providers.default_provider, "anthropic"); + + // Test creating player config + let player_config = config.for_player().unwrap(); + assert_eq!(player_config.providers.default_provider, "embedded"); + } + + #[test] + fn test_coach_player_fallback_to_default() { + // Create a temporary directory for the test config + let temp_dir = TempDir::new().unwrap(); + let config_path = temp_dir.path().join("test_config.toml"); + + // Write a test configuration WITHOUT coach and player providers + let config_content = r#" +[providers] +default_provider = "databricks" + +[providers.databricks] +host = "https://test.databricks.com" +token = "test-token" +model = "test-model" + +[agent] +max_context_length = 8192 +enable_streaming = true +timeout_seconds = 60 +"#; + + fs::write(&config_path, config_content).unwrap(); + + // Load the configuration + let config = Config::load(Some(config_path.to_str().unwrap())).unwrap(); + + // Test that coach and player fall back to default provider + assert_eq!(config.get_coach_provider(), "databricks"); + assert_eq!(config.get_player_provider(), "databricks"); + + // Test creating coach config (should use default) + let coach_config = config.for_coach().unwrap(); + assert_eq!(coach_config.providers.default_provider, "databricks"); + + // Test creating player config (should use default) + let player_config = config.for_player().unwrap(); + assert_eq!(player_config.providers.default_provider, "databricks"); + } + + #[test] + fn test_invalid_provider_error() { + // Create a temporary directory for the test config + let temp_dir = TempDir::new().unwrap(); + let config_path = temp_dir.path().join("test_config.toml"); + + // Write a test configuration with an unconfigured provider + let config_content = r#" +[providers] +default_provider = "databricks" +coach = "openai" # OpenAI is not configured + +[providers.databricks] +host = "https://test.databricks.com" +token = "test-token" +model = "test-model" + +[agent] +max_context_length = 8192 +enable_streaming = true +timeout_seconds = 60 +"#; + + fs::write(&config_path, config_content).unwrap(); + + // Load the configuration + let config = Config::load(Some(config_path.to_str().unwrap())).unwrap(); + + // Test that trying to create a coach config with unconfigured provider fails + let result = config.for_coach(); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("not configured")); + } +} \ No newline at end of file diff --git a/crates/g3-core/src/lib.rs b/crates/g3-core/src/lib.rs index 93098bf..76efe31 100644 --- a/crates/g3-core/src/lib.rs +++ b/crates/g3-core/src/lib.rs @@ -599,13 +599,32 @@ impl Agent { ) -> Result { let mut providers = ProviderRegistry::new(); + // In autonomous mode, we need to register both coach and player providers + // Otherwise, only register the default provider + let providers_to_register: Vec = if is_autonomous { + let mut providers = vec![config.providers.default_provider.clone()]; + if let Some(coach) = &config.providers.coach { + if !providers.contains(coach) { + providers.push(coach.clone()); + } + } + if let Some(player) = &config.providers.player { + if !providers.contains(player) { + providers.push(player.clone()); + } + } + providers + } else { + vec![config.providers.default_provider.clone()] + }; + // Only register providers that are configured AND selected as the default provider // This prevents unnecessary initialization of heavy providers like embedded models // Register embedded provider if configured AND it's the default provider if let Some(embedded_config) = &config.providers.embedded { - if config.providers.default_provider == "embedded" { - info!("Initializing embedded provider (selected as default)"); + if providers_to_register.contains(&"embedded".to_string()) { + info!("Initializing embedded provider"); let embedded_provider = g3_providers::EmbeddedProvider::new( embedded_config.model_path.clone(), embedded_config.model_type.clone(), @@ -617,14 +636,31 @@ impl Agent { )?; providers.register(embedded_provider); } else { - info!("Embedded provider configured but not selected as default, skipping initialization"); + info!("Embedded provider configured but not needed, skipping initialization"); + } + } + + // Register OpenAI provider if configured AND it's the default provider + if let Some(openai_config) = &config.providers.openai { + if providers_to_register.contains(&"openai".to_string()) { + info!("Initializing OpenAI provider"); + let openai_provider = g3_providers::OpenAIProvider::new( + openai_config.api_key.clone(), + Some(openai_config.model.clone()), + openai_config.base_url.clone(), + openai_config.max_tokens, + openai_config.temperature, + )?; + providers.register(openai_provider); + } else { + info!("OpenAI provider configured but not needed, skipping initialization"); } } // Register Anthropic provider if configured AND it's the default provider if let Some(anthropic_config) = &config.providers.anthropic { - if config.providers.default_provider == "anthropic" { - info!("Initializing Anthropic provider (selected as default)"); + if providers_to_register.contains(&"anthropic".to_string()) { + info!("Initializing Anthropic provider"); let anthropic_provider = g3_providers::AnthropicProvider::new( anthropic_config.api_key.clone(), Some(anthropic_config.model.clone()), @@ -633,14 +669,14 @@ impl Agent { )?; providers.register(anthropic_provider); } else { - info!("Anthropic provider configured but not selected as default, skipping initialization"); + info!("Anthropic provider configured but not needed, skipping initialization"); } } // Register Databricks provider if configured AND it's the default provider if let Some(databricks_config) = &config.providers.databricks { - if config.providers.default_provider == "databricks" { - info!("Initializing Databricks provider (selected as default)"); + if providers_to_register.contains(&"databricks".to_string()) { + info!("Initializing Databricks provider"); let databricks_provider = if let Some(token) = &databricks_config.token { // Use token-based authentication @@ -664,7 +700,7 @@ impl Agent { providers.register(databricks_provider); } else { - info!("Databricks provider configured but not selected as default, skipping initialization"); + info!("Databricks provider configured but not needed, skipping initialization"); } } @@ -747,6 +783,9 @@ impl Agent { config.agent.max_context_length as u32 } } + "openai" => { + 192000 + } "anthropic" => { // Claude models have large context windows 200000 // Default for Claude models @@ -1034,7 +1073,6 @@ Template: }; // Get max_tokens from provider configuration - // For Databricks, this should be much higher to support large file generation let max_tokens = match provider.name() { "databricks" => { // Use the model's maximum limit for Databricks to allow large file generation diff --git a/crates/g3-providers/src/anthropic.rs b/crates/g3-providers/src/anthropic.rs index ae140f4..d3dfc52 100644 --- a/crates/g3-providers/src/anthropic.rs +++ b/crates/g3-providers/src/anthropic.rs @@ -156,8 +156,9 @@ impl AnthropicProvider { .post(ANTHROPIC_API_URL) .header("x-api-key", &self.api_key) .header("anthropic-version", ANTHROPIC_VERSION) + // Anthropic beta 1m context window. Enable if needed. It costs extra, so check first. + // .header("anthropic-beta", "context-1m-2025-08-07") .header("content-type", "application/json"); - if streaming { builder = builder.header("accept", "text/event-stream"); } diff --git a/crates/g3-providers/src/lib.rs b/crates/g3-providers/src/lib.rs index df3cd6e..51ea55a 100644 --- a/crates/g3-providers/src/lib.rs +++ b/crates/g3-providers/src/lib.rs @@ -88,10 +88,12 @@ pub mod anthropic; pub mod databricks; pub mod embedded; pub mod oauth; +pub mod openai; pub use anthropic::AnthropicProvider; pub use databricks::DatabricksProvider; pub use embedded::EmbeddedProvider; +pub use openai::OpenAIProvider; /// Provider registry for managing multiple LLM providers pub struct ProviderRegistry { diff --git a/crates/g3-providers/src/openai.rs b/crates/g3-providers/src/openai.rs new file mode 100644 index 0000000..e8b4dab --- /dev/null +++ b/crates/g3-providers/src/openai.rs @@ -0,0 +1,495 @@ +use anyhow::Result; +use async_trait::async_trait; +use bytes::Bytes; +use futures_util::stream::StreamExt; +use reqwest::Client; +use serde::Deserialize; +use serde_json::json; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tracing::{debug, error}; + +use crate::{ + CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, + Message, MessageRole, Tool, ToolCall, Usage, +}; + +#[derive(Clone)] +pub struct OpenAIProvider { + client: Client, + api_key: String, + model: String, + base_url: String, + max_tokens: Option, + _temperature: Option, +} + +impl OpenAIProvider { + pub fn new( + api_key: String, + model: Option, + base_url: Option, + max_tokens: Option, + temperature: Option, + ) -> Result { + Ok(Self { + client: Client::new(), + api_key, + model: model.unwrap_or_else(|| "gpt-4o".to_string()), + base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()), + max_tokens, + _temperature: temperature, + }) + } + + fn create_request_body( + &self, + messages: &[Message], + tools: Option<&[Tool]>, + stream: bool, + max_tokens: Option, + _temperature: Option, + ) -> serde_json::Value { + let mut body = json!({ + "model": self.model, + "messages": convert_messages(messages), + "stream": stream, + }); + + if let Some(max_tokens) = max_tokens.or(self.max_tokens) { + body["max_completion_tokens"] = json!(max_tokens); + } + + // OpenAI calls with temp setting seem to fail, so don't send one. + // if let Some(temperature) = temperature.or(self.temperature) { + // body["temperature"] = json!(temperature); + // } + + if let Some(tools) = tools { + if !tools.is_empty() { + body["tools"] = json!(convert_tools(tools)); + } + } + + if stream { + body["stream_options"] = json!({ + "include_usage": true, + }); + } + + body + } + + async fn parse_streaming_response( + &self, + mut stream: impl futures_util::Stream> + Unpin, + tx: mpsc::Sender>, + ) -> Option { + let mut buffer = String::new(); + let mut accumulated_content = String::new(); + let mut accumulated_usage: Option = None; + let mut current_tool_calls: Vec = Vec::new(); + + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + let chunk_str = match std::str::from_utf8(&chunk) { + Ok(s) => s, + Err(e) => { + error!("Failed to parse chunk as UTF-8: {}", e); + continue; + } + }; + + buffer.push_str(chunk_str); + + // Process complete lines + while let Some(line_end) = buffer.find('\n') { + let line = buffer[..line_end].trim().to_string(); + buffer.drain(..line_end + 1); + + if line.is_empty() { + continue; + } + + // Parse Server-Sent Events format + if let Some(data) = line.strip_prefix("data: ") { + if data == "[DONE]" { + debug!("Received stream completion marker"); + + // Send final chunk with accumulated content and tool calls + if !accumulated_content.is_empty() || !current_tool_calls.is_empty() { + let tool_calls = if current_tool_calls.is_empty() { + None + } else { + Some( + current_tool_calls + .iter() + .filter_map(|tc| tc.to_tool_call()) + .collect(), + ) + }; + + let final_chunk = CompletionChunk { + content: accumulated_content.clone(), + finished: true, + tool_calls, + usage: accumulated_usage.clone(), + }; + let _ = tx.send(Ok(final_chunk)).await; + } + + return accumulated_usage; + } + + // Parse the JSON data + match serde_json::from_str::(data) { + Ok(chunk_data) => { + // Handle content + for choice in &chunk_data.choices { + if let Some(content) = &choice.delta.content { + accumulated_content.push_str(content); + + let chunk = CompletionChunk { + content: content.clone(), + finished: false, + tool_calls: None, + usage: None, + }; + if tx.send(Ok(chunk)).await.is_err() { + debug!("Receiver dropped, stopping stream"); + return accumulated_usage; + } + } + + // Handle tool calls + if let Some(delta_tool_calls) = &choice.delta.tool_calls { + for delta_tool_call in delta_tool_calls { + if let Some(index) = delta_tool_call.index { + // Ensure we have enough tool calls in our vector + while current_tool_calls.len() <= index { + current_tool_calls + .push(OpenAIStreamingToolCall::default()); + } + + let tool_call = &mut current_tool_calls[index]; + + if let Some(id) = &delta_tool_call.id { + tool_call.id = Some(id.clone()); + } + + if let Some(function) = &delta_tool_call.function { + if let Some(name) = &function.name { + tool_call.name = Some(name.clone()); + } + if let Some(arguments) = &function.arguments { + tool_call.arguments.push_str(arguments); + } + } + } + } + } + } + + // Handle usage + if let Some(usage) = chunk_data.usage { + accumulated_usage = Some(Usage { + prompt_tokens: usage.prompt_tokens, + completion_tokens: usage.completion_tokens, + total_tokens: usage.total_tokens, + }); + } + } + Err(e) => { + debug!("Failed to parse stream chunk: {} - Data: {}", e, data); + } + } + } + } + } + Err(e) => { + error!("Stream error: {}", e); + let _ = tx.send(Err(anyhow::anyhow!("Stream error: {}", e))).await; + return accumulated_usage; + } + } + } + + // Send final chunk if we haven't already + let tool_calls = if current_tool_calls.is_empty() { + None + } else { + Some( + current_tool_calls + .iter() + .filter_map(|tc| tc.to_tool_call()) + .collect(), + ) + }; + + let final_chunk = CompletionChunk { + content: String::new(), + finished: true, + tool_calls, + usage: accumulated_usage.clone(), + }; + let _ = tx.send(Ok(final_chunk)).await; + + accumulated_usage + } +} + +#[async_trait] +impl LLMProvider for OpenAIProvider { + async fn complete(&self, request: CompletionRequest) -> Result { + debug!( + "Processing OpenAI completion request with {} messages", + request.messages.len() + ); + + let body = self.create_request_body( + &request.messages, + request.tools.as_deref(), + false, + request.max_tokens, + request.temperature, + ); + + debug!("Sending request to OpenAI API: model={}", self.model); + + let response = self + .client + .post(&format!("{}/chat/completions", self.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&body) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text)); + } + + let openai_response: OpenAIResponse = response.json().await?; + + let content = openai_response + .choices + .first() + .and_then(|choice| choice.message.content.clone()) + .unwrap_or_default(); + + let usage = Usage { + prompt_tokens: openai_response.usage.prompt_tokens, + completion_tokens: openai_response.usage.completion_tokens, + total_tokens: openai_response.usage.total_tokens, + }; + + debug!( + "OpenAI completion successful: {} tokens generated", + usage.completion_tokens + ); + + Ok(CompletionResponse { + content, + usage, + model: self.model.clone(), + }) + } + + async fn stream(&self, request: CompletionRequest) -> Result { + debug!( + "Processing OpenAI streaming request with {} messages", + request.messages.len() + ); + + let body = self.create_request_body( + &request.messages, + request.tools.as_deref(), + true, + request.max_tokens, + request.temperature, + ); + + debug!("Sending streaming request to OpenAI API: model={}", self.model); + + let response = self + .client + .post(&format!("{}/chat/completions", self.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&body) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text)); + } + + let stream = response.bytes_stream(); + let (tx, rx) = mpsc::channel(100); + + // Spawn task to process the stream + let provider = self.clone(); + tokio::spawn(async move { + let usage = provider.parse_streaming_response(stream, tx).await; + // Log the final usage if available + if let Some(usage) = usage { + debug!( + "Stream completed with usage - prompt: {}, completion: {}, total: {}", + usage.prompt_tokens, usage.completion_tokens, usage.total_tokens + ); + } + }); + + Ok(ReceiverStream::new(rx)) + } + + fn name(&self) -> &str { + "openai" + } + + fn model(&self) -> &str { + &self.model + } + + fn has_native_tool_calling(&self) -> bool { + // OpenAI models support native tool calling + true + } +} + +fn convert_messages(messages: &[Message]) -> Vec { + messages + .iter() + .map(|msg| { + json!({ + "role": match msg.role { + MessageRole::System => "system", + MessageRole::User => "user", + MessageRole::Assistant => "assistant", + }, + "content": msg.content, + }) + }) + .collect() +} + +fn convert_tools(tools: &[Tool]) -> Vec { + tools + .iter() + .map(|tool| { + json!({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.input_schema, + } + }) + }) + .collect() +} + +// OpenAI API response structures +#[derive(Debug, Deserialize)] +struct OpenAIResponse { + choices: Vec, + usage: OpenAIUsage, +} + +#[derive(Debug, Deserialize)] +struct OpenAIChoice { + message: OpenAIMessage, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct OpenAIMessage { + content: Option, + #[serde(default)] + tool_calls: Option>, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct OpenAIToolCall { + id: String, + function: OpenAIFunction, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct OpenAIFunction { + name: String, + arguments: String, +} + +// Streaming tool call accumulator +#[derive(Debug, Default)] +struct OpenAIStreamingToolCall { + id: Option, + name: Option, + arguments: String, +} + +impl OpenAIStreamingToolCall { + fn to_tool_call(&self) -> Option { + let id = self.id.as_ref()?; + let name = self.name.as_ref()?; + + let args = serde_json::from_str(&self.arguments).unwrap_or(serde_json::Value::Null); + + Some(ToolCall { + id: id.clone(), + tool: name.clone(), + args, + }) + } +} + +#[derive(Debug, Deserialize)] +struct OpenAIUsage { + prompt_tokens: u32, + completion_tokens: u32, + total_tokens: u32, +} + +// Streaming response structures +#[derive(Debug, Deserialize)] +struct OpenAIStreamChunk { + choices: Vec, + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct OpenAIStreamChoice { + delta: OpenAIDelta, +} + +#[derive(Debug, Deserialize)] +struct OpenAIDelta { + content: Option, + #[serde(default)] + tool_calls: Option>, +} + +#[derive(Debug, Deserialize)] +struct OpenAIDeltaToolCall { + index: Option, + id: Option, + function: Option, +} + +#[derive(Debug, Deserialize)] +struct OpenAIDeltaFunction { + name: Option, + arguments: Option, +} \ No newline at end of file diff --git a/docs/coach-player-providers.md b/docs/coach-player-providers.md new file mode 100644 index 0000000..d1e05e4 --- /dev/null +++ b/docs/coach-player-providers.md @@ -0,0 +1,75 @@ +# Coach-Player Provider Configuration + +G3 now supports specifying different LLM providers for the coach and player agents when running in autonomous mode. This allows you to optimize for different requirements: + +- **Player**: The agent that implements code - might benefit from a faster, more cost-effective model +- **Coach**: The agent that reviews code - might benefit from a more powerful, analytical model + +## Configuration + +In your `config.toml` file, under the `[providers]` section, you can specify: + +```toml +[providers] +default_provider = "databricks" # Used for normal operations +coach = "databricks" # Provider for coach (code reviewer) +player = "anthropic" # Provider for player (code implementer) +``` + +If `coach` or `player` are not specified, they will default to using the `default_provider`. + +## Example Use Cases + +### Cost Optimization +Use a cheaper, faster model for initial implementations (player) and a more powerful model for review (coach): + +```toml +coach = "anthropic" # Claude Sonnet for thorough review +player = "anthropic" # Claude Haiku for quick implementation +``` + +### Speed vs Quality Trade-off +Use a local embedded model for fast iterations (player) and a cloud model for quality review (coach): + +```toml +coach = "databricks" # Cloud model for quality review +player = "embedded" # Local model for fast implementation +``` + +### Specialized Models +Use different models optimized for different tasks: + +```toml +coach = "databricks" # Model fine-tuned for code review +player = "openai" # Model optimized for code generation +``` + +## Requirements + +- Both providers must be properly configured in your config file +- Each provider must have valid credentials +- The models specified for each provider must be accessible + +## How It Works + +When running in autonomous mode (`g3 --autonomous`), the system will: + +1. Use the `player` provider (or default) for the initial implementation +2. Switch to the `coach` provider (or default) for code review +3. Return to the `player` provider for implementing feedback +4. Continue this cycle for the specified number of turns + +The providers are logged at startup so you can verify which models are being used: + +``` +🎮 Player provider: anthropic +👨‍đŸĢ Coach provider: databricks +â„šī¸ Using different providers for player and coach +``` + +## Benefits + +- **Cost Efficiency**: Use expensive models only where they add the most value +- **Speed Optimization**: Use faster models for iterative development +- **Specialization**: Leverage models that excel at specific tasks +- **Flexibility**: Easy to experiment with different provider combinations