From 7ecbf8cf60feff43da961ec3e3d99e7570f75454 Mon Sep 17 00:00:00 2001 From: Ben Kunkle Date: Tue, 25 Nov 2025 10:44:04 -0800 Subject: [PATCH] zeta2: Remove expected context from evals (#43430) Closes #ISSUE Release Notes: - N/A *or* Added/Fixed/Improved ... --- crates/zeta_cli/src/evaluate.rs | 114 +------------------ crates/zeta_cli/src/example.rs | 189 +------------------------------- crates/zeta_cli/src/main.rs | 2 - crates/zeta_cli/src/predict.rs | 87 ++------------- 4 files changed, 18 insertions(+), 374 deletions(-) diff --git a/crates/zeta_cli/src/evaluate.rs b/crates/zeta_cli/src/evaluate.rs index a0ebdf998595ccacec2dafecf51b6094e5e401b5..6726dcb3aafdeff7fe41cbbbc49850c1e7465cf4 100644 --- a/crates/zeta_cli/src/evaluate.rs +++ b/crates/zeta_cli/src/evaluate.rs @@ -1,5 +1,5 @@ use std::{ - collections::{BTreeSet, HashMap}, + collections::HashMap, io::{IsTerminal, Write}, sync::Arc, }; @@ -125,21 +125,10 @@ fn write_aggregated_scores( .peekable(); let has_edit_predictions = edit_predictions.peek().is_some(); let aggregated_result = EvaluationResult { - context: Scores::aggregate(successful.iter().map(|r| &r.context)), edit_prediction: has_edit_predictions.then(|| Scores::aggregate(edit_predictions)), prompt_len: successful.iter().map(|r| r.prompt_len).sum::() / successful.len(), generated_len: successful.iter().map(|r| r.generated_len).sum::() / successful.len(), - context_lines_found_in_context: successful - .iter() - .map(|r| r.context_lines_found_in_context) - .sum::() - / successful.len(), - context_lines_in_expected_patch: successful - .iter() - .map(|r| r.context_lines_in_expected_patch) - .sum::() - / successful.len(), }; writeln!(w, "\n{}", "-".repeat(80))?; @@ -261,11 +250,8 @@ fn write_eval_result( #[derive(Debug, Default)] pub struct EvaluationResult { pub edit_prediction: Option, - pub context: Scores, pub prompt_len: usize, pub generated_len: usize, - pub context_lines_in_expected_patch: usize, - pub context_lines_found_in_context: usize, } #[derive(Default, Debug)] @@ -363,14 +349,6 @@ impl std::fmt::Display for EvaluationResult { impl EvaluationResult { fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - r#" -### Context Scores -{} -"#, - self.context.to_markdown(), - )?; if let Some(prediction) = &self.edit_prediction { write!( f, @@ -387,34 +365,18 @@ impl EvaluationResult { writeln!(f, "### Scores\n")?; writeln!( f, - " Prompt Generated RetrievedContext PatchContext TP FP FN Precision Recall F1" + " Prompt Generated TP FP FN Precision Recall F1" )?; writeln!( f, - "─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────" - )?; - writeln!( - f, - "Context Retrieval {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}", - "", - "", - "", - "", - self.context.true_positives, - self.context.false_positives, - self.context.false_negatives, - self.context.precision() * 100.0, - self.context.recall() * 100.0, - self.context.f1_score() * 100.0 + "───────────────────────────────────────────────────────────────────────────────────────────────" )?; if let Some(edit_prediction) = &self.edit_prediction { writeln!( f, - "Edit Prediction {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}", + "Edit Prediction {:<7} {:<9} {:<6} {:<6} {:<6} {:>9.2} {:>8.2} {:>7.2}", self.prompt_len, self.generated_len, - self.context_lines_found_in_context, - self.context_lines_in_expected_patch, edit_prediction.true_positives, edit_prediction.false_positives, edit_prediction.false_negatives, @@ -434,53 +396,6 @@ fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> Eval ..Default::default() }; - let actual_context_lines: HashSet<_> = preds - .excerpts - .iter() - .flat_map(|excerpt| { - excerpt - .text - .lines() - .map(|line| format!("{}: {line}", excerpt.path.display())) - }) - .collect(); - - let mut false_positive_lines = actual_context_lines.clone(); - - for entry in &example.expected_context { - let mut best_alternative_score: Option = None; - - for alternative in &entry.alternatives { - let expected: HashSet<_> = alternative - .excerpts - .iter() - .flat_map(|excerpt| { - excerpt - .text - .lines() - .map(|line| format!("{}: {line}", excerpt.path.display())) - }) - .collect(); - - let scores = Scores::new(&expected, &actual_context_lines); - - false_positive_lines.retain(|line| !expected.contains(line)); - - if best_alternative_score - .as_ref() - .is_none_or(|best| scores.recall() > best.recall()) - { - best_alternative_score = Some(scores); - } - } - - let best_alternative = best_alternative_score.unwrap_or_default(); - eval_result.context.false_negatives += best_alternative.false_negatives; - eval_result.context.true_positives += best_alternative.true_positives; - } - - eval_result.context.false_positives = false_positive_lines.len(); - if predict { // todo: alternatives for patches let expected_patch = example @@ -493,25 +408,6 @@ fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> Eval .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_))) .map(|line| line.to_string()) .collect(); - let expected_context_lines = expected_patch - .iter() - .filter_map(|line| { - if let DiffLine::Context(str) = line { - Some(String::from(*str)) - } else { - None - } - }) - .collect::>(); - let actual_context_lines = preds - .excerpts - .iter() - .flat_map(|excerpt| excerpt.text.lines().map(ToOwned::to_owned)) - .collect::>(); - - let matched = expected_context_lines - .intersection(&actual_context_lines) - .count(); let actual_patch_lines = preds .diff @@ -522,8 +418,6 @@ fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> Eval .collect(); eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines)); - eval_result.context_lines_in_expected_patch = expected_context_lines.len(); - eval_result.context_lines_found_in_context = matched; } eval_result diff --git a/crates/zeta_cli/src/example.rs b/crates/zeta_cli/src/example.rs index 7dbe304a88b9ea024adab793fa782fd2f4bdf1c0..a9d4c4f47c5a05d4198b1cffaee51e14a122e88d 100644 --- a/crates/zeta_cli/src/example.rs +++ b/crates/zeta_cli/src/example.rs @@ -14,7 +14,6 @@ use anyhow::{Context as _, Result, anyhow}; use clap::ValueEnum; use cloud_zeta2_prompt::CURSOR_MARKER; use collections::HashMap; -use edit_prediction_context::Line; use futures::{ AsyncWriteExt as _, lock::{Mutex, OwnedMutexGuard}, @@ -53,7 +52,6 @@ pub struct Example { pub cursor_position: String, pub edit_history: String, pub expected_patch: String, - pub expected_context: Vec, } pub type ActualExcerpt = Excerpt; @@ -64,25 +62,6 @@ pub struct Excerpt { pub text: String, } -#[derive(Default, Clone, Debug, Serialize, Deserialize)] -pub struct ExpectedContextEntry { - pub heading: String, - pub alternatives: Vec, -} - -#[derive(Default, Clone, Debug, Serialize, Deserialize)] -pub struct ExpectedExcerptSet { - pub heading: String, - pub excerpts: Vec, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct ExpectedExcerpt { - pub path: PathBuf, - pub text: String, - pub required_lines: Vec, -} - #[derive(ValueEnum, Debug, Clone)] pub enum ExampleFormat { Json, @@ -132,7 +111,6 @@ impl NamedExample { cursor_position: String::new(), edit_history: String::new(), expected_patch: String::new(), - expected_context: Vec::new(), }, }; @@ -197,30 +175,10 @@ impl NamedExample { }; } Event::End(TagEnd::Heading(HeadingLevel::H3)) => { - let heading = mem::take(&mut text); - match current_section { - Section::ExpectedExcerpts => { - named.example.expected_context.push(ExpectedContextEntry { - heading, - alternatives: Vec::new(), - }); - } - _ => {} - } + mem::take(&mut text); } Event::End(TagEnd::Heading(HeadingLevel::H4)) => { - let heading = mem::take(&mut text); - match current_section { - Section::ExpectedExcerpts => { - let expected_context = &mut named.example.expected_context; - let last_entry = expected_context.last_mut().unwrap(); - last_entry.alternatives.push(ExpectedExcerptSet { - heading, - excerpts: Vec::new(), - }) - } - _ => {} - } + mem::take(&mut text); } Event::End(TagEnd::Heading(level)) => { anyhow::bail!("Unexpected heading level: {level}"); @@ -253,41 +211,7 @@ impl NamedExample { named.example.cursor_position = mem::take(&mut text); } Section::ExpectedExcerpts => { - let text = mem::take(&mut text); - for excerpt in text.split("\n…\n") { - let (mut text, required_lines) = extract_required_lines(&excerpt); - if !text.ends_with('\n') { - text.push('\n'); - } - - if named.example.expected_context.is_empty() { - named.example.expected_context.push(Default::default()); - } - - let alternatives = &mut named - .example - .expected_context - .last_mut() - .unwrap() - .alternatives; - - if alternatives.is_empty() { - alternatives.push(ExpectedExcerptSet { - heading: String::new(), - excerpts: vec![], - }); - } - - alternatives - .last_mut() - .unwrap() - .excerpts - .push(ExpectedExcerpt { - path: block_info.into(), - text, - required_lines, - }); - } + mem::take(&mut text); } Section::ExpectedPatch => { named.example.expected_patch = mem::take(&mut text); @@ -561,47 +485,6 @@ impl NamedExample { } } -fn extract_required_lines(text: &str) -> (String, Vec) { - const MARKER: &str = "[ZETA]"; - let mut new_text = String::new(); - let mut required_lines = Vec::new(); - let mut skipped_lines = 0_u32; - - for (row, mut line) in text.split('\n').enumerate() { - if let Some(marker_column) = line.find(MARKER) { - let mut strip_column = marker_column; - - while strip_column > 0 { - let prev_char = line[strip_column - 1..].chars().next().unwrap(); - if prev_char.is_whitespace() || ['/', '#'].contains(&prev_char) { - strip_column -= 1; - } else { - break; - } - } - - let metadata = &line[marker_column + MARKER.len()..]; - if metadata.contains("required") { - required_lines.push(Line(row as u32 - skipped_lines)); - } - - if strip_column == 0 { - skipped_lines += 1; - continue; - } - - line = &line[..strip_column]; - } - - new_text.push_str(line); - new_text.push('\n'); - } - - new_text.pop(); - - (new_text, required_lines) -} - async fn run_git(repo_path: &Path, args: &[&str]) -> Result { let output = smol::process::Command::new("git") .current_dir(repo_path) @@ -656,37 +539,6 @@ impl Display for NamedExample { )?; } - if !self.example.expected_context.is_empty() { - write!(f, "\n## {EXPECTED_CONTEXT_HEADING}\n\n")?; - - for entry in &self.example.expected_context { - write!(f, "\n### {}\n\n", entry.heading)?; - - let skip_h4 = - entry.alternatives.len() == 1 && entry.alternatives[0].heading.is_empty(); - - for excerpt_set in &entry.alternatives { - if !skip_h4 { - write!(f, "\n#### {}\n\n", excerpt_set.heading)?; - } - - for excerpt in &excerpt_set.excerpts { - write!( - f, - "`````{}{}\n{}`````\n\n", - excerpt - .path - .extension() - .map(|ext| format!("{} ", ext.to_string_lossy())) - .unwrap_or_default(), - excerpt.path.display(), - excerpt.text - )?; - } - } - } - } - Ok(()) } } @@ -707,38 +559,3 @@ pub async fn lock_repo(path: impl AsRef) -> OwnedMutexGuard<()> { .lock_owned() .await } - -#[cfg(test)] -mod tests { - use super::*; - use indoc::indoc; - use pretty_assertions::assert_eq; - - #[test] - fn test_extract_required_lines() { - let input = indoc! {" - zero - one // [ZETA] required - two - // [ZETA] something - three - four # [ZETA] required - five - "}; - - let expected_updated_input = indoc! {" - zero - one - two - three - four - five - "}; - - let expected_required_lines = vec![Line(1), Line(4)]; - - let (updated_input, required_lines) = extract_required_lines(input); - assert_eq!(updated_input, expected_updated_input); - assert_eq!(required_lines, expected_required_lines); - } -} diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index f87563cc34ca7631baf8195e42e4e3473f522659..d13f0710cdc4d16666594d25dc639d337fb6bdfc 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -128,8 +128,6 @@ pub struct PredictArguments { #[derive(Clone, Debug, Args)] pub struct PredictionOptions { - #[arg(long)] - use_expected_context: bool, #[clap(flatten)] zeta2: Zeta2Args, #[clap(long)] diff --git a/crates/zeta_cli/src/predict.rs b/crates/zeta_cli/src/predict.rs index a757a5faa0dbae95c4dcab58c76d50450b1d2e9f..8a1a4131fb684a5186b2111f9d922fa34d6972e1 100644 --- a/crates/zeta_cli/src/predict.rs +++ b/crates/zeta_cli/src/predict.rs @@ -1,4 +1,4 @@ -use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample}; +use crate::example::{ActualExcerpt, NamedExample}; use crate::headless::ZetaCliAppState; use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir}; use crate::{ @@ -7,16 +7,13 @@ use crate::{ use ::serde::Serialize; use anyhow::{Context, Result, anyhow}; use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock}; -use collections::HashMap; use futures::StreamExt as _; use gpui::{AppContext, AsyncApp, Entity}; -use language::{Anchor, Buffer, Point}; use project::Project; use project::buffer_store::BufferStoreEvent; use serde::Deserialize; use std::fs; use std::io::{IsTerminal, Write}; -use std::ops::Range; use std::path::PathBuf; use std::sync::Arc; use std::sync::Mutex; @@ -204,15 +201,12 @@ pub async fn perform_predict( let mut result = result.lock().unwrap(); result.generated_len = response.chars().count(); - if !options.use_expected_context { - result.planning_search_time = Some( - search_queries_generated_at.unwrap() - start_time.unwrap(), - ); - result.running_search_time = Some( - search_queries_executed_at.unwrap() - - search_queries_generated_at.unwrap(), - ); - } + result.planning_search_time = + Some(search_queries_generated_at.unwrap() - start_time.unwrap()); + result.running_search_time = Some( + search_queries_executed_at.unwrap() + - search_queries_generated_at.unwrap(), + ); result.prediction_time = prediction_finished_at - prediction_started_at; result.total_time = prediction_finished_at - start_time.unwrap(); @@ -224,37 +218,10 @@ pub async fn perform_predict( } }); - if options.use_expected_context { - let context_excerpts_tasks = example - .example - .expected_context - .iter() - .flat_map(|section| { - section.alternatives[0].excerpts.iter().map(|excerpt| { - resolve_context_entry(project.clone(), excerpt.clone(), cx.clone()) - }) - }) - .collect::>(); - let context_excerpts_vec = - futures::future::try_join_all(context_excerpts_tasks).await?; - - let mut context_excerpts = HashMap::default(); - for (buffer, mut excerpts) in context_excerpts_vec { - context_excerpts - .entry(buffer) - .or_insert(Vec::new()) - .append(&mut excerpts); - } - - zeta.update(cx, |zeta, _cx| { - zeta.set_context(project.clone(), context_excerpts) - })?; - } else { - zeta.update(cx, |zeta, cx| { - zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) - })? - .await?; - } + zeta.update(cx, |zeta, cx| { + zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) + })? + .await?; } let prediction = zeta @@ -274,38 +241,6 @@ pub async fn perform_predict( anyhow::Ok(result) } -async fn resolve_context_entry( - project: Entity, - excerpt: ExpectedExcerpt, - mut cx: AsyncApp, -) -> Result<(Entity, Vec>)> { - let buffer = project - .update(&mut cx, |project, cx| { - let project_path = project.find_project_path(&excerpt.path, cx).unwrap(); - project.open_buffer(project_path, cx) - })? - .await?; - - let ranges = buffer.read_with(&mut cx, |buffer, _| { - let full_text = buffer.text(); - let offset = full_text - .find(&excerpt.text) - .expect("Expected context not found"); - let point = buffer.offset_to_point(offset); - excerpt - .required_lines - .iter() - .map(|line| { - let row = point.row + line.0; - let range = Point::new(row, 0)..Point::new(row + 1, 0); - buffer.anchor_after(range.start)..buffer.anchor_before(range.end) - }) - .collect() - })?; - - Ok((buffer, ranges)) -} - struct RunCache { cache_mode: CacheMode, example_run_dir: PathBuf,