zeta2: Make eval example file format more expressive (#42156)

Max Brunsfeld , Oleksiy Syvokon , Ben Kunkle , and Agus Zubiaga created

* Allow expressing alternative possible context fetches in `Expected
Context` section
* Allow marking a subset of lines as "required" in `Expected Context`.

We still need to improve how we display the results. I've removed the
context pass/fail pretty printing for now, because it would need to be
rethought to work with the new structure, but for now I think we should
focus on getting basic predictions to run. But this is progress toward a
better structure for eval examples.

Release Notes:

- N/A

---------

Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
Co-authored-by: Ben Kunkle <ben@zed.dev>
Co-authored-by: Agus Zubiaga <agus@zed.dev>

Change summary

crates/zeta_cli/src/evaluate.rs | 195 ++++++++++---------------
crates/zeta_cli/src/example.rs  | 267 +++++++++++++++++++++++++++++-----
2 files changed, 306 insertions(+), 156 deletions(-)

Detailed changes

crates/zeta_cli/src/evaluate.rs šŸ”—

@@ -83,11 +83,6 @@ pub async fn run_evaluate_one(
 
     let evaluation_result = evaluate(&example.example, &predictions);
 
-    println!("# {}\n", example.name);
-    println!(
-        "## Expected Context: \n\n```\n{}\n```\n\n",
-        compare_context(&example.example, &predictions)
-    );
     println!(
         "## Expected edit prediction:\n\n```diff\n{}\n```\n",
         compare_diffs(&example.example.expected_patch, &predictions.diff)
@@ -104,21 +99,30 @@ pub async fn run_evaluate_one(
 
 #[derive(Debug, Default)]
 pub struct EvaluationResult {
-    pub context: Scores,
     pub edit_prediction: Scores,
+    pub context: Scores,
 }
 
 #[derive(Default, Debug)]
 pub struct Scores {
-    pub precision: f64,
-    pub recall: f64,
-    pub f1_score: f64,
     pub true_positives: usize,
     pub false_positives: usize,
     pub false_negatives: usize,
 }
 
 impl Scores {
+    pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
+        let true_positives = expected.intersection(actual).count();
+        let false_positives = actual.difference(expected).count();
+        let false_negatives = expected.difference(actual).count();
+
+        Scores {
+            true_positives,
+            false_positives,
+            false_negatives,
+        }
+    }
+
     pub fn to_markdown(&self) -> String {
         format!(
             "
@@ -128,17 +132,15 @@ F1 Score        : {:.4}
 True Positives  : {}
 False Positives : {}
 False Negatives : {}",
-            self.precision,
-            self.recall,
-            self.f1_score,
+            self.precision(),
+            self.recall(),
+            self.f1_score(),
             self.true_positives,
             self.false_positives,
             self.false_negatives
         )
     }
-}
 
-impl Scores {
     pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
         let mut true_positives = 0;
         let mut false_positives = 0;
@@ -150,22 +152,38 @@ impl Scores {
             false_negatives += score.false_negatives;
         }
 
-        let precision = true_positives as f64 / (true_positives + false_positives) as f64;
-        let recall = true_positives as f64 / (true_positives + false_negatives) as f64;
-        let mut f1_score = 2.0 * precision * recall / (precision + recall);
-        if f1_score.is_nan() {
-            f1_score = 0.0;
-        }
-
         Scores {
-            precision,
-            recall,
-            f1_score,
             true_positives,
             false_positives,
             false_negatives,
         }
     }
+
+    pub fn precision(&self) -> f64 {
+        if self.true_positives + self.false_positives == 0 {
+            0.0
+        } else {
+            self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
+        }
+    }
+
+    pub fn recall(&self) -> f64 {
+        if self.true_positives + self.false_negatives == 0 {
+            0.0
+        } else {
+            self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
+        }
+    }
+
+    pub fn f1_score(&self) -> f64 {
+        let recall = self.recall();
+        let precision = self.precision();
+        if precision + recall == 0.0 {
+            0.0
+        } else {
+            2.0 * precision * recall / (precision + recall)
+        }
+    }
 }
 
 impl EvaluationResult {
@@ -185,19 +203,9 @@ impl EvaluationResult {
 }
 
 pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
-    let mut result = EvaluationResult::default();
+    let mut eval_result = EvaluationResult::default();
 
-    let expected_context_lines = example
-        .expected_excerpts
-        .iter()
-        .flat_map(|excerpt| {
-            excerpt
-                .text
-                .lines()
-                .map(|line| format!("{}: {line}", excerpt.path.display()))
-        })
-        .collect();
-    let actual_context_lines = preds
+    let actual_context_lines: HashSet<_> = preds
         .excerpts
         .iter()
         .flat_map(|excerpt| {
@@ -208,8 +216,39 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResul
         })
         .collect();
 
-    result.context = precision_recall(&expected_context_lines, &actual_context_lines);
+    let mut false_positive_lines = actual_context_lines.clone();
+
+    for entry in &example.expected_context {
+        let mut best_alternative_score = Scores::default();
+
+        for alternative in &entry.alternatives {
+            let expected: HashSet<_> = alternative
+                .excerpts
+                .iter()
+                .flat_map(|excerpt| {
+                    excerpt
+                        .text
+                        .lines()
+                        .map(|line| format!("{}: {line}", excerpt.path.display()))
+                })
+                .collect();
+
+            let scores = Scores::new(&expected, &actual_context_lines);
 
+            false_positive_lines.retain(|line| !actual_context_lines.contains(line));
+
+            if scores.recall() > best_alternative_score.recall() {
+                best_alternative_score = scores;
+            }
+        }
+
+        eval_result.context.false_negatives += best_alternative_score.false_negatives;
+        eval_result.context.true_positives += best_alternative_score.true_positives;
+    }
+
+    eval_result.context.false_positives = false_positive_lines.len();
+
+    // todo: alternatives for patches
     let expected_patch_lines = example
         .expected_patch
         .lines()
@@ -226,86 +265,8 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResul
         .map(|line| line.to_string())
         .collect();
 
-    result.edit_prediction = precision_recall(&expected_patch_lines, &actual_patch_lines);
-
-    result
-}
-
-fn precision_recall(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
-    let true_positives = expected.intersection(actual).count();
-    let false_positives = actual.difference(expected).count();
-    let false_negatives = expected.difference(actual).count();
-
-    let precision = if true_positives + false_positives == 0 {
-        0.0
-    } else {
-        true_positives as f64 / (true_positives + false_positives) as f64
-    };
-    let recall = if true_positives + false_negatives == 0 {
-        0.0
-    } else {
-        true_positives as f64 / (true_positives + false_negatives) as f64
-    };
-    let f1_score = if precision + recall == 0.0 {
-        0.0
-    } else {
-        2.0 * precision * recall / (precision + recall)
-    };
-
-    Scores {
-        precision,
-        recall,
-        f1_score,
-        true_positives,
-        false_positives,
-        false_negatives,
-    }
-}
-
-/// Compare actual and expected context.
-///
-/// Return expected context annotated with these markers:
-///
-/// `āœ“ context line`  -- line was correctly predicted
-/// `āœ— context line`  -- line is missing from predictions
-pub fn compare_context(example: &Example, preds: &PredictionDetails) -> String {
-    let use_color = std::io::stdout().is_terminal();
-    let green = if use_color { "\x1b[32m" } else { "" };
-    let red = if use_color { "\x1b[31m" } else { "" };
-    let reset = if use_color { "\x1b[0m" } else { "" };
-    let expected: Vec<_> = example
-        .expected_excerpts
-        .iter()
-        .flat_map(|excerpt| {
-            excerpt
-                .text
-                .lines()
-                .map(|line| (excerpt.path.clone(), line))
-        })
-        .collect();
-    let actual: HashSet<_> = preds
-        .excerpts
-        .iter()
-        .flat_map(|excerpt| {
-            excerpt
-                .text
-                .lines()
-                .map(|line| (excerpt.path.clone(), line))
-        })
-        .collect();
-
-    let annotated = expected
-        .iter()
-        .map(|(path, line)| {
-            if actual.contains(&(path.to_path_buf(), line)) {
-                format!("{green}āœ“ {line}{reset}")
-            } else {
-                format!("{red}āœ— {line}{reset}")
-            }
-        })
-        .collect::<Vec<String>>();
-
-    annotated.join("\n")
+    eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines);
+    eval_result
 }
 
 /// Return annotated `patch_a` so that:

crates/zeta_cli/src/example.rs šŸ”—

@@ -13,6 +13,7 @@ use anyhow::{Context as _, Result, anyhow};
 use clap::ValueEnum;
 use cloud_zeta2_prompt::CURSOR_MARKER;
 use collections::HashMap;
+use edit_prediction_context::Line;
 use futures::{
     AsyncWriteExt as _,
     lock::{Mutex, OwnedMutexGuard},
@@ -31,7 +32,7 @@ const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
 const EDIT_HISTORY_HEADING: &str = "Edit History";
 const CURSOR_POSITION_HEADING: &str = "Cursor Position";
 const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
-const EXPECTED_EXCERPTS_HEADING: &str = "Expected Excerpts";
+const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
 const REPOSITORY_URL_FIELD: &str = "repository_url";
 const REVISION_FIELD: &str = "revision";
 
@@ -50,10 +51,9 @@ pub struct Example {
     pub cursor_position: String,
     pub edit_history: String,
     pub expected_patch: String,
-    pub expected_excerpts: Vec<ExpectedExcerpt>,
+    pub expected_context: Vec<ExpectedContextEntry>,
 }
 
-pub type ExpectedExcerpt = Excerpt;
 pub type ActualExcerpt = Excerpt;
 
 #[derive(Clone, Debug, Serialize, Deserialize)]
@@ -62,6 +62,25 @@ pub struct Excerpt {
     pub text: String,
 }
 
+#[derive(Default, Clone, Debug, Serialize, Deserialize)]
+pub struct ExpectedContextEntry {
+    pub heading: String,
+    pub alternatives: Vec<ExpectedExcerptSet>,
+}
+
+#[derive(Default, Clone, Debug, Serialize, Deserialize)]
+pub struct ExpectedExcerptSet {
+    pub heading: String,
+    pub excerpts: Vec<ExpectedExcerpt>,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ExpectedExcerpt {
+    pub path: PathBuf,
+    pub text: String,
+    pub required_lines: Vec<Line>,
+}
+
 #[derive(ValueEnum, Debug, Clone)]
 pub enum ExampleFormat {
     Json,
@@ -111,21 +130,32 @@ impl NamedExample {
                 cursor_position: String::new(),
                 edit_history: String::new(),
                 expected_patch: String::new(),
-                expected_excerpts: Vec::new(),
+                expected_context: Vec::new(),
             },
         };
 
         let mut text = String::new();
-        let mut current_section = String::new();
         let mut block_info: CowStr = "".into();
 
+        #[derive(PartialEq)]
+        enum Section {
+            UncommittedDiff,
+            EditHistory,
+            CursorPosition,
+            ExpectedExcerpts,
+            ExpectedPatch,
+            Other,
+        }
+
+        let mut current_section = Section::Other;
+
         for event in parser {
             match event {
                 Event::Text(line) => {
                     text.push_str(&line);
 
                     if !named.name.is_empty()
-                        && current_section.is_empty()
+                        && current_section == Section::Other
                         // in h1 section
                         && let Some((field, value)) = line.split_once('=')
                     {
@@ -151,7 +181,47 @@ impl NamedExample {
                     named.name = mem::take(&mut text);
                 }
                 Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
-                    current_section = mem::take(&mut text);
+                    let title = mem::take(&mut text);
+                    current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
+                        Section::UncommittedDiff
+                    } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
+                        Section::EditHistory
+                    } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
+                        Section::CursorPosition
+                    } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
+                        Section::ExpectedPatch
+                    } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
+                        Section::ExpectedExcerpts
+                    } else {
+                        eprintln!("Warning: Unrecognized section `{title:?}`");
+                        Section::Other
+                    };
+                }
+                Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
+                    let heading = mem::take(&mut text);
+                    match current_section {
+                        Section::ExpectedExcerpts => {
+                            named.example.expected_context.push(ExpectedContextEntry {
+                                heading,
+                                alternatives: Vec::new(),
+                            });
+                        }
+                        _ => {}
+                    }
+                }
+                Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
+                    let heading = mem::take(&mut text);
+                    match current_section {
+                        Section::ExpectedExcerpts => {
+                            let expected_context = &mut named.example.expected_context;
+                            let last_entry = expected_context.last_mut().unwrap();
+                            last_entry.alternatives.push(ExpectedExcerptSet {
+                                heading,
+                                excerpts: Vec::new(),
+                            })
+                        }
+                        _ => {}
+                    }
                 }
                 Event::End(TagEnd::Heading(level)) => {
                     anyhow::bail!("Unexpected heading level: {level}");
@@ -172,23 +242,53 @@ impl NamedExample {
                 }
                 Event::End(TagEnd::CodeBlock) => {
                     let block_info = block_info.trim();
-                    if current_section.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
-                        named.example.uncommitted_diff = mem::take(&mut text);
-                    } else if current_section.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
-                        named.example.edit_history.push_str(&mem::take(&mut text));
-                    } else if current_section.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
-                        named.example.cursor_path = block_info.into();
-                        named.example.cursor_position = mem::take(&mut text);
-                    } else if current_section.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
-                        named.example.expected_patch = mem::take(&mut text);
-                    } else if current_section.eq_ignore_ascii_case(EXPECTED_EXCERPTS_HEADING) {
-                        // TODO: "…" should not be a part of the excerpt
-                        named.example.expected_excerpts.push(ExpectedExcerpt {
-                            path: block_info.into(),
-                            text: mem::take(&mut text),
-                        });
-                    } else {
-                        eprintln!("Warning: Unrecognized section `{current_section:?}`")
+                    match current_section {
+                        Section::UncommittedDiff => {
+                            named.example.uncommitted_diff = mem::take(&mut text);
+                        }
+                        Section::EditHistory => {
+                            named.example.edit_history.push_str(&mem::take(&mut text));
+                        }
+                        Section::CursorPosition => {
+                            named.example.cursor_path = block_info.into();
+                            named.example.cursor_position = mem::take(&mut text);
+                        }
+                        Section::ExpectedExcerpts => {
+                            let text = mem::take(&mut text);
+                            for excerpt in text.split("\n…\n") {
+                                let (mut text, required_lines) = extract_required_lines(&excerpt);
+                                if !text.ends_with('\n') {
+                                    text.push('\n');
+                                }
+                                let alternatives = &mut named
+                                    .example
+                                    .expected_context
+                                    .last_mut()
+                                    .unwrap()
+                                    .alternatives;
+
+                                if alternatives.is_empty() {
+                                    alternatives.push(ExpectedExcerptSet {
+                                        heading: String::new(),
+                                        excerpts: vec![],
+                                    });
+                                }
+
+                                alternatives
+                                    .last_mut()
+                                    .unwrap()
+                                    .excerpts
+                                    .push(ExpectedExcerpt {
+                                        path: block_info.into(),
+                                        text,
+                                        required_lines,
+                                    });
+                            }
+                        }
+                        Section::ExpectedPatch => {
+                            named.example.expected_patch = mem::take(&mut text);
+                        }
+                        Section::Other => {}
                     }
                 }
                 _ => {}
@@ -404,6 +504,47 @@ impl NamedExample {
     }
 }
 
+fn extract_required_lines(text: &str) -> (String, Vec<Line>) {
+    const MARKER: &str = "[ZETA]";
+    let mut new_text = String::new();
+    let mut required_lines = Vec::new();
+    let mut skipped_lines = 0_u32;
+
+    for (row, mut line) in text.split('\n').enumerate() {
+        if let Some(marker_column) = line.find(MARKER) {
+            let mut strip_column = marker_column;
+
+            while strip_column > 0 {
+                let prev_char = line[strip_column - 1..].chars().next().unwrap();
+                if prev_char.is_whitespace() || ['/', '#'].contains(&prev_char) {
+                    strip_column -= 1;
+                } else {
+                    break;
+                }
+            }
+
+            let metadata = &line[marker_column + MARKER.len()..];
+            if metadata.contains("required") {
+                required_lines.push(Line(row as u32 - skipped_lines));
+            }
+
+            if strip_column == 0 {
+                skipped_lines += 1;
+                continue;
+            }
+
+            line = &line[..strip_column];
+        }
+
+        new_text.push_str(line);
+        new_text.push('\n');
+    }
+
+    new_text.pop();
+
+    (new_text, required_lines)
+}
+
 async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
     let output = smol::process::Command::new("git")
         .current_dir(repo_path)
@@ -458,21 +599,34 @@ impl Display for NamedExample {
             )?;
         }
 
-        if !self.example.expected_excerpts.is_empty() {
-            write!(f, "\n## {EXPECTED_EXCERPTS_HEADING}\n\n")?;
-
-            for excerpt in &self.example.expected_excerpts {
-                write!(
-                    f,
-                    "`````{}{}\n{}`````\n\n",
-                    excerpt
-                        .path
-                        .extension()
-                        .map(|ext| format!("{} ", ext.to_string_lossy()))
-                        .unwrap_or_default(),
-                    excerpt.path.display(),
-                    excerpt.text
-                )?;
+        if !self.example.expected_context.is_empty() {
+            write!(f, "\n## {EXPECTED_CONTEXT_HEADING}\n\n")?;
+
+            for entry in &self.example.expected_context {
+                write!(f, "\n### {}\n\n", entry.heading)?;
+
+                let skip_h4 =
+                    entry.alternatives.len() == 1 && entry.alternatives[0].heading.is_empty();
+
+                for excerpt_set in &entry.alternatives {
+                    if !skip_h4 {
+                        write!(f, "\n#### {}\n\n", excerpt_set.heading)?;
+                    }
+
+                    for excerpt in &excerpt_set.excerpts {
+                        write!(
+                            f,
+                            "`````{}{}\n{}`````\n\n",
+                            excerpt
+                                .path
+                                .extension()
+                                .map(|ext| format!("{} ", ext.to_string_lossy()))
+                                .unwrap_or_default(),
+                            excerpt.path.display(),
+                            excerpt.text
+                        )?;
+                    }
+                }
             }
         }
 
@@ -496,3 +650,38 @@ pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
         .lock_owned()
         .await
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use indoc::indoc;
+    use pretty_assertions::assert_eq;
+
+    #[test]
+    fn test_extract_required_lines() {
+        let input = indoc! {"
+            zero
+            one // [ZETA] required
+            two
+            // [ZETA] something
+            three
+            four # [ZETA] required
+            five
+        "};
+
+        let expected_updated_input = indoc! {"
+            zero
+            one
+            two
+            three
+            four
+            five
+        "};
+
+        let expected_required_lines = vec![Line(1), Line(4)];
+
+        let (updated_input, required_lines) = extract_required_lines(input);
+        assert_eq!(updated_input, expected_updated_input);
+        assert_eq!(required_lines, expected_required_lines);
+    }
+}