From 1ad74baaa5cbb9f8c26c802bb4b884c1c700af6f Mon Sep 17 00:00:00 2001 From: "Dhanji R. Prasanna" Date: Fri, 13 Feb 2026 16:21:38 +1100 Subject: [PATCH] Readability refactor: extract mega-functions into focused helpers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Agent: carmack 4 files refactored, net -250 lines, all tests passing (417 + 71). datalog.rs: - Extract 7 predicate evaluation helpers from evaluate_predicate_datalog() (~200-line match → 12-line dispatch table) - Extract rule_body_for_predicate() from format_datalog_program() (~75-line match → 2-line call) invariants.rs: - Extract 7 per-rule helpers from evaluate_predicate() (~230-line match → 12-line dispatch table) envelope.rs: - Simplify summary construction in verify_envelope() - Eliminate redundant clone in stamp_envelope() anthropic.rs: - Introduce StreamState struct with 6 handler methods - parse_streaming_response: ~290 lines → ~90 lines - Max nesting depth reduced from 8 to 4 levels --- crates/g3-core/src/tools/datalog.rs | 474 +++++++++++-------------- crates/g3-core/src/tools/envelope.rs | 30 +- crates/g3-core/src/tools/invariants.rs | 349 +++++++----------- crates/g3-providers/src/anthropic.rs | 431 ++++++++++------------ 4 files changed, 517 insertions(+), 767 deletions(-) diff --git a/crates/g3-core/src/tools/datalog.rs b/crates/g3-core/src/tools/datalog.rs index 39dad1e..7bbf028 100644 --- a/crates/g3-core/src/tools/datalog.rs +++ b/crates/g3-core/src/tools/datalog.rs @@ -425,7 +425,161 @@ pub fn execute_rules( } } -/// Evaluate a single predicate using the fact lookup. +// ── Predicate evaluation helpers ──────────────────────────────────────── +// Each returns (passed: bool, reason: String) for a specific rule type. + +/// Exists / NotExists: check whether the claim has any values. +/// `expect_present = true` → Exists; `false` → NotExists. +fn eval_existence(claim_values: Option<&HashSet<&str>>, expect_present: bool) -> (bool, String) { + let has_values = claim_values.map_or(false, |v| !v.is_empty()); + if expect_present == has_values { + if has_values { + (true, "Value exists".into()) + } else { + (true, "Value does not exist as expected".into()) + } + } else if expect_present { + (false, "Value does not exist".into()) + } else { + (false, "Value exists but should not".into()) + } +} + +/// Contains / NotContains: check whether a specific value is among the claim's facts. +/// `expect_present = true` → Contains; `false` → NotContains. +fn eval_membership( + claim_values: Option<&HashSet<&str>>, + pred: &CompiledPredicate, + expect_present: bool, +) -> (bool, String) { + let expected = pred.expected_value.as_deref().unwrap_or(""); + match claim_values { + Some(values) => { + let found = values.contains(expected); + if expect_present { + if found { + (true, format!("Contains '{}'", expected)) + } else { + (false, format!("Does not contain '{}'", expected)) + } + } else if found { + (false, format!("Contains '{}' but should not", expected)) + } else { + (true, format!("Does not contain '{}'", expected)) + } + } + None if expect_present => (false, format!("Claim '{}' has no values", pred.claim_name)), + None => (true, format!("Claim '{}' has no values (not_contains passes vacuously)", pred.claim_name)), + } +} + +/// Equals: exactly one value must match the expected string. +fn eval_equals(claim_values: Option<&HashSet<&str>>, pred: &CompiledPredicate) -> (bool, String) { + let expected = pred.expected_value.as_deref().unwrap_or(""); + let Some(values) = claim_values else { + return (false, format!("Claim '{}' has no values", pred.claim_name)); + }; + if values.len() == 1 && values.contains(expected) { + (true, format!("Equals '{}'", expected)) + } else if values.len() > 1 { + (false, format!("Multiple values found, expected single value '{}'", expected)) + } else { + let actual = values.iter().next().copied().unwrap_or(""); + (false, format!("Expected '{}', got '{}'", expected, actual)) + } +} + +/// MinLength / MaxLength: compare the `__length` fact against a threshold. +fn eval_length( + fact_lookup: &HashMap<&str, HashSet<&str>>, + pred: &CompiledPredicate, + cmp: impl Fn(usize, usize) -> bool, + pass_op: &str, + fail_op: &str, + label: &str, +) -> (bool, String) { + let expected: usize = pred.expected_value.as_deref() + .and_then(|s| s.parse().ok()) + .unwrap_or(0); + let length_key = format!("{}.__length", pred.claim_name); + let length: usize = fact_lookup + .get(length_key.as_str()) + .and_then(|v| v.iter().next()) + .and_then(|s| s.parse().ok()) + .unwrap_or(0); + if cmp(length, expected) { + (true, format!("Length {} {} {}", length, pass_op, expected)) + } else { + (false, format!("Length {} {} {} ({})", length, fail_op, expected, label)) + } +} + +/// GreaterThan / LessThan: parse the first claim value as f64 and compare. +fn eval_numeric_cmp( + claim_values: Option<&HashSet<&str>>, + pred: &CompiledPredicate, + cmp: impl Fn(f64, f64) -> bool, + op: &str, +) -> (bool, String) { + let expected: f64 = pred.expected_value.as_deref() + .and_then(|s| s.parse().ok()) + .unwrap_or(0.0); + let Some(values) = claim_values else { + return (false, format!("Claim '{}' has no values", pred.claim_name)); + }; + match values.iter().next().and_then(|s| s.parse::().ok()) { + Some(actual) if cmp(actual, expected) => (true, format!("{} {} {}", actual, op, expected)), + Some(actual) => (false, format!("{} is not {} {}", actual, op, expected)), + None => (false, "Value is not a number".into()), + } +} + +/// Matches: check if any claim value matches a regex pattern. +fn eval_matches(claim_values: Option<&HashSet<&str>>, pred: &CompiledPredicate) -> (bool, String) { + let pattern = pred.expected_value.as_deref().unwrap_or(""); + let regex = match regex::Regex::new(pattern) { + Ok(r) => r, + Err(e) => return (false, format!("Invalid regex: {}", e)), + }; + let Some(values) = claim_values else { + return (false, format!("Claim '{}' has no values", pred.claim_name)); + }; + if values.iter().any(|v| regex.is_match(v)) { + (true, format!("Matches pattern '{}'", pattern)) + } else { + (false, format!("No value matches pattern '{}'", pattern)) + } +} + +/// AnyOf / NoneOf: parse the expected value as a bracketed set and check membership. +/// `expect_in_set = true` → AnyOf (pass if value in set); `false` → NoneOf. +fn eval_set_membership( + claim_values: Option<&HashSet<&str>>, + pred: &CompiledPredicate, + expect_in_set: bool, +) -> (bool, String) { + let set: Vec<&str> = pred.expected_value.as_deref() + .map(|v| v.trim_matches(|c| c == '[' || c == ']').split(", ").collect()) + .unwrap_or_default(); + let Some(values) = claim_values else { + return if expect_in_set { + (false, format!("Claim '{}' has no values", pred.claim_name)) + } else { + (true, format!("Claim '{}' has no values (none_of passes vacuously)", pred.claim_name)) + }; + }; + let found = values.iter().any(|v| set.contains(v)); + if expect_in_set { + if found { (true, "Value is in allowed set".into()) } + else { (false, "Value is not in allowed set".into()) } + } else if found { + (false, "Value is in forbidden set".into()) + } else { + (true, "Value is not in forbidden set".into()) + } +} + +/// Evaluate a single predicate against the fact lookup table. fn evaluate_predicate_datalog( pred: &CompiledPredicate, fact_lookup: &HashMap<&str, HashSet<&str>>, @@ -433,202 +587,18 @@ fn evaluate_predicate_datalog( let claim_values = fact_lookup.get(pred.claim_name.as_str()); let (passed, reason) = match pred.rule { - PredicateRule::Exists => { - if claim_values.is_some() && !claim_values.unwrap().is_empty() { - (true, "Value exists".to_string()) - } else { - (false, "Value does not exist".to_string()) - } - } - PredicateRule::NotExists => { - if claim_values.is_none() || claim_values.unwrap().is_empty() { - (true, "Value does not exist as expected".to_string()) - } else { - (false, "Value exists but should not".to_string()) - } - } - PredicateRule::Contains => { - let expected = pred.expected_value.as_deref().unwrap_or(""); - if let Some(values) = claim_values { - if values.contains(expected) { - (true, format!("Contains '{}'", expected)) - } else { - (false, format!("Does not contain '{}'", expected)) - } - } else { - (false, format!("Claim '{}' has no values", pred.claim_name)) - } - } - PredicateRule::Equals => { - let expected = pred.expected_value.as_deref().unwrap_or(""); - if let Some(values) = claim_values { - if values.len() == 1 && values.contains(expected) { - (true, format!("Equals '{}'", expected)) - } else if values.len() > 1 { - (false, format!("Multiple values found, expected single value '{}'", expected)) - } else { - let actual = values.iter().next().map(|s| *s).unwrap_or(""); - (false, format!("Expected '{}', got '{}'", expected, actual)) - } - } else { - (false, format!("Claim '{}' has no values", pred.claim_name)) - } - } - PredicateRule::MinLength => { - let expected: usize = pred - .expected_value - .as_deref() - .and_then(|s| s.parse().ok()) - .unwrap_or(0); - - // Check the __length fact - let length_claim = format!("{}.__length", pred.claim_name); - let length = fact_lookup - .get(length_claim.as_str()) - .and_then(|v| v.iter().next()) - .and_then(|s| s.parse::().ok()) - .unwrap_or(0); - - if length >= expected { - (true, format!("Length {} >= {}", length, expected)) - } else { - (false, format!("Length {} < {} (minimum)", length, expected)) - } - } - PredicateRule::MaxLength => { - let expected: usize = pred - .expected_value - .as_deref() - .and_then(|s| s.parse().ok()) - .unwrap_or(usize::MAX); - - let length_claim = format!("{}.__length", pred.claim_name); - let length = fact_lookup - .get(length_claim.as_str()) - .and_then(|v| v.iter().next()) - .and_then(|s| s.parse::().ok()) - .unwrap_or(0); - - if length <= expected { - (true, format!("Length {} <= {}", length, expected)) - } else { - (false, format!("Length {} > {} (maximum)", length, expected)) - } - } - PredicateRule::GreaterThan => { - let expected: f64 = pred - .expected_value - .as_deref() - .and_then(|s| s.parse().ok()) - .unwrap_or(0.0); - - if let Some(values) = claim_values { - if let Some(actual) = values.iter().next().and_then(|s| s.parse::().ok()) { - if actual > expected { - (true, format!("{} > {}", actual, expected)) - } else { - (false, format!("{} is not > {}", actual, expected)) - } - } else { - (false, "Value is not a number".to_string()) - } - } else { - (false, format!("Claim '{}' has no values", pred.claim_name)) - } - } - PredicateRule::LessThan => { - let expected: f64 = pred - .expected_value - .as_deref() - .and_then(|s| s.parse().ok()) - .unwrap_or(0.0); - - if let Some(values) = claim_values { - if let Some(actual) = values.iter().next().and_then(|s| s.parse::().ok()) { - if actual < expected { - (true, format!("{} < {}", actual, expected)) - } else { - (false, format!("{} is not < {}", actual, expected)) - } - } else { - (false, "Value is not a number".to_string()) - } - } else { - (false, format!("Claim '{}' has no values", pred.claim_name)) - } - } - PredicateRule::Matches => { - let pattern = pred.expected_value.as_deref().unwrap_or(""); - let regex = match regex::Regex::new(pattern) { - Ok(r) => r, - Err(e) => { - return DatalogPredicateResult { - id: pred.id, - claim_name: pred.claim_name.clone(), - rule: pred.rule.clone(), - expected_value: pred.expected_value.clone(), - passed: false, - reason: format!("Invalid regex: {}", e), - source: pred.source, - notes: pred.notes.clone(), - }; - } - }; - - if let Some(values) = claim_values { - if values.iter().any(|v| regex.is_match(v)) { - (true, format!("Matches pattern '{}'", pattern)) - } else { - (false, format!("No value matches pattern '{}'", pattern)) - } - } else { - (false, format!("Claim '{}' has no values", pred.claim_name)) - } - } - PredicateRule::NotContains => { - let expected = pred.expected_value.as_deref().unwrap_or(""); - if let Some(values) = claim_values { - if values.contains(expected) { - (false, format!("Contains '{}' but should not", expected)) - } else { - (true, format!("Does not contain '{}'", expected)) - } - } else { - (true, format!("Claim '{}' has no values (not_contains passes vacuously)", pred.claim_name)) - } - } - PredicateRule::AnyOf => { - let expected_set: Vec<&str> = pred.expected_value.as_deref() - .map(|v| v.trim_matches(|c| c == '[' || c == ']') - .split(", ") - .collect()) - .unwrap_or_default(); - if let Some(values) = claim_values { - if values.iter().any(|v| expected_set.contains(v)) { - (true, format!("Value is in allowed set")) - } else { - (false, format!("Value is not in allowed set")) - } - } else { - (false, format!("Claim '{}' has no values", pred.claim_name)) - } - } - PredicateRule::NoneOf => { - let forbidden_set: Vec<&str> = pred.expected_value.as_deref() - .map(|v| v.trim_matches(|c| c == '[' || c == ']') - .split(", ") - .collect()) - .unwrap_or_default(); - if let Some(values) = claim_values { - if values.iter().any(|v| forbidden_set.contains(v)) { - (false, format!("Value is in forbidden set")) - } else { - (true, format!("Value is not in forbidden set")) - } - } else { - (true, format!("Claim '{}' has no values (none_of passes vacuously)", pred.claim_name)) - } - } + PredicateRule::Exists => eval_existence(claim_values, true), + PredicateRule::NotExists => eval_existence(claim_values, false), + PredicateRule::Contains => eval_membership(claim_values, pred, true), + PredicateRule::NotContains => eval_membership(claim_values, pred, false), + PredicateRule::Equals => eval_equals(claim_values, pred), + PredicateRule::MinLength => eval_length(fact_lookup, pred, |len, exp| len >= exp, ">=", "<", "minimum"), + PredicateRule::MaxLength => eval_length(fact_lookup, pred, |len, exp| len <= exp, "<=", ">", "maximum"), + PredicateRule::GreaterThan => eval_numeric_cmp(claim_values, pred, |a, e| a > e, ">"), + PredicateRule::LessThan => eval_numeric_cmp(claim_values, pred, |a, e| a < e, "<"), + PredicateRule::Matches => eval_matches(claim_values, pred), + PredicateRule::AnyOf => eval_set_membership(claim_values, pred, true), + PredicateRule::NoneOf => eval_set_membership(claim_values, pred, false), }; DatalogPredicateResult { @@ -658,6 +628,36 @@ fn escape_datalog_string(s: &str) -> String { .replace('\t', "\\t") } +/// Generate the Soufflé rule body for a single predicate. +/// +/// Returns the clause after `predicate_pass(id) :- `. +fn rule_body_for_predicate(rule: &PredicateRule, claim: &str, expected: &str) -> String { + match rule { + PredicateRule::Exists => + format!("claim_value(\"{}\", _)", claim), + PredicateRule::NotExists => + format!("!claim_value(\"{}\", _)", claim), + PredicateRule::Equals | PredicateRule::Contains => + format!("claim_value(\"{}\", \"{}\")", claim, expected), + PredicateRule::NotContains => + format!("!claim_value(\"{}\", \"{}\")", claim, expected), + PredicateRule::GreaterThan => + format!("claim_value(\"{}\", V), to_number(V, N), N > {}", claim, expected), + PredicateRule::LessThan => + format!("claim_value(\"{}\", V), to_number(V, N), N < {}", claim, expected), + PredicateRule::MinLength => + format!("claim_length(\"{}\", N), N >= {}", claim, expected), + PredicateRule::MaxLength => + format!("claim_length(\"{}\", N), N <= {}", claim, expected), + PredicateRule::Matches => + format!("claim_value(\"{}\", V), match(\"{}\", V)", claim, expected), + PredicateRule::AnyOf => + format!("claim_value(\"{}\", V), any_of(\"{}\", V)", claim, expected), + PredicateRule::NoneOf => + format!("!claim_value(\"{}\", V), none_of(\"{}\", V)", claim, expected), + } +} + /// Format a compiled rulespec and extracted facts as a datalog program. /// /// Produces a textual `.dl` file with: @@ -739,83 +739,8 @@ pub fn format_datalog_program( pred.notes.as_deref().map(|n| format!(" -- {}", n)).unwrap_or_default(), )); - match pred.rule { - PredicateRule::Exists => { - out.push_str(&format!( - "predicate_pass({}) :- claim_value(\"{}\", _).\n", - id, claim, - )); - } - PredicateRule::NotExists => { - // Pass when no matching fact exists - out.push_str(&format!( - "predicate_pass({}) :- !claim_value(\"{}\", _).\n", - id, claim, - )); - } - PredicateRule::Equals => { - out.push_str(&format!( - "predicate_pass({}) :- claim_value(\"{}\", \"{}\").\n", - id, claim, expected, - )); - } - PredicateRule::Contains => { - out.push_str(&format!( - "predicate_pass({}) :- claim_value(\"{}\", \"{}\").\n", - id, claim, expected, - )); - } - PredicateRule::GreaterThan => { - out.push_str(&format!( - "predicate_pass({}) :- claim_value(\"{}\", V), to_number(V, N), N > {}.\n", - id, claim, expected, - )); - } - PredicateRule::LessThan => { - out.push_str(&format!( - "predicate_pass({}) :- claim_value(\"{}\", V), to_number(V, N), N < {}.\n", - id, claim, expected, - )); - } - PredicateRule::MinLength => { - out.push_str(&format!( - "predicate_pass({}) :- claim_length(\"{}\", N), N >= {}.\n", - id, claim, expected, - )); - } - PredicateRule::MaxLength => { - out.push_str(&format!( - "predicate_pass({}) :- claim_length(\"{}\", N), N <= {}.\n", - id, claim, expected, - )); - } - PredicateRule::Matches => { - // Regex matching expressed as a match functor - out.push_str(&format!( - "predicate_pass({}) :- claim_value(\"{}\", V), match(\"{}\", V).\n", - id, claim, expected, - )); - } - PredicateRule::NotContains => { - out.push_str(&format!( - "predicate_pass({}) :- !claim_value(\"{}\", \"{}\").\n", - id, claim, expected, - )); - } - PredicateRule::AnyOf => { - // any_of: pass if claim value matches any element in the set - out.push_str(&format!( - "predicate_pass({}) :- claim_value(\"{}\", V), any_of(\"{}\", V).\n", - id, claim, expected, - )); - } - PredicateRule::NoneOf => { - out.push_str(&format!( - "predicate_pass({}) :- !claim_value(\"{}\", V), none_of(\"{}\", V).\n", - id, claim, expected, - )); - } - } + let body = rule_body_for_predicate(&pred.rule, &claim, &expected); + out.push_str(&format!("predicate_pass({}) :- {}.\n", id, body)); // Derive failure as the negation of pass out.push_str(&format!( @@ -828,7 +753,6 @@ pub fn format_datalog_program( out } // ============================================================================ -// ============================================================================ // Formatting // ============================================================================ diff --git a/crates/g3-core/src/tools/envelope.rs b/crates/g3-core/src/tools/envelope.rs index 27eed3f..d2ce87e 100644 --- a/crates/g3-core/src/tools/envelope.rs +++ b/crates/g3-core/src/tools/envelope.rs @@ -397,25 +397,17 @@ pub fn verify_envelope(session_id: &str, working_dir: &Path) -> String { } } - // Return a summary for the tool output - // NOTE: The token value is intentionally NOT included in the output - // returned to the LLM. It exists only in the envelope.yaml file. - let summary = if result.failed_count == 0 { + // Summary for tool output (token value intentionally omitted — LLM must not see it) + let total = result.passed_count + result.failed_count; + if result.failed_count == 0 { format!( - "\n✅ Invariant verification: {}/{} passed\n", - result.passed_count, - result.passed_count + result.failed_count, + "\n✅ Invariant verification: {}/{} passed\n", result.passed_count, total, ) } else { format!( - "\n⚠️ Invariant verification: {}/{} passed, {} failed\n", - result.passed_count, - result.passed_count + result.failed_count, - result.failed_count, + "\n⚠️ Invariant verification: {}/{} passed, {} failed\n", result.passed_count, total, result.failed_count, ) - }; - - summary + } } /// Stamp an envelope with a verification token and re-write it to disk. @@ -432,17 +424,13 @@ fn stamp_envelope( ) -> Result<()> { let key = get_or_create_verification_key()?; - // Compute token over the original envelope (without any previous verified field) + // Compute token over the envelope without any previous verified field let mut clean_envelope = envelope.clone(); clean_envelope.verified = None; let token = mint_token(&key, &clean_envelope, rulespec); - // Set the verified field and re-write - let mut stamped = envelope.clone(); - stamped.verified = Some(token); - write_envelope(session_id, &stamped)?; - - Ok(()) + clean_envelope.verified = Some(token); + write_envelope(session_id, &clean_envelope) } // ============================================================================ diff --git a/crates/g3-core/src/tools/invariants.rs b/crates/g3-core/src/tools/invariants.rs index ab4eb21..d22c52a 100644 --- a/crates/g3-core/src/tools/invariants.rs +++ b/crates/g3-core/src/tools/invariants.rs @@ -559,234 +559,139 @@ pub fn evaluate_predicate( selected_values: &[YamlValue], ) -> PredicateResult { match predicate.rule { - PredicateRule::Exists => { - // Filter out null values — null means "absent" - let non_null: Vec<_> = selected_values.iter() - .filter(|v| !v.is_null()) - .collect(); - if non_null.is_empty() { - PredicateResult::fail("Value does not exist") + PredicateRule::Exists => eval_yaml_existence(selected_values, true), + PredicateRule::NotExists => eval_yaml_existence(selected_values, false), + PredicateRule::Contains => eval_yaml_containment(predicate, selected_values, true), + PredicateRule::NotContains => eval_yaml_containment(predicate, selected_values, false), + PredicateRule::Equals => eval_yaml_equals(predicate, selected_values), + PredicateRule::MinLength => eval_yaml_length(predicate, selected_values, |len, n| len >= n, "min"), + PredicateRule::MaxLength => eval_yaml_length(predicate, selected_values, |len, n| len <= n, "max"), + PredicateRule::GreaterThan => eval_yaml_numeric(predicate, selected_values, |v, t| v > t, ">"), + PredicateRule::LessThan => eval_yaml_numeric(predicate, selected_values, |v, t| v < t, "<"), + PredicateRule::Matches => eval_yaml_matches(predicate, selected_values), + PredicateRule::AnyOf => eval_yaml_set(predicate, selected_values, true), + PredicateRule::NoneOf => eval_yaml_set(predicate, selected_values, false), + } +} + +// ── Per-rule evaluation helpers ───────────────────────────────────────── + +/// Exists / NotExists: any non-null value counts as "present". +fn eval_yaml_existence(values: &[YamlValue], expect_present: bool) -> PredicateResult { + let has_non_null = values.iter().any(|v| !v.is_null()); + match (expect_present, has_non_null) { + (true, true) => PredicateResult::pass("Value exists"), + (true, false) => PredicateResult::fail("Value does not exist"), + (false, false) => PredicateResult::pass("Value does not exist as expected"), + (false, true) => PredicateResult::fail("Value exists but should not"), + } +} + +/// Contains / NotContains: search selected values using `value_contains`. +fn eval_yaml_containment(pred: &Predicate, values: &[YamlValue], expect_present: bool) -> PredicateResult { + let Some(target) = &pred.value else { + let rule_name = if expect_present { "contains" } else { "not_contains" }; + return PredicateResult::fail(format!("No value specified for {}", rule_name)); + }; + let found = values.iter().any(|v| value_contains(v, target)); + let display = yaml_to_display(target); + match (expect_present, found) { + (true, true) => PredicateResult::pass(format!("Value contains {:?}", display)), + (true, false) => PredicateResult::fail(format!("Value does not contain {:?}", display)), + (false, false) => PredicateResult::pass(format!("Value does not contain {:?}", display)), + (false, true) => PredicateResult::fail(format!("Value contains {:?} but should not", display)), + } +} + +fn eval_yaml_equals(pred: &Predicate, values: &[YamlValue]) -> PredicateResult { + let Some(target) = &pred.value else { + return PredicateResult::fail("No value specified for equals"); + }; + if values.len() != 1 { + return PredicateResult::fail(format!("Expected single value for equals, got {}", values.len())); + } + if &values[0] == target { + PredicateResult::pass("Values are equal") + } else { + PredicateResult::fail(format!("Values not equal: {:?} != {:?}", yaml_to_display(&values[0]), yaml_to_display(target))) + } +} + +/// MinLength / MaxLength: find the first Sequence and compare its length. +fn eval_yaml_length(pred: &Predicate, values: &[YamlValue], cmp: fn(usize, usize) -> bool, label: &str) -> PredicateResult { + let threshold = match &pred.value { + Some(YamlValue::Number(n)) => n.as_u64().unwrap_or(0) as usize, + _ => return PredicateResult::fail(format!("{}_length requires a numeric value", label)), + }; + for value in values { + if let YamlValue::Sequence(seq) = value { + return if cmp(seq.len(), threshold) { + PredicateResult::pass(format!("Array has {} elements ({}: {})", seq.len(), label, threshold)) } else { - PredicateResult::pass("Value exists") - } + PredicateResult::fail(format!("Array has {} elements ({}: {})", seq.len(), label, threshold)) + }; } - PredicateRule::NotExists => { - // Filter out null values — null means "absent" - let non_null: Vec<_> = selected_values.iter() - .filter(|v| !v.is_null()) - .collect(); - if non_null.is_empty() { - PredicateResult::pass("Value does not exist as expected") + } + PredicateResult::fail("Value is not an array") +} + +/// GreaterThan / LessThan: find the first Number and compare. +fn eval_yaml_numeric(pred: &Predicate, values: &[YamlValue], cmp: fn(f64, f64) -> bool, op: &str) -> PredicateResult { + let target = match &pred.value { + Some(YamlValue::Number(n)) => n.as_f64().unwrap_or(0.0), + _ => return PredicateResult::fail(format!("{} requires a numeric value", pred.rule)), + }; + for value in values { + if let YamlValue::Number(n) = value { + let v = n.as_f64().unwrap_or(0.0); + return if cmp(v, target) { + PredicateResult::pass(format!("{} {} {}", v, op, target)) } else { - PredicateResult::fail("Value exists but should not") + PredicateResult::fail(format!("{} is not {} {}", v, op, target)) + }; + } + } + PredicateResult::fail("Value is not a number") +} + +fn eval_yaml_matches(pred: &Predicate, values: &[YamlValue]) -> PredicateResult { + let Some(YamlValue::String(pattern)) = &pred.value else { + return PredicateResult::fail("matches requires a string pattern"); + }; + let regex = match regex::Regex::new(pattern) { + Ok(r) => r, + Err(e) => return PredicateResult::fail(format!("Invalid regex: {}", e)), + }; + for value in values { + if let YamlValue::String(s) = value { + if regex.is_match(s) { + return PredicateResult::pass(format!("'{}' matches pattern", s)); } } - PredicateRule::Contains => { - let target = match &predicate.value { - Some(v) => v, - None => return PredicateResult::fail("No value specified for contains"), - }; - - for value in selected_values { - if value_contains(value, target) { - return PredicateResult::pass(format!( - "Value contains {:?}", - yaml_to_display(target) - )); - } - } - PredicateResult::fail(format!( - "Value does not contain {:?}", - yaml_to_display(target) - )) + } + PredicateResult::fail(format!("No value matches pattern '{}'", pattern)) +} + +/// AnyOf / NoneOf: check if selected values are in (or not in) a set. +fn eval_yaml_set(pred: &Predicate, values: &[YamlValue], expect_in_set: bool) -> PredicateResult { + let label = if expect_in_set { "any_of" } else { "none_of" }; + let set = match &pred.value { + Some(YamlValue::Sequence(seq)) => seq, + Some(_) => return PredicateResult::fail(format!("{} requires an array value", label)), + None => return PredicateResult::fail(format!("No value specified for {}", label)), + }; + let found = values.iter().any(|v| set.contains(v)); + let set_display = set.iter().map(yaml_to_display).collect::>().join(", "); + match (expect_in_set, found) { + (true, true) => { + let matched = values.iter().find(|v| set.contains(v)).unwrap(); + PredicateResult::pass(format!("Value {:?} is in allowed set", yaml_to_display(matched))) } - PredicateRule::Equals => { - let target = match &predicate.value { - Some(v) => v, - None => return PredicateResult::fail("No value specified for equals"), - }; - - if selected_values.len() != 1 { - return PredicateResult::fail(format!( - "Expected single value for equals, got {}", - selected_values.len() - )); - } - - if &selected_values[0] == target { - PredicateResult::pass("Values are equal") - } else { - PredicateResult::fail(format!( - "Values not equal: {:?} != {:?}", - yaml_to_display(&selected_values[0]), - yaml_to_display(target) - )) - } - } - PredicateRule::MinLength => { - let min = match &predicate.value { - Some(YamlValue::Number(n)) => n.as_u64().unwrap_or(0) as usize, - _ => return PredicateResult::fail("min_length requires a numeric value"), - }; - - for value in selected_values { - if let YamlValue::Sequence(seq) = value { - if seq.len() >= min { - return PredicateResult::pass(format!( - "Array has {} elements (min: {})", - seq.len(), - min - )); - } else { - return PredicateResult::fail(format!( - "Array has {} elements (min: {})", - seq.len(), - min - )); - } - } - } - PredicateResult::fail("Value is not an array") - } - PredicateRule::MaxLength => { - let max = match &predicate.value { - Some(YamlValue::Number(n)) => n.as_u64().unwrap_or(0) as usize, - _ => return PredicateResult::fail("max_length requires a numeric value"), - }; - - for value in selected_values { - if let YamlValue::Sequence(seq) = value { - if seq.len() <= max { - return PredicateResult::pass(format!( - "Array has {} elements (max: {})", - seq.len(), - max - )); - } else { - return PredicateResult::fail(format!( - "Array has {} elements (max: {})", - seq.len(), - max - )); - } - } - } - PredicateResult::fail("Value is not an array") - } - PredicateRule::GreaterThan => { - let target = match &predicate.value { - Some(YamlValue::Number(n)) => n.as_f64().unwrap_or(0.0), - _ => return PredicateResult::fail("greater_than requires a numeric value"), - }; - - for value in selected_values { - if let YamlValue::Number(n) = value { - let v = n.as_f64().unwrap_or(0.0); - if v > target { - return PredicateResult::pass(format!("{} > {}", v, target)); - } else { - return PredicateResult::fail(format!("{} is not > {}", v, target)); - } - } - } - PredicateResult::fail("Value is not a number") - } - PredicateRule::LessThan => { - let target = match &predicate.value { - Some(YamlValue::Number(n)) => n.as_f64().unwrap_or(0.0), - _ => return PredicateResult::fail("less_than requires a numeric value"), - }; - - for value in selected_values { - if let YamlValue::Number(n) = value { - let v = n.as_f64().unwrap_or(0.0); - if v < target { - return PredicateResult::pass(format!("{} < {}", v, target)); - } else { - return PredicateResult::fail(format!("{} is not < {}", v, target)); - } - } - } - PredicateResult::fail("Value is not a number") - } - PredicateRule::Matches => { - let pattern = match &predicate.value { - Some(YamlValue::String(s)) => s, - _ => return PredicateResult::fail("matches requires a string pattern"), - }; - - let regex = match regex::Regex::new(pattern) { - Ok(r) => r, - Err(e) => return PredicateResult::fail(format!("Invalid regex: {}", e)), - }; - - for value in selected_values { - if let YamlValue::String(s) = value { - if regex.is_match(s) { - return PredicateResult::pass(format!("'{}' matches pattern", s)); - } - } - } - PredicateResult::fail(format!("No value matches pattern '{}'", pattern)) - } - PredicateRule::NotContains => { - let target = match &predicate.value { - Some(v) => v, - None => return PredicateResult::fail("No value specified for not_contains"), - }; - - for value in selected_values { - if value_contains(value, target) { - return PredicateResult::fail(format!( - "Value contains {:?} but should not", - yaml_to_display(target) - )); - } - } - PredicateResult::pass(format!( - "Value does not contain {:?}", - yaml_to_display(target) - )) - } - PredicateRule::AnyOf => { - let allowed = match &predicate.value { - Some(YamlValue::Sequence(seq)) => seq, - Some(_) => return PredicateResult::fail("any_of requires an array value"), - None => return PredicateResult::fail("No value specified for any_of"), - }; - - for value in selected_values { - if allowed.contains(value) { - return PredicateResult::pass(format!( - "Value {:?} is in allowed set", - yaml_to_display(value) - )); - } - } - PredicateResult::fail(format!( - "Value is not in allowed set [{}]", - allowed.iter().map(yaml_to_display).collect::>().join(", ") - )) - } - PredicateRule::NoneOf => { - let forbidden = match &predicate.value { - Some(YamlValue::Sequence(seq)) => seq, - Some(_) => return PredicateResult::fail("none_of requires an array value"), - None => return PredicateResult::fail("No value specified for none_of"), - }; - - for value in selected_values { - if forbidden.contains(value) { - return PredicateResult::fail(format!( - "Value {:?} is in forbidden set", - yaml_to_display(value) - )); - } - } - PredicateResult::pass(format!( - "Value is not in forbidden set [{}]", - forbidden.iter().map(yaml_to_display).collect::>().join(", ") - )) + (true, false) => PredicateResult::fail(format!("Value is not in allowed set [{}]", set_display)), + (false, false) => PredicateResult::pass(format!("Value is not in forbidden set [{}]", set_display)), + (false, true) => { + let matched = values.iter().find(|v| set.contains(v)).unwrap(); + PredicateResult::fail(format!("Value {:?} is in forbidden set", yaml_to_display(matched))) } } } diff --git a/crates/g3-providers/src/anthropic.rs b/crates/g3-providers/src/anthropic.rs index 2e6ea76..2f8485a 100644 --- a/crates/g3-providers/src/anthropic.rs +++ b/crates/g3-providers/src/anthropic.rs @@ -135,6 +135,134 @@ pub struct AnthropicProvider { thinking_budget_tokens: Option, } +// ── SSE Stream State ──────────────────────────────────────────────────── +// Mutable state threaded through Anthropic's SSE stream parser. +// Each `handle_*` method processes one event type and returns chunks to send. + +struct StreamState { + tool_calls: Vec, + partial_tool_json: String, + usage: Option, + message_stopped: bool, + stop_reason: Option, +} + +impl StreamState { + fn new() -> Self { + Self { + tool_calls: Vec::new(), + partial_tool_json: String::new(), + usage: None, + message_stopped: false, + stop_reason: None, + } + } + + fn handle_message_start(&mut self, event: &AnthropicStreamEvent) { + if let Some(message) = &event.message { + if let Some(u) = &message.usage { + self.usage = Some(Usage { + prompt_tokens: u.input_tokens, + completion_tokens: u.output_tokens, + total_tokens: u.input_tokens + u.output_tokens, + cache_creation_tokens: u.cache_creation_input_tokens, + cache_read_tokens: u.cache_read_input_tokens, + }); + debug!("Captured usage from message_start: {:?}", self.usage); + } + } + } + + /// Returns chunks to send for a content_block_start event. + fn handle_block_start(&mut self, event: AnthropicStreamEvent) -> Vec> { + let Some(content_block) = event.content_block else { return vec![] }; + match content_block { + AnthropicContent::ToolUse { id, name, input } => { + debug!("Tool use block: id={}, name={}, input={:?}", id, name, input); + let tool_call = ToolCall { id: id.clone(), tool: name.clone(), args: input.clone() }; + + let has_complete_args = !input.is_null() + && input != serde_json::Value::Object(serde_json::Map::new()); + + if has_complete_args { + debug!("Tool call has complete args, sending immediately"); + vec![Ok(make_tool_chunk(vec![tool_call]))] + } else { + debug!("Tool call has empty args, will accumulate from partial_json"); + let hint = make_tool_streaming_hint(name); + self.tool_calls.push(tool_call); + self.partial_tool_json.clear(); + vec![Ok(hint)] + } + } + _ => { + debug!("Non-tool content block: {:?}", content_block); + vec![] + } + } + } + + /// Returns chunks to send for a content_block_delta event. + fn handle_block_delta(&mut self, event: AnthropicStreamEvent) -> Vec> { + let Some(delta) = event.delta else { return vec![] }; + let mut chunks = Vec::new(); + if let Some(text) = delta.text { + debug!("Text chunk (len {})", text.len()); + chunks.push(Ok(make_text_chunk(text))); + } + if let Some(json_fragment) = delta.partial_json { + debug!("Partial JSON: {}", json_fragment); + self.partial_tool_json.push_str(&json_fragment); + chunks.push(Ok(make_tool_streaming_active())); + } + chunks + } + + /// Returns chunks to send when a content block finishes. + fn handle_block_stop(&mut self) -> Vec> { + // Finalize accumulated partial JSON into the last tool call's args + if !self.tool_calls.is_empty() && !self.partial_tool_json.is_empty() { + debug!("Parsing complete tool JSON: {}", self.partial_tool_json); + if let Ok(parsed) = serde_json::from_str::(&self.partial_tool_json) { + if let Some(last) = self.tool_calls.last_mut() { + last.args = parsed; + debug!("Updated tool call with complete args: {:?}", last); + } + } else { + debug!("Failed to parse accumulated JSON: {}", self.partial_tool_json); + } + self.partial_tool_json.clear(); + } + + if self.tool_calls.is_empty() { + return vec![]; + } + let chunk = make_tool_chunk(self.tool_calls.clone()); + self.tool_calls.clear(); + vec![Ok(chunk)] + } + + fn handle_message_delta(&mut self, event: &AnthropicStreamEvent) { + if let Some(delta) = &event.delta { + if let Some(reason) = &delta.stop_reason { + debug!("Received stop_reason: {}", reason); + self.stop_reason = Some(reason.clone()); + } + } + } + + fn handle_message_stop(&mut self) -> Vec> { + debug!("Received message stop event"); + self.message_stopped = true; + let chunk = make_final_chunk_with_reason( + self.tool_calls.clone(), + self.usage.clone(), + self.stop_reason.clone(), + ); + vec![Ok(chunk)] + } +} + impl AnthropicProvider { pub fn new( api_key: String, @@ -501,267 +629,76 @@ impl AnthropicProvider { mut stream: impl futures_util::Stream> + Unpin, tx: mpsc::Sender>, ) -> Option { - let mut buffer = String::new(); - let mut current_tool_calls: Vec = Vec::new(); - let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls - let mut accumulated_usage: Option = None; - let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences - let mut message_stopped = false; // Track if we've received message_stop - let mut stop_reason: Option = None; // Track why the message stopped + let mut state = StreamState::new(); + let mut line_buffer = String::new(); + let mut byte_buffer: Vec = Vec::new(); while let Some(chunk_result) = stream.next().await { match chunk_result { Ok(chunk) => { byte_buffer.extend_from_slice(&chunk); - let Some(chunk_str) = decode_utf8_streaming(&mut byte_buffer) else { continue; }; + line_buffer.push_str(&chunk_str); - buffer.push_str(&chunk_str); + while let Some(line_end) = line_buffer.find('\n') { + let line = line_buffer[..line_end].trim().to_string(); + line_buffer.drain(..line_end + 1); - // Process complete lines - while let Some(line_end) = buffer.find('\n') { - let line = buffer[..line_end].trim().to_string(); - buffer.drain(..line_end + 1); - - if line.is_empty() { - continue; - } - - // If we've already sent the final chunk, skip processing more events - if message_stopped { - debug!("Skipping event after message_stop: {}", line); - continue; - } - - // Parse Server-Sent Events format - if let Some(data) = line.strip_prefix("data: ") { - if data == "[DONE]" { - debug!("Received stream completion marker"); - let final_chunk = make_final_chunk( - current_tool_calls.clone(), - accumulated_usage.clone(), - ); - if tx.send(Ok(final_chunk)).await.is_err() { - debug!("Receiver dropped, stopping stream"); - } - return accumulated_usage; + if line.is_empty() || state.message_stopped { + if state.message_stopped && !line.is_empty() { + debug!("Skipping event after message_stop: {}", line); } + continue; + } - debug!("Raw Claude API JSON: {}", data); + let Some(data) = line.strip_prefix("data: ") else { continue }; - match serde_json::from_str::(data) { - Ok(event) => { - debug!( - "Parsed event type: {}, event: {:?}", - event.event_type, event - ); - match event.event_type.as_str() { - "message_start" => { - // Extract usage data from message_start event - if let Some(message) = event.message { - if let Some(usage) = message.usage { - accumulated_usage = Some(Usage { - prompt_tokens: usage.input_tokens, - completion_tokens: usage.output_tokens, - total_tokens: usage.input_tokens - + usage.output_tokens, - cache_creation_tokens: usage - .cache_creation_input_tokens, - cache_read_tokens: usage - .cache_read_input_tokens, - }); - debug!( - "Captured usage from message_start: {:?}", - accumulated_usage - ); - } - } - } - "content_block_start" => { - debug!( - "Received content_block_start event: {:?}", - event - ); - if let Some(content_block) = event.content_block { - match content_block { - AnthropicContent::ToolUse { - id, - name, - input, - } => { - debug!("Found tool use in content_block_start: id={}, name={}, input={:?}", id, name, input); + // Stream completion marker + if data == "[DONE]" { + debug!("Received stream completion marker"); + let final_chunk = make_final_chunk(state.tool_calls.clone(), state.usage.clone()); + let _ = tx.send(Ok(final_chunk)).await; + return state.usage; + } - // For native tool calls, create the tool call immediately if we have complete args - // If args are empty, we'll wait for partial_json to accumulate them - let tool_call = ToolCall { - id: id.clone(), - tool: name.clone(), - args: input.clone(), - }; + debug!("Raw Claude API JSON: {}", data); - // Check if we already have complete arguments - if !input.is_null() - && input - != serde_json::Value::Object( - serde_json::Map::new(), - ) - { - // We have complete arguments, send the tool call immediately - debug!("Tool call has complete args, sending immediately: {:?}", tool_call); - let chunk = - make_tool_chunk(vec![tool_call]); - if tx.send(Ok(chunk)).await.is_err() { - debug!("Receiver dropped, stopping stream"); - return accumulated_usage; - } - } else { - // Arguments are empty, we'll accumulate them from partial_json - debug!("Tool call has empty args, will accumulate from partial_json"); - // Send a streaming hint so the UI can show the tool name immediately - let hint_chunk = make_tool_streaming_hint(name.clone()); - if tx.send(Ok(hint_chunk)).await.is_err() { - debug!("Receiver dropped, stopping stream"); - return accumulated_usage; - } - current_tool_calls.push(tool_call); - partial_tool_json.clear(); - } - } - _ => { - debug!( - "Non-tool content block: {:?}", - content_block - ); - } - } - } - } - "content_block_delta" => { - if let Some(delta) = event.delta { - if let Some(text) = delta.text { - debug!( - "Sending text chunk of length {}: '{}'", - text.len(), - text - ); - let chunk = make_text_chunk(text); - if tx.send(Ok(chunk)).await.is_err() { - debug!("Receiver dropped, stopping stream"); - return accumulated_usage; - } - } - // Handle partial JSON for tool calls - if let Some(partial_json) = delta.partial_json { - debug!( - "Received partial JSON: {}", - partial_json - ); - partial_tool_json.push_str(&partial_json); - debug!( - "Accumulated tool JSON: {}", - partial_tool_json - ); - // Send an active hint to trigger UI blink - let active_chunk = make_tool_streaming_active(); - if tx.send(Ok(active_chunk)).await.is_err() { - debug!("Receiver dropped, stopping stream"); - return accumulated_usage; - } - } - } - } - "content_block_stop" => { - // Tool call block is complete - now parse the accumulated JSON - if !current_tool_calls.is_empty() - && !partial_tool_json.is_empty() - { - debug!( - "Parsing complete tool JSON: {}", - partial_tool_json - ); + let event = match serde_json::from_str::(data) { + Ok(e) => e, + Err(e) => { + debug!("Failed to parse stream event: {} - Data: {}", e, data); + continue; + } + }; - // Parse the accumulated JSON and update the last tool call - if let Ok(parsed_args) = - serde_json::from_str::( - &partial_tool_json, - ) - { - if let Some(last_tool) = - current_tool_calls.last_mut() - { - last_tool.args = parsed_args; - debug!("Updated tool call with complete args: {:?}", last_tool); - } - } else { - debug!( - "Failed to parse accumulated JSON: {}", - partial_tool_json - ); - } + debug!("Parsed event type: {}", event.event_type); - // Clear the accumulator - partial_tool_json.clear(); - } - - // Send the complete tool call - if !current_tool_calls.is_empty() { - let chunk = - make_tool_chunk(current_tool_calls.clone()); - if tx.send(Ok(chunk)).await.is_err() { - debug!("Receiver dropped, stopping stream"); - return accumulated_usage; - } - // Clear tool calls after sending to prevent duplicates at message_stop - current_tool_calls.clear(); - } - } - "message_delta" => { - // message_delta contains the stop_reason and final usage - if let Some(delta) = &event.delta { - if let Some(reason) = &delta.stop_reason { - debug!("Received stop_reason: {}", reason); - stop_reason = Some(reason.clone()); - } - } - // Usage is also in message_delta but we get it from message_start - } - "message_stop" => { - debug!("Received message stop event"); - message_stopped = true; - let final_chunk = make_final_chunk_with_reason( - current_tool_calls.clone(), - accumulated_usage.clone(), - stop_reason.clone(), - ); - if tx.send(Ok(final_chunk)).await.is_err() { - debug!("Receiver dropped, stopping stream"); - } - // Don't return here - let the stream naturally exhaust - // This prevents dropping the sender prematurely - } - "error" => { - if let Some(error) = event.error { - error!("Anthropic API error: {:?}", error); - let _ = tx - .send(Err(anyhow!( - "Anthropic API error: {:?}", - error - ))) - .await; - break; // Break to let stream exhaust naturally - } - } - _ => { - debug!("Ignoring event type: {}", event.event_type); - } - } - } - Err(e) => { - debug!("Failed to parse stream event: {} - Data: {}", e, data); - // Don't error out on parse failures, just continue + // Dispatch to per-event handlers; collect chunks to send + let chunks: Vec> = match event.event_type.as_str() { + "message_start" => { state.handle_message_start(&event); vec![] } + "content_block_start" => state.handle_block_start(event), + "content_block_delta" => state.handle_block_delta(event), + "content_block_stop" => state.handle_block_stop(), + "message_delta" => { state.handle_message_delta(&event); vec![] } + "message_stop" => state.handle_message_stop(), + "error" => { + if let Some(error) = event.error { + error!("Anthropic API error: {:?}", error); + let _ = tx.send(Err(anyhow!("Anthropic API error: {:?}", error))).await; + break; } + vec![] + } + _ => { debug!("Ignoring event type: {}", event.event_type); vec![] } + }; + + // Send all chunks produced by the handler + for chunk in chunks { + if tx.send(chunk).await.is_err() { + debug!("Receiver dropped, stopping stream"); + return state.usage; } } } @@ -769,18 +706,14 @@ impl AnthropicProvider { Err(e) => { error!("Stream error: {}", e); let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await; - // Don't return here either - let the stream exhaust naturally - // The error has been sent to the receiver, so it will handle it - // Breaking here ensures we clean up properly break; } } } - // Send final chunk if we haven't already - let final_chunk = make_final_chunk(current_tool_calls, accumulated_usage.clone()); + let final_chunk = make_final_chunk(state.tool_calls, state.usage.clone()); let _ = tx.send(Ok(final_chunk)).await; - accumulated_usage + state.usage } }