Readability refactor: extract mega-functions into focused helpers
Agent: carmack 4 files refactored, net -250 lines, all tests passing (417 + 71). datalog.rs: - Extract 7 predicate evaluation helpers from evaluate_predicate_datalog() (~200-line match → 12-line dispatch table) - Extract rule_body_for_predicate() from format_datalog_program() (~75-line match → 2-line call) invariants.rs: - Extract 7 per-rule helpers from evaluate_predicate() (~230-line match → 12-line dispatch table) envelope.rs: - Simplify summary construction in verify_envelope() - Eliminate redundant clone in stamp_envelope() anthropic.rs: - Introduce StreamState struct with 6 handler methods - parse_streaming_response: ~290 lines → ~90 lines - Max nesting depth reduced from 8 to 4 levels
This commit is contained in:
@@ -425,7 +425,161 @@ pub fn execute_rules(
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate a single predicate using the fact lookup.
|
||||
// ── Predicate evaluation helpers ────────────────────────────────────────
|
||||
// Each returns (passed: bool, reason: String) for a specific rule type.
|
||||
|
||||
/// Exists / NotExists: check whether the claim has any values.
|
||||
/// `expect_present = true` → Exists; `false` → NotExists.
|
||||
fn eval_existence(claim_values: Option<&HashSet<&str>>, expect_present: bool) -> (bool, String) {
|
||||
let has_values = claim_values.map_or(false, |v| !v.is_empty());
|
||||
if expect_present == has_values {
|
||||
if has_values {
|
||||
(true, "Value exists".into())
|
||||
} else {
|
||||
(true, "Value does not exist as expected".into())
|
||||
}
|
||||
} else if expect_present {
|
||||
(false, "Value does not exist".into())
|
||||
} else {
|
||||
(false, "Value exists but should not".into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Contains / NotContains: check whether a specific value is among the claim's facts.
|
||||
/// `expect_present = true` → Contains; `false` → NotContains.
|
||||
fn eval_membership(
|
||||
claim_values: Option<&HashSet<&str>>,
|
||||
pred: &CompiledPredicate,
|
||||
expect_present: bool,
|
||||
) -> (bool, String) {
|
||||
let expected = pred.expected_value.as_deref().unwrap_or("");
|
||||
match claim_values {
|
||||
Some(values) => {
|
||||
let found = values.contains(expected);
|
||||
if expect_present {
|
||||
if found {
|
||||
(true, format!("Contains '{}'", expected))
|
||||
} else {
|
||||
(false, format!("Does not contain '{}'", expected))
|
||||
}
|
||||
} else if found {
|
||||
(false, format!("Contains '{}' but should not", expected))
|
||||
} else {
|
||||
(true, format!("Does not contain '{}'", expected))
|
||||
}
|
||||
}
|
||||
None if expect_present => (false, format!("Claim '{}' has no values", pred.claim_name)),
|
||||
None => (true, format!("Claim '{}' has no values (not_contains passes vacuously)", pred.claim_name)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Equals: exactly one value must match the expected string.
|
||||
fn eval_equals(claim_values: Option<&HashSet<&str>>, pred: &CompiledPredicate) -> (bool, String) {
|
||||
let expected = pred.expected_value.as_deref().unwrap_or("");
|
||||
let Some(values) = claim_values else {
|
||||
return (false, format!("Claim '{}' has no values", pred.claim_name));
|
||||
};
|
||||
if values.len() == 1 && values.contains(expected) {
|
||||
(true, format!("Equals '{}'", expected))
|
||||
} else if values.len() > 1 {
|
||||
(false, format!("Multiple values found, expected single value '{}'", expected))
|
||||
} else {
|
||||
let actual = values.iter().next().copied().unwrap_or("<none>");
|
||||
(false, format!("Expected '{}', got '{}'", expected, actual))
|
||||
}
|
||||
}
|
||||
|
||||
/// MinLength / MaxLength: compare the `__length` fact against a threshold.
|
||||
fn eval_length(
|
||||
fact_lookup: &HashMap<&str, HashSet<&str>>,
|
||||
pred: &CompiledPredicate,
|
||||
cmp: impl Fn(usize, usize) -> bool,
|
||||
pass_op: &str,
|
||||
fail_op: &str,
|
||||
label: &str,
|
||||
) -> (bool, String) {
|
||||
let expected: usize = pred.expected_value.as_deref()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(0);
|
||||
let length_key = format!("{}.__length", pred.claim_name);
|
||||
let length: usize = fact_lookup
|
||||
.get(length_key.as_str())
|
||||
.and_then(|v| v.iter().next())
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(0);
|
||||
if cmp(length, expected) {
|
||||
(true, format!("Length {} {} {}", length, pass_op, expected))
|
||||
} else {
|
||||
(false, format!("Length {} {} {} ({})", length, fail_op, expected, label))
|
||||
}
|
||||
}
|
||||
|
||||
/// GreaterThan / LessThan: parse the first claim value as f64 and compare.
|
||||
fn eval_numeric_cmp(
|
||||
claim_values: Option<&HashSet<&str>>,
|
||||
pred: &CompiledPredicate,
|
||||
cmp: impl Fn(f64, f64) -> bool,
|
||||
op: &str,
|
||||
) -> (bool, String) {
|
||||
let expected: f64 = pred.expected_value.as_deref()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(0.0);
|
||||
let Some(values) = claim_values else {
|
||||
return (false, format!("Claim '{}' has no values", pred.claim_name));
|
||||
};
|
||||
match values.iter().next().and_then(|s| s.parse::<f64>().ok()) {
|
||||
Some(actual) if cmp(actual, expected) => (true, format!("{} {} {}", actual, op, expected)),
|
||||
Some(actual) => (false, format!("{} is not {} {}", actual, op, expected)),
|
||||
None => (false, "Value is not a number".into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Matches: check if any claim value matches a regex pattern.
|
||||
fn eval_matches(claim_values: Option<&HashSet<&str>>, pred: &CompiledPredicate) -> (bool, String) {
|
||||
let pattern = pred.expected_value.as_deref().unwrap_or("");
|
||||
let regex = match regex::Regex::new(pattern) {
|
||||
Ok(r) => r,
|
||||
Err(e) => return (false, format!("Invalid regex: {}", e)),
|
||||
};
|
||||
let Some(values) = claim_values else {
|
||||
return (false, format!("Claim '{}' has no values", pred.claim_name));
|
||||
};
|
||||
if values.iter().any(|v| regex.is_match(v)) {
|
||||
(true, format!("Matches pattern '{}'", pattern))
|
||||
} else {
|
||||
(false, format!("No value matches pattern '{}'", pattern))
|
||||
}
|
||||
}
|
||||
|
||||
/// AnyOf / NoneOf: parse the expected value as a bracketed set and check membership.
|
||||
/// `expect_in_set = true` → AnyOf (pass if value in set); `false` → NoneOf.
|
||||
fn eval_set_membership(
|
||||
claim_values: Option<&HashSet<&str>>,
|
||||
pred: &CompiledPredicate,
|
||||
expect_in_set: bool,
|
||||
) -> (bool, String) {
|
||||
let set: Vec<&str> = pred.expected_value.as_deref()
|
||||
.map(|v| v.trim_matches(|c| c == '[' || c == ']').split(", ").collect())
|
||||
.unwrap_or_default();
|
||||
let Some(values) = claim_values else {
|
||||
return if expect_in_set {
|
||||
(false, format!("Claim '{}' has no values", pred.claim_name))
|
||||
} else {
|
||||
(true, format!("Claim '{}' has no values (none_of passes vacuously)", pred.claim_name))
|
||||
};
|
||||
};
|
||||
let found = values.iter().any(|v| set.contains(v));
|
||||
if expect_in_set {
|
||||
if found { (true, "Value is in allowed set".into()) }
|
||||
else { (false, "Value is not in allowed set".into()) }
|
||||
} else if found {
|
||||
(false, "Value is in forbidden set".into())
|
||||
} else {
|
||||
(true, "Value is not in forbidden set".into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate a single predicate against the fact lookup table.
|
||||
fn evaluate_predicate_datalog(
|
||||
pred: &CompiledPredicate,
|
||||
fact_lookup: &HashMap<&str, HashSet<&str>>,
|
||||
@@ -433,202 +587,18 @@ fn evaluate_predicate_datalog(
|
||||
let claim_values = fact_lookup.get(pred.claim_name.as_str());
|
||||
|
||||
let (passed, reason) = match pred.rule {
|
||||
PredicateRule::Exists => {
|
||||
if claim_values.is_some() && !claim_values.unwrap().is_empty() {
|
||||
(true, "Value exists".to_string())
|
||||
} else {
|
||||
(false, "Value does not exist".to_string())
|
||||
}
|
||||
}
|
||||
PredicateRule::NotExists => {
|
||||
if claim_values.is_none() || claim_values.unwrap().is_empty() {
|
||||
(true, "Value does not exist as expected".to_string())
|
||||
} else {
|
||||
(false, "Value exists but should not".to_string())
|
||||
}
|
||||
}
|
||||
PredicateRule::Contains => {
|
||||
let expected = pred.expected_value.as_deref().unwrap_or("");
|
||||
if let Some(values) = claim_values {
|
||||
if values.contains(expected) {
|
||||
(true, format!("Contains '{}'", expected))
|
||||
} else {
|
||||
(false, format!("Does not contain '{}'", expected))
|
||||
}
|
||||
} else {
|
||||
(false, format!("Claim '{}' has no values", pred.claim_name))
|
||||
}
|
||||
}
|
||||
PredicateRule::Equals => {
|
||||
let expected = pred.expected_value.as_deref().unwrap_or("");
|
||||
if let Some(values) = claim_values {
|
||||
if values.len() == 1 && values.contains(expected) {
|
||||
(true, format!("Equals '{}'", expected))
|
||||
} else if values.len() > 1 {
|
||||
(false, format!("Multiple values found, expected single value '{}'", expected))
|
||||
} else {
|
||||
let actual = values.iter().next().map(|s| *s).unwrap_or("<none>");
|
||||
(false, format!("Expected '{}', got '{}'", expected, actual))
|
||||
}
|
||||
} else {
|
||||
(false, format!("Claim '{}' has no values", pred.claim_name))
|
||||
}
|
||||
}
|
||||
PredicateRule::MinLength => {
|
||||
let expected: usize = pred
|
||||
.expected_value
|
||||
.as_deref()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
// Check the __length fact
|
||||
let length_claim = format!("{}.__length", pred.claim_name);
|
||||
let length = fact_lookup
|
||||
.get(length_claim.as_str())
|
||||
.and_then(|v| v.iter().next())
|
||||
.and_then(|s| s.parse::<usize>().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
if length >= expected {
|
||||
(true, format!("Length {} >= {}", length, expected))
|
||||
} else {
|
||||
(false, format!("Length {} < {} (minimum)", length, expected))
|
||||
}
|
||||
}
|
||||
PredicateRule::MaxLength => {
|
||||
let expected: usize = pred
|
||||
.expected_value
|
||||
.as_deref()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(usize::MAX);
|
||||
|
||||
let length_claim = format!("{}.__length", pred.claim_name);
|
||||
let length = fact_lookup
|
||||
.get(length_claim.as_str())
|
||||
.and_then(|v| v.iter().next())
|
||||
.and_then(|s| s.parse::<usize>().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
if length <= expected {
|
||||
(true, format!("Length {} <= {}", length, expected))
|
||||
} else {
|
||||
(false, format!("Length {} > {} (maximum)", length, expected))
|
||||
}
|
||||
}
|
||||
PredicateRule::GreaterThan => {
|
||||
let expected: f64 = pred
|
||||
.expected_value
|
||||
.as_deref()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(0.0);
|
||||
|
||||
if let Some(values) = claim_values {
|
||||
if let Some(actual) = values.iter().next().and_then(|s| s.parse::<f64>().ok()) {
|
||||
if actual > expected {
|
||||
(true, format!("{} > {}", actual, expected))
|
||||
} else {
|
||||
(false, format!("{} is not > {}", actual, expected))
|
||||
}
|
||||
} else {
|
||||
(false, "Value is not a number".to_string())
|
||||
}
|
||||
} else {
|
||||
(false, format!("Claim '{}' has no values", pred.claim_name))
|
||||
}
|
||||
}
|
||||
PredicateRule::LessThan => {
|
||||
let expected: f64 = pred
|
||||
.expected_value
|
||||
.as_deref()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(0.0);
|
||||
|
||||
if let Some(values) = claim_values {
|
||||
if let Some(actual) = values.iter().next().and_then(|s| s.parse::<f64>().ok()) {
|
||||
if actual < expected {
|
||||
(true, format!("{} < {}", actual, expected))
|
||||
} else {
|
||||
(false, format!("{} is not < {}", actual, expected))
|
||||
}
|
||||
} else {
|
||||
(false, "Value is not a number".to_string())
|
||||
}
|
||||
} else {
|
||||
(false, format!("Claim '{}' has no values", pred.claim_name))
|
||||
}
|
||||
}
|
||||
PredicateRule::Matches => {
|
||||
let pattern = pred.expected_value.as_deref().unwrap_or("");
|
||||
let regex = match regex::Regex::new(pattern) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
return DatalogPredicateResult {
|
||||
id: pred.id,
|
||||
claim_name: pred.claim_name.clone(),
|
||||
rule: pred.rule.clone(),
|
||||
expected_value: pred.expected_value.clone(),
|
||||
passed: false,
|
||||
reason: format!("Invalid regex: {}", e),
|
||||
source: pred.source,
|
||||
notes: pred.notes.clone(),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(values) = claim_values {
|
||||
if values.iter().any(|v| regex.is_match(v)) {
|
||||
(true, format!("Matches pattern '{}'", pattern))
|
||||
} else {
|
||||
(false, format!("No value matches pattern '{}'", pattern))
|
||||
}
|
||||
} else {
|
||||
(false, format!("Claim '{}' has no values", pred.claim_name))
|
||||
}
|
||||
}
|
||||
PredicateRule::NotContains => {
|
||||
let expected = pred.expected_value.as_deref().unwrap_or("");
|
||||
if let Some(values) = claim_values {
|
||||
if values.contains(expected) {
|
||||
(false, format!("Contains '{}' but should not", expected))
|
||||
} else {
|
||||
(true, format!("Does not contain '{}'", expected))
|
||||
}
|
||||
} else {
|
||||
(true, format!("Claim '{}' has no values (not_contains passes vacuously)", pred.claim_name))
|
||||
}
|
||||
}
|
||||
PredicateRule::AnyOf => {
|
||||
let expected_set: Vec<&str> = pred.expected_value.as_deref()
|
||||
.map(|v| v.trim_matches(|c| c == '[' || c == ']')
|
||||
.split(", ")
|
||||
.collect())
|
||||
.unwrap_or_default();
|
||||
if let Some(values) = claim_values {
|
||||
if values.iter().any(|v| expected_set.contains(v)) {
|
||||
(true, format!("Value is in allowed set"))
|
||||
} else {
|
||||
(false, format!("Value is not in allowed set"))
|
||||
}
|
||||
} else {
|
||||
(false, format!("Claim '{}' has no values", pred.claim_name))
|
||||
}
|
||||
}
|
||||
PredicateRule::NoneOf => {
|
||||
let forbidden_set: Vec<&str> = pred.expected_value.as_deref()
|
||||
.map(|v| v.trim_matches(|c| c == '[' || c == ']')
|
||||
.split(", ")
|
||||
.collect())
|
||||
.unwrap_or_default();
|
||||
if let Some(values) = claim_values {
|
||||
if values.iter().any(|v| forbidden_set.contains(v)) {
|
||||
(false, format!("Value is in forbidden set"))
|
||||
} else {
|
||||
(true, format!("Value is not in forbidden set"))
|
||||
}
|
||||
} else {
|
||||
(true, format!("Claim '{}' has no values (none_of passes vacuously)", pred.claim_name))
|
||||
}
|
||||
}
|
||||
PredicateRule::Exists => eval_existence(claim_values, true),
|
||||
PredicateRule::NotExists => eval_existence(claim_values, false),
|
||||
PredicateRule::Contains => eval_membership(claim_values, pred, true),
|
||||
PredicateRule::NotContains => eval_membership(claim_values, pred, false),
|
||||
PredicateRule::Equals => eval_equals(claim_values, pred),
|
||||
PredicateRule::MinLength => eval_length(fact_lookup, pred, |len, exp| len >= exp, ">=", "<", "minimum"),
|
||||
PredicateRule::MaxLength => eval_length(fact_lookup, pred, |len, exp| len <= exp, "<=", ">", "maximum"),
|
||||
PredicateRule::GreaterThan => eval_numeric_cmp(claim_values, pred, |a, e| a > e, ">"),
|
||||
PredicateRule::LessThan => eval_numeric_cmp(claim_values, pred, |a, e| a < e, "<"),
|
||||
PredicateRule::Matches => eval_matches(claim_values, pred),
|
||||
PredicateRule::AnyOf => eval_set_membership(claim_values, pred, true),
|
||||
PredicateRule::NoneOf => eval_set_membership(claim_values, pred, false),
|
||||
};
|
||||
|
||||
DatalogPredicateResult {
|
||||
@@ -658,6 +628,36 @@ fn escape_datalog_string(s: &str) -> String {
|
||||
.replace('\t', "\\t")
|
||||
}
|
||||
|
||||
/// Generate the Soufflé rule body for a single predicate.
|
||||
///
|
||||
/// Returns the clause after `predicate_pass(id) :- `.
|
||||
fn rule_body_for_predicate(rule: &PredicateRule, claim: &str, expected: &str) -> String {
|
||||
match rule {
|
||||
PredicateRule::Exists =>
|
||||
format!("claim_value(\"{}\", _)", claim),
|
||||
PredicateRule::NotExists =>
|
||||
format!("!claim_value(\"{}\", _)", claim),
|
||||
PredicateRule::Equals | PredicateRule::Contains =>
|
||||
format!("claim_value(\"{}\", \"{}\")", claim, expected),
|
||||
PredicateRule::NotContains =>
|
||||
format!("!claim_value(\"{}\", \"{}\")", claim, expected),
|
||||
PredicateRule::GreaterThan =>
|
||||
format!("claim_value(\"{}\", V), to_number(V, N), N > {}", claim, expected),
|
||||
PredicateRule::LessThan =>
|
||||
format!("claim_value(\"{}\", V), to_number(V, N), N < {}", claim, expected),
|
||||
PredicateRule::MinLength =>
|
||||
format!("claim_length(\"{}\", N), N >= {}", claim, expected),
|
||||
PredicateRule::MaxLength =>
|
||||
format!("claim_length(\"{}\", N), N <= {}", claim, expected),
|
||||
PredicateRule::Matches =>
|
||||
format!("claim_value(\"{}\", V), match(\"{}\", V)", claim, expected),
|
||||
PredicateRule::AnyOf =>
|
||||
format!("claim_value(\"{}\", V), any_of(\"{}\", V)", claim, expected),
|
||||
PredicateRule::NoneOf =>
|
||||
format!("!claim_value(\"{}\", V), none_of(\"{}\", V)", claim, expected),
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a compiled rulespec and extracted facts as a datalog program.
|
||||
///
|
||||
/// Produces a textual `.dl` file with:
|
||||
@@ -739,83 +739,8 @@ pub fn format_datalog_program(
|
||||
pred.notes.as_deref().map(|n| format!(" -- {}", n)).unwrap_or_default(),
|
||||
));
|
||||
|
||||
match pred.rule {
|
||||
PredicateRule::Exists => {
|
||||
out.push_str(&format!(
|
||||
"predicate_pass({}) :- claim_value(\"{}\", _).\n",
|
||||
id, claim,
|
||||
));
|
||||
}
|
||||
PredicateRule::NotExists => {
|
||||
// Pass when no matching fact exists
|
||||
out.push_str(&format!(
|
||||
"predicate_pass({}) :- !claim_value(\"{}\", _).\n",
|
||||
id, claim,
|
||||
));
|
||||
}
|
||||
PredicateRule::Equals => {
|
||||
out.push_str(&format!(
|
||||
"predicate_pass({}) :- claim_value(\"{}\", \"{}\").\n",
|
||||
id, claim, expected,
|
||||
));
|
||||
}
|
||||
PredicateRule::Contains => {
|
||||
out.push_str(&format!(
|
||||
"predicate_pass({}) :- claim_value(\"{}\", \"{}\").\n",
|
||||
id, claim, expected,
|
||||
));
|
||||
}
|
||||
PredicateRule::GreaterThan => {
|
||||
out.push_str(&format!(
|
||||
"predicate_pass({}) :- claim_value(\"{}\", V), to_number(V, N), N > {}.\n",
|
||||
id, claim, expected,
|
||||
));
|
||||
}
|
||||
PredicateRule::LessThan => {
|
||||
out.push_str(&format!(
|
||||
"predicate_pass({}) :- claim_value(\"{}\", V), to_number(V, N), N < {}.\n",
|
||||
id, claim, expected,
|
||||
));
|
||||
}
|
||||
PredicateRule::MinLength => {
|
||||
out.push_str(&format!(
|
||||
"predicate_pass({}) :- claim_length(\"{}\", N), N >= {}.\n",
|
||||
id, claim, expected,
|
||||
));
|
||||
}
|
||||
PredicateRule::MaxLength => {
|
||||
out.push_str(&format!(
|
||||
"predicate_pass({}) :- claim_length(\"{}\", N), N <= {}.\n",
|
||||
id, claim, expected,
|
||||
));
|
||||
}
|
||||
PredicateRule::Matches => {
|
||||
// Regex matching expressed as a match functor
|
||||
out.push_str(&format!(
|
||||
"predicate_pass({}) :- claim_value(\"{}\", V), match(\"{}\", V).\n",
|
||||
id, claim, expected,
|
||||
));
|
||||
}
|
||||
PredicateRule::NotContains => {
|
||||
out.push_str(&format!(
|
||||
"predicate_pass({}) :- !claim_value(\"{}\", \"{}\").\n",
|
||||
id, claim, expected,
|
||||
));
|
||||
}
|
||||
PredicateRule::AnyOf => {
|
||||
// any_of: pass if claim value matches any element in the set
|
||||
out.push_str(&format!(
|
||||
"predicate_pass({}) :- claim_value(\"{}\", V), any_of(\"{}\", V).\n",
|
||||
id, claim, expected,
|
||||
));
|
||||
}
|
||||
PredicateRule::NoneOf => {
|
||||
out.push_str(&format!(
|
||||
"predicate_pass({}) :- !claim_value(\"{}\", V), none_of(\"{}\", V).\n",
|
||||
id, claim, expected,
|
||||
));
|
||||
}
|
||||
}
|
||||
let body = rule_body_for_predicate(&pred.rule, &claim, &expected);
|
||||
out.push_str(&format!("predicate_pass({}) :- {}.\n", id, body));
|
||||
|
||||
// Derive failure as the negation of pass
|
||||
out.push_str(&format!(
|
||||
@@ -828,7 +753,6 @@ pub fn format_datalog_program(
|
||||
out
|
||||
}
|
||||
// ============================================================================
|
||||
// ============================================================================
|
||||
// Formatting
|
||||
// ============================================================================
|
||||
|
||||
|
||||
@@ -397,25 +397,17 @@ pub fn verify_envelope(session_id: &str, working_dir: &Path) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
// Return a summary for the tool output
|
||||
// NOTE: The token value is intentionally NOT included in the output
|
||||
// returned to the LLM. It exists only in the envelope.yaml file.
|
||||
let summary = if result.failed_count == 0 {
|
||||
// Summary for tool output (token value intentionally omitted — LLM must not see it)
|
||||
let total = result.passed_count + result.failed_count;
|
||||
if result.failed_count == 0 {
|
||||
format!(
|
||||
"\n✅ Invariant verification: {}/{} passed\n",
|
||||
result.passed_count,
|
||||
result.passed_count + result.failed_count,
|
||||
"\n✅ Invariant verification: {}/{} passed\n", result.passed_count, total,
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"\n⚠️ Invariant verification: {}/{} passed, {} failed\n",
|
||||
result.passed_count,
|
||||
result.passed_count + result.failed_count,
|
||||
result.failed_count,
|
||||
"\n⚠️ Invariant verification: {}/{} passed, {} failed\n", result.passed_count, total, result.failed_count,
|
||||
)
|
||||
};
|
||||
|
||||
summary
|
||||
}
|
||||
}
|
||||
|
||||
/// Stamp an envelope with a verification token and re-write it to disk.
|
||||
@@ -432,17 +424,13 @@ fn stamp_envelope(
|
||||
) -> Result<()> {
|
||||
let key = get_or_create_verification_key()?;
|
||||
|
||||
// Compute token over the original envelope (without any previous verified field)
|
||||
// Compute token over the envelope without any previous verified field
|
||||
let mut clean_envelope = envelope.clone();
|
||||
clean_envelope.verified = None;
|
||||
let token = mint_token(&key, &clean_envelope, rulespec);
|
||||
|
||||
// Set the verified field and re-write
|
||||
let mut stamped = envelope.clone();
|
||||
stamped.verified = Some(token);
|
||||
write_envelope(session_id, &stamped)?;
|
||||
|
||||
Ok(())
|
||||
clean_envelope.verified = Some(token);
|
||||
write_envelope(session_id, &clean_envelope)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
||||
@@ -559,234 +559,139 @@ pub fn evaluate_predicate(
|
||||
selected_values: &[YamlValue],
|
||||
) -> PredicateResult {
|
||||
match predicate.rule {
|
||||
PredicateRule::Exists => {
|
||||
// Filter out null values — null means "absent"
|
||||
let non_null: Vec<_> = selected_values.iter()
|
||||
.filter(|v| !v.is_null())
|
||||
.collect();
|
||||
if non_null.is_empty() {
|
||||
PredicateResult::fail("Value does not exist")
|
||||
PredicateRule::Exists => eval_yaml_existence(selected_values, true),
|
||||
PredicateRule::NotExists => eval_yaml_existence(selected_values, false),
|
||||
PredicateRule::Contains => eval_yaml_containment(predicate, selected_values, true),
|
||||
PredicateRule::NotContains => eval_yaml_containment(predicate, selected_values, false),
|
||||
PredicateRule::Equals => eval_yaml_equals(predicate, selected_values),
|
||||
PredicateRule::MinLength => eval_yaml_length(predicate, selected_values, |len, n| len >= n, "min"),
|
||||
PredicateRule::MaxLength => eval_yaml_length(predicate, selected_values, |len, n| len <= n, "max"),
|
||||
PredicateRule::GreaterThan => eval_yaml_numeric(predicate, selected_values, |v, t| v > t, ">"),
|
||||
PredicateRule::LessThan => eval_yaml_numeric(predicate, selected_values, |v, t| v < t, "<"),
|
||||
PredicateRule::Matches => eval_yaml_matches(predicate, selected_values),
|
||||
PredicateRule::AnyOf => eval_yaml_set(predicate, selected_values, true),
|
||||
PredicateRule::NoneOf => eval_yaml_set(predicate, selected_values, false),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Per-rule evaluation helpers ─────────────────────────────────────────
|
||||
|
||||
/// Exists / NotExists: any non-null value counts as "present".
|
||||
fn eval_yaml_existence(values: &[YamlValue], expect_present: bool) -> PredicateResult {
|
||||
let has_non_null = values.iter().any(|v| !v.is_null());
|
||||
match (expect_present, has_non_null) {
|
||||
(true, true) => PredicateResult::pass("Value exists"),
|
||||
(true, false) => PredicateResult::fail("Value does not exist"),
|
||||
(false, false) => PredicateResult::pass("Value does not exist as expected"),
|
||||
(false, true) => PredicateResult::fail("Value exists but should not"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Contains / NotContains: search selected values using `value_contains`.
|
||||
fn eval_yaml_containment(pred: &Predicate, values: &[YamlValue], expect_present: bool) -> PredicateResult {
|
||||
let Some(target) = &pred.value else {
|
||||
let rule_name = if expect_present { "contains" } else { "not_contains" };
|
||||
return PredicateResult::fail(format!("No value specified for {}", rule_name));
|
||||
};
|
||||
let found = values.iter().any(|v| value_contains(v, target));
|
||||
let display = yaml_to_display(target);
|
||||
match (expect_present, found) {
|
||||
(true, true) => PredicateResult::pass(format!("Value contains {:?}", display)),
|
||||
(true, false) => PredicateResult::fail(format!("Value does not contain {:?}", display)),
|
||||
(false, false) => PredicateResult::pass(format!("Value does not contain {:?}", display)),
|
||||
(false, true) => PredicateResult::fail(format!("Value contains {:?} but should not", display)),
|
||||
}
|
||||
}
|
||||
|
||||
fn eval_yaml_equals(pred: &Predicate, values: &[YamlValue]) -> PredicateResult {
|
||||
let Some(target) = &pred.value else {
|
||||
return PredicateResult::fail("No value specified for equals");
|
||||
};
|
||||
if values.len() != 1 {
|
||||
return PredicateResult::fail(format!("Expected single value for equals, got {}", values.len()));
|
||||
}
|
||||
if &values[0] == target {
|
||||
PredicateResult::pass("Values are equal")
|
||||
} else {
|
||||
PredicateResult::fail(format!("Values not equal: {:?} != {:?}", yaml_to_display(&values[0]), yaml_to_display(target)))
|
||||
}
|
||||
}
|
||||
|
||||
/// MinLength / MaxLength: find the first Sequence and compare its length.
|
||||
fn eval_yaml_length(pred: &Predicate, values: &[YamlValue], cmp: fn(usize, usize) -> bool, label: &str) -> PredicateResult {
|
||||
let threshold = match &pred.value {
|
||||
Some(YamlValue::Number(n)) => n.as_u64().unwrap_or(0) as usize,
|
||||
_ => return PredicateResult::fail(format!("{}_length requires a numeric value", label)),
|
||||
};
|
||||
for value in values {
|
||||
if let YamlValue::Sequence(seq) = value {
|
||||
return if cmp(seq.len(), threshold) {
|
||||
PredicateResult::pass(format!("Array has {} elements ({}: {})", seq.len(), label, threshold))
|
||||
} else {
|
||||
PredicateResult::pass("Value exists")
|
||||
}
|
||||
PredicateResult::fail(format!("Array has {} elements ({}: {})", seq.len(), label, threshold))
|
||||
};
|
||||
}
|
||||
PredicateRule::NotExists => {
|
||||
// Filter out null values — null means "absent"
|
||||
let non_null: Vec<_> = selected_values.iter()
|
||||
.filter(|v| !v.is_null())
|
||||
.collect();
|
||||
if non_null.is_empty() {
|
||||
PredicateResult::pass("Value does not exist as expected")
|
||||
}
|
||||
PredicateResult::fail("Value is not an array")
|
||||
}
|
||||
|
||||
/// GreaterThan / LessThan: find the first Number and compare.
|
||||
fn eval_yaml_numeric(pred: &Predicate, values: &[YamlValue], cmp: fn(f64, f64) -> bool, op: &str) -> PredicateResult {
|
||||
let target = match &pred.value {
|
||||
Some(YamlValue::Number(n)) => n.as_f64().unwrap_or(0.0),
|
||||
_ => return PredicateResult::fail(format!("{} requires a numeric value", pred.rule)),
|
||||
};
|
||||
for value in values {
|
||||
if let YamlValue::Number(n) = value {
|
||||
let v = n.as_f64().unwrap_or(0.0);
|
||||
return if cmp(v, target) {
|
||||
PredicateResult::pass(format!("{} {} {}", v, op, target))
|
||||
} else {
|
||||
PredicateResult::fail("Value exists but should not")
|
||||
PredicateResult::fail(format!("{} is not {} {}", v, op, target))
|
||||
};
|
||||
}
|
||||
}
|
||||
PredicateResult::fail("Value is not a number")
|
||||
}
|
||||
|
||||
fn eval_yaml_matches(pred: &Predicate, values: &[YamlValue]) -> PredicateResult {
|
||||
let Some(YamlValue::String(pattern)) = &pred.value else {
|
||||
return PredicateResult::fail("matches requires a string pattern");
|
||||
};
|
||||
let regex = match regex::Regex::new(pattern) {
|
||||
Ok(r) => r,
|
||||
Err(e) => return PredicateResult::fail(format!("Invalid regex: {}", e)),
|
||||
};
|
||||
for value in values {
|
||||
if let YamlValue::String(s) = value {
|
||||
if regex.is_match(s) {
|
||||
return PredicateResult::pass(format!("'{}' matches pattern", s));
|
||||
}
|
||||
}
|
||||
PredicateRule::Contains => {
|
||||
let target = match &predicate.value {
|
||||
Some(v) => v,
|
||||
None => return PredicateResult::fail("No value specified for contains"),
|
||||
};
|
||||
|
||||
for value in selected_values {
|
||||
if value_contains(value, target) {
|
||||
return PredicateResult::pass(format!(
|
||||
"Value contains {:?}",
|
||||
yaml_to_display(target)
|
||||
));
|
||||
}
|
||||
}
|
||||
PredicateResult::fail(format!(
|
||||
"Value does not contain {:?}",
|
||||
yaml_to_display(target)
|
||||
))
|
||||
}
|
||||
PredicateResult::fail(format!("No value matches pattern '{}'", pattern))
|
||||
}
|
||||
|
||||
/// AnyOf / NoneOf: check if selected values are in (or not in) a set.
|
||||
fn eval_yaml_set(pred: &Predicate, values: &[YamlValue], expect_in_set: bool) -> PredicateResult {
|
||||
let label = if expect_in_set { "any_of" } else { "none_of" };
|
||||
let set = match &pred.value {
|
||||
Some(YamlValue::Sequence(seq)) => seq,
|
||||
Some(_) => return PredicateResult::fail(format!("{} requires an array value", label)),
|
||||
None => return PredicateResult::fail(format!("No value specified for {}", label)),
|
||||
};
|
||||
let found = values.iter().any(|v| set.contains(v));
|
||||
let set_display = set.iter().map(yaml_to_display).collect::<Vec<_>>().join(", ");
|
||||
match (expect_in_set, found) {
|
||||
(true, true) => {
|
||||
let matched = values.iter().find(|v| set.contains(v)).unwrap();
|
||||
PredicateResult::pass(format!("Value {:?} is in allowed set", yaml_to_display(matched)))
|
||||
}
|
||||
PredicateRule::Equals => {
|
||||
let target = match &predicate.value {
|
||||
Some(v) => v,
|
||||
None => return PredicateResult::fail("No value specified for equals"),
|
||||
};
|
||||
|
||||
if selected_values.len() != 1 {
|
||||
return PredicateResult::fail(format!(
|
||||
"Expected single value for equals, got {}",
|
||||
selected_values.len()
|
||||
));
|
||||
}
|
||||
|
||||
if &selected_values[0] == target {
|
||||
PredicateResult::pass("Values are equal")
|
||||
} else {
|
||||
PredicateResult::fail(format!(
|
||||
"Values not equal: {:?} != {:?}",
|
||||
yaml_to_display(&selected_values[0]),
|
||||
yaml_to_display(target)
|
||||
))
|
||||
}
|
||||
}
|
||||
PredicateRule::MinLength => {
|
||||
let min = match &predicate.value {
|
||||
Some(YamlValue::Number(n)) => n.as_u64().unwrap_or(0) as usize,
|
||||
_ => return PredicateResult::fail("min_length requires a numeric value"),
|
||||
};
|
||||
|
||||
for value in selected_values {
|
||||
if let YamlValue::Sequence(seq) = value {
|
||||
if seq.len() >= min {
|
||||
return PredicateResult::pass(format!(
|
||||
"Array has {} elements (min: {})",
|
||||
seq.len(),
|
||||
min
|
||||
));
|
||||
} else {
|
||||
return PredicateResult::fail(format!(
|
||||
"Array has {} elements (min: {})",
|
||||
seq.len(),
|
||||
min
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
PredicateResult::fail("Value is not an array")
|
||||
}
|
||||
PredicateRule::MaxLength => {
|
||||
let max = match &predicate.value {
|
||||
Some(YamlValue::Number(n)) => n.as_u64().unwrap_or(0) as usize,
|
||||
_ => return PredicateResult::fail("max_length requires a numeric value"),
|
||||
};
|
||||
|
||||
for value in selected_values {
|
||||
if let YamlValue::Sequence(seq) = value {
|
||||
if seq.len() <= max {
|
||||
return PredicateResult::pass(format!(
|
||||
"Array has {} elements (max: {})",
|
||||
seq.len(),
|
||||
max
|
||||
));
|
||||
} else {
|
||||
return PredicateResult::fail(format!(
|
||||
"Array has {} elements (max: {})",
|
||||
seq.len(),
|
||||
max
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
PredicateResult::fail("Value is not an array")
|
||||
}
|
||||
PredicateRule::GreaterThan => {
|
||||
let target = match &predicate.value {
|
||||
Some(YamlValue::Number(n)) => n.as_f64().unwrap_or(0.0),
|
||||
_ => return PredicateResult::fail("greater_than requires a numeric value"),
|
||||
};
|
||||
|
||||
for value in selected_values {
|
||||
if let YamlValue::Number(n) = value {
|
||||
let v = n.as_f64().unwrap_or(0.0);
|
||||
if v > target {
|
||||
return PredicateResult::pass(format!("{} > {}", v, target));
|
||||
} else {
|
||||
return PredicateResult::fail(format!("{} is not > {}", v, target));
|
||||
}
|
||||
}
|
||||
}
|
||||
PredicateResult::fail("Value is not a number")
|
||||
}
|
||||
PredicateRule::LessThan => {
|
||||
let target = match &predicate.value {
|
||||
Some(YamlValue::Number(n)) => n.as_f64().unwrap_or(0.0),
|
||||
_ => return PredicateResult::fail("less_than requires a numeric value"),
|
||||
};
|
||||
|
||||
for value in selected_values {
|
||||
if let YamlValue::Number(n) = value {
|
||||
let v = n.as_f64().unwrap_or(0.0);
|
||||
if v < target {
|
||||
return PredicateResult::pass(format!("{} < {}", v, target));
|
||||
} else {
|
||||
return PredicateResult::fail(format!("{} is not < {}", v, target));
|
||||
}
|
||||
}
|
||||
}
|
||||
PredicateResult::fail("Value is not a number")
|
||||
}
|
||||
PredicateRule::Matches => {
|
||||
let pattern = match &predicate.value {
|
||||
Some(YamlValue::String(s)) => s,
|
||||
_ => return PredicateResult::fail("matches requires a string pattern"),
|
||||
};
|
||||
|
||||
let regex = match regex::Regex::new(pattern) {
|
||||
Ok(r) => r,
|
||||
Err(e) => return PredicateResult::fail(format!("Invalid regex: {}", e)),
|
||||
};
|
||||
|
||||
for value in selected_values {
|
||||
if let YamlValue::String(s) = value {
|
||||
if regex.is_match(s) {
|
||||
return PredicateResult::pass(format!("'{}' matches pattern", s));
|
||||
}
|
||||
}
|
||||
}
|
||||
PredicateResult::fail(format!("No value matches pattern '{}'", pattern))
|
||||
}
|
||||
PredicateRule::NotContains => {
|
||||
let target = match &predicate.value {
|
||||
Some(v) => v,
|
||||
None => return PredicateResult::fail("No value specified for not_contains"),
|
||||
};
|
||||
|
||||
for value in selected_values {
|
||||
if value_contains(value, target) {
|
||||
return PredicateResult::fail(format!(
|
||||
"Value contains {:?} but should not",
|
||||
yaml_to_display(target)
|
||||
));
|
||||
}
|
||||
}
|
||||
PredicateResult::pass(format!(
|
||||
"Value does not contain {:?}",
|
||||
yaml_to_display(target)
|
||||
))
|
||||
}
|
||||
PredicateRule::AnyOf => {
|
||||
let allowed = match &predicate.value {
|
||||
Some(YamlValue::Sequence(seq)) => seq,
|
||||
Some(_) => return PredicateResult::fail("any_of requires an array value"),
|
||||
None => return PredicateResult::fail("No value specified for any_of"),
|
||||
};
|
||||
|
||||
for value in selected_values {
|
||||
if allowed.contains(value) {
|
||||
return PredicateResult::pass(format!(
|
||||
"Value {:?} is in allowed set",
|
||||
yaml_to_display(value)
|
||||
));
|
||||
}
|
||||
}
|
||||
PredicateResult::fail(format!(
|
||||
"Value is not in allowed set [{}]",
|
||||
allowed.iter().map(yaml_to_display).collect::<Vec<_>>().join(", ")
|
||||
))
|
||||
}
|
||||
PredicateRule::NoneOf => {
|
||||
let forbidden = match &predicate.value {
|
||||
Some(YamlValue::Sequence(seq)) => seq,
|
||||
Some(_) => return PredicateResult::fail("none_of requires an array value"),
|
||||
None => return PredicateResult::fail("No value specified for none_of"),
|
||||
};
|
||||
|
||||
for value in selected_values {
|
||||
if forbidden.contains(value) {
|
||||
return PredicateResult::fail(format!(
|
||||
"Value {:?} is in forbidden set",
|
||||
yaml_to_display(value)
|
||||
));
|
||||
}
|
||||
}
|
||||
PredicateResult::pass(format!(
|
||||
"Value is not in forbidden set [{}]",
|
||||
forbidden.iter().map(yaml_to_display).collect::<Vec<_>>().join(", ")
|
||||
))
|
||||
(true, false) => PredicateResult::fail(format!("Value is not in allowed set [{}]", set_display)),
|
||||
(false, false) => PredicateResult::pass(format!("Value is not in forbidden set [{}]", set_display)),
|
||||
(false, true) => {
|
||||
let matched = values.iter().find(|v| set.contains(v)).unwrap();
|
||||
PredicateResult::fail(format!("Value {:?} is in forbidden set", yaml_to_display(matched)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,6 +135,134 @@ pub struct AnthropicProvider {
|
||||
thinking_budget_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
// ── SSE Stream State ────────────────────────────────────────────────────
|
||||
// Mutable state threaded through Anthropic's SSE stream parser.
|
||||
// Each `handle_*` method processes one event type and returns chunks to send.
|
||||
|
||||
struct StreamState {
|
||||
tool_calls: Vec<ToolCall>,
|
||||
partial_tool_json: String,
|
||||
usage: Option<Usage>,
|
||||
message_stopped: bool,
|
||||
stop_reason: Option<String>,
|
||||
}
|
||||
|
||||
impl StreamState {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
tool_calls: Vec::new(),
|
||||
partial_tool_json: String::new(),
|
||||
usage: None,
|
||||
message_stopped: false,
|
||||
stop_reason: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_message_start(&mut self, event: &AnthropicStreamEvent) {
|
||||
if let Some(message) = &event.message {
|
||||
if let Some(u) = &message.usage {
|
||||
self.usage = Some(Usage {
|
||||
prompt_tokens: u.input_tokens,
|
||||
completion_tokens: u.output_tokens,
|
||||
total_tokens: u.input_tokens + u.output_tokens,
|
||||
cache_creation_tokens: u.cache_creation_input_tokens,
|
||||
cache_read_tokens: u.cache_read_input_tokens,
|
||||
});
|
||||
debug!("Captured usage from message_start: {:?}", self.usage);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns chunks to send for a content_block_start event.
|
||||
fn handle_block_start(&mut self, event: AnthropicStreamEvent) -> Vec<Result<CompletionChunk>> {
|
||||
let Some(content_block) = event.content_block else { return vec![] };
|
||||
match content_block {
|
||||
AnthropicContent::ToolUse { id, name, input } => {
|
||||
debug!("Tool use block: id={}, name={}, input={:?}", id, name, input);
|
||||
let tool_call = ToolCall { id: id.clone(), tool: name.clone(), args: input.clone() };
|
||||
|
||||
let has_complete_args = !input.is_null()
|
||||
&& input != serde_json::Value::Object(serde_json::Map::new());
|
||||
|
||||
if has_complete_args {
|
||||
debug!("Tool call has complete args, sending immediately");
|
||||
vec![Ok(make_tool_chunk(vec![tool_call]))]
|
||||
} else {
|
||||
debug!("Tool call has empty args, will accumulate from partial_json");
|
||||
let hint = make_tool_streaming_hint(name);
|
||||
self.tool_calls.push(tool_call);
|
||||
self.partial_tool_json.clear();
|
||||
vec![Ok(hint)]
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
debug!("Non-tool content block: {:?}", content_block);
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns chunks to send for a content_block_delta event.
|
||||
fn handle_block_delta(&mut self, event: AnthropicStreamEvent) -> Vec<Result<CompletionChunk>> {
|
||||
let Some(delta) = event.delta else { return vec![] };
|
||||
let mut chunks = Vec::new();
|
||||
if let Some(text) = delta.text {
|
||||
debug!("Text chunk (len {})", text.len());
|
||||
chunks.push(Ok(make_text_chunk(text)));
|
||||
}
|
||||
if let Some(json_fragment) = delta.partial_json {
|
||||
debug!("Partial JSON: {}", json_fragment);
|
||||
self.partial_tool_json.push_str(&json_fragment);
|
||||
chunks.push(Ok(make_tool_streaming_active()));
|
||||
}
|
||||
chunks
|
||||
}
|
||||
|
||||
/// Returns chunks to send when a content block finishes.
|
||||
fn handle_block_stop(&mut self) -> Vec<Result<CompletionChunk>> {
|
||||
// Finalize accumulated partial JSON into the last tool call's args
|
||||
if !self.tool_calls.is_empty() && !self.partial_tool_json.is_empty() {
|
||||
debug!("Parsing complete tool JSON: {}", self.partial_tool_json);
|
||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&self.partial_tool_json) {
|
||||
if let Some(last) = self.tool_calls.last_mut() {
|
||||
last.args = parsed;
|
||||
debug!("Updated tool call with complete args: {:?}", last);
|
||||
}
|
||||
} else {
|
||||
debug!("Failed to parse accumulated JSON: {}", self.partial_tool_json);
|
||||
}
|
||||
self.partial_tool_json.clear();
|
||||
}
|
||||
|
||||
if self.tool_calls.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
let chunk = make_tool_chunk(self.tool_calls.clone());
|
||||
self.tool_calls.clear();
|
||||
vec![Ok(chunk)]
|
||||
}
|
||||
|
||||
fn handle_message_delta(&mut self, event: &AnthropicStreamEvent) {
|
||||
if let Some(delta) = &event.delta {
|
||||
if let Some(reason) = &delta.stop_reason {
|
||||
debug!("Received stop_reason: {}", reason);
|
||||
self.stop_reason = Some(reason.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_message_stop(&mut self) -> Vec<Result<CompletionChunk>> {
|
||||
debug!("Received message stop event");
|
||||
self.message_stopped = true;
|
||||
let chunk = make_final_chunk_with_reason(
|
||||
self.tool_calls.clone(),
|
||||
self.usage.clone(),
|
||||
self.stop_reason.clone(),
|
||||
);
|
||||
vec![Ok(chunk)]
|
||||
}
|
||||
}
|
||||
|
||||
impl AnthropicProvider {
|
||||
pub fn new(
|
||||
api_key: String,
|
||||
@@ -501,267 +629,76 @@ impl AnthropicProvider {
|
||||
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
|
||||
tx: mpsc::Sender<Result<CompletionChunk>>,
|
||||
) -> Option<Usage> {
|
||||
let mut buffer = String::new();
|
||||
let mut current_tool_calls: Vec<ToolCall> = Vec::new();
|
||||
let mut partial_tool_json = String::new(); // Accumulate partial JSON for tool calls
|
||||
let mut accumulated_usage: Option<Usage> = None;
|
||||
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
|
||||
let mut message_stopped = false; // Track if we've received message_stop
|
||||
let mut stop_reason: Option<String> = None; // Track why the message stopped
|
||||
let mut state = StreamState::new();
|
||||
let mut line_buffer = String::new();
|
||||
let mut byte_buffer: Vec<u8> = Vec::new();
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
byte_buffer.extend_from_slice(&chunk);
|
||||
|
||||
let Some(chunk_str) = decode_utf8_streaming(&mut byte_buffer) else {
|
||||
continue;
|
||||
};
|
||||
line_buffer.push_str(&chunk_str);
|
||||
|
||||
buffer.push_str(&chunk_str);
|
||||
while let Some(line_end) = line_buffer.find('\n') {
|
||||
let line = line_buffer[..line_end].trim().to_string();
|
||||
line_buffer.drain(..line_end + 1);
|
||||
|
||||
// Process complete lines
|
||||
while let Some(line_end) = buffer.find('\n') {
|
||||
let line = buffer[..line_end].trim().to_string();
|
||||
buffer.drain(..line_end + 1);
|
||||
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If we've already sent the final chunk, skip processing more events
|
||||
if message_stopped {
|
||||
debug!("Skipping event after message_stop: {}", line);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse Server-Sent Events format
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
if data == "[DONE]" {
|
||||
debug!("Received stream completion marker");
|
||||
let final_chunk = make_final_chunk(
|
||||
current_tool_calls.clone(),
|
||||
accumulated_usage.clone(),
|
||||
);
|
||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
}
|
||||
return accumulated_usage;
|
||||
if line.is_empty() || state.message_stopped {
|
||||
if state.message_stopped && !line.is_empty() {
|
||||
debug!("Skipping event after message_stop: {}", line);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
debug!("Raw Claude API JSON: {}", data);
|
||||
let Some(data) = line.strip_prefix("data: ") else { continue };
|
||||
|
||||
match serde_json::from_str::<AnthropicStreamEvent>(data) {
|
||||
Ok(event) => {
|
||||
debug!(
|
||||
"Parsed event type: {}, event: {:?}",
|
||||
event.event_type, event
|
||||
);
|
||||
match event.event_type.as_str() {
|
||||
"message_start" => {
|
||||
// Extract usage data from message_start event
|
||||
if let Some(message) = event.message {
|
||||
if let Some(usage) = message.usage {
|
||||
accumulated_usage = Some(Usage {
|
||||
prompt_tokens: usage.input_tokens,
|
||||
completion_tokens: usage.output_tokens,
|
||||
total_tokens: usage.input_tokens
|
||||
+ usage.output_tokens,
|
||||
cache_creation_tokens: usage
|
||||
.cache_creation_input_tokens,
|
||||
cache_read_tokens: usage
|
||||
.cache_read_input_tokens,
|
||||
});
|
||||
debug!(
|
||||
"Captured usage from message_start: {:?}",
|
||||
accumulated_usage
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
"content_block_start" => {
|
||||
debug!(
|
||||
"Received content_block_start event: {:?}",
|
||||
event
|
||||
);
|
||||
if let Some(content_block) = event.content_block {
|
||||
match content_block {
|
||||
AnthropicContent::ToolUse {
|
||||
id,
|
||||
name,
|
||||
input,
|
||||
} => {
|
||||
debug!("Found tool use in content_block_start: id={}, name={}, input={:?}", id, name, input);
|
||||
// Stream completion marker
|
||||
if data == "[DONE]" {
|
||||
debug!("Received stream completion marker");
|
||||
let final_chunk = make_final_chunk(state.tool_calls.clone(), state.usage.clone());
|
||||
let _ = tx.send(Ok(final_chunk)).await;
|
||||
return state.usage;
|
||||
}
|
||||
|
||||
// For native tool calls, create the tool call immediately if we have complete args
|
||||
// If args are empty, we'll wait for partial_json to accumulate them
|
||||
let tool_call = ToolCall {
|
||||
id: id.clone(),
|
||||
tool: name.clone(),
|
||||
args: input.clone(),
|
||||
};
|
||||
debug!("Raw Claude API JSON: {}", data);
|
||||
|
||||
// Check if we already have complete arguments
|
||||
if !input.is_null()
|
||||
&& input
|
||||
!= serde_json::Value::Object(
|
||||
serde_json::Map::new(),
|
||||
)
|
||||
{
|
||||
// We have complete arguments, send the tool call immediately
|
||||
debug!("Tool call has complete args, sending immediately: {:?}", tool_call);
|
||||
let chunk =
|
||||
make_tool_chunk(vec![tool_call]);
|
||||
if tx.send(Ok(chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
return accumulated_usage;
|
||||
}
|
||||
} else {
|
||||
// Arguments are empty, we'll accumulate them from partial_json
|
||||
debug!("Tool call has empty args, will accumulate from partial_json");
|
||||
// Send a streaming hint so the UI can show the tool name immediately
|
||||
let hint_chunk = make_tool_streaming_hint(name.clone());
|
||||
if tx.send(Ok(hint_chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
return accumulated_usage;
|
||||
}
|
||||
current_tool_calls.push(tool_call);
|
||||
partial_tool_json.clear();
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
debug!(
|
||||
"Non-tool content block: {:?}",
|
||||
content_block
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"content_block_delta" => {
|
||||
if let Some(delta) = event.delta {
|
||||
if let Some(text) = delta.text {
|
||||
debug!(
|
||||
"Sending text chunk of length {}: '{}'",
|
||||
text.len(),
|
||||
text
|
||||
);
|
||||
let chunk = make_text_chunk(text);
|
||||
if tx.send(Ok(chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
return accumulated_usage;
|
||||
}
|
||||
}
|
||||
// Handle partial JSON for tool calls
|
||||
if let Some(partial_json) = delta.partial_json {
|
||||
debug!(
|
||||
"Received partial JSON: {}",
|
||||
partial_json
|
||||
);
|
||||
partial_tool_json.push_str(&partial_json);
|
||||
debug!(
|
||||
"Accumulated tool JSON: {}",
|
||||
partial_tool_json
|
||||
);
|
||||
// Send an active hint to trigger UI blink
|
||||
let active_chunk = make_tool_streaming_active();
|
||||
if tx.send(Ok(active_chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
return accumulated_usage;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"content_block_stop" => {
|
||||
// Tool call block is complete - now parse the accumulated JSON
|
||||
if !current_tool_calls.is_empty()
|
||||
&& !partial_tool_json.is_empty()
|
||||
{
|
||||
debug!(
|
||||
"Parsing complete tool JSON: {}",
|
||||
partial_tool_json
|
||||
);
|
||||
let event = match serde_json::from_str::<AnthropicStreamEvent>(data) {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
debug!("Failed to parse stream event: {} - Data: {}", e, data);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Parse the accumulated JSON and update the last tool call
|
||||
if let Ok(parsed_args) =
|
||||
serde_json::from_str::<serde_json::Value>(
|
||||
&partial_tool_json,
|
||||
)
|
||||
{
|
||||
if let Some(last_tool) =
|
||||
current_tool_calls.last_mut()
|
||||
{
|
||||
last_tool.args = parsed_args;
|
||||
debug!("Updated tool call with complete args: {:?}", last_tool);
|
||||
}
|
||||
} else {
|
||||
debug!(
|
||||
"Failed to parse accumulated JSON: {}",
|
||||
partial_tool_json
|
||||
);
|
||||
}
|
||||
debug!("Parsed event type: {}", event.event_type);
|
||||
|
||||
// Clear the accumulator
|
||||
partial_tool_json.clear();
|
||||
}
|
||||
|
||||
// Send the complete tool call
|
||||
if !current_tool_calls.is_empty() {
|
||||
let chunk =
|
||||
make_tool_chunk(current_tool_calls.clone());
|
||||
if tx.send(Ok(chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
return accumulated_usage;
|
||||
}
|
||||
// Clear tool calls after sending to prevent duplicates at message_stop
|
||||
current_tool_calls.clear();
|
||||
}
|
||||
}
|
||||
"message_delta" => {
|
||||
// message_delta contains the stop_reason and final usage
|
||||
if let Some(delta) = &event.delta {
|
||||
if let Some(reason) = &delta.stop_reason {
|
||||
debug!("Received stop_reason: {}", reason);
|
||||
stop_reason = Some(reason.clone());
|
||||
}
|
||||
}
|
||||
// Usage is also in message_delta but we get it from message_start
|
||||
}
|
||||
"message_stop" => {
|
||||
debug!("Received message stop event");
|
||||
message_stopped = true;
|
||||
let final_chunk = make_final_chunk_with_reason(
|
||||
current_tool_calls.clone(),
|
||||
accumulated_usage.clone(),
|
||||
stop_reason.clone(),
|
||||
);
|
||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
}
|
||||
// Don't return here - let the stream naturally exhaust
|
||||
// This prevents dropping the sender prematurely
|
||||
}
|
||||
"error" => {
|
||||
if let Some(error) = event.error {
|
||||
error!("Anthropic API error: {:?}", error);
|
||||
let _ = tx
|
||||
.send(Err(anyhow!(
|
||||
"Anthropic API error: {:?}",
|
||||
error
|
||||
)))
|
||||
.await;
|
||||
break; // Break to let stream exhaust naturally
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
debug!("Ignoring event type: {}", event.event_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to parse stream event: {} - Data: {}", e, data);
|
||||
// Don't error out on parse failures, just continue
|
||||
// Dispatch to per-event handlers; collect chunks to send
|
||||
let chunks: Vec<Result<CompletionChunk>> = match event.event_type.as_str() {
|
||||
"message_start" => { state.handle_message_start(&event); vec![] }
|
||||
"content_block_start" => state.handle_block_start(event),
|
||||
"content_block_delta" => state.handle_block_delta(event),
|
||||
"content_block_stop" => state.handle_block_stop(),
|
||||
"message_delta" => { state.handle_message_delta(&event); vec![] }
|
||||
"message_stop" => state.handle_message_stop(),
|
||||
"error" => {
|
||||
if let Some(error) = event.error {
|
||||
error!("Anthropic API error: {:?}", error);
|
||||
let _ = tx.send(Err(anyhow!("Anthropic API error: {:?}", error))).await;
|
||||
break;
|
||||
}
|
||||
vec![]
|
||||
}
|
||||
_ => { debug!("Ignoring event type: {}", event.event_type); vec![] }
|
||||
};
|
||||
|
||||
// Send all chunks produced by the handler
|
||||
for chunk in chunks {
|
||||
if tx.send(chunk).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
return state.usage;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -769,18 +706,14 @@ impl AnthropicProvider {
|
||||
Err(e) => {
|
||||
error!("Stream error: {}", e);
|
||||
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
||||
// Don't return here either - let the stream exhaust naturally
|
||||
// The error has been sent to the receiver, so it will handle it
|
||||
// Breaking here ensures we clean up properly
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send final chunk if we haven't already
|
||||
let final_chunk = make_final_chunk(current_tool_calls, accumulated_usage.clone());
|
||||
let final_chunk = make_final_chunk(state.tool_calls, state.usage.clone());
|
||||
let _ = tx.send(Ok(final_chunk)).await;
|
||||
accumulated_usage
|
||||
state.usage
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user