Compare commits

...

4 Commits

Author SHA1 Message Date
Michael Neale
af6d37a8e2 Add --interactive-requirements flag for AI-enhanced requirements mode
- Adds new --interactive-requirements CLI flag for autonomous mode
- Prompts user for brief requirements input
- Uses AI to enhance and structure requirements into proper markdown
- Shows enhanced requirements and allows user to approve/edit/cancel
- Saves to requirements.md and proceeds with autonomous mode if approved
- Includes test script for manual verification
2025-10-22 14:58:35 +11:00
Dhanji R. Prasanna
c1c6680e03 Merge pull request #7 from jochenx/jochen-add-openai-and-multi-providers
coach/player provider split + add OpenAI
2025-10-22 13:46:16 +11:00
Jochen
f2d8e744bb fix panic in CLI parser 2025-10-22 13:20:45 +11:00
Jochen
010a43d203 coach/player provider split + add OpenAI
Allows coach and player LLM providers to be separately specified.
Also adds OpenAI provider
2025-10-21 16:59:13 +11:00
13 changed files with 1033 additions and 141 deletions

1
Cargo.lock generated
View File

@@ -1316,6 +1316,7 @@ dependencies = [
"dirs 5.0.1", "dirs 5.0.1",
"serde", "serde",
"shellexpand", "shellexpand",
"tempfile",
"thiserror 1.0.69", "thiserror 1.0.69",
"toml", "toml",
] ]

View File

@@ -0,0 +1,24 @@
[providers]
default_provider = "databricks"
# Specify different providers for coach and player in autonomous mode
coach = "databricks" # Provider for coach (code reviewer) - can be more powerful/expensive
player = "anthropic" # Provider for player (code implementer) - can be faster/cheaper
[providers.databricks]
host = "https://your-workspace.cloud.databricks.com"
# token = "your-databricks-token" # Optional - will use OAuth if not provided
model = "databricks-claude-sonnet-4"
max_tokens = 4096
temperature = 0.1
use_oauth = true
[providers.anthropic]
api_key = "your-anthropic-api-key"
model = "claude-3-haiku-20240307" # Using a faster model for player
max_tokens = 4096
temperature = 0.3 # Slightly higher temperature for more creative implementations
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60

View File

@@ -1,5 +1,10 @@
[providers] [providers]
default_provider = "databricks" default_provider = "databricks"
# Optional: Specify different providers for coach and player in autonomous mode
# If not specified, will use default_provider for both
# coach = "databricks" # Provider for coach (code reviewer)
# player = "anthropic" # Provider for player (code implementer)
# Note: Make sure the specified providers are configured below
[providers.databricks] [providers.databricks]
host = "https://your-workspace.cloud.databricks.com" host = "https://your-workspace.cloud.databricks.com"

View File

@@ -103,11 +103,14 @@ fn extract_coach_feedback_from_logs(
coach_result: &g3_core::TaskResult, coach_result: &g3_core::TaskResult,
coach_agent: &g3_core::Agent<ConsoleUiWriter>, coach_agent: &g3_core::Agent<ConsoleUiWriter>,
output: &SimpleOutput, output: &SimpleOutput,
) -> String { ) -> Result<String> {
// CORRECT APPROACH: Get the session ID from the current coach agent
// and read its specific log file directly
// Get the coach agent's session ID // Get the coach agent's session ID
let session_id = coach_agent let session_id = coach_agent
.get_session_id() .get_session_id()
.expect("Coach agent has no session ID"); .ok_or_else(|| anyhow::anyhow!("Coach agent has no session ID"))?;
// Construct the log file path for this specific coach session // Construct the log file path for this specific coach session
let logs_dir = std::path::Path::new("logs"); let logs_dir = std::path::Path::new("logs");
@@ -120,75 +123,15 @@ fn extract_coach_feedback_from_logs(
if let Some(context_window) = log_json.get("context_window") { if let Some(context_window) = log_json.get("context_window") {
if let Some(conversation_history) = context_window.get("conversation_history") { if let Some(conversation_history) = context_window.get("conversation_history") {
if let Some(messages) = conversation_history.as_array() { if let Some(messages) = conversation_history.as_array() {
// Look for the last assistant message (regardless of tool used) // Simply get the last message content - this is the coach's final feedback
for message in messages.iter().rev() { if let Some(last_message) = messages.last() {
if let Some(role) = message.get("role") { if let Some(content) = last_message.get("content") {
if role.as_str() == Some("assistant") {
if let Some(content) = message.get("content") {
if let Some(content_str) = content.as_str() { if let Some(content_str) = content.as_str() {
// First, check if this is plain text feedback (no tool call)
// This happens when the coach returns final feedback directly
if !content_str.contains("{\"tool\"") {
let trimmed = content_str.trim();
if !trimmed.is_empty() {
output.print(&format!( output.print(&format!(
"✅ Extracted coach feedback from session: {} ({} chars) [plain text]", "✅ Extracted coach feedback from session: {}",
session_id, session_id
trimmed.len()
)); ));
return trimmed.to_string(); return Ok(content_str.to_string());
}
}
// Look for ANY tool call in the message
// Pattern: {"tool": "...", "args": {...}}
if let Some(tool_start) = content_str.find("{\"tool\"") {
let json_part = &content_str[tool_start..];
// Find the end of the JSON object
if let Some(json_end) = find_json_end(json_part) {
let json_str = &json_part[..json_end];
if let Ok(tool_call) = serde_json::from_str::<serde_json::Value>(json_str) {
if let Some(args) = tool_call.get("args") {
// Try to extract feedback from different possible fields
let feedback = if let Some(summary) = args.get("summary") {
// final_output tool uses "summary"
summary.as_str().map(|s| s.to_string())
} else if let Some(content) = args.get("content") {
// todo_write and other tools might use "content"
content.as_str().map(|s| s.to_string())
} else {
// Fallback: use the entire args as JSON string
Some(serde_json::to_string_pretty(args).unwrap_or_default())
};
if let Some(feedback_str) = feedback {
if !feedback_str.trim().is_empty() {
output.print(&format!(
"✅ Extracted coach feedback from session: {} ({} chars)",
session_id,
feedback_str.len()
));
// Validate feedback length
if feedback_str.len() < 80 && !feedback_str.contains("IMPLEMENTATION_APPROVED") {
panic!(
"Coach feedback is too short ({} chars): '{}'",
feedback_str.len(),
feedback_str
);
}
return feedback_str;
}
}
}
}
}
}
}
}
} }
} }
} }
@@ -213,35 +156,6 @@ fn extract_coach_feedback_from_logs(
); );
} }
/// Helper function to find the end of a JSON object using brace counting
fn find_json_end(json_str: &str) -> Option<usize> {
let mut depth = 0;
let mut in_string = false;
let mut escape_next = false;
for (i, ch) in json_str.char_indices() {
if escape_next {
escape_next = false;
continue;
}
match ch {
'\\' if in_string => escape_next = true,
'"' => in_string = !in_string,
'{' if !in_string => depth += 1,
'}' if !in_string => {
depth -= 1;
if depth == 0 {
return Some(i + 1);
}
}
_ => {}
}
}
None
}
use clap::Parser; use clap::Parser;
use g3_config::Config; use g3_config::Config;
use g3_core::{project::Project, ui_writer::UiWriter, Agent}; use g3_core::{project::Project, ui_writer::UiWriter, Agent};
@@ -302,6 +216,10 @@ pub struct Cli {
#[arg(long, value_name = "TEXT")] #[arg(long, value_name = "TEXT")]
pub requirements: Option<String>, pub requirements: Option<String>,
/// Interactive mode: prompt for requirements and save to requirements.md before starting autonomous mode
#[arg(long)]
pub interactive_requirements: bool,
/// Use retro terminal UI (inspired by 80s sci-fi) /// Use retro terminal UI (inspired by 80s sci-fi)
#[arg(long)] #[arg(long)]
pub retro: bool, pub retro: bool,
@@ -321,10 +239,6 @@ pub struct Cli {
/// Disable log file creation (no logs/ directory or session logs) /// Disable log file creation (no logs/ directory or session logs)
#[arg(long)] #[arg(long)]
pub quiet: bool, pub quiet: bool,
/// Enable WebDriver tools for browser automation (Safari)
#[arg(long)]
pub webdriver: bool,
} }
pub async fn run() -> Result<()> { pub async fn run() -> Result<()> {
@@ -393,6 +307,112 @@ pub async fn run() -> Result<()> {
// Create project model // Create project model
let project = if cli.autonomous { let project = if cli.autonomous {
// Handle interactive requirements mode with AI enhancement
if cli.interactive_requirements {
println!("\n📝 Interactive Requirements Mode");
println!("================================\n");
println!("Describe what you want to build (can be brief):");
println!("Press Ctrl+D (Unix) or Ctrl+Z (Windows) when done.\n");
use std::io::{self, Read, Write};
let mut requirements_input = String::new();
io::stdin().read_to_string(&mut requirements_input)?;
if requirements_input.trim().is_empty() {
anyhow::bail!("No requirements provided. Exiting.");
}
println!("\n🤖 Enhancing your requirements with AI...\n");
// Create a temporary agent to enhance the requirements
let temp_config = Config::load_with_overrides(
cli.config.as_deref(),
cli.provider.clone(),
cli.model.clone(),
)?;
let ui_writer = ConsoleUiWriter::new();
let mut temp_agent = Agent::new_with_readme_and_quiet(
temp_config,
ui_writer,
None,
true, // quiet mode
).await?;
// Craft the enhancement prompt
let enhancement_prompt = format!(
r#"You are a requirements analyst. Take this brief user input and expand it into a structured requirements document.
USER INPUT:
{}
Create a professional requirements document with:
1. A clear project title (# heading)
2. An overview section explaining what will be built
3. Organized requirements (functional, technical, quality)
4. Acceptance criteria
5. Any technical constraints or preferences mentioned
Format as proper markdown. Be specific and actionable. If the user's input is vague, make reasonable assumptions but keep it focused on what they described.
Output ONLY the markdown content, no explanations or meta-commentary."#,
requirements_input.trim()
);
// Execute enhancement task
let result = temp_agent
.execute_task_with_timing(&enhancement_prompt, None, false, false, false, false)
.await?;
let enhanced_requirements = result.response.trim().to_string();
// Show the enhanced requirements
println!("\n📋 Enhanced Requirements Document:");
println!("{}\n", "=".repeat(60));
println!("{}", enhanced_requirements);
println!("{}\n", "=".repeat(60));
// Ask for confirmation
println!("\n❓ Is this requirements document acceptable?");
println!(" [y] Yes, proceed with autonomous mode");
println!(" [e] Edit and save manually");
println!(" [n] No, cancel\n");
print!("Your choice (y/e/n): ");
io::stdout().flush()?;
let mut choice = String::new();
io::stdin().read_line(&mut choice)?;
let choice = choice.trim().to_lowercase();
let requirements_path = workspace_dir.join("requirements.md");
match choice.as_str() {
"y" | "yes" => {
// Save enhanced requirements
std::fs::write(&requirements_path, &enhanced_requirements)?;
println!("\n✅ Requirements saved to: {}", requirements_path.display());
println!("🚀 Starting autonomous mode...\n");
}
"e" | "edit" => {
// Save enhanced requirements for manual editing
std::fs::write(&requirements_path, &enhanced_requirements)?;
println!("\n✅ Requirements saved to: {}", requirements_path.display());
println!("📝 Please edit the file and run: g3 --autonomous");
println!(" Exiting for now.\n");
return Ok(());
}
"n" | "no" => {
println!("\n❌ Cancelled. No files were saved.\n");
return Ok(());
}
_ => {
println!("\n❌ Invalid choice. Cancelled.\n");
return Ok(());
}
}
}
if let Some(requirements_text) = cli.requirements { if let Some(requirements_text) = cli.requirements {
// Use requirements text override // Use requirements text override
Project::new_autonomous_with_requirements(workspace_dir.clone(), requirements_text)? Project::new_autonomous_with_requirements(workspace_dir.clone(), requirements_text)?
@@ -413,17 +433,12 @@ pub async fn run() -> Result<()> {
} }
// Load configuration with CLI overrides // Load configuration with CLI overrides
let mut config = Config::load_with_overrides( let config = Config::load_with_overrides(
cli.config.as_deref(), cli.config.as_deref(),
cli.provider.clone(), cli.provider.clone(),
cli.model.clone(), cli.model.clone(),
)?; )?;
// Override webdriver setting from CLI flag
if cli.webdriver {
config.webdriver.enabled = true;
}
// Validate provider if specified // Validate provider if specified
if let Some(ref provider) = cli.provider { if let Some(ref provider) = cli.provider {
let valid_providers = ["anthropic", "databricks", "embedded", "openai"]; let valid_providers = ["anthropic", "databricks", "embedded", "openai"];
@@ -1358,10 +1373,6 @@ async fn run_autonomous(
loop { loop {
let turn_start_time = Instant::now(); let turn_start_time = Instant::now();
let turn_start_tokens = agent.get_context_window().used_tokens; let turn_start_tokens = agent.get_context_window().used_tokens;
// Reset filter suppression state at the start of each turn
g3_core::fixed_filter_json::reset_fixed_json_tool_state();
// Skip player turn if it's the first turn and implementation files exist // Skip player turn if it's the first turn and implementation files exist
if !(turn == 1 && skip_first_player) { if !(turn == 1 && skip_first_player) {
output.print(&format!( output.print(&format!(
@@ -1522,14 +1533,15 @@ async fn run_autonomous(
// Create a new agent instance for coach mode to ensure fresh context // Create a new agent instance for coach mode to ensure fresh context
// Use the same config with overrides that was passed to the player agent // Use the same config with overrides that was passed to the player agent
let config = agent.get_config().clone(); let base_config = agent.get_config().clone();
let coach_config = base_config.for_coach()?;
// Reset filter suppression state before creating coach agent // Reset filter suppression state before creating coach agent
g3_core::fixed_filter_json::reset_fixed_json_tool_state(); g3_core::fixed_filter_json::reset_fixed_json_tool_state();
let ui_writer = ConsoleUiWriter::new(); let ui_writer = ConsoleUiWriter::new();
let mut coach_agent = let mut coach_agent =
Agent::new_autonomous_with_readme_and_quiet(config, ui_writer, None, quiet).await?; Agent::new_autonomous_with_readme_and_quiet(coach_config, ui_writer, None, quiet).await?;
// Ensure coach agent is also in the workspace directory // Ensure coach agent is also in the workspace directory
project.enter_workspace()?; project.enter_workspace()?;
@@ -1677,7 +1689,7 @@ Remember: Be clear in your review and concise in your feedback. APPROVE if the i
// Extract the complete coach feedback from final_output // Extract the complete coach feedback from final_output
let coach_feedback_text = let coach_feedback_text =
extract_coach_feedback_from_logs(&coach_result, &coach_agent, &output); extract_coach_feedback_from_logs(&coach_result, &coach_agent, &output)?;
// Log the size of the feedback for debugging // Log the size of the feedback for debugging
info!( info!(
@@ -1704,15 +1716,6 @@ Remember: Be clear in your review and concise in your feedback. APPROVE if the i
output.print_smart(&format!("Coach feedback:\n{}", coach_feedback_text)); output.print_smart(&format!("Coach feedback:\n{}", coach_feedback_text));
// Record turn metrics before checking for approval or max turns
let turn_duration = turn_start_time.elapsed();
let turn_tokens = agent.get_context_window().used_tokens.saturating_sub(turn_start_tokens);
turn_metrics.push(TurnMetrics {
turn_number: turn,
tokens_used: turn_tokens,
wall_clock_time: turn_duration,
});
// Check if coach approved the implementation // Check if coach approved the implementation
if coach_result.is_approved() || coach_feedback_text.contains("IMPLEMENTATION_APPROVED") { if coach_result.is_approved() || coach_feedback_text.contains("IMPLEMENTATION_APPROVED") {
output.print("\n=== SESSION COMPLETED - IMPLEMENTATION APPROVED ==="); output.print("\n=== SESSION COMPLETED - IMPLEMENTATION APPROVED ===");
@@ -1721,7 +1724,6 @@ Remember: Be clear in your review and concise in your feedback. APPROVE if the i
break; break;
} }
// Increment turn counter after recording metrics but before checking max turns
// Check if we've reached max turns // Check if we've reached max turns
if turn >= max_turns { if turn >= max_turns {
output.print("\n=== SESSION COMPLETED - MAX TURNS REACHED ==="); output.print("\n=== SESSION COMPLETED - MAX TURNS REACHED ===");
@@ -1731,7 +1733,14 @@ Remember: Be clear in your review and concise in your feedback. APPROVE if the i
// Store coach feedback for next iteration // Store coach feedback for next iteration
coach_feedback = coach_feedback_text; coach_feedback = coach_feedback_text;
// Record turn metrics before incrementing
let turn_duration = turn_start_time.elapsed();
let turn_tokens = agent.get_context_window().used_tokens.saturating_sub(turn_start_tokens);
turn_metrics.push(TurnMetrics {
turn_number: turn,
tokens_used: turn_tokens,
wall_clock_time: turn_duration,
});
turn += 1; turn += 1;
output.print("🔄 Coach provided feedback for next iteration"); output.print("🔄 Coach provided feedback for next iteration");

View File

@@ -12,3 +12,6 @@ thiserror = { workspace = true }
toml = "0.8" toml = "0.8"
shellexpand = "3.0" shellexpand = "3.0"
dirs = "5.0" dirs = "5.0"
[dev-dependencies]
tempfile = "3.8"

View File

@@ -17,6 +17,8 @@ pub struct ProvidersConfig {
pub databricks: Option<DatabricksConfig>, pub databricks: Option<DatabricksConfig>,
pub embedded: Option<EmbeddedConfig>, pub embedded: Option<EmbeddedConfig>,
pub default_provider: String, pub default_provider: String,
pub coach: Option<String>, // Provider to use for coach in autonomous mode
pub player: Option<String>, // Provider to use for player in autonomous mode
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -112,6 +114,8 @@ impl Default for Config {
}), }),
embedded: None, embedded: None,
default_provider: "databricks".to_string(), default_provider: "databricks".to_string(),
coach: None, // Will use default_provider if not specified
player: None, // Will use default_provider if not specified
}, },
agent: AgentConfig { agent: AgentConfig {
max_context_length: 8192, max_context_length: 8192,
@@ -224,6 +228,8 @@ impl Config {
threads: Some(8), threads: Some(8),
}), }),
default_provider: "embedded".to_string(), default_provider: "embedded".to_string(),
coach: None, // Will use default_provider if not specified
player: None, // Will use default_provider if not specified
}, },
agent: AgentConfig { agent: AgentConfig {
max_context_length: 8192, max_context_length: 8192,
@@ -300,4 +306,67 @@ impl Config {
Ok(config) Ok(config)
} }
/// Get the provider to use for coach mode in autonomous execution
pub fn get_coach_provider(&self) -> &str {
self.providers.coach
.as_deref()
.unwrap_or(&self.providers.default_provider)
}
/// Get the provider to use for player mode in autonomous execution
pub fn get_player_provider(&self) -> &str {
self.providers.player
.as_deref()
.unwrap_or(&self.providers.default_provider)
}
/// Create a copy of the config with a different default provider
pub fn with_provider_override(&self, provider: &str) -> Result<Self> {
// Validate that the provider is configured
match provider {
"anthropic" if self.providers.anthropic.is_none() => {
return Err(anyhow::anyhow!(
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
provider, provider
));
}
"databricks" if self.providers.databricks.is_none() => {
return Err(anyhow::anyhow!(
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
provider, provider
));
}
"embedded" if self.providers.embedded.is_none() => {
return Err(anyhow::anyhow!(
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
provider, provider
));
}
"openai" if self.providers.openai.is_none() => {
return Err(anyhow::anyhow!(
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
provider, provider
));
}
_ => {} // Provider is configured or unknown (will be caught later)
}
let mut config = self.clone();
config.providers.default_provider = provider.to_string();
Ok(config)
}
/// Create a copy of the config for coach mode in autonomous execution
pub fn for_coach(&self) -> Result<Self> {
self.with_provider_override(self.get_coach_provider())
}
/// Create a copy of the config for player mode in autonomous execution
pub fn for_player(&self) -> Result<Self> {
self.with_provider_override(self.get_player_provider())
}
} }
#[cfg(test)]
mod tests;

View File

@@ -0,0 +1,131 @@
#[cfg(test)]
mod tests {
use crate::Config;
use std::fs;
use tempfile::TempDir;
#[test]
fn test_coach_player_providers() {
// Create a temporary directory for the test config
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("test_config.toml");
// Write a test configuration with coach and player providers
let config_content = r#"
[providers]
default_provider = "databricks"
coach = "anthropic"
player = "embedded"
[providers.databricks]
host = "https://test.databricks.com"
token = "test-token"
model = "test-model"
[providers.anthropic]
api_key = "test-key"
model = "claude-3"
[providers.embedded]
model_path = "test.gguf"
model_type = "llama"
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
"#;
fs::write(&config_path, config_content).unwrap();
// Load the configuration
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
// Test that the providers are correctly identified
assert_eq!(config.providers.default_provider, "databricks");
assert_eq!(config.get_coach_provider(), "anthropic");
assert_eq!(config.get_player_provider(), "embedded");
// Test creating coach config
let coach_config = config.for_coach().unwrap();
assert_eq!(coach_config.providers.default_provider, "anthropic");
// Test creating player config
let player_config = config.for_player().unwrap();
assert_eq!(player_config.providers.default_provider, "embedded");
}
#[test]
fn test_coach_player_fallback_to_default() {
// Create a temporary directory for the test config
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("test_config.toml");
// Write a test configuration WITHOUT coach and player providers
let config_content = r#"
[providers]
default_provider = "databricks"
[providers.databricks]
host = "https://test.databricks.com"
token = "test-token"
model = "test-model"
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
"#;
fs::write(&config_path, config_content).unwrap();
// Load the configuration
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
// Test that coach and player fall back to default provider
assert_eq!(config.get_coach_provider(), "databricks");
assert_eq!(config.get_player_provider(), "databricks");
// Test creating coach config (should use default)
let coach_config = config.for_coach().unwrap();
assert_eq!(coach_config.providers.default_provider, "databricks");
// Test creating player config (should use default)
let player_config = config.for_player().unwrap();
assert_eq!(player_config.providers.default_provider, "databricks");
}
#[test]
fn test_invalid_provider_error() {
// Create a temporary directory for the test config
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("test_config.toml");
// Write a test configuration with an unconfigured provider
let config_content = r#"
[providers]
default_provider = "databricks"
coach = "openai" # OpenAI is not configured
[providers.databricks]
host = "https://test.databricks.com"
token = "test-token"
model = "test-model"
[agent]
max_context_length = 8192
enable_streaming = true
timeout_seconds = 60
"#;
fs::write(&config_path, config_content).unwrap();
// Load the configuration
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
// Test that trying to create a coach config with unconfigured provider fails
let result = config.for_coach();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not configured"));
}
}

View File

@@ -599,13 +599,32 @@ impl<W: UiWriter> Agent<W> {
) -> Result<Self> { ) -> Result<Self> {
let mut providers = ProviderRegistry::new(); let mut providers = ProviderRegistry::new();
// In autonomous mode, we need to register both coach and player providers
// Otherwise, only register the default provider
let providers_to_register: Vec<String> = if is_autonomous {
let mut providers = vec![config.providers.default_provider.clone()];
if let Some(coach) = &config.providers.coach {
if !providers.contains(coach) {
providers.push(coach.clone());
}
}
if let Some(player) = &config.providers.player {
if !providers.contains(player) {
providers.push(player.clone());
}
}
providers
} else {
vec![config.providers.default_provider.clone()]
};
// Only register providers that are configured AND selected as the default provider // Only register providers that are configured AND selected as the default provider
// This prevents unnecessary initialization of heavy providers like embedded models // This prevents unnecessary initialization of heavy providers like embedded models
// Register embedded provider if configured AND it's the default provider // Register embedded provider if configured AND it's the default provider
if let Some(embedded_config) = &config.providers.embedded { if let Some(embedded_config) = &config.providers.embedded {
if config.providers.default_provider == "embedded" { if providers_to_register.contains(&"embedded".to_string()) {
info!("Initializing embedded provider (selected as default)"); info!("Initializing embedded provider");
let embedded_provider = g3_providers::EmbeddedProvider::new( let embedded_provider = g3_providers::EmbeddedProvider::new(
embedded_config.model_path.clone(), embedded_config.model_path.clone(),
embedded_config.model_type.clone(), embedded_config.model_type.clone(),
@@ -617,14 +636,31 @@ impl<W: UiWriter> Agent<W> {
)?; )?;
providers.register(embedded_provider); providers.register(embedded_provider);
} else { } else {
info!("Embedded provider configured but not selected as default, skipping initialization"); info!("Embedded provider configured but not needed, skipping initialization");
}
}
// Register OpenAI provider if configured AND it's the default provider
if let Some(openai_config) = &config.providers.openai {
if providers_to_register.contains(&"openai".to_string()) {
info!("Initializing OpenAI provider");
let openai_provider = g3_providers::OpenAIProvider::new(
openai_config.api_key.clone(),
Some(openai_config.model.clone()),
openai_config.base_url.clone(),
openai_config.max_tokens,
openai_config.temperature,
)?;
providers.register(openai_provider);
} else {
info!("OpenAI provider configured but not needed, skipping initialization");
} }
} }
// Register Anthropic provider if configured AND it's the default provider // Register Anthropic provider if configured AND it's the default provider
if let Some(anthropic_config) = &config.providers.anthropic { if let Some(anthropic_config) = &config.providers.anthropic {
if config.providers.default_provider == "anthropic" { if providers_to_register.contains(&"anthropic".to_string()) {
info!("Initializing Anthropic provider (selected as default)"); info!("Initializing Anthropic provider");
let anthropic_provider = g3_providers::AnthropicProvider::new( let anthropic_provider = g3_providers::AnthropicProvider::new(
anthropic_config.api_key.clone(), anthropic_config.api_key.clone(),
Some(anthropic_config.model.clone()), Some(anthropic_config.model.clone()),
@@ -633,14 +669,14 @@ impl<W: UiWriter> Agent<W> {
)?; )?;
providers.register(anthropic_provider); providers.register(anthropic_provider);
} else { } else {
info!("Anthropic provider configured but not selected as default, skipping initialization"); info!("Anthropic provider configured but not needed, skipping initialization");
} }
} }
// Register Databricks provider if configured AND it's the default provider // Register Databricks provider if configured AND it's the default provider
if let Some(databricks_config) = &config.providers.databricks { if let Some(databricks_config) = &config.providers.databricks {
if config.providers.default_provider == "databricks" { if providers_to_register.contains(&"databricks".to_string()) {
info!("Initializing Databricks provider (selected as default)"); info!("Initializing Databricks provider");
let databricks_provider = if let Some(token) = &databricks_config.token { let databricks_provider = if let Some(token) = &databricks_config.token {
// Use token-based authentication // Use token-based authentication
@@ -664,7 +700,7 @@ impl<W: UiWriter> Agent<W> {
providers.register(databricks_provider); providers.register(databricks_provider);
} else { } else {
info!("Databricks provider configured but not selected as default, skipping initialization"); info!("Databricks provider configured but not needed, skipping initialization");
} }
} }
@@ -747,6 +783,9 @@ impl<W: UiWriter> Agent<W> {
config.agent.max_context_length as u32 config.agent.max_context_length as u32
} }
} }
"openai" => {
192000
}
"anthropic" => { "anthropic" => {
// Claude models have large context windows // Claude models have large context windows
200000 // Default for Claude models 200000 // Default for Claude models
@@ -1034,7 +1073,6 @@ Template:
}; };
// Get max_tokens from provider configuration // Get max_tokens from provider configuration
// For Databricks, this should be much higher to support large file generation
let max_tokens = match provider.name() { let max_tokens = match provider.name() {
"databricks" => { "databricks" => {
// Use the model's maximum limit for Databricks to allow large file generation // Use the model's maximum limit for Databricks to allow large file generation

View File

@@ -156,8 +156,9 @@ impl AnthropicProvider {
.post(ANTHROPIC_API_URL) .post(ANTHROPIC_API_URL)
.header("x-api-key", &self.api_key) .header("x-api-key", &self.api_key)
.header("anthropic-version", ANTHROPIC_VERSION) .header("anthropic-version", ANTHROPIC_VERSION)
// Anthropic beta 1m context window. Enable if needed. It costs extra, so check first.
// .header("anthropic-beta", "context-1m-2025-08-07")
.header("content-type", "application/json"); .header("content-type", "application/json");
if streaming { if streaming {
builder = builder.header("accept", "text/event-stream"); builder = builder.header("accept", "text/event-stream");
} }

View File

@@ -88,10 +88,12 @@ pub mod anthropic;
pub mod databricks; pub mod databricks;
pub mod embedded; pub mod embedded;
pub mod oauth; pub mod oauth;
pub mod openai;
pub use anthropic::AnthropicProvider; pub use anthropic::AnthropicProvider;
pub use databricks::DatabricksProvider; pub use databricks::DatabricksProvider;
pub use embedded::EmbeddedProvider; pub use embedded::EmbeddedProvider;
pub use openai::OpenAIProvider;
/// Provider registry for managing multiple LLM providers /// Provider registry for managing multiple LLM providers
pub struct ProviderRegistry { pub struct ProviderRegistry {

View File

@@ -0,0 +1,495 @@
use anyhow::Result;
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::stream::StreamExt;
use reqwest::Client;
use serde::Deserialize;
use serde_json::json;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error};
use crate::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider,
Message, MessageRole, Tool, ToolCall, Usage,
};
#[derive(Clone)]
pub struct OpenAIProvider {
client: Client,
api_key: String,
model: String,
base_url: String,
max_tokens: Option<u32>,
_temperature: Option<f32>,
}
impl OpenAIProvider {
pub fn new(
api_key: String,
model: Option<String>,
base_url: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
) -> Result<Self> {
Ok(Self {
client: Client::new(),
api_key,
model: model.unwrap_or_else(|| "gpt-4o".to_string()),
base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
max_tokens,
_temperature: temperature,
})
}
fn create_request_body(
&self,
messages: &[Message],
tools: Option<&[Tool]>,
stream: bool,
max_tokens: Option<u32>,
_temperature: Option<f32>,
) -> serde_json::Value {
let mut body = json!({
"model": self.model,
"messages": convert_messages(messages),
"stream": stream,
});
if let Some(max_tokens) = max_tokens.or(self.max_tokens) {
body["max_completion_tokens"] = json!(max_tokens);
}
// OpenAI calls with temp setting seem to fail, so don't send one.
// if let Some(temperature) = temperature.or(self.temperature) {
// body["temperature"] = json!(temperature);
// }
if let Some(tools) = tools {
if !tools.is_empty() {
body["tools"] = json!(convert_tools(tools));
}
}
if stream {
body["stream_options"] = json!({
"include_usage": true,
});
}
body
}
async fn parse_streaming_response(
&self,
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
tx: mpsc::Sender<Result<CompletionChunk>>,
) -> Option<Usage> {
let mut buffer = String::new();
let mut accumulated_content = String::new();
let mut accumulated_usage: Option<Usage> = None;
let mut current_tool_calls: Vec<OpenAIStreamingToolCall> = Vec::new();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
let chunk_str = match std::str::from_utf8(&chunk) {
Ok(s) => s,
Err(e) => {
error!("Failed to parse chunk as UTF-8: {}", e);
continue;
}
};
buffer.push_str(chunk_str);
// Process complete lines
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].trim().to_string();
buffer.drain(..line_end + 1);
if line.is_empty() {
continue;
}
// Parse Server-Sent Events format
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
debug!("Received stream completion marker");
// Send final chunk with accumulated content and tool calls
if !accumulated_content.is_empty() || !current_tool_calls.is_empty() {
let tool_calls = if current_tool_calls.is_empty() {
None
} else {
Some(
current_tool_calls
.iter()
.filter_map(|tc| tc.to_tool_call())
.collect(),
)
};
let final_chunk = CompletionChunk {
content: accumulated_content.clone(),
finished: true,
tool_calls,
usage: accumulated_usage.clone(),
};
let _ = tx.send(Ok(final_chunk)).await;
}
return accumulated_usage;
}
// Parse the JSON data
match serde_json::from_str::<OpenAIStreamChunk>(data) {
Ok(chunk_data) => {
// Handle content
for choice in &chunk_data.choices {
if let Some(content) = &choice.delta.content {
accumulated_content.push_str(content);
let chunk = CompletionChunk {
content: content.clone(),
finished: false,
tool_calls: None,
usage: None,
};
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return accumulated_usage;
}
}
// Handle tool calls
if let Some(delta_tool_calls) = &choice.delta.tool_calls {
for delta_tool_call in delta_tool_calls {
if let Some(index) = delta_tool_call.index {
// Ensure we have enough tool calls in our vector
while current_tool_calls.len() <= index {
current_tool_calls
.push(OpenAIStreamingToolCall::default());
}
let tool_call = &mut current_tool_calls[index];
if let Some(id) = &delta_tool_call.id {
tool_call.id = Some(id.clone());
}
if let Some(function) = &delta_tool_call.function {
if let Some(name) = &function.name {
tool_call.name = Some(name.clone());
}
if let Some(arguments) = &function.arguments {
tool_call.arguments.push_str(arguments);
}
}
}
}
}
}
// Handle usage
if let Some(usage) = chunk_data.usage {
accumulated_usage = Some(Usage {
prompt_tokens: usage.prompt_tokens,
completion_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
});
}
}
Err(e) => {
debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
}
}
}
}
}
Err(e) => {
error!("Stream error: {}", e);
let _ = tx.send(Err(anyhow::anyhow!("Stream error: {}", e))).await;
return accumulated_usage;
}
}
}
// Send final chunk if we haven't already
let tool_calls = if current_tool_calls.is_empty() {
None
} else {
Some(
current_tool_calls
.iter()
.filter_map(|tc| tc.to_tool_call())
.collect(),
)
};
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
tool_calls,
usage: accumulated_usage.clone(),
};
let _ = tx.send(Ok(final_chunk)).await;
accumulated_usage
}
}
#[async_trait]
impl LLMProvider for OpenAIProvider {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
debug!(
"Processing OpenAI completion request with {} messages",
request.messages.len()
);
let body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
false,
request.max_tokens,
request.temperature,
);
debug!("Sending request to OpenAI API: model={}", self.model);
let response = self
.client
.post(&format!("{}/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text));
}
let openai_response: OpenAIResponse = response.json().await?;
let content = openai_response
.choices
.first()
.and_then(|choice| choice.message.content.clone())
.unwrap_or_default();
let usage = Usage {
prompt_tokens: openai_response.usage.prompt_tokens,
completion_tokens: openai_response.usage.completion_tokens,
total_tokens: openai_response.usage.total_tokens,
};
debug!(
"OpenAI completion successful: {} tokens generated",
usage.completion_tokens
);
Ok(CompletionResponse {
content,
usage,
model: self.model.clone(),
})
}
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
debug!(
"Processing OpenAI streaming request with {} messages",
request.messages.len()
);
let body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
true,
request.max_tokens,
request.temperature,
);
debug!("Sending streaming request to OpenAI API: model={}", self.model);
let response = self
.client
.post(&format!("{}/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text));
}
let stream = response.bytes_stream();
let (tx, rx) = mpsc::channel(100);
// Spawn task to process the stream
let provider = self.clone();
tokio::spawn(async move {
let usage = provider.parse_streaming_response(stream, tx).await;
// Log the final usage if available
if let Some(usage) = usage {
debug!(
"Stream completed with usage - prompt: {}, completion: {}, total: {}",
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
);
}
});
Ok(ReceiverStream::new(rx))
}
fn name(&self) -> &str {
"openai"
}
fn model(&self) -> &str {
&self.model
}
fn has_native_tool_calling(&self) -> bool {
// OpenAI models support native tool calling
true
}
}
fn convert_messages(messages: &[Message]) -> Vec<serde_json::Value> {
messages
.iter()
.map(|msg| {
json!({
"role": match msg.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
},
"content": msg.content,
})
})
.collect()
}
fn convert_tools(tools: &[Tool]) -> Vec<serde_json::Value> {
tools
.iter()
.map(|tool| {
json!({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.input_schema,
}
})
})
.collect()
}
// OpenAI API response structures
#[derive(Debug, Deserialize)]
struct OpenAIResponse {
choices: Vec<OpenAIChoice>,
usage: OpenAIUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAIChoice {
message: OpenAIMessage,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct OpenAIMessage {
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<OpenAIToolCall>>,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct OpenAIToolCall {
id: String,
function: OpenAIFunction,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct OpenAIFunction {
name: String,
arguments: String,
}
// Streaming tool call accumulator
#[derive(Debug, Default)]
struct OpenAIStreamingToolCall {
id: Option<String>,
name: Option<String>,
arguments: String,
}
impl OpenAIStreamingToolCall {
fn to_tool_call(&self) -> Option<ToolCall> {
let id = self.id.as_ref()?;
let name = self.name.as_ref()?;
let args = serde_json::from_str(&self.arguments).unwrap_or(serde_json::Value::Null);
Some(ToolCall {
id: id.clone(),
tool: name.clone(),
args,
})
}
}
#[derive(Debug, Deserialize)]
struct OpenAIUsage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
// Streaming response structures
#[derive(Debug, Deserialize)]
struct OpenAIStreamChunk {
choices: Vec<OpenAIStreamChoice>,
usage: Option<OpenAIUsage>,
}
#[derive(Debug, Deserialize)]
struct OpenAIStreamChoice {
delta: OpenAIDelta,
}
#[derive(Debug, Deserialize)]
struct OpenAIDelta {
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<OpenAIDeltaToolCall>>,
}
#[derive(Debug, Deserialize)]
struct OpenAIDeltaToolCall {
index: Option<usize>,
id: Option<String>,
function: Option<OpenAIDeltaFunction>,
}
#[derive(Debug, Deserialize)]
struct OpenAIDeltaFunction {
name: Option<String>,
arguments: Option<String>,
}

View File

@@ -0,0 +1,75 @@
# Coach-Player Provider Configuration
G3 now supports specifying different LLM providers for the coach and player agents when running in autonomous mode. This allows you to optimize for different requirements:
- **Player**: The agent that implements code - might benefit from a faster, more cost-effective model
- **Coach**: The agent that reviews code - might benefit from a more powerful, analytical model
## Configuration
In your `config.toml` file, under the `[providers]` section, you can specify:
```toml
[providers]
default_provider = "databricks" # Used for normal operations
coach = "databricks" # Provider for coach (code reviewer)
player = "anthropic" # Provider for player (code implementer)
```
If `coach` or `player` are not specified, they will default to using the `default_provider`.
## Example Use Cases
### Cost Optimization
Use a cheaper, faster model for initial implementations (player) and a more powerful model for review (coach):
```toml
coach = "anthropic" # Claude Sonnet for thorough review
player = "anthropic" # Claude Haiku for quick implementation
```
### Speed vs Quality Trade-off
Use a local embedded model for fast iterations (player) and a cloud model for quality review (coach):
```toml
coach = "databricks" # Cloud model for quality review
player = "embedded" # Local model for fast implementation
```
### Specialized Models
Use different models optimized for different tasks:
```toml
coach = "databricks" # Model fine-tuned for code review
player = "openai" # Model optimized for code generation
```
## Requirements
- Both providers must be properly configured in your config file
- Each provider must have valid credentials
- The models specified for each provider must be accessible
## How It Works
When running in autonomous mode (`g3 --autonomous`), the system will:
1. Use the `player` provider (or default) for the initial implementation
2. Switch to the `coach` provider (or default) for code review
3. Return to the `player` provider for implementing feedback
4. Continue this cycle for the specified number of turns
The providers are logged at startup so you can verify which models are being used:
```
🎮 Player provider: anthropic
👨‍🏫 Coach provider: databricks
Using different providers for player and coach
```
## Benefits
- **Cost Efficiency**: Use expensive models only where they add the most value
- **Speed Optimization**: Use faster models for iterative development
- **Specialization**: Leverage models that excel at specific tasks
- **Flexibility**: Easy to experiment with different provider combinations

39
test-ai-requirements.sh Executable file
View File

@@ -0,0 +1,39 @@
#!/bin/bash
# Test script for AI-enhanced interactive requirements mode
echo "Testing AI-enhanced interactive requirements mode..."
echo ""
# Create a test workspace
TEST_WORKSPACE="/tmp/g3-test-interactive-$(date +%s)"
mkdir -p "$TEST_WORKSPACE"
echo "Test workspace: $TEST_WORKSPACE"
echo ""
# Create sample brief input
BRIEF_INPUT="build a calculator cli in rust with basic operations"
echo "Brief input:"
echo "---"
echo "$BRIEF_INPUT"
echo "---"
echo ""
echo "This will:"
echo "1. Send brief input to AI"
echo "2. AI generates structured requirements.md"
echo "3. Show enhanced requirements"
echo "4. Prompt for confirmation (y/e/n)"
echo ""
echo "To test manually, run:"
echo "cargo run -- --autonomous --interactive-requirements --workspace $TEST_WORKSPACE"
echo ""
echo "Then type: $BRIEF_INPUT"
echo "Press Ctrl+D"
echo "Review the AI-generated requirements"
echo "Choose 'y' to proceed, 'e' to edit, or 'n' to cancel"
echo ""
echo "Test workspace will be at: $TEST_WORKSPACE"