1use std::env;
2use std::fmt::Write as _;
3use std::fs;
4use std::path::Path;
5use std::process;
6
7use edit_prediction_metrics::{
8 ClassificationMetrics, DeltaChrFMetrics, KeptRateResult, TokenAnnotation,
9 annotate_kept_rate_tokens, braces_disbalance, compute_kept_rate, count_patch_token_changes,
10 delta_chr_f, exact_lines_match, extract_changed_lines_from_diff,
11 has_isolated_whitespace_changes, is_editable_region_correct,
12};
13use serde::Deserialize;
14
15fn main() {
16 if let Err(error) = run() {
17 eprintln!("error: {error}");
18 process::exit(1);
19 }
20}
21
22fn run() -> Result<(), String> {
23 let args: Vec<String> = env::args().skip(1).collect();
24 if args.is_empty() {
25 print_usage();
26 return Err("missing arguments".to_string());
27 }
28
29 let input = CliInput::parse(&args)?;
30 let report = match input {
31 CliInput::Files {
32 base_path,
33 expected_patch_path,
34 actual_patch_path,
35 } => {
36 let base = fs::read_to_string(&base_path)
37 .map_err(|err| format!("failed to read {}: {err}", base_path.display()))?;
38 let expected_patch = fs::read_to_string(&expected_patch_path).map_err(|err| {
39 format!("failed to read {}: {err}", expected_patch_path.display())
40 })?;
41 let actual_patch = fs::read_to_string(&actual_patch_path)
42 .map_err(|err| format!("failed to read {}: {err}", actual_patch_path.display()))?;
43
44 let expected = apply_patch_to_excerpt(&base, &expected_patch, 0)?;
45 let actual = apply_patch_to_excerpt(&base, &actual_patch, 0)?;
46
47 EvaluationReport::new(base, expected_patch, actual_patch, expected, actual)
48 }
49 CliInput::Json {
50 json_path,
51 prediction_index,
52 } => {
53 let json = fs::read_to_string(&json_path)
54 .map_err(|err| format!("failed to read {}: {err}", json_path.display()))?;
55 let example: JsonExample = serde_json::from_str(&json)
56 .map_err(|err| format!("failed to parse {}: {err}", json_path.display()))?;
57
58 let base = example.prompt_inputs.cursor_excerpt;
59 let excerpt_start_row = example.prompt_inputs.excerpt_start_row;
60 let expected_patch = example
61 .expected_patches
62 .into_iter()
63 .next()
64 .ok_or_else(|| "JSON input is missing expected_patches[0]".to_string())?;
65 let actual_patch = example
66 .predictions
67 .into_iter()
68 .nth(prediction_index)
69 .ok_or_else(|| {
70 format!("JSON input does not contain predictions[{prediction_index}]")
71 })?
72 .actual_patch;
73
74 let expected = apply_patch_to_excerpt(&base, &expected_patch, excerpt_start_row)?;
75 let actual = apply_patch_to_excerpt(&base, &actual_patch, excerpt_start_row)?;
76
77 EvaluationReport::new(base, expected_patch, actual_patch, expected, actual)
78 }
79 };
80
81 print_report(&report);
82 Ok(())
83}
84
85fn print_usage() {
86 eprintln!(
87 "Usage:\n edit_prediction_metrics --base <base.txt> --expected-patch <expected.diff> --actual-patch <actual.diff>\n edit_prediction_metrics --json <example.json> [--prediction-index <n>]"
88 );
89}
90
91enum CliInput {
92 Files {
93 base_path: std::path::PathBuf,
94 expected_patch_path: std::path::PathBuf,
95 actual_patch_path: std::path::PathBuf,
96 },
97 Json {
98 json_path: std::path::PathBuf,
99 prediction_index: usize,
100 },
101}
102
103impl CliInput {
104 fn parse(args: &[String]) -> Result<Self, String> {
105 let mut base_path = None;
106 let mut expected_patch_path = None;
107 let mut actual_patch_path = None;
108 let mut json_path = None;
109 let mut prediction_index = 0usize;
110
111 let mut index = 0;
112 while index < args.len() {
113 match args[index].as_str() {
114 "--base" => {
115 index += 1;
116 base_path = Some(path_arg(args, index, "--base")?);
117 }
118 "--expected-patch" => {
119 index += 1;
120 expected_patch_path = Some(path_arg(args, index, "--expected-patch")?);
121 }
122 "--actual-patch" => {
123 index += 1;
124 actual_patch_path = Some(path_arg(args, index, "--actual-patch")?);
125 }
126 "--json" => {
127 index += 1;
128 json_path = Some(path_arg(args, index, "--json")?);
129 }
130 "--prediction-index" => {
131 index += 1;
132 let raw = string_arg(args, index, "--prediction-index")?;
133 prediction_index = raw.parse::<usize>().map_err(|err| {
134 format!("invalid value for --prediction-index ({raw}): {err}")
135 })?;
136 }
137 "--help" | "-h" => {
138 print_usage();
139 process::exit(0);
140 }
141 unknown => {
142 return Err(format!("unrecognized argument: {unknown}"));
143 }
144 }
145 index += 1;
146 }
147
148 if let Some(json_path) = json_path {
149 if base_path.is_some() || expected_patch_path.is_some() || actual_patch_path.is_some() {
150 return Err(
151 "--json cannot be combined with --base/--expected-patch/--actual-patch"
152 .to_string(),
153 );
154 }
155 return Ok(CliInput::Json {
156 json_path,
157 prediction_index,
158 });
159 }
160
161 match (base_path, expected_patch_path, actual_patch_path) {
162 (Some(base_path), Some(expected_patch_path), Some(actual_patch_path)) => {
163 Ok(CliInput::Files {
164 base_path,
165 expected_patch_path,
166 actual_patch_path,
167 })
168 }
169 _ => Err(
170 "expected either --json <file> or all of --base, --expected-patch, and --actual-patch"
171 .to_string(),
172 ),
173 }
174 }
175}
176
177fn path_arg(args: &[String], index: usize, flag: &str) -> Result<std::path::PathBuf, String> {
178 Ok(Path::new(string_arg(args, index, flag)?).to_path_buf())
179}
180
181fn string_arg<'a>(args: &'a [String], index: usize, flag: &str) -> Result<&'a str, String> {
182 args.get(index)
183 .map(|value| value.as_str())
184 .ok_or_else(|| format!("missing value for {flag}"))
185}
186
187#[derive(Debug)]
188struct EvaluationReport {
189 base: String,
190 expected: String,
191 actual: String,
192 kept_rate: KeptRateResult,
193 exact_lines: ClassificationMetrics,
194 delta_chr_f: DeltaChrFMetrics,
195 expected_changed_lines: usize,
196 actual_changed_lines: usize,
197 token_changes: edit_prediction_metrics::TokenChangeCounts,
198 isolated_whitespace_changes: bool,
199 editable_region_correct: bool,
200 expected_braces_disbalance: usize,
201 actual_braces_disbalance: usize,
202}
203
204impl EvaluationReport {
205 fn new(
206 base: String,
207 expected_patch: String,
208 actual_patch: String,
209 expected: String,
210 actual: String,
211 ) -> Self {
212 let kept_rate = compute_kept_rate(&base, &actual, &expected);
213 let exact_lines = exact_lines_match(&expected_patch, &actual_patch);
214 let delta_chr_f = delta_chr_f(&base, &expected, &actual);
215 let expected_changed_lines = extract_changed_lines_from_diff(&expected_patch)
216 .values()
217 .sum();
218 let actual_changed_lines = extract_changed_lines_from_diff(&actual_patch)
219 .values()
220 .sum();
221 let token_changes = count_patch_token_changes(&actual_patch);
222 let isolated_whitespace_changes = has_isolated_whitespace_changes(&actual_patch, None);
223 let editable_region_correct = is_editable_region_correct(&actual_patch);
224 let expected_braces_disbalance = braces_disbalance(&expected);
225 let actual_braces_disbalance = braces_disbalance(&actual);
226
227 Self {
228 base,
229 expected,
230 actual,
231 kept_rate,
232 exact_lines,
233 delta_chr_f,
234 expected_changed_lines,
235 actual_changed_lines,
236 token_changes,
237 isolated_whitespace_changes,
238 editable_region_correct,
239 expected_braces_disbalance,
240 actual_braces_disbalance,
241 }
242 }
243}
244
245fn print_report(report: &EvaluationReport) {
246 println!("Metrics");
247 println!("=======");
248 println!("kept_rate: {:.6}", report.kept_rate.kept_rate);
249 println!("kept_rate_recall: {:.6}", report.kept_rate.recall_rate);
250 println!("delta_chr_f: {:.6}", report.delta_chr_f.score);
251 println!("delta_chr_f_precision: {:.6}", report.delta_chr_f.precision);
252 println!("delta_chr_f_recall: {:.6}", report.delta_chr_f.recall);
253 println!("delta_chr_f_beta: {:.6}", report.delta_chr_f.beta);
254 println!();
255
256 println!("Exact line match");
257 println!("----------------");
258 println!("true_positives: {}", report.exact_lines.true_positives);
259 println!("false_positives: {}", report.exact_lines.false_positives);
260 println!("false_negatives: {}", report.exact_lines.false_negatives);
261 println!("precision: {:.6}", report.exact_lines.precision());
262 println!("recall: {:.6}", report.exact_lines.recall());
263 println!("f1: {:.6}", report.exact_lines.f1());
264 println!("expected_changed_lines: {}", report.expected_changed_lines);
265 println!("actual_changed_lines: {}", report.actual_changed_lines);
266 println!();
267
268 println!("Patch structure");
269 println!("---------------");
270 println!("inserted_tokens: {}", report.token_changes.inserted_tokens);
271 println!("deleted_tokens: {}", report.token_changes.deleted_tokens);
272 println!(
273 "isolated_whitespace_changes: {}",
274 report.isolated_whitespace_changes
275 );
276 println!(
277 "editable_region_correct: {}",
278 report.editable_region_correct
279 );
280 println!();
281
282 println!("Final text checks");
283 println!("-----------------");
284 println!(
285 "expected_braces_disbalance: {}",
286 report.expected_braces_disbalance
287 );
288 println!(
289 "actual_braces_disbalance: {}",
290 report.actual_braces_disbalance
291 );
292 println!();
293
294 println!("Kept-rate breakdown");
295 println!("-------------------");
296 println!(
297 "candidate_new_chars: {}",
298 report.kept_rate.candidate_new_chars
299 );
300 println!(
301 "reference_new_chars: {}",
302 report.kept_rate.reference_new_chars
303 );
304 println!(
305 "candidate_deleted_chars: {}",
306 report.kept_rate.candidate_deleted_chars
307 );
308 println!(
309 "reference_deleted_chars: {}",
310 report.kept_rate.reference_deleted_chars
311 );
312 println!("kept_chars: {}", report.kept_rate.kept_chars);
313 println!(
314 "correctly_deleted_chars: {}",
315 report.kept_rate.correctly_deleted_chars
316 );
317 println!("discarded_chars: {}", report.kept_rate.discarded_chars);
318 println!("context_chars: {}", report.kept_rate.context_chars);
319 println!();
320
321 print_kept_rate_explanation(&report.base, &report.actual, &report.expected);
322}
323
324fn print_kept_rate_explanation(base: &str, actual: &str, expected: &str) {
325 println!("Kept-rate explanation");
326 println!("---------------------");
327 println!("Legend: context = default, kept = green background, discarded = red background");
328 println!();
329
330 let annotated = annotate_kept_rate_tokens(base, actual, expected);
331 println!("Actual final text with token annotations:");
332 println!("{}", render_annotated_tokens(&annotated));
333 println!();
334}
335
336fn render_annotated_tokens(tokens: &[edit_prediction_metrics::AnnotatedToken]) -> String {
337 const RESET: &str = "\x1b[0m";
338 const KEPT_STYLE: &str = "\x1b[30;42m";
339 const DISCARDED_STYLE: &str = "\x1b[30;41m";
340
341 let mut rendered = String::new();
342 for token in tokens {
343 let style = match token.annotation {
344 TokenAnnotation::Context => "",
345 TokenAnnotation::Kept => KEPT_STYLE,
346 TokenAnnotation::Discarded => DISCARDED_STYLE,
347 };
348
349 if style.is_empty() {
350 rendered.push_str(&visualize_whitespace(&token.token));
351 } else {
352 rendered.push_str(style);
353 rendered.push_str(&visualize_whitespace(&token.token));
354 rendered.push_str(RESET);
355 }
356 }
357 rendered
358}
359
360fn visualize_whitespace(token: &str) -> String {
361 let mut rendered = String::new();
362 for ch in token.chars() {
363 match ch {
364 ' ' => rendered.push('·'),
365 '\t' => rendered.push('⇥'),
366 '\n' => rendered.push_str("↵\n"),
367 _ => rendered.push(ch),
368 }
369 }
370 rendered
371}
372
373#[derive(Debug, Deserialize)]
374struct JsonExample {
375 prompt_inputs: PromptInputs,
376 expected_patches: Vec<String>,
377 predictions: Vec<Prediction>,
378}
379
380#[derive(Debug, Deserialize)]
381struct PromptInputs {
382 cursor_excerpt: String,
383 excerpt_start_row: u32,
384}
385
386#[derive(Debug, Deserialize)]
387struct Prediction {
388 actual_patch: String,
389}
390
391#[derive(Debug, Clone)]
392struct ParsedHunk {
393 old_start: u32,
394 lines: Vec<HunkLine>,
395}
396
397#[derive(Debug, Clone)]
398enum HunkLine {
399 Context(String),
400 Addition(String),
401 Deletion(String),
402}
403
404fn apply_patch_to_excerpt(
405 base: &str,
406 patch: &str,
407 excerpt_start_row: u32,
408) -> Result<String, String> {
409 let hunks = parse_diff_hunks(patch);
410
411 let result = try_apply_hunks(base, &hunks, excerpt_start_row);
412
413 // Predicted patches may use excerpt-relative line numbers instead of
414 // file-global ones. When all hunks fall outside the excerpt window the
415 // result is identical to the base text. Retry with a zero offset so the
416 // line numbers are interpreted relative to the excerpt.
417 if excerpt_start_row > 0 && !hunks.is_empty() {
418 let should_retry = match &result {
419 Ok(text) => text == base,
420 Err(_) => true,
421 };
422
423 if should_retry {
424 let fallback = try_apply_hunks(base, &hunks, 0);
425 if matches!(&fallback, Ok(text) if text != base) {
426 return fallback;
427 }
428 }
429 }
430
431 result
432}
433
434fn try_apply_hunks(
435 base: &str,
436 hunks: &[ParsedHunk],
437 excerpt_start_row: u32,
438) -> Result<String, String> {
439 let base_has_trailing_newline = base.ends_with('\n');
440 let mut lines = split_preserving_final_empty_line(base);
441 let original_line_count = lines.len() as u32;
442
443 let excerpt_end_row = excerpt_start_row + original_line_count;
444 let mut line_delta: i64 = 0;
445
446 for hunk in hunks {
447 let filtered = match filter_hunk_to_excerpt(hunk, excerpt_start_row, excerpt_end_row) {
448 Some(filtered) => filtered,
449 None => continue,
450 };
451
452 let local_start = filtered.old_start.saturating_sub(excerpt_start_row) as i64 + line_delta;
453 if local_start < 0 {
454 return Err(format!(
455 "patch application moved before excerpt start at source row {}",
456 filtered.old_start
457 ));
458 }
459 let local_start = local_start as usize;
460
461 if local_start > lines.len() {
462 return Err(format!(
463 "patch application starts past excerpt end at local line {}",
464 local_start + 1
465 ));
466 }
467
468 let old_len = filtered
469 .lines
470 .iter()
471 .filter(|line| !matches!(line, HunkLine::Addition(_)))
472 .count();
473 let new_len = filtered
474 .lines
475 .iter()
476 .filter(|line| !matches!(line, HunkLine::Deletion(_)))
477 .count();
478
479 let old_segment: Vec<&str> = filtered
480 .lines
481 .iter()
482 .filter_map(|line| match line {
483 HunkLine::Context(text) | HunkLine::Deletion(text) => Some(text.as_str()),
484 HunkLine::Addition(_) => None,
485 })
486 .collect();
487
488 let new_segment: Vec<String> = filtered
489 .lines
490 .iter()
491 .filter_map(|line| match line {
492 HunkLine::Context(text) | HunkLine::Addition(text) => Some(text.clone()),
493 HunkLine::Deletion(_) => None,
494 })
495 .collect();
496
497 if local_start + old_len > lines.len() {
498 return Err(format!(
499 "patch application exceeds excerpt bounds near source row {}",
500 filtered.old_start
501 ));
502 }
503
504 let current_segment: Vec<&str> = lines[local_start..local_start + old_len]
505 .iter()
506 .map(String::as_str)
507 .collect();
508
509 if current_segment != old_segment {
510 let mut details = String::new();
511 let _ = write!(
512 details,
513 "patch context mismatch near source row {}: expected {:?}, found {:?}",
514 filtered.old_start, old_segment, current_segment
515 );
516 return Err(details);
517 }
518
519 lines.splice(local_start..local_start + old_len, new_segment);
520 line_delta += new_len as i64 - old_len as i64;
521 }
522
523 Ok(join_lines(&lines, base_has_trailing_newline))
524}
525
526fn split_preserving_final_empty_line(text: &str) -> Vec<String> {
527 let mut lines: Vec<String> = text.lines().map(ToString::to_string).collect();
528 if text.ends_with('\n') {
529 if lines.last().is_some_and(|line| !line.is_empty()) || lines.is_empty() {
530 lines.push(String::new());
531 }
532 }
533 lines
534}
535
536fn join_lines(lines: &[String], had_trailing_newline: bool) -> String {
537 if lines.is_empty() {
538 return String::new();
539 }
540
541 let mut joined = lines.join("\n");
542 if had_trailing_newline && !joined.ends_with('\n') {
543 joined.push('\n');
544 }
545 if !had_trailing_newline && joined.ends_with('\n') {
546 joined.pop();
547 }
548 joined
549}
550
551fn filter_hunk_to_excerpt(
552 hunk: &ParsedHunk,
553 excerpt_start_row: u32,
554 excerpt_end_row: u32,
555) -> Option<ParsedHunk> {
556 let mut filtered_lines = Vec::new();
557 let mut current_old_row = hunk.old_start.saturating_sub(1);
558 let mut filtered_old_start = None;
559 let mut has_overlap = false;
560
561 for line in &hunk.lines {
562 match line {
563 HunkLine::Context(text) => {
564 let in_excerpt =
565 current_old_row >= excerpt_start_row && current_old_row < excerpt_end_row;
566 if in_excerpt {
567 filtered_old_start.get_or_insert(current_old_row);
568 filtered_lines.push(HunkLine::Context(text.clone()));
569 has_overlap = true;
570 }
571 current_old_row += 1;
572 }
573 HunkLine::Deletion(text) => {
574 let in_excerpt =
575 current_old_row >= excerpt_start_row && current_old_row < excerpt_end_row;
576 if in_excerpt {
577 filtered_old_start.get_or_insert(current_old_row);
578 filtered_lines.push(HunkLine::Deletion(text.clone()));
579 has_overlap = true;
580 }
581 current_old_row += 1;
582 }
583 HunkLine::Addition(text) => {
584 let insertion_in_excerpt =
585 current_old_row >= excerpt_start_row && current_old_row <= excerpt_end_row;
586 if insertion_in_excerpt {
587 filtered_old_start.get_or_insert(current_old_row);
588 filtered_lines.push(HunkLine::Addition(text.clone()));
589 has_overlap = true;
590 }
591 }
592 }
593 }
594
595 if !has_overlap {
596 return None;
597 }
598
599 Some(ParsedHunk {
600 old_start: filtered_old_start.unwrap_or(excerpt_start_row),
601 lines: filtered_lines,
602 })
603}
604
605fn parse_diff_hunks(diff: &str) -> Vec<ParsedHunk> {
606 let mut hunks = Vec::new();
607 let mut current_hunk: Option<ParsedHunk> = None;
608
609 for line in diff.lines() {
610 if let Some((old_start, old_count, _new_start, _new_count)) = parse_hunk_header(line) {
611 if let Some(hunk) = current_hunk.take() {
612 hunks.push(hunk);
613 }
614 let _ = old_count;
615 current_hunk = Some(ParsedHunk {
616 old_start,
617 lines: Vec::new(),
618 });
619 continue;
620 }
621
622 let Some(hunk) = current_hunk.as_mut() else {
623 continue;
624 };
625
626 if let Some(text) = line.strip_prefix('+') {
627 if !line.starts_with("+++") {
628 hunk.lines.push(HunkLine::Addition(text.to_string()));
629 }
630 } else if let Some(text) = line.strip_prefix('-') {
631 if !line.starts_with("---") {
632 hunk.lines.push(HunkLine::Deletion(text.to_string()));
633 }
634 } else if let Some(text) = line.strip_prefix(' ') {
635 hunk.lines.push(HunkLine::Context(text.to_string()));
636 } else if line.is_empty() {
637 hunk.lines.push(HunkLine::Context(String::new()));
638 }
639 }
640
641 if let Some(hunk) = current_hunk {
642 hunks.push(hunk);
643 }
644
645 hunks
646}
647
648fn parse_hunk_header(line: &str) -> Option<(u32, u32, u32, u32)> {
649 let line = line.strip_prefix("@@ -")?;
650 let (old_part, rest) = line.split_once(' ')?;
651 let rest = rest.strip_prefix('+')?;
652 let (new_part, _) = rest.split_once(" @@")?;
653
654 let (old_start, old_count) = parse_hunk_range(old_part)?;
655 let (new_start, new_count) = parse_hunk_range(new_part)?;
656 Some((old_start, old_count, new_start, new_count))
657}
658
659fn parse_hunk_range(part: &str) -> Option<(u32, u32)> {
660 if let Some((start, count)) = part.split_once(',') {
661 Some((start.parse().ok()?, count.parse().ok()?))
662 } else {
663 Some((part.parse().ok()?, 1))
664 }
665}
666
667#[cfg(test)]
668mod tests {
669 use super::*;
670
671 #[test]
672 fn applies_patch_in_file_mode() {
673 let base = "fn main() {\n println!(\"hello\");\n}\n";
674 let patch = "@@ -1,3 +1,3 @@\n fn main() {\n- println!(\"hello\");\n+ println!(\"world\");\n }\n";
675
676 let actual = apply_patch_to_excerpt(base, patch, 0).unwrap();
677 assert_eq!(actual, "fn main() {\n println!(\"world\");\n}\n");
678 }
679
680 #[test]
681 fn applies_patch_in_json_excerpt_mode() {
682 let base = "b\nc\nd\n";
683 let patch = "@@ -2,2 +2,2 @@\n-b\n-c\n+x\n+y\n";
684
685 let actual = apply_patch_to_excerpt(base, patch, 1).unwrap();
686 assert_eq!(actual, "x\ny\nd\n");
687 }
688
689 #[test]
690 fn applies_patch_with_excerpt_relative_line_numbers() {
691 let base = "a\nb\nc\nd\n";
692 // Patch uses excerpt-relative line numbers (line 2 of excerpt)
693 // even though the excerpt starts at file row 100.
694 let patch = "@@ -2,2 +2,2 @@\n-b\n-c\n+x\n+y\n";
695
696 let actual = apply_patch_to_excerpt(base, patch, 100).unwrap();
697 assert_eq!(actual, "a\nx\ny\nd\n");
698 }
699
700 #[test]
701 fn prefers_file_global_line_numbers_over_excerpt_relative() {
702 let base = "a\nb\nc\n";
703 // Patch uses file-global line numbers: excerpt starts at row 5,
704 // hunk targets line 6 (1-based) = row 5 (0-based) = first line.
705 let patch = "@@ -6,2 +6,2 @@\n-a\n-b\n+x\n+y\n";
706
707 let actual = apply_patch_to_excerpt(base, patch, 5).unwrap();
708 assert_eq!(actual, "x\ny\nc\n");
709 }
710}