@@ -1,20 +1,20 @@
use collections::HashMap;
-type Counts = HashMap<String, usize>;
+pub type Counts = HashMap<String, usize>;
type CountsDelta = HashMap<String, isize>;
/// Context characters needed on each side of a change to capture all affected n-grams
const CONTEXT_CHARS: usize = CHR_F_CHAR_ORDER - 1;
#[derive(Default, Debug, Clone)]
-struct ClassificationMetrics {
- true_positives: usize,
- false_positives: usize,
- false_negatives: usize,
+pub struct ClassificationMetrics {
+ pub true_positives: usize,
+ pub false_positives: usize,
+ pub false_negatives: usize,
}
impl ClassificationMetrics {
- fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
+ pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
let mut true_positives = 0;
let mut false_positives = 0;
let mut false_negatives = 0;
@@ -42,7 +42,7 @@ impl ClassificationMetrics {
}
}
- fn precision(&self) -> f64 {
+ pub fn precision(&self) -> f64 {
if self.true_positives + self.false_positives == 0 {
0.0
} else {
@@ -50,13 +50,23 @@ impl ClassificationMetrics {
}
}
- fn recall(&self) -> f64 {
+ pub fn recall(&self) -> f64 {
if self.true_positives + self.false_negatives == 0 {
0.0
} else {
self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
}
}
+
+ pub fn f1(&self) -> f64 {
+ let precision = self.precision();
+ let recall = self.recall();
+ if precision + recall == 0.0 {
+ 0.0
+ } else {
+ 2.0 * precision * recall / (precision + recall)
+ }
+ }
}
enum ChrfWhitespace {
@@ -335,6 +345,43 @@ pub fn braces_disbalance(text: &str) -> usize {
disbalance as usize
}
+/// Extracts changed lines from a unified diff string.
+/// Returns a bag (multiset) of lines that were added (+) or removed (-).
+/// The +/- prefix is included in the line to distinguish additions from deletions.
+pub fn extract_changed_lines_from_diff(diff: &str) -> Counts {
+ let mut counts = Counts::default();
+
+ for line in diff.lines() {
+ // Skip file headers (--- and +++)
+ if line.starts_with("---") || line.starts_with("+++") {
+ continue;
+ }
+ // Skip hunk headers (@@)
+ if line.starts_with("@@") {
+ continue;
+ }
+ // Skip diff header lines (diff --git, index, etc.)
+ if line.starts_with("diff ") || line.starts_with("index ") {
+ continue;
+ }
+ // Include added and removed lines (with their prefix)
+ if line.starts_with('+') || line.starts_with('-') {
+ *counts.entry(line.to_string()).or_insert(0) += 1;
+ }
+ }
+
+ counts
+}
+
+/// Computes exact lines match metrics between expected and actual patches.
+/// Treats changed lines as a bag (multiset) - order is discarded but count matters.
+/// Returns ClassificationMetrics with TP/FP/FN counts.
+pub fn exact_lines_match(expected_patch: &str, actual_patch: &str) -> ClassificationMetrics {
+ let expected_lines = extract_changed_lines_from_diff(expected_patch);
+ let actual_lines = extract_changed_lines_from_diff(actual_patch);
+ ClassificationMetrics::from_counts(&expected_lines, &actual_lines)
+}
+
#[cfg(test)]
mod test_optimization {
use super::*;
@@ -559,4 +606,124 @@ mod test {
let text = "let x = { 1 + 2 )";
assert_eq!(braces_disbalance(text), 2);
}
+
+ #[test]
+ fn test_extract_changed_lines_from_diff() {
+ let diff = r#"--- a/file.rs
++++ b/file.rs
+@@ -1,3 +1,3 @@
+ fn main() {
+- println!("hello");
++ println!("world");
+ }"#;
+
+ let counts = extract_changed_lines_from_diff(diff);
+ assert_eq!(counts.get("- println!(\"hello\");"), Some(&1));
+ assert_eq!(counts.get("+ println!(\"world\");"), Some(&1));
+ assert_eq!(counts.len(), 2);
+ }
+
+ #[test]
+ fn test_extract_changed_lines_skips_headers() {
+ let diff = r#"diff --git a/file.rs b/file.rs
+index abc123..def456 100644
+--- a/file.rs
++++ b/file.rs
+@@ -1,2 +1,2 @@
+-old line
++new line"#;
+
+ let counts = extract_changed_lines_from_diff(diff);
+ assert_eq!(counts.get("-old line"), Some(&1));
+ assert_eq!(counts.get("+new line"), Some(&1));
+ assert_eq!(counts.len(), 2);
+ }
+
+ #[test]
+ fn test_exact_lines_match_perfect() {
+ let expected = r#"--- a/file.rs
++++ b/file.rs
+@@ -1,3 +1,3 @@
+-old line 1
+-old line 2
++new line 1
++new line 2"#;
+
+ let actual = r#"--- a/file.rs
++++ b/file.rs
+@@ -1,3 +1,3 @@
+-old line 1
+-old line 2
++new line 1
++new line 2"#;
+
+ let metrics = exact_lines_match(expected, actual);
+ assert_eq!(metrics.true_positives, 4);
+ assert_eq!(metrics.false_positives, 0);
+ assert_eq!(metrics.false_negatives, 0);
+ assert!((metrics.precision() - 1.0).abs() < 1e-6);
+ assert!((metrics.recall() - 1.0).abs() < 1e-6);
+ assert!((metrics.f1() - 1.0).abs() < 1e-6);
+ }
+
+ #[test]
+ fn test_exact_lines_match_partial() {
+ let expected = r#"-old line 1
+-old line 2
++new line 1
++new line 2"#;
+
+ let actual = r#"-old line 1
++new line 1
++extra line"#;
+
+ let metrics = exact_lines_match(expected, actual);
+ // TP: "-old line 1" and "+new line 1" (2)
+ // FP: "+extra line" (1)
+ // FN: "-old line 2" and "+new line 2" (2)
+ assert_eq!(metrics.true_positives, 2);
+ assert_eq!(metrics.false_positives, 1);
+ assert_eq!(metrics.false_negatives, 2);
+ }
+
+ #[test]
+ fn test_exact_lines_match_no_overlap() {
+ let expected = r#"-line a
++line b"#;
+
+ let actual = r#"-line x
++line y"#;
+
+ let metrics = exact_lines_match(expected, actual);
+ assert_eq!(metrics.true_positives, 0);
+ assert_eq!(metrics.false_positives, 2);
+ assert_eq!(metrics.false_negatives, 2);
+ assert!((metrics.precision()).abs() < 1e-6);
+ assert!((metrics.recall()).abs() < 1e-6);
+ }
+
+ #[test]
+ fn test_exact_lines_match_duplicate_lines() {
+ let expected = r#"+line a
++line a
++line a"#;
+
+ let actual = r#"+line a
++line a"#;
+
+ let metrics = exact_lines_match(expected, actual);
+ // Expected has 3 "+line a", actual has 2
+ // TP: 2, FN: 1, FP: 0
+ assert_eq!(metrics.true_positives, 2);
+ assert_eq!(metrics.false_positives, 0);
+ assert_eq!(metrics.false_negatives, 1);
+ }
+
+ #[test]
+ fn test_exact_lines_match_empty_patches() {
+ let metrics = exact_lines_match("", "");
+ assert_eq!(metrics.true_positives, 0);
+ assert_eq!(metrics.false_positives, 0);
+ assert_eq!(metrics.false_negatives, 0);
+ }
}
@@ -42,6 +42,9 @@ pub async fn run_scoring(
let zero_scores = ExampleScore {
delta_chr_f: 0.0,
braces_disbalance: 0,
+ exact_lines_tp: 0,
+ exact_lines_fp: 0,
+ exact_lines_fn: 0,
};
progress.set_substatus("computing metrics");
@@ -82,9 +85,21 @@ pub async fn run_scoring(
std::fs::write("/tmp/unbalanced-text.after", &actual_text).ok();
}
+ // Compute exact lines match against best matching expected patch
+ let best_exact_lines = example
+ .spec
+ .expected_patches
+ .iter()
+ .map(|expected_patch| metrics::exact_lines_match(expected_patch, &actual_patch))
+ .max_by_key(|m| m.true_positives)
+ .unwrap_or_default();
+
scores.push(ExampleScore {
delta_chr_f: best_delta_chr_f,
braces_disbalance,
+ exact_lines_tp: best_exact_lines.true_positives,
+ exact_lines_fp: best_exact_lines.false_positives,
+ exact_lines_fn: best_exact_lines.false_negatives,
});
}
@@ -93,53 +108,73 @@ pub async fn run_scoring(
}
pub fn print_report(examples: &[Example]) {
+ use crate::metrics::ClassificationMetrics;
+
+ const LINE_WIDTH: usize = 100;
+ let separator = "─".repeat(LINE_WIDTH);
+
+ eprintln!("{}", separator);
eprintln!(
- "──────────────────────────────────────────────────────────────────────────────────────"
- );
- eprintln!(
- "{:<50} {:>14} {:>10}",
- "Example name", "BracesDisbalance", "DeltaChrF"
- );
- eprintln!(
- "──────────────────────────────────────────────────────────────────────────────────────"
+ "{:<40} {:>8} {:>5} {:>4} {:>4} {:>4} {:>7} {:>7} {:>7}",
+ "Example", "DeltaChrF", "Brace", "TP", "FP", "FN", "Prec", "Rec", "F1"
);
+ eprintln!("{}", separator);
let mut all_delta_chr_f_scores = Vec::new();
let mut braces_disbalance_sum: usize = 0;
+ let mut total_exact_lines = ClassificationMetrics::default();
let mut total_scores: usize = 0;
for example in examples {
for score in example.score.iter() {
+ let exact_lines = ClassificationMetrics {
+ true_positives: score.exact_lines_tp,
+ false_positives: score.exact_lines_fp,
+ false_negatives: score.exact_lines_fn,
+ };
+
eprintln!(
- "{:<50} {:>14} {:>9.2}",
- truncate_name(&example.spec.name, 50),
+ "{:<40} {:>8.2} {:>5} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}%",
+ truncate_name(&example.spec.name, 40),
+ score.delta_chr_f,
score.braces_disbalance,
- score.delta_chr_f
+ score.exact_lines_tp,
+ score.exact_lines_fp,
+ score.exact_lines_fn,
+ exact_lines.precision() * 100.0,
+ exact_lines.recall() * 100.0,
+ exact_lines.f1() * 100.0
);
all_delta_chr_f_scores.push(score.delta_chr_f);
total_scores += 1;
braces_disbalance_sum += score.braces_disbalance;
+ total_exact_lines.true_positives += score.exact_lines_tp;
+ total_exact_lines.false_positives += score.exact_lines_fp;
+ total_exact_lines.false_negatives += score.exact_lines_fn;
}
}
- eprintln!(
- "──────────────────────────────────────────────────────────────────────────────────────"
- );
+ eprintln!("{}", separator);
if !all_delta_chr_f_scores.is_empty() {
let avg_delta_chr_f: f32 =
all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
- let braces_disbalance_display = format!("{:.2}", braces_disbalance_avg);
eprintln!(
- "{:<50} {:>14} {:>9.2}",
- "AVERAGE", braces_disbalance_display, avg_delta_chr_f
- );
- eprintln!(
- "──────────────────────────────────────────────────────────────────────────────────────"
+ "{:<40} {:>8.2} {:>5.1} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}%",
+ "TOTAL / AVERAGE",
+ avg_delta_chr_f,
+ braces_disbalance_avg,
+ total_exact_lines.true_positives,
+ total_exact_lines.false_positives,
+ total_exact_lines.false_negatives,
+ total_exact_lines.precision() * 100.0,
+ total_exact_lines.recall() * 100.0,
+ total_exact_lines.f1() * 100.0
);
+ eprintln!("{}", separator);
}
eprintln!("\n");