feat: add --auto-memory flag to prompt LLM to save discoveries

Adds a new --auto-memory CLI flag that automatically sends a reminder
to the LLM after each turn where tools were called, prompting it to
call the remember tool if it discovered any key code locations.

Changes:
- Add auto_memory field and set_auto_memory() method to Agent
- Add tool_calls_this_turn tracking in execute_tool_in_dir()
- Add send_auto_memory_reminder() that sends reminder after tool use
- Add --auto-memory CLI flag and wire it up in console/machine modes
- Call send_auto_memory_reminder() in single-shot and interactive modes
- Add visible status messages for auto-memory actions

Fixes bug where tool calls were not being tracked when execute_tool_in_dir
was called directly with working_dir=None.
This commit is contained in:
Dhanji R. Prasanna
2026-01-11 08:00:51 +08:00
parent 39918cf281
commit 280ae1fcbb
2 changed files with 129 additions and 8 deletions

View File

@@ -112,6 +112,8 @@ pub struct Agent<W: UiWriter> {
>,
webdriver_process: std::sync::Arc<tokio::sync::RwLock<Option<tokio::process::Child>>>,
tool_call_count: usize,
/// Tool calls made in the current turn (reset after each turn)
tool_calls_this_turn: Vec<String>,
requirements_sha: Option<String>,
/// Working directory for tool execution (set by --codebase-fast-start)
working_dir: Option<String>,
@@ -122,6 +124,8 @@ pub struct Agent<W: UiWriter> {
is_agent_mode: bool,
/// Name of the agent if running in agent mode (e.g., "fowler", "pike")
agent_name: Option<String>,
/// Whether auto-memory reminders are enabled (--auto-memory flag)
auto_memory: bool,
}
impl<W: UiWriter> Agent<W> {
@@ -284,6 +288,7 @@ impl<W: UiWriter> Agent<W> {
webdriver_session: std::sync::Arc::new(tokio::sync::RwLock::new(None)),
webdriver_process: std::sync::Arc::new(tokio::sync::RwLock::new(None)),
tool_call_count: 0,
tool_calls_this_turn: Vec::new(),
requirements_sha: None,
working_dir: None,
background_process_manager: std::sync::Arc::new(
@@ -293,6 +298,7 @@ impl<W: UiWriter> Agent<W> {
pending_images: Vec::new(),
is_agent_mode: false,
agent_name: None,
auto_memory: false,
})
}
@@ -1441,6 +1447,83 @@ impl<W: UiWriter> Agent<W> {
debug!("Agent mode enabled for agent: {}", agent_name);
}
/// Enable auto-memory reminders after turns with tool calls
pub fn set_auto_memory(&mut self, enabled: bool) {
self.auto_memory = enabled;
debug!("Auto-memory reminders: {}", if enabled { "enabled" } else { "disabled" });
}
/// Send an auto-memory reminder to the LLM if tools were called during the turn.
/// This prompts the LLM to call the `remember` tool if it discovered any key code locations.
/// Returns true if a reminder was sent and processed.
pub async fn send_auto_memory_reminder(&mut self) -> Result<bool> {
if !self.auto_memory {
return Ok(false);
}
// Check if any tools were called this turn
if self.tool_calls_this_turn.is_empty() {
debug!("Auto-memory: No tools called, skipping reminder");
self.ui_writer.print_context_status("📝 Auto-memory: No tools called this turn, skipping reminder.\n");
return Ok(false);
}
// Check if remember was already called this turn - no need to remind
if self.tool_calls_this_turn.iter().any(|t| t == "remember") {
debug!("Auto-memory: 'remember' was already called this turn, skipping reminder");
self.ui_writer.print_context_status("📝 Auto-memory: 'remember' already called, skipping reminder.\n");
self.tool_calls_this_turn.clear();
return Ok(false);
}
// Take the tools list and reset for next turn
let tools_called = std::mem::take(&mut self.tool_calls_this_turn);
debug!("Auto-memory: Sending reminder to LLM ({} tools called this turn: {:?})", tools_called.len(), tools_called);
self.ui_writer.print_context_status(&format!("\n📝 Auto-memory: Checking if discoveries should be saved ({} tools used)...\n", tools_called.len()));
let reminder = "SYSTEM REMINDER: You used tools during this turn. If you discovered any key code locations, patterns, or entry points that aren't already in Project Memory, please call the `remember` tool now to save them. If you didn't discover anything new worth remembering, you can skip this. Respond briefly after deciding.";
// Add the reminder as a user message and get a response
self.context_window.add_message(Message::new(
MessageRole::User,
reminder.to_string(),
));
// Build the completion request
let messages = self.context_window.conversation_history.clone();
// Get provider and tools
let provider = self.providers.get(None)?;
let provider_name = provider.name().to_string();
let tools = if provider.has_native_tool_calling() {
let tool_config = tool_definitions::ToolConfig::new(
self.config.webdriver.enabled,
self.config.computer_control.enabled,
);
Some(tool_definitions::create_tool_definitions(tool_config))
} else {
None
};
let _ = provider; // Drop the provider reference
let max_tokens = Some(self.resolve_max_tokens(&provider_name));
let request = CompletionRequest {
messages,
max_tokens,
temperature: Some(self.resolve_temperature(&provider_name)),
stream: true,
tools,
disable_thinking: true, // Keep it brief
};
// Execute the reminder turn (show_timing = false to keep it quiet)
self.stream_completion_with_tools(request, false).await?;
Ok(true)
}
/// Initialize session ID manually (primarily for testing).
/// This allows tests to verify session ID generation without calling execute_task,
/// which would require an LLM provider.
@@ -2753,8 +2836,7 @@ impl<W: UiWriter> Agent<W> {
}
pub async fn execute_tool(&mut self, tool_call: &ToolCall) -> Result<String> {
// Increment tool call count
self.tool_call_count += 1;
// Tool tracking is handled by execute_tool_in_dir
self.execute_tool_in_dir(tool_call, None).await
}
@@ -2764,10 +2846,9 @@ impl<W: UiWriter> Agent<W> {
tool_call: &ToolCall,
working_dir: Option<&str>,
) -> Result<String> {
// Only increment tool call count if not already incremented by execute_tool
if working_dir.is_some() {
self.tool_call_count += 1;
}
// Always track tool calls for auto-memory feature
self.tool_call_count += 1;
self.tool_calls_this_turn.push(tool_call.tool.clone());
let result = self.execute_tool_inner_in_dir(tool_call, working_dir).await;
let log_str = match &result {