diff --git a/crates/zeta_cli/src/evaluate.rs b/crates/zeta_cli/src/evaluate.rs index 9f087188b7f7a615398eaab19ae934cdcd5c64ff..14dc0f6c0c105919b822b9077211a1e1d9686d04 100644 --- a/crates/zeta_cli/src/evaluate.rs +++ b/crates/zeta_cli/src/evaluate.rs @@ -1,5 +1,5 @@ use std::{ - collections::HashMap, + collections::{BTreeSet, HashMap}, io::{IsTerminal, Write}, path::PathBuf, sync::Arc, @@ -140,6 +140,16 @@ fn write_aggregated_scores( prompt_len: successful.iter().map(|r| r.prompt_len).sum::() / successful.len(), generated_len: successful.iter().map(|r| r.generated_len).sum::() / successful.len(), + context_lines_found_in_context: successful + .iter() + .map(|r| r.context_lines_found_in_context) + .sum::() + / successful.len(), + context_lines_in_expected_patch: successful + .iter() + .map(|r| r.context_lines_in_expected_patch) + .sum::() + / successful.len(), }; writeln!(w, "\n{}", "-".repeat(80))?; @@ -268,6 +278,8 @@ pub struct EvaluationResult { pub context: Scores, pub prompt_len: usize, pub generated_len: usize, + pub context_lines_in_expected_patch: usize, + pub context_lines_found_in_context: usize, } #[derive(Default, Debug)] @@ -389,15 +401,17 @@ impl EvaluationResult { writeln!(f, "### Scores\n")?; writeln!( f, - " Prompt Generated TP FP FN Precision Recall F1" + " Prompt Generated RetrievedContext PatchContext TP FP FN Precision Recall F1" )?; writeln!( f, - "────────────────────────────────────────────────────────────────────────────────────" + "─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────" )?; writeln!( f, - "Context Retrieval {:<7} {:<10} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}", + "Context Retrieval {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}", + "", + "", "", "", self.context.true_positives, @@ -410,9 +424,11 @@ impl EvaluationResult { if let Some(edit_prediction) = &self.edit_prediction { writeln!( f, - "Edit Prediction {:<7} {:<10} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}", + "Edit Prediction {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}", self.prompt_len, self.generated_len, + self.context_lines_found_in_context, + self.context_lines_in_expected_patch, edit_prediction.true_positives, edit_prediction.false_positives, edit_prediction.false_negatives, @@ -425,7 +441,7 @@ impl EvaluationResult { } } -pub fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult { +fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult { let mut eval_result = EvaluationResult { prompt_len: preds.prompt_len, generated_len: preds.generated_len, @@ -481,13 +497,35 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> if predict { // todo: alternatives for patches - let expected_patch_lines = example + let expected_patch = example .expected_patch .lines() .map(DiffLine::parse) + .collect::>(); + let expected_patch_lines = expected_patch + .iter() .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_))) .map(|line| line.to_string()) .collect(); + let expected_context_lines = expected_patch + .iter() + .filter_map(|line| { + if let DiffLine::Context(str) = line { + Some(String::from(*str)) + } else { + None + } + }) + .collect::>(); + let actual_context_lines = preds + .excerpts + .iter() + .flat_map(|excerpt| excerpt.text.lines().map(ToOwned::to_owned)) + .collect::>(); + + let matched = expected_context_lines + .intersection(&actual_context_lines) + .count(); let actual_patch_lines = preds .diff @@ -498,6 +536,8 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> .collect(); eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines)); + eval_result.context_lines_in_expected_patch = expected_context_lines.len(); + eval_result.context_lines_found_in_context = matched; } eval_result