1mod headless;
2mod source_location;
3mod syntax_retrieval_stats;
4mod util;
5
6use crate::syntax_retrieval_stats::retrieval_stats;
7use ::serde::Serialize;
8use ::util::paths::PathStyle;
9use anyhow::{Context as _, Result, anyhow};
10use clap::{Args, Parser, Subcommand};
11use cloud_llm_client::predict_edits_v3::{self, Excerpt};
12use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
13use edit_prediction_context::{
14 EditPredictionContextOptions, EditPredictionExcerpt, EditPredictionExcerptOptions,
15 EditPredictionScoreOptions, Line,
16};
17use futures::StreamExt as _;
18use futures::channel::mpsc;
19use gpui::{Application, AsyncApp, Entity, prelude::*};
20use language::{Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point};
21use language_model::LanguageModelRegistry;
22use project::{Project, Worktree};
23use reqwest_client::ReqwestClient;
24use serde_json::json;
25use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc};
26use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery};
27
28use crate::headless::ZetaCliAppState;
29use crate::source_location::SourceLocation;
30use crate::util::{open_buffer, open_buffer_with_language_server};
31
32#[derive(Parser, Debug)]
33#[command(name = "zeta")]
34struct ZetaCliArgs {
35 #[command(subcommand)]
36 command: Command,
37}
38
39#[derive(Subcommand, Debug)]
40enum Command {
41 Zeta1 {
42 #[command(subcommand)]
43 command: Zeta1Command,
44 },
45 Zeta2 {
46 #[clap(flatten)]
47 args: Zeta2Args,
48 #[command(subcommand)]
49 command: Zeta2Command,
50 },
51}
52
53#[derive(Subcommand, Debug)]
54enum Zeta1Command {
55 Context {
56 #[clap(flatten)]
57 context_args: ContextArgs,
58 },
59}
60
61#[derive(Subcommand, Debug)]
62enum Zeta2Command {
63 Syntax {
64 #[clap(flatten)]
65 syntax_args: Zeta2SyntaxArgs,
66 #[command(subcommand)]
67 command: Zeta2SyntaxCommand,
68 },
69 Llm {
70 #[command(subcommand)]
71 command: Zeta2LlmCommand,
72 },
73}
74
75#[derive(Subcommand, Debug)]
76enum Zeta2SyntaxCommand {
77 Context {
78 #[clap(flatten)]
79 context_args: ContextArgs,
80 },
81 Stats {
82 #[arg(long)]
83 worktree: PathBuf,
84 #[arg(long)]
85 extension: Option<String>,
86 #[arg(long)]
87 limit: Option<usize>,
88 #[arg(long)]
89 skip: Option<usize>,
90 },
91}
92
93#[derive(Subcommand, Debug)]
94enum Zeta2LlmCommand {
95 Context {
96 #[clap(flatten)]
97 context_args: ContextArgs,
98 },
99}
100
101#[derive(Debug, Args)]
102#[group(requires = "worktree")]
103struct ContextArgs {
104 #[arg(long)]
105 worktree: PathBuf,
106 #[arg(long)]
107 cursor: SourceLocation,
108 #[arg(long)]
109 use_language_server: bool,
110 #[arg(long)]
111 edit_history: Option<FileOrStdin>,
112}
113
114#[derive(Debug, Args)]
115struct Zeta2Args {
116 #[arg(long, default_value_t = 8192)]
117 max_prompt_bytes: usize,
118 #[arg(long, default_value_t = 2048)]
119 max_excerpt_bytes: usize,
120 #[arg(long, default_value_t = 1024)]
121 min_excerpt_bytes: usize,
122 #[arg(long, default_value_t = 0.66)]
123 target_before_cursor_over_total_bytes: f32,
124 #[arg(long, default_value_t = 1024)]
125 max_diagnostic_bytes: usize,
126 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
127 prompt_format: PromptFormat,
128 #[arg(long, value_enum, default_value_t = Default::default())]
129 output_format: OutputFormat,
130 #[arg(long, default_value_t = 42)]
131 file_indexing_parallelism: usize,
132}
133
134#[derive(Debug, Args)]
135struct Zeta2SyntaxArgs {
136 #[arg(long, default_value_t = false)]
137 disable_imports_gathering: bool,
138 #[arg(long, default_value_t = u8::MAX)]
139 max_retrieved_definitions: u8,
140}
141
142fn syntax_args_to_options(
143 zeta2_args: &Zeta2Args,
144 syntax_args: &Zeta2SyntaxArgs,
145 omit_excerpt_overlaps: bool,
146) -> zeta2::ZetaOptions {
147 zeta2::ZetaOptions {
148 context: ContextMode::Syntax(EditPredictionContextOptions {
149 max_retrieved_declarations: syntax_args.max_retrieved_definitions,
150 use_imports: !syntax_args.disable_imports_gathering,
151 excerpt: EditPredictionExcerptOptions {
152 max_bytes: zeta2_args.max_excerpt_bytes,
153 min_bytes: zeta2_args.min_excerpt_bytes,
154 target_before_cursor_over_total_bytes: zeta2_args
155 .target_before_cursor_over_total_bytes,
156 },
157 score: EditPredictionScoreOptions {
158 omit_excerpt_overlaps,
159 },
160 }),
161 max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
162 max_prompt_bytes: zeta2_args.max_prompt_bytes,
163 prompt_format: zeta2_args.prompt_format.clone().into(),
164 file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
165 }
166}
167
168#[derive(clap::ValueEnum, Default, Debug, Clone)]
169enum PromptFormat {
170 MarkedExcerpt,
171 LabeledSections,
172 OnlySnippets,
173 #[default]
174 NumberedLines,
175}
176
177impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
178 fn into(self) -> predict_edits_v3::PromptFormat {
179 match self {
180 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
181 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
182 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
183 Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
184 }
185 }
186}
187
188#[derive(clap::ValueEnum, Default, Debug, Clone)]
189enum OutputFormat {
190 #[default]
191 Prompt,
192 Request,
193 Full,
194}
195
196#[derive(Debug, Clone)]
197enum FileOrStdin {
198 File(PathBuf),
199 Stdin,
200}
201
202impl FileOrStdin {
203 async fn read_to_string(&self) -> Result<String, std::io::Error> {
204 match self {
205 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
206 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
207 }
208 }
209}
210
211impl FromStr for FileOrStdin {
212 type Err = <PathBuf as FromStr>::Err;
213
214 fn from_str(s: &str) -> Result<Self, Self::Err> {
215 match s {
216 "-" => Ok(Self::Stdin),
217 _ => Ok(Self::File(PathBuf::from_str(s)?)),
218 }
219 }
220}
221
222struct LoadedContext {
223 full_path_str: String,
224 snapshot: BufferSnapshot,
225 clipped_cursor: Point,
226 worktree: Entity<Worktree>,
227 project: Entity<Project>,
228 buffer: Entity<Buffer>,
229}
230
231async fn load_context(
232 args: &ContextArgs,
233 app_state: &Arc<ZetaCliAppState>,
234 cx: &mut AsyncApp,
235) -> Result<LoadedContext> {
236 let ContextArgs {
237 worktree: worktree_path,
238 cursor,
239 use_language_server,
240 ..
241 } = args;
242
243 let worktree_path = worktree_path.canonicalize()?;
244
245 let project = cx.update(|cx| {
246 Project::local(
247 app_state.client.clone(),
248 app_state.node_runtime.clone(),
249 app_state.user_store.clone(),
250 app_state.languages.clone(),
251 app_state.fs.clone(),
252 None,
253 cx,
254 )
255 })?;
256
257 let worktree = project
258 .update(cx, |project, cx| {
259 project.create_worktree(&worktree_path, true, cx)
260 })?
261 .await?;
262
263 let mut ready_languages = HashSet::default();
264 let (_lsp_open_handle, buffer) = if *use_language_server {
265 let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
266 project.clone(),
267 worktree.clone(),
268 cursor.path.clone(),
269 &mut ready_languages,
270 cx,
271 )
272 .await?;
273 (Some(lsp_open_handle), buffer)
274 } else {
275 let buffer =
276 open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
277 (None, buffer)
278 };
279
280 let full_path_str = worktree
281 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
282 .display(PathStyle::local())
283 .to_string();
284
285 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
286 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
287 if clipped_cursor != cursor.point {
288 let max_row = snapshot.max_point().row;
289 if cursor.point.row < max_row {
290 return Err(anyhow!(
291 "Cursor position {:?} is out of bounds (line length is {})",
292 cursor.point,
293 snapshot.line_len(cursor.point.row)
294 ));
295 } else {
296 return Err(anyhow!(
297 "Cursor position {:?} is out of bounds (max row is {})",
298 cursor.point,
299 max_row
300 ));
301 }
302 }
303
304 Ok(LoadedContext {
305 full_path_str,
306 snapshot,
307 clipped_cursor,
308 worktree,
309 project,
310 buffer,
311 })
312}
313
314async fn zeta2_syntax_context(
315 zeta2_args: Zeta2Args,
316 syntax_args: Zeta2SyntaxArgs,
317 args: ContextArgs,
318 app_state: &Arc<ZetaCliAppState>,
319 cx: &mut AsyncApp,
320) -> Result<String> {
321 let LoadedContext {
322 worktree,
323 project,
324 buffer,
325 clipped_cursor,
326 ..
327 } = load_context(&args, app_state, cx).await?;
328
329 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
330 // the whole worktree.
331 worktree
332 .read_with(cx, |worktree, _cx| {
333 worktree.as_local().unwrap().scan_complete()
334 })?
335 .await;
336 let output = cx
337 .update(|cx| {
338 let zeta = cx.new(|cx| {
339 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
340 });
341 let indexing_done_task = zeta.update(cx, |zeta, cx| {
342 zeta.set_options(syntax_args_to_options(&zeta2_args, &syntax_args, true));
343 zeta.register_buffer(&buffer, &project, cx);
344 zeta.wait_for_initial_indexing(&project, cx)
345 });
346 cx.spawn(async move |cx| {
347 indexing_done_task.await?;
348 let request = zeta
349 .update(cx, |zeta, cx| {
350 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
351 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
352 })?
353 .await?;
354
355 let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
356
357 match zeta2_args.output_format {
358 OutputFormat::Prompt => anyhow::Ok(prompt_string),
359 OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
360 OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
361 "request": request,
362 "prompt": prompt_string,
363 "section_labels": section_labels,
364 }))?),
365 }
366 })
367 })?
368 .await?;
369
370 Ok(output)
371}
372
373async fn zeta2_llm_context(
374 zeta2_args: Zeta2Args,
375 context_args: ContextArgs,
376 app_state: &Arc<ZetaCliAppState>,
377 cx: &mut AsyncApp,
378) -> Result<String> {
379 let LoadedContext {
380 buffer,
381 clipped_cursor,
382 snapshot: cursor_snapshot,
383 project,
384 ..
385 } = load_context(&context_args, app_state, cx).await?;
386
387 let cursor_position = cursor_snapshot.anchor_after(clipped_cursor);
388
389 cx.update(|cx| {
390 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
391 registry
392 .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
393 .unwrap()
394 .authenticate(cx)
395 })
396 })?
397 .await?;
398
399 let edit_history_unified_diff = match context_args.edit_history {
400 Some(events) => events.read_to_string().await?,
401 None => String::new(),
402 };
403
404 let (debug_tx, mut debug_rx) = mpsc::unbounded();
405
406 let excerpt_options = EditPredictionExcerptOptions {
407 max_bytes: zeta2_args.max_excerpt_bytes,
408 min_bytes: zeta2_args.min_excerpt_bytes,
409 target_before_cursor_over_total_bytes: zeta2_args.target_before_cursor_over_total_bytes,
410 };
411
412 let related_excerpts = cx
413 .update(|cx| {
414 zeta2::related_excerpts::find_related_excerpts(
415 buffer,
416 cursor_position,
417 &project,
418 edit_history_unified_diff,
419 &LlmContextOptions {
420 excerpt: excerpt_options.clone(),
421 },
422 Some(debug_tx),
423 cx,
424 )
425 })?
426 .await?;
427
428 let cursor_excerpt = EditPredictionExcerpt::select_from_buffer(
429 clipped_cursor,
430 &cursor_snapshot,
431 &excerpt_options,
432 None,
433 )
434 .context("line didn't fit")?;
435
436 #[derive(Serialize)]
437 struct Output {
438 excerpts: Vec<OutputExcerpt>,
439 formatted_excerpts: String,
440 meta: OutputMeta,
441 }
442
443 #[derive(Default, Serialize)]
444 struct OutputMeta {
445 search_prompt: String,
446 search_queries: Vec<SearchToolQuery>,
447 }
448
449 #[derive(Serialize)]
450 struct OutputExcerpt {
451 path: PathBuf,
452 #[serde(flatten)]
453 excerpt: Excerpt,
454 }
455
456 let mut meta = OutputMeta::default();
457
458 while let Some(debug_info) = debug_rx.next().await {
459 match debug_info {
460 zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
461 meta.search_prompt = info.search_prompt;
462 }
463 zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
464 meta.search_queries = info.queries
465 }
466 _ => {}
467 }
468 }
469
470 cx.update(|cx| {
471 let mut excerpts = Vec::new();
472 let mut formatted_excerpts = String::new();
473
474 let cursor_insertions = [(
475 predict_edits_v3::Point {
476 line: Line(clipped_cursor.row),
477 column: clipped_cursor.column,
478 },
479 CURSOR_MARKER,
480 )];
481
482 let mut cursor_excerpt_added = false;
483
484 for (buffer, ranges) in related_excerpts {
485 let excerpt_snapshot = buffer.read(cx).snapshot();
486
487 let mut line_ranges = ranges
488 .into_iter()
489 .map(|range| {
490 let point_range = range.to_point(&excerpt_snapshot);
491 Line(point_range.start.row)..Line(point_range.end.row)
492 })
493 .collect::<Vec<_>>();
494
495 let Some(file) = excerpt_snapshot.file() else {
496 continue;
497 };
498 let path = file.full_path(cx);
499
500 let is_cursor_file = path == cursor_snapshot.file().unwrap().full_path(cx);
501 if is_cursor_file {
502 let insertion_ix = line_ranges
503 .binary_search_by(|probe| {
504 probe
505 .start
506 .cmp(&cursor_excerpt.line_range.start)
507 .then(cursor_excerpt.line_range.end.cmp(&probe.end))
508 })
509 .unwrap_or_else(|ix| ix);
510 line_ranges.insert(insertion_ix, cursor_excerpt.line_range.clone());
511 cursor_excerpt_added = true;
512 }
513
514 let merged_excerpts =
515 zeta2::merge_excerpts::merge_excerpts(&excerpt_snapshot, line_ranges)
516 .into_iter()
517 .map(|excerpt| OutputExcerpt {
518 path: path.clone(),
519 excerpt,
520 });
521
522 let excerpt_start_ix = excerpts.len();
523 excerpts.extend(merged_excerpts);
524
525 write_codeblock(
526 &path,
527 excerpts[excerpt_start_ix..].iter().map(|e| &e.excerpt),
528 if is_cursor_file {
529 &cursor_insertions
530 } else {
531 &[]
532 },
533 Line(excerpt_snapshot.max_point().row),
534 true,
535 &mut formatted_excerpts,
536 );
537 }
538
539 if !cursor_excerpt_added {
540 write_codeblock(
541 &cursor_snapshot.file().unwrap().full_path(cx),
542 &[Excerpt {
543 start_line: cursor_excerpt.line_range.start,
544 text: cursor_excerpt.text(&cursor_snapshot).body.into(),
545 }],
546 &cursor_insertions,
547 Line(cursor_snapshot.max_point().row),
548 true,
549 &mut formatted_excerpts,
550 );
551 }
552
553 let output = Output {
554 excerpts,
555 formatted_excerpts,
556 meta,
557 };
558
559 Ok(serde_json::to_string_pretty(&output)?)
560 })
561 .unwrap()
562}
563
564async fn zeta1_context(
565 args: ContextArgs,
566 app_state: &Arc<ZetaCliAppState>,
567 cx: &mut AsyncApp,
568) -> Result<zeta::GatherContextOutput> {
569 let LoadedContext {
570 full_path_str,
571 snapshot,
572 clipped_cursor,
573 ..
574 } = load_context(&args, app_state, cx).await?;
575
576 let events = match args.edit_history {
577 Some(events) => events.read_to_string().await?,
578 None => String::new(),
579 };
580
581 let prompt_for_events = move || (events, 0);
582 cx.update(|cx| {
583 zeta::gather_context(
584 full_path_str,
585 &snapshot,
586 clipped_cursor,
587 prompt_for_events,
588 cx,
589 )
590 })?
591 .await
592}
593
594fn main() {
595 zlog::init();
596 zlog::init_output_stderr();
597 let args = ZetaCliArgs::parse();
598 let http_client = Arc::new(ReqwestClient::new());
599 let app = Application::headless().with_http_client(http_client);
600
601 app.run(move |cx| {
602 let app_state = Arc::new(headless::init(cx));
603 cx.spawn(async move |cx| {
604 let result = match args.command {
605 Command::Zeta1 {
606 command: Zeta1Command::Context { context_args },
607 } => {
608 let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
609 serde_json::to_string_pretty(&context.body).map_err(|err| anyhow::anyhow!(err))
610 }
611 Command::Zeta2 { args, command } => match command {
612 Zeta2Command::Syntax {
613 syntax_args,
614 command,
615 } => match command {
616 Zeta2SyntaxCommand::Context { context_args } => {
617 zeta2_syntax_context(args, syntax_args, context_args, &app_state, cx)
618 .await
619 }
620 Zeta2SyntaxCommand::Stats {
621 worktree,
622 extension,
623 limit,
624 skip,
625 } => {
626 retrieval_stats(
627 worktree,
628 app_state,
629 extension,
630 limit,
631 skip,
632 syntax_args_to_options(&args, &syntax_args, false),
633 cx,
634 )
635 .await
636 }
637 },
638 Zeta2Command::Llm { command } => match command {
639 Zeta2LlmCommand::Context { context_args } => {
640 zeta2_llm_context(args, context_args, &app_state, cx).await
641 }
642 },
643 },
644 };
645
646 match result {
647 Ok(output) => {
648 println!("{}", output);
649 let _ = cx.update(|cx| cx.quit());
650 }
651 Err(e) => {
652 eprintln!("Failed: {:?}", e);
653 exit(1);
654 }
655 }
656 })
657 .detach();
658 });
659}