1use std::ops::Range;
2use std::path::Path;
3use std::sync::Arc;
4
5use edit_prediction::udiff::apply_diff_to_string;
6use language::text_diff;
7
8use crate::example::ExamplePromptInputs;
9
10pub fn reverse_diff(diff: &str) -> String {
11 let mut result: String = diff
12 .lines()
13 .map(|line| {
14 if line.starts_with("--- ") {
15 line.replacen("--- ", "+++ ", 1)
16 } else if line.starts_with("+++ ") {
17 line.replacen("+++ ", "--- ", 1)
18 } else if line.starts_with('+') && !line.starts_with("+++") {
19 format!("-{}", &line[1..])
20 } else if line.starts_with('-') && !line.starts_with("---") {
21 format!("+{}", &line[1..])
22 } else {
23 line.to_string()
24 }
25 })
26 .collect::<Vec<_>>()
27 .join("\n");
28 if diff.ends_with('\n') {
29 result.push('\n');
30 }
31 result
32}
33
34#[derive(Debug, Clone, PartialEq, Eq)]
35pub struct GranularEdit {
36 pub range: Range<usize>,
37 pub old_text: String,
38 pub new_text: String,
39}
40
41pub fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec<GranularEdit> {
42 text_diff(old_text, new_text)
43 .into_iter()
44 .map(|(range, new_text)| GranularEdit {
45 old_text: old_text[range.clone()].to_string(),
46 range,
47 new_text: new_text.to_string(),
48 })
49 .collect()
50}
51
52#[derive(Debug, Clone)]
53pub struct HistoryAdditionRange {
54 pub range_in_current: Range<usize>,
55}
56
57#[derive(Debug, Clone)]
58pub struct HistoryDeletionRange {
59 pub deleted_text: String,
60}
61
62pub fn compute_history_addition_ranges(
63 history_edits: &[GranularEdit],
64) -> Vec<HistoryAdditionRange> {
65 let mut result = Vec::new();
66 let mut offset_delta: isize = 0;
67
68 for edit in history_edits {
69 if !edit.new_text.is_empty() {
70 let new_start = (edit.range.start as isize + offset_delta) as usize;
71 let new_end = new_start + edit.new_text.len();
72 result.push(HistoryAdditionRange {
73 range_in_current: new_start..new_end,
74 });
75 }
76
77 offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize;
78 }
79
80 result
81}
82
83pub fn compute_history_deletion_ranges(
84 history_edits: &[GranularEdit],
85) -> Vec<HistoryDeletionRange> {
86 history_edits
87 .iter()
88 .filter(|edit| !edit.old_text.is_empty())
89 .map(|edit| HistoryDeletionRange {
90 deleted_text: edit.old_text.clone(),
91 })
92 .collect()
93}
94
95#[derive(Debug, Clone, Default, PartialEq, Eq)]
96pub struct ReversalOverlap {
97 pub chars_reversing_user_edits: usize,
98 pub total_chars_in_prediction: usize,
99}
100
101impl ReversalOverlap {
102 pub fn ratio(&self) -> f32 {
103 if self.total_chars_in_prediction == 0 {
104 0.0
105 } else {
106 self.chars_reversing_user_edits as f32 / self.total_chars_in_prediction as f32
107 }
108 }
109}
110
111/// Compute how much of a prediction reverses recent user edits.
112pub fn compute_reversal_overlap(
113 original_content: &str,
114 current_content: &str,
115 predicted_content: &str,
116) -> ReversalOverlap {
117 let history_edits = compute_granular_edits(original_content, current_content);
118 let prediction_edits = compute_granular_edits(current_content, predicted_content);
119
120 let history_addition_ranges = compute_history_addition_ranges(&history_edits);
121 let history_deletion_ranges = compute_history_deletion_ranges(&history_edits);
122
123 let reversed_additions =
124 compute_reversed_additions(&history_addition_ranges, &prediction_edits);
125 let restored_deletions =
126 compute_restored_deletions(&history_deletion_ranges, &prediction_edits);
127
128 let prediction_added_chars: usize = prediction_edits.iter().map(|e| e.new_text.len()).sum();
129 let prediction_deleted_chars: usize = prediction_edits.iter().map(|e| e.old_text.len()).sum();
130
131 ReversalOverlap {
132 chars_reversing_user_edits: reversed_additions + restored_deletions,
133 total_chars_in_prediction: prediction_added_chars + prediction_deleted_chars,
134 }
135}
136
137pub fn compute_reversed_additions(
138 history_addition_ranges: &[HistoryAdditionRange],
139 prediction_edits: &[GranularEdit],
140) -> usize {
141 let mut reversed_chars = 0;
142
143 for pred_edit in prediction_edits {
144 for history_addition in history_addition_ranges {
145 let overlap_start = pred_edit
146 .range
147 .start
148 .max(history_addition.range_in_current.start);
149 let overlap_end = pred_edit
150 .range
151 .end
152 .min(history_addition.range_in_current.end);
153
154 if overlap_start < overlap_end {
155 reversed_chars += overlap_end - overlap_start;
156 }
157 }
158 }
159
160 reversed_chars
161}
162
163pub fn compute_restored_deletions(
164 history_deletion_ranges: &[HistoryDeletionRange],
165 prediction_edits: &[GranularEdit],
166) -> usize {
167 let history_deleted_text: String = history_deletion_ranges
168 .iter()
169 .map(|r| r.deleted_text.as_str())
170 .collect();
171
172 let prediction_added_text: String = prediction_edits
173 .iter()
174 .map(|e| e.new_text.as_str())
175 .collect();
176
177 compute_lcs_length(&history_deleted_text, &prediction_added_text)
178}
179
180fn compute_lcs_length(a: &str, b: &str) -> usize {
181 let a_chars: Vec<char> = a.chars().collect();
182 let b_chars: Vec<char> = b.chars().collect();
183 let m = a_chars.len();
184 let n = b_chars.len();
185
186 if m == 0 || n == 0 {
187 return 0;
188 }
189
190 let mut prev = vec![0; n + 1];
191 let mut curr = vec![0; n + 1];
192
193 for i in 1..=m {
194 for j in 1..=n {
195 if a_chars[i - 1] == b_chars[j - 1] {
196 curr[j] = prev[j - 1] + 1;
197 } else {
198 curr[j] = prev[j].max(curr[j - 1]);
199 }
200 }
201 std::mem::swap(&mut prev, &mut curr);
202 curr.fill(0);
203 }
204
205 prev[n]
206}
207
208pub fn filter_edit_history_by_path<'a>(
209 edit_history: &'a [Arc<zeta_prompt::Event>],
210 cursor_path: &std::path::Path,
211) -> Vec<&'a zeta_prompt::Event> {
212 edit_history
213 .iter()
214 .filter(|event| match event.as_ref() {
215 zeta_prompt::Event::BufferChange { path, .. } => {
216 let event_path = path.as_ref();
217 if event_path == cursor_path {
218 return true;
219 }
220 let stripped = event_path
221 .components()
222 .skip(1)
223 .collect::<std::path::PathBuf>();
224 stripped == cursor_path
225 }
226 })
227 .map(|arc| arc.as_ref())
228 .collect()
229}
230
231pub fn extract_diff_from_event(event: &zeta_prompt::Event) -> &str {
232 match event {
233 zeta_prompt::Event::BufferChange { diff, .. } => diff.as_str(),
234 }
235}
236
237pub fn compute_prediction_reversal_ratio(
238 prompt_inputs: &ExamplePromptInputs,
239 predicted_content: &str,
240 cursor_path: &Path,
241) -> f32 {
242 let current_content = &prompt_inputs.content;
243
244 let edit_history: &[Arc<zeta_prompt::Event>] = &prompt_inputs.edit_history;
245 let relevant_events = filter_edit_history_by_path(edit_history, cursor_path);
246
247 let mut original_content = current_content.to_string();
248 for event in relevant_events.into_iter().rev() {
249 let diff = extract_diff_from_event(event);
250 if diff.is_empty() {
251 continue;
252 }
253 let reversed = reverse_diff(diff);
254 let with_headers = format!("--- a/file\n+++ b/file\n{}", reversed);
255 match apply_diff_to_string(&with_headers, &original_content) {
256 Ok(updated_content) => original_content = updated_content,
257 Err(err) => {
258 log::warn!(
259 "Failed to reconstruct original content for reversal tracking: Failed to apply reversed diff: {:#}",
260 err
261 );
262 return 0.0;
263 }
264 }
265 }
266
267 let overlap = compute_reversal_overlap(&original_content, current_content, predicted_content);
268 overlap.ratio()
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use edit_prediction::udiff::apply_diff_to_string;
275
276 #[test]
277 fn test_reversal_overlap() {
278 struct Case {
279 name: &'static str,
280 original: &'static str,
281 current: &'static str,
282 predicted: &'static str,
283 expected_reversal_chars: usize,
284 expected_total_chars: usize,
285 }
286
287 let cases = [
288 Case {
289 name: "user_adds_line_prediction_removes_it",
290 original: "a\nb\nc",
291 current: "a\nnew line\nb\nc",
292 predicted: "a\nb\nc",
293 expected_reversal_chars: 9,
294 expected_total_chars: 9,
295 },
296 Case {
297 name: "user_deletes_line_prediction_restores_it",
298 original: "a\ndeleted\nb",
299 current: "a\nb",
300 predicted: "a\ndeleted\nb",
301 expected_reversal_chars: 8,
302 expected_total_chars: 8,
303 },
304 Case {
305 name: "user_deletes_text_prediction_restores_partial",
306 original: "hello beautiful world",
307 current: "hello world",
308 predicted: "hello beautiful world",
309 expected_reversal_chars: 10,
310 expected_total_chars: 10,
311 },
312 Case {
313 name: "user_deletes_foo_prediction_adds_bar",
314 original: "foo",
315 current: "",
316 predicted: "bar",
317 expected_reversal_chars: 0,
318 expected_total_chars: 3,
319 },
320 Case {
321 name: "independent_edits_different_locations",
322 original: "line1\nline2\nline3",
323 current: "LINE1\nline2\nline3",
324 predicted: "LINE1\nline2\nLINE3",
325 expected_reversal_chars: 0,
326 expected_total_chars: 10,
327 },
328 Case {
329 name: "no_history_edits",
330 original: "same",
331 current: "same",
332 predicted: "different",
333 expected_reversal_chars: 0,
334 expected_total_chars: 13,
335 },
336 Case {
337 name: "user_replaces_text_prediction_reverses",
338 original: "keep\ndelete_me\nkeep2",
339 current: "keep\nadded\nkeep2",
340 predicted: "keep\ndelete_me\nkeep2",
341 expected_reversal_chars: 14,
342 expected_total_chars: 14,
343 },
344 Case {
345 name: "user_modifies_word_prediction_modifies_differently",
346 original: "the quick brown fox",
347 current: "the slow brown fox",
348 predicted: "the fast brown fox",
349 expected_reversal_chars: 4,
350 expected_total_chars: 8,
351 },
352 ];
353
354 for case in &cases {
355 let overlap = compute_reversal_overlap(case.original, case.current, case.predicted);
356 assert_eq!(
357 overlap.chars_reversing_user_edits, case.expected_reversal_chars,
358 "Test '{}': expected {} reversal chars, got {}",
359 case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
360 );
361 assert_eq!(
362 overlap.total_chars_in_prediction, case.expected_total_chars,
363 "Test '{}': expected {} total chars, got {}",
364 case.name, case.expected_total_chars, overlap.total_chars_in_prediction
365 );
366 }
367 }
368
369 #[test]
370 fn test_reverse_diff() {
371 let forward_diff = "\
372--- a/file.rs
373+++ b/file.rs
374@@ -1,3 +1,4 @@
375 fn main() {
376+ let x = 42;
377 println!(\"hello\");
378}";
379
380 let reversed = reverse_diff(forward_diff);
381
382 assert!(
383 reversed.contains("+++ a/file.rs"),
384 "Should have +++ for old path"
385 );
386 assert!(
387 reversed.contains("--- b/file.rs"),
388 "Should have --- for new path"
389 );
390 assert!(
391 reversed.contains("- let x = 42;"),
392 "Added line should become deletion"
393 );
394 assert!(
395 reversed.contains(" fn main()"),
396 "Context lines should be unchanged"
397 );
398 }
399
400 #[test]
401 fn test_reverse_diff_roundtrip() {
402 // Applying a diff and then its reverse should get back to original
403 let original = "first line\nhello world\nlast line\n";
404 let modified = "first line\nhello beautiful world\nlast line\n";
405
406 // unified_diff doesn't include file headers, but apply_diff_to_string needs them
407 let diff_body = language::unified_diff(original, modified);
408 let forward_diff = format!("--- a/file\n+++ b/file\n{}", diff_body);
409 let reversed_diff = reverse_diff(&forward_diff);
410
411 // Apply forward diff to original
412 let after_forward = apply_diff_to_string(&forward_diff, original).unwrap();
413 assert_eq!(after_forward, modified);
414
415 // Apply reversed diff to modified
416 let after_reverse = apply_diff_to_string(&reversed_diff, &after_forward).unwrap();
417 assert_eq!(after_reverse, original);
418 }
419
420 #[test]
421 fn test_filter_edit_history_by_path() {
422 // Test that filter_edit_history_by_path correctly matches paths when
423 // the edit history has paths with a repo prefix (e.g., "repo/src/file.rs")
424 // but the cursor_path doesn't have the repo prefix (e.g., "src/file.rs")
425 let events = vec![
426 Arc::new(zeta_prompt::Event::BufferChange {
427 path: Arc::from(Path::new("myrepo/src/file.rs")),
428 old_path: Arc::from(Path::new("myrepo/src/file.rs")),
429 diff: "@@ -1 +1 @@\n-old\n+new".into(),
430 predicted: false,
431 in_open_source_repo: true,
432 }),
433 Arc::new(zeta_prompt::Event::BufferChange {
434 path: Arc::from(Path::new("myrepo/other.rs")),
435 old_path: Arc::from(Path::new("myrepo/other.rs")),
436 diff: "@@ -1 +1 @@\n-a\n+b".into(),
437 predicted: false,
438 in_open_source_repo: true,
439 }),
440 Arc::new(zeta_prompt::Event::BufferChange {
441 path: Arc::from(Path::new("src/file.rs")),
442 old_path: Arc::from(Path::new("src/file.rs")),
443 diff: "@@ -1 +1 @@\n-x\n+y".into(),
444 predicted: false,
445 in_open_source_repo: true,
446 }),
447 ];
448
449 // "myrepo/src/file.rs" stripped -> "src/file.rs" matches cursor_path
450 // "src/file.rs" exact match
451 let cursor_path = Path::new("src/file.rs");
452 let filtered = filter_edit_history_by_path(&events, cursor_path);
453 assert_eq!(
454 filtered.len(),
455 2,
456 "Should match myrepo/src/file.rs (stripped) and src/file.rs (exact)"
457 );
458
459 // "myrepo/src/file.rs" stripped -> "src/file.rs" != "file.rs"
460 // "src/file.rs" stripped -> "file.rs" == "file.rs"
461 let cursor_path = Path::new("file.rs");
462 let filtered = filter_edit_history_by_path(&events, cursor_path);
463 assert_eq!(
464 filtered.len(),
465 1,
466 "Should only match src/file.rs (stripped to file.rs)"
467 );
468
469 // "myrepo/other.rs" stripped -> "other.rs" == "other.rs"
470 let cursor_path = Path::new("other.rs");
471 let filtered = filter_edit_history_by_path(&events, cursor_path);
472 assert_eq!(filtered.len(), 1, "Should match only myrepo/other.rs");
473 }
474
475 #[test]
476 fn test_reverse_diff_preserves_trailing_newline() {
477 let diff_with_trailing_newline = "--- a/file\n+++ b/file\n@@ -1 +1 @@\n-old\n+new\n";
478 let reversed = reverse_diff(diff_with_trailing_newline);
479 assert!(
480 reversed.ends_with('\n'),
481 "Reversed diff should preserve trailing newline"
482 );
483
484 let diff_without_trailing_newline = "--- a/file\n+++ b/file\n@@ -1 +1 @@\n-old\n+new";
485 let reversed = reverse_diff(diff_without_trailing_newline);
486 assert!(
487 !reversed.ends_with('\n'),
488 "Reversed diff should not add trailing newline if original didn't have one"
489 );
490 }
491}