diff --git a/Cargo.lock b/Cargo.lock index 0a3d358410784f5fd9057a30f9a70d49e2fd2d90..4f9a3f26e9a20df498bd3b735cfec54aa77c77cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -21864,6 +21864,7 @@ dependencies = [ "shellexpand 2.1.2", "smol", "soa-rs", + "sweep_ai", "terminal_view", "toml 0.8.23", "util", diff --git a/crates/sweep_ai/src/sweep_ai.rs b/crates/sweep_ai/src/sweep_ai.rs index 75f6f123d5f2460fa7f2f078bec17fad0eb8acaf..1b4c92120d866a218987f36161e9520a0f3f703a 100644 --- a/crates/sweep_ai/src/sweep_ai.rs +++ b/crates/sweep_ai/src/sweep_ai.rs @@ -11,7 +11,7 @@ use http_client::{AsyncBody, Method}; use language::{ Anchor, Buffer, BufferSnapshot, EditPreview, Point, ToOffset as _, ToPoint, text_diff, }; -use project::Project; +use project::{Project, ProjectPath}; use release_channel::{AppCommitSha, AppVersion}; use std::collections::{VecDeque, hash_map}; use std::fmt::{self, Display}; @@ -48,11 +48,11 @@ impl Global for SweepAiGlobal {} #[derive(Clone)] pub struct EditPrediction { - id: EditPredictionId, - path: Arc, - edits: Arc<[(Range, Arc)]>, - snapshot: BufferSnapshot, - edit_preview: EditPreview, + pub id: EditPredictionId, + pub path: Arc, + pub edits: Arc<[(Range, Arc)]>, + pub snapshot: BufferSnapshot, + pub edit_preview: EditPreview, } impl EditPrediction { @@ -110,7 +110,7 @@ impl SweepAi { } } - fn new(cx: &mut Context) -> Self { + pub fn new(cx: &mut Context) -> Self { Self { api_token: std::env::var("SWEEP_AI_TOKEN").ok(), projects: HashMap::default(), @@ -195,8 +195,8 @@ impl SweepAi { pub fn request_completion( &mut self, - workspace: &WeakEntity, project: &Entity, + recent_buffers: impl Iterator, active_buffer: &Entity, position: language::Anchor, cx: &mut Context, @@ -223,26 +223,17 @@ impl SweepAi { let events = project_state.events.clone(); let http_client = cx.http_client(); - let Some(recent_buffers) = workspace - .read_with(cx, |workspace, cx| { - workspace - .recent_navigation_history_iter(cx) - .filter_map(|(project_path, _)| { - let buffer = project.read(cx).get_open_buffer(&project_path, cx)?; - - if active_buffer == &buffer { - None - } else { - Some(buffer.read(cx).snapshot()) - } - }) - .take(3) - .collect::>() + let recent_buffer_snapshots = recent_buffers + .filter_map(|project_path| { + let buffer = project.read(cx).get_open_buffer(&project_path, cx)?; + if active_buffer == &buffer { + None + } else { + Some(buffer.read(cx).snapshot()) + } }) - .log_err() - else { - return Task::ready(Ok(None)); - }; + .take(3) + .collect::>(); let result = cx.background_spawn({ let full_path = full_path.clone(); @@ -255,7 +246,7 @@ impl SweepAi { writeln!(&mut recent_changes, "{event}")?; } - let file_chunks = recent_buffers + let file_chunks = recent_buffer_snapshots .into_iter() .map(|snapshot| { let end_point = language::Point::new(30, 0).min(snapshot.max_point()); @@ -623,8 +614,23 @@ impl edit_prediction::EditPredictionProvider for SweepAiEditPredictionProvider { let completion_request = this.update(cx, |this, cx| { this.last_request_timestamp = Instant::now(); + this.sweep_ai.update(cx, |sweep_ai, cx| { - sweep_ai.request_completion(&workspace, &project, &buffer, position, cx) + let Some(recent_buffers) = workspace + .read_with(cx, |workspace, cx| { + workspace.recent_navigation_history_iter(cx) + }) + .log_err() + else { + return Task::ready(Ok(None)); + }; + sweep_ai.request_completion( + &project, + recent_buffers.map(move |(project_path, _)| project_path), + &buffer, + position, + cx, + ) }) }); diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index 20139e3ae8104fc0d4c1bce98f265144ef344f0d..14b33af6cd1f8778a9bbafeb8e9854cc9fc11247 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -1845,7 +1845,7 @@ impl Workspace { pub fn recent_navigation_history_iter( &self, cx: &App, - ) -> impl Iterator)> { + ) -> impl Iterator)> + use<> { let mut abs_paths_opened: HashMap> = HashMap::default(); let mut history: HashMap, usize)> = HashMap::default(); diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 881a7254f876e1b2df636513480115bf36489a24..099cd95134ec3d1fd59bbc33306bc439c0a8ee1a 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -50,7 +50,7 @@ pub mod udiff; mod xml_edits; use crate::assemble_excerpts::assemble_excerpts; -use crate::prediction::EditPrediction; +pub use crate::prediction::EditPrediction; pub use crate::prediction::EditPredictionId; pub use provider::ZetaEditPredictionProvider; @@ -327,6 +327,14 @@ impl Event { } } } + + pub fn project_path(&self, cx: &App) -> Option { + match self { + Event::BufferChange { new_snapshot, .. } => new_snapshot + .file() + .map(|f| project::ProjectPath::from_file(f.as_ref(), cx)), + } + } } impl Zeta { @@ -401,7 +409,10 @@ impl Zeta { } } - pub fn history_for_project(&self, project: &Entity) -> impl Iterator { + pub fn history_for_project( + &self, + project: &Entity, + ) -> impl DoubleEndedIterator { self.projects .get(&project.entity_id()) .map(|project| project.events.iter()) diff --git a/crates/zeta_cli/Cargo.toml b/crates/zeta_cli/Cargo.toml index e18cf54787ca98e2be60db4977dd2de18e9c09e2..35fbcb1c61097156d2f0e172d700ed12d3d3894e 100644 --- a/crates/zeta_cli/Cargo.toml +++ b/crates/zeta_cli/Cargo.toml @@ -49,6 +49,7 @@ settings.workspace = true shellexpand.workspace = true smol.workspace = true soa-rs = "0.8.1" +sweep_ai.workspace = true terminal_view.workspace = true toml.workspace = true util.workspace = true diff --git a/crates/zeta_cli/src/evaluate.rs b/crates/zeta_cli/src/evaluate.rs index d808e3d743d7009ca66a75b3a349914b0a4f5447..09fbbb29dd6cf58910a2b6e6ff7fb4a31fc4a10a 100644 --- a/crates/zeta_cli/src/evaluate.rs +++ b/crates/zeta_cli/src/evaluate.rs @@ -1,41 +1,25 @@ use std::{ collections::{BTreeSet, HashMap}, io::{IsTerminal, Write}, - path::PathBuf, sync::Arc, }; use anyhow::Result; -use clap::Args; use collections::HashSet; use gpui::{AsyncApp, Entity}; use project::Project; +use sweep_ai::SweepAi; use util::ResultExt as _; use zeta2::{Zeta, udiff::DiffLine}; use crate::{ - PromptFormat, + EvaluateArguments, PredictionOptions, PredictionProvider, example::{Example, NamedExample}, headless::ZetaCliAppState, paths::print_run_data_dir, - predict::{CacheMode, PredictionDetails, zeta2_predict}, + predict::{PredictionDetails, perform_predict, setup_sweep, setup_zeta}, }; -#[derive(Debug, Args)] -pub struct EvaluateArguments { - example_paths: Vec, - #[arg(long, value_enum, default_value_t = PromptFormat::default())] - prompt_format: PromptFormat, - #[arg(long)] - use_expected_context: bool, - #[clap(long, value_enum, default_value_t = CacheMode::default())] - cache: CacheMode, - #[clap(short, long, default_value_t = 1, alias = "repeat")] - repetitions: u16, - #[arg(long)] - skip_prediction: bool, -} - #[derive(Debug)] pub(crate) struct ExecutionData { execution_id: String, @@ -52,38 +36,56 @@ pub async fn run_evaluate( eprintln!("No examples provided"); return; } + let all_tasks = args.example_paths.into_iter().map(|path| { + let options = args.options.clone(); let app_state = app_state.clone(); let example = NamedExample::load(&path).expect("Failed to load example"); cx.spawn(async move |cx| { - let (project, zetas, _edited_buffers) = example - .setup_project(&app_state, args.repetitions, cx) - .await - .unwrap(); - - let tasks = zetas.into_iter().enumerate().map(|(repetition_ix, zeta)| { - let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16); - let example = example.clone(); - let project = project.clone(); - - cx.spawn(async move |cx| { - let name = example.name.clone(); - run_evaluate_one( - example, - repetition_ix, - project, - zeta, - args.prompt_format, - args.use_expected_context, - !args.skip_prediction, - args.cache, - cx, + let project = example.setup_project(&app_state, cx).await.unwrap(); + + let providers = (0..args.repetitions) + .map(|_| { + ( + setup_zeta(&project, &app_state, cx).unwrap(), + if matches!(args.options.provider, PredictionProvider::Sweep) { + Some(setup_sweep(&project, cx).unwrap()) + } else { + None + }, ) - .await - .map_err(|err| (err, name, repetition_ix)) }) - }); + .collect::>(); + + let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap(); + + let tasks = + providers + .into_iter() + .enumerate() + .map(move |(repetition_ix, (zeta, sweep))| { + let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16); + let example = example.clone(); + let project = project.clone(); + let options = options.clone(); + + cx.spawn(async move |cx| { + let name = example.name.clone(); + run_evaluate_one( + example, + repetition_ix, + project, + zeta, + sweep, + options, + !args.skip_prediction, + cx, + ) + .await + .map_err(|err| (err, name, repetition_ix)) + }) + }); futures::future::join_all(tasks).await }) }); @@ -175,20 +177,18 @@ pub async fn run_evaluate_one( repetition_ix: Option, project: Entity, zeta: Entity, - prompt_format: PromptFormat, - use_expected_context: bool, + sweep: Option>, + prediction_options: PredictionOptions, predict: bool, - cache_mode: CacheMode, cx: &mut AsyncApp, ) -> Result<(EvaluationResult, ExecutionData)> { - let predict_result = zeta2_predict( + let predict_result = perform_predict( example.clone(), project, zeta, + sweep, repetition_ix, - prompt_format, - use_expected_context, - cache_mode, + prediction_options, cx, ) .await?; diff --git a/crates/zeta_cli/src/example.rs b/crates/zeta_cli/src/example.rs index 300e453af93bd3c69a47f5e155e274431aa01c92..67eed23f90dc1a5b48a53a2a7de07f500396ba9f 100644 --- a/crates/zeta_cli/src/example.rs +++ b/crates/zeta_cli/src/example.rs @@ -20,13 +20,13 @@ use futures::{ lock::{Mutex, OwnedMutexGuard}, }; use futures::{FutureExt as _, future::Shared}; -use gpui::{AppContext as _, AsyncApp, Entity, Task, http_client::Url}; +use gpui::{AsyncApp, Entity, Task, http_client::Url}; use language::{Anchor, Buffer}; use project::{Project, ProjectPath}; use pulldown_cmark::CowStr; use serde::{Deserialize, Serialize}; use util::{paths::PathStyle, rel_path::RelPath}; -use zeta2::{Zeta, udiff::OpenedBuffers}; +use zeta2::udiff::OpenedBuffers; use crate::paths::{REPOS_DIR, WORKTREES_DIR}; @@ -318,12 +318,11 @@ impl NamedExample { } } - pub async fn setup_project<'a>( - &'a self, + pub async fn setup_project( + &self, app_state: &Arc, - repetitions: u16, cx: &mut AsyncApp, - ) -> Result<(Entity, Vec>, OpenedBuffers<'a>)> { + ) -> Result> { let worktree_path = self.setup_worktree().await?; static AUTHENTICATED: OnceLock>> = OnceLock::new(); @@ -365,33 +364,7 @@ impl NamedExample { })? .await; - let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?; - - let zetas = (0..repetitions) - .map(|_| { - let zeta = cx.new(|cx| { - zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx) - })?; - - cx.subscribe(&buffer_store, { - let project = project.clone(); - let zeta = zeta.clone(); - move |_, event, cx| match event { - project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => { - zeta.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx)); - } - _ => {} - } - })? - .detach(); - - anyhow::Ok(zeta) - }) - .collect::>>()?; - - let edited_buffers = self.apply_edit_history(&project, cx).await?; - - anyhow::Ok((project, zetas, edited_buffers)) + anyhow::Ok(project) } pub async fn setup_worktree(&self) -> Result { diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 517deb6ec7482ca2712a347531b24eca5ed16796..803e02b10cfb7533341a3009e0325a7bcf13df1e 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -7,13 +7,18 @@ mod source_location; mod syntax_retrieval_stats; mod util; -use crate::evaluate::{EvaluateArguments, run_evaluate}; -use crate::example::{ExampleFormat, NamedExample}; -use crate::predict::{PredictArguments, run_zeta2_predict}; -use crate::syntax_retrieval_stats::retrieval_stats; +use crate::{ + evaluate::run_evaluate, + example::{ExampleFormat, NamedExample}, + headless::ZetaCliAppState, + predict::run_predict, + source_location::SourceLocation, + syntax_retrieval_stats::retrieval_stats, + util::{open_buffer, open_buffer_with_language_server}, +}; use ::util::paths::PathStyle; use anyhow::{Result, anyhow}; -use clap::{Args, Parser, Subcommand}; +use clap::{Args, Parser, Subcommand, ValueEnum}; use cloud_llm_client::predict_edits_v3; use edit_prediction_context::{ EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions, @@ -28,10 +33,6 @@ use std::time::Duration; use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc}; use zeta2::ContextMode; -use crate::headless::ZetaCliAppState; -use crate::source_location::SourceLocation; -use crate::util::{open_buffer, open_buffer_with_language_server}; - #[derive(Parser, Debug)] #[command(name = "zeta")] struct ZetaCliArgs { @@ -43,14 +44,10 @@ struct ZetaCliArgs { #[derive(Subcommand, Debug)] enum Command { - Zeta1 { - #[command(subcommand)] - command: Zeta1Command, - }, - Zeta2 { - #[command(subcommand)] - command: Zeta2Command, - }, + Context(ContextArgs), + ContextStats(ContextStatsArgs), + Predict(PredictArguments), + Eval(EvaluateArguments), ConvertExample { path: PathBuf, #[arg(long, value_enum, default_value_t = ExampleFormat::Md)] @@ -59,49 +56,24 @@ enum Command { Clean, } -#[derive(Subcommand, Debug)] -enum Zeta1Command { - Context { - #[clap(flatten)] - context_args: ContextArgs, - }, -} - -#[derive(Subcommand, Debug)] -enum Zeta2Command { - Syntax { - #[clap(flatten)] - args: Zeta2Args, - #[clap(flatten)] - syntax_args: Zeta2SyntaxArgs, - #[command(subcommand)] - command: Zeta2SyntaxCommand, - }, - Predict(PredictArguments), - Eval(EvaluateArguments), -} - -#[derive(Subcommand, Debug)] -enum Zeta2SyntaxCommand { - Context { - #[clap(flatten)] - context_args: ContextArgs, - }, - Stats { - #[arg(long)] - worktree: PathBuf, - #[arg(long)] - extension: Option, - #[arg(long)] - limit: Option, - #[arg(long)] - skip: Option, - }, +#[derive(Debug, Args)] +struct ContextStatsArgs { + #[arg(long)] + worktree: PathBuf, + #[arg(long)] + extension: Option, + #[arg(long)] + limit: Option, + #[arg(long)] + skip: Option, + #[clap(flatten)] + zeta2_args: Zeta2Args, } #[derive(Debug, Args)] -#[group(requires = "worktree")] struct ContextArgs { + #[arg(long)] + provider: ContextProvider, #[arg(long)] worktree: PathBuf, #[arg(long)] @@ -110,9 +82,18 @@ struct ContextArgs { use_language_server: bool, #[arg(long)] edit_history: Option, + #[clap(flatten)] + zeta2_args: Zeta2Args, } -#[derive(Debug, Args)] +#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)] +enum ContextProvider { + Zeta1, + #[default] + Syntax, +} + +#[derive(Clone, Debug, Args)] struct Zeta2Args { #[arg(long, default_value_t = 8192)] max_prompt_bytes: usize, @@ -130,39 +111,111 @@ struct Zeta2Args { output_format: OutputFormat, #[arg(long, default_value_t = 42)] file_indexing_parallelism: usize, -} - -#[derive(Debug, Args)] -struct Zeta2SyntaxArgs { #[arg(long, default_value_t = false)] disable_imports_gathering: bool, #[arg(long, default_value_t = u8::MAX)] max_retrieved_definitions: u8, } -fn syntax_args_to_options( - zeta2_args: &Zeta2Args, - syntax_args: &Zeta2SyntaxArgs, - omit_excerpt_overlaps: bool, -) -> zeta2::ZetaOptions { +#[derive(Debug, Args)] +pub struct PredictArguments { + #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)] + format: PredictionsOutputFormat, + example_path: PathBuf, + #[clap(flatten)] + options: PredictionOptions, +} + +#[derive(Clone, Debug, Args)] +pub struct PredictionOptions { + #[arg(long)] + use_expected_context: bool, + #[clap(flatten)] + zeta2: Zeta2Args, + #[clap(long)] + provider: PredictionProvider, + #[clap(long, value_enum, default_value_t = CacheMode::default())] + cache: CacheMode, +} + +#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)] +pub enum CacheMode { + /// Use cached LLM requests and responses, except when multiple repetitions are requested + #[default] + Auto, + /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint. + #[value(alias = "request")] + Requests, + /// Ignore existing cache entries for both LLM and search. + Skip, + /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet. + /// Useful for reproducing results and fixing bugs outside of search queries + Force, +} + +impl CacheMode { + fn use_cached_llm_responses(&self) -> bool { + self.assert_not_auto(); + matches!(self, CacheMode::Requests | CacheMode::Force) + } + + fn use_cached_search_results(&self) -> bool { + self.assert_not_auto(); + matches!(self, CacheMode::Force) + } + + fn assert_not_auto(&self) { + assert_ne!( + *self, + CacheMode::Auto, + "Cache mode should not be auto at this point!" + ); + } +} + +#[derive(clap::ValueEnum, Debug, Clone)] +pub enum PredictionsOutputFormat { + Json, + Md, + Diff, +} + +#[derive(Debug, Args)] +pub struct EvaluateArguments { + example_paths: Vec, + #[clap(flatten)] + options: PredictionOptions, + #[clap(short, long, default_value_t = 1, alias = "repeat")] + repetitions: u16, + #[arg(long)] + skip_prediction: bool, +} + +#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)] +enum PredictionProvider { + #[default] + Zeta2, + Sweep, +} + +fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions { zeta2::ZetaOptions { context: ContextMode::Syntax(EditPredictionContextOptions { - max_retrieved_declarations: syntax_args.max_retrieved_definitions, - use_imports: !syntax_args.disable_imports_gathering, + max_retrieved_declarations: args.max_retrieved_definitions, + use_imports: !args.disable_imports_gathering, excerpt: EditPredictionExcerptOptions { - max_bytes: zeta2_args.max_excerpt_bytes, - min_bytes: zeta2_args.min_excerpt_bytes, - target_before_cursor_over_total_bytes: zeta2_args - .target_before_cursor_over_total_bytes, + max_bytes: args.max_excerpt_bytes, + min_bytes: args.min_excerpt_bytes, + target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes, }, score: EditPredictionScoreOptions { omit_excerpt_overlaps, }, }), - max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes, - max_prompt_bytes: zeta2_args.max_prompt_bytes, - prompt_format: zeta2_args.prompt_format.into(), - file_indexing_parallelism: zeta2_args.file_indexing_parallelism, + max_diagnostic_bytes: args.max_diagnostic_bytes, + max_prompt_bytes: args.max_prompt_bytes, + prompt_format: args.prompt_format.into(), + file_indexing_parallelism: args.file_indexing_parallelism, buffer_change_grouping_interval: Duration::ZERO, } } @@ -320,8 +373,6 @@ async fn load_context( } async fn zeta2_syntax_context( - zeta2_args: Zeta2Args, - syntax_args: Zeta2SyntaxArgs, args: ContextArgs, app_state: &Arc, cx: &mut AsyncApp, @@ -347,7 +398,7 @@ async fn zeta2_syntax_context( zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx) }); let indexing_done_task = zeta.update(cx, |zeta, cx| { - zeta.set_options(syntax_args_to_options(&zeta2_args, &syntax_args, true)); + zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true)); zeta.register_buffer(&buffer, &project, cx); zeta.wait_for_initial_indexing(&project, cx) }); @@ -362,7 +413,7 @@ async fn zeta2_syntax_context( let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?; - match zeta2_args.output_format { + match args.zeta2_args.output_format { OutputFormat::Prompt => anyhow::Ok(prompt_string), OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?), OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({ @@ -427,57 +478,40 @@ fn main() { panic!("Expected a command"); } } - Some(Command::Zeta1 { - command: Zeta1Command::Context { context_args }, - }) => { - let context = zeta1_context(context_args, &app_state, cx).await.unwrap(); - let result = serde_json::to_string_pretty(&context.body).unwrap(); - println!("{}", result); + Some(Command::ContextStats(arguments)) => { + let result = retrieval_stats( + arguments.worktree, + app_state, + arguments.extension, + arguments.limit, + arguments.skip, + zeta2_args_to_options(&arguments.zeta2_args, false), + cx, + ) + .await; + println!("{}", result.unwrap()); } - Some(Command::Zeta2 { command }) => match command { - Zeta2Command::Predict(arguments) => { - run_zeta2_predict(arguments, &app_state, cx).await; - } - Zeta2Command::Eval(arguments) => { - run_evaluate(arguments, &app_state, cx).await; - } - Zeta2Command::Syntax { - args, - syntax_args, - command, - } => { - let result = match command { - Zeta2SyntaxCommand::Context { context_args } => { - zeta2_syntax_context( - args, - syntax_args, - context_args, - &app_state, - cx, - ) + Some(Command::Context(context_args)) => { + let result = match context_args.provider { + ContextProvider::Zeta1 => { + let context = + zeta1_context(context_args, &app_state, cx).await.unwrap(); + serde_json::to_string_pretty(&context.body).unwrap() + } + ContextProvider::Syntax => { + zeta2_syntax_context(context_args, &app_state, cx) .await - } - Zeta2SyntaxCommand::Stats { - worktree, - extension, - limit, - skip, - } => { - retrieval_stats( - worktree, - app_state, - extension, - limit, - skip, - syntax_args_to_options(&args, &syntax_args, false), - cx, - ) - .await - } - }; - println!("{}", result.unwrap()); - } - }, + .unwrap() + } + }; + println!("{}", result); + } + Some(Command::Predict(arguments)) => { + run_predict(arguments, &app_state, cx).await; + } + Some(Command::Eval(arguments)) => { + run_evaluate(arguments, &app_state, cx).await; + } Some(Command::ConvertExample { path, output_format, diff --git a/crates/zeta_cli/src/predict.rs b/crates/zeta_cli/src/predict.rs index 28eb7e426c21126b1c91dc62132c1bf460a93661..4505035eaf992751e85216a314b731a12ffbd342 100644 --- a/crates/zeta_cli/src/predict.rs +++ b/crates/zeta_cli/src/predict.rs @@ -1,16 +1,18 @@ -use crate::PromptFormat; use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample}; use crate::headless::ZetaCliAppState; use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir}; +use crate::{ + CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat, +}; use ::serde::Serialize; use anyhow::{Context, Result, anyhow}; -use clap::{Args, ValueEnum}; use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock}; use collections::HashMap; use futures::StreamExt as _; use gpui::{AppContext, AsyncApp, Entity}; use language::{Anchor, Buffer, Point}; use project::Project; +use project::buffer_store::BufferStoreEvent; use serde::Deserialize; use std::fs; use std::io::{IsTerminal, Write}; @@ -19,98 +21,86 @@ use std::path::PathBuf; use std::sync::Arc; use std::sync::Mutex; use std::time::{Duration, Instant}; +use sweep_ai::SweepAi; use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta}; -#[derive(Debug, Args)] -pub struct PredictArguments { - #[arg(long, value_enum, default_value_t = PromptFormat::default())] - prompt_format: PromptFormat, - #[arg(long)] - use_expected_context: bool, - #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)] - format: PredictionsOutputFormat, - example_path: PathBuf, - #[clap(long, value_enum, default_value_t = CacheMode::default())] - cache: CacheMode, -} - -#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)] -pub enum CacheMode { - /// Use cached LLM requests and responses, except when multiple repetitions are requested - #[default] - Auto, - /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint. - #[value(alias = "request")] - Requests, - /// Ignore existing cache entries for both LLM and search. - Skip, - /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet. - /// Useful for reproducing results and fixing bugs outside of search queries - Force, -} - -impl CacheMode { - fn use_cached_llm_responses(&self) -> bool { - self.assert_not_auto(); - matches!(self, CacheMode::Requests | CacheMode::Force) - } - - fn use_cached_search_results(&self) -> bool { - self.assert_not_auto(); - matches!(self, CacheMode::Force) - } - - fn assert_not_auto(&self) { - assert_ne!( - *self, - CacheMode::Auto, - "Cache mode should not be auto at this point!" - ); - } -} - -#[derive(clap::ValueEnum, Debug, Clone)] -pub enum PredictionsOutputFormat { - Json, - Md, - Diff, -} - -pub async fn run_zeta2_predict( +pub async fn run_predict( args: PredictArguments, app_state: &Arc, cx: &mut AsyncApp, ) { let example = NamedExample::load(args.example_path).unwrap(); - let (project, mut zetas, _edited_buffers) = - example.setup_project(app_state, 1, cx).await.unwrap(); - let result = zeta2_predict( - example, - project, - zetas.remove(0), - None, - args.prompt_format, - args.use_expected_context, - args.cache, - cx, - ) - .await - .unwrap(); + let project = example.setup_project(app_state, cx).await.unwrap(); + let zeta = setup_zeta(&project, app_state, cx).unwrap(); + let sweep = if matches!(args.options.provider, PredictionProvider::Sweep) { + Some(setup_sweep(&project, cx).unwrap()) + } else { + None + }; + let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap(); + let result = perform_predict(example, project, zeta, sweep, None, args.options, cx) + .await + .unwrap(); result.write(args.format, std::io::stdout()).unwrap(); print_run_data_dir(true, std::io::stdout().is_terminal()); } -pub async fn zeta2_predict( +pub fn setup_zeta( + project: &Entity, + app_state: &Arc, + cx: &mut AsyncApp, +) -> Result> { + let zeta = + cx.new(|cx| zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?; + + let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?; + + cx.subscribe(&buffer_store, { + let project = project.clone(); + let zeta = zeta.clone(); + move |_, event, cx| match event { + BufferStoreEvent::BufferAdded(buffer) => { + zeta.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx)); + } + _ => {} + } + })? + .detach(); + + anyhow::Ok(zeta) +} + +pub fn setup_sweep(project: &Entity, cx: &mut AsyncApp) -> Result> { + let sweep = cx.new(|cx| SweepAi::new(cx))?; + + let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?; + + cx.subscribe(&buffer_store, { + let project = project.clone(); + let sweep = sweep.clone(); + move |_, event, cx| match event { + BufferStoreEvent::BufferAdded(buffer) => { + sweep.update(cx, |sweep, cx| sweep.register_buffer(&buffer, &project, cx)); + } + _ => {} + } + })? + .detach(); + + anyhow::Ok(sweep) +} + +pub async fn perform_predict( example: NamedExample, project: Entity, zeta: Entity, + sweep: Option>, repetition_ix: Option, - prompt_format: PromptFormat, - use_expected_context: bool, - mut cache_mode: CacheMode, + options: PredictionOptions, cx: &mut AsyncApp, ) -> Result { + let mut cache_mode = options.cache; if repetition_ix.is_some() { if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip { panic!("Repetitions are not supported in Auto cache mode"); @@ -148,94 +138,8 @@ pub async fn zeta2_predict( let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?; let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone()))); - let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?; - - let debug_task = cx.background_spawn({ - let result = result.clone(); - async move { - let mut start_time = None; - let mut search_queries_generated_at = None; - let mut search_queries_executed_at = None; - while let Some(event) = debug_rx.next().await { - match event { - zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => { - start_time = Some(info.timestamp); - fs::write( - example_run_dir.join("search_prompt.md"), - &info.search_prompt, - )?; - } - zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => { - search_queries_generated_at = Some(info.timestamp); - fs::write( - example_run_dir.join("search_queries.json"), - serde_json::to_string_pretty(&info.search_queries).unwrap(), - )?; - } - zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => { - search_queries_executed_at = Some(info.timestamp); - } - zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {} - zeta2::ZetaDebugInfo::EditPredictionRequested(request) => { - let prediction_started_at = Instant::now(); - start_time.get_or_insert(prediction_started_at); - let prompt = request.local_prompt.unwrap_or_default(); - fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?; - - { - let mut result = result.lock().unwrap(); - result.prompt_len = prompt.chars().count(); - - for included_file in request.request.included_files { - let insertions = - vec![(request.request.cursor_point, CURSOR_MARKER)]; - result.excerpts.extend(included_file.excerpts.iter().map( - |excerpt| ActualExcerpt { - path: included_file.path.components().skip(1).collect(), - text: String::from(excerpt.text.as_ref()), - }, - )); - write_codeblock( - &included_file.path, - included_file.excerpts.iter(), - if included_file.path == request.request.excerpt_path { - &insertions - } else { - &[] - }, - included_file.max_row, - false, - &mut result.excerpts_text, - ); - } - } - let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?; - let response = zeta2::text_from_response(response).unwrap_or_default(); - let prediction_finished_at = Instant::now(); - fs::write(example_run_dir.join("prediction_response.md"), &response)?; - - let mut result = result.lock().unwrap(); - result.generated_len = response.chars().count(); - - if !use_expected_context { - result.planning_search_time = - Some(search_queries_generated_at.unwrap() - start_time.unwrap()); - result.running_search_time = Some( - search_queries_executed_at.unwrap() - - search_queries_generated_at.unwrap(), - ); - } - result.prediction_time = prediction_finished_at - prediction_started_at; - result.total_time = prediction_finished_at - start_time.unwrap(); - - break; - } - } - } - anyhow::Ok(()) - } - }); + let prompt_format = options.zeta2.prompt_format; zeta.update(cx, |zeta, _cx| { let mut options = zeta.options().clone(); @@ -243,55 +147,194 @@ pub async fn zeta2_predict( zeta.set_options(options); })?; - if use_expected_context { - let context_excerpts_tasks = example - .example - .expected_context - .iter() - .flat_map(|section| { - section.alternatives[0].excerpts.iter().map(|excerpt| { - resolve_context_entry(project.clone(), excerpt.clone(), cx.clone()) - }) - }) - .collect::>(); - let context_excerpts_vec = futures::future::try_join_all(context_excerpts_tasks).await?; - - let mut context_excerpts = HashMap::default(); - for (buffer, mut excerpts) in context_excerpts_vec { - context_excerpts - .entry(buffer) - .or_insert(Vec::new()) - .append(&mut excerpts); - } + let prediction = match options.provider { + crate::PredictionProvider::Zeta2 => { + let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?; + + let debug_task = cx.background_spawn({ + let result = result.clone(); + async move { + let mut start_time = None; + let mut search_queries_generated_at = None; + let mut search_queries_executed_at = None; + while let Some(event) = debug_rx.next().await { + match event { + zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => { + start_time = Some(info.timestamp); + fs::write( + example_run_dir.join("search_prompt.md"), + &info.search_prompt, + )?; + } + zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => { + search_queries_generated_at = Some(info.timestamp); + fs::write( + example_run_dir.join("search_queries.json"), + serde_json::to_string_pretty(&info.search_queries).unwrap(), + )?; + } + zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => { + search_queries_executed_at = Some(info.timestamp); + } + zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {} + zeta2::ZetaDebugInfo::EditPredictionRequested(request) => { + let prediction_started_at = Instant::now(); + start_time.get_or_insert(prediction_started_at); + let prompt = request.local_prompt.unwrap_or_default(); + fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?; + + { + let mut result = result.lock().unwrap(); + result.prompt_len = prompt.chars().count(); + + for included_file in request.request.included_files { + let insertions = + vec![(request.request.cursor_point, CURSOR_MARKER)]; + result.excerpts.extend(included_file.excerpts.iter().map( + |excerpt| { + ActualExcerpt { + path: included_file + .path + .components() + .skip(1) + .collect(), + text: String::from(excerpt.text.as_ref()), + } + }, + )); + write_codeblock( + &included_file.path, + included_file.excerpts.iter(), + if included_file.path == request.request.excerpt_path { + &insertions + } else { + &[] + }, + included_file.max_row, + false, + &mut result.excerpts_text, + ); + } + } + + let response = + request.response_rx.await?.0.map_err(|err| anyhow!(err))?; + let response = + zeta2::text_from_response(response).unwrap_or_default(); + let prediction_finished_at = Instant::now(); + fs::write( + example_run_dir.join("prediction_response.md"), + &response, + )?; + + let mut result = result.lock().unwrap(); + result.generated_len = response.chars().count(); + + if !options.use_expected_context { + result.planning_search_time = Some( + search_queries_generated_at.unwrap() - start_time.unwrap(), + ); + result.running_search_time = Some( + search_queries_executed_at.unwrap() + - search_queries_generated_at.unwrap(), + ); + } + result.prediction_time = + prediction_finished_at - prediction_started_at; + result.total_time = prediction_finished_at - start_time.unwrap(); + + break; + } + } + } + anyhow::Ok(()) + } + }); + + if options.use_expected_context { + let context_excerpts_tasks = example + .example + .expected_context + .iter() + .flat_map(|section| { + section.alternatives[0].excerpts.iter().map(|excerpt| { + resolve_context_entry(project.clone(), excerpt.clone(), cx.clone()) + }) + }) + .collect::>(); + let context_excerpts_vec = + futures::future::try_join_all(context_excerpts_tasks).await?; + + let mut context_excerpts = HashMap::default(); + for (buffer, mut excerpts) in context_excerpts_vec { + context_excerpts + .entry(buffer) + .or_insert(Vec::new()) + .append(&mut excerpts); + } - zeta.update(cx, |zeta, _cx| { - zeta.set_context(project.clone(), context_excerpts) - })?; - } else { - zeta.update(cx, |zeta, cx| { - zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) - })? - .await?; - } + zeta.update(cx, |zeta, _cx| { + zeta.set_context(project.clone(), context_excerpts) + })?; + } else { + zeta.update(cx, |zeta, cx| { + zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) + })? + .await?; + } - let prediction = zeta - .update(cx, |zeta, cx| { - zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx) - })? - .await?; + let prediction = zeta + .update(cx, |zeta, cx| { + zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx) + })? + .await? + .map(|prediction| (prediction.buffer, prediction.snapshot, prediction.edits)); + + debug_task.await?; + + prediction + } + crate::PredictionProvider::Sweep => sweep + .unwrap() + .update(cx, |sweep, cx| { + let mut recent_paths = Vec::new(); + for path in zeta + .read(cx) + .history_for_project(&project) + .rev() + .filter_map(|event| event.project_path(cx)) + { + if !recent_paths.contains(&path) { + recent_paths.push(path); + } + } - debug_task.await?; + sweep.request_completion( + &project, + recent_paths.into_iter(), + &cursor_buffer, + cursor_anchor, + cx, + ) + })? + .await? + .map( + |sweep_ai::EditPrediction { + edits, snapshot, .. + }| { (cursor_buffer.clone(), snapshot, edits) }, + ), + }; let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap(); + result.diff = prediction - .map(|prediction| { - let old_text = prediction.snapshot.text(); - let new_text = prediction - .buffer + .map(|(buffer, snapshot, edits)| { + let old_text = snapshot.text(); + let new_text = buffer .update(cx, |buffer, cx| { let branch = buffer.branch(cx); branch.update(cx, |branch, cx| { - branch.edit(prediction.edits.iter().cloned(), None, cx); + branch.edit(edits.iter().cloned(), None, cx); branch.text() }) })