main.rs

  1mod evaluate;
  2mod example;
  3mod headless;
  4mod paths;
  5mod predict;
  6mod source_location;
  7mod syntax_retrieval_stats;
  8mod util;
  9
 10use crate::{
 11    evaluate::run_evaluate,
 12    example::{ExampleFormat, NamedExample},
 13    headless::ZetaCliAppState,
 14    predict::run_predict,
 15    source_location::SourceLocation,
 16    syntax_retrieval_stats::retrieval_stats,
 17    util::{open_buffer, open_buffer_with_language_server},
 18};
 19use ::util::paths::PathStyle;
 20use anyhow::{Result, anyhow};
 21use clap::{Args, Parser, Subcommand, ValueEnum};
 22use cloud_llm_client::predict_edits_v3;
 23use edit_prediction_context::{
 24    EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
 25};
 26use gpui::{Application, AsyncApp, Entity, prelude::*};
 27use language::{Bias, Buffer, BufferSnapshot, Point};
 28use project::{Project, Worktree};
 29use reqwest_client::ReqwestClient;
 30use serde_json::json;
 31use std::io::{self};
 32use std::time::Duration;
 33use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
 34use zeta2::ContextMode;
 35
 36#[derive(Parser, Debug)]
 37#[command(name = "zeta")]
 38struct ZetaCliArgs {
 39    #[arg(long, default_value_t = false)]
 40    printenv: bool,
 41    #[command(subcommand)]
 42    command: Option<Command>,
 43}
 44
 45#[derive(Subcommand, Debug)]
 46enum Command {
 47    Context(ContextArgs),
 48    ContextStats(ContextStatsArgs),
 49    Predict(PredictArguments),
 50    Eval(EvaluateArguments),
 51    ConvertExample {
 52        path: PathBuf,
 53        #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
 54        output_format: ExampleFormat,
 55    },
 56    Clean,
 57}
 58
 59#[derive(Debug, Args)]
 60struct ContextStatsArgs {
 61    #[arg(long)]
 62    worktree: PathBuf,
 63    #[arg(long)]
 64    extension: Option<String>,
 65    #[arg(long)]
 66    limit: Option<usize>,
 67    #[arg(long)]
 68    skip: Option<usize>,
 69    #[clap(flatten)]
 70    zeta2_args: Zeta2Args,
 71}
 72
 73#[derive(Debug, Args)]
 74struct ContextArgs {
 75    #[arg(long)]
 76    provider: ContextProvider,
 77    #[arg(long)]
 78    worktree: PathBuf,
 79    #[arg(long)]
 80    cursor: SourceLocation,
 81    #[arg(long)]
 82    use_language_server: bool,
 83    #[arg(long)]
 84    edit_history: Option<FileOrStdin>,
 85    #[clap(flatten)]
 86    zeta2_args: Zeta2Args,
 87}
 88
 89#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
 90enum ContextProvider {
 91    Zeta1,
 92    #[default]
 93    Syntax,
 94}
 95
 96#[derive(Clone, Debug, Args)]
 97struct Zeta2Args {
 98    #[arg(long, default_value_t = 8192)]
 99    max_prompt_bytes: usize,
100    #[arg(long, default_value_t = 2048)]
101    max_excerpt_bytes: usize,
102    #[arg(long, default_value_t = 1024)]
103    min_excerpt_bytes: usize,
104    #[arg(long, default_value_t = 0.66)]
105    target_before_cursor_over_total_bytes: f32,
106    #[arg(long, default_value_t = 1024)]
107    max_diagnostic_bytes: usize,
108    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
109    prompt_format: PromptFormat,
110    #[arg(long, value_enum, default_value_t = Default::default())]
111    output_format: OutputFormat,
112    #[arg(long, default_value_t = 42)]
113    file_indexing_parallelism: usize,
114    #[arg(long, default_value_t = false)]
115    disable_imports_gathering: bool,
116    #[arg(long, default_value_t = u8::MAX)]
117    max_retrieved_definitions: u8,
118}
119
120#[derive(Debug, Args)]
121pub struct PredictArguments {
122    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
123    format: PredictionsOutputFormat,
124    example_path: PathBuf,
125    #[clap(flatten)]
126    options: PredictionOptions,
127}
128
129#[derive(Clone, Debug, Args)]
130pub struct PredictionOptions {
131    #[arg(long)]
132    use_expected_context: bool,
133    #[clap(flatten)]
134    zeta2: Zeta2Args,
135    #[clap(long)]
136    provider: PredictionProvider,
137    #[clap(long, value_enum, default_value_t = CacheMode::default())]
138    cache: CacheMode,
139}
140
141#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
142pub enum CacheMode {
143    /// Use cached LLM requests and responses, except when multiple repetitions are requested
144    #[default]
145    Auto,
146    /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
147    #[value(alias = "request")]
148    Requests,
149    /// Ignore existing cache entries for both LLM and search.
150    Skip,
151    /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
152    /// Useful for reproducing results and fixing bugs outside of search queries
153    Force,
154}
155
156impl CacheMode {
157    fn use_cached_llm_responses(&self) -> bool {
158        self.assert_not_auto();
159        matches!(self, CacheMode::Requests | CacheMode::Force)
160    }
161
162    fn use_cached_search_results(&self) -> bool {
163        self.assert_not_auto();
164        matches!(self, CacheMode::Force)
165    }
166
167    fn assert_not_auto(&self) {
168        assert_ne!(
169            *self,
170            CacheMode::Auto,
171            "Cache mode should not be auto at this point!"
172        );
173    }
174}
175
176#[derive(clap::ValueEnum, Debug, Clone)]
177pub enum PredictionsOutputFormat {
178    Json,
179    Md,
180    Diff,
181}
182
183#[derive(Debug, Args)]
184pub struct EvaluateArguments {
185    example_paths: Vec<PathBuf>,
186    #[clap(flatten)]
187    options: PredictionOptions,
188    #[clap(short, long, default_value_t = 1, alias = "repeat")]
189    repetitions: u16,
190    #[arg(long)]
191    skip_prediction: bool,
192}
193
194#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
195enum PredictionProvider {
196    #[default]
197    Zeta2,
198    Sweep,
199}
200
201fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions {
202    zeta2::ZetaOptions {
203        context: ContextMode::Syntax(EditPredictionContextOptions {
204            max_retrieved_declarations: args.max_retrieved_definitions,
205            use_imports: !args.disable_imports_gathering,
206            excerpt: EditPredictionExcerptOptions {
207                max_bytes: args.max_excerpt_bytes,
208                min_bytes: args.min_excerpt_bytes,
209                target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
210            },
211            score: EditPredictionScoreOptions {
212                omit_excerpt_overlaps,
213            },
214        }),
215        max_diagnostic_bytes: args.max_diagnostic_bytes,
216        max_prompt_bytes: args.max_prompt_bytes,
217        prompt_format: args.prompt_format.into(),
218        file_indexing_parallelism: args.file_indexing_parallelism,
219        buffer_change_grouping_interval: Duration::ZERO,
220    }
221}
222
223#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
224enum PromptFormat {
225    MarkedExcerpt,
226    LabeledSections,
227    OnlySnippets,
228    #[default]
229    NumberedLines,
230    OldTextNewText,
231    Minimal,
232    MinimalQwen,
233    SeedCoder1120,
234}
235
236impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
237    fn into(self) -> predict_edits_v3::PromptFormat {
238        match self {
239            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
240            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
241            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
242            Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
243            Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
244            Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
245            Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
246            Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
247        }
248    }
249}
250
251#[derive(clap::ValueEnum, Default, Debug, Clone)]
252enum OutputFormat {
253    #[default]
254    Prompt,
255    Request,
256    Full,
257}
258
259#[derive(Debug, Clone)]
260enum FileOrStdin {
261    File(PathBuf),
262    Stdin,
263}
264
265impl FileOrStdin {
266    async fn read_to_string(&self) -> Result<String, std::io::Error> {
267        match self {
268            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
269            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
270        }
271    }
272}
273
274impl FromStr for FileOrStdin {
275    type Err = <PathBuf as FromStr>::Err;
276
277    fn from_str(s: &str) -> Result<Self, Self::Err> {
278        match s {
279            "-" => Ok(Self::Stdin),
280            _ => Ok(Self::File(PathBuf::from_str(s)?)),
281        }
282    }
283}
284
285struct LoadedContext {
286    full_path_str: String,
287    snapshot: BufferSnapshot,
288    clipped_cursor: Point,
289    worktree: Entity<Worktree>,
290    project: Entity<Project>,
291    buffer: Entity<Buffer>,
292}
293
294async fn load_context(
295    args: &ContextArgs,
296    app_state: &Arc<ZetaCliAppState>,
297    cx: &mut AsyncApp,
298) -> Result<LoadedContext> {
299    let ContextArgs {
300        worktree: worktree_path,
301        cursor,
302        use_language_server,
303        ..
304    } = args;
305
306    let worktree_path = worktree_path.canonicalize()?;
307
308    let project = cx.update(|cx| {
309        Project::local(
310            app_state.client.clone(),
311            app_state.node_runtime.clone(),
312            app_state.user_store.clone(),
313            app_state.languages.clone(),
314            app_state.fs.clone(),
315            None,
316            cx,
317        )
318    })?;
319
320    let worktree = project
321        .update(cx, |project, cx| {
322            project.create_worktree(&worktree_path, true, cx)
323        })?
324        .await?;
325
326    let mut ready_languages = HashSet::default();
327    let (_lsp_open_handle, buffer) = if *use_language_server {
328        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
329            project.clone(),
330            worktree.clone(),
331            cursor.path.clone(),
332            &mut ready_languages,
333            cx,
334        )
335        .await?;
336        (Some(lsp_open_handle), buffer)
337    } else {
338        let buffer =
339            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
340        (None, buffer)
341    };
342
343    let full_path_str = worktree
344        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
345        .display(PathStyle::local())
346        .to_string();
347
348    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
349    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
350    if clipped_cursor != cursor.point {
351        let max_row = snapshot.max_point().row;
352        if cursor.point.row < max_row {
353            return Err(anyhow!(
354                "Cursor position {:?} is out of bounds (line length is {})",
355                cursor.point,
356                snapshot.line_len(cursor.point.row)
357            ));
358        } else {
359            return Err(anyhow!(
360                "Cursor position {:?} is out of bounds (max row is {})",
361                cursor.point,
362                max_row
363            ));
364        }
365    }
366
367    Ok(LoadedContext {
368        full_path_str,
369        snapshot,
370        clipped_cursor,
371        worktree,
372        project,
373        buffer,
374    })
375}
376
377async fn zeta2_syntax_context(
378    args: ContextArgs,
379    app_state: &Arc<ZetaCliAppState>,
380    cx: &mut AsyncApp,
381) -> Result<String> {
382    let LoadedContext {
383        worktree,
384        project,
385        buffer,
386        clipped_cursor,
387        ..
388    } = load_context(&args, app_state, cx).await?;
389
390    // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
391    // the whole worktree.
392    worktree
393        .read_with(cx, |worktree, _cx| {
394            worktree.as_local().unwrap().scan_complete()
395        })?
396        .await;
397    let output = cx
398        .update(|cx| {
399            let zeta = cx.new(|cx| {
400                zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
401            });
402            let indexing_done_task = zeta.update(cx, |zeta, cx| {
403                zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true));
404                zeta.register_buffer(&buffer, &project, cx);
405                zeta.wait_for_initial_indexing(&project, cx)
406            });
407            cx.spawn(async move |cx| {
408                indexing_done_task.await?;
409                let request = zeta
410                    .update(cx, |zeta, cx| {
411                        let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
412                        zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
413                    })?
414                    .await?;
415
416                let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
417
418                match args.zeta2_args.output_format {
419                    OutputFormat::Prompt => anyhow::Ok(prompt_string),
420                    OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
421                    OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
422                        "request": request,
423                        "prompt": prompt_string,
424                        "section_labels": section_labels,
425                    }))?),
426                }
427            })
428        })?
429        .await?;
430
431    Ok(output)
432}
433
434async fn zeta1_context(
435    args: ContextArgs,
436    app_state: &Arc<ZetaCliAppState>,
437    cx: &mut AsyncApp,
438) -> Result<zeta::GatherContextOutput> {
439    let LoadedContext {
440        full_path_str,
441        snapshot,
442        clipped_cursor,
443        ..
444    } = load_context(&args, app_state, cx).await?;
445
446    let events = match args.edit_history {
447        Some(events) => events.read_to_string().await?,
448        None => String::new(),
449    };
450
451    let prompt_for_events = move || (events, 0);
452    cx.update(|cx| {
453        zeta::gather_context(
454            full_path_str,
455            &snapshot,
456            clipped_cursor,
457            prompt_for_events,
458            cx,
459        )
460    })?
461    .await
462}
463
464fn main() {
465    zlog::init();
466    zlog::init_output_stderr();
467    let args = ZetaCliArgs::parse();
468    let http_client = Arc::new(ReqwestClient::new());
469    let app = Application::headless().with_http_client(http_client);
470
471    app.run(move |cx| {
472        let app_state = Arc::new(headless::init(cx));
473        cx.spawn(async move |cx| {
474            match args.command {
475                None => {
476                    if args.printenv {
477                        ::util::shell_env::print_env();
478                        return;
479                    } else {
480                        panic!("Expected a command");
481                    }
482                }
483                Some(Command::ContextStats(arguments)) => {
484                    let result = retrieval_stats(
485                        arguments.worktree,
486                        app_state,
487                        arguments.extension,
488                        arguments.limit,
489                        arguments.skip,
490                        zeta2_args_to_options(&arguments.zeta2_args, false),
491                        cx,
492                    )
493                    .await;
494                    println!("{}", result.unwrap());
495                }
496                Some(Command::Context(context_args)) => {
497                    let result = match context_args.provider {
498                        ContextProvider::Zeta1 => {
499                            let context =
500                                zeta1_context(context_args, &app_state, cx).await.unwrap();
501                            serde_json::to_string_pretty(&context.body).unwrap()
502                        }
503                        ContextProvider::Syntax => {
504                            zeta2_syntax_context(context_args, &app_state, cx)
505                                .await
506                                .unwrap()
507                        }
508                    };
509                    println!("{}", result);
510                }
511                Some(Command::Predict(arguments)) => {
512                    run_predict(arguments, &app_state, cx).await;
513                }
514                Some(Command::Eval(arguments)) => {
515                    run_evaluate(arguments, &app_state, cx).await;
516                }
517                Some(Command::ConvertExample {
518                    path,
519                    output_format,
520                }) => {
521                    let example = NamedExample::load(path).unwrap();
522                    example.write(output_format, io::stdout()).unwrap();
523                }
524                Some(Command::Clean) => {
525                    std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
526                }
527            };
528
529            let _ = cx.update(|cx| cx.quit());
530        })
531        .detach();
532    });
533}