From cfd6144af817cc71242d7f11c5e611d8bac7ae0c Mon Sep 17 00:00:00 2001 From: Ben Kunkle Date: Fri, 30 Jan 2026 15:01:16 -0600 Subject: [PATCH] ep_cli: Only check cursor excerpt for reversals (#48044) Closes #ISSUE Release Notes: - N/A *or* Added/Fixed/Improved ... --- crates/edit_prediction/src/capture_example.rs | 1 + .../src/edit_prediction_tests.rs | 1 + crates/edit_prediction/src/example_spec.rs | 2 + crates/edit_prediction/src/mercury.rs | 4 +- crates/edit_prediction/src/prediction.rs | 1 + crates/edit_prediction/src/sweep_ai.rs | 1 + crates/edit_prediction/src/zeta1.rs | 2 + crates/edit_prediction/src/zeta2.rs | 2 + crates/edit_prediction_cli/src/example.rs | 2 + .../edit_prediction_cli/src/format_prompt.rs | 1 + .../edit_prediction_cli/src/load_project.rs | 1 + crates/edit_prediction_cli/src/main.rs | 2 + .../edit_prediction_cli/src/pull_examples.rs | 1 + .../src/retrieve_context.rs | 1 + .../src/reversal_tracking.rs | 1437 ++++++++++++++++- crates/zeta_prompt/src/zeta_prompt.rs | 3 + 16 files changed, 1413 insertions(+), 49 deletions(-) diff --git a/crates/edit_prediction/src/capture_example.rs b/crates/edit_prediction/src/capture_example.rs index 177df98f135cc97a70de6ba73ca4c711bf6e40b8..a319c95ec6a501bbca6213a44aa205aba4156f73 100644 --- a/crates/edit_prediction/src/capture_example.rs +++ b/crates/edit_prediction/src/capture_example.rs @@ -156,6 +156,7 @@ pub fn capture_example( cursor_offset: full_cursor_offset, cursor_row: cursor_point.row, cursor_column: cursor_point.column, + excerpt_start_row: Some(0), events: captured_events, related_files: captured_related_files, } diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index babb9f857edd6358c8f0c6df1477ff7702bb24a6..14e6efa903d0cf074477564ebd6d1c6cfd8bdf30 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -1444,6 +1444,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { cursor_excerpt: "".into(), editable_range_in_excerpt: 0..0, cursor_offset_in_excerpt: 0, + excerpt_start_row: None, }, buffer_snapshotted_at: Instant::now(), response_received_at: Instant::now(), diff --git a/crates/edit_prediction/src/example_spec.rs b/crates/edit_prediction/src/example_spec.rs index 530a6216216cdf05773c0b47bca07a0ef1e320af..d3ba165e51048c14d6e6d9bcc857557bf352c81a 100644 --- a/crates/edit_prediction/src/example_spec.rs +++ b/crates/edit_prediction/src/example_spec.rs @@ -61,6 +61,8 @@ pub struct CapturedPromptInput { pub cursor_offset: usize, pub cursor_row: u32, pub cursor_column: u32, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub excerpt_start_row: Option, pub events: Vec, pub related_files: Vec, } diff --git a/crates/edit_prediction/src/mercury.rs b/crates/edit_prediction/src/mercury.rs index 09d301c5fa2e7c0edf964c626d3d40d9764c33cc..3905939f05358fba18379ea17f8407d26f3b3a2a 100644 --- a/crates/edit_prediction/src/mercury.rs +++ b/crates/edit_prediction/src/mercury.rs @@ -75,6 +75,7 @@ impl Mercury { ); let context_offset_range = context_range.to_offset(&snapshot); + let context_start_row = context_range.start.row; let editable_offset_range = editable_range.to_offset(&snapshot); @@ -82,7 +83,7 @@ impl Mercury { events, related_files, cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - - context_range.start.to_offset(&snapshot), + - context_offset_range.start, cursor_path: full_path.clone(), cursor_excerpt: snapshot .text_for_range(context_range) @@ -91,6 +92,7 @@ impl Mercury { editable_range_in_excerpt: (editable_offset_range.start - context_offset_range.start) ..(editable_offset_range.end - context_offset_range.start), + excerpt_start_row: Some(context_start_row), }; let prompt = build_prompt(&inputs); diff --git a/crates/edit_prediction/src/prediction.rs b/crates/edit_prediction/src/prediction.rs index 48b7a85756d75f1d10345dee2fb6b86f6d3c0b53..9035d326c2295671ee382e3b905508b577929b9c 100644 --- a/crates/edit_prediction/src/prediction.rs +++ b/crates/edit_prediction/src/prediction.rs @@ -153,6 +153,7 @@ mod tests { cursor_offset_in_excerpt: 0, cursor_excerpt: "".into(), editable_range_in_excerpt: 0..0, + excerpt_start_row: None, }, buffer_snapshotted_at: Instant::now(), response_received_at: Instant::now(), diff --git a/crates/edit_prediction/src/sweep_ai.rs b/crates/edit_prediction/src/sweep_ai.rs index cb1f89c18778cf90b6099dd4cccb4119fc1ddfd8..15dccee48bea85da2d8e72c0860f4e9f5f4de77f 100644 --- a/crates/edit_prediction/src/sweep_ai.rs +++ b/crates/edit_prediction/src/sweep_ai.rs @@ -214,6 +214,7 @@ impl SweepAi { // we actually don't know editable_range_in_excerpt: 0..inputs.snapshot.len(), cursor_offset_in_excerpt: request_body.cursor_position, + excerpt_start_row: Some(0), }; send_started_event( diff --git a/crates/edit_prediction/src/zeta1.rs b/crates/edit_prediction/src/zeta1.rs index 6ce1a94228fc24cafcf95321017c9bb30b045ae1..c82a949d669f03f4e811c6ef8a3479a8bc85c4c3 100644 --- a/crates/edit_prediction/src/zeta1.rs +++ b/crates/edit_prediction/src/zeta1.rs @@ -129,6 +129,7 @@ pub(crate) fn request_prediction_with_zeta1( .await; let context_start_offset = context_range.start.to_offset(&snapshot); + let context_start_row = context_range.start.row; let editable_offset_range = editable_range.to_offset(&snapshot); let inputs = ZetaPromptInput { @@ -142,6 +143,7 @@ pub(crate) fn request_prediction_with_zeta1( editable_range_in_excerpt: (editable_range.start - context_start_offset) ..(editable_offset_range.end - context_start_offset), cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset, + excerpt_start_row: Some(context_start_row), }; if let Some(debug_tx) = &debug_tx { diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs index 3cdd0ff8a5751bd268e5f01053efc86c6a8fab9a..0e776f1ecc5f862a5676505900de2b3eef5029fc 100644 --- a/crates/edit_prediction/src/zeta2.rs +++ b/crates/edit_prediction/src/zeta2.rs @@ -249,6 +249,7 @@ pub fn zeta2_prompt_input( ); let context_start_offset = context_range.start.to_offset(snapshot); + let context_start_row = context_range.start.row; let editable_offset_range = editable_range.to_offset(snapshot); let cursor_offset_in_excerpt = cursor_offset - context_start_offset; let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset) @@ -262,6 +263,7 @@ pub fn zeta2_prompt_input( .into(), editable_range_in_excerpt, cursor_offset_in_excerpt, + excerpt_start_row: Some(context_start_row), events, related_files, }; diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index 871175508e105650c78377c8584554b275d54ff6..0c8905b6291b15eaf2dcf3b80573fcc8966e79a5 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -65,6 +65,8 @@ pub struct ExamplePromptInputs { pub cursor_row: u32, pub cursor_column: u32, pub cursor_offset: usize, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub excerpt_start_row: Option, pub edit_history: Vec>, pub related_files: Option>, } diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index 5588200f745b90d4f92b0a87f45571753f3b0d6f..ed248070ccc960513324ebdec1dc68c7cc2042be 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -97,6 +97,7 @@ pub async fn run_format_prompt( cursor_excerpt: prompt_inputs.content[context_range].to_string().into(), editable_range_in_excerpt, cursor_offset_in_excerpt, + excerpt_start_row: prompt_inputs.excerpt_start_row, events: prompt_inputs.edit_history.clone(), related_files: prompt_inputs.related_files.clone().unwrap_or_default(), }; diff --git a/crates/edit_prediction_cli/src/load_project.rs b/crates/edit_prediction_cli/src/load_project.rs index 2635c11f61172960bd6a5ce231c280a64f21c35f..1458a5e8cc46b8c0424812ad817a5d5257f5759d 100644 --- a/crates/edit_prediction_cli/src/load_project.rs +++ b/crates/edit_prediction_cli/src/load_project.rs @@ -78,6 +78,7 @@ pub async fn run_load_project( cursor_row: cursor_point.row, cursor_column: cursor_point.column, cursor_offset: cursor_position.to_offset(&buffer), + excerpt_start_row: Some(0), edit_history, related_files: example .prompt_inputs diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 96662afa13a86ffb1b99d37aa50232eeafed9928..3124716791b7fa0468be44be2a78f710ecddf554 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -62,6 +62,8 @@ struct EpArgs { printenv: bool, #[clap(long, default_value_t = 10, global = true)] max_parallelism: usize, + /// The limit for the number of examples to process + /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake #[clap(long, global = true)] limit: Option, #[clap(long, global = true)] diff --git a/crates/edit_prediction_cli/src/pull_examples.rs b/crates/edit_prediction_cli/src/pull_examples.rs index 953f0edacebf8ccf20450b7b4de2c5ba75baf2bb..198848f1d6c088b1103b516f71249ac9f81516d7 100644 --- a/crates/edit_prediction_cli/src/pull_examples.rs +++ b/crates/edit_prediction_cli/src/pull_examples.rs @@ -1269,6 +1269,7 @@ fn build_example_from_snowflake( cursor_offset, cursor_row, cursor_column, + excerpt_start_row: None, events, related_files, }), diff --git a/crates/edit_prediction_cli/src/retrieve_context.rs b/crates/edit_prediction_cli/src/retrieve_context.rs index 76ebb96509b851174a6bf29dc654e008e4532192..a000f69c768b3ac370e4f6a202e8f1250a28d6da 100644 --- a/crates/edit_prediction_cli/src/retrieve_context.rs +++ b/crates/edit_prediction_cli/src/retrieve_context.rs @@ -49,6 +49,7 @@ pub async fn run_context_retrieval( cursor_row: captured.cursor_row, cursor_column: captured.cursor_column, cursor_offset: captured.cursor_offset, + excerpt_start_row: captured.excerpt_start_row, edit_history, related_files: Some(related_files), }); diff --git a/crates/edit_prediction_cli/src/reversal_tracking.rs b/crates/edit_prediction_cli/src/reversal_tracking.rs index 05d2e3b9d386ac20d2ad99b42707a76dd2fa845d..139730bbe3a8f5788e7ef0aa8aaf32344802c685 100644 --- a/crates/edit_prediction_cli/src/reversal_tracking.rs +++ b/crates/edit_prediction_cli/src/reversal_tracking.rs @@ -7,7 +7,310 @@ use language::text_diff; use crate::example::ExamplePromptInputs; -pub fn reverse_diff(diff: &str) -> String { +fn apply_diff_to_string_lenient(diff_str: &str, text: &str) -> String { + let hunks = parse_diff_hunks(diff_str); + let mut result = text.to_string(); + + for hunk in hunks { + let hunk_diff = format!("--- a/file\n+++ b/file\n{}", format_hunk(&hunk)); + if let Ok(updated) = apply_diff_to_string(&hunk_diff, &result) { + result = updated; + } + } + + result +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct ParsedHunk { + old_start: u32, + old_count: u32, + new_start: u32, + new_count: u32, + lines: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum HunkLine { + Context(String), + Addition(String), + Deletion(String), +} + +fn parse_hunk_header(line: &str) -> Option<(u32, u32, u32, u32)> { + let line = line.strip_prefix("@@ -")?; + let (old_part, rest) = line.split_once(' ')?; + let rest = rest.strip_prefix('+')?; + let (new_part, _) = rest.split_once(" @@")?; + + let (old_start, old_count) = if let Some((start, count)) = old_part.split_once(',') { + (start.parse().ok()?, count.parse().ok()?) + } else { + (old_part.parse().ok()?, 1) + }; + + let (new_start, new_count) = if let Some((start, count)) = new_part.split_once(',') { + (start.parse().ok()?, count.parse().ok()?) + } else { + (new_part.parse().ok()?, 1) + }; + + Some((old_start, old_count, new_start, new_count)) +} + +fn parse_diff_hunks(diff: &str) -> Vec { + let mut hunks = Vec::new(); + let mut current_hunk: Option = None; + + for line in diff.lines() { + if let Some((old_start, old_count, new_start, new_count)) = parse_hunk_header(line) { + if let Some(hunk) = current_hunk.take() { + hunks.push(hunk); + } + current_hunk = Some(ParsedHunk { + old_start, + old_count, + new_start, + new_count, + lines: Vec::new(), + }); + } else if let Some(ref mut hunk) = current_hunk { + if let Some(stripped) = line.strip_prefix('+') { + hunk.lines.push(HunkLine::Addition(stripped.to_string())); + } else if let Some(stripped) = line.strip_prefix('-') { + hunk.lines.push(HunkLine::Deletion(stripped.to_string())); + } else if let Some(stripped) = line.strip_prefix(' ') { + hunk.lines.push(HunkLine::Context(stripped.to_string())); + } else if line.is_empty() { + hunk.lines.push(HunkLine::Context(String::new())); + } + } + } + + if let Some(hunk) = current_hunk { + hunks.push(hunk); + } + + hunks +} + +fn format_hunk(hunk: &ParsedHunk) -> String { + let mut result = format!( + "@@ -{},{} +{},{} @@\n", + hunk.old_start, hunk.old_count, hunk.new_start, hunk.new_count + ); + for line in &hunk.lines { + match line { + HunkLine::Context(text) => { + result.push(' '); + result.push_str(text); + result.push('\n'); + } + HunkLine::Addition(text) => { + result.push('+'); + result.push_str(text); + result.push('\n'); + } + HunkLine::Deletion(text) => { + result.push('-'); + result.push_str(text); + result.push('\n'); + } + } + } + result +} + +fn filter_diff_hunks_by_excerpt( + diff: &str, + excerpt_start_row: u32, + excerpt_row_count: u32, +) -> (String, i32) { + let hunks = parse_diff_hunks(diff); + let excerpt_start_0based = excerpt_start_row; + let excerpt_end_0based = excerpt_start_row + excerpt_row_count; + + let mut filtered_hunks = Vec::new(); + let mut cumulative_line_offset: i32 = 0; + + for hunk in hunks { + let hunk_start_0based = hunk.new_start.saturating_sub(1); + let hunk_end_0based = hunk_start_0based + hunk.new_count; + + let additions: i32 = hunk + .lines + .iter() + .filter(|l| matches!(l, HunkLine::Addition(_))) + .count() as i32; + let deletions: i32 = hunk + .lines + .iter() + .filter(|l| matches!(l, HunkLine::Deletion(_))) + .count() as i32; + let hunk_line_delta = additions - deletions; + + if hunk_end_0based <= excerpt_start_0based { + cumulative_line_offset += hunk_line_delta; + continue; + } + + if hunk_start_0based >= excerpt_end_0based { + continue; + } + + let mut filtered_lines = Vec::new(); + let mut current_row_0based = hunk_start_0based; + let mut filtered_old_count = 0u32; + let mut filtered_new_count = 0u32; + let mut first_included_row: Option = None; + + for line in &hunk.lines { + match line { + HunkLine::Context(text) => { + if current_row_0based >= excerpt_start_0based + && current_row_0based < excerpt_end_0based + { + if first_included_row.is_none() { + first_included_row = Some(current_row_0based); + } + filtered_lines.push(HunkLine::Context(text.clone())); + filtered_old_count += 1; + filtered_new_count += 1; + } + current_row_0based += 1; + } + HunkLine::Addition(text) => { + if current_row_0based >= excerpt_start_0based + && current_row_0based < excerpt_end_0based + { + if first_included_row.is_none() { + first_included_row = Some(current_row_0based); + } + filtered_lines.push(HunkLine::Addition(text.clone())); + filtered_new_count += 1; + } + current_row_0based += 1; + } + HunkLine::Deletion(text) => { + if current_row_0based >= excerpt_start_0based + && current_row_0based < excerpt_end_0based + { + if first_included_row.is_none() { + first_included_row = Some(current_row_0based); + } + filtered_lines.push(HunkLine::Deletion(text.clone())); + filtered_old_count += 1; + } + } + } + } + + if !filtered_lines.is_empty() { + let first_row = first_included_row.unwrap_or(excerpt_start_0based); + let new_start_1based = (first_row - excerpt_start_0based) + 1; + + filtered_hunks.push(ParsedHunk { + old_start: new_start_1based, + old_count: filtered_old_count, + new_start: new_start_1based, + new_count: filtered_new_count, + lines: filtered_lines, + }); + } + + cumulative_line_offset += hunk_line_delta; + } + + let mut result = String::new(); + for hunk in &filtered_hunks { + result.push_str(&format_hunk(hunk)); + } + + (result, cumulative_line_offset) +} + +fn compute_excerpt_aware_reversal_overlap( + edit_history_diffs: &[&str], + excerpt_content: &str, + excerpt_start_row: u32, + predicted_content: &str, +) -> ReversalOverlap { + let mut current_content = excerpt_content.to_string(); + let mut current_excerpt_start_row = excerpt_start_row; + + for diff in edit_history_diffs.iter().rev() { + if diff.is_empty() { + continue; + } + + let current_row_count = current_content.lines().count() as u32; + let (filtered_diff, _line_offset) = + filter_diff_hunks_by_excerpt(diff, current_excerpt_start_row, current_row_count.max(1)); + + if filtered_diff.is_empty() { + let hunks = parse_diff_hunks(diff); + for hunk in hunks { + let hunk_end = hunk.new_start.saturating_sub(1) + hunk.new_count; + if hunk_end <= current_excerpt_start_row { + let additions: u32 = hunk + .lines + .iter() + .filter(|l| matches!(l, HunkLine::Addition(_))) + .count() as u32; + let deletions: u32 = hunk + .lines + .iter() + .filter(|l| matches!(l, HunkLine::Deletion(_))) + .count() as u32; + if additions >= deletions { + current_excerpt_start_row = + current_excerpt_start_row.saturating_sub(additions - deletions); + } else { + current_excerpt_start_row += deletions - additions; + } + } + } + continue; + } + + let reversed = reverse_diff(&format!("--- a/file\n+++ b/file\n{}", filtered_diff)); + match apply_diff_to_string(&reversed, ¤t_content) { + Ok(updated) => { + current_content = updated; + } + Err(_) => { + continue; + } + } + + let hunks = parse_diff_hunks(diff); + for hunk in hunks { + let hunk_end = hunk.new_start.saturating_sub(1) + hunk.new_count; + if hunk_end <= current_excerpt_start_row { + let additions: u32 = hunk + .lines + .iter() + .filter(|l| matches!(l, HunkLine::Addition(_))) + .count() as u32; + let deletions: u32 = hunk + .lines + .iter() + .filter(|l| matches!(l, HunkLine::Deletion(_))) + .count() as u32; + if additions >= deletions { + current_excerpt_start_row = + current_excerpt_start_row.saturating_sub(additions - deletions); + } else { + current_excerpt_start_row += deletions - additions; + } + } + } + } + + compute_reversal_overlap(¤t_content, excerpt_content, predicted_content) +} + +fn reverse_diff(diff: &str) -> String { let mut result: String = diff .lines() .map(|line| { @@ -145,17 +448,20 @@ fn normalize_extension_edits(edits: Vec) -> Vec { } if is_subsequence(&edit.old_text, &edit.new_text) { - let inserted_len = edit.new_text.len() - edit.old_text.len(); + let inserted_char_count = + edit.new_text.chars().count() - edit.old_text.chars().count(); GranularEdit { range: edit.range.start..edit.range.start, old_text: String::new(), - new_text: edit.new_text.chars().take(inserted_len).collect(), + new_text: edit.new_text.chars().take(inserted_char_count).collect(), } } else if is_subsequence(&edit.new_text, &edit.old_text) { - let deleted_len = edit.old_text.len() - edit.new_text.len(); + let deleted_char_count = + edit.old_text.chars().count() - edit.new_text.chars().count(); + let deleted_text: String = edit.old_text.chars().take(deleted_char_count).collect(); GranularEdit { - range: edit.range.start..edit.range.start + deleted_len, - old_text: edit.old_text.chars().take(deleted_len).collect(), + range: edit.range.start..edit.range.start + deleted_text.len(), + old_text: deleted_text, new_text: String::new(), } } else { @@ -185,7 +491,7 @@ fn compute_reversal_overlap( let total_chars_in_prediction: usize = prediction_edits .iter() - .map(|e| e.new_text.len() + e.old_text.len()) + .map(|e| e.new_text.chars().count() + e.old_text.chars().count()) .sum(); ReversalOverlap { @@ -212,7 +518,10 @@ fn compute_reversed_additions( .min(history_addition.range_in_current.end); if overlap_start < overlap_end { - reversed_chars += overlap_end - overlap_start; + let relative_start = overlap_start - pred_edit.range.start; + let relative_end = overlap_end - pred_edit.range.start; + let overlap_text = &pred_edit.old_text[relative_start..relative_end]; + reversed_chars += overlap_text.chars().count(); } } } @@ -271,7 +580,7 @@ fn compute_lcs_length(a: &str, b: &str) -> usize { prev[n] } -pub fn filter_edit_history_by_path<'a>( +fn filter_edit_history_by_path<'a>( edit_history: &'a [Arc], cursor_path: &std::path::Path, ) -> Vec<&'a zeta_prompt::Event> { @@ -294,7 +603,7 @@ pub fn filter_edit_history_by_path<'a>( .collect() } -pub fn extract_diff_from_event(event: &zeta_prompt::Event) -> &str { +fn extract_diff_from_event(event: &zeta_prompt::Event) -> &str { match event { zeta_prompt::Event::BufferChange { diff, .. } => diff.as_str(), } @@ -310,6 +619,20 @@ pub fn compute_prediction_reversal_ratio( let edit_history: &[Arc] = &prompt_inputs.edit_history; let relevant_events = filter_edit_history_by_path(edit_history, cursor_path); + if let Some(excerpt_start_row) = prompt_inputs.excerpt_start_row { + let diffs: Vec<&str> = relevant_events + .iter() + .map(|e| extract_diff_from_event(e)) + .collect(); + let overlap = compute_excerpt_aware_reversal_overlap( + &diffs, + current_content, + excerpt_start_row, + predicted_content, + ); + return overlap.ratio(); + } + let mut original_content = current_content.to_string(); for event in relevant_events.into_iter().rev() { let diff = extract_diff_from_event(event); @@ -320,12 +643,8 @@ pub fn compute_prediction_reversal_ratio( let with_headers = format!("--- a/file\n+++ b/file\n{}", reversed); match apply_diff_to_string(&with_headers, &original_content) { Ok(updated_content) => original_content = updated_content, - Err(err) => { - log::warn!( - "Failed to reconstruct original content for reversal tracking: Failed to apply reversed diff: {:#}", - err - ); - return 0.0; + Err(_) => { + original_content = apply_diff_to_string_lenient(&reversed, &original_content); } } } @@ -338,6 +657,7 @@ pub fn compute_prediction_reversal_ratio( mod tests { use super::*; use edit_prediction::udiff::apply_diff_to_string; + use indoc::indoc; #[test] fn test_reversal_overlap() { @@ -353,17 +673,35 @@ mod tests { let cases = [ Case { name: "user_adds_line_prediction_removes_it", - original: "a\nb\nc", - current: "a\nnew line\nb\nc", - predicted: "a\nb\nc", + original: indoc! {" + a + b + c"}, + current: indoc! {" + a + new line + b + c"}, + predicted: indoc! {" + a + b + c"}, expected_reversal_chars: 9, expected_total_chars: 9, }, Case { name: "user_deletes_line_prediction_restores_it", - original: "a\ndeleted\nb", - current: "a\nb", - predicted: "a\ndeleted\nb", + original: indoc! {" + a + deleted + b"}, + current: indoc! {" + a + b"}, + predicted: indoc! {" + a + deleted + b"}, expected_reversal_chars: 8, expected_total_chars: 8, }, @@ -385,9 +723,18 @@ mod tests { }, Case { name: "independent_edits_different_locations", - original: "line1\nline2\nline3", - current: "LINE1\nline2\nline3", - predicted: "LINE1\nline2\nLINE3", + original: indoc! {" + line1 + line2 + line3"}, + current: indoc! {" + LINE1 + line2 + line3"}, + predicted: indoc! {" + LINE1 + line2 + LINE3"}, expected_reversal_chars: 0, expected_total_chars: 10, }, @@ -401,9 +748,18 @@ mod tests { }, Case { name: "user_replaces_text_prediction_reverses", - original: "keep\ndelete_me\nkeep2", - current: "keep\nadded\nkeep2", - predicted: "keep\ndelete_me\nkeep2", + original: indoc! {" + keep + delete_me + keep2"}, + current: indoc! {" + keep + added + keep2"}, + predicted: indoc! {" + keep + delete_me + keep2"}, expected_reversal_chars: 14, expected_total_chars: 14, }, @@ -523,9 +879,13 @@ mod tests { }, Case { name: "infix insertion not reversal", - original: "from my_project import Foo\n", - current: "ifrom my_project import Foo\n", - predicted: indoc::indoc! {" + original: indoc! {" + from my_project import Foo + "}, + current: indoc! {" + ifrom my_project import Foo + "}, + predicted: indoc! {" import from my_project import Foo "}, @@ -544,9 +904,9 @@ mod tests { name: "multiple insertions no reversal", original: "print(\"Hello, World!\")", current: "sys.(\"Hello, World!\")", - predicted: "sys.stdout.write(\"Hello, World!\n\")", + predicted: "sys.stdout.write(\"Hello, World!\\n\")", expected_reversal_chars: 0, - expected_total_chars: 13, + expected_total_chars: 14, }, ]; @@ -567,14 +927,14 @@ mod tests { #[test] fn test_reverse_diff() { - let forward_diff = "\ ---- a/file.rs -+++ b/file.rs -@@ -1,3 +1,4 @@ - fn main() { -+ let x = 42; - println!(\"hello\"); -}"; + let forward_diff = indoc! {" + --- a/file.rs + +++ b/file.rs + @@ -1,3 +1,4 @@ + fn main() { + + let x = 42; + println!(\"hello\"); + }"}; let reversed = reverse_diff(forward_diff); @@ -599,8 +959,16 @@ mod tests { #[test] fn test_reverse_diff_roundtrip() { // Applying a diff and then its reverse should get back to original - let original = "first line\nhello world\nlast line\n"; - let modified = "first line\nhello beautiful world\nlast line\n"; + let original = indoc! {" + first line + hello world + last line + "}; + let modified = indoc! {" + first line + hello beautiful world + last line + "}; // unified_diff doesn't include file headers, but apply_diff_to_string needs them let diff_body = language::unified_diff(original, modified); @@ -625,21 +993,33 @@ mod tests { Arc::new(zeta_prompt::Event::BufferChange { path: Arc::from(Path::new("myrepo/src/file.rs")), old_path: Arc::from(Path::new("myrepo/src/file.rs")), - diff: "@@ -1 +1 @@\n-old\n+new".into(), + diff: indoc! {" + @@ -1 +1 @@ + -old + +new"} + .into(), predicted: false, in_open_source_repo: true, }), Arc::new(zeta_prompt::Event::BufferChange { path: Arc::from(Path::new("myrepo/other.rs")), old_path: Arc::from(Path::new("myrepo/other.rs")), - diff: "@@ -1 +1 @@\n-a\n+b".into(), + diff: indoc! {" + @@ -1 +1 @@ + -a + +b"} + .into(), predicted: false, in_open_source_repo: true, }), Arc::new(zeta_prompt::Event::BufferChange { path: Arc::from(Path::new("src/file.rs")), old_path: Arc::from(Path::new("src/file.rs")), - diff: "@@ -1 +1 @@\n-x\n+y".into(), + diff: indoc! {" + @@ -1 +1 @@ + -x + +y"} + .into(), predicted: false, in_open_source_repo: true, }), @@ -673,18 +1053,979 @@ mod tests { #[test] fn test_reverse_diff_preserves_trailing_newline() { - let diff_with_trailing_newline = "--- a/file\n+++ b/file\n@@ -1 +1 @@\n-old\n+new\n"; + let diff_with_trailing_newline = indoc! {" + --- a/file + +++ b/file + @@ -1 +1 @@ + -old + +new + "}; let reversed = reverse_diff(diff_with_trailing_newline); assert!( reversed.ends_with('\n'), "Reversed diff should preserve trailing newline" ); - let diff_without_trailing_newline = "--- a/file\n+++ b/file\n@@ -1 +1 @@\n-old\n+new"; + let diff_without_trailing_newline = indoc! {" + --- a/file + +++ b/file + @@ -1 +1 @@ + -old + +new"}; let reversed = reverse_diff(diff_without_trailing_newline); assert!( !reversed.ends_with('\n'), "Reversed diff should not add trailing newline if original didn't have one" ); } + + #[test] + fn test_filter_hunks_by_excerpt_region() { + struct Case { + name: &'static str, + diff: &'static str, + excerpt_start_row: u32, + excerpt_row_count: u32, + expected_filtered_diff: &'static str, + expected_line_offset: i32, + } + + let cases = [ + Case { + name: "hunk_entirely_before_excerpt", + diff: indoc! {" + @@ -1,3 +1,4 @@ + line1 + +inserted + line2 + line3 + "}, + excerpt_start_row: 10, + excerpt_row_count: 5, + expected_filtered_diff: "", + expected_line_offset: 1, + }, + Case { + name: "hunk_entirely_inside_excerpt", + diff: indoc! {" + @@ -12,3 +12,4 @@ + line12 + +inserted + line13 + line14 + "}, + excerpt_start_row: 10, + excerpt_row_count: 10, + expected_filtered_diff: indoc! {" + @@ -2,3 +2,4 @@ + line12 + +inserted + line13 + line14 + "}, + expected_line_offset: 1, + }, + Case { + name: "hunk_entirely_after_excerpt", + diff: indoc! {" + @@ -50,3 +50,4 @@ + line50 + +inserted + line51 + line52 + "}, + excerpt_start_row: 10, + excerpt_row_count: 5, + expected_filtered_diff: "", + expected_line_offset: 0, + }, + Case { + name: "hunk_straddles_excerpt_start", + diff: indoc! {" + @@ -8,5 +8,6 @@ + line8 + line9 + +inserted + line10 + line11 + line12 + "}, + excerpt_start_row: 10, + excerpt_row_count: 10, + expected_filtered_diff: indoc! {" + @@ -1,3 +1,3 @@ + line10 + line11 + line12 + "}, + expected_line_offset: 1, + }, + Case { + name: "hunk_straddles_excerpt_end", + diff: indoc! {" + @@ -18,5 +18,6 @@ + line18 + line19 + +inserted + line20 + line21 + line22 + "}, + excerpt_start_row: 10, + excerpt_row_count: 10, + expected_filtered_diff: indoc! {" + @@ -8,2 +8,3 @@ + line18 + line19 + +inserted + "}, + expected_line_offset: 1, + }, + Case { + name: "multiple_hunks_mixed", + diff: indoc! {" + @@ -1,2 +1,3 @@ + line1 + +before_excerpt + line2 + @@ -12,2 +13,3 @@ + line12 + +inside_excerpt + line13 + @@ -50,2 +52,3 @@ + line50 + +after_excerpt + line51 + "}, + excerpt_start_row: 10, + excerpt_row_count: 10, + expected_filtered_diff: indoc! {" + @@ -3,2 +3,3 @@ + line12 + +inside_excerpt + line13 + "}, + expected_line_offset: 2, + }, + Case { + name: "deletion_before_excerpt", + diff: indoc! {" + @@ -1,4 +1,3 @@ + line1 + -deleted + line2 + line3 + "}, + excerpt_start_row: 10, + excerpt_row_count: 5, + expected_filtered_diff: "", + expected_line_offset: -1, + }, + Case { + name: "deletion_inside_excerpt", + diff: indoc! {" + @@ -12,4 +12,3 @@ + line12 + -deleted + line13 + line14 + "}, + excerpt_start_row: 10, + excerpt_row_count: 10, + expected_filtered_diff: indoc! {" + @@ -2,4 +2,3 @@ + line12 + -deleted + line13 + line14 + "}, + expected_line_offset: -1, + }, + Case { + name: "empty_diff", + diff: "", + excerpt_start_row: 10, + excerpt_row_count: 5, + expected_filtered_diff: "", + expected_line_offset: 0, + }, + Case { + name: "hunk_spans_entire_excerpt", + diff: indoc! {" + @@ -8,10 +8,12 @@ + line8 + line9 + line10 + line11 + +inserted1 + line12 + line13 + +inserted2 + line14 + line15 + line16 + line17 + "}, + excerpt_start_row: 10, + excerpt_row_count: 5, + expected_filtered_diff: indoc! {" + @@ -1,3 +1,5 @@ + line11 + +inserted1 + line12 + line13 + +inserted2 + "}, + expected_line_offset: 2, + }, + Case { + name: "replacement_inside_excerpt", + diff: indoc! {" + @@ -12,3 +12,3 @@ + line12 + -old_text + +new_text + line14 + "}, + excerpt_start_row: 10, + excerpt_row_count: 10, + expected_filtered_diff: indoc! {" + @@ -2,3 +2,3 @@ + line12 + -old_text + +new_text + line14 + "}, + expected_line_offset: 0, + }, + ]; + + for case in &cases { + let (filtered, line_offset) = filter_diff_hunks_by_excerpt( + case.diff, + case.excerpt_start_row, + case.excerpt_row_count, + ); + assert_eq!( + filtered, case.expected_filtered_diff, + "Test '{}': filtered diff mismatch.\nExpected:\n{}\nGot:\n{}", + case.name, case.expected_filtered_diff, filtered + ); + assert_eq!( + line_offset, case.expected_line_offset, + "Test '{}': line offset mismatch. Expected {}, got {}", + case.name, case.expected_line_offset, line_offset + ); + } + } + + #[test] + fn test_excerpt_aware_reversal_tracking() { + struct Case { + name: &'static str, + edit_history_diffs: Vec<&'static str>, + excerpt_content: &'static str, + excerpt_start_row: u32, + predicted_content: &'static str, + expected_reversal_chars: usize, + expected_total_chars: usize, + } + + let cases = [ + Case { + name: "edit_outside_excerpt_no_reversal", + edit_history_diffs: vec![indoc! {" + @@ -1,2 +1,3 @@ + line1 + +added_outside + line2 + "}], + excerpt_content: indoc! {" + line10 + line11 + line12 + "}, + excerpt_start_row: 10, + predicted_content: indoc! {" + line10 + modified + line12 + "}, + expected_reversal_chars: 0, + expected_total_chars: 14, + }, + Case { + name: "edit_inside_excerpt_with_reversal", + edit_history_diffs: vec![indoc! {" + @@ -10,3 +10,4 @@ + line10 + +user_added + line11 + line12 + "}], + excerpt_content: indoc! {" + line10 + user_added + line11 + line12 + "}, + excerpt_start_row: 10, + predicted_content: indoc! {" + line10 + line11 + line12 + "}, + expected_reversal_chars: 11, + expected_total_chars: 11, + }, + Case { + name: "straddling_edit_partial_reversal", + edit_history_diffs: vec![indoc! {" + @@ -8,6 +8,8 @@ + line8 + line9 + +before_excerpt + line10 + +inside_excerpt + line11 + line12 + line13 + "}], + excerpt_content: indoc! {" + line10 + inside_excerpt + line11 + line12 + line13 + "}, + excerpt_start_row: 10, + predicted_content: indoc! {" + line10 + line11 + line12 + line13 + "}, + expected_reversal_chars: 15, + expected_total_chars: 15, + }, + Case { + name: "multiple_edits_mixed_locations", + edit_history_diffs: vec![ + indoc! {" + @@ -1,2 +1,3 @@ + line1 + +outside1 + line2 + "}, + indoc! {" + @@ -11,2 +12,3 @@ + line11 + +inside1 + line12 + "}, + ], + excerpt_content: indoc! {" + line10 + line11 + inside1 + line12 + line13 + "}, + excerpt_start_row: 10, + predicted_content: indoc! {" + line10 + line11 + line12 + line13 + "}, + expected_reversal_chars: 8, + expected_total_chars: 8, + }, + Case { + name: "no_edit_history", + edit_history_diffs: vec![], + excerpt_content: indoc! {" + line10 + line11 + line12 + "}, + excerpt_start_row: 10, + predicted_content: indoc! {" + line10 + modified + line12 + "}, + expected_reversal_chars: 0, + expected_total_chars: 14, + }, + Case { + name: "edit_after_excerpt_no_effect", + edit_history_diffs: vec![indoc! {" + @@ -50,2 +50,3 @@ + line50 + +added_after + line51 + "}], + excerpt_content: indoc! {" + line10 + line11 + line12 + "}, + excerpt_start_row: 10, + predicted_content: indoc! {" + line10 + changed + line12 + "}, + expected_reversal_chars: 0, + expected_total_chars: 13, + }, + Case { + name: "line_offset_tracking_across_hunks", + edit_history_diffs: vec![ + indoc! {" + @@ -1,2 +1,4 @@ + line1 + +added1 + +added2 + line2 + "}, + indoc! {" + @@ -12,2 +14,3 @@ + line12 + +inside_after_offset + line13 + "}, + ], + excerpt_content: indoc! {" + line10 + line11 + line12 + inside_after_offset + line13 + "}, + excerpt_start_row: 10, + predicted_content: indoc! {" + line10 + line11 + line12 + line13 + "}, + expected_reversal_chars: 20, + expected_total_chars: 20, + }, + ]; + + for case in &cases { + let overlap = compute_excerpt_aware_reversal_overlap( + &case.edit_history_diffs, + case.excerpt_content, + case.excerpt_start_row, + case.predicted_content, + ); + assert_eq!( + overlap.chars_reversing_user_edits, case.expected_reversal_chars, + "Test '{}': expected {} reversal chars, got {}", + case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits + ); + assert_eq!( + overlap.total_chars_in_prediction, case.expected_total_chars, + "Test '{}': expected {} total chars, got {}", + case.name, case.expected_total_chars, overlap.total_chars_in_prediction + ); + } + } + + #[test] + fn test_lenient_diff_application() { + struct Case { + name: &'static str, + diff: &'static str, + content: &'static str, + expected_result: &'static str, + } + + let cases = [ + Case { + name: "hunk_context_not_found_skipped", + diff: indoc! {" + @@ -1,3 +1,4 @@ + context_not_in_content + +added_line + more_context + final_context + "}, + content: indoc! {" + completely + different + content + "}, + expected_result: indoc! {" + completely + different + content + "}, + }, + Case { + name: "hunk_context_found_applied", + diff: indoc! {" + @@ -1,3 +1,4 @@ + line1 + +inserted + line2 + line3 + "}, + content: indoc! {" + line1 + line2 + line3 + "}, + expected_result: indoc! {" + line1 + inserted + line2 + line3 + "}, + }, + Case { + name: "multiple_hunks_partial_match", + diff: indoc! {" + @@ -1,2 +1,3 @@ + not_found + +skipped + also_not_found + @@ -5,2 +6,3 @@ + line5 + +applied + line6 + "}, + content: indoc! {" + line1 + line2 + line3 + line4 + line5 + line6 + "}, + expected_result: indoc! {" + line1 + line2 + line3 + line4 + line5 + applied + line6 + "}, + }, + Case { + name: "empty_diff", + diff: "", + content: indoc! {" + unchanged + content + "}, + expected_result: indoc! {" + unchanged + content + "}, + }, + ]; + + for case in &cases { + let result = apply_diff_to_string_lenient(case.diff, case.content); + assert_eq!( + result, case.expected_result, + "Test '{}': expected:\n{}\ngot:\n{}", + case.name, case.expected_result, result + ); + } + } + + #[test] + fn test_unicode_reversal_overlap() { + struct Case { + name: &'static str, + original: &'static str, + current: &'static str, + predicted: &'static str, + expected_reversal_chars: usize, + expected_total_chars: usize, + } + + let cases = [ + Case { + name: "unicode_extension_cjk", + original: "", + current: "日", // 1 char + predicted: "日本語", // 3 chars, adds 2 chars + expected_reversal_chars: 0, + expected_total_chars: 2, // "本語" = 2 chars added + }, + Case { + name: "unicode_extension_emoji", + original: "", + current: "🎉", // 1 char + predicted: "🎉🎊🎈", // 3 chars, adds 2 chars + expected_reversal_chars: 0, + expected_total_chars: 2, // "🎊🎈" = 2 chars added + }, + Case { + name: "unicode_deletion_restored", + original: "héllo wörld", // 11 chars + current: "héllo", // 5 chars + predicted: "héllo wörld", // restores " wörld" = 6 chars + expected_reversal_chars: 6, // LCS(" wörld", " wörld") = 6 chars + expected_total_chars: 6, + }, + Case { + name: "unicode_addition_reversed", + original: "café", // 4 chars + current: "café latté", // 10 chars, added " latté" = 6 chars + predicted: "café", // removes " latté" + expected_reversal_chars: 6, // 6 chars removed + expected_total_chars: 6, + }, + Case { + name: "mixed_ascii_unicode", + original: "", + current: "test日本", // 6 chars + predicted: "test日本語です", // 9 chars + expected_reversal_chars: 0, + expected_total_chars: 3, // 3 new chars after subsequence normalization + }, + Case { + name: "unicode_replacement_not_subsequence", + original: "", + current: "日本", // 2 chars + predicted: "中国", // 2 chars, different + expected_reversal_chars: 2, // removes "日本" = 2 chars + expected_total_chars: 4, // 2 removed + 2 added + }, + ]; + + for case in &cases { + let overlap = compute_reversal_overlap(case.original, case.current, case.predicted); + assert_eq!( + overlap.chars_reversing_user_edits, case.expected_reversal_chars, + "Test '{}': expected {} reversal chars, got {}", + case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits + ); + assert_eq!( + overlap.total_chars_in_prediction, case.expected_total_chars, + "Test '{}': expected {} total chars, got {}", + case.name, case.expected_total_chars, overlap.total_chars_in_prediction + ); + } + } + + #[test] + fn test_is_subsequence() { + assert!(is_subsequence("", "anything")); + assert!(is_subsequence("", "")); + assert!(is_subsequence("abc", "abc")); + assert!(is_subsequence("abc", "aXbXc")); + assert!(is_subsequence("ac", "abc")); + assert!(!is_subsequence("abc", "ab")); + assert!(!is_subsequence("abc", "cba")); + assert!(!is_subsequence("abc", "")); + assert!(is_subsequence("日本", "日X本Y語")); + assert!(!is_subsequence("日本語", "日本")); + } + + #[test] + fn test_compute_lcs_length() { + assert_eq!(compute_lcs_length("", ""), 0); + assert_eq!(compute_lcs_length("abc", ""), 0); + assert_eq!(compute_lcs_length("", "abc"), 0); + assert_eq!(compute_lcs_length("abc", "abc"), 3); + assert_eq!(compute_lcs_length("abc", "def"), 0); + assert_eq!(compute_lcs_length("abcdef", "ace"), 3); + assert_eq!(compute_lcs_length("AGGTAB", "GXTXAYB"), 4); + assert_eq!(compute_lcs_length("日本語", "日語"), 2); + } + + #[test] + fn test_compute_prediction_reversal_ratio_full_file() { + let prompt_inputs = ExamplePromptInputs { + content: indoc! {" + line1 + user_added + line2 + "} + .to_string(), + cursor_row: 0, + cursor_column: 0, + cursor_offset: 0, + edit_history: vec![Arc::new(zeta_prompt::Event::BufferChange { + path: Arc::from(Path::new("src/test.rs")), + old_path: Arc::from(Path::new("src/test.rs")), + diff: indoc! {" + @@ -1,2 +1,3 @@ + line1 + +user_added + line2 + "} + .into(), + predicted: false, + in_open_source_repo: false, + })], + excerpt_start_row: None, + related_files: None, + }; + + let predicted = indoc! {" + line1 + line2 + "}; + let ratio = + compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs")); + + assert!( + ratio > 0.9, + "Expected high reversal ratio when prediction removes user addition, got {}", + ratio + ); + } + + #[test] + fn test_compute_prediction_reversal_ratio_with_excerpt() { + let prompt_inputs = ExamplePromptInputs { + content: indoc! {" + line10 + user_added + line11 + "} + .to_string(), + cursor_row: 0, + cursor_column: 0, + cursor_offset: 0, + edit_history: vec![Arc::new(zeta_prompt::Event::BufferChange { + path: Arc::from(Path::new("src/test.rs")), + old_path: Arc::from(Path::new("src/test.rs")), + diff: indoc! {" + @@ -10,2 +10,3 @@ + line10 + +user_added + line11 + "} + .into(), + predicted: false, + in_open_source_repo: false, + })], + excerpt_start_row: Some(10), + related_files: None, + }; + + let predicted = indoc! {" + line10 + line11 + "}; + let ratio = + compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs")); + + assert!( + ratio > 0.9, + "Expected high reversal ratio for excerpt-aware computation, got {}", + ratio + ); + } + + #[test] + fn test_compute_prediction_reversal_ratio_no_history() { + let prompt_inputs = ExamplePromptInputs { + content: indoc! {" + original content + "} + .to_string(), + cursor_row: 0, + cursor_column: 0, + cursor_offset: 0, + edit_history: vec![], + excerpt_start_row: None, + related_files: None, + }; + + let predicted = indoc! {" + completely different + "}; + let ratio = + compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs")); + + assert_eq!( + ratio, 0.0, + "Expected zero reversal ratio with no edit history" + ); + } + + #[test] + fn test_compute_prediction_reversal_ratio_path_filtering() { + let prompt_inputs = ExamplePromptInputs { + content: indoc! {" + line1 + user_added + line2 + "} + .to_string(), + cursor_row: 0, + cursor_column: 0, + cursor_offset: 0, + edit_history: vec![Arc::new(zeta_prompt::Event::BufferChange { + path: Arc::from(Path::new("src/other.rs")), + old_path: Arc::from(Path::new("src/other.rs")), + diff: indoc! {" + @@ -1,2 +1,3 @@ + line1 + +user_added + line2 + "} + .into(), + predicted: false, + in_open_source_repo: false, + })], + excerpt_start_row: None, + related_files: None, + }; + + let predicted = indoc! {" + line1 + line2 + "}; + let ratio = + compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs")); + + assert_eq!( + ratio, 0.0, + "Expected zero reversal when edit history is for different file" + ); + } + + #[test] + fn test_compute_prediction_reversal_ratio_lenient_fallback() { + let prompt_inputs = ExamplePromptInputs { + content: indoc! {" + actual_line1 + user_added + actual_line2 + "} + .to_string(), + cursor_row: 0, + cursor_column: 0, + cursor_offset: 0, + edit_history: vec![Arc::new(zeta_prompt::Event::BufferChange { + path: Arc::from(Path::new("src/test.rs")), + old_path: Arc::from(Path::new("src/test.rs")), + diff: indoc! {" + @@ -1,2 +1,3 @@ + wrong_context + +user_added + more_wrong + "} + .into(), + predicted: false, + in_open_source_repo: false, + })], + excerpt_start_row: None, + related_files: None, + }; + + let predicted = indoc! {" + actual_line1 + actual_line2 + "}; + let ratio = + compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs")); + + assert!( + ratio >= 0.0 && ratio <= 1.0, + "Ratio should be valid even with lenient fallback, got {}", + ratio + ); + } + + #[test] + fn test_excerpt_aware_reversal_error_recovery() { + let diffs = vec![indoc! {" + @@ -1,2 +1,3 @@ + nonexistent_context + +added + more_nonexistent + "}]; + let excerpt_content = indoc! {" + completely + different + content + "}; + let predicted_content = indoc! {" + completely + modified + content + "}; + + let overlap = + compute_excerpt_aware_reversal_overlap(&diffs, excerpt_content, 0, predicted_content); + + assert!( + overlap.ratio() >= 0.0 && overlap.ratio() <= 1.0, + "Should handle failed diff application gracefully" + ); + } + + #[test] + fn test_multiple_sequential_diffs() { + let prompt_inputs = ExamplePromptInputs { + content: indoc! {" + line1 + first_add + second_add + line2 + "} + .to_string(), + cursor_row: 0, + cursor_column: 0, + cursor_offset: 0, + edit_history: vec![ + Arc::new(zeta_prompt::Event::BufferChange { + path: Arc::from(Path::new("src/test.rs")), + old_path: Arc::from(Path::new("src/test.rs")), + diff: indoc! {" + @@ -1,2 +1,3 @@ + line1 + +first_add + line2 + "} + .into(), + predicted: false, + in_open_source_repo: false, + }), + Arc::new(zeta_prompt::Event::BufferChange { + path: Arc::from(Path::new("src/test.rs")), + old_path: Arc::from(Path::new("src/test.rs")), + diff: indoc! {" + @@ -2,2 +2,3 @@ + first_add + +second_add + line2 + "} + .into(), + predicted: false, + in_open_source_repo: false, + }), + ], + excerpt_start_row: None, + related_files: None, + }; + + let predicted = indoc! {" + line1 + line2 + "}; + let ratio = + compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs")); + + assert!( + ratio > 0.9, + "Expected high reversal ratio when reversing multiple sequential edits, got {}", + ratio + ); + } } diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index 61104270c3d6ca5c35a13e778ffc999528a04e20..cb9c839d36b056b57e041f1d530fc0d0cdd239a4 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -19,6 +19,8 @@ pub struct ZetaPromptInput { pub cursor_excerpt: Arc, pub editable_range_in_excerpt: Range, pub cursor_offset_in_excerpt: usize, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub excerpt_start_row: Option, pub events: Vec>, pub related_files: Vec, } @@ -433,6 +435,7 @@ mod tests { cursor_excerpt: cursor_excerpt.into(), editable_range_in_excerpt: editable_range, cursor_offset_in_excerpt: cursor_offset, + excerpt_start_row: None, events: events.into_iter().map(Arc::new).collect(), related_files, }