zeta: Add stats about context lines from patch that were retrieved during context retrieval (#43053)

Piotr Osiewicz created

A.K.A: Eval: Expect lines necessary to uniquely target every change in
"Expected Patch" to be included as context

Release Notes:

- N/A

Change summary

crates/zeta_cli/src/evaluate.rs | 54 ++++++++++++++++++++++++++++++----
1 file changed, 47 insertions(+), 7 deletions(-)

Detailed changes

crates/zeta_cli/src/evaluate.rs 🔗

@@ -1,5 +1,5 @@
 use std::{
-    collections::HashMap,
+    collections::{BTreeSet, HashMap},
     io::{IsTerminal, Write},
     path::PathBuf,
     sync::Arc,
@@ -140,6 +140,16 @@ fn write_aggregated_scores(
             prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
             generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
                 / successful.len(),
+            context_lines_found_in_context: successful
+                .iter()
+                .map(|r| r.context_lines_found_in_context)
+                .sum::<usize>()
+                / successful.len(),
+            context_lines_in_expected_patch: successful
+                .iter()
+                .map(|r| r.context_lines_in_expected_patch)
+                .sum::<usize>()
+                / successful.len(),
         };
 
         writeln!(w, "\n{}", "-".repeat(80))?;
@@ -268,6 +278,8 @@ pub struct EvaluationResult {
     pub context: Scores,
     pub prompt_len: usize,
     pub generated_len: usize,
+    pub context_lines_in_expected_patch: usize,
+    pub context_lines_found_in_context: usize,
 }
 
 #[derive(Default, Debug)]
@@ -389,15 +401,17 @@ impl EvaluationResult {
         writeln!(f, "### Scores\n")?;
         writeln!(
             f,
-            "                   Prompt  Generated  TP     FP     FN     Precision   Recall     F1"
+            "                   Prompt  Generated RetrievedContext PatchContext     TP     FP     FN     Precision   Recall     F1"
         )?;
         writeln!(
             f,
-            "────────────────────────────────────────────────────────────────────────────────────"
+            "─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────"
         )?;
         writeln!(
             f,
-            "Context Retrieval  {:<7} {:<10} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
+            "Context Retrieval  {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
+            "",
+            "",
             "",
             "",
             self.context.true_positives,
@@ -410,9 +424,11 @@ impl EvaluationResult {
         if let Some(edit_prediction) = &self.edit_prediction {
             writeln!(
                 f,
-                "Edit Prediction    {:<7} {:<10} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
+                "Edit Prediction    {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
                 self.prompt_len,
                 self.generated_len,
+                self.context_lines_found_in_context,
+                self.context_lines_in_expected_patch,
                 edit_prediction.true_positives,
                 edit_prediction.false_positives,
                 edit_prediction.false_negatives,
@@ -425,7 +441,7 @@ impl EvaluationResult {
     }
 }
 
-pub fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
+fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
     let mut eval_result = EvaluationResult {
         prompt_len: preds.prompt_len,
         generated_len: preds.generated_len,
@@ -481,13 +497,35 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) ->
 
     if predict {
         // todo: alternatives for patches
-        let expected_patch_lines = example
+        let expected_patch = example
             .expected_patch
             .lines()
             .map(DiffLine::parse)
+            .collect::<Vec<_>>();
+        let expected_patch_lines = expected_patch
+            .iter()
             .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
             .map(|line| line.to_string())
             .collect();
+        let expected_context_lines = expected_patch
+            .iter()
+            .filter_map(|line| {
+                if let DiffLine::Context(str) = line {
+                    Some(String::from(*str))
+                } else {
+                    None
+                }
+            })
+            .collect::<BTreeSet<_>>();
+        let actual_context_lines = preds
+            .excerpts
+            .iter()
+            .flat_map(|excerpt| excerpt.text.lines().map(ToOwned::to_owned))
+            .collect::<BTreeSet<_>>();
+
+        let matched = expected_context_lines
+            .intersection(&actual_context_lines)
+            .count();
 
         let actual_patch_lines = preds
             .diff
@@ -498,6 +536,8 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) ->
             .collect();
 
         eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines));
+        eval_result.context_lines_in_expected_patch = expected_context_lines.len();
+        eval_result.context_lines_found_in_context = matched;
     }
 
     eval_result