1mod example;
  2mod headless;
  3mod source_location;
  4mod syntax_retrieval_stats;
  5mod util;
  6
  7use crate::example::{ExampleFormat, NamedExample};
  8use crate::syntax_retrieval_stats::retrieval_stats;
  9use ::serde::Serialize;
 10use ::util::paths::PathStyle;
 11use anyhow::{Context as _, Result, anyhow};
 12use clap::{Args, Parser, Subcommand};
 13use cloud_llm_client::predict_edits_v3::{self, Excerpt};
 14use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
 15use edit_prediction_context::{
 16    EditPredictionContextOptions, EditPredictionExcerpt, EditPredictionExcerptOptions,
 17    EditPredictionScoreOptions, Line,
 18};
 19use futures::StreamExt as _;
 20use futures::channel::mpsc;
 21use gpui::{Application, AsyncApp, Entity, prelude::*};
 22use language::{Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point};
 23use language_model::LanguageModelRegistry;
 24use project::{Project, Worktree};
 25use reqwest_client::ReqwestClient;
 26use serde_json::json;
 27use std::io;
 28use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc};
 29use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery};
 30
 31use crate::headless::ZetaCliAppState;
 32use crate::source_location::SourceLocation;
 33use crate::util::{open_buffer, open_buffer_with_language_server};
 34
 35#[derive(Parser, Debug)]
 36#[command(name = "zeta")]
 37struct ZetaCliArgs {
 38    #[command(subcommand)]
 39    command: Command,
 40}
 41
 42#[derive(Subcommand, Debug)]
 43enum Command {
 44    Zeta1 {
 45        #[command(subcommand)]
 46        command: Zeta1Command,
 47    },
 48    Zeta2 {
 49        #[clap(flatten)]
 50        args: Zeta2Args,
 51        #[command(subcommand)]
 52        command: Zeta2Command,
 53    },
 54    ConvertExample {
 55        path: PathBuf,
 56        #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
 57        output_format: ExampleFormat,
 58    },
 59}
 60
 61#[derive(Subcommand, Debug)]
 62enum Zeta1Command {
 63    Context {
 64        #[clap(flatten)]
 65        context_args: ContextArgs,
 66    },
 67}
 68
 69#[derive(Subcommand, Debug)]
 70enum Zeta2Command {
 71    Syntax {
 72        #[clap(flatten)]
 73        syntax_args: Zeta2SyntaxArgs,
 74        #[command(subcommand)]
 75        command: Zeta2SyntaxCommand,
 76    },
 77    Llm {
 78        #[command(subcommand)]
 79        command: Zeta2LlmCommand,
 80    },
 81}
 82
 83#[derive(Subcommand, Debug)]
 84enum Zeta2SyntaxCommand {
 85    Context {
 86        #[clap(flatten)]
 87        context_args: ContextArgs,
 88    },
 89    Stats {
 90        #[arg(long)]
 91        worktree: PathBuf,
 92        #[arg(long)]
 93        extension: Option<String>,
 94        #[arg(long)]
 95        limit: Option<usize>,
 96        #[arg(long)]
 97        skip: Option<usize>,
 98    },
 99}
100
101#[derive(Subcommand, Debug)]
102enum Zeta2LlmCommand {
103    Context {
104        #[clap(flatten)]
105        context_args: ContextArgs,
106    },
107}
108
109#[derive(Debug, Args)]
110#[group(requires = "worktree")]
111struct ContextArgs {
112    #[arg(long)]
113    worktree: PathBuf,
114    #[arg(long)]
115    cursor: SourceLocation,
116    #[arg(long)]
117    use_language_server: bool,
118    #[arg(long)]
119    edit_history: Option<FileOrStdin>,
120}
121
122#[derive(Debug, Args)]
123struct Zeta2Args {
124    #[arg(long, default_value_t = 8192)]
125    max_prompt_bytes: usize,
126    #[arg(long, default_value_t = 2048)]
127    max_excerpt_bytes: usize,
128    #[arg(long, default_value_t = 1024)]
129    min_excerpt_bytes: usize,
130    #[arg(long, default_value_t = 0.66)]
131    target_before_cursor_over_total_bytes: f32,
132    #[arg(long, default_value_t = 1024)]
133    max_diagnostic_bytes: usize,
134    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
135    prompt_format: PromptFormat,
136    #[arg(long, value_enum, default_value_t = Default::default())]
137    output_format: OutputFormat,
138    #[arg(long, default_value_t = 42)]
139    file_indexing_parallelism: usize,
140}
141
142#[derive(Debug, Args)]
143struct Zeta2SyntaxArgs {
144    #[arg(long, default_value_t = false)]
145    disable_imports_gathering: bool,
146    #[arg(long, default_value_t = u8::MAX)]
147    max_retrieved_definitions: u8,
148}
149
150fn syntax_args_to_options(
151    zeta2_args: &Zeta2Args,
152    syntax_args: &Zeta2SyntaxArgs,
153    omit_excerpt_overlaps: bool,
154) -> zeta2::ZetaOptions {
155    zeta2::ZetaOptions {
156        context: ContextMode::Syntax(EditPredictionContextOptions {
157            max_retrieved_declarations: syntax_args.max_retrieved_definitions,
158            use_imports: !syntax_args.disable_imports_gathering,
159            excerpt: EditPredictionExcerptOptions {
160                max_bytes: zeta2_args.max_excerpt_bytes,
161                min_bytes: zeta2_args.min_excerpt_bytes,
162                target_before_cursor_over_total_bytes: zeta2_args
163                    .target_before_cursor_over_total_bytes,
164            },
165            score: EditPredictionScoreOptions {
166                omit_excerpt_overlaps,
167            },
168        }),
169        max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
170        max_prompt_bytes: zeta2_args.max_prompt_bytes,
171        prompt_format: zeta2_args.prompt_format.clone().into(),
172        file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
173    }
174}
175
176#[derive(clap::ValueEnum, Default, Debug, Clone)]
177enum PromptFormat {
178    MarkedExcerpt,
179    LabeledSections,
180    OnlySnippets,
181    #[default]
182    NumberedLines,
183}
184
185impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
186    fn into(self) -> predict_edits_v3::PromptFormat {
187        match self {
188            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
189            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
190            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
191            Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
192        }
193    }
194}
195
196#[derive(clap::ValueEnum, Default, Debug, Clone)]
197enum OutputFormat {
198    #[default]
199    Prompt,
200    Request,
201    Full,
202}
203
204#[derive(Debug, Clone)]
205enum FileOrStdin {
206    File(PathBuf),
207    Stdin,
208}
209
210impl FileOrStdin {
211    async fn read_to_string(&self) -> Result<String, std::io::Error> {
212        match self {
213            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
214            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
215        }
216    }
217}
218
219impl FromStr for FileOrStdin {
220    type Err = <PathBuf as FromStr>::Err;
221
222    fn from_str(s: &str) -> Result<Self, Self::Err> {
223        match s {
224            "-" => Ok(Self::Stdin),
225            _ => Ok(Self::File(PathBuf::from_str(s)?)),
226        }
227    }
228}
229
230struct LoadedContext {
231    full_path_str: String,
232    snapshot: BufferSnapshot,
233    clipped_cursor: Point,
234    worktree: Entity<Worktree>,
235    project: Entity<Project>,
236    buffer: Entity<Buffer>,
237}
238
239async fn load_context(
240    args: &ContextArgs,
241    app_state: &Arc<ZetaCliAppState>,
242    cx: &mut AsyncApp,
243) -> Result<LoadedContext> {
244    let ContextArgs {
245        worktree: worktree_path,
246        cursor,
247        use_language_server,
248        ..
249    } = args;
250
251    let worktree_path = worktree_path.canonicalize()?;
252
253    let project = cx.update(|cx| {
254        Project::local(
255            app_state.client.clone(),
256            app_state.node_runtime.clone(),
257            app_state.user_store.clone(),
258            app_state.languages.clone(),
259            app_state.fs.clone(),
260            None,
261            cx,
262        )
263    })?;
264
265    let worktree = project
266        .update(cx, |project, cx| {
267            project.create_worktree(&worktree_path, true, cx)
268        })?
269        .await?;
270
271    let mut ready_languages = HashSet::default();
272    let (_lsp_open_handle, buffer) = if *use_language_server {
273        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
274            project.clone(),
275            worktree.clone(),
276            cursor.path.clone(),
277            &mut ready_languages,
278            cx,
279        )
280        .await?;
281        (Some(lsp_open_handle), buffer)
282    } else {
283        let buffer =
284            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
285        (None, buffer)
286    };
287
288    let full_path_str = worktree
289        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
290        .display(PathStyle::local())
291        .to_string();
292
293    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
294    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
295    if clipped_cursor != cursor.point {
296        let max_row = snapshot.max_point().row;
297        if cursor.point.row < max_row {
298            return Err(anyhow!(
299                "Cursor position {:?} is out of bounds (line length is {})",
300                cursor.point,
301                snapshot.line_len(cursor.point.row)
302            ));
303        } else {
304            return Err(anyhow!(
305                "Cursor position {:?} is out of bounds (max row is {})",
306                cursor.point,
307                max_row
308            ));
309        }
310    }
311
312    Ok(LoadedContext {
313        full_path_str,
314        snapshot,
315        clipped_cursor,
316        worktree,
317        project,
318        buffer,
319    })
320}
321
322async fn zeta2_syntax_context(
323    zeta2_args: Zeta2Args,
324    syntax_args: Zeta2SyntaxArgs,
325    args: ContextArgs,
326    app_state: &Arc<ZetaCliAppState>,
327    cx: &mut AsyncApp,
328) -> Result<String> {
329    let LoadedContext {
330        worktree,
331        project,
332        buffer,
333        clipped_cursor,
334        ..
335    } = load_context(&args, app_state, cx).await?;
336
337    // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
338    // the whole worktree.
339    worktree
340        .read_with(cx, |worktree, _cx| {
341            worktree.as_local().unwrap().scan_complete()
342        })?
343        .await;
344    let output = cx
345        .update(|cx| {
346            let zeta = cx.new(|cx| {
347                zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
348            });
349            let indexing_done_task = zeta.update(cx, |zeta, cx| {
350                zeta.set_options(syntax_args_to_options(&zeta2_args, &syntax_args, true));
351                zeta.register_buffer(&buffer, &project, cx);
352                zeta.wait_for_initial_indexing(&project, cx)
353            });
354            cx.spawn(async move |cx| {
355                indexing_done_task.await?;
356                let request = zeta
357                    .update(cx, |zeta, cx| {
358                        let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
359                        zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
360                    })?
361                    .await?;
362
363                let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
364
365                match zeta2_args.output_format {
366                    OutputFormat::Prompt => anyhow::Ok(prompt_string),
367                    OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
368                    OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
369                        "request": request,
370                        "prompt": prompt_string,
371                        "section_labels": section_labels,
372                    }))?),
373                }
374            })
375        })?
376        .await?;
377
378    Ok(output)
379}
380
381async fn zeta2_llm_context(
382    zeta2_args: Zeta2Args,
383    context_args: ContextArgs,
384    app_state: &Arc<ZetaCliAppState>,
385    cx: &mut AsyncApp,
386) -> Result<String> {
387    let LoadedContext {
388        buffer,
389        clipped_cursor,
390        snapshot: cursor_snapshot,
391        project,
392        ..
393    } = load_context(&context_args, app_state, cx).await?;
394
395    let cursor_position = cursor_snapshot.anchor_after(clipped_cursor);
396
397    cx.update(|cx| {
398        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
399            registry
400                .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
401                .unwrap()
402                .authenticate(cx)
403        })
404    })?
405    .await?;
406
407    let edit_history_unified_diff = match context_args.edit_history {
408        Some(events) => events.read_to_string().await?,
409        None => String::new(),
410    };
411
412    let (debug_tx, mut debug_rx) = mpsc::unbounded();
413
414    let excerpt_options = EditPredictionExcerptOptions {
415        max_bytes: zeta2_args.max_excerpt_bytes,
416        min_bytes: zeta2_args.min_excerpt_bytes,
417        target_before_cursor_over_total_bytes: zeta2_args.target_before_cursor_over_total_bytes,
418    };
419
420    let related_excerpts = cx
421        .update(|cx| {
422            zeta2::related_excerpts::find_related_excerpts(
423                buffer,
424                cursor_position,
425                &project,
426                edit_history_unified_diff,
427                &LlmContextOptions {
428                    excerpt: excerpt_options.clone(),
429                },
430                Some(debug_tx),
431                cx,
432            )
433        })?
434        .await?;
435
436    let cursor_excerpt = EditPredictionExcerpt::select_from_buffer(
437        clipped_cursor,
438        &cursor_snapshot,
439        &excerpt_options,
440        None,
441    )
442    .context("line didn't fit")?;
443
444    #[derive(Serialize)]
445    struct Output {
446        excerpts: Vec<OutputExcerpt>,
447        formatted_excerpts: String,
448        meta: OutputMeta,
449    }
450
451    #[derive(Default, Serialize)]
452    struct OutputMeta {
453        search_prompt: String,
454        search_queries: Vec<SearchToolQuery>,
455    }
456
457    #[derive(Serialize)]
458    struct OutputExcerpt {
459        path: PathBuf,
460        #[serde(flatten)]
461        excerpt: Excerpt,
462    }
463
464    let mut meta = OutputMeta::default();
465
466    while let Some(debug_info) = debug_rx.next().await {
467        match debug_info {
468            zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
469                meta.search_prompt = info.search_prompt;
470            }
471            zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
472                meta.search_queries = info.queries
473            }
474            _ => {}
475        }
476    }
477
478    cx.update(|cx| {
479        let mut excerpts = Vec::new();
480        let mut formatted_excerpts = String::new();
481
482        let cursor_insertions = [(
483            predict_edits_v3::Point {
484                line: Line(clipped_cursor.row),
485                column: clipped_cursor.column,
486            },
487            CURSOR_MARKER,
488        )];
489
490        let mut cursor_excerpt_added = false;
491
492        for (buffer, ranges) in related_excerpts {
493            let excerpt_snapshot = buffer.read(cx).snapshot();
494
495            let mut line_ranges = ranges
496                .into_iter()
497                .map(|range| {
498                    let point_range = range.to_point(&excerpt_snapshot);
499                    Line(point_range.start.row)..Line(point_range.end.row)
500                })
501                .collect::<Vec<_>>();
502
503            let Some(file) = excerpt_snapshot.file() else {
504                continue;
505            };
506            let path = file.full_path(cx);
507
508            let is_cursor_file = path == cursor_snapshot.file().unwrap().full_path(cx);
509            if is_cursor_file {
510                let insertion_ix = line_ranges
511                    .binary_search_by(|probe| {
512                        probe
513                            .start
514                            .cmp(&cursor_excerpt.line_range.start)
515                            .then(cursor_excerpt.line_range.end.cmp(&probe.end))
516                    })
517                    .unwrap_or_else(|ix| ix);
518                line_ranges.insert(insertion_ix, cursor_excerpt.line_range.clone());
519                cursor_excerpt_added = true;
520            }
521
522            let merged_excerpts =
523                zeta2::merge_excerpts::merge_excerpts(&excerpt_snapshot, line_ranges)
524                    .into_iter()
525                    .map(|excerpt| OutputExcerpt {
526                        path: path.clone(),
527                        excerpt,
528                    });
529
530            let excerpt_start_ix = excerpts.len();
531            excerpts.extend(merged_excerpts);
532
533            write_codeblock(
534                &path,
535                excerpts[excerpt_start_ix..].iter().map(|e| &e.excerpt),
536                if is_cursor_file {
537                    &cursor_insertions
538                } else {
539                    &[]
540                },
541                Line(excerpt_snapshot.max_point().row),
542                true,
543                &mut formatted_excerpts,
544            );
545        }
546
547        if !cursor_excerpt_added {
548            write_codeblock(
549                &cursor_snapshot.file().unwrap().full_path(cx),
550                &[Excerpt {
551                    start_line: cursor_excerpt.line_range.start,
552                    text: cursor_excerpt.text(&cursor_snapshot).body.into(),
553                }],
554                &cursor_insertions,
555                Line(cursor_snapshot.max_point().row),
556                true,
557                &mut formatted_excerpts,
558            );
559        }
560
561        let output = Output {
562            excerpts,
563            formatted_excerpts,
564            meta,
565        };
566
567        Ok(serde_json::to_string_pretty(&output)?)
568    })
569    .unwrap()
570}
571
572async fn zeta1_context(
573    args: ContextArgs,
574    app_state: &Arc<ZetaCliAppState>,
575    cx: &mut AsyncApp,
576) -> Result<zeta::GatherContextOutput> {
577    let LoadedContext {
578        full_path_str,
579        snapshot,
580        clipped_cursor,
581        ..
582    } = load_context(&args, app_state, cx).await?;
583
584    let events = match args.edit_history {
585        Some(events) => events.read_to_string().await?,
586        None => String::new(),
587    };
588
589    let prompt_for_events = move || (events, 0);
590    cx.update(|cx| {
591        zeta::gather_context(
592            full_path_str,
593            &snapshot,
594            clipped_cursor,
595            prompt_for_events,
596            cx,
597        )
598    })?
599    .await
600}
601
602fn main() {
603    zlog::init();
604    zlog::init_output_stderr();
605    let args = ZetaCliArgs::parse();
606    let http_client = Arc::new(ReqwestClient::new());
607    let app = Application::headless().with_http_client(http_client);
608
609    app.run(move |cx| {
610        let app_state = Arc::new(headless::init(cx));
611        cx.spawn(async move |cx| {
612            let result = match args.command {
613                Command::Zeta1 {
614                    command: Zeta1Command::Context { context_args },
615                } => {
616                    let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
617                    serde_json::to_string_pretty(&context.body).map_err(|err| anyhow::anyhow!(err))
618                }
619                Command::Zeta2 { args, command } => match command {
620                    Zeta2Command::Syntax {
621                        syntax_args,
622                        command,
623                    } => match command {
624                        Zeta2SyntaxCommand::Context { context_args } => {
625                            zeta2_syntax_context(args, syntax_args, context_args, &app_state, cx)
626                                .await
627                        }
628                        Zeta2SyntaxCommand::Stats {
629                            worktree,
630                            extension,
631                            limit,
632                            skip,
633                        } => {
634                            retrieval_stats(
635                                worktree,
636                                app_state,
637                                extension,
638                                limit,
639                                skip,
640                                syntax_args_to_options(&args, &syntax_args, false),
641                                cx,
642                            )
643                            .await
644                        }
645                    },
646                    Zeta2Command::Llm { command } => match command {
647                        Zeta2LlmCommand::Context { context_args } => {
648                            zeta2_llm_context(args, context_args, &app_state, cx).await
649                        }
650                    },
651                },
652                Command::ConvertExample {
653                    path,
654                    output_format,
655                } => {
656                    let example = NamedExample::load(path).unwrap();
657                    example.write(output_format, io::stdout()).unwrap();
658                    let _ = cx.update(|cx| cx.quit());
659                    return;
660                }
661            };
662
663            match result {
664                Ok(output) => {
665                    println!("{}", output);
666                    let _ = cx.update(|cx| cx.quit());
667                }
668                Err(e) => {
669                    eprintln!("Failed: {:?}", e);
670                    exit(1);
671                }
672            }
673        })
674        .detach();
675    });
676}