main.rs

  1mod headless;
  2mod source_location;
  3mod syntax_retrieval_stats;
  4mod util;
  5
  6use crate::syntax_retrieval_stats::retrieval_stats;
  7use ::serde::Serialize;
  8use ::util::paths::PathStyle;
  9use anyhow::{Context as _, Result, anyhow};
 10use clap::{Args, Parser, Subcommand};
 11use cloud_llm_client::predict_edits_v3::{self, Excerpt};
 12use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
 13use edit_prediction_context::{
 14    EditPredictionContextOptions, EditPredictionExcerpt, EditPredictionExcerptOptions,
 15    EditPredictionScoreOptions, Line,
 16};
 17use futures::StreamExt as _;
 18use futures::channel::mpsc;
 19use gpui::{Application, AsyncApp, Entity, prelude::*};
 20use language::{Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point};
 21use language_model::LanguageModelRegistry;
 22use project::{Project, Worktree};
 23use reqwest_client::ReqwestClient;
 24use serde_json::json;
 25use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc};
 26use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery};
 27
 28use crate::headless::ZetaCliAppState;
 29use crate::source_location::SourceLocation;
 30use crate::util::{open_buffer, open_buffer_with_language_server};
 31
 32#[derive(Parser, Debug)]
 33#[command(name = "zeta")]
 34struct ZetaCliArgs {
 35    #[command(subcommand)]
 36    command: Command,
 37}
 38
 39#[derive(Subcommand, Debug)]
 40enum Command {
 41    Zeta1 {
 42        #[command(subcommand)]
 43        command: Zeta1Command,
 44    },
 45    Zeta2 {
 46        #[clap(flatten)]
 47        args: Zeta2Args,
 48        #[command(subcommand)]
 49        command: Zeta2Command,
 50    },
 51}
 52
 53#[derive(Subcommand, Debug)]
 54enum Zeta1Command {
 55    Context {
 56        #[clap(flatten)]
 57        context_args: ContextArgs,
 58    },
 59}
 60
 61#[derive(Subcommand, Debug)]
 62enum Zeta2Command {
 63    Syntax {
 64        #[clap(flatten)]
 65        syntax_args: Zeta2SyntaxArgs,
 66        #[command(subcommand)]
 67        command: Zeta2SyntaxCommand,
 68    },
 69    Llm {
 70        #[command(subcommand)]
 71        command: Zeta2LlmCommand,
 72    },
 73}
 74
 75#[derive(Subcommand, Debug)]
 76enum Zeta2SyntaxCommand {
 77    Context {
 78        #[clap(flatten)]
 79        context_args: ContextArgs,
 80    },
 81    Stats {
 82        #[arg(long)]
 83        worktree: PathBuf,
 84        #[arg(long)]
 85        extension: Option<String>,
 86        #[arg(long)]
 87        limit: Option<usize>,
 88        #[arg(long)]
 89        skip: Option<usize>,
 90    },
 91}
 92
 93#[derive(Subcommand, Debug)]
 94enum Zeta2LlmCommand {
 95    Context {
 96        #[clap(flatten)]
 97        context_args: ContextArgs,
 98    },
 99}
100
101#[derive(Debug, Args)]
102#[group(requires = "worktree")]
103struct ContextArgs {
104    #[arg(long)]
105    worktree: PathBuf,
106    #[arg(long)]
107    cursor: SourceLocation,
108    #[arg(long)]
109    use_language_server: bool,
110    #[arg(long)]
111    edit_history: Option<FileOrStdin>,
112}
113
114#[derive(Debug, Args)]
115struct Zeta2Args {
116    #[arg(long, default_value_t = 8192)]
117    max_prompt_bytes: usize,
118    #[arg(long, default_value_t = 2048)]
119    max_excerpt_bytes: usize,
120    #[arg(long, default_value_t = 1024)]
121    min_excerpt_bytes: usize,
122    #[arg(long, default_value_t = 0.66)]
123    target_before_cursor_over_total_bytes: f32,
124    #[arg(long, default_value_t = 1024)]
125    max_diagnostic_bytes: usize,
126    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
127    prompt_format: PromptFormat,
128    #[arg(long, value_enum, default_value_t = Default::default())]
129    output_format: OutputFormat,
130    #[arg(long, default_value_t = 42)]
131    file_indexing_parallelism: usize,
132}
133
134#[derive(Debug, Args)]
135struct Zeta2SyntaxArgs {
136    #[arg(long, default_value_t = false)]
137    disable_imports_gathering: bool,
138    #[arg(long, default_value_t = u8::MAX)]
139    max_retrieved_definitions: u8,
140}
141
142fn syntax_args_to_options(
143    zeta2_args: &Zeta2Args,
144    syntax_args: &Zeta2SyntaxArgs,
145    omit_excerpt_overlaps: bool,
146) -> zeta2::ZetaOptions {
147    zeta2::ZetaOptions {
148        context: ContextMode::Syntax(EditPredictionContextOptions {
149            max_retrieved_declarations: syntax_args.max_retrieved_definitions,
150            use_imports: !syntax_args.disable_imports_gathering,
151            excerpt: EditPredictionExcerptOptions {
152                max_bytes: zeta2_args.max_excerpt_bytes,
153                min_bytes: zeta2_args.min_excerpt_bytes,
154                target_before_cursor_over_total_bytes: zeta2_args
155                    .target_before_cursor_over_total_bytes,
156            },
157            score: EditPredictionScoreOptions {
158                omit_excerpt_overlaps,
159            },
160        }),
161        max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
162        max_prompt_bytes: zeta2_args.max_prompt_bytes,
163        prompt_format: zeta2_args.prompt_format.clone().into(),
164        file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
165    }
166}
167
168#[derive(clap::ValueEnum, Default, Debug, Clone)]
169enum PromptFormat {
170    MarkedExcerpt,
171    LabeledSections,
172    OnlySnippets,
173    #[default]
174    NumberedLines,
175}
176
177impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
178    fn into(self) -> predict_edits_v3::PromptFormat {
179        match self {
180            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
181            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
182            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
183            Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
184        }
185    }
186}
187
188#[derive(clap::ValueEnum, Default, Debug, Clone)]
189enum OutputFormat {
190    #[default]
191    Prompt,
192    Request,
193    Full,
194}
195
196#[derive(Debug, Clone)]
197enum FileOrStdin {
198    File(PathBuf),
199    Stdin,
200}
201
202impl FileOrStdin {
203    async fn read_to_string(&self) -> Result<String, std::io::Error> {
204        match self {
205            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
206            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
207        }
208    }
209}
210
211impl FromStr for FileOrStdin {
212    type Err = <PathBuf as FromStr>::Err;
213
214    fn from_str(s: &str) -> Result<Self, Self::Err> {
215        match s {
216            "-" => Ok(Self::Stdin),
217            _ => Ok(Self::File(PathBuf::from_str(s)?)),
218        }
219    }
220}
221
222struct LoadedContext {
223    full_path_str: String,
224    snapshot: BufferSnapshot,
225    clipped_cursor: Point,
226    worktree: Entity<Worktree>,
227    project: Entity<Project>,
228    buffer: Entity<Buffer>,
229}
230
231async fn load_context(
232    args: &ContextArgs,
233    app_state: &Arc<ZetaCliAppState>,
234    cx: &mut AsyncApp,
235) -> Result<LoadedContext> {
236    let ContextArgs {
237        worktree: worktree_path,
238        cursor,
239        use_language_server,
240        ..
241    } = args;
242
243    let worktree_path = worktree_path.canonicalize()?;
244
245    let project = cx.update(|cx| {
246        Project::local(
247            app_state.client.clone(),
248            app_state.node_runtime.clone(),
249            app_state.user_store.clone(),
250            app_state.languages.clone(),
251            app_state.fs.clone(),
252            None,
253            cx,
254        )
255    })?;
256
257    let worktree = project
258        .update(cx, |project, cx| {
259            project.create_worktree(&worktree_path, true, cx)
260        })?
261        .await?;
262
263    let mut ready_languages = HashSet::default();
264    let (_lsp_open_handle, buffer) = if *use_language_server {
265        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
266            project.clone(),
267            worktree.clone(),
268            cursor.path.clone(),
269            &mut ready_languages,
270            cx,
271        )
272        .await?;
273        (Some(lsp_open_handle), buffer)
274    } else {
275        let buffer =
276            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
277        (None, buffer)
278    };
279
280    let full_path_str = worktree
281        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
282        .display(PathStyle::local())
283        .to_string();
284
285    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
286    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
287    if clipped_cursor != cursor.point {
288        let max_row = snapshot.max_point().row;
289        if cursor.point.row < max_row {
290            return Err(anyhow!(
291                "Cursor position {:?} is out of bounds (line length is {})",
292                cursor.point,
293                snapshot.line_len(cursor.point.row)
294            ));
295        } else {
296            return Err(anyhow!(
297                "Cursor position {:?} is out of bounds (max row is {})",
298                cursor.point,
299                max_row
300            ));
301        }
302    }
303
304    Ok(LoadedContext {
305        full_path_str,
306        snapshot,
307        clipped_cursor,
308        worktree,
309        project,
310        buffer,
311    })
312}
313
314async fn zeta2_syntax_context(
315    zeta2_args: Zeta2Args,
316    syntax_args: Zeta2SyntaxArgs,
317    args: ContextArgs,
318    app_state: &Arc<ZetaCliAppState>,
319    cx: &mut AsyncApp,
320) -> Result<String> {
321    let LoadedContext {
322        worktree,
323        project,
324        buffer,
325        clipped_cursor,
326        ..
327    } = load_context(&args, app_state, cx).await?;
328
329    // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
330    // the whole worktree.
331    worktree
332        .read_with(cx, |worktree, _cx| {
333            worktree.as_local().unwrap().scan_complete()
334        })?
335        .await;
336    let output = cx
337        .update(|cx| {
338            let zeta = cx.new(|cx| {
339                zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
340            });
341            let indexing_done_task = zeta.update(cx, |zeta, cx| {
342                zeta.set_options(syntax_args_to_options(&zeta2_args, &syntax_args, true));
343                zeta.register_buffer(&buffer, &project, cx);
344                zeta.wait_for_initial_indexing(&project, cx)
345            });
346            cx.spawn(async move |cx| {
347                indexing_done_task.await?;
348                let request = zeta
349                    .update(cx, |zeta, cx| {
350                        let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
351                        zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
352                    })?
353                    .await?;
354
355                let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
356
357                match zeta2_args.output_format {
358                    OutputFormat::Prompt => anyhow::Ok(prompt_string),
359                    OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
360                    OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
361                        "request": request,
362                        "prompt": prompt_string,
363                        "section_labels": section_labels,
364                    }))?),
365                }
366            })
367        })?
368        .await?;
369
370    Ok(output)
371}
372
373async fn zeta2_llm_context(
374    zeta2_args: Zeta2Args,
375    context_args: ContextArgs,
376    app_state: &Arc<ZetaCliAppState>,
377    cx: &mut AsyncApp,
378) -> Result<String> {
379    let LoadedContext {
380        buffer,
381        clipped_cursor,
382        snapshot: cursor_snapshot,
383        project,
384        ..
385    } = load_context(&context_args, app_state, cx).await?;
386
387    let cursor_position = cursor_snapshot.anchor_after(clipped_cursor);
388
389    cx.update(|cx| {
390        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
391            registry
392                .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
393                .unwrap()
394                .authenticate(cx)
395        })
396    })?
397    .await?;
398
399    let edit_history_unified_diff = match context_args.edit_history {
400        Some(events) => events.read_to_string().await?,
401        None => String::new(),
402    };
403
404    let (debug_tx, mut debug_rx) = mpsc::unbounded();
405
406    let excerpt_options = EditPredictionExcerptOptions {
407        max_bytes: zeta2_args.max_excerpt_bytes,
408        min_bytes: zeta2_args.min_excerpt_bytes,
409        target_before_cursor_over_total_bytes: zeta2_args.target_before_cursor_over_total_bytes,
410    };
411
412    let related_excerpts = cx
413        .update(|cx| {
414            zeta2::related_excerpts::find_related_excerpts(
415                buffer,
416                cursor_position,
417                &project,
418                edit_history_unified_diff,
419                &LlmContextOptions {
420                    excerpt: excerpt_options.clone(),
421                },
422                Some(debug_tx),
423                cx,
424            )
425        })?
426        .await?;
427
428    let cursor_excerpt = EditPredictionExcerpt::select_from_buffer(
429        clipped_cursor,
430        &cursor_snapshot,
431        &excerpt_options,
432        None,
433    )
434    .context("line didn't fit")?;
435
436    #[derive(Serialize)]
437    struct Output {
438        excerpts: Vec<OutputExcerpt>,
439        formatted_excerpts: String,
440        meta: OutputMeta,
441    }
442
443    #[derive(Default, Serialize)]
444    struct OutputMeta {
445        search_prompt: String,
446        search_queries: Vec<SearchToolQuery>,
447    }
448
449    #[derive(Serialize)]
450    struct OutputExcerpt {
451        path: PathBuf,
452        #[serde(flatten)]
453        excerpt: Excerpt,
454    }
455
456    let mut meta = OutputMeta::default();
457
458    while let Some(debug_info) = debug_rx.next().await {
459        match debug_info {
460            zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
461                meta.search_prompt = info.search_prompt;
462            }
463            zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
464                meta.search_queries = info.queries
465            }
466            _ => {}
467        }
468    }
469
470    cx.update(|cx| {
471        let mut excerpts = Vec::new();
472        let mut formatted_excerpts = String::new();
473
474        let cursor_insertions = [(
475            predict_edits_v3::Point {
476                line: Line(clipped_cursor.row),
477                column: clipped_cursor.column,
478            },
479            CURSOR_MARKER,
480        )];
481
482        let mut cursor_excerpt_added = false;
483
484        for (buffer, ranges) in related_excerpts {
485            let excerpt_snapshot = buffer.read(cx).snapshot();
486
487            let mut line_ranges = ranges
488                .into_iter()
489                .map(|range| {
490                    let point_range = range.to_point(&excerpt_snapshot);
491                    Line(point_range.start.row)..Line(point_range.end.row)
492                })
493                .collect::<Vec<_>>();
494
495            let Some(file) = excerpt_snapshot.file() else {
496                continue;
497            };
498            let path = file.full_path(cx);
499
500            let is_cursor_file = path == cursor_snapshot.file().unwrap().full_path(cx);
501            if is_cursor_file {
502                let insertion_ix = line_ranges
503                    .binary_search_by(|probe| {
504                        probe
505                            .start
506                            .cmp(&cursor_excerpt.line_range.start)
507                            .then(cursor_excerpt.line_range.end.cmp(&probe.end))
508                    })
509                    .unwrap_or_else(|ix| ix);
510                line_ranges.insert(insertion_ix, cursor_excerpt.line_range.clone());
511                cursor_excerpt_added = true;
512            }
513
514            let merged_excerpts =
515                zeta2::merge_excerpts::merge_excerpts(&excerpt_snapshot, line_ranges)
516                    .into_iter()
517                    .map(|excerpt| OutputExcerpt {
518                        path: path.clone(),
519                        excerpt,
520                    });
521
522            let excerpt_start_ix = excerpts.len();
523            excerpts.extend(merged_excerpts);
524
525            write_codeblock(
526                &path,
527                excerpts[excerpt_start_ix..].iter().map(|e| &e.excerpt),
528                if is_cursor_file {
529                    &cursor_insertions
530                } else {
531                    &[]
532                },
533                Line(excerpt_snapshot.max_point().row),
534                true,
535                &mut formatted_excerpts,
536            );
537        }
538
539        if !cursor_excerpt_added {
540            write_codeblock(
541                &cursor_snapshot.file().unwrap().full_path(cx),
542                &[Excerpt {
543                    start_line: cursor_excerpt.line_range.start,
544                    text: cursor_excerpt.text(&cursor_snapshot).body.into(),
545                }],
546                &cursor_insertions,
547                Line(cursor_snapshot.max_point().row),
548                true,
549                &mut formatted_excerpts,
550            );
551        }
552
553        let output = Output {
554            excerpts,
555            formatted_excerpts,
556            meta,
557        };
558
559        Ok(serde_json::to_string_pretty(&output)?)
560    })
561    .unwrap()
562}
563
564async fn zeta1_context(
565    args: ContextArgs,
566    app_state: &Arc<ZetaCliAppState>,
567    cx: &mut AsyncApp,
568) -> Result<zeta::GatherContextOutput> {
569    let LoadedContext {
570        full_path_str,
571        snapshot,
572        clipped_cursor,
573        ..
574    } = load_context(&args, app_state, cx).await?;
575
576    let events = match args.edit_history {
577        Some(events) => events.read_to_string().await?,
578        None => String::new(),
579    };
580
581    let prompt_for_events = move || (events, 0);
582    cx.update(|cx| {
583        zeta::gather_context(
584            full_path_str,
585            &snapshot,
586            clipped_cursor,
587            prompt_for_events,
588            cx,
589        )
590    })?
591    .await
592}
593
594fn main() {
595    zlog::init();
596    zlog::init_output_stderr();
597    let args = ZetaCliArgs::parse();
598    let http_client = Arc::new(ReqwestClient::new());
599    let app = Application::headless().with_http_client(http_client);
600
601    app.run(move |cx| {
602        let app_state = Arc::new(headless::init(cx));
603        cx.spawn(async move |cx| {
604            let result = match args.command {
605                Command::Zeta1 {
606                    command: Zeta1Command::Context { context_args },
607                } => {
608                    let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
609                    serde_json::to_string_pretty(&context.body).map_err(|err| anyhow::anyhow!(err))
610                }
611                Command::Zeta2 { args, command } => match command {
612                    Zeta2Command::Syntax {
613                        syntax_args,
614                        command,
615                    } => match command {
616                        Zeta2SyntaxCommand::Context { context_args } => {
617                            zeta2_syntax_context(args, syntax_args, context_args, &app_state, cx)
618                                .await
619                        }
620                        Zeta2SyntaxCommand::Stats {
621                            worktree,
622                            extension,
623                            limit,
624                            skip,
625                        } => {
626                            retrieval_stats(
627                                worktree,
628                                app_state,
629                                extension,
630                                limit,
631                                skip,
632                                syntax_args_to_options(&args, &syntax_args, false),
633                                cx,
634                            )
635                            .await
636                        }
637                    },
638                    Zeta2Command::Llm { command } => match command {
639                        Zeta2LlmCommand::Context { context_args } => {
640                            zeta2_llm_context(args, context_args, &app_state, cx).await
641                        }
642                    },
643                },
644            };
645
646            match result {
647                Ok(output) => {
648                    println!("{}", output);
649                    let _ = cx.update(|cx| cx.quit());
650                }
651                Err(e) => {
652                    eprintln!("Failed: {:?}", e);
653                    exit(1);
654                }
655            }
656        })
657        .detach();
658    });
659}