zeta2: Remove expected context from evals (#43430)

Ben Kunkle created

Closes #ISSUE

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Change summary

crates/zeta_cli/src/evaluate.rs | 114 --------------------
crates/zeta_cli/src/example.rs  | 189 ----------------------------------
crates/zeta_cli/src/main.rs     |   2 
crates/zeta_cli/src/predict.rs  |  87 ++--------------
4 files changed, 18 insertions(+), 374 deletions(-)

Detailed changes

crates/zeta_cli/src/evaluate.rs šŸ”—

@@ -1,5 +1,5 @@
 use std::{
-    collections::{BTreeSet, HashMap},
+    collections::HashMap,
     io::{IsTerminal, Write},
     sync::Arc,
 };
@@ -125,21 +125,10 @@ fn write_aggregated_scores(
             .peekable();
         let has_edit_predictions = edit_predictions.peek().is_some();
         let aggregated_result = EvaluationResult {
-            context: Scores::aggregate(successful.iter().map(|r| &r.context)),
             edit_prediction: has_edit_predictions.then(|| Scores::aggregate(edit_predictions)),
             prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
             generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
                 / successful.len(),
-            context_lines_found_in_context: successful
-                .iter()
-                .map(|r| r.context_lines_found_in_context)
-                .sum::<usize>()
-                / successful.len(),
-            context_lines_in_expected_patch: successful
-                .iter()
-                .map(|r| r.context_lines_in_expected_patch)
-                .sum::<usize>()
-                / successful.len(),
         };
 
         writeln!(w, "\n{}", "-".repeat(80))?;
@@ -261,11 +250,8 @@ fn write_eval_result(
 #[derive(Debug, Default)]
 pub struct EvaluationResult {
     pub edit_prediction: Option<Scores>,
-    pub context: Scores,
     pub prompt_len: usize,
     pub generated_len: usize,
-    pub context_lines_in_expected_patch: usize,
-    pub context_lines_found_in_context: usize,
 }
 
 #[derive(Default, Debug)]
@@ -363,14 +349,6 @@ impl std::fmt::Display for EvaluationResult {
 
 impl EvaluationResult {
     fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        write!(
-            f,
-            r#"
-### Context Scores
-{}
-"#,
-            self.context.to_markdown(),
-        )?;
         if let Some(prediction) = &self.edit_prediction {
             write!(
                 f,
@@ -387,34 +365,18 @@ impl EvaluationResult {
         writeln!(f, "### Scores\n")?;
         writeln!(
             f,
-            "                   Prompt  Generated RetrievedContext PatchContext     TP     FP     FN     Precision   Recall     F1"
+            "                   Prompt  Generated  TP     FP     FN     Precision   Recall      F1"
         )?;
         writeln!(
             f,
-            "─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────"
-        )?;
-        writeln!(
-            f,
-            "Context Retrieval  {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
-            "",
-            "",
-            "",
-            "",
-            self.context.true_positives,
-            self.context.false_positives,
-            self.context.false_negatives,
-            self.context.precision() * 100.0,
-            self.context.recall() * 100.0,
-            self.context.f1_score() * 100.0
+            "───────────────────────────────────────────────────────────────────────────────────────────────"
         )?;
         if let Some(edit_prediction) = &self.edit_prediction {
             writeln!(
                 f,
-                "Edit Prediction    {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
+                "Edit Prediction    {:<7} {:<9}  {:<6} {:<6} {:<6} {:>9.2} {:>8.2} {:>7.2}",
                 self.prompt_len,
                 self.generated_len,
-                self.context_lines_found_in_context,
-                self.context_lines_in_expected_patch,
                 edit_prediction.true_positives,
                 edit_prediction.false_positives,
                 edit_prediction.false_negatives,
@@ -434,53 +396,6 @@ fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> Eval
         ..Default::default()
     };
 
-    let actual_context_lines: HashSet<_> = preds
-        .excerpts
-        .iter()
-        .flat_map(|excerpt| {
-            excerpt
-                .text
-                .lines()
-                .map(|line| format!("{}: {line}", excerpt.path.display()))
-        })
-        .collect();
-
-    let mut false_positive_lines = actual_context_lines.clone();
-
-    for entry in &example.expected_context {
-        let mut best_alternative_score: Option<Scores> = None;
-
-        for alternative in &entry.alternatives {
-            let expected: HashSet<_> = alternative
-                .excerpts
-                .iter()
-                .flat_map(|excerpt| {
-                    excerpt
-                        .text
-                        .lines()
-                        .map(|line| format!("{}: {line}", excerpt.path.display()))
-                })
-                .collect();
-
-            let scores = Scores::new(&expected, &actual_context_lines);
-
-            false_positive_lines.retain(|line| !expected.contains(line));
-
-            if best_alternative_score
-                .as_ref()
-                .is_none_or(|best| scores.recall() > best.recall())
-            {
-                best_alternative_score = Some(scores);
-            }
-        }
-
-        let best_alternative = best_alternative_score.unwrap_or_default();
-        eval_result.context.false_negatives += best_alternative.false_negatives;
-        eval_result.context.true_positives += best_alternative.true_positives;
-    }
-
-    eval_result.context.false_positives = false_positive_lines.len();
-
     if predict {
         // todo: alternatives for patches
         let expected_patch = example
@@ -493,25 +408,6 @@ fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> Eval
             .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
             .map(|line| line.to_string())
             .collect();
-        let expected_context_lines = expected_patch
-            .iter()
-            .filter_map(|line| {
-                if let DiffLine::Context(str) = line {
-                    Some(String::from(*str))
-                } else {
-                    None
-                }
-            })
-            .collect::<BTreeSet<_>>();
-        let actual_context_lines = preds
-            .excerpts
-            .iter()
-            .flat_map(|excerpt| excerpt.text.lines().map(ToOwned::to_owned))
-            .collect::<BTreeSet<_>>();
-
-        let matched = expected_context_lines
-            .intersection(&actual_context_lines)
-            .count();
 
         let actual_patch_lines = preds
             .diff
@@ -522,8 +418,6 @@ fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> Eval
             .collect();
 
         eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines));
-        eval_result.context_lines_in_expected_patch = expected_context_lines.len();
-        eval_result.context_lines_found_in_context = matched;
     }
 
     eval_result

crates/zeta_cli/src/example.rs šŸ”—

@@ -14,7 +14,6 @@ use anyhow::{Context as _, Result, anyhow};
 use clap::ValueEnum;
 use cloud_zeta2_prompt::CURSOR_MARKER;
 use collections::HashMap;
-use edit_prediction_context::Line;
 use futures::{
     AsyncWriteExt as _,
     lock::{Mutex, OwnedMutexGuard},
@@ -53,7 +52,6 @@ pub struct Example {
     pub cursor_position: String,
     pub edit_history: String,
     pub expected_patch: String,
-    pub expected_context: Vec<ExpectedContextEntry>,
 }
 
 pub type ActualExcerpt = Excerpt;
@@ -64,25 +62,6 @@ pub struct Excerpt {
     pub text: String,
 }
 
-#[derive(Default, Clone, Debug, Serialize, Deserialize)]
-pub struct ExpectedContextEntry {
-    pub heading: String,
-    pub alternatives: Vec<ExpectedExcerptSet>,
-}
-
-#[derive(Default, Clone, Debug, Serialize, Deserialize)]
-pub struct ExpectedExcerptSet {
-    pub heading: String,
-    pub excerpts: Vec<ExpectedExcerpt>,
-}
-
-#[derive(Clone, Debug, Serialize, Deserialize)]
-pub struct ExpectedExcerpt {
-    pub path: PathBuf,
-    pub text: String,
-    pub required_lines: Vec<Line>,
-}
-
 #[derive(ValueEnum, Debug, Clone)]
 pub enum ExampleFormat {
     Json,
@@ -132,7 +111,6 @@ impl NamedExample {
                 cursor_position: String::new(),
                 edit_history: String::new(),
                 expected_patch: String::new(),
-                expected_context: Vec::new(),
             },
         };
 
@@ -197,30 +175,10 @@ impl NamedExample {
                     };
                 }
                 Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
-                    let heading = mem::take(&mut text);
-                    match current_section {
-                        Section::ExpectedExcerpts => {
-                            named.example.expected_context.push(ExpectedContextEntry {
-                                heading,
-                                alternatives: Vec::new(),
-                            });
-                        }
-                        _ => {}
-                    }
+                    mem::take(&mut text);
                 }
                 Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
-                    let heading = mem::take(&mut text);
-                    match current_section {
-                        Section::ExpectedExcerpts => {
-                            let expected_context = &mut named.example.expected_context;
-                            let last_entry = expected_context.last_mut().unwrap();
-                            last_entry.alternatives.push(ExpectedExcerptSet {
-                                heading,
-                                excerpts: Vec::new(),
-                            })
-                        }
-                        _ => {}
-                    }
+                    mem::take(&mut text);
                 }
                 Event::End(TagEnd::Heading(level)) => {
                     anyhow::bail!("Unexpected heading level: {level}");
@@ -253,41 +211,7 @@ impl NamedExample {
                             named.example.cursor_position = mem::take(&mut text);
                         }
                         Section::ExpectedExcerpts => {
-                            let text = mem::take(&mut text);
-                            for excerpt in text.split("\n…\n") {
-                                let (mut text, required_lines) = extract_required_lines(&excerpt);
-                                if !text.ends_with('\n') {
-                                    text.push('\n');
-                                }
-
-                                if named.example.expected_context.is_empty() {
-                                    named.example.expected_context.push(Default::default());
-                                }
-
-                                let alternatives = &mut named
-                                    .example
-                                    .expected_context
-                                    .last_mut()
-                                    .unwrap()
-                                    .alternatives;
-
-                                if alternatives.is_empty() {
-                                    alternatives.push(ExpectedExcerptSet {
-                                        heading: String::new(),
-                                        excerpts: vec![],
-                                    });
-                                }
-
-                                alternatives
-                                    .last_mut()
-                                    .unwrap()
-                                    .excerpts
-                                    .push(ExpectedExcerpt {
-                                        path: block_info.into(),
-                                        text,
-                                        required_lines,
-                                    });
-                            }
+                            mem::take(&mut text);
                         }
                         Section::ExpectedPatch => {
                             named.example.expected_patch = mem::take(&mut text);
@@ -561,47 +485,6 @@ impl NamedExample {
     }
 }
 
-fn extract_required_lines(text: &str) -> (String, Vec<Line>) {
-    const MARKER: &str = "[ZETA]";
-    let mut new_text = String::new();
-    let mut required_lines = Vec::new();
-    let mut skipped_lines = 0_u32;
-
-    for (row, mut line) in text.split('\n').enumerate() {
-        if let Some(marker_column) = line.find(MARKER) {
-            let mut strip_column = marker_column;
-
-            while strip_column > 0 {
-                let prev_char = line[strip_column - 1..].chars().next().unwrap();
-                if prev_char.is_whitespace() || ['/', '#'].contains(&prev_char) {
-                    strip_column -= 1;
-                } else {
-                    break;
-                }
-            }
-
-            let metadata = &line[marker_column + MARKER.len()..];
-            if metadata.contains("required") {
-                required_lines.push(Line(row as u32 - skipped_lines));
-            }
-
-            if strip_column == 0 {
-                skipped_lines += 1;
-                continue;
-            }
-
-            line = &line[..strip_column];
-        }
-
-        new_text.push_str(line);
-        new_text.push('\n');
-    }
-
-    new_text.pop();
-
-    (new_text, required_lines)
-}
-
 async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
     let output = smol::process::Command::new("git")
         .current_dir(repo_path)
@@ -656,37 +539,6 @@ impl Display for NamedExample {
             )?;
         }
 
-        if !self.example.expected_context.is_empty() {
-            write!(f, "\n## {EXPECTED_CONTEXT_HEADING}\n\n")?;
-
-            for entry in &self.example.expected_context {
-                write!(f, "\n### {}\n\n", entry.heading)?;
-
-                let skip_h4 =
-                    entry.alternatives.len() == 1 && entry.alternatives[0].heading.is_empty();
-
-                for excerpt_set in &entry.alternatives {
-                    if !skip_h4 {
-                        write!(f, "\n#### {}\n\n", excerpt_set.heading)?;
-                    }
-
-                    for excerpt in &excerpt_set.excerpts {
-                        write!(
-                            f,
-                            "`````{}{}\n{}`````\n\n",
-                            excerpt
-                                .path
-                                .extension()
-                                .map(|ext| format!("{} ", ext.to_string_lossy()))
-                                .unwrap_or_default(),
-                            excerpt.path.display(),
-                            excerpt.text
-                        )?;
-                    }
-                }
-            }
-        }
-
         Ok(())
     }
 }
@@ -707,38 +559,3 @@ pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
         .lock_owned()
         .await
 }
-
-#[cfg(test)]
-mod tests {
-    use super::*;
-    use indoc::indoc;
-    use pretty_assertions::assert_eq;
-
-    #[test]
-    fn test_extract_required_lines() {
-        let input = indoc! {"
-            zero
-            one // [ZETA] required
-            two
-            // [ZETA] something
-            three
-            four # [ZETA] required
-            five
-        "};
-
-        let expected_updated_input = indoc! {"
-            zero
-            one
-            two
-            three
-            four
-            five
-        "};
-
-        let expected_required_lines = vec![Line(1), Line(4)];
-
-        let (updated_input, required_lines) = extract_required_lines(input);
-        assert_eq!(updated_input, expected_updated_input);
-        assert_eq!(required_lines, expected_required_lines);
-    }
-}

crates/zeta_cli/src/main.rs šŸ”—

@@ -128,8 +128,6 @@ pub struct PredictArguments {
 
 #[derive(Clone, Debug, Args)]
 pub struct PredictionOptions {
-    #[arg(long)]
-    use_expected_context: bool,
     #[clap(flatten)]
     zeta2: Zeta2Args,
     #[clap(long)]

crates/zeta_cli/src/predict.rs šŸ”—

@@ -1,4 +1,4 @@
-use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
+use crate::example::{ActualExcerpt, NamedExample};
 use crate::headless::ZetaCliAppState;
 use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
 use crate::{
@@ -7,16 +7,13 @@ use crate::{
 use ::serde::Serialize;
 use anyhow::{Context, Result, anyhow};
 use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
-use collections::HashMap;
 use futures::StreamExt as _;
 use gpui::{AppContext, AsyncApp, Entity};
-use language::{Anchor, Buffer, Point};
 use project::Project;
 use project::buffer_store::BufferStoreEvent;
 use serde::Deserialize;
 use std::fs;
 use std::io::{IsTerminal, Write};
-use std::ops::Range;
 use std::path::PathBuf;
 use std::sync::Arc;
 use std::sync::Mutex;
@@ -204,15 +201,12 @@ pub async fn perform_predict(
                             let mut result = result.lock().unwrap();
                             result.generated_len = response.chars().count();
 
-                            if !options.use_expected_context {
-                                result.planning_search_time = Some(
-                                    search_queries_generated_at.unwrap() - start_time.unwrap(),
-                                );
-                                result.running_search_time = Some(
-                                    search_queries_executed_at.unwrap()
-                                        - search_queries_generated_at.unwrap(),
-                                );
-                            }
+                            result.planning_search_time =
+                                Some(search_queries_generated_at.unwrap() - start_time.unwrap());
+                            result.running_search_time = Some(
+                                search_queries_executed_at.unwrap()
+                                    - search_queries_generated_at.unwrap(),
+                            );
                             result.prediction_time = prediction_finished_at - prediction_started_at;
                             result.total_time = prediction_finished_at - start_time.unwrap();
 
@@ -224,37 +218,10 @@ pub async fn perform_predict(
             }
         });
 
-        if options.use_expected_context {
-            let context_excerpts_tasks = example
-                .example
-                .expected_context
-                .iter()
-                .flat_map(|section| {
-                    section.alternatives[0].excerpts.iter().map(|excerpt| {
-                        resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
-                    })
-                })
-                .collect::<Vec<_>>();
-            let context_excerpts_vec =
-                futures::future::try_join_all(context_excerpts_tasks).await?;
-
-            let mut context_excerpts = HashMap::default();
-            for (buffer, mut excerpts) in context_excerpts_vec {
-                context_excerpts
-                    .entry(buffer)
-                    .or_insert(Vec::new())
-                    .append(&mut excerpts);
-            }
-
-            zeta.update(cx, |zeta, _cx| {
-                zeta.set_context(project.clone(), context_excerpts)
-            })?;
-        } else {
-            zeta.update(cx, |zeta, cx| {
-                zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
-            })?
-            .await?;
-        }
+        zeta.update(cx, |zeta, cx| {
+            zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
+        })?
+        .await?;
     }
 
     let prediction = zeta
@@ -274,38 +241,6 @@ pub async fn perform_predict(
     anyhow::Ok(result)
 }
 
-async fn resolve_context_entry(
-    project: Entity<Project>,
-    excerpt: ExpectedExcerpt,
-    mut cx: AsyncApp,
-) -> Result<(Entity<Buffer>, Vec<Range<Anchor>>)> {
-    let buffer = project
-        .update(&mut cx, |project, cx| {
-            let project_path = project.find_project_path(&excerpt.path, cx).unwrap();
-            project.open_buffer(project_path, cx)
-        })?
-        .await?;
-
-    let ranges = buffer.read_with(&mut cx, |buffer, _| {
-        let full_text = buffer.text();
-        let offset = full_text
-            .find(&excerpt.text)
-            .expect("Expected context not found");
-        let point = buffer.offset_to_point(offset);
-        excerpt
-            .required_lines
-            .iter()
-            .map(|line| {
-                let row = point.row + line.0;
-                let range = Point::new(row, 0)..Point::new(row + 1, 0);
-                buffer.anchor_after(range.start)..buffer.anchor_before(range.end)
-            })
-            .collect()
-    })?;
-
-    Ok((buffer, ranges))
-}
-
 struct RunCache {
     cache_mode: CacheMode,
     example_run_dir: PathBuf,