Allow running zeta evals against sweep (#43039)

Max Brunsfeld , Ben Kunkle , and Agus created

This PR restructures the subcommands in `zeta-cli`, so that the
prediction engine (currently `zeta1` vs `zeta2`) is no longer the
highest order subcommand. Instead, there is just one layer of
subcommands: `eval`, `predict`, `context`, etc. Within these commands,
there are flags for using `zeta1`, `zeta2`, and now `sweep`.

Release Notes:

- N/A

---------

Co-authored-by: Ben Kunkle <ben@zed.dev>
Co-authored-by: Agus <agus@zed.dev>

Change summary

Cargo.lock                        |   1 
crates/sweep_ai/src/sweep_ai.rs   |  64 ++--
crates/workspace/src/workspace.rs |   2 
crates/zeta2/src/zeta2.rs         |  15 
crates/zeta_cli/Cargo.toml        |   1 
crates/zeta_cli/src/evaluate.rs   | 102 +++---
crates/zeta_cli/src/example.rs    |  39 --
crates/zeta_cli/src/main.rs       | 292 +++++++++++---------
crates/zeta_cli/src/predict.rs    | 449 ++++++++++++++++++--------------
9 files changed, 517 insertions(+), 448 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -21864,6 +21864,7 @@ dependencies = [
  "shellexpand 2.1.2",
  "smol",
  "soa-rs",
+ "sweep_ai",
  "terminal_view",
  "toml 0.8.23",
  "util",

crates/sweep_ai/src/sweep_ai.rs 🔗

@@ -11,7 +11,7 @@ use http_client::{AsyncBody, Method};
 use language::{
     Anchor, Buffer, BufferSnapshot, EditPreview, Point, ToOffset as _, ToPoint, text_diff,
 };
-use project::Project;
+use project::{Project, ProjectPath};
 use release_channel::{AppCommitSha, AppVersion};
 use std::collections::{VecDeque, hash_map};
 use std::fmt::{self, Display};
@@ -48,11 +48,11 @@ impl Global for SweepAiGlobal {}
 
 #[derive(Clone)]
 pub struct EditPrediction {
-    id: EditPredictionId,
-    path: Arc<Path>,
-    edits: Arc<[(Range<Anchor>, Arc<str>)]>,
-    snapshot: BufferSnapshot,
-    edit_preview: EditPreview,
+    pub id: EditPredictionId,
+    pub path: Arc<Path>,
+    pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
+    pub snapshot: BufferSnapshot,
+    pub edit_preview: EditPreview,
 }
 
 impl EditPrediction {
@@ -110,7 +110,7 @@ impl SweepAi {
         }
     }
 
-    fn new(cx: &mut Context<Self>) -> Self {
+    pub fn new(cx: &mut Context<Self>) -> Self {
         Self {
             api_token: std::env::var("SWEEP_AI_TOKEN").ok(),
             projects: HashMap::default(),
@@ -195,8 +195,8 @@ impl SweepAi {
 
     pub fn request_completion(
         &mut self,
-        workspace: &WeakEntity<Workspace>,
         project: &Entity<Project>,
+        recent_buffers: impl Iterator<Item = ProjectPath>,
         active_buffer: &Entity<Buffer>,
         position: language::Anchor,
         cx: &mut Context<Self>,
@@ -223,26 +223,17 @@ impl SweepAi {
         let events = project_state.events.clone();
         let http_client = cx.http_client();
 
-        let Some(recent_buffers) = workspace
-            .read_with(cx, |workspace, cx| {
-                workspace
-                    .recent_navigation_history_iter(cx)
-                    .filter_map(|(project_path, _)| {
-                        let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
-
-                        if active_buffer == &buffer {
-                            None
-                        } else {
-                            Some(buffer.read(cx).snapshot())
-                        }
-                    })
-                    .take(3)
-                    .collect::<Vec<_>>()
+        let recent_buffer_snapshots = recent_buffers
+            .filter_map(|project_path| {
+                let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
+                if active_buffer == &buffer {
+                    None
+                } else {
+                    Some(buffer.read(cx).snapshot())
+                }
             })
-            .log_err()
-        else {
-            return Task::ready(Ok(None));
-        };
+            .take(3)
+            .collect::<Vec<_>>();
 
         let result = cx.background_spawn({
             let full_path = full_path.clone();
@@ -255,7 +246,7 @@ impl SweepAi {
                     writeln!(&mut recent_changes, "{event}")?;
                 }
 
-                let file_chunks = recent_buffers
+                let file_chunks = recent_buffer_snapshots
                     .into_iter()
                     .map(|snapshot| {
                         let end_point = language::Point::new(30, 0).min(snapshot.max_point());
@@ -623,8 +614,23 @@ impl edit_prediction::EditPredictionProvider for SweepAiEditPredictionProvider {
 
             let completion_request = this.update(cx, |this, cx| {
                 this.last_request_timestamp = Instant::now();
+
                 this.sweep_ai.update(cx, |sweep_ai, cx| {
-                    sweep_ai.request_completion(&workspace, &project, &buffer, position, cx)
+                    let Some(recent_buffers) = workspace
+                        .read_with(cx, |workspace, cx| {
+                            workspace.recent_navigation_history_iter(cx)
+                        })
+                        .log_err()
+                    else {
+                        return Task::ready(Ok(None));
+                    };
+                    sweep_ai.request_completion(
+                        &project,
+                        recent_buffers.map(move |(project_path, _)| project_path),
+                        &buffer,
+                        position,
+                        cx,
+                    )
                 })
             });
 

crates/workspace/src/workspace.rs 🔗

@@ -1845,7 +1845,7 @@ impl Workspace {
     pub fn recent_navigation_history_iter(
         &self,
         cx: &App,
-    ) -> impl Iterator<Item = (ProjectPath, Option<PathBuf>)> {
+    ) -> impl Iterator<Item = (ProjectPath, Option<PathBuf>)> + use<> {
         let mut abs_paths_opened: HashMap<PathBuf, HashSet<ProjectPath>> = HashMap::default();
         let mut history: HashMap<ProjectPath, (Option<PathBuf>, usize)> = HashMap::default();
 

crates/zeta2/src/zeta2.rs 🔗

@@ -50,7 +50,7 @@ pub mod udiff;
 mod xml_edits;
 
 use crate::assemble_excerpts::assemble_excerpts;
-use crate::prediction::EditPrediction;
+pub use crate::prediction::EditPrediction;
 pub use crate::prediction::EditPredictionId;
 pub use provider::ZetaEditPredictionProvider;
 
@@ -327,6 +327,14 @@ impl Event {
             }
         }
     }
+
+    pub fn project_path(&self, cx: &App) -> Option<project::ProjectPath> {
+        match self {
+            Event::BufferChange { new_snapshot, .. } => new_snapshot
+                .file()
+                .map(|f| project::ProjectPath::from_file(f.as_ref(), cx)),
+        }
+    }
 }
 
 impl Zeta {
@@ -401,7 +409,10 @@ impl Zeta {
         }
     }
 
-    pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
+    pub fn history_for_project(
+        &self,
+        project: &Entity<Project>,
+    ) -> impl DoubleEndedIterator<Item = &Event> {
         self.projects
             .get(&project.entity_id())
             .map(|project| project.events.iter())

crates/zeta_cli/Cargo.toml 🔗

@@ -49,6 +49,7 @@ settings.workspace = true
 shellexpand.workspace = true
 smol.workspace = true
 soa-rs = "0.8.1"
+sweep_ai.workspace = true
 terminal_view.workspace = true
 toml.workspace = true
 util.workspace = true

crates/zeta_cli/src/evaluate.rs 🔗

@@ -1,41 +1,25 @@
 use std::{
     collections::{BTreeSet, HashMap},
     io::{IsTerminal, Write},
-    path::PathBuf,
     sync::Arc,
 };
 
 use anyhow::Result;
-use clap::Args;
 use collections::HashSet;
 use gpui::{AsyncApp, Entity};
 use project::Project;
+use sweep_ai::SweepAi;
 use util::ResultExt as _;
 use zeta2::{Zeta, udiff::DiffLine};
 
 use crate::{
-    PromptFormat,
+    EvaluateArguments, PredictionOptions, PredictionProvider,
     example::{Example, NamedExample},
     headless::ZetaCliAppState,
     paths::print_run_data_dir,
-    predict::{CacheMode, PredictionDetails, zeta2_predict},
+    predict::{PredictionDetails, perform_predict, setup_sweep, setup_zeta},
 };
 
-#[derive(Debug, Args)]
-pub struct EvaluateArguments {
-    example_paths: Vec<PathBuf>,
-    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
-    prompt_format: PromptFormat,
-    #[arg(long)]
-    use_expected_context: bool,
-    #[clap(long, value_enum, default_value_t = CacheMode::default())]
-    cache: CacheMode,
-    #[clap(short, long, default_value_t = 1, alias = "repeat")]
-    repetitions: u16,
-    #[arg(long)]
-    skip_prediction: bool,
-}
-
 #[derive(Debug)]
 pub(crate) struct ExecutionData {
     execution_id: String,
@@ -52,38 +36,56 @@ pub async fn run_evaluate(
         eprintln!("No examples provided");
         return;
     }
+
     let all_tasks = args.example_paths.into_iter().map(|path| {
+        let options = args.options.clone();
         let app_state = app_state.clone();
         let example = NamedExample::load(&path).expect("Failed to load example");
 
         cx.spawn(async move |cx| {
-            let (project, zetas, _edited_buffers) = example
-                .setup_project(&app_state, args.repetitions, cx)
-                .await
-                .unwrap();
-
-            let tasks = zetas.into_iter().enumerate().map(|(repetition_ix, zeta)| {
-                let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
-                let example = example.clone();
-                let project = project.clone();
-
-                cx.spawn(async move |cx| {
-                    let name = example.name.clone();
-                    run_evaluate_one(
-                        example,
-                        repetition_ix,
-                        project,
-                        zeta,
-                        args.prompt_format,
-                        args.use_expected_context,
-                        !args.skip_prediction,
-                        args.cache,
-                        cx,
+            let project = example.setup_project(&app_state, cx).await.unwrap();
+
+            let providers = (0..args.repetitions)
+                .map(|_| {
+                    (
+                        setup_zeta(&project, &app_state, cx).unwrap(),
+                        if matches!(args.options.provider, PredictionProvider::Sweep) {
+                            Some(setup_sweep(&project, cx).unwrap())
+                        } else {
+                            None
+                        },
                     )
-                    .await
-                    .map_err(|err| (err, name, repetition_ix))
                 })
-            });
+                .collect::<Vec<_>>();
+
+            let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
+
+            let tasks =
+                providers
+                    .into_iter()
+                    .enumerate()
+                    .map(move |(repetition_ix, (zeta, sweep))| {
+                        let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
+                        let example = example.clone();
+                        let project = project.clone();
+                        let options = options.clone();
+
+                        cx.spawn(async move |cx| {
+                            let name = example.name.clone();
+                            run_evaluate_one(
+                                example,
+                                repetition_ix,
+                                project,
+                                zeta,
+                                sweep,
+                                options,
+                                !args.skip_prediction,
+                                cx,
+                            )
+                            .await
+                            .map_err(|err| (err, name, repetition_ix))
+                        })
+                    });
             futures::future::join_all(tasks).await
         })
     });
@@ -175,20 +177,18 @@ pub async fn run_evaluate_one(
     repetition_ix: Option<u16>,
     project: Entity<Project>,
     zeta: Entity<Zeta>,
-    prompt_format: PromptFormat,
-    use_expected_context: bool,
+    sweep: Option<Entity<SweepAi>>,
+    prediction_options: PredictionOptions,
     predict: bool,
-    cache_mode: CacheMode,
     cx: &mut AsyncApp,
 ) -> Result<(EvaluationResult, ExecutionData)> {
-    let predict_result = zeta2_predict(
+    let predict_result = perform_predict(
         example.clone(),
         project,
         zeta,
+        sweep,
         repetition_ix,
-        prompt_format,
-        use_expected_context,
-        cache_mode,
+        prediction_options,
         cx,
     )
     .await?;

crates/zeta_cli/src/example.rs 🔗

@@ -20,13 +20,13 @@ use futures::{
     lock::{Mutex, OwnedMutexGuard},
 };
 use futures::{FutureExt as _, future::Shared};
-use gpui::{AppContext as _, AsyncApp, Entity, Task, http_client::Url};
+use gpui::{AsyncApp, Entity, Task, http_client::Url};
 use language::{Anchor, Buffer};
 use project::{Project, ProjectPath};
 use pulldown_cmark::CowStr;
 use serde::{Deserialize, Serialize};
 use util::{paths::PathStyle, rel_path::RelPath};
-use zeta2::{Zeta, udiff::OpenedBuffers};
+use zeta2::udiff::OpenedBuffers;
 
 use crate::paths::{REPOS_DIR, WORKTREES_DIR};
 
@@ -318,12 +318,11 @@ impl NamedExample {
         }
     }
 
-    pub async fn setup_project<'a>(
-        &'a self,
+    pub async fn setup_project(
+        &self,
         app_state: &Arc<ZetaCliAppState>,
-        repetitions: u16,
         cx: &mut AsyncApp,
-    ) -> Result<(Entity<Project>, Vec<Entity<Zeta>>, OpenedBuffers<'a>)> {
+    ) -> Result<Entity<Project>> {
         let worktree_path = self.setup_worktree().await?;
 
         static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
@@ -365,33 +364,7 @@ impl NamedExample {
             })?
             .await;
 
-        let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
-
-        let zetas = (0..repetitions)
-            .map(|_| {
-                let zeta = cx.new(|cx| {
-                    zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
-                })?;
-
-                cx.subscribe(&buffer_store, {
-                    let project = project.clone();
-                    let zeta = zeta.clone();
-                    move |_, event, cx| match event {
-                        project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => {
-                            zeta.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
-                        }
-                        _ => {}
-                    }
-                })?
-                .detach();
-
-                anyhow::Ok(zeta)
-            })
-            .collect::<Result<Vec<_>>>()?;
-
-        let edited_buffers = self.apply_edit_history(&project, cx).await?;
-
-        anyhow::Ok((project, zetas, edited_buffers))
+        anyhow::Ok(project)
     }
 
     pub async fn setup_worktree(&self) -> Result<PathBuf> {

crates/zeta_cli/src/main.rs 🔗

@@ -7,13 +7,18 @@ mod source_location;
 mod syntax_retrieval_stats;
 mod util;
 
-use crate::evaluate::{EvaluateArguments, run_evaluate};
-use crate::example::{ExampleFormat, NamedExample};
-use crate::predict::{PredictArguments, run_zeta2_predict};
-use crate::syntax_retrieval_stats::retrieval_stats;
+use crate::{
+    evaluate::run_evaluate,
+    example::{ExampleFormat, NamedExample},
+    headless::ZetaCliAppState,
+    predict::run_predict,
+    source_location::SourceLocation,
+    syntax_retrieval_stats::retrieval_stats,
+    util::{open_buffer, open_buffer_with_language_server},
+};
 use ::util::paths::PathStyle;
 use anyhow::{Result, anyhow};
-use clap::{Args, Parser, Subcommand};
+use clap::{Args, Parser, Subcommand, ValueEnum};
 use cloud_llm_client::predict_edits_v3;
 use edit_prediction_context::{
     EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
@@ -28,10 +33,6 @@ use std::time::Duration;
 use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
 use zeta2::ContextMode;
 
-use crate::headless::ZetaCliAppState;
-use crate::source_location::SourceLocation;
-use crate::util::{open_buffer, open_buffer_with_language_server};
-
 #[derive(Parser, Debug)]
 #[command(name = "zeta")]
 struct ZetaCliArgs {
@@ -43,14 +44,10 @@ struct ZetaCliArgs {
 
 #[derive(Subcommand, Debug)]
 enum Command {
-    Zeta1 {
-        #[command(subcommand)]
-        command: Zeta1Command,
-    },
-    Zeta2 {
-        #[command(subcommand)]
-        command: Zeta2Command,
-    },
+    Context(ContextArgs),
+    ContextStats(ContextStatsArgs),
+    Predict(PredictArguments),
+    Eval(EvaluateArguments),
     ConvertExample {
         path: PathBuf,
         #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
@@ -59,49 +56,24 @@ enum Command {
     Clean,
 }
 
-#[derive(Subcommand, Debug)]
-enum Zeta1Command {
-    Context {
-        #[clap(flatten)]
-        context_args: ContextArgs,
-    },
-}
-
-#[derive(Subcommand, Debug)]
-enum Zeta2Command {
-    Syntax {
-        #[clap(flatten)]
-        args: Zeta2Args,
-        #[clap(flatten)]
-        syntax_args: Zeta2SyntaxArgs,
-        #[command(subcommand)]
-        command: Zeta2SyntaxCommand,
-    },
-    Predict(PredictArguments),
-    Eval(EvaluateArguments),
-}
-
-#[derive(Subcommand, Debug)]
-enum Zeta2SyntaxCommand {
-    Context {
-        #[clap(flatten)]
-        context_args: ContextArgs,
-    },
-    Stats {
-        #[arg(long)]
-        worktree: PathBuf,
-        #[arg(long)]
-        extension: Option<String>,
-        #[arg(long)]
-        limit: Option<usize>,
-        #[arg(long)]
-        skip: Option<usize>,
-    },
+#[derive(Debug, Args)]
+struct ContextStatsArgs {
+    #[arg(long)]
+    worktree: PathBuf,
+    #[arg(long)]
+    extension: Option<String>,
+    #[arg(long)]
+    limit: Option<usize>,
+    #[arg(long)]
+    skip: Option<usize>,
+    #[clap(flatten)]
+    zeta2_args: Zeta2Args,
 }
 
 #[derive(Debug, Args)]
-#[group(requires = "worktree")]
 struct ContextArgs {
+    #[arg(long)]
+    provider: ContextProvider,
     #[arg(long)]
     worktree: PathBuf,
     #[arg(long)]
@@ -110,9 +82,18 @@ struct ContextArgs {
     use_language_server: bool,
     #[arg(long)]
     edit_history: Option<FileOrStdin>,
+    #[clap(flatten)]
+    zeta2_args: Zeta2Args,
 }
 
-#[derive(Debug, Args)]
+#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
+enum ContextProvider {
+    Zeta1,
+    #[default]
+    Syntax,
+}
+
+#[derive(Clone, Debug, Args)]
 struct Zeta2Args {
     #[arg(long, default_value_t = 8192)]
     max_prompt_bytes: usize,
@@ -130,39 +111,111 @@ struct Zeta2Args {
     output_format: OutputFormat,
     #[arg(long, default_value_t = 42)]
     file_indexing_parallelism: usize,
-}
-
-#[derive(Debug, Args)]
-struct Zeta2SyntaxArgs {
     #[arg(long, default_value_t = false)]
     disable_imports_gathering: bool,
     #[arg(long, default_value_t = u8::MAX)]
     max_retrieved_definitions: u8,
 }
 
-fn syntax_args_to_options(
-    zeta2_args: &Zeta2Args,
-    syntax_args: &Zeta2SyntaxArgs,
-    omit_excerpt_overlaps: bool,
-) -> zeta2::ZetaOptions {
+#[derive(Debug, Args)]
+pub struct PredictArguments {
+    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
+    format: PredictionsOutputFormat,
+    example_path: PathBuf,
+    #[clap(flatten)]
+    options: PredictionOptions,
+}
+
+#[derive(Clone, Debug, Args)]
+pub struct PredictionOptions {
+    #[arg(long)]
+    use_expected_context: bool,
+    #[clap(flatten)]
+    zeta2: Zeta2Args,
+    #[clap(long)]
+    provider: PredictionProvider,
+    #[clap(long, value_enum, default_value_t = CacheMode::default())]
+    cache: CacheMode,
+}
+
+#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
+pub enum CacheMode {
+    /// Use cached LLM requests and responses, except when multiple repetitions are requested
+    #[default]
+    Auto,
+    /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
+    #[value(alias = "request")]
+    Requests,
+    /// Ignore existing cache entries for both LLM and search.
+    Skip,
+    /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
+    /// Useful for reproducing results and fixing bugs outside of search queries
+    Force,
+}
+
+impl CacheMode {
+    fn use_cached_llm_responses(&self) -> bool {
+        self.assert_not_auto();
+        matches!(self, CacheMode::Requests | CacheMode::Force)
+    }
+
+    fn use_cached_search_results(&self) -> bool {
+        self.assert_not_auto();
+        matches!(self, CacheMode::Force)
+    }
+
+    fn assert_not_auto(&self) {
+        assert_ne!(
+            *self,
+            CacheMode::Auto,
+            "Cache mode should not be auto at this point!"
+        );
+    }
+}
+
+#[derive(clap::ValueEnum, Debug, Clone)]
+pub enum PredictionsOutputFormat {
+    Json,
+    Md,
+    Diff,
+}
+
+#[derive(Debug, Args)]
+pub struct EvaluateArguments {
+    example_paths: Vec<PathBuf>,
+    #[clap(flatten)]
+    options: PredictionOptions,
+    #[clap(short, long, default_value_t = 1, alias = "repeat")]
+    repetitions: u16,
+    #[arg(long)]
+    skip_prediction: bool,
+}
+
+#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
+enum PredictionProvider {
+    #[default]
+    Zeta2,
+    Sweep,
+}
+
+fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions {
     zeta2::ZetaOptions {
         context: ContextMode::Syntax(EditPredictionContextOptions {
-            max_retrieved_declarations: syntax_args.max_retrieved_definitions,
-            use_imports: !syntax_args.disable_imports_gathering,
+            max_retrieved_declarations: args.max_retrieved_definitions,
+            use_imports: !args.disable_imports_gathering,
             excerpt: EditPredictionExcerptOptions {
-                max_bytes: zeta2_args.max_excerpt_bytes,
-                min_bytes: zeta2_args.min_excerpt_bytes,
-                target_before_cursor_over_total_bytes: zeta2_args
-                    .target_before_cursor_over_total_bytes,
+                max_bytes: args.max_excerpt_bytes,
+                min_bytes: args.min_excerpt_bytes,
+                target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
             },
             score: EditPredictionScoreOptions {
                 omit_excerpt_overlaps,
             },
         }),
-        max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
-        max_prompt_bytes: zeta2_args.max_prompt_bytes,
-        prompt_format: zeta2_args.prompt_format.into(),
-        file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
+        max_diagnostic_bytes: args.max_diagnostic_bytes,
+        max_prompt_bytes: args.max_prompt_bytes,
+        prompt_format: args.prompt_format.into(),
+        file_indexing_parallelism: args.file_indexing_parallelism,
         buffer_change_grouping_interval: Duration::ZERO,
     }
 }
@@ -320,8 +373,6 @@ async fn load_context(
 }
 
 async fn zeta2_syntax_context(
-    zeta2_args: Zeta2Args,
-    syntax_args: Zeta2SyntaxArgs,
     args: ContextArgs,
     app_state: &Arc<ZetaCliAppState>,
     cx: &mut AsyncApp,
@@ -347,7 +398,7 @@ async fn zeta2_syntax_context(
                 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
             });
             let indexing_done_task = zeta.update(cx, |zeta, cx| {
-                zeta.set_options(syntax_args_to_options(&zeta2_args, &syntax_args, true));
+                zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true));
                 zeta.register_buffer(&buffer, &project, cx);
                 zeta.wait_for_initial_indexing(&project, cx)
             });
@@ -362,7 +413,7 @@ async fn zeta2_syntax_context(
 
                 let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
 
-                match zeta2_args.output_format {
+                match args.zeta2_args.output_format {
                     OutputFormat::Prompt => anyhow::Ok(prompt_string),
                     OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
                     OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
@@ -427,57 +478,40 @@ fn main() {
                         panic!("Expected a command");
                     }
                 }
-                Some(Command::Zeta1 {
-                    command: Zeta1Command::Context { context_args },
-                }) => {
-                    let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
-                    let result = serde_json::to_string_pretty(&context.body).unwrap();
-                    println!("{}", result);
+                Some(Command::ContextStats(arguments)) => {
+                    let result = retrieval_stats(
+                        arguments.worktree,
+                        app_state,
+                        arguments.extension,
+                        arguments.limit,
+                        arguments.skip,
+                        zeta2_args_to_options(&arguments.zeta2_args, false),
+                        cx,
+                    )
+                    .await;
+                    println!("{}", result.unwrap());
                 }
-                Some(Command::Zeta2 { command }) => match command {
-                    Zeta2Command::Predict(arguments) => {
-                        run_zeta2_predict(arguments, &app_state, cx).await;
-                    }
-                    Zeta2Command::Eval(arguments) => {
-                        run_evaluate(arguments, &app_state, cx).await;
-                    }
-                    Zeta2Command::Syntax {
-                        args,
-                        syntax_args,
-                        command,
-                    } => {
-                        let result = match command {
-                            Zeta2SyntaxCommand::Context { context_args } => {
-                                zeta2_syntax_context(
-                                    args,
-                                    syntax_args,
-                                    context_args,
-                                    &app_state,
-                                    cx,
-                                )
+                Some(Command::Context(context_args)) => {
+                    let result = match context_args.provider {
+                        ContextProvider::Zeta1 => {
+                            let context =
+                                zeta1_context(context_args, &app_state, cx).await.unwrap();
+                            serde_json::to_string_pretty(&context.body).unwrap()
+                        }
+                        ContextProvider::Syntax => {
+                            zeta2_syntax_context(context_args, &app_state, cx)
                                 .await
-                            }
-                            Zeta2SyntaxCommand::Stats {
-                                worktree,
-                                extension,
-                                limit,
-                                skip,
-                            } => {
-                                retrieval_stats(
-                                    worktree,
-                                    app_state,
-                                    extension,
-                                    limit,
-                                    skip,
-                                    syntax_args_to_options(&args, &syntax_args, false),
-                                    cx,
-                                )
-                                .await
-                            }
-                        };
-                        println!("{}", result.unwrap());
-                    }
-                },
+                                .unwrap()
+                        }
+                    };
+                    println!("{}", result);
+                }
+                Some(Command::Predict(arguments)) => {
+                    run_predict(arguments, &app_state, cx).await;
+                }
+                Some(Command::Eval(arguments)) => {
+                    run_evaluate(arguments, &app_state, cx).await;
+                }
                 Some(Command::ConvertExample {
                     path,
                     output_format,

crates/zeta_cli/src/predict.rs 🔗

@@ -1,16 +1,18 @@
-use crate::PromptFormat;
 use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
 use crate::headless::ZetaCliAppState;
 use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
+use crate::{
+    CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat,
+};
 use ::serde::Serialize;
 use anyhow::{Context, Result, anyhow};
-use clap::{Args, ValueEnum};
 use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
 use collections::HashMap;
 use futures::StreamExt as _;
 use gpui::{AppContext, AsyncApp, Entity};
 use language::{Anchor, Buffer, Point};
 use project::Project;
+use project::buffer_store::BufferStoreEvent;
 use serde::Deserialize;
 use std::fs;
 use std::io::{IsTerminal, Write};
@@ -19,98 +21,86 @@ use std::path::PathBuf;
 use std::sync::Arc;
 use std::sync::Mutex;
 use std::time::{Duration, Instant};
+use sweep_ai::SweepAi;
 use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
 
-#[derive(Debug, Args)]
-pub struct PredictArguments {
-    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
-    prompt_format: PromptFormat,
-    #[arg(long)]
-    use_expected_context: bool,
-    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
-    format: PredictionsOutputFormat,
-    example_path: PathBuf,
-    #[clap(long, value_enum, default_value_t = CacheMode::default())]
-    cache: CacheMode,
-}
-
-#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
-pub enum CacheMode {
-    /// Use cached LLM requests and responses, except when multiple repetitions are requested
-    #[default]
-    Auto,
-    /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
-    #[value(alias = "request")]
-    Requests,
-    /// Ignore existing cache entries for both LLM and search.
-    Skip,
-    /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
-    /// Useful for reproducing results and fixing bugs outside of search queries
-    Force,
-}
-
-impl CacheMode {
-    fn use_cached_llm_responses(&self) -> bool {
-        self.assert_not_auto();
-        matches!(self, CacheMode::Requests | CacheMode::Force)
-    }
-
-    fn use_cached_search_results(&self) -> bool {
-        self.assert_not_auto();
-        matches!(self, CacheMode::Force)
-    }
-
-    fn assert_not_auto(&self) {
-        assert_ne!(
-            *self,
-            CacheMode::Auto,
-            "Cache mode should not be auto at this point!"
-        );
-    }
-}
-
-#[derive(clap::ValueEnum, Debug, Clone)]
-pub enum PredictionsOutputFormat {
-    Json,
-    Md,
-    Diff,
-}
-
-pub async fn run_zeta2_predict(
+pub async fn run_predict(
     args: PredictArguments,
     app_state: &Arc<ZetaCliAppState>,
     cx: &mut AsyncApp,
 ) {
     let example = NamedExample::load(args.example_path).unwrap();
-    let (project, mut zetas, _edited_buffers) =
-        example.setup_project(app_state, 1, cx).await.unwrap();
-    let result = zeta2_predict(
-        example,
-        project,
-        zetas.remove(0),
-        None,
-        args.prompt_format,
-        args.use_expected_context,
-        args.cache,
-        cx,
-    )
-    .await
-    .unwrap();
+    let project = example.setup_project(app_state, cx).await.unwrap();
+    let zeta = setup_zeta(&project, app_state, cx).unwrap();
+    let sweep = if matches!(args.options.provider, PredictionProvider::Sweep) {
+        Some(setup_sweep(&project, cx).unwrap())
+    } else {
+        None
+    };
+    let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
+    let result = perform_predict(example, project, zeta, sweep, None, args.options, cx)
+        .await
+        .unwrap();
     result.write(args.format, std::io::stdout()).unwrap();
 
     print_run_data_dir(true, std::io::stdout().is_terminal());
 }
 
-pub async fn zeta2_predict(
+pub fn setup_zeta(
+    project: &Entity<Project>,
+    app_state: &Arc<ZetaCliAppState>,
+    cx: &mut AsyncApp,
+) -> Result<Entity<Zeta>> {
+    let zeta =
+        cx.new(|cx| zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?;
+
+    let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
+
+    cx.subscribe(&buffer_store, {
+        let project = project.clone();
+        let zeta = zeta.clone();
+        move |_, event, cx| match event {
+            BufferStoreEvent::BufferAdded(buffer) => {
+                zeta.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
+            }
+            _ => {}
+        }
+    })?
+    .detach();
+
+    anyhow::Ok(zeta)
+}
+
+pub fn setup_sweep(project: &Entity<Project>, cx: &mut AsyncApp) -> Result<Entity<SweepAi>> {
+    let sweep = cx.new(|cx| SweepAi::new(cx))?;
+
+    let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
+
+    cx.subscribe(&buffer_store, {
+        let project = project.clone();
+        let sweep = sweep.clone();
+        move |_, event, cx| match event {
+            BufferStoreEvent::BufferAdded(buffer) => {
+                sweep.update(cx, |sweep, cx| sweep.register_buffer(&buffer, &project, cx));
+            }
+            _ => {}
+        }
+    })?
+    .detach();
+
+    anyhow::Ok(sweep)
+}
+
+pub async fn perform_predict(
     example: NamedExample,
     project: Entity<Project>,
     zeta: Entity<Zeta>,
+    sweep: Option<Entity<SweepAi>>,
     repetition_ix: Option<u16>,
-    prompt_format: PromptFormat,
-    use_expected_context: bool,
-    mut cache_mode: CacheMode,
+    options: PredictionOptions,
     cx: &mut AsyncApp,
 ) -> Result<PredictionDetails> {
+    let mut cache_mode = options.cache;
     if repetition_ix.is_some() {
         if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
             panic!("Repetitions are not supported in Auto cache mode");
@@ -148,94 +138,8 @@ pub async fn zeta2_predict(
     let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
 
     let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
-    let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
-
-    let debug_task = cx.background_spawn({
-        let result = result.clone();
-        async move {
-            let mut start_time = None;
-            let mut search_queries_generated_at = None;
-            let mut search_queries_executed_at = None;
-            while let Some(event) = debug_rx.next().await {
-                match event {
-                    zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
-                        start_time = Some(info.timestamp);
-                        fs::write(
-                            example_run_dir.join("search_prompt.md"),
-                            &info.search_prompt,
-                        )?;
-                    }
-                    zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
-                        search_queries_generated_at = Some(info.timestamp);
-                        fs::write(
-                            example_run_dir.join("search_queries.json"),
-                            serde_json::to_string_pretty(&info.search_queries).unwrap(),
-                        )?;
-                    }
-                    zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
-                        search_queries_executed_at = Some(info.timestamp);
-                    }
-                    zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
-                    zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
-                        let prediction_started_at = Instant::now();
-                        start_time.get_or_insert(prediction_started_at);
-                        let prompt = request.local_prompt.unwrap_or_default();
-                        fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
-
-                        {
-                            let mut result = result.lock().unwrap();
-                            result.prompt_len = prompt.chars().count();
-
-                            for included_file in request.request.included_files {
-                                let insertions =
-                                    vec![(request.request.cursor_point, CURSOR_MARKER)];
-                                result.excerpts.extend(included_file.excerpts.iter().map(
-                                    |excerpt| ActualExcerpt {
-                                        path: included_file.path.components().skip(1).collect(),
-                                        text: String::from(excerpt.text.as_ref()),
-                                    },
-                                ));
-                                write_codeblock(
-                                    &included_file.path,
-                                    included_file.excerpts.iter(),
-                                    if included_file.path == request.request.excerpt_path {
-                                        &insertions
-                                    } else {
-                                        &[]
-                                    },
-                                    included_file.max_row,
-                                    false,
-                                    &mut result.excerpts_text,
-                                );
-                            }
-                        }
 
-                        let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
-                        let response = zeta2::text_from_response(response).unwrap_or_default();
-                        let prediction_finished_at = Instant::now();
-                        fs::write(example_run_dir.join("prediction_response.md"), &response)?;
-
-                        let mut result = result.lock().unwrap();
-                        result.generated_len = response.chars().count();
-
-                        if !use_expected_context {
-                            result.planning_search_time =
-                                Some(search_queries_generated_at.unwrap() - start_time.unwrap());
-                            result.running_search_time = Some(
-                                search_queries_executed_at.unwrap()
-                                    - search_queries_generated_at.unwrap(),
-                            );
-                        }
-                        result.prediction_time = prediction_finished_at - prediction_started_at;
-                        result.total_time = prediction_finished_at - start_time.unwrap();
-
-                        break;
-                    }
-                }
-            }
-            anyhow::Ok(())
-        }
-    });
+    let prompt_format = options.zeta2.prompt_format;
 
     zeta.update(cx, |zeta, _cx| {
         let mut options = zeta.options().clone();
@@ -243,55 +147,194 @@ pub async fn zeta2_predict(
         zeta.set_options(options);
     })?;
 
-    if use_expected_context {
-        let context_excerpts_tasks = example
-            .example
-            .expected_context
-            .iter()
-            .flat_map(|section| {
-                section.alternatives[0].excerpts.iter().map(|excerpt| {
-                    resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
-                })
-            })
-            .collect::<Vec<_>>();
-        let context_excerpts_vec = futures::future::try_join_all(context_excerpts_tasks).await?;
-
-        let mut context_excerpts = HashMap::default();
-        for (buffer, mut excerpts) in context_excerpts_vec {
-            context_excerpts
-                .entry(buffer)
-                .or_insert(Vec::new())
-                .append(&mut excerpts);
-        }
+    let prediction = match options.provider {
+        crate::PredictionProvider::Zeta2 => {
+            let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
+
+            let debug_task = cx.background_spawn({
+                let result = result.clone();
+                async move {
+                    let mut start_time = None;
+                    let mut search_queries_generated_at = None;
+                    let mut search_queries_executed_at = None;
+                    while let Some(event) = debug_rx.next().await {
+                        match event {
+                            zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
+                                start_time = Some(info.timestamp);
+                                fs::write(
+                                    example_run_dir.join("search_prompt.md"),
+                                    &info.search_prompt,
+                                )?;
+                            }
+                            zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
+                                search_queries_generated_at = Some(info.timestamp);
+                                fs::write(
+                                    example_run_dir.join("search_queries.json"),
+                                    serde_json::to_string_pretty(&info.search_queries).unwrap(),
+                                )?;
+                            }
+                            zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
+                                search_queries_executed_at = Some(info.timestamp);
+                            }
+                            zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
+                            zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
+                                let prediction_started_at = Instant::now();
+                                start_time.get_or_insert(prediction_started_at);
+                                let prompt = request.local_prompt.unwrap_or_default();
+                                fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
+
+                                {
+                                    let mut result = result.lock().unwrap();
+                                    result.prompt_len = prompt.chars().count();
+
+                                    for included_file in request.request.included_files {
+                                        let insertions =
+                                            vec![(request.request.cursor_point, CURSOR_MARKER)];
+                                        result.excerpts.extend(included_file.excerpts.iter().map(
+                                            |excerpt| {
+                                                ActualExcerpt {
+                                                    path: included_file
+                                                        .path
+                                                        .components()
+                                                        .skip(1)
+                                                        .collect(),
+                                                    text: String::from(excerpt.text.as_ref()),
+                                                }
+                                            },
+                                        ));
+                                        write_codeblock(
+                                            &included_file.path,
+                                            included_file.excerpts.iter(),
+                                            if included_file.path == request.request.excerpt_path {
+                                                &insertions
+                                            } else {
+                                                &[]
+                                            },
+                                            included_file.max_row,
+                                            false,
+                                            &mut result.excerpts_text,
+                                        );
+                                    }
+                                }
+
+                                let response =
+                                    request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
+                                let response =
+                                    zeta2::text_from_response(response).unwrap_or_default();
+                                let prediction_finished_at = Instant::now();
+                                fs::write(
+                                    example_run_dir.join("prediction_response.md"),
+                                    &response,
+                                )?;
+
+                                let mut result = result.lock().unwrap();
+                                result.generated_len = response.chars().count();
+
+                                if !options.use_expected_context {
+                                    result.planning_search_time = Some(
+                                        search_queries_generated_at.unwrap() - start_time.unwrap(),
+                                    );
+                                    result.running_search_time = Some(
+                                        search_queries_executed_at.unwrap()
+                                            - search_queries_generated_at.unwrap(),
+                                    );
+                                }
+                                result.prediction_time =
+                                    prediction_finished_at - prediction_started_at;
+                                result.total_time = prediction_finished_at - start_time.unwrap();
+
+                                break;
+                            }
+                        }
+                    }
+                    anyhow::Ok(())
+                }
+            });
+
+            if options.use_expected_context {
+                let context_excerpts_tasks = example
+                    .example
+                    .expected_context
+                    .iter()
+                    .flat_map(|section| {
+                        section.alternatives[0].excerpts.iter().map(|excerpt| {
+                            resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
+                        })
+                    })
+                    .collect::<Vec<_>>();
+                let context_excerpts_vec =
+                    futures::future::try_join_all(context_excerpts_tasks).await?;
+
+                let mut context_excerpts = HashMap::default();
+                for (buffer, mut excerpts) in context_excerpts_vec {
+                    context_excerpts
+                        .entry(buffer)
+                        .or_insert(Vec::new())
+                        .append(&mut excerpts);
+                }
 
-        zeta.update(cx, |zeta, _cx| {
-            zeta.set_context(project.clone(), context_excerpts)
-        })?;
-    } else {
-        zeta.update(cx, |zeta, cx| {
-            zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
-        })?
-        .await?;
-    }
+                zeta.update(cx, |zeta, _cx| {
+                    zeta.set_context(project.clone(), context_excerpts)
+                })?;
+            } else {
+                zeta.update(cx, |zeta, cx| {
+                    zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
+                })?
+                .await?;
+            }
 
-    let prediction = zeta
-        .update(cx, |zeta, cx| {
-            zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
-        })?
-        .await?;
+            let prediction = zeta
+                .update(cx, |zeta, cx| {
+                    zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
+                })?
+                .await?
+                .map(|prediction| (prediction.buffer, prediction.snapshot, prediction.edits));
+
+            debug_task.await?;
+
+            prediction
+        }
+        crate::PredictionProvider::Sweep => sweep
+            .unwrap()
+            .update(cx, |sweep, cx| {
+                let mut recent_paths = Vec::new();
+                for path in zeta
+                    .read(cx)
+                    .history_for_project(&project)
+                    .rev()
+                    .filter_map(|event| event.project_path(cx))
+                {
+                    if !recent_paths.contains(&path) {
+                        recent_paths.push(path);
+                    }
+                }
 
-    debug_task.await?;
+                sweep.request_completion(
+                    &project,
+                    recent_paths.into_iter(),
+                    &cursor_buffer,
+                    cursor_anchor,
+                    cx,
+                )
+            })?
+            .await?
+            .map(
+                |sweep_ai::EditPrediction {
+                     edits, snapshot, ..
+                 }| { (cursor_buffer.clone(), snapshot, edits) },
+            ),
+    };
 
     let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
+
     result.diff = prediction
-        .map(|prediction| {
-            let old_text = prediction.snapshot.text();
-            let new_text = prediction
-                .buffer
+        .map(|(buffer, snapshot, edits)| {
+            let old_text = snapshot.text();
+            let new_text = buffer
                 .update(cx, |buffer, cx| {
                     let branch = buffer.branch(cx);
                     branch.update(cx, |branch, cx| {
-                        branch.edit(prediction.edits.iter().cloned(), None, cx);
+                        branch.edit(edits.iter().cloned(), None, cx);
                         branch.text()
                     })
                 })