Readability refactor: extract mega-functions into focused helpers

Agent: carmack

4 files refactored, net -250 lines, all tests passing (417 + 71).

datalog.rs:
- Extract 7 predicate evaluation helpers from evaluate_predicate_datalog()
  (~200-line match → 12-line dispatch table)
- Extract rule_body_for_predicate() from format_datalog_program()
  (~75-line match → 2-line call)

invariants.rs:
- Extract 7 per-rule helpers from evaluate_predicate()
  (~230-line match → 12-line dispatch table)

envelope.rs:
- Simplify summary construction in verify_envelope()
- Eliminate redundant clone in stamp_envelope()

anthropic.rs:
- Introduce StreamState struct with 6 handler methods
- parse_streaming_response: ~290 lines → ~90 lines
- Max nesting depth reduced from 8 to 4 levels
This commit is contained in:
Dhanji R. Prasanna
2026-02-13 16:21:38 +11:00
parent 0410efd41b
commit 1ad74baaa5
4 changed files with 517 additions and 767 deletions

View File

@@ -425,7 +425,161 @@ pub fn execute_rules(
}
}
/// Evaluate a single predicate using the fact lookup.
// ── Predicate evaluation helpers ────────────────────────────────────────
// Each returns (passed: bool, reason: String) for a specific rule type.
/// Exists / NotExists: check whether the claim has any values.
/// `expect_present = true` → Exists; `false` → NotExists.
fn eval_existence(claim_values: Option<&HashSet<&str>>, expect_present: bool) -> (bool, String) {
let has_values = claim_values.map_or(false, |v| !v.is_empty());
if expect_present == has_values {
if has_values {
(true, "Value exists".into())
} else {
(true, "Value does not exist as expected".into())
}
} else if expect_present {
(false, "Value does not exist".into())
} else {
(false, "Value exists but should not".into())
}
}
/// Contains / NotContains: check whether a specific value is among the claim's facts.
/// `expect_present = true` → Contains; `false` → NotContains.
fn eval_membership(
claim_values: Option<&HashSet<&str>>,
pred: &CompiledPredicate,
expect_present: bool,
) -> (bool, String) {
let expected = pred.expected_value.as_deref().unwrap_or("");
match claim_values {
Some(values) => {
let found = values.contains(expected);
if expect_present {
if found {
(true, format!("Contains '{}'", expected))
} else {
(false, format!("Does not contain '{}'", expected))
}
} else if found {
(false, format!("Contains '{}' but should not", expected))
} else {
(true, format!("Does not contain '{}'", expected))
}
}
None if expect_present => (false, format!("Claim '{}' has no values", pred.claim_name)),
None => (true, format!("Claim '{}' has no values (not_contains passes vacuously)", pred.claim_name)),
}
}
/// Equals: exactly one value must match the expected string.
fn eval_equals(claim_values: Option<&HashSet<&str>>, pred: &CompiledPredicate) -> (bool, String) {
let expected = pred.expected_value.as_deref().unwrap_or("");
let Some(values) = claim_values else {
return (false, format!("Claim '{}' has no values", pred.claim_name));
};
if values.len() == 1 && values.contains(expected) {
(true, format!("Equals '{}'", expected))
} else if values.len() > 1 {
(false, format!("Multiple values found, expected single value '{}'", expected))
} else {
let actual = values.iter().next().copied().unwrap_or("<none>");
(false, format!("Expected '{}', got '{}'", expected, actual))
}
}
/// MinLength / MaxLength: compare the `__length` fact against a threshold.
fn eval_length(
fact_lookup: &HashMap<&str, HashSet<&str>>,
pred: &CompiledPredicate,
cmp: impl Fn(usize, usize) -> bool,
pass_op: &str,
fail_op: &str,
label: &str,
) -> (bool, String) {
let expected: usize = pred.expected_value.as_deref()
.and_then(|s| s.parse().ok())
.unwrap_or(0);
let length_key = format!("{}.__length", pred.claim_name);
let length: usize = fact_lookup
.get(length_key.as_str())
.and_then(|v| v.iter().next())
.and_then(|s| s.parse().ok())
.unwrap_or(0);
if cmp(length, expected) {
(true, format!("Length {} {} {}", length, pass_op, expected))
} else {
(false, format!("Length {} {} {} ({})", length, fail_op, expected, label))
}
}
/// GreaterThan / LessThan: parse the first claim value as f64 and compare.
fn eval_numeric_cmp(
claim_values: Option<&HashSet<&str>>,
pred: &CompiledPredicate,
cmp: impl Fn(f64, f64) -> bool,
op: &str,
) -> (bool, String) {
let expected: f64 = pred.expected_value.as_deref()
.and_then(|s| s.parse().ok())
.unwrap_or(0.0);
let Some(values) = claim_values else {
return (false, format!("Claim '{}' has no values", pred.claim_name));
};
match values.iter().next().and_then(|s| s.parse::<f64>().ok()) {
Some(actual) if cmp(actual, expected) => (true, format!("{} {} {}", actual, op, expected)),
Some(actual) => (false, format!("{} is not {} {}", actual, op, expected)),
None => (false, "Value is not a number".into()),
}
}
/// Matches: check if any claim value matches a regex pattern.
fn eval_matches(claim_values: Option<&HashSet<&str>>, pred: &CompiledPredicate) -> (bool, String) {
let pattern = pred.expected_value.as_deref().unwrap_or("");
let regex = match regex::Regex::new(pattern) {
Ok(r) => r,
Err(e) => return (false, format!("Invalid regex: {}", e)),
};
let Some(values) = claim_values else {
return (false, format!("Claim '{}' has no values", pred.claim_name));
};
if values.iter().any(|v| regex.is_match(v)) {
(true, format!("Matches pattern '{}'", pattern))
} else {
(false, format!("No value matches pattern '{}'", pattern))
}
}
/// AnyOf / NoneOf: parse the expected value as a bracketed set and check membership.
/// `expect_in_set = true` → AnyOf (pass if value in set); `false` → NoneOf.
fn eval_set_membership(
claim_values: Option<&HashSet<&str>>,
pred: &CompiledPredicate,
expect_in_set: bool,
) -> (bool, String) {
let set: Vec<&str> = pred.expected_value.as_deref()
.map(|v| v.trim_matches(|c| c == '[' || c == ']').split(", ").collect())
.unwrap_or_default();
let Some(values) = claim_values else {
return if expect_in_set {
(false, format!("Claim '{}' has no values", pred.claim_name))
} else {
(true, format!("Claim '{}' has no values (none_of passes vacuously)", pred.claim_name))
};
};
let found = values.iter().any(|v| set.contains(v));
if expect_in_set {
if found { (true, "Value is in allowed set".into()) }
else { (false, "Value is not in allowed set".into()) }
} else if found {
(false, "Value is in forbidden set".into())
} else {
(true, "Value is not in forbidden set".into())
}
}
/// Evaluate a single predicate against the fact lookup table.
fn evaluate_predicate_datalog(
pred: &CompiledPredicate,
fact_lookup: &HashMap<&str, HashSet<&str>>,
@@ -433,202 +587,18 @@ fn evaluate_predicate_datalog(
let claim_values = fact_lookup.get(pred.claim_name.as_str());
let (passed, reason) = match pred.rule {
PredicateRule::Exists => {
if claim_values.is_some() && !claim_values.unwrap().is_empty() {
(true, "Value exists".to_string())
} else {
(false, "Value does not exist".to_string())
}
}
PredicateRule::NotExists => {
if claim_values.is_none() || claim_values.unwrap().is_empty() {
(true, "Value does not exist as expected".to_string())
} else {
(false, "Value exists but should not".to_string())
}
}
PredicateRule::Contains => {
let expected = pred.expected_value.as_deref().unwrap_or("");
if let Some(values) = claim_values {
if values.contains(expected) {
(true, format!("Contains '{}'", expected))
} else {
(false, format!("Does not contain '{}'", expected))
}
} else {
(false, format!("Claim '{}' has no values", pred.claim_name))
}
}
PredicateRule::Equals => {
let expected = pred.expected_value.as_deref().unwrap_or("");
if let Some(values) = claim_values {
if values.len() == 1 && values.contains(expected) {
(true, format!("Equals '{}'", expected))
} else if values.len() > 1 {
(false, format!("Multiple values found, expected single value '{}'", expected))
} else {
let actual = values.iter().next().map(|s| *s).unwrap_or("<none>");
(false, format!("Expected '{}', got '{}'", expected, actual))
}
} else {
(false, format!("Claim '{}' has no values", pred.claim_name))
}
}
PredicateRule::MinLength => {
let expected: usize = pred
.expected_value
.as_deref()
.and_then(|s| s.parse().ok())
.unwrap_or(0);
// Check the __length fact
let length_claim = format!("{}.__length", pred.claim_name);
let length = fact_lookup
.get(length_claim.as_str())
.and_then(|v| v.iter().next())
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(0);
if length >= expected {
(true, format!("Length {} >= {}", length, expected))
} else {
(false, format!("Length {} < {} (minimum)", length, expected))
}
}
PredicateRule::MaxLength => {
let expected: usize = pred
.expected_value
.as_deref()
.and_then(|s| s.parse().ok())
.unwrap_or(usize::MAX);
let length_claim = format!("{}.__length", pred.claim_name);
let length = fact_lookup
.get(length_claim.as_str())
.and_then(|v| v.iter().next())
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(0);
if length <= expected {
(true, format!("Length {} <= {}", length, expected))
} else {
(false, format!("Length {} > {} (maximum)", length, expected))
}
}
PredicateRule::GreaterThan => {
let expected: f64 = pred
.expected_value
.as_deref()
.and_then(|s| s.parse().ok())
.unwrap_or(0.0);
if let Some(values) = claim_values {
if let Some(actual) = values.iter().next().and_then(|s| s.parse::<f64>().ok()) {
if actual > expected {
(true, format!("{} > {}", actual, expected))
} else {
(false, format!("{} is not > {}", actual, expected))
}
} else {
(false, "Value is not a number".to_string())
}
} else {
(false, format!("Claim '{}' has no values", pred.claim_name))
}
}
PredicateRule::LessThan => {
let expected: f64 = pred
.expected_value
.as_deref()
.and_then(|s| s.parse().ok())
.unwrap_or(0.0);
if let Some(values) = claim_values {
if let Some(actual) = values.iter().next().and_then(|s| s.parse::<f64>().ok()) {
if actual < expected {
(true, format!("{} < {}", actual, expected))
} else {
(false, format!("{} is not < {}", actual, expected))
}
} else {
(false, "Value is not a number".to_string())
}
} else {
(false, format!("Claim '{}' has no values", pred.claim_name))
}
}
PredicateRule::Matches => {
let pattern = pred.expected_value.as_deref().unwrap_or("");
let regex = match regex::Regex::new(pattern) {
Ok(r) => r,
Err(e) => {
return DatalogPredicateResult {
id: pred.id,
claim_name: pred.claim_name.clone(),
rule: pred.rule.clone(),
expected_value: pred.expected_value.clone(),
passed: false,
reason: format!("Invalid regex: {}", e),
source: pred.source,
notes: pred.notes.clone(),
};
}
};
if let Some(values) = claim_values {
if values.iter().any(|v| regex.is_match(v)) {
(true, format!("Matches pattern '{}'", pattern))
} else {
(false, format!("No value matches pattern '{}'", pattern))
}
} else {
(false, format!("Claim '{}' has no values", pred.claim_name))
}
}
PredicateRule::NotContains => {
let expected = pred.expected_value.as_deref().unwrap_or("");
if let Some(values) = claim_values {
if values.contains(expected) {
(false, format!("Contains '{}' but should not", expected))
} else {
(true, format!("Does not contain '{}'", expected))
}
} else {
(true, format!("Claim '{}' has no values (not_contains passes vacuously)", pred.claim_name))
}
}
PredicateRule::AnyOf => {
let expected_set: Vec<&str> = pred.expected_value.as_deref()
.map(|v| v.trim_matches(|c| c == '[' || c == ']')
.split(", ")
.collect())
.unwrap_or_default();
if let Some(values) = claim_values {
if values.iter().any(|v| expected_set.contains(v)) {
(true, format!("Value is in allowed set"))
} else {
(false, format!("Value is not in allowed set"))
}
} else {
(false, format!("Claim '{}' has no values", pred.claim_name))
}
}
PredicateRule::NoneOf => {
let forbidden_set: Vec<&str> = pred.expected_value.as_deref()
.map(|v| v.trim_matches(|c| c == '[' || c == ']')
.split(", ")
.collect())
.unwrap_or_default();
if let Some(values) = claim_values {
if values.iter().any(|v| forbidden_set.contains(v)) {
(false, format!("Value is in forbidden set"))
} else {
(true, format!("Value is not in forbidden set"))
}
} else {
(true, format!("Claim '{}' has no values (none_of passes vacuously)", pred.claim_name))
}
}
PredicateRule::Exists => eval_existence(claim_values, true),
PredicateRule::NotExists => eval_existence(claim_values, false),
PredicateRule::Contains => eval_membership(claim_values, pred, true),
PredicateRule::NotContains => eval_membership(claim_values, pred, false),
PredicateRule::Equals => eval_equals(claim_values, pred),
PredicateRule::MinLength => eval_length(fact_lookup, pred, |len, exp| len >= exp, ">=", "<", "minimum"),
PredicateRule::MaxLength => eval_length(fact_lookup, pred, |len, exp| len <= exp, "<=", ">", "maximum"),
PredicateRule::GreaterThan => eval_numeric_cmp(claim_values, pred, |a, e| a > e, ">"),
PredicateRule::LessThan => eval_numeric_cmp(claim_values, pred, |a, e| a < e, "<"),
PredicateRule::Matches => eval_matches(claim_values, pred),
PredicateRule::AnyOf => eval_set_membership(claim_values, pred, true),
PredicateRule::NoneOf => eval_set_membership(claim_values, pred, false),
};
DatalogPredicateResult {
@@ -658,6 +628,36 @@ fn escape_datalog_string(s: &str) -> String {
.replace('\t', "\\t")
}
/// Generate the Soufflé rule body for a single predicate.
///
/// Returns the clause after `predicate_pass(id) :- `.
fn rule_body_for_predicate(rule: &PredicateRule, claim: &str, expected: &str) -> String {
match rule {
PredicateRule::Exists =>
format!("claim_value(\"{}\", _)", claim),
PredicateRule::NotExists =>
format!("!claim_value(\"{}\", _)", claim),
PredicateRule::Equals | PredicateRule::Contains =>
format!("claim_value(\"{}\", \"{}\")", claim, expected),
PredicateRule::NotContains =>
format!("!claim_value(\"{}\", \"{}\")", claim, expected),
PredicateRule::GreaterThan =>
format!("claim_value(\"{}\", V), to_number(V, N), N > {}", claim, expected),
PredicateRule::LessThan =>
format!("claim_value(\"{}\", V), to_number(V, N), N < {}", claim, expected),
PredicateRule::MinLength =>
format!("claim_length(\"{}\", N), N >= {}", claim, expected),
PredicateRule::MaxLength =>
format!("claim_length(\"{}\", N), N <= {}", claim, expected),
PredicateRule::Matches =>
format!("claim_value(\"{}\", V), match(\"{}\", V)", claim, expected),
PredicateRule::AnyOf =>
format!("claim_value(\"{}\", V), any_of(\"{}\", V)", claim, expected),
PredicateRule::NoneOf =>
format!("!claim_value(\"{}\", V), none_of(\"{}\", V)", claim, expected),
}
}
/// Format a compiled rulespec and extracted facts as a datalog program.
///
/// Produces a textual `.dl` file with:
@@ -739,83 +739,8 @@ pub fn format_datalog_program(
pred.notes.as_deref().map(|n| format!(" -- {}", n)).unwrap_or_default(),
));
match pred.rule {
PredicateRule::Exists => {
out.push_str(&format!(
"predicate_pass({}) :- claim_value(\"{}\", _).\n",
id, claim,
));
}
PredicateRule::NotExists => {
// Pass when no matching fact exists
out.push_str(&format!(
"predicate_pass({}) :- !claim_value(\"{}\", _).\n",
id, claim,
));
}
PredicateRule::Equals => {
out.push_str(&format!(
"predicate_pass({}) :- claim_value(\"{}\", \"{}\").\n",
id, claim, expected,
));
}
PredicateRule::Contains => {
out.push_str(&format!(
"predicate_pass({}) :- claim_value(\"{}\", \"{}\").\n",
id, claim, expected,
));
}
PredicateRule::GreaterThan => {
out.push_str(&format!(
"predicate_pass({}) :- claim_value(\"{}\", V), to_number(V, N), N > {}.\n",
id, claim, expected,
));
}
PredicateRule::LessThan => {
out.push_str(&format!(
"predicate_pass({}) :- claim_value(\"{}\", V), to_number(V, N), N < {}.\n",
id, claim, expected,
));
}
PredicateRule::MinLength => {
out.push_str(&format!(
"predicate_pass({}) :- claim_length(\"{}\", N), N >= {}.\n",
id, claim, expected,
));
}
PredicateRule::MaxLength => {
out.push_str(&format!(
"predicate_pass({}) :- claim_length(\"{}\", N), N <= {}.\n",
id, claim, expected,
));
}
PredicateRule::Matches => {
// Regex matching expressed as a match functor
out.push_str(&format!(
"predicate_pass({}) :- claim_value(\"{}\", V), match(\"{}\", V).\n",
id, claim, expected,
));
}
PredicateRule::NotContains => {
out.push_str(&format!(
"predicate_pass({}) :- !claim_value(\"{}\", \"{}\").\n",
id, claim, expected,
));
}
PredicateRule::AnyOf => {
// any_of: pass if claim value matches any element in the set
out.push_str(&format!(
"predicate_pass({}) :- claim_value(\"{}\", V), any_of(\"{}\", V).\n",
id, claim, expected,
));
}
PredicateRule::NoneOf => {
out.push_str(&format!(
"predicate_pass({}) :- !claim_value(\"{}\", V), none_of(\"{}\", V).\n",
id, claim, expected,
));
}
}
let body = rule_body_for_predicate(&pred.rule, &claim, &expected);
out.push_str(&format!("predicate_pass({}) :- {}.\n", id, body));
// Derive failure as the negation of pass
out.push_str(&format!(
@@ -828,7 +753,6 @@ pub fn format_datalog_program(
out
}
// ============================================================================
// ============================================================================
// Formatting
// ============================================================================

View File

@@ -397,25 +397,17 @@ pub fn verify_envelope(session_id: &str, working_dir: &Path) -> String {
}
}
// Return a summary for the tool output
// NOTE: The token value is intentionally NOT included in the output
// returned to the LLM. It exists only in the envelope.yaml file.
let summary = if result.failed_count == 0 {
// Summary for tool output (token value intentionally omitted — LLM must not see it)
let total = result.passed_count + result.failed_count;
if result.failed_count == 0 {
format!(
"\n✅ Invariant verification: {}/{} passed\n",
result.passed_count,
result.passed_count + result.failed_count,
"\n✅ Invariant verification: {}/{} passed\n", result.passed_count, total,
)
} else {
format!(
"\n⚠️ Invariant verification: {}/{} passed, {} failed\n",
result.passed_count,
result.passed_count + result.failed_count,
result.failed_count,
"\n⚠️ Invariant verification: {}/{} passed, {} failed\n", result.passed_count, total, result.failed_count,
)
};
summary
}
}
/// Stamp an envelope with a verification token and re-write it to disk.
@@ -432,17 +424,13 @@ fn stamp_envelope(
) -> Result<()> {
let key = get_or_create_verification_key()?;
// Compute token over the original envelope (without any previous verified field)
// Compute token over the envelope without any previous verified field
let mut clean_envelope = envelope.clone();
clean_envelope.verified = None;
let token = mint_token(&key, &clean_envelope, rulespec);
// Set the verified field and re-write
let mut stamped = envelope.clone();
stamped.verified = Some(token);
write_envelope(session_id, &stamped)?;
Ok(())
clean_envelope.verified = Some(token);
write_envelope(session_id, &clean_envelope)
}
// ============================================================================

View File

@@ -559,234 +559,139 @@ pub fn evaluate_predicate(
selected_values: &[YamlValue],
) -> PredicateResult {
match predicate.rule {
PredicateRule::Exists => {
// Filter out null values — null means "absent"
let non_null: Vec<_> = selected_values.iter()
.filter(|v| !v.is_null())
.collect();
if non_null.is_empty() {
PredicateResult::fail("Value does not exist")
PredicateRule::Exists => eval_yaml_existence(selected_values, true),
PredicateRule::NotExists => eval_yaml_existence(selected_values, false),
PredicateRule::Contains => eval_yaml_containment(predicate, selected_values, true),
PredicateRule::NotContains => eval_yaml_containment(predicate, selected_values, false),
PredicateRule::Equals => eval_yaml_equals(predicate, selected_values),
PredicateRule::MinLength => eval_yaml_length(predicate, selected_values, |len, n| len >= n, "min"),
PredicateRule::MaxLength => eval_yaml_length(predicate, selected_values, |len, n| len <= n, "max"),
PredicateRule::GreaterThan => eval_yaml_numeric(predicate, selected_values, |v, t| v > t, ">"),
PredicateRule::LessThan => eval_yaml_numeric(predicate, selected_values, |v, t| v < t, "<"),
PredicateRule::Matches => eval_yaml_matches(predicate, selected_values),
PredicateRule::AnyOf => eval_yaml_set(predicate, selected_values, true),
PredicateRule::NoneOf => eval_yaml_set(predicate, selected_values, false),
}
}
// ── Per-rule evaluation helpers ─────────────────────────────────────────
/// Exists / NotExists: any non-null value counts as "present".
fn eval_yaml_existence(values: &[YamlValue], expect_present: bool) -> PredicateResult {
let has_non_null = values.iter().any(|v| !v.is_null());
match (expect_present, has_non_null) {
(true, true) => PredicateResult::pass("Value exists"),
(true, false) => PredicateResult::fail("Value does not exist"),
(false, false) => PredicateResult::pass("Value does not exist as expected"),
(false, true) => PredicateResult::fail("Value exists but should not"),
}
}
/// Contains / NotContains: search selected values using `value_contains`.
fn eval_yaml_containment(pred: &Predicate, values: &[YamlValue], expect_present: bool) -> PredicateResult {
let Some(target) = &pred.value else {
let rule_name = if expect_present { "contains" } else { "not_contains" };
return PredicateResult::fail(format!("No value specified for {}", rule_name));
};
let found = values.iter().any(|v| value_contains(v, target));
let display = yaml_to_display(target);
match (expect_present, found) {
(true, true) => PredicateResult::pass(format!("Value contains {:?}", display)),
(true, false) => PredicateResult::fail(format!("Value does not contain {:?}", display)),
(false, false) => PredicateResult::pass(format!("Value does not contain {:?}", display)),
(false, true) => PredicateResult::fail(format!("Value contains {:?} but should not", display)),
}
}
fn eval_yaml_equals(pred: &Predicate, values: &[YamlValue]) -> PredicateResult {
let Some(target) = &pred.value else {
return PredicateResult::fail("No value specified for equals");
};
if values.len() != 1 {
return PredicateResult::fail(format!("Expected single value for equals, got {}", values.len()));
}
if &values[0] == target {
PredicateResult::pass("Values are equal")
} else {
PredicateResult::fail(format!("Values not equal: {:?} != {:?}", yaml_to_display(&values[0]), yaml_to_display(target)))
}
}
/// MinLength / MaxLength: find the first Sequence and compare its length.
fn eval_yaml_length(pred: &Predicate, values: &[YamlValue], cmp: fn(usize, usize) -> bool, label: &str) -> PredicateResult {
let threshold = match &pred.value {
Some(YamlValue::Number(n)) => n.as_u64().unwrap_or(0) as usize,
_ => return PredicateResult::fail(format!("{}_length requires a numeric value", label)),
};
for value in values {
if let YamlValue::Sequence(seq) = value {
return if cmp(seq.len(), threshold) {
PredicateResult::pass(format!("Array has {} elements ({}: {})", seq.len(), label, threshold))
} else {
PredicateResult::pass("Value exists")
}
PredicateResult::fail(format!("Array has {} elements ({}: {})", seq.len(), label, threshold))
};
}
PredicateRule::NotExists => {
// Filter out null values — null means "absent"
let non_null: Vec<_> = selected_values.iter()
.filter(|v| !v.is_null())
.collect();
if non_null.is_empty() {
PredicateResult::pass("Value does not exist as expected")
}
PredicateResult::fail("Value is not an array")
}
/// GreaterThan / LessThan: find the first Number and compare.
fn eval_yaml_numeric(pred: &Predicate, values: &[YamlValue], cmp: fn(f64, f64) -> bool, op: &str) -> PredicateResult {
let target = match &pred.value {
Some(YamlValue::Number(n)) => n.as_f64().unwrap_or(0.0),
_ => return PredicateResult::fail(format!("{} requires a numeric value", pred.rule)),
};
for value in values {
if let YamlValue::Number(n) = value {
let v = n.as_f64().unwrap_or(0.0);
return if cmp(v, target) {
PredicateResult::pass(format!("{} {} {}", v, op, target))
} else {
PredicateResult::fail("Value exists but should not")
PredicateResult::fail(format!("{} is not {} {}", v, op, target))
};
}
}
PredicateResult::fail("Value is not a number")
}
fn eval_yaml_matches(pred: &Predicate, values: &[YamlValue]) -> PredicateResult {
let Some(YamlValue::String(pattern)) = &pred.value else {
return PredicateResult::fail("matches requires a string pattern");
};
let regex = match regex::Regex::new(pattern) {
Ok(r) => r,
Err(e) => return PredicateResult::fail(format!("Invalid regex: {}", e)),
};
for value in values {
if let YamlValue::String(s) = value {
if regex.is_match(s) {
return PredicateResult::pass(format!("'{}' matches pattern", s));
}
}
PredicateRule::Contains => {
let target = match &predicate.value {
Some(v) => v,
None => return PredicateResult::fail("No value specified for contains"),
};
for value in selected_values {
if value_contains(value, target) {
return PredicateResult::pass(format!(
"Value contains {:?}",
yaml_to_display(target)
));
}
}
PredicateResult::fail(format!(
"Value does not contain {:?}",
yaml_to_display(target)
))
}
PredicateResult::fail(format!("No value matches pattern '{}'", pattern))
}
/// AnyOf / NoneOf: check if selected values are in (or not in) a set.
fn eval_yaml_set(pred: &Predicate, values: &[YamlValue], expect_in_set: bool) -> PredicateResult {
let label = if expect_in_set { "any_of" } else { "none_of" };
let set = match &pred.value {
Some(YamlValue::Sequence(seq)) => seq,
Some(_) => return PredicateResult::fail(format!("{} requires an array value", label)),
None => return PredicateResult::fail(format!("No value specified for {}", label)),
};
let found = values.iter().any(|v| set.contains(v));
let set_display = set.iter().map(yaml_to_display).collect::<Vec<_>>().join(", ");
match (expect_in_set, found) {
(true, true) => {
let matched = values.iter().find(|v| set.contains(v)).unwrap();
PredicateResult::pass(format!("Value {:?} is in allowed set", yaml_to_display(matched)))
}
PredicateRule::Equals => {
let target = match &predicate.value {
Some(v) => v,
None => return PredicateResult::fail("No value specified for equals"),
};
if selected_values.len() != 1 {
return PredicateResult::fail(format!(
"Expected single value for equals, got {}",
selected_values.len()
));
}
if &selected_values[0] == target {
PredicateResult::pass("Values are equal")
} else {
PredicateResult::fail(format!(
"Values not equal: {:?} != {:?}",
yaml_to_display(&selected_values[0]),
yaml_to_display(target)
))
}
}
PredicateRule::MinLength => {
let min = match &predicate.value {
Some(YamlValue::Number(n)) => n.as_u64().unwrap_or(0) as usize,
_ => return PredicateResult::fail("min_length requires a numeric value"),
};
for value in selected_values {
if let YamlValue::Sequence(seq) = value {
if seq.len() >= min {
return PredicateResult::pass(format!(
"Array has {} elements (min: {})",
seq.len(),
min
));
} else {
return PredicateResult::fail(format!(
"Array has {} elements (min: {})",
seq.len(),
min
));
}
}
}
PredicateResult::fail("Value is not an array")
}
PredicateRule::MaxLength => {
let max = match &predicate.value {
Some(YamlValue::Number(n)) => n.as_u64().unwrap_or(0) as usize,
_ => return PredicateResult::fail("max_length requires a numeric value"),
};
for value in selected_values {
if let YamlValue::Sequence(seq) = value {
if seq.len() <= max {
return PredicateResult::pass(format!(
"Array has {} elements (max: {})",
seq.len(),
max
));
} else {
return PredicateResult::fail(format!(
"Array has {} elements (max: {})",
seq.len(),
max
));
}
}
}
PredicateResult::fail("Value is not an array")
}
PredicateRule::GreaterThan => {
let target = match &predicate.value {
Some(YamlValue::Number(n)) => n.as_f64().unwrap_or(0.0),
_ => return PredicateResult::fail("greater_than requires a numeric value"),
};
for value in selected_values {
if let YamlValue::Number(n) = value {
let v = n.as_f64().unwrap_or(0.0);
if v > target {
return PredicateResult::pass(format!("{} > {}", v, target));
} else {
return PredicateResult::fail(format!("{} is not > {}", v, target));
}
}
}
PredicateResult::fail("Value is not a number")
}
PredicateRule::LessThan => {
let target = match &predicate.value {
Some(YamlValue::Number(n)) => n.as_f64().unwrap_or(0.0),
_ => return PredicateResult::fail("less_than requires a numeric value"),
};
for value in selected_values {
if let YamlValue::Number(n) = value {
let v = n.as_f64().unwrap_or(0.0);
if v < target {
return PredicateResult::pass(format!("{} < {}", v, target));
} else {
return PredicateResult::fail(format!("{} is not < {}", v, target));
}
}
}
PredicateResult::fail("Value is not a number")
}
PredicateRule::Matches => {
let pattern = match &predicate.value {
Some(YamlValue::String(s)) => s,
_ => return PredicateResult::fail("matches requires a string pattern"),
};
let regex = match regex::Regex::new(pattern) {
Ok(r) => r,
Err(e) => return PredicateResult::fail(format!("Invalid regex: {}", e)),
};
for value in selected_values {
if let YamlValue::String(s) = value {
if regex.is_match(s) {
return PredicateResult::pass(format!("'{}' matches pattern", s));
}
}
}
PredicateResult::fail(format!("No value matches pattern '{}'", pattern))
}
PredicateRule::NotContains => {
let target = match &predicate.value {
Some(v) => v,
None => return PredicateResult::fail("No value specified for not_contains"),
};
for value in selected_values {
if value_contains(value, target) {
return PredicateResult::fail(format!(
"Value contains {:?} but should not",
yaml_to_display(target)
));
}
}
PredicateResult::pass(format!(
"Value does not contain {:?}",
yaml_to_display(target)
))
}
PredicateRule::AnyOf => {
let allowed = match &predicate.value {
Some(YamlValue::Sequence(seq)) => seq,
Some(_) => return PredicateResult::fail("any_of requires an array value"),
None => return PredicateResult::fail("No value specified for any_of"),
};
for value in selected_values {
if allowed.contains(value) {
return PredicateResult::pass(format!(
"Value {:?} is in allowed set",
yaml_to_display(value)
));
}
}
PredicateResult::fail(format!(
"Value is not in allowed set [{}]",
allowed.iter().map(yaml_to_display).collect::<Vec<_>>().join(", ")
))
}
PredicateRule::NoneOf => {
let forbidden = match &predicate.value {
Some(YamlValue::Sequence(seq)) => seq,
Some(_) => return PredicateResult::fail("none_of requires an array value"),
None => return PredicateResult::fail("No value specified for none_of"),
};
for value in selected_values {
if forbidden.contains(value) {
return PredicateResult::fail(format!(
"Value {:?} is in forbidden set",
yaml_to_display(value)
));
}
}
PredicateResult::pass(format!(
"Value is not in forbidden set [{}]",
forbidden.iter().map(yaml_to_display).collect::<Vec<_>>().join(", ")
))
(true, false) => PredicateResult::fail(format!("Value is not in allowed set [{}]", set_display)),
(false, false) => PredicateResult::pass(format!("Value is not in forbidden set [{}]", set_display)),
(false, true) => {
let matched = values.iter().find(|v| set.contains(v)).unwrap();
PredicateResult::fail(format!("Value {:?} is in forbidden set", yaml_to_display(matched)))
}
}
}

View File

@@ -135,6 +135,134 @@ pub struct AnthropicProvider {
thinking_budget_tokens: Option<u32>,
}
// ── SSE Stream State ────────────────────────────────────────────────────
// Mutable state threaded through Anthropic's SSE stream parser.
// Each `handle_*` method processes one event type and returns chunks to send.
struct StreamState {
tool_calls: Vec<ToolCall>,
partial_tool_json: String,
usage: Option<Usage>,
message_stopped: bool,
stop_reason: Option<String>,
}
impl StreamState {
fn new() -> Self {
Self {
tool_calls: Vec::new(),
partial_tool_json: String::new(),
usage: None,
message_stopped: false,
stop_reason: None,
}
}
fn handle_message_start(&mut self, event: &AnthropicStreamEvent) {
if let Some(message) = &event.message {
if let Some(u) = &message.usage {
self.usage = Some(Usage {
prompt_tokens: u.input_tokens,
completion_tokens: u.output_tokens,
total_tokens: u.input_tokens + u.output_tokens,
cache_creation_tokens: u.cache_creation_input_tokens,
cache_read_tokens: u.cache_read_input_tokens,
});
debug!("Captured usage from message_start: {:?}", self.usage);
}
}
}
/// Returns chunks to send for a content_block_start event.
fn handle_block_start(&mut self, event: AnthropicStreamEvent) -> Vec<Result<CompletionChunk>> {
let Some(content_block) = event.content_block else { return vec![] };
match content_block {
AnthropicContent::ToolUse { id, name, input } => {
debug!("Tool use block: id={}, name={}, input={:?}", id, name, input);
let tool_call = ToolCall { id: id.clone(), tool: name.clone(), args: input.clone() };
let has_complete_args = !input.is_null()
&& input != serde_json::Value::Object(serde_json::Map::new());
if has_complete_args {
debug!("Tool call has complete args, sending immediately");
vec![Ok(make_tool_chunk(vec![tool_call]))]
} else {
debug!("Tool call has empty args, will accumulate from partial_json");
let hint = make_tool_streaming_hint(name);
self.tool_calls.push(tool_call);
self.partial_tool_json.clear();
vec![Ok(hint)]
}
}
_ => {
debug!("Non-tool content block: {:?}", content_block);
vec![]
}
}
}
/// Returns chunks to send for a content_block_delta event.
fn handle_block_delta(&mut self, event: AnthropicStreamEvent) -> Vec<Result<CompletionChunk>> {
let Some(delta) = event.delta else { return vec![] };
let mut chunks = Vec::new();
if let Some(text) = delta.text {
debug!("Text chunk (len {})", text.len());
chunks.push(Ok(make_text_chunk(text)));
}
if let Some(json_fragment) = delta.partial_json {
debug!("Partial JSON: {}", json_fragment);
self.partial_tool_json.push_str(&json_fragment);
chunks.push(Ok(make_tool_streaming_active()));
}
chunks
}
/// Returns chunks to send when a content block finishes.
fn handle_block_stop(&mut self) -> Vec<Result<CompletionChunk>> {
// Finalize accumulated partial JSON into the last tool call's args
if !self.tool_calls.is_empty() && !self.partial_tool_json.is_empty() {
debug!("Parsing complete tool JSON: {}", self.partial_tool_json);
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&self.partial_tool_json) {
if let Some(last) = self.tool_calls.last_mut() {
last.args = parsed;
debug!("Updated tool call with complete args: {:?}", last);
}
} else {
debug!("Failed to parse accumulated JSON: {}", self.partial_tool_json);
}
self.partial_tool_json.clear();
}
if self.tool_calls.is_empty() {
return vec![];
}
let chunk = make_tool_chunk(self.tool_calls.clone());
self.tool_calls.clear();
vec![Ok(chunk)]
}
fn handle_message_delta(&mut self, event: &AnthropicStreamEvent) {
if let Some(delta) = &event.delta {
if let Some(reason) = &delta.stop_reason {
debug!("Received stop_reason: {}", reason);
self.stop_reason = Some(reason.clone());
}
}
}
fn handle_message_stop(&mut self) -> Vec<Result<CompletionChunk>> {
debug!("Received message stop event");
self.message_stopped = true;
let chunk = make_final_chunk_with_reason(
self.tool_calls.clone(),
self.usage.clone(),
self.stop_reason.clone(),
);
vec![Ok(chunk)]
}
}
impl AnthropicProvider {
pub fn new(
api_key: String,
@@ -501,267 +629,76 @@ impl AnthropicProvider {
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
tx: mpsc::Sender<Result<CompletionChunk>>,
) -> Option<Usage> {
let mut buffer = String::new();
let mut current_tool_calls: Vec<ToolCall> = Vec::new();
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
let mut accumulated_usage: Option<Usage> = None;
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
let mut message_stopped = false; // Track if we've received message_stop
let mut stop_reason: Option<String> = None; // Track why the message stopped
let mut state = StreamState::new();
let mut line_buffer = String::new();
let mut byte_buffer: Vec<u8> = Vec::new();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
byte_buffer.extend_from_slice(&chunk);
let Some(chunk_str) = decode_utf8_streaming(&mut byte_buffer) else {
continue;
};
line_buffer.push_str(&chunk_str);
buffer.push_str(&chunk_str);
while let Some(line_end) = line_buffer.find('\n') {
let line = line_buffer[..line_end].trim().to_string();
line_buffer.drain(..line_end + 1);
// Process complete lines
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].trim().to_string();
buffer.drain(..line_end + 1);
if line.is_empty() {
continue;
}
// If we've already sent the final chunk, skip processing more events
if message_stopped {
debug!("Skipping event after message_stop: {}", line);
continue;
}
// Parse Server-Sent Events format
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
debug!("Received stream completion marker");
let final_chunk = make_final_chunk(
current_tool_calls.clone(),
accumulated_usage.clone(),
);
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
}
return accumulated_usage;
if line.is_empty() || state.message_stopped {
if state.message_stopped && !line.is_empty() {
debug!("Skipping event after message_stop: {}", line);
}
continue;
}
debug!("Raw Claude API JSON: {}", data);
let Some(data) = line.strip_prefix("data: ") else { continue };
match serde_json::from_str::<AnthropicStreamEvent>(data) {
Ok(event) => {
debug!(
"Parsed event type: {}, event: {:?}",
event.event_type, event
);
match event.event_type.as_str() {
"message_start" => {
// Extract usage data from message_start event
if let Some(message) = event.message {
if let Some(usage) = message.usage {
accumulated_usage = Some(Usage {
prompt_tokens: usage.input_tokens,
completion_tokens: usage.output_tokens,
total_tokens: usage.input_tokens
+ usage.output_tokens,
cache_creation_tokens: usage
.cache_creation_input_tokens,
cache_read_tokens: usage
.cache_read_input_tokens,
});
debug!(
"Captured usage from message_start: {:?}",
accumulated_usage
);
}
}
}
"content_block_start" => {
debug!(
"Received content_block_start event: {:?}",
event
);
if let Some(content_block) = event.content_block {
match content_block {
AnthropicContent::ToolUse {
id,
name,
input,
} => {
debug!("Found tool use in content_block_start: id={}, name={}, input={:?}", id, name, input);
// Stream completion marker
if data == "[DONE]" {
debug!("Received stream completion marker");
let final_chunk = make_final_chunk(state.tool_calls.clone(), state.usage.clone());
let _ = tx.send(Ok(final_chunk)).await;
return state.usage;
}
// For native tool calls, create the tool call immediately if we have complete args
// If args are empty, we'll wait for partial_json to accumulate them
let tool_call = ToolCall {
id: id.clone(),
tool: name.clone(),
args: input.clone(),
};
debug!("Raw Claude API JSON: {}", data);
// Check if we already have complete arguments
if !input.is_null()
&& input
!= serde_json::Value::Object(
serde_json::Map::new(),
)
{
// We have complete arguments, send the tool call immediately
debug!("Tool call has complete args, sending immediately: {:?}", tool_call);
let chunk =
make_tool_chunk(vec![tool_call]);
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return accumulated_usage;
}
} else {
// Arguments are empty, we'll accumulate them from partial_json
debug!("Tool call has empty args, will accumulate from partial_json");
// Send a streaming hint so the UI can show the tool name immediately
let hint_chunk = make_tool_streaming_hint(name.clone());
if tx.send(Ok(hint_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return accumulated_usage;
}
current_tool_calls.push(tool_call);
partial_tool_json.clear();
}
}
_ => {
debug!(
"Non-tool content block: {:?}",
content_block
);
}
}
}
}
"content_block_delta" => {
if let Some(delta) = event.delta {
if let Some(text) = delta.text {
debug!(
"Sending text chunk of length {}: '{}'",
text.len(),
text
);
let chunk = make_text_chunk(text);
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return accumulated_usage;
}
}
// Handle partial JSON for tool calls
if let Some(partial_json) = delta.partial_json {
debug!(
"Received partial JSON: {}",
partial_json
);
partial_tool_json.push_str(&partial_json);
debug!(
"Accumulated tool JSON: {}",
partial_tool_json
);
// Send an active hint to trigger UI blink
let active_chunk = make_tool_streaming_active();
if tx.send(Ok(active_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return accumulated_usage;
}
}
}
}
"content_block_stop" => {
// Tool call block is complete - now parse the accumulated JSON
if !current_tool_calls.is_empty()
&& !partial_tool_json.is_empty()
{
debug!(
"Parsing complete tool JSON: {}",
partial_tool_json
);
let event = match serde_json::from_str::<AnthropicStreamEvent>(data) {
Ok(e) => e,
Err(e) => {
debug!("Failed to parse stream event: {} - Data: {}", e, data);
continue;
}
};
// Parse the accumulated JSON and update the last tool call
if let Ok(parsed_args) =
serde_json::from_str::<serde_json::Value>(
&partial_tool_json,
)
{
if let Some(last_tool) =
current_tool_calls.last_mut()
{
last_tool.args = parsed_args;
debug!("Updated tool call with complete args: {:?}", last_tool);
}
} else {
debug!(
"Failed to parse accumulated JSON: {}",
partial_tool_json
);
}
debug!("Parsed event type: {}", event.event_type);
// Clear the accumulator
partial_tool_json.clear();
}
// Send the complete tool call
if !current_tool_calls.is_empty() {
let chunk =
make_tool_chunk(current_tool_calls.clone());
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return accumulated_usage;
}
// Clear tool calls after sending to prevent duplicates at message_stop
current_tool_calls.clear();
}
}
"message_delta" => {
// message_delta contains the stop_reason and final usage
if let Some(delta) = &event.delta {
if let Some(reason) = &delta.stop_reason {
debug!("Received stop_reason: {}", reason);
stop_reason = Some(reason.clone());
}
}
// Usage is also in message_delta but we get it from message_start
}
"message_stop" => {
debug!("Received message stop event");
message_stopped = true;
let final_chunk = make_final_chunk_with_reason(
current_tool_calls.clone(),
accumulated_usage.clone(),
stop_reason.clone(),
);
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
}
// Don't return here - let the stream naturally exhaust
// This prevents dropping the sender prematurely
}
"error" => {
if let Some(error) = event.error {
error!("Anthropic API error: {:?}", error);
let _ = tx
.send(Err(anyhow!(
"Anthropic API error: {:?}",
error
)))
.await;
break; // Break to let stream exhaust naturally
}
}
_ => {
debug!("Ignoring event type: {}", event.event_type);
}
}
}
Err(e) => {
debug!("Failed to parse stream event: {} - Data: {}", e, data);
// Don't error out on parse failures, just continue
// Dispatch to per-event handlers; collect chunks to send
let chunks: Vec<Result<CompletionChunk>> = match event.event_type.as_str() {
"message_start" => { state.handle_message_start(&event); vec![] }
"content_block_start" => state.handle_block_start(event),
"content_block_delta" => state.handle_block_delta(event),
"content_block_stop" => state.handle_block_stop(),
"message_delta" => { state.handle_message_delta(&event); vec![] }
"message_stop" => state.handle_message_stop(),
"error" => {
if let Some(error) = event.error {
error!("Anthropic API error: {:?}", error);
let _ = tx.send(Err(anyhow!("Anthropic API error: {:?}", error))).await;
break;
}
vec![]
}
_ => { debug!("Ignoring event type: {}", event.event_type); vec![] }
};
// Send all chunks produced by the handler
for chunk in chunks {
if tx.send(chunk).await.is_err() {
debug!("Receiver dropped, stopping stream");
return state.usage;
}
}
}
@@ -769,18 +706,14 @@ impl AnthropicProvider {
Err(e) => {
error!("Stream error: {}", e);
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
// Don't return here either - let the stream exhaust naturally
// The error has been sent to the receiver, so it will handle it
// Breaking here ensures we clean up properly
break;
}
}
}
// Send final chunk if we haven't already
let final_chunk = make_final_chunk(current_tool_calls, accumulated_usage.clone());
let final_chunk = make_final_chunk(state.tool_calls, state.usage.clone());
let _ = tx.send(Ok(final_chunk)).await;
accumulated_usage
state.usage
}
}