main.rs

  1mod evaluate;
  2mod example;
  3mod headless;
  4mod predict;
  5mod source_location;
  6mod syntax_retrieval_stats;
  7mod util;
  8
  9use crate::evaluate::{EvaluateArguments, run_evaluate};
 10use crate::example::{ExampleFormat, NamedExample};
 11use crate::predict::{PredictArguments, run_zeta2_predict};
 12use crate::syntax_retrieval_stats::retrieval_stats;
 13use ::serde::Serialize;
 14use ::util::paths::PathStyle;
 15use anyhow::{Context as _, Result, anyhow};
 16use clap::{Args, Parser, Subcommand};
 17use cloud_llm_client::predict_edits_v3::{self, Excerpt};
 18use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
 19use edit_prediction_context::{
 20    EditPredictionContextOptions, EditPredictionExcerpt, EditPredictionExcerptOptions,
 21    EditPredictionScoreOptions, Line,
 22};
 23use futures::StreamExt as _;
 24use futures::channel::mpsc;
 25use gpui::{Application, AsyncApp, Entity, prelude::*};
 26use language::{Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point};
 27use language_model::LanguageModelRegistry;
 28use project::{Project, Worktree};
 29use reqwest_client::ReqwestClient;
 30use serde_json::json;
 31use std::io::{self};
 32use std::time::Duration;
 33use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
 34use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery};
 35
 36use crate::headless::ZetaCliAppState;
 37use crate::source_location::SourceLocation;
 38use crate::util::{open_buffer, open_buffer_with_language_server};
 39
 40#[derive(Parser, Debug)]
 41#[command(name = "zeta")]
 42struct ZetaCliArgs {
 43    #[command(subcommand)]
 44    command: Command,
 45}
 46
 47#[derive(Subcommand, Debug)]
 48enum Command {
 49    Zeta1 {
 50        #[command(subcommand)]
 51        command: Zeta1Command,
 52    },
 53    Zeta2 {
 54        #[command(subcommand)]
 55        command: Zeta2Command,
 56    },
 57    ConvertExample {
 58        path: PathBuf,
 59        #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
 60        output_format: ExampleFormat,
 61    },
 62}
 63
 64#[derive(Subcommand, Debug)]
 65enum Zeta1Command {
 66    Context {
 67        #[clap(flatten)]
 68        context_args: ContextArgs,
 69    },
 70}
 71
 72#[derive(Subcommand, Debug)]
 73enum Zeta2Command {
 74    Syntax {
 75        #[clap(flatten)]
 76        args: Zeta2Args,
 77        #[clap(flatten)]
 78        syntax_args: Zeta2SyntaxArgs,
 79        #[command(subcommand)]
 80        command: Zeta2SyntaxCommand,
 81    },
 82    Llm {
 83        #[clap(flatten)]
 84        args: Zeta2Args,
 85        #[command(subcommand)]
 86        command: Zeta2LlmCommand,
 87    },
 88    Predict(PredictArguments),
 89    Eval(EvaluateArguments),
 90}
 91
 92#[derive(Subcommand, Debug)]
 93enum Zeta2SyntaxCommand {
 94    Context {
 95        #[clap(flatten)]
 96        context_args: ContextArgs,
 97    },
 98    Stats {
 99        #[arg(long)]
100        worktree: PathBuf,
101        #[arg(long)]
102        extension: Option<String>,
103        #[arg(long)]
104        limit: Option<usize>,
105        #[arg(long)]
106        skip: Option<usize>,
107    },
108}
109
110#[derive(Subcommand, Debug)]
111enum Zeta2LlmCommand {
112    Context {
113        #[clap(flatten)]
114        context_args: ContextArgs,
115    },
116}
117
118#[derive(Debug, Args)]
119#[group(requires = "worktree")]
120struct ContextArgs {
121    #[arg(long)]
122    worktree: PathBuf,
123    #[arg(long)]
124    cursor: SourceLocation,
125    #[arg(long)]
126    use_language_server: bool,
127    #[arg(long)]
128    edit_history: Option<FileOrStdin>,
129}
130
131#[derive(Debug, Args)]
132struct Zeta2Args {
133    #[arg(long, default_value_t = 8192)]
134    max_prompt_bytes: usize,
135    #[arg(long, default_value_t = 2048)]
136    max_excerpt_bytes: usize,
137    #[arg(long, default_value_t = 1024)]
138    min_excerpt_bytes: usize,
139    #[arg(long, default_value_t = 0.66)]
140    target_before_cursor_over_total_bytes: f32,
141    #[arg(long, default_value_t = 1024)]
142    max_diagnostic_bytes: usize,
143    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
144    prompt_format: PromptFormat,
145    #[arg(long, value_enum, default_value_t = Default::default())]
146    output_format: OutputFormat,
147    #[arg(long, default_value_t = 42)]
148    file_indexing_parallelism: usize,
149}
150
151#[derive(Debug, Args)]
152struct Zeta2SyntaxArgs {
153    #[arg(long, default_value_t = false)]
154    disable_imports_gathering: bool,
155    #[arg(long, default_value_t = u8::MAX)]
156    max_retrieved_definitions: u8,
157}
158
159fn syntax_args_to_options(
160    zeta2_args: &Zeta2Args,
161    syntax_args: &Zeta2SyntaxArgs,
162    omit_excerpt_overlaps: bool,
163) -> zeta2::ZetaOptions {
164    zeta2::ZetaOptions {
165        context: ContextMode::Syntax(EditPredictionContextOptions {
166            max_retrieved_declarations: syntax_args.max_retrieved_definitions,
167            use_imports: !syntax_args.disable_imports_gathering,
168            excerpt: EditPredictionExcerptOptions {
169                max_bytes: zeta2_args.max_excerpt_bytes,
170                min_bytes: zeta2_args.min_excerpt_bytes,
171                target_before_cursor_over_total_bytes: zeta2_args
172                    .target_before_cursor_over_total_bytes,
173            },
174            score: EditPredictionScoreOptions {
175                omit_excerpt_overlaps,
176            },
177        }),
178        max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
179        max_prompt_bytes: zeta2_args.max_prompt_bytes,
180        prompt_format: zeta2_args.prompt_format.clone().into(),
181        file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
182        buffer_change_grouping_interval: Duration::ZERO,
183    }
184}
185
186#[derive(clap::ValueEnum, Default, Debug, Clone)]
187enum PromptFormat {
188    MarkedExcerpt,
189    LabeledSections,
190    OnlySnippets,
191    #[default]
192    NumberedLines,
193}
194
195impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
196    fn into(self) -> predict_edits_v3::PromptFormat {
197        match self {
198            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
199            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
200            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
201            Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
202        }
203    }
204}
205
206#[derive(clap::ValueEnum, Default, Debug, Clone)]
207enum OutputFormat {
208    #[default]
209    Prompt,
210    Request,
211    Full,
212}
213
214#[derive(Debug, Clone)]
215enum FileOrStdin {
216    File(PathBuf),
217    Stdin,
218}
219
220impl FileOrStdin {
221    async fn read_to_string(&self) -> Result<String, std::io::Error> {
222        match self {
223            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
224            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
225        }
226    }
227}
228
229impl FromStr for FileOrStdin {
230    type Err = <PathBuf as FromStr>::Err;
231
232    fn from_str(s: &str) -> Result<Self, Self::Err> {
233        match s {
234            "-" => Ok(Self::Stdin),
235            _ => Ok(Self::File(PathBuf::from_str(s)?)),
236        }
237    }
238}
239
240struct LoadedContext {
241    full_path_str: String,
242    snapshot: BufferSnapshot,
243    clipped_cursor: Point,
244    worktree: Entity<Worktree>,
245    project: Entity<Project>,
246    buffer: Entity<Buffer>,
247}
248
249async fn load_context(
250    args: &ContextArgs,
251    app_state: &Arc<ZetaCliAppState>,
252    cx: &mut AsyncApp,
253) -> Result<LoadedContext> {
254    let ContextArgs {
255        worktree: worktree_path,
256        cursor,
257        use_language_server,
258        ..
259    } = args;
260
261    let worktree_path = worktree_path.canonicalize()?;
262
263    let project = cx.update(|cx| {
264        Project::local(
265            app_state.client.clone(),
266            app_state.node_runtime.clone(),
267            app_state.user_store.clone(),
268            app_state.languages.clone(),
269            app_state.fs.clone(),
270            None,
271            cx,
272        )
273    })?;
274
275    let worktree = project
276        .update(cx, |project, cx| {
277            project.create_worktree(&worktree_path, true, cx)
278        })?
279        .await?;
280
281    let mut ready_languages = HashSet::default();
282    let (_lsp_open_handle, buffer) = if *use_language_server {
283        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
284            project.clone(),
285            worktree.clone(),
286            cursor.path.clone(),
287            &mut ready_languages,
288            cx,
289        )
290        .await?;
291        (Some(lsp_open_handle), buffer)
292    } else {
293        let buffer =
294            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
295        (None, buffer)
296    };
297
298    let full_path_str = worktree
299        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
300        .display(PathStyle::local())
301        .to_string();
302
303    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
304    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
305    if clipped_cursor != cursor.point {
306        let max_row = snapshot.max_point().row;
307        if cursor.point.row < max_row {
308            return Err(anyhow!(
309                "Cursor position {:?} is out of bounds (line length is {})",
310                cursor.point,
311                snapshot.line_len(cursor.point.row)
312            ));
313        } else {
314            return Err(anyhow!(
315                "Cursor position {:?} is out of bounds (max row is {})",
316                cursor.point,
317                max_row
318            ));
319        }
320    }
321
322    Ok(LoadedContext {
323        full_path_str,
324        snapshot,
325        clipped_cursor,
326        worktree,
327        project,
328        buffer,
329    })
330}
331
332async fn zeta2_syntax_context(
333    zeta2_args: Zeta2Args,
334    syntax_args: Zeta2SyntaxArgs,
335    args: ContextArgs,
336    app_state: &Arc<ZetaCliAppState>,
337    cx: &mut AsyncApp,
338) -> Result<String> {
339    let LoadedContext {
340        worktree,
341        project,
342        buffer,
343        clipped_cursor,
344        ..
345    } = load_context(&args, app_state, cx).await?;
346
347    // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
348    // the whole worktree.
349    worktree
350        .read_with(cx, |worktree, _cx| {
351            worktree.as_local().unwrap().scan_complete()
352        })?
353        .await;
354    let output = cx
355        .update(|cx| {
356            let zeta = cx.new(|cx| {
357                zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
358            });
359            let indexing_done_task = zeta.update(cx, |zeta, cx| {
360                zeta.set_options(syntax_args_to_options(&zeta2_args, &syntax_args, true));
361                zeta.register_buffer(&buffer, &project, cx);
362                zeta.wait_for_initial_indexing(&project, cx)
363            });
364            cx.spawn(async move |cx| {
365                indexing_done_task.await?;
366                let request = zeta
367                    .update(cx, |zeta, cx| {
368                        let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
369                        zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
370                    })?
371                    .await?;
372
373                let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
374
375                match zeta2_args.output_format {
376                    OutputFormat::Prompt => anyhow::Ok(prompt_string),
377                    OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
378                    OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
379                        "request": request,
380                        "prompt": prompt_string,
381                        "section_labels": section_labels,
382                    }))?),
383                }
384            })
385        })?
386        .await?;
387
388    Ok(output)
389}
390
391async fn zeta2_llm_context(
392    zeta2_args: Zeta2Args,
393    context_args: ContextArgs,
394    app_state: &Arc<ZetaCliAppState>,
395    cx: &mut AsyncApp,
396) -> Result<String> {
397    let LoadedContext {
398        buffer,
399        clipped_cursor,
400        snapshot: cursor_snapshot,
401        project,
402        ..
403    } = load_context(&context_args, app_state, cx).await?;
404
405    let cursor_position = cursor_snapshot.anchor_after(clipped_cursor);
406
407    cx.update(|cx| {
408        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
409            registry
410                .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
411                .unwrap()
412                .authenticate(cx)
413        })
414    })?
415    .await?;
416
417    let edit_history_unified_diff = match context_args.edit_history {
418        Some(events) => events.read_to_string().await?,
419        None => String::new(),
420    };
421
422    let (debug_tx, mut debug_rx) = mpsc::unbounded();
423
424    let excerpt_options = EditPredictionExcerptOptions {
425        max_bytes: zeta2_args.max_excerpt_bytes,
426        min_bytes: zeta2_args.min_excerpt_bytes,
427        target_before_cursor_over_total_bytes: zeta2_args.target_before_cursor_over_total_bytes,
428    };
429
430    let related_excerpts = cx
431        .update(|cx| {
432            zeta2::related_excerpts::find_related_excerpts(
433                buffer,
434                cursor_position,
435                &project,
436                edit_history_unified_diff,
437                &LlmContextOptions {
438                    excerpt: excerpt_options.clone(),
439                },
440                Some(debug_tx),
441                cx,
442            )
443        })?
444        .await?;
445
446    let cursor_excerpt = EditPredictionExcerpt::select_from_buffer(
447        clipped_cursor,
448        &cursor_snapshot,
449        &excerpt_options,
450        None,
451    )
452    .context("line didn't fit")?;
453
454    #[derive(Serialize)]
455    struct Output {
456        excerpts: Vec<OutputExcerpt>,
457        formatted_excerpts: String,
458        meta: OutputMeta,
459    }
460
461    #[derive(Default, Serialize)]
462    struct OutputMeta {
463        search_prompt: String,
464        search_queries: Vec<SearchToolQuery>,
465    }
466
467    #[derive(Serialize)]
468    struct OutputExcerpt {
469        path: PathBuf,
470        #[serde(flatten)]
471        excerpt: Excerpt,
472    }
473
474    let mut meta = OutputMeta::default();
475
476    while let Some(debug_info) = debug_rx.next().await {
477        match debug_info {
478            zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
479                meta.search_prompt = info.search_prompt;
480            }
481            zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
482                meta.search_queries = info.queries
483            }
484            _ => {}
485        }
486    }
487
488    cx.update(|cx| {
489        let mut excerpts = Vec::new();
490        let mut formatted_excerpts = String::new();
491
492        let cursor_insertions = [(
493            predict_edits_v3::Point {
494                line: Line(clipped_cursor.row),
495                column: clipped_cursor.column,
496            },
497            CURSOR_MARKER,
498        )];
499
500        let mut cursor_excerpt_added = false;
501
502        for (buffer, ranges) in related_excerpts {
503            let excerpt_snapshot = buffer.read(cx).snapshot();
504
505            let mut line_ranges = ranges
506                .into_iter()
507                .map(|range| {
508                    let point_range = range.to_point(&excerpt_snapshot);
509                    Line(point_range.start.row)..Line(point_range.end.row)
510                })
511                .collect::<Vec<_>>();
512
513            let Some(file) = excerpt_snapshot.file() else {
514                continue;
515            };
516            let path = file.full_path(cx);
517
518            let is_cursor_file = path == cursor_snapshot.file().unwrap().full_path(cx);
519            if is_cursor_file {
520                let insertion_ix = line_ranges
521                    .binary_search_by(|probe| {
522                        probe
523                            .start
524                            .cmp(&cursor_excerpt.line_range.start)
525                            .then(cursor_excerpt.line_range.end.cmp(&probe.end))
526                    })
527                    .unwrap_or_else(|ix| ix);
528                line_ranges.insert(insertion_ix, cursor_excerpt.line_range.clone());
529                cursor_excerpt_added = true;
530            }
531
532            let merged_excerpts =
533                zeta2::merge_excerpts::merge_excerpts(&excerpt_snapshot, line_ranges)
534                    .into_iter()
535                    .map(|excerpt| OutputExcerpt {
536                        path: path.clone(),
537                        excerpt,
538                    });
539
540            let excerpt_start_ix = excerpts.len();
541            excerpts.extend(merged_excerpts);
542
543            write_codeblock(
544                &path,
545                excerpts[excerpt_start_ix..].iter().map(|e| &e.excerpt),
546                if is_cursor_file {
547                    &cursor_insertions
548                } else {
549                    &[]
550                },
551                Line(excerpt_snapshot.max_point().row),
552                true,
553                &mut formatted_excerpts,
554            );
555        }
556
557        if !cursor_excerpt_added {
558            write_codeblock(
559                &cursor_snapshot.file().unwrap().full_path(cx),
560                &[Excerpt {
561                    start_line: cursor_excerpt.line_range.start,
562                    text: cursor_excerpt.text(&cursor_snapshot).body.into(),
563                }],
564                &cursor_insertions,
565                Line(cursor_snapshot.max_point().row),
566                true,
567                &mut formatted_excerpts,
568            );
569        }
570
571        let output = Output {
572            excerpts,
573            formatted_excerpts,
574            meta,
575        };
576
577        Ok(serde_json::to_string_pretty(&output)?)
578    })
579    .unwrap()
580}
581
582async fn zeta1_context(
583    args: ContextArgs,
584    app_state: &Arc<ZetaCliAppState>,
585    cx: &mut AsyncApp,
586) -> Result<zeta::GatherContextOutput> {
587    let LoadedContext {
588        full_path_str,
589        snapshot,
590        clipped_cursor,
591        ..
592    } = load_context(&args, app_state, cx).await?;
593
594    let events = match args.edit_history {
595        Some(events) => events.read_to_string().await?,
596        None => String::new(),
597    };
598
599    let prompt_for_events = move || (events, 0);
600    cx.update(|cx| {
601        zeta::gather_context(
602            full_path_str,
603            &snapshot,
604            clipped_cursor,
605            prompt_for_events,
606            cx,
607        )
608    })?
609    .await
610}
611
612fn main() {
613    zlog::init();
614    zlog::init_output_stderr();
615    let args = ZetaCliArgs::parse();
616    let http_client = Arc::new(ReqwestClient::new());
617    let app = Application::headless().with_http_client(http_client);
618
619    app.run(move |cx| {
620        let app_state = Arc::new(headless::init(cx));
621        cx.spawn(async move |cx| {
622            match args.command {
623                Command::Zeta1 {
624                    command: Zeta1Command::Context { context_args },
625                } => {
626                    let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
627                    let result = serde_json::to_string_pretty(&context.body).unwrap();
628                    println!("{}", result);
629                }
630                Command::Zeta2 { command } => match command {
631                    Zeta2Command::Predict(arguments) => {
632                        run_zeta2_predict(arguments, &app_state, cx).await;
633                    }
634                    Zeta2Command::Eval(arguments) => {
635                        run_evaluate(arguments, &app_state, cx).await;
636                    }
637                    Zeta2Command::Syntax {
638                        args,
639                        syntax_args,
640                        command,
641                    } => {
642                        let result = match command {
643                            Zeta2SyntaxCommand::Context { context_args } => {
644                                zeta2_syntax_context(
645                                    args,
646                                    syntax_args,
647                                    context_args,
648                                    &app_state,
649                                    cx,
650                                )
651                                .await
652                            }
653                            Zeta2SyntaxCommand::Stats {
654                                worktree,
655                                extension,
656                                limit,
657                                skip,
658                            } => {
659                                retrieval_stats(
660                                    worktree,
661                                    app_state,
662                                    extension,
663                                    limit,
664                                    skip,
665                                    syntax_args_to_options(&args, &syntax_args, false),
666                                    cx,
667                                )
668                                .await
669                            }
670                        };
671                        println!("{}", result.unwrap());
672                    }
673                    Zeta2Command::Llm { args, command } => match command {
674                        Zeta2LlmCommand::Context { context_args } => {
675                            let result =
676                                zeta2_llm_context(args, context_args, &app_state, cx).await;
677                            println!("{}", result.unwrap());
678                        }
679                    },
680                },
681                Command::ConvertExample {
682                    path,
683                    output_format,
684                } => {
685                    let example = NamedExample::load(path).unwrap();
686                    example.write(output_format, io::stdout()).unwrap();
687                }
688            };
689
690            let _ = cx.update(|cx| cx.quit());
691        })
692        .detach();
693    });
694}