main.rs

  1mod evaluate;
  2mod example;
  3mod headless;
  4mod metrics;
  5mod paths;
  6mod predict;
  7mod source_location;
  8mod syntax_retrieval_stats;
  9mod util;
 10
 11use crate::{
 12    evaluate::run_evaluate,
 13    example::{ExampleFormat, NamedExample},
 14    headless::ZetaCliAppState,
 15    predict::run_predict,
 16    source_location::SourceLocation,
 17    syntax_retrieval_stats::retrieval_stats,
 18    util::{open_buffer, open_buffer_with_language_server},
 19};
 20use ::util::paths::PathStyle;
 21use anyhow::{Result, anyhow};
 22use clap::{Args, Parser, Subcommand, ValueEnum};
 23use cloud_llm_client::predict_edits_v3;
 24use edit_prediction_context::{
 25    EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
 26};
 27use gpui::{Application, AsyncApp, Entity, prelude::*};
 28use language::{Bias, Buffer, BufferSnapshot, Point};
 29use project::{Project, Worktree};
 30use reqwest_client::ReqwestClient;
 31use serde_json::json;
 32use std::io::{self};
 33use std::time::Duration;
 34use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
 35use zeta::ContextMode;
 36
 37#[derive(Parser, Debug)]
 38#[command(name = "zeta")]
 39struct ZetaCliArgs {
 40    #[arg(long, default_value_t = false)]
 41    printenv: bool,
 42    #[command(subcommand)]
 43    command: Option<Command>,
 44}
 45
 46#[derive(Subcommand, Debug)]
 47enum Command {
 48    Context(ContextArgs),
 49    ContextStats(ContextStatsArgs),
 50    Predict(PredictArguments),
 51    Eval(EvaluateArguments),
 52    ConvertExample {
 53        path: PathBuf,
 54        #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
 55        output_format: ExampleFormat,
 56    },
 57    Clean,
 58}
 59
 60#[derive(Debug, Args)]
 61struct ContextStatsArgs {
 62    #[arg(long)]
 63    worktree: PathBuf,
 64    #[arg(long)]
 65    extension: Option<String>,
 66    #[arg(long)]
 67    limit: Option<usize>,
 68    #[arg(long)]
 69    skip: Option<usize>,
 70    #[clap(flatten)]
 71    zeta2_args: Zeta2Args,
 72}
 73
 74#[derive(Debug, Args)]
 75struct ContextArgs {
 76    #[arg(long)]
 77    provider: ContextProvider,
 78    #[arg(long)]
 79    worktree: PathBuf,
 80    #[arg(long)]
 81    cursor: SourceLocation,
 82    #[arg(long)]
 83    use_language_server: bool,
 84    #[arg(long)]
 85    edit_history: Option<FileOrStdin>,
 86    #[clap(flatten)]
 87    zeta2_args: Zeta2Args,
 88}
 89
 90#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
 91enum ContextProvider {
 92    Zeta1,
 93    #[default]
 94    Syntax,
 95}
 96
 97#[derive(Clone, Debug, Args)]
 98struct Zeta2Args {
 99    #[arg(long, default_value_t = 8192)]
100    max_prompt_bytes: usize,
101    #[arg(long, default_value_t = 2048)]
102    max_excerpt_bytes: usize,
103    #[arg(long, default_value_t = 1024)]
104    min_excerpt_bytes: usize,
105    #[arg(long, default_value_t = 0.66)]
106    target_before_cursor_over_total_bytes: f32,
107    #[arg(long, default_value_t = 1024)]
108    max_diagnostic_bytes: usize,
109    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
110    prompt_format: PromptFormat,
111    #[arg(long, value_enum, default_value_t = Default::default())]
112    output_format: OutputFormat,
113    #[arg(long, default_value_t = 42)]
114    file_indexing_parallelism: usize,
115    #[arg(long, default_value_t = false)]
116    disable_imports_gathering: bool,
117    #[arg(long, default_value_t = u8::MAX)]
118    max_retrieved_definitions: u8,
119}
120
121#[derive(Debug, Args)]
122pub struct PredictArguments {
123    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
124    format: PredictionsOutputFormat,
125    example_path: PathBuf,
126    #[clap(flatten)]
127    options: PredictionOptions,
128}
129
130#[derive(Clone, Debug, Args)]
131pub struct PredictionOptions {
132    #[clap(flatten)]
133    zeta2: Zeta2Args,
134    #[clap(long)]
135    provider: PredictionProvider,
136    #[clap(long, value_enum, default_value_t = CacheMode::default())]
137    cache: CacheMode,
138}
139
140#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
141pub enum CacheMode {
142    /// Use cached LLM requests and responses, except when multiple repetitions are requested
143    #[default]
144    Auto,
145    /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
146    #[value(alias = "request")]
147    Requests,
148    /// Ignore existing cache entries for both LLM and search.
149    Skip,
150    /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
151    /// Useful for reproducing results and fixing bugs outside of search queries
152    Force,
153}
154
155impl CacheMode {
156    fn use_cached_llm_responses(&self) -> bool {
157        self.assert_not_auto();
158        matches!(self, CacheMode::Requests | CacheMode::Force)
159    }
160
161    fn use_cached_search_results(&self) -> bool {
162        self.assert_not_auto();
163        matches!(self, CacheMode::Force)
164    }
165
166    fn assert_not_auto(&self) {
167        assert_ne!(
168            *self,
169            CacheMode::Auto,
170            "Cache mode should not be auto at this point!"
171        );
172    }
173}
174
175#[derive(clap::ValueEnum, Debug, Clone)]
176pub enum PredictionsOutputFormat {
177    Json,
178    Md,
179    Diff,
180}
181
182#[derive(Debug, Args)]
183pub struct EvaluateArguments {
184    example_paths: Vec<PathBuf>,
185    #[clap(flatten)]
186    options: PredictionOptions,
187    #[clap(short, long, default_value_t = 1, alias = "repeat")]
188    repetitions: u16,
189    #[arg(long)]
190    skip_prediction: bool,
191}
192
193#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
194enum PredictionProvider {
195    Zeta1,
196    #[default]
197    Zeta2,
198    Sweep,
199}
200
201fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta::ZetaOptions {
202    zeta::ZetaOptions {
203        context: ContextMode::Syntax(EditPredictionContextOptions {
204            max_retrieved_declarations: args.max_retrieved_definitions,
205            use_imports: !args.disable_imports_gathering,
206            excerpt: EditPredictionExcerptOptions {
207                max_bytes: args.max_excerpt_bytes,
208                min_bytes: args.min_excerpt_bytes,
209                target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
210            },
211            score: EditPredictionScoreOptions {
212                omit_excerpt_overlaps,
213            },
214        }),
215        max_diagnostic_bytes: args.max_diagnostic_bytes,
216        max_prompt_bytes: args.max_prompt_bytes,
217        prompt_format: args.prompt_format.into(),
218        file_indexing_parallelism: args.file_indexing_parallelism,
219        buffer_change_grouping_interval: Duration::ZERO,
220    }
221}
222
223#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
224enum PromptFormat {
225    MarkedExcerpt,
226    LabeledSections,
227    OnlySnippets,
228    #[default]
229    NumberedLines,
230    OldTextNewText,
231    Minimal,
232    MinimalQwen,
233    SeedCoder1120,
234}
235
236impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
237    fn into(self) -> predict_edits_v3::PromptFormat {
238        match self {
239            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
240            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
241            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
242            Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
243            Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
244            Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
245            Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
246            Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
247        }
248    }
249}
250
251#[derive(clap::ValueEnum, Default, Debug, Clone)]
252enum OutputFormat {
253    #[default]
254    Prompt,
255    Request,
256    Full,
257}
258
259#[derive(Debug, Clone)]
260enum FileOrStdin {
261    File(PathBuf),
262    Stdin,
263}
264
265impl FileOrStdin {
266    async fn read_to_string(&self) -> Result<String, std::io::Error> {
267        match self {
268            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
269            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
270        }
271    }
272}
273
274impl FromStr for FileOrStdin {
275    type Err = <PathBuf as FromStr>::Err;
276
277    fn from_str(s: &str) -> Result<Self, Self::Err> {
278        match s {
279            "-" => Ok(Self::Stdin),
280            _ => Ok(Self::File(PathBuf::from_str(s)?)),
281        }
282    }
283}
284
285struct LoadedContext {
286    full_path_str: String,
287    snapshot: BufferSnapshot,
288    clipped_cursor: Point,
289    worktree: Entity<Worktree>,
290    project: Entity<Project>,
291    buffer: Entity<Buffer>,
292}
293
294async fn load_context(
295    args: &ContextArgs,
296    app_state: &Arc<ZetaCliAppState>,
297    cx: &mut AsyncApp,
298) -> Result<LoadedContext> {
299    let ContextArgs {
300        worktree: worktree_path,
301        cursor,
302        use_language_server,
303        ..
304    } = args;
305
306    let worktree_path = worktree_path.canonicalize()?;
307
308    let project = cx.update(|cx| {
309        Project::local(
310            app_state.client.clone(),
311            app_state.node_runtime.clone(),
312            app_state.user_store.clone(),
313            app_state.languages.clone(),
314            app_state.fs.clone(),
315            None,
316            cx,
317        )
318    })?;
319
320    let worktree = project
321        .update(cx, |project, cx| {
322            project.create_worktree(&worktree_path, true, cx)
323        })?
324        .await?;
325
326    let mut ready_languages = HashSet::default();
327    let (_lsp_open_handle, buffer) = if *use_language_server {
328        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
329            project.clone(),
330            worktree.clone(),
331            cursor.path.clone(),
332            &mut ready_languages,
333            cx,
334        )
335        .await?;
336        (Some(lsp_open_handle), buffer)
337    } else {
338        let buffer =
339            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
340        (None, buffer)
341    };
342
343    let full_path_str = worktree
344        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
345        .display(PathStyle::local())
346        .to_string();
347
348    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
349    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
350    if clipped_cursor != cursor.point {
351        let max_row = snapshot.max_point().row;
352        if cursor.point.row < max_row {
353            return Err(anyhow!(
354                "Cursor position {:?} is out of bounds (line length is {})",
355                cursor.point,
356                snapshot.line_len(cursor.point.row)
357            ));
358        } else {
359            return Err(anyhow!(
360                "Cursor position {:?} is out of bounds (max row is {})",
361                cursor.point,
362                max_row
363            ));
364        }
365    }
366
367    Ok(LoadedContext {
368        full_path_str,
369        snapshot,
370        clipped_cursor,
371        worktree,
372        project,
373        buffer,
374    })
375}
376
377async fn zeta2_syntax_context(
378    args: ContextArgs,
379    app_state: &Arc<ZetaCliAppState>,
380    cx: &mut AsyncApp,
381) -> Result<String> {
382    let LoadedContext {
383        worktree,
384        project,
385        buffer,
386        clipped_cursor,
387        ..
388    } = load_context(&args, app_state, cx).await?;
389
390    // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
391    // the whole worktree.
392    worktree
393        .read_with(cx, |worktree, _cx| {
394            worktree.as_local().unwrap().scan_complete()
395        })?
396        .await;
397    let output = cx
398        .update(|cx| {
399            let zeta = cx.new(|cx| {
400                zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
401            });
402            let indexing_done_task = zeta.update(cx, |zeta, cx| {
403                zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true));
404                zeta.register_buffer(&buffer, &project, cx);
405                zeta.wait_for_initial_indexing(&project, cx)
406            });
407            cx.spawn(async move |cx| {
408                indexing_done_task.await?;
409                let request = zeta
410                    .update(cx, |zeta, cx| {
411                        let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
412                        zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
413                    })?
414                    .await?;
415
416                let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
417
418                match args.zeta2_args.output_format {
419                    OutputFormat::Prompt => anyhow::Ok(prompt_string),
420                    OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
421                    OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
422                        "request": request,
423                        "prompt": prompt_string,
424                        "section_labels": section_labels,
425                    }))?),
426                }
427            })
428        })?
429        .await?;
430
431    Ok(output)
432}
433
434async fn zeta1_context(
435    args: ContextArgs,
436    app_state: &Arc<ZetaCliAppState>,
437    cx: &mut AsyncApp,
438) -> Result<zeta::zeta1::GatherContextOutput> {
439    let LoadedContext {
440        full_path_str,
441        snapshot,
442        clipped_cursor,
443        ..
444    } = load_context(&args, app_state, cx).await?;
445
446    let events = match args.edit_history {
447        Some(events) => events.read_to_string().await?,
448        None => String::new(),
449    };
450
451    let prompt_for_events = move || (events, 0);
452    cx.update(|cx| {
453        zeta::zeta1::gather_context(
454            full_path_str,
455            &snapshot,
456            clipped_cursor,
457            prompt_for_events,
458            cloud_llm_client::PredictEditsRequestTrigger::Cli,
459            cx,
460        )
461    })?
462    .await
463}
464
465fn main() {
466    zlog::init();
467    zlog::init_output_stderr();
468    let args = ZetaCliArgs::parse();
469    let http_client = Arc::new(ReqwestClient::new());
470    let app = Application::headless().with_http_client(http_client);
471
472    app.run(move |cx| {
473        let app_state = Arc::new(headless::init(cx));
474        cx.spawn(async move |cx| {
475            match args.command {
476                None => {
477                    if args.printenv {
478                        ::util::shell_env::print_env();
479                        return;
480                    } else {
481                        panic!("Expected a command");
482                    }
483                }
484                Some(Command::ContextStats(arguments)) => {
485                    let result = retrieval_stats(
486                        arguments.worktree,
487                        app_state,
488                        arguments.extension,
489                        arguments.limit,
490                        arguments.skip,
491                        zeta2_args_to_options(&arguments.zeta2_args, false),
492                        cx,
493                    )
494                    .await;
495                    println!("{}", result.unwrap());
496                }
497                Some(Command::Context(context_args)) => {
498                    let result = match context_args.provider {
499                        ContextProvider::Zeta1 => {
500                            let context =
501                                zeta1_context(context_args, &app_state, cx).await.unwrap();
502                            serde_json::to_string_pretty(&context.body).unwrap()
503                        }
504                        ContextProvider::Syntax => {
505                            zeta2_syntax_context(context_args, &app_state, cx)
506                                .await
507                                .unwrap()
508                        }
509                    };
510                    println!("{}", result);
511                }
512                Some(Command::Predict(arguments)) => {
513                    run_predict(arguments, &app_state, cx).await;
514                }
515                Some(Command::Eval(arguments)) => {
516                    run_evaluate(arguments, &app_state, cx).await;
517                }
518                Some(Command::ConvertExample {
519                    path,
520                    output_format,
521                }) => {
522                    let example = NamedExample::load(path).unwrap();
523                    example.write(output_format, io::stdout()).unwrap();
524                }
525                Some(Command::Clean) => {
526                    std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
527                }
528            };
529
530            let _ = cx.update(|cx| cx.quit());
531        })
532        .detach();
533    });
534}