refactor(g3-core): extract 4 modules from monolithic lib.rs
Reduce lib.rs from 7481 to 6557 lines (-12.4%) by extracting: - paths.rs: Session/workspace path utilities (get_todo_path, get_logs_dir, etc.) - streaming_parser.rs: StreamingToolParser for LLM response parsing - utils.rs: Diff parsing and shell escaping utilities - webdriver_session.rs: Unified Safari/Chrome WebDriver abstraction All public APIs preserved via re-exports for backward compatibility. Added 13 new unit tests across extracted modules. All 225 tests pass.
This commit is contained in:
File diff suppressed because it is too large
Load Diff
124
crates/g3-core/src/paths.rs
Normal file
124
crates/g3-core/src/paths.rs
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
//! Path utilities for G3 session and workspace management.
|
||||||
|
//!
|
||||||
|
//! This module centralizes all path-related logic for:
|
||||||
|
//! - TODO file location
|
||||||
|
//! - Logs directory
|
||||||
|
//! - Session directories and files
|
||||||
|
//! - Thinned content storage
|
||||||
|
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
/// Environment variable name for workspace path.
|
||||||
|
/// Used to direct all logs to the workspace directory.
|
||||||
|
pub const G3_WORKSPACE_PATH_ENV: &str = "G3_WORKSPACE_PATH";
|
||||||
|
|
||||||
|
/// Environment variable name for custom TODO file path.
|
||||||
|
const G3_TODO_PATH_ENV: &str = "G3_TODO_PATH";
|
||||||
|
|
||||||
|
/// Get the path to the todo.g3.md file.
|
||||||
|
///
|
||||||
|
/// Checks for G3_TODO_PATH environment variable first (used by planning mode),
|
||||||
|
/// then falls back to todo.g3.md in the current directory.
|
||||||
|
pub fn get_todo_path() -> PathBuf {
|
||||||
|
if let Ok(custom_path) = std::env::var(G3_TODO_PATH_ENV) {
|
||||||
|
PathBuf::from(custom_path)
|
||||||
|
} else {
|
||||||
|
std::env::current_dir().unwrap_or_default().join("todo.g3.md")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the path to the logs directory.
|
||||||
|
///
|
||||||
|
/// Checks for G3_WORKSPACE_PATH environment variable first (used by planning mode),
|
||||||
|
/// then falls back to "logs" in the current directory.
|
||||||
|
pub fn get_logs_dir() -> PathBuf {
|
||||||
|
if let Ok(workspace_path) = std::env::var(G3_WORKSPACE_PATH_ENV) {
|
||||||
|
PathBuf::from(workspace_path).join("logs")
|
||||||
|
} else {
|
||||||
|
std::env::current_dir().unwrap_or_default().join("logs")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Public accessor for the logs directory path (for use by submodules).
|
||||||
|
/// Alias for `get_logs_dir()` for backward compatibility.
|
||||||
|
pub fn logs_dir() -> PathBuf {
|
||||||
|
get_logs_dir()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the base .g3 directory path.
|
||||||
|
/// This is the root for all g3 session data in the current workspace.
|
||||||
|
pub fn get_g3_dir() -> PathBuf {
|
||||||
|
if let Ok(workspace_path) = std::env::var(G3_WORKSPACE_PATH_ENV) {
|
||||||
|
PathBuf::from(workspace_path).join(".g3")
|
||||||
|
} else {
|
||||||
|
std::env::current_dir().unwrap_or_default().join(".g3")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the session directory for a specific session ID.
|
||||||
|
/// Returns .g3/sessions/<session_id>/
|
||||||
|
pub fn get_session_logs_dir(session_id: &str) -> PathBuf {
|
||||||
|
get_g3_dir().join("sessions").join(session_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Ensure the session directory exists for a specific session ID.
|
||||||
|
/// Creates .g3/sessions/<session_id>/ and subdirectories.
|
||||||
|
pub fn ensure_session_dir(session_id: &str) -> std::io::Result<PathBuf> {
|
||||||
|
let session_dir = get_session_logs_dir(session_id);
|
||||||
|
std::fs::create_dir_all(&session_dir)?;
|
||||||
|
|
||||||
|
// Create subdirectories
|
||||||
|
std::fs::create_dir_all(session_dir.join("thinned"))?;
|
||||||
|
|
||||||
|
Ok(session_dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the thinned content directory for a session.
|
||||||
|
/// Returns .g3/sessions/<session_id>/thinned/
|
||||||
|
pub fn get_thinned_dir(session_id: &str) -> PathBuf {
|
||||||
|
get_session_logs_dir(session_id).join("thinned")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the path to the session.json file for a session.
|
||||||
|
/// Returns .g3/sessions/<session_id>/session.json
|
||||||
|
pub fn get_session_file(session_id: &str) -> PathBuf {
|
||||||
|
get_session_logs_dir(session_id).join("session.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the path to the context summary file for a session.
|
||||||
|
/// Returns .g3/sessions/<session_id>/context_summary.txt
|
||||||
|
pub fn get_context_summary_file(session_id: &str) -> PathBuf {
|
||||||
|
get_session_logs_dir(session_id).join("context_summary.txt")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_todo_path_default() {
|
||||||
|
// When G3_TODO_PATH is not set, should return current_dir/todo.g3.md
|
||||||
|
std::env::remove_var(G3_TODO_PATH_ENV);
|
||||||
|
let path = get_todo_path();
|
||||||
|
assert!(path.ends_with("todo.g3.md"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_session_paths_are_consistent() {
|
||||||
|
let session_id = "test-session-123";
|
||||||
|
let session_dir = get_session_logs_dir(session_id);
|
||||||
|
let thinned_dir = get_thinned_dir(session_id);
|
||||||
|
let session_file = get_session_file(session_id);
|
||||||
|
let summary_file = get_context_summary_file(session_id);
|
||||||
|
|
||||||
|
// All paths should be under the session directory
|
||||||
|
assert!(thinned_dir.starts_with(&session_dir));
|
||||||
|
assert!(session_file.starts_with(&session_dir));
|
||||||
|
assert!(summary_file.starts_with(&session_dir));
|
||||||
|
|
||||||
|
// Check expected filenames
|
||||||
|
assert!(thinned_dir.ends_with("thinned"));
|
||||||
|
assert!(session_file.ends_with("session.json"));
|
||||||
|
assert!(summary_file.ends_with("context_summary.txt"));
|
||||||
|
}
|
||||||
|
}
|
||||||
413
crates/g3-core/src/streaming_parser.rs
Normal file
413
crates/g3-core/src/streaming_parser.rs
Normal file
@@ -0,0 +1,413 @@
|
|||||||
|
//! Streaming tool parser for processing LLM response chunks.
|
||||||
|
//!
|
||||||
|
//! This module handles parsing of tool calls from streaming LLM responses,
|
||||||
|
//! supporting both native tool calls and JSON-based fallback parsing.
|
||||||
|
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
use crate::ToolCall;
|
||||||
|
|
||||||
|
/// Patterns used to detect JSON tool calls in text.
|
||||||
|
/// These cover common whitespace variations in JSON formatting.
|
||||||
|
const TOOL_CALL_PATTERNS: [&str; 4] = [
|
||||||
|
r#"{"tool":"#,
|
||||||
|
r#"{ "tool":"#,
|
||||||
|
r#"{"tool" :"#,
|
||||||
|
r#"{ "tool" :"#,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Modern streaming tool parser that properly handles native tool calls and SSE chunks.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct StreamingToolParser {
|
||||||
|
/// Buffer for accumulating text content
|
||||||
|
text_buffer: String,
|
||||||
|
/// Position in text_buffer up to which tool calls have been consumed/executed.
|
||||||
|
/// This prevents has_unexecuted_tool_call() from returning true for already-executed tools.
|
||||||
|
last_consumed_position: usize,
|
||||||
|
/// Whether we've received a message_stop event
|
||||||
|
message_stopped: bool,
|
||||||
|
/// Whether we're currently in a JSON tool call (for fallback parsing)
|
||||||
|
in_json_tool_call: bool,
|
||||||
|
/// Start position of JSON tool call (for fallback parsing)
|
||||||
|
json_tool_start: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for StreamingToolParser {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StreamingToolParser {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
text_buffer: String::new(),
|
||||||
|
last_consumed_position: 0,
|
||||||
|
message_stopped: false,
|
||||||
|
in_json_tool_call: false,
|
||||||
|
json_tool_start: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find the starting position of the last tool call pattern in the given text.
|
||||||
|
/// Returns None if no tool call pattern is found.
|
||||||
|
fn find_last_tool_call_start(text: &str) -> Option<usize> {
|
||||||
|
let mut best_start: Option<usize> = None;
|
||||||
|
for pattern in &TOOL_CALL_PATTERNS {
|
||||||
|
if let Some(pos) = text.rfind(pattern) {
|
||||||
|
if best_start.map_or(true, |best| pos > best) {
|
||||||
|
best_start = Some(pos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
best_start
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find the starting position of the FIRST tool call pattern in the given text.
|
||||||
|
/// Returns None if no tool call pattern is found.
|
||||||
|
fn find_first_tool_call_start(text: &str) -> Option<usize> {
|
||||||
|
let mut best_start: Option<usize> = None;
|
||||||
|
for pattern in &TOOL_CALL_PATTERNS {
|
||||||
|
if let Some(pos) = text.find(pattern) {
|
||||||
|
if best_start.map_or(true, |best| pos < best) {
|
||||||
|
best_start = Some(pos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
best_start
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate that tool call args don't contain message-like content.
|
||||||
|
/// This detects malformed tool calls where agent messages got mixed into args.
|
||||||
|
fn has_message_like_keys(args: &serde_json::Map<String, serde_json::Value>) -> bool {
|
||||||
|
args.keys().any(|key| {
|
||||||
|
key.len() > 100
|
||||||
|
|| key.contains('\n')
|
||||||
|
|| key.contains("I'll")
|
||||||
|
|| key.contains("Let me")
|
||||||
|
|| key.contains("Here's")
|
||||||
|
|| key.contains("I can")
|
||||||
|
|| key.contains("I need")
|
||||||
|
|| key.contains("First")
|
||||||
|
|| key.contains("Now")
|
||||||
|
|| key.contains("The ")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process a streaming chunk and return completed tool calls if any.
|
||||||
|
pub fn process_chunk(&mut self, chunk: &g3_providers::CompletionChunk) -> Vec<ToolCall> {
|
||||||
|
let mut completed_tools = Vec::new();
|
||||||
|
|
||||||
|
// Add text content to buffer
|
||||||
|
if !chunk.content.is_empty() {
|
||||||
|
self.text_buffer.push_str(&chunk.content);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle native tool calls - return them immediately when received.
|
||||||
|
// This allows tools to be executed as soon as they're fully parsed,
|
||||||
|
// preventing duplicate tool calls from being accumulated.
|
||||||
|
if let Some(ref tool_calls) = chunk.tool_calls {
|
||||||
|
debug!("Received native tool calls: {:?}", tool_calls);
|
||||||
|
|
||||||
|
// Convert and return tool calls immediately
|
||||||
|
for tool_call in tool_calls {
|
||||||
|
let converted_tool = ToolCall {
|
||||||
|
tool: tool_call.tool.clone(),
|
||||||
|
args: tool_call.args.clone(),
|
||||||
|
};
|
||||||
|
completed_tools.push(converted_tool);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if message is finished/stopped
|
||||||
|
if chunk.finished {
|
||||||
|
self.message_stopped = true;
|
||||||
|
debug!("Message finished, processing accumulated tool calls");
|
||||||
|
|
||||||
|
// When stream finishes, find ALL JSON tool calls in the accumulated buffer
|
||||||
|
if completed_tools.is_empty() && !self.text_buffer.is_empty() {
|
||||||
|
let all_tools = self.try_parse_all_json_tool_calls_from_buffer();
|
||||||
|
if !all_tools.is_empty() {
|
||||||
|
debug!(
|
||||||
|
"Found {} JSON tool calls in buffer at stream end",
|
||||||
|
all_tools.len()
|
||||||
|
);
|
||||||
|
completed_tools.extend(all_tools);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: Try to parse JSON tool calls from current chunk content if no native tool calls
|
||||||
|
if completed_tools.is_empty() && !chunk.content.is_empty() && !chunk.finished {
|
||||||
|
if let Some(json_tool) = self.try_parse_json_tool_call(&chunk.content) {
|
||||||
|
completed_tools.push(json_tool);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
completed_tools
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fallback method to parse JSON tool calls from text content.
|
||||||
|
fn try_parse_json_tool_call(&mut self, _content: &str) -> Option<ToolCall> {
|
||||||
|
// If we're not currently in a JSON tool call, look for the start
|
||||||
|
if !self.in_json_tool_call {
|
||||||
|
if let Some(pos) = Self::find_last_tool_call_start(&self.text_buffer) {
|
||||||
|
debug!("Found JSON tool call pattern at position {}", pos);
|
||||||
|
self.in_json_tool_call = true;
|
||||||
|
self.json_tool_start = Some(pos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we're in a JSON tool call, try to find the end and parse it
|
||||||
|
if self.in_json_tool_call {
|
||||||
|
if let Some(start_pos) = self.json_tool_start {
|
||||||
|
let json_text = &self.text_buffer[start_pos..];
|
||||||
|
|
||||||
|
// Try to find a complete JSON object
|
||||||
|
if let Some(end_pos) = Self::find_complete_json_object_end(json_text) {
|
||||||
|
let json_str = &json_text[..=end_pos];
|
||||||
|
debug!("Attempting to parse JSON tool call: {}", json_str);
|
||||||
|
|
||||||
|
// Try to parse as a ToolCall
|
||||||
|
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(json_str) {
|
||||||
|
// Validate that args is an object with reasonable keys
|
||||||
|
if let Some(args_obj) = tool_call.args.as_object() {
|
||||||
|
if Self::has_message_like_keys(args_obj) {
|
||||||
|
debug!(
|
||||||
|
"Detected malformed tool call with message-like keys, skipping"
|
||||||
|
);
|
||||||
|
self.in_json_tool_call = false;
|
||||||
|
self.json_tool_start = None;
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
debug!("Successfully parsed valid JSON tool call: {:?}", tool_call);
|
||||||
|
self.in_json_tool_call = false;
|
||||||
|
self.json_tool_start = None;
|
||||||
|
return Some(tool_call);
|
||||||
|
}
|
||||||
|
debug!("Tool call args is not an object, skipping");
|
||||||
|
} else {
|
||||||
|
debug!("Failed to parse JSON tool call: {}", json_str);
|
||||||
|
}
|
||||||
|
// Reset and continue looking
|
||||||
|
self.in_json_tool_call = false;
|
||||||
|
self.json_tool_start = None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse ALL JSON tool calls from the accumulated text buffer.
|
||||||
|
/// This finds all complete tool calls, not just the last one.
|
||||||
|
fn try_parse_all_json_tool_calls_from_buffer(&self) -> Vec<ToolCall> {
|
||||||
|
let mut tool_calls = Vec::new();
|
||||||
|
let mut search_start = 0;
|
||||||
|
|
||||||
|
while search_start < self.text_buffer.len() {
|
||||||
|
let search_text = &self.text_buffer[search_start..];
|
||||||
|
|
||||||
|
// Find the next tool call pattern
|
||||||
|
if let Some(relative_pos) = Self::find_first_tool_call_start(search_text) {
|
||||||
|
let abs_start = search_start + relative_pos;
|
||||||
|
let json_text = &self.text_buffer[abs_start..];
|
||||||
|
|
||||||
|
// Try to find a complete JSON object
|
||||||
|
if let Some(end_pos) = Self::find_complete_json_object_end(json_text) {
|
||||||
|
let json_str = &json_text[..=end_pos];
|
||||||
|
|
||||||
|
if let Ok(tool_call) = serde_json::from_str::<ToolCall>(json_str) {
|
||||||
|
if let Some(args_obj) = tool_call.args.as_object() {
|
||||||
|
if !Self::has_message_like_keys(args_obj) {
|
||||||
|
debug!(
|
||||||
|
"Found tool call at position {}: {:?}",
|
||||||
|
abs_start, tool_call.tool
|
||||||
|
);
|
||||||
|
tool_calls.push(tool_call);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Move past this tool call
|
||||||
|
search_start = abs_start + end_pos + 1;
|
||||||
|
} else {
|
||||||
|
// Incomplete JSON, stop searching
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No more tool call patterns found
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_calls
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the accumulated text content (excluding tool calls).
|
||||||
|
pub fn get_text_content(&self) -> &str {
|
||||||
|
&self.text_buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get content before a specific position (for display purposes).
|
||||||
|
pub fn get_content_before_position(&self, pos: usize) -> String {
|
||||||
|
if pos <= self.text_buffer.len() {
|
||||||
|
self.text_buffer[..pos].to_string()
|
||||||
|
} else {
|
||||||
|
self.text_buffer.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if the message has been stopped/finished.
|
||||||
|
pub fn is_message_stopped(&self) -> bool {
|
||||||
|
self.message_stopped
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if the text buffer contains an incomplete JSON tool call.
|
||||||
|
/// This detects cases where the LLM started emitting a tool call but the stream ended
|
||||||
|
/// before the JSON was complete (truncated output).
|
||||||
|
pub fn has_incomplete_tool_call(&self) -> bool {
|
||||||
|
// Only check the unconsumed portion of the buffer
|
||||||
|
let unchecked_buffer = &self.text_buffer[self.last_consumed_position..];
|
||||||
|
if let Some(start_pos) = Self::find_last_tool_call_start(unchecked_buffer) {
|
||||||
|
let json_text = &unchecked_buffer[start_pos..];
|
||||||
|
// If NOT complete, it's an incomplete tool call
|
||||||
|
Self::find_complete_json_object_end(json_text).is_none()
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if the text buffer contains an unexecuted tool call.
|
||||||
|
/// This detects cases where the LLM emitted a complete tool call JSON
|
||||||
|
/// but it wasn't parsed/executed (e.g., due to parsing issues).
|
||||||
|
pub fn has_unexecuted_tool_call(&self) -> bool {
|
||||||
|
// Only check the unconsumed portion of the buffer
|
||||||
|
let unchecked_buffer = &self.text_buffer[self.last_consumed_position..];
|
||||||
|
if let Some(start_pos) = Self::find_last_tool_call_start(unchecked_buffer) {
|
||||||
|
let json_text = &unchecked_buffer[start_pos..];
|
||||||
|
// If the JSON IS complete, it means there's an unexecuted tool call
|
||||||
|
if let Some(json_end) = Self::find_complete_json_object_end(json_text) {
|
||||||
|
let json_only = &json_text[..=json_end];
|
||||||
|
return serde_json::from_str::<serde_json::Value>(json_only).is_ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mark all tool calls up to the current buffer position as consumed/executed.
|
||||||
|
/// This prevents has_unexecuted_tool_call() from returning true for already-executed tools.
|
||||||
|
pub fn mark_tool_calls_consumed(&mut self) {
|
||||||
|
self.last_consumed_position = self.text_buffer.len();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find the end position (byte index) of a complete JSON object in the text.
|
||||||
|
/// Returns None if no complete JSON object is found.
|
||||||
|
pub fn find_complete_json_object_end(text: &str) -> Option<usize> {
|
||||||
|
let mut brace_count = 0;
|
||||||
|
let mut in_string = false;
|
||||||
|
let mut escape_next = false;
|
||||||
|
let mut found_start = false;
|
||||||
|
|
||||||
|
for (i, ch) in text.char_indices() {
|
||||||
|
if escape_next {
|
||||||
|
escape_next = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match ch {
|
||||||
|
'\\' => escape_next = true,
|
||||||
|
'"' if !escape_next => in_string = !in_string,
|
||||||
|
'{' if !in_string => {
|
||||||
|
brace_count += 1;
|
||||||
|
found_start = true;
|
||||||
|
}
|
||||||
|
'}' if !in_string => {
|
||||||
|
brace_count -= 1;
|
||||||
|
if brace_count == 0 && found_start {
|
||||||
|
return Some(i); // Return the byte index of the closing brace
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None // No complete JSON object found
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reset the parser state for a new message.
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.text_buffer.clear();
|
||||||
|
self.last_consumed_position = 0;
|
||||||
|
self.message_stopped = false;
|
||||||
|
self.in_json_tool_call = false;
|
||||||
|
self.json_tool_start = None;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the current text buffer length (for position tracking).
|
||||||
|
pub fn text_buffer_len(&self) -> usize {
|
||||||
|
self.text_buffer.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if currently parsing a JSON tool call (for debugging).
|
||||||
|
pub fn is_in_json_tool_call(&self) -> bool {
|
||||||
|
self.in_json_tool_call
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the JSON tool start position (for debugging).
|
||||||
|
pub fn json_tool_start_position(&self) -> Option<usize> {
|
||||||
|
self.json_tool_start
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_find_complete_json_object_end_simple() {
|
||||||
|
let text = r#"{"tool":"shell","args":{"command":"ls"}}"#;
|
||||||
|
assert_eq!(
|
||||||
|
StreamingToolParser::find_complete_json_object_end(text),
|
||||||
|
Some(text.len() - 1)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_find_complete_json_object_end_nested() {
|
||||||
|
let text = r#"{"tool":"write","args":{"content":"{nested}"}}"#;
|
||||||
|
assert_eq!(
|
||||||
|
StreamingToolParser::find_complete_json_object_end(text),
|
||||||
|
Some(text.len() - 1)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_find_complete_json_object_end_incomplete() {
|
||||||
|
let text = r#"{"tool":"shell","args":{"command":"ls""#;
|
||||||
|
assert_eq!(StreamingToolParser::find_complete_json_object_end(text), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tool_call_patterns() {
|
||||||
|
// Test that all patterns are detected
|
||||||
|
assert!(StreamingToolParser::find_first_tool_call_start(r#"{"tool":"test"}"#).is_some());
|
||||||
|
assert!(StreamingToolParser::find_first_tool_call_start(r#"{ "tool":"test"}"#).is_some());
|
||||||
|
assert!(StreamingToolParser::find_first_tool_call_start(r#"{"tool" :"test"}"#).is_some());
|
||||||
|
assert!(StreamingToolParser::find_first_tool_call_start(r#"{ "tool" :"test"}"#).is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parser_reset() {
|
||||||
|
let mut parser = StreamingToolParser::new();
|
||||||
|
parser.text_buffer = "some content".to_string();
|
||||||
|
parser.message_stopped = true;
|
||||||
|
parser.last_consumed_position = 5;
|
||||||
|
|
||||||
|
parser.reset();
|
||||||
|
|
||||||
|
assert!(parser.text_buffer.is_empty());
|
||||||
|
assert!(!parser.message_stopped);
|
||||||
|
assert_eq!(parser.last_consumed_position, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
440
crates/g3-core/src/utils.rs
Normal file
440
crates/g3-core/src/utils.rs
Normal file
@@ -0,0 +1,440 @@
|
|||||||
|
//! Utility functions for diff parsing, shell escaping, and JSON fixing.
|
||||||
|
//!
|
||||||
|
//! This module contains helper functions used by the agent for:
|
||||||
|
//! - Applying unified diffs to strings
|
||||||
|
//! - Shell command escaping
|
||||||
|
//! - JSON quote fixing
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
/// Apply unified diff to an input string with optional [start, end) bounds.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `file_content` - The original file content
|
||||||
|
/// * `diff` - The unified diff to apply
|
||||||
|
/// * `start_char` - Optional start character position (0-indexed, inclusive)
|
||||||
|
/// * `end_char` - Optional end character position (0-indexed, exclusive)
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// The modified content with the diff applied
|
||||||
|
pub fn apply_unified_diff_to_string(
|
||||||
|
file_content: &str,
|
||||||
|
diff: &str,
|
||||||
|
start_char: Option<usize>,
|
||||||
|
end_char: Option<usize>,
|
||||||
|
) -> Result<String> {
|
||||||
|
// Parse full unified diff into hunks and apply sequentially.
|
||||||
|
let hunks = parse_unified_diff_hunks(diff);
|
||||||
|
if hunks.is_empty() {
|
||||||
|
anyhow::bail!(
|
||||||
|
"Invalid diff format. Expected unified diff with @@ hunks or +/- with context lines"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize line endings to avoid CRLF/CR mismatches
|
||||||
|
let content_norm = file_content.replace("\r\n", "\n").replace('\r', "\n");
|
||||||
|
|
||||||
|
// Determine and validate the search range
|
||||||
|
let search_start = start_char.unwrap_or(0);
|
||||||
|
let search_end = end_char.unwrap_or(content_norm.len());
|
||||||
|
|
||||||
|
if search_start > content_norm.len() {
|
||||||
|
anyhow::bail!(
|
||||||
|
"start position {} exceeds file length {}",
|
||||||
|
search_start,
|
||||||
|
content_norm.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if search_end > content_norm.len() {
|
||||||
|
anyhow::bail!(
|
||||||
|
"end position {} exceeds file length {}",
|
||||||
|
search_end,
|
||||||
|
content_norm.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if search_start > search_end {
|
||||||
|
anyhow::bail!(
|
||||||
|
"start position {} is greater than end position {}",
|
||||||
|
search_start,
|
||||||
|
search_end
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the region we're going to modify, ensuring we're at char boundaries
|
||||||
|
// Find the nearest valid char boundaries
|
||||||
|
let start_boundary = if search_start == 0 {
|
||||||
|
0
|
||||||
|
} else {
|
||||||
|
content_norm
|
||||||
|
.char_indices()
|
||||||
|
.find(|(i, _)| *i >= search_start)
|
||||||
|
.map(|(i, _)| i)
|
||||||
|
.unwrap_or(search_start)
|
||||||
|
};
|
||||||
|
let end_boundary = content_norm
|
||||||
|
.char_indices()
|
||||||
|
.find(|(i, _)| *i >= search_end)
|
||||||
|
.map(|(i, _)| i)
|
||||||
|
.unwrap_or(content_norm.len());
|
||||||
|
|
||||||
|
let mut region_content = content_norm[start_boundary..end_boundary].to_string();
|
||||||
|
|
||||||
|
// Apply hunks in order
|
||||||
|
for (idx, (old_block, new_block)) in hunks.iter().enumerate() {
|
||||||
|
debug!(
|
||||||
|
"Applying hunk {}: old_len={}, new_len={}",
|
||||||
|
idx + 1,
|
||||||
|
old_block.len(),
|
||||||
|
new_block.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(pos) = region_content.find(old_block) {
|
||||||
|
let endpos = pos + old_block.len();
|
||||||
|
region_content.replace_range(pos..endpos, new_block);
|
||||||
|
} else {
|
||||||
|
// Not found; provide helpful diagnostics with a short preview
|
||||||
|
let preview_len = old_block.len().min(200);
|
||||||
|
let mut old_preview = old_block[..preview_len].to_string();
|
||||||
|
if old_block.len() > preview_len {
|
||||||
|
old_preview.push_str("...");
|
||||||
|
}
|
||||||
|
|
||||||
|
let range_note = if start_char.is_some() || end_char.is_some() {
|
||||||
|
format!(
|
||||||
|
" (within character range {}:{})",
|
||||||
|
start_boundary, end_boundary
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
anyhow::bail!(
|
||||||
|
"Pattern not found in file{}\nHunk {} failed. Searched for:\n{}",
|
||||||
|
range_note,
|
||||||
|
idx + 1,
|
||||||
|
old_preview
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reconstruct the full content with the modified region
|
||||||
|
let mut result = String::with_capacity(content_norm.len() + region_content.len());
|
||||||
|
result.push_str(&content_norm[..start_boundary]);
|
||||||
|
result.push_str(®ion_content);
|
||||||
|
result.push_str(&content_norm[end_boundary..]);
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a unified diff into a list of hunks as (old_block, new_block).
|
||||||
|
/// Each hunk contains the exact text to search for and the replacement text including context lines.
|
||||||
|
pub fn parse_unified_diff_hunks(diff: &str) -> Vec<(String, String)> {
|
||||||
|
let mut hunks: Vec<(String, String)> = Vec::new();
|
||||||
|
|
||||||
|
let mut old_lines: Vec<String> = Vec::new();
|
||||||
|
let mut new_lines: Vec<String> = Vec::new();
|
||||||
|
let mut in_hunk = false;
|
||||||
|
|
||||||
|
for raw_line in diff.lines() {
|
||||||
|
let line = raw_line;
|
||||||
|
|
||||||
|
// Skip common diff headers
|
||||||
|
if line.starts_with("diff ")
|
||||||
|
|| line.starts_with("index ")
|
||||||
|
|| line.starts_with("new file mode")
|
||||||
|
|| line.starts_with("deleted file mode")
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if line.starts_with("--- ") || line.starts_with("+++ ") {
|
||||||
|
// File header lines — ignore
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if line.starts_with("@@") {
|
||||||
|
// Starting a new hunk — flush previous if present
|
||||||
|
if in_hunk && (!old_lines.is_empty() || !new_lines.is_empty()) {
|
||||||
|
hunks.push((old_lines.join("\n"), new_lines.join("\n")));
|
||||||
|
old_lines.clear();
|
||||||
|
new_lines.clear();
|
||||||
|
}
|
||||||
|
in_hunk = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !in_hunk {
|
||||||
|
// Some minimal diffs may omit @@; start collecting once we see diff markers
|
||||||
|
if line.starts_with(' ')
|
||||||
|
|| (line.starts_with('-') && !line.starts_with("---"))
|
||||||
|
|| (line.starts_with('+') && !line.starts_with("+++"))
|
||||||
|
{
|
||||||
|
in_hunk = true;
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(content) = line.strip_prefix(' ') {
|
||||||
|
old_lines.push(content.to_string());
|
||||||
|
new_lines.push(content.to_string());
|
||||||
|
} else if line.starts_with('+') && !line.starts_with("+++") {
|
||||||
|
new_lines.push(line[1..].to_string());
|
||||||
|
} else if line.starts_with('-') && !line.starts_with("---") {
|
||||||
|
old_lines.push(line[1..].to_string());
|
||||||
|
} else if line.starts_with('\\') {
|
||||||
|
// Example: "\\ No newline at end of file" — ignore
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
// Unknown line type — ignore
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if in_hunk && (!old_lines.is_empty() || !new_lines.is_empty()) {
|
||||||
|
hunks.push((old_lines.join("\n"), new_lines.join("\n")));
|
||||||
|
}
|
||||||
|
|
||||||
|
hunks
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to properly escape shell commands.
|
||||||
|
/// Handles file paths with spaces and other special characters.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn shell_escape_command(command: &str) -> String {
|
||||||
|
let parts: Vec<&str> = command.split_whitespace().collect();
|
||||||
|
if parts.is_empty() {
|
||||||
|
return command.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let cmd = parts[0];
|
||||||
|
|
||||||
|
// Commands that typically take file paths as arguments
|
||||||
|
let file_commands = [
|
||||||
|
"cat", "ls", "cp", "mv", "rm", "chmod", "chown", "file", "head", "tail", "wc", "grep",
|
||||||
|
];
|
||||||
|
|
||||||
|
if file_commands.contains(&cmd) {
|
||||||
|
// For file commands, we need to be smarter about escaping
|
||||||
|
// Check if the command already has proper quoting
|
||||||
|
if command.contains('"') || command.contains('\'') {
|
||||||
|
// Already has some quoting, use as-is
|
||||||
|
return command.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look for file paths that need escaping (contain spaces but aren't quoted)
|
||||||
|
let mut escaped_command = String::new();
|
||||||
|
let mut in_quotes = false;
|
||||||
|
let mut current_word = String::new();
|
||||||
|
let mut words = Vec::new();
|
||||||
|
|
||||||
|
for ch in command.chars() {
|
||||||
|
match ch {
|
||||||
|
' ' if !in_quotes => {
|
||||||
|
if !current_word.is_empty() {
|
||||||
|
words.push(current_word.clone());
|
||||||
|
current_word.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'"' => {
|
||||||
|
in_quotes = !in_quotes;
|
||||||
|
current_word.push(ch);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
current_word.push(ch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !current_word.is_empty() {
|
||||||
|
words.push(current_word);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reconstruct the command with proper escaping
|
||||||
|
for (i, word) in words.iter().enumerate() {
|
||||||
|
if i > 0 {
|
||||||
|
escaped_command.push(' ');
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this word looks like a file path (contains / or ~) and has spaces, quote it
|
||||||
|
if word.contains('/') || word.starts_with('~') {
|
||||||
|
if word.contains(' ') && !word.starts_with('"') && !word.starts_with('\'') {
|
||||||
|
escaped_command.push_str(&format!("\"{}\"", word));
|
||||||
|
} else {
|
||||||
|
escaped_command.push_str(word);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
escaped_command.push_str(word);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
escaped_command
|
||||||
|
} else {
|
||||||
|
// For non-file commands, use the original command
|
||||||
|
command.to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to fix nested quotes in shell commands within JSON.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn fix_nested_quotes_in_shell_command(json_str: &str) -> String {
|
||||||
|
// Look for the pattern: "command": "
|
||||||
|
if let Some(command_start) = json_str.find(r#""command": ""#) {
|
||||||
|
let command_value_start = command_start + r#""command": ""#.len();
|
||||||
|
|
||||||
|
// Find the end of the command string by looking for the pattern "}
|
||||||
|
if let Some(end_marker) = json_str[command_value_start..].find(r#"" }"#) {
|
||||||
|
let command_end = command_value_start + end_marker;
|
||||||
|
|
||||||
|
let before = &json_str[..command_value_start];
|
||||||
|
let command_content = &json_str[command_value_start..command_end];
|
||||||
|
let after = &json_str[command_end..];
|
||||||
|
|
||||||
|
// Fix the command content by properly escaping quotes
|
||||||
|
let mut fixed_command = String::new();
|
||||||
|
let mut chars = command_content.chars().peekable();
|
||||||
|
|
||||||
|
while let Some(ch) = chars.next() {
|
||||||
|
match ch {
|
||||||
|
'"' => {
|
||||||
|
// Check if this quote is already escaped
|
||||||
|
if fixed_command.ends_with('\\') {
|
||||||
|
fixed_command.push(ch); // Already escaped, keep as-is
|
||||||
|
} else {
|
||||||
|
fixed_command.push_str(r#"\""#); // Escape the quote
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'\\' => {
|
||||||
|
// Check what follows the backslash
|
||||||
|
if let Some(&next_ch) = chars.peek() {
|
||||||
|
if next_ch == '"' {
|
||||||
|
// This is an escaped quote, keep the backslash
|
||||||
|
fixed_command.push(ch);
|
||||||
|
} else {
|
||||||
|
// Regular backslash, escape it
|
||||||
|
fixed_command.push_str(r#"\\"#);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Backslash at end, escape it
|
||||||
|
fixed_command.push_str(r#"\\"#);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => fixed_command.push(ch),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return format!("{}{}{}", before, fixed_command, after);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: if we can't parse the structure, return as-is
|
||||||
|
json_str.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to fix mixed quotes in JSON (single quotes where double quotes should be).
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn fix_mixed_quotes_in_json(json_str: &str) -> String {
|
||||||
|
let mut result = String::new();
|
||||||
|
let mut chars = json_str.chars().peekable();
|
||||||
|
let mut in_string = false;
|
||||||
|
let mut string_delimiter = '"';
|
||||||
|
|
||||||
|
while let Some(ch) = chars.next() {
|
||||||
|
match ch {
|
||||||
|
'"' if !in_string => {
|
||||||
|
// Start of a double-quoted string
|
||||||
|
in_string = true;
|
||||||
|
string_delimiter = '"';
|
||||||
|
result.push(ch);
|
||||||
|
}
|
||||||
|
'\'' if !in_string => {
|
||||||
|
// Start of a single-quoted string - convert to double quotes
|
||||||
|
in_string = true;
|
||||||
|
string_delimiter = '\'';
|
||||||
|
result.push('"'); // Convert single quote to double quote
|
||||||
|
}
|
||||||
|
c if in_string && c == string_delimiter => {
|
||||||
|
// End of current string
|
||||||
|
if string_delimiter == '\'' {
|
||||||
|
result.push('"'); // Convert single quote to double quote
|
||||||
|
} else {
|
||||||
|
result.push(c);
|
||||||
|
}
|
||||||
|
in_string = false;
|
||||||
|
}
|
||||||
|
'"' if in_string && string_delimiter == '\'' => {
|
||||||
|
// Double quote inside single-quoted string - escape it
|
||||||
|
result.push_str(r#"\""#);
|
||||||
|
}
|
||||||
|
'\\' if in_string => {
|
||||||
|
// Escape sequence - preserve it
|
||||||
|
result.push(ch);
|
||||||
|
if chars.peek().is_some() {
|
||||||
|
result.push(chars.next().unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
result.push(ch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parses_minimal_unified_diff_without_hunk_header() {
|
||||||
|
let diff = "--- old\n-old text\n+++ new\n+new text\n";
|
||||||
|
let hunks = parse_unified_diff_hunks(diff);
|
||||||
|
assert_eq!(hunks.len(), 1);
|
||||||
|
assert_eq!(hunks[0].0, "old text");
|
||||||
|
assert_eq!(hunks[0].1, "new text");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parses_diff_with_context_and_hunk_headers() {
|
||||||
|
let diff = "@@ -1,3 +1,3 @@\n common\n-old\n+new\n common2\n";
|
||||||
|
let hunks = parse_unified_diff_hunks(diff);
|
||||||
|
assert_eq!(hunks.len(), 1);
|
||||||
|
assert_eq!(hunks[0].0, "common\nold\ncommon2");
|
||||||
|
assert_eq!(hunks[0].1, "common\nnew\ncommon2");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn apply_multi_hunk_unified_diff_to_string() {
|
||||||
|
let original = "line 1\nkeep\nold A\nkeep 2\nold B\nkeep 3\n";
|
||||||
|
let diff =
|
||||||
|
"@@ -1,6 +1,6 @@\n line 1\n keep\n-old A\n+new A\n keep 2\n-old B\n+new B\n keep 3\n";
|
||||||
|
let result = apply_unified_diff_to_string(original, diff, None, None).unwrap();
|
||||||
|
let expected = "line 1\nkeep\nnew A\nkeep 2\nnew B\nkeep 3\n";
|
||||||
|
assert_eq!(result, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn apply_diff_within_range_only() {
|
||||||
|
let original = "A\nold\nB\nold\nC\n";
|
||||||
|
// Only the first 'old' should be replaced due to range
|
||||||
|
let diff = "@@ -1,3 +1,3 @@\n A\n-old\n+NEW\n B\n";
|
||||||
|
let start = 0usize; // Start of file
|
||||||
|
let end = original.find("B\n").unwrap() + 2; // up to end of line 'B\n'
|
||||||
|
let result = apply_unified_diff_to_string(original, diff, Some(start), Some(end)).unwrap();
|
||||||
|
let expected = "A\nNEW\nB\nold\nC\n";
|
||||||
|
assert_eq!(result, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn shell_escape_preserves_simple_commands() {
|
||||||
|
assert_eq!(shell_escape_command("ls -la"), "ls -la");
|
||||||
|
assert_eq!(shell_escape_command("echo hello"), "echo hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fix_mixed_quotes_converts_single_to_double() {
|
||||||
|
let input = "{'key': 'value'}";
|
||||||
|
let result = fix_mixed_quotes_in_json(input);
|
||||||
|
assert_eq!(result, "{\"key\": \"value\"}");
|
||||||
|
}
|
||||||
|
}
|
||||||
133
crates/g3-core/src/webdriver_session.rs
Normal file
133
crates/g3-core/src/webdriver_session.rs
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
//! Unified WebDriver session abstraction.
|
||||||
|
//!
|
||||||
|
//! This module provides a unified interface for browser automation
|
||||||
|
//! that can work with either Safari or Chrome WebDriver.
|
||||||
|
|
||||||
|
use g3_computer_control::{ChromeDriver, SafariDriver, WebDriverController, WebElement};
|
||||||
|
|
||||||
|
/// Unified WebDriver session that can hold either Safari or Chrome driver.
|
||||||
|
pub enum WebDriverSession {
|
||||||
|
Safari(SafariDriver),
|
||||||
|
Chrome(ChromeDriver),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl WebDriverController for WebDriverSession {
|
||||||
|
async fn navigate(&mut self, url: &str) -> anyhow::Result<()> {
|
||||||
|
match self {
|
||||||
|
WebDriverSession::Safari(driver) => driver.navigate(url).await,
|
||||||
|
WebDriverSession::Chrome(driver) => driver.navigate(url).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn current_url(&self) -> anyhow::Result<String> {
|
||||||
|
match self {
|
||||||
|
WebDriverSession::Safari(driver) => driver.current_url().await,
|
||||||
|
WebDriverSession::Chrome(driver) => driver.current_url().await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn title(&self) -> anyhow::Result<String> {
|
||||||
|
match self {
|
||||||
|
WebDriverSession::Safari(driver) => driver.title().await,
|
||||||
|
WebDriverSession::Chrome(driver) => driver.title().await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn find_element(
|
||||||
|
&mut self,
|
||||||
|
selector: &str,
|
||||||
|
) -> anyhow::Result<WebElement> {
|
||||||
|
match self {
|
||||||
|
WebDriverSession::Safari(driver) => driver.find_element(selector).await,
|
||||||
|
WebDriverSession::Chrome(driver) => driver.find_element(selector).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn find_elements(
|
||||||
|
&mut self,
|
||||||
|
selector: &str,
|
||||||
|
) -> anyhow::Result<Vec<WebElement>> {
|
||||||
|
match self {
|
||||||
|
WebDriverSession::Safari(driver) => driver.find_elements(selector).await,
|
||||||
|
WebDriverSession::Chrome(driver) => driver.find_elements(selector).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute_script(
|
||||||
|
&mut self,
|
||||||
|
script: &str,
|
||||||
|
args: Vec<serde_json::Value>,
|
||||||
|
) -> anyhow::Result<serde_json::Value> {
|
||||||
|
match self {
|
||||||
|
WebDriverSession::Safari(driver) => driver.execute_script(script, args).await,
|
||||||
|
WebDriverSession::Chrome(driver) => driver.execute_script(script, args).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn page_source(&self) -> anyhow::Result<String> {
|
||||||
|
match self {
|
||||||
|
WebDriverSession::Safari(driver) => driver.page_source().await,
|
||||||
|
WebDriverSession::Chrome(driver) => driver.page_source().await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn screenshot(&mut self, path: &str) -> anyhow::Result<()> {
|
||||||
|
match self {
|
||||||
|
WebDriverSession::Safari(driver) => driver.screenshot(path).await,
|
||||||
|
WebDriverSession::Chrome(driver) => driver.screenshot(path).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn close(&mut self) -> anyhow::Result<()> {
|
||||||
|
match self {
|
||||||
|
WebDriverSession::Safari(driver) => driver.close().await,
|
||||||
|
WebDriverSession::Chrome(driver) => driver.close().await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn quit(self) -> anyhow::Result<()> {
|
||||||
|
match self {
|
||||||
|
WebDriverSession::Safari(driver) => driver.quit().await,
|
||||||
|
WebDriverSession::Chrome(driver) => driver.quit().await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Additional methods for WebDriverSession that aren't part of the WebDriverController trait
|
||||||
|
impl WebDriverSession {
|
||||||
|
pub async fn back(&mut self) -> anyhow::Result<()> {
|
||||||
|
match self {
|
||||||
|
WebDriverSession::Safari(driver) => driver.back().await,
|
||||||
|
WebDriverSession::Chrome(driver) => driver.back().await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn forward(&mut self) -> anyhow::Result<()> {
|
||||||
|
match self {
|
||||||
|
WebDriverSession::Safari(driver) => driver.forward().await,
|
||||||
|
WebDriverSession::Chrome(driver) => driver.forward().await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn refresh(&mut self) -> anyhow::Result<()> {
|
||||||
|
match self {
|
||||||
|
WebDriverSession::Safari(driver) => driver.refresh().await,
|
||||||
|
WebDriverSession::Chrome(driver) => driver.refresh().await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_webdriver_session_enum_variants() {
|
||||||
|
// This test just verifies the enum structure compiles correctly
|
||||||
|
// Actual WebDriver tests would require a running browser
|
||||||
|
fn _assert_send<T: Send>() {}
|
||||||
|
fn _assert_sync<T: Sync>() {}
|
||||||
|
// WebDriverSession should be Send but not necessarily Sync due to internal state
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user