1mod evaluate;
2mod example;
3mod headless;
4mod predict;
5mod source_location;
6mod syntax_retrieval_stats;
7mod util;
8
9use crate::evaluate::{EvaluateArguments, run_evaluate};
10use crate::example::{ExampleFormat, NamedExample};
11use crate::predict::{PredictArguments, run_zeta2_predict};
12use crate::syntax_retrieval_stats::retrieval_stats;
13use ::serde::Serialize;
14use ::util::paths::PathStyle;
15use anyhow::{Context as _, Result, anyhow};
16use clap::{Args, Parser, Subcommand};
17use cloud_llm_client::predict_edits_v3::{self, Excerpt};
18use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
19use edit_prediction_context::{
20 EditPredictionContextOptions, EditPredictionExcerpt, EditPredictionExcerptOptions,
21 EditPredictionScoreOptions, Line,
22};
23use futures::StreamExt as _;
24use futures::channel::mpsc;
25use gpui::{Application, AsyncApp, Entity, prelude::*};
26use language::{Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point};
27use language_model::LanguageModelRegistry;
28use project::{Project, Worktree};
29use reqwest_client::ReqwestClient;
30use serde_json::json;
31use std::io::{self};
32use std::time::Duration;
33use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
34use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery};
35
36use crate::headless::ZetaCliAppState;
37use crate::source_location::SourceLocation;
38use crate::util::{open_buffer, open_buffer_with_language_server};
39
40#[derive(Parser, Debug)]
41#[command(name = "zeta")]
42struct ZetaCliArgs {
43 #[command(subcommand)]
44 command: Command,
45}
46
47#[derive(Subcommand, Debug)]
48enum Command {
49 Zeta1 {
50 #[command(subcommand)]
51 command: Zeta1Command,
52 },
53 Zeta2 {
54 #[command(subcommand)]
55 command: Zeta2Command,
56 },
57 ConvertExample {
58 path: PathBuf,
59 #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
60 output_format: ExampleFormat,
61 },
62}
63
64#[derive(Subcommand, Debug)]
65enum Zeta1Command {
66 Context {
67 #[clap(flatten)]
68 context_args: ContextArgs,
69 },
70}
71
72#[derive(Subcommand, Debug)]
73enum Zeta2Command {
74 Syntax {
75 #[clap(flatten)]
76 args: Zeta2Args,
77 #[clap(flatten)]
78 syntax_args: Zeta2SyntaxArgs,
79 #[command(subcommand)]
80 command: Zeta2SyntaxCommand,
81 },
82 Llm {
83 #[clap(flatten)]
84 args: Zeta2Args,
85 #[command(subcommand)]
86 command: Zeta2LlmCommand,
87 },
88 Predict(PredictArguments),
89 Eval(EvaluateArguments),
90}
91
92#[derive(Subcommand, Debug)]
93enum Zeta2SyntaxCommand {
94 Context {
95 #[clap(flatten)]
96 context_args: ContextArgs,
97 },
98 Stats {
99 #[arg(long)]
100 worktree: PathBuf,
101 #[arg(long)]
102 extension: Option<String>,
103 #[arg(long)]
104 limit: Option<usize>,
105 #[arg(long)]
106 skip: Option<usize>,
107 },
108}
109
110#[derive(Subcommand, Debug)]
111enum Zeta2LlmCommand {
112 Context {
113 #[clap(flatten)]
114 context_args: ContextArgs,
115 },
116}
117
118#[derive(Debug, Args)]
119#[group(requires = "worktree")]
120struct ContextArgs {
121 #[arg(long)]
122 worktree: PathBuf,
123 #[arg(long)]
124 cursor: SourceLocation,
125 #[arg(long)]
126 use_language_server: bool,
127 #[arg(long)]
128 edit_history: Option<FileOrStdin>,
129}
130
131#[derive(Debug, Args)]
132struct Zeta2Args {
133 #[arg(long, default_value_t = 8192)]
134 max_prompt_bytes: usize,
135 #[arg(long, default_value_t = 2048)]
136 max_excerpt_bytes: usize,
137 #[arg(long, default_value_t = 1024)]
138 min_excerpt_bytes: usize,
139 #[arg(long, default_value_t = 0.66)]
140 target_before_cursor_over_total_bytes: f32,
141 #[arg(long, default_value_t = 1024)]
142 max_diagnostic_bytes: usize,
143 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
144 prompt_format: PromptFormat,
145 #[arg(long, value_enum, default_value_t = Default::default())]
146 output_format: OutputFormat,
147 #[arg(long, default_value_t = 42)]
148 file_indexing_parallelism: usize,
149}
150
151#[derive(Debug, Args)]
152struct Zeta2SyntaxArgs {
153 #[arg(long, default_value_t = false)]
154 disable_imports_gathering: bool,
155 #[arg(long, default_value_t = u8::MAX)]
156 max_retrieved_definitions: u8,
157}
158
159fn syntax_args_to_options(
160 zeta2_args: &Zeta2Args,
161 syntax_args: &Zeta2SyntaxArgs,
162 omit_excerpt_overlaps: bool,
163) -> zeta2::ZetaOptions {
164 zeta2::ZetaOptions {
165 context: ContextMode::Syntax(EditPredictionContextOptions {
166 max_retrieved_declarations: syntax_args.max_retrieved_definitions,
167 use_imports: !syntax_args.disable_imports_gathering,
168 excerpt: EditPredictionExcerptOptions {
169 max_bytes: zeta2_args.max_excerpt_bytes,
170 min_bytes: zeta2_args.min_excerpt_bytes,
171 target_before_cursor_over_total_bytes: zeta2_args
172 .target_before_cursor_over_total_bytes,
173 },
174 score: EditPredictionScoreOptions {
175 omit_excerpt_overlaps,
176 },
177 }),
178 max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
179 max_prompt_bytes: zeta2_args.max_prompt_bytes,
180 prompt_format: zeta2_args.prompt_format.clone().into(),
181 file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
182 buffer_change_grouping_interval: Duration::ZERO,
183 }
184}
185
186#[derive(clap::ValueEnum, Default, Debug, Clone)]
187enum PromptFormat {
188 MarkedExcerpt,
189 LabeledSections,
190 OnlySnippets,
191 #[default]
192 NumberedLines,
193}
194
195impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
196 fn into(self) -> predict_edits_v3::PromptFormat {
197 match self {
198 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
199 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
200 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
201 Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
202 }
203 }
204}
205
206#[derive(clap::ValueEnum, Default, Debug, Clone)]
207enum OutputFormat {
208 #[default]
209 Prompt,
210 Request,
211 Full,
212}
213
214#[derive(Debug, Clone)]
215enum FileOrStdin {
216 File(PathBuf),
217 Stdin,
218}
219
220impl FileOrStdin {
221 async fn read_to_string(&self) -> Result<String, std::io::Error> {
222 match self {
223 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
224 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
225 }
226 }
227}
228
229impl FromStr for FileOrStdin {
230 type Err = <PathBuf as FromStr>::Err;
231
232 fn from_str(s: &str) -> Result<Self, Self::Err> {
233 match s {
234 "-" => Ok(Self::Stdin),
235 _ => Ok(Self::File(PathBuf::from_str(s)?)),
236 }
237 }
238}
239
240struct LoadedContext {
241 full_path_str: String,
242 snapshot: BufferSnapshot,
243 clipped_cursor: Point,
244 worktree: Entity<Worktree>,
245 project: Entity<Project>,
246 buffer: Entity<Buffer>,
247}
248
249async fn load_context(
250 args: &ContextArgs,
251 app_state: &Arc<ZetaCliAppState>,
252 cx: &mut AsyncApp,
253) -> Result<LoadedContext> {
254 let ContextArgs {
255 worktree: worktree_path,
256 cursor,
257 use_language_server,
258 ..
259 } = args;
260
261 let worktree_path = worktree_path.canonicalize()?;
262
263 let project = cx.update(|cx| {
264 Project::local(
265 app_state.client.clone(),
266 app_state.node_runtime.clone(),
267 app_state.user_store.clone(),
268 app_state.languages.clone(),
269 app_state.fs.clone(),
270 None,
271 cx,
272 )
273 })?;
274
275 let worktree = project
276 .update(cx, |project, cx| {
277 project.create_worktree(&worktree_path, true, cx)
278 })?
279 .await?;
280
281 let mut ready_languages = HashSet::default();
282 let (_lsp_open_handle, buffer) = if *use_language_server {
283 let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
284 project.clone(),
285 worktree.clone(),
286 cursor.path.clone(),
287 &mut ready_languages,
288 cx,
289 )
290 .await?;
291 (Some(lsp_open_handle), buffer)
292 } else {
293 let buffer =
294 open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
295 (None, buffer)
296 };
297
298 let full_path_str = worktree
299 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
300 .display(PathStyle::local())
301 .to_string();
302
303 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
304 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
305 if clipped_cursor != cursor.point {
306 let max_row = snapshot.max_point().row;
307 if cursor.point.row < max_row {
308 return Err(anyhow!(
309 "Cursor position {:?} is out of bounds (line length is {})",
310 cursor.point,
311 snapshot.line_len(cursor.point.row)
312 ));
313 } else {
314 return Err(anyhow!(
315 "Cursor position {:?} is out of bounds (max row is {})",
316 cursor.point,
317 max_row
318 ));
319 }
320 }
321
322 Ok(LoadedContext {
323 full_path_str,
324 snapshot,
325 clipped_cursor,
326 worktree,
327 project,
328 buffer,
329 })
330}
331
332async fn zeta2_syntax_context(
333 zeta2_args: Zeta2Args,
334 syntax_args: Zeta2SyntaxArgs,
335 args: ContextArgs,
336 app_state: &Arc<ZetaCliAppState>,
337 cx: &mut AsyncApp,
338) -> Result<String> {
339 let LoadedContext {
340 worktree,
341 project,
342 buffer,
343 clipped_cursor,
344 ..
345 } = load_context(&args, app_state, cx).await?;
346
347 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
348 // the whole worktree.
349 worktree
350 .read_with(cx, |worktree, _cx| {
351 worktree.as_local().unwrap().scan_complete()
352 })?
353 .await;
354 let output = cx
355 .update(|cx| {
356 let zeta = cx.new(|cx| {
357 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
358 });
359 let indexing_done_task = zeta.update(cx, |zeta, cx| {
360 zeta.set_options(syntax_args_to_options(&zeta2_args, &syntax_args, true));
361 zeta.register_buffer(&buffer, &project, cx);
362 zeta.wait_for_initial_indexing(&project, cx)
363 });
364 cx.spawn(async move |cx| {
365 indexing_done_task.await?;
366 let request = zeta
367 .update(cx, |zeta, cx| {
368 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
369 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
370 })?
371 .await?;
372
373 let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
374
375 match zeta2_args.output_format {
376 OutputFormat::Prompt => anyhow::Ok(prompt_string),
377 OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
378 OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
379 "request": request,
380 "prompt": prompt_string,
381 "section_labels": section_labels,
382 }))?),
383 }
384 })
385 })?
386 .await?;
387
388 Ok(output)
389}
390
391async fn zeta2_llm_context(
392 zeta2_args: Zeta2Args,
393 context_args: ContextArgs,
394 app_state: &Arc<ZetaCliAppState>,
395 cx: &mut AsyncApp,
396) -> Result<String> {
397 let LoadedContext {
398 buffer,
399 clipped_cursor,
400 snapshot: cursor_snapshot,
401 project,
402 ..
403 } = load_context(&context_args, app_state, cx).await?;
404
405 let cursor_position = cursor_snapshot.anchor_after(clipped_cursor);
406
407 cx.update(|cx| {
408 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
409 registry
410 .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
411 .unwrap()
412 .authenticate(cx)
413 })
414 })?
415 .await?;
416
417 let edit_history_unified_diff = match context_args.edit_history {
418 Some(events) => events.read_to_string().await?,
419 None => String::new(),
420 };
421
422 let (debug_tx, mut debug_rx) = mpsc::unbounded();
423
424 let excerpt_options = EditPredictionExcerptOptions {
425 max_bytes: zeta2_args.max_excerpt_bytes,
426 min_bytes: zeta2_args.min_excerpt_bytes,
427 target_before_cursor_over_total_bytes: zeta2_args.target_before_cursor_over_total_bytes,
428 };
429
430 let related_excerpts = cx
431 .update(|cx| {
432 zeta2::related_excerpts::find_related_excerpts(
433 buffer,
434 cursor_position,
435 &project,
436 edit_history_unified_diff,
437 &LlmContextOptions {
438 excerpt: excerpt_options.clone(),
439 },
440 Some(debug_tx),
441 cx,
442 )
443 })?
444 .await?;
445
446 let cursor_excerpt = EditPredictionExcerpt::select_from_buffer(
447 clipped_cursor,
448 &cursor_snapshot,
449 &excerpt_options,
450 None,
451 )
452 .context("line didn't fit")?;
453
454 #[derive(Serialize)]
455 struct Output {
456 excerpts: Vec<OutputExcerpt>,
457 formatted_excerpts: String,
458 meta: OutputMeta,
459 }
460
461 #[derive(Default, Serialize)]
462 struct OutputMeta {
463 search_prompt: String,
464 search_queries: Vec<SearchToolQuery>,
465 }
466
467 #[derive(Serialize)]
468 struct OutputExcerpt {
469 path: PathBuf,
470 #[serde(flatten)]
471 excerpt: Excerpt,
472 }
473
474 let mut meta = OutputMeta::default();
475
476 while let Some(debug_info) = debug_rx.next().await {
477 match debug_info {
478 zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
479 meta.search_prompt = info.search_prompt;
480 }
481 zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
482 meta.search_queries = info.queries
483 }
484 _ => {}
485 }
486 }
487
488 cx.update(|cx| {
489 let mut excerpts = Vec::new();
490 let mut formatted_excerpts = String::new();
491
492 let cursor_insertions = [(
493 predict_edits_v3::Point {
494 line: Line(clipped_cursor.row),
495 column: clipped_cursor.column,
496 },
497 CURSOR_MARKER,
498 )];
499
500 let mut cursor_excerpt_added = false;
501
502 for (buffer, ranges) in related_excerpts {
503 let excerpt_snapshot = buffer.read(cx).snapshot();
504
505 let mut line_ranges = ranges
506 .into_iter()
507 .map(|range| {
508 let point_range = range.to_point(&excerpt_snapshot);
509 Line(point_range.start.row)..Line(point_range.end.row)
510 })
511 .collect::<Vec<_>>();
512
513 let Some(file) = excerpt_snapshot.file() else {
514 continue;
515 };
516 let path = file.full_path(cx);
517
518 let is_cursor_file = path == cursor_snapshot.file().unwrap().full_path(cx);
519 if is_cursor_file {
520 let insertion_ix = line_ranges
521 .binary_search_by(|probe| {
522 probe
523 .start
524 .cmp(&cursor_excerpt.line_range.start)
525 .then(cursor_excerpt.line_range.end.cmp(&probe.end))
526 })
527 .unwrap_or_else(|ix| ix);
528 line_ranges.insert(insertion_ix, cursor_excerpt.line_range.clone());
529 cursor_excerpt_added = true;
530 }
531
532 let merged_excerpts =
533 zeta2::merge_excerpts::merge_excerpts(&excerpt_snapshot, line_ranges)
534 .into_iter()
535 .map(|excerpt| OutputExcerpt {
536 path: path.clone(),
537 excerpt,
538 });
539
540 let excerpt_start_ix = excerpts.len();
541 excerpts.extend(merged_excerpts);
542
543 write_codeblock(
544 &path,
545 excerpts[excerpt_start_ix..].iter().map(|e| &e.excerpt),
546 if is_cursor_file {
547 &cursor_insertions
548 } else {
549 &[]
550 },
551 Line(excerpt_snapshot.max_point().row),
552 true,
553 &mut formatted_excerpts,
554 );
555 }
556
557 if !cursor_excerpt_added {
558 write_codeblock(
559 &cursor_snapshot.file().unwrap().full_path(cx),
560 &[Excerpt {
561 start_line: cursor_excerpt.line_range.start,
562 text: cursor_excerpt.text(&cursor_snapshot).body.into(),
563 }],
564 &cursor_insertions,
565 Line(cursor_snapshot.max_point().row),
566 true,
567 &mut formatted_excerpts,
568 );
569 }
570
571 let output = Output {
572 excerpts,
573 formatted_excerpts,
574 meta,
575 };
576
577 Ok(serde_json::to_string_pretty(&output)?)
578 })
579 .unwrap()
580}
581
582async fn zeta1_context(
583 args: ContextArgs,
584 app_state: &Arc<ZetaCliAppState>,
585 cx: &mut AsyncApp,
586) -> Result<zeta::GatherContextOutput> {
587 let LoadedContext {
588 full_path_str,
589 snapshot,
590 clipped_cursor,
591 ..
592 } = load_context(&args, app_state, cx).await?;
593
594 let events = match args.edit_history {
595 Some(events) => events.read_to_string().await?,
596 None => String::new(),
597 };
598
599 let prompt_for_events = move || (events, 0);
600 cx.update(|cx| {
601 zeta::gather_context(
602 full_path_str,
603 &snapshot,
604 clipped_cursor,
605 prompt_for_events,
606 cx,
607 )
608 })?
609 .await
610}
611
612fn main() {
613 zlog::init();
614 zlog::init_output_stderr();
615 let args = ZetaCliArgs::parse();
616 let http_client = Arc::new(ReqwestClient::new());
617 let app = Application::headless().with_http_client(http_client);
618
619 app.run(move |cx| {
620 let app_state = Arc::new(headless::init(cx));
621 cx.spawn(async move |cx| {
622 match args.command {
623 Command::Zeta1 {
624 command: Zeta1Command::Context { context_args },
625 } => {
626 let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
627 let result = serde_json::to_string_pretty(&context.body).unwrap();
628 println!("{}", result);
629 }
630 Command::Zeta2 { command } => match command {
631 Zeta2Command::Predict(arguments) => {
632 run_zeta2_predict(arguments, &app_state, cx).await;
633 }
634 Zeta2Command::Eval(arguments) => {
635 run_evaluate(arguments, &app_state, cx).await;
636 }
637 Zeta2Command::Syntax {
638 args,
639 syntax_args,
640 command,
641 } => {
642 let result = match command {
643 Zeta2SyntaxCommand::Context { context_args } => {
644 zeta2_syntax_context(
645 args,
646 syntax_args,
647 context_args,
648 &app_state,
649 cx,
650 )
651 .await
652 }
653 Zeta2SyntaxCommand::Stats {
654 worktree,
655 extension,
656 limit,
657 skip,
658 } => {
659 retrieval_stats(
660 worktree,
661 app_state,
662 extension,
663 limit,
664 skip,
665 syntax_args_to_options(&args, &syntax_args, false),
666 cx,
667 )
668 .await
669 }
670 };
671 println!("{}", result.unwrap());
672 }
673 Zeta2Command::Llm { args, command } => match command {
674 Zeta2LlmCommand::Context { context_args } => {
675 let result =
676 zeta2_llm_context(args, context_args, &app_state, cx).await;
677 println!("{}", result.unwrap());
678 }
679 },
680 },
681 Command::ConvertExample {
682 path,
683 output_format,
684 } => {
685 let example = NamedExample::load(path).unwrap();
686 example.write(output_format, io::stdout()).unwrap();
687 }
688 };
689
690 let _ = cx.update(|cx| cx.quit());
691 })
692 .detach();
693 });
694}