main.rs

  1mod evaluate;
  2mod example;
  3mod headless;
  4mod paths;
  5mod predict;
  6mod source_location;
  7mod syntax_retrieval_stats;
  8mod util;
  9
 10use crate::{
 11    evaluate::run_evaluate,
 12    example::{ExampleFormat, NamedExample},
 13    headless::ZetaCliAppState,
 14    predict::run_predict,
 15    source_location::SourceLocation,
 16    syntax_retrieval_stats::retrieval_stats,
 17    util::{open_buffer, open_buffer_with_language_server},
 18};
 19use ::util::paths::PathStyle;
 20use anyhow::{Result, anyhow};
 21use clap::{Args, Parser, Subcommand, ValueEnum};
 22use cloud_llm_client::predict_edits_v3;
 23use edit_prediction_context::{
 24    EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
 25};
 26use gpui::{Application, AsyncApp, Entity, prelude::*};
 27use language::{Bias, Buffer, BufferSnapshot, Point};
 28use project::{Project, Worktree};
 29use reqwest_client::ReqwestClient;
 30use serde_json::json;
 31use std::io::{self};
 32use std::time::Duration;
 33use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
 34use zeta2::ContextMode;
 35
 36#[derive(Parser, Debug)]
 37#[command(name = "zeta")]
 38struct ZetaCliArgs {
 39    #[arg(long, default_value_t = false)]
 40    printenv: bool,
 41    #[command(subcommand)]
 42    command: Option<Command>,
 43}
 44
 45#[derive(Subcommand, Debug)]
 46enum Command {
 47    Context(ContextArgs),
 48    ContextStats(ContextStatsArgs),
 49    Predict(PredictArguments),
 50    Eval(EvaluateArguments),
 51    ConvertExample {
 52        path: PathBuf,
 53        #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
 54        output_format: ExampleFormat,
 55    },
 56    Clean,
 57}
 58
 59#[derive(Debug, Args)]
 60struct ContextStatsArgs {
 61    #[arg(long)]
 62    worktree: PathBuf,
 63    #[arg(long)]
 64    extension: Option<String>,
 65    #[arg(long)]
 66    limit: Option<usize>,
 67    #[arg(long)]
 68    skip: Option<usize>,
 69    #[clap(flatten)]
 70    zeta2_args: Zeta2Args,
 71}
 72
 73#[derive(Debug, Args)]
 74struct ContextArgs {
 75    #[arg(long)]
 76    provider: ContextProvider,
 77    #[arg(long)]
 78    worktree: PathBuf,
 79    #[arg(long)]
 80    cursor: SourceLocation,
 81    #[arg(long)]
 82    use_language_server: bool,
 83    #[arg(long)]
 84    edit_history: Option<FileOrStdin>,
 85    #[clap(flatten)]
 86    zeta2_args: Zeta2Args,
 87}
 88
 89#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
 90enum ContextProvider {
 91    Zeta1,
 92    #[default]
 93    Syntax,
 94}
 95
 96#[derive(Clone, Debug, Args)]
 97struct Zeta2Args {
 98    #[arg(long, default_value_t = 8192)]
 99    max_prompt_bytes: usize,
100    #[arg(long, default_value_t = 2048)]
101    max_excerpt_bytes: usize,
102    #[arg(long, default_value_t = 1024)]
103    min_excerpt_bytes: usize,
104    #[arg(long, default_value_t = 0.66)]
105    target_before_cursor_over_total_bytes: f32,
106    #[arg(long, default_value_t = 1024)]
107    max_diagnostic_bytes: usize,
108    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
109    prompt_format: PromptFormat,
110    #[arg(long, value_enum, default_value_t = Default::default())]
111    output_format: OutputFormat,
112    #[arg(long, default_value_t = 42)]
113    file_indexing_parallelism: usize,
114    #[arg(long, default_value_t = false)]
115    disable_imports_gathering: bool,
116    #[arg(long, default_value_t = u8::MAX)]
117    max_retrieved_definitions: u8,
118}
119
120#[derive(Debug, Args)]
121pub struct PredictArguments {
122    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
123    format: PredictionsOutputFormat,
124    example_path: PathBuf,
125    #[clap(flatten)]
126    options: PredictionOptions,
127}
128
129#[derive(Clone, Debug, Args)]
130pub struct PredictionOptions {
131    #[arg(long)]
132    use_expected_context: bool,
133    #[clap(flatten)]
134    zeta2: Zeta2Args,
135    #[clap(long)]
136    provider: PredictionProvider,
137    #[clap(long, value_enum, default_value_t = CacheMode::default())]
138    cache: CacheMode,
139}
140
141#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
142pub enum CacheMode {
143    /// Use cached LLM requests and responses, except when multiple repetitions are requested
144    #[default]
145    Auto,
146    /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
147    #[value(alias = "request")]
148    Requests,
149    /// Ignore existing cache entries for both LLM and search.
150    Skip,
151    /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
152    /// Useful for reproducing results and fixing bugs outside of search queries
153    Force,
154}
155
156impl CacheMode {
157    fn use_cached_llm_responses(&self) -> bool {
158        self.assert_not_auto();
159        matches!(self, CacheMode::Requests | CacheMode::Force)
160    }
161
162    fn use_cached_search_results(&self) -> bool {
163        self.assert_not_auto();
164        matches!(self, CacheMode::Force)
165    }
166
167    fn assert_not_auto(&self) {
168        assert_ne!(
169            *self,
170            CacheMode::Auto,
171            "Cache mode should not be auto at this point!"
172        );
173    }
174}
175
176#[derive(clap::ValueEnum, Debug, Clone)]
177pub enum PredictionsOutputFormat {
178    Json,
179    Md,
180    Diff,
181}
182
183#[derive(Debug, Args)]
184pub struct EvaluateArguments {
185    example_paths: Vec<PathBuf>,
186    #[clap(flatten)]
187    options: PredictionOptions,
188    #[clap(short, long, default_value_t = 1, alias = "repeat")]
189    repetitions: u16,
190    #[arg(long)]
191    skip_prediction: bool,
192}
193
194#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
195enum PredictionProvider {
196    #[default]
197    Zeta2,
198    Sweep,
199}
200
201fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions {
202    zeta2::ZetaOptions {
203        context: ContextMode::Syntax(EditPredictionContextOptions {
204            max_retrieved_declarations: args.max_retrieved_definitions,
205            use_imports: !args.disable_imports_gathering,
206            excerpt: EditPredictionExcerptOptions {
207                max_bytes: args.max_excerpt_bytes,
208                min_bytes: args.min_excerpt_bytes,
209                target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
210            },
211            score: EditPredictionScoreOptions {
212                omit_excerpt_overlaps,
213            },
214        }),
215        max_diagnostic_bytes: args.max_diagnostic_bytes,
216        max_prompt_bytes: args.max_prompt_bytes,
217        prompt_format: args.prompt_format.into(),
218        file_indexing_parallelism: args.file_indexing_parallelism,
219        buffer_change_grouping_interval: Duration::ZERO,
220    }
221}
222
223#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
224enum PromptFormat {
225    MarkedExcerpt,
226    LabeledSections,
227    OnlySnippets,
228    #[default]
229    NumberedLines,
230    OldTextNewText,
231    Minimal,
232    MinimalQwen,
233}
234
235impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
236    fn into(self) -> predict_edits_v3::PromptFormat {
237        match self {
238            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
239            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
240            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
241            Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
242            Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
243            Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
244            Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
245        }
246    }
247}
248
249#[derive(clap::ValueEnum, Default, Debug, Clone)]
250enum OutputFormat {
251    #[default]
252    Prompt,
253    Request,
254    Full,
255}
256
257#[derive(Debug, Clone)]
258enum FileOrStdin {
259    File(PathBuf),
260    Stdin,
261}
262
263impl FileOrStdin {
264    async fn read_to_string(&self) -> Result<String, std::io::Error> {
265        match self {
266            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
267            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
268        }
269    }
270}
271
272impl FromStr for FileOrStdin {
273    type Err = <PathBuf as FromStr>::Err;
274
275    fn from_str(s: &str) -> Result<Self, Self::Err> {
276        match s {
277            "-" => Ok(Self::Stdin),
278            _ => Ok(Self::File(PathBuf::from_str(s)?)),
279        }
280    }
281}
282
283struct LoadedContext {
284    full_path_str: String,
285    snapshot: BufferSnapshot,
286    clipped_cursor: Point,
287    worktree: Entity<Worktree>,
288    project: Entity<Project>,
289    buffer: Entity<Buffer>,
290}
291
292async fn load_context(
293    args: &ContextArgs,
294    app_state: &Arc<ZetaCliAppState>,
295    cx: &mut AsyncApp,
296) -> Result<LoadedContext> {
297    let ContextArgs {
298        worktree: worktree_path,
299        cursor,
300        use_language_server,
301        ..
302    } = args;
303
304    let worktree_path = worktree_path.canonicalize()?;
305
306    let project = cx.update(|cx| {
307        Project::local(
308            app_state.client.clone(),
309            app_state.node_runtime.clone(),
310            app_state.user_store.clone(),
311            app_state.languages.clone(),
312            app_state.fs.clone(),
313            None,
314            cx,
315        )
316    })?;
317
318    let worktree = project
319        .update(cx, |project, cx| {
320            project.create_worktree(&worktree_path, true, cx)
321        })?
322        .await?;
323
324    let mut ready_languages = HashSet::default();
325    let (_lsp_open_handle, buffer) = if *use_language_server {
326        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
327            project.clone(),
328            worktree.clone(),
329            cursor.path.clone(),
330            &mut ready_languages,
331            cx,
332        )
333        .await?;
334        (Some(lsp_open_handle), buffer)
335    } else {
336        let buffer =
337            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
338        (None, buffer)
339    };
340
341    let full_path_str = worktree
342        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
343        .display(PathStyle::local())
344        .to_string();
345
346    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
347    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
348    if clipped_cursor != cursor.point {
349        let max_row = snapshot.max_point().row;
350        if cursor.point.row < max_row {
351            return Err(anyhow!(
352                "Cursor position {:?} is out of bounds (line length is {})",
353                cursor.point,
354                snapshot.line_len(cursor.point.row)
355            ));
356        } else {
357            return Err(anyhow!(
358                "Cursor position {:?} is out of bounds (max row is {})",
359                cursor.point,
360                max_row
361            ));
362        }
363    }
364
365    Ok(LoadedContext {
366        full_path_str,
367        snapshot,
368        clipped_cursor,
369        worktree,
370        project,
371        buffer,
372    })
373}
374
375async fn zeta2_syntax_context(
376    args: ContextArgs,
377    app_state: &Arc<ZetaCliAppState>,
378    cx: &mut AsyncApp,
379) -> Result<String> {
380    let LoadedContext {
381        worktree,
382        project,
383        buffer,
384        clipped_cursor,
385        ..
386    } = load_context(&args, app_state, cx).await?;
387
388    // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
389    // the whole worktree.
390    worktree
391        .read_with(cx, |worktree, _cx| {
392            worktree.as_local().unwrap().scan_complete()
393        })?
394        .await;
395    let output = cx
396        .update(|cx| {
397            let zeta = cx.new(|cx| {
398                zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
399            });
400            let indexing_done_task = zeta.update(cx, |zeta, cx| {
401                zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true));
402                zeta.register_buffer(&buffer, &project, cx);
403                zeta.wait_for_initial_indexing(&project, cx)
404            });
405            cx.spawn(async move |cx| {
406                indexing_done_task.await?;
407                let request = zeta
408                    .update(cx, |zeta, cx| {
409                        let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
410                        zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
411                    })?
412                    .await?;
413
414                let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
415
416                match args.zeta2_args.output_format {
417                    OutputFormat::Prompt => anyhow::Ok(prompt_string),
418                    OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
419                    OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
420                        "request": request,
421                        "prompt": prompt_string,
422                        "section_labels": section_labels,
423                    }))?),
424                }
425            })
426        })?
427        .await?;
428
429    Ok(output)
430}
431
432async fn zeta1_context(
433    args: ContextArgs,
434    app_state: &Arc<ZetaCliAppState>,
435    cx: &mut AsyncApp,
436) -> Result<zeta::GatherContextOutput> {
437    let LoadedContext {
438        full_path_str,
439        snapshot,
440        clipped_cursor,
441        ..
442    } = load_context(&args, app_state, cx).await?;
443
444    let events = match args.edit_history {
445        Some(events) => events.read_to_string().await?,
446        None => String::new(),
447    };
448
449    let prompt_for_events = move || (events, 0);
450    cx.update(|cx| {
451        zeta::gather_context(
452            full_path_str,
453            &snapshot,
454            clipped_cursor,
455            prompt_for_events,
456            cx,
457        )
458    })?
459    .await
460}
461
462fn main() {
463    zlog::init();
464    zlog::init_output_stderr();
465    let args = ZetaCliArgs::parse();
466    let http_client = Arc::new(ReqwestClient::new());
467    let app = Application::headless().with_http_client(http_client);
468
469    app.run(move |cx| {
470        let app_state = Arc::new(headless::init(cx));
471        cx.spawn(async move |cx| {
472            match args.command {
473                None => {
474                    if args.printenv {
475                        ::util::shell_env::print_env();
476                        return;
477                    } else {
478                        panic!("Expected a command");
479                    }
480                }
481                Some(Command::ContextStats(arguments)) => {
482                    let result = retrieval_stats(
483                        arguments.worktree,
484                        app_state,
485                        arguments.extension,
486                        arguments.limit,
487                        arguments.skip,
488                        zeta2_args_to_options(&arguments.zeta2_args, false),
489                        cx,
490                    )
491                    .await;
492                    println!("{}", result.unwrap());
493                }
494                Some(Command::Context(context_args)) => {
495                    let result = match context_args.provider {
496                        ContextProvider::Zeta1 => {
497                            let context =
498                                zeta1_context(context_args, &app_state, cx).await.unwrap();
499                            serde_json::to_string_pretty(&context.body).unwrap()
500                        }
501                        ContextProvider::Syntax => {
502                            zeta2_syntax_context(context_args, &app_state, cx)
503                                .await
504                                .unwrap()
505                        }
506                    };
507                    println!("{}", result);
508                }
509                Some(Command::Predict(arguments)) => {
510                    run_predict(arguments, &app_state, cx).await;
511                }
512                Some(Command::Eval(arguments)) => {
513                    run_evaluate(arguments, &app_state, cx).await;
514                }
515                Some(Command::ConvertExample {
516                    path,
517                    output_format,
518                }) => {
519                    let example = NamedExample::load(path).unwrap();
520                    example.write(output_format, io::stdout()).unwrap();
521                }
522                Some(Command::Clean) => {
523                    std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
524                }
525            };
526
527            let _ = cx.update(|cx| cx.quit());
528        })
529        .detach();
530    });
531}