Cargo.lock 🔗
@@ -5173,6 +5173,7 @@ dependencies = [
"client",
"cloud_llm_client",
"collections",
+ "criterion",
"db",
"debug_adapter_extension",
"dirs 4.0.0",
Oleksiy Syvokon and Ben Kunkle created
Release Notes:
- N/A
---------
Co-authored-by: Ben Kunkle <ben@zed.dev>
Cargo.lock | 1
crates/edit_prediction_cli/Cargo.toml | 8
crates/edit_prediction_cli/benches/kept_rate.rs | 128 +++++
crates/edit_prediction_cli/src/example.rs | 2
crates/edit_prediction_cli/src/kept_rate.rs | 427 +++++++++++++++++++
crates/edit_prediction_cli/src/lib.rs | 4
crates/edit_prediction_cli/src/main.rs | 1
crates/edit_prediction_cli/src/metrics.rs | 2
crates/edit_prediction_cli/src/score.rs | 43 +
9 files changed, 616 insertions(+)
@@ -5173,6 +5173,7 @@ dependencies = [
"client",
"cloud_llm_client",
"collections",
+ "criterion",
"db",
"debug_adapter_extension",
"dirs 4.0.0",
@@ -8,6 +8,9 @@ license = "GPL-3.0-or-later"
[lints]
workspace = true
+[lib]
+path = "src/lib.rs"
+
[[bin]]
name = "ep"
path = "src/main.rs"
@@ -80,9 +83,14 @@ dynamic_prompts = []
ignored = ["wasmtime"]
[dev-dependencies]
+criterion.workspace = true
gpui = { workspace = true, features = ["test-support"] }
indoc.workspace = true
pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] }
tempfile.workspace = true
workspace = { workspace = true, features = ["test-support"] }
+
+[[bench]]
+name = "kept_rate"
+harness = false
@@ -0,0 +1,128 @@
+use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
+use edit_prediction_cli::kept_rate::compute_kept_rate;
+
+fn repeated_function_lines(line_count: usize) -> String {
+ let mut text = String::with_capacity(line_count * 32);
+ for index in 0..line_count {
+ text.push_str("fn helper_");
+ text.push_str(&(index % 16).to_string());
+ text.push_str("() { value += old_name + 1; }\n");
+ }
+ text
+}
+
+fn localized_rename_inputs(line_count: usize) -> (String, String, String) {
+ let base = repeated_function_lines(line_count);
+ let mut predicted = base.clone();
+ let mut final_text = base.clone();
+
+ let needle = "value += old_name + 1;";
+ let prediction = "value += very_long_predicted_name + 1;";
+ let accepted = "value += new_name + 1;";
+
+ let offset = base
+ .rfind(needle)
+ .expect("expected needle in synthetic input");
+ let end = offset + needle.len();
+
+ predicted.replace_range(offset..end, prediction);
+ final_text.replace_range(offset..end, accepted);
+
+ (base, predicted, final_text)
+}
+
+fn identical_new_content_inputs(line_count: usize) -> (String, String, String) {
+ let predicted = repeated_function_lines(line_count);
+ (String::new(), predicted.clone(), predicted)
+}
+
+fn repetitive_token_inputs(token_repetitions: usize) -> (String, String, String) {
+ let repeated_old = "foo + foo + foo + foo + foo\n".repeat(token_repetitions);
+ let repeated_predicted = "foo + foo + prediction_token + foo + foo\n".repeat(token_repetitions);
+ let repeated_final = "foo + foo + kept_token + foo + foo\n".repeat(token_repetitions);
+ (repeated_old, repeated_predicted, repeated_final)
+}
+
+fn kept_rate_benchmark(c: &mut Criterion) {
+ let mut no_change_group = c.benchmark_group("kept_rate/no_change");
+ for line_count in [128usize, 512, 2048] {
+ let text = repeated_function_lines(line_count);
+ no_change_group.bench_with_input(
+ BenchmarkId::new("lines", line_count),
+ &text,
+ |bench, text| {
+ bench.iter(|| {
+ black_box(compute_kept_rate(
+ black_box(text),
+ black_box(text),
+ black_box(text),
+ ));
+ });
+ },
+ );
+ }
+ no_change_group.finish();
+
+ let mut localized_group = c.benchmark_group("kept_rate/localized_rename");
+ for line_count in [128usize, 512, 2048] {
+ let inputs = localized_rename_inputs(line_count);
+ localized_group.bench_with_input(
+ BenchmarkId::new("lines", line_count),
+ &inputs,
+ |bench, inputs| {
+ let (base, predicted, final_text) = inputs;
+ bench.iter(|| {
+ black_box(compute_kept_rate(
+ black_box(base),
+ black_box(predicted),
+ black_box(final_text),
+ ));
+ });
+ },
+ );
+ }
+ localized_group.finish();
+
+ let mut addition_group = c.benchmark_group("kept_rate/identical_addition");
+ for line_count in [128usize, 512, 2048] {
+ let inputs = identical_new_content_inputs(line_count);
+ addition_group.bench_with_input(
+ BenchmarkId::new("lines", line_count),
+ &inputs,
+ |bench, inputs| {
+ let (base, predicted, final_text) = inputs;
+ bench.iter(|| {
+ black_box(compute_kept_rate(
+ black_box(base),
+ black_box(predicted),
+ black_box(final_text),
+ ));
+ });
+ },
+ );
+ }
+ addition_group.finish();
+
+ let mut repetitive_group = c.benchmark_group("kept_rate/repetitive_tokens");
+ for token_repetitions in [64usize, 256, 1024] {
+ let inputs = repetitive_token_inputs(token_repetitions);
+ repetitive_group.bench_with_input(
+ BenchmarkId::new("repetitions", token_repetitions),
+ &inputs,
+ |bench, inputs| {
+ let (base, predicted, final_text) = inputs;
+ bench.iter(|| {
+ black_box(compute_kept_rate(
+ black_box(base),
+ black_box(predicted),
+ black_box(final_text),
+ ));
+ });
+ },
+ );
+ }
+ repetitive_group.finish();
+}
+
+criterion_group!(benches, kept_rate_benchmark);
+criterion_main!(benches);
@@ -184,6 +184,8 @@ pub struct ExampleScore {
#[serde(default)]
pub deleted_tokens: usize,
#[serde(default, skip_serializing_if = "Option::is_none")]
+ pub kept_rate: Option<f64>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
pub cumulative_logprob: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub avg_logprob: Option<f64>,
@@ -0,0 +1,427 @@
+use crate::word_diff::tokenize;
+
+#[cfg(test)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum TokenAnnotation {
+ Context,
+ Kept,
+ Discarded,
+}
+
+#[allow(dead_code)]
+#[derive(Debug, Clone)]
+pub struct KeptRateResult {
+ pub predicted_new_chars: usize,
+ pub final_new_chars: usize,
+ pub kept_chars: usize,
+ pub discarded_chars: usize,
+ pub context_chars: usize,
+ pub kept_rate: f64,
+ #[cfg(test)]
+ pub token_annotations: Vec<TokenAnnotation>,
+}
+
+fn dp_index(width: usize, row: usize, column: usize) -> usize {
+ row * width + column
+}
+
+/// Return masks over `a` and `b` using one-sided LCS tie-breaking for each
+/// side while sharing a single DP table construction.
+fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
+ if a.is_empty() || b.is_empty() {
+ return (vec![false; a.len()], vec![false; b.len()]);
+ }
+
+ if a == b {
+ return (vec![true; a.len()], vec![true; b.len()]);
+ }
+
+ let mut keep_a = vec![false; a.len()];
+ let mut keep_b = vec![false; b.len()];
+
+ let prefix_len = a
+ .iter()
+ .zip(b.iter())
+ .take_while(|(left, right)| left == right)
+ .count();
+ let suffix_len = {
+ let max_suffix = (a.len() - prefix_len).min(b.len() - prefix_len);
+ let mut suffix_len = 0;
+
+ while suffix_len < max_suffix {
+ let a_index = a.len() - 1 - suffix_len;
+ let b_index = b.len() - 1 - suffix_len;
+ if a[a_index] != b[b_index] {
+ break;
+ }
+ suffix_len += 1;
+ }
+
+ suffix_len
+ };
+
+ for index in 0..prefix_len {
+ keep_a[index] = true;
+ keep_b[index] = true;
+ }
+
+ for offset in 0..suffix_len {
+ let a_index = a.len() - suffix_len + offset;
+ let b_index = b.len() - suffix_len + offset;
+ keep_a[a_index] = true;
+ keep_b[b_index] = true;
+ }
+
+ let a_mid = &a[prefix_len..a.len() - suffix_len];
+ let b_mid = &b[prefix_len..b.len() - suffix_len];
+
+ if a_mid.is_empty() || b_mid.is_empty() {
+ return (keep_a, keep_b);
+ }
+
+ let row_count = a_mid.len() + 1;
+ let column_count = b_mid.len() + 1;
+ let mut dp = vec![0u32; row_count * column_count];
+
+ for i in 1..row_count {
+ let token_a = a_mid[i - 1];
+ for j in 1..column_count {
+ let index = dp_index(column_count, i, j);
+ if token_a == b_mid[j - 1] {
+ dp[index] = dp[dp_index(column_count, i - 1, j - 1)] + 1;
+ } else {
+ let up = dp[dp_index(column_count, i - 1, j)];
+ let left = dp[dp_index(column_count, i, j - 1)];
+ dp[index] = up.max(left);
+ }
+ }
+ }
+
+ let mut i = a_mid.len();
+ let mut j = b_mid.len();
+
+ while i > 0 && j > 0 {
+ if a_mid[i - 1] == b_mid[j - 1] {
+ keep_a[prefix_len + i - 1] = true;
+ i -= 1;
+ j -= 1;
+ } else {
+ let up = dp[dp_index(column_count, i - 1, j)];
+ let left = dp[dp_index(column_count, i, j - 1)];
+ if up >= left {
+ i -= 1;
+ } else {
+ j -= 1;
+ }
+ }
+ }
+
+ let mut i = a_mid.len();
+ let mut j = b_mid.len();
+
+ while i > 0 && j > 0 {
+ if a_mid[i - 1] == b_mid[j - 1] {
+ keep_b[prefix_len + j - 1] = true;
+ i -= 1;
+ j -= 1;
+ } else {
+ let up = dp[dp_index(column_count, i - 1, j)];
+ let left = dp[dp_index(column_count, i, j - 1)];
+ if left >= up {
+ j -= 1;
+ } else {
+ i -= 1;
+ }
+ }
+ }
+
+ (keep_a, keep_b)
+}
+
+fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str>, usize, usize) {
+ let mut unmasked_tokens = Vec::with_capacity(tokens.len());
+ let mut unmasked_chars = 0;
+ let mut masked_chars = 0;
+
+ for (&token, &is_masked) in tokens.iter().zip(mask.iter()) {
+ if is_masked {
+ masked_chars += token.len();
+ } else {
+ unmasked_tokens.push(token);
+ unmasked_chars += token.len();
+ }
+ }
+
+ (unmasked_tokens, unmasked_chars, masked_chars)
+}
+
+pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptRateResult {
+ if base == predicted && predicted == final_text {
+ let predicted_tokens = tokenize(predicted);
+ let context_chars = predicted_tokens.iter().map(|token| token.len()).sum();
+ return KeptRateResult {
+ predicted_new_chars: 0,
+ final_new_chars: 0,
+ kept_chars: 0,
+ discarded_chars: 0,
+ context_chars,
+ kept_rate: 1.0,
+ #[cfg(test)]
+ token_annotations: vec![TokenAnnotation::Context; predicted_tokens.len()],
+ };
+ }
+
+ let base_tokens = tokenize(base);
+ let predicted_tokens = tokenize(predicted);
+ let final_tokens = tokenize(final_text);
+
+ let (pred_base_mask, _) = lcs_keep_masks(&predicted_tokens, &base_tokens);
+ let (pred_final_mask, final_pred_mask) = lcs_keep_masks(&predicted_tokens, &final_tokens);
+ let context_mask: Vec<bool> = pred_base_mask
+ .iter()
+ .zip(pred_final_mask.iter())
+ .map(|(&in_base, &in_final)| in_base && in_final)
+ .collect();
+
+ let (stripped_predicted, predicted_new_chars, context_chars) =
+ analyze_masked_tokens(&predicted_tokens, &context_mask);
+
+ let (final_base_mask, _) = lcs_keep_masks(&final_tokens, &base_tokens);
+ let final_context_mask: Vec<bool> = final_base_mask
+ .iter()
+ .zip(final_pred_mask.iter())
+ .map(|(&in_base, &in_predicted)| in_base && in_predicted)
+ .collect();
+
+ let (stripped_final, final_new_chars, _) =
+ analyze_masked_tokens(&final_tokens, &final_context_mask);
+
+ let keep_mask = lcs_keep_masks(&stripped_predicted, &stripped_final).0;
+
+ let kept_chars: usize = stripped_predicted
+ .iter()
+ .zip(keep_mask.iter())
+ .filter_map(|(&token, &is_kept)| is_kept.then_some(token.len()))
+ .sum();
+
+ let discarded_chars = predicted_new_chars - kept_chars;
+
+ let kept_rate = if predicted_new_chars == 0 {
+ if final_new_chars == 0 { 1.0 } else { 0.0 }
+ } else {
+ kept_chars as f64 / predicted_new_chars as f64
+ };
+
+ #[cfg(test)]
+ let token_annotations = {
+ let mut token_annotations = Vec::with_capacity(predicted_tokens.len());
+ let mut new_index = 0;
+ for (token_index, _token) in predicted_tokens.iter().enumerate() {
+ if context_mask[token_index] {
+ token_annotations.push(TokenAnnotation::Context);
+ } else {
+ let annotation = if keep_mask[new_index] {
+ TokenAnnotation::Kept
+ } else {
+ TokenAnnotation::Discarded
+ };
+ #[cfg(test)]
+ token_annotations.push(annotation);
+ new_index += 1;
+ }
+ }
+ token_annotations
+ };
+
+ KeptRateResult {
+ predicted_new_chars,
+ final_new_chars,
+ kept_chars,
+ discarded_chars,
+ context_chars,
+ kept_rate,
+ #[cfg(test)]
+ token_annotations,
+ }
+}
+
+#[cfg(test)]
+mod test_kept_rate {
+ use super::*;
+
+ #[test]
+ fn test_lcs_keep_masks() {
+ let (a_mask, b_mask) = lcs_keep_masks(&["a", "b", "c", "d", "e"], &["a", "c", "e"]);
+ assert_eq!(a_mask, vec![true, false, true, false, true]);
+ assert_eq!(b_mask, vec![true, true, true]);
+
+ let (a_mask, b_mask) = lcs_keep_masks(&[], &["x"]);
+ assert!(a_mask.is_empty());
+ assert_eq!(b_mask, vec![false]);
+ }
+
+ #[test]
+ fn test_lcs_keep_masks_matches_historical_one_sided_masks() {
+ let a = ["x", "a", "x", "b"];
+ let b = ["a", "x", "b", "x"];
+ let (a_mask, b_mask) = lcs_keep_masks(&a, &b);
+ assert_eq!(a_mask, lcs_keep_masks(&a, &b).0);
+ assert_eq!(b_mask, lcs_keep_masks(&b, &a).0);
+ }
+
+ #[test]
+ fn test_rate_extremes() {
+ let no_change = compute_kept_rate("foo bar", "foo bar", "foo bar");
+ assert!((no_change.kept_rate - 1.0).abs() < 1e-6);
+ assert_eq!(no_change.predicted_new_chars, 0);
+ assert!(
+ no_change
+ .token_annotations
+ .iter()
+ .all(|&annotation| annotation == TokenAnnotation::Context)
+ );
+
+ let accepted = compute_kept_rate("old", "new", "new");
+ assert!((accepted.kept_rate - 1.0).abs() < 1e-6);
+
+ let discarded = compute_kept_rate("old", "old", "new");
+ assert!((discarded.kept_rate - 0.0).abs() < 1e-6);
+ }
+
+ #[test]
+ fn test_pure_addition() {
+ let kept = compute_kept_rate("", "brand new line\n", "brand new line\n");
+ assert_eq!(kept.kept_chars, kept.predicted_new_chars);
+ assert!(
+ kept.token_annotations
+ .iter()
+ .all(|&annotation| annotation == TokenAnnotation::Kept)
+ );
+
+ let discarded =
+ compute_kept_rate("", "brand new line\n", "something completely different\n");
+ assert!(discarded.kept_chars < discarded.predicted_new_chars);
+ }
+
+ #[test]
+ fn test_decoy_when_base_excluded() {
+ let base = " decoy.when(mock_sync_hardware_api.sp()).then_return(SpeedStatus.IDLE)\n";
+ let predicted = " decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n";
+ let final_text = " decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n";
+ let result = compute_kept_rate(base, predicted, final_text);
+ let expected_new = "mock_sync_module_hardware".len() + "speed_status".len();
+ assert_eq!(result.predicted_new_chars, expected_new);
+ assert!((result.kept_rate - 1.0).abs() < 1e-6);
+ }
+
+ #[test]
+ fn test_missing_deletion() {
+ let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
+ let predicted = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\neprintln!(\"\");\n";
+ let final_text = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
+ let result = compute_kept_rate(base, predicted, final_text);
+ assert!(
+ result.kept_rate < 0.85,
+ "expected kept_rate < 0.85, got {}",
+ result.kept_rate
+ );
+ assert!(result.discarded_chars > 0);
+ }
+
+ #[test]
+ fn test_empty_prediction() {
+ let result = compute_kept_rate("old line\n", "", "new line\n");
+ assert!((result.kept_rate - 0.0).abs() < 1e-6);
+ }
+
+ #[test]
+ fn test_partial_kept() {
+ let result = compute_kept_rate("old\n", "alpha\nbeta\ngamma\n", "alpha\ngamma\n");
+ assert!(result.kept_chars > 0);
+ assert!(result.discarded_chars > 0);
+ assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
+ }
+
+ #[test]
+ fn test_eprintln_token_alignment() {
+ let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
+ let predicted = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"hello world!\");\n";
+ let final_text = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
+ let result = compute_kept_rate(base, predicted, final_text);
+ assert!(result.discarded_chars > 0);
+ assert!(result.kept_chars > 0);
+ assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
+ assert_eq!(result.kept_chars, 14);
+ assert_eq!(result.discarded_chars, 12);
+ }
+
+ #[test]
+ fn test_annotations_rename() {
+ let base = " foo(old_name)\n";
+ let predicted = " foo(new_name)\n";
+ let final_text = " foo(new_name)\n";
+ let result = compute_kept_rate(base, predicted, final_text);
+
+ assert_eq!(result.predicted_new_chars, "new_name".len());
+ assert_eq!(result.token_annotations.len(), tokenize(predicted).len());
+
+ for (&token, &annotation) in tokenize(predicted).iter().zip(&result.token_annotations) {
+ if token == "new_name" {
+ assert_eq!(annotation, TokenAnnotation::Kept);
+ } else {
+ assert_eq!(annotation, TokenAnnotation::Context);
+ }
+ }
+ }
+
+ #[test]
+ fn test_annotations_eprintln_coloring() {
+ let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
+ let predicted = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"hello world!\");\n";
+ let final_text = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
+ let result = compute_kept_rate(base, predicted, final_text);
+ let predicted_tokens = tokenize(predicted);
+
+ let eprintln_index = predicted_tokens
+ .iter()
+ .position(|&token| token == "eprintln")
+ .expect("eprintln token not found");
+
+ for annotation in &result.token_annotations[..eprintln_index] {
+ assert_eq!(*annotation, TokenAnnotation::Context);
+ }
+
+ assert_eq!(
+ &result.token_annotations[eprintln_index..=eprintln_index + 10],
+ &[
+ TokenAnnotation::Kept,
+ TokenAnnotation::Kept,
+ TokenAnnotation::Kept,
+ TokenAnnotation::Kept,
+ TokenAnnotation::Discarded,
+ TokenAnnotation::Discarded,
+ TokenAnnotation::Discarded,
+ TokenAnnotation::Discarded,
+ TokenAnnotation::Kept,
+ TokenAnnotation::Kept,
+ TokenAnnotation::Kept,
+ ]
+ );
+ assert_eq!(
+ result.token_annotations.last(),
+ Some(&TokenAnnotation::Context)
+ );
+ }
+
+ #[test]
+ fn test_repetitive_tokens_remain_discarded() {
+ let base = "foo + foo + foo + foo + foo\n".repeat(16);
+ let predicted = "foo + foo + prediction_token + foo + foo\n".repeat(16);
+ let final_text = "foo + foo + kept_token + foo + foo\n".repeat(16);
+ let result = compute_kept_rate(&base, &predicted, &final_text);
+
+ assert_eq!(result.kept_chars, 0);
+ assert_eq!(result.discarded_chars, result.predicted_new_chars);
+ assert_eq!(result.predicted_new_chars, "prediction_token".len() * 16);
+ }
+}
@@ -0,0 +1,4 @@
+#[allow(dead_code)]
+mod word_diff;
+
+pub mod kept_rate;
@@ -5,6 +5,7 @@ mod filter_languages;
mod format_prompt;
mod git;
mod headless;
+mod kept_rate;
mod load_project;
mod metrics;
mod openai_client;
@@ -1297,3 +1297,5 @@ index abc123..def456 100644
);
}
}
+
+pub use crate::kept_rate::compute_kept_rate;
@@ -84,6 +84,7 @@ pub async fn run_scoring(
has_isolated_whitespace_changes: false,
inserted_tokens: 0,
deleted_tokens: 0,
+ kept_rate: None,
cumulative_logprob: None,
avg_logprob: None,
};
@@ -120,12 +121,14 @@ pub async fn run_scoring(
let mut best_delta_chr_f_metrics = metrics::DeltaChrFMetrics::default();
let mut best_expected_cursor: Option<usize> = None;
let mut best_patch_idx: Option<usize> = None;
+ let mut best_expected_text: Option<&str> = None;
for (idx, expected) in expected_texts.iter().enumerate() {
let delta_chr_f_metrics = metrics::delta_chr_f(original_text, expected, &actual_text);
if delta_chr_f_metrics.score > best_delta_chr_f_metrics.score {
best_delta_chr_f_metrics = delta_chr_f_metrics;
best_patch_idx = Some(idx);
+ best_expected_text = Some(expected);
}
}
@@ -184,6 +187,10 @@ pub async fn run_scoring(
prediction.actual_cursor.as_ref(),
);
+ let kept_rate = best_expected_text.map(|final_text| {
+ metrics::compute_kept_rate(original_text, &actual_text, final_text).kept_rate
+ });
+
scores.push(ExampleScore {
delta_chr_f: best_delta_chr_f_metrics.score as f32,
delta_chr_f_true_positives: best_delta_chr_f_metrics.counts.true_positives,
@@ -203,6 +210,7 @@ pub async fn run_scoring(
has_isolated_whitespace_changes,
inserted_tokens: token_changes.inserted_tokens,
deleted_tokens: token_changes.deleted_tokens,
+ kept_rate,
cumulative_logprob: prediction.cumulative_logprob,
avg_logprob: prediction.avg_logprob,
});
@@ -267,6 +275,8 @@ pub fn print_report(examples: &[Example], verbose: bool) {
let mut wrong_editable_region_count: usize = 0;
let mut wrong_editable_region_total: usize = 0;
let mut isolated_whitespace_count: usize = 0;
+ let mut kept_rate_sum: f64 = 0.0;
+ let mut kept_rate_count: usize = 0;
let mut patch_inserted_tokens: Vec<usize> = Vec::new();
let mut patch_deleted_tokens: Vec<usize> = Vec::new();
let mut predictions_with_patch: usize = 0;
@@ -359,6 +369,12 @@ pub fn print_report(examples: &[Example], verbose: bool) {
isolated_whitespace_count += 1;
}
+ // Accumulate kept rate metrics
+ if let Some(kr) = score.kept_rate {
+ kept_rate_sum += kr;
+ kept_rate_count += 1;
+ }
+
// Accumulate token change metrics (only for predictions that produced a patch)
let has_patch = example
.predictions
@@ -488,6 +504,16 @@ pub fn print_report(examples: &[Example], verbose: bool) {
println!("Isolated whitespace changes: {}", isolated_ws_str);
}
+ // Print kept rate metrics
+ if kept_rate_count > 0 {
+ let avg_kept_rate = kept_rate_sum / kept_rate_count as f64;
+ println!(
+ "Kept rate: {:.1}% avg ({} evaluated)",
+ avg_kept_rate * 100.0,
+ kept_rate_count
+ );
+ }
+
// Print token change percentile summary (only for predictions with a patch)
if !patch_inserted_tokens.is_empty() {
patch_inserted_tokens.sort_unstable();
@@ -590,6 +616,8 @@ pub struct SummaryJson {
#[serde(skip_serializing_if = "Option::is_none")]
pub wrong_editable_region_rate: Option<f32>,
pub isolated_whitespace_rate: Option<f32>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub avg_kept_rate: Option<f64>,
}
pub fn compute_summary(examples: &[Example]) -> SummaryJson {
@@ -615,6 +643,8 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson {
let mut wrong_editable_region_count: usize = 0;
let mut wrong_editable_region_total: usize = 0;
let mut isolated_whitespace_count: usize = 0;
+ let mut kept_rate_sum: f64 = 0.0;
+ let mut kept_rate_count: usize = 0;
for example in examples {
for (score_idx, score) in example.score.iter().enumerate() {
@@ -655,6 +685,12 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson {
isolated_whitespace_count += 1;
}
+ // Accumulate kept rate metrics
+ if let Some(kr) = score.kept_rate {
+ kept_rate_sum += kr;
+ kept_rate_count += 1;
+ }
+
// Accumulate cursor metrics
if let Some(exact_match) = score.cursor_exact_match {
cursor_total += 1;
@@ -729,6 +765,12 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson {
None
};
+ let avg_kept_rate = if kept_rate_count > 0 {
+ Some(kept_rate_sum / kept_rate_count as f64)
+ } else {
+ None
+ };
+
SummaryJson {
total_examples: total_scores,
avg_delta_chr_f,
@@ -761,6 +803,7 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson {
cursor_total_evaluated,
wrong_editable_region_rate,
isolated_whitespace_rate,
+ avg_kept_rate,
}
}