replace tesseract with apple vision

This commit is contained in:
Dhanji Prasanna
2025-10-24 15:35:47 +11:00
parent d0ac222e2e
commit 61d748034d
16 changed files with 785 additions and 318 deletions

View File

@@ -3,6 +3,9 @@ name = "g3-computer-control"
version = "0.1.0"
edition = "2021"
[build-dependencies]
# Only needed for building Swift bridge on macOS
[dependencies]
# Workspace dependencies
tokio = { workspace = true }
@@ -20,9 +23,6 @@ async-trait = "0.1"
# WebDriver support
fantoccini = "0.21"
# OCR dependencies
tesseract = "0.14"
# macOS dependencies
[target.'cfg(target_os = "macos")'.dependencies]
core-graphics = "0.23"

View File

@@ -0,0 +1,63 @@
use std::env;
use std::path::PathBuf;
use std::process::Command;
fn main() {
// Only build Vision bridge on macOS
if env::var("CARGO_CFG_TARGET_OS").unwrap() != "macos" {
return;
}
println!("cargo:rerun-if-changed=vision-bridge/Sources/VisionBridge/VisionOCR.swift");
println!("cargo:rerun-if-changed=vision-bridge/Sources/VisionBridge/VisionBridge.h");
println!("cargo:rerun-if-changed=vision-bridge/Package.swift");
let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap());
let vision_bridge_dir = manifest_dir.join("vision-bridge");
// Build Swift package
println!("cargo:warning=Building VisionBridge Swift package...");
let build_status = Command::new("swift")
.args(&["build", "-c", "release"])
.current_dir(&vision_bridge_dir)
.status()
.expect("Failed to build Swift package");
if !build_status.success() {
panic!("Swift build failed");
}
// Find the built library
let lib_path = vision_bridge_dir
.join(".build/release")
.canonicalize()
.expect("Failed to find .build/release directory");
// Copy the dylib to the output directory so it can be found at runtime
let target_dir = manifest_dir.parent().unwrap().parent().unwrap().join("target");
let profile = env::var("PROFILE").unwrap_or_else(|_| "debug".to_string());
let output_dir = target_dir.join(&profile);
let dylib_src = lib_path.join("libVisionBridge.dylib");
let dylib_dst = output_dir.join("libVisionBridge.dylib");
std::fs::copy(&dylib_src, &dylib_dst)
.expect(&format!("Failed to copy dylib from {} to {}", dylib_src.display(), dylib_dst.display()));
println!("cargo:warning=Copied libVisionBridge.dylib to {}", dylib_dst.display());
// Add rpath so the dylib can be found at runtime
println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path");
println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path");
println!("cargo:rustc-link-search=native={}", lib_path.display());
println!("cargo:rustc-link-lib=dylib=VisionBridge");
// Link required frameworks
println!("cargo:rustc-link-lib=framework=Vision");
println!("cargo:rustc-link-lib=framework=AppKit");
println!("cargo:rustc-link-lib=framework=Foundation");
println!("cargo:rustc-link-lib=framework=CoreGraphics");
println!("cargo:rustc-link-lib=framework=CoreImage");
println!("cargo:warning=VisionBridge built successfully at {}", lib_path.display());
}

View File

@@ -0,0 +1,85 @@
use g3_computer_control::ocr::{OCREngine, DefaultOCR};
use anyhow::Result;
#[tokio::main]
async fn main() -> Result<()> {
println!("🧪 Testing Apple Vision OCR");
println!("===========================\n");
// Initialize OCR engine
println!("📦 Initializing OCR engine...");
let ocr = DefaultOCR::new()?;
println!("✅ OCR engine: {}\n", ocr.name());
// Check if test image exists
let test_image = "/tmp/safari_test.png";
if !std::path::Path::new(test_image).exists() {
println!("⚠️ Test image not found: {}", test_image);
println!(" Creating a screenshot...");
let status = std::process::Command::new("screencapture")
.arg("-x")
.arg("-R")
.arg("0,0,1200,800")
.arg(test_image)
.status()?;
if !status.success() {
anyhow::bail!("Failed to create screenshot");
}
println!("✅ Screenshot created\n");
}
// Run OCR
println!("🔍 Running Apple Vision OCR on {}...", test_image);
let start = std::time::Instant::now();
let locations = ocr.extract_text_with_locations(test_image).await?;
let duration = start.elapsed();
println!("✅ OCR completed in {:.3}s\n", duration.as_secs_f64());
// Display results
println!("📊 Results:");
println!(" Found {} text elements\n", locations.len());
if locations.is_empty() {
println!("⚠️ No text found in image");
} else {
println!(" Top 20 results:");
println!(" {:<4} {:<40} {:<15} {:<12} {:<8}", "#", "Text", "Position", "Size", "Conf");
println!(" {}", "-".repeat(85));
for (i, loc) in locations.iter().take(20).enumerate() {
let text = if loc.text.len() > 37 {
format!("{}...", &loc.text[..37])
} else {
loc.text.clone()
};
println!(" {:<4} {:<40} ({:>4},{:>4}) {:>4}x{:<4} {:.2}",
i + 1,
text,
loc.x,
loc.y,
loc.width,
loc.height,
loc.confidence
);
}
if locations.len() > 20 {
println!("\n ... and {} more", locations.len() - 20);
}
// Performance comparison
println!("\n📈 Performance:");
println!(" OCR Speed: {:.3}s", duration.as_secs_f64());
println!(" Text elements: {}", locations.len());
println!(" Avg per element: {:.1}ms", duration.as_millis() as f64 / locations.len() as f64);
}
println!("\n✅ Test complete!");
Ok(())
}

View File

@@ -3,6 +3,7 @@
pub mod types;
pub mod platform;
pub mod ocr;
pub mod webdriver;
pub mod macax;
@@ -25,11 +26,11 @@ pub trait ComputerController: Send + Sync {
async fn extract_text_from_screen(&self, region: Rect) -> Result<String>;
async fn extract_text_from_image(&self, path: &str) -> Result<String>;
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>>;
async fn find_text_on_screen(&self, search_text: &str) -> Result<Option<TextLocation>>;
async fn find_text_in_app(&self, app_name: &str, search_text: &str) -> Result<Option<TextLocation>>;
// Mouse operations
fn move_mouse(&self, x: i32, y: i32) -> Result<()>;
fn click_at(&self, x: i32, y: i32) -> Result<()>;
fn click_at(&self, x: i32, y: i32, app_name: Option<&str>) -> Result<()>;
}
// Platform-specific constructor

View File

@@ -0,0 +1,26 @@
use crate::types::TextLocation;
use anyhow::Result;
use async_trait::async_trait;
/// OCR engine trait for text recognition with bounding boxes
#[async_trait]
pub trait OCREngine: Send + Sync {
/// Extract text with locations from an image file
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>>;
/// Get the name of the OCR engine
fn name(&self) -> &str;
}
// Platform-specific modules
#[cfg(target_os = "macos")]
pub mod vision;
pub mod tesseract;
// Re-export the default OCR engine for the platform
#[cfg(target_os = "macos")]
pub use vision::AppleVisionOCR as DefaultOCR;
#[cfg(not(target_os = "macos"))]
pub use tesseract::TesseractOCR as DefaultOCR;

View File

@@ -0,0 +1,84 @@
use super::OCREngine;
use crate::types::TextLocation;
use anyhow::Result;
use async_trait::async_trait;
/// Tesseract OCR engine (fallback/cross-platform)
pub struct TesseractOCR;
impl TesseractOCR {
pub fn new() -> Result<Self> {
// Check if tesseract is available
let tesseract_check = std::process::Command::new("which")
.arg("tesseract")
.output();
if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() {
anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\
To install tesseract:\n macOS: brew install tesseract\n \
Linux: sudo apt-get install tesseract-ocr (Ubuntu/Debian)\n \
sudo yum install tesseract (RHEL/CentOS)\n \
Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki\n\n\
After installation, restart your terminal and try again.");
}
Ok(Self)
}
}
#[async_trait]
impl OCREngine for TesseractOCR {
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>> {
// Use tesseract CLI with TSV output to get bounding boxes
let output = std::process::Command::new("tesseract")
.arg(path)
.arg("stdout")
.arg("tsv")
.output()
.map_err(|e| anyhow::anyhow!("Failed to run tesseract: {}", e))?;
if !output.status.success() {
anyhow::bail!("Tesseract failed: {}", String::from_utf8_lossy(&output.stderr));
}
let tsv_text = String::from_utf8_lossy(&output.stdout);
let mut locations = Vec::new();
// Parse TSV output (skip header line)
for (i, line) in tsv_text.lines().enumerate() {
if i == 0 { continue; } // Skip header
let parts: Vec<&str> = line.split('\t').collect();
if parts.len() >= 12 {
// TSV format: level, page_num, block_num, par_num, line_num, word_num,
// left, top, width, height, conf, text
if let (Ok(x), Ok(y), Ok(w), Ok(h), Ok(conf), text) = (
parts[6].parse::<i32>(),
parts[7].parse::<i32>(),
parts[8].parse::<i32>(),
parts[9].parse::<i32>(),
parts[10].parse::<f32>(),
parts[11],
) {
let trimmed = text.trim();
if !trimmed.is_empty() && conf > 0.0 {
locations.push(TextLocation {
text: trimmed.to_string(),
x,
y,
width: w,
height: h,
confidence: conf / 100.0, // Convert from 0-100 to 0-1
});
}
}
}
}
Ok(locations)
}
fn name(&self) -> &str {
"Tesseract OCR"
}
}

View File

@@ -0,0 +1,103 @@
use super::OCREngine;
use crate::types::TextLocation;
use anyhow::{Result, Context};
use async_trait::async_trait;
use std::ffi::{CStr, CString};
use std::os::raw::{c_char, c_float, c_uint};
// FFI bindings to Swift VisionBridge
#[repr(C)]
struct VisionTextBox {
text: *const c_char,
text_len: c_uint,
x: i32,
y: i32,
width: i32,
height: i32,
confidence: c_float,
}
extern "C" {
fn vision_recognize_text(
image_path: *const c_char,
image_path_len: c_uint,
out_boxes: *mut *mut std::ffi::c_void,
out_count: *mut c_uint,
) -> bool;
fn vision_free_boxes(boxes: *mut std::ffi::c_void, count: c_uint);
}
/// Apple Vision Framework OCR engine
pub struct AppleVisionOCR;
impl AppleVisionOCR {
pub fn new() -> Result<Self> {
Ok(Self)
}
}
#[async_trait]
impl OCREngine for AppleVisionOCR {
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>> {
// Convert path to C string
let c_path = CString::new(path)
.context("Failed to convert path to C string")?;
let mut boxes_ptr: *mut std::ffi::c_void = std::ptr::null_mut();
let mut count: c_uint = 0;
// Call Swift Vision API
let success = unsafe {
vision_recognize_text(
c_path.as_ptr(),
path.len() as c_uint,
&mut boxes_ptr,
&mut count,
)
};
if !success || boxes_ptr.is_null() {
anyhow::bail!("Apple Vision OCR failed");
}
// Convert C array to Rust Vec
let mut locations = Vec::new();
unsafe {
let typed_boxes = boxes_ptr as *const VisionTextBox;
let boxes_slice = std::slice::from_raw_parts(typed_boxes, count as usize);
for box_data in boxes_slice {
// Convert C string to Rust String
let text = if !box_data.text.is_null() {
CStr::from_ptr(box_data.text)
.to_string_lossy()
.into_owned()
} else {
String::new()
};
if !text.is_empty() {
locations.push(TextLocation {
text,
x: box_data.x,
y: box_data.y,
width: box_data.width,
height: box_data.height,
confidence: box_data.confidence,
});
}
}
// Free the C array
vision_free_boxes(boxes_ptr, count);
}
Ok(locations)
}
fn name(&self) -> &str {
"Apple Vision Framework"
}
}

View File

@@ -1,16 +1,21 @@
use crate::{ComputerController, types::{Rect, TextLocation}};
use crate::ocr::{OCREngine, DefaultOCR};
use anyhow::{Result, Context};
use async_trait::async_trait;
use std::path::Path;
use tesseract::Tesseract;
pub struct MacOSController {
// Empty struct for now
ocr_engine: Box<dyn OCREngine>,
#[allow(dead_code)]
ocr_name: String,
}
impl MacOSController {
pub fn new() -> Result<Self> {
Ok(Self {})
let ocr = Box::new(DefaultOCR::new()?);
let ocr_name = ocr.name().to_string();
tracing::info!("Initialized macOS controller with OCR engine: {}", ocr_name);
Ok(Self { ocr_engine: ocr, ocr_name })
}
}
@@ -90,95 +95,21 @@ impl ComputerController for MacOSController {
}
async fn extract_text_from_image(&self, path: &str) -> Result<String> {
// Check if tesseract is available on the system
let tesseract_check = std::process::Command::new("which")
.arg("tesseract")
.output();
if tesseract_check.is_err() || !tesseract_check.as_ref().unwrap().status.success() {
anyhow::bail!("Tesseract OCR is not installed on your system.\n\n\
To install tesseract:\n macOS: brew install tesseract\n \
Linux: sudo apt-get install tesseract-ocr (Ubuntu/Debian)\n \
sudo yum install tesseract (RHEL/CentOS)\n \
Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki\n\n\
After installation, restart your terminal and try again.");
}
// Initialize Tesseract
let tess = Tesseract::new(None, Some("eng"))
.map_err(|e| {
anyhow::anyhow!("Failed to initialize Tesseract: {}\n\n\
This usually means:\n1. Tesseract is not properly installed\n\
2. Language data files are missing\n\nTo fix:\n \
macOS: brew reinstall tesseract\n \
Linux: sudo apt-get install tesseract-ocr-eng\n \
Windows: Reinstall tesseract and ensure language files are included", e)
})?;
let text = tess.set_image(path)
.map_err(|e| anyhow::anyhow!("Failed to load image '{}': {}", path, e))?
.get_text()
.map_err(|e| anyhow::anyhow!("Failed to extract text from image: {}", e))?;
Ok(text)
// Extract all text and concatenate
let locations = self.ocr_engine.extract_text_with_locations(path).await?;
Ok(locations.iter().map(|loc| loc.text.as_str()).collect::<Vec<_>>().join(" "))
}
async fn extract_text_with_locations(&self, path: &str) -> Result<Vec<TextLocation>> {
// For now, use tesseract CLI with TSV output to get bounding boxes
// This is a workaround since the Rust tesseract crate doesn't expose get_component_boxes
let output = std::process::Command::new("tesseract")
.arg(path)
.arg("stdout")
.arg("tsv")
.output()
.map_err(|e| anyhow::anyhow!("Failed to run tesseract: {}", e))?;
if !output.status.success() {
anyhow::bail!("Tesseract failed: {}", String::from_utf8_lossy(&output.stderr));
}
let tsv_text = String::from_utf8_lossy(&output.stdout);
let mut locations = Vec::new();
// Parse TSV output (skip header line)
for (i, line) in tsv_text.lines().enumerate() {
if i == 0 { continue; } // Skip header
let parts: Vec<&str> = line.split('\t').collect();
if parts.len() >= 12 {
// TSV format: level, page_num, block_num, par_num, line_num, word_num,
// left, top, width, height, conf, text
if let (Ok(x), Ok(y), Ok(w), Ok(h), Ok(conf), text) = (
parts[6].parse::<i32>(),
parts[7].parse::<i32>(),
parts[8].parse::<i32>(),
parts[9].parse::<i32>(),
parts[10].parse::<f32>(),
parts[11],
) {
let trimmed = text.trim();
if !trimmed.is_empty() && conf > 0.0 {
locations.push(TextLocation {
text: trimmed.to_string(),
x,
y,
width: w,
height: h,
confidence: conf / 100.0, // Convert from 0-100 to 0-1
});
}
}
}
}
Ok(locations)
// Use the OCR engine
self.ocr_engine.extract_text_with_locations(path).await
}
async fn find_text_on_screen(&self, search_text: &str) -> Result<Option<TextLocation>> {
// Take full screenshot
async fn find_text_in_app(&self, app_name: &str, search_text: &str) -> Result<Option<TextLocation>> {
// Take screenshot of specific app window
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
let temp_path = format!("{}/Desktop/g3_find_text_{}.png", home, uuid::Uuid::new_v4());
self.take_screenshot(&temp_path, None, None).await?;
let temp_path = format!("{}/Desktop/g3_find_text_{}_{}.png", home, app_name, uuid::Uuid::new_v4());
self.take_screenshot(&temp_path, None, Some(app_name)).await?;
// Extract all text with locations
let locations = self.extract_text_with_locations(&temp_path).await?;
@@ -221,7 +152,44 @@ impl ComputerController for MacOSController {
Ok(())
}
fn click_at(&self, x: i32, y: i32) -> Result<()> {
fn click_at(&self, x: i32, y: i32, app_name: Option<&str>) -> Result<()> {
// If app_name is provided, get window position and offset coordinates
let (global_x, global_y) = if let Some(app) = app_name {
// Get window position using AppleScript
let script = format!(
r#"tell application "{}" to get bounds of window 1"#,
app
);
let output = std::process::Command::new("osascript")
.arg("-e")
.arg(&script)
.output()?;
if output.status.success() {
let bounds_str = String::from_utf8_lossy(&output.stdout);
// Parse bounds: "x1, y1, x2, y2"
let parts: Vec<&str> = bounds_str.trim().split(", ").collect();
if parts.len() >= 2 {
if let (Ok(window_x), Ok(window_y)) = (
parts[0].trim().parse::<i32>(),
parts[1].trim().parse::<i32>(),
) {
// Offset relative coordinates by window position
(x + window_x, y + window_y)
} else {
(x, y) // Fallback to absolute coordinates
}
} else {
(x, y) // Fallback to absolute coordinates
}
} else {
(x, y) // Fallback to absolute coordinates
}
} else {
(x, y) // No app name, use absolute coordinates
};
use core_graphics::event::{
CGEvent, CGEventTapLocation, CGEventType, CGMouseButton,
};
@@ -233,7 +201,7 @@ impl ComputerController for MacOSController {
let source = CGEventSource::new(CGEventSourceStateID::HIDSystemState)
.ok().context("Failed to create event source")?;
let point = CGPoint::new(x as f64, y as f64);
let point = CGPoint::new(global_x as f64, global_y as f64);
// Move mouse to position first
let move_event = CGEvent::new_mouse_event(

View File

@@ -0,0 +1,24 @@
// swift-tools-version:5.9
import PackageDescription
let package = Package(
name: "VisionBridge",
platforms: [
.macOS(.v11)
],
products: [
.library(
name: "VisionBridge",
type: .dynamic,
targets: ["VisionBridge"]
),
],
targets: [
.target(
name: "VisionBridge",
dependencies: [],
path: "Sources/VisionBridge",
publicHeadersPath: "."
),
]
)

View File

@@ -0,0 +1,39 @@
#ifndef VisionBridge_h
#define VisionBridge_h
#include <stdint.h>
#include <stdbool.h>
#ifdef __cplusplus
extern "C" {
#endif
// Text box structure for FFI
typedef struct {
const char* text;
uint32_t text_len;
int32_t x;
int32_t y;
int32_t width;
int32_t height;
float confidence;
} VisionTextBox;
// Recognize text in an image and return bounding boxes
// Returns true on success, false on failure
// Caller must free the returned boxes using vision_free_boxes
bool vision_recognize_text(
const char* image_path,
uint32_t image_path_len,
VisionTextBox** out_boxes,
uint32_t* out_count
);
// Free memory allocated by vision_recognize_text
void vision_free_boxes(VisionTextBox* boxes, uint32_t count);
#ifdef __cplusplus
}
#endif
#endif /* VisionBridge_h */

View File

@@ -0,0 +1,145 @@
import Foundation
import Vision
import AppKit
import CoreGraphics
// MARK: - C Bridge Functions
@_cdecl("vision_recognize_text")
public func vision_recognize_text(
_ imagePath: UnsafePointer<CChar>,
_ imagePathLen: UInt32,
_ outBoxes: UnsafeMutablePointer<UnsafeMutableRawPointer?>,
_ outCount: UnsafeMutablePointer<UInt32>
) -> Bool {
// Convert C string to Swift String
guard let pathData = Data(bytes: imagePath, count: Int(imagePathLen)).withUnsafeBytes({
String(bytes: $0, encoding: .utf8)
}) else {
return false
}
let path = pathData.trimmingCharacters(in: .whitespaces)
// Load image
guard let image = NSImage(contentsOfFile: path),
let cgImage = image.cgImage(forProposedRect: nil, context: nil, hints: nil) else {
return false
}
// Perform OCR
var textBoxes: [CTextBox] = []
let semaphore = DispatchSemaphore(value: 0)
var success = false
let request = VNRecognizeTextRequest { request, error in
defer { semaphore.signal() }
if let error = error {
print("Vision OCR error: \(error.localizedDescription)")
return
}
guard let observations = request.results as? [VNRecognizedTextObservation] else {
return
}
let imageSize = CGSize(width: cgImage.width, height: cgImage.height)
for observation in observations {
guard let candidate = observation.topCandidates(1).first else { continue }
let text = candidate.string
let boundingBox = observation.boundingBox
// Convert normalized coordinates (bottom-left origin) to pixel coordinates (top-left origin)
let x = Int32(boundingBox.origin.x * imageSize.width)
let y = Int32((1.0 - boundingBox.origin.y - boundingBox.height) * imageSize.height)
let width = Int32(boundingBox.width * imageSize.width)
let height = Int32(boundingBox.height * imageSize.height)
// Allocate C string for text
let cString = strdup(text)
textBoxes.append(CTextBox(
text: cString,
text_len: UInt32(text.utf8.count),
x: x,
y: y,
width: width,
height: height,
confidence: observation.confidence
))
}
success = true
}
// Configure request for best accuracy
request.recognitionLevel = .accurate
request.usesLanguageCorrection = true
request.recognitionLanguages = ["en-US"]
// Perform request
let handler = VNImageRequestHandler(cgImage: cgImage, options: [:])
do {
try handler.perform([request])
} catch {
print("Vision request failed: \(error.localizedDescription)")
return false
}
// Wait for completion
semaphore.wait()
if !success {
return false
}
// Allocate array for results
let boxesPtr = UnsafeMutablePointer<CTextBox>.allocate(capacity: textBoxes.count)
for (index, box) in textBoxes.enumerated() {
boxesPtr[index] = box
}
outBoxes.pointee = UnsafeMutableRawPointer(boxesPtr)
outCount.pointee = UInt32(textBoxes.count)
return true
}
@_cdecl("vision_free_boxes")
public func vision_free_boxes(
_ boxes: UnsafeMutableRawPointer,
_ count: UInt32
) {
let typedBoxes = boxes.assumingMemoryBound(to: CTextBox.self)
for i in 0..<Int(count) {
if let text = typedBoxes[i].text {
free(UnsafeMutableRawPointer(mutating: text))
}
}
typedBoxes.deallocate()
}
// MARK: - C-Compatible Structure
public struct CTextBox {
public let text: UnsafePointer<CChar>?
public let text_len: UInt32
public let x: Int32
public let y: Int32
public let width: Int32
public let height: Int32
public let confidence: Float
public init(text: UnsafePointer<CChar>?, text_len: UInt32, x: Int32, y: Int32, width: Int32, height: Int32, confidence: Float) {
self.text = text
self.text_len = text_len
self.x = x
self.y = y
self.width = width
self.height = height
self.confidence = confidence
}
}