Cargo.lock 🔗
@@ -21864,6 +21864,7 @@ dependencies = [
"shellexpand 2.1.2",
"smol",
"soa-rs",
+ "sweep_ai",
"terminal_view",
"toml 0.8.23",
"util",
Max Brunsfeld , Ben Kunkle , and Agus created
This PR restructures the subcommands in `zeta-cli`, so that the
prediction engine (currently `zeta1` vs `zeta2`) is no longer the
highest order subcommand. Instead, there is just one layer of
subcommands: `eval`, `predict`, `context`, etc. Within these commands,
there are flags for using `zeta1`, `zeta2`, and now `sweep`.
Release Notes:
- N/A
---------
Co-authored-by: Ben Kunkle <ben@zed.dev>
Co-authored-by: Agus <agus@zed.dev>
Cargo.lock | 1
crates/sweep_ai/src/sweep_ai.rs | 64 ++--
crates/workspace/src/workspace.rs | 2
crates/zeta2/src/zeta2.rs | 15
crates/zeta_cli/Cargo.toml | 1
crates/zeta_cli/src/evaluate.rs | 102 +++---
crates/zeta_cli/src/example.rs | 39 --
crates/zeta_cli/src/main.rs | 292 +++++++++++---------
crates/zeta_cli/src/predict.rs | 449 ++++++++++++++++++--------------
9 files changed, 517 insertions(+), 448 deletions(-)
@@ -21864,6 +21864,7 @@ dependencies = [
"shellexpand 2.1.2",
"smol",
"soa-rs",
+ "sweep_ai",
"terminal_view",
"toml 0.8.23",
"util",
@@ -11,7 +11,7 @@ use http_client::{AsyncBody, Method};
use language::{
Anchor, Buffer, BufferSnapshot, EditPreview, Point, ToOffset as _, ToPoint, text_diff,
};
-use project::Project;
+use project::{Project, ProjectPath};
use release_channel::{AppCommitSha, AppVersion};
use std::collections::{VecDeque, hash_map};
use std::fmt::{self, Display};
@@ -48,11 +48,11 @@ impl Global for SweepAiGlobal {}
#[derive(Clone)]
pub struct EditPrediction {
- id: EditPredictionId,
- path: Arc<Path>,
- edits: Arc<[(Range<Anchor>, Arc<str>)]>,
- snapshot: BufferSnapshot,
- edit_preview: EditPreview,
+ pub id: EditPredictionId,
+ pub path: Arc<Path>,
+ pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
+ pub snapshot: BufferSnapshot,
+ pub edit_preview: EditPreview,
}
impl EditPrediction {
@@ -110,7 +110,7 @@ impl SweepAi {
}
}
- fn new(cx: &mut Context<Self>) -> Self {
+ pub fn new(cx: &mut Context<Self>) -> Self {
Self {
api_token: std::env::var("SWEEP_AI_TOKEN").ok(),
projects: HashMap::default(),
@@ -195,8 +195,8 @@ impl SweepAi {
pub fn request_completion(
&mut self,
- workspace: &WeakEntity<Workspace>,
project: &Entity<Project>,
+ recent_buffers: impl Iterator<Item = ProjectPath>,
active_buffer: &Entity<Buffer>,
position: language::Anchor,
cx: &mut Context<Self>,
@@ -223,26 +223,17 @@ impl SweepAi {
let events = project_state.events.clone();
let http_client = cx.http_client();
- let Some(recent_buffers) = workspace
- .read_with(cx, |workspace, cx| {
- workspace
- .recent_navigation_history_iter(cx)
- .filter_map(|(project_path, _)| {
- let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
-
- if active_buffer == &buffer {
- None
- } else {
- Some(buffer.read(cx).snapshot())
- }
- })
- .take(3)
- .collect::<Vec<_>>()
+ let recent_buffer_snapshots = recent_buffers
+ .filter_map(|project_path| {
+ let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
+ if active_buffer == &buffer {
+ None
+ } else {
+ Some(buffer.read(cx).snapshot())
+ }
})
- .log_err()
- else {
- return Task::ready(Ok(None));
- };
+ .take(3)
+ .collect::<Vec<_>>();
let result = cx.background_spawn({
let full_path = full_path.clone();
@@ -255,7 +246,7 @@ impl SweepAi {
writeln!(&mut recent_changes, "{event}")?;
}
- let file_chunks = recent_buffers
+ let file_chunks = recent_buffer_snapshots
.into_iter()
.map(|snapshot| {
let end_point = language::Point::new(30, 0).min(snapshot.max_point());
@@ -623,8 +614,23 @@ impl edit_prediction::EditPredictionProvider for SweepAiEditPredictionProvider {
let completion_request = this.update(cx, |this, cx| {
this.last_request_timestamp = Instant::now();
+
this.sweep_ai.update(cx, |sweep_ai, cx| {
- sweep_ai.request_completion(&workspace, &project, &buffer, position, cx)
+ let Some(recent_buffers) = workspace
+ .read_with(cx, |workspace, cx| {
+ workspace.recent_navigation_history_iter(cx)
+ })
+ .log_err()
+ else {
+ return Task::ready(Ok(None));
+ };
+ sweep_ai.request_completion(
+ &project,
+ recent_buffers.map(move |(project_path, _)| project_path),
+ &buffer,
+ position,
+ cx,
+ )
})
});
@@ -1845,7 +1845,7 @@ impl Workspace {
pub fn recent_navigation_history_iter(
&self,
cx: &App,
- ) -> impl Iterator<Item = (ProjectPath, Option<PathBuf>)> {
+ ) -> impl Iterator<Item = (ProjectPath, Option<PathBuf>)> + use<> {
let mut abs_paths_opened: HashMap<PathBuf, HashSet<ProjectPath>> = HashMap::default();
let mut history: HashMap<ProjectPath, (Option<PathBuf>, usize)> = HashMap::default();
@@ -50,7 +50,7 @@ pub mod udiff;
mod xml_edits;
use crate::assemble_excerpts::assemble_excerpts;
-use crate::prediction::EditPrediction;
+pub use crate::prediction::EditPrediction;
pub use crate::prediction::EditPredictionId;
pub use provider::ZetaEditPredictionProvider;
@@ -327,6 +327,14 @@ impl Event {
}
}
}
+
+ pub fn project_path(&self, cx: &App) -> Option<project::ProjectPath> {
+ match self {
+ Event::BufferChange { new_snapshot, .. } => new_snapshot
+ .file()
+ .map(|f| project::ProjectPath::from_file(f.as_ref(), cx)),
+ }
+ }
}
impl Zeta {
@@ -401,7 +409,10 @@ impl Zeta {
}
}
- pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
+ pub fn history_for_project(
+ &self,
+ project: &Entity<Project>,
+ ) -> impl DoubleEndedIterator<Item = &Event> {
self.projects
.get(&project.entity_id())
.map(|project| project.events.iter())
@@ -49,6 +49,7 @@ settings.workspace = true
shellexpand.workspace = true
smol.workspace = true
soa-rs = "0.8.1"
+sweep_ai.workspace = true
terminal_view.workspace = true
toml.workspace = true
util.workspace = true
@@ -1,41 +1,25 @@
use std::{
collections::{BTreeSet, HashMap},
io::{IsTerminal, Write},
- path::PathBuf,
sync::Arc,
};
use anyhow::Result;
-use clap::Args;
use collections::HashSet;
use gpui::{AsyncApp, Entity};
use project::Project;
+use sweep_ai::SweepAi;
use util::ResultExt as _;
use zeta2::{Zeta, udiff::DiffLine};
use crate::{
- PromptFormat,
+ EvaluateArguments, PredictionOptions, PredictionProvider,
example::{Example, NamedExample},
headless::ZetaCliAppState,
paths::print_run_data_dir,
- predict::{CacheMode, PredictionDetails, zeta2_predict},
+ predict::{PredictionDetails, perform_predict, setup_sweep, setup_zeta},
};
-#[derive(Debug, Args)]
-pub struct EvaluateArguments {
- example_paths: Vec<PathBuf>,
- #[arg(long, value_enum, default_value_t = PromptFormat::default())]
- prompt_format: PromptFormat,
- #[arg(long)]
- use_expected_context: bool,
- #[clap(long, value_enum, default_value_t = CacheMode::default())]
- cache: CacheMode,
- #[clap(short, long, default_value_t = 1, alias = "repeat")]
- repetitions: u16,
- #[arg(long)]
- skip_prediction: bool,
-}
-
#[derive(Debug)]
pub(crate) struct ExecutionData {
execution_id: String,
@@ -52,38 +36,56 @@ pub async fn run_evaluate(
eprintln!("No examples provided");
return;
}
+
let all_tasks = args.example_paths.into_iter().map(|path| {
+ let options = args.options.clone();
let app_state = app_state.clone();
let example = NamedExample::load(&path).expect("Failed to load example");
cx.spawn(async move |cx| {
- let (project, zetas, _edited_buffers) = example
- .setup_project(&app_state, args.repetitions, cx)
- .await
- .unwrap();
-
- let tasks = zetas.into_iter().enumerate().map(|(repetition_ix, zeta)| {
- let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
- let example = example.clone();
- let project = project.clone();
-
- cx.spawn(async move |cx| {
- let name = example.name.clone();
- run_evaluate_one(
- example,
- repetition_ix,
- project,
- zeta,
- args.prompt_format,
- args.use_expected_context,
- !args.skip_prediction,
- args.cache,
- cx,
+ let project = example.setup_project(&app_state, cx).await.unwrap();
+
+ let providers = (0..args.repetitions)
+ .map(|_| {
+ (
+ setup_zeta(&project, &app_state, cx).unwrap(),
+ if matches!(args.options.provider, PredictionProvider::Sweep) {
+ Some(setup_sweep(&project, cx).unwrap())
+ } else {
+ None
+ },
)
- .await
- .map_err(|err| (err, name, repetition_ix))
})
- });
+ .collect::<Vec<_>>();
+
+ let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
+
+ let tasks =
+ providers
+ .into_iter()
+ .enumerate()
+ .map(move |(repetition_ix, (zeta, sweep))| {
+ let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
+ let example = example.clone();
+ let project = project.clone();
+ let options = options.clone();
+
+ cx.spawn(async move |cx| {
+ let name = example.name.clone();
+ run_evaluate_one(
+ example,
+ repetition_ix,
+ project,
+ zeta,
+ sweep,
+ options,
+ !args.skip_prediction,
+ cx,
+ )
+ .await
+ .map_err(|err| (err, name, repetition_ix))
+ })
+ });
futures::future::join_all(tasks).await
})
});
@@ -175,20 +177,18 @@ pub async fn run_evaluate_one(
repetition_ix: Option<u16>,
project: Entity<Project>,
zeta: Entity<Zeta>,
- prompt_format: PromptFormat,
- use_expected_context: bool,
+ sweep: Option<Entity<SweepAi>>,
+ prediction_options: PredictionOptions,
predict: bool,
- cache_mode: CacheMode,
cx: &mut AsyncApp,
) -> Result<(EvaluationResult, ExecutionData)> {
- let predict_result = zeta2_predict(
+ let predict_result = perform_predict(
example.clone(),
project,
zeta,
+ sweep,
repetition_ix,
- prompt_format,
- use_expected_context,
- cache_mode,
+ prediction_options,
cx,
)
.await?;
@@ -20,13 +20,13 @@ use futures::{
lock::{Mutex, OwnedMutexGuard},
};
use futures::{FutureExt as _, future::Shared};
-use gpui::{AppContext as _, AsyncApp, Entity, Task, http_client::Url};
+use gpui::{AsyncApp, Entity, Task, http_client::Url};
use language::{Anchor, Buffer};
use project::{Project, ProjectPath};
use pulldown_cmark::CowStr;
use serde::{Deserialize, Serialize};
use util::{paths::PathStyle, rel_path::RelPath};
-use zeta2::{Zeta, udiff::OpenedBuffers};
+use zeta2::udiff::OpenedBuffers;
use crate::paths::{REPOS_DIR, WORKTREES_DIR};
@@ -318,12 +318,11 @@ impl NamedExample {
}
}
- pub async fn setup_project<'a>(
- &'a self,
+ pub async fn setup_project(
+ &self,
app_state: &Arc<ZetaCliAppState>,
- repetitions: u16,
cx: &mut AsyncApp,
- ) -> Result<(Entity<Project>, Vec<Entity<Zeta>>, OpenedBuffers<'a>)> {
+ ) -> Result<Entity<Project>> {
let worktree_path = self.setup_worktree().await?;
static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
@@ -365,33 +364,7 @@ impl NamedExample {
})?
.await;
- let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
-
- let zetas = (0..repetitions)
- .map(|_| {
- let zeta = cx.new(|cx| {
- zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
- })?;
-
- cx.subscribe(&buffer_store, {
- let project = project.clone();
- let zeta = zeta.clone();
- move |_, event, cx| match event {
- project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => {
- zeta.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
- }
- _ => {}
- }
- })?
- .detach();
-
- anyhow::Ok(zeta)
- })
- .collect::<Result<Vec<_>>>()?;
-
- let edited_buffers = self.apply_edit_history(&project, cx).await?;
-
- anyhow::Ok((project, zetas, edited_buffers))
+ anyhow::Ok(project)
}
pub async fn setup_worktree(&self) -> Result<PathBuf> {
@@ -7,13 +7,18 @@ mod source_location;
mod syntax_retrieval_stats;
mod util;
-use crate::evaluate::{EvaluateArguments, run_evaluate};
-use crate::example::{ExampleFormat, NamedExample};
-use crate::predict::{PredictArguments, run_zeta2_predict};
-use crate::syntax_retrieval_stats::retrieval_stats;
+use crate::{
+ evaluate::run_evaluate,
+ example::{ExampleFormat, NamedExample},
+ headless::ZetaCliAppState,
+ predict::run_predict,
+ source_location::SourceLocation,
+ syntax_retrieval_stats::retrieval_stats,
+ util::{open_buffer, open_buffer_with_language_server},
+};
use ::util::paths::PathStyle;
use anyhow::{Result, anyhow};
-use clap::{Args, Parser, Subcommand};
+use clap::{Args, Parser, Subcommand, ValueEnum};
use cloud_llm_client::predict_edits_v3;
use edit_prediction_context::{
EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
@@ -28,10 +33,6 @@ use std::time::Duration;
use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
use zeta2::ContextMode;
-use crate::headless::ZetaCliAppState;
-use crate::source_location::SourceLocation;
-use crate::util::{open_buffer, open_buffer_with_language_server};
-
#[derive(Parser, Debug)]
#[command(name = "zeta")]
struct ZetaCliArgs {
@@ -43,14 +44,10 @@ struct ZetaCliArgs {
#[derive(Subcommand, Debug)]
enum Command {
- Zeta1 {
- #[command(subcommand)]
- command: Zeta1Command,
- },
- Zeta2 {
- #[command(subcommand)]
- command: Zeta2Command,
- },
+ Context(ContextArgs),
+ ContextStats(ContextStatsArgs),
+ Predict(PredictArguments),
+ Eval(EvaluateArguments),
ConvertExample {
path: PathBuf,
#[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
@@ -59,49 +56,24 @@ enum Command {
Clean,
}
-#[derive(Subcommand, Debug)]
-enum Zeta1Command {
- Context {
- #[clap(flatten)]
- context_args: ContextArgs,
- },
-}
-
-#[derive(Subcommand, Debug)]
-enum Zeta2Command {
- Syntax {
- #[clap(flatten)]
- args: Zeta2Args,
- #[clap(flatten)]
- syntax_args: Zeta2SyntaxArgs,
- #[command(subcommand)]
- command: Zeta2SyntaxCommand,
- },
- Predict(PredictArguments),
- Eval(EvaluateArguments),
-}
-
-#[derive(Subcommand, Debug)]
-enum Zeta2SyntaxCommand {
- Context {
- #[clap(flatten)]
- context_args: ContextArgs,
- },
- Stats {
- #[arg(long)]
- worktree: PathBuf,
- #[arg(long)]
- extension: Option<String>,
- #[arg(long)]
- limit: Option<usize>,
- #[arg(long)]
- skip: Option<usize>,
- },
+#[derive(Debug, Args)]
+struct ContextStatsArgs {
+ #[arg(long)]
+ worktree: PathBuf,
+ #[arg(long)]
+ extension: Option<String>,
+ #[arg(long)]
+ limit: Option<usize>,
+ #[arg(long)]
+ skip: Option<usize>,
+ #[clap(flatten)]
+ zeta2_args: Zeta2Args,
}
#[derive(Debug, Args)]
-#[group(requires = "worktree")]
struct ContextArgs {
+ #[arg(long)]
+ provider: ContextProvider,
#[arg(long)]
worktree: PathBuf,
#[arg(long)]
@@ -110,9 +82,18 @@ struct ContextArgs {
use_language_server: bool,
#[arg(long)]
edit_history: Option<FileOrStdin>,
+ #[clap(flatten)]
+ zeta2_args: Zeta2Args,
}
-#[derive(Debug, Args)]
+#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
+enum ContextProvider {
+ Zeta1,
+ #[default]
+ Syntax,
+}
+
+#[derive(Clone, Debug, Args)]
struct Zeta2Args {
#[arg(long, default_value_t = 8192)]
max_prompt_bytes: usize,
@@ -130,39 +111,111 @@ struct Zeta2Args {
output_format: OutputFormat,
#[arg(long, default_value_t = 42)]
file_indexing_parallelism: usize,
-}
-
-#[derive(Debug, Args)]
-struct Zeta2SyntaxArgs {
#[arg(long, default_value_t = false)]
disable_imports_gathering: bool,
#[arg(long, default_value_t = u8::MAX)]
max_retrieved_definitions: u8,
}
-fn syntax_args_to_options(
- zeta2_args: &Zeta2Args,
- syntax_args: &Zeta2SyntaxArgs,
- omit_excerpt_overlaps: bool,
-) -> zeta2::ZetaOptions {
+#[derive(Debug, Args)]
+pub struct PredictArguments {
+ #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
+ format: PredictionsOutputFormat,
+ example_path: PathBuf,
+ #[clap(flatten)]
+ options: PredictionOptions,
+}
+
+#[derive(Clone, Debug, Args)]
+pub struct PredictionOptions {
+ #[arg(long)]
+ use_expected_context: bool,
+ #[clap(flatten)]
+ zeta2: Zeta2Args,
+ #[clap(long)]
+ provider: PredictionProvider,
+ #[clap(long, value_enum, default_value_t = CacheMode::default())]
+ cache: CacheMode,
+}
+
+#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
+pub enum CacheMode {
+ /// Use cached LLM requests and responses, except when multiple repetitions are requested
+ #[default]
+ Auto,
+ /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
+ #[value(alias = "request")]
+ Requests,
+ /// Ignore existing cache entries for both LLM and search.
+ Skip,
+ /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
+ /// Useful for reproducing results and fixing bugs outside of search queries
+ Force,
+}
+
+impl CacheMode {
+ fn use_cached_llm_responses(&self) -> bool {
+ self.assert_not_auto();
+ matches!(self, CacheMode::Requests | CacheMode::Force)
+ }
+
+ fn use_cached_search_results(&self) -> bool {
+ self.assert_not_auto();
+ matches!(self, CacheMode::Force)
+ }
+
+ fn assert_not_auto(&self) {
+ assert_ne!(
+ *self,
+ CacheMode::Auto,
+ "Cache mode should not be auto at this point!"
+ );
+ }
+}
+
+#[derive(clap::ValueEnum, Debug, Clone)]
+pub enum PredictionsOutputFormat {
+ Json,
+ Md,
+ Diff,
+}
+
+#[derive(Debug, Args)]
+pub struct EvaluateArguments {
+ example_paths: Vec<PathBuf>,
+ #[clap(flatten)]
+ options: PredictionOptions,
+ #[clap(short, long, default_value_t = 1, alias = "repeat")]
+ repetitions: u16,
+ #[arg(long)]
+ skip_prediction: bool,
+}
+
+#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
+enum PredictionProvider {
+ #[default]
+ Zeta2,
+ Sweep,
+}
+
+fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions {
zeta2::ZetaOptions {
context: ContextMode::Syntax(EditPredictionContextOptions {
- max_retrieved_declarations: syntax_args.max_retrieved_definitions,
- use_imports: !syntax_args.disable_imports_gathering,
+ max_retrieved_declarations: args.max_retrieved_definitions,
+ use_imports: !args.disable_imports_gathering,
excerpt: EditPredictionExcerptOptions {
- max_bytes: zeta2_args.max_excerpt_bytes,
- min_bytes: zeta2_args.min_excerpt_bytes,
- target_before_cursor_over_total_bytes: zeta2_args
- .target_before_cursor_over_total_bytes,
+ max_bytes: args.max_excerpt_bytes,
+ min_bytes: args.min_excerpt_bytes,
+ target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
},
score: EditPredictionScoreOptions {
omit_excerpt_overlaps,
},
}),
- max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
- max_prompt_bytes: zeta2_args.max_prompt_bytes,
- prompt_format: zeta2_args.prompt_format.into(),
- file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
+ max_diagnostic_bytes: args.max_diagnostic_bytes,
+ max_prompt_bytes: args.max_prompt_bytes,
+ prompt_format: args.prompt_format.into(),
+ file_indexing_parallelism: args.file_indexing_parallelism,
buffer_change_grouping_interval: Duration::ZERO,
}
}
@@ -320,8 +373,6 @@ async fn load_context(
}
async fn zeta2_syntax_context(
- zeta2_args: Zeta2Args,
- syntax_args: Zeta2SyntaxArgs,
args: ContextArgs,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
@@ -347,7 +398,7 @@ async fn zeta2_syntax_context(
zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
});
let indexing_done_task = zeta.update(cx, |zeta, cx| {
- zeta.set_options(syntax_args_to_options(&zeta2_args, &syntax_args, true));
+ zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true));
zeta.register_buffer(&buffer, &project, cx);
zeta.wait_for_initial_indexing(&project, cx)
});
@@ -362,7 +413,7 @@ async fn zeta2_syntax_context(
let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
- match zeta2_args.output_format {
+ match args.zeta2_args.output_format {
OutputFormat::Prompt => anyhow::Ok(prompt_string),
OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
@@ -427,57 +478,40 @@ fn main() {
panic!("Expected a command");
}
}
- Some(Command::Zeta1 {
- command: Zeta1Command::Context { context_args },
- }) => {
- let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
- let result = serde_json::to_string_pretty(&context.body).unwrap();
- println!("{}", result);
+ Some(Command::ContextStats(arguments)) => {
+ let result = retrieval_stats(
+ arguments.worktree,
+ app_state,
+ arguments.extension,
+ arguments.limit,
+ arguments.skip,
+ zeta2_args_to_options(&arguments.zeta2_args, false),
+ cx,
+ )
+ .await;
+ println!("{}", result.unwrap());
}
- Some(Command::Zeta2 { command }) => match command {
- Zeta2Command::Predict(arguments) => {
- run_zeta2_predict(arguments, &app_state, cx).await;
- }
- Zeta2Command::Eval(arguments) => {
- run_evaluate(arguments, &app_state, cx).await;
- }
- Zeta2Command::Syntax {
- args,
- syntax_args,
- command,
- } => {
- let result = match command {
- Zeta2SyntaxCommand::Context { context_args } => {
- zeta2_syntax_context(
- args,
- syntax_args,
- context_args,
- &app_state,
- cx,
- )
+ Some(Command::Context(context_args)) => {
+ let result = match context_args.provider {
+ ContextProvider::Zeta1 => {
+ let context =
+ zeta1_context(context_args, &app_state, cx).await.unwrap();
+ serde_json::to_string_pretty(&context.body).unwrap()
+ }
+ ContextProvider::Syntax => {
+ zeta2_syntax_context(context_args, &app_state, cx)
.await
- }
- Zeta2SyntaxCommand::Stats {
- worktree,
- extension,
- limit,
- skip,
- } => {
- retrieval_stats(
- worktree,
- app_state,
- extension,
- limit,
- skip,
- syntax_args_to_options(&args, &syntax_args, false),
- cx,
- )
- .await
- }
- };
- println!("{}", result.unwrap());
- }
- },
+ .unwrap()
+ }
+ };
+ println!("{}", result);
+ }
+ Some(Command::Predict(arguments)) => {
+ run_predict(arguments, &app_state, cx).await;
+ }
+ Some(Command::Eval(arguments)) => {
+ run_evaluate(arguments, &app_state, cx).await;
+ }
Some(Command::ConvertExample {
path,
output_format,
@@ -1,16 +1,18 @@
-use crate::PromptFormat;
use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
use crate::headless::ZetaCliAppState;
use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
+use crate::{
+ CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat,
+};
use ::serde::Serialize;
use anyhow::{Context, Result, anyhow};
-use clap::{Args, ValueEnum};
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
use collections::HashMap;
use futures::StreamExt as _;
use gpui::{AppContext, AsyncApp, Entity};
use language::{Anchor, Buffer, Point};
use project::Project;
+use project::buffer_store::BufferStoreEvent;
use serde::Deserialize;
use std::fs;
use std::io::{IsTerminal, Write};
@@ -19,98 +21,86 @@ use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::{Duration, Instant};
+use sweep_ai::SweepAi;
use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
-#[derive(Debug, Args)]
-pub struct PredictArguments {
- #[arg(long, value_enum, default_value_t = PromptFormat::default())]
- prompt_format: PromptFormat,
- #[arg(long)]
- use_expected_context: bool,
- #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
- format: PredictionsOutputFormat,
- example_path: PathBuf,
- #[clap(long, value_enum, default_value_t = CacheMode::default())]
- cache: CacheMode,
-}
-
-#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
-pub enum CacheMode {
- /// Use cached LLM requests and responses, except when multiple repetitions are requested
- #[default]
- Auto,
- /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
- #[value(alias = "request")]
- Requests,
- /// Ignore existing cache entries for both LLM and search.
- Skip,
- /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
- /// Useful for reproducing results and fixing bugs outside of search queries
- Force,
-}
-
-impl CacheMode {
- fn use_cached_llm_responses(&self) -> bool {
- self.assert_not_auto();
- matches!(self, CacheMode::Requests | CacheMode::Force)
- }
-
- fn use_cached_search_results(&self) -> bool {
- self.assert_not_auto();
- matches!(self, CacheMode::Force)
- }
-
- fn assert_not_auto(&self) {
- assert_ne!(
- *self,
- CacheMode::Auto,
- "Cache mode should not be auto at this point!"
- );
- }
-}
-
-#[derive(clap::ValueEnum, Debug, Clone)]
-pub enum PredictionsOutputFormat {
- Json,
- Md,
- Diff,
-}
-
-pub async fn run_zeta2_predict(
+pub async fn run_predict(
args: PredictArguments,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) {
let example = NamedExample::load(args.example_path).unwrap();
- let (project, mut zetas, _edited_buffers) =
- example.setup_project(app_state, 1, cx).await.unwrap();
- let result = zeta2_predict(
- example,
- project,
- zetas.remove(0),
- None,
- args.prompt_format,
- args.use_expected_context,
- args.cache,
- cx,
- )
- .await
- .unwrap();
+ let project = example.setup_project(app_state, cx).await.unwrap();
+ let zeta = setup_zeta(&project, app_state, cx).unwrap();
+ let sweep = if matches!(args.options.provider, PredictionProvider::Sweep) {
+ Some(setup_sweep(&project, cx).unwrap())
+ } else {
+ None
+ };
+ let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
+ let result = perform_predict(example, project, zeta, sweep, None, args.options, cx)
+ .await
+ .unwrap();
result.write(args.format, std::io::stdout()).unwrap();
print_run_data_dir(true, std::io::stdout().is_terminal());
}
-pub async fn zeta2_predict(
+pub fn setup_zeta(
+ project: &Entity<Project>,
+ app_state: &Arc<ZetaCliAppState>,
+ cx: &mut AsyncApp,
+) -> Result<Entity<Zeta>> {
+ let zeta =
+ cx.new(|cx| zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?;
+
+ let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
+
+ cx.subscribe(&buffer_store, {
+ let project = project.clone();
+ let zeta = zeta.clone();
+ move |_, event, cx| match event {
+ BufferStoreEvent::BufferAdded(buffer) => {
+ zeta.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
+ }
+ _ => {}
+ }
+ })?
+ .detach();
+
+ anyhow::Ok(zeta)
+}
+
+pub fn setup_sweep(project: &Entity<Project>, cx: &mut AsyncApp) -> Result<Entity<SweepAi>> {
+ let sweep = cx.new(|cx| SweepAi::new(cx))?;
+
+ let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
+
+ cx.subscribe(&buffer_store, {
+ let project = project.clone();
+ let sweep = sweep.clone();
+ move |_, event, cx| match event {
+ BufferStoreEvent::BufferAdded(buffer) => {
+ sweep.update(cx, |sweep, cx| sweep.register_buffer(&buffer, &project, cx));
+ }
+ _ => {}
+ }
+ })?
+ .detach();
+
+ anyhow::Ok(sweep)
+}
+
+pub async fn perform_predict(
example: NamedExample,
project: Entity<Project>,
zeta: Entity<Zeta>,
+ sweep: Option<Entity<SweepAi>>,
repetition_ix: Option<u16>,
- prompt_format: PromptFormat,
- use_expected_context: bool,
- mut cache_mode: CacheMode,
+ options: PredictionOptions,
cx: &mut AsyncApp,
) -> Result<PredictionDetails> {
+ let mut cache_mode = options.cache;
if repetition_ix.is_some() {
if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
panic!("Repetitions are not supported in Auto cache mode");
@@ -148,94 +138,8 @@ pub async fn zeta2_predict(
let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
- let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
-
- let debug_task = cx.background_spawn({
- let result = result.clone();
- async move {
- let mut start_time = None;
- let mut search_queries_generated_at = None;
- let mut search_queries_executed_at = None;
- while let Some(event) = debug_rx.next().await {
- match event {
- zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
- start_time = Some(info.timestamp);
- fs::write(
- example_run_dir.join("search_prompt.md"),
- &info.search_prompt,
- )?;
- }
- zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
- search_queries_generated_at = Some(info.timestamp);
- fs::write(
- example_run_dir.join("search_queries.json"),
- serde_json::to_string_pretty(&info.search_queries).unwrap(),
- )?;
- }
- zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
- search_queries_executed_at = Some(info.timestamp);
- }
- zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
- zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
- let prediction_started_at = Instant::now();
- start_time.get_or_insert(prediction_started_at);
- let prompt = request.local_prompt.unwrap_or_default();
- fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
-
- {
- let mut result = result.lock().unwrap();
- result.prompt_len = prompt.chars().count();
-
- for included_file in request.request.included_files {
- let insertions =
- vec![(request.request.cursor_point, CURSOR_MARKER)];
- result.excerpts.extend(included_file.excerpts.iter().map(
- |excerpt| ActualExcerpt {
- path: included_file.path.components().skip(1).collect(),
- text: String::from(excerpt.text.as_ref()),
- },
- ));
- write_codeblock(
- &included_file.path,
- included_file.excerpts.iter(),
- if included_file.path == request.request.excerpt_path {
- &insertions
- } else {
- &[]
- },
- included_file.max_row,
- false,
- &mut result.excerpts_text,
- );
- }
- }
- let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
- let response = zeta2::text_from_response(response).unwrap_or_default();
- let prediction_finished_at = Instant::now();
- fs::write(example_run_dir.join("prediction_response.md"), &response)?;
-
- let mut result = result.lock().unwrap();
- result.generated_len = response.chars().count();
-
- if !use_expected_context {
- result.planning_search_time =
- Some(search_queries_generated_at.unwrap() - start_time.unwrap());
- result.running_search_time = Some(
- search_queries_executed_at.unwrap()
- - search_queries_generated_at.unwrap(),
- );
- }
- result.prediction_time = prediction_finished_at - prediction_started_at;
- result.total_time = prediction_finished_at - start_time.unwrap();
-
- break;
- }
- }
- }
- anyhow::Ok(())
- }
- });
+ let prompt_format = options.zeta2.prompt_format;
zeta.update(cx, |zeta, _cx| {
let mut options = zeta.options().clone();
@@ -243,55 +147,194 @@ pub async fn zeta2_predict(
zeta.set_options(options);
})?;
- if use_expected_context {
- let context_excerpts_tasks = example
- .example
- .expected_context
- .iter()
- .flat_map(|section| {
- section.alternatives[0].excerpts.iter().map(|excerpt| {
- resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
- })
- })
- .collect::<Vec<_>>();
- let context_excerpts_vec = futures::future::try_join_all(context_excerpts_tasks).await?;
-
- let mut context_excerpts = HashMap::default();
- for (buffer, mut excerpts) in context_excerpts_vec {
- context_excerpts
- .entry(buffer)
- .or_insert(Vec::new())
- .append(&mut excerpts);
- }
+ let prediction = match options.provider {
+ crate::PredictionProvider::Zeta2 => {
+ let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
+
+ let debug_task = cx.background_spawn({
+ let result = result.clone();
+ async move {
+ let mut start_time = None;
+ let mut search_queries_generated_at = None;
+ let mut search_queries_executed_at = None;
+ while let Some(event) = debug_rx.next().await {
+ match event {
+ zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
+ start_time = Some(info.timestamp);
+ fs::write(
+ example_run_dir.join("search_prompt.md"),
+ &info.search_prompt,
+ )?;
+ }
+ zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
+ search_queries_generated_at = Some(info.timestamp);
+ fs::write(
+ example_run_dir.join("search_queries.json"),
+ serde_json::to_string_pretty(&info.search_queries).unwrap(),
+ )?;
+ }
+ zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
+ search_queries_executed_at = Some(info.timestamp);
+ }
+ zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
+ zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
+ let prediction_started_at = Instant::now();
+ start_time.get_or_insert(prediction_started_at);
+ let prompt = request.local_prompt.unwrap_or_default();
+ fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
+
+ {
+ let mut result = result.lock().unwrap();
+ result.prompt_len = prompt.chars().count();
+
+ for included_file in request.request.included_files {
+ let insertions =
+ vec![(request.request.cursor_point, CURSOR_MARKER)];
+ result.excerpts.extend(included_file.excerpts.iter().map(
+ |excerpt| {
+ ActualExcerpt {
+ path: included_file
+ .path
+ .components()
+ .skip(1)
+ .collect(),
+ text: String::from(excerpt.text.as_ref()),
+ }
+ },
+ ));
+ write_codeblock(
+ &included_file.path,
+ included_file.excerpts.iter(),
+ if included_file.path == request.request.excerpt_path {
+ &insertions
+ } else {
+ &[]
+ },
+ included_file.max_row,
+ false,
+ &mut result.excerpts_text,
+ );
+ }
+ }
+
+ let response =
+ request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
+ let response =
+ zeta2::text_from_response(response).unwrap_or_default();
+ let prediction_finished_at = Instant::now();
+ fs::write(
+ example_run_dir.join("prediction_response.md"),
+ &response,
+ )?;
+
+ let mut result = result.lock().unwrap();
+ result.generated_len = response.chars().count();
+
+ if !options.use_expected_context {
+ result.planning_search_time = Some(
+ search_queries_generated_at.unwrap() - start_time.unwrap(),
+ );
+ result.running_search_time = Some(
+ search_queries_executed_at.unwrap()
+ - search_queries_generated_at.unwrap(),
+ );
+ }
+ result.prediction_time =
+ prediction_finished_at - prediction_started_at;
+ result.total_time = prediction_finished_at - start_time.unwrap();
+
+ break;
+ }
+ }
+ }
+ anyhow::Ok(())
+ }
+ });
+
+ if options.use_expected_context {
+ let context_excerpts_tasks = example
+ .example
+ .expected_context
+ .iter()
+ .flat_map(|section| {
+ section.alternatives[0].excerpts.iter().map(|excerpt| {
+ resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
+ })
+ })
+ .collect::<Vec<_>>();
+ let context_excerpts_vec =
+ futures::future::try_join_all(context_excerpts_tasks).await?;
+
+ let mut context_excerpts = HashMap::default();
+ for (buffer, mut excerpts) in context_excerpts_vec {
+ context_excerpts
+ .entry(buffer)
+ .or_insert(Vec::new())
+ .append(&mut excerpts);
+ }
- zeta.update(cx, |zeta, _cx| {
- zeta.set_context(project.clone(), context_excerpts)
- })?;
- } else {
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
- })?
- .await?;
- }
+ zeta.update(cx, |zeta, _cx| {
+ zeta.set_context(project.clone(), context_excerpts)
+ })?;
+ } else {
+ zeta.update(cx, |zeta, cx| {
+ zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
+ })?
+ .await?;
+ }
- let prediction = zeta
- .update(cx, |zeta, cx| {
- zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
- })?
- .await?;
+ let prediction = zeta
+ .update(cx, |zeta, cx| {
+ zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
+ })?
+ .await?
+ .map(|prediction| (prediction.buffer, prediction.snapshot, prediction.edits));
+
+ debug_task.await?;
+
+ prediction
+ }
+ crate::PredictionProvider::Sweep => sweep
+ .unwrap()
+ .update(cx, |sweep, cx| {
+ let mut recent_paths = Vec::new();
+ for path in zeta
+ .read(cx)
+ .history_for_project(&project)
+ .rev()
+ .filter_map(|event| event.project_path(cx))
+ {
+ if !recent_paths.contains(&path) {
+ recent_paths.push(path);
+ }
+ }
- debug_task.await?;
+ sweep.request_completion(
+ &project,
+ recent_paths.into_iter(),
+ &cursor_buffer,
+ cursor_anchor,
+ cx,
+ )
+ })?
+ .await?
+ .map(
+ |sweep_ai::EditPrediction {
+ edits, snapshot, ..
+ }| { (cursor_buffer.clone(), snapshot, edits) },
+ ),
+ };
let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
+
result.diff = prediction
- .map(|prediction| {
- let old_text = prediction.snapshot.text();
- let new_text = prediction
- .buffer
+ .map(|(buffer, snapshot, edits)| {
+ let old_text = snapshot.text();
+ let new_text = buffer
.update(cx, |buffer, cx| {
let branch = buffer.branch(cx);
branch.update(cx, |branch, cx| {
- branch.edit(prediction.edits.iter().cloned(), None, cx);
+ branch.edit(edits.iter().cloned(), None, cx);
branch.text()
})
})