diff --git a/Cargo.lock b/Cargo.lock index d88eff40b621a72a3216f1da56e5917706655d75..8b428dbcd537e33088f40fdde5e3251a6148672a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5173,6 +5173,7 @@ dependencies = [ "client", "cloud_llm_client", "collections", + "criterion", "db", "debug_adapter_extension", "dirs 4.0.0", diff --git a/crates/edit_prediction_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml index 83a78641bc2b14a9ea92cc0eae674135444ac691..323ee3de41902b2140f95da22b0e37fb98d31fd5 100644 --- a/crates/edit_prediction_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -8,6 +8,9 @@ license = "GPL-3.0-or-later" [lints] workspace = true +[lib] +path = "src/lib.rs" + [[bin]] name = "ep" path = "src/main.rs" @@ -80,9 +83,14 @@ dynamic_prompts = [] ignored = ["wasmtime"] [dev-dependencies] +criterion.workspace = true gpui = { workspace = true, features = ["test-support"] } indoc.workspace = true pretty_assertions.workspace = true project = { workspace = true, features = ["test-support"] } tempfile.workspace = true workspace = { workspace = true, features = ["test-support"] } + +[[bench]] +name = "kept_rate" +harness = false diff --git a/crates/edit_prediction_cli/benches/kept_rate.rs b/crates/edit_prediction_cli/benches/kept_rate.rs new file mode 100644 index 0000000000000000000000000000000000000000..eccbb42dc0591ee15a0b942a4c326d0e4f2123ee --- /dev/null +++ b/crates/edit_prediction_cli/benches/kept_rate.rs @@ -0,0 +1,128 @@ +use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; +use edit_prediction_cli::kept_rate::compute_kept_rate; + +fn repeated_function_lines(line_count: usize) -> String { + let mut text = String::with_capacity(line_count * 32); + for index in 0..line_count { + text.push_str("fn helper_"); + text.push_str(&(index % 16).to_string()); + text.push_str("() { value += old_name + 1; }\n"); + } + text +} + +fn localized_rename_inputs(line_count: usize) -> (String, String, String) { + let base = repeated_function_lines(line_count); + let mut predicted = base.clone(); + let mut final_text = base.clone(); + + let needle = "value += old_name + 1;"; + let prediction = "value += very_long_predicted_name + 1;"; + let accepted = "value += new_name + 1;"; + + let offset = base + .rfind(needle) + .expect("expected needle in synthetic input"); + let end = offset + needle.len(); + + predicted.replace_range(offset..end, prediction); + final_text.replace_range(offset..end, accepted); + + (base, predicted, final_text) +} + +fn identical_new_content_inputs(line_count: usize) -> (String, String, String) { + let predicted = repeated_function_lines(line_count); + (String::new(), predicted.clone(), predicted) +} + +fn repetitive_token_inputs(token_repetitions: usize) -> (String, String, String) { + let repeated_old = "foo + foo + foo + foo + foo\n".repeat(token_repetitions); + let repeated_predicted = "foo + foo + prediction_token + foo + foo\n".repeat(token_repetitions); + let repeated_final = "foo + foo + kept_token + foo + foo\n".repeat(token_repetitions); + (repeated_old, repeated_predicted, repeated_final) +} + +fn kept_rate_benchmark(c: &mut Criterion) { + let mut no_change_group = c.benchmark_group("kept_rate/no_change"); + for line_count in [128usize, 512, 2048] { + let text = repeated_function_lines(line_count); + no_change_group.bench_with_input( + BenchmarkId::new("lines", line_count), + &text, + |bench, text| { + bench.iter(|| { + black_box(compute_kept_rate( + black_box(text), + black_box(text), + black_box(text), + )); + }); + }, + ); + } + no_change_group.finish(); + + let mut localized_group = c.benchmark_group("kept_rate/localized_rename"); + for line_count in [128usize, 512, 2048] { + let inputs = localized_rename_inputs(line_count); + localized_group.bench_with_input( + BenchmarkId::new("lines", line_count), + &inputs, + |bench, inputs| { + let (base, predicted, final_text) = inputs; + bench.iter(|| { + black_box(compute_kept_rate( + black_box(base), + black_box(predicted), + black_box(final_text), + )); + }); + }, + ); + } + localized_group.finish(); + + let mut addition_group = c.benchmark_group("kept_rate/identical_addition"); + for line_count in [128usize, 512, 2048] { + let inputs = identical_new_content_inputs(line_count); + addition_group.bench_with_input( + BenchmarkId::new("lines", line_count), + &inputs, + |bench, inputs| { + let (base, predicted, final_text) = inputs; + bench.iter(|| { + black_box(compute_kept_rate( + black_box(base), + black_box(predicted), + black_box(final_text), + )); + }); + }, + ); + } + addition_group.finish(); + + let mut repetitive_group = c.benchmark_group("kept_rate/repetitive_tokens"); + for token_repetitions in [64usize, 256, 1024] { + let inputs = repetitive_token_inputs(token_repetitions); + repetitive_group.bench_with_input( + BenchmarkId::new("repetitions", token_repetitions), + &inputs, + |bench, inputs| { + let (base, predicted, final_text) = inputs; + bench.iter(|| { + black_box(compute_kept_rate( + black_box(base), + black_box(predicted), + black_box(final_text), + )); + }); + }, + ); + } + repetitive_group.finish(); +} + +criterion_group!(benches, kept_rate_benchmark); +criterion_main!(benches); diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index 4827337d37a211056d04cf9ca13f8d49fb91c392..682671141d050836d25705b2732f11500f159209 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -184,6 +184,8 @@ pub struct ExampleScore { #[serde(default)] pub deleted_tokens: usize, #[serde(default, skip_serializing_if = "Option::is_none")] + pub kept_rate: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] pub cumulative_logprob: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub avg_logprob: Option, diff --git a/crates/edit_prediction_cli/src/kept_rate.rs b/crates/edit_prediction_cli/src/kept_rate.rs new file mode 100644 index 0000000000000000000000000000000000000000..565597fd12b567e7f7f23be233b87ba2284a176f --- /dev/null +++ b/crates/edit_prediction_cli/src/kept_rate.rs @@ -0,0 +1,427 @@ +use crate::word_diff::tokenize; + +#[cfg(test)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TokenAnnotation { + Context, + Kept, + Discarded, +} + +#[allow(dead_code)] +#[derive(Debug, Clone)] +pub struct KeptRateResult { + pub predicted_new_chars: usize, + pub final_new_chars: usize, + pub kept_chars: usize, + pub discarded_chars: usize, + pub context_chars: usize, + pub kept_rate: f64, + #[cfg(test)] + pub token_annotations: Vec, +} + +fn dp_index(width: usize, row: usize, column: usize) -> usize { + row * width + column +} + +/// Return masks over `a` and `b` using one-sided LCS tie-breaking for each +/// side while sharing a single DP table construction. +fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec, Vec) { + if a.is_empty() || b.is_empty() { + return (vec![false; a.len()], vec![false; b.len()]); + } + + if a == b { + return (vec![true; a.len()], vec![true; b.len()]); + } + + let mut keep_a = vec![false; a.len()]; + let mut keep_b = vec![false; b.len()]; + + let prefix_len = a + .iter() + .zip(b.iter()) + .take_while(|(left, right)| left == right) + .count(); + let suffix_len = { + let max_suffix = (a.len() - prefix_len).min(b.len() - prefix_len); + let mut suffix_len = 0; + + while suffix_len < max_suffix { + let a_index = a.len() - 1 - suffix_len; + let b_index = b.len() - 1 - suffix_len; + if a[a_index] != b[b_index] { + break; + } + suffix_len += 1; + } + + suffix_len + }; + + for index in 0..prefix_len { + keep_a[index] = true; + keep_b[index] = true; + } + + for offset in 0..suffix_len { + let a_index = a.len() - suffix_len + offset; + let b_index = b.len() - suffix_len + offset; + keep_a[a_index] = true; + keep_b[b_index] = true; + } + + let a_mid = &a[prefix_len..a.len() - suffix_len]; + let b_mid = &b[prefix_len..b.len() - suffix_len]; + + if a_mid.is_empty() || b_mid.is_empty() { + return (keep_a, keep_b); + } + + let row_count = a_mid.len() + 1; + let column_count = b_mid.len() + 1; + let mut dp = vec![0u32; row_count * column_count]; + + for i in 1..row_count { + let token_a = a_mid[i - 1]; + for j in 1..column_count { + let index = dp_index(column_count, i, j); + if token_a == b_mid[j - 1] { + dp[index] = dp[dp_index(column_count, i - 1, j - 1)] + 1; + } else { + let up = dp[dp_index(column_count, i - 1, j)]; + let left = dp[dp_index(column_count, i, j - 1)]; + dp[index] = up.max(left); + } + } + } + + let mut i = a_mid.len(); + let mut j = b_mid.len(); + + while i > 0 && j > 0 { + if a_mid[i - 1] == b_mid[j - 1] { + keep_a[prefix_len + i - 1] = true; + i -= 1; + j -= 1; + } else { + let up = dp[dp_index(column_count, i - 1, j)]; + let left = dp[dp_index(column_count, i, j - 1)]; + if up >= left { + i -= 1; + } else { + j -= 1; + } + } + } + + let mut i = a_mid.len(); + let mut j = b_mid.len(); + + while i > 0 && j > 0 { + if a_mid[i - 1] == b_mid[j - 1] { + keep_b[prefix_len + j - 1] = true; + i -= 1; + j -= 1; + } else { + let up = dp[dp_index(column_count, i - 1, j)]; + let left = dp[dp_index(column_count, i, j - 1)]; + if left >= up { + j -= 1; + } else { + i -= 1; + } + } + } + + (keep_a, keep_b) +} + +fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str>, usize, usize) { + let mut unmasked_tokens = Vec::with_capacity(tokens.len()); + let mut unmasked_chars = 0; + let mut masked_chars = 0; + + for (&token, &is_masked) in tokens.iter().zip(mask.iter()) { + if is_masked { + masked_chars += token.len(); + } else { + unmasked_tokens.push(token); + unmasked_chars += token.len(); + } + } + + (unmasked_tokens, unmasked_chars, masked_chars) +} + +pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptRateResult { + if base == predicted && predicted == final_text { + let predicted_tokens = tokenize(predicted); + let context_chars = predicted_tokens.iter().map(|token| token.len()).sum(); + return KeptRateResult { + predicted_new_chars: 0, + final_new_chars: 0, + kept_chars: 0, + discarded_chars: 0, + context_chars, + kept_rate: 1.0, + #[cfg(test)] + token_annotations: vec![TokenAnnotation::Context; predicted_tokens.len()], + }; + } + + let base_tokens = tokenize(base); + let predicted_tokens = tokenize(predicted); + let final_tokens = tokenize(final_text); + + let (pred_base_mask, _) = lcs_keep_masks(&predicted_tokens, &base_tokens); + let (pred_final_mask, final_pred_mask) = lcs_keep_masks(&predicted_tokens, &final_tokens); + let context_mask: Vec = pred_base_mask + .iter() + .zip(pred_final_mask.iter()) + .map(|(&in_base, &in_final)| in_base && in_final) + .collect(); + + let (stripped_predicted, predicted_new_chars, context_chars) = + analyze_masked_tokens(&predicted_tokens, &context_mask); + + let (final_base_mask, _) = lcs_keep_masks(&final_tokens, &base_tokens); + let final_context_mask: Vec = final_base_mask + .iter() + .zip(final_pred_mask.iter()) + .map(|(&in_base, &in_predicted)| in_base && in_predicted) + .collect(); + + let (stripped_final, final_new_chars, _) = + analyze_masked_tokens(&final_tokens, &final_context_mask); + + let keep_mask = lcs_keep_masks(&stripped_predicted, &stripped_final).0; + + let kept_chars: usize = stripped_predicted + .iter() + .zip(keep_mask.iter()) + .filter_map(|(&token, &is_kept)| is_kept.then_some(token.len())) + .sum(); + + let discarded_chars = predicted_new_chars - kept_chars; + + let kept_rate = if predicted_new_chars == 0 { + if final_new_chars == 0 { 1.0 } else { 0.0 } + } else { + kept_chars as f64 / predicted_new_chars as f64 + }; + + #[cfg(test)] + let token_annotations = { + let mut token_annotations = Vec::with_capacity(predicted_tokens.len()); + let mut new_index = 0; + for (token_index, _token) in predicted_tokens.iter().enumerate() { + if context_mask[token_index] { + token_annotations.push(TokenAnnotation::Context); + } else { + let annotation = if keep_mask[new_index] { + TokenAnnotation::Kept + } else { + TokenAnnotation::Discarded + }; + #[cfg(test)] + token_annotations.push(annotation); + new_index += 1; + } + } + token_annotations + }; + + KeptRateResult { + predicted_new_chars, + final_new_chars, + kept_chars, + discarded_chars, + context_chars, + kept_rate, + #[cfg(test)] + token_annotations, + } +} + +#[cfg(test)] +mod test_kept_rate { + use super::*; + + #[test] + fn test_lcs_keep_masks() { + let (a_mask, b_mask) = lcs_keep_masks(&["a", "b", "c", "d", "e"], &["a", "c", "e"]); + assert_eq!(a_mask, vec![true, false, true, false, true]); + assert_eq!(b_mask, vec![true, true, true]); + + let (a_mask, b_mask) = lcs_keep_masks(&[], &["x"]); + assert!(a_mask.is_empty()); + assert_eq!(b_mask, vec![false]); + } + + #[test] + fn test_lcs_keep_masks_matches_historical_one_sided_masks() { + let a = ["x", "a", "x", "b"]; + let b = ["a", "x", "b", "x"]; + let (a_mask, b_mask) = lcs_keep_masks(&a, &b); + assert_eq!(a_mask, lcs_keep_masks(&a, &b).0); + assert_eq!(b_mask, lcs_keep_masks(&b, &a).0); + } + + #[test] + fn test_rate_extremes() { + let no_change = compute_kept_rate("foo bar", "foo bar", "foo bar"); + assert!((no_change.kept_rate - 1.0).abs() < 1e-6); + assert_eq!(no_change.predicted_new_chars, 0); + assert!( + no_change + .token_annotations + .iter() + .all(|&annotation| annotation == TokenAnnotation::Context) + ); + + let accepted = compute_kept_rate("old", "new", "new"); + assert!((accepted.kept_rate - 1.0).abs() < 1e-6); + + let discarded = compute_kept_rate("old", "old", "new"); + assert!((discarded.kept_rate - 0.0).abs() < 1e-6); + } + + #[test] + fn test_pure_addition() { + let kept = compute_kept_rate("", "brand new line\n", "brand new line\n"); + assert_eq!(kept.kept_chars, kept.predicted_new_chars); + assert!( + kept.token_annotations + .iter() + .all(|&annotation| annotation == TokenAnnotation::Kept) + ); + + let discarded = + compute_kept_rate("", "brand new line\n", "something completely different\n"); + assert!(discarded.kept_chars < discarded.predicted_new_chars); + } + + #[test] + fn test_decoy_when_base_excluded() { + let base = " decoy.when(mock_sync_hardware_api.sp()).then_return(SpeedStatus.IDLE)\n"; + let predicted = " decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n"; + let final_text = " decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n"; + let result = compute_kept_rate(base, predicted, final_text); + let expected_new = "mock_sync_module_hardware".len() + "speed_status".len(); + assert_eq!(result.predicted_new_chars, expected_new); + assert!((result.kept_rate - 1.0).abs() < 1e-6); + } + + #[test] + fn test_missing_deletion() { + let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n epr\n"; + let predicted = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n epr\neprintln!(\"\");\n"; + let final_text = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n eprintln!(\"\");\n"; + let result = compute_kept_rate(base, predicted, final_text); + assert!( + result.kept_rate < 0.85, + "expected kept_rate < 0.85, got {}", + result.kept_rate + ); + assert!(result.discarded_chars > 0); + } + + #[test] + fn test_empty_prediction() { + let result = compute_kept_rate("old line\n", "", "new line\n"); + assert!((result.kept_rate - 0.0).abs() < 1e-6); + } + + #[test] + fn test_partial_kept() { + let result = compute_kept_rate("old\n", "alpha\nbeta\ngamma\n", "alpha\ngamma\n"); + assert!(result.kept_chars > 0); + assert!(result.discarded_chars > 0); + assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0); + } + + #[test] + fn test_eprintln_token_alignment() { + let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n epr\n"; + let predicted = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n eprintln!(\"hello world!\");\n"; + let final_text = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n eprintln!(\"\");\n"; + let result = compute_kept_rate(base, predicted, final_text); + assert!(result.discarded_chars > 0); + assert!(result.kept_chars > 0); + assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0); + assert_eq!(result.kept_chars, 14); + assert_eq!(result.discarded_chars, 12); + } + + #[test] + fn test_annotations_rename() { + let base = " foo(old_name)\n"; + let predicted = " foo(new_name)\n"; + let final_text = " foo(new_name)\n"; + let result = compute_kept_rate(base, predicted, final_text); + + assert_eq!(result.predicted_new_chars, "new_name".len()); + assert_eq!(result.token_annotations.len(), tokenize(predicted).len()); + + for (&token, &annotation) in tokenize(predicted).iter().zip(&result.token_annotations) { + if token == "new_name" { + assert_eq!(annotation, TokenAnnotation::Kept); + } else { + assert_eq!(annotation, TokenAnnotation::Context); + } + } + } + + #[test] + fn test_annotations_eprintln_coloring() { + let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n epr\n"; + let predicted = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n eprintln!(\"hello world!\");\n"; + let final_text = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n eprintln!(\"\");\n"; + let result = compute_kept_rate(base, predicted, final_text); + let predicted_tokens = tokenize(predicted); + + let eprintln_index = predicted_tokens + .iter() + .position(|&token| token == "eprintln") + .expect("eprintln token not found"); + + for annotation in &result.token_annotations[..eprintln_index] { + assert_eq!(*annotation, TokenAnnotation::Context); + } + + assert_eq!( + &result.token_annotations[eprintln_index..=eprintln_index + 10], + &[ + TokenAnnotation::Kept, + TokenAnnotation::Kept, + TokenAnnotation::Kept, + TokenAnnotation::Kept, + TokenAnnotation::Discarded, + TokenAnnotation::Discarded, + TokenAnnotation::Discarded, + TokenAnnotation::Discarded, + TokenAnnotation::Kept, + TokenAnnotation::Kept, + TokenAnnotation::Kept, + ] + ); + assert_eq!( + result.token_annotations.last(), + Some(&TokenAnnotation::Context) + ); + } + + #[test] + fn test_repetitive_tokens_remain_discarded() { + let base = "foo + foo + foo + foo + foo\n".repeat(16); + let predicted = "foo + foo + prediction_token + foo + foo\n".repeat(16); + let final_text = "foo + foo + kept_token + foo + foo\n".repeat(16); + let result = compute_kept_rate(&base, &predicted, &final_text); + + assert_eq!(result.kept_chars, 0); + assert_eq!(result.discarded_chars, result.predicted_new_chars); + assert_eq!(result.predicted_new_chars, "prediction_token".len() * 16); + } +} diff --git a/crates/edit_prediction_cli/src/lib.rs b/crates/edit_prediction_cli/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..920bd942675b460c1a292cda7024ad914ba8167c --- /dev/null +++ b/crates/edit_prediction_cli/src/lib.rs @@ -0,0 +1,4 @@ +#[allow(dead_code)] +mod word_diff; + +pub mod kept_rate; diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index cf9232a04a40df507c187d53becfedcd8db03188..0f29d33947612d64b74f4fd847957ced5ad359a4 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -5,6 +5,7 @@ mod filter_languages; mod format_prompt; mod git; mod headless; +mod kept_rate; mod load_project; mod metrics; mod openai_client; diff --git a/crates/edit_prediction_cli/src/metrics.rs b/crates/edit_prediction_cli/src/metrics.rs index 8037699f4bb6f851fdadb05b435b090b911b010a..ffa26beea6eeb52a9dfdfe823ad474f9e63627a8 100644 --- a/crates/edit_prediction_cli/src/metrics.rs +++ b/crates/edit_prediction_cli/src/metrics.rs @@ -1297,3 +1297,5 @@ index abc123..def456 100644 ); } } + +pub use crate::kept_rate::compute_kept_rate; diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index be9b185809e6e0cd49e0befbeecec0f317339342..cb1bd472c3e4268fe0e1037e331ed8cbd0b51cfb 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -84,6 +84,7 @@ pub async fn run_scoring( has_isolated_whitespace_changes: false, inserted_tokens: 0, deleted_tokens: 0, + kept_rate: None, cumulative_logprob: None, avg_logprob: None, }; @@ -120,12 +121,14 @@ pub async fn run_scoring( let mut best_delta_chr_f_metrics = metrics::DeltaChrFMetrics::default(); let mut best_expected_cursor: Option = None; let mut best_patch_idx: Option = None; + let mut best_expected_text: Option<&str> = None; for (idx, expected) in expected_texts.iter().enumerate() { let delta_chr_f_metrics = metrics::delta_chr_f(original_text, expected, &actual_text); if delta_chr_f_metrics.score > best_delta_chr_f_metrics.score { best_delta_chr_f_metrics = delta_chr_f_metrics; best_patch_idx = Some(idx); + best_expected_text = Some(expected); } } @@ -184,6 +187,10 @@ pub async fn run_scoring( prediction.actual_cursor.as_ref(), ); + let kept_rate = best_expected_text.map(|final_text| { + metrics::compute_kept_rate(original_text, &actual_text, final_text).kept_rate + }); + scores.push(ExampleScore { delta_chr_f: best_delta_chr_f_metrics.score as f32, delta_chr_f_true_positives: best_delta_chr_f_metrics.counts.true_positives, @@ -203,6 +210,7 @@ pub async fn run_scoring( has_isolated_whitespace_changes, inserted_tokens: token_changes.inserted_tokens, deleted_tokens: token_changes.deleted_tokens, + kept_rate, cumulative_logprob: prediction.cumulative_logprob, avg_logprob: prediction.avg_logprob, }); @@ -267,6 +275,8 @@ pub fn print_report(examples: &[Example], verbose: bool) { let mut wrong_editable_region_count: usize = 0; let mut wrong_editable_region_total: usize = 0; let mut isolated_whitespace_count: usize = 0; + let mut kept_rate_sum: f64 = 0.0; + let mut kept_rate_count: usize = 0; let mut patch_inserted_tokens: Vec = Vec::new(); let mut patch_deleted_tokens: Vec = Vec::new(); let mut predictions_with_patch: usize = 0; @@ -359,6 +369,12 @@ pub fn print_report(examples: &[Example], verbose: bool) { isolated_whitespace_count += 1; } + // Accumulate kept rate metrics + if let Some(kr) = score.kept_rate { + kept_rate_sum += kr; + kept_rate_count += 1; + } + // Accumulate token change metrics (only for predictions that produced a patch) let has_patch = example .predictions @@ -488,6 +504,16 @@ pub fn print_report(examples: &[Example], verbose: bool) { println!("Isolated whitespace changes: {}", isolated_ws_str); } + // Print kept rate metrics + if kept_rate_count > 0 { + let avg_kept_rate = kept_rate_sum / kept_rate_count as f64; + println!( + "Kept rate: {:.1}% avg ({} evaluated)", + avg_kept_rate * 100.0, + kept_rate_count + ); + } + // Print token change percentile summary (only for predictions with a patch) if !patch_inserted_tokens.is_empty() { patch_inserted_tokens.sort_unstable(); @@ -590,6 +616,8 @@ pub struct SummaryJson { #[serde(skip_serializing_if = "Option::is_none")] pub wrong_editable_region_rate: Option, pub isolated_whitespace_rate: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub avg_kept_rate: Option, } pub fn compute_summary(examples: &[Example]) -> SummaryJson { @@ -615,6 +643,8 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { let mut wrong_editable_region_count: usize = 0; let mut wrong_editable_region_total: usize = 0; let mut isolated_whitespace_count: usize = 0; + let mut kept_rate_sum: f64 = 0.0; + let mut kept_rate_count: usize = 0; for example in examples { for (score_idx, score) in example.score.iter().enumerate() { @@ -655,6 +685,12 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { isolated_whitespace_count += 1; } + // Accumulate kept rate metrics + if let Some(kr) = score.kept_rate { + kept_rate_sum += kr; + kept_rate_count += 1; + } + // Accumulate cursor metrics if let Some(exact_match) = score.cursor_exact_match { cursor_total += 1; @@ -729,6 +765,12 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { None }; + let avg_kept_rate = if kept_rate_count > 0 { + Some(kept_rate_sum / kept_rate_count as f64) + } else { + None + }; + SummaryJson { total_examples: total_scores, avg_delta_chr_f, @@ -761,6 +803,7 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { cursor_total_evaluated, wrong_editable_region_rate, isolated_whitespace_rate, + avg_kept_rate, } }