Compare commits
18 Commits
jochen-add
...
micn/testi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
95e5a59720 | ||
|
|
04ceefd5e2 | ||
|
|
40f9ea5eb3 | ||
|
|
69ae894de8 | ||
|
|
3643dad354 | ||
|
|
14c8d066c9 | ||
|
|
b6e226df67 | ||
|
|
5b46922047 | ||
|
|
1069664e16 | ||
|
|
725f54b99b | ||
|
|
325aab6b0e | ||
|
|
3f21bdc7b2 | ||
|
|
9bffd8b1bf | ||
|
|
bfee8040e9 | ||
|
|
a150ba6a55 | ||
|
|
296bf5a449 | ||
|
|
8d8ddbe4b9 | ||
|
|
0466405d87 |
73
.github/workflows/ci.yml
vendored
Normal file
73
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
pull_request:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- os: ubuntu-latest
|
||||||
|
arch: x86_64
|
||||||
|
- os: ubuntu-latest
|
||||||
|
arch: aarch64
|
||||||
|
- os: macos-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install Rust
|
||||||
|
uses: dtolnay/rust-toolchain@stable
|
||||||
|
|
||||||
|
- name: Set up QEMU (for aarch64 on Linux)
|
||||||
|
if: matrix.arch == 'aarch64' && runner.os == 'Linux'
|
||||||
|
uses: docker/setup-qemu-action@v3
|
||||||
|
|
||||||
|
- name: Cache cargo
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cargo/registry
|
||||||
|
~/.cargo/git
|
||||||
|
target
|
||||||
|
key: ${{ runner.os }}-${{ matrix.arch || 'x86_64' }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||||
|
|
||||||
|
- name: Install system dependencies (Ubuntu)
|
||||||
|
if: runner.os == 'Linux' && matrix.arch != 'aarch64'
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y libx11-dev libxdo-dev libxcb-shape0-dev libxcb-xfixes0-dev libxtst-dev
|
||||||
|
|
||||||
|
- name: Build and test (Linux aarch64)
|
||||||
|
if: matrix.arch == 'aarch64' && runner.os == 'Linux'
|
||||||
|
uses: uraimo/run-on-arch-action@v2
|
||||||
|
with:
|
||||||
|
arch: aarch64
|
||||||
|
distro: ubuntu22.04
|
||||||
|
install: |
|
||||||
|
apt-get update
|
||||||
|
apt-get install -y curl build-essential libx11-dev libxdo-dev libxcb-shape0-dev libxcb-xfixes0-dev libxtst-dev
|
||||||
|
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||||
|
run: |
|
||||||
|
. $HOME/.cargo/env
|
||||||
|
cargo build --workspace --exclude g3-computer-control
|
||||||
|
cargo test --workspace --exclude g3-computer-control --lib --tests
|
||||||
|
|
||||||
|
- name: Build (Linux x86_64)
|
||||||
|
if: matrix.arch != 'aarch64' && runner.os == 'Linux'
|
||||||
|
run: cargo build --workspace --exclude g3-computer-control
|
||||||
|
|
||||||
|
- name: Run tests (Linux x86_64)
|
||||||
|
if: matrix.arch != 'aarch64' && runner.os == 'Linux'
|
||||||
|
run: cargo test --workspace --exclude g3-computer-control --lib --tests
|
||||||
|
|
||||||
|
- name: Build (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: cargo build --workspace
|
||||||
|
|
||||||
|
- name: Run tests (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: cargo test --workspace --lib --tests
|
||||||
@@ -11,12 +11,24 @@ model = "databricks-claude-sonnet-4"
|
|||||||
max_tokens = 4096
|
max_tokens = 4096
|
||||||
temperature = 0.1
|
temperature = 0.1
|
||||||
use_oauth = true
|
use_oauth = true
|
||||||
|
# cache_config = "ephemeral" # Optional: Enable prompt caching for Claude models
|
||||||
|
# Options: "ephemeral", "5minute", "1hour"
|
||||||
|
# Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs.
|
||||||
|
# The cache control will be automatically applied to:
|
||||||
|
# - The system prompt at the start of each session
|
||||||
|
# - Assistant responses after every 10 tool calls
|
||||||
|
# - 5minute costs $3/mtok, more details below
|
||||||
|
# https://docs.claude.com/en/docs/build-with-claude/prompt-caching#pricing
|
||||||
|
|
||||||
[providers.anthropic]
|
[providers.anthropic]
|
||||||
api_key = "your-anthropic-api-key"
|
api_key = "your-anthropic-api-key"
|
||||||
model = "claude-3-haiku-20240307" # Using a faster model for player
|
model = "claude-sonnet-4-5"
|
||||||
max_tokens = 4096
|
max_tokens = 4096
|
||||||
temperature = 0.3 # Slightly higher temperature for more creative implementations
|
temperature = 0.3 # Slightly higher temperature for more creative implementations
|
||||||
|
# cache_config = "ephemeral" # Optional: Enable prompt caching
|
||||||
|
# Options: "ephemeral", "5minute", "1hour"
|
||||||
|
# Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs.
|
||||||
|
# enable_1m_context = true # optional, more expensive
|
||||||
|
|
||||||
[agent]
|
[agent]
|
||||||
fallback_default_max_tokens = 8192
|
fallback_default_max_tokens = 8192
|
||||||
|
|||||||
@@ -15,6 +15,17 @@ max_tokens = 4096 # Per-request output limit (how many tokens the model can gen
|
|||||||
temperature = 0.1
|
temperature = 0.1
|
||||||
use_oauth = true
|
use_oauth = true
|
||||||
|
|
||||||
|
[providers.anthropic]
|
||||||
|
api_key = "your-anthropic-api-key"
|
||||||
|
model = "claude-sonnet-4-5"
|
||||||
|
max_tokens = 4096
|
||||||
|
temperature = 0.3 # Slightly higher temperature for more creative implementations
|
||||||
|
# cache_config = "ephemeral" # Optional: Enable prompt caching
|
||||||
|
# Options: "ephemeral", "5minute", "1hour"
|
||||||
|
# Reduces costs and latency for repeated prompts. Uses Anthropic's prompt caching with different TTLs.
|
||||||
|
# enable_1m_context = true # optional, more expensive
|
||||||
|
|
||||||
|
|
||||||
# Multiple OpenAI-compatible providers can be configured with custom names
|
# Multiple OpenAI-compatible providers can be configured with custom names
|
||||||
# Each provider gets its own section under [providers.openai_compatible.<name>]
|
# Each provider gets its own section under [providers.openai_compatible.<name>]
|
||||||
# [providers.openai_compatible.openrouter]
|
# [providers.openai_compatible.openrouter]
|
||||||
|
|||||||
@@ -1686,6 +1686,9 @@ async fn run_autonomous(
|
|||||||
turn, max_turns
|
turn, max_turns
|
||||||
));
|
));
|
||||||
|
|
||||||
|
// Surface provider info for player agent
|
||||||
|
agent.print_provider_banner("Player");
|
||||||
|
|
||||||
// Player mode: implement requirements (with coach feedback if available)
|
// Player mode: implement requirements (with coach feedback if available)
|
||||||
let player_prompt = if coach_feedback.is_empty() {
|
let player_prompt = if coach_feedback.is_empty() {
|
||||||
format!(
|
format!(
|
||||||
@@ -1879,6 +1882,9 @@ async fn run_autonomous(
|
|||||||
let mut coach_agent =
|
let mut coach_agent =
|
||||||
Agent::new_autonomous_with_readme_and_quiet(coach_config, ui_writer, None, quiet).await?;
|
Agent::new_autonomous_with_readme_and_quiet(coach_config, ui_writer, None, quiet).await?;
|
||||||
|
|
||||||
|
// Surface provider info for coach agent
|
||||||
|
coach_agent.print_provider_banner("Coach");
|
||||||
|
|
||||||
// Ensure coach agent is also in the workspace directory
|
// Ensure coach agent is also in the workspace directory
|
||||||
project.enter_workspace()?;
|
project.enter_workspace()?;
|
||||||
|
|
||||||
|
|||||||
@@ -40,6 +40,8 @@ pub struct AnthropicConfig {
|
|||||||
pub model: String,
|
pub model: String,
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
|
pub cache_config: Option<String>, // "ephemeral", "5minute", "1hour", or None to disable
|
||||||
|
pub enable_1m_context: Option<bool>, // Enable 1m context window (costs extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|||||||
256
crates/g3-console/src/logs.rs
Normal file
256
crates/g3-console/src/logs.rs
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
use crate::models::{InstanceStats, TurnInfo};
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::fs;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct LogEntry {
|
||||||
|
pub timestamp: Option<DateTime<Utc>>,
|
||||||
|
pub role: Option<String>,
|
||||||
|
pub content: Option<String>,
|
||||||
|
pub tool_calls: Option<Vec<Value>>,
|
||||||
|
pub raw: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChatMessage {
|
||||||
|
pub role: String,
|
||||||
|
pub content: String,
|
||||||
|
pub timestamp: Option<DateTime<Utc>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ToolCall {
|
||||||
|
pub name: String,
|
||||||
|
pub parameters: Value,
|
||||||
|
pub result: Option<String>,
|
||||||
|
pub timestamp: Option<DateTime<Utc>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct LogParser;
|
||||||
|
|
||||||
|
impl LogParser {
|
||||||
|
/// Parse logs from a workspace directory
|
||||||
|
pub fn parse_logs(workspace: &Path) -> Result<Vec<LogEntry>> {
|
||||||
|
let logs_dir = workspace.join("logs");
|
||||||
|
|
||||||
|
if !logs_dir.exists() {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut entries = Vec::new();
|
||||||
|
|
||||||
|
// Read all JSON log files
|
||||||
|
for entry in fs::read_dir(&logs_dir).context("Failed to read logs directory")? {
|
||||||
|
let entry = entry?;
|
||||||
|
let path = entry.path();
|
||||||
|
|
||||||
|
if path.extension().and_then(|s| s.to_str()) == Some("json") {
|
||||||
|
if let Ok(content) = fs::read_to_string(&path) {
|
||||||
|
if let Ok(json) = serde_json::from_str::<Value>(&content) {
|
||||||
|
// Try to parse as a log session
|
||||||
|
if let Some(messages) = json.get("messages").and_then(|m| m.as_array()) {
|
||||||
|
for msg in messages {
|
||||||
|
entries.push(LogEntry {
|
||||||
|
timestamp: msg.get("timestamp")
|
||||||
|
.and_then(|t| t.as_str())
|
||||||
|
.and_then(|s| DateTime::parse_from_rfc3339(s).ok())
|
||||||
|
.map(|dt| dt.with_timezone(&Utc)),
|
||||||
|
role: msg.get("role")
|
||||||
|
.and_then(|r| r.as_str())
|
||||||
|
.map(String::from),
|
||||||
|
content: msg.get("content")
|
||||||
|
.and_then(|c| c.as_str())
|
||||||
|
.map(String::from),
|
||||||
|
tool_calls: msg.get("tool_calls")
|
||||||
|
.and_then(|tc| tc.as_array())
|
||||||
|
.map(|arr| arr.clone()),
|
||||||
|
raw: msg.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by timestamp
|
||||||
|
entries.sort_by(|a, b| {
|
||||||
|
match (&a.timestamp, &b.timestamp) {
|
||||||
|
(Some(t1), Some(t2)) => t1.cmp(t2),
|
||||||
|
(Some(_), None) => std::cmp::Ordering::Less,
|
||||||
|
(None, Some(_)) => std::cmp::Ordering::Greater,
|
||||||
|
(None, None) => std::cmp::Ordering::Equal,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(entries)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract chat messages from log entries
|
||||||
|
pub fn extract_chat_messages(entries: &[LogEntry]) -> Vec<ChatMessage> {
|
||||||
|
entries
|
||||||
|
.iter()
|
||||||
|
.filter_map(|entry| {
|
||||||
|
let role = entry.role.clone()?;
|
||||||
|
let content = entry.content.clone()?;
|
||||||
|
|
||||||
|
Some(ChatMessage {
|
||||||
|
role,
|
||||||
|
content,
|
||||||
|
timestamp: entry.timestamp,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract tool calls from log entries
|
||||||
|
pub fn extract_tool_calls(entries: &[LogEntry]) -> Vec<ToolCall> {
|
||||||
|
let mut tool_calls = Vec::new();
|
||||||
|
|
||||||
|
for entry in entries {
|
||||||
|
if let Some(calls) = &entry.tool_calls {
|
||||||
|
for call in calls {
|
||||||
|
if let Some(name) = call.get("name").and_then(|n| n.as_str()) {
|
||||||
|
tool_calls.push(ToolCall {
|
||||||
|
name: name.to_string(),
|
||||||
|
parameters: call.get("parameters")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or(Value::Object(serde_json::Map::new())),
|
||||||
|
result: call.get("result")
|
||||||
|
.and_then(|r| r.as_str())
|
||||||
|
.map(String::from),
|
||||||
|
timestamp: entry.timestamp,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_calls
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct StatsAggregator;
|
||||||
|
|
||||||
|
impl StatsAggregator {
|
||||||
|
/// Aggregate statistics from log entries
|
||||||
|
pub fn aggregate_stats(
|
||||||
|
entries: &[LogEntry],
|
||||||
|
start_time: DateTime<Utc>,
|
||||||
|
is_ensemble: bool,
|
||||||
|
) -> InstanceStats {
|
||||||
|
let total_tokens = Self::count_tokens(entries);
|
||||||
|
let tool_calls = Self::count_tool_calls(entries);
|
||||||
|
let errors = Self::count_errors(entries);
|
||||||
|
|
||||||
|
let duration_secs = if let Some(last_entry) = entries.last() {
|
||||||
|
if let Some(last_time) = last_entry.timestamp {
|
||||||
|
(last_time - start_time).num_seconds().max(0) as u64
|
||||||
|
} else {
|
||||||
|
(Utc::now() - start_time).num_seconds().max(0) as u64
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
(Utc::now() - start_time).num_seconds().max(0) as u64
|
||||||
|
};
|
||||||
|
|
||||||
|
let turns = if is_ensemble {
|
||||||
|
Some(Self::extract_turns(entries))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
InstanceStats {
|
||||||
|
total_tokens,
|
||||||
|
tool_calls,
|
||||||
|
errors,
|
||||||
|
duration_secs,
|
||||||
|
turns,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the latest message content from log entries
|
||||||
|
pub fn get_latest_message(entries: &[LogEntry]) -> Option<String> {
|
||||||
|
entries
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.find(|entry| entry.role.as_deref() == Some("assistant"))
|
||||||
|
.and_then(|entry| entry.content.clone())
|
||||||
|
.or_else(|| {
|
||||||
|
entries
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.find(|entry| entry.content.is_some())
|
||||||
|
.and_then(|entry| entry.content.clone())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn count_tokens(entries: &[LogEntry]) -> u64 {
|
||||||
|
// Try to extract token counts from metadata
|
||||||
|
entries
|
||||||
|
.iter()
|
||||||
|
.filter_map(|entry| {
|
||||||
|
entry.raw.get("usage")
|
||||||
|
.and_then(|u| u.get("total_tokens"))
|
||||||
|
.and_then(|t| t.as_u64())
|
||||||
|
})
|
||||||
|
.sum()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn count_tool_calls(entries: &[LogEntry]) -> u64 {
|
||||||
|
entries
|
||||||
|
.iter()
|
||||||
|
.filter_map(|entry| entry.tool_calls.as_ref())
|
||||||
|
.map(|calls| calls.len() as u64)
|
||||||
|
.sum()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn count_errors(entries: &[LogEntry]) -> u64 {
|
||||||
|
entries
|
||||||
|
.iter()
|
||||||
|
.filter(|entry| {
|
||||||
|
entry.raw.get("error").is_some()
|
||||||
|
|| entry.content.as_ref().map(|c| c.to_lowercase().contains("error")).unwrap_or(false)
|
||||||
|
})
|
||||||
|
.count() as u64
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_turns(entries: &[LogEntry]) -> Vec<TurnInfo> {
|
||||||
|
// Simple implementation: group consecutive assistant messages as turns
|
||||||
|
let mut turns = Vec::new();
|
||||||
|
let mut current_turn_start: Option<DateTime<Utc>> = None;
|
||||||
|
let mut turn_count = 0;
|
||||||
|
|
||||||
|
for entry in entries {
|
||||||
|
if entry.role.as_deref() == Some("assistant") {
|
||||||
|
if current_turn_start.is_none() {
|
||||||
|
current_turn_start = entry.timestamp;
|
||||||
|
turn_count += 1;
|
||||||
|
}
|
||||||
|
} else if entry.role.as_deref() == Some("user") {
|
||||||
|
if let Some(start) = current_turn_start {
|
||||||
|
if let Some(end) = entry.timestamp {
|
||||||
|
let duration = (end - start).num_seconds().max(0) as u64;
|
||||||
|
turns.push(TurnInfo {
|
||||||
|
agent: format!("agent-{}", turn_count),
|
||||||
|
duration_secs: duration,
|
||||||
|
status: "completed".to_string(),
|
||||||
|
color: Self::get_turn_color(turn_count),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
current_turn_start = None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
turns
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_turn_color(turn_number: usize) -> String {
|
||||||
|
let colors = vec!["blue", "green", "purple", "orange", "pink", "teal"];
|
||||||
|
colors[turn_number % colors.len()].to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,7 +3,7 @@ use anyhow::Result;
|
|||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use sysinfo::{System, Pid, Process};
|
use sysinfo::{System, Pid, Process};
|
||||||
use tracing::{debug, warn};
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
pub struct ProcessDetector {
|
pub struct ProcessDetector {
|
||||||
system: System,
|
system: System,
|
||||||
@@ -17,7 +17,11 @@ impl ProcessDetector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn detect_instances(&mut self) -> Result<Vec<Instance>> {
|
pub fn detect_instances(&mut self) -> Result<Vec<Instance>> {
|
||||||
self.system.refresh_processes();
|
info!("Scanning for g3 processes...");
|
||||||
|
// Refresh all processes to ensure we catch newly started ones
|
||||||
|
// Using refresh_all() instead of just refresh_processes() to ensure
|
||||||
|
// we get complete information about new processes
|
||||||
|
self.system.refresh_all();
|
||||||
let mut instances = Vec::new();
|
let mut instances = Vec::new();
|
||||||
|
|
||||||
// Find all g3 processes
|
// Find all g3 processes
|
||||||
@@ -33,7 +37,7 @@ impl ProcessDetector {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
debug!("Detected {} g3 instances", instances.len());
|
info!("Detected {} g3 instances", instances.len());
|
||||||
Ok(instances)
|
Ok(instances)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -45,24 +49,27 @@ impl ProcessDetector {
|
|||||||
) -> Option<Instance> {
|
) -> Option<Instance> {
|
||||||
let cmd_str = cmd.join(" ");
|
let cmd_str = cmd.join(" ");
|
||||||
|
|
||||||
|
// Exclude g3-console itself
|
||||||
|
if cmd_str.contains("g3-console") {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
// Check if this is a g3 binary (more comprehensive check)
|
// Check if this is a g3 binary (more comprehensive check)
|
||||||
let is_g3_binary = cmd.get(0).map(|s| {
|
let is_g3_binary = cmd.get(0).map(|s| {
|
||||||
s.ends_with("g3") || s.ends_with("/g3") || s.contains("/target/release/g3") || s.contains("/target/debug/g3")
|
(s.ends_with("g3") || s.ends_with("/g3") || s.contains("/target/release/g3") || s.contains("/target/debug/g3"))
|
||||||
|
&& !s.contains("g3-") // Exclude other g3-* binaries
|
||||||
}).unwrap_or(false);
|
}).unwrap_or(false);
|
||||||
|
|
||||||
// Check if this is cargo run with g3
|
// Check if this is cargo run with g3 (not g3-console or other variants)
|
||||||
let is_cargo_run = cmd.get(0).map(|s| s.contains("cargo")).unwrap_or(false) && cmd.iter().any(|s| s == "run");
|
let is_cargo_run = cmd.get(0).map(|s| s.contains("cargo")).unwrap_or(false)
|
||||||
|
&& cmd.iter().any(|s| s == "run")
|
||||||
|
&& !cmd_str.contains("g3-console");
|
||||||
|
|
||||||
// Also check if any part of the command line contains g3-related patterns
|
// Also check if command line has g3-specific flags
|
||||||
let has_g3_pattern = cmd_str.contains("g3 ")
|
let has_g3_flags = cmd_str.contains("--workspace") || cmd_str.contains("--autonomous");
|
||||||
|| cmd_str.contains("/g3 ")
|
|
||||||
|| cmd_str.contains("g3-")
|
|
||||||
|| cmd_str.ends_with("g3")
|
|
||||||
|| cmd_str.contains("--workspace") // g3-specific flag
|
|
||||||
|| cmd_str.contains("--autonomous"); // g3-specific flag
|
|
||||||
|
|
||||||
// Accept if it's a g3 binary, cargo run with g3 patterns, or has g3-specific flags
|
// Accept if it's a g3 binary or cargo run with g3, and has typical g3 patterns
|
||||||
let is_g3_process = is_g3_binary || (is_cargo_run && has_g3_pattern) || has_g3_pattern;
|
let is_g3_process = is_g3_binary || (is_cargo_run && has_g3_flags);
|
||||||
|
|
||||||
if !is_g3_process {
|
if !is_g3_process {
|
||||||
return None;
|
return None;
|
||||||
@@ -165,7 +172,7 @@ impl ProcessDetector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_process_status(&mut self, pid: u32) -> Option<InstanceStatus> {
|
pub fn get_process_status(&mut self, pid: u32) -> Option<InstanceStatus> {
|
||||||
self.system.refresh_processes();
|
self.system.refresh_all();
|
||||||
|
|
||||||
let sysinfo_pid = Pid::from_u32(pid);
|
let sysinfo_pid = Pid::from_u32(pid);
|
||||||
if self.system.process(sysinfo_pid).is_some() {
|
if self.system.process(sysinfo_pid).is_some() {
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
<div id="app">
|
<div id="app">
|
||||||
<header class="header">
|
<header class="header">
|
||||||
<div class="header-content">
|
<div class="header-content">
|
||||||
<h1 class="header-title">G3 Console</h1>
|
<h1 class="header-title">G3 Console <span id="live-indicator" class="live-indicator" title="Scanning for processes every 3 seconds">● LIVE</span></h1>
|
||||||
<div class="header-actions">
|
<div class="header-actions">
|
||||||
<button id="new-run-btn" class="btn btn-primary">+ New Run</button>
|
<button id="new-run-btn" class="btn btn-primary">+ New Run</button>
|
||||||
<button id="theme-toggle" class="btn btn-secondary">🌙</button>
|
<button id="theme-toggle" class="btn btn-secondary">🌙</button>
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ const router = {
|
|||||||
currentInstanceId: null,
|
currentInstanceId: null,
|
||||||
initialized: false,
|
initialized: false,
|
||||||
renderInProgress: false,
|
renderInProgress: false,
|
||||||
|
REFRESH_INTERVAL_MS: 3000, // Refresh every 3 seconds for live updates
|
||||||
|
|
||||||
init() {
|
init() {
|
||||||
console.log('[Router] init() called');
|
console.log('[Router] init() called');
|
||||||
@@ -84,6 +85,9 @@ const router = {
|
|||||||
this.renderInProgress = true;
|
this.renderInProgress = true;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
// Flash live indicator
|
||||||
|
this.flashLiveIndicator();
|
||||||
|
|
||||||
// Check if we already have a container for instances
|
// Check if we already have a container for instances
|
||||||
let instancesList = container.querySelector('.instances-list');
|
let instancesList = container.querySelector('.instances-list');
|
||||||
const isInitialLoad = !instancesList;
|
const isInitialLoad = !instancesList;
|
||||||
@@ -167,11 +171,11 @@ const router = {
|
|||||||
|
|
||||||
// Schedule next refresh only if still on home route
|
// Schedule next refresh only if still on home route
|
||||||
if (this.currentRoute === '/' || this.currentRoute === '') {
|
if (this.currentRoute === '/' || this.currentRoute === '') {
|
||||||
console.log('[Router] Scheduling auto-refresh in 5 seconds');
|
console.log(`[Router] Scheduling auto-refresh in ${this.REFRESH_INTERVAL_MS}ms`);
|
||||||
this.refreshTimeout = setTimeout(() => {
|
this.refreshTimeout = setTimeout(() => {
|
||||||
console.log('[Router] Auto-refresh triggered');
|
console.log('[Router] Auto-refresh triggered');
|
||||||
this.renderHome(container);
|
this.renderHome(container);
|
||||||
}, 5000);
|
}, this.REFRESH_INTERVAL_MS);
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('[Router] Error in renderHome:', error);
|
console.error('[Router] Error in renderHome:', error);
|
||||||
@@ -187,12 +191,26 @@ const router = {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
flashLiveIndicator() {
|
||||||
|
const indicator = document.getElementById('live-indicator');
|
||||||
|
if (indicator) {
|
||||||
|
indicator.style.animation = 'none';
|
||||||
|
// Force reflow
|
||||||
|
void indicator.offsetWidth;
|
||||||
|
indicator.style.animation = null;
|
||||||
|
indicator.style.opacity = '1';
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
async renderDetail(container, id) {
|
async renderDetail(container, id) {
|
||||||
console.log('[Router] renderDetail called for', id);
|
console.log('[Router] renderDetail called for', id);
|
||||||
|
|
||||||
this.currentInstanceId = id;
|
this.currentInstanceId = id;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
// Flash live indicator
|
||||||
|
this.flashLiveIndicator();
|
||||||
|
|
||||||
// Check if we already have a detail view for this instance
|
// Check if we already have a detail view for this instance
|
||||||
let detailView = container.querySelector('.detail-view');
|
let detailView = container.querySelector('.detail-view');
|
||||||
const isInitialLoad = !detailView || detailView.getAttribute('data-instance-id') !== id;
|
const isInitialLoad = !detailView || detailView.getAttribute('data-instance-id') !== id;
|
||||||
|
|||||||
@@ -64,6 +64,22 @@ body {
|
|||||||
color: var(--text-primary);
|
color: var(--text-primary);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.live-indicator {
|
||||||
|
font-size: 0.625rem; /* 75% of 0.833rem */
|
||||||
|
font-weight: 600;
|
||||||
|
color: var(--success);
|
||||||
|
margin-left: 0.75rem;
|
||||||
|
display: inline-flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.25rem;
|
||||||
|
animation: pulse 2s ease-in-out infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes pulse {
|
||||||
|
0%, 100% { opacity: 1; }
|
||||||
|
50% { opacity: 0.5; }
|
||||||
|
}
|
||||||
|
|
||||||
.header-actions {
|
.header-actions {
|
||||||
display: flex;
|
display: flex;
|
||||||
gap: 1rem;
|
gap: 1rem;
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ pub mod error_handling;
|
|||||||
pub mod project;
|
pub mod project;
|
||||||
pub mod task_result;
|
pub mod task_result;
|
||||||
pub mod ui_writer;
|
pub mod ui_writer;
|
||||||
|
|
||||||
pub use task_result::TaskResult;
|
pub use task_result::TaskResult;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -25,7 +26,7 @@ use anyhow::Result;
|
|||||||
use g3_computer_control::WebDriverController;
|
use g3_computer_control::WebDriverController;
|
||||||
use g3_config::Config;
|
use g3_config::Config;
|
||||||
use g3_execution::CodeExecutor;
|
use g3_execution::CodeExecutor;
|
||||||
use g3_providers::{CompletionRequest, Message, MessageRole, ProviderRegistry, Tool};
|
use g3_providers::{CacheControl, CompletionRequest, Message, MessageRole, ProviderRegistry, Tool};
|
||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@@ -426,18 +427,12 @@ Format this as a detailed but concise summary that can be used to resume the con
|
|||||||
self.used_tokens = 0;
|
self.used_tokens = 0;
|
||||||
|
|
||||||
// Add the summary as a system message
|
// Add the summary as a system message
|
||||||
let summary_message = Message {
|
let summary_message = Message::new(MessageRole::System, format!("Previous conversation summary:\n\n{}", summary));
|
||||||
role: MessageRole::System,
|
|
||||||
content: format!("Previous conversation summary:\n\n{}", summary),
|
|
||||||
};
|
|
||||||
self.add_message(summary_message);
|
self.add_message(summary_message);
|
||||||
|
|
||||||
// Add the latest user message if provided
|
// Add the latest user message if provided
|
||||||
if let Some(user_msg) = latest_user_message {
|
if let Some(user_msg) = latest_user_message {
|
||||||
self.add_message(Message {
|
self.add_message(Message::new(MessageRole::User, user_msg));
|
||||||
role: MessageRole::User,
|
|
||||||
content: user_msg,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let new_chars: usize = self
|
let new_chars: usize = self
|
||||||
@@ -759,6 +754,7 @@ pub struct Agent<W: UiWriter> {
|
|||||||
safaridriver_process: std::sync::Arc<tokio::sync::RwLock<Option<tokio::process::Child>>>,
|
safaridriver_process: std::sync::Arc<tokio::sync::RwLock<Option<tokio::process::Child>>>,
|
||||||
macax_controller:
|
macax_controller:
|
||||||
std::sync::Arc<tokio::sync::RwLock<Option<g3_computer_control::MacAxController>>>,
|
std::sync::Arc<tokio::sync::RwLock<Option<g3_computer_control::MacAxController>>>,
|
||||||
|
tool_call_count: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<W: UiWriter> Agent<W> {
|
impl<W: UiWriter> Agent<W> {
|
||||||
@@ -901,6 +897,8 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
Some(anthropic_config.model.clone()),
|
Some(anthropic_config.model.clone()),
|
||||||
anthropic_config.max_tokens,
|
anthropic_config.max_tokens,
|
||||||
anthropic_config.temperature,
|
anthropic_config.temperature,
|
||||||
|
anthropic_config.cache_config.clone(),
|
||||||
|
anthropic_config.enable_1m_context,
|
||||||
)?;
|
)?;
|
||||||
providers.register(anthropic_provider);
|
providers.register(anthropic_provider);
|
||||||
}
|
}
|
||||||
@@ -942,15 +940,36 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
debug!("Default provider set successfully");
|
debug!("Default provider set successfully");
|
||||||
|
|
||||||
// Determine context window size based on active provider
|
// Determine context window size based on active provider
|
||||||
let context_length = Self::get_configured_context_length(&config, &providers)?;
|
let mut context_warnings = Vec::new();
|
||||||
|
let context_length =
|
||||||
|
Self::get_configured_context_length(&config, &providers, &mut context_warnings)?;
|
||||||
let mut context_window = ContextWindow::new(context_length);
|
let mut context_window = ContextWindow::new(context_length);
|
||||||
|
|
||||||
// If README content is provided, add it as the first system message
|
// Surface any context warnings to the user via UI
|
||||||
if let Some(readme) = readme_content {
|
for warning in context_warnings {
|
||||||
let readme_message = Message {
|
ui_writer.print_context_status(&format!("⚠️ {}", warning));
|
||||||
role: MessageRole::System,
|
}
|
||||||
content: readme,
|
|
||||||
|
// Add system prompt as the FIRST message (before README)
|
||||||
|
// This ensures the agent always has proper tool usage instructions
|
||||||
|
let provider = providers.get(None)?;
|
||||||
|
let provider_has_native_tool_calling = provider.has_native_tool_calling();
|
||||||
|
let _ = provider; // Drop provider reference to avoid borrowing issues
|
||||||
|
|
||||||
|
let system_prompt = if provider_has_native_tool_calling {
|
||||||
|
// For native tool calling providers, use a more explicit system prompt
|
||||||
|
SYSTEM_PROMPT_FOR_NATIVE_TOOL_USE.to_string()
|
||||||
|
} else {
|
||||||
|
// For non-native providers (embedded models), use JSON format instructions
|
||||||
|
SYSTEM_PROMPT_FOR_NON_NATIVE_TOOL_USE.to_string()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let system_message = Message::new(MessageRole::System, system_prompt);
|
||||||
|
context_window.add_message(system_message);
|
||||||
|
|
||||||
|
// If README content is provided, add it as a second system message (after the main system prompt)
|
||||||
|
if let Some(readme) = readme_content {
|
||||||
|
let readme_message = Message::new(MessageRole::System, readme);
|
||||||
context_window.add_message(readme_message);
|
context_window.add_message(readme_message);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1006,18 +1025,54 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
None
|
None
|
||||||
}))
|
}))
|
||||||
},
|
},
|
||||||
|
tool_call_count: 0,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_configured_context_length(config: &Config, providers: &ProviderRegistry) -> Result<u32> {
|
/// Validate that the system prompt is the first message in the conversation history.
|
||||||
// First, check if there's a global max_context_length override in agent config
|
/// This is a critical invariant that must be maintained for proper agent operation.
|
||||||
if let Some(max_context_length) = config.agent.max_context_length {
|
///
|
||||||
debug!("Using configured agent.max_context_length: {}", max_context_length);
|
/// # Panics
|
||||||
return Ok(max_context_length);
|
/// Panics if:
|
||||||
|
/// - The conversation history is empty
|
||||||
|
/// - The first message is not a System message
|
||||||
|
/// - The first message doesn't contain the system prompt markers
|
||||||
|
fn validate_system_prompt_is_first(&self) {
|
||||||
|
if self.context_window.conversation_history.is_empty() {
|
||||||
|
panic!(
|
||||||
|
"FATAL: Conversation history is empty. System prompt must be the first message."
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the configured max_tokens for the current provider
|
let first_message = &self.context_window.conversation_history[0];
|
||||||
fn get_provider_max_tokens(config: &Config, provider_name: &str) -> Option<u32> {
|
|
||||||
|
if !matches!(first_message.role, MessageRole::System) {
|
||||||
|
panic!(
|
||||||
|
"FATAL: First message is not a System message. Found: {:?}",
|
||||||
|
first_message.role
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if !first_message.content.contains("You are G3") {
|
||||||
|
panic!("FATAL: First system message does not contain the system prompt. This likely means the README was added before the system prompt.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert cache config string to CacheControl enum
|
||||||
|
fn parse_cache_control(cache_config: &str) -> Option<CacheControl> {
|
||||||
|
match cache_config {
|
||||||
|
"ephemeral" => Some(CacheControl::ephemeral()),
|
||||||
|
"5minute" => Some(CacheControl::five_minute()),
|
||||||
|
"1hour" => Some(CacheControl::one_hour()),
|
||||||
|
_ => {
|
||||||
|
warn!("Invalid cache_config value: '{}'. Valid values are: ephemeral, 5minute, 1hour", cache_config);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the configured max_tokens for a provider from top-level config
|
||||||
|
fn provider_max_tokens(config: &Config, provider_name: &str) -> Option<u32> {
|
||||||
match provider_name {
|
match provider_name {
|
||||||
"anthropic" => config.providers.anthropic.as_ref()?.max_tokens,
|
"anthropic" => config.providers.anthropic.as_ref()?.max_tokens,
|
||||||
"openai" => config.providers.openai.as_ref()?.max_tokens,
|
"openai" => config.providers.openai.as_ref()?.max_tokens,
|
||||||
@@ -1027,6 +1082,61 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Resolve the max_tokens to use for a given provider, applying fallbacks
|
||||||
|
fn resolve_max_tokens(&self, provider_name: &str) -> u32 {
|
||||||
|
match provider_name {
|
||||||
|
"databricks" => Self::provider_max_tokens(&self.config, "databricks")
|
||||||
|
.or(Some(self.config.agent.fallback_default_max_tokens as u32))
|
||||||
|
.unwrap_or(32000),
|
||||||
|
other => Self::provider_max_tokens(&self.config, other)
|
||||||
|
.or(Some(self.config.agent.fallback_default_max_tokens as u32))
|
||||||
|
.unwrap_or(16000),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Print provider diagnostics through the UiWriter for visibility
|
||||||
|
pub fn print_provider_banner(&self, role_label: &str) {
|
||||||
|
if let Ok((provider_name, model)) = self.get_provider_info() {
|
||||||
|
let max_tokens = self.resolve_max_tokens(&provider_name);
|
||||||
|
let context_len = self.context_window.total_tokens;
|
||||||
|
|
||||||
|
let mut details = vec![
|
||||||
|
format!("provider={}", provider_name),
|
||||||
|
format!("model={}", model),
|
||||||
|
format!("max_tokens={}", max_tokens),
|
||||||
|
format!("context_window_length={}", context_len),
|
||||||
|
];
|
||||||
|
|
||||||
|
if let Ok(provider) = self.providers.get(None) {
|
||||||
|
details.push(format!(
|
||||||
|
"native_tools={}",
|
||||||
|
if provider.has_native_tool_calling() {
|
||||||
|
"yes"
|
||||||
|
} else {
|
||||||
|
"no"
|
||||||
|
}
|
||||||
|
));
|
||||||
|
if provider.supports_cache_control() {
|
||||||
|
details.push("cache_control=yes".to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.ui_writer
|
||||||
|
.print_context_status(&format!("{}: {}", role_label, details.join(", ")));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_configured_context_length(
|
||||||
|
config: &Config,
|
||||||
|
providers: &ProviderRegistry,
|
||||||
|
warnings: &mut Vec<String>,
|
||||||
|
) -> Result<u32> {
|
||||||
|
// First, check if there's a global max_context_length override in agent config
|
||||||
|
if let Some(max_context_length) = config.agent.max_context_length {
|
||||||
|
debug!("Using configured agent.max_context_length: {}", max_context_length);
|
||||||
|
return Ok(max_context_length);
|
||||||
|
}
|
||||||
|
|
||||||
// Get the active provider to determine context length
|
// Get the active provider to determine context length
|
||||||
let provider = providers.get(None)?;
|
let provider = providers.get(None)?;
|
||||||
let provider_name = provider.name();
|
let provider_name = provider.name();
|
||||||
@@ -1053,25 +1163,45 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
}
|
}
|
||||||
"openai" => {
|
"openai" => {
|
||||||
// gpt-5 has 400k window
|
// gpt-5 has 400k window
|
||||||
get_provider_max_tokens(config, "openai").unwrap_or(400000)
|
if let Some(max_tokens) = Self::provider_max_tokens(config, "openai") {
|
||||||
|
warnings.push(format!(
|
||||||
|
"Context length falling back to max_tokens ({}) for provider=openai",
|
||||||
|
max_tokens
|
||||||
|
));
|
||||||
|
max_tokens
|
||||||
|
} else {
|
||||||
|
400000
|
||||||
|
}
|
||||||
}
|
}
|
||||||
"anthropic" => {
|
"anthropic" => {
|
||||||
// Claude models have large context windows
|
// Claude models have large context windows
|
||||||
// Use configured max_tokens or fall back to default
|
// Use configured max_tokens or fall back to default
|
||||||
get_provider_max_tokens(config, "anthropic").unwrap_or(200000)
|
if let Some(max_tokens) = Self::provider_max_tokens(config, "anthropic") {
|
||||||
|
warnings.push(format!(
|
||||||
|
"Context length falling back to max_tokens ({}) for provider=anthropic",
|
||||||
|
max_tokens
|
||||||
|
));
|
||||||
|
max_tokens
|
||||||
|
} else {
|
||||||
|
200000
|
||||||
|
}
|
||||||
}
|
}
|
||||||
"databricks" => {
|
"databricks" => {
|
||||||
// Databricks models have varying context windows depending on the model
|
// Databricks models have varying context windows depending on the model
|
||||||
// Use configured max_tokens or fall back to model-specific defaults
|
// Use configured max_tokens or fall back to model-specific defaults
|
||||||
get_provider_max_tokens(config, "databricks").unwrap_or_else(|| {
|
if let Some(max_tokens) = Self::provider_max_tokens(config, "databricks") {
|
||||||
if model_name.contains("claude") {
|
warnings.push(format!(
|
||||||
|
"Context length falling back to max_tokens ({}) for provider=databricks",
|
||||||
|
max_tokens
|
||||||
|
));
|
||||||
|
max_tokens
|
||||||
|
} else if model_name.contains("claude") {
|
||||||
200000 // Claude models on Databricks have large context windows
|
200000 // Claude models on Databricks have large context windows
|
||||||
} else if model_name.contains("llama") || model_name.contains("dbrx") {
|
} else if model_name.contains("llama") || model_name.contains("dbrx") {
|
||||||
32768 // DBRX supports 32k context
|
32768 // DBRX supports 32k context
|
||||||
} else {
|
} else {
|
||||||
16384 // Conservative default for other Databricks models
|
16384 // Conservative default for other Databricks models
|
||||||
}
|
}
|
||||||
})
|
|
||||||
}
|
}
|
||||||
_ => config.agent.fallback_default_max_tokens as u32,
|
_ => config.agent.fallback_default_max_tokens as u32,
|
||||||
};
|
};
|
||||||
@@ -1171,7 +1301,7 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
async fn execute_single_task(
|
async fn execute_single_task(
|
||||||
&mut self,
|
&mut self,
|
||||||
description: &str,
|
description: &str,
|
||||||
show_prompt: bool,
|
_show_prompt: bool,
|
||||||
_show_code: bool,
|
_show_code: bool,
|
||||||
show_timing: bool,
|
show_timing: bool,
|
||||||
cancellation_token: CancellationToken,
|
cancellation_token: CancellationToken,
|
||||||
@@ -1180,39 +1310,16 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
// This prevents the filter from staying in suppression mode between user interactions
|
// This prevents the filter from staying in suppression mode between user interactions
|
||||||
fixed_filter_json::reset_fixed_json_tool_state();
|
fixed_filter_json::reset_fixed_json_tool_state();
|
||||||
|
|
||||||
|
// Validate that the system prompt is the first message (critical invariant)
|
||||||
|
self.validate_system_prompt_is_first();
|
||||||
|
|
||||||
// Generate session ID based on the initial prompt if this is a new session
|
// Generate session ID based on the initial prompt if this is a new session
|
||||||
if self.session_id.is_none() {
|
if self.session_id.is_none() {
|
||||||
self.session_id = Some(self.generate_session_id(description));
|
self.session_id = Some(self.generate_session_id(description));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only add system message if this is the first interaction (empty conversation history)
|
|
||||||
if self.context_window.conversation_history.is_empty() {
|
|
||||||
let provider = self.providers.get(None)?;
|
|
||||||
let system_prompt = if provider.has_native_tool_calling() {
|
|
||||||
// For native tool calling providers, use a more explicit system prompt
|
|
||||||
SYSTEM_PROMPT_FOR_NATIVE_TOOL_USE.to_string()
|
|
||||||
} else {
|
|
||||||
// For non-native providers (embedded models), use JSON format instructions
|
|
||||||
SYSTEM_PROMPT_FOR_NON_NATIVE_TOOL_USE.to_string()
|
|
||||||
};
|
|
||||||
|
|
||||||
if show_prompt {
|
|
||||||
self.ui_writer.print_system_prompt(&system_prompt);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add system message to context window
|
|
||||||
let system_message = Message {
|
|
||||||
role: MessageRole::System,
|
|
||||||
content: system_prompt,
|
|
||||||
};
|
|
||||||
self.context_window.add_message(system_message);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add user message to context window
|
// Add user message to context window
|
||||||
let user_message = Message {
|
let user_message = Message::new(MessageRole::User, format!("Task: {}", description));
|
||||||
role: MessageRole::User,
|
|
||||||
content: format!("Task: {}", description),
|
|
||||||
};
|
|
||||||
self.context_window.add_message(user_message);
|
self.context_window.add_message(user_message);
|
||||||
|
|
||||||
// Use the complete conversation history for the request
|
// Use the complete conversation history for the request
|
||||||
@@ -1220,6 +1327,9 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
|
|
||||||
// Check if provider supports native tool calling and add tools if so
|
// Check if provider supports native tool calling and add tools if so
|
||||||
let provider = self.providers.get(None)?;
|
let provider = self.providers.get(None)?;
|
||||||
|
let provider_name = provider.name().to_string();
|
||||||
|
let _has_native_tool_calling = provider.has_native_tool_calling();
|
||||||
|
let _supports_cache_control = provider.supports_cache_control();
|
||||||
let tools = if provider.has_native_tool_calling() {
|
let tools = if provider.has_native_tool_calling() {
|
||||||
Some(Self::create_tool_definitions(
|
Some(Self::create_tool_definitions(
|
||||||
self.config.webdriver.enabled,
|
self.config.webdriver.enabled,
|
||||||
@@ -1229,18 +1339,10 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
let _ = provider; // Drop the provider reference to avoid borrowing issues
|
||||||
|
|
||||||
// Get max_tokens from provider configuration
|
// Get max_tokens from provider configuration, falling back to sensible defaults
|
||||||
let max_tokens = match provider.name() {
|
let max_tokens = Some(self.resolve_max_tokens(&provider_name));
|
||||||
"databricks" => {
|
|
||||||
// Use the model's maximum limit for Databricks to allow large file generation
|
|
||||||
Some(32000)
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
// Default for other providers
|
|
||||||
Some(16000)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let request = CompletionRequest {
|
let request = CompletionRequest {
|
||||||
messages,
|
messages,
|
||||||
@@ -1286,9 +1388,23 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
// Add assistant response to context window only if not empty
|
// Add assistant response to context window only if not empty
|
||||||
// This prevents the "Skipping empty message" warning when only tools were executed
|
// This prevents the "Skipping empty message" warning when only tools were executed
|
||||||
if !response_content.trim().is_empty() {
|
if !response_content.trim().is_empty() {
|
||||||
let assistant_message = Message {
|
let assistant_message = {
|
||||||
role: MessageRole::Assistant,
|
// Check if we should use cache control (every 10 tool calls)
|
||||||
content: response_content.clone(),
|
if self.tool_call_count > 0 && self.tool_call_count % 10 == 0 {
|
||||||
|
let provider = self.providers.get(None)?;
|
||||||
|
if let Some(cache_config) = match provider.name() {
|
||||||
|
"anthropic" => self.config.providers.anthropic.as_ref()
|
||||||
|
.and_then(|c| c.cache_config.as_ref())
|
||||||
|
.and_then(|config| Self::parse_cache_control(config)),
|
||||||
|
_ => None,
|
||||||
|
} {
|
||||||
|
Message::with_cache_control_validated(MessageRole::Assistant, response_content.clone(), cache_config, provider)
|
||||||
|
} else {
|
||||||
|
Message::new(MessageRole::Assistant, response_content.clone())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Message::new(MessageRole::Assistant, response_content.clone())
|
||||||
|
}
|
||||||
};
|
};
|
||||||
self.context_window.add_message(assistant_message);
|
self.context_window.add_message(assistant_message);
|
||||||
} else {
|
} else {
|
||||||
@@ -1491,17 +1607,11 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
.join("\n\n");
|
.join("\n\n");
|
||||||
|
|
||||||
let summary_messages = vec![
|
let summary_messages = vec![
|
||||||
Message {
|
Message::new(MessageRole::System, "You are a helpful assistant that creates concise summaries.".to_string()),
|
||||||
role: MessageRole::System,
|
Message::new(MessageRole::User, format!(
|
||||||
content: "You are a helpful assistant that creates concise summaries.".to_string(),
|
|
||||||
},
|
|
||||||
Message {
|
|
||||||
role: MessageRole::User,
|
|
||||||
content: format!(
|
|
||||||
"Based on this conversation history, {}\n\nConversation:\n{}",
|
"Based on this conversation history, {}\n\nConversation:\n{}",
|
||||||
summary_prompt, conversation_text
|
summary_prompt, conversation_text
|
||||||
),
|
)),
|
||||||
},
|
|
||||||
];
|
];
|
||||||
|
|
||||||
let provider = self.providers.get(None)?;
|
let provider = self.providers.get(None)?;
|
||||||
@@ -1589,11 +1699,12 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
pub fn reload_readme(&mut self) -> Result<bool> {
|
pub fn reload_readme(&mut self) -> Result<bool> {
|
||||||
info!("Manual README reload triggered");
|
info!("Manual README reload triggered");
|
||||||
|
|
||||||
// Check if the first message in conversation history is a system message with README content
|
// Check if the second message in conversation history is a system message with README content
|
||||||
|
// (The first message should always be the system prompt)
|
||||||
let has_readme = self
|
let has_readme = self
|
||||||
.context_window
|
.context_window
|
||||||
.conversation_history
|
.conversation_history
|
||||||
.first()
|
.get(1) // Check the SECOND message (index 1)
|
||||||
.map(|m| {
|
.map(|m| {
|
||||||
matches!(m.role, MessageRole::System)
|
matches!(m.role, MessageRole::System)
|
||||||
&& (m.content.contains("Project README")
|
&& (m.content.contains("Project README")
|
||||||
@@ -1601,6 +1712,9 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
})
|
})
|
||||||
.unwrap_or(false);
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
// Validate that the system prompt is still first
|
||||||
|
self.validate_system_prompt_is_first();
|
||||||
|
|
||||||
if !has_readme {
|
if !has_readme {
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
@@ -1623,8 +1737,8 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if found_any {
|
if found_any {
|
||||||
// Replace the first message with the new content
|
// Replace the second message (README) with the new content
|
||||||
if let Some(first_msg) = self.context_window.conversation_history.first_mut() {
|
if let Some(first_msg) = self.context_window.conversation_history.get_mut(1) {
|
||||||
first_msg.content = combined_content;
|
first_msg.content = combined_content;
|
||||||
info!("README content reloaded successfully");
|
info!("README content reloaded successfully");
|
||||||
Ok(true)
|
Ok(true)
|
||||||
@@ -2484,18 +2598,11 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
.join("\n\n");
|
.join("\n\n");
|
||||||
|
|
||||||
let summary_messages = vec![
|
let summary_messages = vec![
|
||||||
Message {
|
Message::new(MessageRole::System, "You are a helpful assistant that creates concise summaries.".to_string()),
|
||||||
role: MessageRole::System,
|
Message::new(MessageRole::User, format!(
|
||||||
content: "You are a helpful assistant that creates concise summaries."
|
|
||||||
.to_string(),
|
|
||||||
},
|
|
||||||
Message {
|
|
||||||
role: MessageRole::User,
|
|
||||||
content: format!(
|
|
||||||
"Based on this conversation history, {}\n\nConversation:\n{}",
|
"Based on this conversation history, {}\n\nConversation:\n{}",
|
||||||
summary_prompt, conversation_text
|
summary_prompt, conversation_text
|
||||||
),
|
)),
|
||||||
},
|
|
||||||
];
|
];
|
||||||
|
|
||||||
let provider = self.providers.get(None)?;
|
let provider = self.providers.get(None)?;
|
||||||
@@ -2981,29 +3088,20 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
// Add the tool call and result to the context window using RAW unfiltered content
|
// Add the tool call and result to the context window using RAW unfiltered content
|
||||||
// This ensures the log file contains the true raw content including JSON tool calls
|
// This ensures the log file contains the true raw content including JSON tool calls
|
||||||
let tool_message = if !raw_content_for_log.trim().is_empty() {
|
let tool_message = if !raw_content_for_log.trim().is_empty() {
|
||||||
Message {
|
Message::new(MessageRole::Assistant, format!(
|
||||||
role: MessageRole::Assistant,
|
|
||||||
content: format!(
|
|
||||||
"{}\n\n{{\"tool\": \"{}\", \"args\": {}}}",
|
"{}\n\n{{\"tool\": \"{}\", \"args\": {}}}",
|
||||||
raw_content_for_log.trim(),
|
raw_content_for_log.trim(),
|
||||||
tool_call.tool,
|
tool_call.tool,
|
||||||
tool_call.args
|
tool_call.args
|
||||||
),
|
))
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// No text content before tool call, just include the tool call
|
// No text content before tool call, just include the tool call
|
||||||
Message {
|
Message::new(MessageRole::Assistant, format!(
|
||||||
role: MessageRole::Assistant,
|
|
||||||
content: format!(
|
|
||||||
"{{\"tool\": \"{}\", \"args\": {}}}",
|
"{{\"tool\": \"{}\", \"args\": {}}}",
|
||||||
tool_call.tool, tool_call.args
|
tool_call.tool, tool_call.args
|
||||||
),
|
))
|
||||||
}
|
|
||||||
};
|
|
||||||
let result_message = Message {
|
|
||||||
role: MessageRole::User,
|
|
||||||
content: format!("Tool result: {}", tool_result),
|
|
||||||
};
|
};
|
||||||
|
let result_message = Message::new(MessageRole::User, format!("Tool result: {}", tool_result));
|
||||||
|
|
||||||
self.context_window.add_message(tool_message);
|
self.context_window.add_message(tool_message);
|
||||||
self.context_window.add_message(result_message);
|
self.context_window.add_message(result_message);
|
||||||
@@ -3012,7 +3110,8 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
request.messages = self.context_window.conversation_history.clone();
|
request.messages = self.context_window.conversation_history.clone();
|
||||||
|
|
||||||
// Ensure tools are included for native providers in subsequent iterations
|
// Ensure tools are included for native providers in subsequent iterations
|
||||||
if provider.has_native_tool_calling() {
|
let provider_for_tools = self.providers.get(None)?;
|
||||||
|
if provider_for_tools.has_native_tool_calling() {
|
||||||
request.tools = Some(Self::create_tool_definitions(
|
request.tools = Some(Self::create_tool_definitions(
|
||||||
self.config.webdriver.enabled,
|
self.config.webdriver.enabled,
|
||||||
self.config.macax.enabled,
|
self.config.macax.enabled,
|
||||||
@@ -3343,9 +3442,23 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
.replace("<</SYS>>", "");
|
.replace("<</SYS>>", "");
|
||||||
|
|
||||||
if !raw_clean.trim().is_empty() {
|
if !raw_clean.trim().is_empty() {
|
||||||
let assistant_message = Message {
|
let assistant_message = {
|
||||||
role: MessageRole::Assistant,
|
// Check if we should use cache control (every 10 tool calls)
|
||||||
content: raw_clean,
|
if self.tool_call_count > 0 && self.tool_call_count % 10 == 0 {
|
||||||
|
let provider = self.providers.get(None)?;
|
||||||
|
if let Some(cache_config) = match provider.name() {
|
||||||
|
"anthropic" => self.config.providers.anthropic.as_ref()
|
||||||
|
.and_then(|c| c.cache_config.as_ref())
|
||||||
|
.and_then(|config| Self::parse_cache_control(config)),
|
||||||
|
_ => None,
|
||||||
|
} {
|
||||||
|
Message::with_cache_control_validated(MessageRole::Assistant, raw_clean, cache_config, provider)
|
||||||
|
} else {
|
||||||
|
Message::new(MessageRole::Assistant, raw_clean)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Message::new(MessageRole::Assistant, raw_clean)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
self.context_window.add_message(assistant_message);
|
self.context_window.add_message(assistant_message);
|
||||||
}
|
}
|
||||||
@@ -3387,7 +3500,10 @@ impl<W: UiWriter> Agent<W> {
|
|||||||
Ok(TaskResult::new(final_response, self.context_window.clone()))
|
Ok(TaskResult::new(final_response, self.context_window.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn execute_tool(&self, tool_call: &ToolCall) -> Result<String> {
|
pub async fn execute_tool(&mut self, tool_call: &ToolCall) -> Result<String> {
|
||||||
|
// Increment tool call count
|
||||||
|
self.tool_call_count += 1;
|
||||||
|
|
||||||
debug!("=== EXECUTING TOOL ===");
|
debug!("=== EXECUTING TOOL ===");
|
||||||
debug!("Tool name: {}", tool_call.tool);
|
debug!("Tool name: {}", tool_call.tool);
|
||||||
debug!("Tool args (raw): {:?}", tool_call.args);
|
debug!("Tool args (raw): {:?}", tool_call.args);
|
||||||
@@ -5371,6 +5487,16 @@ mod integration_tests {
|
|||||||
// Implement Drop to clean up safaridriver process
|
// Implement Drop to clean up safaridriver process
|
||||||
impl<W: UiWriter> Drop for Agent<W> {
|
impl<W: UiWriter> Drop for Agent<W> {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
|
// Validate system prompt invariant on drop (agent exit)
|
||||||
|
// This catches any bugs where the conversation history was corrupted during execution
|
||||||
|
if !self.context_window.conversation_history.is_empty() {
|
||||||
|
if let Err(e) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||||
|
self.validate_system_prompt_is_first();
|
||||||
|
})) {
|
||||||
|
eprintln!("\n⚠️ FATAL ERROR ON EXIT: System prompt validation failed: {:?}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Try to kill safaridriver process if it's still running
|
// Try to kill safaridriver process if it's still running
|
||||||
// We need to use try_lock since we can't await in Drop
|
// We need to use try_lock since we can't await in Drop
|
||||||
if let Ok(mut process_guard) = self.safaridriver_process.try_write() {
|
if let Ok(mut process_guard) = self.safaridriver_process.try_write() {
|
||||||
|
|||||||
@@ -6,14 +6,10 @@ use std::sync::Arc;
|
|||||||
fn test_task_result_basic_functionality() {
|
fn test_task_result_basic_functionality() {
|
||||||
// Create a context window with some messages
|
// Create a context window with some messages
|
||||||
let mut context = ContextWindow::new(10000);
|
let mut context = ContextWindow::new(10000);
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(MessageRole::User, "Test message 1".to_string())
|
||||||
role: MessageRole::User,
|
);
|
||||||
content: "Test message 1".to_string(),
|
context.add_message(Message::new(MessageRole::Assistant, "Response 1".to_string())
|
||||||
});
|
);
|
||||||
context.add_message(Message {
|
|
||||||
role: MessageRole::Assistant,
|
|
||||||
content: "Response 1".to_string(),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Create a TaskResult
|
// Create a TaskResult
|
||||||
let response = "This is the response\n\nFinal output block".to_string();
|
let response = "This is the response\n\nFinal output block".to_string();
|
||||||
@@ -100,10 +96,7 @@ fn test_context_window_preservation() {
|
|||||||
|
|
||||||
// Add some messages
|
// Add some messages
|
||||||
for i in 0..5 {
|
for i in 0..5 {
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant }, format!("Message {}", i)));
|
||||||
role: if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant },
|
|
||||||
content: format!("Message {}", i),
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create TaskResult
|
// Create TaskResult
|
||||||
|
|||||||
@@ -46,10 +46,10 @@ fn test_thin_context_basic() {
|
|||||||
// Add some messages to the first third
|
// Add some messages to the first third
|
||||||
for i in 0..9 {
|
for i in 0..9 {
|
||||||
if i % 2 == 0 {
|
if i % 2 == 0 {
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(
|
||||||
role: MessageRole::Assistant,
|
MessageRole::Assistant,
|
||||||
content: format!("Assistant message {}", i),
|
format!("Assistant message {}", i),
|
||||||
});
|
));
|
||||||
} else {
|
} else {
|
||||||
// Add tool results with varying sizes
|
// Add tool results with varying sizes
|
||||||
let content = if i == 1 {
|
let content = if i == 1 {
|
||||||
@@ -63,10 +63,10 @@ fn test_thin_context_basic() {
|
|||||||
format!("Tool result: small result {}", i)
|
format!("Tool result: small result {}", i)
|
||||||
};
|
};
|
||||||
|
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(
|
||||||
role: MessageRole::User,
|
MessageRole::User,
|
||||||
content,
|
content,
|
||||||
});
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,10 +98,10 @@ fn test_thin_write_file_tool_calls() {
|
|||||||
let mut context = ContextWindow::new(10000);
|
let mut context = ContextWindow::new(10000);
|
||||||
|
|
||||||
// Add some messages including a write_file tool call with large content
|
// Add some messages including a write_file tool call with large content
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(
|
||||||
role: MessageRole::User,
|
MessageRole::User,
|
||||||
content: "Please create a large file".to_string(),
|
"Please create a large file".to_string(),
|
||||||
});
|
));
|
||||||
|
|
||||||
// Add an assistant message with a write_file tool call containing large content
|
// Add an assistant message with a write_file tool call containing large content
|
||||||
let large_content = "x".repeat(1500);
|
let large_content = "x".repeat(1500);
|
||||||
@@ -109,22 +109,22 @@ fn test_thin_write_file_tool_calls() {
|
|||||||
r#"{{"tool": "write_file", "args": {{"file_path": "test.txt", "content": "{}"}}}}"#,
|
r#"{{"tool": "write_file", "args": {{"file_path": "test.txt", "content": "{}"}}}}"#,
|
||||||
large_content
|
large_content
|
||||||
);
|
);
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(
|
||||||
role: MessageRole::Assistant,
|
MessageRole::Assistant,
|
||||||
content: format!("I'll create that file.\n\n{}", tool_call_json),
|
format!("I'll create that file.\n\n{}", tool_call_json),
|
||||||
});
|
));
|
||||||
|
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(
|
||||||
role: MessageRole::User,
|
MessageRole::User,
|
||||||
content: "Tool result: ✅ Successfully wrote 1500 lines".to_string(),
|
"Tool result: ✅ Successfully wrote 1500 lines".to_string(),
|
||||||
});
|
));
|
||||||
|
|
||||||
// Add more messages to ensure we have enough for "first third" logic
|
// Add more messages to ensure we have enough for "first third" logic
|
||||||
for i in 0..6 {
|
for i in 0..6 {
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(
|
||||||
role: MessageRole::Assistant,
|
MessageRole::Assistant,
|
||||||
content: format!("Response {}", i),
|
format!("Response {}", i),
|
||||||
});
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trigger thinning at 50%
|
// Trigger thinning at 50%
|
||||||
@@ -154,10 +154,10 @@ fn test_thin_str_replace_tool_calls() {
|
|||||||
let mut context = ContextWindow::new(10000);
|
let mut context = ContextWindow::new(10000);
|
||||||
|
|
||||||
// Add some messages including a str_replace tool call with large diff
|
// Add some messages including a str_replace tool call with large diff
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(
|
||||||
role: MessageRole::User,
|
MessageRole::User,
|
||||||
content: "Please update the file".to_string(),
|
"Please update the file".to_string(),
|
||||||
});
|
));
|
||||||
|
|
||||||
// Add an assistant message with a str_replace tool call containing large diff
|
// Add an assistant message with a str_replace tool call containing large diff
|
||||||
let large_diff = format!("--- old\n{}\n+++ new\n{}", "-old line\n".repeat(100), "+new line\n".repeat(100));
|
let large_diff = format!("--- old\n{}\n+++ new\n{}", "-old line\n".repeat(100), "+new line\n".repeat(100));
|
||||||
@@ -165,22 +165,22 @@ fn test_thin_str_replace_tool_calls() {
|
|||||||
r#"{{"tool": "str_replace", "args": {{"file_path": "test.txt", "diff": "{}"}}}}"#,
|
r#"{{"tool": "str_replace", "args": {{"file_path": "test.txt", "diff": "{}"}}}}"#,
|
||||||
large_diff.replace('\n', "\\n")
|
large_diff.replace('\n', "\\n")
|
||||||
);
|
);
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(
|
||||||
role: MessageRole::Assistant,
|
MessageRole::Assistant,
|
||||||
content: format!("I'll update that file.\n\n{}", tool_call_json),
|
format!("I'll update that file.\n\n{}", tool_call_json),
|
||||||
});
|
));
|
||||||
|
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(
|
||||||
role: MessageRole::User,
|
MessageRole::User,
|
||||||
content: "Tool result: ✅ applied unified diff".to_string(),
|
"Tool result: ✅ applied unified diff".to_string(),
|
||||||
});
|
));
|
||||||
|
|
||||||
// Add more messages to ensure we have enough for "first third" logic
|
// Add more messages to ensure we have enough for "first third" logic
|
||||||
for i in 0..6 {
|
for i in 0..6 {
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(
|
||||||
role: MessageRole::Assistant,
|
MessageRole::Assistant,
|
||||||
content: format!("Response {}", i),
|
format!("Response {}", i),
|
||||||
});
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trigger thinning at 50%
|
// Trigger thinning at 50%
|
||||||
@@ -212,10 +212,10 @@ fn test_thin_context_no_large_results() {
|
|||||||
|
|
||||||
// Add only small messages
|
// Add only small messages
|
||||||
for i in 0..9 {
|
for i in 0..9 {
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(
|
||||||
role: MessageRole::User,
|
MessageRole::User,
|
||||||
content: format!("Tool result: small {}", i),
|
format!("Tool result: small {}", i),
|
||||||
});
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
context.used_tokens = 5000;
|
context.used_tokens = 5000;
|
||||||
@@ -244,7 +244,7 @@ fn test_thin_context_only_affects_first_third() {
|
|||||||
MessageRole::Assistant
|
MessageRole::Assistant
|
||||||
};
|
};
|
||||||
|
|
||||||
context.add_message(Message { role, content });
|
context.add_message(Message::new(role, content));
|
||||||
}
|
}
|
||||||
|
|
||||||
context.used_tokens = 5000;
|
context.used_tokens = 5000;
|
||||||
|
|||||||
@@ -8,27 +8,18 @@ fn test_todo_read_results_not_thinned() {
|
|||||||
let mut context = ContextWindow::new(10000);
|
let mut context = ContextWindow::new(10000);
|
||||||
|
|
||||||
// Add a todo_read tool call
|
// Add a todo_read tool call
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(MessageRole::Assistant, r#"{"tool": "todo_read", "args": {}}"#.to_string()));
|
||||||
role: MessageRole::Assistant,
|
|
||||||
content: r#"{"tool": "todo_read", "args": {}}"#.to_string(),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Add a large TODO result (> 500 chars)
|
// Add a large TODO result (> 500 chars)
|
||||||
let large_todo_result = format!(
|
let large_todo_result = format!(
|
||||||
"Tool result: 📝 TODO list:\n{}",
|
"Tool result: 📝 TODO list:\n{}",
|
||||||
"- [ ] Task with long description\n".repeat(50)
|
"- [ ] Task with long description\n".repeat(50)
|
||||||
);
|
);
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(MessageRole::User, large_todo_result.clone()));
|
||||||
role: MessageRole::User,
|
|
||||||
content: large_todo_result.clone(),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Add more messages to ensure we have enough for "first third" logic
|
// Add more messages to ensure we have enough for "first third" logic
|
||||||
for i in 0..6 {
|
for i in 0..6 {
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(MessageRole::Assistant, format!("Response {}", i)))
|
||||||
role: MessageRole::Assistant,
|
|
||||||
content: format!("Response {}", i),
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trigger thinning at 50%
|
// Trigger thinning at 50%
|
||||||
@@ -65,27 +56,18 @@ fn test_todo_write_results_not_thinned() {
|
|||||||
|
|
||||||
// Add a todo_write tool call
|
// Add a todo_write tool call
|
||||||
let large_content = "- [ ] Task\n".repeat(100);
|
let large_content = "- [ ] Task\n".repeat(100);
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(MessageRole::Assistant, format!(r#"{{"tool": "todo_write", "args": {{"content": "{}"}}}}"#, large_content)));
|
||||||
role: MessageRole::Assistant,
|
|
||||||
content: format!(r#"{{"tool": "todo_write", "args": {{"content": "{}"}}}}"#, large_content),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Add a large TODO write result
|
// Add a large TODO write result
|
||||||
let large_todo_result = format!(
|
let large_todo_result = format!(
|
||||||
"Tool result: ✅ TODO list updated ({} chars) and saved to todo.g3.md",
|
"Tool result: ✅ TODO list updated ({} chars) and saved to todo.g3.md",
|
||||||
large_content.len()
|
large_content.len()
|
||||||
);
|
);
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(MessageRole::User, large_todo_result.clone()));
|
||||||
role: MessageRole::User,
|
|
||||||
content: large_todo_result.clone(),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Add more messages
|
// Add more messages
|
||||||
for i in 0..6 {
|
for i in 0..6 {
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(MessageRole::Assistant, format!("Response {}", i)))
|
||||||
role: MessageRole::Assistant,
|
|
||||||
content: format!("Response {}", i),
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trigger thinning at 50%
|
// Trigger thinning at 50%
|
||||||
@@ -119,24 +101,15 @@ fn test_non_todo_results_still_thinned() {
|
|||||||
let mut context = ContextWindow::new(10000);
|
let mut context = ContextWindow::new(10000);
|
||||||
|
|
||||||
// Add a non-TODO tool call (e.g., read_file)
|
// Add a non-TODO tool call (e.g., read_file)
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(MessageRole::Assistant, r#"{"tool": "read_file", "args": {"file_path": "test.txt"}}"#.to_string()));
|
||||||
role: MessageRole::Assistant,
|
|
||||||
content: r#"{"tool": "read_file", "args": {"file_path": "test.txt"}}"#.to_string(),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Add a large read_file result (> 500 chars)
|
// Add a large read_file result (> 500 chars)
|
||||||
let large_result = format!("Tool result: {}", "x".repeat(1500));
|
let large_result = format!("Tool result: {}", "x".repeat(1500));
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(MessageRole::User, large_result));
|
||||||
role: MessageRole::User,
|
|
||||||
content: large_result,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Add more messages
|
// Add more messages
|
||||||
for i in 0..6 {
|
for i in 0..6 {
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(MessageRole::Assistant, format!("Response {}", i)))
|
||||||
role: MessageRole::Assistant,
|
|
||||||
content: format!("Response {}", i),
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trigger thinning at 50%
|
// Trigger thinning at 50%
|
||||||
@@ -172,27 +145,18 @@ fn test_todo_read_with_spaces_in_tool_name() {
|
|||||||
let mut context = ContextWindow::new(10000);
|
let mut context = ContextWindow::new(10000);
|
||||||
|
|
||||||
// Add a todo_read tool call with spaces (JSON formatting variation)
|
// Add a todo_read tool call with spaces (JSON formatting variation)
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(MessageRole::Assistant, r#"{"tool": "todo_read", "args": {}}"#.to_string()));
|
||||||
role: MessageRole::Assistant,
|
|
||||||
content: r#"{"tool": "todo_read", "args": {}}"#.to_string(),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Add a large TODO result
|
// Add a large TODO result
|
||||||
let large_todo_result = format!(
|
let large_todo_result = format!(
|
||||||
"Tool result: 📝 TODO list:\n{}",
|
"Tool result: 📝 TODO list:\n{}",
|
||||||
"- [ ] Task\n".repeat(50)
|
"- [ ] Task\n".repeat(50)
|
||||||
);
|
);
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(MessageRole::User, large_todo_result.clone()));
|
||||||
role: MessageRole::User,
|
|
||||||
content: large_todo_result.clone(),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Add more messages
|
// Add more messages
|
||||||
for i in 0..6 {
|
for i in 0..6 {
|
||||||
context.add_message(Message {
|
context.add_message(Message::new(MessageRole::Assistant, format!("Response {}", i)))
|
||||||
role: MessageRole::Assistant,
|
|
||||||
content: format!("Response {}", i),
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trigger thinning
|
// Trigger thinning
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ fn get_todo_path(temp_dir: &TempDir) -> PathBuf {
|
|||||||
#[serial]
|
#[serial]
|
||||||
async fn test_todo_write_creates_file() {
|
async fn test_todo_write_creates_file() {
|
||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let agent = create_test_agent_in_dir(&temp_dir).await;
|
let mut agent = create_test_agent_in_dir(&temp_dir).await;
|
||||||
let todo_path = get_todo_path(&temp_dir);
|
let todo_path = get_todo_path(&temp_dir);
|
||||||
|
|
||||||
// Initially, todo.g3.md should not exist
|
// Initially, todo.g3.md should not exist
|
||||||
@@ -67,7 +67,7 @@ async fn test_todo_read_from_file() {
|
|||||||
fs::write(&todo_path, test_content).unwrap();
|
fs::write(&todo_path, test_content).unwrap();
|
||||||
|
|
||||||
// Create agent (should load from file)
|
// Create agent (should load from file)
|
||||||
let agent = create_test_agent_in_dir(&temp_dir).await;
|
let mut agent = create_test_agent_in_dir(&temp_dir).await;
|
||||||
|
|
||||||
// Create a tool call to read TODO
|
// Create a tool call to read TODO
|
||||||
let tool_call = g3_core::ToolCall {
|
let tool_call = g3_core::ToolCall {
|
||||||
@@ -88,7 +88,7 @@ async fn test_todo_read_from_file() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
async fn test_todo_read_empty_file() {
|
async fn test_todo_read_empty_file() {
|
||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let agent = create_test_agent_in_dir(&temp_dir).await;
|
let mut agent = create_test_agent_in_dir(&temp_dir).await;
|
||||||
|
|
||||||
// Create a tool call to read TODO (file doesn't exist)
|
// Create a tool call to read TODO (file doesn't exist)
|
||||||
let tool_call = g3_core::ToolCall {
|
let tool_call = g3_core::ToolCall {
|
||||||
@@ -111,7 +111,7 @@ async fn test_todo_persistence_across_agents() {
|
|||||||
|
|
||||||
// Agent 1: Write TODO
|
// Agent 1: Write TODO
|
||||||
{
|
{
|
||||||
let agent = create_test_agent_in_dir(&temp_dir).await;
|
let mut agent = create_test_agent_in_dir(&temp_dir).await;
|
||||||
let tool_call = g3_core::ToolCall {
|
let tool_call = g3_core::ToolCall {
|
||||||
tool: "todo_write".to_string(),
|
tool: "todo_write".to_string(),
|
||||||
args: serde_json::json!({
|
args: serde_json::json!({
|
||||||
@@ -126,7 +126,7 @@ async fn test_todo_persistence_across_agents() {
|
|||||||
|
|
||||||
// Agent 2: Read TODO (new agent instance)
|
// Agent 2: Read TODO (new agent instance)
|
||||||
{
|
{
|
||||||
let agent = create_test_agent_in_dir(&temp_dir).await;
|
let mut agent = create_test_agent_in_dir(&temp_dir).await;
|
||||||
let tool_call = g3_core::ToolCall {
|
let tool_call = g3_core::ToolCall {
|
||||||
tool: "todo_read".to_string(),
|
tool: "todo_read".to_string(),
|
||||||
args: serde_json::json!({}),
|
args: serde_json::json!({}),
|
||||||
@@ -143,7 +143,7 @@ async fn test_todo_persistence_across_agents() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
async fn test_todo_update_preserves_file() {
|
async fn test_todo_update_preserves_file() {
|
||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let agent = create_test_agent_in_dir(&temp_dir).await;
|
let mut agent = create_test_agent_in_dir(&temp_dir).await;
|
||||||
let todo_path = get_todo_path(&temp_dir);
|
let todo_path = get_todo_path(&temp_dir);
|
||||||
|
|
||||||
// Write initial TODO
|
// Write initial TODO
|
||||||
@@ -173,7 +173,7 @@ async fn test_todo_update_preserves_file() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
async fn test_todo_handles_large_content() {
|
async fn test_todo_handles_large_content() {
|
||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let agent = create_test_agent_in_dir(&temp_dir).await;
|
let mut agent = create_test_agent_in_dir(&temp_dir).await;
|
||||||
let todo_path = get_todo_path(&temp_dir);
|
let todo_path = get_todo_path(&temp_dir);
|
||||||
|
|
||||||
// Create a large TODO (but under the 50k limit)
|
// Create a large TODO (but under the 50k limit)
|
||||||
@@ -202,7 +202,7 @@ async fn test_todo_handles_large_content() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
async fn test_todo_respects_size_limit() {
|
async fn test_todo_respects_size_limit() {
|
||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let agent = create_test_agent_in_dir(&temp_dir).await;
|
let mut agent = create_test_agent_in_dir(&temp_dir).await;
|
||||||
|
|
||||||
// Create content that exceeds the default 50k limit
|
// Create content that exceeds the default 50k limit
|
||||||
let huge_content = "x".repeat(60_000);
|
let huge_content = "x".repeat(60_000);
|
||||||
@@ -232,7 +232,7 @@ async fn test_todo_agent_initialization_loads_file() {
|
|||||||
fs::write(&todo_path, initial_content).unwrap();
|
fs::write(&todo_path, initial_content).unwrap();
|
||||||
|
|
||||||
// Create agent - should load the file during initialization
|
// Create agent - should load the file during initialization
|
||||||
let agent = create_test_agent_in_dir(&temp_dir).await;
|
let mut agent = create_test_agent_in_dir(&temp_dir).await;
|
||||||
|
|
||||||
// Read TODO - should return the pre-existing content
|
// Read TODO - should return the pre-existing content
|
||||||
let tool_call = g3_core::ToolCall {
|
let tool_call = g3_core::ToolCall {
|
||||||
@@ -248,7 +248,7 @@ async fn test_todo_agent_initialization_loads_file() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
async fn test_todo_handles_unicode_content() {
|
async fn test_todo_handles_unicode_content() {
|
||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let agent = create_test_agent_in_dir(&temp_dir).await;
|
let mut agent = create_test_agent_in_dir(&temp_dir).await;
|
||||||
let todo_path = get_todo_path(&temp_dir);
|
let todo_path = get_todo_path(&temp_dir);
|
||||||
|
|
||||||
// Create TODO with unicode characters
|
// Create TODO with unicode characters
|
||||||
@@ -283,7 +283,7 @@ async fn test_todo_handles_unicode_content() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
async fn test_todo_empty_content_creates_empty_file() {
|
async fn test_todo_empty_content_creates_empty_file() {
|
||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let agent = create_test_agent_in_dir(&temp_dir).await;
|
let mut agent = create_test_agent_in_dir(&temp_dir).await;
|
||||||
let todo_path = get_todo_path(&temp_dir);
|
let todo_path = get_todo_path(&temp_dir);
|
||||||
|
|
||||||
// Write empty TODO
|
// Write empty TODO
|
||||||
@@ -306,7 +306,7 @@ async fn test_todo_empty_content_creates_empty_file() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
async fn test_todo_whitespace_only_content() {
|
async fn test_todo_whitespace_only_content() {
|
||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let agent = create_test_agent_in_dir(&temp_dir).await;
|
let mut agent = create_test_agent_in_dir(&temp_dir).await;
|
||||||
|
|
||||||
// Write whitespace-only TODO
|
// Write whitespace-only TODO
|
||||||
let tool_call = g3_core::ToolCall {
|
let tool_call = g3_core::ToolCall {
|
||||||
|
|||||||
@@ -21,22 +21,18 @@
|
|||||||
//! // Create the provider with your API key
|
//! // Create the provider with your API key
|
||||||
//! let provider = AnthropicProvider::new(
|
//! let provider = AnthropicProvider::new(
|
||||||
//! "your-api-key".to_string(),
|
//! "your-api-key".to_string(),
|
||||||
//! Some("claude-3-5-sonnet-20241022".to_string()), // Optional: defaults to claude-3-5-sonnet-20241022
|
//! Some("claude-3-5-sonnet-20241022".to_string()),
|
||||||
//! Some(4096), // Optional: max tokens
|
//! Some(4096),
|
||||||
//! Some(0.1), // Optional: temperature
|
//! Some(0.1),
|
||||||
|
//! None, // cache_config
|
||||||
|
//! None, // enable_1m_context
|
||||||
//! )?;
|
//! )?;
|
||||||
//!
|
//!
|
||||||
//! // Create a completion request
|
//! // Create a completion request
|
||||||
//! let request = CompletionRequest {
|
//! let request = CompletionRequest {
|
||||||
//! messages: vec![
|
//! messages: vec![
|
||||||
//! Message {
|
//! Message::new(MessageRole::System, "You are a helpful assistant.".to_string()),
|
||||||
//! role: MessageRole::System,
|
//! Message::new(MessageRole::User, "Hello! How are you?".to_string()),
|
||||||
//! content: "You are a helpful assistant.".to_string(),
|
|
||||||
//! },
|
|
||||||
//! Message {
|
|
||||||
//! role: MessageRole::User,
|
|
||||||
//! content: "Hello! How are you?".to_string(),
|
|
||||||
//! },
|
|
||||||
//! ],
|
//! ],
|
||||||
//! max_tokens: Some(1000),
|
//! max_tokens: Some(1000),
|
||||||
//! temperature: Some(0.7),
|
//! temperature: Some(0.7),
|
||||||
@@ -62,15 +58,16 @@
|
|||||||
//! async fn main() -> anyhow::Result<()> {
|
//! async fn main() -> anyhow::Result<()> {
|
||||||
//! let provider = AnthropicProvider::new(
|
//! let provider = AnthropicProvider::new(
|
||||||
//! "your-api-key".to_string(),
|
//! "your-api-key".to_string(),
|
||||||
//! None, None, None,
|
//! None,
|
||||||
|
//! None,
|
||||||
|
//! None,
|
||||||
|
//! None, // cache_config
|
||||||
|
//! None, // enable_1m_context
|
||||||
//! )?;
|
//! )?;
|
||||||
//!
|
//!
|
||||||
//! let request = CompletionRequest {
|
//! let request = CompletionRequest {
|
||||||
//! messages: vec![
|
//! messages: vec![
|
||||||
//! Message {
|
//! Message::new(MessageRole::User, "Write a short story about a robot.".to_string()),
|
||||||
//! role: MessageRole::User,
|
|
||||||
//! content: "Write a short story about a robot.".to_string(),
|
|
||||||
//! },
|
|
||||||
//! ],
|
//! ],
|
||||||
//! max_tokens: Some(1000),
|
//! max_tokens: Some(1000),
|
||||||
//! temperature: Some(0.7),
|
//! temperature: Some(0.7),
|
||||||
@@ -123,6 +120,8 @@ pub struct AnthropicProvider {
|
|||||||
model: String,
|
model: String,
|
||||||
max_tokens: u32,
|
max_tokens: u32,
|
||||||
temperature: f32,
|
temperature: f32,
|
||||||
|
cache_config: Option<String>,
|
||||||
|
enable_1m_context: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AnthropicProvider {
|
impl AnthropicProvider {
|
||||||
@@ -131,6 +130,8 @@ impl AnthropicProvider {
|
|||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
max_tokens: Option<u32>,
|
max_tokens: Option<u32>,
|
||||||
temperature: Option<f32>,
|
temperature: Option<f32>,
|
||||||
|
cache_config: Option<String>,
|
||||||
|
enable_1m_context: Option<bool>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let client = Client::builder()
|
let client = Client::builder()
|
||||||
.timeout(Duration::from_secs(300))
|
.timeout(Duration::from_secs(300))
|
||||||
@@ -147,6 +148,8 @@ impl AnthropicProvider {
|
|||||||
model,
|
model,
|
||||||
max_tokens: max_tokens.unwrap_or(4096),
|
max_tokens: max_tokens.unwrap_or(4096),
|
||||||
temperature: temperature.unwrap_or(0.1),
|
temperature: temperature.unwrap_or(0.1),
|
||||||
|
cache_config,
|
||||||
|
enable_1m_context: enable_1m_context.unwrap_or(false),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,9 +159,12 @@ impl AnthropicProvider {
|
|||||||
.post(ANTHROPIC_API_URL)
|
.post(ANTHROPIC_API_URL)
|
||||||
.header("x-api-key", &self.api_key)
|
.header("x-api-key", &self.api_key)
|
||||||
.header("anthropic-version", ANTHROPIC_VERSION)
|
.header("anthropic-version", ANTHROPIC_VERSION)
|
||||||
// Anthropic beta 1m context window. Enable if needed. It costs extra, so check first.
|
|
||||||
// .header("anthropic-beta", "context-1m-2025-08-07")
|
|
||||||
.header("content-type", "application/json");
|
.header("content-type", "application/json");
|
||||||
|
|
||||||
|
if self.enable_1m_context {
|
||||||
|
builder = builder.header("anthropic-beta", "context-1m-2025-08-07");
|
||||||
|
}
|
||||||
|
|
||||||
if streaming {
|
if streaming {
|
||||||
builder = builder.header("accept", "text/event-stream");
|
builder = builder.header("accept", "text/event-stream");
|
||||||
}
|
}
|
||||||
@@ -166,6 +172,11 @@ impl AnthropicProvider {
|
|||||||
builder
|
builder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn convert_cache_control(cache_control: &crate::CacheControl) -> crate::CacheControl {
|
||||||
|
// Anthropic uses the same format, so just clone it
|
||||||
|
cache_control.clone()
|
||||||
|
}
|
||||||
|
|
||||||
fn convert_tools(&self, tools: &[Tool]) -> Vec<AnthropicTool> {
|
fn convert_tools(&self, tools: &[Tool]) -> Vec<AnthropicTool> {
|
||||||
tools
|
tools
|
||||||
.iter()
|
.iter()
|
||||||
@@ -214,6 +225,8 @@ impl AnthropicProvider {
|
|||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: vec![AnthropicContent::Text {
|
content: vec![AnthropicContent::Text {
|
||||||
text: message.content.clone(),
|
text: message.content.clone(),
|
||||||
|
cache_control: message.cache_control.as_ref()
|
||||||
|
.map(Self::convert_cache_control),
|
||||||
}],
|
}],
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -222,6 +235,8 @@ impl AnthropicProvider {
|
|||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: vec![AnthropicContent::Text {
|
content: vec![AnthropicContent::Text {
|
||||||
text: message.content.clone(),
|
text: message.content.clone(),
|
||||||
|
cache_control: message.cache_control.as_ref()
|
||||||
|
.map(Self::convert_cache_control),
|
||||||
}],
|
}],
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -564,7 +579,7 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
.content
|
.content
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|c| match c {
|
.filter_map(|c| match c {
|
||||||
AnthropicContent::Text { text } => Some(text.as_str()),
|
AnthropicContent::Text { text, .. } => Some(text.as_str()),
|
||||||
_ => None,
|
_ => None,
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
@@ -658,6 +673,11 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
// Claude models support native tool calling
|
// Claude models support native tool calling
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_cache_control(&self) -> bool {
|
||||||
|
// Anthropic supports cache control
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Anthropic API request/response structures
|
// Anthropic API request/response structures
|
||||||
@@ -701,7 +721,11 @@ struct AnthropicMessage {
|
|||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
enum AnthropicContent {
|
enum AnthropicContent {
|
||||||
#[serde(rename = "text")]
|
#[serde(rename = "text")]
|
||||||
Text { text: String },
|
Text {
|
||||||
|
text: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
cache_control: Option<crate::CacheControl>,
|
||||||
|
},
|
||||||
#[serde(rename = "tool_use")]
|
#[serde(rename = "tool_use")]
|
||||||
ToolUse {
|
ToolUse {
|
||||||
id: String,
|
id: String,
|
||||||
@@ -771,21 +795,14 @@ mod tests {
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
).unwrap();
|
).unwrap();
|
||||||
|
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
Message {
|
Message::new(MessageRole::System, "You are a helpful assistant.".to_string()),
|
||||||
role: MessageRole::System,
|
Message::new(MessageRole::User, "Hello!".to_string()),
|
||||||
content: "You are a helpful assistant.".to_string(),
|
Message::new(MessageRole::Assistant, "Hi there!".to_string()),
|
||||||
},
|
|
||||||
Message {
|
|
||||||
role: MessageRole::User,
|
|
||||||
content: "Hello!".to_string(),
|
|
||||||
},
|
|
||||||
Message {
|
|
||||||
role: MessageRole::Assistant,
|
|
||||||
content: "Hi there!".to_string(),
|
|
||||||
},
|
|
||||||
];
|
];
|
||||||
|
|
||||||
let (system, anthropic_messages) = provider.convert_messages(&messages).unwrap();
|
let (system, anthropic_messages) = provider.convert_messages(&messages).unwrap();
|
||||||
@@ -803,14 +820,11 @@ mod tests {
|
|||||||
Some("claude-3-haiku-20240307".to_string()),
|
Some("claude-3-haiku-20240307".to_string()),
|
||||||
Some(1000),
|
Some(1000),
|
||||||
Some(0.5),
|
Some(0.5),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
).unwrap();
|
).unwrap();
|
||||||
|
|
||||||
let messages = vec![
|
let messages = vec![Message::new(MessageRole::User, "Test message".to_string())];
|
||||||
Message {
|
|
||||||
role: MessageRole::User,
|
|
||||||
content: "Test message".to_string(),
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
let request_body = provider
|
let request_body = provider
|
||||||
.create_request_body(&messages, None, false, 1000, 0.5)
|
.create_request_body(&messages, None, false, 1000, 0.5)
|
||||||
@@ -831,6 +845,8 @@ mod tests {
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
).unwrap();
|
).unwrap();
|
||||||
|
|
||||||
let tools = vec![
|
let tools = vec![
|
||||||
@@ -859,4 +875,48 @@ mod tests {
|
|||||||
assert!(anthropic_tools[0].input_schema.required.is_some());
|
assert!(anthropic_tools[0].input_schema.required.is_some());
|
||||||
assert_eq!(anthropic_tools[0].input_schema.required.as_ref().unwrap()[0], "location");
|
assert_eq!(anthropic_tools[0].input_schema.required.as_ref().unwrap()[0], "location");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cache_control_serialization() {
|
||||||
|
let provider = AnthropicProvider::new(
|
||||||
|
"test-key".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
).unwrap();
|
||||||
|
|
||||||
|
// Test message WITHOUT cache_control
|
||||||
|
let messages_without = vec![Message::new(MessageRole::User, "Hello".to_string())];
|
||||||
|
let (_, anthropic_messages_without) = provider.convert_messages(&messages_without).unwrap();
|
||||||
|
let json_without = serde_json::to_string(&anthropic_messages_without).unwrap();
|
||||||
|
|
||||||
|
println!("Anthropic JSON without cache_control: {}", json_without);
|
||||||
|
// Check if cache_control appears in the JSON
|
||||||
|
if json_without.contains("cache_control") {
|
||||||
|
println!("WARNING: JSON contains 'cache_control' field when not configured!");
|
||||||
|
assert!(!json_without.contains("\"cache_control\":null"),
|
||||||
|
"JSON should not contain 'cache_control: null'");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test message WITH cache_control
|
||||||
|
let messages_with = vec![Message::with_cache_control(
|
||||||
|
MessageRole::User,
|
||||||
|
"Hello".to_string(),
|
||||||
|
crate::CacheControl::ephemeral(),
|
||||||
|
)];
|
||||||
|
let (_, anthropic_messages_with) = provider.convert_messages(&messages_with).unwrap();
|
||||||
|
let json_with = serde_json::to_string(&anthropic_messages_with).unwrap();
|
||||||
|
|
||||||
|
println!("Anthropic JSON with cache_control: {}", json_with);
|
||||||
|
assert!(json_with.contains("cache_control"),
|
||||||
|
"JSON should contain 'cache_control' field when configured");
|
||||||
|
assert!(json_with.contains("ephemeral"),
|
||||||
|
"JSON should contain 'ephemeral' type");
|
||||||
|
|
||||||
|
// The key assertion: when cache_control is None, it should not appear in JSON
|
||||||
|
assert!(!json_without.contains("cache_control") || !json_without.contains("null"),
|
||||||
|
"JSON should not contain 'cache_control' field or null values when not configured");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,10 +39,7 @@
|
|||||||
//! // Create a completion request
|
//! // Create a completion request
|
||||||
//! let request = CompletionRequest {
|
//! let request = CompletionRequest {
|
||||||
//! messages: vec![
|
//! messages: vec![
|
||||||
//! Message {
|
//! Message::new(MessageRole::User, "Hello! How are you?".to_string()),
|
||||||
//! role: MessageRole::User,
|
|
||||||
//! content: "Hello! How are you?".to_string(),
|
|
||||||
//! },
|
|
||||||
//! ],
|
//! ],
|
||||||
//! max_tokens: Some(1000),
|
//! max_tokens: Some(1000),
|
||||||
//! temperature: Some(0.7),
|
//! temperature: Some(0.7),
|
||||||
@@ -251,9 +248,12 @@ impl DatabricksProvider {
|
|||||||
MessageRole::Assistant => "assistant",
|
MessageRole::Assistant => "assistant",
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Always use simple string format (Databricks doesn't support cache_control)
|
||||||
|
let content = serde_json::Value::String(message.content.clone());
|
||||||
|
|
||||||
databricks_messages.push(DatabricksMessage {
|
databricks_messages.push(DatabricksMessage {
|
||||||
role: role.to_string(),
|
role: role.to_string(),
|
||||||
content: Some(message.content.clone()),
|
content: Some(content),
|
||||||
tool_calls: None, // Only used in responses, not requests
|
tool_calls: None, // Only used in responses, not requests
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -864,8 +864,22 @@ impl LLMProvider for DatabricksProvider {
|
|||||||
let content = databricks_response
|
let content = databricks_response
|
||||||
.choices
|
.choices
|
||||||
.first()
|
.first()
|
||||||
.and_then(|choice| choice.message.content.as_ref())
|
.and_then(|choice| {
|
||||||
.cloned()
|
choice.message.content.as_ref().map(|c| {
|
||||||
|
// Handle both string and array formats
|
||||||
|
if let Some(s) = c.as_str() {
|
||||||
|
s.to_string()
|
||||||
|
} else if let Some(arr) = c.as_array() {
|
||||||
|
// Extract text from content blocks
|
||||||
|
arr.iter()
|
||||||
|
.filter_map(|block| block.get("text").and_then(|t| t.as_str()))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("")
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
// Check if there are tool calls in the response
|
// Check if there are tool calls in the response
|
||||||
@@ -1037,6 +1051,10 @@ impl LLMProvider for DatabricksProvider {
|
|||||||
// This includes Claude, Llama, DBRX, and most other models on the platform
|
// This includes Claude, Llama, DBRX, and most other models on the platform
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_cache_control(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Databricks API request/response structures
|
// Databricks API request/response structures
|
||||||
@@ -1067,7 +1085,8 @@ struct DatabricksFunction {
|
|||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct DatabricksMessage {
|
struct DatabricksMessage {
|
||||||
role: String,
|
role: String,
|
||||||
content: Option<String>, // Make content optional since tool calls might not have content
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
content: Option<serde_json::Value>, // Can be string or array of content blocks
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
tool_calls: Option<Vec<DatabricksToolCall>>, // Add tool_calls field for responses
|
tool_calls: Option<Vec<DatabricksToolCall>>, // Add tool_calls field for responses
|
||||||
}
|
}
|
||||||
@@ -1154,18 +1173,9 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
Message {
|
Message::new(MessageRole::System, "You are a helpful assistant.".to_string()),
|
||||||
role: MessageRole::System,
|
Message::new(MessageRole::User, "Hello!".to_string()),
|
||||||
content: "You are a helpful assistant.".to_string(),
|
Message::new(MessageRole::Assistant, "Hi there!".to_string()),
|
||||||
},
|
|
||||||
Message {
|
|
||||||
role: MessageRole::User,
|
|
||||||
content: "Hello!".to_string(),
|
|
||||||
},
|
|
||||||
Message {
|
|
||||||
role: MessageRole::Assistant,
|
|
||||||
content: "Hi there!".to_string(),
|
|
||||||
},
|
|
||||||
];
|
];
|
||||||
|
|
||||||
let databricks_messages = provider.convert_messages(&messages).unwrap();
|
let databricks_messages = provider.convert_messages(&messages).unwrap();
|
||||||
@@ -1187,10 +1197,7 @@ mod tests {
|
|||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let messages = vec![Message {
|
let messages = vec![Message::new(MessageRole::User, "Test message".to_string())];
|
||||||
role: MessageRole::User,
|
|
||||||
content: "Test message".to_string(),
|
|
||||||
}];
|
|
||||||
|
|
||||||
let request_body = provider
|
let request_body = provider
|
||||||
.create_request_body(&messages, None, false, 1000, 0.5)
|
.create_request_body(&messages, None, false, 1000, 0.5)
|
||||||
@@ -1273,4 +1280,62 @@ mod tests {
|
|||||||
assert!(llama_provider.has_native_tool_calling());
|
assert!(llama_provider.has_native_tool_calling());
|
||||||
assert!(dbrx_provider.has_native_tool_calling());
|
assert!(dbrx_provider.has_native_tool_calling());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cache_control_serialization() {
|
||||||
|
let provider = DatabricksProvider::from_token(
|
||||||
|
"https://test.databricks.com".to_string(),
|
||||||
|
"test-token".to_string(),
|
||||||
|
"databricks-claude-sonnet-4".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Test message WITHOUT cache_control
|
||||||
|
let messages_without = vec![Message::new(MessageRole::User, "Hello".to_string())];
|
||||||
|
let databricks_messages_without = provider.convert_messages(&messages_without).unwrap();
|
||||||
|
let json_without = serde_json::to_string(&databricks_messages_without).unwrap();
|
||||||
|
|
||||||
|
println!("JSON without cache_control: {}", json_without);
|
||||||
|
assert!(!json_without.contains("cache_control"),
|
||||||
|
"JSON should not contain 'cache_control' field when not configured");
|
||||||
|
|
||||||
|
// Test message WITH cache_control - should still NOT include it (Databricks doesn't support it)
|
||||||
|
let messages_with = vec![Message::with_cache_control(
|
||||||
|
MessageRole::User,
|
||||||
|
"Hello".to_string(),
|
||||||
|
crate::CacheControl::ephemeral(),
|
||||||
|
)];
|
||||||
|
let databricks_messages_with = provider.convert_messages(&messages_with).unwrap();
|
||||||
|
let json_with = serde_json::to_string(&databricks_messages_with).unwrap();
|
||||||
|
|
||||||
|
println!("JSON with cache_control: {}", json_with);
|
||||||
|
assert!(!json_with.contains("cache_control"),
|
||||||
|
"JSON should NOT contain 'cache_control' field - Databricks doesn't support it");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_databricks_does_not_support_cache_control() {
|
||||||
|
let claude_provider = DatabricksProvider::from_token(
|
||||||
|
"https://test.databricks.com".to_string(),
|
||||||
|
"test-token".to_string(),
|
||||||
|
"databricks-claude-sonnet-4".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let llama_provider = DatabricksProvider::from_token(
|
||||||
|
"https://test.databricks.com".to_string(),
|
||||||
|
"test-token".to_string(),
|
||||||
|
"databricks-meta-llama-3-3-70b-instruct".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!claude_provider.supports_cache_control(), "Databricks should not support cache_control even for Claude models");
|
||||||
|
assert!(!llama_provider.supports_cache_control(), "Databricks should not support cache_control for Llama models");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,6 +21,11 @@ pub trait LLMProvider: Send + Sync {
|
|||||||
fn has_native_tool_calling(&self) -> bool {
|
fn has_native_tool_calling(&self) -> bool {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check if the provider supports cache control
|
||||||
|
fn supports_cache_control(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -32,10 +37,40 @@ pub struct CompletionRequest {
|
|||||||
pub tools: Option<Vec<Tool>>,
|
pub tools: Option<Vec<Tool>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct CacheControl {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub cache_type: CacheType,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub ttl: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum CacheType {
|
||||||
|
Ephemeral,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CacheControl {
|
||||||
|
pub fn ephemeral() -> Self {
|
||||||
|
Self { cache_type: CacheType::Ephemeral, ttl: None }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn five_minute() -> Self {
|
||||||
|
Self { cache_type: CacheType::Ephemeral, ttl: Some("5m".to_string()) }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn one_hour() -> Self {
|
||||||
|
Self { cache_type: CacheType::Ephemeral, ttl: Some("1h".to_string()) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct Message {
|
pub struct Message {
|
||||||
pub role: MessageRole,
|
pub role: MessageRole,
|
||||||
pub content: String,
|
pub content: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub cache_control: Option<CacheControl>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -95,6 +130,45 @@ pub use databricks::DatabricksProvider;
|
|||||||
pub use embedded::EmbeddedProvider;
|
pub use embedded::EmbeddedProvider;
|
||||||
pub use openai::OpenAIProvider;
|
pub use openai::OpenAIProvider;
|
||||||
|
|
||||||
|
impl Message {
|
||||||
|
/// Create a new message with optional cache control
|
||||||
|
pub fn new(role: MessageRole, content: String) -> Self {
|
||||||
|
Self {
|
||||||
|
role,
|
||||||
|
content,
|
||||||
|
cache_control: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new message with cache control
|
||||||
|
pub fn with_cache_control(role: MessageRole, content: String, cache_control: CacheControl) -> Self {
|
||||||
|
Self {
|
||||||
|
role,
|
||||||
|
content,
|
||||||
|
cache_control: Some(cache_control),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a message with cache control, with provider validation
|
||||||
|
pub fn with_cache_control_validated(
|
||||||
|
role: MessageRole,
|
||||||
|
content: String,
|
||||||
|
cache_control: CacheControl,
|
||||||
|
provider: &dyn LLMProvider
|
||||||
|
) -> Self {
|
||||||
|
if !provider.supports_cache_control() {
|
||||||
|
tracing::warn!(
|
||||||
|
"Cache control requested for provider '{}' which does not support it. \
|
||||||
|
Cache control is only supported by Anthropic and Anthropic via Databricks.",
|
||||||
|
provider.name()
|
||||||
|
);
|
||||||
|
return Self::new(role, content);
|
||||||
|
}
|
||||||
|
|
||||||
|
Self::with_cache_control(role, content, cache_control)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Provider registry for managing multiple LLM providers
|
/// Provider registry for managing multiple LLM providers
|
||||||
pub struct ProviderRegistry {
|
pub struct ProviderRegistry {
|
||||||
providers: HashMap<String, Box<dyn LLMProvider>>,
|
providers: HashMap<String, Box<dyn LLMProvider>>,
|
||||||
@@ -144,3 +218,68 @@ impl Default for ProviderRegistry {
|
|||||||
Self::new()
|
Self::new()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_message_serialization_without_cache_control() {
|
||||||
|
let msg = Message::new(MessageRole::User, "Hello".to_string());
|
||||||
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
|
|
||||||
|
println!("Message JSON without cache_control: {}", json);
|
||||||
|
assert!(!json.contains("cache_control"),
|
||||||
|
"JSON should not contain 'cache_control' field when not configured");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_message_serialization_with_cache_control() {
|
||||||
|
let msg = Message::with_cache_control(
|
||||||
|
MessageRole::User,
|
||||||
|
"Hello".to_string(),
|
||||||
|
CacheControl::ephemeral(),
|
||||||
|
);
|
||||||
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
|
|
||||||
|
println!("Message JSON with cache_control: {}", json);
|
||||||
|
assert!(json.contains("cache_control"),
|
||||||
|
"JSON should contain 'cache_control' field when configured");
|
||||||
|
assert!(json.contains("ephemeral"),
|
||||||
|
"JSON should contain 'ephemeral' value");
|
||||||
|
assert!(json.contains("\"type\":"),
|
||||||
|
"JSON should contain 'type' field in cache_control");
|
||||||
|
assert!(!json.contains("null"),
|
||||||
|
"JSON should not contain null values");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cache_control_five_minute_serialization() {
|
||||||
|
let msg = Message::with_cache_control(
|
||||||
|
MessageRole::User,
|
||||||
|
"Hello".to_string(),
|
||||||
|
CacheControl::five_minute(),
|
||||||
|
);
|
||||||
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
|
|
||||||
|
println!("Message JSON with 5-minute cache_control: {}", json);
|
||||||
|
assert!(json.contains("cache_control"), "JSON should contain 'cache_control' field");
|
||||||
|
assert!(json.contains("ephemeral"), "JSON should contain 'ephemeral' type");
|
||||||
|
assert!(json.contains("\"ttl\":\"5m\""), "JSON should contain ttl field with 5m value");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cache_control_one_hour_serialization() {
|
||||||
|
let msg = Message::with_cache_control(
|
||||||
|
MessageRole::User,
|
||||||
|
"Hello".to_string(),
|
||||||
|
CacheControl::one_hour(),
|
||||||
|
);
|
||||||
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
|
|
||||||
|
println!("Message JSON with 1-hour cache_control: {}", json);
|
||||||
|
assert!(json.contains("cache_control"), "JSON should contain 'cache_control' field");
|
||||||
|
assert!(json.contains("ephemeral"), "JSON should contain 'ephemeral' type");
|
||||||
|
assert!(json.contains("\"ttl\":\"1h\""), "JSON should contain ttl field with 1h value");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
131
crates/g3-providers/tests/cache_control_error_regression_test.rs
Normal file
131
crates/g3-providers/tests/cache_control_error_regression_test.rs
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
//! Regression test for cache_control serialization bug
|
||||||
|
//!
|
||||||
|
//! This test verifies that cache_control is NOT serialized in the wrong format.
|
||||||
|
//! The bug was that it serialized as:
|
||||||
|
//! - `system.0.cache_control.ephemeral.ttl` (WRONG)
|
||||||
|
//!
|
||||||
|
//! It should serialize as:
|
||||||
|
//! - `"cache_control": {"type": "ephemeral"}` for ephemeral
|
||||||
|
//! - `"cache_control": {"type": "ephemeral", "ttl": "5m"}` for 5minute
|
||||||
|
//! - `"cache_control": {"type": "ephemeral", "ttl": "1h"}` for 1hour
|
||||||
|
|
||||||
|
use g3_providers::{CacheControl, Message, MessageRole};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_no_wrong_serialization_format() {
|
||||||
|
// Test ephemeral
|
||||||
|
let msg = Message::with_cache_control(
|
||||||
|
MessageRole::System,
|
||||||
|
"Test".to_string(),
|
||||||
|
CacheControl::ephemeral(),
|
||||||
|
);
|
||||||
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
|
|
||||||
|
println!("Ephemeral message JSON: {}", json);
|
||||||
|
|
||||||
|
// Should NOT contain the wrong format
|
||||||
|
assert!(!json.contains("system.0.cache_control"),
|
||||||
|
"JSON should not contain 'system.0.cache_control' path");
|
||||||
|
assert!(!json.contains("cache_control.ephemeral"),
|
||||||
|
"JSON should not contain 'cache_control.ephemeral' path");
|
||||||
|
|
||||||
|
// Should contain the correct format
|
||||||
|
assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#),
|
||||||
|
"JSON should contain correct cache_control format");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_five_minute_no_wrong_format() {
|
||||||
|
let msg = Message::with_cache_control(
|
||||||
|
MessageRole::System,
|
||||||
|
"Test".to_string(),
|
||||||
|
CacheControl::five_minute(),
|
||||||
|
);
|
||||||
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
|
|
||||||
|
println!("5-minute message JSON: {}", json);
|
||||||
|
|
||||||
|
// Should NOT contain the wrong format
|
||||||
|
assert!(!json.contains("system.0.cache_control"),
|
||||||
|
"JSON should not contain 'system.0.cache_control' path");
|
||||||
|
assert!(!json.contains("cache_control.ephemeral.ttl"),
|
||||||
|
"JSON should not contain 'cache_control.ephemeral.ttl' path");
|
||||||
|
|
||||||
|
// Should contain the correct format with ttl as a direct field
|
||||||
|
assert!(json.contains(r#""type":"ephemeral""#),
|
||||||
|
"JSON should contain type field");
|
||||||
|
assert!(json.contains(r#""ttl":"5m""#),
|
||||||
|
"JSON should contain ttl field with value 5m");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_one_hour_no_wrong_format() {
|
||||||
|
let msg = Message::with_cache_control(
|
||||||
|
MessageRole::System,
|
||||||
|
"Test".to_string(),
|
||||||
|
CacheControl::one_hour(),
|
||||||
|
);
|
||||||
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
|
|
||||||
|
println!("1-hour message JSON: {}", json);
|
||||||
|
|
||||||
|
// Should NOT contain the wrong format
|
||||||
|
assert!(!json.contains("system.0.cache_control"),
|
||||||
|
"JSON should not contain 'system.0.cache_control' path");
|
||||||
|
assert!(!json.contains("cache_control.ephemeral.ttl"),
|
||||||
|
"JSON should not contain 'cache_control.ephemeral.ttl' path");
|
||||||
|
|
||||||
|
// Should contain the correct format with ttl as a direct field
|
||||||
|
assert!(json.contains(r#""type":"ephemeral""#),
|
||||||
|
"JSON should contain type field");
|
||||||
|
assert!(json.contains(r#""ttl":"1h""#),
|
||||||
|
"JSON should contain ttl field with value 1h");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cache_control_structure_is_flat() {
|
||||||
|
// Verify that the cache_control object has a flat structure
|
||||||
|
// with 'type' and optional 'ttl' at the same level
|
||||||
|
|
||||||
|
let cache_control = CacheControl::five_minute();
|
||||||
|
let json_value = serde_json::to_value(&cache_control).unwrap();
|
||||||
|
|
||||||
|
println!("Cache control as JSON value: {}", serde_json::to_string_pretty(&json_value).unwrap());
|
||||||
|
|
||||||
|
let obj = json_value.as_object().expect("Should be an object");
|
||||||
|
|
||||||
|
// Should have exactly 2 keys at the top level
|
||||||
|
assert_eq!(obj.len(), 2, "Cache control should have exactly 2 top-level fields");
|
||||||
|
|
||||||
|
// Both 'type' and 'ttl' should be at the same level
|
||||||
|
assert!(obj.contains_key("type"), "Should have 'type' field");
|
||||||
|
assert!(obj.contains_key("ttl"), "Should have 'ttl' field");
|
||||||
|
|
||||||
|
// 'type' should be a string, not an object
|
||||||
|
assert!(obj["type"].is_string(), "'type' should be a string value");
|
||||||
|
|
||||||
|
// 'ttl' should be a string, not nested
|
||||||
|
assert!(obj["ttl"].is_string(), "'ttl' should be a string value");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ephemeral_cache_control_structure() {
|
||||||
|
let cache_control = CacheControl::ephemeral();
|
||||||
|
let json_value = serde_json::to_value(&cache_control).unwrap();
|
||||||
|
|
||||||
|
println!("Ephemeral cache control as JSON value: {}", serde_json::to_string_pretty(&json_value).unwrap());
|
||||||
|
|
||||||
|
let obj = json_value.as_object().expect("Should be an object");
|
||||||
|
|
||||||
|
// Should have exactly 1 key (only 'type', no 'ttl')
|
||||||
|
assert_eq!(obj.len(), 1, "Ephemeral cache control should have exactly 1 top-level field");
|
||||||
|
|
||||||
|
// Should have 'type' field
|
||||||
|
assert!(obj.contains_key("type"), "Should have 'type' field");
|
||||||
|
|
||||||
|
// Should NOT have 'ttl' field
|
||||||
|
assert!(!obj.contains_key("ttl"), "Ephemeral should not have 'ttl' field");
|
||||||
|
|
||||||
|
// 'type' should be a string with value "ephemeral"
|
||||||
|
assert_eq!(obj["type"].as_str().unwrap(), "ephemeral");
|
||||||
|
}
|
||||||
164
crates/g3-providers/tests/cache_control_integration_test.rs
Normal file
164
crates/g3-providers/tests/cache_control_integration_test.rs
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
//! Integration tests for cache_control feature
|
||||||
|
//!
|
||||||
|
//! These tests verify that cache_control is correctly serialized in messages
|
||||||
|
//! for both Anthropic and Databricks providers.
|
||||||
|
|
||||||
|
use g3_providers::{CacheControl, Message, MessageRole};
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ephemeral_cache_control_serialization() {
|
||||||
|
let cache_control = CacheControl::ephemeral();
|
||||||
|
let json = serde_json::to_value(&cache_control).unwrap();
|
||||||
|
|
||||||
|
println!("Ephemeral cache_control JSON: {}", serde_json::to_string(&json).unwrap());
|
||||||
|
|
||||||
|
assert_eq!(json, json!({
|
||||||
|
"type": "ephemeral"
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Verify no ttl field is present
|
||||||
|
assert!(!json.as_object().unwrap().contains_key("ttl"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_five_minute_cache_control_serialization() {
|
||||||
|
let cache_control = CacheControl::five_minute();
|
||||||
|
let json = serde_json::to_value(&cache_control).unwrap();
|
||||||
|
|
||||||
|
println!("5-minute cache_control JSON: {}", serde_json::to_string(&json).unwrap());
|
||||||
|
|
||||||
|
assert_eq!(json, json!({
|
||||||
|
"type": "ephemeral",
|
||||||
|
"ttl": "5m"
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_one_hour_cache_control_serialization() {
|
||||||
|
let cache_control = CacheControl::one_hour();
|
||||||
|
let json = serde_json::to_value(&cache_control).unwrap();
|
||||||
|
|
||||||
|
println!("1-hour cache_control JSON: {}", serde_json::to_string(&json).unwrap());
|
||||||
|
|
||||||
|
assert_eq!(json, json!({
|
||||||
|
"type": "ephemeral",
|
||||||
|
"ttl": "1h"
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_message_with_ephemeral_cache_control() {
|
||||||
|
let msg = Message::with_cache_control(
|
||||||
|
MessageRole::System,
|
||||||
|
"System prompt".to_string(),
|
||||||
|
CacheControl::ephemeral(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let json = serde_json::to_value(&msg).unwrap();
|
||||||
|
println!("Message with ephemeral cache_control: {}", serde_json::to_string(&json).unwrap());
|
||||||
|
|
||||||
|
let cache_control = json.get("cache_control").expect("cache_control field should exist");
|
||||||
|
assert_eq!(cache_control.get("type").unwrap(), "ephemeral");
|
||||||
|
assert!(!cache_control.as_object().unwrap().contains_key("ttl"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_message_with_five_minute_cache_control() {
|
||||||
|
let msg = Message::with_cache_control(
|
||||||
|
MessageRole::System,
|
||||||
|
"System prompt".to_string(),
|
||||||
|
CacheControl::five_minute(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let json = serde_json::to_value(&msg).unwrap();
|
||||||
|
println!("Message with 5-minute cache_control: {}", serde_json::to_string(&json).unwrap());
|
||||||
|
|
||||||
|
let cache_control = json.get("cache_control").expect("cache_control field should exist");
|
||||||
|
assert_eq!(cache_control.get("type").unwrap(), "ephemeral");
|
||||||
|
assert_eq!(cache_control.get("ttl").unwrap(), "5m");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_message_with_one_hour_cache_control() {
|
||||||
|
let msg = Message::with_cache_control(
|
||||||
|
MessageRole::System,
|
||||||
|
"System prompt".to_string(),
|
||||||
|
CacheControl::one_hour(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let json = serde_json::to_value(&msg).unwrap();
|
||||||
|
println!("Message with 1-hour cache_control: {}", serde_json::to_string(&json).unwrap());
|
||||||
|
|
||||||
|
let cache_control = json.get("cache_control").expect("cache_control field should exist");
|
||||||
|
assert_eq!(cache_control.get("type").unwrap(), "ephemeral");
|
||||||
|
assert_eq!(cache_control.get("ttl").unwrap(), "1h");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_message_without_cache_control() {
|
||||||
|
let msg = Message::new(MessageRole::User, "Hello".to_string());
|
||||||
|
|
||||||
|
let json = serde_json::to_value(&msg).unwrap();
|
||||||
|
println!("Message without cache_control: {}", serde_json::to_string(&json).unwrap());
|
||||||
|
|
||||||
|
// cache_control field should not be present when not set
|
||||||
|
assert!(!json.as_object().unwrap().contains_key("cache_control"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cache_control_json_format_ephemeral() {
|
||||||
|
let cache_control = CacheControl::ephemeral();
|
||||||
|
let json_str = serde_json::to_string(&cache_control).unwrap();
|
||||||
|
|
||||||
|
println!("Ephemeral JSON string: {}", json_str);
|
||||||
|
|
||||||
|
// Verify exact JSON format
|
||||||
|
assert_eq!(json_str, r#"{"type":"ephemeral"}"#);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cache_control_json_format_five_minute() {
|
||||||
|
let cache_control = CacheControl::five_minute();
|
||||||
|
let json_str = serde_json::to_string(&cache_control).unwrap();
|
||||||
|
|
||||||
|
println!("5-minute JSON string: {}", json_str);
|
||||||
|
|
||||||
|
// Verify exact JSON format
|
||||||
|
assert_eq!(json_str, r#"{"type":"ephemeral","ttl":"5m"}"#);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cache_control_json_format_one_hour() {
|
||||||
|
let cache_control = CacheControl::one_hour();
|
||||||
|
let json_str = serde_json::to_string(&cache_control).unwrap();
|
||||||
|
|
||||||
|
println!("1-hour JSON string: {}", json_str);
|
||||||
|
|
||||||
|
// Verify exact JSON format
|
||||||
|
assert_eq!(json_str, r#"{"type":"ephemeral","ttl":"1h"}"#);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_deserialization_ephemeral() {
|
||||||
|
let json_str = r#"{"type":"ephemeral"}"#;
|
||||||
|
let cache_control: CacheControl = serde_json::from_str(json_str).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(cache_control.ttl, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_deserialization_five_minute() {
|
||||||
|
let json_str = r#"{"type":"ephemeral","ttl":"5m"}"#;
|
||||||
|
let cache_control: CacheControl = serde_json::from_str(json_str).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(cache_control.ttl, Some("5m".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_deserialization_one_hour() {
|
||||||
|
let json_str = r#"{"type":"ephemeral","ttl":"1h"}"#;
|
||||||
|
let cache_control: CacheControl = serde_json::from_str(json_str).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(cache_control.ttl, Some("1h".to_string()));
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user