Evaluate zeta2 context retrieval and edit predictions (#41921)

Oleksiy Syvokon , Piotr Osiewicz , and Agus Zubiaga created

This PR implements the `zeta-cli eval` command. It will:

- Run the edit prediction model if there are no cached results
- Compute precision/recall/F1 for context retrieval at the line level:
every retrieved line of context is counted as a true positive (correct
retrieval), false positive (retrieved something that was not expected),
or false negative (didn't retrieve an expected line)
- Compute similar metrics for edit predictions
- Pretty-print results, highlighting the difference between actual and
expected when printing to tty

Other changes:
- `zeta-cli predict` accepts a `--format` argument with options `md`,
`json`, `diff`
- Code restructure

Release Notes:

- N/A

---------

Co-authored-by: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com>
Co-authored-by: Agus Zubiaga <agus@zed.dev>

Change summary

crates/cloud_llm_client/src/udiff.rs |  36 ++
crates/zeta_cli/src/evaluate.rs      | 338 +++++++++++++++++++++++++++++
crates/zeta_cli/src/example.rs       |  48 +++
crates/zeta_cli/src/main.rs          | 297 ++++----------------------
crates/zeta_cli/src/predict.rs       | 287 +++++++++++++++++++++++++
5 files changed, 742 insertions(+), 264 deletions(-)

Detailed changes

crates/cloud_llm_client/src/udiff.rs 🔗

@@ -1,4 +1,4 @@
-use std::borrow::Cow;
+use std::{borrow::Cow, fmt::Display};
 
 #[derive(Debug, PartialEq)]
 pub enum DiffLine<'a> {
@@ -8,7 +8,7 @@ pub enum DiffLine<'a> {
     Context(&'a str),
     Deletion(&'a str),
     Addition(&'a str),
-    Garbage,
+    Garbage(&'a str),
 }
 
 #[derive(Debug, PartialEq)]
@@ -21,7 +21,7 @@ pub struct HunkLocation {
 
 impl<'a> DiffLine<'a> {
     pub fn parse(line: &'a str) -> Self {
-        Self::try_parse(line).unwrap_or(Self::Garbage)
+        Self::try_parse(line).unwrap_or(Self::Garbage(line))
     }
 
     fn try_parse(line: &'a str) -> Option<Self> {
@@ -60,6 +60,30 @@ impl<'a> DiffLine<'a> {
     }
 }
 
+impl<'a> Display for DiffLine<'a> {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        match self {
+            DiffLine::OldPath { path } => write!(f, "--- {path}"),
+            DiffLine::NewPath { path } => write!(f, "+++ {path}"),
+            DiffLine::HunkHeader(Some(hunk_location)) => {
+                write!(
+                    f,
+                    "@@ -{},{} +{},{} @@",
+                    hunk_location.start_line_old + 1,
+                    hunk_location.count_old,
+                    hunk_location.start_line_new + 1,
+                    hunk_location.count_new
+                )
+            }
+            DiffLine::HunkHeader(None) => write!(f, "@@ ... @@"),
+            DiffLine::Context(content) => write!(f, " {content}"),
+            DiffLine::Deletion(content) => write!(f, "-{content}"),
+            DiffLine::Addition(content) => write!(f, "+{content}"),
+            DiffLine::Garbage(line) => write!(f, "{line}"),
+        }
+    }
+}
+
 fn parse_header_path<'a>(strip_prefix: &'static str, header: &'a str) -> Cow<'a, str> {
     if !header.contains(['"', '\\']) {
         let path = header.split_ascii_whitespace().next().unwrap_or(header);
@@ -134,8 +158,8 @@ mod tests {
         pretty_assertions::assert_eq!(
             lines,
             &[
-                DiffLine::Garbage,
-                DiffLine::Garbage,
+                DiffLine::Garbage("diff --git a/text.txt b/text.txt"),
+                DiffLine::Garbage("index 86c770d..a1fd855 100644"),
                 DiffLine::OldPath {
                     path: "file.txt".into()
                 },
@@ -151,7 +175,7 @@ mod tests {
                 DiffLine::Context("context"),
                 DiffLine::Deletion("deleted"),
                 DiffLine::Addition("inserted"),
-                DiffLine::Garbage,
+                DiffLine::Garbage("garbage"),
                 DiffLine::Context(""),
                 DiffLine::OldPath {
                     path: "b/file.txt".into()

crates/zeta_cli/src/evaluate.rs 🔗

@@ -0,0 +1,338 @@
+use std::{
+    fs,
+    io::IsTerminal,
+    path::{Path, PathBuf},
+    sync::Arc,
+};
+
+use anyhow::Result;
+use clap::Args;
+use cloud_llm_client::udiff::DiffLine;
+use collections::HashSet;
+use gpui::AsyncApp;
+
+use crate::{
+    example::{Example, NamedExample},
+    headless::ZetaCliAppState,
+    predict::{PredictionDetails, zeta2_predict},
+};
+
+#[derive(Debug, Args)]
+pub struct EvaluateArguments {
+    example_paths: Vec<PathBuf>,
+    #[clap(long)]
+    re_run: bool,
+}
+
+pub async fn run_evaluate(
+    args: EvaluateArguments,
+    app_state: &Arc<ZetaCliAppState>,
+    cx: &mut AsyncApp,
+) {
+    let example_len = args.example_paths.len();
+    let all_tasks = args.example_paths.into_iter().map(|path| {
+        let app_state = app_state.clone();
+        cx.spawn(async move |cx| run_evaluate_one(&path, args.re_run, app_state.clone(), cx).await)
+    });
+    let all_results = futures::future::try_join_all(all_tasks).await.unwrap();
+
+    let aggregated_result = EvaluationResult {
+        context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
+        edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
+    };
+
+    if example_len > 1 {
+        println!("\n{}", "-".repeat(80));
+        println!("# TOTAL SCORES:");
+        println!("{}", aggregated_result.to_markdown());
+    }
+}
+
+pub async fn run_evaluate_one(
+    example_path: &Path,
+    re_run: bool,
+    app_state: Arc<ZetaCliAppState>,
+    cx: &mut AsyncApp,
+) -> Result<EvaluationResult> {
+    let cache_dir = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap_or_default())
+        .join("../../target/zeta-prediction-cache");
+    let example = NamedExample::load(&example_path).unwrap();
+    let example_cache_path = cache_dir.join(&example_path.file_name().unwrap());
+
+    let predictions = if !re_run && example_cache_path.exists() {
+        let file_contents = fs::read_to_string(&example_cache_path)?;
+        let as_json = serde_json::from_str::<PredictionDetails>(&file_contents)?;
+        log::debug!(
+            "Loaded predictions from cache: {}",
+            example_cache_path.display()
+        );
+        as_json
+    } else {
+        zeta2_predict(example.clone(), &app_state, cx)
+            .await
+            .unwrap()
+    };
+
+    if !example_cache_path.exists() {
+        fs::create_dir_all(&cache_dir).unwrap();
+        fs::write(
+            example_cache_path,
+            serde_json::to_string(&predictions).unwrap(),
+        )
+        .unwrap();
+    }
+
+    let evaluation_result = evaluate(&example.example, &predictions);
+
+    println!("# {}\n", example.name);
+    println!(
+        "## Expected Context: \n\n```\n{}\n```\n\n",
+        compare_context(&example.example, &predictions)
+    );
+    println!(
+        "## Expected edit prediction:\n\n```diff\n{}\n```\n",
+        compare_diffs(&example.example.expected_patch, &predictions.diff)
+    );
+    println!(
+        "## Actual edit prediction:\n\n```diff\n{}\n```\n",
+        compare_diffs(&predictions.diff, &example.example.expected_patch)
+    );
+
+    println!("{}", evaluation_result.to_markdown());
+
+    anyhow::Ok(evaluation_result)
+}
+
+#[derive(Debug, Default)]
+pub struct EvaluationResult {
+    pub context: Scores,
+    pub edit_prediction: Scores,
+}
+
+#[derive(Default, Debug)]
+pub struct Scores {
+    pub precision: f64,
+    pub recall: f64,
+    pub f1_score: f64,
+    pub true_positives: usize,
+    pub false_positives: usize,
+    pub false_negatives: usize,
+}
+
+impl Scores {
+    pub fn to_markdown(&self) -> String {
+        format!(
+            "
+Precision       : {:.4}
+Recall          : {:.4}
+F1 Score        : {:.4}
+True Positives  : {}
+False Positives : {}
+False Negatives : {}",
+            self.precision,
+            self.recall,
+            self.f1_score,
+            self.true_positives,
+            self.false_positives,
+            self.false_negatives
+        )
+    }
+}
+
+impl Scores {
+    pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
+        let mut true_positives = 0;
+        let mut false_positives = 0;
+        let mut false_negatives = 0;
+
+        for score in scores {
+            true_positives += score.true_positives;
+            false_positives += score.false_positives;
+            false_negatives += score.false_negatives;
+        }
+
+        let precision = true_positives as f64 / (true_positives + false_positives) as f64;
+        let recall = true_positives as f64 / (true_positives + false_negatives) as f64;
+        let mut f1_score = 2.0 * precision * recall / (precision + recall);
+        if f1_score.is_nan() {
+            f1_score = 0.0;
+        }
+
+        Scores {
+            precision,
+            recall,
+            f1_score,
+            true_positives,
+            false_positives,
+            false_negatives,
+        }
+    }
+}
+
+impl EvaluationResult {
+    pub fn to_markdown(&self) -> String {
+        format!(
+            r#"
+### Context Scores
+{}
+
+### Edit Prediction Scores
+{}
+"#,
+            self.context.to_markdown(),
+            self.edit_prediction.to_markdown()
+        )
+    }
+}
+
+pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
+    let mut result = EvaluationResult::default();
+
+    let expected_context_lines = example
+        .expected_excerpts
+        .iter()
+        .flat_map(|excerpt| {
+            excerpt
+                .text
+                .lines()
+                .map(|line| format!("{}: {line}", excerpt.path.display()))
+        })
+        .collect();
+    let actual_context_lines = preds
+        .excerpts
+        .iter()
+        .flat_map(|excerpt| {
+            excerpt
+                .text
+                .lines()
+                .map(|line| format!("{}: {line}", excerpt.path.display()))
+        })
+        .collect();
+
+    result.context = precision_recall(&expected_context_lines, &actual_context_lines);
+
+    let expected_patch_lines = example
+        .expected_patch
+        .lines()
+        .map(DiffLine::parse)
+        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
+        .map(|line| line.to_string())
+        .collect();
+
+    let actual_patch_lines = preds
+        .diff
+        .lines()
+        .map(DiffLine::parse)
+        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
+        .map(|line| line.to_string())
+        .collect();
+
+    result.edit_prediction = precision_recall(&expected_patch_lines, &actual_patch_lines);
+
+    result
+}
+
+fn precision_recall(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
+    let true_positives = expected.intersection(actual).count();
+    let false_positives = actual.difference(expected).count();
+    let false_negatives = expected.difference(actual).count();
+
+    let precision = if true_positives + false_positives == 0 {
+        0.0
+    } else {
+        true_positives as f64 / (true_positives + false_positives) as f64
+    };
+    let recall = if true_positives + false_negatives == 0 {
+        0.0
+    } else {
+        true_positives as f64 / (true_positives + false_negatives) as f64
+    };
+    let f1_score = if precision + recall == 0.0 {
+        0.0
+    } else {
+        2.0 * precision * recall / (precision + recall)
+    };
+
+    Scores {
+        precision,
+        recall,
+        f1_score,
+        true_positives,
+        false_positives,
+        false_negatives,
+    }
+}
+
+/// Compare actual and expected context.
+///
+/// Return expected context annotated with these markers:
+///
+/// `✓ context line`  -- line was correctly predicted
+/// `✗ context line`  -- line is missing from predictions
+pub fn compare_context(example: &Example, preds: &PredictionDetails) -> String {
+    let use_color = std::io::stdout().is_terminal();
+    let green = if use_color { "\x1b[32m" } else { "" };
+    let red = if use_color { "\x1b[31m" } else { "" };
+    let reset = if use_color { "\x1b[0m" } else { "" };
+    let expected: Vec<_> = example
+        .expected_excerpts
+        .iter()
+        .flat_map(|excerpt| {
+            excerpt
+                .text
+                .lines()
+                .map(|line| (excerpt.path.clone(), line))
+        })
+        .collect();
+    let actual: HashSet<_> = preds
+        .excerpts
+        .iter()
+        .flat_map(|excerpt| {
+            excerpt
+                .text
+                .lines()
+                .map(|line| (excerpt.path.clone(), line))
+        })
+        .collect();
+
+    let annotated = expected
+        .iter()
+        .map(|(path, line)| {
+            if actual.contains(&(path.to_path_buf(), line)) {
+                format!("{green}✓ {line}{reset}")
+            } else {
+                format!("{red}✗ {line}{reset}")
+            }
+        })
+        .collect::<Vec<String>>();
+
+    annotated.join("\n")
+}
+
+/// Return annotated `patch_a` so that:
+/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
+/// Additions and deletions that are present in `patch_b` will be highlighted in green.
+pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
+    let use_color = std::io::stdout().is_terminal();
+    let green = if use_color { "\x1b[32m✓ " } else { "" };
+    let red = if use_color { "\x1b[31m✗ " } else { "" };
+    let neutral = if use_color { "  " } else { "" };
+    let reset = if use_color { "\x1b[0m" } else { "" };
+    let lines_a = patch_a.lines().map(DiffLine::parse);
+    let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
+
+    let annotated = lines_a
+        .map(|line| match line {
+            DiffLine::Addition(_) | DiffLine::Deletion(_) => {
+                if lines_b.contains(&line) {
+                    format!("{green}{line}{reset}")
+                } else {
+                    format!("{red}{line}{reset}")
+                }
+            }
+            _ => format!("{neutral}{line}{reset}"),
+        })
+        .collect::<Vec<String>>();
+
+    annotated.join("\n")
+}

crates/zeta_cli/src/example.rs 🔗

@@ -1,5 +1,6 @@
 use std::{
     borrow::Cow,
+    cell::RefCell,
     env,
     fmt::{self, Display},
     fs,
@@ -7,12 +8,16 @@ use std::{
     mem,
     ops::Range,
     path::{Path, PathBuf},
+    sync::Arc,
 };
 
 use anyhow::{Context as _, Result};
 use clap::ValueEnum;
-use collections::HashSet;
-use futures::AsyncWriteExt as _;
+use collections::{HashMap, HashSet};
+use futures::{
+    AsyncWriteExt as _,
+    lock::{Mutex, OwnedMutexGuard},
+};
 use gpui::{AsyncApp, Entity, http_client::Url};
 use language::Buffer;
 use project::{Project, ProjectPath};
@@ -27,13 +32,13 @@ const EXPECTED_EXCERPTS_HEADING: &str = "Expected Excerpts";
 const REPOSITORY_URL_FIELD: &str = "repository_url";
 const REVISION_FIELD: &str = "revision";
 
-#[derive(Debug)]
+#[derive(Debug, Clone)]
 pub struct NamedExample {
     pub name: String,
     pub example: Example,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Clone, Debug, Serialize, Deserialize)]
 pub struct Example {
     pub repository_url: String,
     pub revision: String,
@@ -45,10 +50,13 @@ pub struct Example {
     pub expected_excerpts: Vec<ExpectedExcerpt>,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
-pub struct ExpectedExcerpt {
-    path: PathBuf,
-    text: String,
+pub type ExpectedExcerpt = Excerpt;
+pub type ActualExcerpt = Excerpt;
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct Excerpt {
+    pub path: PathBuf,
+    pub text: String,
 }
 
 #[derive(ValueEnum, Debug, Clone)]
@@ -171,6 +179,7 @@ impl NamedExample {
                     } else if current_section.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
                         named.example.expected_patch = mem::take(&mut text);
                     } else if current_section.eq_ignore_ascii_case(EXPECTED_EXCERPTS_HEADING) {
+                        // TODO: "…" should not be a part of the excerpt
                         named.example.expected_excerpts.push(ExpectedExcerpt {
                             path: block_info.into(),
                             text: mem::take(&mut text),
@@ -202,7 +211,6 @@ impl NamedExample {
         }
     }
 
-    #[allow(unused)]
     pub async fn setup_worktree(&self) -> Result<PathBuf> {
         let (repo_owner, repo_name) = self.repo_name()?;
         let file_name = self.file_name();
@@ -213,6 +221,8 @@ impl NamedExample {
         fs::create_dir_all(&worktrees_dir)?;
 
         let repo_dir = repos_dir.join(repo_owner.as_ref()).join(repo_name.as_ref());
+        let repo_lock = lock_repo(&repo_dir).await;
+
         if !repo_dir.is_dir() {
             fs::create_dir_all(&repo_dir)?;
             run_git(&repo_dir, &["init"]).await?;
@@ -255,6 +265,7 @@ impl NamedExample {
             )
             .await?;
         }
+        drop(repo_lock);
 
         // Apply the uncommitted diff for this example.
         if !self.example.uncommitted_diff.is_empty() {
@@ -419,6 +430,23 @@ impl Display for NamedExample {
     }
 }
 
+thread_local! {
+    static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
+}
+
+#[must_use]
+pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
+    REPO_LOCKS
+        .with(|cell| {
+            cell.borrow_mut()
+                .entry(path.as_ref().to_path_buf())
+                .or_default()
+                .clone()
+        })
+        .lock_owned()
+        .await
+}
+
 #[must_use]
 pub async fn apply_diff(
     diff: &str,
@@ -487,7 +515,7 @@ pub async fn apply_diff(
                     });
                 }
             }
-            DiffLine::HunkHeader(_) | DiffLine::Garbage => {}
+            DiffLine::HunkHeader(_) | DiffLine::Garbage(_) => {}
         }
 
         let at_hunk_end = match diff_lines.peek() {

crates/zeta_cli/src/main.rs 🔗

@@ -1,14 +1,17 @@
+mod evaluate;
 mod example;
 mod headless;
+mod predict;
 mod source_location;
 mod syntax_retrieval_stats;
 mod util;
 
+use crate::evaluate::{EvaluateArguments, run_evaluate};
 use crate::example::{ExampleFormat, NamedExample};
+use crate::predict::{PredictArguments, run_zeta2_predict};
 use crate::syntax_retrieval_stats::retrieval_stats;
 use ::serde::Serialize;
 use ::util::paths::PathStyle;
-use ::util::rel_path::RelPath;
 use anyhow::{Context as _, Result, anyhow};
 use clap::{Args, Parser, Subcommand};
 use cloud_llm_client::predict_edits_v3::{self, Excerpt};
@@ -22,12 +25,12 @@ use futures::channel::mpsc;
 use gpui::{Application, AsyncApp, Entity, prelude::*};
 use language::{Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point};
 use language_model::LanguageModelRegistry;
-use project::{Project, ProjectPath, Worktree};
+use project::{Project, Worktree};
 use reqwest_client::ReqwestClient;
 use serde_json::json;
-use std::io;
-use std::time::{Duration, Instant};
-use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc};
+use std::io::{self};
+use std::time::Duration;
+use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
 use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery};
 
 use crate::headless::ZetaCliAppState;
@@ -82,9 +85,8 @@ enum Zeta2Command {
         #[command(subcommand)]
         command: Zeta2LlmCommand,
     },
-    Predict {
-        example_path: PathBuf,
-    },
+    Predict(PredictArguments),
+    Eval(EvaluateArguments),
 }
 
 #[derive(Subcommand, Debug)]
@@ -327,208 +329,6 @@ async fn load_context(
     })
 }
 
-async fn zeta2_predict(
-    example: NamedExample,
-    app_state: &Arc<ZetaCliAppState>,
-    cx: &mut AsyncApp,
-) -> Result<()> {
-    let worktree_path = example.setup_worktree().await?;
-
-    cx.update(|cx| {
-        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
-            registry
-                .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
-                .unwrap()
-                .authenticate(cx)
-        })
-    })?
-    .await?;
-
-    app_state
-        .client
-        .sign_in_with_optional_connect(true, cx)
-        .await?;
-
-    let project = cx.update(|cx| {
-        Project::local(
-            app_state.client.clone(),
-            app_state.node_runtime.clone(),
-            app_state.user_store.clone(),
-            app_state.languages.clone(),
-            app_state.fs.clone(),
-            None,
-            cx,
-        )
-    })?;
-
-    let worktree = project
-        .update(cx, |project, cx| {
-            project.create_worktree(&worktree_path, true, cx)
-        })?
-        .await?;
-    worktree
-        .read_with(cx, |worktree, _cx| {
-            worktree.as_local().unwrap().scan_complete()
-        })?
-        .await;
-
-    let _edited_buffers = example.apply_edit_history(&project, cx).await?;
-
-    let cursor_path = RelPath::new(&example.example.cursor_path, PathStyle::Posix)?.into_arc();
-
-    let cursor_buffer = project
-        .update(cx, |project, cx| {
-            project.open_buffer(
-                ProjectPath {
-                    worktree_id: worktree.read(cx).id(),
-                    path: cursor_path,
-                },
-                cx,
-            )
-        })?
-        .await?;
-
-    let cursor_offset_within_excerpt = example
-        .example
-        .cursor_position
-        .find(CURSOR_MARKER)
-        .ok_or_else(|| anyhow!("missing cursor marker"))?;
-    let mut cursor_excerpt = example.example.cursor_position.clone();
-    cursor_excerpt.replace_range(
-        cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
-        "",
-    );
-    let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
-        let text = buffer.text();
-
-        let mut matches = text.match_indices(&cursor_excerpt);
-        let Some((excerpt_offset, _)) = matches.next() else {
-            anyhow::bail!(
-                "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n"
-            );
-        };
-        assert!(matches.next().is_none());
-
-        Ok(excerpt_offset)
-    })??;
-
-    let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
-    let cursor_anchor =
-        cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
-
-    let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
-
-    let refresh_task = zeta.update(cx, |zeta, cx| {
-        zeta.register_buffer(&cursor_buffer, &project, cx);
-        zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
-    })?;
-
-    let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
-    let mut context_retrieval_started_at = None;
-    let mut context_retrieval_finished_at = None;
-    let mut search_queries_generated_at = None;
-    let mut search_queries_executed_at = None;
-    let mut prediction_started_at = None;
-    let mut prediction_finished_at = None;
-    let mut excerpts_text = String::new();
-    let mut prediction_task = None;
-    while let Some(event) = debug_rx.next().await {
-        match event {
-            zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
-                context_retrieval_started_at = Some(info.timestamp);
-            }
-            zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
-                search_queries_generated_at = Some(info.timestamp);
-            }
-            zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
-                search_queries_executed_at = Some(info.timestamp);
-            }
-            zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
-                context_retrieval_finished_at = Some(info.timestamp);
-
-                prediction_task = Some(zeta.update(cx, |zeta, cx| {
-                    zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
-                })?);
-            }
-            zeta2::ZetaDebugInfo::EditPredicted(request) => {
-                prediction_started_at = Some(Instant::now());
-                request.response_rx.await?.map_err(|err| anyhow!(err))?;
-                prediction_finished_at = Some(Instant::now());
-
-                for included_file in request.request.included_files {
-                    let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)];
-                    write_codeblock(
-                        &included_file.path,
-                        included_file.excerpts.iter(),
-                        if included_file.path == request.request.excerpt_path {
-                            &insertions
-                        } else {
-                            &[]
-                        },
-                        included_file.max_row,
-                        false,
-                        &mut excerpts_text,
-                    );
-                }
-                break;
-            }
-            _ => {}
-        }
-    }
-
-    refresh_task.await.context("context retrieval failed")?;
-    let prediction = prediction_task.unwrap().await?.context("No prediction")?;
-
-    println!("## Excerpts\n");
-    println!("{excerpts_text}");
-
-    let old_text = prediction.snapshot.text();
-    let new_text = prediction.buffer.update(cx, |buffer, cx| {
-        buffer.edit(prediction.edits.iter().cloned(), None, cx);
-        buffer.text()
-    })?;
-    let diff = language::unified_diff(&old_text, &new_text);
-
-    println!("## Prediction\n");
-    println!("{diff}");
-
-    println!("## Time\n");
-
-    let planning_search_time =
-        search_queries_generated_at.unwrap() - context_retrieval_started_at.unwrap();
-
-    println!("Planning searches: {}ms", planning_search_time.as_millis());
-    println!(
-        "Running searches: {}ms",
-        (search_queries_executed_at.unwrap() - search_queries_generated_at.unwrap()).as_millis()
-    );
-
-    let filtering_search_time =
-        context_retrieval_finished_at.unwrap() - search_queries_executed_at.unwrap();
-    println!(
-        "Filtering context results: {}ms",
-        filtering_search_time.as_millis()
-    );
-
-    let prediction_time = prediction_finished_at.unwrap() - prediction_started_at.unwrap();
-    println!("Making Prediction: {}ms", prediction_time.as_millis());
-
-    println!("-------------------");
-    let total_time =
-        (prediction_finished_at.unwrap() - context_retrieval_started_at.unwrap()).as_millis();
-    println!("Total: {}ms", total_time);
-
-    let inference_time =
-        (planning_search_time + filtering_search_time + prediction_time).as_millis();
-    println!(
-        "Inference: {}ms ({:.2}%)",
-        inference_time,
-        (inference_time as f64 / total_time as f64) * 100.
-    );
-
-    anyhow::Ok(())
-}
-
 async fn zeta2_syntax_context(
     zeta2_args: Zeta2Args,
     syntax_args: Zeta2SyntaxArgs,
@@ -819,50 +619,62 @@ fn main() {
     app.run(move |cx| {
         let app_state = Arc::new(headless::init(cx));
         cx.spawn(async move |cx| {
-            let result = match args.command {
+            match args.command {
                 Command::Zeta1 {
                     command: Zeta1Command::Context { context_args },
                 } => {
                     let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
-                    serde_json::to_string_pretty(&context.body).map_err(|err| anyhow::anyhow!(err))
+                    let result = serde_json::to_string_pretty(&context.body).unwrap();
+                    println!("{}", result);
                 }
                 Command::Zeta2 { command } => match command {
-                    Zeta2Command::Predict { example_path } => {
-                        let example = NamedExample::load(example_path).unwrap();
-                        zeta2_predict(example, &app_state, cx).await.unwrap();
-                        let _ = cx.update(|cx| cx.quit());
-                        return;
+                    Zeta2Command::Predict(arguments) => {
+                        run_zeta2_predict(arguments, &app_state, cx).await;
+                    }
+                    Zeta2Command::Eval(arguments) => {
+                        run_evaluate(arguments, &app_state, cx).await;
                     }
                     Zeta2Command::Syntax {
                         args,
                         syntax_args,
                         command,
-                    } => match command {
-                        Zeta2SyntaxCommand::Context { context_args } => {
-                            zeta2_syntax_context(args, syntax_args, context_args, &app_state, cx)
+                    } => {
+                        let result = match command {
+                            Zeta2SyntaxCommand::Context { context_args } => {
+                                zeta2_syntax_context(
+                                    args,
+                                    syntax_args,
+                                    context_args,
+                                    &app_state,
+                                    cx,
+                                )
                                 .await
-                        }
-                        Zeta2SyntaxCommand::Stats {
-                            worktree,
-                            extension,
-                            limit,
-                            skip,
-                        } => {
-                            retrieval_stats(
+                            }
+                            Zeta2SyntaxCommand::Stats {
                                 worktree,
-                                app_state,
                                 extension,
                                 limit,
                                 skip,
-                                syntax_args_to_options(&args, &syntax_args, false),
-                                cx,
-                            )
-                            .await
-                        }
-                    },
+                            } => {
+                                retrieval_stats(
+                                    worktree,
+                                    app_state,
+                                    extension,
+                                    limit,
+                                    skip,
+                                    syntax_args_to_options(&args, &syntax_args, false),
+                                    cx,
+                                )
+                                .await
+                            }
+                        };
+                        println!("{}", result.unwrap());
+                    }
                     Zeta2Command::Llm { args, command } => match command {
                         Zeta2LlmCommand::Context { context_args } => {
-                            zeta2_llm_context(args, context_args, &app_state, cx).await
+                            let result =
+                                zeta2_llm_context(args, context_args, &app_state, cx).await;
+                            println!("{}", result.unwrap());
                         }
                     },
                 },
@@ -872,21 +684,10 @@ fn main() {
                 } => {
                     let example = NamedExample::load(path).unwrap();
                     example.write(output_format, io::stdout()).unwrap();
-                    let _ = cx.update(|cx| cx.quit());
-                    return;
                 }
             };
 
-            match result {
-                Ok(output) => {
-                    println!("{}", output);
-                    let _ = cx.update(|cx| cx.quit());
-                }
-                Err(e) => {
-                    eprintln!("Failed: {:?}", e);
-                    exit(1);
-                }
-            }
+            let _ = cx.update(|cx| cx.quit());
         })
         .detach();
     });

crates/zeta_cli/src/predict.rs 🔗

@@ -0,0 +1,287 @@
+use crate::example::{ActualExcerpt, NamedExample};
+
+use crate::headless::ZetaCliAppState;
+use ::serde::Serialize;
+use ::util::paths::PathStyle;
+use anyhow::{Context as _, Result, anyhow};
+use clap::Args;
+use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
+use futures::StreamExt as _;
+use gpui::AsyncApp;
+use language_model::LanguageModelRegistry;
+use project::{Project, ProjectPath};
+use serde::Deserialize;
+use std::cell::Cell;
+use std::io::Write;
+use std::path::PathBuf;
+use std::sync::Arc;
+use std::time::{Duration, Instant};
+use util::rel_path::RelPath;
+
+#[derive(Debug, Args)]
+pub struct PredictArguments {
+    example_path: PathBuf,
+    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
+    format: PredictionsOutputFormat,
+}
+
+#[derive(clap::ValueEnum, Debug, Clone)]
+pub enum PredictionsOutputFormat {
+    Json,
+    Md,
+    Diff,
+}
+pub async fn run_zeta2_predict(
+    args: PredictArguments,
+    app_state: &Arc<ZetaCliAppState>,
+    cx: &mut AsyncApp,
+) {
+    let example = NamedExample::load(args.example_path).unwrap();
+    let result = zeta2_predict(example, &app_state, cx).await.unwrap();
+    result.write(args.format, std::io::stdout()).unwrap();
+}
+
+thread_local! {
+    static AUTHENTICATED: Cell<bool> = const { Cell::new(false) };
+}
+
+pub async fn zeta2_predict(
+    example: NamedExample,
+    app_state: &Arc<ZetaCliAppState>,
+    cx: &mut AsyncApp,
+) -> Result<PredictionDetails> {
+    let worktree_path = example.setup_worktree().await?;
+
+    if !AUTHENTICATED.get() {
+        AUTHENTICATED.set(true);
+
+        cx.update(|cx| {
+            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+                registry
+                    .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
+                    .unwrap()
+                    .authenticate(cx)
+            })
+        })?
+        .await?;
+
+        app_state
+            .client
+            .sign_in_with_optional_connect(true, cx)
+            .await?;
+    }
+
+    let project = cx.update(|cx| {
+        Project::local(
+            app_state.client.clone(),
+            app_state.node_runtime.clone(),
+            app_state.user_store.clone(),
+            app_state.languages.clone(),
+            app_state.fs.clone(),
+            None,
+            cx,
+        )
+    })?;
+
+    let worktree = project
+        .update(cx, |project, cx| {
+            project.create_worktree(&worktree_path, true, cx)
+        })?
+        .await?;
+    worktree
+        .read_with(cx, |worktree, _cx| {
+            worktree.as_local().unwrap().scan_complete()
+        })?
+        .await;
+
+    let _edited_buffers = example.apply_edit_history(&project, cx).await?;
+
+    let cursor_path = RelPath::new(&example.example.cursor_path, PathStyle::Posix)?.into_arc();
+
+    let cursor_buffer = project
+        .update(cx, |project, cx| {
+            project.open_buffer(
+                ProjectPath {
+                    worktree_id: worktree.read(cx).id(),
+                    path: cursor_path,
+                },
+                cx,
+            )
+        })?
+        .await?;
+
+    let cursor_offset_within_excerpt = example
+        .example
+        .cursor_position
+        .find(CURSOR_MARKER)
+        .ok_or_else(|| anyhow!("missing cursor marker"))?;
+    let mut cursor_excerpt = example.example.cursor_position.clone();
+    cursor_excerpt.replace_range(
+        cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
+        "",
+    );
+    let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
+        let text = buffer.text();
+
+        let mut matches = text.match_indices(&cursor_excerpt);
+        let Some((excerpt_offset, _)) = matches.next() else {
+            anyhow::bail!(
+                "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n"
+            );
+        };
+        assert!(matches.next().is_none());
+
+        Ok(excerpt_offset)
+    })??;
+
+    let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
+    let cursor_anchor =
+        cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
+
+    let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
+
+    let refresh_task = zeta.update(cx, |zeta, cx| {
+        zeta.register_buffer(&cursor_buffer, &project, cx);
+        zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
+    })?;
+
+    let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
+    let mut context_retrieval_started_at = None;
+    let mut context_retrieval_finished_at = None;
+    let mut search_queries_generated_at = None;
+    let mut search_queries_executed_at = None;
+    let mut prediction_started_at = None;
+    let mut prediction_finished_at = None;
+    let mut excerpts_text = String::new();
+    let mut prediction_task = None;
+    let mut result = PredictionDetails::default();
+    while let Some(event) = debug_rx.next().await {
+        match event {
+            zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
+                context_retrieval_started_at = Some(info.timestamp);
+            }
+            zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
+                search_queries_generated_at = Some(info.timestamp);
+            }
+            zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
+                search_queries_executed_at = Some(info.timestamp);
+            }
+            zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
+                context_retrieval_finished_at = Some(info.timestamp);
+
+                prediction_task = Some(zeta.update(cx, |zeta, cx| {
+                    zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
+                })?);
+            }
+            zeta2::ZetaDebugInfo::EditPredicted(request) => {
+                prediction_started_at = Some(Instant::now());
+                request.response_rx.await?.map_err(|err| anyhow!(err))?;
+                prediction_finished_at = Some(Instant::now());
+
+                for included_file in request.request.included_files {
+                    let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)];
+                    result
+                        .excerpts
+                        .extend(included_file.excerpts.iter().map(|excerpt| ActualExcerpt {
+                            path: included_file.path.components().skip(1).collect(),
+                            text: String::from(excerpt.text.as_ref()),
+                        }));
+                    write_codeblock(
+                        &included_file.path,
+                        included_file.excerpts.iter(),
+                        if included_file.path == request.request.excerpt_path {
+                            &insertions
+                        } else {
+                            &[]
+                        },
+                        included_file.max_row,
+                        false,
+                        &mut excerpts_text,
+                    );
+                }
+                break;
+            }
+            _ => {}
+        }
+    }
+
+    refresh_task.await.context("context retrieval failed")?;
+    let prediction = prediction_task.unwrap().await?;
+
+    result.diff = prediction
+        .map(|prediction| {
+            let old_text = prediction.snapshot.text();
+            let new_text = prediction.buffer.update(cx, |buffer, cx| {
+                buffer.edit(prediction.edits.iter().cloned(), None, cx);
+                buffer.text()
+            })?;
+            anyhow::Ok(language::unified_diff(&old_text, &new_text))
+        })
+        .transpose()?
+        .unwrap_or_default();
+    result.excerpts_text = excerpts_text;
+
+    result.planning_search_time =
+        search_queries_generated_at.unwrap() - context_retrieval_started_at.unwrap();
+    result.running_search_time =
+        search_queries_executed_at.unwrap() - search_queries_generated_at.unwrap();
+    result.filtering_search_time =
+        context_retrieval_finished_at.unwrap() - search_queries_executed_at.unwrap();
+    result.prediction_time = prediction_finished_at.unwrap() - prediction_started_at.unwrap();
+    result.total_time = prediction_finished_at.unwrap() - context_retrieval_started_at.unwrap();
+
+    anyhow::Ok(result)
+}
+
+#[derive(Debug, Default, Serialize, Deserialize)]
+pub struct PredictionDetails {
+    pub diff: String,
+    pub excerpts: Vec<ActualExcerpt>,
+    pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
+    pub planning_search_time: Duration,
+    pub filtering_search_time: Duration,
+    pub running_search_time: Duration,
+    pub prediction_time: Duration,
+    pub total_time: Duration,
+}
+
+impl PredictionDetails {
+    pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
+        let formatted = match format {
+            PredictionsOutputFormat::Md => self.to_markdown(),
+            PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
+            PredictionsOutputFormat::Diff => self.diff.clone(),
+        };
+
+        Ok(out.write_all(formatted.as_bytes())?)
+    }
+
+    pub fn to_markdown(&self) -> String {
+        let inference_time =
+            self.planning_search_time + self.filtering_search_time + self.prediction_time;
+
+        format!(
+            "## Excerpts\n\n\
+            {}\n\n\
+            ## Prediction\n\n\
+            {}\n\n\
+            ## Time\n\n\
+            Planning searches: {}ms\n\
+            Running searches: {}ms\n\
+            Filtering context results: {}ms\n\
+            Making Prediction: {}ms\n\n\
+            -------------------\n\n\
+            Total: {}ms\n\
+            Inference: {}ms ({:.2}%)\n",
+            self.excerpts_text,
+            self.diff,
+            self.planning_search_time.as_millis(),
+            self.running_search_time.as_millis(),
+            self.filtering_search_time.as_millis(),
+            self.prediction_time.as_millis(),
+            self.total_time.as_millis(),
+            inference_time.as_millis(),
+            (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
+        )
+    }
+}