From f60f6adb5ba458a8fe0e6dbda9de7ffc4dc56a9b Mon Sep 17 00:00:00 2001 From: Oleksiy Syvokon Date: Wed, 8 Apr 2026 14:37:31 +0300 Subject: [PATCH] ep: Add LCS-based recall This PR adds `correctly_deleted_chars` field and updates `kept_rate` to account for it, not just inserted chars. It also adds `recall_rate` to measure coverage of reference insertions/deletions. Finally, it renames "final" to "reference" and "prediction" to "candidate". --- crates/edit_prediction_cli/src/example.rs | 2 + crates/edit_prediction_cli/src/kept_rate.rs | 208 +++++++++++++------- crates/edit_prediction_cli/src/score.rs | 47 ++++- 3 files changed, 185 insertions(+), 72 deletions(-) diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index 682671141d050836d25705b2732f11500f159209..1e044b0dae353498b67ffa917d89e2945f4f7787 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -186,6 +186,8 @@ pub struct ExampleScore { #[serde(default, skip_serializing_if = "Option::is_none")] pub kept_rate: Option, #[serde(default, skip_serializing_if = "Option::is_none")] + pub recall_rate: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] pub cumulative_logprob: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub avg_logprob: Option, diff --git a/crates/edit_prediction_cli/src/kept_rate.rs b/crates/edit_prediction_cli/src/kept_rate.rs index 565597fd12b567e7f7f23be233b87ba2284a176f..632e02fad5aefe85dc8368830a73252dd3c6a6b5 100644 --- a/crates/edit_prediction_cli/src/kept_rate.rs +++ b/crates/edit_prediction_cli/src/kept_rate.rs @@ -11,12 +11,33 @@ pub enum TokenAnnotation { #[allow(dead_code)] #[derive(Debug, Clone)] pub struct KeptRateResult { - pub predicted_new_chars: usize, - pub final_new_chars: usize, + /// Characters newly introduced by the candidate + pub candidate_new_chars: usize, + /// Characters newly introduced by the reference + pub reference_new_chars: usize, + /// Characters from `base` that are deleted by the candidate. + pub candidate_deleted_chars: usize, + /// Characters from `base` that are deleted by the reference. + pub reference_deleted_chars: usize, + /// Candidate new characters that are also present in the reference. pub kept_chars: usize, + /// Base characters deleted by both the candidate and the reference. + pub correctly_deleted_chars: usize, + /// Candidate new characters that are not kept in the reference. pub discarded_chars: usize, + /// Candidate characters treated as unchanged context pub context_chars: usize, + /// Fraction of candidate edit characters that match the reference edit. + /// + /// This includes both kept newly introduced characters and correctly + /// deleted base characters. pub kept_rate: f64, + /// Fraction of reference edit characters covered by the candidate edit. + /// + /// This includes both kept newly introduced characters and correctly + /// deleted base characters. + pub recall_rate: f64, + /// Per-token classification for candidate tokens used by tests. #[cfg(test)] pub token_annotations: Vec, } @@ -155,68 +176,102 @@ fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str> (unmasked_tokens, unmasked_chars, masked_chars) } -pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptRateResult { - if base == predicted && predicted == final_text { - let predicted_tokens = tokenize(predicted); - let context_chars = predicted_tokens.iter().map(|token| token.len()).sum(); +fn count_unmasked_chars(tokens: &[&str], mask: &[bool]) -> usize { + tokens + .iter() + .zip(mask.iter()) + .filter_map(|(&token, &is_masked)| (!is_masked).then_some(token.len())) + .sum() +} + +pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRateResult { + if base == candidate && candidate == reference { + let candidate_tokens = tokenize(candidate); + let context_chars = candidate_tokens.iter().map(|token| token.len()).sum(); return KeptRateResult { - predicted_new_chars: 0, - final_new_chars: 0, + candidate_new_chars: 0, + reference_new_chars: 0, + candidate_deleted_chars: 0, + reference_deleted_chars: 0, kept_chars: 0, + correctly_deleted_chars: 0, discarded_chars: 0, context_chars, kept_rate: 1.0, + recall_rate: 1.0, #[cfg(test)] - token_annotations: vec![TokenAnnotation::Context; predicted_tokens.len()], + token_annotations: vec![TokenAnnotation::Context; candidate_tokens.len()], }; } let base_tokens = tokenize(base); - let predicted_tokens = tokenize(predicted); - let final_tokens = tokenize(final_text); - - let (pred_base_mask, _) = lcs_keep_masks(&predicted_tokens, &base_tokens); - let (pred_final_mask, final_pred_mask) = lcs_keep_masks(&predicted_tokens, &final_tokens); - let context_mask: Vec = pred_base_mask + let candidate_tokens = tokenize(candidate); + let reference_tokens = tokenize(reference); + + let (candidate_base_mask, base_candidate_mask) = + lcs_keep_masks(&candidate_tokens, &base_tokens); + let (candidate_reference_mask, reference_candidate_mask) = + lcs_keep_masks(&candidate_tokens, &reference_tokens); + let context_mask: Vec = candidate_base_mask .iter() - .zip(pred_final_mask.iter()) + .zip(candidate_reference_mask.iter()) .map(|(&in_base, &in_final)| in_base && in_final) .collect(); - let (stripped_predicted, predicted_new_chars, context_chars) = - analyze_masked_tokens(&predicted_tokens, &context_mask); + let (stripped_candidate, candidate_new_chars, context_chars) = + analyze_masked_tokens(&candidate_tokens, &context_mask); - let (final_base_mask, _) = lcs_keep_masks(&final_tokens, &base_tokens); - let final_context_mask: Vec = final_base_mask + let (reference_base_mask, base_reference_mask) = + lcs_keep_masks(&reference_tokens, &base_tokens); + let reference_context_mask: Vec = reference_base_mask .iter() - .zip(final_pred_mask.iter()) - .map(|(&in_base, &in_predicted)| in_base && in_predicted) + .zip(reference_candidate_mask.iter()) + .map(|(&in_base, &in_candidate)| in_base && in_candidate) .collect(); - let (stripped_final, final_new_chars, _) = - analyze_masked_tokens(&final_tokens, &final_context_mask); + let (stripped_reference, reference_new_chars, _) = + analyze_masked_tokens(&reference_tokens, &reference_context_mask); - let keep_mask = lcs_keep_masks(&stripped_predicted, &stripped_final).0; + let keep_mask = lcs_keep_masks(&stripped_candidate, &stripped_reference).0; - let kept_chars: usize = stripped_predicted + let kept_chars: usize = stripped_candidate .iter() .zip(keep_mask.iter()) .filter_map(|(&token, &is_kept)| is_kept.then_some(token.len())) .sum(); - let discarded_chars = predicted_new_chars - kept_chars; + let candidate_deleted_chars = count_unmasked_chars(&base_tokens, &base_candidate_mask); + let reference_deleted_chars = count_unmasked_chars(&base_tokens, &base_reference_mask); + let correctly_deleted_chars: usize = base_tokens + .iter() + .zip(base_candidate_mask.iter().zip(base_reference_mask.iter())) + .filter_map(|(&token, (&in_candidate, &in_reference))| { + (!in_candidate && !in_reference).then_some(token.len()) + }) + .sum(); + + let discarded_chars = candidate_new_chars - kept_chars; + let matched_edit_chars = kept_chars + correctly_deleted_chars; + let candidate_edit_chars = candidate_new_chars + candidate_deleted_chars; + let reference_edit_chars = reference_new_chars + reference_deleted_chars; - let kept_rate = if predicted_new_chars == 0 { - if final_new_chars == 0 { 1.0 } else { 0.0 } + let kept_rate = if candidate_edit_chars == 0 { + if reference_edit_chars == 0 { 1.0 } else { 0.0 } } else { - kept_chars as f64 / predicted_new_chars as f64 + matched_edit_chars as f64 / candidate_edit_chars as f64 + }; + + let recall_rate = if reference_edit_chars == 0 { + if candidate_edit_chars == 0 { 1.0 } else { 0.0 } + } else { + matched_edit_chars as f64 / reference_edit_chars as f64 }; #[cfg(test)] let token_annotations = { - let mut token_annotations = Vec::with_capacity(predicted_tokens.len()); + let mut token_annotations = Vec::with_capacity(candidate_tokens.len()); let mut new_index = 0; - for (token_index, _token) in predicted_tokens.iter().enumerate() { + for (token_index, _token) in candidate_tokens.iter().enumerate() { if context_mask[token_index] { token_annotations.push(TokenAnnotation::Context); } else { @@ -234,12 +289,16 @@ pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptR }; KeptRateResult { - predicted_new_chars, - final_new_chars, + candidate_new_chars, + reference_new_chars, + candidate_deleted_chars, + reference_deleted_chars, kept_chars, + correctly_deleted_chars, discarded_chars, context_chars, kept_rate, + recall_rate, #[cfg(test)] token_annotations, } @@ -273,7 +332,8 @@ mod test_kept_rate { fn test_rate_extremes() { let no_change = compute_kept_rate("foo bar", "foo bar", "foo bar"); assert!((no_change.kept_rate - 1.0).abs() < 1e-6); - assert_eq!(no_change.predicted_new_chars, 0); + assert!((no_change.recall_rate - 1.0).abs() < 1e-6); + assert_eq!(no_change.candidate_new_chars, 0); assert!( no_change .token_annotations @@ -283,15 +343,17 @@ mod test_kept_rate { let accepted = compute_kept_rate("old", "new", "new"); assert!((accepted.kept_rate - 1.0).abs() < 1e-6); + assert!((accepted.recall_rate - 1.0).abs() < 1e-6); let discarded = compute_kept_rate("old", "old", "new"); assert!((discarded.kept_rate - 0.0).abs() < 1e-6); + assert!((discarded.recall_rate - 0.0).abs() < 1e-6); } #[test] fn test_pure_addition() { let kept = compute_kept_rate("", "brand new line\n", "brand new line\n"); - assert_eq!(kept.kept_chars, kept.predicted_new_chars); + assert_eq!(kept.kept_chars, kept.candidate_new_chars); assert!( kept.token_annotations .iter() @@ -300,26 +362,28 @@ mod test_kept_rate { let discarded = compute_kept_rate("", "brand new line\n", "something completely different\n"); - assert!(discarded.kept_chars < discarded.predicted_new_chars); + assert!(discarded.kept_chars < discarded.candidate_new_chars); } #[test] fn test_decoy_when_base_excluded() { let base = " decoy.when(mock_sync_hardware_api.sp()).then_return(SpeedStatus.IDLE)\n"; - let predicted = " decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n"; - let final_text = " decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n"; - let result = compute_kept_rate(base, predicted, final_text); + let candidate = " decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n"; + let reference = " decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n"; + let result = compute_kept_rate(base, candidate, reference); let expected_new = "mock_sync_module_hardware".len() + "speed_status".len(); - assert_eq!(result.predicted_new_chars, expected_new); + assert_eq!(result.candidate_new_chars, expected_new); + assert!(result.correctly_deleted_chars > 0); assert!((result.kept_rate - 1.0).abs() < 1e-6); + assert!((result.recall_rate - 1.0).abs() < 1e-6); } #[test] fn test_missing_deletion() { let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n epr\n"; - let predicted = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n epr\neprintln!(\"\");\n"; - let final_text = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n eprintln!(\"\");\n"; - let result = compute_kept_rate(base, predicted, final_text); + let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n epr\neprintln!(\"\");\n"; + let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n eprintln!(\"\");\n"; + let result = compute_kept_rate(base, candidate, reference); assert!( result.kept_rate < 0.85, "expected kept_rate < 0.85, got {}", @@ -331,7 +395,12 @@ mod test_kept_rate { #[test] fn test_empty_prediction() { let result = compute_kept_rate("old line\n", "", "new line\n"); - assert!((result.kept_rate - 0.0).abs() < 1e-6); + assert_eq!(result.candidate_new_chars, 0); + assert!(result.candidate_deleted_chars > 0); + assert!(result.correctly_deleted_chars > 0); + assert!(result.correctly_deleted_chars < result.candidate_deleted_chars); + assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0); + assert!(result.recall_rate > 0.0 && result.recall_rate < 1.0); } #[test] @@ -345,9 +414,9 @@ mod test_kept_rate { #[test] fn test_eprintln_token_alignment() { let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n epr\n"; - let predicted = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n eprintln!(\"hello world!\");\n"; - let final_text = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n eprintln!(\"\");\n"; - let result = compute_kept_rate(base, predicted, final_text); + let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n eprintln!(\"hello world!\");\n"; + let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n eprintln!(\"\");\n"; + let result = compute_kept_rate(base, candidate, reference); assert!(result.discarded_chars > 0); assert!(result.kept_chars > 0); assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0); @@ -358,14 +427,18 @@ mod test_kept_rate { #[test] fn test_annotations_rename() { let base = " foo(old_name)\n"; - let predicted = " foo(new_name)\n"; - let final_text = " foo(new_name)\n"; - let result = compute_kept_rate(base, predicted, final_text); - - assert_eq!(result.predicted_new_chars, "new_name".len()); - assert_eq!(result.token_annotations.len(), tokenize(predicted).len()); - - for (&token, &annotation) in tokenize(predicted).iter().zip(&result.token_annotations) { + let candidate = " foo(new_name)\n"; + let reference = " foo(new_name)\n"; + let result = compute_kept_rate(base, candidate, reference); + + assert_eq!(result.candidate_new_chars, "new_name".len()); + assert_eq!(result.candidate_deleted_chars, "old_name".len()); + assert_eq!(result.reference_deleted_chars, "old_name".len()); + assert_eq!(result.correctly_deleted_chars, "old_name".len()); + assert!((result.recall_rate - 1.0).abs() < 1e-6); + assert_eq!(result.token_annotations.len(), tokenize(candidate).len()); + + for (&token, &annotation) in tokenize(candidate).iter().zip(&result.token_annotations) { if token == "new_name" { assert_eq!(annotation, TokenAnnotation::Kept); } else { @@ -377,12 +450,12 @@ mod test_kept_rate { #[test] fn test_annotations_eprintln_coloring() { let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n epr\n"; - let predicted = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n eprintln!(\"hello world!\");\n"; - let final_text = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n eprintln!(\"\");\n"; - let result = compute_kept_rate(base, predicted, final_text); - let predicted_tokens = tokenize(predicted); + let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n eprintln!(\"hello world!\");\n"; + let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n eprintln!(\"\");\n"; + let result = compute_kept_rate(base, candidate, reference); + let candidate_tokens = tokenize(candidate); - let eprintln_index = predicted_tokens + let eprintln_index = candidate_tokens .iter() .position(|&token| token == "eprintln") .expect("eprintln token not found"); @@ -416,12 +489,15 @@ mod test_kept_rate { #[test] fn test_repetitive_tokens_remain_discarded() { let base = "foo + foo + foo + foo + foo\n".repeat(16); - let predicted = "foo + foo + prediction_token + foo + foo\n".repeat(16); - let final_text = "foo + foo + kept_token + foo + foo\n".repeat(16); - let result = compute_kept_rate(&base, &predicted, &final_text); + let candidate = "foo + foo + prediction_token + foo + foo\n".repeat(16); + let reference = "foo + foo + kept_token + foo + foo\n".repeat(16); + let result = compute_kept_rate(&base, &candidate, &reference); assert_eq!(result.kept_chars, 0); - assert_eq!(result.discarded_chars, result.predicted_new_chars); - assert_eq!(result.predicted_new_chars, "prediction_token".len() * 16); + assert_eq!(result.correctly_deleted_chars, "foo".len() * 16); + assert_eq!(result.discarded_chars, result.candidate_new_chars); + assert_eq!(result.candidate_new_chars, "prediction_token".len() * 16); + assert!(result.kept_rate > 0.0); + assert!(result.recall_rate > 0.0); } } diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index 1dace832d4998362610e860b386f4db49f965144..f30cf7d106f737f1e479fdac38adc10e4effcea2 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -85,6 +85,7 @@ pub async fn run_scoring( inserted_tokens: 0, deleted_tokens: 0, kept_rate: None, + recall_rate: None, cumulative_logprob: None, avg_logprob: None, }; @@ -187,9 +188,13 @@ pub async fn run_scoring( prediction.actual_cursor.as_ref(), ); - let kept_rate = best_expected_text.map(|final_text| { - metrics::compute_kept_rate(original_text, &actual_text, final_text).kept_rate - }); + let (kept_rate, recall_rate) = best_expected_text + .map(|reference_text| { + let result = + metrics::compute_kept_rate(original_text, &actual_text, reference_text); + (Some(result.kept_rate), Some(result.recall_rate)) + }) + .unwrap_or((None, None)); scores.push(ExampleScore { delta_chr_f: best_delta_chr_f_metrics.score as f32, @@ -211,6 +216,7 @@ pub async fn run_scoring( inserted_tokens: token_changes.inserted_tokens, deleted_tokens: token_changes.deleted_tokens, kept_rate, + recall_rate, cumulative_logprob: prediction.cumulative_logprob, avg_logprob: prediction.avg_logprob, }); @@ -277,6 +283,8 @@ pub fn print_report(examples: &[Example], verbose: bool) { let mut isolated_whitespace_count: usize = 0; let mut kept_rate_sum: f64 = 0.0; let mut kept_rate_count: usize = 0; + let mut recall_rate_sum: f64 = 0.0; + let mut recall_rate_count: usize = 0; let mut patch_inserted_tokens: Vec = Vec::new(); let mut patch_deleted_tokens: Vec = Vec::new(); let mut predictions_with_patch: usize = 0; @@ -369,11 +377,15 @@ pub fn print_report(examples: &[Example], verbose: bool) { isolated_whitespace_count += 1; } - // Accumulate kept rate metrics + // Accumulate kept and recall rate metrics if let Some(kr) = score.kept_rate { kept_rate_sum += kr; kept_rate_count += 1; } + if let Some(rr) = score.recall_rate { + recall_rate_sum += rr; + recall_rate_count += 1; + } // Accumulate token change metrics (only for predictions that produced a patch) let has_patch = example @@ -504,7 +516,7 @@ pub fn print_report(examples: &[Example], verbose: bool) { println!("Isolated whitespace changes: {}", isolated_ws_str); } - // Print kept rate metrics + // Print kept and recall rate metrics if kept_rate_count > 0 { let avg_kept_rate = kept_rate_sum / kept_rate_count as f64; println!( @@ -513,6 +525,14 @@ pub fn print_report(examples: &[Example], verbose: bool) { kept_rate_count ); } + if recall_rate_count > 0 { + let avg_recall_rate = recall_rate_sum / recall_rate_count as f64; + println!( + "Recall rate: {:.1}% avg ({} evaluated)", + avg_recall_rate * 100.0, + recall_rate_count + ); + } // Print token change percentile summary (only for predictions with a patch) if !patch_inserted_tokens.is_empty() { @@ -618,6 +638,8 @@ pub struct SummaryJson { pub isolated_whitespace_rate: Option, #[serde(skip_serializing_if = "Option::is_none")] pub avg_kept_rate: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub avg_recall_rate: Option, } pub fn compute_summary(examples: &[Example]) -> SummaryJson { @@ -645,6 +667,8 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { let mut isolated_whitespace_count: usize = 0; let mut kept_rate_sum: f64 = 0.0; let mut kept_rate_count: usize = 0; + let mut recall_rate_sum: f64 = 0.0; + let mut recall_rate_count: usize = 0; for example in examples { for (score_idx, score) in example.score.iter().enumerate() { @@ -685,11 +709,15 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { isolated_whitespace_count += 1; } - // Accumulate kept rate metrics + // Accumulate kept and recall rate metrics if let Some(kr) = score.kept_rate { kept_rate_sum += kr; kept_rate_count += 1; } + if let Some(rr) = score.recall_rate { + recall_rate_sum += rr; + recall_rate_count += 1; + } // Accumulate cursor metrics if let Some(exact_match) = score.cursor_exact_match { @@ -771,6 +799,12 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { None }; + let avg_recall_rate = if recall_rate_count > 0 { + Some(recall_rate_sum / recall_rate_count as f64) + } else { + None + }; + SummaryJson { total_examples: total_scores, avg_delta_chr_f, @@ -804,6 +838,7 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { wrong_editable_region_rate, isolated_whitespace_rate, avg_kept_rate, + avg_recall_rate, } }