add context window monitor

Writes the current context window to logs/current_context_window (uses a symlink to a session ID).

This PR was unfortunately generated by a different LLM and did a ton of superficial reformating, it's actually a fairly small and benign change, but I don't want to roll back everything. Hope that's ok.
This commit is contained in:
Jochen
2025-11-27 21:00:02 +11:00
parent 93dc4acf86
commit 52f78653b4
89 changed files with 4040 additions and 2576 deletions

View File

@@ -3,8 +3,8 @@ use anyhow::{anyhow, Result};
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use tree_sitter::{Language, Parser, Query, QueryCursor};
use streaming_iterator::StreamingIterator;
use tree_sitter::{Language, Parser, Query, QueryCursor};
use walkdir::WalkDir;
pub struct TreeSitterSearcher {
@@ -47,10 +47,11 @@ impl TreeSitterSearcher {
.set_language(&language)
.map_err(|e| anyhow!("Failed to set JavaScript language: {}", e))?;
parsers.insert("javascript".to_string(), parser);
// Create separate parser for "js" alias
let mut parser_js = Parser::new();
parser_js.set_language(&language)
parser_js
.set_language(&language)
.map_err(|e| anyhow!("Failed to set JavaScript language: {}", e))?;
parsers.insert("js".to_string(), parser_js);
languages.insert("javascript".to_string(), language.clone());
@@ -65,10 +66,11 @@ impl TreeSitterSearcher {
.set_language(&language)
.map_err(|e| anyhow!("Failed to set TypeScript language: {}", e))?;
parsers.insert("typescript".to_string(), parser);
// Create separate parser for "ts" alias
let mut parser_ts = Parser::new();
parser_ts.set_language(&language)
parser_ts
.set_language(&language)
.map_err(|e| anyhow!("Failed to set TypeScript language: {}", e))?;
parsers.insert("ts".to_string(), parser_ts);
languages.insert("typescript".to_string(), language.clone());
@@ -215,8 +217,8 @@ impl TreeSitterSearcher {
.ok_or_else(|| anyhow!("Language not found: {}", spec.language))?;
// Parse query
let query = Query::new(language, &spec.query)
.map_err(|e| anyhow!("Invalid query: {}", e))?;
let query =
Query::new(language, &spec.query).map_err(|e| anyhow!("Invalid query: {}", e))?;
let mut matches = Vec::new();
let mut files_searched = 0;
@@ -255,11 +257,8 @@ impl TreeSitterSearcher {
if let Ok(source_code) = fs::read_to_string(path) {
if let Some(tree) = parser.parse(&source_code, None) {
let mut cursor = QueryCursor::new();
let mut query_matches = cursor.matches(
&query,
tree.root_node(),
source_code.as_bytes(),
);
let mut query_matches =
cursor.matches(&query, tree.root_node(), source_code.as_bytes());
query_matches.advance();
while let Some(query_match) = query_matches.get() {
@@ -308,7 +307,7 @@ impl TreeSitterSearcher {
captures: captures_map,
context,
});
query_matches.advance();
}
}

View File

@@ -106,15 +106,15 @@ impl ErrorContext {
error!("Session ID: {:?}", self.session_id);
error!("Context Tokens: {}", self.context_tokens);
error!("Last Prompt: {}", self.last_prompt);
if let Some(ref req) = self.raw_request {
error!("Raw Request: {}", req);
}
if let Some(ref resp) = self.raw_response {
error!("Raw Response: {}", resp);
}
error!("Stack Trace:\n{}", self.stack_trace);
error!("=== END ERROR DETAILS ===");
@@ -191,23 +191,36 @@ pub fn classify_error(error: &anyhow::Error) -> ErrorType {
let error_str = error.to_string().to_lowercase();
// Check for recoverable error patterns
if error_str.contains("rate limit") || error_str.contains("rate_limit") || error_str.contains("429") {
if error_str.contains("rate limit")
|| error_str.contains("rate_limit")
|| error_str.contains("429")
{
return ErrorType::Recoverable(RecoverableError::RateLimit);
}
if error_str.contains("network") || error_str.contains("connection") ||
error_str.contains("dns") || error_str.contains("refused") {
if error_str.contains("network")
|| error_str.contains("connection")
|| error_str.contains("dns")
|| error_str.contains("refused")
{
return ErrorType::Recoverable(RecoverableError::NetworkError);
}
if error_str.contains("500") || error_str.contains("502") ||
error_str.contains("503") || error_str.contains("504") ||
error_str.contains("server error") || error_str.contains("internal error") {
if error_str.contains("500")
|| error_str.contains("502")
|| error_str.contains("503")
|| error_str.contains("504")
|| error_str.contains("server error")
|| error_str.contains("internal error")
{
return ErrorType::Recoverable(RecoverableError::ServerError);
}
if error_str.contains("busy") || error_str.contains("overloaded") ||
error_str.contains("capacity") || error_str.contains("unavailable") {
if error_str.contains("busy")
|| error_str.contains("overloaded")
|| error_str.contains("capacity")
|| error_str.contains("unavailable")
{
return ErrorType::Recoverable(RecoverableError::ModelBusy);
}
@@ -216,18 +229,24 @@ pub fn classify_error(error: &anyhow::Error) -> ErrorType {
error_str.contains("timed out") ||
error_str.contains("operation timed out") ||
error_str.contains("request or response body error") || // Common timeout pattern
error_str.contains("stream error") && error_str.contains("timed out") {
error_str.contains("stream error") && error_str.contains("timed out")
{
return ErrorType::Recoverable(RecoverableError::Timeout);
}
// Check for context length exceeded errors (HTTP 400 with specific messages)
if (error_str.contains("400") || error_str.contains("bad request")) &&
(error_str.contains("context length") || error_str.contains("prompt is too long") ||
error_str.contains("maximum context length") || error_str.contains("context_length_exceeded")) {
if (error_str.contains("400") || error_str.contains("bad request"))
&& (error_str.contains("context length")
|| error_str.contains("prompt is too long")
|| error_str.contains("maximum context length")
|| error_str.contains("context_length_exceeded"))
{
return ErrorType::Recoverable(RecoverableError::ContextLengthExceeded);
}
if error_str.contains("token") && (error_str.contains("limit") || error_str.contains("exceeded")) {
if error_str.contains("token")
&& (error_str.contains("limit") || error_str.contains("exceeded"))
{
return ErrorType::Recoverable(RecoverableError::TokenLimit);
}
@@ -239,12 +258,14 @@ pub fn classify_error(error: &anyhow::Error) -> ErrorType {
fn calculate_autonomous_retry_delay(attempt: u32) -> Duration {
use rand::Rng;
let mut rng = rand::thread_rng();
// Distribute 6 retries over 10 minutes (600 seconds)
// Base delays: 10s, 30s, 60s, 120s, 180s, 200s = 600s total
let base_delays_ms = [10000, 30000, 60000, 120000, 180000, 200000];
let base_delay = base_delays_ms.get(attempt.saturating_sub(1) as usize).unwrap_or(&200000);
let base_delay = base_delays_ms
.get(attempt.saturating_sub(1) as usize)
.unwrap_or(&200000);
// Add jitter of ±30% to prevent thundering herd
let jitter = (*base_delay as f64 * 0.3 * rng.gen::<f64>()) as u64;
let final_delay = if rng.gen_bool(0.5) {
@@ -252,7 +273,7 @@ fn calculate_autonomous_retry_delay(attempt: u32) -> Duration {
} else {
base_delay.saturating_sub(jitter)
};
Duration::from_millis(final_delay)
}
@@ -261,14 +282,18 @@ pub fn calculate_retry_delay(attempt: u32, is_autonomous: bool) -> Duration {
if is_autonomous {
return calculate_autonomous_retry_delay(attempt);
}
use rand::Rng;
let max_retry_delay_ms = if is_autonomous { AUTONOMOUS_MAX_RETRY_DELAY_MS } else { DEFAULT_MAX_RETRY_DELAY_MS };
let max_retry_delay_ms = if is_autonomous {
AUTONOMOUS_MAX_RETRY_DELAY_MS
} else {
DEFAULT_MAX_RETRY_DELAY_MS
};
// Exponential backoff: delay = base * 2^attempt
let base_delay = BASE_RETRY_DELAY_MS * (2_u64.pow(attempt.saturating_sub(1)));
let capped_delay = base_delay.min(max_retry_delay_ms);
// Add jitter to prevent thundering herd
let mut rng = rand::thread_rng();
let jitter = (capped_delay as f64 * JITTER_FACTOR * rng.gen::<f64>()) as u64;
@@ -277,7 +302,7 @@ pub fn calculate_retry_delay(attempt: u32, is_autonomous: bool) -> Duration {
} else {
capped_delay.saturating_sub(jitter)
};
Duration::from_millis(final_delay)
}
@@ -298,7 +323,7 @@ where
loop {
attempt += 1;
match operation().await {
Ok(result) => {
if attempt > 1 {
@@ -321,19 +346,19 @@ where
context.clone().log_error(&error);
return Err(error);
}
let delay = calculate_retry_delay(attempt, is_autonomous);
warn!(
"Recoverable error ({:?}) in '{}' (attempt {}/{}). Retrying in {:?}...",
recoverable_type, operation_name, attempt, max_attempts, delay
);
warn!("Error details: {}", error);
// Special handling for token limit errors
if matches!(recoverable_type, RecoverableError::TokenLimit) {
info!("Token limit error detected. Consider triggering summarization.");
}
tokio::time::sleep(delay).await;
_last_error = Some(error);
}
@@ -359,18 +384,22 @@ fn truncate_for_logging(s: &str, max_len: usize) -> String {
// Find a safe UTF-8 boundary to truncate at
// We need to ensure we don't cut in the middle of a multi-byte character
let mut truncate_at = max_len;
// Walk backwards from max_len to find a character boundary
while truncate_at > 0 && !s.is_char_boundary(truncate_at) {
truncate_at -= 1;
}
// If we couldn't find a boundary (shouldn't happen), use a safe default
if truncate_at == 0 {
truncate_at = max_len.min(s.len());
}
format!("{}... (truncated, {} total bytes)", &s[..truncate_at], s.len())
format!(
"{}... (truncated, {} total bytes)",
&s[..truncate_at],
s.len()
)
}
}
@@ -398,42 +427,69 @@ mod tests {
fn test_error_classification() {
// Rate limit errors
let error = anyhow!("Rate limit exceeded");
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::RateLimit));
assert_eq!(
classify_error(&error),
ErrorType::Recoverable(RecoverableError::RateLimit)
);
let error = anyhow!("HTTP 429 Too Many Requests");
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::RateLimit));
assert_eq!(
classify_error(&error),
ErrorType::Recoverable(RecoverableError::RateLimit)
);
// Network errors
let error = anyhow!("Network connection failed");
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::NetworkError));
assert_eq!(
classify_error(&error),
ErrorType::Recoverable(RecoverableError::NetworkError)
);
// Server errors
let error = anyhow!("HTTP 503 Service Unavailable");
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::ServerError));
assert_eq!(
classify_error(&error),
ErrorType::Recoverable(RecoverableError::ServerError)
);
// Model busy
let error = anyhow!("Model is busy, please try again");
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::ModelBusy));
assert_eq!(
classify_error(&error),
ErrorType::Recoverable(RecoverableError::ModelBusy)
);
// Timeout
let error = anyhow!("Request timed out");
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::Timeout));
assert_eq!(
classify_error(&error),
ErrorType::Recoverable(RecoverableError::Timeout)
);
// Token limit
let error = anyhow!("Token limit exceeded");
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::TokenLimit));
assert_eq!(
classify_error(&error),
ErrorType::Recoverable(RecoverableError::TokenLimit)
);
// Context length exceeded
let error = anyhow!("HTTP 400 Bad Request: context length exceeded");
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::ContextLengthExceeded));
assert_eq!(
classify_error(&error),
ErrorType::Recoverable(RecoverableError::ContextLengthExceeded)
);
let error = anyhow!("Error 400: prompt is too long");
assert_eq!(classify_error(&error), ErrorType::Recoverable(RecoverableError::ContextLengthExceeded));
assert_eq!(
classify_error(&error),
ErrorType::Recoverable(RecoverableError::ContextLengthExceeded)
);
// Non-recoverable
let error = anyhow!("Invalid API key");
assert_eq!(classify_error(&error), ErrorType::NonRecoverable);
let error = anyhow!("Malformed request");
assert_eq!(classify_error(&error), ErrorType::NonRecoverable);
}
@@ -444,17 +500,17 @@ mod tests {
let delay1 = calculate_retry_delay(1, false);
let delay2 = calculate_retry_delay(2, false);
let delay3 = calculate_retry_delay(3, false);
// Due to jitter, we can't test exact values, but the base should increase
assert!(delay1.as_millis() >= (BASE_RETRY_DELAY_MS as f64 * 0.7) as u128);
assert!(delay1.as_millis() <= (BASE_RETRY_DELAY_MS as f64 * 1.3) as u128);
// Delay 2 should be roughly 2x delay 1 (minus jitter)
assert!(delay2.as_millis() >= delay1.as_millis());
// Delay 3 should be roughly 2x delay 2 (minus jitter)
assert!(delay3.as_millis() >= delay2.as_millis());
// Test max cap
let delay_max = calculate_retry_delay(10, false);
assert!(delay_max.as_millis() <= (DEFAULT_MAX_RETRY_DELAY_MS as f64 * 1.3) as u128);
@@ -469,7 +525,7 @@ mod tests {
let delay4 = calculate_retry_delay(4, true);
let delay5 = calculate_retry_delay(5, true);
let delay6 = calculate_retry_delay(6, true);
// Base delays should be around: 10s, 30s, 60s, 120s, 180s, 200s
// With ±30% jitter
assert!(delay1.as_millis() >= 7000 && delay1.as_millis() <= 13000);
@@ -484,14 +540,14 @@ mod tests {
fn test_truncate_for_logging() {
let short_text = "Hello, world!";
assert_eq!(truncate_for_logging(short_text, 20), "Hello, world!");
let long_text = "This is a very long text that should be truncated for logging purposes";
let truncated = truncate_for_logging(long_text, 20);
assert!(truncated.starts_with("This is a very long "));
assert!(truncated.contains("truncated"));
assert!(truncated.contains("total bytes"));
}
#[test]
fn test_truncate_with_multibyte_chars() {
// Test with multi-byte UTF-8 characters
@@ -499,7 +555,7 @@ mod tests {
let truncated = truncate_for_logging(text_with_emoji, 10);
// Should truncate at a valid UTF-8 boundary
assert!(truncated.starts_with("Hello "));
// Test with box-drawing characters like the one causing the panic
let text_with_box = "Some text ┌─────┐ more text";
let truncated = truncate_for_logging(text_with_box, 12);

View File

@@ -17,7 +17,7 @@ mod tests {
"test prompt".to_string(),
None,
100,
false, // quiet parameter
false, // quiet parameter
);
let result = retry_with_backoff(
@@ -57,7 +57,7 @@ mod tests {
"test prompt".to_string(),
None,
100,
false, // quiet parameter
false, // quiet parameter
);
let result: Result<&str, _> = retry_with_backoff(
@@ -91,7 +91,7 @@ mod tests {
"test prompt".to_string(),
None,
100,
false, // quiet parameter
false, // quiet parameter
);
let result: Result<&str, _> = retry_with_backoff(
@@ -124,7 +124,7 @@ mod tests {
long_prompt,
None,
100,
false, // quiet parameter
false, // quiet parameter
);
// The prompt should be truncated to 1000 chars

View File

@@ -5,7 +5,7 @@
// 4. Return everything else as the final filtered string
//! JSON tool call filtering for streaming LLM responses.
//!
//!
//! This module filters out JSON tool calls from LLM output streams while preserving
//! regular text content. It uses a state machine to handle streaming chunks.
@@ -29,7 +29,7 @@ struct FixedJsonToolState {
brace_depth: i32,
buffer: String,
json_start_in_buffer: Option<usize>, // Position where confirmed JSON tool call starts
content_returned_up_to: usize, // Track how much content we've already returned
content_returned_up_to: usize, // Track how much content we've already returned
potential_json_start: Option<usize>, // Where the potential JSON started
}

View File

@@ -358,8 +358,8 @@ More text"#;
// 2. Then the same complete JSON appears
let chunks = vec![
"Some text\n",
r#"{"tool": "str_replace", "args": {"diff":"...","file_path":"./crates/g3-cli"#, // Truncated
r#"{"tool": "str_replace", "args": {"diff":"...","file_path":"./crates/g3-cli/src/lib.rs"}}"#, // Complete
r#"{"tool": "str_replace", "args": {"diff":"...","file_path":"./crates/g3-cli"#, // Truncated
r#"{"tool": "str_replace", "args": {"diff":"...","file_path":"./crates/g3-cli/src/lib.rs"}}"#, // Complete
"\nMore text",
];

File diff suppressed because it is too large Load Diff

View File

@@ -7,19 +7,19 @@ use std::path::{Path, PathBuf};
pub struct Project {
/// The workspace directory for the project
pub workspace_dir: PathBuf,
/// Path to the requirements document (for autonomous mode)
pub requirements_path: Option<PathBuf>,
/// Override requirements text (takes precedence over requirements_path)
pub requirements_text: Option<String>,
/// Whether the project is in autonomous mode
pub autonomous: bool,
/// Project name (derived from workspace directory name)
pub name: String,
/// Session ID for tracking
pub session_id: Option<String>,
}
@@ -32,7 +32,7 @@ impl Project {
.and_then(|n| n.to_str())
.unwrap_or("unnamed")
.to_string();
Self {
workspace_dir,
requirements_path: None,
@@ -42,33 +42,36 @@ impl Project {
session_id: None,
}
}
/// Create a project for autonomous mode
pub fn new_autonomous(workspace_dir: PathBuf) -> Result<Self> {
let mut project = Self::new(workspace_dir.clone());
project.autonomous = true;
// Look for requirements.md in the workspace directory
let requirements_path = workspace_dir.join("requirements.md");
if requirements_path.exists() {
project.requirements_path = Some(requirements_path);
}
Ok(project)
}
/// Create a project for autonomous mode with requirements text override
pub fn new_autonomous_with_requirements(workspace_dir: PathBuf, requirements_text: String) -> Result<Self> {
pub fn new_autonomous_with_requirements(
workspace_dir: PathBuf,
requirements_text: String,
) -> Result<Self> {
let mut project = Self::new(workspace_dir.clone());
project.autonomous = true;
project.requirements_text = Some(requirements_text);
// Don't look for requirements.md file when text is provided
// The text override takes precedence
Ok(project)
}
/// Set the workspace directory and update related paths
pub fn set_workspace(&mut self, workspace_dir: PathBuf) {
self.workspace_dir = workspace_dir.clone();
@@ -77,7 +80,7 @@ impl Project {
.and_then(|n| n.to_str())
.unwrap_or("unnamed")
.to_string();
// Update requirements path if in autonomous mode
if self.autonomous {
let requirements_path = workspace_dir.join("requirements.md");
@@ -86,18 +89,18 @@ impl Project {
}
}
}
/// Get the workspace directory
pub fn workspace(&self) -> &Path {
&self.workspace_dir
}
/// Check if requirements file exists
pub fn has_requirements(&self) -> bool {
// Has requirements if either text override is provided or requirements file exists
self.requirements_text.is_some() || self.requirements_path.is_some()
}
/// Read the requirements file content
pub fn read_requirements(&self) -> Result<Option<String>> {
// Prioritize requirements text override
@@ -110,7 +113,7 @@ impl Project {
Ok(None)
}
}
/// Create the workspace directory if it doesn't exist
pub fn ensure_workspace_exists(&self) -> Result<()> {
if !self.workspace_dir.exists() {
@@ -118,18 +121,18 @@ impl Project {
}
Ok(())
}
/// Change to the workspace directory
pub fn enter_workspace(&self) -> Result<()> {
std::env::set_current_dir(&self.workspace_dir)?;
Ok(())
}
/// Get the logs directory for the project
pub fn logs_dir(&self) -> PathBuf {
self.workspace_dir.join("logs")
}
/// Ensure the logs directory exists
pub fn ensure_logs_dir(&self) -> Result<()> {
let logs_dir = self.logs_dir();

View File

@@ -189,7 +189,7 @@ Do not explain what you're going to do - just do it by calling the tools.
";
pub const SYSTEM_PROMPT_FOR_NATIVE_TOOL_USE: &'static str =
concatcp!(SYSTEM_NATIVE_TOOL_CALLS, CODING_STYLE);
concatcp!(SYSTEM_NATIVE_TOOL_CALLS, CODING_STYLE);
/// Generate system prompt based on whether multiple tool calls are allowed
pub fn get_system_prompt_for_native(allow_multiple: bool) -> String {

View File

@@ -30,7 +30,7 @@ impl TaskResult {
// Look for the final_output marker pattern
// The final_output content typically appears after the tool is called
// and is the substantive content that follows
// First, try to find if there's a clear final_output section
// This would be the content after the last tool execution
if let Some(final_output_pos) = content_without_timing.rfind("final_output") {
@@ -39,7 +39,7 @@ impl TaskResult {
if let Some(content_start) = content_without_timing[final_output_pos..].find('\n') {
let start_pos = final_output_pos + content_start + 1;
let final_content = &content_without_timing[start_pos..];
// Trim and return the complete content
let trimmed = final_content.trim();
if !trimmed.is_empty() {
@@ -47,7 +47,7 @@ impl TaskResult {
}
}
}
// Fallback to the original extract_last_block behavior if we can't find final_output
// This maintains backward compatibility
self.extract_last_block()
@@ -62,12 +62,13 @@ impl TaskResult {
} else {
&self.response
};
// Split by double newlines to find the last substantial block
let blocks: Vec<&str> = content_without_timing.split("\n\n").collect();
// Find the last non-empty block that isn't just whitespace
blocks.iter()
blocks
.iter()
.rev()
.find(|block| !block.trim().is_empty())
.map(|block| block.trim().to_string())
@@ -79,7 +80,8 @@ impl TaskResult {
/// Check if the response contains an approval (for autonomous mode)
pub fn is_approved(&self) -> bool {
self.extract_final_output().contains("IMPLEMENTATION_APPROVED")
self.extract_final_output()
.contains("IMPLEMENTATION_APPROVED")
}
}
@@ -91,20 +93,21 @@ mod tests {
fn test_extract_last_block() {
// Test case 1: Response with timing info
let context_window = ContextWindow::new(1000);
let response_with_timing = "Some initial content\n\nFinal block content\n\n⏱️ 2.3s | 💭 1.2s".to_string();
let response_with_timing =
"Some initial content\n\nFinal block content\n\n⏱️ 2.3s | 💭 1.2s".to_string();
let result = TaskResult::new(response_with_timing, context_window.clone());
assert_eq!(result.extract_last_block(), "Final block content");
// Test case 2: Response without timing
let response_no_timing = "Some initial content\n\nFinal block content".to_string();
let result = TaskResult::new(response_no_timing, context_window.clone());
assert_eq!(result.extract_last_block(), "Final block content");
// Test case 3: Response with IMPLEMENTATION_APPROVED
let response_approved = "Some content\n\nIMPLEMENTATION_APPROVED".to_string();
let result = TaskResult::new(response_approved, context_window.clone());
assert!(result.is_approved());
// Test case 4: Response without approval
let response_not_approved = "Some content\n\nNeeds more work".to_string();
let result = TaskResult::new(response_not_approved, context_window);
@@ -114,17 +117,17 @@ mod tests {
#[test]
fn test_extract_last_block_edge_cases() {
let context_window = ContextWindow::new(1000);
// Test empty response
let empty_response = "".to_string();
let result = TaskResult::new(empty_response, context_window.clone());
assert_eq!(result.extract_last_block(), "");
// Test single block
let single_block = "Just one block".to_string();
let result = TaskResult::new(single_block, context_window.clone());
assert_eq!(result.extract_last_block(), "Just one block");
// Test multiple empty blocks
let multiple_empty = "\n\n\n\nSome content\n\n\n\n".to_string();
let result = TaskResult::new(multiple_empty, context_window);
@@ -134,18 +137,22 @@ mod tests {
#[test]
fn test_extract_final_output() {
let context_window = ContextWindow::new(1000);
// Test case 1: Response with final_output tool call
let response_with_final_output = "Analyzing files...\n\nCalling final_output\n\nThis is the complete feedback\nwith multiple lines\nand important details\n\n⏱️ 2.3s".to_string();
let result = TaskResult::new(response_with_final_output, context_window.clone());
assert_eq!(result.extract_final_output(), "This is the complete feedback\nwith multiple lines\nand important details");
assert_eq!(
result.extract_final_output(),
"This is the complete feedback\nwith multiple lines\nand important details"
);
// Test case 2: Response with IMPLEMENTATION_APPROVED in final_output
let response_approved = "Review complete\n\nfinal_output called\n\nIMPLEMENTATION_APPROVED".to_string();
let response_approved =
"Review complete\n\nfinal_output called\n\nIMPLEMENTATION_APPROVED".to_string();
let result = TaskResult::new(response_approved, context_window.clone());
assert_eq!(result.extract_final_output(), "IMPLEMENTATION_APPROVED");
assert!(result.is_approved());
// Test case 3: Response with detailed feedback in final_output
let response_feedback = "Checking implementation...\n\nfinal_output\n\nThe following issues need to be addressed:\n1. Missing error handling in main.rs\n2. Tests are not comprehensive\n3. Documentation needs improvement\n\nPlease fix these issues.".to_string();
let result = TaskResult::new(response_feedback, context_window.clone());
@@ -154,12 +161,12 @@ mod tests {
assert!(extracted.contains("1. Missing error handling"));
assert!(extracted.contains("Please fix these issues."));
assert!(!result.is_approved());
// Test case 4: Response without final_output (fallback to extract_last_block)
let response_no_final_output = "Some analysis\n\nFinal thoughts here".to_string();
let result = TaskResult::new(response_no_final_output, context_window.clone());
assert_eq!(result.extract_final_output(), "Final thoughts here");
// Test case 5: Empty response
let empty_response = "".to_string();
let result = TaskResult::new(empty_response, context_window);

View File

@@ -6,15 +6,19 @@ use std::sync::Arc;
fn test_task_result_basic_functionality() {
// Create a context window with some messages
let mut context = ContextWindow::new(10000);
context.add_message(Message::new(MessageRole::User, "Test message 1".to_string())
);
context.add_message(Message::new(MessageRole::Assistant, "Response 1".to_string())
);
context.add_message(Message::new(
MessageRole::User,
"Test message 1".to_string(),
));
context.add_message(Message::new(
MessageRole::Assistant,
"Response 1".to_string(),
));
// Create a TaskResult
let response = "This is the response\n\nFinal output block".to_string();
let result = TaskResult::new(response.clone(), context.clone());
// Test basic properties
assert_eq!(result.response, response);
assert_eq!(result.context_window.conversation_history.len(), 2);
@@ -24,32 +28,32 @@ fn test_task_result_basic_functionality() {
#[test]
fn test_extract_last_block_various_formats() {
let context = ContextWindow::new(1000);
// Test 1: Standard format with multiple blocks
let response1 = "First block\n\nSecond block\n\nThird block".to_string();
let result1 = TaskResult::new(response1, context.clone());
assert_eq!(result1.extract_last_block(), "Third block");
// Test 2: With timing information
let response2 = "Content\n\nFinal block\n\n⏱️ 2.3s | 💭 1.2s".to_string();
let result2 = TaskResult::new(response2, context.clone());
assert_eq!(result2.extract_last_block(), "Final block");
// Test 3: Single line response
let response3 = "Single line response".to_string();
let result3 = TaskResult::new(response3, context.clone());
assert_eq!(result3.extract_last_block(), "Single line response");
// Test 4: Empty response
let response4 = "".to_string();
let result4 = TaskResult::new(response4, context.clone());
assert_eq!(result4.extract_last_block(), "");
// Test 5: Only whitespace
let response5 = "\n\n\n \n\n".to_string();
let result5 = TaskResult::new(response5, context.clone());
assert_eq!(result5.extract_last_block(), "");
// Test 6: Multiple blocks with empty ones
let response6 = "First\n\n\n\n\n\nLast block here".to_string();
let result6 = TaskResult::new(response6, context.clone());
@@ -59,7 +63,7 @@ fn test_extract_last_block_various_formats() {
#[test]
fn test_is_approved_detection() {
let context = ContextWindow::new(1000);
// Test approved cases
let approved_responses = vec![
"Analysis complete\n\nIMPLEMENTATION_APPROVED",
@@ -67,12 +71,16 @@ fn test_is_approved_detection() {
"IMPLEMENTATION_APPROVED",
"Review done\n\n✅ IMPLEMENTATION_APPROVED - All tests pass",
];
for response in approved_responses {
let result = TaskResult::new(response.to_string(), context.clone());
assert!(result.is_approved(), "Failed to detect approval in: {}", response);
assert!(
result.is_approved(),
"Failed to detect approval in: {}",
response
);
}
// Test not approved cases
let not_approved_responses = vec![
"Needs more work",
@@ -81,10 +89,14 @@ fn test_is_approved_detection() {
"Almost there but not APPROVED",
"",
];
for response in not_approved_responses {
let result = TaskResult::new(response.to_string(), context.clone());
assert!(!result.is_approved(), "Incorrectly detected approval in: {}", response);
assert!(
!result.is_approved(),
"Incorrectly detected approval in: {}",
response
);
}
}
@@ -93,33 +105,46 @@ fn test_context_window_preservation() {
// Create a context window with specific state
let mut context = ContextWindow::new(5000);
context.used_tokens = 1234;
// Add some messages
for i in 0..5 {
context.add_message(Message::new(if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant }, format!("Message {}", i)));
context.add_message(Message::new(
if i % 2 == 0 {
MessageRole::User
} else {
MessageRole::Assistant
},
format!("Message {}", i),
));
}
// Create TaskResult
let result = TaskResult::new("Response".to_string(), context.clone());
// Verify context is preserved
assert_eq!(result.context_window.total_tokens, 5000);
assert!(result.context_window.used_tokens > 1234); // Should have increased
assert_eq!(result.context_window.conversation_history.len(), 5);
// Verify messages are preserved correctly
for i in 0..5 {
let is_user = matches!(result.context_window.conversation_history[i].role, MessageRole::User);
let is_user = matches!(
result.context_window.conversation_history[i].role,
MessageRole::User
);
let expected_is_user = i % 2 == 0;
assert_eq!(is_user, expected_is_user, "Message {} has wrong role", i);
assert_eq!(result.context_window.conversation_history[i].content, format!("Message {}", i));
assert_eq!(
result.context_window.conversation_history[i].content,
format!("Message {}", i)
);
}
}
#[test]
fn test_coach_feedback_extraction_scenarios() {
let context = ContextWindow::new(1000);
// Scenario 1: Coach feedback with file operations and analysis
let coach_response = r#"Reading file: src/main.rs
📄 File content (23 lines):
@@ -133,13 +158,13 @@ The implementation needs the following fixes:
1. Add error handling
2. Implement missing functions
3. Add tests"#;
let result = TaskResult::new(coach_response.to_string(), context.clone());
let feedback = result.extract_last_block();
assert!(feedback.contains("Add error handling"));
assert!(feedback.contains("Implement missing functions"));
assert!(feedback.contains("Add tests"));
// Scenario 2: Coach approval
let approval_response = r#"Checking compilation...
✅ Build successful
@@ -148,11 +173,11 @@ Running tests...
✅ All tests pass
IMPLEMENTATION_APPROVED"#;
let result = TaskResult::new(approval_response.to_string(), context.clone());
assert!(result.is_approved());
assert_eq!(result.extract_last_block(), "IMPLEMENTATION_APPROVED");
// Scenario 3: Complex feedback with timing
let complex_response = r#"Tool execution log...
@@ -163,7 +188,7 @@ The following issues were found:
- Missing input validation
⏱️ 5.2s | 💭 2.1s"#;
let result = TaskResult::new(complex_response.to_string(), context.clone());
let feedback = result.extract_last_block();
assert!(feedback.contains("Memory leak"));
@@ -174,17 +199,18 @@ The following issues were found:
#[test]
fn test_edge_cases_and_special_characters() {
let context = ContextWindow::new(1000);
// Test with special characters and emojis
let response_with_emojis = "First part 🚀\n\n✅ Final part with emojis 🎉".to_string();
let result = TaskResult::new(response_with_emojis, context.clone());
assert_eq!(result.extract_last_block(), "✅ Final part with emojis 🎉");
// Test with code blocks
let response_with_code = "Explanation\n\n```rust\nfn main() {}\n```\n\nFinal comment".to_string();
let response_with_code =
"Explanation\n\n```rust\nfn main() {}\n```\n\nFinal comment".to_string();
let result = TaskResult::new(response_with_code, context.clone());
assert_eq!(result.extract_last_block(), "Final comment");
// Test with mixed newlines
let mixed_newlines = "Part 1\r\n\r\nPart 2\n\nPart 3".to_string();
let result = TaskResult::new(mixed_newlines, context.clone());
@@ -194,30 +220,33 @@ fn test_edge_cases_and_special_characters() {
#[test]
fn test_large_response_handling() {
let context = ContextWindow::new(100000);
// Create a large response
let mut large_response = String::new();
for i in 0..100 {
large_response.push_str(&format!("Block {} with some content\n\n", i));
}
large_response.push_str("This is the final block after 100 other blocks");
let result = TaskResult::new(large_response, context);
assert_eq!(result.extract_last_block(), "This is the final block after 100 other blocks");
assert_eq!(
result.extract_last_block(),
"This is the final block after 100 other blocks"
);
}
#[test]
fn test_concurrent_access() {
use std::thread;
let context = ContextWindow::new(1000);
let result = Arc::new(TaskResult::new(
"Concurrent test\n\nFinal block".to_string(),
context,
));
let mut handles = vec![];
// Spawn multiple threads to access the TaskResult
for _ in 0..10 {
let result_clone = Arc::clone(&result);
@@ -225,16 +254,15 @@ fn test_concurrent_access() {
// Each thread extracts the last block
let block = result_clone.extract_last_block();
assert_eq!(block, "Final block");
// Check approval status
assert!(!result_clone.is_approved());
});
handles.push(handle);
}
// Wait for all threads to complete
for handle in handles {
handle.join().unwrap();
}
}

View File

@@ -7,10 +7,10 @@ mod tilde_expansion_tests {
// Test that shellexpand works
let path_with_tilde = "~/test.txt";
let expanded = shellexpand::tilde(path_with_tilde);
// Get the actual home directory
let home = env::var("HOME").expect("HOME environment variable not set");
// Verify expansion happened
assert_eq!(expanded.as_ref(), format!("{}/test.txt", home));
assert!(!expanded.contains("~"));
@@ -20,9 +20,9 @@ mod tilde_expansion_tests {
fn test_tilde_expansion_with_subdirs() {
let path_with_tilde = "~/Documents/test.txt";
let expanded = shellexpand::tilde(path_with_tilde);
let home = env::var("HOME").expect("HOME environment variable not set");
assert_eq!(expanded.as_ref(), format!("{}/Documents/test.txt", home));
}
@@ -30,7 +30,7 @@ mod tilde_expansion_tests {
fn test_no_tilde_unchanged() {
let path_without_tilde = "/absolute/path/test.txt";
let expanded = shellexpand::tilde(path_without_tilde);
assert_eq!(expanded.as_ref(), path_without_tilde);
}
}

View File

@@ -4,58 +4,60 @@
pub trait UiWriter: Send + Sync {
/// Print a simple message
fn print(&self, message: &str);
/// Print a message with a newline
fn println(&self, message: &str);
/// Print without newline (for progress indicators)
fn print_inline(&self, message: &str);
/// Print a system prompt section
fn print_system_prompt(&self, prompt: &str);
/// Print a context window status message
fn print_context_status(&self, message: &str);
/// Print a context thinning success message with highlight and animation
fn print_context_thinning(&self, message: &str);
/// Print a tool execution header
fn print_tool_header(&self, tool_name: &str);
/// Print a tool argument
fn print_tool_arg(&self, key: &str, value: &str);
/// Print tool output header
fn print_tool_output_header(&self);
/// Update the current tool output line (replaces previous line)
fn update_tool_output_line(&self, line: &str);
/// Print a tool output line
fn print_tool_output_line(&self, line: &str);
/// Print tool output summary (when output is truncated)
fn print_tool_output_summary(&self, hidden_count: usize);
/// Print tool execution timing
fn print_tool_timing(&self, duration_str: &str);
/// Print the agent prompt indicator
fn print_agent_prompt(&self);
/// Print agent response inline (for streaming)
fn print_agent_response(&self, content: &str);
/// Notify that an SSE event was received (including pings)
fn notify_sse_received(&self);
/// Flush any buffered output
fn flush(&self);
/// Returns true if this UI writer wants full, untruncated output
/// Default is false (truncate for human readability)
fn wants_full_output(&self) -> bool { false }
fn wants_full_output(&self) -> bool {
false
}
/// Prompt the user for a yes/no confirmation
fn prompt_user_yes_no(&self, message: &str) -> bool;
@@ -86,7 +88,13 @@ impl UiWriter for NullUiWriter {
fn print_agent_response(&self, _content: &str) {}
fn notify_sse_received(&self) {}
fn flush(&self) {}
fn wants_full_output(&self) -> bool { false }
fn prompt_user_yes_no(&self, _message: &str) -> bool { true }
fn prompt_user_choice(&self, _message: &str, _options: &[&str]) -> usize { 0 }
}
fn wants_full_output(&self) -> bool {
false
}
fn prompt_user_yes_no(&self, _message: &str) -> bool {
true
}
fn prompt_user_choice(&self, _message: &str, _options: &[&str]) -> usize {
0
}
}