From e25065b85a2bed3745547a4c5cf68ee36352b1b5 Mon Sep 17 00:00:00 2001 From: Oleksiy Syvokon Date: Tue, 24 Mar 2026 17:08:37 +0200 Subject: [PATCH] Add token-match-debug command It writes HTML with explanation of how mertic values are derived --- crates/edit_prediction_cli/src/example.rs | 10 + crates/edit_prediction_cli/src/main.rs | 13 + crates/edit_prediction_cli/src/metrics.rs | 476 +++++++++++++++--- crates/edit_prediction_cli/src/score.rs | 67 ++- .../src/token_match_debug.rs | 404 +++++++++++++++ crates/edit_prediction_cli/src/word_diff.rs | 20 +- 6 files changed, 923 insertions(+), 67 deletions(-) create mode 100644 crates/edit_prediction_cli/src/token_match_debug.rs diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index 196f4f96d99b64aed2ff3ae2d7a9897295a60b29..08049ca5cc0694f3a5cabb10f8a733cabc46284a 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -158,6 +158,16 @@ pub struct ExampleScore { #[serde(default)] pub exact_lines_fn: usize, #[serde(default)] + pub token_match_tp: usize, + #[serde(default)] + pub token_match_fp: usize, + #[serde(default)] + pub token_match_fn: usize, + #[serde(default)] + pub token_match_precision: f64, + #[serde(default)] + pub token_match_recall: f64, + #[serde(default)] pub reversal_ratio: f32, #[serde(default, skip_serializing_if = "Option::is_none")] pub cursor_distance: Option, diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 1dcd1d4aa3ad34df853e9d7b193c246f151a61b2..a283f23c302c835f73e5ea4244ac188d77cf6a92 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -22,6 +22,7 @@ mod reversal_tracking; mod score; mod split_commit; mod split_dataset; +mod token_match_debug; mod synthesize; mod truncate_expected_patch; @@ -61,6 +62,7 @@ use crate::score::run_scoring; use crate::split_commit::SplitCommitArgs; use crate::split_dataset::SplitArgs; use crate::synthesize::{SynthesizeConfig, run_synthesize}; +use crate::token_match_debug::{TokenMatchDebugArgs, run_token_match_debug}; use crate::truncate_expected_patch::TruncatePatchArgs; #[derive(Parser, Debug)] @@ -214,6 +216,8 @@ enum Command { SplitCommit(SplitCommitArgs), /// Truncate expected patch by the given criteria TruncatePatch(TruncatePatchArgs), + /// Generate token-match debug HTML for expected vs predicted patches + TokenMatchDebug(TokenMatchDebugArgs), /// Split a JSONL dataset into multiple files (stratified by repository_url if present) Split(SplitArgs), /// Filter a JSONL dataset by programming language (based on cursor_path extension) @@ -257,6 +261,7 @@ impl Display for Command { Command::Clean => write!(f, "clean"), Command::SplitCommit(_) => write!(f, "split-commit"), Command::TruncatePatch(_) => write!(f, "truncate-patch"), + Command::TokenMatchDebug(_) => write!(f, "token-match-debug"), Command::Split(_) => write!(f, "split"), Command::FilterLanguages(_) => write!(f, "filter-languages"), Command::ImportBatch(args) => { @@ -1056,6 +1061,13 @@ fn main() { } return; } + Command::TokenMatchDebug(debug_args) => { + if let Err(error) = run_token_match_debug(debug_args, &args.inputs) { + eprintln!("{error:#}"); + std::process::exit(1); + } + return; + } Command::Split(split_args) => { if let Err(error) = split_dataset::run_split(split_args, &args.inputs) { eprintln!("{error:#}"); @@ -1249,6 +1261,7 @@ fn main() { | Command::SplitCommit(_) | Command::Split(_) | Command::TruncatePatch(_) + | Command::TokenMatchDebug(_) | Command::FilterLanguages(_) | Command::ImportBatch(_) | Command::PrintZetaFormats => { diff --git a/crates/edit_prediction_cli/src/metrics.rs b/crates/edit_prediction_cli/src/metrics.rs index 1bfd8e542fa3d74b55f091d2ac13aa22883f6a2f..9c3432a95ffafeb1297ee9059b9824014eba127c 100644 --- a/crates/edit_prediction_cli/src/metrics.rs +++ b/crates/edit_prediction_cli/src/metrics.rs @@ -12,13 +12,32 @@ type CountsDelta = HashMap; /// Context characters needed on each side of a change to capture all affected n-grams const CONTEXT_CHARS: usize = CHR_F_CHAR_ORDER - 1; -#[derive(Default, Debug, Clone)] +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] pub struct ClassificationMetrics { pub true_positives: usize, pub false_positives: usize, pub false_negatives: usize, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TokenClass { + TruePositive, + FalsePositive, + FalseNegative, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ClassifiedToken { + pub token: String, + pub class: TokenClass, +} + +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct TokenClassificationDetail { + pub expected_tokens: Vec, + pub actual_tokens: Vec, +} + impl ClassificationMetrics { pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics { let mut true_positives = 0; @@ -48,6 +67,12 @@ impl ClassificationMetrics { } } + pub fn accumulate(&mut self, other: &ClassificationMetrics) { + self.true_positives += other.true_positives; + self.false_positives += other.false_positives; + self.false_negatives += other.false_negatives; + } + pub fn precision(&self) -> f64 { if self.true_positives + self.false_positives == 0 { 0.0 @@ -75,6 +100,19 @@ impl ClassificationMetrics { } } +pub fn compare_classification_metrics( + left: &ClassificationMetrics, + right: &ClassificationMetrics, +) -> std::cmp::Ordering { + left.f1() + .total_cmp(&right.f1()) + .then_with(|| left.precision().total_cmp(&right.precision())) + .then_with(|| left.recall().total_cmp(&right.recall())) + .then_with(|| left.true_positives.cmp(&right.true_positives)) + .then_with(|| right.false_positives.cmp(&left.false_positives)) + .then_with(|| right.false_negatives.cmp(&left.false_negatives)) +} + enum ChrfWhitespace { /// Preserve whitespace as-is #[allow(unused)] @@ -525,63 +563,137 @@ pub struct TokenChangeCounts { pub deleted_tokens: usize, } -/// Counts the number of inserted and deleted tokens in a unified diff patch. -/// -/// Tokens are words and whitespace sequences (as defined by `word_diff::tokenize`). -/// Within each hunk, the old (`-`) and new (`+`) lines are compared at the token level -/// using an LCS-based diff, so modified lines only count the actually changed tokens -/// rather than the entire line. -pub fn count_patch_token_changes(patch: &str) -> TokenChangeCounts { - let mut counts = TokenChangeCounts::default(); - let mut old_lines: Vec<&str> = Vec::new(); - let mut new_lines: Vec<&str> = Vec::new(); +fn classify_token_diff_ops( + expected_tokens: &[&str], + actual_tokens: &[&str], +) -> ClassificationMetrics { + classify_token_diff_ops_detailed(expected_tokens, actual_tokens).0 +} - let flush = - |old_lines: &mut Vec<&str>, new_lines: &mut Vec<&str>, counts: &mut TokenChangeCounts| { - if old_lines.is_empty() && new_lines.is_empty() { - return; +fn classify_token_diff_ops_detailed( + expected_tokens: &[&str], + actual_tokens: &[&str], +) -> (ClassificationMetrics, TokenClassificationDetail) { + let mut metrics = ClassificationMetrics::default(); + let mut detail = TokenClassificationDetail::default(); + + for operation in diff_tokens(expected_tokens, actual_tokens) { + match operation { + DiffOp::Equal { + old_start, + old_end, + new_start, + new_end, + } => { + metrics.true_positives += old_end - old_start; + for token in &expected_tokens[old_start..old_end] { + detail.expected_tokens.push(ClassifiedToken { + token: (*token).to_string(), + class: TokenClass::TruePositive, + }); + } + for token in &actual_tokens[new_start..new_end] { + detail.actual_tokens.push(ClassifiedToken { + token: (*token).to_string(), + class: TokenClass::TruePositive, + }); + } } - - let old_text: String = old_lines - .iter() - .map(|line| if line.len() > 1 { &line[1..] } else { "" }) - .collect::>() - .join("\n"); - - let new_text: String = new_lines - .iter() - .map(|line| if line.len() > 1 { &line[1..] } else { "" }) - .collect::>() - .join("\n"); - - let old_tokens = tokenize(&old_text); - let new_tokens = tokenize(&new_text); - let ops = diff_tokens(&old_tokens, &new_tokens); - - for op in ops { - match op { - DiffOp::Equal(..) => {} - DiffOp::Delete(start, end) => { - counts.deleted_tokens += end - start; - } - DiffOp::Insert(start, end) => { - counts.inserted_tokens += end - start; - } - DiffOp::Replace { - old_start, - old_end, - new_start, - new_end, - } => { - counts.deleted_tokens += old_end - old_start; - counts.inserted_tokens += new_end - new_start; - } + DiffOp::Delete(start, end) => { + metrics.false_negatives += end - start; + for token in &expected_tokens[start..end] { + detail.expected_tokens.push(ClassifiedToken { + token: (*token).to_string(), + class: TokenClass::FalseNegative, + }); + } + } + DiffOp::Insert(start, end) => { + metrics.false_positives += end - start; + for token in &actual_tokens[start..end] { + detail.actual_tokens.push(ClassifiedToken { + token: (*token).to_string(), + class: TokenClass::FalsePositive, + }); + } + } + DiffOp::Replace { + old_start, + old_end, + new_start, + new_end, + } => { + metrics.false_negatives += old_end - old_start; + metrics.false_positives += new_end - new_start; + + for token in &expected_tokens[old_start..old_end] { + detail.expected_tokens.push(ClassifiedToken { + token: (*token).to_string(), + class: TokenClass::FalseNegative, + }); + } + for token in &actual_tokens[new_start..new_end] { + detail.actual_tokens.push(ClassifiedToken { + token: (*token).to_string(), + class: TokenClass::FalsePositive, + }); } } + } + } + + (metrics, detail) +} + +fn classify_token_texts(expected_text: &str, actual_text: &str) -> ClassificationMetrics { + let expected_tokens = tokenize(expected_text); + let actual_tokens = tokenize(actual_text); + classify_token_diff_ops(&expected_tokens, &actual_tokens) +} + +fn classify_token_texts_detailed( + expected_text: &str, + actual_text: &str, +) -> (ClassificationMetrics, TokenClassificationDetail) { + let expected_tokens = tokenize(expected_text); + let actual_tokens = tokenize(actual_text); + classify_token_diff_ops_detailed(&expected_tokens, &actual_tokens) +} + +fn strip_patch_line_prefix(line: &str) -> &str { + line.strip_prefix('-') + .or_else(|| line.strip_prefix('+')) + .unwrap_or(line) +} - old_lines.clear(); - new_lines.clear(); - }; +fn extract_patch_change_blocks(patch: &str) -> Vec<(String, String)> { + let mut blocks = Vec::new(); + let mut old_lines: Vec<&str> = Vec::new(); + let mut new_lines: Vec<&str> = Vec::new(); + + let flush = |old_lines: &mut Vec<&str>, + new_lines: &mut Vec<&str>, + blocks: &mut Vec<(String, String)>| { + if old_lines.is_empty() && new_lines.is_empty() { + return; + } + + let old_text = old_lines + .iter() + .map(|line| strip_patch_line_prefix(line)) + .collect::>() + .join("\n"); + + let new_text = new_lines + .iter() + .map(|line| strip_patch_line_prefix(line)) + .collect::>() + .join("\n"); + + blocks.push((old_text, new_text)); + old_lines.clear(); + new_lines.clear(); + }; for line in patch.lines() { if line.starts_with("---") @@ -590,17 +702,102 @@ pub fn count_patch_token_changes(patch: &str) -> TokenChangeCounts { || line.starts_with("diff ") || line.starts_with("index ") { - flush(&mut old_lines, &mut new_lines, &mut counts); + flush(&mut old_lines, &mut new_lines, &mut blocks); } else if line.starts_with('-') { old_lines.push(line); } else if line.starts_with('+') { new_lines.push(line); } else { - flush(&mut old_lines, &mut new_lines, &mut counts); + flush(&mut old_lines, &mut new_lines, &mut blocks); } } - flush(&mut old_lines, &mut new_lines, &mut counts); + flush(&mut old_lines, &mut new_lines, &mut blocks); + blocks +} + +fn collect_patch_side_text(patch: &str, mut select_side: F) -> String +where + F: FnMut(&(String, String)) -> &str, +{ + let mut text = String::new(); + + for block in extract_patch_change_blocks(patch) { + let block_text = select_side(&block); + if block_text.is_empty() { + continue; + } + + if !text.is_empty() { + text.push('\n'); + } + text.push_str(block_text); + } + + text +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TokenMatchDebugReport { + pub expected_deleted_text: String, + pub actual_deleted_text: String, + pub expected_inserted_text: String, + pub actual_inserted_text: String, + pub deleted: TokenClassificationDetail, + pub inserted: TokenClassificationDetail, + pub metrics: ClassificationMetrics, +} + +/// Computes token-match precision/recall counts between expected and actual patches. +/// +/// Deletions and insertions are aligned independently, then their counts are summed. +/// Tokenization uses `word_diff::tokenize`, so identifiers, whitespace runs, and punctuation +/// are compared using the same token boundaries as the word-diff view. +pub fn token_match(expected_patch: &str, actual_patch: &str) -> ClassificationMetrics { + token_match_debug_report(expected_patch, actual_patch).metrics +} + +pub fn token_match_debug_report(expected_patch: &str, actual_patch: &str) -> TokenMatchDebugReport { + let expected_deleted = + collect_patch_side_text(expected_patch, |(old_text, _)| old_text.as_str()); + let actual_deleted = collect_patch_side_text(actual_patch, |(old_text, _)| old_text.as_str()); + let expected_inserted = + collect_patch_side_text(expected_patch, |(_, new_text)| new_text.as_str()); + let actual_inserted = collect_patch_side_text(actual_patch, |(_, new_text)| new_text.as_str()); + + let (mut metrics, deleted_detail) = + classify_token_texts_detailed(&expected_deleted, &actual_deleted); + let (inserted_metrics, inserted_detail) = + classify_token_texts_detailed(&expected_inserted, &actual_inserted); + metrics.accumulate(&inserted_metrics); + + TokenMatchDebugReport { + expected_deleted_text: expected_deleted, + actual_deleted_text: actual_deleted, + expected_inserted_text: expected_inserted, + actual_inserted_text: actual_inserted, + deleted: deleted_detail, + inserted: inserted_detail, + metrics, + } +} + +/// Counts the number of inserted and deleted tokens in a unified diff patch. +/// +/// Tokens are words and whitespace sequences (as defined by `word_diff::tokenize`). +/// Within each hunk, the old (`-`) and new (`+`) lines are compared at the token level +/// using an LCS-based diff, so modified lines only count the actually changed tokens +/// rather than the entire line. + +pub fn count_patch_token_changes(patch: &str) -> TokenChangeCounts { + let mut counts = TokenChangeCounts::default(); + + for (old_text, new_text) in extract_patch_change_blocks(patch) { + let metrics = classify_token_texts(&old_text, &new_text); + counts.deleted_tokens += metrics.false_negatives; + counts.inserted_tokens += metrics.false_positives; + } + counts } @@ -961,6 +1158,173 @@ index abc123..def456 100644 assert_eq!(metrics.false_negatives, 0); } + #[test] + fn test_token_match_perfect() { + let expected = indoc! {" + @@ -1,2 +1,4 @@ + -str + +struct LanguageEntry { + + path: PathBuf, + +} + "}; + + let actual = indoc! {" + @@ -1,2 +1,4 @@ + -str + +struct LanguageEntry { + + path: PathBuf, + +} + "}; + + let metrics = token_match(expected, actual); + assert_eq!(metrics.false_positives, 0); + assert_eq!(metrics.false_negatives, 0); + assert!(metrics.true_positives > 0); + assert!((metrics.precision() - 1.0).abs() < 1e-6); + assert!((metrics.recall() - 1.0).abs() < 1e-6); + assert!((metrics.f1() - 1.0).abs() < 1e-6); + } + + #[test] + fn test_token_match_partial_subset_keeps_high_precision() { + let expected = indoc! {" + @@ -1,2 +1,6 @@ + -str + +struct LanguageEntry { + + path: PathBuf, + + language: OnceCell, + + external_files: Option>, + +} + "}; + + let actual = indoc! {" + @@ -1,2 +1,3 @@ + -str + +struct LanguageEntry { + +} + "}; + + let metrics = token_match(expected, actual); + assert!(metrics.true_positives > 0); + assert_eq!(metrics.false_positives, 0); + assert!(metrics.false_negatives > 0); + assert!((metrics.precision() - 1.0).abs() < 1e-6); + assert!(metrics.recall() < 1.0); + } + + #[test] + fn test_token_match_counts_wrong_tokens_as_fp_and_fn() { + let expected = indoc! {" + @@ -1,1 +1,1 @@ + -old_name + +new_name + "}; + + let actual = indoc! {" + @@ -1,1 +1,1 @@ + -different_old + +different_new + "}; + + let metrics = token_match(expected, actual); + assert_eq!(metrics.true_positives, 0); + assert!(metrics.false_positives > 0); + assert!(metrics.false_negatives > 0); + } + + #[test] + fn test_token_match_debug_report_metrics_match_token_match() { + let expected = indoc! {" + @@ -1,2 +1,3 @@ + -str + +struct LanguageEntry { + +} + "}; + + let actual = indoc! {" + @@ -1,2 +1,4 @@ + -str + +struct LanguageEntry { + + path: PathBuf, + +} + "}; + + let metrics = token_match(expected, actual); + let report = token_match_debug_report(expected, actual); + + assert_eq!(report.metrics, metrics); + + let expected_tp = report + .deleted + .expected_tokens + .iter() + .chain(report.inserted.expected_tokens.iter()) + .filter(|token| token.class == TokenClass::TruePositive) + .count(); + let expected_fn = report + .deleted + .expected_tokens + .iter() + .chain(report.inserted.expected_tokens.iter()) + .filter(|token| token.class == TokenClass::FalseNegative) + .count(); + let actual_tp = report + .deleted + .actual_tokens + .iter() + .chain(report.inserted.actual_tokens.iter()) + .filter(|token| token.class == TokenClass::TruePositive) + .count(); + let actual_fp = report + .deleted + .actual_tokens + .iter() + .chain(report.inserted.actual_tokens.iter()) + .filter(|token| token.class == TokenClass::FalsePositive) + .count(); + + assert_eq!(expected_tp, report.metrics.true_positives); + assert_eq!(actual_tp, report.metrics.true_positives); + assert_eq!(expected_fn, report.metrics.false_negatives); + assert_eq!(actual_fp, report.metrics.false_positives); + } + + #[test] + fn test_token_match_debug_report_marks_inserted_extra_tokens_as_fp() { + let expected = indoc! {" + @@ -1,1 +1,1 @@ + -a + +value + "}; + + let actual = indoc! {" + @@ -1,1 +1,1 @@ + -a + +value_extra + "}; + + let report = token_match_debug_report(expected, actual); + + assert_eq!(report.metrics.false_positives, 1); + assert_eq!(report.metrics.false_negatives, 1); + + assert!( + report + .inserted + .actual_tokens + .iter() + .any(|token| token.token == "value_extra" + && token.class == TokenClass::FalsePositive) + ); + assert!( + report + .inserted + .expected_tokens + .iter() + .any(|token| token.token == "value" && token.class == TokenClass::FalseNegative) + ); + } + #[test] fn test_is_editable_region_correct() { let patch = indoc! {" diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index d75cf55e85b198bc28469e83d8f9209a8a59a83f..978093482b87ce31a620c025c9dec0348e3f94ef 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -71,6 +71,11 @@ pub async fn run_scoring( exact_lines_tp: 0, exact_lines_fp: 0, exact_lines_fn: 0, + token_match_tp: 0, + token_match_fp: 0, + token_match_fn: 0, + token_match_precision: 0.0, + token_match_recall: 0.0, reversal_ratio: 0.0, cursor_distance: None, cursor_exact_match: None, @@ -100,10 +105,30 @@ pub async fn run_scoring( let token_changes = metrics::count_patch_token_changes(&actual_patch); + let best_exact_lines = expected_patches_with_cursors + .iter() + .map(|(expected_patch, _)| metrics::exact_lines_match(expected_patch, &actual_patch)) + .max_by_key(|m| m.true_positives) + .unwrap_or_default(); + + let best_token_match = expected_patches_with_cursors + .iter() + .map(|(expected_patch, _)| metrics::token_match(expected_patch, &actual_patch)) + .max_by(metrics::compare_classification_metrics) + .unwrap_or_default(); + let actual_text = match apply_diff_to_string(&actual_patch, original_text) { Ok(text) => text, Err(_) => { let mut s = zero_scores.clone(); + s.exact_lines_tp = best_exact_lines.true_positives; + s.exact_lines_fp = best_exact_lines.false_positives; + s.exact_lines_fn = best_exact_lines.false_negatives; + s.token_match_tp = best_token_match.true_positives; + s.token_match_fp = best_token_match.false_positives; + s.token_match_fn = best_token_match.false_negatives; + s.token_match_precision = best_token_match.precision(); + s.token_match_recall = best_token_match.recall(); s.inserted_tokens = token_changes.inserted_tokens; s.deleted_tokens = token_changes.deleted_tokens; scores.push(s); @@ -151,13 +176,6 @@ pub async fn run_scoring( let disbalance_after = metrics::braces_disbalance(&actual_text); let braces_disbalance = disbalance_after.saturating_sub(disbalance_before); - // Compute exact lines match against best matching expected patch - let best_exact_lines = expected_patches_with_cursors - .iter() - .map(|(expected_patch, _)| metrics::exact_lines_match(expected_patch, &actual_patch)) - .max_by_key(|m| m.true_positives) - .unwrap_or_default(); - // Compute reversal ratio let reversal_ratio = reversal_tracking::compute_prediction_reversal_ratio( prompt_inputs, @@ -184,6 +202,11 @@ pub async fn run_scoring( exact_lines_tp: best_exact_lines.true_positives, exact_lines_fp: best_exact_lines.false_positives, exact_lines_fn: best_exact_lines.false_negatives, + token_match_tp: best_token_match.true_positives, + token_match_fp: best_token_match.false_positives, + token_match_fn: best_token_match.false_negatives, + token_match_precision: best_token_match.precision(), + token_match_recall: best_token_match.recall(), reversal_ratio, cursor_distance, cursor_exact_match, @@ -253,6 +276,7 @@ pub fn print_report(examples: &[Example], verbose: bool) { let mut isolated_whitespace_count: usize = 0; let mut patch_inserted_tokens: Vec = Vec::new(); let mut patch_deleted_tokens: Vec = Vec::new(); + let mut total_token_match = ClassificationMetrics::default(); let mut predictions_with_patch: usize = 0; let mut printed_lines: usize = 0; @@ -317,6 +341,9 @@ pub fn print_report(examples: &[Example], verbose: bool) { total_exact_lines.true_positives += score.exact_lines_tp; total_exact_lines.false_positives += score.exact_lines_fp; total_exact_lines.false_negatives += score.exact_lines_fn; + total_token_match.true_positives += score.token_match_tp; + total_token_match.false_positives += score.token_match_fp; + total_token_match.false_negatives += score.token_match_fn; // Accumulate QA metrics if let Some(qa) = qa_result { @@ -465,6 +492,16 @@ pub fn print_report(examples: &[Example], verbose: bool) { println!("Isolated whitespace changes: {}", isolated_ws_str); } + println!( + "Token match: P={:.1}% R={:.1}% F1={:.1}% (TP={}, FP={}, FN={})", + total_token_match.precision() * 100.0, + total_token_match.recall() * 100.0, + total_token_match.f1() * 100.0, + total_token_match.true_positives, + total_token_match.false_positives, + total_token_match.false_negatives, + ); + // Print token change percentile summary (only for predictions with a patch) if !patch_inserted_tokens.is_empty() { patch_inserted_tokens.sort_unstable(); @@ -547,6 +584,12 @@ pub struct SummaryJson { pub exact_lines_precision: f64, pub exact_lines_recall: f64, pub exact_lines_f1: f64, + pub token_match_tp: usize, + pub token_match_fp: usize, + pub token_match_fn: usize, + pub token_match_precision: f64, + pub token_match_recall: f64, + pub token_match_f1: f64, pub avg_reversal_ratio: f32, #[serde(skip_serializing_if = "Option::is_none")] pub qa_avg_reverts_edits: Option, @@ -570,6 +613,7 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { let mut all_reversal_ratios = Vec::new(); let mut braces_disbalance_sum: usize = 0; let mut total_exact_lines = ClassificationMetrics::default(); + let mut total_token_match = ClassificationMetrics::default(); let mut total_scores: usize = 0; let mut qa_reverts_count: usize = 0; let mut qa_reverts_total: usize = 0; @@ -592,6 +636,9 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { total_exact_lines.true_positives += score.exact_lines_tp; total_exact_lines.false_positives += score.exact_lines_fp; total_exact_lines.false_negatives += score.exact_lines_fn; + total_token_match.true_positives += score.token_match_tp; + total_token_match.false_positives += score.token_match_fp; + total_token_match.false_negatives += score.token_match_fn; // Accumulate QA metrics if let Some(Some(qa)) = example.qa.get(score_idx) { @@ -704,6 +751,12 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { exact_lines_precision: total_exact_lines.precision(), exact_lines_recall: total_exact_lines.recall(), exact_lines_f1: total_exact_lines.f1(), + token_match_tp: total_token_match.true_positives, + token_match_fp: total_token_match.false_positives, + token_match_fn: total_token_match.false_negatives, + token_match_precision: total_token_match.precision(), + token_match_recall: total_token_match.recall(), + token_match_f1: total_token_match.f1(), avg_reversal_ratio, qa_avg_reverts_edits, qa_avg_confidence, diff --git a/crates/edit_prediction_cli/src/token_match_debug.rs b/crates/edit_prediction_cli/src/token_match_debug.rs new file mode 100644 index 0000000000000000000000000000000000000000..a6d1429685417738b566f0712a7685de1d596523 --- /dev/null +++ b/crates/edit_prediction_cli/src/token_match_debug.rs @@ -0,0 +1,404 @@ +use crate::{example::read_example_files, metrics}; +use anyhow::Context as _; +use clap::Args; +use std::fmt::Write as _; +use std::path::PathBuf; + +#[derive(Args, Debug, Clone)] +#[command( + about = "Generate token-match debug HTML for expected vs predicted patches", + after_help = r#"EXAMPLES: + # Debug all examples from a jsonl dataset + ep token-match-debug examples.jsonl + + # Write HTML files to a specific directory + ep token-match-debug examples.jsonl --output-dir out/token-debug + + # Keep only the best expected patch per prediction + ep token-match-debug examples.jsonl --best-only + + # Limit generated files + ep token-match-debug examples.jsonl --limit 50 +"# +)] +pub struct TokenMatchDebugArgs { + /// Directory where HTML reports are written. + #[arg(long, default_value = "token-match-debug")] + pub output_dir: PathBuf, + + /// Only emit one report per prediction (best matching expected patch). + #[arg(long, default_value_t = false)] + pub best_only: bool, + + /// Maximum number of reports to write. + #[arg(long)] + pub limit: Option, +} + +pub fn run_token_match_debug(args: &TokenMatchDebugArgs, inputs: &[PathBuf]) -> anyhow::Result<()> { + let stdin_path = PathBuf::from("-"); + let inputs = if inputs.is_empty() { + std::slice::from_ref(&stdin_path) + } else { + inputs + }; + + let examples = read_example_files(inputs); + std::fs::create_dir_all(&args.output_dir).with_context(|| { + format!( + "failed to create output directory '{}'", + args.output_dir.display() + ) + })?; + + let mut written = 0usize; + for example in &examples { + let expected_patches = example.spec.expected_patches_with_cursor_positions(); + if expected_patches.is_empty() || example.predictions.is_empty() { + continue; + } + + for (prediction_index, prediction) in example.predictions.iter().enumerate() { + let Some(actual_patch) = prediction.actual_patch.as_deref() else { + continue; + }; + if actual_patch.trim().is_empty() { + continue; + } + + if args.best_only { + if let Some((expected_index, report)) = + best_expected_patch_report(&expected_patches, actual_patch) + { + let html = render_report_html( + &example.spec.name, + prediction_index, + expected_index, + &expected_patches[expected_index].0, + actual_patch, + &report, + ); + + let path = args.output_dir.join(report_filename( + &example.spec.filename(), + prediction_index, + expected_index, + )); + std::fs::write(&path, html) + .with_context(|| format!("failed to write report '{}'", path.display()))?; + written += 1; + if args.limit.is_some_and(|limit| written >= limit) { + eprintln!( + "Wrote {} report(s) to {}", + written, + args.output_dir.display() + ); + return Ok(()); + } + } + continue; + } + + for (expected_index, (expected_patch, _)) in expected_patches.iter().enumerate() { + let report = metrics::token_match_debug_report(expected_patch, actual_patch); + let html = render_report_html( + &example.spec.name, + prediction_index, + expected_index, + expected_patch, + actual_patch, + &report, + ); + let path = args.output_dir.join(report_filename( + &example.spec.filename(), + prediction_index, + expected_index, + )); + + std::fs::write(&path, html) + .with_context(|| format!("failed to write report '{}'", path.display()))?; + written += 1; + + if args.limit.is_some_and(|limit| written >= limit) { + eprintln!( + "Wrote {} report(s) to {}", + written, + args.output_dir.display() + ); + return Ok(()); + } + } + } + } + + eprintln!( + "Wrote {} report(s) to {}", + written, + args.output_dir.display() + ); + Ok(()) +} + +fn best_expected_patch_report( + expected_patches: &[(String, Option)], + actual_patch: &str, +) -> Option<(usize, metrics::TokenMatchDebugReport)> { + let mut best: Option<(usize, metrics::TokenMatchDebugReport)> = None; + for (index, (expected_patch, _)) in expected_patches.iter().enumerate() { + let report = metrics::token_match_debug_report(expected_patch, actual_patch); + match &best { + Some((_, current)) => { + if metrics::compare_classification_metrics(&report.metrics, ¤t.metrics) + .is_gt() + { + best = Some((index, report)); + } + } + None => best = Some((index, report)), + } + } + best +} + +fn report_filename(example_name: &str, prediction_index: usize, expected_index: usize) -> String { + format!( + "{}__prediction-{}__expected-{}.html", + example_name, prediction_index, expected_index + ) +} + +fn render_report_html( + example_name: &str, + prediction_index: usize, + expected_index: usize, + expected_patch: &str, + actual_patch: &str, + report: &metrics::TokenMatchDebugReport, +) -> String { + let mut html = String::new(); + + let precision = report.metrics.precision() * 100.0; + let recall = report.metrics.recall() * 100.0; + let f1 = report.metrics.f1() * 100.0; + + let _ = write!( + html, + r#" + + + + +Token Match Debug + + + +
+

Token Match Debug

+

Example: {example_name} · Prediction #{prediction_index} · Expected Patch #{expected_index}

+ +
+
+
Precision
{precision:.1}%
+
Recall
{recall:.1}%
+
F1
{f1:.1}%
+
TP
{tp}
+
FP
{fp}
+
FN
{fn}
+
+
+ True Positive + False Positive + False Negative +
+
+ +
+
+
Expected patch
+
{expected_patch}
+
+
+
Actual patch
+
{actual_patch}
+
+
+ +
+
+

Deleted-side token alignment

+
Expected deleted text
+
{expected_deleted_text}
+
Actual deleted text
+
{actual_deleted_text}
+
Expected deleted tokens (FN highlighted)
+
{deleted_expected_tokens}
+
Actual deleted tokens (FP highlighted)
+
{deleted_actual_tokens}
+
+ +
+

Inserted-side token alignment

+
Expected inserted text
+
{expected_inserted_text}
+
Actual inserted text
+
{actual_inserted_text}
+
Expected inserted tokens (FN highlighted)
+
{inserted_expected_tokens}
+
Actual inserted tokens (FP highlighted)
+
{inserted_actual_tokens}
+
+
+
+ +"#, + example_name = escape_html(example_name), + prediction_index = prediction_index, + expected_index = expected_index, + precision = precision, + recall = recall, + f1 = f1, + tp = report.metrics.true_positives, + fp = report.metrics.false_positives, + fn = report.metrics.false_negatives, + expected_patch = escape_html(expected_patch), + actual_patch = escape_html(actual_patch), + expected_deleted_text = escape_html(&report.expected_deleted_text), + actual_deleted_text = escape_html(&report.actual_deleted_text), + expected_inserted_text = escape_html(&report.expected_inserted_text), + actual_inserted_text = escape_html(&report.actual_inserted_text), + deleted_expected_tokens = render_classified_tokens(&report.deleted.expected_tokens), + deleted_actual_tokens = render_classified_tokens(&report.deleted.actual_tokens), + inserted_expected_tokens = render_classified_tokens(&report.inserted.expected_tokens), + inserted_actual_tokens = render_classified_tokens(&report.inserted.actual_tokens), + ); + + html +} + +fn render_classified_tokens(tokens: &[metrics::ClassifiedToken]) -> String { + let mut result = String::new(); + for token in tokens { + let class = match token.class { + metrics::TokenClass::TruePositive => "tp", + metrics::TokenClass::FalsePositive => "fp", + metrics::TokenClass::FalseNegative => "fn", + }; + let escaped = escape_html(&token.token); + let _ = write!(result, r#"{escaped}"#); + } + result +} + +fn escape_html(input: &str) -> String { + let mut result = String::with_capacity(input.len()); + for character in input.chars() { + match character { + '&' => result.push_str("&"), + '<' => result.push_str("<"), + '>' => result.push_str(">"), + '"' => result.push_str("""), + '\'' => result.push_str("'"), + _ => result.push(character), + } + } + result +} diff --git a/crates/edit_prediction_cli/src/word_diff.rs b/crates/edit_prediction_cli/src/word_diff.rs index 72026d5715d312ded5702e6571a8cd52f02185c1..cd0458f2b133faaa4d84eb66a8f0f98de5794c30 100644 --- a/crates/edit_prediction_cli/src/word_diff.rs +++ b/crates/edit_prediction_cli/src/word_diff.rs @@ -85,8 +85,10 @@ fn compute_word_diff(old_text: &str, new_text: &str) -> String { for op in ops { match op { - DiffOp::Equal(start, end) => { - for token in &old_words[start..end] { + DiffOp::Equal { + old_start, old_end, .. + } => { + for token in &old_words[old_start..old_end] { result.push_str(token); } } @@ -178,7 +180,12 @@ pub(crate) fn tokenize(text: &str) -> Vec<&str> { #[derive(Debug)] pub(crate) enum DiffOp { - Equal(usize, usize), + Equal { + old_start: usize, + old_end: usize, + new_start: usize, + new_end: usize, + }, Delete(usize, usize), Insert(usize, usize), Replace { @@ -199,7 +206,12 @@ pub(crate) fn diff_tokens<'a>(old: &[&'a str], new: &[&'a str]) -> Vec { let old_range = op.old_range(); let new_range = op.new_range(); match tag { - DiffTag::Equal => DiffOp::Equal(old_range.start, old_range.end), + DiffTag::Equal => DiffOp::Equal { + old_start: old_range.start, + old_end: old_range.end, + new_start: new_range.start, + new_end: new_range.end, + }, DiffTag::Delete => DiffOp::Delete(old_range.start, old_range.end), DiffTag::Insert => DiffOp::Insert(new_range.start, new_range.end), DiffTag::Replace => DiffOp::Replace {