Allow multiple expected patches, remove line-based patch scoring

Max Brunsfeld created

Change summary

crates/edit_prediction/src/capture_example.rs   |   4 
crates/edit_prediction/src/example_spec.rs      |  20 
crates/edit_prediction_cli/src/distill.rs       |  17 
crates/edit_prediction_cli/src/example.rs       |   3 
crates/edit_prediction_cli/src/format_prompt.rs |  19 +
crates/edit_prediction_cli/src/metrics.rs       | 217 ++++--------------
crates/edit_prediction_cli/src/score.rs         |  75 ++---
7 files changed, 115 insertions(+), 240 deletions(-)

Detailed changes

crates/edit_prediction/src/capture_example.rs πŸ”—

@@ -74,7 +74,7 @@ pub fn capture_example(
             cursor_path: cursor_path.as_std_path().into(),
             cursor_position: String::new(),
             edit_history,
-            expected_patch: String::new(),
+            expected_patches: Vec::new(),
         };
         spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
         Ok(spec)
@@ -350,7 +350,7 @@ mod tests {
                          seven();
                 "}
                 .to_string(),
-                expected_patch: "".to_string(),
+                expected_patches: Vec::new()
             }
         );
     }

crates/edit_prediction/src/example_spec.rs πŸ”—

@@ -15,7 +15,7 @@ pub struct ExampleSpec {
     pub cursor_path: Arc<Path>,
     pub cursor_position: String,
     pub edit_history: String,
-    pub expected_patch: String,
+    pub expected_patches: Vec<String>,
 }
 
 const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
@@ -95,13 +95,15 @@ impl ExampleSpec {
 
         _ = writeln!(markdown, "## {}", EXPECTED_PATCH_HEADING);
         markdown.push('\n');
-        _ = writeln!(markdown, "```diff");
-        markdown.push_str(&self.expected_patch);
-        if !markdown.ends_with('\n') {
+        for patch in &self.expected_patches {
+            _ = writeln!(markdown, "```diff");
+            markdown.push_str(patch);
+            if !markdown.ends_with('\n') {
+                markdown.push('\n');
+            }
+            _ = writeln!(markdown, "```");
             markdown.push('\n');
         }
-        _ = writeln!(markdown, "```");
-        markdown.push('\n');
 
         markdown
     }
@@ -118,7 +120,7 @@ impl ExampleSpec {
             cursor_path: Path::new("").into(),
             cursor_position: String::new(),
             edit_history: String::new(),
-            expected_patch: String::new(),
+            expected_patches: Vec::new(),
         };
 
         if let Some(rest) = input.strip_prefix("+++\n")
@@ -212,7 +214,7 @@ impl ExampleSpec {
                             mem::take(&mut text);
                         }
                         Section::ExpectedPatch => {
-                            spec.expected_patch = mem::take(&mut text);
+                            spec.expected_patches.push(mem::take(&mut text));
                         }
                         Section::Start | Section::Other => {}
                     }
@@ -353,7 +355,7 @@ mod tests {
             cursor_path: Path::new("test.rs").into(),
             cursor_position: String::new(),
             edit_history: String::new(),
-            expected_patch: String::new(),
+            expected_patches: Vec::new(),
         };
 
         // Cursor before `42`

crates/edit_prediction_cli/src/distill.rs πŸ”—

@@ -1,20 +1,15 @@
-use anyhow::{Result, anyhow};
+use anyhow::Result;
 use std::mem;
 
 use crate::example::Example;
 
 pub async fn run_distill(example: &mut Example) -> Result<()> {
-    let [prediction]: [_; 1] =
-        mem::take(&mut example.predictions)
-            .try_into()
-            .map_err(|preds: Vec<_>| {
-                anyhow!(
-                    "Example has {} predictions, but it should have exactly one",
-                    preds.len()
-                )
-            })?;
+    let predictions = mem::take(&mut example.predictions)
+        .into_iter()
+        .map(|p| p.actual_patch)
+        .collect();
 
-    example.spec.expected_patch = prediction.actual_patch;
+    example.spec.expected_patches = predictions;
     example.prompt = None;
     example.predictions = Vec::new();
     example.score = Vec::new();

crates/edit_prediction_cli/src/example.rs πŸ”—

@@ -1,4 +1,4 @@
-use crate::{PredictionProvider, PromptFormat, metrics::ClassificationMetrics};
+use crate::{PredictionProvider, PromptFormat};
 use anyhow::{Context as _, Result};
 use collections::HashMap;
 use edit_prediction::example_spec::ExampleSpec;
@@ -87,7 +87,6 @@ pub struct ExamplePrediction {
 #[derive(Clone, Debug, Serialize, Deserialize)]
 pub struct ExampleScore {
     pub delta_chr_f: f32,
-    pub line_match: ClassificationMetrics,
 }
 
 impl Example {

crates/edit_prediction_cli/src/format_prompt.rs πŸ”—

@@ -30,7 +30,13 @@ pub async fn run_format_prompt(
             let prompt = TeacherPrompt::format_prompt(example);
             example.prompt = Some(ExamplePrompt {
                 input: prompt,
-                expected_output: example.spec.expected_patch.clone(), // TODO
+                // TODO
+                expected_output: example
+                    .spec
+                    .expected_patches
+                    .first()
+                    .context("no expected patches")?
+                    .clone(),
                 format: prompt_format,
             });
         }
@@ -68,8 +74,15 @@ pub async fn run_format_prompt(
                 ))
             })??;
             let prompt = format_zeta_prompt(&input);
-            let expected_output =
-                zeta2_output_for_patch(&input, &example.spec.expected_patch.clone())?;
+            let expected_output = zeta2_output_for_patch(
+                &input,
+                &example
+                    .spec
+                    .expected_patches
+                    .first()
+                    .context("expected patches is empty")?
+                    .clone(),
+            )?;
             example.prompt = Some(ExamplePrompt {
                 input: prompt,
                 expected_output,

crates/edit_prediction_cli/src/metrics.rs πŸ”—

@@ -1,34 +1,17 @@
-use collections::{HashMap, HashSet};
-use edit_prediction::udiff::DiffLine;
-use serde::{Deserialize, Serialize};
+use collections::HashMap;
 
 type Counts = HashMap<String, usize>;
 type CountsDelta = HashMap<String, isize>;
 
-#[derive(Default, Debug, Clone, Serialize, Deserialize)]
-pub struct ClassificationMetrics {
-    pub true_positives: usize,
-    pub false_positives: usize,
-    pub false_negatives: usize,
+#[derive(Default, Debug, Clone)]
+struct ClassificationMetrics {
+    true_positives: usize,
+    false_positives: usize,
+    false_negatives: usize,
 }
 
 impl ClassificationMetrics {
-    pub fn from_sets(
-        expected: &HashSet<String>,
-        actual: &HashSet<String>,
-    ) -> ClassificationMetrics {
-        let true_positives = expected.intersection(actual).count();
-        let false_positives = actual.difference(expected).count();
-        let false_negatives = expected.difference(actual).count();
-
-        ClassificationMetrics {
-            true_positives,
-            false_positives,
-            false_negatives,
-        }
-    }
-
-    pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
+    fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
         let mut true_positives = 0;
         let mut false_positives = 0;
         let mut false_negatives = 0;
@@ -56,27 +39,7 @@ impl ClassificationMetrics {
         }
     }
 
-    pub fn aggregate<'a>(
-        scores: impl Iterator<Item = &'a ClassificationMetrics>,
-    ) -> ClassificationMetrics {
-        let mut true_positives = 0;
-        let mut false_positives = 0;
-        let mut false_negatives = 0;
-
-        for score in scores {
-            true_positives += score.true_positives;
-            false_positives += score.false_positives;
-            false_negatives += score.false_negatives;
-        }
-
-        ClassificationMetrics {
-            true_positives,
-            false_positives,
-            false_negatives,
-        }
-    }
-
-    pub fn precision(&self) -> f64 {
+    fn precision(&self) -> f64 {
         if self.true_positives + self.false_positives == 0 {
             0.0
         } else {
@@ -84,42 +47,13 @@ impl ClassificationMetrics {
         }
     }
 
-    pub fn recall(&self) -> f64 {
+    fn recall(&self) -> f64 {
         if self.true_positives + self.false_negatives == 0 {
             0.0
         } else {
             self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
         }
     }
-
-    pub fn f1_score(&self) -> f64 {
-        let recall = self.recall();
-        let precision = self.precision();
-        if precision + recall == 0.0 {
-            0.0
-        } else {
-            2.0 * precision * recall / (precision + recall)
-        }
-    }
-}
-
-pub fn line_match_score(
-    expected_patch: &[DiffLine],
-    actual_patch: &[DiffLine],
-) -> ClassificationMetrics {
-    let expected_change_lines = expected_patch
-        .iter()
-        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
-        .map(|line| line.to_string())
-        .collect();
-
-    let actual_change_lines = actual_patch
-        .iter()
-        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
-        .map(|line| line.to_string())
-        .collect();
-
-    ClassificationMetrics::from_sets(&expected_change_lines, &actual_change_lines)
 }
 
 enum ChrfWhitespace {
@@ -135,55 +69,26 @@ const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Ignore;
 /// Computes a delta-chrF score that compares two sets of edits.
 ///
 /// This metric works by:
-/// 1. Reconstructing original, golden (expected result), and actual texts from diffs
-/// 2. Computing n-gram count differences (deltas) between original→golden and original→actual
-/// 3. Comparing these deltas to measure how well actual edits match expected edits
-pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 {
-    // Reconstruct texts from diffs
-    let mut original_text = String::new(); // state of the text before any edits
-    let mut golden_text = String::new(); // text after applying golden edits
-    let mut actual_text = String::new(); // text after applying actual edits
-
-    for line in expected {
-        match line {
-            DiffLine::Context(s) => {
-                original_text.push_str(s);
-                golden_text.push_str(s);
-            }
-            DiffLine::Deletion(s) => {
-                original_text.push_str(s);
-            }
-            DiffLine::Addition(s) => {
-                golden_text.push_str(s);
-            }
-            _ => {}
-        }
-    }
-
-    for line in actual {
-        match line {
-            DiffLine::Context(s) | DiffLine::Addition(s) => {
-                actual_text.push_str(s);
-            }
-            _ => {}
-        }
-    }
-
-    // Edge case
-    if original_text == golden_text && golden_text == actual_text {
+/// 1. Computing n-gram count differences (deltas) between original→expected and original→actual
+/// 2. Comparing these deltas to measure how well actual edits match expected edits
+///
+/// Returns a score from 0.0 to 100.0, where 100.0 means the actual edits perfectly match
+/// the expected edits.
+pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> f64 {
+    // Edge case: if all texts are identical, the edits match perfectly
+    if original == expected && expected == actual {
         return 100.0;
     }
 
-    // Compute the metric
-    let original_ngrams = chr_f_ngram_counts(&original_text);
-    let golden_ngrams = chr_f_ngram_counts(&golden_text);
-    let actual_ngrams = chr_f_ngram_counts(&actual_text);
+    let original_ngrams = chr_f_ngram_counts(original);
+    let expected_ngrams = chr_f_ngram_counts(expected);
+    let actual_ngrams = chr_f_ngram_counts(actual);
 
     let mut total_precision = 0.0;
     let mut total_recall = 0.0;
 
     for order in 0..CHR_F_CHAR_ORDER {
-        let expected_delta = compute_ngram_delta(&golden_ngrams[order], &original_ngrams[order]);
+        let expected_delta = compute_ngram_delta(&expected_ngrams[order], &original_ngrams[order]);
         let actual_delta = compute_ngram_delta(&actual_ngrams[order], &original_ngrams[order]);
 
         if expected_delta.is_empty() && actual_delta.is_empty() {
@@ -278,94 +183,68 @@ fn count_ngrams(text: &str, n: usize) -> Counts {
 #[cfg(test)]
 mod test {
     use super::*;
-    use edit_prediction::udiff::DiffLine;
 
     #[test]
     fn test_delta_chr_f_perfect_match() {
-        let diff = vec![
-            DiffLine::Context("fn main() {"),
-            DiffLine::Deletion("    println!(\"Hello\");"),
-            DiffLine::Addition("    println!(\"Hello, World!\");"),
-            DiffLine::Context("}"),
-        ];
-
-        let score = delta_chr_f(&diff, &diff);
+        let original = "fn main() {    println!(\"Hello\");}";
+        let expected = "fn main() {    println!(\"Hello, World!\");}";
+
+        let score = delta_chr_f(original, expected, expected);
         assert!((score - 100.0).abs() < 1e-2);
     }
 
     #[test]
     fn test_delta_chr_f_wrong_edit() {
         // When the edit is wrong
-        let expected = vec![
-            DiffLine::Context("one "),
-            DiffLine::Deletion("two "),
-            DiffLine::Context("three"),
-        ];
-
-        let actual = vec![
-            DiffLine::Context("one "),
-            DiffLine::Context("two "),
-            DiffLine::Deletion("three"),
-            DiffLine::Addition("four"),
-        ];
+        let original = "one two three";
+        let expected = "one three"; // deleted "two "
+        let actual = "one two four"; // deleted "three", added "four"
 
         // Then the score should be low
-        let score = delta_chr_f(&expected, &actual);
+        let score = delta_chr_f(original, expected, actual);
         assert!(score > 20.0 && score < 40.0);
     }
 
     #[test]
     fn test_delta_chr_f_partial_match() {
-        let expected = vec![
-            DiffLine::Deletion("let x = 42;"),
-            DiffLine::Addition("let x = 100;"),
-        ];
-
-        let actual = vec![
-            DiffLine::Deletion("let x = 42;"),
-            DiffLine::Addition("let x = 99;"),
-        ];
+        let original = "let x = 42;";
+        let expected = "let x = 100;";
+        let actual = "let x = 99;";
 
         // We got the edit location right, but the replacement text is wrong.
         // Deleted ngrams will match, bringing the score somewhere in the middle.
-        let score = delta_chr_f(&expected, &actual);
+        let score = delta_chr_f(original, expected, actual);
         assert!(score > 40.0 && score < 60.0);
     }
 
     #[test]
     fn test_delta_chr_f_missed_edit() {
         // When predictions makes no changes
-        let expected = vec![
-            DiffLine::Context("prefix "),
-            DiffLine::Deletion("old"),
-            DiffLine::Addition("new"),
-            DiffLine::Context(" suffix"),
-        ];
-
-        let actual = vec![
-            DiffLine::Context("prefix "),
-            DiffLine::Context("old"),
-            DiffLine::Context(" suffix"),
-        ];
+        let original = "prefix old suffix";
+        let expected = "prefix new suffix";
+        let actual = "prefix old suffix"; // no change
 
         // Then the score should be low (all expected changes are false negatives)
-        let score = delta_chr_f(&expected, &actual);
+        let score = delta_chr_f(original, expected, actual);
         assert!(score < 20.0);
     }
 
     #[test]
     fn test_delta_chr_f_extra_edit() {
         // When adding unexpected content
-        let expected = vec![DiffLine::Context("hello"), DiffLine::Context("world")];
-
-        let actual = vec![
-            DiffLine::Context("hello"),
-            DiffLine::Addition("extra"),
-            DiffLine::Context("world"),
-        ];
+        let original = "helloworld";
+        let expected = "helloworld"; // no change expected
+        let actual = "helloextraworld"; // added "extra"
 
         // Then the score should be low (all actual changes are false positives)
-        let score = delta_chr_f(&expected, &actual);
+        let score = delta_chr_f(original, expected, actual);
         assert!(score < 20.0);
     }
+
+    #[test]
+    fn test_delta_chr_f_no_changes() {
+        let text = "unchanged text";
+        let score = delta_chr_f(text, text, text);
+        assert!((score - 100.0).abs() < 1e-2);
+    }
 }

crates/edit_prediction_cli/src/score.rs πŸ”—

@@ -2,11 +2,12 @@ use crate::{
     PredictArgs,
     example::{Example, ExampleScore},
     headless::EpAppState,
-    metrics::{self, ClassificationMetrics},
+    metrics,
     predict::run_prediction,
     progress::{Progress, Step},
 };
-use edit_prediction::udiff::DiffLine;
+use anyhow::Context as _;
+use edit_prediction::udiff::apply_diff_to_string;
 use gpui::AsyncApp;
 use std::sync::Arc;
 
@@ -27,18 +28,32 @@ pub async fn run_scoring(
 
     let _progress = Progress::global().start(Step::Score, &example.spec.name);
 
-    let expected_patch = parse_patch(&example.spec.expected_patch);
+    let original_text = &example.buffer.as_ref().unwrap().content;
+    let expected_texts: Vec<String> = example
+        .spec
+        .expected_patches
+        .iter()
+        .map(|patch| {
+            apply_diff_to_string(original_text, patch)
+                .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
+        })
+        .collect::<Result<Vec<_>, _>>()?;
 
     let mut scores = vec![];
-
-    for pred in &example.predictions {
-        let actual_patch = parse_patch(&pred.actual_patch);
-        let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
-        let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
-
+    for prediction in &example.predictions {
+        let actual_text = match apply_diff_to_string(original_text, &prediction.actual_patch) {
+            Ok(text) => text,
+            Err(_) => {
+                scores.push(ExampleScore { delta_chr_f: 0.0 });
+                continue;
+            }
+        };
+        let best_delta_chr_f = expected_texts
+            .iter()
+            .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
+            .fold(0.0, f32::max);
         scores.push(ExampleScore {
-            delta_chr_f,
-            line_match,
+            delta_chr_f: best_delta_chr_f,
         });
     }
 
@@ -46,42 +61,25 @@ pub async fn run_scoring(
     Ok(())
 }
 
-fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
-    patch.lines().map(DiffLine::parse).collect()
-}
-
 pub fn print_report(examples: &[Example]) {
     eprintln!(
         "──────────────────────────────────────────────────────────────────────────────────────"
     );
-    eprintln!(
-        "{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
-        "Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
-    );
+    eprintln!("{:<50} {:>10}", "Example name", "DeltaChrF");
     eprintln!(
         "──────────────────────────────────────────────────────────────────────────────────────"
     );
 
-    let mut all_line_match_scores = Vec::new();
     let mut all_delta_chr_f_scores = Vec::new();
 
     for example in examples {
         for score in example.score.iter() {
-            let line_match = &score.line_match;
-
             eprintln!(
-                "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
-                truncate_name(&example.spec.name, 30),
-                line_match.true_positives,
-                line_match.false_positives,
-                line_match.false_negatives,
-                line_match.precision() * 100.0,
-                line_match.recall() * 100.0,
-                line_match.f1_score() * 100.0,
+                "{:<50} {:>9.2}",
+                truncate_name(&example.spec.name, 50),
                 score.delta_chr_f
             );
 
-            all_line_match_scores.push(line_match.clone());
             all_delta_chr_f_scores.push(score.delta_chr_f);
         }
     }
@@ -90,22 +88,11 @@ pub fn print_report(examples: &[Example]) {
         "──────────────────────────────────────────────────────────────────────────────────────"
     );
 
-    if !all_line_match_scores.is_empty() {
-        let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
+    if !all_delta_chr_f_scores.is_empty() {
         let avg_delta_chr_f: f32 =
             all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
 
-        eprintln!(
-            "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
-            "TOTAL",
-            total_line_match.true_positives,
-            total_line_match.false_positives,
-            total_line_match.false_negatives,
-            total_line_match.precision() * 100.0,
-            total_line_match.recall() * 100.0,
-            total_line_match.f1_score() * 100.0,
-            avg_delta_chr_f
-        );
+        eprintln!("{:<50} {:>9.2}", "AVERAGE", avg_delta_chr_f);
         eprintln!(
             "──────────────────────────────────────────────────────────────────────────────────────"
         );