main.rs

  1mod evaluate;
  2mod example;
  3mod headless;
  4mod paths;
  5mod predict;
  6mod source_location;
  7mod syntax_retrieval_stats;
  8mod util;
  9
 10use crate::{
 11    evaluate::run_evaluate,
 12    example::{ExampleFormat, NamedExample},
 13    headless::ZetaCliAppState,
 14    predict::run_predict,
 15    source_location::SourceLocation,
 16    syntax_retrieval_stats::retrieval_stats,
 17    util::{open_buffer, open_buffer_with_language_server},
 18};
 19use ::util::paths::PathStyle;
 20use anyhow::{Result, anyhow};
 21use clap::{Args, Parser, Subcommand, ValueEnum};
 22use cloud_llm_client::predict_edits_v3;
 23use edit_prediction_context::{
 24    EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
 25};
 26use gpui::{Application, AsyncApp, Entity, prelude::*};
 27use language::{Bias, Buffer, BufferSnapshot, Point};
 28use project::{Project, Worktree};
 29use reqwest_client::ReqwestClient;
 30use serde_json::json;
 31use std::io::{self};
 32use std::time::Duration;
 33use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
 34use zeta::ContextMode;
 35
 36#[derive(Parser, Debug)]
 37#[command(name = "zeta")]
 38struct ZetaCliArgs {
 39    #[arg(long, default_value_t = false)]
 40    printenv: bool,
 41    #[command(subcommand)]
 42    command: Option<Command>,
 43}
 44
 45#[derive(Subcommand, Debug)]
 46enum Command {
 47    Context(ContextArgs),
 48    ContextStats(ContextStatsArgs),
 49    Predict(PredictArguments),
 50    Eval(EvaluateArguments),
 51    ConvertExample {
 52        path: PathBuf,
 53        #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
 54        output_format: ExampleFormat,
 55    },
 56    Clean,
 57}
 58
 59#[derive(Debug, Args)]
 60struct ContextStatsArgs {
 61    #[arg(long)]
 62    worktree: PathBuf,
 63    #[arg(long)]
 64    extension: Option<String>,
 65    #[arg(long)]
 66    limit: Option<usize>,
 67    #[arg(long)]
 68    skip: Option<usize>,
 69    #[clap(flatten)]
 70    zeta2_args: Zeta2Args,
 71}
 72
 73#[derive(Debug, Args)]
 74struct ContextArgs {
 75    #[arg(long)]
 76    provider: ContextProvider,
 77    #[arg(long)]
 78    worktree: PathBuf,
 79    #[arg(long)]
 80    cursor: SourceLocation,
 81    #[arg(long)]
 82    use_language_server: bool,
 83    #[arg(long)]
 84    edit_history: Option<FileOrStdin>,
 85    #[clap(flatten)]
 86    zeta2_args: Zeta2Args,
 87}
 88
 89#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
 90enum ContextProvider {
 91    Zeta1,
 92    #[default]
 93    Syntax,
 94}
 95
 96#[derive(Clone, Debug, Args)]
 97struct Zeta2Args {
 98    #[arg(long, default_value_t = 8192)]
 99    max_prompt_bytes: usize,
100    #[arg(long, default_value_t = 2048)]
101    max_excerpt_bytes: usize,
102    #[arg(long, default_value_t = 1024)]
103    min_excerpt_bytes: usize,
104    #[arg(long, default_value_t = 0.66)]
105    target_before_cursor_over_total_bytes: f32,
106    #[arg(long, default_value_t = 1024)]
107    max_diagnostic_bytes: usize,
108    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
109    prompt_format: PromptFormat,
110    #[arg(long, value_enum, default_value_t = Default::default())]
111    output_format: OutputFormat,
112    #[arg(long, default_value_t = 42)]
113    file_indexing_parallelism: usize,
114    #[arg(long, default_value_t = false)]
115    disable_imports_gathering: bool,
116    #[arg(long, default_value_t = u8::MAX)]
117    max_retrieved_definitions: u8,
118}
119
120#[derive(Debug, Args)]
121pub struct PredictArguments {
122    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
123    format: PredictionsOutputFormat,
124    example_path: PathBuf,
125    #[clap(flatten)]
126    options: PredictionOptions,
127}
128
129#[derive(Clone, Debug, Args)]
130pub struct PredictionOptions {
131    #[clap(flatten)]
132    zeta2: Zeta2Args,
133    #[clap(long)]
134    provider: PredictionProvider,
135    #[clap(long, value_enum, default_value_t = CacheMode::default())]
136    cache: CacheMode,
137}
138
139#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
140pub enum CacheMode {
141    /// Use cached LLM requests and responses, except when multiple repetitions are requested
142    #[default]
143    Auto,
144    /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
145    #[value(alias = "request")]
146    Requests,
147    /// Ignore existing cache entries for both LLM and search.
148    Skip,
149    /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
150    /// Useful for reproducing results and fixing bugs outside of search queries
151    Force,
152}
153
154impl CacheMode {
155    fn use_cached_llm_responses(&self) -> bool {
156        self.assert_not_auto();
157        matches!(self, CacheMode::Requests | CacheMode::Force)
158    }
159
160    fn use_cached_search_results(&self) -> bool {
161        self.assert_not_auto();
162        matches!(self, CacheMode::Force)
163    }
164
165    fn assert_not_auto(&self) {
166        assert_ne!(
167            *self,
168            CacheMode::Auto,
169            "Cache mode should not be auto at this point!"
170        );
171    }
172}
173
174#[derive(clap::ValueEnum, Debug, Clone)]
175pub enum PredictionsOutputFormat {
176    Json,
177    Md,
178    Diff,
179}
180
181#[derive(Debug, Args)]
182pub struct EvaluateArguments {
183    example_paths: Vec<PathBuf>,
184    #[clap(flatten)]
185    options: PredictionOptions,
186    #[clap(short, long, default_value_t = 1, alias = "repeat")]
187    repetitions: u16,
188    #[arg(long)]
189    skip_prediction: bool,
190}
191
192#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
193enum PredictionProvider {
194    Zeta1,
195    #[default]
196    Zeta2,
197    Sweep,
198}
199
200fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta::ZetaOptions {
201    zeta::ZetaOptions {
202        context: ContextMode::Syntax(EditPredictionContextOptions {
203            max_retrieved_declarations: args.max_retrieved_definitions,
204            use_imports: !args.disable_imports_gathering,
205            excerpt: EditPredictionExcerptOptions {
206                max_bytes: args.max_excerpt_bytes,
207                min_bytes: args.min_excerpt_bytes,
208                target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
209            },
210            score: EditPredictionScoreOptions {
211                omit_excerpt_overlaps,
212            },
213        }),
214        max_diagnostic_bytes: args.max_diagnostic_bytes,
215        max_prompt_bytes: args.max_prompt_bytes,
216        prompt_format: args.prompt_format.into(),
217        file_indexing_parallelism: args.file_indexing_parallelism,
218        buffer_change_grouping_interval: Duration::ZERO,
219    }
220}
221
222#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
223enum PromptFormat {
224    MarkedExcerpt,
225    LabeledSections,
226    OnlySnippets,
227    #[default]
228    NumberedLines,
229    OldTextNewText,
230    Minimal,
231    MinimalQwen,
232    SeedCoder1120,
233}
234
235impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
236    fn into(self) -> predict_edits_v3::PromptFormat {
237        match self {
238            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
239            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
240            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
241            Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
242            Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
243            Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
244            Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
245            Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
246        }
247    }
248}
249
250#[derive(clap::ValueEnum, Default, Debug, Clone)]
251enum OutputFormat {
252    #[default]
253    Prompt,
254    Request,
255    Full,
256}
257
258#[derive(Debug, Clone)]
259enum FileOrStdin {
260    File(PathBuf),
261    Stdin,
262}
263
264impl FileOrStdin {
265    async fn read_to_string(&self) -> Result<String, std::io::Error> {
266        match self {
267            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
268            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
269        }
270    }
271}
272
273impl FromStr for FileOrStdin {
274    type Err = <PathBuf as FromStr>::Err;
275
276    fn from_str(s: &str) -> Result<Self, Self::Err> {
277        match s {
278            "-" => Ok(Self::Stdin),
279            _ => Ok(Self::File(PathBuf::from_str(s)?)),
280        }
281    }
282}
283
284struct LoadedContext {
285    full_path_str: String,
286    snapshot: BufferSnapshot,
287    clipped_cursor: Point,
288    worktree: Entity<Worktree>,
289    project: Entity<Project>,
290    buffer: Entity<Buffer>,
291}
292
293async fn load_context(
294    args: &ContextArgs,
295    app_state: &Arc<ZetaCliAppState>,
296    cx: &mut AsyncApp,
297) -> Result<LoadedContext> {
298    let ContextArgs {
299        worktree: worktree_path,
300        cursor,
301        use_language_server,
302        ..
303    } = args;
304
305    let worktree_path = worktree_path.canonicalize()?;
306
307    let project = cx.update(|cx| {
308        Project::local(
309            app_state.client.clone(),
310            app_state.node_runtime.clone(),
311            app_state.user_store.clone(),
312            app_state.languages.clone(),
313            app_state.fs.clone(),
314            None,
315            cx,
316        )
317    })?;
318
319    let worktree = project
320        .update(cx, |project, cx| {
321            project.create_worktree(&worktree_path, true, cx)
322        })?
323        .await?;
324
325    let mut ready_languages = HashSet::default();
326    let (_lsp_open_handle, buffer) = if *use_language_server {
327        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
328            project.clone(),
329            worktree.clone(),
330            cursor.path.clone(),
331            &mut ready_languages,
332            cx,
333        )
334        .await?;
335        (Some(lsp_open_handle), buffer)
336    } else {
337        let buffer =
338            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
339        (None, buffer)
340    };
341
342    let full_path_str = worktree
343        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
344        .display(PathStyle::local())
345        .to_string();
346
347    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
348    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
349    if clipped_cursor != cursor.point {
350        let max_row = snapshot.max_point().row;
351        if cursor.point.row < max_row {
352            return Err(anyhow!(
353                "Cursor position {:?} is out of bounds (line length is {})",
354                cursor.point,
355                snapshot.line_len(cursor.point.row)
356            ));
357        } else {
358            return Err(anyhow!(
359                "Cursor position {:?} is out of bounds (max row is {})",
360                cursor.point,
361                max_row
362            ));
363        }
364    }
365
366    Ok(LoadedContext {
367        full_path_str,
368        snapshot,
369        clipped_cursor,
370        worktree,
371        project,
372        buffer,
373    })
374}
375
376async fn zeta2_syntax_context(
377    args: ContextArgs,
378    app_state: &Arc<ZetaCliAppState>,
379    cx: &mut AsyncApp,
380) -> Result<String> {
381    let LoadedContext {
382        worktree,
383        project,
384        buffer,
385        clipped_cursor,
386        ..
387    } = load_context(&args, app_state, cx).await?;
388
389    // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
390    // the whole worktree.
391    worktree
392        .read_with(cx, |worktree, _cx| {
393            worktree.as_local().unwrap().scan_complete()
394        })?
395        .await;
396    let output = cx
397        .update(|cx| {
398            let zeta = cx.new(|cx| {
399                zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
400            });
401            let indexing_done_task = zeta.update(cx, |zeta, cx| {
402                zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true));
403                zeta.register_buffer(&buffer, &project, cx);
404                zeta.wait_for_initial_indexing(&project, cx)
405            });
406            cx.spawn(async move |cx| {
407                indexing_done_task.await?;
408                let request = zeta
409                    .update(cx, |zeta, cx| {
410                        let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
411                        zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
412                    })?
413                    .await?;
414
415                let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
416
417                match args.zeta2_args.output_format {
418                    OutputFormat::Prompt => anyhow::Ok(prompt_string),
419                    OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
420                    OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
421                        "request": request,
422                        "prompt": prompt_string,
423                        "section_labels": section_labels,
424                    }))?),
425                }
426            })
427        })?
428        .await?;
429
430    Ok(output)
431}
432
433async fn zeta1_context(
434    args: ContextArgs,
435    app_state: &Arc<ZetaCliAppState>,
436    cx: &mut AsyncApp,
437) -> Result<zeta::zeta1::GatherContextOutput> {
438    let LoadedContext {
439        full_path_str,
440        snapshot,
441        clipped_cursor,
442        ..
443    } = load_context(&args, app_state, cx).await?;
444
445    let events = match args.edit_history {
446        Some(events) => events.read_to_string().await?,
447        None => String::new(),
448    };
449
450    let prompt_for_events = move || (events, 0);
451    cx.update(|cx| {
452        zeta::zeta1::gather_context(
453            full_path_str,
454            &snapshot,
455            clipped_cursor,
456            prompt_for_events,
457            cx,
458        )
459    })?
460    .await
461}
462
463fn main() {
464    zlog::init();
465    zlog::init_output_stderr();
466    let args = ZetaCliArgs::parse();
467    let http_client = Arc::new(ReqwestClient::new());
468    let app = Application::headless().with_http_client(http_client);
469
470    app.run(move |cx| {
471        let app_state = Arc::new(headless::init(cx));
472        cx.spawn(async move |cx| {
473            match args.command {
474                None => {
475                    if args.printenv {
476                        ::util::shell_env::print_env();
477                        return;
478                    } else {
479                        panic!("Expected a command");
480                    }
481                }
482                Some(Command::ContextStats(arguments)) => {
483                    let result = retrieval_stats(
484                        arguments.worktree,
485                        app_state,
486                        arguments.extension,
487                        arguments.limit,
488                        arguments.skip,
489                        zeta2_args_to_options(&arguments.zeta2_args, false),
490                        cx,
491                    )
492                    .await;
493                    println!("{}", result.unwrap());
494                }
495                Some(Command::Context(context_args)) => {
496                    let result = match context_args.provider {
497                        ContextProvider::Zeta1 => {
498                            let context =
499                                zeta1_context(context_args, &app_state, cx).await.unwrap();
500                            serde_json::to_string_pretty(&context.body).unwrap()
501                        }
502                        ContextProvider::Syntax => {
503                            zeta2_syntax_context(context_args, &app_state, cx)
504                                .await
505                                .unwrap()
506                        }
507                    };
508                    println!("{}", result);
509                }
510                Some(Command::Predict(arguments)) => {
511                    run_predict(arguments, &app_state, cx).await;
512                }
513                Some(Command::Eval(arguments)) => {
514                    run_evaluate(arguments, &app_state, cx).await;
515                }
516                Some(Command::ConvertExample {
517                    path,
518                    output_format,
519                }) => {
520                    let example = NamedExample::load(path).unwrap();
521                    example.write(output_format, io::stdout()).unwrap();
522                }
523                Some(Command::Clean) => {
524                    std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
525                }
526            };
527
528            let _ = cx.update(|cx| cx.quit());
529        })
530        .detach();
531    });
532}