diff --git a/crates/edit_prediction/src/udiff.rs b/crates/edit_prediction/src/udiff.rs index f107db223137ef62d8f1b0e0327c8dd75616a2ae..e914d4c95f349aee07f32e21caf3c04c318af4d2 100644 --- a/crates/edit_prediction/src/udiff.rs +++ b/crates/edit_prediction/src/udiff.rs @@ -323,7 +323,7 @@ pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result { let mut text = text.to_string(); - while let Some(event) = diff.next()? { + while let Some(event) = diff.next().context("Failed to parse diff")? { match event { DiffEvent::Hunk { hunk, @@ -340,7 +340,7 @@ pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result { disambiguate_by_line_number(&candidates, hunk.start_line, |offset| { text[..offset].matches('\n').count() as u32 }) - .ok_or_else(|| anyhow!("couldn't resolve hunk: {}", hunk.context))?; + .ok_or_else(|| anyhow!("couldn't resolve hunk"))?; for edit in hunk.edits.iter().rev() { let range = (hunk_offset + edit.range.start)..(hunk_offset + edit.range.end); diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index fd165a75d8233a1425afd54d2d16b814db9b5e15..381e914969db312df5dafcf1df1ab6e6d7ba0cc8 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -105,6 +105,8 @@ pub struct ExampleScore { pub exact_lines_fp: usize, #[serde(default)] pub exact_lines_fn: usize, + #[serde(default)] + pub reversal_ratio: f32, } impl Example { diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 79a334a5e874eaa783581726afdc699768d360a7..821dc6d86be489ddd2eb086e074f7ce92af8ef6f 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -16,6 +16,7 @@ mod qa; mod reorder_patch; mod repair; mod retrieve_context; +mod reversal_tracking; mod score; mod split_commit; mod split_dataset; diff --git a/crates/edit_prediction_cli/src/reversal_tracking.rs b/crates/edit_prediction_cli/src/reversal_tracking.rs new file mode 100644 index 0000000000000000000000000000000000000000..a23343dea3dec4d18b6b24833c50efe85014e247 --- /dev/null +++ b/crates/edit_prediction_cli/src/reversal_tracking.rs @@ -0,0 +1,491 @@ +use std::ops::Range; +use std::path::Path; +use std::sync::Arc; + +use edit_prediction::udiff::apply_diff_to_string; +use language::text_diff; + +use crate::example::ExamplePromptInputs; + +pub fn reverse_diff(diff: &str) -> String { + let mut result: String = diff + .lines() + .map(|line| { + if line.starts_with("--- ") { + line.replacen("--- ", "+++ ", 1) + } else if line.starts_with("+++ ") { + line.replacen("+++ ", "--- ", 1) + } else if line.starts_with('+') && !line.starts_with("+++") { + format!("-{}", &line[1..]) + } else if line.starts_with('-') && !line.starts_with("---") { + format!("+{}", &line[1..]) + } else { + line.to_string() + } + }) + .collect::>() + .join("\n"); + if diff.ends_with('\n') { + result.push('\n'); + } + result +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GranularEdit { + pub range: Range, + pub old_text: String, + pub new_text: String, +} + +pub fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec { + text_diff(old_text, new_text) + .into_iter() + .map(|(range, new_text)| GranularEdit { + old_text: old_text[range.clone()].to_string(), + range, + new_text: new_text.to_string(), + }) + .collect() +} + +#[derive(Debug, Clone)] +pub struct HistoryAdditionRange { + pub range_in_current: Range, +} + +#[derive(Debug, Clone)] +pub struct HistoryDeletionRange { + pub deleted_text: String, +} + +pub fn compute_history_addition_ranges( + history_edits: &[GranularEdit], +) -> Vec { + let mut result = Vec::new(); + let mut offset_delta: isize = 0; + + for edit in history_edits { + if !edit.new_text.is_empty() { + let new_start = (edit.range.start as isize + offset_delta) as usize; + let new_end = new_start + edit.new_text.len(); + result.push(HistoryAdditionRange { + range_in_current: new_start..new_end, + }); + } + + offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize; + } + + result +} + +pub fn compute_history_deletion_ranges( + history_edits: &[GranularEdit], +) -> Vec { + history_edits + .iter() + .filter(|edit| !edit.old_text.is_empty()) + .map(|edit| HistoryDeletionRange { + deleted_text: edit.old_text.clone(), + }) + .collect() +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct ReversalOverlap { + pub chars_reversing_user_edits: usize, + pub total_chars_in_prediction: usize, +} + +impl ReversalOverlap { + pub fn ratio(&self) -> f32 { + if self.total_chars_in_prediction == 0 { + 0.0 + } else { + self.chars_reversing_user_edits as f32 / self.total_chars_in_prediction as f32 + } + } +} + +/// Compute how much of a prediction reverses recent user edits. +pub fn compute_reversal_overlap( + original_content: &str, + current_content: &str, + predicted_content: &str, +) -> ReversalOverlap { + let history_edits = compute_granular_edits(original_content, current_content); + let prediction_edits = compute_granular_edits(current_content, predicted_content); + + let history_addition_ranges = compute_history_addition_ranges(&history_edits); + let history_deletion_ranges = compute_history_deletion_ranges(&history_edits); + + let reversed_additions = + compute_reversed_additions(&history_addition_ranges, &prediction_edits); + let restored_deletions = + compute_restored_deletions(&history_deletion_ranges, &prediction_edits); + + let prediction_added_chars: usize = prediction_edits.iter().map(|e| e.new_text.len()).sum(); + let prediction_deleted_chars: usize = prediction_edits.iter().map(|e| e.old_text.len()).sum(); + + ReversalOverlap { + chars_reversing_user_edits: reversed_additions + restored_deletions, + total_chars_in_prediction: prediction_added_chars + prediction_deleted_chars, + } +} + +pub fn compute_reversed_additions( + history_addition_ranges: &[HistoryAdditionRange], + prediction_edits: &[GranularEdit], +) -> usize { + let mut reversed_chars = 0; + + for pred_edit in prediction_edits { + for history_addition in history_addition_ranges { + let overlap_start = pred_edit + .range + .start + .max(history_addition.range_in_current.start); + let overlap_end = pred_edit + .range + .end + .min(history_addition.range_in_current.end); + + if overlap_start < overlap_end { + reversed_chars += overlap_end - overlap_start; + } + } + } + + reversed_chars +} + +pub fn compute_restored_deletions( + history_deletion_ranges: &[HistoryDeletionRange], + prediction_edits: &[GranularEdit], +) -> usize { + let history_deleted_text: String = history_deletion_ranges + .iter() + .map(|r| r.deleted_text.as_str()) + .collect(); + + let prediction_added_text: String = prediction_edits + .iter() + .map(|e| e.new_text.as_str()) + .collect(); + + compute_lcs_length(&history_deleted_text, &prediction_added_text) +} + +fn compute_lcs_length(a: &str, b: &str) -> usize { + let a_chars: Vec = a.chars().collect(); + let b_chars: Vec = b.chars().collect(); + let m = a_chars.len(); + let n = b_chars.len(); + + if m == 0 || n == 0 { + return 0; + } + + let mut prev = vec![0; n + 1]; + let mut curr = vec![0; n + 1]; + + for i in 1..=m { + for j in 1..=n { + if a_chars[i - 1] == b_chars[j - 1] { + curr[j] = prev[j - 1] + 1; + } else { + curr[j] = prev[j].max(curr[j - 1]); + } + } + std::mem::swap(&mut prev, &mut curr); + curr.fill(0); + } + + prev[n] +} + +pub fn filter_edit_history_by_path<'a>( + edit_history: &'a [Arc], + cursor_path: &std::path::Path, +) -> Vec<&'a zeta_prompt::Event> { + edit_history + .iter() + .filter(|event| match event.as_ref() { + zeta_prompt::Event::BufferChange { path, .. } => { + let event_path = path.as_ref(); + if event_path == cursor_path { + return true; + } + let stripped = event_path + .components() + .skip(1) + .collect::(); + stripped == cursor_path + } + }) + .map(|arc| arc.as_ref()) + .collect() +} + +pub fn extract_diff_from_event(event: &zeta_prompt::Event) -> &str { + match event { + zeta_prompt::Event::BufferChange { diff, .. } => diff.as_str(), + } +} + +pub fn compute_prediction_reversal_ratio( + prompt_inputs: &ExamplePromptInputs, + predicted_content: &str, + cursor_path: &Path, +) -> f32 { + let current_content = &prompt_inputs.content; + + let edit_history: &[Arc] = &prompt_inputs.edit_history; + let relevant_events = filter_edit_history_by_path(edit_history, cursor_path); + + let mut original_content = current_content.to_string(); + for event in relevant_events.into_iter().rev() { + let diff = extract_diff_from_event(event); + if diff.is_empty() { + continue; + } + let reversed = reverse_diff(diff); + let with_headers = format!("--- a/file\n+++ b/file\n{}", reversed); + match apply_diff_to_string(&with_headers, &original_content) { + Ok(updated_content) => original_content = updated_content, + Err(err) => { + log::warn!( + "Failed to reconstruct original content for reversal tracking: Failed to apply reversed diff: {:#}", + err + ); + return 0.0; + } + } + } + + let overlap = compute_reversal_overlap(&original_content, current_content, predicted_content); + overlap.ratio() +} + +#[cfg(test)] +mod tests { + use super::*; + use edit_prediction::udiff::apply_diff_to_string; + + #[test] + fn test_reversal_overlap() { + struct Case { + name: &'static str, + original: &'static str, + current: &'static str, + predicted: &'static str, + expected_reversal_chars: usize, + expected_total_chars: usize, + } + + let cases = [ + Case { + name: "user_adds_line_prediction_removes_it", + original: "a\nb\nc", + current: "a\nnew line\nb\nc", + predicted: "a\nb\nc", + expected_reversal_chars: 9, + expected_total_chars: 9, + }, + Case { + name: "user_deletes_line_prediction_restores_it", + original: "a\ndeleted\nb", + current: "a\nb", + predicted: "a\ndeleted\nb", + expected_reversal_chars: 8, + expected_total_chars: 8, + }, + Case { + name: "user_deletes_text_prediction_restores_partial", + original: "hello beautiful world", + current: "hello world", + predicted: "hello beautiful world", + expected_reversal_chars: 10, + expected_total_chars: 10, + }, + Case { + name: "user_deletes_foo_prediction_adds_bar", + original: "foo", + current: "", + predicted: "bar", + expected_reversal_chars: 0, + expected_total_chars: 3, + }, + Case { + name: "independent_edits_different_locations", + original: "line1\nline2\nline3", + current: "LINE1\nline2\nline3", + predicted: "LINE1\nline2\nLINE3", + expected_reversal_chars: 0, + expected_total_chars: 10, + }, + Case { + name: "no_history_edits", + original: "same", + current: "same", + predicted: "different", + expected_reversal_chars: 0, + expected_total_chars: 13, + }, + Case { + name: "user_replaces_text_prediction_reverses", + original: "keep\ndelete_me\nkeep2", + current: "keep\nadded\nkeep2", + predicted: "keep\ndelete_me\nkeep2", + expected_reversal_chars: 14, + expected_total_chars: 14, + }, + Case { + name: "user_modifies_word_prediction_modifies_differently", + original: "the quick brown fox", + current: "the slow brown fox", + predicted: "the fast brown fox", + expected_reversal_chars: 4, + expected_total_chars: 8, + }, + ]; + + for case in &cases { + let overlap = compute_reversal_overlap(case.original, case.current, case.predicted); + assert_eq!( + overlap.chars_reversing_user_edits, case.expected_reversal_chars, + "Test '{}': expected {} reversal chars, got {}", + case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits + ); + assert_eq!( + overlap.total_chars_in_prediction, case.expected_total_chars, + "Test '{}': expected {} total chars, got {}", + case.name, case.expected_total_chars, overlap.total_chars_in_prediction + ); + } + } + + #[test] + fn test_reverse_diff() { + let forward_diff = "\ +--- a/file.rs ++++ b/file.rs +@@ -1,3 +1,4 @@ + fn main() { ++ let x = 42; + println!(\"hello\"); +}"; + + let reversed = reverse_diff(forward_diff); + + assert!( + reversed.contains("+++ a/file.rs"), + "Should have +++ for old path" + ); + assert!( + reversed.contains("--- b/file.rs"), + "Should have --- for new path" + ); + assert!( + reversed.contains("- let x = 42;"), + "Added line should become deletion" + ); + assert!( + reversed.contains(" fn main()"), + "Context lines should be unchanged" + ); + } + + #[test] + fn test_reverse_diff_roundtrip() { + // Applying a diff and then its reverse should get back to original + let original = "first line\nhello world\nlast line\n"; + let modified = "first line\nhello beautiful world\nlast line\n"; + + // unified_diff doesn't include file headers, but apply_diff_to_string needs them + let diff_body = language::unified_diff(original, modified); + let forward_diff = format!("--- a/file\n+++ b/file\n{}", diff_body); + let reversed_diff = reverse_diff(&forward_diff); + + // Apply forward diff to original + let after_forward = apply_diff_to_string(&forward_diff, original).unwrap(); + assert_eq!(after_forward, modified); + + // Apply reversed diff to modified + let after_reverse = apply_diff_to_string(&reversed_diff, &after_forward).unwrap(); + assert_eq!(after_reverse, original); + } + + #[test] + fn test_filter_edit_history_by_path() { + // Test that filter_edit_history_by_path correctly matches paths when + // the edit history has paths with a repo prefix (e.g., "repo/src/file.rs") + // but the cursor_path doesn't have the repo prefix (e.g., "src/file.rs") + let events = vec![ + Arc::new(zeta_prompt::Event::BufferChange { + path: Arc::from(Path::new("myrepo/src/file.rs")), + old_path: Arc::from(Path::new("myrepo/src/file.rs")), + diff: "@@ -1 +1 @@\n-old\n+new".into(), + predicted: false, + in_open_source_repo: true, + }), + Arc::new(zeta_prompt::Event::BufferChange { + path: Arc::from(Path::new("myrepo/other.rs")), + old_path: Arc::from(Path::new("myrepo/other.rs")), + diff: "@@ -1 +1 @@\n-a\n+b".into(), + predicted: false, + in_open_source_repo: true, + }), + Arc::new(zeta_prompt::Event::BufferChange { + path: Arc::from(Path::new("src/file.rs")), + old_path: Arc::from(Path::new("src/file.rs")), + diff: "@@ -1 +1 @@\n-x\n+y".into(), + predicted: false, + in_open_source_repo: true, + }), + ]; + + // "myrepo/src/file.rs" stripped -> "src/file.rs" matches cursor_path + // "src/file.rs" exact match + let cursor_path = Path::new("src/file.rs"); + let filtered = filter_edit_history_by_path(&events, cursor_path); + assert_eq!( + filtered.len(), + 2, + "Should match myrepo/src/file.rs (stripped) and src/file.rs (exact)" + ); + + // "myrepo/src/file.rs" stripped -> "src/file.rs" != "file.rs" + // "src/file.rs" stripped -> "file.rs" == "file.rs" + let cursor_path = Path::new("file.rs"); + let filtered = filter_edit_history_by_path(&events, cursor_path); + assert_eq!( + filtered.len(), + 1, + "Should only match src/file.rs (stripped to file.rs)" + ); + + // "myrepo/other.rs" stripped -> "other.rs" == "other.rs" + let cursor_path = Path::new("other.rs"); + let filtered = filter_edit_history_by_path(&events, cursor_path); + assert_eq!(filtered.len(), 1, "Should match only myrepo/other.rs"); + } + + #[test] + fn test_reverse_diff_preserves_trailing_newline() { + let diff_with_trailing_newline = "--- a/file\n+++ b/file\n@@ -1 +1 @@\n-old\n+new\n"; + let reversed = reverse_diff(diff_with_trailing_newline); + assert!( + reversed.ends_with('\n'), + "Reversed diff should preserve trailing newline" + ); + + let diff_without_trailing_newline = "--- a/file\n+++ b/file\n@@ -1 +1 @@\n-old\n+new"; + let reversed = reverse_diff(diff_without_trailing_newline); + assert!( + !reversed.ends_with('\n'), + "Reversed diff should not add trailing newline if original didn't have one" + ); + } +} diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index 1b403f50a2590f3e5a5fd2c52bf3c31897745621..010763e507475088cbe686fc7fbfc6a0e1427ad1 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -6,6 +6,7 @@ use crate::{ parse_output::parse_prediction_output, predict::run_prediction, progress::{ExampleProgress, Step}, + reversal_tracking, }; use anyhow::Context as _; use edit_prediction::udiff::apply_diff_to_string; @@ -49,8 +50,12 @@ pub async fn run_scoring( exact_lines_tp: 0, exact_lines_fp: 0, exact_lines_fn: 0, + reversal_ratio: 0.0, }; + let prompt_inputs = example.prompt_inputs.as_ref().unwrap(); + let cursor_path = example.spec.cursor_path.as_ref(); + progress.set_substatus("computing metrics"); let mut scores = vec![]; for prediction in &example.predictions { @@ -98,12 +103,20 @@ pub async fn run_scoring( .max_by_key(|m| m.true_positives) .unwrap_or_default(); + // Compute reversal ratio + let reversal_ratio = reversal_tracking::compute_prediction_reversal_ratio( + prompt_inputs, + &actual_text, + cursor_path, + ); + scores.push(ExampleScore { delta_chr_f: best_delta_chr_f, braces_disbalance, exact_lines_tp: best_exact_lines.true_positives, exact_lines_fp: best_exact_lines.false_positives, exact_lines_fn: best_exact_lines.false_negatives, + reversal_ratio, }); } @@ -114,17 +127,18 @@ pub async fn run_scoring( pub fn print_report(examples: &[Example]) { use crate::metrics::ClassificationMetrics; - const LINE_WIDTH: usize = 100; + const LINE_WIDTH: usize = 110; let separator = "─".repeat(LINE_WIDTH); println!("{}", separator); println!( - "{:<40} {:>8} {:>5} {:>4} {:>4} {:>4} {:>7} {:>7} {:>7}", - "Example", "DeltaChrF", "Brace", "TP", "FP", "FN", "Prec", "Rec", "F1" + "{:<40} {:>8} {:>5} {:>4} {:>4} {:>4} {:>7} {:>7} {:>7} {:>7}", + "Example", "DeltaChrF", "Brace", "TP", "FP", "FN", "Prec", "Rec", "F1", "Revert" ); println!("{}", separator); let mut all_delta_chr_f_scores = Vec::new(); + let mut all_reversal_ratios = Vec::new(); let mut braces_disbalance_sum: usize = 0; let mut total_exact_lines = ClassificationMetrics::default(); let mut total_scores: usize = 0; @@ -138,7 +152,7 @@ pub fn print_report(examples: &[Example]) { }; println!( - "{:<40} {:>8.2} {:>5} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}%", + "{:<40} {:>8.2} {:>5} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}% {:>6.1}%", truncate_name(&example.spec.name, 40), score.delta_chr_f, score.braces_disbalance, @@ -147,10 +161,12 @@ pub fn print_report(examples: &[Example]) { score.exact_lines_fn, exact_lines.precision() * 100.0, exact_lines.recall() * 100.0, - exact_lines.f1() * 100.0 + exact_lines.f1() * 100.0, + score.reversal_ratio * 100.0 ); all_delta_chr_f_scores.push(score.delta_chr_f); + all_reversal_ratios.push(score.reversal_ratio); total_scores += 1; braces_disbalance_sum += score.braces_disbalance; total_exact_lines.true_positives += score.exact_lines_tp; @@ -164,10 +180,12 @@ pub fn print_report(examples: &[Example]) { if !all_delta_chr_f_scores.is_empty() { let avg_delta_chr_f: f32 = all_delta_chr_f_scores.iter().sum::() / all_delta_chr_f_scores.len() as f32; + let avg_reversal_ratio: f32 = + all_reversal_ratios.iter().sum::() / all_reversal_ratios.len() as f32; let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32; println!( - "{:<40} {:>8.2} {:>5.1} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}%", + "{:<40} {:>8.2} {:>5.1} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}% {:>6.1}%", "TOTAL / AVERAGE", avg_delta_chr_f, braces_disbalance_avg, @@ -176,7 +194,8 @@ pub fn print_report(examples: &[Example]) { total_exact_lines.false_negatives, total_exact_lines.precision() * 100.0, total_exact_lines.recall() * 100.0, - total_exact_lines.f1() * 100.0 + total_exact_lines.f1() * 100.0, + avg_reversal_ratio * 100.0 ); println!("{}", separator); } @@ -203,12 +222,14 @@ pub struct SummaryJson { pub exact_lines_precision: f64, pub exact_lines_recall: f64, pub exact_lines_f1: f64, + pub avg_reversal_ratio: f32, } pub fn compute_summary(examples: &[Example]) -> SummaryJson { use crate::metrics::ClassificationMetrics; let mut all_delta_chr_f_scores = Vec::new(); + let mut all_reversal_ratios = Vec::new(); let mut braces_disbalance_sum: usize = 0; let mut total_exact_lines = ClassificationMetrics::default(); let mut total_scores: usize = 0; @@ -216,6 +237,7 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { for example in examples { for score in example.score.iter() { all_delta_chr_f_scores.push(score.delta_chr_f); + all_reversal_ratios.push(score.reversal_ratio); total_scores += 1; braces_disbalance_sum += score.braces_disbalance; total_exact_lines.true_positives += score.exact_lines_tp; @@ -230,6 +252,12 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { all_delta_chr_f_scores.iter().sum::() / all_delta_chr_f_scores.len() as f32 }; + let avg_reversal_ratio = if all_reversal_ratios.is_empty() { + 0.0 + } else { + all_reversal_ratios.iter().sum::() / all_reversal_ratios.len() as f32 + }; + let avg_braces_disbalance = if total_scores == 0 { 0.0 } else { @@ -246,6 +274,7 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { exact_lines_precision: total_exact_lines.precision(), exact_lines_recall: total_exact_lines.recall(), exact_lines_f1: total_exact_lines.f1(), + avg_reversal_ratio, } }