WIP: zeta2 eval: Compute the edit sites coverage metric

Oleksiy Syvokon and Max Brunsfeld created

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>

Change summary

crates/zeta_cli/src/evaluate.rs | 134 +++++++++++++++++++++++++++++++++-
1 file changed, 127 insertions(+), 7 deletions(-)

Detailed changes

crates/zeta_cli/src/evaluate.rs 🔗

@@ -12,7 +12,7 @@ use collections::HashSet;
 use gpui::AsyncApp;
 
 use crate::{
-    example::{Example, NamedExample},
+    example::{Example, Excerpt, NamedExample},
     headless::ZetaCliAppState,
     predict::{PredictionDetails, zeta2_predict},
 };
@@ -39,6 +39,11 @@ pub async fn run_evaluate(
     let aggregated_result = EvaluationResult {
         context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
         edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
+        edit_sites_coverage: all_results
+            .iter()
+            .map(|r| r.edit_sites_coverage)
+            .sum::<f64>()
+            / all_results.len() as f64,
     };
 
     if example_len > 1 {
@@ -106,6 +111,12 @@ pub async fn run_evaluate_one(
 #[derive(Debug, Default)]
 pub struct EvaluationResult {
     pub context: Scores,
+
+    /// Ratio of edited lines that we expect to edit (as indicated in the
+    /// expected patch) AND were included into the context
+    /// num_correctly_retrieved_lines / num_expected_lines
+    pub edit_sites_coverage: f64,
+
     pub edit_prediction: Scores,
 }
 
@@ -123,12 +134,12 @@ impl Scores {
     pub fn to_markdown(&self) -> String {
         format!(
             "
-Precision       : {:.4}
-Recall          : {:.4}
-F1 Score        : {:.4}
-True Positives  : {}
-False Positives : {}
-False Negatives : {}",
+Precision          : {:.4}
+Recall             : {:.4}
+F1 Score           : {:.4}
+True Positives     : {}
+False Positives    : {}
+False Negatives    : {}",
             self.precision,
             self.recall,
             self.f1_score,
@@ -169,17 +180,25 @@ impl Scores {
     }
 }
 
+#[derive(Debug, Clone)]
+struct EditSitesScores {
+    num_edit_sites: u32,
+    num_correctly_retrieved: u32,
+}
+
 impl EvaluationResult {
     pub fn to_markdown(&self) -> String {
         format!(
             r#"
 ### Context Scores
 {}
+Edit sites coverage: {}
 
 ### Edit Prediction Scores
 {}
 "#,
             self.context.to_markdown(),
+            self.edit_sites_coverage,
             self.edit_prediction.to_markdown()
         )
     }
@@ -229,9 +248,54 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResul
 
     result.edit_prediction = precision_recall(&expected_patch_lines, &actual_patch_lines);
 
+    result.edit_sites_coverage =
+        calculate_edit_sites_coverage(&example.expected_patch, &preds.excerpts);
+
     result
 }
 
+/// Compute the ratio of lines that we expect to edit (are in the expected patch) that
+/// were included in the retrieved context
+/// `num_correctly_retrieved_lines / num_edited_lines_in_expected_patch`
+///
+/// In order to make an edit in some line, the model has to have an access to this line.
+/// If we don't include the line in the retrieved context, there's no chance to make an edit.
+///
+/// This metric reflects that, where 1.0 -- we retrieved all lines to be
+/// edited, and 0.0 -- we retrieved none of them.
+///
+/// Example:
+fn calculate_edit_sites_coverage(patch: &str, excerpts: &[Excerpt]) -> EditSitesScores {
+    // todo:
+    let expected_patch_lines = patch
+        .lines()
+        .map(DiffLine::parse)
+        .filter_map(|line| match line {
+            DiffLine::Deletion(text) => Some(text.trim().to_string()),
+            _ => None,
+        })
+        .collect::<Vec<_>>();
+
+    let correct_cases = expected_patch_lines
+        .iter()
+        .filter(|line| {
+            excerpts.iter().any(|excerpt| {
+                excerpt
+                    .text
+                    .lines()
+                    .any(|excerpt_line| excerpt_line == *line)
+            })
+        })
+        .count();
+    let total_cases = expected_patch_lines.len();
+
+    if total_cases == 0 {
+        0.0
+    } else {
+        correct_cases as f64 / total_cases as f64
+    }
+}
+
 fn precision_recall(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
     let true_positives = expected.intersection(actual).count();
     let false_positives = actual.difference(expected).count();
@@ -336,3 +400,59 @@ pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
 
     annotated.join("\n")
 }
+
+#[cfg(test)]
+mod tests {
+    use super::calculate_edit_sites_coverage;
+    use crate::example::Excerpt;
+
+    #[test]
+    fn test_evaluate_expected_edit_places() {
+        let patch = indoc::indoc! {"
+            --- a/test.txt
+            +++ b/test.txt
+            @@ -1,4 +1,4 @@
+             apple
+            -banana
+            +BANANA
+             cherry
+            -date
+            +DATE
+            "};
+
+        let one_correct_excerpt = vec![Excerpt {
+            path: "test.txt".into(),
+            text: "apple\nbanana\n".to_string(),
+        }];
+
+        assert_eq!(
+            calculate_edit_sites_coverage(&patch, &one_correct_excerpt),
+            0.5,
+        );
+
+        let both_correct_excerpts = vec![
+            Excerpt {
+                path: "test.txt".into(),
+                text: "apple\nbanana\n".to_string(),
+            },
+            Excerpt {
+                path: "test.txt".into(),
+                text: "cherry\ndate\n".to_string(),
+            },
+        ];
+
+        assert_eq!(
+            calculate_edit_sites_coverage(&patch, &both_correct_excerpts),
+            1.0,
+        );
+
+        let incorrect_excerpts = vec![Excerpt {
+            path: "test.txt".into(),
+            text: "apple\n".into(),
+        }];
+        assert_eq!(
+            calculate_edit_sites_coverage(&patch, &incorrect_excerpts),
+            0.0,
+        );
+    }
+}