Enhance read_image tool with magic byte detection and multi-image support
- Fix media type detection using magic bytes instead of file extension - Correctly identifies JPEG files with .png extension (and vice versa) - Supports PNG, JPEG, GIF, and WebP formats - Add multi-image support with file_paths array parameter - Load multiple images in a single tool call - All images queued for LLM analysis - Enhanced CLI output: - Inline image preview via iTerm2 imgcat protocol (height=5) - Dimmed info line showing: path | dimensions | media type | file size - Proper │ prefix alignment with tool output boxing - Human-readable file sizes (bytes, KB, MB) - Add image dimension extraction from file headers - PNG, JPEG, GIF, WebP dimension parsing - Add comprehensive tests for magic byte detection and dimensions
This commit is contained in:
@@ -44,6 +44,7 @@ streaming-iterator = "0.1"
|
||||
walkdir = "2.4"
|
||||
|
||||
const_format = "0.2"
|
||||
base64 = "0.22.1"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.8"
|
||||
|
||||
@@ -867,6 +867,8 @@ pub struct Agent<W: UiWriter> {
|
||||
/// Working directory for tool execution (set by --codebase-fast-start)
|
||||
working_dir: Option<String>,
|
||||
background_process_manager: std::sync::Arc<background_process::BackgroundProcessManager>,
|
||||
/// Pending images to attach to the next user message
|
||||
pending_images: Vec<g3_providers::ImageContent>,
|
||||
}
|
||||
|
||||
impl<W: UiWriter> Agent<W> {
|
||||
@@ -1167,6 +1169,7 @@ impl<W: UiWriter> Agent<W> {
|
||||
background_process::BackgroundProcessManager::new(
|
||||
paths::get_logs_dir().join("background_processes")
|
||||
)),
|
||||
pending_images: Vec::new(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1774,7 +1777,7 @@ impl<W: UiWriter> Agent<W> {
|
||||
}
|
||||
|
||||
// Add user message to context window
|
||||
let user_message = {
|
||||
let mut user_message = {
|
||||
// Check if we should use cache control (every 10 tool calls)
|
||||
// But only if we haven't already added 4 cache_control annotations
|
||||
let provider = self.providers.get(None)?;
|
||||
@@ -1802,6 +1805,12 @@ impl<W: UiWriter> Agent<W> {
|
||||
Message::new(MessageRole::User, format!("Task: {}", description))
|
||||
}
|
||||
};
|
||||
|
||||
// Attach any pending images to this user message
|
||||
if !self.pending_images.is_empty() {
|
||||
user_message.images = std::mem::take(&mut self.pending_images);
|
||||
}
|
||||
|
||||
self.context_window.add_message(user_message);
|
||||
|
||||
// Execute fast-discovery tool calls if provided (immediately after user message)
|
||||
@@ -2721,6 +2730,7 @@ impl<W: UiWriter> Agent<W> {
|
||||
self.context_window.add_message(Message {
|
||||
role,
|
||||
id: String::new(),
|
||||
images: Vec::new(),
|
||||
content: content.to_string(),
|
||||
cache_control: None,
|
||||
});
|
||||
@@ -2746,6 +2756,7 @@ impl<W: UiWriter> Agent<W> {
|
||||
self.context_window.add_message(Message {
|
||||
role: MessageRole::User,
|
||||
id: String::new(),
|
||||
images: Vec::new(),
|
||||
content: format!("[Session Resumed]\n\n{}", context_msg),
|
||||
cache_control: None,
|
||||
});
|
||||
@@ -2829,6 +2840,21 @@ impl<W: UiWriter> Agent<W> {
|
||||
"required": ["file_path"]
|
||||
}),
|
||||
},
|
||||
Tool {
|
||||
name: "read_image".to_string(),
|
||||
description: "Read one or more image files and send them to the LLM for visual analysis. Supports PNG, JPEG, GIF, and WebP formats. Use this when you need to visually inspect images (e.g., find sprites, analyze UI, read diagrams). The images will be included in your next response for analysis.".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_paths": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" },
|
||||
"description": "Array of paths to image files to read"
|
||||
}
|
||||
},
|
||||
"required": ["file_paths"]
|
||||
}),
|
||||
},
|
||||
Tool {
|
||||
name: "write_file".to_string(),
|
||||
description: "Write content to a file (creates or overwrites). You MUST provide all arguments".to_string(),
|
||||
@@ -4041,7 +4067,7 @@ impl<W: UiWriter> Agent<W> {
|
||||
),
|
||||
)
|
||||
};
|
||||
let result_message = {
|
||||
let mut result_message = {
|
||||
// Check if we should use cache control (every 10 tool calls)
|
||||
// But only if we haven't already added 4 cache_control annotations
|
||||
if self.tool_call_count > 0
|
||||
@@ -4083,6 +4109,12 @@ impl<W: UiWriter> Agent<W> {
|
||||
}
|
||||
};
|
||||
|
||||
// Attach any pending images to the result message
|
||||
// (images loaded via read_image tool)
|
||||
if !self.pending_images.is_empty() {
|
||||
result_message.images = std::mem::take(&mut self.pending_images);
|
||||
}
|
||||
|
||||
// Track tokens before adding messages
|
||||
let tokens_before = self.context_window.used_tokens;
|
||||
|
||||
@@ -4979,6 +5011,108 @@ impl<W: UiWriter> Agent<W> {
|
||||
Ok("❌ Missing file_path argument".to_string())
|
||||
}
|
||||
}
|
||||
"read_image" => {
|
||||
debug!("Processing read_image tool call");
|
||||
|
||||
// Get paths from file_paths array
|
||||
let mut paths: Vec<String> = Vec::new();
|
||||
|
||||
if let Some(file_paths) = tool_call.args.get("file_paths") {
|
||||
if let Some(arr) = file_paths.as_array() {
|
||||
for p in arr {
|
||||
if let Some(s) = p.as_str() {
|
||||
paths.push(s.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if paths.is_empty() {
|
||||
return Ok("❌ Missing or empty file_paths argument".to_string());
|
||||
}
|
||||
|
||||
let mut results: Vec<String> = Vec::new();
|
||||
let mut success_count = 0;
|
||||
|
||||
for path_str in &paths {
|
||||
// Expand tilde (~) to home directory
|
||||
let expanded_path = shellexpand::tilde(path_str);
|
||||
let path = std::path::Path::new(expanded_path.as_ref());
|
||||
|
||||
// Check file exists
|
||||
if !path.exists() {
|
||||
results.push(format!("❌ Image file not found: {}", path_str));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Read the file first, then detect format from magic bytes
|
||||
match std::fs::read(path) {
|
||||
Ok(bytes) => {
|
||||
// Detect media type from magic bytes (file signature)
|
||||
let media_type = match g3_providers::ImageContent::media_type_from_bytes(&bytes) {
|
||||
Some(mt) => mt,
|
||||
None => {
|
||||
// Fall back to extension-based detection
|
||||
let ext = path.extension()
|
||||
.and_then(|e| e.to_str())
|
||||
.unwrap_or("");
|
||||
match g3_providers::ImageContent::media_type_from_extension(ext) {
|
||||
Some(mt) => mt,
|
||||
None => {
|
||||
results.push(format!(
|
||||
"❌ {}: Unsupported or unrecognized image format",
|
||||
path_str
|
||||
));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let file_size = bytes.len();
|
||||
|
||||
// Try to get image dimensions
|
||||
let dimensions = Self::get_image_dimensions(&bytes, media_type);
|
||||
|
||||
// Build info string
|
||||
let dim_str = dimensions
|
||||
.map(|(w, h)| format!("{}x{}", w, h))
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
let size_str = if file_size >= 1024 * 1024 {
|
||||
format!("{:.1} MB", file_size as f64 / (1024.0 * 1024.0))
|
||||
} else if file_size >= 1024 {
|
||||
format!("{:.1} KB", file_size as f64 / 1024.0)
|
||||
} else {
|
||||
format!("{} bytes", file_size)
|
||||
};
|
||||
|
||||
// Output imgcat inline image to terminal (height constrained)
|
||||
// followed by info line
|
||||
Self::print_imgcat(&bytes, path_str, &dim_str, media_type, &size_str, 5);
|
||||
|
||||
// Store the image to be attached to the next user message
|
||||
use base64::Engine;
|
||||
let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes);
|
||||
let image = g3_providers::ImageContent::new(media_type, encoded);
|
||||
self.pending_images.push(image);
|
||||
|
||||
success_count += 1;
|
||||
}
|
||||
Err(e) => {
|
||||
results.push(format!("❌ Failed to read '{}': {}", path_str, e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let summary = if success_count == paths.len() {
|
||||
format!("\n{} image(s) read.", success_count)
|
||||
} else {
|
||||
format!("\n{}/{} image(s) read.", success_count, paths.len())
|
||||
};
|
||||
|
||||
Ok(format!("{}\n{}", results.join("\n"), summary))
|
||||
}
|
||||
"write_file" => {
|
||||
debug!("Processing write_file tool call");
|
||||
debug!("Raw tool_call.args: {:?}", tool_call.args);
|
||||
@@ -6556,6 +6690,87 @@ impl<W: UiWriter> Agent<W> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Get image dimensions from raw bytes
|
||||
fn get_image_dimensions(bytes: &[u8], media_type: &str) -> Option<(u32, u32)> {
|
||||
match media_type {
|
||||
"image/png" => {
|
||||
// PNG: width at bytes 16-19, height at bytes 20-23 (big-endian)
|
||||
if bytes.len() >= 24 {
|
||||
let width = u32::from_be_bytes([bytes[16], bytes[17], bytes[18], bytes[19]]);
|
||||
let height = u32::from_be_bytes([bytes[20], bytes[21], bytes[22], bytes[23]]);
|
||||
Some((width, height))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
"image/jpeg" => {
|
||||
// JPEG: Need to find SOF0/SOF2 marker (FF C0 or FF C2)
|
||||
let mut i = 2; // Skip FF D8
|
||||
while i + 8 < bytes.len() {
|
||||
if bytes[i] == 0xFF {
|
||||
let marker = bytes[i + 1];
|
||||
// SOF0, SOF1, SOF2 markers contain dimensions
|
||||
if marker == 0xC0 || marker == 0xC1 || marker == 0xC2 {
|
||||
let height = u16::from_be_bytes([bytes[i + 5], bytes[i + 6]]) as u32;
|
||||
let width = u16::from_be_bytes([bytes[i + 7], bytes[i + 8]]) as u32;
|
||||
return Some((width, height));
|
||||
}
|
||||
// Skip to next marker
|
||||
if marker == 0xD8 || marker == 0xD9 || marker == 0x01 || (0xD0..=0xD7).contains(&marker) {
|
||||
i += 2;
|
||||
} else {
|
||||
let len = u16::from_be_bytes([bytes[i + 2], bytes[i + 3]]) as usize;
|
||||
i += 2 + len;
|
||||
}
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
"image/gif" => {
|
||||
// GIF: width at bytes 6-7, height at bytes 8-9 (little-endian)
|
||||
if bytes.len() >= 10 {
|
||||
let width = u16::from_le_bytes([bytes[6], bytes[7]]) as u32;
|
||||
let height = u16::from_le_bytes([bytes[8], bytes[9]]) as u32;
|
||||
Some((width, height))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
"image/webp" => {
|
||||
// WebP VP8: dimensions at specific offsets (simplified)
|
||||
// This is a basic implementation - WebP format is complex
|
||||
if bytes.len() >= 30 && &bytes[12..16] == b"VP8 " {
|
||||
// Lossy WebP
|
||||
let width = (u16::from_le_bytes([bytes[26], bytes[27]]) & 0x3FFF) as u32;
|
||||
let height = (u16::from_le_bytes([bytes[28], bytes[29]]) & 0x3FFF) as u32;
|
||||
Some((width, height))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Print image using iTerm2 imgcat protocol
|
||||
/// Print image using iTerm2 imgcat protocol with info line
|
||||
fn print_imgcat(bytes: &[u8], name: &str, dimensions: &str, media_type: &str, size: &str, max_height: u32) {
|
||||
use base64::Engine;
|
||||
let encoded = base64::engine::general_purpose::STANDARD.encode(bytes);
|
||||
// Print 3 lines of │ prefix before image for visual alignment
|
||||
println!("│");
|
||||
println!("│");
|
||||
// iTerm2 inline image protocol
|
||||
print!("│ \x1b]1337;File=inline=1;height={};name={}:{}\x07", max_height, name, encoded);
|
||||
// Print dimmed info line (no checkmark)
|
||||
println!("│ \x1b[2m{} | {} | {} | {}\x1b[0m", name, dimensions, media_type, size);
|
||||
// Blank line before next image
|
||||
println!("│");
|
||||
}
|
||||
|
||||
fn format_duration(duration: Duration) -> String {
|
||||
let total_ms = duration.as_millis();
|
||||
|
||||
|
||||
@@ -259,6 +259,10 @@ Short description for providers without native calling specs:
|
||||
- Example: {\"tool\": \"read_file\", \"args\": {\"file_path\": \"src/main.rs\"}
|
||||
- Example (partial): {\"tool\": \"read_file\", \"args\": {\"file_path\": \"large.log\", \"start\": 0, \"end\": 1000}
|
||||
|
||||
- **read_image**: Read an image file for visual analysis (PNG, JPEG, GIF, WebP)
|
||||
- Format: {\"tool\": \"read_image\", \"args\": {\"file_path\": \"path/to/image.png\"}}
|
||||
- Example: {\"tool\": \"read_image\", \"args\": {\"file_path\": \"sprites/fairy.png\"}}
|
||||
|
||||
- **write_file**: Write content to a file (creates or overwrites)
|
||||
- Format: {\"tool\": \"write_file\", \"args\": {\"file_path\": \"path/to/file\", \"content\": \"file content\"}
|
||||
- Example: {\"tool\": \"write_file\", \"args\": {\"file_path\": \"src/lib.rs\", \"content\": \"pub fn hello() {}\"}
|
||||
|
||||
201
crates/g3-core/tests/read_image_test.rs
Normal file
201
crates/g3-core/tests/read_image_test.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
use g3_providers::ImageContent;
|
||||
use std::fs;
|
||||
|
||||
#[test]
|
||||
fn test_image_content_media_type_detection() {
|
||||
assert_eq!(ImageContent::media_type_from_extension("png"), Some("image/png"));
|
||||
assert_eq!(ImageContent::media_type_from_extension("PNG"), Some("image/png"));
|
||||
assert_eq!(ImageContent::media_type_from_extension("jpg"), Some("image/jpeg"));
|
||||
assert_eq!(ImageContent::media_type_from_extension("jpeg"), Some("image/jpeg"));
|
||||
assert_eq!(ImageContent::media_type_from_extension("JPEG"), Some("image/jpeg"));
|
||||
assert_eq!(ImageContent::media_type_from_extension("gif"), Some("image/gif"));
|
||||
assert_eq!(ImageContent::media_type_from_extension("webp"), Some("image/webp"));
|
||||
assert_eq!(ImageContent::media_type_from_extension("bmp"), None); // Not supported
|
||||
assert_eq!(ImageContent::media_type_from_extension("txt"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_content_creation() {
|
||||
let image = ImageContent::new("image/png", "base64data".to_string());
|
||||
assert_eq!(image.media_type, "image/png");
|
||||
assert_eq!(image.data, "base64data");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_and_encode_image() {
|
||||
// Create a minimal valid PNG
|
||||
let test_dir = std::env::temp_dir().join("g3_read_image_test");
|
||||
let _ = fs::remove_dir_all(&test_dir);
|
||||
fs::create_dir_all(&test_dir).unwrap();
|
||||
|
||||
// Minimal 1x1 red PNG (hand-crafted)
|
||||
let png_bytes: Vec<u8> = vec![
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
|
||||
0x00, 0x00, 0x00, 0x0D, // IHDR length
|
||||
0x49, 0x48, 0x44, 0x52, // IHDR
|
||||
0x00, 0x00, 0x00, 0x01, // width = 1
|
||||
0x00, 0x00, 0x00, 0x01, // height = 1
|
||||
0x08, 0x02, 0x00, 0x00, 0x00, // bit depth, color type, etc.
|
||||
0x90, 0x77, 0x53, 0xDE, // CRC
|
||||
0x00, 0x00, 0x00, 0x0C, // IDAT length
|
||||
0x49, 0x44, 0x41, 0x54, // IDAT
|
||||
0x08, 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, // compressed data
|
||||
0x01, 0x01, 0x01, 0x00, // CRC (approximate)
|
||||
0x00, 0x00, 0x00, 0x00, // IEND length
|
||||
0x49, 0x45, 0x4E, 0x44, // IEND
|
||||
0xAE, 0x42, 0x60, 0x82, // CRC
|
||||
];
|
||||
|
||||
let image_path = test_dir.join("test.png");
|
||||
fs::write(&image_path, &png_bytes).unwrap();
|
||||
|
||||
// Read and encode
|
||||
let bytes = fs::read(&image_path).unwrap();
|
||||
use base64::Engine;
|
||||
let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes);
|
||||
|
||||
// Verify it's valid base64
|
||||
assert!(!encoded.is_empty());
|
||||
assert!(encoded.len() > 10);
|
||||
|
||||
// Verify we can decode it back
|
||||
let decoded = base64::engine::general_purpose::STANDARD.decode(&encoded).unwrap();
|
||||
assert_eq!(decoded, bytes);
|
||||
|
||||
// Create ImageContent
|
||||
let ext = image_path.extension().unwrap().to_str().unwrap();
|
||||
let media_type = ImageContent::media_type_from_extension(ext).unwrap();
|
||||
let image = ImageContent::new(media_type, encoded);
|
||||
|
||||
assert_eq!(image.media_type, "image/png");
|
||||
assert!(!image.data.is_empty());
|
||||
|
||||
// Cleanup
|
||||
let _ = fs::remove_dir_all(&test_dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_media_type_from_bytes_png() {
|
||||
// PNG magic bytes
|
||||
let png_bytes: Vec<u8> = vec![
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
|
||||
0x00, 0x00, 0x00, 0x0D, // IHDR length
|
||||
0x49, 0x48, 0x44, 0x52, // IHDR
|
||||
];
|
||||
assert_eq!(ImageContent::media_type_from_bytes(&png_bytes), Some("image/png"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_media_type_from_bytes_jpeg() {
|
||||
// JPEG magic bytes (FF D8 FF)
|
||||
let jpeg_bytes: Vec<u8> = vec![
|
||||
0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46,
|
||||
0x49, 0x46, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01,
|
||||
];
|
||||
assert_eq!(ImageContent::media_type_from_bytes(&jpeg_bytes), Some("image/jpeg"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_media_type_from_bytes_gif() {
|
||||
// GIF magic bytes (GIF89a)
|
||||
let gif_bytes: Vec<u8> = vec![
|
||||
0x47, 0x49, 0x46, 0x38, 0x39, 0x61, 0x01, 0x00,
|
||||
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
];
|
||||
assert_eq!(ImageContent::media_type_from_bytes(&gif_bytes), Some("image/gif"));
|
||||
|
||||
// GIF87a variant
|
||||
let gif87_bytes: Vec<u8> = vec![
|
||||
0x47, 0x49, 0x46, 0x38, 0x37, 0x61, 0x01, 0x00,
|
||||
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
];
|
||||
assert_eq!(ImageContent::media_type_from_bytes(&gif87_bytes), Some("image/gif"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_media_type_from_bytes_webp() {
|
||||
// WebP magic bytes (RIFF....WEBP)
|
||||
let webp_bytes: Vec<u8> = vec![
|
||||
0x52, 0x49, 0x46, 0x46, // RIFF
|
||||
0x00, 0x00, 0x00, 0x00, // file size (placeholder)
|
||||
0x57, 0x45, 0x42, 0x50, // WEBP
|
||||
0x56, 0x50, 0x38, 0x20, // VP8 (additional data)
|
||||
];
|
||||
assert_eq!(ImageContent::media_type_from_bytes(&webp_bytes), Some("image/webp"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_media_type_from_bytes_unknown() {
|
||||
// Random bytes that don't match any format
|
||||
let unknown_bytes: Vec<u8> = vec![
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||
0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
|
||||
];
|
||||
assert_eq!(ImageContent::media_type_from_bytes(&unknown_bytes), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_media_type_from_bytes_too_short() {
|
||||
// Too short to detect
|
||||
let short_bytes: Vec<u8> = vec![0x89, 0x50, 0x4E];
|
||||
assert_eq!(ImageContent::media_type_from_bytes(&short_bytes), None);
|
||||
|
||||
// Empty
|
||||
let empty_bytes: Vec<u8> = vec![];
|
||||
assert_eq!(ImageContent::media_type_from_bytes(&empty_bytes), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_image_multiple_paths_schema() {
|
||||
// This test verifies the tool accepts file_paths array
|
||||
|
||||
// Single path in array
|
||||
let single_args = serde_json::json!({
|
||||
"file_paths": ["/path/to/image.png"]
|
||||
});
|
||||
let paths = single_args.get("file_paths").unwrap().as_array().unwrap();
|
||||
assert_eq!(paths.len(), 1);
|
||||
|
||||
// Multiple paths in array
|
||||
let multi_args = serde_json::json!({
|
||||
"file_paths": ["/path/to/image1.png", "/path/to/image2.jpg"]
|
||||
});
|
||||
let paths = multi_args.get("file_paths").unwrap().as_array().unwrap();
|
||||
assert_eq!(paths.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_dimensions_png() {
|
||||
// Minimal PNG with known dimensions (1x1)
|
||||
let png_bytes: Vec<u8> = vec![
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
|
||||
0x00, 0x00, 0x00, 0x0D, // IHDR length
|
||||
0x49, 0x48, 0x44, 0x52, // IHDR
|
||||
0x00, 0x00, 0x00, 0x01, // width = 1
|
||||
0x00, 0x00, 0x00, 0x01, // height = 1
|
||||
0x08, 0x02, 0x00, 0x00, 0x00, // bit depth, color type, etc.
|
||||
];
|
||||
|
||||
// PNG dimensions are at bytes 16-19 (width) and 20-23 (height)
|
||||
if png_bytes.len() >= 24 {
|
||||
let width = u32::from_be_bytes([png_bytes[16], png_bytes[17], png_bytes[18], png_bytes[19]]);
|
||||
let height = u32::from_be_bytes([png_bytes[20], png_bytes[21], png_bytes[22], png_bytes[23]]);
|
||||
assert_eq!(width, 1);
|
||||
assert_eq!(height, 1);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_dimensions_gif() {
|
||||
// GIF with known dimensions
|
||||
let gif_bytes: Vec<u8> = vec![
|
||||
0x47, 0x49, 0x46, 0x38, 0x39, 0x61, // GIF89a
|
||||
0x64, 0x00, // width = 100 (little-endian)
|
||||
0xC8, 0x00, // height = 200 (little-endian)
|
||||
];
|
||||
|
||||
let width = u16::from_le_bytes([gif_bytes[6], gif_bytes[7]]) as u32;
|
||||
let height = u16::from_le_bytes([gif_bytes[8], gif_bytes[9]]) as u32;
|
||||
assert_eq!(width, 100);
|
||||
assert_eq!(height, 200);
|
||||
}
|
||||
@@ -274,15 +274,29 @@ impl AnthropicProvider {
|
||||
}
|
||||
}
|
||||
MessageRole::User => {
|
||||
// Build content blocks - images first, then text
|
||||
let mut content_blocks: Vec<AnthropicContent> = Vec::new();
|
||||
|
||||
// Add any images attached to this message
|
||||
for image in &message.images {
|
||||
content_blocks.push(AnthropicContent::Image {
|
||||
source: AnthropicImageSource {
|
||||
source_type: "base64".to_string(),
|
||||
media_type: image.media_type.clone(),
|
||||
data: image.data.clone(),
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Add text content
|
||||
content_blocks.push(AnthropicContent::Text {
|
||||
text: message.content.clone(),
|
||||
cache_control: message.cache_control.as_ref().map(Self::convert_cache_control),
|
||||
});
|
||||
|
||||
anthropic_messages.push(AnthropicMessage {
|
||||
role: "user".to_string(),
|
||||
content: vec![AnthropicContent::Text {
|
||||
text: message.content.clone(),
|
||||
cache_control: message
|
||||
.cache_control
|
||||
.as_ref()
|
||||
.map(Self::convert_cache_control),
|
||||
}],
|
||||
content: content_blocks,
|
||||
});
|
||||
}
|
||||
MessageRole::Assistant => {
|
||||
@@ -924,6 +938,19 @@ enum AnthropicContent {
|
||||
name: String,
|
||||
input: serde_json::Value,
|
||||
},
|
||||
#[serde(rename = "image")]
|
||||
Image {
|
||||
source: AnthropicImageSource,
|
||||
},
|
||||
}
|
||||
|
||||
/// Image source for Anthropic API
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct AnthropicImageSource {
|
||||
#[serde(rename = "type")]
|
||||
source_type: String, // Always "base64"
|
||||
media_type: String, // e.g., "image/png", "image/jpeg"
|
||||
data: String, // Base64-encoded image data
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
||||
@@ -88,6 +88,8 @@ pub struct Message {
|
||||
pub role: MessageRole,
|
||||
pub content: String,
|
||||
#[serde(skip)]
|
||||
pub images: Vec<ImageContent>,
|
||||
#[serde(skip)]
|
||||
pub id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cache_control: Option<CacheControl>,
|
||||
@@ -101,6 +103,65 @@ pub enum MessageRole {
|
||||
Assistant,
|
||||
}
|
||||
|
||||
/// Image content for multimodal messages
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ImageContent {
|
||||
/// Media type (e.g., "image/png", "image/jpeg", "image/gif", "image/webp")
|
||||
pub media_type: String,
|
||||
/// Base64-encoded image data
|
||||
pub data: String,
|
||||
}
|
||||
|
||||
impl ImageContent {
|
||||
pub fn new(media_type: &str, data: String) -> Self {
|
||||
Self {
|
||||
media_type: media_type.to_string(),
|
||||
data,
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect media type from file extension
|
||||
pub fn media_type_from_extension(ext: &str) -> Option<&'static str> {
|
||||
match ext.to_lowercase().as_str() {
|
||||
"png" => Some("image/png"),
|
||||
"jpg" | "jpeg" => Some("image/jpeg"),
|
||||
"gif" => Some("image/gif"),
|
||||
"webp" => Some("image/webp"),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect media type from image data magic bytes (file signature)
|
||||
/// This is more reliable than file extension as it checks actual content
|
||||
pub fn media_type_from_bytes(bytes: &[u8]) -> Option<&'static str> {
|
||||
if bytes.len() < 12 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// PNG: 89 50 4E 47 0D 0A 1A 0A
|
||||
if bytes.starts_with(&[0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]) {
|
||||
return Some("image/png");
|
||||
}
|
||||
|
||||
// JPEG: FF D8 FF
|
||||
if bytes.starts_with(&[0xFF, 0xD8, 0xFF]) {
|
||||
return Some("image/jpeg");
|
||||
}
|
||||
|
||||
// GIF: 47 49 46 38 (GIF8)
|
||||
if bytes.starts_with(&[0x47, 0x49, 0x46, 0x38]) {
|
||||
return Some("image/gif");
|
||||
}
|
||||
|
||||
// WebP: 52 49 46 46 ... 57 45 42 50 (RIFF....WEBP)
|
||||
if bytes.starts_with(&[0x52, 0x49, 0x46, 0x46]) && bytes.len() >= 12 && &bytes[8..12] == b"WEBP" {
|
||||
return Some("image/webp");
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CompletionResponse {
|
||||
pub content: String,
|
||||
@@ -174,6 +235,7 @@ impl Message {
|
||||
Self {
|
||||
role,
|
||||
content,
|
||||
images: Vec::new(),
|
||||
id: Self::generate_id(),
|
||||
cache_control: None,
|
||||
}
|
||||
@@ -188,6 +250,7 @@ impl Message {
|
||||
Self {
|
||||
role,
|
||||
content,
|
||||
images: Vec::new(),
|
||||
id: Self::generate_id(),
|
||||
cache_control: Some(cache_control),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user