diff --git a/crates/edit_prediction/src/capture_example.rs b/crates/edit_prediction/src/capture_example.rs index 54bf027f2afc0ed1282c73ecd73884c518b3d66b..2565f9844c779f31ed7d848cba3b9f431a75ee56 100644 --- a/crates/edit_prediction/src/capture_example.rs +++ b/crates/edit_prediction/src/capture_example.rs @@ -74,7 +74,7 @@ pub fn capture_example( cursor_path: cursor_path.as_std_path().into(), cursor_position: String::new(), edit_history, - expected_patch: String::new(), + expected_patches: Vec::new(), }; spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix); Ok(spec) @@ -350,7 +350,7 @@ mod tests { seven(); "} .to_string(), - expected_patch: "".to_string(), + expected_patches: Vec::new() } ); } diff --git a/crates/edit_prediction/src/example_spec.rs b/crates/edit_prediction/src/example_spec.rs index 1caa918d406a02b2b0a2f33e86dd2825d16da3e4..f6eac8d44039ec931e9b94fba31ca7a5fbdaa217 100644 --- a/crates/edit_prediction/src/example_spec.rs +++ b/crates/edit_prediction/src/example_spec.rs @@ -15,7 +15,7 @@ pub struct ExampleSpec { pub cursor_path: Arc, pub cursor_position: String, pub edit_history: String, - pub expected_patch: String, + pub expected_patches: Vec, } const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff"; @@ -95,13 +95,15 @@ impl ExampleSpec { _ = writeln!(markdown, "## {}", EXPECTED_PATCH_HEADING); markdown.push('\n'); - _ = writeln!(markdown, "```diff"); - markdown.push_str(&self.expected_patch); - if !markdown.ends_with('\n') { + for patch in &self.expected_patches { + _ = writeln!(markdown, "```diff"); + markdown.push_str(patch); + if !markdown.ends_with('\n') { + markdown.push('\n'); + } + _ = writeln!(markdown, "```"); markdown.push('\n'); } - _ = writeln!(markdown, "```"); - markdown.push('\n'); markdown } @@ -118,7 +120,7 @@ impl ExampleSpec { cursor_path: Path::new("").into(), cursor_position: String::new(), edit_history: String::new(), - expected_patch: String::new(), + expected_patches: Vec::new(), }; if let Some(rest) = input.strip_prefix("+++\n") @@ -212,7 +214,7 @@ impl ExampleSpec { mem::take(&mut text); } Section::ExpectedPatch => { - spec.expected_patch = mem::take(&mut text); + spec.expected_patches.push(mem::take(&mut text)); } Section::Start | Section::Other => {} } @@ -353,7 +355,7 @@ mod tests { cursor_path: Path::new("test.rs").into(), cursor_position: String::new(), edit_history: String::new(), - expected_patch: String::new(), + expected_patches: Vec::new(), }; // Cursor before `42` diff --git a/crates/edit_prediction_cli/src/distill.rs b/crates/edit_prediction_cli/src/distill.rs index abfe178ae61b6da522f43c93d40b6000800d0e4d..d6343871e8054fc54062f3d3f7f5210374b36812 100644 --- a/crates/edit_prediction_cli/src/distill.rs +++ b/crates/edit_prediction_cli/src/distill.rs @@ -1,20 +1,15 @@ -use anyhow::{Result, anyhow}; +use anyhow::Result; use std::mem; use crate::example::Example; pub async fn run_distill(example: &mut Example) -> Result<()> { - let [prediction]: [_; 1] = - mem::take(&mut example.predictions) - .try_into() - .map_err(|preds: Vec<_>| { - anyhow!( - "Example has {} predictions, but it should have exactly one", - preds.len() - ) - })?; + let predictions = mem::take(&mut example.predictions) + .into_iter() + .map(|p| p.actual_patch) + .collect(); - example.spec.expected_patch = prediction.actual_patch; + example.spec.expected_patches = predictions; example.prompt = None; example.predictions = Vec::new(); example.score = Vec::new(); diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index ef3bbd95508da0fef7e5a56c431805d93ab1a8cc..63a53b0d7dc667b05171d486e078617187f24fe6 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -1,4 +1,4 @@ -use crate::{PredictionProvider, PromptFormat, metrics::ClassificationMetrics}; +use crate::{PredictionProvider, PromptFormat}; use anyhow::{Context as _, Result}; use collections::HashMap; use edit_prediction::example_spec::ExampleSpec; @@ -87,7 +87,6 @@ pub struct ExamplePrediction { #[derive(Clone, Debug, Serialize, Deserialize)] pub struct ExampleScore { pub delta_chr_f: f32, - pub line_match: ClassificationMetrics, } impl Example { diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index c21cec01856f59cf2a95a526e7a0deb0d896e6d7..7fbf86e51e7a28d4190195ac8424c538d1673c48 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -30,7 +30,13 @@ pub async fn run_format_prompt( let prompt = TeacherPrompt::format_prompt(example); example.prompt = Some(ExamplePrompt { input: prompt, - expected_output: example.spec.expected_patch.clone(), // TODO + // TODO + expected_output: example + .spec + .expected_patches + .first() + .context("no expected patches")? + .clone(), format: prompt_format, }); } @@ -68,8 +74,15 @@ pub async fn run_format_prompt( )) })??; let prompt = format_zeta_prompt(&input); - let expected_output = - zeta2_output_for_patch(&input, &example.spec.expected_patch.clone())?; + let expected_output = zeta2_output_for_patch( + &input, + &example + .spec + .expected_patches + .first() + .context("expected patches is empty")? + .clone(), + )?; example.prompt = Some(ExamplePrompt { input: prompt, expected_output, diff --git a/crates/edit_prediction_cli/src/metrics.rs b/crates/edit_prediction_cli/src/metrics.rs index b3e5eb8688724c821953a56c4fe82e67c75e13b6..0d6298f6fd4216c06e3c624c1ad4e6ab1a25d375 100644 --- a/crates/edit_prediction_cli/src/metrics.rs +++ b/crates/edit_prediction_cli/src/metrics.rs @@ -1,34 +1,17 @@ -use collections::{HashMap, HashSet}; -use edit_prediction::udiff::DiffLine; -use serde::{Deserialize, Serialize}; +use collections::HashMap; type Counts = HashMap; type CountsDelta = HashMap; -#[derive(Default, Debug, Clone, Serialize, Deserialize)] -pub struct ClassificationMetrics { - pub true_positives: usize, - pub false_positives: usize, - pub false_negatives: usize, +#[derive(Default, Debug, Clone)] +struct ClassificationMetrics { + true_positives: usize, + false_positives: usize, + false_negatives: usize, } impl ClassificationMetrics { - pub fn from_sets( - expected: &HashSet, - actual: &HashSet, - ) -> ClassificationMetrics { - let true_positives = expected.intersection(actual).count(); - let false_positives = actual.difference(expected).count(); - let false_negatives = expected.difference(actual).count(); - - ClassificationMetrics { - true_positives, - false_positives, - false_negatives, - } - } - - pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics { + fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics { let mut true_positives = 0; let mut false_positives = 0; let mut false_negatives = 0; @@ -56,27 +39,7 @@ impl ClassificationMetrics { } } - pub fn aggregate<'a>( - scores: impl Iterator, - ) -> ClassificationMetrics { - let mut true_positives = 0; - let mut false_positives = 0; - let mut false_negatives = 0; - - for score in scores { - true_positives += score.true_positives; - false_positives += score.false_positives; - false_negatives += score.false_negatives; - } - - ClassificationMetrics { - true_positives, - false_positives, - false_negatives, - } - } - - pub fn precision(&self) -> f64 { + fn precision(&self) -> f64 { if self.true_positives + self.false_positives == 0 { 0.0 } else { @@ -84,42 +47,13 @@ impl ClassificationMetrics { } } - pub fn recall(&self) -> f64 { + fn recall(&self) -> f64 { if self.true_positives + self.false_negatives == 0 { 0.0 } else { self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64 } } - - pub fn f1_score(&self) -> f64 { - let recall = self.recall(); - let precision = self.precision(); - if precision + recall == 0.0 { - 0.0 - } else { - 2.0 * precision * recall / (precision + recall) - } - } -} - -pub fn line_match_score( - expected_patch: &[DiffLine], - actual_patch: &[DiffLine], -) -> ClassificationMetrics { - let expected_change_lines = expected_patch - .iter() - .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_))) - .map(|line| line.to_string()) - .collect(); - - let actual_change_lines = actual_patch - .iter() - .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_))) - .map(|line| line.to_string()) - .collect(); - - ClassificationMetrics::from_sets(&expected_change_lines, &actual_change_lines) } enum ChrfWhitespace { @@ -135,55 +69,26 @@ const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Ignore; /// Computes a delta-chrF score that compares two sets of edits. /// /// This metric works by: -/// 1. Reconstructing original, golden (expected result), and actual texts from diffs -/// 2. Computing n-gram count differences (deltas) between original→golden and original→actual -/// 3. Comparing these deltas to measure how well actual edits match expected edits -pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 { - // Reconstruct texts from diffs - let mut original_text = String::new(); // state of the text before any edits - let mut golden_text = String::new(); // text after applying golden edits - let mut actual_text = String::new(); // text after applying actual edits - - for line in expected { - match line { - DiffLine::Context(s) => { - original_text.push_str(s); - golden_text.push_str(s); - } - DiffLine::Deletion(s) => { - original_text.push_str(s); - } - DiffLine::Addition(s) => { - golden_text.push_str(s); - } - _ => {} - } - } - - for line in actual { - match line { - DiffLine::Context(s) | DiffLine::Addition(s) => { - actual_text.push_str(s); - } - _ => {} - } - } - - // Edge case - if original_text == golden_text && golden_text == actual_text { +/// 1. Computing n-gram count differences (deltas) between original→expected and original→actual +/// 2. Comparing these deltas to measure how well actual edits match expected edits +/// +/// Returns a score from 0.0 to 100.0, where 100.0 means the actual edits perfectly match +/// the expected edits. +pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> f64 { + // Edge case: if all texts are identical, the edits match perfectly + if original == expected && expected == actual { return 100.0; } - // Compute the metric - let original_ngrams = chr_f_ngram_counts(&original_text); - let golden_ngrams = chr_f_ngram_counts(&golden_text); - let actual_ngrams = chr_f_ngram_counts(&actual_text); + let original_ngrams = chr_f_ngram_counts(original); + let expected_ngrams = chr_f_ngram_counts(expected); + let actual_ngrams = chr_f_ngram_counts(actual); let mut total_precision = 0.0; let mut total_recall = 0.0; for order in 0..CHR_F_CHAR_ORDER { - let expected_delta = compute_ngram_delta(&golden_ngrams[order], &original_ngrams[order]); + let expected_delta = compute_ngram_delta(&expected_ngrams[order], &original_ngrams[order]); let actual_delta = compute_ngram_delta(&actual_ngrams[order], &original_ngrams[order]); if expected_delta.is_empty() && actual_delta.is_empty() { @@ -278,94 +183,68 @@ fn count_ngrams(text: &str, n: usize) -> Counts { #[cfg(test)] mod test { use super::*; - use edit_prediction::udiff::DiffLine; #[test] fn test_delta_chr_f_perfect_match() { - let diff = vec![ - DiffLine::Context("fn main() {"), - DiffLine::Deletion(" println!(\"Hello\");"), - DiffLine::Addition(" println!(\"Hello, World!\");"), - DiffLine::Context("}"), - ]; - - let score = delta_chr_f(&diff, &diff); + let original = "fn main() { println!(\"Hello\");}"; + let expected = "fn main() { println!(\"Hello, World!\");}"; + + let score = delta_chr_f(original, expected, expected); assert!((score - 100.0).abs() < 1e-2); } #[test] fn test_delta_chr_f_wrong_edit() { // When the edit is wrong - let expected = vec![ - DiffLine::Context("one "), - DiffLine::Deletion("two "), - DiffLine::Context("three"), - ]; - - let actual = vec![ - DiffLine::Context("one "), - DiffLine::Context("two "), - DiffLine::Deletion("three"), - DiffLine::Addition("four"), - ]; + let original = "one two three"; + let expected = "one three"; // deleted "two " + let actual = "one two four"; // deleted "three", added "four" // Then the score should be low - let score = delta_chr_f(&expected, &actual); + let score = delta_chr_f(original, expected, actual); assert!(score > 20.0 && score < 40.0); } #[test] fn test_delta_chr_f_partial_match() { - let expected = vec![ - DiffLine::Deletion("let x = 42;"), - DiffLine::Addition("let x = 100;"), - ]; - - let actual = vec![ - DiffLine::Deletion("let x = 42;"), - DiffLine::Addition("let x = 99;"), - ]; + let original = "let x = 42;"; + let expected = "let x = 100;"; + let actual = "let x = 99;"; // We got the edit location right, but the replacement text is wrong. // Deleted ngrams will match, bringing the score somewhere in the middle. - let score = delta_chr_f(&expected, &actual); + let score = delta_chr_f(original, expected, actual); assert!(score > 40.0 && score < 60.0); } #[test] fn test_delta_chr_f_missed_edit() { // When predictions makes no changes - let expected = vec![ - DiffLine::Context("prefix "), - DiffLine::Deletion("old"), - DiffLine::Addition("new"), - DiffLine::Context(" suffix"), - ]; - - let actual = vec![ - DiffLine::Context("prefix "), - DiffLine::Context("old"), - DiffLine::Context(" suffix"), - ]; + let original = "prefix old suffix"; + let expected = "prefix new suffix"; + let actual = "prefix old suffix"; // no change // Then the score should be low (all expected changes are false negatives) - let score = delta_chr_f(&expected, &actual); + let score = delta_chr_f(original, expected, actual); assert!(score < 20.0); } #[test] fn test_delta_chr_f_extra_edit() { // When adding unexpected content - let expected = vec![DiffLine::Context("hello"), DiffLine::Context("world")]; - - let actual = vec![ - DiffLine::Context("hello"), - DiffLine::Addition("extra"), - DiffLine::Context("world"), - ]; + let original = "helloworld"; + let expected = "helloworld"; // no change expected + let actual = "helloextraworld"; // added "extra" // Then the score should be low (all actual changes are false positives) - let score = delta_chr_f(&expected, &actual); + let score = delta_chr_f(original, expected, actual); assert!(score < 20.0); } + + #[test] + fn test_delta_chr_f_no_changes() { + let text = "unchanged text"; + let score = delta_chr_f(text, text, text); + assert!((score - 100.0).abs() < 1e-2); + } } diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index 7b507e6d19c943de92eb0b22c7d24d4026789fed..4ea5a5b8792a6454a7dea3eeeb58ae401cec795a 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -2,11 +2,12 @@ use crate::{ PredictArgs, example::{Example, ExampleScore}, headless::EpAppState, - metrics::{self, ClassificationMetrics}, + metrics, predict::run_prediction, progress::{Progress, Step}, }; -use edit_prediction::udiff::DiffLine; +use anyhow::Context as _; +use edit_prediction::udiff::apply_diff_to_string; use gpui::AsyncApp; use std::sync::Arc; @@ -27,18 +28,32 @@ pub async fn run_scoring( let _progress = Progress::global().start(Step::Score, &example.spec.name); - let expected_patch = parse_patch(&example.spec.expected_patch); + let original_text = &example.buffer.as_ref().unwrap().content; + let expected_texts: Vec = example + .spec + .expected_patches + .iter() + .map(|patch| { + apply_diff_to_string(original_text, patch) + .with_context(|| format!("Expected patch did not apply for {}", example.spec.name)) + }) + .collect::, _>>()?; let mut scores = vec![]; - - for pred in &example.predictions { - let actual_patch = parse_patch(&pred.actual_patch); - let line_match = metrics::line_match_score(&expected_patch, &actual_patch); - let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32; - + for prediction in &example.predictions { + let actual_text = match apply_diff_to_string(original_text, &prediction.actual_patch) { + Ok(text) => text, + Err(_) => { + scores.push(ExampleScore { delta_chr_f: 0.0 }); + continue; + } + }; + let best_delta_chr_f = expected_texts + .iter() + .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32) + .fold(0.0, f32::max); scores.push(ExampleScore { - delta_chr_f, - line_match, + delta_chr_f: best_delta_chr_f, }); } @@ -46,42 +61,25 @@ pub async fn run_scoring( Ok(()) } -fn parse_patch(patch: &str) -> Vec> { - patch.lines().map(DiffLine::parse).collect() -} - pub fn print_report(examples: &[Example]) { eprintln!( "──────────────────────────────────────────────────────────────────────────────────────" ); - eprintln!( - "{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}", - "Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF" - ); + eprintln!("{:<50} {:>10}", "Example name", "DeltaChrF"); eprintln!( "──────────────────────────────────────────────────────────────────────────────────────" ); - let mut all_line_match_scores = Vec::new(); let mut all_delta_chr_f_scores = Vec::new(); for example in examples { for score in example.score.iter() { - let line_match = &score.line_match; - eprintln!( - "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}", - truncate_name(&example.spec.name, 30), - line_match.true_positives, - line_match.false_positives, - line_match.false_negatives, - line_match.precision() * 100.0, - line_match.recall() * 100.0, - line_match.f1_score() * 100.0, + "{:<50} {:>9.2}", + truncate_name(&example.spec.name, 50), score.delta_chr_f ); - all_line_match_scores.push(line_match.clone()); all_delta_chr_f_scores.push(score.delta_chr_f); } } @@ -90,22 +88,11 @@ pub fn print_report(examples: &[Example]) { "──────────────────────────────────────────────────────────────────────────────────────" ); - if !all_line_match_scores.is_empty() { - let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter()); + if !all_delta_chr_f_scores.is_empty() { let avg_delta_chr_f: f32 = all_delta_chr_f_scores.iter().sum::() / all_delta_chr_f_scores.len() as f32; - eprintln!( - "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}", - "TOTAL", - total_line_match.true_positives, - total_line_match.false_positives, - total_line_match.false_negatives, - total_line_match.precision() * 100.0, - total_line_match.recall() * 100.0, - total_line_match.f1_score() * 100.0, - avg_delta_chr_f - ); + eprintln!("{:<50} {:>9.2}", "AVERAGE", avg_delta_chr_f); eprintln!( "──────────────────────────────────────────────────────────────────────────────────────" );