@@ -1,5 +1,5 @@
use std::{
- collections::HashMap,
+ collections::{BTreeSet, HashMap},
io::{IsTerminal, Write},
path::PathBuf,
sync::Arc,
@@ -140,6 +140,16 @@ fn write_aggregated_scores(
prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
/ successful.len(),
+ context_lines_found_in_context: successful
+ .iter()
+ .map(|r| r.context_lines_found_in_context)
+ .sum::<usize>()
+ / successful.len(),
+ context_lines_in_expected_patch: successful
+ .iter()
+ .map(|r| r.context_lines_in_expected_patch)
+ .sum::<usize>()
+ / successful.len(),
};
writeln!(w, "\n{}", "-".repeat(80))?;
@@ -268,6 +278,8 @@ pub struct EvaluationResult {
pub context: Scores,
pub prompt_len: usize,
pub generated_len: usize,
+ pub context_lines_in_expected_patch: usize,
+ pub context_lines_found_in_context: usize,
}
#[derive(Default, Debug)]
@@ -389,15 +401,17 @@ impl EvaluationResult {
writeln!(f, "### Scores\n")?;
writeln!(
f,
- " Prompt Generated TP FP FN Precision Recall F1"
+ " Prompt Generated RetrievedContext PatchContext TP FP FN Precision Recall F1"
)?;
writeln!(
f,
- "────────────────────────────────────────────────────────────────────────────────────"
+ "─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────"
)?;
writeln!(
f,
- "Context Retrieval {:<7} {:<10} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
+ "Context Retrieval {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
+ "",
+ "",
"",
"",
self.context.true_positives,
@@ -410,9 +424,11 @@ impl EvaluationResult {
if let Some(edit_prediction) = &self.edit_prediction {
writeln!(
f,
- "Edit Prediction {:<7} {:<10} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
+ "Edit Prediction {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
self.prompt_len,
self.generated_len,
+ self.context_lines_found_in_context,
+ self.context_lines_in_expected_patch,
edit_prediction.true_positives,
edit_prediction.false_positives,
edit_prediction.false_negatives,
@@ -425,7 +441,7 @@ impl EvaluationResult {
}
}
-pub fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
+fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
let mut eval_result = EvaluationResult {
prompt_len: preds.prompt_len,
generated_len: preds.generated_len,
@@ -481,13 +497,35 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) ->
if predict {
// todo: alternatives for patches
- let expected_patch_lines = example
+ let expected_patch = example
.expected_patch
.lines()
.map(DiffLine::parse)
+ .collect::<Vec<_>>();
+ let expected_patch_lines = expected_patch
+ .iter()
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
.map(|line| line.to_string())
.collect();
+ let expected_context_lines = expected_patch
+ .iter()
+ .filter_map(|line| {
+ if let DiffLine::Context(str) = line {
+ Some(String::from(*str))
+ } else {
+ None
+ }
+ })
+ .collect::<BTreeSet<_>>();
+ let actual_context_lines = preds
+ .excerpts
+ .iter()
+ .flat_map(|excerpt| excerpt.text.lines().map(ToOwned::to_owned))
+ .collect::<BTreeSet<_>>();
+
+ let matched = expected_context_lines
+ .intersection(&actual_context_lines)
+ .count();
let actual_patch_lines = preds
.diff
@@ -498,6 +536,8 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) ->
.collect();
eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines));
+ eval_result.context_lines_in_expected_patch = expected_context_lines.len();
+ eval_result.context_lines_found_in_context = matched;
}
eval_result