1use collections::{HashMap, HashSet};
2use edit_prediction::udiff::DiffLine;
3use serde::{Deserialize, Serialize};
4
5type Counts = HashMap<String, usize>;
6type CountsDelta = HashMap<String, isize>;
7
8#[derive(Default, Debug, Clone, Serialize, Deserialize)]
9pub struct ClassificationMetrics {
10 pub true_positives: usize,
11 pub false_positives: usize,
12 pub false_negatives: usize,
13}
14
15impl ClassificationMetrics {
16 pub fn from_sets(
17 expected: &HashSet<String>,
18 actual: &HashSet<String>,
19 ) -> ClassificationMetrics {
20 let true_positives = expected.intersection(actual).count();
21 let false_positives = actual.difference(expected).count();
22 let false_negatives = expected.difference(actual).count();
23
24 ClassificationMetrics {
25 true_positives,
26 false_positives,
27 false_negatives,
28 }
29 }
30
31 pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
32 let mut true_positives = 0;
33 let mut false_positives = 0;
34 let mut false_negatives = 0;
35
36 for (ngram, &expected_count) in expected {
37 let actual_count = *actual.get(ngram).unwrap_or(&0);
38 if actual_count > expected_count {
39 false_positives += actual_count - expected_count;
40 } else {
41 false_negatives += expected_count - actual_count;
42 }
43 true_positives += expected_count.min(actual_count);
44 }
45
46 for (ngram, &actual_count) in actual {
47 if !expected.contains_key(ngram) {
48 false_positives += actual_count;
49 }
50 }
51
52 ClassificationMetrics {
53 true_positives,
54 false_positives,
55 false_negatives,
56 }
57 }
58
59 pub fn aggregate<'a>(
60 scores: impl Iterator<Item = &'a ClassificationMetrics>,
61 ) -> ClassificationMetrics {
62 let mut true_positives = 0;
63 let mut false_positives = 0;
64 let mut false_negatives = 0;
65
66 for score in scores {
67 true_positives += score.true_positives;
68 false_positives += score.false_positives;
69 false_negatives += score.false_negatives;
70 }
71
72 ClassificationMetrics {
73 true_positives,
74 false_positives,
75 false_negatives,
76 }
77 }
78
79 pub fn precision(&self) -> f64 {
80 if self.true_positives + self.false_positives == 0 {
81 0.0
82 } else {
83 self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
84 }
85 }
86
87 pub fn recall(&self) -> f64 {
88 if self.true_positives + self.false_negatives == 0 {
89 0.0
90 } else {
91 self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
92 }
93 }
94
95 pub fn f1_score(&self) -> f64 {
96 let recall = self.recall();
97 let precision = self.precision();
98 if precision + recall == 0.0 {
99 0.0
100 } else {
101 2.0 * precision * recall / (precision + recall)
102 }
103 }
104}
105
106pub fn line_match_score(
107 expected_patch: &[DiffLine],
108 actual_patch: &[DiffLine],
109) -> ClassificationMetrics {
110 let expected_change_lines = expected_patch
111 .iter()
112 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
113 .map(|line| line.to_string())
114 .collect();
115
116 let actual_change_lines = actual_patch
117 .iter()
118 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
119 .map(|line| line.to_string())
120 .collect();
121
122 ClassificationMetrics::from_sets(&expected_change_lines, &actual_change_lines)
123}
124
125enum ChrfWhitespace {
126 #[allow(unused)]
127 Unchanged,
128 Ignore,
129}
130
131const CHR_F_CHAR_ORDER: usize = 6;
132const CHR_F_BETA: f64 = 2.0;
133const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Ignore;
134
135/// Computes a delta-chrF score that compares two sets of edits.
136///
137/// This metric works by:
138/// 1. Reconstructing original, golden (expected result), and actual texts from diffs
139/// 2. Computing n-gram count differences (deltas) between original→golden and original→actual
140/// 3. Comparing these deltas to measure how well actual edits match expected edits
141pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 {
142 // Reconstruct texts from diffs
143 let mut original_text = String::new(); // state of the text before any edits
144 let mut golden_text = String::new(); // text after applying golden edits
145 let mut actual_text = String::new(); // text after applying actual edits
146
147 for line in expected {
148 match line {
149 DiffLine::Context(s) => {
150 original_text.push_str(s);
151 golden_text.push_str(s);
152 }
153 DiffLine::Deletion(s) => {
154 original_text.push_str(s);
155 }
156 DiffLine::Addition(s) => {
157 golden_text.push_str(s);
158 }
159 _ => {}
160 }
161 }
162
163 for line in actual {
164 match line {
165 DiffLine::Context(s) | DiffLine::Addition(s) => {
166 actual_text.push_str(s);
167 }
168 _ => {}
169 }
170 }
171
172 // Edge case
173 if original_text == golden_text && golden_text == actual_text {
174 return 100.0;
175 }
176
177 // Compute the metric
178 let original_ngrams = chr_f_ngram_counts(&original_text);
179 let golden_ngrams = chr_f_ngram_counts(&golden_text);
180 let actual_ngrams = chr_f_ngram_counts(&actual_text);
181
182 let mut total_precision = 0.0;
183 let mut total_recall = 0.0;
184
185 for order in 0..CHR_F_CHAR_ORDER {
186 let expected_delta = compute_ngram_delta(&golden_ngrams[order], &original_ngrams[order]);
187 let actual_delta = compute_ngram_delta(&actual_ngrams[order], &original_ngrams[order]);
188
189 if expected_delta.is_empty() && actual_delta.is_empty() {
190 total_precision += 1.0;
191 total_recall += 1.0;
192 continue;
193 }
194
195 let expected_counts = ngram_delta_to_counts(&expected_delta);
196 let actual_counts = ngram_delta_to_counts(&actual_delta);
197
198 let score = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
199 total_precision += score.precision();
200 total_recall += score.recall();
201 }
202
203 let prec = total_precision / CHR_F_CHAR_ORDER as f64;
204 let recall = total_recall / CHR_F_CHAR_ORDER as f64;
205 let f_score = if prec + recall == 0.0 {
206 0.0
207 } else {
208 (1.0 + CHR_F_BETA * CHR_F_BETA) * prec * recall / (CHR_F_BETA * CHR_F_BETA * prec + recall)
209 };
210
211 f_score * 100.0
212}
213
214fn chr_f_ngram_counts(text: &str) -> Vec<Counts> {
215 // Ignore whitespace. The original chrF implementation skips all
216 // whitespace. We should consider compressing multiple consecutive
217 // spaces into one -- this may reflect our task more closely.
218 let text = match CHR_F_WHITESPACE {
219 ChrfWhitespace::Unchanged => text.to_string(),
220 ChrfWhitespace::Ignore => text
221 .chars()
222 .filter(|c| !c.is_whitespace())
223 .collect::<String>(),
224 };
225
226 (1..=CHR_F_CHAR_ORDER)
227 .map(|order| count_ngrams(&text, order))
228 .collect()
229}
230
231fn compute_ngram_delta(after: &Counts, before: &Counts) -> CountsDelta {
232 let mut delta = CountsDelta::default();
233
234 for (ngram, &before_count) in before {
235 let after_count = *after.get(ngram).unwrap_or(&0);
236 delta.insert(ngram.clone(), after_count as isize - before_count as isize);
237 }
238
239 for (ngram, &after_count) in after {
240 if !before.contains_key(ngram) {
241 delta.insert(ngram.clone(), after_count as isize);
242 }
243 }
244
245 delta
246}
247
248/// Convert negative counts to special deletion tokens.
249/// For example, if expected delta is {"foo": -1} and actual delta is {"bar": -1},
250/// we convert it to {"¬foo": +1} and {"¬bar": +1}. This way _not_ deleting "foo"
251/// will result in a false negative, and mistakenly deleting "bar" will result in a false positive.
252fn ngram_delta_to_counts(delta: &CountsDelta) -> Counts {
253 let mut counts = Counts::default();
254
255 for (ngram, &delta) in delta {
256 if delta > 0 {
257 counts.insert(ngram.clone(), delta as usize);
258 } else {
259 counts.insert(format!("¬{ngram}"), delta.unsigned_abs());
260 }
261 }
262
263 counts
264}
265
266fn count_ngrams(text: &str, n: usize) -> Counts {
267 let chars: Vec<char> = text.chars().collect();
268 let mut counts = Counts::default();
269
270 for window in chars.windows(n) {
271 let ngram: String = window.iter().collect();
272 *counts.entry(ngram).or_insert(0) += 1;
273 }
274
275 counts
276}
277
278#[cfg(test)]
279mod test {
280 use super::*;
281 use edit_prediction::udiff::DiffLine;
282
283 #[test]
284 fn test_delta_chr_f_perfect_match() {
285 let diff = vec![
286 DiffLine::Context("fn main() {"),
287 DiffLine::Deletion(" println!(\"Hello\");"),
288 DiffLine::Addition(" println!(\"Hello, World!\");"),
289 DiffLine::Context("}"),
290 ];
291
292 let score = delta_chr_f(&diff, &diff);
293 assert!((score - 100.0).abs() < 1e-2);
294 }
295
296 #[test]
297 fn test_delta_chr_f_wrong_edit() {
298 // When the edit is wrong
299 let expected = vec![
300 DiffLine::Context("one "),
301 DiffLine::Deletion("two "),
302 DiffLine::Context("three"),
303 ];
304
305 let actual = vec![
306 DiffLine::Context("one "),
307 DiffLine::Context("two "),
308 DiffLine::Deletion("three"),
309 DiffLine::Addition("four"),
310 ];
311
312 // Then the score should be low
313 let score = delta_chr_f(&expected, &actual);
314 assert!(score > 20.0 && score < 40.0);
315 }
316
317 #[test]
318 fn test_delta_chr_f_partial_match() {
319 let expected = vec![
320 DiffLine::Deletion("let x = 42;"),
321 DiffLine::Addition("let x = 100;"),
322 ];
323
324 let actual = vec![
325 DiffLine::Deletion("let x = 42;"),
326 DiffLine::Addition("let x = 99;"),
327 ];
328
329 // We got the edit location right, but the replacement text is wrong.
330 // Deleted ngrams will match, bringing the score somewhere in the middle.
331 let score = delta_chr_f(&expected, &actual);
332 assert!(score > 40.0 && score < 60.0);
333 }
334
335 #[test]
336 fn test_delta_chr_f_missed_edit() {
337 // When predictions makes no changes
338 let expected = vec![
339 DiffLine::Context("prefix "),
340 DiffLine::Deletion("old"),
341 DiffLine::Addition("new"),
342 DiffLine::Context(" suffix"),
343 ];
344
345 let actual = vec![
346 DiffLine::Context("prefix "),
347 DiffLine::Context("old"),
348 DiffLine::Context(" suffix"),
349 ];
350
351 // Then the score should be low (all expected changes are false negatives)
352 let score = delta_chr_f(&expected, &actual);
353 assert!(score < 20.0);
354 }
355
356 #[test]
357 fn test_delta_chr_f_extra_edit() {
358 // When adding unexpected content
359 let expected = vec![DiffLine::Context("hello"), DiffLine::Context("world")];
360
361 let actual = vec![
362 DiffLine::Context("hello"),
363 DiffLine::Addition("extra"),
364 DiffLine::Context("world"),
365 ];
366
367 // Then the score should be low (all actual changes are false positives)
368 let score = delta_chr_f(&expected, &actual);
369 assert!(score < 20.0);
370 }
371}