@@ -83,11 +83,6 @@ pub async fn run_evaluate_one(
let evaluation_result = evaluate(&example.example, &predictions);
- println!("# {}\n", example.name);
- println!(
- "## Expected Context: \n\n```\n{}\n```\n\n",
- compare_context(&example.example, &predictions)
- );
println!(
"## Expected edit prediction:\n\n```diff\n{}\n```\n",
compare_diffs(&example.example.expected_patch, &predictions.diff)
@@ -104,21 +99,30 @@ pub async fn run_evaluate_one(
#[derive(Debug, Default)]
pub struct EvaluationResult {
- pub context: Scores,
pub edit_prediction: Scores,
+ pub context: Scores,
}
#[derive(Default, Debug)]
pub struct Scores {
- pub precision: f64,
- pub recall: f64,
- pub f1_score: f64,
pub true_positives: usize,
pub false_positives: usize,
pub false_negatives: usize,
}
impl Scores {
+ pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
+ let true_positives = expected.intersection(actual).count();
+ let false_positives = actual.difference(expected).count();
+ let false_negatives = expected.difference(actual).count();
+
+ Scores {
+ true_positives,
+ false_positives,
+ false_negatives,
+ }
+ }
+
pub fn to_markdown(&self) -> String {
format!(
"
@@ -128,17 +132,15 @@ F1 Score : {:.4}
True Positives : {}
False Positives : {}
False Negatives : {}",
- self.precision,
- self.recall,
- self.f1_score,
+ self.precision(),
+ self.recall(),
+ self.f1_score(),
self.true_positives,
self.false_positives,
self.false_negatives
)
}
-}
-impl Scores {
pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
let mut true_positives = 0;
let mut false_positives = 0;
@@ -150,22 +152,38 @@ impl Scores {
false_negatives += score.false_negatives;
}
- let precision = true_positives as f64 / (true_positives + false_positives) as f64;
- let recall = true_positives as f64 / (true_positives + false_negatives) as f64;
- let mut f1_score = 2.0 * precision * recall / (precision + recall);
- if f1_score.is_nan() {
- f1_score = 0.0;
- }
-
Scores {
- precision,
- recall,
- f1_score,
true_positives,
false_positives,
false_negatives,
}
}
+
+ pub fn precision(&self) -> f64 {
+ if self.true_positives + self.false_positives == 0 {
+ 0.0
+ } else {
+ self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
+ }
+ }
+
+ pub fn recall(&self) -> f64 {
+ if self.true_positives + self.false_negatives == 0 {
+ 0.0
+ } else {
+ self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
+ }
+ }
+
+ pub fn f1_score(&self) -> f64 {
+ let recall = self.recall();
+ let precision = self.precision();
+ if precision + recall == 0.0 {
+ 0.0
+ } else {
+ 2.0 * precision * recall / (precision + recall)
+ }
+ }
}
impl EvaluationResult {
@@ -185,19 +203,9 @@ impl EvaluationResult {
}
pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
- let mut result = EvaluationResult::default();
+ let mut eval_result = EvaluationResult::default();
- let expected_context_lines = example
- .expected_excerpts
- .iter()
- .flat_map(|excerpt| {
- excerpt
- .text
- .lines()
- .map(|line| format!("{}: {line}", excerpt.path.display()))
- })
- .collect();
- let actual_context_lines = preds
+ let actual_context_lines: HashSet<_> = preds
.excerpts
.iter()
.flat_map(|excerpt| {
@@ -208,8 +216,39 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResul
})
.collect();
- result.context = precision_recall(&expected_context_lines, &actual_context_lines);
+ let mut false_positive_lines = actual_context_lines.clone();
+
+ for entry in &example.expected_context {
+ let mut best_alternative_score = Scores::default();
+
+ for alternative in &entry.alternatives {
+ let expected: HashSet<_> = alternative
+ .excerpts
+ .iter()
+ .flat_map(|excerpt| {
+ excerpt
+ .text
+ .lines()
+ .map(|line| format!("{}: {line}", excerpt.path.display()))
+ })
+ .collect();
+
+ let scores = Scores::new(&expected, &actual_context_lines);
+ false_positive_lines.retain(|line| !actual_context_lines.contains(line));
+
+ if scores.recall() > best_alternative_score.recall() {
+ best_alternative_score = scores;
+ }
+ }
+
+ eval_result.context.false_negatives += best_alternative_score.false_negatives;
+ eval_result.context.true_positives += best_alternative_score.true_positives;
+ }
+
+ eval_result.context.false_positives = false_positive_lines.len();
+
+ // todo: alternatives for patches
let expected_patch_lines = example
.expected_patch
.lines()
@@ -226,86 +265,8 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResul
.map(|line| line.to_string())
.collect();
- result.edit_prediction = precision_recall(&expected_patch_lines, &actual_patch_lines);
-
- result
-}
-
-fn precision_recall(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
- let true_positives = expected.intersection(actual).count();
- let false_positives = actual.difference(expected).count();
- let false_negatives = expected.difference(actual).count();
-
- let precision = if true_positives + false_positives == 0 {
- 0.0
- } else {
- true_positives as f64 / (true_positives + false_positives) as f64
- };
- let recall = if true_positives + false_negatives == 0 {
- 0.0
- } else {
- true_positives as f64 / (true_positives + false_negatives) as f64
- };
- let f1_score = if precision + recall == 0.0 {
- 0.0
- } else {
- 2.0 * precision * recall / (precision + recall)
- };
-
- Scores {
- precision,
- recall,
- f1_score,
- true_positives,
- false_positives,
- false_negatives,
- }
-}
-
-/// Compare actual and expected context.
-///
-/// Return expected context annotated with these markers:
-///
-/// `ā context line` -- line was correctly predicted
-/// `ā context line` -- line is missing from predictions
-pub fn compare_context(example: &Example, preds: &PredictionDetails) -> String {
- let use_color = std::io::stdout().is_terminal();
- let green = if use_color { "\x1b[32m" } else { "" };
- let red = if use_color { "\x1b[31m" } else { "" };
- let reset = if use_color { "\x1b[0m" } else { "" };
- let expected: Vec<_> = example
- .expected_excerpts
- .iter()
- .flat_map(|excerpt| {
- excerpt
- .text
- .lines()
- .map(|line| (excerpt.path.clone(), line))
- })
- .collect();
- let actual: HashSet<_> = preds
- .excerpts
- .iter()
- .flat_map(|excerpt| {
- excerpt
- .text
- .lines()
- .map(|line| (excerpt.path.clone(), line))
- })
- .collect();
-
- let annotated = expected
- .iter()
- .map(|(path, line)| {
- if actual.contains(&(path.to_path_buf(), line)) {
- format!("{green}ā {line}{reset}")
- } else {
- format!("{red}ā {line}{reset}")
- }
- })
- .collect::<Vec<String>>();
-
- annotated.join("\n")
+ eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines);
+ eval_result
}
/// Return annotated `patch_a` so that:
@@ -13,6 +13,7 @@ use anyhow::{Context as _, Result, anyhow};
use clap::ValueEnum;
use cloud_zeta2_prompt::CURSOR_MARKER;
use collections::HashMap;
+use edit_prediction_context::Line;
use futures::{
AsyncWriteExt as _,
lock::{Mutex, OwnedMutexGuard},
@@ -31,7 +32,7 @@ const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
const EDIT_HISTORY_HEADING: &str = "Edit History";
const CURSOR_POSITION_HEADING: &str = "Cursor Position";
const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
-const EXPECTED_EXCERPTS_HEADING: &str = "Expected Excerpts";
+const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
const REPOSITORY_URL_FIELD: &str = "repository_url";
const REVISION_FIELD: &str = "revision";
@@ -50,10 +51,9 @@ pub struct Example {
pub cursor_position: String,
pub edit_history: String,
pub expected_patch: String,
- pub expected_excerpts: Vec<ExpectedExcerpt>,
+ pub expected_context: Vec<ExpectedContextEntry>,
}
-pub type ExpectedExcerpt = Excerpt;
pub type ActualExcerpt = Excerpt;
#[derive(Clone, Debug, Serialize, Deserialize)]
@@ -62,6 +62,25 @@ pub struct Excerpt {
pub text: String,
}
+#[derive(Default, Clone, Debug, Serialize, Deserialize)]
+pub struct ExpectedContextEntry {
+ pub heading: String,
+ pub alternatives: Vec<ExpectedExcerptSet>,
+}
+
+#[derive(Default, Clone, Debug, Serialize, Deserialize)]
+pub struct ExpectedExcerptSet {
+ pub heading: String,
+ pub excerpts: Vec<ExpectedExcerpt>,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ExpectedExcerpt {
+ pub path: PathBuf,
+ pub text: String,
+ pub required_lines: Vec<Line>,
+}
+
#[derive(ValueEnum, Debug, Clone)]
pub enum ExampleFormat {
Json,
@@ -111,21 +130,32 @@ impl NamedExample {
cursor_position: String::new(),
edit_history: String::new(),
expected_patch: String::new(),
- expected_excerpts: Vec::new(),
+ expected_context: Vec::new(),
},
};
let mut text = String::new();
- let mut current_section = String::new();
let mut block_info: CowStr = "".into();
+ #[derive(PartialEq)]
+ enum Section {
+ UncommittedDiff,
+ EditHistory,
+ CursorPosition,
+ ExpectedExcerpts,
+ ExpectedPatch,
+ Other,
+ }
+
+ let mut current_section = Section::Other;
+
for event in parser {
match event {
Event::Text(line) => {
text.push_str(&line);
if !named.name.is_empty()
- && current_section.is_empty()
+ && current_section == Section::Other
// in h1 section
&& let Some((field, value)) = line.split_once('=')
{
@@ -151,7 +181,47 @@ impl NamedExample {
named.name = mem::take(&mut text);
}
Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
- current_section = mem::take(&mut text);
+ let title = mem::take(&mut text);
+ current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
+ Section::UncommittedDiff
+ } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
+ Section::EditHistory
+ } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
+ Section::CursorPosition
+ } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
+ Section::ExpectedPatch
+ } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
+ Section::ExpectedExcerpts
+ } else {
+ eprintln!("Warning: Unrecognized section `{title:?}`");
+ Section::Other
+ };
+ }
+ Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
+ let heading = mem::take(&mut text);
+ match current_section {
+ Section::ExpectedExcerpts => {
+ named.example.expected_context.push(ExpectedContextEntry {
+ heading,
+ alternatives: Vec::new(),
+ });
+ }
+ _ => {}
+ }
+ }
+ Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
+ let heading = mem::take(&mut text);
+ match current_section {
+ Section::ExpectedExcerpts => {
+ let expected_context = &mut named.example.expected_context;
+ let last_entry = expected_context.last_mut().unwrap();
+ last_entry.alternatives.push(ExpectedExcerptSet {
+ heading,
+ excerpts: Vec::new(),
+ })
+ }
+ _ => {}
+ }
}
Event::End(TagEnd::Heading(level)) => {
anyhow::bail!("Unexpected heading level: {level}");
@@ -172,23 +242,53 @@ impl NamedExample {
}
Event::End(TagEnd::CodeBlock) => {
let block_info = block_info.trim();
- if current_section.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
- named.example.uncommitted_diff = mem::take(&mut text);
- } else if current_section.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
- named.example.edit_history.push_str(&mem::take(&mut text));
- } else if current_section.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
- named.example.cursor_path = block_info.into();
- named.example.cursor_position = mem::take(&mut text);
- } else if current_section.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
- named.example.expected_patch = mem::take(&mut text);
- } else if current_section.eq_ignore_ascii_case(EXPECTED_EXCERPTS_HEADING) {
- // TODO: "ā¦" should not be a part of the excerpt
- named.example.expected_excerpts.push(ExpectedExcerpt {
- path: block_info.into(),
- text: mem::take(&mut text),
- });
- } else {
- eprintln!("Warning: Unrecognized section `{current_section:?}`")
+ match current_section {
+ Section::UncommittedDiff => {
+ named.example.uncommitted_diff = mem::take(&mut text);
+ }
+ Section::EditHistory => {
+ named.example.edit_history.push_str(&mem::take(&mut text));
+ }
+ Section::CursorPosition => {
+ named.example.cursor_path = block_info.into();
+ named.example.cursor_position = mem::take(&mut text);
+ }
+ Section::ExpectedExcerpts => {
+ let text = mem::take(&mut text);
+ for excerpt in text.split("\nā¦\n") {
+ let (mut text, required_lines) = extract_required_lines(&excerpt);
+ if !text.ends_with('\n') {
+ text.push('\n');
+ }
+ let alternatives = &mut named
+ .example
+ .expected_context
+ .last_mut()
+ .unwrap()
+ .alternatives;
+
+ if alternatives.is_empty() {
+ alternatives.push(ExpectedExcerptSet {
+ heading: String::new(),
+ excerpts: vec![],
+ });
+ }
+
+ alternatives
+ .last_mut()
+ .unwrap()
+ .excerpts
+ .push(ExpectedExcerpt {
+ path: block_info.into(),
+ text,
+ required_lines,
+ });
+ }
+ }
+ Section::ExpectedPatch => {
+ named.example.expected_patch = mem::take(&mut text);
+ }
+ Section::Other => {}
}
}
_ => {}
@@ -404,6 +504,47 @@ impl NamedExample {
}
}
+fn extract_required_lines(text: &str) -> (String, Vec<Line>) {
+ const MARKER: &str = "[ZETA]";
+ let mut new_text = String::new();
+ let mut required_lines = Vec::new();
+ let mut skipped_lines = 0_u32;
+
+ for (row, mut line) in text.split('\n').enumerate() {
+ if let Some(marker_column) = line.find(MARKER) {
+ let mut strip_column = marker_column;
+
+ while strip_column > 0 {
+ let prev_char = line[strip_column - 1..].chars().next().unwrap();
+ if prev_char.is_whitespace() || ['/', '#'].contains(&prev_char) {
+ strip_column -= 1;
+ } else {
+ break;
+ }
+ }
+
+ let metadata = &line[marker_column + MARKER.len()..];
+ if metadata.contains("required") {
+ required_lines.push(Line(row as u32 - skipped_lines));
+ }
+
+ if strip_column == 0 {
+ skipped_lines += 1;
+ continue;
+ }
+
+ line = &line[..strip_column];
+ }
+
+ new_text.push_str(line);
+ new_text.push('\n');
+ }
+
+ new_text.pop();
+
+ (new_text, required_lines)
+}
+
async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
let output = smol::process::Command::new("git")
.current_dir(repo_path)
@@ -458,21 +599,34 @@ impl Display for NamedExample {
)?;
}
- if !self.example.expected_excerpts.is_empty() {
- write!(f, "\n## {EXPECTED_EXCERPTS_HEADING}\n\n")?;
-
- for excerpt in &self.example.expected_excerpts {
- write!(
- f,
- "`````{}{}\n{}`````\n\n",
- excerpt
- .path
- .extension()
- .map(|ext| format!("{} ", ext.to_string_lossy()))
- .unwrap_or_default(),
- excerpt.path.display(),
- excerpt.text
- )?;
+ if !self.example.expected_context.is_empty() {
+ write!(f, "\n## {EXPECTED_CONTEXT_HEADING}\n\n")?;
+
+ for entry in &self.example.expected_context {
+ write!(f, "\n### {}\n\n", entry.heading)?;
+
+ let skip_h4 =
+ entry.alternatives.len() == 1 && entry.alternatives[0].heading.is_empty();
+
+ for excerpt_set in &entry.alternatives {
+ if !skip_h4 {
+ write!(f, "\n#### {}\n\n", excerpt_set.heading)?;
+ }
+
+ for excerpt in &excerpt_set.excerpts {
+ write!(
+ f,
+ "`````{}{}\n{}`````\n\n",
+ excerpt
+ .path
+ .extension()
+ .map(|ext| format!("{} ", ext.to_string_lossy()))
+ .unwrap_or_default(),
+ excerpt.path.display(),
+ excerpt.text
+ )?;
+ }
+ }
}
}
@@ -496,3 +650,38 @@ pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
.lock_owned()
.await
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use indoc::indoc;
+ use pretty_assertions::assert_eq;
+
+ #[test]
+ fn test_extract_required_lines() {
+ let input = indoc! {"
+ zero
+ one // [ZETA] required
+ two
+ // [ZETA] something
+ three
+ four # [ZETA] required
+ five
+ "};
+
+ let expected_updated_input = indoc! {"
+ zero
+ one
+ two
+ three
+ four
+ five
+ "};
+
+ let expected_required_lines = vec![Line(1), Line(4)];
+
+ let (updated_input, required_lines) = extract_required_lines(input);
+ assert_eq!(updated_input, expected_updated_input);
+ assert_eq!(required_lines, expected_required_lines);
+ }
+}