From 5044e6ac1dea07ac70bd13c64636feb82af3f5d9 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 6 Nov 2025 18:05:18 -0800 Subject: [PATCH] zeta2: Make eval example file format more expressive (#42156) * Allow expressing alternative possible context fetches in `Expected Context` section * Allow marking a subset of lines as "required" in `Expected Context`. We still need to improve how we display the results. I've removed the context pass/fail pretty printing for now, because it would need to be rethought to work with the new structure, but for now I think we should focus on getting basic predictions to run. But this is progress toward a better structure for eval examples. Release Notes: - N/A --------- Co-authored-by: Oleksiy Syvokon Co-authored-by: Ben Kunkle Co-authored-by: Agus Zubiaga --- crates/zeta_cli/src/evaluate.rs | 195 ++++++++++------------- crates/zeta_cli/src/example.rs | 267 +++++++++++++++++++++++++++----- 2 files changed, 306 insertions(+), 156 deletions(-) diff --git a/crates/zeta_cli/src/evaluate.rs b/crates/zeta_cli/src/evaluate.rs index 5ffdd8ccff6601cf99b2bb3237f46cab224b0daf..f99747e676b777e5d7a086c61db2f9e8d152c20b 100644 --- a/crates/zeta_cli/src/evaluate.rs +++ b/crates/zeta_cli/src/evaluate.rs @@ -83,11 +83,6 @@ pub async fn run_evaluate_one( let evaluation_result = evaluate(&example.example, &predictions); - println!("# {}\n", example.name); - println!( - "## Expected Context: \n\n```\n{}\n```\n\n", - compare_context(&example.example, &predictions) - ); println!( "## Expected edit prediction:\n\n```diff\n{}\n```\n", compare_diffs(&example.example.expected_patch, &predictions.diff) @@ -104,21 +99,30 @@ pub async fn run_evaluate_one( #[derive(Debug, Default)] pub struct EvaluationResult { - pub context: Scores, pub edit_prediction: Scores, + pub context: Scores, } #[derive(Default, Debug)] pub struct Scores { - pub precision: f64, - pub recall: f64, - pub f1_score: f64, pub true_positives: usize, pub false_positives: usize, pub false_negatives: usize, } impl Scores { + pub fn new(expected: &HashSet, actual: &HashSet) -> Scores { + let true_positives = expected.intersection(actual).count(); + let false_positives = actual.difference(expected).count(); + let false_negatives = expected.difference(actual).count(); + + Scores { + true_positives, + false_positives, + false_negatives, + } + } + pub fn to_markdown(&self) -> String { format!( " @@ -128,17 +132,15 @@ F1 Score : {:.4} True Positives : {} False Positives : {} False Negatives : {}", - self.precision, - self.recall, - self.f1_score, + self.precision(), + self.recall(), + self.f1_score(), self.true_positives, self.false_positives, self.false_negatives ) } -} -impl Scores { pub fn aggregate<'a>(scores: impl Iterator) -> Scores { let mut true_positives = 0; let mut false_positives = 0; @@ -150,22 +152,38 @@ impl Scores { false_negatives += score.false_negatives; } - let precision = true_positives as f64 / (true_positives + false_positives) as f64; - let recall = true_positives as f64 / (true_positives + false_negatives) as f64; - let mut f1_score = 2.0 * precision * recall / (precision + recall); - if f1_score.is_nan() { - f1_score = 0.0; - } - Scores { - precision, - recall, - f1_score, true_positives, false_positives, false_negatives, } } + + pub fn precision(&self) -> f64 { + if self.true_positives + self.false_positives == 0 { + 0.0 + } else { + self.true_positives as f64 / (self.true_positives + self.false_positives) as f64 + } + } + + pub fn recall(&self) -> f64 { + if self.true_positives + self.false_negatives == 0 { + 0.0 + } else { + self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64 + } + } + + pub fn f1_score(&self) -> f64 { + let recall = self.recall(); + let precision = self.precision(); + if precision + recall == 0.0 { + 0.0 + } else { + 2.0 * precision * recall / (precision + recall) + } + } } impl EvaluationResult { @@ -185,19 +203,9 @@ impl EvaluationResult { } pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult { - let mut result = EvaluationResult::default(); + let mut eval_result = EvaluationResult::default(); - let expected_context_lines = example - .expected_excerpts - .iter() - .flat_map(|excerpt| { - excerpt - .text - .lines() - .map(|line| format!("{}: {line}", excerpt.path.display())) - }) - .collect(); - let actual_context_lines = preds + let actual_context_lines: HashSet<_> = preds .excerpts .iter() .flat_map(|excerpt| { @@ -208,8 +216,39 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResul }) .collect(); - result.context = precision_recall(&expected_context_lines, &actual_context_lines); + let mut false_positive_lines = actual_context_lines.clone(); + + for entry in &example.expected_context { + let mut best_alternative_score = Scores::default(); + + for alternative in &entry.alternatives { + let expected: HashSet<_> = alternative + .excerpts + .iter() + .flat_map(|excerpt| { + excerpt + .text + .lines() + .map(|line| format!("{}: {line}", excerpt.path.display())) + }) + .collect(); + + let scores = Scores::new(&expected, &actual_context_lines); + false_positive_lines.retain(|line| !actual_context_lines.contains(line)); + + if scores.recall() > best_alternative_score.recall() { + best_alternative_score = scores; + } + } + + eval_result.context.false_negatives += best_alternative_score.false_negatives; + eval_result.context.true_positives += best_alternative_score.true_positives; + } + + eval_result.context.false_positives = false_positive_lines.len(); + + // todo: alternatives for patches let expected_patch_lines = example .expected_patch .lines() @@ -226,86 +265,8 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResul .map(|line| line.to_string()) .collect(); - result.edit_prediction = precision_recall(&expected_patch_lines, &actual_patch_lines); - - result -} - -fn precision_recall(expected: &HashSet, actual: &HashSet) -> Scores { - let true_positives = expected.intersection(actual).count(); - let false_positives = actual.difference(expected).count(); - let false_negatives = expected.difference(actual).count(); - - let precision = if true_positives + false_positives == 0 { - 0.0 - } else { - true_positives as f64 / (true_positives + false_positives) as f64 - }; - let recall = if true_positives + false_negatives == 0 { - 0.0 - } else { - true_positives as f64 / (true_positives + false_negatives) as f64 - }; - let f1_score = if precision + recall == 0.0 { - 0.0 - } else { - 2.0 * precision * recall / (precision + recall) - }; - - Scores { - precision, - recall, - f1_score, - true_positives, - false_positives, - false_negatives, - } -} - -/// Compare actual and expected context. -/// -/// Return expected context annotated with these markers: -/// -/// `✓ context line` -- line was correctly predicted -/// `✗ context line` -- line is missing from predictions -pub fn compare_context(example: &Example, preds: &PredictionDetails) -> String { - let use_color = std::io::stdout().is_terminal(); - let green = if use_color { "\x1b[32m" } else { "" }; - let red = if use_color { "\x1b[31m" } else { "" }; - let reset = if use_color { "\x1b[0m" } else { "" }; - let expected: Vec<_> = example - .expected_excerpts - .iter() - .flat_map(|excerpt| { - excerpt - .text - .lines() - .map(|line| (excerpt.path.clone(), line)) - }) - .collect(); - let actual: HashSet<_> = preds - .excerpts - .iter() - .flat_map(|excerpt| { - excerpt - .text - .lines() - .map(|line| (excerpt.path.clone(), line)) - }) - .collect(); - - let annotated = expected - .iter() - .map(|(path, line)| { - if actual.contains(&(path.to_path_buf(), line)) { - format!("{green}✓ {line}{reset}") - } else { - format!("{red}✗ {line}{reset}") - } - }) - .collect::>(); - - annotated.join("\n") + eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines); + eval_result } /// Return annotated `patch_a` so that: diff --git a/crates/zeta_cli/src/example.rs b/crates/zeta_cli/src/example.rs index ab62d690887aa42b2fb3de0c7f05cfc0975de177..e3c988dd4b0046f117999ad61050c3d710f618a1 100644 --- a/crates/zeta_cli/src/example.rs +++ b/crates/zeta_cli/src/example.rs @@ -13,6 +13,7 @@ use anyhow::{Context as _, Result, anyhow}; use clap::ValueEnum; use cloud_zeta2_prompt::CURSOR_MARKER; use collections::HashMap; +use edit_prediction_context::Line; use futures::{ AsyncWriteExt as _, lock::{Mutex, OwnedMutexGuard}, @@ -31,7 +32,7 @@ const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff"; const EDIT_HISTORY_HEADING: &str = "Edit History"; const CURSOR_POSITION_HEADING: &str = "Cursor Position"; const EXPECTED_PATCH_HEADING: &str = "Expected Patch"; -const EXPECTED_EXCERPTS_HEADING: &str = "Expected Excerpts"; +const EXPECTED_CONTEXT_HEADING: &str = "Expected Context"; const REPOSITORY_URL_FIELD: &str = "repository_url"; const REVISION_FIELD: &str = "revision"; @@ -50,10 +51,9 @@ pub struct Example { pub cursor_position: String, pub edit_history: String, pub expected_patch: String, - pub expected_excerpts: Vec, + pub expected_context: Vec, } -pub type ExpectedExcerpt = Excerpt; pub type ActualExcerpt = Excerpt; #[derive(Clone, Debug, Serialize, Deserialize)] @@ -62,6 +62,25 @@ pub struct Excerpt { pub text: String, } +#[derive(Default, Clone, Debug, Serialize, Deserialize)] +pub struct ExpectedContextEntry { + pub heading: String, + pub alternatives: Vec, +} + +#[derive(Default, Clone, Debug, Serialize, Deserialize)] +pub struct ExpectedExcerptSet { + pub heading: String, + pub excerpts: Vec, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ExpectedExcerpt { + pub path: PathBuf, + pub text: String, + pub required_lines: Vec, +} + #[derive(ValueEnum, Debug, Clone)] pub enum ExampleFormat { Json, @@ -111,21 +130,32 @@ impl NamedExample { cursor_position: String::new(), edit_history: String::new(), expected_patch: String::new(), - expected_excerpts: Vec::new(), + expected_context: Vec::new(), }, }; let mut text = String::new(); - let mut current_section = String::new(); let mut block_info: CowStr = "".into(); + #[derive(PartialEq)] + enum Section { + UncommittedDiff, + EditHistory, + CursorPosition, + ExpectedExcerpts, + ExpectedPatch, + Other, + } + + let mut current_section = Section::Other; + for event in parser { match event { Event::Text(line) => { text.push_str(&line); if !named.name.is_empty() - && current_section.is_empty() + && current_section == Section::Other // in h1 section && let Some((field, value)) = line.split_once('=') { @@ -151,7 +181,47 @@ impl NamedExample { named.name = mem::take(&mut text); } Event::End(TagEnd::Heading(HeadingLevel::H2)) => { - current_section = mem::take(&mut text); + let title = mem::take(&mut text); + current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) { + Section::UncommittedDiff + } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) { + Section::EditHistory + } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) { + Section::CursorPosition + } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) { + Section::ExpectedPatch + } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) { + Section::ExpectedExcerpts + } else { + eprintln!("Warning: Unrecognized section `{title:?}`"); + Section::Other + }; + } + Event::End(TagEnd::Heading(HeadingLevel::H3)) => { + let heading = mem::take(&mut text); + match current_section { + Section::ExpectedExcerpts => { + named.example.expected_context.push(ExpectedContextEntry { + heading, + alternatives: Vec::new(), + }); + } + _ => {} + } + } + Event::End(TagEnd::Heading(HeadingLevel::H4)) => { + let heading = mem::take(&mut text); + match current_section { + Section::ExpectedExcerpts => { + let expected_context = &mut named.example.expected_context; + let last_entry = expected_context.last_mut().unwrap(); + last_entry.alternatives.push(ExpectedExcerptSet { + heading, + excerpts: Vec::new(), + }) + } + _ => {} + } } Event::End(TagEnd::Heading(level)) => { anyhow::bail!("Unexpected heading level: {level}"); @@ -172,23 +242,53 @@ impl NamedExample { } Event::End(TagEnd::CodeBlock) => { let block_info = block_info.trim(); - if current_section.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) { - named.example.uncommitted_diff = mem::take(&mut text); - } else if current_section.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) { - named.example.edit_history.push_str(&mem::take(&mut text)); - } else if current_section.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) { - named.example.cursor_path = block_info.into(); - named.example.cursor_position = mem::take(&mut text); - } else if current_section.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) { - named.example.expected_patch = mem::take(&mut text); - } else if current_section.eq_ignore_ascii_case(EXPECTED_EXCERPTS_HEADING) { - // TODO: "…" should not be a part of the excerpt - named.example.expected_excerpts.push(ExpectedExcerpt { - path: block_info.into(), - text: mem::take(&mut text), - }); - } else { - eprintln!("Warning: Unrecognized section `{current_section:?}`") + match current_section { + Section::UncommittedDiff => { + named.example.uncommitted_diff = mem::take(&mut text); + } + Section::EditHistory => { + named.example.edit_history.push_str(&mem::take(&mut text)); + } + Section::CursorPosition => { + named.example.cursor_path = block_info.into(); + named.example.cursor_position = mem::take(&mut text); + } + Section::ExpectedExcerpts => { + let text = mem::take(&mut text); + for excerpt in text.split("\n…\n") { + let (mut text, required_lines) = extract_required_lines(&excerpt); + if !text.ends_with('\n') { + text.push('\n'); + } + let alternatives = &mut named + .example + .expected_context + .last_mut() + .unwrap() + .alternatives; + + if alternatives.is_empty() { + alternatives.push(ExpectedExcerptSet { + heading: String::new(), + excerpts: vec![], + }); + } + + alternatives + .last_mut() + .unwrap() + .excerpts + .push(ExpectedExcerpt { + path: block_info.into(), + text, + required_lines, + }); + } + } + Section::ExpectedPatch => { + named.example.expected_patch = mem::take(&mut text); + } + Section::Other => {} } } _ => {} @@ -404,6 +504,47 @@ impl NamedExample { } } +fn extract_required_lines(text: &str) -> (String, Vec) { + const MARKER: &str = "[ZETA]"; + let mut new_text = String::new(); + let mut required_lines = Vec::new(); + let mut skipped_lines = 0_u32; + + for (row, mut line) in text.split('\n').enumerate() { + if let Some(marker_column) = line.find(MARKER) { + let mut strip_column = marker_column; + + while strip_column > 0 { + let prev_char = line[strip_column - 1..].chars().next().unwrap(); + if prev_char.is_whitespace() || ['/', '#'].contains(&prev_char) { + strip_column -= 1; + } else { + break; + } + } + + let metadata = &line[marker_column + MARKER.len()..]; + if metadata.contains("required") { + required_lines.push(Line(row as u32 - skipped_lines)); + } + + if strip_column == 0 { + skipped_lines += 1; + continue; + } + + line = &line[..strip_column]; + } + + new_text.push_str(line); + new_text.push('\n'); + } + + new_text.pop(); + + (new_text, required_lines) +} + async fn run_git(repo_path: &Path, args: &[&str]) -> Result { let output = smol::process::Command::new("git") .current_dir(repo_path) @@ -458,21 +599,34 @@ impl Display for NamedExample { )?; } - if !self.example.expected_excerpts.is_empty() { - write!(f, "\n## {EXPECTED_EXCERPTS_HEADING}\n\n")?; - - for excerpt in &self.example.expected_excerpts { - write!( - f, - "`````{}{}\n{}`````\n\n", - excerpt - .path - .extension() - .map(|ext| format!("{} ", ext.to_string_lossy())) - .unwrap_or_default(), - excerpt.path.display(), - excerpt.text - )?; + if !self.example.expected_context.is_empty() { + write!(f, "\n## {EXPECTED_CONTEXT_HEADING}\n\n")?; + + for entry in &self.example.expected_context { + write!(f, "\n### {}\n\n", entry.heading)?; + + let skip_h4 = + entry.alternatives.len() == 1 && entry.alternatives[0].heading.is_empty(); + + for excerpt_set in &entry.alternatives { + if !skip_h4 { + write!(f, "\n#### {}\n\n", excerpt_set.heading)?; + } + + for excerpt in &excerpt_set.excerpts { + write!( + f, + "`````{}{}\n{}`````\n\n", + excerpt + .path + .extension() + .map(|ext| format!("{} ", ext.to_string_lossy())) + .unwrap_or_default(), + excerpt.path.display(), + excerpt.text + )?; + } + } } } @@ -496,3 +650,38 @@ pub async fn lock_repo(path: impl AsRef) -> OwnedMutexGuard<()> { .lock_owned() .await } + +#[cfg(test)] +mod tests { + use super::*; + use indoc::indoc; + use pretty_assertions::assert_eq; + + #[test] + fn test_extract_required_lines() { + let input = indoc! {" + zero + one // [ZETA] required + two + // [ZETA] something + three + four # [ZETA] required + five + "}; + + let expected_updated_input = indoc! {" + zero + one + two + three + four + five + "}; + + let expected_required_lines = vec![Line(1), Line(4)]; + + let (updated_input, required_lines) = extract_required_lines(input); + assert_eq!(updated_input, expected_updated_input); + assert_eq!(required_lines, expected_required_lines); + } +}