1use collections::HashMap;
2
3use crate::{
4 example::ActualCursor,
5 reorder_patch::{Patch, PatchLine},
6 word_diff::{DiffOp, diff_tokens, tokenize},
7};
8
9pub type Counts = HashMap<String, usize>;
10type CountsDelta = HashMap<String, isize>;
11
12/// Context characters needed on each side of a change to capture all affected n-grams
13const CONTEXT_CHARS: usize = CHR_F_CHAR_ORDER - 1;
14
15#[derive(Default, Debug, Clone)]
16pub struct ClassificationMetrics {
17 pub true_positives: usize,
18 pub false_positives: usize,
19 pub false_negatives: usize,
20}
21
22impl ClassificationMetrics {
23 pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
24 let mut true_positives = 0;
25 let mut false_positives = 0;
26 let mut false_negatives = 0;
27
28 for (ngram, &expected_count) in expected {
29 let actual_count = *actual.get(ngram).unwrap_or(&0);
30 if actual_count > expected_count {
31 false_positives += actual_count - expected_count;
32 } else {
33 false_negatives += expected_count - actual_count;
34 }
35 true_positives += expected_count.min(actual_count);
36 }
37
38 for (ngram, &actual_count) in actual {
39 if !expected.contains_key(ngram) {
40 false_positives += actual_count;
41 }
42 }
43
44 ClassificationMetrics {
45 true_positives,
46 false_positives,
47 false_negatives,
48 }
49 }
50
51 pub fn accumulate(&mut self, other: &ClassificationMetrics) {
52 self.true_positives += other.true_positives;
53 self.false_positives += other.false_positives;
54 self.false_negatives += other.false_negatives;
55 }
56
57 pub fn precision(&self) -> f64 {
58 if self.true_positives + self.false_positives == 0 {
59 0.0
60 } else {
61 self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
62 }
63 }
64
65 pub fn recall(&self) -> f64 {
66 if self.true_positives + self.false_negatives == 0 {
67 0.0
68 } else {
69 self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
70 }
71 }
72
73 pub fn f1(&self) -> f64 {
74 let precision = self.precision();
75 let recall = self.recall();
76 if precision + recall == 0.0 {
77 0.0
78 } else {
79 2.0 * precision * recall / (precision + recall)
80 }
81 }
82}
83
84enum ChrfWhitespace {
85 /// Preserve whitespace as-is
86 #[allow(unused)]
87 Unchanged,
88
89 /// Ignore all whitespace differences
90 #[allow(unused)]
91 Ignore,
92
93 /// Collapse whitespace into single spaces
94 Collapse,
95}
96
97const CHR_F_CHAR_ORDER: usize = 6;
98const CHR_F_BETA: f64 = 0.5;
99const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Collapse;
100
101pub fn delta_chr_f_beta() -> f64 {
102 CHR_F_BETA
103}
104
105#[derive(Default, Debug, Clone)]
106pub struct DeltaChrFMetrics {
107 pub score: f64,
108 pub beta: f64,
109 pub counts: ClassificationMetrics,
110 pub precision: f64,
111 pub recall: f64,
112}
113
114/// Computes delta-chrF metrics that compare two sets of edits.
115///
116/// This metric works by:
117/// 1. Computing n-gram count differences (deltas) between original→expected and original→actual
118/// 2. Comparing these deltas to measure how well actual edits match expected edits
119///
120/// Returns a score from 0.0 to 100.0, where 100.0 means the actual edits perfectly match
121/// the expected edits.
122pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> DeltaChrFMetrics {
123 if original == expected && expected == actual {
124 return DeltaChrFMetrics {
125 score: 100.0,
126 beta: CHR_F_BETA,
127 precision: 1.0,
128 recall: 1.0,
129 ..DeltaChrFMetrics::default()
130 };
131 }
132
133 let orig_chars: Vec<char> = filter_whitespace_chars(original);
134 let exp_chars: Vec<char> = filter_whitespace_chars(expected);
135 let act_chars: Vec<char> = filter_whitespace_chars(actual);
136
137 // Find the changed regions between original→expected and original→actual
138 // We only need to compute n-grams on these regions (plus context for boundary n-grams)
139 let (orig_for_exp, exp_region) = extract_changed_regions(&orig_chars, &exp_chars);
140 let (orig_for_act, act_region) = extract_changed_regions(&orig_chars, &act_chars);
141
142 let mut total_precision = 0.0;
143 let mut total_recall = 0.0;
144 let mut total_counts = ClassificationMetrics::default();
145
146 for order in 1..=CHR_F_CHAR_ORDER {
147 let orig_ngrams_for_exp = count_ngrams_from_chars(&orig_for_exp, order);
148 let exp_ngrams = count_ngrams_from_chars(&exp_region, order);
149 let expected_delta = compute_ngram_delta(&exp_ngrams, &orig_ngrams_for_exp);
150
151 let orig_ngrams_for_act = count_ngrams_from_chars(&orig_for_act, order);
152 let act_ngrams = count_ngrams_from_chars(&act_region, order);
153 let actual_delta = compute_ngram_delta(&act_ngrams, &orig_ngrams_for_act);
154
155 if expected_delta.is_empty() && actual_delta.is_empty() {
156 total_precision += 1.0;
157 total_recall += 1.0;
158 continue;
159 }
160
161 let expected_counts = ngram_delta_to_counts(&expected_delta);
162 let actual_counts = ngram_delta_to_counts(&actual_delta);
163
164 let counts = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
165 total_precision += counts.precision();
166 total_recall += counts.recall();
167 total_counts.accumulate(&counts);
168 }
169
170 let average_precision = total_precision / CHR_F_CHAR_ORDER as f64;
171 let average_recall = total_recall / CHR_F_CHAR_ORDER as f64;
172 let score = if average_precision + average_recall == 0.0 {
173 0.0
174 } else {
175 (1.0 + CHR_F_BETA * CHR_F_BETA) * average_precision * average_recall
176 / (CHR_F_BETA * CHR_F_BETA * average_precision + average_recall)
177 * 100.0
178 };
179
180 DeltaChrFMetrics {
181 score,
182 beta: CHR_F_BETA,
183 counts: total_counts,
184 precision: average_precision,
185 recall: average_recall,
186 }
187}
188
189/// Reference implementation of delta-chrF metrics (original, non-optimized version).
190/// Used for testing that the optimized version produces identical results.
191#[cfg(test)]
192fn delta_chr_f_reference(original: &str, expected: &str, actual: &str) -> DeltaChrFMetrics {
193 if original == expected && expected == actual {
194 return DeltaChrFMetrics {
195 score: 100.0,
196 beta: CHR_F_BETA,
197 precision: 1.0,
198 recall: 1.0,
199 ..DeltaChrFMetrics::default()
200 };
201 }
202
203 let original_ngrams = chr_f_ngram_counts(original);
204 let expected_ngrams = chr_f_ngram_counts(expected);
205 let actual_ngrams = chr_f_ngram_counts(actual);
206
207 let mut total_precision = 0.0;
208 let mut total_recall = 0.0;
209 let mut total_counts = ClassificationMetrics::default();
210
211 for order in 0..CHR_F_CHAR_ORDER {
212 let expected_delta = compute_ngram_delta(&expected_ngrams[order], &original_ngrams[order]);
213 let actual_delta = compute_ngram_delta(&actual_ngrams[order], &original_ngrams[order]);
214
215 if expected_delta.is_empty() && actual_delta.is_empty() {
216 total_precision += 1.0;
217 total_recall += 1.0;
218 continue;
219 }
220
221 let expected_counts = ngram_delta_to_counts(&expected_delta);
222 let actual_counts = ngram_delta_to_counts(&actual_delta);
223
224 let counts = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
225 total_precision += counts.precision();
226 total_recall += counts.recall();
227 total_counts.accumulate(&counts);
228 }
229
230 let average_precision = total_precision / CHR_F_CHAR_ORDER as f64;
231 let average_recall = total_recall / CHR_F_CHAR_ORDER as f64;
232 let score = if average_precision + average_recall == 0.0 {
233 0.0
234 } else {
235 (1.0 + CHR_F_BETA * CHR_F_BETA) * average_precision * average_recall
236 / (CHR_F_BETA * CHR_F_BETA * average_precision + average_recall)
237 * 100.0
238 };
239
240 DeltaChrFMetrics {
241 score,
242 beta: CHR_F_BETA,
243 counts: total_counts,
244 precision: average_precision,
245 recall: average_recall,
246 }
247}
248
249/// Filter whitespace from a string and return as Vec<char>
250fn filter_whitespace_chars(text: &str) -> Vec<char> {
251 match CHR_F_WHITESPACE {
252 ChrfWhitespace::Unchanged => text.chars().collect(),
253 ChrfWhitespace::Ignore => text.chars().filter(|c| !c.is_whitespace()).collect(),
254 ChrfWhitespace::Collapse => collapse_whitespace(text.chars()),
255 }
256}
257
258/// Collapse whitespace into single spaces.
259/// Newlines and spaces are collapsed separately.
260fn collapse_whitespace(chars: impl Iterator<Item = char>) -> Vec<char> {
261 let mut result = Vec::new();
262 let mut last_whitespace = None;
263 for c in chars {
264 if c.is_whitespace() && c != '\n' {
265 if last_whitespace != Some(' ') {
266 result.push(' ');
267 last_whitespace = Some(' ');
268 }
269 } else if c == '\n' {
270 if last_whitespace != Some('\n') {
271 result.push(c);
272 last_whitespace = Some('\n');
273 }
274 } else {
275 result.push(c);
276 last_whitespace = None;
277 }
278 }
279 result
280}
281
282/// Extract only the changed regions between two texts, with context for n-gram boundaries.
283///
284/// Returns (original_affected_region, modified_affected_region) as Vec<char>.
285///
286/// The key insight: when computing n-gram delta between two nearly-identical texts,
287/// n-grams from unchanged regions cancel out. We only need to process:
288/// 1. The changed content itself
289/// 2. CONTEXT_CHARS (n-1) characters before and after, to capture boundary-crossing n-grams
290fn extract_changed_regions(original: &[char], modified: &[char]) -> (Vec<char>, Vec<char>) {
291 // Find longest common prefix
292 let prefix_len = original
293 .iter()
294 .zip(modified.iter())
295 .take_while(|(a, b)| a == b)
296 .count();
297
298 // Find longest common suffix (that doesn't overlap with prefix)
299 let orig_remaining = original.len().saturating_sub(prefix_len);
300 let mod_remaining = modified.len().saturating_sub(prefix_len);
301 let max_suffix = orig_remaining.min(mod_remaining);
302
303 let suffix_len = original
304 .iter()
305 .rev()
306 .zip(modified.iter().rev())
307 .take(max_suffix)
308 .take_while(|(a, b)| a == b)
309 .count();
310
311 // Calculate the changed region boundaries
312 let orig_change_start = prefix_len;
313 let orig_change_end = original.len().saturating_sub(suffix_len);
314 let mod_change_start = prefix_len;
315 let mod_change_end = modified.len().saturating_sub(suffix_len);
316
317 // If there's no actual change, return empty regions
318 if orig_change_start >= orig_change_end && mod_change_start >= mod_change_end {
319 return (Vec::new(), Vec::new());
320 }
321
322 // Expand to include context for n-gram boundaries
323 let orig_context_start = orig_change_start.saturating_sub(CONTEXT_CHARS);
324 let orig_context_end = (orig_change_end + CONTEXT_CHARS).min(original.len());
325 let mod_context_start = mod_change_start.saturating_sub(CONTEXT_CHARS);
326 let mod_context_end = (mod_change_end + CONTEXT_CHARS).min(modified.len());
327
328 let orig_region: Vec<char> = original[orig_context_start..orig_context_end].to_vec();
329 let mod_region: Vec<char> = modified[mod_context_start..mod_context_end].to_vec();
330
331 (orig_region, mod_region)
332}
333
334/// Count n-grams directly from a char slice (avoids String allocation for the full text)
335fn count_ngrams_from_chars(chars: &[char], n: usize) -> Counts {
336 let mut counts = Counts::default();
337
338 if chars.len() < n {
339 return counts;
340 }
341
342 for window in chars.windows(n) {
343 let ngram: String = window.iter().collect();
344 *counts.entry(ngram).or_insert(0) += 1;
345 }
346
347 counts
348}
349
350#[allow(dead_code)]
351fn chr_f_ngram_counts(text: &str) -> Vec<Counts> {
352 let text = match CHR_F_WHITESPACE {
353 ChrfWhitespace::Unchanged => text.to_string(),
354 ChrfWhitespace::Ignore => text
355 .chars()
356 .filter(|c| !c.is_whitespace())
357 .collect::<String>(),
358 ChrfWhitespace::Collapse => collapse_whitespace(text.chars())
359 .into_iter()
360 .collect::<String>(),
361 };
362
363 (1..=CHR_F_CHAR_ORDER)
364 .map(|order| count_ngrams(&text, order))
365 .collect()
366}
367
368fn compute_ngram_delta(after: &Counts, before: &Counts) -> CountsDelta {
369 let mut delta = CountsDelta::default();
370
371 for (ngram, &before_count) in before {
372 let after_count = *after.get(ngram).unwrap_or(&0);
373 delta.insert(ngram.clone(), after_count as isize - before_count as isize);
374 }
375
376 for (ngram, &after_count) in after {
377 if !before.contains_key(ngram) {
378 delta.insert(ngram.clone(), after_count as isize);
379 }
380 }
381
382 delta
383}
384
385/// Convert negative counts to special deletion tokens.
386/// For example, if expected delta is {"foo": -1} and actual delta is {"bar": -1},
387/// we convert it to {"¬foo": +1} and {"¬bar": +1}. This way _not_ deleting "foo"
388/// will result in a false negative, and mistakenly deleting "bar" will result in a false positive.
389fn ngram_delta_to_counts(delta: &CountsDelta) -> Counts {
390 let mut counts = Counts::default();
391
392 for (ngram, &delta) in delta {
393 if delta > 0 {
394 counts.insert(ngram.clone(), delta as usize);
395 } else if delta < 0 {
396 counts.insert(format!("¬{ngram}"), delta.unsigned_abs());
397 }
398 }
399
400 counts
401}
402
403#[allow(dead_code)]
404fn count_ngrams(text: &str, n: usize) -> Counts {
405 let chars: Vec<char> = text.chars().collect();
406 let mut counts = Counts::default();
407
408 for window in chars.windows(n) {
409 let ngram: String = window.iter().collect();
410 *counts.entry(ngram).or_insert(0) += 1;
411 }
412
413 counts
414}
415
416pub fn braces_disbalance(text: &str) -> usize {
417 let mut disbalance = 0isize;
418
419 let a = text.chars().filter(|&c| c == '{').count() as isize;
420 let b = text.chars().filter(|&c| c == '}').count() as isize;
421 disbalance += (a - b).abs();
422
423 let a = text.chars().filter(|&c| c == '(').count() as isize;
424 let b = text.chars().filter(|&c| c == ')').count() as isize;
425 disbalance += (a - b).abs();
426
427 let a = text.chars().filter(|&c| c == '[').count() as isize;
428 let b = text.chars().filter(|&c| c == ']').count() as isize;
429 disbalance += (a - b).abs();
430
431 disbalance as usize
432}
433
434/// Extracts changed lines from a unified diff string.
435/// Returns a bag (multiset) of lines that were added (+) or removed (-).
436/// The +/- prefix is included in the line to distinguish additions from deletions.
437pub fn extract_changed_lines_from_diff(diff: &str) -> Counts {
438 let mut counts = Counts::default();
439
440 for line in diff.lines() {
441 // Skip file headers (--- and +++)
442 if line.starts_with("---") || line.starts_with("+++") {
443 continue;
444 }
445 // Skip hunk headers (@@)
446 if line.starts_with("@@") {
447 continue;
448 }
449 // Skip diff header lines (diff --git, index, etc.)
450 if line.starts_with("diff ") || line.starts_with("index ") {
451 continue;
452 }
453 // Include added and removed lines (with their prefix)
454 if line.starts_with('+') || line.starts_with('-') {
455 *counts.entry(line.to_string()).or_insert(0) += 1;
456 }
457 }
458
459 counts
460}
461
462/// Computes exact lines match metrics between expected and actual patches.
463/// Treats changed lines as a bag (multiset) - order is discarded but count matters.
464/// Returns ClassificationMetrics with TP/FP/FN counts.
465pub fn exact_lines_match(expected_patch: &str, actual_patch: &str) -> ClassificationMetrics {
466 let expected_lines = extract_changed_lines_from_diff(expected_patch);
467 let actual_lines = extract_changed_lines_from_diff(actual_patch);
468 ClassificationMetrics::from_counts(&expected_lines, &actual_lines)
469}
470
471/// Returns whether the patch contains any isolated whitespace-only changes.
472///
473/// A whitespace-only change is an added or deleted line whose content is empty or
474/// contains only whitespace. It is "isolated" when it is not adjacent to any
475/// substantive (non-whitespace) change within the same contiguous change group.
476pub fn has_isolated_whitespace_changes(patch_str: &str, cursor: Option<&ActualCursor>) -> bool {
477 let patch = Patch::parse_unified_diff(patch_str);
478
479 let cursor_new_file_line = cursor.as_ref().map(|c| (c.row + 1) as usize);
480
481 for hunk in &patch.hunks {
482 let lines = &hunk.lines;
483 let mut new_text_line = hunk.new_start as usize;
484
485 for (i, line) in lines.iter().enumerate() {
486 let content = match line {
487 PatchLine::Addition(s) => {
488 let addition_line = new_text_line;
489 new_text_line += 1;
490 if s.trim().is_empty() && cursor_new_file_line == Some(addition_line) {
491 continue;
492 }
493 s.as_str()
494 }
495 PatchLine::Deletion(s) => s.as_str(),
496 PatchLine::Context(_) => {
497 new_text_line += 1;
498 continue;
499 }
500 _ => continue,
501 };
502
503 if !content.trim().is_empty() {
504 continue;
505 }
506
507 if is_whitespace_change_isolated(lines, i) {
508 return true;
509 }
510 }
511 }
512
513 false
514}
515
516fn is_whitespace_change_isolated(lines: &[PatchLine], index: usize) -> bool {
517 // Look backward for a non-whitespace change before hitting a context line
518 for line in lines[..index].iter().rev() {
519 match line {
520 PatchLine::Addition(s) | PatchLine::Deletion(s) => {
521 if !s.trim().is_empty() {
522 return false;
523 }
524 }
525 _ => break,
526 }
527 }
528
529 // Look forward for a non-whitespace change before hitting a context line
530 for line in &lines[index + 1..] {
531 match line {
532 PatchLine::Addition(s) | PatchLine::Deletion(s) => {
533 if !s.trim().is_empty() {
534 return false;
535 }
536 }
537 _ => break,
538 }
539 }
540
541 true
542}
543
544/// A simple proxy for whether the prediction respects editable region.
545pub fn is_editable_region_correct(actual_patch: &str) -> bool {
546 // A typical sign of a wrong editable region: a bunch of lines deletion
547 // at the beginning or end of the patch.
548 let patch = Patch::parse_unified_diff(actual_patch);
549 if patch.hunks.is_empty() {
550 return true;
551 }
552
553 let hunk = &patch.hunks[0];
554 let mut deletions_at_start = 0;
555
556 for line in hunk.lines.iter() {
557 match line {
558 PatchLine::Deletion(_) => deletions_at_start += 1,
559 _ => break,
560 }
561 }
562
563 if deletions_at_start >= 3 {
564 return false;
565 }
566
567 true
568}
569
570#[derive(Debug, Default, Clone)]
571pub struct TokenChangeCounts {
572 pub inserted_tokens: usize,
573 pub deleted_tokens: usize,
574}
575
576/// Counts the number of inserted and deleted tokens in a unified diff patch.
577///
578/// Tokens are words and whitespace sequences (as defined by `word_diff::tokenize`).
579/// Within each hunk, the old (`-`) and new (`+`) lines are compared at the token level
580/// using an LCS-based diff, so modified lines only count the actually changed tokens
581/// rather than the entire line.
582pub fn count_patch_token_changes(patch: &str) -> TokenChangeCounts {
583 let mut counts = TokenChangeCounts::default();
584 let mut old_lines: Vec<&str> = Vec::new();
585 let mut new_lines: Vec<&str> = Vec::new();
586
587 let flush =
588 |old_lines: &mut Vec<&str>, new_lines: &mut Vec<&str>, counts: &mut TokenChangeCounts| {
589 if old_lines.is_empty() && new_lines.is_empty() {
590 return;
591 }
592
593 let old_text: String = old_lines
594 .iter()
595 .map(|line| if line.len() > 1 { &line[1..] } else { "" })
596 .collect::<Vec<_>>()
597 .join("\n");
598
599 let new_text: String = new_lines
600 .iter()
601 .map(|line| if line.len() > 1 { &line[1..] } else { "" })
602 .collect::<Vec<_>>()
603 .join("\n");
604
605 let old_tokens = tokenize(&old_text);
606 let new_tokens = tokenize(&new_text);
607 let ops = diff_tokens(&old_tokens, &new_tokens);
608
609 for op in ops {
610 match op {
611 DiffOp::Equal(..) => {}
612 DiffOp::Delete(start, end) => {
613 counts.deleted_tokens += end - start;
614 }
615 DiffOp::Insert(start, end) => {
616 counts.inserted_tokens += end - start;
617 }
618 DiffOp::Replace {
619 old_start,
620 old_end,
621 new_start,
622 new_end,
623 } => {
624 counts.deleted_tokens += old_end - old_start;
625 counts.inserted_tokens += new_end - new_start;
626 }
627 }
628 }
629
630 old_lines.clear();
631 new_lines.clear();
632 };
633
634 for line in patch.lines() {
635 if line.starts_with("---")
636 || line.starts_with("+++")
637 || line.starts_with("@@")
638 || line.starts_with("diff ")
639 || line.starts_with("index ")
640 {
641 flush(&mut old_lines, &mut new_lines, &mut counts);
642 } else if line.starts_with('-') {
643 old_lines.push(line);
644 } else if line.starts_with('+') {
645 new_lines.push(line);
646 } else {
647 flush(&mut old_lines, &mut new_lines, &mut counts);
648 }
649 }
650
651 flush(&mut old_lines, &mut new_lines, &mut counts);
652 counts
653}
654
655#[cfg(test)]
656mod test_optimization {
657 use super::*;
658
659 #[test]
660 fn test_extract_changed_regions_simple() {
661 let original: Vec<char> = "hello world".chars().collect();
662 let modified: Vec<char> = "hello there".chars().collect();
663
664 let (orig_region, mod_region) = extract_changed_regions(&original, &modified);
665
666 // "world" vs "there" - with 5 chars context, we get "ello world" vs "ello there"
667 // (or less if not enough chars available)
668 assert!(orig_region.len() < original.len());
669 assert!(mod_region.len() < modified.len());
670 }
671
672 #[test]
673 fn test_extract_changed_regions_insertion() {
674 let original: Vec<char> = "abcdef".chars().collect();
675 let modified: Vec<char> = "abcXYZdef".chars().collect();
676
677 let (orig_region, mod_region) = extract_changed_regions(&original, &modified);
678
679 // The insertion is between c and d, so we need context around that point
680 assert!(orig_region.len() <= original.len());
681 assert!(mod_region.iter().collect::<String>().contains("XYZ"));
682 }
683
684 #[test]
685 fn test_extract_changed_regions_identical() {
686 let text: Vec<char> = "identical text".chars().collect();
687
688 let (orig_region, mod_region) = extract_changed_regions(&text, &text);
689
690 // When texts are identical, regions should be empty
691 assert!(orig_region.is_empty());
692 assert!(mod_region.is_empty());
693 }
694
695 #[test]
696 fn test_optimized_matches_original_score() {
697 // Test that our optimized version produces the same results
698 let test_cases = vec![
699 ("hello world", "hello there", "hello world"),
700 (
701 "fn main() {}",
702 "fn main() { println!(); }",
703 "fn main() { print!(); }",
704 ),
705 ("abcdefghij", "abcXXXghij", "abcYYghij"),
706 ("unchanged", "unchanged", "unchanged"),
707 (
708 "prefix middle suffix",
709 "prefix CHANGED suffix",
710 "prefix middle suffix",
711 ),
712 ];
713
714 for (original, expected, actual) in test_cases {
715 let score = delta_chr_f(original, expected, actual).score;
716 // Just verify it produces a reasonable score (0-100)
717 assert!(
718 score >= 0.0 && score <= 100.0,
719 "Score {} out of range for ({}, {}, {})",
720 score,
721 original,
722 expected,
723 actual
724 );
725 }
726 }
727
728 #[test]
729 fn test_optimized_equals_reference() {
730 // Comprehensive test that optimized version matches reference implementation exactly
731 let test_cases = vec![
732 // Basic cases
733 ("hello world", "hello there", "hello world"),
734 ("hello world", "hello there", "hello there"),
735 ("unchanged", "unchanged", "unchanged"),
736 // Code-like cases
737 (
738 "fn main() { println!(\"Hello\"); }",
739 "fn main() { println!(\"Hello, World!\"); }",
740 "fn main() { println!(\"Hello, World!\"); }",
741 ),
742 (
743 "fn main() { println!(\"Hello\"); }",
744 "fn main() { println!(\"Hello, World!\"); }",
745 "fn main() { println!(\"Goodbye\"); }",
746 ),
747 // Insertion
748 ("abcdef", "abcXYZdef", "abcdef"),
749 ("abcdef", "abcXYZdef", "abcXYZdef"),
750 ("abcdef", "abcXYZdef", "abcABCdef"),
751 // Deletion
752 ("abcXYZdef", "abcdef", "abcXYZdef"),
753 ("abcXYZdef", "abcdef", "abcdef"),
754 // Multiple changes (simulated by different expected/actual)
755 ("one two three four", "one THREE four", "one two FOUR"),
756 // Edge cases
757 ("a", "b", "c"),
758 ("", "abc", ""),
759 ("abc", "", "abc"),
760 // Longer text with small change
761 (
762 "This is a longer piece of text that contains many words and characters to process",
763 "This is a longer piece of TEXT that contains many words and characters to process",
764 "This is a longer piece of text that contains many words and characters to process",
765 ),
766 // Change at the beginning
767 (
768 "ORIGINAL start of text",
769 "NEW start of text",
770 "DIFFERENT start of text",
771 ),
772 // Change at the end
773 (
774 "text ending ORIGINAL",
775 "text ending NEW",
776 "text ending DIFFERENT",
777 ),
778 // Whitespace (should be ignored)
779 ("hello world", "hello there", "hello world"),
780 ("a b c d", "a X c d", "a Y c d"),
781 ];
782
783 for (original, expected, actual) in test_cases {
784 let optimized_metrics = delta_chr_f(original, expected, actual);
785 let reference_metrics = delta_chr_f_reference(original, expected, actual);
786
787 assert!(
788 (optimized_metrics.score - reference_metrics.score).abs() < 1e-10,
789 "Score mismatch for ({:?}, {:?}, {:?}):\n optimized: {}\n reference: {}",
790 original,
791 expected,
792 actual,
793 optimized_metrics.score,
794 reference_metrics.score
795 );
796 assert_eq!(
797 optimized_metrics.counts.true_positives,
798 reference_metrics.counts.true_positives
799 );
800 assert_eq!(
801 optimized_metrics.counts.false_positives,
802 reference_metrics.counts.false_positives
803 );
804 assert_eq!(
805 optimized_metrics.counts.false_negatives,
806 reference_metrics.counts.false_negatives
807 );
808 assert!((optimized_metrics.precision - reference_metrics.precision).abs() < 1e-10);
809 assert!((optimized_metrics.recall - reference_metrics.recall).abs() < 1e-10);
810 }
811 }
812
813 #[test]
814 fn test_delta_chr_f_metrics_include_counts_and_rates() {
815 let original = "one two three";
816 let expected = "one three";
817 let actual = "one two four";
818
819 let metrics = delta_chr_f(original, expected, actual);
820
821 assert!(metrics.score > 20.0 && metrics.score < 40.0);
822 assert!(metrics.counts.true_positives > 0);
823 assert!(metrics.counts.false_positives > 0);
824 assert!(metrics.counts.false_negatives > 0);
825 assert!(metrics.precision > 0.0 && metrics.precision < 1.0);
826 assert!(metrics.recall > 0.0 && metrics.recall < 1.0);
827 assert_eq!(metrics.beta, CHR_F_BETA);
828 }
829}
830
831#[cfg(test)]
832mod test {
833 use super::*;
834 use crate::example::ActualCursor;
835 use indoc::indoc;
836
837 fn cursor_on_line(one_based_line: u32) -> ActualCursor {
838 ActualCursor {
839 path: String::new(),
840 row: one_based_line - 1,
841 column: 0,
842 offset: 0,
843 editable_region_offset: None,
844 }
845 }
846
847 #[test]
848 fn test_delta_chr_f_perfect_match() {
849 let original = "fn main() { println!(\"Hello\");}";
850 let expected = "fn main() { println!(\"Hello, World!\");}";
851
852 let score = delta_chr_f(original, expected, expected).score;
853 assert!((score - 100.0).abs() < 1e-2);
854 }
855
856 #[test]
857 fn test_delta_chr_f_wrong_edit() {
858 // When the edit is wrong
859 let original = "one two three";
860 let expected = "one three"; // deleted "two "
861 let actual = "one two four"; // deleted "three", added "four"
862
863 // Then the score should be low
864 let score = delta_chr_f(original, expected, actual).score;
865 assert!(score > 20.0 && score < 40.0);
866 }
867
868 #[test]
869 fn test_delta_chr_f_partial_match() {
870 let original = "let x = 42;";
871 let expected = "let x = 100;";
872 let actual = "let x = 99;";
873
874 // We got the edit location right, but the replacement text is wrong.
875 // Deleted ngrams will match, bringing the score somewhere in the middle.
876 let score = delta_chr_f(original, expected, actual).score;
877 assert!(score > 40.0 && score < 60.0);
878 }
879
880 #[test]
881 fn test_delta_chr_f_missed_edit() {
882 // When predictions makes no changes
883 let original = "prefix old suffix";
884 let expected = "prefix new suffix";
885 let actual = "prefix old suffix"; // no change
886
887 // Then the score should be low (all expected changes are false negatives)
888 let score = delta_chr_f(original, expected, actual).score;
889 assert!(score < 20.0);
890 }
891
892 #[test]
893 fn test_delta_chr_f_extra_edit() {
894 // When adding unexpected content
895 let original = "helloworld";
896 let expected = "helloworld"; // no change expected
897 let actual = "helloextraworld"; // added "extra"
898
899 // Then the score should be low (all actual changes are false positives)
900 let score = delta_chr_f(original, expected, actual).score;
901 assert!(score < 20.0);
902 }
903
904 #[test]
905 fn test_delta_chr_f_no_changes() {
906 let text = "unchanged text";
907 let score = delta_chr_f(text, text, text).score;
908 assert!((score - 100.0).abs() < 1e-2);
909 }
910
911 #[test]
912 fn test_braces_disbalance() {
913 let text = "let x = { 1 + 2 };";
914 assert_eq!(braces_disbalance(text), 0);
915
916 let text = "let x = { 1 + 2";
917 assert_eq!(braces_disbalance(text), 1);
918
919 let text = "let x = { 1 + 2 )";
920 assert_eq!(braces_disbalance(text), 2);
921 }
922
923 #[test]
924 fn test_extract_changed_lines_from_diff() {
925 let diff = r#"--- a/file.rs
926+++ b/file.rs
927@@ -1,3 +1,3 @@
928 fn main() {
929- println!("hello");
930+ println!("world");
931 }"#;
932
933 let counts = extract_changed_lines_from_diff(diff);
934 assert_eq!(counts.get("- println!(\"hello\");"), Some(&1));
935 assert_eq!(counts.get("+ println!(\"world\");"), Some(&1));
936 assert_eq!(counts.len(), 2);
937 }
938
939 #[test]
940 fn test_extract_changed_lines_skips_headers() {
941 let diff = r#"diff --git a/file.rs b/file.rs
942index abc123..def456 100644
943--- a/file.rs
944+++ b/file.rs
945@@ -1,2 +1,2 @@
946-old line
947+new line"#;
948
949 let counts = extract_changed_lines_from_diff(diff);
950 assert_eq!(counts.get("-old line"), Some(&1));
951 assert_eq!(counts.get("+new line"), Some(&1));
952 assert_eq!(counts.len(), 2);
953 }
954
955 #[test]
956 fn test_exact_lines_match_perfect() {
957 let expected = r#"--- a/file.rs
958+++ b/file.rs
959@@ -1,3 +1,3 @@
960-old line 1
961-old line 2
962+new line 1
963+new line 2"#;
964
965 let actual = r#"--- a/file.rs
966+++ b/file.rs
967@@ -1,3 +1,3 @@
968-old line 1
969-old line 2
970+new line 1
971+new line 2"#;
972
973 let metrics = exact_lines_match(expected, actual);
974 assert_eq!(metrics.true_positives, 4);
975 assert_eq!(metrics.false_positives, 0);
976 assert_eq!(metrics.false_negatives, 0);
977 assert!((metrics.precision() - 1.0).abs() < 1e-6);
978 assert!((metrics.recall() - 1.0).abs() < 1e-6);
979 assert!((metrics.f1() - 1.0).abs() < 1e-6);
980 }
981
982 #[test]
983 fn test_exact_lines_match_partial() {
984 let expected = r#"-old line 1
985-old line 2
986+new line 1
987+new line 2"#;
988
989 let actual = r#"-old line 1
990+new line 1
991+extra line"#;
992
993 let metrics = exact_lines_match(expected, actual);
994 // TP: "-old line 1" and "+new line 1" (2)
995 // FP: "+extra line" (1)
996 // FN: "-old line 2" and "+new line 2" (2)
997 assert_eq!(metrics.true_positives, 2);
998 assert_eq!(metrics.false_positives, 1);
999 assert_eq!(metrics.false_negatives, 2);
1000 }
1001
1002 #[test]
1003 fn test_exact_lines_match_no_overlap() {
1004 let expected = r#"-line a
1005+line b"#;
1006
1007 let actual = r#"-line x
1008+line y"#;
1009
1010 let metrics = exact_lines_match(expected, actual);
1011 assert_eq!(metrics.true_positives, 0);
1012 assert_eq!(metrics.false_positives, 2);
1013 assert_eq!(metrics.false_negatives, 2);
1014 assert!((metrics.precision()).abs() < 1e-6);
1015 assert!((metrics.recall()).abs() < 1e-6);
1016 }
1017
1018 #[test]
1019 fn test_exact_lines_match_duplicate_lines() {
1020 let expected = r#"+line a
1021+line a
1022+line a"#;
1023
1024 let actual = r#"+line a
1025+line a"#;
1026
1027 let metrics = exact_lines_match(expected, actual);
1028 // Expected has 3 "+line a", actual has 2
1029 // TP: 2, FN: 1, FP: 0
1030 assert_eq!(metrics.true_positives, 2);
1031 assert_eq!(metrics.false_positives, 0);
1032 assert_eq!(metrics.false_negatives, 1);
1033 }
1034
1035 #[test]
1036 fn test_exact_lines_match_empty_patches() {
1037 let metrics = exact_lines_match("", "");
1038 assert_eq!(metrics.true_positives, 0);
1039 assert_eq!(metrics.false_positives, 0);
1040 assert_eq!(metrics.false_negatives, 0);
1041 }
1042
1043 #[test]
1044 fn test_is_editable_region_correct() {
1045 let patch = indoc! {"
1046 @@ -1,1 +1,1 @@
1047 -context
1048 -removed
1049 -from the beginning of the file
1050 import sys
1051 +sys.exit(0)
1052
1053 "};
1054 assert!(!is_editable_region_correct(patch));
1055
1056 let patch = indoc! {"
1057 @@ -1,1 +1,1 @@
1058 "};
1059 assert!(is_editable_region_correct(patch));
1060 }
1061
1062 #[test]
1063 fn test_isolated_whitespace_purely_whitespace_patch() {
1064 let patch = indoc! {"
1065 @@ -1,3 +1,4 @@
1066 fn main() {
1067 +
1068 println!(\"hello\");
1069 }
1070 "};
1071 assert!(has_isolated_whitespace_changes(patch, None));
1072 }
1073
1074 #[test]
1075 fn test_isolated_whitespace_adjacent_to_real_change() {
1076 let patch = indoc! {"
1077 @@ -1,3 +1,4 @@
1078 fn main() {
1079 +
1080 + let x = 1;
1081 println!(\"hello\");
1082 }
1083 "};
1084 assert!(!has_isolated_whitespace_changes(patch, None));
1085 }
1086
1087 #[test]
1088 fn test_isolated_whitespace_no_whitespace_changes() {
1089 let patch = indoc! {"
1090 @@ -1,3 +1,3 @@
1091 fn main() {
1092 - println!(\"hello\");
1093 + println!(\"world\");
1094 }
1095 "};
1096 assert!(!has_isolated_whitespace_changes(patch, None));
1097 }
1098
1099 #[test]
1100 fn test_isolated_whitespace_deletion() {
1101 let patch = indoc! {"
1102 @@ -1,4 +1,3 @@
1103 fn main() {
1104 -
1105 println!(\"hello\");
1106 }
1107 "};
1108 assert!(has_isolated_whitespace_changes(patch, None));
1109 }
1110
1111 #[test]
1112 fn test_isolated_whitespace_mixed_groups() {
1113 let patch = indoc! {"
1114 @@ -1,7 +1,8 @@
1115 fn main() {
1116 +
1117 let x = 1;
1118 - let y = 2;
1119 + let y = 3;
1120
1121 +
1122 println!(\"hello\");
1123 }
1124 "};
1125 assert!(has_isolated_whitespace_changes(patch, None));
1126 }
1127
1128 #[test]
1129 fn test_isolated_whitespace_empty_patch() {
1130 let patch = "";
1131 assert!(!has_isolated_whitespace_changes(patch, None));
1132 }
1133
1134 #[test]
1135 fn test_isolated_whitespace_skipped_on_cursor_line() {
1136 // The addition of a blank line at new-file line 2 should be skipped
1137 // because the cursor is on that line.
1138 let patch = indoc! {"
1139 @@ -1,3 +1,4 @@
1140 fn main() {
1141 +
1142 println!(\"hello\");
1143 }
1144 "};
1145 // New-file line 2 is the added blank line
1146 let cursor = cursor_on_line(2);
1147 assert!(!has_isolated_whitespace_changes(patch, Some(&cursor)));
1148 }
1149
1150 #[test]
1151 fn test_isolated_whitespace_not_skipped_when_cursor_on_different_line() {
1152 // The blank line is at new-file line 2, but the cursor is on line 1.
1153 let patch = indoc! {"
1154 @@ -1,3 +1,4 @@
1155 fn main() {
1156 +
1157 println!(\"hello\");
1158 }
1159 "};
1160 let cursor = cursor_on_line(1);
1161 assert!(has_isolated_whitespace_changes(patch, Some(&cursor)));
1162 }
1163
1164 #[test]
1165 fn test_isolated_whitespace_deletion_not_skipped_by_cursor() {
1166 // Deletions don't have a new-file line, so cursor can't suppress them.
1167 let patch = indoc! {"
1168 @@ -1,4 +1,3 @@
1169 fn main() {
1170 -
1171 println!(\"hello\");
1172 }
1173 "};
1174 let cursor = cursor_on_line(2);
1175 assert!(has_isolated_whitespace_changes(patch, Some(&cursor)));
1176 }
1177
1178 #[test]
1179 fn test_count_patch_token_changes_real_world_rename() {
1180 // Real-world patch that was reported as returning 0 tokens
1181 let patch = "--- a/sip_call\\README.md\n+++ b/sip_call\\README.md\n@@ -1,1 +1,1 @@\n-# \n+# SIP Call\n";
1182 let counts = count_patch_token_changes(patch);
1183 // "# " vs "# SIP Call" — the "SIP" and "Call" tokens (and a whitespace token) are inserted
1184 assert!(
1185 counts.inserted_tokens > 0,
1186 "expected inserted tokens > 0, got {}",
1187 counts.inserted_tokens
1188 );
1189 assert_eq!(counts.deleted_tokens, 0);
1190 }
1191
1192 #[test]
1193 fn test_count_patch_token_changes_real_world_expansion() {
1194 // Real-world patch: single token expanded to multiple lines
1195 let patch = "--- a/task1/src/app/app.html\n+++ b/task1/src/app/app.html\n@@ -1,7 +1,9 @@\n <style>\n- m\n+ main {\n+ \n+ }\n </style>\n \n <main>\n \n </main>\n";
1196 let counts = count_patch_token_changes(patch);
1197 assert!(
1198 counts.inserted_tokens > 0,
1199 "expected inserted tokens > 0, got {}",
1200 counts.inserted_tokens
1201 );
1202 assert!(
1203 counts.deleted_tokens > 0,
1204 "expected deleted tokens > 0, got {}",
1205 counts.deleted_tokens
1206 );
1207 }
1208
1209 #[test]
1210 fn test_count_patch_token_changes_simple_replacement() {
1211 let patch = indoc! {"
1212 @@ -1,3 +1,3 @@
1213 fn main() {
1214 - println!(\"hello\");
1215 + println!(\"world\");
1216 }
1217 "};
1218 let counts = count_patch_token_changes(patch);
1219 assert_eq!(counts.deleted_tokens, 1, "deleted: \"hello\"");
1220 assert_eq!(counts.inserted_tokens, 1, "inserted: \"world\"");
1221 }
1222
1223 #[test]
1224 fn test_count_patch_token_changes_insertion_only() {
1225 let patch = indoc! {"
1226 @@ -1,2 +1,3 @@
1227 fn main() {
1228 + println!(\"hello\");
1229 }
1230 "};
1231 let counts = count_patch_token_changes(patch);
1232 assert_eq!(counts.deleted_tokens, 0);
1233 assert!(counts.inserted_tokens > 0);
1234 }
1235
1236 #[test]
1237 fn test_count_patch_token_changes_deletion_only() {
1238 let patch = indoc! {"
1239 @@ -1,3 +1,2 @@
1240 fn main() {
1241 - println!(\"hello\");
1242 }
1243 "};
1244 let counts = count_patch_token_changes(patch);
1245 assert!(counts.deleted_tokens > 0);
1246 assert_eq!(counts.inserted_tokens, 0);
1247 }
1248
1249 #[test]
1250 fn test_count_patch_token_changes_empty_patch() {
1251 let patch = "";
1252 let counts = count_patch_token_changes(patch);
1253 assert_eq!(counts.deleted_tokens, 0);
1254 assert_eq!(counts.inserted_tokens, 0);
1255 }
1256
1257 #[test]
1258 fn test_count_patch_token_changes_multiple_hunks() {
1259 let patch = indoc! {"
1260 @@ -1,3 +1,3 @@
1261 fn main() {
1262 - let x = 1;
1263 + let x = 2;
1264 }
1265 @@ -10,3 +10,3 @@
1266 fn other() {
1267 - let y = 3;
1268 + let y = 4;
1269 }
1270 "};
1271 let counts = count_patch_token_changes(patch);
1272 assert_eq!(counts.deleted_tokens, 2, "deleted: \"1\" and \"3\"");
1273 assert_eq!(counts.inserted_tokens, 2, "inserted: \"2\" and \"4\"");
1274 }
1275
1276 #[test]
1277 fn test_count_patch_token_changes_multiword_change() {
1278 let patch = indoc! {"
1279 @@ -1,1 +1,1 @@
1280 -hello world foo
1281 +hello bar baz
1282 "};
1283 let counts = count_patch_token_changes(patch);
1284 // "world" and "foo" deleted, "bar" and "baz" inserted
1285 // (whitespace tokens between them may also count)
1286 assert!(counts.deleted_tokens >= 2);
1287 assert!(counts.inserted_tokens >= 2);
1288 }
1289
1290 #[test]
1291 fn test_whitespace_collapse() {
1292 let text = "abc \n\n\n 123";
1293 let collapsed = collapse_whitespace(text.chars());
1294 assert_eq!(
1295 collapsed,
1296 vec!['a', 'b', 'c', ' ', '\n', ' ', '1', '2', '3']
1297 );
1298 }
1299}
1300
1301pub use edit_prediction::metrics::compute_kept_rate;