Detailed changes
@@ -74,7 +74,7 @@ pub fn capture_example(
cursor_path: cursor_path.as_std_path().into(),
cursor_position: String::new(),
edit_history,
- expected_patch: String::new(),
+ expected_patches: Vec::new(),
};
spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
Ok(spec)
@@ -350,7 +350,7 @@ mod tests {
seven();
"}
.to_string(),
- expected_patch: "".to_string(),
+ expected_patches: Vec::new()
}
);
}
@@ -15,7 +15,7 @@ pub struct ExampleSpec {
pub cursor_path: Arc<Path>,
pub cursor_position: String,
pub edit_history: String,
- pub expected_patch: String,
+ pub expected_patches: Vec<String>,
}
const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
@@ -95,13 +95,15 @@ impl ExampleSpec {
_ = writeln!(markdown, "## {}", EXPECTED_PATCH_HEADING);
markdown.push('\n');
- _ = writeln!(markdown, "```diff");
- markdown.push_str(&self.expected_patch);
- if !markdown.ends_with('\n') {
+ for patch in &self.expected_patches {
+ _ = writeln!(markdown, "```diff");
+ markdown.push_str(patch);
+ if !markdown.ends_with('\n') {
+ markdown.push('\n');
+ }
+ _ = writeln!(markdown, "```");
markdown.push('\n');
}
- _ = writeln!(markdown, "```");
- markdown.push('\n');
markdown
}
@@ -118,7 +120,7 @@ impl ExampleSpec {
cursor_path: Path::new("").into(),
cursor_position: String::new(),
edit_history: String::new(),
- expected_patch: String::new(),
+ expected_patches: Vec::new(),
};
if let Some(rest) = input.strip_prefix("+++\n")
@@ -212,7 +214,7 @@ impl ExampleSpec {
mem::take(&mut text);
}
Section::ExpectedPatch => {
- spec.expected_patch = mem::take(&mut text);
+ spec.expected_patches.push(mem::take(&mut text));
}
Section::Start | Section::Other => {}
}
@@ -353,7 +355,7 @@ mod tests {
cursor_path: Path::new("test.rs").into(),
cursor_position: String::new(),
edit_history: String::new(),
- expected_patch: String::new(),
+ expected_patches: Vec::new(),
};
// Cursor before `42`
@@ -1,20 +1,15 @@
-use anyhow::{Result, anyhow};
+use anyhow::Result;
use std::mem;
use crate::example::Example;
pub async fn run_distill(example: &mut Example) -> Result<()> {
- let [prediction]: [_; 1] =
- mem::take(&mut example.predictions)
- .try_into()
- .map_err(|preds: Vec<_>| {
- anyhow!(
- "Example has {} predictions, but it should have exactly one",
- preds.len()
- )
- })?;
+ let predictions = mem::take(&mut example.predictions)
+ .into_iter()
+ .map(|p| p.actual_patch)
+ .collect();
- example.spec.expected_patch = prediction.actual_patch;
+ example.spec.expected_patches = predictions;
example.prompt = None;
example.predictions = Vec::new();
example.score = Vec::new();
@@ -1,4 +1,4 @@
-use crate::{PredictionProvider, PromptFormat, metrics::ClassificationMetrics};
+use crate::{PredictionProvider, PromptFormat};
use anyhow::{Context as _, Result};
use collections::HashMap;
use edit_prediction::example_spec::ExampleSpec;
@@ -87,7 +87,6 @@ pub struct ExamplePrediction {
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExampleScore {
pub delta_chr_f: f32,
- pub line_match: ClassificationMetrics,
}
impl Example {
@@ -30,7 +30,13 @@ pub async fn run_format_prompt(
let prompt = TeacherPrompt::format_prompt(example);
example.prompt = Some(ExamplePrompt {
input: prompt,
- expected_output: example.spec.expected_patch.clone(), // TODO
+ // TODO
+ expected_output: example
+ .spec
+ .expected_patches
+ .first()
+ .context("no expected patches")?
+ .clone(),
format: prompt_format,
});
}
@@ -68,8 +74,15 @@ pub async fn run_format_prompt(
))
})??;
let prompt = format_zeta_prompt(&input);
- let expected_output =
- zeta2_output_for_patch(&input, &example.spec.expected_patch.clone())?;
+ let expected_output = zeta2_output_for_patch(
+ &input,
+ &example
+ .spec
+ .expected_patches
+ .first()
+ .context("expected patches is empty")?
+ .clone(),
+ )?;
example.prompt = Some(ExamplePrompt {
input: prompt,
expected_output,
@@ -1,34 +1,17 @@
-use collections::{HashMap, HashSet};
-use edit_prediction::udiff::DiffLine;
-use serde::{Deserialize, Serialize};
+use collections::HashMap;
type Counts = HashMap<String, usize>;
type CountsDelta = HashMap<String, isize>;
-#[derive(Default, Debug, Clone, Serialize, Deserialize)]
-pub struct ClassificationMetrics {
- pub true_positives: usize,
- pub false_positives: usize,
- pub false_negatives: usize,
+#[derive(Default, Debug, Clone)]
+struct ClassificationMetrics {
+ true_positives: usize,
+ false_positives: usize,
+ false_negatives: usize,
}
impl ClassificationMetrics {
- pub fn from_sets(
- expected: &HashSet<String>,
- actual: &HashSet<String>,
- ) -> ClassificationMetrics {
- let true_positives = expected.intersection(actual).count();
- let false_positives = actual.difference(expected).count();
- let false_negatives = expected.difference(actual).count();
-
- ClassificationMetrics {
- true_positives,
- false_positives,
- false_negatives,
- }
- }
-
- pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
+ fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
let mut true_positives = 0;
let mut false_positives = 0;
let mut false_negatives = 0;
@@ -56,27 +39,7 @@ impl ClassificationMetrics {
}
}
- pub fn aggregate<'a>(
- scores: impl Iterator<Item = &'a ClassificationMetrics>,
- ) -> ClassificationMetrics {
- let mut true_positives = 0;
- let mut false_positives = 0;
- let mut false_negatives = 0;
-
- for score in scores {
- true_positives += score.true_positives;
- false_positives += score.false_positives;
- false_negatives += score.false_negatives;
- }
-
- ClassificationMetrics {
- true_positives,
- false_positives,
- false_negatives,
- }
- }
-
- pub fn precision(&self) -> f64 {
+ fn precision(&self) -> f64 {
if self.true_positives + self.false_positives == 0 {
0.0
} else {
@@ -84,42 +47,13 @@ impl ClassificationMetrics {
}
}
- pub fn recall(&self) -> f64 {
+ fn recall(&self) -> f64 {
if self.true_positives + self.false_negatives == 0 {
0.0
} else {
self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
}
}
-
- pub fn f1_score(&self) -> f64 {
- let recall = self.recall();
- let precision = self.precision();
- if precision + recall == 0.0 {
- 0.0
- } else {
- 2.0 * precision * recall / (precision + recall)
- }
- }
-}
-
-pub fn line_match_score(
- expected_patch: &[DiffLine],
- actual_patch: &[DiffLine],
-) -> ClassificationMetrics {
- let expected_change_lines = expected_patch
- .iter()
- .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
- .map(|line| line.to_string())
- .collect();
-
- let actual_change_lines = actual_patch
- .iter()
- .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
- .map(|line| line.to_string())
- .collect();
-
- ClassificationMetrics::from_sets(&expected_change_lines, &actual_change_lines)
}
enum ChrfWhitespace {
@@ -135,55 +69,26 @@ const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Ignore;
/// Computes a delta-chrF score that compares two sets of edits.
///
/// This metric works by:
-/// 1. Reconstructing original, golden (expected result), and actual texts from diffs
-/// 2. Computing n-gram count differences (deltas) between originalβgolden and originalβactual
-/// 3. Comparing these deltas to measure how well actual edits match expected edits
-pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 {
- // Reconstruct texts from diffs
- let mut original_text = String::new(); // state of the text before any edits
- let mut golden_text = String::new(); // text after applying golden edits
- let mut actual_text = String::new(); // text after applying actual edits
-
- for line in expected {
- match line {
- DiffLine::Context(s) => {
- original_text.push_str(s);
- golden_text.push_str(s);
- }
- DiffLine::Deletion(s) => {
- original_text.push_str(s);
- }
- DiffLine::Addition(s) => {
- golden_text.push_str(s);
- }
- _ => {}
- }
- }
-
- for line in actual {
- match line {
- DiffLine::Context(s) | DiffLine::Addition(s) => {
- actual_text.push_str(s);
- }
- _ => {}
- }
- }
-
- // Edge case
- if original_text == golden_text && golden_text == actual_text {
+/// 1. Computing n-gram count differences (deltas) between originalβexpected and originalβactual
+/// 2. Comparing these deltas to measure how well actual edits match expected edits
+///
+/// Returns a score from 0.0 to 100.0, where 100.0 means the actual edits perfectly match
+/// the expected edits.
+pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> f64 {
+ // Edge case: if all texts are identical, the edits match perfectly
+ if original == expected && expected == actual {
return 100.0;
}
- // Compute the metric
- let original_ngrams = chr_f_ngram_counts(&original_text);
- let golden_ngrams = chr_f_ngram_counts(&golden_text);
- let actual_ngrams = chr_f_ngram_counts(&actual_text);
+ let original_ngrams = chr_f_ngram_counts(original);
+ let expected_ngrams = chr_f_ngram_counts(expected);
+ let actual_ngrams = chr_f_ngram_counts(actual);
let mut total_precision = 0.0;
let mut total_recall = 0.0;
for order in 0..CHR_F_CHAR_ORDER {
- let expected_delta = compute_ngram_delta(&golden_ngrams[order], &original_ngrams[order]);
+ let expected_delta = compute_ngram_delta(&expected_ngrams[order], &original_ngrams[order]);
let actual_delta = compute_ngram_delta(&actual_ngrams[order], &original_ngrams[order]);
if expected_delta.is_empty() && actual_delta.is_empty() {
@@ -278,94 +183,68 @@ fn count_ngrams(text: &str, n: usize) -> Counts {
#[cfg(test)]
mod test {
use super::*;
- use edit_prediction::udiff::DiffLine;
#[test]
fn test_delta_chr_f_perfect_match() {
- let diff = vec![
- DiffLine::Context("fn main() {"),
- DiffLine::Deletion(" println!(\"Hello\");"),
- DiffLine::Addition(" println!(\"Hello, World!\");"),
- DiffLine::Context("}"),
- ];
-
- let score = delta_chr_f(&diff, &diff);
+ let original = "fn main() { println!(\"Hello\");}";
+ let expected = "fn main() { println!(\"Hello, World!\");}";
+
+ let score = delta_chr_f(original, expected, expected);
assert!((score - 100.0).abs() < 1e-2);
}
#[test]
fn test_delta_chr_f_wrong_edit() {
// When the edit is wrong
- let expected = vec![
- DiffLine::Context("one "),
- DiffLine::Deletion("two "),
- DiffLine::Context("three"),
- ];
-
- let actual = vec![
- DiffLine::Context("one "),
- DiffLine::Context("two "),
- DiffLine::Deletion("three"),
- DiffLine::Addition("four"),
- ];
+ let original = "one two three";
+ let expected = "one three"; // deleted "two "
+ let actual = "one two four"; // deleted "three", added "four"
// Then the score should be low
- let score = delta_chr_f(&expected, &actual);
+ let score = delta_chr_f(original, expected, actual);
assert!(score > 20.0 && score < 40.0);
}
#[test]
fn test_delta_chr_f_partial_match() {
- let expected = vec![
- DiffLine::Deletion("let x = 42;"),
- DiffLine::Addition("let x = 100;"),
- ];
-
- let actual = vec![
- DiffLine::Deletion("let x = 42;"),
- DiffLine::Addition("let x = 99;"),
- ];
+ let original = "let x = 42;";
+ let expected = "let x = 100;";
+ let actual = "let x = 99;";
// We got the edit location right, but the replacement text is wrong.
// Deleted ngrams will match, bringing the score somewhere in the middle.
- let score = delta_chr_f(&expected, &actual);
+ let score = delta_chr_f(original, expected, actual);
assert!(score > 40.0 && score < 60.0);
}
#[test]
fn test_delta_chr_f_missed_edit() {
// When predictions makes no changes
- let expected = vec![
- DiffLine::Context("prefix "),
- DiffLine::Deletion("old"),
- DiffLine::Addition("new"),
- DiffLine::Context(" suffix"),
- ];
-
- let actual = vec![
- DiffLine::Context("prefix "),
- DiffLine::Context("old"),
- DiffLine::Context(" suffix"),
- ];
+ let original = "prefix old suffix";
+ let expected = "prefix new suffix";
+ let actual = "prefix old suffix"; // no change
// Then the score should be low (all expected changes are false negatives)
- let score = delta_chr_f(&expected, &actual);
+ let score = delta_chr_f(original, expected, actual);
assert!(score < 20.0);
}
#[test]
fn test_delta_chr_f_extra_edit() {
// When adding unexpected content
- let expected = vec![DiffLine::Context("hello"), DiffLine::Context("world")];
-
- let actual = vec![
- DiffLine::Context("hello"),
- DiffLine::Addition("extra"),
- DiffLine::Context("world"),
- ];
+ let original = "helloworld";
+ let expected = "helloworld"; // no change expected
+ let actual = "helloextraworld"; // added "extra"
// Then the score should be low (all actual changes are false positives)
- let score = delta_chr_f(&expected, &actual);
+ let score = delta_chr_f(original, expected, actual);
assert!(score < 20.0);
}
+
+ #[test]
+ fn test_delta_chr_f_no_changes() {
+ let text = "unchanged text";
+ let score = delta_chr_f(text, text, text);
+ assert!((score - 100.0).abs() < 1e-2);
+ }
}
@@ -2,11 +2,12 @@ use crate::{
PredictArgs,
example::{Example, ExampleScore},
headless::EpAppState,
- metrics::{self, ClassificationMetrics},
+ metrics,
predict::run_prediction,
progress::{Progress, Step},
};
-use edit_prediction::udiff::DiffLine;
+use anyhow::Context as _;
+use edit_prediction::udiff::apply_diff_to_string;
use gpui::AsyncApp;
use std::sync::Arc;
@@ -27,18 +28,32 @@ pub async fn run_scoring(
let _progress = Progress::global().start(Step::Score, &example.spec.name);
- let expected_patch = parse_patch(&example.spec.expected_patch);
+ let original_text = &example.buffer.as_ref().unwrap().content;
+ let expected_texts: Vec<String> = example
+ .spec
+ .expected_patches
+ .iter()
+ .map(|patch| {
+ apply_diff_to_string(original_text, patch)
+ .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
+ })
+ .collect::<Result<Vec<_>, _>>()?;
let mut scores = vec![];
-
- for pred in &example.predictions {
- let actual_patch = parse_patch(&pred.actual_patch);
- let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
- let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
-
+ for prediction in &example.predictions {
+ let actual_text = match apply_diff_to_string(original_text, &prediction.actual_patch) {
+ Ok(text) => text,
+ Err(_) => {
+ scores.push(ExampleScore { delta_chr_f: 0.0 });
+ continue;
+ }
+ };
+ let best_delta_chr_f = expected_texts
+ .iter()
+ .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
+ .fold(0.0, f32::max);
scores.push(ExampleScore {
- delta_chr_f,
- line_match,
+ delta_chr_f: best_delta_chr_f,
});
}
@@ -46,42 +61,25 @@ pub async fn run_scoring(
Ok(())
}
-fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
- patch.lines().map(DiffLine::parse).collect()
-}
-
pub fn print_report(examples: &[Example]) {
eprintln!(
"ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
);
- eprintln!(
- "{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
- "Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
- );
+ eprintln!("{:<50} {:>10}", "Example name", "DeltaChrF");
eprintln!(
"ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
);
- let mut all_line_match_scores = Vec::new();
let mut all_delta_chr_f_scores = Vec::new();
for example in examples {
for score in example.score.iter() {
- let line_match = &score.line_match;
-
eprintln!(
- "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
- truncate_name(&example.spec.name, 30),
- line_match.true_positives,
- line_match.false_positives,
- line_match.false_negatives,
- line_match.precision() * 100.0,
- line_match.recall() * 100.0,
- line_match.f1_score() * 100.0,
+ "{:<50} {:>9.2}",
+ truncate_name(&example.spec.name, 50),
score.delta_chr_f
);
- all_line_match_scores.push(line_match.clone());
all_delta_chr_f_scores.push(score.delta_chr_f);
}
}
@@ -90,22 +88,11 @@ pub fn print_report(examples: &[Example]) {
"ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
);
- if !all_line_match_scores.is_empty() {
- let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
+ if !all_delta_chr_f_scores.is_empty() {
let avg_delta_chr_f: f32 =
all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
- eprintln!(
- "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
- "TOTAL",
- total_line_match.true_positives,
- total_line_match.false_positives,
- total_line_match.false_negatives,
- total_line_match.precision() * 100.0,
- total_line_match.recall() * 100.0,
- total_line_match.f1_score() * 100.0,
- avg_delta_chr_f
- );
+ eprintln!("{:<50} {:>9.2}", "AVERAGE", avg_delta_chr_f);
eprintln!(
"ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
);