Merge branch 'main' into micn/fix-anthropic-1p
* main: fix panic in CLI parser coach/player provider split + add OpenAI
This commit is contained in:
@@ -103,11 +103,14 @@ fn extract_coach_feedback_from_logs(
|
||||
coach_result: &g3_core::TaskResult,
|
||||
coach_agent: &g3_core::Agent<ConsoleUiWriter>,
|
||||
output: &SimpleOutput,
|
||||
) -> String {
|
||||
) -> Result<String> {
|
||||
// CORRECT APPROACH: Get the session ID from the current coach agent
|
||||
// and read its specific log file directly
|
||||
|
||||
// Get the coach agent's session ID
|
||||
let session_id = coach_agent
|
||||
.get_session_id()
|
||||
.expect("Coach agent has no session ID");
|
||||
.ok_or_else(|| anyhow::anyhow!("Coach agent has no session ID"))?;
|
||||
|
||||
// Construct the log file path for this specific coach session
|
||||
let logs_dir = std::path::Path::new("logs");
|
||||
@@ -120,75 +123,15 @@ fn extract_coach_feedback_from_logs(
|
||||
if let Some(context_window) = log_json.get("context_window") {
|
||||
if let Some(conversation_history) = context_window.get("conversation_history") {
|
||||
if let Some(messages) = conversation_history.as_array() {
|
||||
// Look for the last assistant message (regardless of tool used)
|
||||
for message in messages.iter().rev() {
|
||||
if let Some(role) = message.get("role") {
|
||||
if role.as_str() == Some("assistant") {
|
||||
if let Some(content) = message.get("content") {
|
||||
if let Some(content_str) = content.as_str() {
|
||||
// First, check if this is plain text feedback (no tool call)
|
||||
// This happens when the coach returns final feedback directly
|
||||
if !content_str.contains("{\"tool\"") {
|
||||
let trimmed = content_str.trim();
|
||||
if !trimmed.is_empty() {
|
||||
output.print(&format!(
|
||||
"✅ Extracted coach feedback from session: {} ({} chars) [plain text]",
|
||||
session_id,
|
||||
trimmed.len()
|
||||
));
|
||||
return trimmed.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
// Look for ANY tool call in the message
|
||||
// Pattern: {"tool": "...", "args": {...}}
|
||||
if let Some(tool_start) = content_str.find("{\"tool\"") {
|
||||
let json_part = &content_str[tool_start..];
|
||||
|
||||
// Find the end of the JSON object
|
||||
if let Some(json_end) = find_json_end(json_part) {
|
||||
let json_str = &json_part[..json_end];
|
||||
|
||||
if let Ok(tool_call) = serde_json::from_str::<serde_json::Value>(json_str) {
|
||||
if let Some(args) = tool_call.get("args") {
|
||||
// Try to extract feedback from different possible fields
|
||||
let feedback = if let Some(summary) = args.get("summary") {
|
||||
// final_output tool uses "summary"
|
||||
summary.as_str().map(|s| s.to_string())
|
||||
} else if let Some(content) = args.get("content") {
|
||||
// todo_write and other tools might use "content"
|
||||
content.as_str().map(|s| s.to_string())
|
||||
} else {
|
||||
// Fallback: use the entire args as JSON string
|
||||
Some(serde_json::to_string_pretty(args).unwrap_or_default())
|
||||
};
|
||||
|
||||
if let Some(feedback_str) = feedback {
|
||||
if !feedback_str.trim().is_empty() {
|
||||
output.print(&format!(
|
||||
"✅ Extracted coach feedback from session: {} ({} chars)",
|
||||
session_id,
|
||||
feedback_str.len()
|
||||
));
|
||||
|
||||
// Validate feedback length
|
||||
if feedback_str.len() < 80 && !feedback_str.contains("IMPLEMENTATION_APPROVED") {
|
||||
panic!(
|
||||
"Coach feedback is too short ({} chars): '{}'",
|
||||
feedback_str.len(),
|
||||
feedback_str
|
||||
);
|
||||
}
|
||||
|
||||
return feedback_str;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Simply get the last message content - this is the coach's final feedback
|
||||
if let Some(last_message) = messages.last() {
|
||||
if let Some(content) = last_message.get("content") {
|
||||
if let Some(content_str) = content.as_str() {
|
||||
output.print(&format!(
|
||||
"✅ Extracted coach feedback from session: {}",
|
||||
session_id
|
||||
));
|
||||
return Ok(content_str.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -213,35 +156,6 @@ fn extract_coach_feedback_from_logs(
|
||||
);
|
||||
}
|
||||
|
||||
/// Helper function to find the end of a JSON object using brace counting
|
||||
fn find_json_end(json_str: &str) -> Option<usize> {
|
||||
let mut depth = 0;
|
||||
let mut in_string = false;
|
||||
let mut escape_next = false;
|
||||
|
||||
for (i, ch) in json_str.char_indices() {
|
||||
if escape_next {
|
||||
escape_next = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
match ch {
|
||||
'\\' if in_string => escape_next = true,
|
||||
'"' => in_string = !in_string,
|
||||
'{' if !in_string => depth += 1,
|
||||
'}' if !in_string => {
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
return Some(i + 1);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
use clap::Parser;
|
||||
use g3_config::Config;
|
||||
use g3_core::{project::Project, ui_writer::UiWriter, Agent};
|
||||
@@ -321,10 +235,6 @@ pub struct Cli {
|
||||
/// Disable log file creation (no logs/ directory or session logs)
|
||||
#[arg(long)]
|
||||
pub quiet: bool,
|
||||
|
||||
/// Enable WebDriver tools for browser automation (Safari)
|
||||
#[arg(long)]
|
||||
pub webdriver: bool,
|
||||
}
|
||||
|
||||
pub async fn run() -> Result<()> {
|
||||
@@ -413,17 +323,12 @@ pub async fn run() -> Result<()> {
|
||||
}
|
||||
|
||||
// Load configuration with CLI overrides
|
||||
let mut config = Config::load_with_overrides(
|
||||
let config = Config::load_with_overrides(
|
||||
cli.config.as_deref(),
|
||||
cli.provider.clone(),
|
||||
cli.model.clone(),
|
||||
)?;
|
||||
|
||||
// Override webdriver setting from CLI flag
|
||||
if cli.webdriver {
|
||||
config.webdriver.enabled = true;
|
||||
}
|
||||
|
||||
// Validate provider if specified
|
||||
if let Some(ref provider) = cli.provider {
|
||||
let valid_providers = ["anthropic", "databricks", "embedded", "openai"];
|
||||
@@ -1358,10 +1263,6 @@ async fn run_autonomous(
|
||||
loop {
|
||||
let turn_start_time = Instant::now();
|
||||
let turn_start_tokens = agent.get_context_window().used_tokens;
|
||||
|
||||
// Reset filter suppression state at the start of each turn
|
||||
g3_core::fixed_filter_json::reset_fixed_json_tool_state();
|
||||
|
||||
// Skip player turn if it's the first turn and implementation files exist
|
||||
if !(turn == 1 && skip_first_player) {
|
||||
output.print(&format!(
|
||||
@@ -1522,14 +1423,15 @@ async fn run_autonomous(
|
||||
|
||||
// Create a new agent instance for coach mode to ensure fresh context
|
||||
// Use the same config with overrides that was passed to the player agent
|
||||
let config = agent.get_config().clone();
|
||||
|
||||
let base_config = agent.get_config().clone();
|
||||
let coach_config = base_config.for_coach()?;
|
||||
|
||||
// Reset filter suppression state before creating coach agent
|
||||
g3_core::fixed_filter_json::reset_fixed_json_tool_state();
|
||||
|
||||
|
||||
let ui_writer = ConsoleUiWriter::new();
|
||||
let mut coach_agent =
|
||||
Agent::new_autonomous_with_readme_and_quiet(config, ui_writer, None, quiet).await?;
|
||||
Agent::new_autonomous_with_readme_and_quiet(coach_config, ui_writer, None, quiet).await?;
|
||||
|
||||
// Ensure coach agent is also in the workspace directory
|
||||
project.enter_workspace()?;
|
||||
@@ -1677,7 +1579,7 @@ Remember: Be clear in your review and concise in your feedback. APPROVE if the i
|
||||
|
||||
// Extract the complete coach feedback from final_output
|
||||
let coach_feedback_text =
|
||||
extract_coach_feedback_from_logs(&coach_result, &coach_agent, &output);
|
||||
extract_coach_feedback_from_logs(&coach_result, &coach_agent, &output)?;
|
||||
|
||||
// Log the size of the feedback for debugging
|
||||
info!(
|
||||
@@ -1704,15 +1606,6 @@ Remember: Be clear in your review and concise in your feedback. APPROVE if the i
|
||||
|
||||
output.print_smart(&format!("Coach feedback:\n{}", coach_feedback_text));
|
||||
|
||||
// Record turn metrics before checking for approval or max turns
|
||||
let turn_duration = turn_start_time.elapsed();
|
||||
let turn_tokens = agent.get_context_window().used_tokens.saturating_sub(turn_start_tokens);
|
||||
turn_metrics.push(TurnMetrics {
|
||||
turn_number: turn,
|
||||
tokens_used: turn_tokens,
|
||||
wall_clock_time: turn_duration,
|
||||
});
|
||||
|
||||
// Check if coach approved the implementation
|
||||
if coach_result.is_approved() || coach_feedback_text.contains("IMPLEMENTATION_APPROVED") {
|
||||
output.print("\n=== SESSION COMPLETED - IMPLEMENTATION APPROVED ===");
|
||||
@@ -1721,7 +1614,6 @@ Remember: Be clear in your review and concise in your feedback. APPROVE if the i
|
||||
break;
|
||||
}
|
||||
|
||||
// Increment turn counter after recording metrics but before checking max turns
|
||||
// Check if we've reached max turns
|
||||
if turn >= max_turns {
|
||||
output.print("\n=== SESSION COMPLETED - MAX TURNS REACHED ===");
|
||||
@@ -1731,7 +1623,14 @@ Remember: Be clear in your review and concise in your feedback. APPROVE if the i
|
||||
|
||||
// Store coach feedback for next iteration
|
||||
coach_feedback = coach_feedback_text;
|
||||
|
||||
// Record turn metrics before incrementing
|
||||
let turn_duration = turn_start_time.elapsed();
|
||||
let turn_tokens = agent.get_context_window().used_tokens.saturating_sub(turn_start_tokens);
|
||||
turn_metrics.push(TurnMetrics {
|
||||
turn_number: turn,
|
||||
tokens_used: turn_tokens,
|
||||
wall_clock_time: turn_duration,
|
||||
});
|
||||
turn += 1;
|
||||
|
||||
output.print("🔄 Coach provided feedback for next iteration");
|
||||
|
||||
@@ -12,3 +12,6 @@ thiserror = { workspace = true }
|
||||
toml = "0.8"
|
||||
shellexpand = "3.0"
|
||||
dirs = "5.0"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.8"
|
||||
|
||||
@@ -17,6 +17,8 @@ pub struct ProvidersConfig {
|
||||
pub databricks: Option<DatabricksConfig>,
|
||||
pub embedded: Option<EmbeddedConfig>,
|
||||
pub default_provider: String,
|
||||
pub coach: Option<String>, // Provider to use for coach in autonomous mode
|
||||
pub player: Option<String>, // Provider to use for player in autonomous mode
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -112,6 +114,8 @@ impl Default for Config {
|
||||
}),
|
||||
embedded: None,
|
||||
default_provider: "databricks".to_string(),
|
||||
coach: None, // Will use default_provider if not specified
|
||||
player: None, // Will use default_provider if not specified
|
||||
},
|
||||
agent: AgentConfig {
|
||||
max_context_length: 8192,
|
||||
@@ -224,6 +228,8 @@ impl Config {
|
||||
threads: Some(8),
|
||||
}),
|
||||
default_provider: "embedded".to_string(),
|
||||
coach: None, // Will use default_provider if not specified
|
||||
player: None, // Will use default_provider if not specified
|
||||
},
|
||||
agent: AgentConfig {
|
||||
max_context_length: 8192,
|
||||
@@ -300,4 +306,67 @@ impl Config {
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Get the provider to use for coach mode in autonomous execution
|
||||
pub fn get_coach_provider(&self) -> &str {
|
||||
self.providers.coach
|
||||
.as_deref()
|
||||
.unwrap_or(&self.providers.default_provider)
|
||||
}
|
||||
|
||||
/// Get the provider to use for player mode in autonomous execution
|
||||
pub fn get_player_provider(&self) -> &str {
|
||||
self.providers.player
|
||||
.as_deref()
|
||||
.unwrap_or(&self.providers.default_provider)
|
||||
}
|
||||
|
||||
/// Create a copy of the config with a different default provider
|
||||
pub fn with_provider_override(&self, provider: &str) -> Result<Self> {
|
||||
// Validate that the provider is configured
|
||||
match provider {
|
||||
"anthropic" if self.providers.anthropic.is_none() => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
|
||||
provider, provider
|
||||
));
|
||||
}
|
||||
"databricks" if self.providers.databricks.is_none() => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
|
||||
provider, provider
|
||||
));
|
||||
}
|
||||
"embedded" if self.providers.embedded.is_none() => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
|
||||
provider, provider
|
||||
));
|
||||
}
|
||||
"openai" if self.providers.openai.is_none() => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Provider '{}' is specified but not configured. Please add {} configuration to your config file.",
|
||||
provider, provider
|
||||
));
|
||||
}
|
||||
_ => {} // Provider is configured or unknown (will be caught later)
|
||||
}
|
||||
|
||||
let mut config = self.clone();
|
||||
config.providers.default_provider = provider.to_string();
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Create a copy of the config for coach mode in autonomous execution
|
||||
pub fn for_coach(&self) -> Result<Self> {
|
||||
self.with_provider_override(self.get_coach_provider())
|
||||
}
|
||||
|
||||
/// Create a copy of the config for player mode in autonomous execution
|
||||
pub fn for_player(&self) -> Result<Self> {
|
||||
self.with_provider_override(self.get_player_provider())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
131
crates/g3-config/src/tests.rs
Normal file
131
crates/g3-config/src/tests.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::Config;
|
||||
use std::fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn test_coach_player_providers() {
|
||||
// Create a temporary directory for the test config
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let config_path = temp_dir.path().join("test_config.toml");
|
||||
|
||||
// Write a test configuration with coach and player providers
|
||||
let config_content = r#"
|
||||
[providers]
|
||||
default_provider = "databricks"
|
||||
coach = "anthropic"
|
||||
player = "embedded"
|
||||
|
||||
[providers.databricks]
|
||||
host = "https://test.databricks.com"
|
||||
token = "test-token"
|
||||
model = "test-model"
|
||||
|
||||
[providers.anthropic]
|
||||
api_key = "test-key"
|
||||
model = "claude-3"
|
||||
|
||||
[providers.embedded]
|
||||
model_path = "test.gguf"
|
||||
model_type = "llama"
|
||||
|
||||
[agent]
|
||||
max_context_length = 8192
|
||||
enable_streaming = true
|
||||
timeout_seconds = 60
|
||||
"#;
|
||||
|
||||
fs::write(&config_path, config_content).unwrap();
|
||||
|
||||
// Load the configuration
|
||||
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
|
||||
|
||||
// Test that the providers are correctly identified
|
||||
assert_eq!(config.providers.default_provider, "databricks");
|
||||
assert_eq!(config.get_coach_provider(), "anthropic");
|
||||
assert_eq!(config.get_player_provider(), "embedded");
|
||||
|
||||
// Test creating coach config
|
||||
let coach_config = config.for_coach().unwrap();
|
||||
assert_eq!(coach_config.providers.default_provider, "anthropic");
|
||||
|
||||
// Test creating player config
|
||||
let player_config = config.for_player().unwrap();
|
||||
assert_eq!(player_config.providers.default_provider, "embedded");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coach_player_fallback_to_default() {
|
||||
// Create a temporary directory for the test config
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let config_path = temp_dir.path().join("test_config.toml");
|
||||
|
||||
// Write a test configuration WITHOUT coach and player providers
|
||||
let config_content = r#"
|
||||
[providers]
|
||||
default_provider = "databricks"
|
||||
|
||||
[providers.databricks]
|
||||
host = "https://test.databricks.com"
|
||||
token = "test-token"
|
||||
model = "test-model"
|
||||
|
||||
[agent]
|
||||
max_context_length = 8192
|
||||
enable_streaming = true
|
||||
timeout_seconds = 60
|
||||
"#;
|
||||
|
||||
fs::write(&config_path, config_content).unwrap();
|
||||
|
||||
// Load the configuration
|
||||
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
|
||||
|
||||
// Test that coach and player fall back to default provider
|
||||
assert_eq!(config.get_coach_provider(), "databricks");
|
||||
assert_eq!(config.get_player_provider(), "databricks");
|
||||
|
||||
// Test creating coach config (should use default)
|
||||
let coach_config = config.for_coach().unwrap();
|
||||
assert_eq!(coach_config.providers.default_provider, "databricks");
|
||||
|
||||
// Test creating player config (should use default)
|
||||
let player_config = config.for_player().unwrap();
|
||||
assert_eq!(player_config.providers.default_provider, "databricks");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_provider_error() {
|
||||
// Create a temporary directory for the test config
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let config_path = temp_dir.path().join("test_config.toml");
|
||||
|
||||
// Write a test configuration with an unconfigured provider
|
||||
let config_content = r#"
|
||||
[providers]
|
||||
default_provider = "databricks"
|
||||
coach = "openai" # OpenAI is not configured
|
||||
|
||||
[providers.databricks]
|
||||
host = "https://test.databricks.com"
|
||||
token = "test-token"
|
||||
model = "test-model"
|
||||
|
||||
[agent]
|
||||
max_context_length = 8192
|
||||
enable_streaming = true
|
||||
timeout_seconds = 60
|
||||
"#;
|
||||
|
||||
fs::write(&config_path, config_content).unwrap();
|
||||
|
||||
// Load the configuration
|
||||
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
|
||||
|
||||
// Test that trying to create a coach config with unconfigured provider fails
|
||||
let result = config.for_coach();
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("not configured"));
|
||||
}
|
||||
}
|
||||
@@ -625,13 +625,32 @@ impl<W: UiWriter> Agent<W> {
|
||||
) -> Result<Self> {
|
||||
let mut providers = ProviderRegistry::new();
|
||||
|
||||
// In autonomous mode, we need to register both coach and player providers
|
||||
// Otherwise, only register the default provider
|
||||
let providers_to_register: Vec<String> = if is_autonomous {
|
||||
let mut providers = vec![config.providers.default_provider.clone()];
|
||||
if let Some(coach) = &config.providers.coach {
|
||||
if !providers.contains(coach) {
|
||||
providers.push(coach.clone());
|
||||
}
|
||||
}
|
||||
if let Some(player) = &config.providers.player {
|
||||
if !providers.contains(player) {
|
||||
providers.push(player.clone());
|
||||
}
|
||||
}
|
||||
providers
|
||||
} else {
|
||||
vec![config.providers.default_provider.clone()]
|
||||
};
|
||||
|
||||
// Only register providers that are configured AND selected as the default provider
|
||||
// This prevents unnecessary initialization of heavy providers like embedded models
|
||||
|
||||
// Register embedded provider if configured AND it's the default provider
|
||||
if let Some(embedded_config) = &config.providers.embedded {
|
||||
if config.providers.default_provider == "embedded" {
|
||||
info!("Initializing embedded provider (selected as default)");
|
||||
if providers_to_register.contains(&"embedded".to_string()) {
|
||||
info!("Initializing embedded provider");
|
||||
let embedded_provider = g3_providers::EmbeddedProvider::new(
|
||||
embedded_config.model_path.clone(),
|
||||
embedded_config.model_type.clone(),
|
||||
@@ -643,14 +662,31 @@ impl<W: UiWriter> Agent<W> {
|
||||
)?;
|
||||
providers.register(embedded_provider);
|
||||
} else {
|
||||
info!("Embedded provider configured but not selected as default, skipping initialization");
|
||||
info!("Embedded provider configured but not needed, skipping initialization");
|
||||
}
|
||||
}
|
||||
|
||||
// Register OpenAI provider if configured AND it's the default provider
|
||||
if let Some(openai_config) = &config.providers.openai {
|
||||
if providers_to_register.contains(&"openai".to_string()) {
|
||||
info!("Initializing OpenAI provider");
|
||||
let openai_provider = g3_providers::OpenAIProvider::new(
|
||||
openai_config.api_key.clone(),
|
||||
Some(openai_config.model.clone()),
|
||||
openai_config.base_url.clone(),
|
||||
openai_config.max_tokens,
|
||||
openai_config.temperature,
|
||||
)?;
|
||||
providers.register(openai_provider);
|
||||
} else {
|
||||
info!("OpenAI provider configured but not needed, skipping initialization");
|
||||
}
|
||||
}
|
||||
|
||||
// Register Anthropic provider if configured AND it's the default provider
|
||||
if let Some(anthropic_config) = &config.providers.anthropic {
|
||||
if config.providers.default_provider == "anthropic" {
|
||||
info!("Initializing Anthropic provider (selected as default)");
|
||||
if providers_to_register.contains(&"anthropic".to_string()) {
|
||||
info!("Initializing Anthropic provider");
|
||||
let anthropic_provider = g3_providers::AnthropicProvider::new(
|
||||
anthropic_config.api_key.clone(),
|
||||
Some(anthropic_config.model.clone()),
|
||||
@@ -659,14 +695,14 @@ impl<W: UiWriter> Agent<W> {
|
||||
)?;
|
||||
providers.register(anthropic_provider);
|
||||
} else {
|
||||
info!("Anthropic provider configured but not selected as default, skipping initialization");
|
||||
info!("Anthropic provider configured but not needed, skipping initialization");
|
||||
}
|
||||
}
|
||||
|
||||
// Register Databricks provider if configured AND it's the default provider
|
||||
if let Some(databricks_config) = &config.providers.databricks {
|
||||
if config.providers.default_provider == "databricks" {
|
||||
info!("Initializing Databricks provider (selected as default)");
|
||||
if providers_to_register.contains(&"databricks".to_string()) {
|
||||
info!("Initializing Databricks provider");
|
||||
|
||||
let databricks_provider = if let Some(token) = &databricks_config.token {
|
||||
// Use token-based authentication
|
||||
@@ -690,7 +726,7 @@ impl<W: UiWriter> Agent<W> {
|
||||
|
||||
providers.register(databricks_provider);
|
||||
} else {
|
||||
info!("Databricks provider configured but not selected as default, skipping initialization");
|
||||
info!("Databricks provider configured but not needed, skipping initialization");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -773,6 +809,9 @@ impl<W: UiWriter> Agent<W> {
|
||||
config.agent.max_context_length as u32
|
||||
}
|
||||
}
|
||||
"openai" => {
|
||||
192000
|
||||
}
|
||||
"anthropic" => {
|
||||
// Claude models have large context windows
|
||||
200000 // Default for Claude models
|
||||
@@ -1060,7 +1099,6 @@ Template:
|
||||
};
|
||||
|
||||
// Get max_tokens from provider configuration
|
||||
// For Databricks, this should be much higher to support large file generation
|
||||
let max_tokens = match provider.name() {
|
||||
"databricks" => {
|
||||
// Use the model's maximum limit for Databricks to allow large file generation
|
||||
|
||||
@@ -156,8 +156,9 @@ impl AnthropicProvider {
|
||||
.post(ANTHROPIC_API_URL)
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", ANTHROPIC_VERSION)
|
||||
// Anthropic beta 1m context window. Enable if needed. It costs extra, so check first.
|
||||
// .header("anthropic-beta", "context-1m-2025-08-07")
|
||||
.header("content-type", "application/json");
|
||||
|
||||
if streaming {
|
||||
builder = builder.header("accept", "text/event-stream");
|
||||
}
|
||||
|
||||
@@ -88,10 +88,12 @@ pub mod anthropic;
|
||||
pub mod databricks;
|
||||
pub mod embedded;
|
||||
pub mod oauth;
|
||||
pub mod openai;
|
||||
|
||||
pub use anthropic::AnthropicProvider;
|
||||
pub use databricks::DatabricksProvider;
|
||||
pub use embedded::EmbeddedProvider;
|
||||
pub use openai::OpenAIProvider;
|
||||
|
||||
/// Provider registry for managing multiple LLM providers
|
||||
pub struct ProviderRegistry {
|
||||
|
||||
495
crates/g3-providers/src/openai.rs
Normal file
495
crates/g3-providers/src/openai.rs
Normal file
@@ -0,0 +1,495 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use futures_util::stream::StreamExt;
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tracing::{debug, error};
|
||||
|
||||
use crate::{
|
||||
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider,
|
||||
Message, MessageRole, Tool, ToolCall, Usage,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAIProvider {
|
||||
client: Client,
|
||||
api_key: String,
|
||||
model: String,
|
||||
base_url: String,
|
||||
max_tokens: Option<u32>,
|
||||
_temperature: Option<f32>,
|
||||
}
|
||||
|
||||
impl OpenAIProvider {
|
||||
pub fn new(
|
||||
api_key: String,
|
||||
model: Option<String>,
|
||||
base_url: Option<String>,
|
||||
max_tokens: Option<u32>,
|
||||
temperature: Option<f32>,
|
||||
) -> Result<Self> {
|
||||
Ok(Self {
|
||||
client: Client::new(),
|
||||
api_key,
|
||||
model: model.unwrap_or_else(|| "gpt-4o".to_string()),
|
||||
base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
|
||||
max_tokens,
|
||||
_temperature: temperature,
|
||||
})
|
||||
}
|
||||
|
||||
fn create_request_body(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
tools: Option<&[Tool]>,
|
||||
stream: bool,
|
||||
max_tokens: Option<u32>,
|
||||
_temperature: Option<f32>,
|
||||
) -> serde_json::Value {
|
||||
let mut body = json!({
|
||||
"model": self.model,
|
||||
"messages": convert_messages(messages),
|
||||
"stream": stream,
|
||||
});
|
||||
|
||||
if let Some(max_tokens) = max_tokens.or(self.max_tokens) {
|
||||
body["max_completion_tokens"] = json!(max_tokens);
|
||||
}
|
||||
|
||||
// OpenAI calls with temp setting seem to fail, so don't send one.
|
||||
// if let Some(temperature) = temperature.or(self.temperature) {
|
||||
// body["temperature"] = json!(temperature);
|
||||
// }
|
||||
|
||||
if let Some(tools) = tools {
|
||||
if !tools.is_empty() {
|
||||
body["tools"] = json!(convert_tools(tools));
|
||||
}
|
||||
}
|
||||
|
||||
if stream {
|
||||
body["stream_options"] = json!({
|
||||
"include_usage": true,
|
||||
});
|
||||
}
|
||||
|
||||
body
|
||||
}
|
||||
|
||||
async fn parse_streaming_response(
|
||||
&self,
|
||||
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
|
||||
tx: mpsc::Sender<Result<CompletionChunk>>,
|
||||
) -> Option<Usage> {
|
||||
let mut buffer = String::new();
|
||||
let mut accumulated_content = String::new();
|
||||
let mut accumulated_usage: Option<Usage> = None;
|
||||
let mut current_tool_calls: Vec<OpenAIStreamingToolCall> = Vec::new();
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
let chunk_str = match std::str::from_utf8(&chunk) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
error!("Failed to parse chunk as UTF-8: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
buffer.push_str(chunk_str);
|
||||
|
||||
// Process complete lines
|
||||
while let Some(line_end) = buffer.find('\n') {
|
||||
let line = buffer[..line_end].trim().to_string();
|
||||
buffer.drain(..line_end + 1);
|
||||
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse Server-Sent Events format
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
if data == "[DONE]" {
|
||||
debug!("Received stream completion marker");
|
||||
|
||||
// Send final chunk with accumulated content and tool calls
|
||||
if !accumulated_content.is_empty() || !current_tool_calls.is_empty() {
|
||||
let tool_calls = if current_tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
current_tool_calls
|
||||
.iter()
|
||||
.filter_map(|tc| tc.to_tool_call())
|
||||
.collect(),
|
||||
)
|
||||
};
|
||||
|
||||
let final_chunk = CompletionChunk {
|
||||
content: accumulated_content.clone(),
|
||||
finished: true,
|
||||
tool_calls,
|
||||
usage: accumulated_usage.clone(),
|
||||
};
|
||||
let _ = tx.send(Ok(final_chunk)).await;
|
||||
}
|
||||
|
||||
return accumulated_usage;
|
||||
}
|
||||
|
||||
// Parse the JSON data
|
||||
match serde_json::from_str::<OpenAIStreamChunk>(data) {
|
||||
Ok(chunk_data) => {
|
||||
// Handle content
|
||||
for choice in &chunk_data.choices {
|
||||
if let Some(content) = &choice.delta.content {
|
||||
accumulated_content.push_str(content);
|
||||
|
||||
let chunk = CompletionChunk {
|
||||
content: content.clone(),
|
||||
finished: false,
|
||||
tool_calls: None,
|
||||
usage: None,
|
||||
};
|
||||
if tx.send(Ok(chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
return accumulated_usage;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
if let Some(delta_tool_calls) = &choice.delta.tool_calls {
|
||||
for delta_tool_call in delta_tool_calls {
|
||||
if let Some(index) = delta_tool_call.index {
|
||||
// Ensure we have enough tool calls in our vector
|
||||
while current_tool_calls.len() <= index {
|
||||
current_tool_calls
|
||||
.push(OpenAIStreamingToolCall::default());
|
||||
}
|
||||
|
||||
let tool_call = &mut current_tool_calls[index];
|
||||
|
||||
if let Some(id) = &delta_tool_call.id {
|
||||
tool_call.id = Some(id.clone());
|
||||
}
|
||||
|
||||
if let Some(function) = &delta_tool_call.function {
|
||||
if let Some(name) = &function.name {
|
||||
tool_call.name = Some(name.clone());
|
||||
}
|
||||
if let Some(arguments) = &function.arguments {
|
||||
tool_call.arguments.push_str(arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle usage
|
||||
if let Some(usage) = chunk_data.usage {
|
||||
accumulated_usage = Some(Usage {
|
||||
prompt_tokens: usage.prompt_tokens,
|
||||
completion_tokens: usage.completion_tokens,
|
||||
total_tokens: usage.total_tokens,
|
||||
});
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Stream error: {}", e);
|
||||
let _ = tx.send(Err(anyhow::anyhow!("Stream error: {}", e))).await;
|
||||
return accumulated_usage;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send final chunk if we haven't already
|
||||
let tool_calls = if current_tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
current_tool_calls
|
||||
.iter()
|
||||
.filter_map(|tc| tc.to_tool_call())
|
||||
.collect(),
|
||||
)
|
||||
};
|
||||
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
tool_calls,
|
||||
usage: accumulated_usage.clone(),
|
||||
};
|
||||
let _ = tx.send(Ok(final_chunk)).await;
|
||||
|
||||
accumulated_usage
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for OpenAIProvider {
|
||||
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
|
||||
debug!(
|
||||
"Processing OpenAI completion request with {} messages",
|
||||
request.messages.len()
|
||||
);
|
||||
|
||||
let body = self.create_request_body(
|
||||
&request.messages,
|
||||
request.tools.as_deref(),
|
||||
false,
|
||||
request.max_tokens,
|
||||
request.temperature,
|
||||
);
|
||||
|
||||
debug!("Sending request to OpenAI API: model={}", self.model);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&format!("{}/chat/completions", self.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text));
|
||||
}
|
||||
|
||||
let openai_response: OpenAIResponse = response.json().await?;
|
||||
|
||||
let content = openai_response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.content.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
let usage = Usage {
|
||||
prompt_tokens: openai_response.usage.prompt_tokens,
|
||||
completion_tokens: openai_response.usage.completion_tokens,
|
||||
total_tokens: openai_response.usage.total_tokens,
|
||||
};
|
||||
|
||||
debug!(
|
||||
"OpenAI completion successful: {} tokens generated",
|
||||
usage.completion_tokens
|
||||
);
|
||||
|
||||
Ok(CompletionResponse {
|
||||
content,
|
||||
usage,
|
||||
model: self.model.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
|
||||
debug!(
|
||||
"Processing OpenAI streaming request with {} messages",
|
||||
request.messages.len()
|
||||
);
|
||||
|
||||
let body = self.create_request_body(
|
||||
&request.messages,
|
||||
request.tools.as_deref(),
|
||||
true,
|
||||
request.max_tokens,
|
||||
request.temperature,
|
||||
);
|
||||
|
||||
debug!("Sending streaming request to OpenAI API: model={}", self.model);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&format!("{}/chat/completions", self.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text));
|
||||
}
|
||||
|
||||
let stream = response.bytes_stream();
|
||||
let (tx, rx) = mpsc::channel(100);
|
||||
|
||||
// Spawn task to process the stream
|
||||
let provider = self.clone();
|
||||
tokio::spawn(async move {
|
||||
let usage = provider.parse_streaming_response(stream, tx).await;
|
||||
// Log the final usage if available
|
||||
if let Some(usage) = usage {
|
||||
debug!(
|
||||
"Stream completed with usage - prompt: {}, completion: {}, total: {}",
|
||||
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(ReceiverStream::new(rx))
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"openai"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
fn has_native_tool_calling(&self) -> bool {
|
||||
// OpenAI models support native tool calling
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_messages(messages: &[Message]) -> Vec<serde_json::Value> {
|
||||
messages
|
||||
.iter()
|
||||
.map(|msg| {
|
||||
json!({
|
||||
"role": match msg.role {
|
||||
MessageRole::System => "system",
|
||||
MessageRole::User => "user",
|
||||
MessageRole::Assistant => "assistant",
|
||||
},
|
||||
"content": msg.content,
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn convert_tools(tools: &[Tool]) -> Vec<serde_json::Value> {
|
||||
tools
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.input_schema,
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// OpenAI API response structures
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAIResponse {
|
||||
choices: Vec<OpenAIChoice>,
|
||||
usage: OpenAIUsage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAIChoice {
|
||||
message: OpenAIMessage,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAIMessage {
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
tool_calls: Option<Vec<OpenAIToolCall>>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAIToolCall {
|
||||
id: String,
|
||||
function: OpenAIFunction,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAIFunction {
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
// Streaming tool call accumulator
|
||||
#[derive(Debug, Default)]
|
||||
struct OpenAIStreamingToolCall {
|
||||
id: Option<String>,
|
||||
name: Option<String>,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
impl OpenAIStreamingToolCall {
|
||||
fn to_tool_call(&self) -> Option<ToolCall> {
|
||||
let id = self.id.as_ref()?;
|
||||
let name = self.name.as_ref()?;
|
||||
|
||||
let args = serde_json::from_str(&self.arguments).unwrap_or(serde_json::Value::Null);
|
||||
|
||||
Some(ToolCall {
|
||||
id: id.clone(),
|
||||
tool: name.clone(),
|
||||
args,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAIUsage {
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
total_tokens: u32,
|
||||
}
|
||||
|
||||
// Streaming response structures
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAIStreamChunk {
|
||||
choices: Vec<OpenAIStreamChoice>,
|
||||
usage: Option<OpenAIUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAIStreamChoice {
|
||||
delta: OpenAIDelta,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAIDelta {
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
tool_calls: Option<Vec<OpenAIDeltaToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAIDeltaToolCall {
|
||||
index: Option<usize>,
|
||||
id: Option<String>,
|
||||
function: Option<OpenAIDeltaFunction>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAIDeltaFunction {
|
||||
name: Option<String>,
|
||||
arguments: Option<String>,
|
||||
}
|
||||
Reference in New Issue
Block a user