Add zeta-cli subcommand for running zeta2 predictions (#41722)

Max Brunsfeld , Agus Zubiaga , Piotr Osiewicz , and Ben Kunkle created

This PR adds a `zeta zeta2 predict` subcommand that takes an edit
prediction example markdown file as an argument, and performs zeta2's
prediction, showing the retrieved context and the predicted edit.

* [x] Apply uncommitted diff to get repo into the right state.
* [x] Apply edits in edit history
* [x] Display predicted edits as unified diff, regardless of model
output format

Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com>
Co-authored-by: Ben Kunkle <ben.kunkle@gmail.com>

Change summary

Cargo.lock                                      |   2 
crates/cloud_llm_client/src/cloud_llm_client.rs |   1 
crates/cloud_llm_client/src/udiff.rs            | 270 +++++++++
crates/zeta2/src/related_excerpts.rs            |   3 
crates/zeta2/src/zeta2.rs                       | 149 ++--
crates/zeta2_tools/src/zeta2_tools.rs           |   2 
crates/zeta_cli/Cargo.toml                      |   9 
crates/zeta_cli/src/example.rs                  | 540 +++++++++++++++++-
crates/zeta_cli/src/main.rs                     | 227 +++++++
9 files changed, 1,091 insertions(+), 112 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -21741,6 +21741,7 @@ dependencies = [
  "futures 0.3.31",
  "gpui",
  "gpui_tokio",
+ "indoc",
  "language",
  "language_extension",
  "language_model",
@@ -21751,6 +21752,7 @@ dependencies = [
  "ordered-float 2.10.1",
  "paths",
  "polars",
+ "pretty_assertions",
  "project",
  "prompt_store",
  "pulldown-cmark 0.12.2",

crates/cloud_llm_client/src/udiff.rs 🔗

@@ -0,0 +1,270 @@
+use std::borrow::Cow;
+
+#[derive(Debug, PartialEq)]
+pub enum DiffLine<'a> {
+    OldPath { path: Cow<'a, str> },
+    NewPath { path: Cow<'a, str> },
+    HunkHeader(Option<HunkLocation>),
+    Context(&'a str),
+    Deletion(&'a str),
+    Addition(&'a str),
+    Garbage,
+}
+
+#[derive(Debug, PartialEq)]
+pub struct HunkLocation {
+    start_line_old: u32,
+    count_old: u32,
+    start_line_new: u32,
+    count_new: u32,
+}
+
+impl<'a> DiffLine<'a> {
+    pub fn parse(line: &'a str) -> Self {
+        Self::try_parse(line).unwrap_or(Self::Garbage)
+    }
+
+    fn try_parse(line: &'a str) -> Option<Self> {
+        if let Some(header) = line.strip_prefix("---").and_then(eat_required_whitespace) {
+            let path = parse_header_path("a/", header);
+            Some(Self::OldPath { path })
+        } else if let Some(header) = line.strip_prefix("+++").and_then(eat_required_whitespace) {
+            Some(Self::NewPath {
+                path: parse_header_path("b/", header),
+            })
+        } else if let Some(header) = line.strip_prefix("@@").and_then(eat_required_whitespace) {
+            if header.starts_with("...") {
+                return Some(Self::HunkHeader(None));
+            }
+
+            let (start_line_old, header) = header.strip_prefix('-')?.split_once(',')?;
+            let mut parts = header.split_ascii_whitespace();
+            let count_old = parts.next()?;
+            let (start_line_new, count_new) = parts.next()?.strip_prefix('+')?.split_once(',')?;
+
+            Some(Self::HunkHeader(Some(HunkLocation {
+                start_line_old: start_line_old.parse::<u32>().ok()?.saturating_sub(1),
+                count_old: count_old.parse().ok()?,
+                start_line_new: start_line_new.parse::<u32>().ok()?.saturating_sub(1),
+                count_new: count_new.parse().ok()?,
+            })))
+        } else if let Some(deleted_header) = line.strip_prefix("-") {
+            Some(Self::Deletion(deleted_header))
+        } else if line.is_empty() {
+            Some(Self::Context(""))
+        } else if let Some(context) = line.strip_prefix(" ") {
+            Some(Self::Context(context))
+        } else {
+            Some(Self::Addition(line.strip_prefix("+")?))
+        }
+    }
+}
+
+fn parse_header_path<'a>(strip_prefix: &'static str, header: &'a str) -> Cow<'a, str> {
+    if !header.contains(['"', '\\']) {
+        let path = header.split_ascii_whitespace().next().unwrap_or(header);
+        return Cow::Borrowed(path.strip_prefix(strip_prefix).unwrap_or(path));
+    }
+
+    let mut path = String::with_capacity(header.len());
+    let mut in_quote = false;
+    let mut chars = header.chars().peekable();
+    let mut strip_prefix = Some(strip_prefix);
+
+    while let Some(char) = chars.next() {
+        if char == '"' {
+            in_quote = !in_quote;
+        } else if char == '\\' {
+            let Some(&next_char) = chars.peek() else {
+                break;
+            };
+            chars.next();
+            path.push(next_char);
+        } else if char.is_ascii_whitespace() && !in_quote {
+            break;
+        } else {
+            path.push(char);
+        }
+
+        if let Some(prefix) = strip_prefix
+            && path == prefix
+        {
+            strip_prefix.take();
+            path.clear();
+        }
+    }
+
+    Cow::Owned(path)
+}
+
+fn eat_required_whitespace(header: &str) -> Option<&str> {
+    let trimmed = header.trim_ascii_start();
+
+    if trimmed.len() == header.len() {
+        None
+    } else {
+        Some(trimmed)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use indoc::indoc;
+
+    #[test]
+    fn parse_lines_simple() {
+        let input = indoc! {"
+            diff --git a/text.txt b/text.txt
+            index 86c770d..a1fd855 100644
+            --- a/file.txt
+            +++ b/file.txt
+            @@ -1,2 +1,3 @@
+             context
+            -deleted
+            +inserted
+            garbage
+
+            --- b/file.txt
+            +++ a/file.txt
+        "};
+
+        let lines = input.lines().map(DiffLine::parse).collect::<Vec<_>>();
+
+        pretty_assertions::assert_eq!(
+            lines,
+            &[
+                DiffLine::Garbage,
+                DiffLine::Garbage,
+                DiffLine::OldPath {
+                    path: "file.txt".into()
+                },
+                DiffLine::NewPath {
+                    path: "file.txt".into()
+                },
+                DiffLine::HunkHeader(Some(HunkLocation {
+                    start_line_old: 0,
+                    count_old: 2,
+                    start_line_new: 0,
+                    count_new: 3
+                })),
+                DiffLine::Context("context"),
+                DiffLine::Deletion("deleted"),
+                DiffLine::Addition("inserted"),
+                DiffLine::Garbage,
+                DiffLine::Context(""),
+                DiffLine::OldPath {
+                    path: "b/file.txt".into()
+                },
+                DiffLine::NewPath {
+                    path: "a/file.txt".into()
+                },
+            ]
+        );
+    }
+
+    #[test]
+    fn file_header_extra_space() {
+        let options = ["--- file", "---   file", "---\tfile"];
+
+        for option in options {
+            pretty_assertions::assert_eq!(
+                DiffLine::parse(option),
+                DiffLine::OldPath {
+                    path: "file".into()
+                },
+                "{option}",
+            );
+        }
+    }
+
+    #[test]
+    fn hunk_header_extra_space() {
+        let options = [
+            "@@ -1,2 +1,3 @@",
+            "@@  -1,2  +1,3 @@",
+            "@@\t-1,2\t+1,3\t@@",
+            "@@ -1,2  +1,3 @@",
+            "@@ -1,2   +1,3 @@",
+            "@@ -1,2 +1,3   @@",
+            "@@ -1,2 +1,3 @@ garbage",
+        ];
+
+        for option in options {
+            pretty_assertions::assert_eq!(
+                DiffLine::parse(option),
+                DiffLine::HunkHeader(Some(HunkLocation {
+                    start_line_old: 0,
+                    count_old: 2,
+                    start_line_new: 0,
+                    count_new: 3
+                })),
+                "{option}",
+            );
+        }
+    }
+
+    #[test]
+    fn hunk_header_without_location() {
+        pretty_assertions::assert_eq!(DiffLine::parse("@@ ... @@"), DiffLine::HunkHeader(None));
+    }
+
+    #[test]
+    fn test_parse_path() {
+        assert_eq!(parse_header_path("a/", "foo.txt"), "foo.txt");
+        assert_eq!(
+            parse_header_path("a/", "foo/bar/baz.txt"),
+            "foo/bar/baz.txt"
+        );
+        assert_eq!(parse_header_path("a/", "a/foo.txt"), "foo.txt");
+        assert_eq!(
+            parse_header_path("a/", "a/foo/bar/baz.txt"),
+            "foo/bar/baz.txt"
+        );
+
+        // Extra
+        assert_eq!(
+            parse_header_path("a/", "a/foo/bar/baz.txt  2025"),
+            "foo/bar/baz.txt"
+        );
+        assert_eq!(
+            parse_header_path("a/", "a/foo/bar/baz.txt\t2025"),
+            "foo/bar/baz.txt"
+        );
+        assert_eq!(
+            parse_header_path("a/", "a/foo/bar/baz.txt \""),
+            "foo/bar/baz.txt"
+        );
+
+        // Quoted
+        assert_eq!(
+            parse_header_path("a/", "a/foo/bar/\"baz quox.txt\""),
+            "foo/bar/baz quox.txt"
+        );
+        assert_eq!(
+            parse_header_path("a/", "\"a/foo/bar/baz quox.txt\""),
+            "foo/bar/baz quox.txt"
+        );
+        assert_eq!(
+            parse_header_path("a/", "\"foo/bar/baz quox.txt\""),
+            "foo/bar/baz quox.txt"
+        );
+        assert_eq!(parse_header_path("a/", "\"whatever 🤷\""), "whatever 🤷");
+        assert_eq!(
+            parse_header_path("a/", "\"foo/bar/baz quox.txt\"  2025"),
+            "foo/bar/baz quox.txt"
+        );
+        // unescaped quotes are dropped
+        assert_eq!(parse_header_path("a/", "foo/\"bar\""), "foo/bar");
+
+        // Escaped
+        assert_eq!(
+            parse_header_path("a/", "\"foo/\\\"bar\\\"/baz.txt\""),
+            "foo/\"bar\"/baz.txt"
+        );
+        assert_eq!(
+            parse_header_path("a/", "\"C:\\\\Projects\\\\My App\\\\old file.txt\""),
+            "C:\\Projects\\My App\\old file.txt"
+        );
+    }
+}

crates/zeta2/src/related_excerpts.rs 🔗

@@ -149,6 +149,9 @@ pub fn find_related_excerpts(
         .find(|model| {
             model.provider_id() == MODEL_PROVIDER_ID
                 && model.id() == LanguageModelId("claude-haiku-4-5-latest".into())
+            // model.provider_id() == LanguageModelProviderId::new("zeta-ctx-qwen-30b")
+            // model.provider_id() == LanguageModelProviderId::new("ollama")
+            //     && model.id() == LanguageModelId("gpt-oss:20b".into())
         })
     else {
         return Task::ready(Err(anyhow!("could not find context model")));

crates/zeta2/src/zeta2.rs 🔗

@@ -35,8 +35,8 @@ use std::str::FromStr as _;
 use std::sync::Arc;
 use std::time::{Duration, Instant};
 use thiserror::Error;
-use util::ResultExt as _;
 use util::rel_path::RelPathBuf;
+use util::{LogErrorFuture, TryFutureExt};
 use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 
 pub mod merge_excerpts;
@@ -50,8 +50,6 @@ use crate::related_excerpts::find_related_excerpts;
 pub use crate::related_excerpts::{LlmContextOptions, SearchToolQuery};
 pub use provider::ZetaEditPredictionProvider;
 
-const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
-
 /// Maximum number of events to track.
 const MAX_EVENT_COUNT: usize = 16;
 
@@ -83,6 +81,7 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
     max_diagnostic_bytes: 2048,
     prompt_format: PromptFormat::DEFAULT,
     file_indexing_parallelism: 1,
+    buffer_change_grouping_interval: Duration::from_secs(1),
 };
 
 pub struct Zeta2FeatureFlag;
@@ -118,6 +117,7 @@ pub struct ZetaOptions {
     pub max_diagnostic_bytes: usize,
     pub prompt_format: predict_edits_v3::PromptFormat,
     pub file_indexing_parallelism: usize,
+    pub buffer_change_grouping_interval: Duration,
 }
 
 #[derive(Debug, Clone, PartialEq)]
@@ -135,6 +135,7 @@ impl ContextMode {
     }
 }
 
+#[derive(Debug)]
 pub enum ZetaDebugInfo {
     ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
     SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
@@ -144,17 +145,20 @@ pub enum ZetaDebugInfo {
     EditPredicted(ZetaEditPredictionDebugInfo),
 }
 
+#[derive(Debug)]
 pub struct ZetaContextRetrievalStartedDebugInfo {
     pub project: Entity<Project>,
     pub timestamp: Instant,
     pub search_prompt: String,
 }
 
+#[derive(Debug)]
 pub struct ZetaContextRetrievalDebugInfo {
     pub project: Entity<Project>,
     pub timestamp: Instant,
 }
 
+#[derive(Debug)]
 pub struct ZetaEditPredictionDebugInfo {
     pub request: predict_edits_v3::PredictEditsRequest,
     pub retrieval_time: TimeDelta,
@@ -164,6 +168,7 @@ pub struct ZetaEditPredictionDebugInfo {
     pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
 }
 
+#[derive(Debug)]
 pub struct ZetaSearchQueryDebugInfo {
     pub project: Entity<Project>,
     pub timestamp: Instant,
@@ -178,7 +183,7 @@ struct ZetaProject {
     registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
     current_prediction: Option<CurrentEditPrediction>,
     context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
-    refresh_context_task: Option<Task<Option<()>>>,
+    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
     refresh_context_debounce_task: Option<Task<Option<()>>>,
     refresh_context_timestamp: Option<Instant>,
 }
@@ -460,6 +465,7 @@ impl Zeta {
         project: &Entity<Project>,
         cx: &mut Context<Self>,
     ) -> BufferSnapshot {
+        let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
         let zeta_project = self.get_or_init_zeta_project(project, cx);
         let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 
@@ -469,6 +475,7 @@ impl Zeta {
                 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
             Self::push_event(
                 zeta_project,
+                buffer_change_grouping_interval,
                 Event::BufferChange {
                     old_snapshot,
                     new_snapshot: new_snapshot.clone(),
@@ -480,14 +487,19 @@ impl Zeta {
         new_snapshot
     }
 
-    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
+    fn push_event(
+        zeta_project: &mut ZetaProject,
+        buffer_change_grouping_interval: Duration,
+        event: Event,
+    ) {
         let events = &mut zeta_project.events;
 
-        if let Some(Event::BufferChange {
-            new_snapshot: last_new_snapshot,
-            timestamp: last_timestamp,
-            ..
-        }) = events.back_mut()
+        if buffer_change_grouping_interval > Duration::ZERO
+            && let Some(Event::BufferChange {
+                new_snapshot: last_new_snapshot,
+                timestamp: last_timestamp,
+                ..
+            }) = events.back_mut()
         {
             // Coalesce edits for the same buffer when they happen one after the other.
             let Event::BufferChange {
@@ -496,7 +508,7 @@ impl Zeta {
                 timestamp,
             } = &event;
 
-            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
+            if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
                 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
                 && old_snapshot.version == last_new_snapshot.version
             {
@@ -624,7 +636,7 @@ impl Zeta {
         })
     }
 
-    fn request_prediction(
+    pub fn request_prediction(
         &mut self,
         project: &Entity<Project>,
         buffer: &Entity<Buffer>,
@@ -1068,7 +1080,11 @@ impl Zeta {
                     log::debug!("refetching edit prediction context after pause");
                 }
                 this.update(cx, |this, cx| {
-                    this.refresh_context(project, buffer, cursor_position, cx);
+                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
+
+                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
+                        zeta_project.refresh_context_task = Some(task.log_err());
+                    };
                 })
                 .ok()
             }
@@ -1077,73 +1093,68 @@ impl Zeta {
 
     // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
     // and avoid spawning more than one concurrent task.
-    fn refresh_context(
+    pub fn refresh_context(
         &mut self,
         project: Entity<Project>,
         buffer: Entity<language::Buffer>,
         cursor_position: language::Anchor,
         cx: &mut Context<Self>,
-    ) {
-        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
-            return;
-        };
-
-        let debug_tx = self.debug_tx.clone();
-
-        zeta_project
-            .refresh_context_task
-            .get_or_insert(cx.spawn(async move |this, cx| {
-                let related_excerpts = this
-                    .update(cx, |this, cx| {
-                        let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
-                            return Task::ready(anyhow::Ok(HashMap::default()));
-                        };
+    ) -> Task<Result<()>> {
+        cx.spawn(async move |this, cx| {
+            let related_excerpts_result = this
+                .update(cx, |this, cx| {
+                    let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
+                        return Task::ready(anyhow::Ok(HashMap::default()));
+                    };
 
-                        let ContextMode::Llm(options) = &this.options().context else {
-                            return Task::ready(anyhow::Ok(HashMap::default()));
-                        };
+                    let ContextMode::Llm(options) = &this.options().context else {
+                        return Task::ready(anyhow::Ok(HashMap::default()));
+                    };
 
-                        let mut edit_history_unified_diff = String::new();
+                    let mut edit_history_unified_diff = String::new();
 
-                        for event in zeta_project.events.iter() {
-                            if let Some(event) = event.to_request_event(cx) {
-                                writeln!(&mut edit_history_unified_diff, "{event}").ok();
-                            }
+                    for event in zeta_project.events.iter() {
+                        if let Some(event) = event.to_request_event(cx) {
+                            writeln!(&mut edit_history_unified_diff, "{event}").ok();
                         }
+                    }
 
-                        find_related_excerpts(
-                            buffer.clone(),
-                            cursor_position,
-                            &project,
-                            edit_history_unified_diff,
-                            options,
-                            debug_tx,
-                            cx,
-                        )
-                    })
-                    .ok()?
-                    .await
-                    .log_err()
-                    .unwrap_or_default();
-                this.update(cx, |this, _cx| {
-                    let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
-                        return;
-                    };
-                    zeta_project.context = Some(related_excerpts);
-                    zeta_project.refresh_context_task.take();
-                    if let Some(debug_tx) = &this.debug_tx {
-                        debug_tx
-                            .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
-                                ZetaContextRetrievalDebugInfo {
-                                    project,
-                                    timestamp: Instant::now(),
-                                },
-                            ))
-                            .ok();
+                    find_related_excerpts(
+                        buffer.clone(),
+                        cursor_position,
+                        &project,
+                        edit_history_unified_diff,
+                        options,
+                        this.debug_tx.clone(),
+                        cx,
+                    )
+                })?
+                .await;
+
+            this.update(cx, |this, _cx| {
+                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
+                    return Ok(());
+                };
+                zeta_project.refresh_context_task.take();
+                if let Some(debug_tx) = &this.debug_tx {
+                    debug_tx
+                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
+                            ZetaContextRetrievalDebugInfo {
+                                project,
+                                timestamp: Instant::now(),
+                            },
+                        ))
+                        .ok();
+                }
+                match related_excerpts_result {
+                    Ok(excerpts) => {
+                        zeta_project.context = Some(excerpts);
+                        Ok(())
                     }
-                })
-                .ok()
-            }));
+                    Err(error) => Err(error),
+                }
+            })?
+        })
     }
 
     fn gather_nearby_diagnostics(

crates/zeta2_tools/src/zeta2_tools.rs 🔗

@@ -335,6 +335,8 @@ impl Zeta2Inspector {
                         max_diagnostic_bytes: zeta_options.max_diagnostic_bytes,
                         prompt_format: zeta_options.prompt_format,
                         file_indexing_parallelism: zeta_options.file_indexing_parallelism,
+                        buffer_change_grouping_interval: zeta_options
+                            .buffer_change_grouping_interval,
                     },
                     cx,
                 );

crates/zeta_cli/Cargo.toml 🔗

@@ -13,6 +13,7 @@ name = "zeta"
 path = "src/main.rs"
 
 [dependencies]
+
 anyhow.workspace = true
 chrono.workspace = true
 clap.workspace = true
@@ -42,7 +43,6 @@ prompt_store.workspace = true
 pulldown-cmark.workspace = true
 release_channel.workspace = true
 reqwest_client.workspace = true
-toml.workspace = true
 serde.workspace = true
 serde_json.workspace = true
 settings.workspace = true
@@ -50,8 +50,15 @@ shellexpand.workspace = true
 smol.workspace = true
 soa-rs = "0.8.1"
 terminal_view.workspace = true
+toml.workspace = true
 util.workspace = true
 watch.workspace = true
 zeta.workspace = true
 zeta2.workspace = true
 zlog.workspace = true
+
+[dev-dependencies]
+indoc.workspace = true
+gpui = { workspace = true, features = ["test-support"] }
+project = { workspace = true, features = ["test-support"] }
+pretty_assertions.workspace = true

crates/zeta_cli/src/example.rs 🔗

@@ -5,17 +5,23 @@ use std::{
     fs,
     io::Write,
     mem,
+    ops::Range,
     path::{Path, PathBuf},
 };
 
 use anyhow::{Context as _, Result};
 use clap::ValueEnum;
-use gpui::http_client::Url;
+use collections::HashSet;
+use futures::AsyncWriteExt as _;
+use gpui::{AsyncApp, Entity, http_client::Url};
+use language::Buffer;
+use project::{Project, ProjectPath};
 use pulldown_cmark::CowStr;
 use serde::{Deserialize, Serialize};
 
-const CURSOR_POSITION_HEADING: &str = "Cursor Position";
+const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
 const EDIT_HISTORY_HEADING: &str = "Edit History";
+const CURSOR_POSITION_HEADING: &str = "Cursor Position";
 const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
 const EXPECTED_EXCERPTS_HEADING: &str = "Expected Excerpts";
 const REPOSITORY_URL_FIELD: &str = "repository_url";
@@ -31,9 +37,10 @@ pub struct NamedExample {
 pub struct Example {
     pub repository_url: String,
     pub revision: String,
+    pub uncommitted_diff: String,
     pub cursor_path: PathBuf,
     pub cursor_position: String,
-    pub edit_history: Vec<String>,
+    pub edit_history: String,
     pub expected_patch: String,
     pub expected_excerpts: Vec<ExpectedExcerpt>,
 }
@@ -59,11 +66,11 @@ impl NamedExample {
 
         match ext.and_then(|s| s.to_str()) {
             Some("json") => Ok(Self {
-                name: path.file_name().unwrap_or_default().display().to_string(),
+                name: path.file_stem().unwrap_or_default().display().to_string(),
                 example: serde_json::from_str(&content)?,
             }),
             Some("toml") => Ok(Self {
-                name: path.file_name().unwrap_or_default().display().to_string(),
+                name: path.file_stem().unwrap_or_default().display().to_string(),
                 example: toml::from_str(&content)?,
             }),
             Some("md") => Self::parse_md(&content),
@@ -88,9 +95,10 @@ impl NamedExample {
             example: Example {
                 repository_url: String::new(),
                 revision: String::new(),
+                uncommitted_diff: String::new(),
                 cursor_path: PathBuf::new(),
                 cursor_position: String::new(),
-                edit_history: Vec::new(),
+                edit_history: String::new(),
                 expected_patch: String::new(),
                 expected_excerpts: Vec::new(),
             },
@@ -152,18 +160,19 @@ impl NamedExample {
                     block_info = "".into();
                 }
                 Event::End(TagEnd::CodeBlock) => {
-                    if current_section.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
-                        named.example.edit_history.push(mem::take(&mut text));
+                    let block_info = block_info.trim();
+                    if current_section.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
+                        named.example.uncommitted_diff = mem::take(&mut text);
+                    } else if current_section.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
+                        named.example.edit_history.push_str(&mem::take(&mut text));
                     } else if current_section.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
-                        let path = PathBuf::from(block_info.trim());
-                        named.example.cursor_path = path;
+                        named.example.cursor_path = block_info.into();
                         named.example.cursor_position = mem::take(&mut text);
                     } else if current_section.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
                         named.example.expected_patch = mem::take(&mut text);
                     } else if current_section.eq_ignore_ascii_case(EXPECTED_EXCERPTS_HEADING) {
-                        let path = PathBuf::from(block_info.trim());
                         named.example.expected_excerpts.push(ExpectedExcerpt {
-                            path,
+                            path: block_info.into(),
                             text: mem::take(&mut text),
                         });
                     } else {
@@ -195,13 +204,14 @@ impl NamedExample {
 
     #[allow(unused)]
     pub async fn setup_worktree(&self) -> Result<PathBuf> {
+        let (repo_owner, repo_name) = self.repo_name()?;
+        let file_name = self.file_name();
+
         let worktrees_dir = env::current_dir()?.join("target").join("zeta-worktrees");
         let repos_dir = env::current_dir()?.join("target").join("zeta-repos");
         fs::create_dir_all(&repos_dir)?;
         fs::create_dir_all(&worktrees_dir)?;
 
-        let (repo_owner, repo_name) = self.repo_name()?;
-
         let repo_dir = repos_dir.join(repo_owner.as_ref()).join(repo_name.as_ref());
         if !repo_dir.is_dir() {
             fs::create_dir_all(&repo_dir)?;
@@ -213,36 +223,81 @@ impl NamedExample {
             .await?;
         }
 
-        run_git(
-            &repo_dir,
-            &["fetch", "--depth", "1", "origin", &self.example.revision],
-        )
-        .await?;
-
-        let worktree_path = worktrees_dir.join(&self.name);
+        // Resolve the example to a revision, fetching it if needed.
+        let revision = run_git(&repo_dir, &["rev-parse", &self.example.revision]).await;
+        let revision = if let Ok(revision) = revision {
+            revision
+        } else {
+            run_git(
+                &repo_dir,
+                &["fetch", "--depth", "1", "origin", &self.example.revision],
+            )
+            .await?;
+            let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
+            if revision != self.example.revision {
+                run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?;
+            }
+            revision
+        };
 
+        // Create the worktree for this example if needed.
+        let worktree_path = worktrees_dir.join(&file_name);
         if worktree_path.is_dir() {
             run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
             run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
-            run_git(&worktree_path, &["checkout", &self.example.revision]).await?;
+            run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
         } else {
             let worktree_path_string = worktree_path.to_string_lossy();
+            run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
             run_git(
                 &repo_dir,
-                &[
-                    "worktree",
-                    "add",
-                    "-f",
-                    &worktree_path_string,
-                    &self.example.revision,
-                ],
+                &["worktree", "add", "-f", &worktree_path_string, &file_name],
             )
             .await?;
         }
 
+        // Apply the uncommitted diff for this example.
+        if !self.example.uncommitted_diff.is_empty() {
+            let mut apply_process = smol::process::Command::new("git")
+                .current_dir(&worktree_path)
+                .args(&["apply", "-"])
+                .stdin(std::process::Stdio::piped())
+                .spawn()?;
+
+            let mut stdin = apply_process.stdin.take().unwrap();
+            stdin
+                .write_all(self.example.uncommitted_diff.as_bytes())
+                .await?;
+            stdin.close().await?;
+            drop(stdin);
+
+            let apply_result = apply_process.output().await?;
+            if !apply_result.status.success() {
+                anyhow::bail!(
+                    "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
+                    apply_result.status,
+                    String::from_utf8_lossy(&apply_result.stderr),
+                    String::from_utf8_lossy(&apply_result.stdout),
+                );
+            }
+        }
+
         Ok(worktree_path)
     }
 
+    fn file_name(&self) -> String {
+        self.name
+            .chars()
+            .map(|c| {
+                if c.is_whitespace() {
+                    '-'
+                } else {
+                    c.to_ascii_lowercase()
+                }
+            })
+            .collect()
+    }
+
     #[allow(unused)]
     fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
         // git@github.com:owner/repo.git
@@ -277,6 +332,15 @@ impl NamedExample {
             Ok((owner.into(), repo.into()))
         }
     }
+
+    #[must_use]
+    pub async fn apply_edit_history(
+        &self,
+        project: &Entity<Project>,
+        cx: &mut AsyncApp,
+    ) -> Result<HashSet<Entity<Buffer>>> {
+        apply_diff(&self.example.edit_history, project, cx).await
+    }
 }
 
 async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
@@ -308,6 +372,15 @@ impl Display for NamedExample {
         )?;
         write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
 
+        write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
+        write!(f, "`````diff\n")?;
+        write!(f, "{}", self.example.uncommitted_diff)?;
+        write!(f, "`````\n")?;
+
+        if !self.example.edit_history.is_empty() {
+            write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
+        }
+
         write!(
             f,
             "## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
@@ -316,14 +389,6 @@ impl Display for NamedExample {
         )?;
         write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
 
-        if !self.example.edit_history.is_empty() {
-            write!(f, "`````diff\n")?;
-            for item in &self.example.edit_history {
-                write!(f, "{item}")?;
-            }
-            write!(f, "`````\n")?;
-        }
-
         if !self.example.expected_patch.is_empty() {
             write!(
                 f,
@@ -353,3 +418,404 @@ impl Display for NamedExample {
         Ok(())
     }
 }
+
+#[must_use]
+pub async fn apply_diff(
+    diff: &str,
+    project: &Entity<Project>,
+    cx: &mut AsyncApp,
+) -> Result<HashSet<Entity<Buffer>>> {
+    use cloud_llm_client::udiff::DiffLine;
+    use std::fmt::Write;
+
+    #[derive(Debug, Default)]
+    struct HunkState {
+        context: String,
+        edits: Vec<Edit>,
+    }
+
+    #[derive(Debug)]
+    struct Edit {
+        range: Range<usize>,
+        text: String,
+    }
+
+    let mut old_path = None;
+    let mut new_path = None;
+    let mut hunk = HunkState::default();
+    let mut diff_lines = diff.lines().map(DiffLine::parse).peekable();
+    let mut open_buffers = HashSet::default();
+
+    while let Some(diff_line) = diff_lines.next() {
+        match diff_line {
+            DiffLine::OldPath { path } => old_path = Some(path),
+            DiffLine::NewPath { path } => {
+                if old_path.is_none() {
+                    anyhow::bail!(
+                        "Found a new path header (`+++`) before an (`---`) old path header"
+                    );
+                }
+                new_path = Some(path)
+            }
+            DiffLine::Context(ctx) => {
+                writeln!(&mut hunk.context, "{ctx}")?;
+            }
+            DiffLine::Deletion(del) => {
+                let range = hunk.context.len()..hunk.context.len() + del.len() + '\n'.len_utf8();
+                if let Some(last_edit) = hunk.edits.last_mut()
+                    && last_edit.range.end == range.start
+                {
+                    last_edit.range.end = range.end;
+                } else {
+                    hunk.edits.push(Edit {
+                        range,
+                        text: String::new(),
+                    });
+                }
+                writeln!(&mut hunk.context, "{del}")?;
+            }
+            DiffLine::Addition(add) => {
+                let range = hunk.context.len()..hunk.context.len();
+                if let Some(last_edit) = hunk.edits.last_mut()
+                    && last_edit.range.end == range.start
+                {
+                    writeln!(&mut last_edit.text, "{add}").unwrap();
+                } else {
+                    hunk.edits.push(Edit {
+                        range,
+                        text: format!("{add}\n"),
+                    });
+                }
+            }
+            DiffLine::HunkHeader(_) | DiffLine::Garbage => {}
+        }
+
+        let at_hunk_end = match diff_lines.peek() {
+            Some(DiffLine::OldPath { .. }) | Some(DiffLine::HunkHeader(_)) | None => true,
+            _ => false,
+        };
+
+        if at_hunk_end {
+            let hunk = mem::take(&mut hunk);
+
+            let Some(old_path) = old_path.as_deref() else {
+                anyhow::bail!("Missing old path (`---`) header")
+            };
+
+            let Some(new_path) = new_path.as_deref() else {
+                anyhow::bail!("Missing new path (`+++`) header")
+            };
+
+            let buffer = project
+                .update(cx, |project, cx| {
+                    let project_path = project
+                        .find_project_path(old_path, cx)
+                        .context("Failed to find old_path in project")?;
+
+                    anyhow::Ok(project.open_buffer(project_path, cx))
+                })??
+                .await?;
+            open_buffers.insert(buffer.clone());
+
+            if old_path != new_path {
+                project
+                    .update(cx, |project, cx| {
+                        let project_file = project::File::from_dyn(buffer.read(cx).file()).unwrap();
+                        let new_path = ProjectPath {
+                            worktree_id: project_file.worktree_id(cx),
+                            path: project_file.path.clone(),
+                        };
+                        project.rename_entry(project_file.entry_id.unwrap(), new_path, cx)
+                    })?
+                    .await?;
+            }
+
+            // TODO is it worth using project search?
+            buffer.update(cx, |buffer, cx| {
+                let context_offset = if hunk.context.is_empty() {
+                    0
+                } else {
+                    let text = buffer.text();
+                    if let Some(offset) = text.find(&hunk.context) {
+                        if text[offset + 1..].contains(&hunk.context) {
+                            anyhow::bail!("Context is not unique enough:\n{}", hunk.context);
+                        }
+                        offset
+                    } else {
+                        anyhow::bail!(
+                            "Failed to match context:\n{}\n\nBuffer:\n{}",
+                            hunk.context,
+                            text
+                        );
+                    }
+                };
+
+                buffer.edit(
+                    hunk.edits.into_iter().map(|edit| {
+                        (
+                            context_offset + edit.range.start..context_offset + edit.range.end,
+                            edit.text,
+                        )
+                    }),
+                    None,
+                    cx,
+                );
+
+                anyhow::Ok(())
+            })??;
+        }
+    }
+
+    anyhow::Ok(open_buffers)
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use ::fs::FakeFs;
+    use gpui::TestAppContext;
+    use indoc::indoc;
+    use pretty_assertions::assert_eq;
+    use project::Project;
+    use serde_json::json;
+    use settings::SettingsStore;
+    use util::path;
+
+    #[gpui::test]
+    async fn test_apply_diff_successful(cx: &mut TestAppContext) {
+        let buffer_1_text = indoc! {r#"
+            one
+            two
+            three
+            four
+            five
+        "# };
+
+        let buffer_1_text_final = indoc! {r#"
+            3
+            4
+            5
+        "# };
+
+        let buffer_2_text = indoc! {r#"
+            six
+            seven
+            eight
+            nine
+            ten
+        "# };
+
+        let buffer_2_text_final = indoc! {r#"
+            5
+            six
+            seven
+            7.5
+            eight
+            nine
+            ten
+            11
+        "# };
+
+        cx.update(|cx| {
+            let settings_store = SettingsStore::test(cx);
+            cx.set_global(settings_store);
+            Project::init_settings(cx);
+            language::init(cx);
+        });
+
+        let fs = FakeFs::new(cx.background_executor.clone());
+        fs.insert_tree(
+            path!("/root"),
+            json!({
+                "file1": buffer_1_text,
+                "file2": buffer_2_text,
+            }),
+        )
+        .await;
+
+        let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
+
+        let diff = indoc! {r#"
+            --- a/root/file1
+            +++ b/root/file1
+             one
+             two
+            -three
+            +3
+             four
+             five
+            --- a/root/file1
+            +++ b/root/file1
+             3
+            -four
+            -five
+            +4
+            +5
+            --- a/root/file1
+            +++ b/root/file1
+            -one
+            -two
+             3
+             4
+            --- a/root/file2
+            +++ b/root/file2
+            +5
+             six
+            --- a/root/file2
+            +++ b/root/file2
+             seven
+            +7.5
+             eight
+            --- a/root/file2
+            +++ b/root/file2
+             ten
+            +11
+        "#};
+
+        let _buffers = apply_diff(diff, &project, &mut cx.to_async())
+            .await
+            .unwrap();
+        let buffer_1 = project
+            .update(cx, |project, cx| {
+                let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap();
+                project.open_buffer(project_path, cx)
+            })
+            .await
+            .unwrap();
+
+        buffer_1.read_with(cx, |buffer, _cx| {
+            assert_eq!(buffer.text(), buffer_1_text_final);
+        });
+        let buffer_2 = project
+            .update(cx, |project, cx| {
+                let project_path = project.find_project_path(path!("/root/file2"), cx).unwrap();
+                project.open_buffer(project_path, cx)
+            })
+            .await
+            .unwrap();
+
+        buffer_2.read_with(cx, |buffer, _cx| {
+            assert_eq!(buffer.text(), buffer_2_text_final);
+        });
+    }
+
+    #[gpui::test]
+    async fn test_apply_diff_non_unique(cx: &mut TestAppContext) {
+        let buffer_1_text = indoc! {r#"
+            one
+            two
+            three
+            four
+            five
+            one
+            two
+            three
+            four
+            five
+        "# };
+
+        cx.update(|cx| {
+            let settings_store = SettingsStore::test(cx);
+            cx.set_global(settings_store);
+            Project::init_settings(cx);
+            language::init(cx);
+        });
+
+        let fs = FakeFs::new(cx.background_executor.clone());
+        fs.insert_tree(
+            path!("/root"),
+            json!({
+                "file1": buffer_1_text,
+            }),
+        )
+        .await;
+
+        let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
+
+        let diff = indoc! {r#"
+            --- a/root/file1
+            +++ b/root/file1
+             one
+             two
+            -three
+            +3
+             four
+             five
+        "#};
+
+        apply_diff(diff, &project, &mut cx.to_async())
+            .await
+            .expect_err("Non-unique edits should fail");
+    }
+
+    #[gpui::test]
+    async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) {
+        let start = indoc! {r#"
+            one
+            two
+            three
+            four
+            five
+
+            four
+            five
+        "# };
+
+        let end = indoc! {r#"
+            one
+            two
+            3
+            four
+            5
+
+            four
+            five
+        "# };
+
+        cx.update(|cx| {
+            let settings_store = SettingsStore::test(cx);
+            cx.set_global(settings_store);
+            Project::init_settings(cx);
+            language::init(cx);
+        });
+
+        let fs = FakeFs::new(cx.background_executor.clone());
+        fs.insert_tree(
+            path!("/root"),
+            json!({
+                "file1": start,
+            }),
+        )
+        .await;
+
+        let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
+
+        let diff = indoc! {r#"
+            --- a/root/file1
+            +++ b/root/file1
+             one
+             two
+            -three
+            +3
+             four
+            -five
+            +5
+        "#};
+
+        let _buffers = apply_diff(diff, &project, &mut cx.to_async())
+            .await
+            .unwrap();
+
+        let buffer_1 = project
+            .update(cx, |project, cx| {
+                let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap();
+                project.open_buffer(project_path, cx)
+            })
+            .await
+            .unwrap();
+
+        buffer_1.read_with(cx, |buffer, _cx| {
+            assert_eq!(buffer.text(), end);
+        });
+    }
+}

crates/zeta_cli/src/main.rs 🔗

@@ -8,6 +8,7 @@ use crate::example::{ExampleFormat, NamedExample};
 use crate::syntax_retrieval_stats::retrieval_stats;
 use ::serde::Serialize;
 use ::util::paths::PathStyle;
+use ::util::rel_path::RelPath;
 use anyhow::{Context as _, Result, anyhow};
 use clap::{Args, Parser, Subcommand};
 use cloud_llm_client::predict_edits_v3::{self, Excerpt};
@@ -21,10 +22,11 @@ use futures::channel::mpsc;
 use gpui::{Application, AsyncApp, Entity, prelude::*};
 use language::{Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point};
 use language_model::LanguageModelRegistry;
-use project::{Project, Worktree};
+use project::{Project, ProjectPath, Worktree};
 use reqwest_client::ReqwestClient;
 use serde_json::json;
 use std::io;
+use std::time::{Duration, Instant};
 use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc};
 use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery};
 
@@ -46,8 +48,6 @@ enum Command {
         command: Zeta1Command,
     },
     Zeta2 {
-        #[clap(flatten)]
-        args: Zeta2Args,
         #[command(subcommand)]
         command: Zeta2Command,
     },
@@ -69,15 +69,22 @@ enum Zeta1Command {
 #[derive(Subcommand, Debug)]
 enum Zeta2Command {
     Syntax {
+        #[clap(flatten)]
+        args: Zeta2Args,
         #[clap(flatten)]
         syntax_args: Zeta2SyntaxArgs,
         #[command(subcommand)]
         command: Zeta2SyntaxCommand,
     },
     Llm {
+        #[clap(flatten)]
+        args: Zeta2Args,
         #[command(subcommand)]
         command: Zeta2LlmCommand,
     },
+    Predict {
+        example_path: PathBuf,
+    },
 }
 
 #[derive(Subcommand, Debug)]
@@ -170,6 +177,7 @@ fn syntax_args_to_options(
         max_prompt_bytes: zeta2_args.max_prompt_bytes,
         prompt_format: zeta2_args.prompt_format.clone().into(),
         file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
+        buffer_change_grouping_interval: Duration::ZERO,
     }
 }
 
@@ -319,6 +327,208 @@ async fn load_context(
     })
 }
 
+async fn zeta2_predict(
+    example: NamedExample,
+    app_state: &Arc<ZetaCliAppState>,
+    cx: &mut AsyncApp,
+) -> Result<()> {
+    let worktree_path = example.setup_worktree().await?;
+
+    cx.update(|cx| {
+        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+            registry
+                .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
+                .unwrap()
+                .authenticate(cx)
+        })
+    })?
+    .await?;
+
+    app_state
+        .client
+        .sign_in_with_optional_connect(true, cx)
+        .await?;
+
+    let project = cx.update(|cx| {
+        Project::local(
+            app_state.client.clone(),
+            app_state.node_runtime.clone(),
+            app_state.user_store.clone(),
+            app_state.languages.clone(),
+            app_state.fs.clone(),
+            None,
+            cx,
+        )
+    })?;
+
+    let worktree = project
+        .update(cx, |project, cx| {
+            project.create_worktree(&worktree_path, true, cx)
+        })?
+        .await?;
+    worktree
+        .read_with(cx, |worktree, _cx| {
+            worktree.as_local().unwrap().scan_complete()
+        })?
+        .await;
+
+    let _edited_buffers = example.apply_edit_history(&project, cx).await?;
+
+    let cursor_path = RelPath::new(&example.example.cursor_path, PathStyle::Posix)?.into_arc();
+
+    let cursor_buffer = project
+        .update(cx, |project, cx| {
+            project.open_buffer(
+                ProjectPath {
+                    worktree_id: worktree.read(cx).id(),
+                    path: cursor_path,
+                },
+                cx,
+            )
+        })?
+        .await?;
+
+    let cursor_offset_within_excerpt = example
+        .example
+        .cursor_position
+        .find(CURSOR_MARKER)
+        .ok_or_else(|| anyhow!("missing cursor marker"))?;
+    let mut cursor_excerpt = example.example.cursor_position.clone();
+    cursor_excerpt.replace_range(
+        cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
+        "",
+    );
+    let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
+        let text = buffer.text();
+
+        let mut matches = text.match_indices(&cursor_excerpt);
+        let Some((excerpt_offset, _)) = matches.next() else {
+            anyhow::bail!(
+                "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n"
+            );
+        };
+        assert!(matches.next().is_none());
+
+        Ok(excerpt_offset)
+    })??;
+
+    let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
+    let cursor_anchor =
+        cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
+
+    let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
+
+    let refresh_task = zeta.update(cx, |zeta, cx| {
+        zeta.register_buffer(&cursor_buffer, &project, cx);
+        zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
+    })?;
+
+    let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
+    let mut context_retrieval_started_at = None;
+    let mut context_retrieval_finished_at = None;
+    let mut search_queries_generated_at = None;
+    let mut search_queries_executed_at = None;
+    let mut prediction_started_at = None;
+    let mut prediction_finished_at = None;
+    let mut excerpts_text = String::new();
+    let mut prediction_task = None;
+    while let Some(event) = debug_rx.next().await {
+        match event {
+            zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
+                context_retrieval_started_at = Some(info.timestamp);
+            }
+            zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
+                search_queries_generated_at = Some(info.timestamp);
+            }
+            zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
+                search_queries_executed_at = Some(info.timestamp);
+            }
+            zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
+                context_retrieval_finished_at = Some(info.timestamp);
+
+                prediction_task = Some(zeta.update(cx, |zeta, cx| {
+                    zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
+                })?);
+            }
+            zeta2::ZetaDebugInfo::EditPredicted(request) => {
+                prediction_started_at = Some(Instant::now());
+                request.response_rx.await?.map_err(|err| anyhow!(err))?;
+                prediction_finished_at = Some(Instant::now());
+
+                for included_file in request.request.included_files {
+                    let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)];
+                    write_codeblock(
+                        &included_file.path,
+                        included_file.excerpts.iter(),
+                        if included_file.path == request.request.excerpt_path {
+                            &insertions
+                        } else {
+                            &[]
+                        },
+                        included_file.max_row,
+                        false,
+                        &mut excerpts_text,
+                    );
+                }
+                break;
+            }
+            _ => {}
+        }
+    }
+
+    refresh_task.await.context("context retrieval failed")?;
+    let prediction = prediction_task.unwrap().await?.context("No prediction")?;
+
+    println!("## Excerpts\n");
+    println!("{excerpts_text}");
+
+    let old_text = prediction.snapshot.text();
+    let new_text = prediction.buffer.update(cx, |buffer, cx| {
+        buffer.edit(prediction.edits.iter().cloned(), None, cx);
+        buffer.text()
+    })?;
+    let diff = language::unified_diff(&old_text, &new_text);
+
+    println!("## Prediction\n");
+    println!("{diff}");
+
+    println!("## Time\n");
+
+    let planning_search_time =
+        search_queries_generated_at.unwrap() - context_retrieval_started_at.unwrap();
+
+    println!("Planning searches: {}ms", planning_search_time.as_millis());
+    println!(
+        "Running searches: {}ms",
+        (search_queries_executed_at.unwrap() - search_queries_generated_at.unwrap()).as_millis()
+    );
+
+    let filtering_search_time =
+        context_retrieval_finished_at.unwrap() - search_queries_executed_at.unwrap();
+    println!(
+        "Filtering context results: {}ms",
+        filtering_search_time.as_millis()
+    );
+
+    let prediction_time = prediction_finished_at.unwrap() - prediction_started_at.unwrap();
+    println!("Making Prediction: {}ms", prediction_time.as_millis());
+
+    println!("-------------------");
+    let total_time =
+        (prediction_finished_at.unwrap() - context_retrieval_started_at.unwrap()).as_millis();
+    println!("Total: {}ms", total_time);
+
+    let inference_time =
+        (planning_search_time + filtering_search_time + prediction_time).as_millis();
+    println!(
+        "Inference: {}ms ({:.2}%)",
+        inference_time,
+        (inference_time as f64 / total_time as f64) * 100.
+    );
+
+    anyhow::Ok(())
+}
+
 async fn zeta2_syntax_context(
     zeta2_args: Zeta2Args,
     syntax_args: Zeta2SyntaxArgs,
@@ -616,8 +826,15 @@ fn main() {
                     let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
                     serde_json::to_string_pretty(&context.body).map_err(|err| anyhow::anyhow!(err))
                 }
-                Command::Zeta2 { args, command } => match command {
+                Command::Zeta2 { command } => match command {
+                    Zeta2Command::Predict { example_path } => {
+                        let example = NamedExample::load(example_path).unwrap();
+                        zeta2_predict(example, &app_state, cx).await.unwrap();
+                        let _ = cx.update(|cx| cx.quit());
+                        return;
+                    }
                     Zeta2Command::Syntax {
+                        args,
                         syntax_args,
                         command,
                     } => match command {
@@ -643,7 +860,7 @@ fn main() {
                             .await
                         }
                     },
-                    Zeta2Command::Llm { command } => match command {
+                    Zeta2Command::Llm { args, command } => match command {
                         Zeta2LlmCommand::Context { context_args } => {
                             zeta2_llm_context(args, context_args, &app_state, cx).await
                         }