1mod example;
2mod headless;
3mod source_location;
4mod syntax_retrieval_stats;
5mod util;
6
7use crate::example::{ExampleFormat, NamedExample};
8use crate::syntax_retrieval_stats::retrieval_stats;
9use ::serde::Serialize;
10use ::util::paths::PathStyle;
11use anyhow::{Context as _, Result, anyhow};
12use clap::{Args, Parser, Subcommand};
13use cloud_llm_client::predict_edits_v3::{self, Excerpt};
14use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
15use edit_prediction_context::{
16 EditPredictionContextOptions, EditPredictionExcerpt, EditPredictionExcerptOptions,
17 EditPredictionScoreOptions, Line,
18};
19use futures::StreamExt as _;
20use futures::channel::mpsc;
21use gpui::{Application, AsyncApp, Entity, prelude::*};
22use language::{Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point};
23use language_model::LanguageModelRegistry;
24use project::{Project, Worktree};
25use reqwest_client::ReqwestClient;
26use serde_json::json;
27use std::io;
28use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc};
29use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery};
30
31use crate::headless::ZetaCliAppState;
32use crate::source_location::SourceLocation;
33use crate::util::{open_buffer, open_buffer_with_language_server};
34
35#[derive(Parser, Debug)]
36#[command(name = "zeta")]
37struct ZetaCliArgs {
38 #[command(subcommand)]
39 command: Command,
40}
41
42#[derive(Subcommand, Debug)]
43enum Command {
44 Zeta1 {
45 #[command(subcommand)]
46 command: Zeta1Command,
47 },
48 Zeta2 {
49 #[clap(flatten)]
50 args: Zeta2Args,
51 #[command(subcommand)]
52 command: Zeta2Command,
53 },
54 ConvertExample {
55 path: PathBuf,
56 #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
57 output_format: ExampleFormat,
58 },
59}
60
61#[derive(Subcommand, Debug)]
62enum Zeta1Command {
63 Context {
64 #[clap(flatten)]
65 context_args: ContextArgs,
66 },
67}
68
69#[derive(Subcommand, Debug)]
70enum Zeta2Command {
71 Syntax {
72 #[clap(flatten)]
73 syntax_args: Zeta2SyntaxArgs,
74 #[command(subcommand)]
75 command: Zeta2SyntaxCommand,
76 },
77 Llm {
78 #[command(subcommand)]
79 command: Zeta2LlmCommand,
80 },
81}
82
83#[derive(Subcommand, Debug)]
84enum Zeta2SyntaxCommand {
85 Context {
86 #[clap(flatten)]
87 context_args: ContextArgs,
88 },
89 Stats {
90 #[arg(long)]
91 worktree: PathBuf,
92 #[arg(long)]
93 extension: Option<String>,
94 #[arg(long)]
95 limit: Option<usize>,
96 #[arg(long)]
97 skip: Option<usize>,
98 },
99}
100
101#[derive(Subcommand, Debug)]
102enum Zeta2LlmCommand {
103 Context {
104 #[clap(flatten)]
105 context_args: ContextArgs,
106 },
107}
108
109#[derive(Debug, Args)]
110#[group(requires = "worktree")]
111struct ContextArgs {
112 #[arg(long)]
113 worktree: PathBuf,
114 #[arg(long)]
115 cursor: SourceLocation,
116 #[arg(long)]
117 use_language_server: bool,
118 #[arg(long)]
119 edit_history: Option<FileOrStdin>,
120}
121
122#[derive(Debug, Args)]
123struct Zeta2Args {
124 #[arg(long, default_value_t = 8192)]
125 max_prompt_bytes: usize,
126 #[arg(long, default_value_t = 2048)]
127 max_excerpt_bytes: usize,
128 #[arg(long, default_value_t = 1024)]
129 min_excerpt_bytes: usize,
130 #[arg(long, default_value_t = 0.66)]
131 target_before_cursor_over_total_bytes: f32,
132 #[arg(long, default_value_t = 1024)]
133 max_diagnostic_bytes: usize,
134 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
135 prompt_format: PromptFormat,
136 #[arg(long, value_enum, default_value_t = Default::default())]
137 output_format: OutputFormat,
138 #[arg(long, default_value_t = 42)]
139 file_indexing_parallelism: usize,
140}
141
142#[derive(Debug, Args)]
143struct Zeta2SyntaxArgs {
144 #[arg(long, default_value_t = false)]
145 disable_imports_gathering: bool,
146 #[arg(long, default_value_t = u8::MAX)]
147 max_retrieved_definitions: u8,
148}
149
150fn syntax_args_to_options(
151 zeta2_args: &Zeta2Args,
152 syntax_args: &Zeta2SyntaxArgs,
153 omit_excerpt_overlaps: bool,
154) -> zeta2::ZetaOptions {
155 zeta2::ZetaOptions {
156 context: ContextMode::Syntax(EditPredictionContextOptions {
157 max_retrieved_declarations: syntax_args.max_retrieved_definitions,
158 use_imports: !syntax_args.disable_imports_gathering,
159 excerpt: EditPredictionExcerptOptions {
160 max_bytes: zeta2_args.max_excerpt_bytes,
161 min_bytes: zeta2_args.min_excerpt_bytes,
162 target_before_cursor_over_total_bytes: zeta2_args
163 .target_before_cursor_over_total_bytes,
164 },
165 score: EditPredictionScoreOptions {
166 omit_excerpt_overlaps,
167 },
168 }),
169 max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
170 max_prompt_bytes: zeta2_args.max_prompt_bytes,
171 prompt_format: zeta2_args.prompt_format.clone().into(),
172 file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
173 }
174}
175
176#[derive(clap::ValueEnum, Default, Debug, Clone)]
177enum PromptFormat {
178 MarkedExcerpt,
179 LabeledSections,
180 OnlySnippets,
181 #[default]
182 NumberedLines,
183}
184
185impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
186 fn into(self) -> predict_edits_v3::PromptFormat {
187 match self {
188 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
189 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
190 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
191 Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
192 }
193 }
194}
195
196#[derive(clap::ValueEnum, Default, Debug, Clone)]
197enum OutputFormat {
198 #[default]
199 Prompt,
200 Request,
201 Full,
202}
203
204#[derive(Debug, Clone)]
205enum FileOrStdin {
206 File(PathBuf),
207 Stdin,
208}
209
210impl FileOrStdin {
211 async fn read_to_string(&self) -> Result<String, std::io::Error> {
212 match self {
213 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
214 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
215 }
216 }
217}
218
219impl FromStr for FileOrStdin {
220 type Err = <PathBuf as FromStr>::Err;
221
222 fn from_str(s: &str) -> Result<Self, Self::Err> {
223 match s {
224 "-" => Ok(Self::Stdin),
225 _ => Ok(Self::File(PathBuf::from_str(s)?)),
226 }
227 }
228}
229
230struct LoadedContext {
231 full_path_str: String,
232 snapshot: BufferSnapshot,
233 clipped_cursor: Point,
234 worktree: Entity<Worktree>,
235 project: Entity<Project>,
236 buffer: Entity<Buffer>,
237}
238
239async fn load_context(
240 args: &ContextArgs,
241 app_state: &Arc<ZetaCliAppState>,
242 cx: &mut AsyncApp,
243) -> Result<LoadedContext> {
244 let ContextArgs {
245 worktree: worktree_path,
246 cursor,
247 use_language_server,
248 ..
249 } = args;
250
251 let worktree_path = worktree_path.canonicalize()?;
252
253 let project = cx.update(|cx| {
254 Project::local(
255 app_state.client.clone(),
256 app_state.node_runtime.clone(),
257 app_state.user_store.clone(),
258 app_state.languages.clone(),
259 app_state.fs.clone(),
260 None,
261 cx,
262 )
263 })?;
264
265 let worktree = project
266 .update(cx, |project, cx| {
267 project.create_worktree(&worktree_path, true, cx)
268 })?
269 .await?;
270
271 let mut ready_languages = HashSet::default();
272 let (_lsp_open_handle, buffer) = if *use_language_server {
273 let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
274 project.clone(),
275 worktree.clone(),
276 cursor.path.clone(),
277 &mut ready_languages,
278 cx,
279 )
280 .await?;
281 (Some(lsp_open_handle), buffer)
282 } else {
283 let buffer =
284 open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
285 (None, buffer)
286 };
287
288 let full_path_str = worktree
289 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
290 .display(PathStyle::local())
291 .to_string();
292
293 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
294 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
295 if clipped_cursor != cursor.point {
296 let max_row = snapshot.max_point().row;
297 if cursor.point.row < max_row {
298 return Err(anyhow!(
299 "Cursor position {:?} is out of bounds (line length is {})",
300 cursor.point,
301 snapshot.line_len(cursor.point.row)
302 ));
303 } else {
304 return Err(anyhow!(
305 "Cursor position {:?} is out of bounds (max row is {})",
306 cursor.point,
307 max_row
308 ));
309 }
310 }
311
312 Ok(LoadedContext {
313 full_path_str,
314 snapshot,
315 clipped_cursor,
316 worktree,
317 project,
318 buffer,
319 })
320}
321
322async fn zeta2_syntax_context(
323 zeta2_args: Zeta2Args,
324 syntax_args: Zeta2SyntaxArgs,
325 args: ContextArgs,
326 app_state: &Arc<ZetaCliAppState>,
327 cx: &mut AsyncApp,
328) -> Result<String> {
329 let LoadedContext {
330 worktree,
331 project,
332 buffer,
333 clipped_cursor,
334 ..
335 } = load_context(&args, app_state, cx).await?;
336
337 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
338 // the whole worktree.
339 worktree
340 .read_with(cx, |worktree, _cx| {
341 worktree.as_local().unwrap().scan_complete()
342 })?
343 .await;
344 let output = cx
345 .update(|cx| {
346 let zeta = cx.new(|cx| {
347 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
348 });
349 let indexing_done_task = zeta.update(cx, |zeta, cx| {
350 zeta.set_options(syntax_args_to_options(&zeta2_args, &syntax_args, true));
351 zeta.register_buffer(&buffer, &project, cx);
352 zeta.wait_for_initial_indexing(&project, cx)
353 });
354 cx.spawn(async move |cx| {
355 indexing_done_task.await?;
356 let request = zeta
357 .update(cx, |zeta, cx| {
358 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
359 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
360 })?
361 .await?;
362
363 let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
364
365 match zeta2_args.output_format {
366 OutputFormat::Prompt => anyhow::Ok(prompt_string),
367 OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
368 OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
369 "request": request,
370 "prompt": prompt_string,
371 "section_labels": section_labels,
372 }))?),
373 }
374 })
375 })?
376 .await?;
377
378 Ok(output)
379}
380
381async fn zeta2_llm_context(
382 zeta2_args: Zeta2Args,
383 context_args: ContextArgs,
384 app_state: &Arc<ZetaCliAppState>,
385 cx: &mut AsyncApp,
386) -> Result<String> {
387 let LoadedContext {
388 buffer,
389 clipped_cursor,
390 snapshot: cursor_snapshot,
391 project,
392 ..
393 } = load_context(&context_args, app_state, cx).await?;
394
395 let cursor_position = cursor_snapshot.anchor_after(clipped_cursor);
396
397 cx.update(|cx| {
398 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
399 registry
400 .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
401 .unwrap()
402 .authenticate(cx)
403 })
404 })?
405 .await?;
406
407 let edit_history_unified_diff = match context_args.edit_history {
408 Some(events) => events.read_to_string().await?,
409 None => String::new(),
410 };
411
412 let (debug_tx, mut debug_rx) = mpsc::unbounded();
413
414 let excerpt_options = EditPredictionExcerptOptions {
415 max_bytes: zeta2_args.max_excerpt_bytes,
416 min_bytes: zeta2_args.min_excerpt_bytes,
417 target_before_cursor_over_total_bytes: zeta2_args.target_before_cursor_over_total_bytes,
418 };
419
420 let related_excerpts = cx
421 .update(|cx| {
422 zeta2::related_excerpts::find_related_excerpts(
423 buffer,
424 cursor_position,
425 &project,
426 edit_history_unified_diff,
427 &LlmContextOptions {
428 excerpt: excerpt_options.clone(),
429 },
430 Some(debug_tx),
431 cx,
432 )
433 })?
434 .await?;
435
436 let cursor_excerpt = EditPredictionExcerpt::select_from_buffer(
437 clipped_cursor,
438 &cursor_snapshot,
439 &excerpt_options,
440 None,
441 )
442 .context("line didn't fit")?;
443
444 #[derive(Serialize)]
445 struct Output {
446 excerpts: Vec<OutputExcerpt>,
447 formatted_excerpts: String,
448 meta: OutputMeta,
449 }
450
451 #[derive(Default, Serialize)]
452 struct OutputMeta {
453 search_prompt: String,
454 search_queries: Vec<SearchToolQuery>,
455 }
456
457 #[derive(Serialize)]
458 struct OutputExcerpt {
459 path: PathBuf,
460 #[serde(flatten)]
461 excerpt: Excerpt,
462 }
463
464 let mut meta = OutputMeta::default();
465
466 while let Some(debug_info) = debug_rx.next().await {
467 match debug_info {
468 zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
469 meta.search_prompt = info.search_prompt;
470 }
471 zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
472 meta.search_queries = info.queries
473 }
474 _ => {}
475 }
476 }
477
478 cx.update(|cx| {
479 let mut excerpts = Vec::new();
480 let mut formatted_excerpts = String::new();
481
482 let cursor_insertions = [(
483 predict_edits_v3::Point {
484 line: Line(clipped_cursor.row),
485 column: clipped_cursor.column,
486 },
487 CURSOR_MARKER,
488 )];
489
490 let mut cursor_excerpt_added = false;
491
492 for (buffer, ranges) in related_excerpts {
493 let excerpt_snapshot = buffer.read(cx).snapshot();
494
495 let mut line_ranges = ranges
496 .into_iter()
497 .map(|range| {
498 let point_range = range.to_point(&excerpt_snapshot);
499 Line(point_range.start.row)..Line(point_range.end.row)
500 })
501 .collect::<Vec<_>>();
502
503 let Some(file) = excerpt_snapshot.file() else {
504 continue;
505 };
506 let path = file.full_path(cx);
507
508 let is_cursor_file = path == cursor_snapshot.file().unwrap().full_path(cx);
509 if is_cursor_file {
510 let insertion_ix = line_ranges
511 .binary_search_by(|probe| {
512 probe
513 .start
514 .cmp(&cursor_excerpt.line_range.start)
515 .then(cursor_excerpt.line_range.end.cmp(&probe.end))
516 })
517 .unwrap_or_else(|ix| ix);
518 line_ranges.insert(insertion_ix, cursor_excerpt.line_range.clone());
519 cursor_excerpt_added = true;
520 }
521
522 let merged_excerpts =
523 zeta2::merge_excerpts::merge_excerpts(&excerpt_snapshot, line_ranges)
524 .into_iter()
525 .map(|excerpt| OutputExcerpt {
526 path: path.clone(),
527 excerpt,
528 });
529
530 let excerpt_start_ix = excerpts.len();
531 excerpts.extend(merged_excerpts);
532
533 write_codeblock(
534 &path,
535 excerpts[excerpt_start_ix..].iter().map(|e| &e.excerpt),
536 if is_cursor_file {
537 &cursor_insertions
538 } else {
539 &[]
540 },
541 Line(excerpt_snapshot.max_point().row),
542 true,
543 &mut formatted_excerpts,
544 );
545 }
546
547 if !cursor_excerpt_added {
548 write_codeblock(
549 &cursor_snapshot.file().unwrap().full_path(cx),
550 &[Excerpt {
551 start_line: cursor_excerpt.line_range.start,
552 text: cursor_excerpt.text(&cursor_snapshot).body.into(),
553 }],
554 &cursor_insertions,
555 Line(cursor_snapshot.max_point().row),
556 true,
557 &mut formatted_excerpts,
558 );
559 }
560
561 let output = Output {
562 excerpts,
563 formatted_excerpts,
564 meta,
565 };
566
567 Ok(serde_json::to_string_pretty(&output)?)
568 })
569 .unwrap()
570}
571
572async fn zeta1_context(
573 args: ContextArgs,
574 app_state: &Arc<ZetaCliAppState>,
575 cx: &mut AsyncApp,
576) -> Result<zeta::GatherContextOutput> {
577 let LoadedContext {
578 full_path_str,
579 snapshot,
580 clipped_cursor,
581 ..
582 } = load_context(&args, app_state, cx).await?;
583
584 let events = match args.edit_history {
585 Some(events) => events.read_to_string().await?,
586 None => String::new(),
587 };
588
589 let prompt_for_events = move || (events, 0);
590 cx.update(|cx| {
591 zeta::gather_context(
592 full_path_str,
593 &snapshot,
594 clipped_cursor,
595 prompt_for_events,
596 cx,
597 )
598 })?
599 .await
600}
601
602fn main() {
603 zlog::init();
604 zlog::init_output_stderr();
605 let args = ZetaCliArgs::parse();
606 let http_client = Arc::new(ReqwestClient::new());
607 let app = Application::headless().with_http_client(http_client);
608
609 app.run(move |cx| {
610 let app_state = Arc::new(headless::init(cx));
611 cx.spawn(async move |cx| {
612 let result = match args.command {
613 Command::Zeta1 {
614 command: Zeta1Command::Context { context_args },
615 } => {
616 let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
617 serde_json::to_string_pretty(&context.body).map_err(|err| anyhow::anyhow!(err))
618 }
619 Command::Zeta2 { args, command } => match command {
620 Zeta2Command::Syntax {
621 syntax_args,
622 command,
623 } => match command {
624 Zeta2SyntaxCommand::Context { context_args } => {
625 zeta2_syntax_context(args, syntax_args, context_args, &app_state, cx)
626 .await
627 }
628 Zeta2SyntaxCommand::Stats {
629 worktree,
630 extension,
631 limit,
632 skip,
633 } => {
634 retrieval_stats(
635 worktree,
636 app_state,
637 extension,
638 limit,
639 skip,
640 syntax_args_to_options(&args, &syntax_args, false),
641 cx,
642 )
643 .await
644 }
645 },
646 Zeta2Command::Llm { command } => match command {
647 Zeta2LlmCommand::Context { context_args } => {
648 zeta2_llm_context(args, context_args, &app_state, cx).await
649 }
650 },
651 },
652 Command::ConvertExample {
653 path,
654 output_format,
655 } => {
656 let example = NamedExample::load(path).unwrap();
657 example.write(output_format, io::stdout()).unwrap();
658 let _ = cx.update(|cx| cx.quit());
659 return;
660 }
661 };
662
663 match result {
664 Ok(output) => {
665 println!("{}", output);
666 let _ = cx.update(|cx| cx.quit());
667 }
668 Err(e) => {
669 eprintln!("Failed: {:?}", e);
670 exit(1);
671 }
672 }
673 })
674 .detach();
675 });
676}