main.rs

  1use std::env;
  2use std::fmt::Write as _;
  3use std::fs;
  4use std::path::Path;
  5use std::process;
  6
  7use edit_prediction_metrics::{
  8    ClassificationMetrics, DeltaChrFMetrics, KeptRateResult, TokenAnnotation,
  9    annotate_kept_rate_tokens, braces_disbalance, compute_kept_rate, count_patch_token_changes,
 10    delta_chr_f, exact_lines_match, extract_changed_lines_from_diff,
 11    has_isolated_whitespace_changes, is_editable_region_correct,
 12};
 13use serde::Deserialize;
 14
 15fn main() {
 16    if let Err(error) = run() {
 17        eprintln!("error: {error}");
 18        process::exit(1);
 19    }
 20}
 21
 22fn run() -> Result<(), String> {
 23    let args: Vec<String> = env::args().skip(1).collect();
 24    if args.is_empty() {
 25        print_usage();
 26        return Err("missing arguments".to_string());
 27    }
 28
 29    let input = CliInput::parse(&args)?;
 30    let report = match input {
 31        CliInput::Files {
 32            base_path,
 33            expected_patch_path,
 34            actual_patch_path,
 35        } => {
 36            let base = fs::read_to_string(&base_path)
 37                .map_err(|err| format!("failed to read {}: {err}", base_path.display()))?;
 38            let expected_patch = fs::read_to_string(&expected_patch_path).map_err(|err| {
 39                format!("failed to read {}: {err}", expected_patch_path.display())
 40            })?;
 41            let actual_patch = fs::read_to_string(&actual_patch_path)
 42                .map_err(|err| format!("failed to read {}: {err}", actual_patch_path.display()))?;
 43
 44            let expected = apply_patch_to_excerpt(&base, &expected_patch, 0)?;
 45            let actual = apply_patch_to_excerpt(&base, &actual_patch, 0)?;
 46
 47            EvaluationReport::new(base, expected_patch, actual_patch, expected, actual)
 48        }
 49        CliInput::Json {
 50            json_path,
 51            prediction_index,
 52        } => {
 53            let json = fs::read_to_string(&json_path)
 54                .map_err(|err| format!("failed to read {}: {err}", json_path.display()))?;
 55            let example: JsonExample = serde_json::from_str(&json)
 56                .map_err(|err| format!("failed to parse {}: {err}", json_path.display()))?;
 57
 58            let base = example.prompt_inputs.cursor_excerpt;
 59            let excerpt_start_row = example.prompt_inputs.excerpt_start_row;
 60            let expected_patch = example
 61                .expected_patches
 62                .into_iter()
 63                .next()
 64                .ok_or_else(|| "JSON input is missing expected_patches[0]".to_string())?;
 65            let actual_patch = example
 66                .predictions
 67                .into_iter()
 68                .nth(prediction_index)
 69                .ok_or_else(|| {
 70                    format!("JSON input does not contain predictions[{prediction_index}]")
 71                })?
 72                .actual_patch;
 73
 74            let expected = apply_patch_to_excerpt(&base, &expected_patch, excerpt_start_row)?;
 75            let actual = apply_patch_to_excerpt(&base, &actual_patch, excerpt_start_row)?;
 76
 77            EvaluationReport::new(base, expected_patch, actual_patch, expected, actual)
 78        }
 79    };
 80
 81    print_report(&report);
 82    Ok(())
 83}
 84
 85fn print_usage() {
 86    eprintln!(
 87        "Usage:\n  edit_prediction_metrics --base <base.txt> --expected-patch <expected.diff> --actual-patch <actual.diff>\n  edit_prediction_metrics --json <example.json> [--prediction-index <n>]"
 88    );
 89}
 90
 91enum CliInput {
 92    Files {
 93        base_path: std::path::PathBuf,
 94        expected_patch_path: std::path::PathBuf,
 95        actual_patch_path: std::path::PathBuf,
 96    },
 97    Json {
 98        json_path: std::path::PathBuf,
 99        prediction_index: usize,
100    },
101}
102
103impl CliInput {
104    fn parse(args: &[String]) -> Result<Self, String> {
105        let mut base_path = None;
106        let mut expected_patch_path = None;
107        let mut actual_patch_path = None;
108        let mut json_path = None;
109        let mut prediction_index = 0usize;
110
111        let mut index = 0;
112        while index < args.len() {
113            match args[index].as_str() {
114                "--base" => {
115                    index += 1;
116                    base_path = Some(path_arg(args, index, "--base")?);
117                }
118                "--expected-patch" => {
119                    index += 1;
120                    expected_patch_path = Some(path_arg(args, index, "--expected-patch")?);
121                }
122                "--actual-patch" => {
123                    index += 1;
124                    actual_patch_path = Some(path_arg(args, index, "--actual-patch")?);
125                }
126                "--json" => {
127                    index += 1;
128                    json_path = Some(path_arg(args, index, "--json")?);
129                }
130                "--prediction-index" => {
131                    index += 1;
132                    let raw = string_arg(args, index, "--prediction-index")?;
133                    prediction_index = raw.parse::<usize>().map_err(|err| {
134                        format!("invalid value for --prediction-index ({raw}): {err}")
135                    })?;
136                }
137                "--help" | "-h" => {
138                    print_usage();
139                    process::exit(0);
140                }
141                unknown => {
142                    return Err(format!("unrecognized argument: {unknown}"));
143                }
144            }
145            index += 1;
146        }
147
148        if let Some(json_path) = json_path {
149            if base_path.is_some() || expected_patch_path.is_some() || actual_patch_path.is_some() {
150                return Err(
151                    "--json cannot be combined with --base/--expected-patch/--actual-patch"
152                        .to_string(),
153                );
154            }
155            return Ok(CliInput::Json {
156                json_path,
157                prediction_index,
158            });
159        }
160
161        match (base_path, expected_patch_path, actual_patch_path) {
162            (Some(base_path), Some(expected_patch_path), Some(actual_patch_path)) => {
163                Ok(CliInput::Files {
164                    base_path,
165                    expected_patch_path,
166                    actual_patch_path,
167                })
168            }
169            _ => Err(
170                "expected either --json <file> or all of --base, --expected-patch, and --actual-patch"
171                    .to_string(),
172            ),
173        }
174    }
175}
176
177fn path_arg(args: &[String], index: usize, flag: &str) -> Result<std::path::PathBuf, String> {
178    Ok(Path::new(string_arg(args, index, flag)?).to_path_buf())
179}
180
181fn string_arg<'a>(args: &'a [String], index: usize, flag: &str) -> Result<&'a str, String> {
182    args.get(index)
183        .map(|value| value.as_str())
184        .ok_or_else(|| format!("missing value for {flag}"))
185}
186
187#[derive(Debug)]
188struct EvaluationReport {
189    base: String,
190    expected: String,
191    actual: String,
192    kept_rate: KeptRateResult,
193    exact_lines: ClassificationMetrics,
194    delta_chr_f: DeltaChrFMetrics,
195    expected_changed_lines: usize,
196    actual_changed_lines: usize,
197    token_changes: edit_prediction_metrics::TokenChangeCounts,
198    isolated_whitespace_changes: bool,
199    editable_region_correct: bool,
200    expected_braces_disbalance: usize,
201    actual_braces_disbalance: usize,
202}
203
204impl EvaluationReport {
205    fn new(
206        base: String,
207        expected_patch: String,
208        actual_patch: String,
209        expected: String,
210        actual: String,
211    ) -> Self {
212        let kept_rate = compute_kept_rate(&base, &actual, &expected);
213        let exact_lines = exact_lines_match(&expected_patch, &actual_patch);
214        let delta_chr_f = delta_chr_f(&base, &expected, &actual);
215        let expected_changed_lines = extract_changed_lines_from_diff(&expected_patch)
216            .values()
217            .sum();
218        let actual_changed_lines = extract_changed_lines_from_diff(&actual_patch)
219            .values()
220            .sum();
221        let token_changes = count_patch_token_changes(&actual_patch);
222        let isolated_whitespace_changes = has_isolated_whitespace_changes(&actual_patch, None);
223        let editable_region_correct = is_editable_region_correct(&actual_patch);
224        let expected_braces_disbalance = braces_disbalance(&expected);
225        let actual_braces_disbalance = braces_disbalance(&actual);
226
227        Self {
228            base,
229            expected,
230            actual,
231            kept_rate,
232            exact_lines,
233            delta_chr_f,
234            expected_changed_lines,
235            actual_changed_lines,
236            token_changes,
237            isolated_whitespace_changes,
238            editable_region_correct,
239            expected_braces_disbalance,
240            actual_braces_disbalance,
241        }
242    }
243}
244
245fn print_report(report: &EvaluationReport) {
246    println!("Metrics");
247    println!("=======");
248    println!("kept_rate: {:.6}", report.kept_rate.kept_rate);
249    println!("kept_rate_recall: {:.6}", report.kept_rate.recall_rate);
250    println!("delta_chr_f: {:.6}", report.delta_chr_f.score);
251    println!("delta_chr_f_precision: {:.6}", report.delta_chr_f.precision);
252    println!("delta_chr_f_recall: {:.6}", report.delta_chr_f.recall);
253    println!("delta_chr_f_beta: {:.6}", report.delta_chr_f.beta);
254    println!();
255
256    println!("Exact line match");
257    println!("----------------");
258    println!("true_positives: {}", report.exact_lines.true_positives);
259    println!("false_positives: {}", report.exact_lines.false_positives);
260    println!("false_negatives: {}", report.exact_lines.false_negatives);
261    println!("precision: {:.6}", report.exact_lines.precision());
262    println!("recall: {:.6}", report.exact_lines.recall());
263    println!("f1: {:.6}", report.exact_lines.f1());
264    println!("expected_changed_lines: {}", report.expected_changed_lines);
265    println!("actual_changed_lines: {}", report.actual_changed_lines);
266    println!();
267
268    println!("Patch structure");
269    println!("---------------");
270    println!("inserted_tokens: {}", report.token_changes.inserted_tokens);
271    println!("deleted_tokens: {}", report.token_changes.deleted_tokens);
272    println!(
273        "isolated_whitespace_changes: {}",
274        report.isolated_whitespace_changes
275    );
276    println!(
277        "editable_region_correct: {}",
278        report.editable_region_correct
279    );
280    println!();
281
282    println!("Final text checks");
283    println!("-----------------");
284    println!(
285        "expected_braces_disbalance: {}",
286        report.expected_braces_disbalance
287    );
288    println!(
289        "actual_braces_disbalance: {}",
290        report.actual_braces_disbalance
291    );
292    println!();
293
294    println!("Kept-rate breakdown");
295    println!("-------------------");
296    println!(
297        "candidate_new_chars: {}",
298        report.kept_rate.candidate_new_chars
299    );
300    println!(
301        "reference_new_chars: {}",
302        report.kept_rate.reference_new_chars
303    );
304    println!(
305        "candidate_deleted_chars: {}",
306        report.kept_rate.candidate_deleted_chars
307    );
308    println!(
309        "reference_deleted_chars: {}",
310        report.kept_rate.reference_deleted_chars
311    );
312    println!("kept_chars: {}", report.kept_rate.kept_chars);
313    println!(
314        "correctly_deleted_chars: {}",
315        report.kept_rate.correctly_deleted_chars
316    );
317    println!("discarded_chars: {}", report.kept_rate.discarded_chars);
318    println!("context_chars: {}", report.kept_rate.context_chars);
319    println!();
320
321    print_kept_rate_explanation(&report.base, &report.actual, &report.expected);
322}
323
324fn print_kept_rate_explanation(base: &str, actual: &str, expected: &str) {
325    println!("Kept-rate explanation");
326    println!("---------------------");
327    println!("Legend: context = default, kept = green background, discarded = red background");
328    println!();
329
330    let annotated = annotate_kept_rate_tokens(base, actual, expected);
331    println!("Actual final text with token annotations:");
332    println!("{}", render_annotated_tokens(&annotated));
333    println!();
334}
335
336fn render_annotated_tokens(tokens: &[edit_prediction_metrics::AnnotatedToken]) -> String {
337    const RESET: &str = "\x1b[0m";
338    const KEPT_STYLE: &str = "\x1b[30;42m";
339    const DISCARDED_STYLE: &str = "\x1b[30;41m";
340
341    let mut rendered = String::new();
342    for token in tokens {
343        let style = match token.annotation {
344            TokenAnnotation::Context => "",
345            TokenAnnotation::Kept => KEPT_STYLE,
346            TokenAnnotation::Discarded => DISCARDED_STYLE,
347        };
348
349        if style.is_empty() {
350            rendered.push_str(&visualize_whitespace(&token.token));
351        } else {
352            rendered.push_str(style);
353            rendered.push_str(&visualize_whitespace(&token.token));
354            rendered.push_str(RESET);
355        }
356    }
357    rendered
358}
359
360fn visualize_whitespace(token: &str) -> String {
361    let mut rendered = String::new();
362    for ch in token.chars() {
363        match ch {
364            ' ' => rendered.push('·'),
365            '\t' => rendered.push('⇥'),
366            '\n' => rendered.push_str("\n"),
367            _ => rendered.push(ch),
368        }
369    }
370    rendered
371}
372
373#[derive(Debug, Deserialize)]
374struct JsonExample {
375    prompt_inputs: PromptInputs,
376    expected_patches: Vec<String>,
377    predictions: Vec<Prediction>,
378}
379
380#[derive(Debug, Deserialize)]
381struct PromptInputs {
382    cursor_excerpt: String,
383    excerpt_start_row: u32,
384}
385
386#[derive(Debug, Deserialize)]
387struct Prediction {
388    actual_patch: String,
389}
390
391#[derive(Debug, Clone)]
392struct ParsedHunk {
393    old_start: u32,
394    lines: Vec<HunkLine>,
395}
396
397#[derive(Debug, Clone)]
398enum HunkLine {
399    Context(String),
400    Addition(String),
401    Deletion(String),
402}
403
404fn apply_patch_to_excerpt(
405    base: &str,
406    patch: &str,
407    excerpt_start_row: u32,
408) -> Result<String, String> {
409    let hunks = parse_diff_hunks(patch);
410
411    let result = try_apply_hunks(base, &hunks, excerpt_start_row);
412
413    // Predicted patches may use excerpt-relative line numbers instead of
414    // file-global ones. When all hunks fall outside the excerpt window the
415    // result is identical to the base text. Retry with a zero offset so the
416    // line numbers are interpreted relative to the excerpt.
417    if excerpt_start_row > 0 && !hunks.is_empty() {
418        let should_retry = match &result {
419            Ok(text) => text == base,
420            Err(_) => true,
421        };
422
423        if should_retry {
424            let fallback = try_apply_hunks(base, &hunks, 0);
425            if matches!(&fallback, Ok(text) if text != base) {
426                return fallback;
427            }
428        }
429    }
430
431    result
432}
433
434fn try_apply_hunks(
435    base: &str,
436    hunks: &[ParsedHunk],
437    excerpt_start_row: u32,
438) -> Result<String, String> {
439    let base_has_trailing_newline = base.ends_with('\n');
440    let mut lines = split_preserving_final_empty_line(base);
441    let original_line_count = lines.len() as u32;
442
443    let excerpt_end_row = excerpt_start_row + original_line_count;
444    let mut line_delta: i64 = 0;
445
446    for hunk in hunks {
447        let filtered = match filter_hunk_to_excerpt(hunk, excerpt_start_row, excerpt_end_row) {
448            Some(filtered) => filtered,
449            None => continue,
450        };
451
452        let local_start = filtered.old_start.saturating_sub(excerpt_start_row) as i64 + line_delta;
453        if local_start < 0 {
454            return Err(format!(
455                "patch application moved before excerpt start at source row {}",
456                filtered.old_start
457            ));
458        }
459        let local_start = local_start as usize;
460
461        if local_start > lines.len() {
462            return Err(format!(
463                "patch application starts past excerpt end at local line {}",
464                local_start + 1
465            ));
466        }
467
468        let old_len = filtered
469            .lines
470            .iter()
471            .filter(|line| !matches!(line, HunkLine::Addition(_)))
472            .count();
473        let new_len = filtered
474            .lines
475            .iter()
476            .filter(|line| !matches!(line, HunkLine::Deletion(_)))
477            .count();
478
479        let old_segment: Vec<&str> = filtered
480            .lines
481            .iter()
482            .filter_map(|line| match line {
483                HunkLine::Context(text) | HunkLine::Deletion(text) => Some(text.as_str()),
484                HunkLine::Addition(_) => None,
485            })
486            .collect();
487
488        let new_segment: Vec<String> = filtered
489            .lines
490            .iter()
491            .filter_map(|line| match line {
492                HunkLine::Context(text) | HunkLine::Addition(text) => Some(text.clone()),
493                HunkLine::Deletion(_) => None,
494            })
495            .collect();
496
497        if local_start + old_len > lines.len() {
498            return Err(format!(
499                "patch application exceeds excerpt bounds near source row {}",
500                filtered.old_start
501            ));
502        }
503
504        let current_segment: Vec<&str> = lines[local_start..local_start + old_len]
505            .iter()
506            .map(String::as_str)
507            .collect();
508
509        if current_segment != old_segment {
510            let mut details = String::new();
511            let _ = write!(
512                details,
513                "patch context mismatch near source row {}: expected {:?}, found {:?}",
514                filtered.old_start, old_segment, current_segment
515            );
516            return Err(details);
517        }
518
519        lines.splice(local_start..local_start + old_len, new_segment);
520        line_delta += new_len as i64 - old_len as i64;
521    }
522
523    Ok(join_lines(&lines, base_has_trailing_newline))
524}
525
526fn split_preserving_final_empty_line(text: &str) -> Vec<String> {
527    let mut lines: Vec<String> = text.lines().map(ToString::to_string).collect();
528    if text.ends_with('\n') {
529        if lines.last().is_some_and(|line| !line.is_empty()) || lines.is_empty() {
530            lines.push(String::new());
531        }
532    }
533    lines
534}
535
536fn join_lines(lines: &[String], had_trailing_newline: bool) -> String {
537    if lines.is_empty() {
538        return String::new();
539    }
540
541    let mut joined = lines.join("\n");
542    if had_trailing_newline && !joined.ends_with('\n') {
543        joined.push('\n');
544    }
545    if !had_trailing_newline && joined.ends_with('\n') {
546        joined.pop();
547    }
548    joined
549}
550
551fn filter_hunk_to_excerpt(
552    hunk: &ParsedHunk,
553    excerpt_start_row: u32,
554    excerpt_end_row: u32,
555) -> Option<ParsedHunk> {
556    let mut filtered_lines = Vec::new();
557    let mut current_old_row = hunk.old_start.saturating_sub(1);
558    let mut filtered_old_start = None;
559    let mut has_overlap = false;
560
561    for line in &hunk.lines {
562        match line {
563            HunkLine::Context(text) => {
564                let in_excerpt =
565                    current_old_row >= excerpt_start_row && current_old_row < excerpt_end_row;
566                if in_excerpt {
567                    filtered_old_start.get_or_insert(current_old_row);
568                    filtered_lines.push(HunkLine::Context(text.clone()));
569                    has_overlap = true;
570                }
571                current_old_row += 1;
572            }
573            HunkLine::Deletion(text) => {
574                let in_excerpt =
575                    current_old_row >= excerpt_start_row && current_old_row < excerpt_end_row;
576                if in_excerpt {
577                    filtered_old_start.get_or_insert(current_old_row);
578                    filtered_lines.push(HunkLine::Deletion(text.clone()));
579                    has_overlap = true;
580                }
581                current_old_row += 1;
582            }
583            HunkLine::Addition(text) => {
584                let insertion_in_excerpt =
585                    current_old_row >= excerpt_start_row && current_old_row <= excerpt_end_row;
586                if insertion_in_excerpt {
587                    filtered_old_start.get_or_insert(current_old_row);
588                    filtered_lines.push(HunkLine::Addition(text.clone()));
589                    has_overlap = true;
590                }
591            }
592        }
593    }
594
595    if !has_overlap {
596        return None;
597    }
598
599    Some(ParsedHunk {
600        old_start: filtered_old_start.unwrap_or(excerpt_start_row),
601        lines: filtered_lines,
602    })
603}
604
605fn parse_diff_hunks(diff: &str) -> Vec<ParsedHunk> {
606    let mut hunks = Vec::new();
607    let mut current_hunk: Option<ParsedHunk> = None;
608
609    for line in diff.lines() {
610        if let Some((old_start, old_count, _new_start, _new_count)) = parse_hunk_header(line) {
611            if let Some(hunk) = current_hunk.take() {
612                hunks.push(hunk);
613            }
614            let _ = old_count;
615            current_hunk = Some(ParsedHunk {
616                old_start,
617                lines: Vec::new(),
618            });
619            continue;
620        }
621
622        let Some(hunk) = current_hunk.as_mut() else {
623            continue;
624        };
625
626        if let Some(text) = line.strip_prefix('+') {
627            if !line.starts_with("+++") {
628                hunk.lines.push(HunkLine::Addition(text.to_string()));
629            }
630        } else if let Some(text) = line.strip_prefix('-') {
631            if !line.starts_with("---") {
632                hunk.lines.push(HunkLine::Deletion(text.to_string()));
633            }
634        } else if let Some(text) = line.strip_prefix(' ') {
635            hunk.lines.push(HunkLine::Context(text.to_string()));
636        } else if line.is_empty() {
637            hunk.lines.push(HunkLine::Context(String::new()));
638        }
639    }
640
641    if let Some(hunk) = current_hunk {
642        hunks.push(hunk);
643    }
644
645    hunks
646}
647
648fn parse_hunk_header(line: &str) -> Option<(u32, u32, u32, u32)> {
649    let line = line.strip_prefix("@@ -")?;
650    let (old_part, rest) = line.split_once(' ')?;
651    let rest = rest.strip_prefix('+')?;
652    let (new_part, _) = rest.split_once(" @@")?;
653
654    let (old_start, old_count) = parse_hunk_range(old_part)?;
655    let (new_start, new_count) = parse_hunk_range(new_part)?;
656    Some((old_start, old_count, new_start, new_count))
657}
658
659fn parse_hunk_range(part: &str) -> Option<(u32, u32)> {
660    if let Some((start, count)) = part.split_once(',') {
661        Some((start.parse().ok()?, count.parse().ok()?))
662    } else {
663        Some((part.parse().ok()?, 1))
664    }
665}
666
667#[cfg(test)]
668mod tests {
669    use super::*;
670
671    #[test]
672    fn applies_patch_in_file_mode() {
673        let base = "fn main() {\n    println!(\"hello\");\n}\n";
674        let patch = "@@ -1,3 +1,3 @@\n fn main() {\n-    println!(\"hello\");\n+    println!(\"world\");\n }\n";
675
676        let actual = apply_patch_to_excerpt(base, patch, 0).unwrap();
677        assert_eq!(actual, "fn main() {\n    println!(\"world\");\n}\n");
678    }
679
680    #[test]
681    fn applies_patch_in_json_excerpt_mode() {
682        let base = "b\nc\nd\n";
683        let patch = "@@ -2,2 +2,2 @@\n-b\n-c\n+x\n+y\n";
684
685        let actual = apply_patch_to_excerpt(base, patch, 1).unwrap();
686        assert_eq!(actual, "x\ny\nd\n");
687    }
688
689    #[test]
690    fn applies_patch_with_excerpt_relative_line_numbers() {
691        let base = "a\nb\nc\nd\n";
692        // Patch uses excerpt-relative line numbers (line 2 of excerpt)
693        // even though the excerpt starts at file row 100.
694        let patch = "@@ -2,2 +2,2 @@\n-b\n-c\n+x\n+y\n";
695
696        let actual = apply_patch_to_excerpt(base, patch, 100).unwrap();
697        assert_eq!(actual, "a\nx\ny\nd\n");
698    }
699
700    #[test]
701    fn prefers_file_global_line_numbers_over_excerpt_relative() {
702        let base = "a\nb\nc\n";
703        // Patch uses file-global line numbers: excerpt starts at row 5,
704        // hunk targets line 6 (1-based) = row 5 (0-based) = first line.
705        let patch = "@@ -6,2 +6,2 @@\n-a\n-b\n+x\n+y\n";
706
707        let actual = apply_patch_to_excerpt(base, patch, 5).unwrap();
708        assert_eq!(actual, "x\ny\nc\n");
709    }
710}