error handling in autonomous mode
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -856,6 +856,7 @@ dependencies = [
|
|||||||
"g3-execution",
|
"g3-execution",
|
||||||
"g3-providers",
|
"g3-providers",
|
||||||
"llama_cpp",
|
"llama_cpp",
|
||||||
|
"rand",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
|||||||
@@ -47,6 +47,15 @@ Command-line interface:
|
|||||||
- Configuration management commands
|
- Configuration management commands
|
||||||
- Session management
|
- Session management
|
||||||
|
|
||||||
|
### Error Handling & Resilience
|
||||||
|
|
||||||
|
G3 includes robust error handling with automatic retry logic:
|
||||||
|
- **Recoverable Error Detection**: Automatically identifies recoverable errors (rate limits, network issues, server errors, timeouts)
|
||||||
|
- **Exponential Backoff with Jitter**: Implements intelligent retry delays to avoid overwhelming services
|
||||||
|
- **Detailed Error Logging**: Captures comprehensive error context including stack traces, request/response data, and session information
|
||||||
|
- **Error Persistence**: Saves detailed error logs to `logs/errors/` for post-mortem analysis
|
||||||
|
- **Graceful Degradation**: Non-recoverable errors are logged with full context before terminating
|
||||||
|
|
||||||
## Key Features
|
## Key Features
|
||||||
|
|
||||||
### Intelligent Context Management
|
### Intelligent Context Management
|
||||||
|
|||||||
@@ -23,3 +23,4 @@ shellexpand = "3.1"
|
|||||||
tokio-util = "0.7"
|
tokio-util = "0.7"
|
||||||
futures-util = "0.3"
|
futures-util = "0.3"
|
||||||
chrono = { version = "0.4", features = ["serde"] }
|
chrono = { version = "0.4", features = ["serde"] }
|
||||||
|
rand = "0.8"
|
||||||
|
|||||||
399
crates/g3-core/src/error_handling.rs
Normal file
399
crates/g3-core/src/error_handling.rs
Normal file
@@ -0,0 +1,399 @@
|
|||||||
|
//! Error handling module for G3 with retry logic and detailed logging
|
||||||
|
//!
|
||||||
|
//! This module provides:
|
||||||
|
//! - Classification of errors as recoverable or non-recoverable
|
||||||
|
//! - Retry logic with exponential backoff and jitter for recoverable errors
|
||||||
|
//! - Detailed error logging with context information
|
||||||
|
//! - Request/response capture for debugging
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::time::Duration;
|
||||||
|
use tracing::{error, info, warn};
|
||||||
|
|
||||||
|
/// Maximum number of retry attempts for recoverable errors
|
||||||
|
const MAX_RETRY_ATTEMPTS: u32 = 3;
|
||||||
|
|
||||||
|
/// Base delay for exponential backoff (in milliseconds)
|
||||||
|
const BASE_RETRY_DELAY_MS: u64 = 1000;
|
||||||
|
|
||||||
|
/// Maximum delay between retries (in milliseconds)
|
||||||
|
const MAX_RETRY_DELAY_MS: u64 = 10000;
|
||||||
|
|
||||||
|
/// Jitter factor (0.0 to 1.0) to randomize retry delays
|
||||||
|
const JITTER_FACTOR: f64 = 0.3;
|
||||||
|
|
||||||
|
/// Error context information for detailed logging
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ErrorContext {
|
||||||
|
/// The operation that was being performed
|
||||||
|
pub operation: String,
|
||||||
|
/// The provider being used
|
||||||
|
pub provider: String,
|
||||||
|
/// The model being used
|
||||||
|
pub model: String,
|
||||||
|
/// The last prompt sent (truncated for logging)
|
||||||
|
pub last_prompt: String,
|
||||||
|
/// Raw request data (if available)
|
||||||
|
pub raw_request: Option<String>,
|
||||||
|
/// Raw response data (if available)
|
||||||
|
pub raw_response: Option<String>,
|
||||||
|
/// Stack trace
|
||||||
|
pub stack_trace: String,
|
||||||
|
/// Timestamp
|
||||||
|
pub timestamp: u64,
|
||||||
|
/// Number of tokens in context
|
||||||
|
pub context_tokens: u32,
|
||||||
|
/// Session ID if available
|
||||||
|
pub session_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ErrorContext {
|
||||||
|
pub fn new(
|
||||||
|
operation: String,
|
||||||
|
provider: String,
|
||||||
|
model: String,
|
||||||
|
last_prompt: String,
|
||||||
|
session_id: Option<String>,
|
||||||
|
context_tokens: u32,
|
||||||
|
) -> Self {
|
||||||
|
let timestamp = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs();
|
||||||
|
|
||||||
|
// Capture stack trace
|
||||||
|
let stack_trace = std::backtrace::Backtrace::force_capture().to_string();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
operation,
|
||||||
|
provider,
|
||||||
|
model,
|
||||||
|
last_prompt: truncate_for_logging(&last_prompt, 1000),
|
||||||
|
raw_request: None,
|
||||||
|
raw_response: None,
|
||||||
|
stack_trace,
|
||||||
|
timestamp,
|
||||||
|
context_tokens,
|
||||||
|
session_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_request(mut self, request: String) -> Self {
|
||||||
|
self.raw_request = Some(truncate_for_logging(&request, 5000));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_response(mut self, response: String) -> Self {
|
||||||
|
self.raw_response = Some(truncate_for_logging(&response, 5000));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Log the error context with ERROR level
|
||||||
|
pub fn log_error(&self, error: &anyhow::Error) {
|
||||||
|
error!("=== G3 ERROR DETAILS ===");
|
||||||
|
error!("Operation: {}", self.operation);
|
||||||
|
error!("Provider: {} | Model: {}", self.provider, self.model);
|
||||||
|
error!("Error: {}", error);
|
||||||
|
error!("Timestamp: {}", self.timestamp);
|
||||||
|
error!("Session ID: {:?}", self.session_id);
|
||||||
|
error!("Context Tokens: {}", self.context_tokens);
|
||||||
|
error!("Last Prompt: {}", self.last_prompt);
|
||||||
|
|
||||||
|
if let Some(ref req) = self.raw_request {
|
||||||
|
error!("Raw Request: {}", req);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref resp) = self.raw_response {
|
||||||
|
error!("Raw Response: {}", resp);
|
||||||
|
}
|
||||||
|
|
||||||
|
error!("Stack Trace:\n{}", self.stack_trace);
|
||||||
|
error!("=== END ERROR DETAILS ===");
|
||||||
|
|
||||||
|
// Also save to error log file
|
||||||
|
self.save_to_file();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Save error context to a file for later analysis
|
||||||
|
fn save_to_file(&self) {
|
||||||
|
let logs_dir = std::path::Path::new("logs/errors");
|
||||||
|
if !logs_dir.exists() {
|
||||||
|
if let Err(e) = std::fs::create_dir_all(logs_dir) {
|
||||||
|
error!("Failed to create error logs directory: {}", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let filename = format!(
|
||||||
|
"logs/errors/error_{}_{}.json",
|
||||||
|
self.timestamp,
|
||||||
|
self.session_id.as_deref().unwrap_or("unknown")
|
||||||
|
);
|
||||||
|
|
||||||
|
match serde_json::to_string_pretty(self) {
|
||||||
|
Ok(json_content) => {
|
||||||
|
if let Err(e) = std::fs::write(&filename, json_content) {
|
||||||
|
error!("Failed to save error context to {}: {}", filename, e);
|
||||||
|
} else {
|
||||||
|
info!("Error details saved to: {}", filename);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to serialize error context: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Classification of error types
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub enum ErrorType {
|
||||||
|
/// Recoverable errors that should be retried
|
||||||
|
Recoverable(RecoverableError),
|
||||||
|
/// Non-recoverable errors that should terminate execution
|
||||||
|
NonRecoverable,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Types of recoverable errors
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub enum RecoverableError {
|
||||||
|
/// Rate limit exceeded
|
||||||
|
RateLimit,
|
||||||
|
/// Temporary network error
|
||||||
|
NetworkError,
|
||||||
|
/// Server error (5xx)
|
||||||
|
ServerError,
|
||||||
|
/// Model is busy/overloaded
|
||||||
|
ModelBusy,
|
||||||
|
/// Timeout
|
||||||
|
Timeout,
|
||||||
|
/// Token limit exceeded (might be recoverable with summarization)
|
||||||
|
TokenLimit,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Classify an error as recoverable or non-recoverable
|
||||||
|
pub fn classify_error(error: &anyhow::Error) -> ErrorType {
|
||||||
|
let error_str = error.to_string().to_lowercase();
|
||||||
|
|
||||||
|
// Check for recoverable error patterns
|
||||||
|
if error_str.contains("rate limit") || error_str.contains("rate_limit") || error_str.contains("429") {
|
||||||
|
return ErrorType::Recoverable(RecoverableError::RateLimit);
|
||||||
|
}
|
||||||
|
|
||||||
|
if error_str.contains("network") || error_str.contains("connection") ||
|
||||||
|
error_str.contains("dns") || error_str.contains("refused") {
|
||||||
|
return ErrorType::Recoverable(RecoverableError::NetworkError);
|
||||||
|
}
|
||||||
|
|
||||||
|
if error_str.contains("500") || error_str.contains("502") ||
|
||||||
|
error_str.contains("503") || error_str.contains("504") ||
|
||||||
|
error_str.contains("server error") || error_str.contains("internal error") {
|
||||||
|
return ErrorType::Recoverable(RecoverableError::ServerError);
|
||||||
|
}
|
||||||
|
|
||||||
|
if error_str.contains("busy") || error_str.contains("overloaded") ||
|
||||||
|
error_str.contains("capacity") || error_str.contains("unavailable") {
|
||||||
|
return ErrorType::Recoverable(RecoverableError::ModelBusy);
|
||||||
|
}
|
||||||
|
|
||||||
|
if error_str.contains("timeout") || error_str.contains("timed out") {
|
||||||
|
return ErrorType::Recoverable(RecoverableError::Timeout);
|
||||||
|
}
|
||||||
|
|
||||||
|
if error_str.contains("token") && (error_str.contains("limit") || error_str.contains("exceeded")) {
|
||||||
|
return ErrorType::Recoverable(RecoverableError::TokenLimit);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to non-recoverable for unknown errors
|
||||||
|
ErrorType::NonRecoverable
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate retry delay with exponential backoff and jitter
|
||||||
|
pub fn calculate_retry_delay(attempt: u32) -> Duration {
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
// Exponential backoff: delay = base * 2^attempt
|
||||||
|
let base_delay = BASE_RETRY_DELAY_MS * (2_u64.pow(attempt.saturating_sub(1)));
|
||||||
|
let capped_delay = base_delay.min(MAX_RETRY_DELAY_MS);
|
||||||
|
|
||||||
|
// Add jitter to prevent thundering herd
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
let jitter = (capped_delay as f64 * JITTER_FACTOR * rng.gen::<f64>()) as u64;
|
||||||
|
let final_delay = if rng.gen_bool(0.5) {
|
||||||
|
capped_delay + jitter
|
||||||
|
} else {
|
||||||
|
capped_delay.saturating_sub(jitter)
|
||||||
|
};
|
||||||
|
|
||||||
|
Duration::from_millis(final_delay)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retry logic for async operations
|
||||||
|
pub async fn retry_with_backoff<F, Fut, T>(
|
||||||
|
operation_name: &str,
|
||||||
|
mut operation: F,
|
||||||
|
context: &ErrorContext,
|
||||||
|
) -> Result<T>
|
||||||
|
where
|
||||||
|
F: FnMut() -> Fut,
|
||||||
|
Fut: std::future::Future<Output = Result<T>>,
|
||||||
|
{
|
||||||
|
let mut attempt = 0;
|
||||||
|
let mut _last_error = None;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
attempt += 1;
|
||||||
|
|
||||||
|
match operation().await {
|
||||||
|
Ok(result) => {
|
||||||
|
if attempt > 1 {
|
||||||
|
info!(
|
||||||
|
"Operation '{}' succeeded after {} attempts",
|
||||||
|
operation_name, attempt
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return Ok(result);
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
let error_type = classify_error(&error);
|
||||||
|
|
||||||
|
match error_type {
|
||||||
|
ErrorType::Recoverable(recoverable_type) => {
|
||||||
|
if attempt >= MAX_RETRY_ATTEMPTS {
|
||||||
|
error!(
|
||||||
|
"Operation '{}' failed after {} attempts. Giving up.",
|
||||||
|
operation_name, attempt
|
||||||
|
);
|
||||||
|
context.clone().log_error(&error);
|
||||||
|
return Err(error);
|
||||||
|
}
|
||||||
|
|
||||||
|
let delay = calculate_retry_delay(attempt);
|
||||||
|
warn!(
|
||||||
|
"Recoverable error ({:?}) in '{}' (attempt {}/{}). Retrying in {:?}...",
|
||||||
|
recoverable_type, operation_name, attempt, MAX_RETRY_ATTEMPTS, delay
|
||||||
|
);
|
||||||
|
warn!("Error details: {}", error);
|
||||||
|
|
||||||
|
// Special handling for token limit errors
|
||||||
|
if matches!(recoverable_type, RecoverableError::TokenLimit) {
|
||||||
|
info!("Token limit error detected. Consider triggering summarization.");
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::time::sleep(delay).await;
|
||||||
|
_last_error = Some(error);
|
||||||
|
}
|
||||||
|
ErrorType::NonRecoverable => {
|
||||||
|
error!(
|
||||||
|
"Non-recoverable error in '{}' (attempt {}). Terminating.",
|
||||||
|
operation_name, attempt
|
||||||
|
);
|
||||||
|
context.clone().log_error(&error);
|
||||||
|
return Err(error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to truncate strings for logging
|
||||||
|
fn truncate_for_logging(s: &str, max_len: usize) -> String {
|
||||||
|
if s.len() <= max_len {
|
||||||
|
s.to_string()
|
||||||
|
} else {
|
||||||
|
format!("{}... (truncated, {} total chars)", &s[..max_len], s.len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Macro for creating error context easily
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! error_context {
|
||||||
|
($operation:expr, $provider:expr, $model:expr, $prompt:expr, $session_id:expr, $tokens:expr) => {
|
||||||
|
$crate::error_handling::ErrorContext::new(
|
||||||
|
$operation.to_string(),
|
||||||
|
$provider.to_string(),
|
||||||
|
$model.to_string(),
|
||||||
|
$prompt.to_string(),
|
||||||
|
$session_id,
|
||||||
|
$tokens,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_error_classification() {
|
||||||
|
// Rate limit errors
|
||||||
|
let error = anyhow!("Rate limit exceeded");
|
||||||
|
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::RateLimit));
|
||||||
|
|
||||||
|
let error = anyhow!("HTTP 429 Too Many Requests");
|
||||||
|
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::RateLimit));
|
||||||
|
|
||||||
|
// Network errors
|
||||||
|
let error = anyhow!("Network connection failed");
|
||||||
|
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::NetworkError));
|
||||||
|
|
||||||
|
// Server errors
|
||||||
|
let error = anyhow!("HTTP 503 Service Unavailable");
|
||||||
|
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::ServerError));
|
||||||
|
|
||||||
|
// Model busy
|
||||||
|
let error = anyhow!("Model is busy, please try again");
|
||||||
|
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::ModelBusy));
|
||||||
|
|
||||||
|
// Timeout
|
||||||
|
let error = anyhow!("Request timed out");
|
||||||
|
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::Timeout));
|
||||||
|
|
||||||
|
// Token limit
|
||||||
|
let error = anyhow!("Token limit exceeded");
|
||||||
|
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::TokenLimit));
|
||||||
|
|
||||||
|
// Non-recoverable
|
||||||
|
let error = anyhow!("Invalid API key");
|
||||||
|
assert_eq!(classify_error(&error), ErrorType::NonRecoverable);
|
||||||
|
|
||||||
|
let error = anyhow!("Malformed request");
|
||||||
|
assert_eq!(classify_error(&error), ErrorType::NonRecoverable);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_retry_delay_calculation() {
|
||||||
|
// Test that delays increase exponentially
|
||||||
|
let delay1 = calculate_retry_delay(1);
|
||||||
|
let delay2 = calculate_retry_delay(2);
|
||||||
|
let delay3 = calculate_retry_delay(3);
|
||||||
|
|
||||||
|
// Due to jitter, we can't test exact values, but the base should increase
|
||||||
|
assert!(delay1.as_millis() >= (BASE_RETRY_DELAY_MS as f64 * 0.7) as u128);
|
||||||
|
assert!(delay1.as_millis() <= (BASE_RETRY_DELAY_MS as f64 * 1.3) as u128);
|
||||||
|
|
||||||
|
// Delay 2 should be roughly 2x delay 1 (minus jitter)
|
||||||
|
assert!(delay2.as_millis() >= delay1.as_millis());
|
||||||
|
|
||||||
|
// Delay 3 should be roughly 2x delay 2 (minus jitter)
|
||||||
|
assert!(delay3.as_millis() >= delay2.as_millis());
|
||||||
|
|
||||||
|
// Test max cap
|
||||||
|
let delay_max = calculate_retry_delay(10);
|
||||||
|
assert!(delay_max.as_millis() <= (MAX_RETRY_DELAY_MS as f64 * 1.3) as u128);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_truncate_for_logging() {
|
||||||
|
let short_text = "Hello, world!";
|
||||||
|
assert_eq!(truncate_for_logging(short_text, 20), "Hello, world!");
|
||||||
|
|
||||||
|
let long_text = "This is a very long text that should be truncated for logging purposes";
|
||||||
|
let truncated = truncate_for_logging(long_text, 20);
|
||||||
|
assert!(truncated.starts_with("This is a very long "));
|
||||||
|
assert!(truncated.contains("truncated"));
|
||||||
|
assert!(truncated.contains("total chars"));
|
||||||
|
}
|
||||||
|
}
|
||||||
148
crates/g3-core/src/error_handling_test.rs
Normal file
148
crates/g3-core/src/error_handling_test.rs
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
//! Integration tests for error handling with retry logic
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::super::error_handling::*;
|
||||||
|
use anyhow::anyhow;
|
||||||
|
use std::sync::atomic::{AtomicU32, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_retry_with_recoverable_error() {
|
||||||
|
let attempt_count = Arc::new(AtomicU32::new(0));
|
||||||
|
|
||||||
|
let context = ErrorContext::new(
|
||||||
|
"test_operation".to_string(),
|
||||||
|
"test_provider".to_string(),
|
||||||
|
"test_model".to_string(),
|
||||||
|
"test prompt".to_string(),
|
||||||
|
None,
|
||||||
|
100,
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = retry_with_backoff(
|
||||||
|
"test_operation",
|
||||||
|
|| {
|
||||||
|
let counter = Arc::clone(&attempt_count);
|
||||||
|
async move {
|
||||||
|
let count = counter.fetch_add(1, Ordering::SeqCst);
|
||||||
|
if count < 2 {
|
||||||
|
// Fail with recoverable error on first two attempts
|
||||||
|
Err(anyhow!("Rate limit exceeded"))
|
||||||
|
} else {
|
||||||
|
// Succeed on third attempt
|
||||||
|
Ok("Success")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
&context,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
assert_eq!(result.unwrap(), "Success");
|
||||||
|
assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_retry_with_non_recoverable_error() {
|
||||||
|
let attempt_count = Arc::new(AtomicU32::new(0));
|
||||||
|
|
||||||
|
let context = ErrorContext::new(
|
||||||
|
"test_operation".to_string(),
|
||||||
|
"test_provider".to_string(),
|
||||||
|
"test_model".to_string(),
|
||||||
|
"test prompt".to_string(),
|
||||||
|
None,
|
||||||
|
100,
|
||||||
|
);
|
||||||
|
|
||||||
|
let result: Result<&str, _> = retry_with_backoff(
|
||||||
|
"test_operation",
|
||||||
|
|| {
|
||||||
|
let counter = Arc::clone(&attempt_count);
|
||||||
|
async move {
|
||||||
|
counter.fetch_add(1, Ordering::SeqCst);
|
||||||
|
// Always fail with non-recoverable error
|
||||||
|
Err(anyhow!("Invalid API key"))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
&context,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert_eq!(attempt_count.load(Ordering::SeqCst), 1); // Should only try once
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_retry_exhaustion() {
|
||||||
|
let attempt_count = Arc::new(AtomicU32::new(0));
|
||||||
|
|
||||||
|
let context = ErrorContext::new(
|
||||||
|
"test_operation".to_string(),
|
||||||
|
"test_provider".to_string(),
|
||||||
|
"test_model".to_string(),
|
||||||
|
"test prompt".to_string(),
|
||||||
|
None,
|
||||||
|
100,
|
||||||
|
);
|
||||||
|
|
||||||
|
let result: Result<&str, _> = retry_with_backoff(
|
||||||
|
"test_operation",
|
||||||
|
|| {
|
||||||
|
let counter = Arc::clone(&attempt_count);
|
||||||
|
async move {
|
||||||
|
counter.fetch_add(1, Ordering::SeqCst);
|
||||||
|
// Always fail with recoverable error
|
||||||
|
Err(anyhow!("Network connection failed"))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
&context,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert_eq!(attempt_count.load(Ordering::SeqCst), 3); // Should try MAX_RETRY_ATTEMPTS times
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_error_context_truncation() {
|
||||||
|
let long_prompt = "a".repeat(2000);
|
||||||
|
let context = ErrorContext::new(
|
||||||
|
"test_op".to_string(),
|
||||||
|
"provider".to_string(),
|
||||||
|
"model".to_string(),
|
||||||
|
long_prompt,
|
||||||
|
None,
|
||||||
|
100,
|
||||||
|
);
|
||||||
|
|
||||||
|
// The prompt should be truncated to 1000 chars
|
||||||
|
assert!(context.last_prompt.len() < 1100); // Some buffer for the truncation message
|
||||||
|
assert!(context.last_prompt.contains("truncated"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_retry_delay_increases() {
|
||||||
|
let delay1 = calculate_retry_delay(1);
|
||||||
|
let delay2 = calculate_retry_delay(2);
|
||||||
|
let delay3 = calculate_retry_delay(3);
|
||||||
|
|
||||||
|
// Delays should generally increase (though jitter can affect this)
|
||||||
|
// We'll test the base delays without jitter
|
||||||
|
let base1 = 1000u64; // BASE_RETRY_DELAY_MS
|
||||||
|
let base2 = 1000u64 * 2;
|
||||||
|
let base3 = 1000u64 * 4;
|
||||||
|
|
||||||
|
// Check that delays are within expected ranges (accounting for jitter)
|
||||||
|
assert!(delay1.as_millis() >= (base1 as f64 * 0.7) as u128);
|
||||||
|
assert!(delay1.as_millis() <= (base1 as f64 * 1.3) as u128);
|
||||||
|
|
||||||
|
assert!(delay2.as_millis() >= (base2 as f64 * 0.7) as u128);
|
||||||
|
assert!(delay2.as_millis() <= (base2 as f64 * 1.3) as u128);
|
||||||
|
|
||||||
|
assert!(delay3.as_millis() >= (base3 as f64 * 0.7) as u128);
|
||||||
|
assert!(delay3.as_millis() <= (base3 as f64 * 1.3) as u128);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,8 @@
|
|||||||
pub mod project;
|
pub mod project;
|
||||||
|
pub mod error_handling;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod error_handling_test;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use g3_config::Config;
|
use g3_config::Config;
|
||||||
use g3_execution::CodeExecutor;
|
use g3_execution::CodeExecutor;
|
||||||
@@ -965,10 +969,51 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Helper method to stream with retry logic
|
||||||
|
async fn stream_with_retry(
|
||||||
|
&self,
|
||||||
|
request: &CompletionRequest,
|
||||||
|
error_context: &error_handling::ErrorContext,
|
||||||
|
) -> Result<g3_providers::CompletionStream> {
|
||||||
|
use crate::error_handling::{classify_error, calculate_retry_delay, ErrorType};
|
||||||
|
|
||||||
|
let mut attempt = 0;
|
||||||
|
const MAX_ATTEMPTS: u32 = 3;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
attempt += 1;
|
||||||
|
let provider = self.providers.get(None)?;
|
||||||
|
|
||||||
|
match provider.stream(request.clone()).await {
|
||||||
|
Ok(stream) => {
|
||||||
|
if attempt > 1 {
|
||||||
|
info!("Stream started successfully after {} attempts", attempt);
|
||||||
|
}
|
||||||
|
return Ok(stream);
|
||||||
|
}
|
||||||
|
Err(e) if attempt < MAX_ATTEMPTS => {
|
||||||
|
if matches!(classify_error(&e), ErrorType::Recoverable(_)) {
|
||||||
|
let delay = calculate_retry_delay(attempt);
|
||||||
|
warn!("Recoverable error on attempt {}/{}: {}. Retrying in {:?}...", attempt, MAX_ATTEMPTS, e, delay);
|
||||||
|
tokio::time::sleep(delay).await;
|
||||||
|
} else {
|
||||||
|
error_context.clone().log_error(&e);
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error_context.clone().log_error(&e);
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn stream_completion_with_tools(
|
async fn stream_completion_with_tools(
|
||||||
&mut self,
|
&mut self,
|
||||||
mut request: CompletionRequest,
|
mut request: CompletionRequest,
|
||||||
) -> Result<(String, Duration)> {
|
) -> Result<(String, Duration)> {
|
||||||
|
use crate::error_handling::ErrorContext;
|
||||||
use std::io::{self, Write};
|
use std::io::{self, Write};
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
|
|
||||||
@@ -1119,16 +1164,39 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
|
|
||||||
let provider = self.providers.get(None)?;
|
let provider = self.providers.get(None)?;
|
||||||
debug!("Got provider: {}", provider.name());
|
debug!("Got provider: {}", provider.name());
|
||||||
let mut stream = match provider.stream(request.clone()).await {
|
|
||||||
|
// Create error context for detailed logging
|
||||||
|
let last_prompt = request.messages
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.find(|m| matches!(m.role, MessageRole::User))
|
||||||
|
.map(|m| m.content.clone())
|
||||||
|
.unwrap_or_else(|| "No user message found".to_string());
|
||||||
|
|
||||||
|
let error_context = ErrorContext::new(
|
||||||
|
"stream_completion".to_string(),
|
||||||
|
provider.name().to_string(),
|
||||||
|
provider.model().to_string(),
|
||||||
|
last_prompt,
|
||||||
|
self.session_id.clone(),
|
||||||
|
self.context_window.used_tokens,
|
||||||
|
).with_request(
|
||||||
|
serde_json::to_string(&request).unwrap_or_else(|_| "Failed to serialize request".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
|
// Try to get stream with retry logic
|
||||||
|
let mut stream = match self.stream_with_retry(&request, &error_context).await {
|
||||||
Ok(s) => s,
|
Ok(s) => s,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
// Additional retry for "busy" errors on subsequent iterations
|
||||||
if iteration_count > 1 && e.to_string().contains("busy") {
|
if iteration_count > 1 && e.to_string().contains("busy") {
|
||||||
warn!(
|
warn!(
|
||||||
"Model busy on iteration {}, retrying in 500ms",
|
"Model busy on iteration {}, attempting one more retry in 500ms",
|
||||||
iteration_count
|
iteration_count
|
||||||
);
|
);
|
||||||
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
||||||
match provider.stream(request.clone()).await {
|
|
||||||
|
match self.stream_with_retry(&request, &error_context).await {
|
||||||
Ok(s) => s,
|
Ok(s) => s,
|
||||||
Err(e2) => {
|
Err(e2) => {
|
||||||
error!("Failed to start stream after retry: {}", e2);
|
error!("Failed to start stream after retry: {}", e2);
|
||||||
|
|||||||
Reference in New Issue
Block a user