1mod headless;
2
3use anyhow::{Result, anyhow};
4use clap::{Args, Parser, Subcommand};
5use cloud_llm_client::predict_edits_v3;
6use edit_prediction_context::{
7 Declaration, EditPredictionContext, EditPredictionExcerptOptions, Identifier, ReferenceRegion,
8 SyntaxIndex, references_in_range,
9};
10use futures::channel::mpsc;
11use futures::{FutureExt as _, StreamExt as _};
12use gpui::{AppContext, Application, AsyncApp};
13use gpui::{Entity, Task};
14use language::Bias;
15use language::Point;
16use language::{Buffer, OffsetRangeExt};
17use language_model::LlmApiToken;
18use ordered_float::OrderedFloat;
19use project::{Project, ProjectPath, Worktree};
20use release_channel::AppVersion;
21use reqwest_client::ReqwestClient;
22use serde_json::json;
23use std::cmp::Reverse;
24use std::collections::HashMap;
25use std::io::Write as _;
26use std::ops::Range;
27use std::path::{Path, PathBuf};
28use std::process::exit;
29use std::str::FromStr;
30use std::sync::Arc;
31use std::time::Duration;
32use util::paths::PathStyle;
33use util::rel_path::RelPath;
34use util::{RangeExt, ResultExt as _};
35use zeta::{PerformPredictEditsParams, Zeta};
36
37use crate::headless::ZetaCliAppState;
38
39#[derive(Parser, Debug)]
40#[command(name = "zeta")]
41struct ZetaCliArgs {
42 #[command(subcommand)]
43 command: Commands,
44}
45
46#[derive(Subcommand, Debug)]
47enum Commands {
48 Context(ContextArgs),
49 Zeta2Context {
50 #[clap(flatten)]
51 zeta2_args: Zeta2Args,
52 #[clap(flatten)]
53 context_args: ContextArgs,
54 },
55 Predict {
56 #[arg(long)]
57 predict_edits_body: Option<FileOrStdin>,
58 #[clap(flatten)]
59 context_args: Option<ContextArgs>,
60 },
61 RetrievalStats {
62 #[arg(long)]
63 worktree: PathBuf,
64 #[arg(long, default_value_t = 42)]
65 file_indexing_parallelism: usize,
66 },
67}
68
69#[derive(Debug, Args)]
70#[group(requires = "worktree")]
71struct ContextArgs {
72 #[arg(long)]
73 worktree: PathBuf,
74 #[arg(long)]
75 cursor: CursorPosition,
76 #[arg(long)]
77 use_language_server: bool,
78 #[arg(long)]
79 events: Option<FileOrStdin>,
80}
81
82#[derive(Debug, Args)]
83struct Zeta2Args {
84 #[arg(long, default_value_t = 8192)]
85 max_prompt_bytes: usize,
86 #[arg(long, default_value_t = 2048)]
87 max_excerpt_bytes: usize,
88 #[arg(long, default_value_t = 1024)]
89 min_excerpt_bytes: usize,
90 #[arg(long, default_value_t = 0.66)]
91 target_before_cursor_over_total_bytes: f32,
92 #[arg(long, default_value_t = 1024)]
93 max_diagnostic_bytes: usize,
94 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
95 prompt_format: PromptFormat,
96 #[arg(long, value_enum, default_value_t = Default::default())]
97 output_format: OutputFormat,
98 #[arg(long, default_value_t = 42)]
99 file_indexing_parallelism: usize,
100}
101
102#[derive(clap::ValueEnum, Default, Debug, Clone)]
103enum PromptFormat {
104 #[default]
105 MarkedExcerpt,
106 LabeledSections,
107 OnlySnippets,
108}
109
110impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
111 fn into(self) -> predict_edits_v3::PromptFormat {
112 match self {
113 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
114 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
115 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
116 }
117 }
118}
119
120#[derive(clap::ValueEnum, Default, Debug, Clone)]
121enum OutputFormat {
122 #[default]
123 Prompt,
124 Request,
125 Full,
126}
127
128#[derive(Debug, Clone)]
129enum FileOrStdin {
130 File(PathBuf),
131 Stdin,
132}
133
134impl FileOrStdin {
135 async fn read_to_string(&self) -> Result<String, std::io::Error> {
136 match self {
137 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
138 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
139 }
140 }
141}
142
143impl FromStr for FileOrStdin {
144 type Err = <PathBuf as FromStr>::Err;
145
146 fn from_str(s: &str) -> Result<Self, Self::Err> {
147 match s {
148 "-" => Ok(Self::Stdin),
149 _ => Ok(Self::File(PathBuf::from_str(s)?)),
150 }
151 }
152}
153
154#[derive(Debug, Clone)]
155struct CursorPosition {
156 path: Arc<RelPath>,
157 point: Point,
158}
159
160impl FromStr for CursorPosition {
161 type Err = anyhow::Error;
162
163 fn from_str(s: &str) -> Result<Self> {
164 let parts: Vec<&str> = s.split(':').collect();
165 if parts.len() != 3 {
166 return Err(anyhow!(
167 "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
168 s
169 ));
170 }
171
172 let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
173 let line: u32 = parts[1]
174 .parse()
175 .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
176 let column: u32 = parts[2]
177 .parse()
178 .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
179
180 // Convert from 1-based to 0-based indexing
181 let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
182
183 Ok(CursorPosition { path, point })
184 }
185}
186
187enum GetContextOutput {
188 Zeta1(zeta::GatherContextOutput),
189 Zeta2(String),
190}
191
192async fn get_context(
193 zeta2_args: Option<Zeta2Args>,
194 args: ContextArgs,
195 app_state: &Arc<ZetaCliAppState>,
196 cx: &mut AsyncApp,
197) -> Result<GetContextOutput> {
198 let ContextArgs {
199 worktree: worktree_path,
200 cursor,
201 use_language_server,
202 events,
203 } = args;
204
205 let worktree_path = worktree_path.canonicalize()?;
206
207 let project = cx.update(|cx| {
208 Project::local(
209 app_state.client.clone(),
210 app_state.node_runtime.clone(),
211 app_state.user_store.clone(),
212 app_state.languages.clone(),
213 app_state.fs.clone(),
214 None,
215 cx,
216 )
217 })?;
218
219 let worktree = project
220 .update(cx, |project, cx| {
221 project.create_worktree(&worktree_path, true, cx)
222 })?
223 .await?;
224
225 let (_lsp_open_handle, buffer) = if use_language_server {
226 let (lsp_open_handle, buffer) =
227 open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
228 (Some(lsp_open_handle), buffer)
229 } else {
230 let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
231 (None, buffer)
232 };
233
234 let full_path_str = worktree
235 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
236 .display(PathStyle::local())
237 .to_string();
238
239 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
240 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
241 if clipped_cursor != cursor.point {
242 let max_row = snapshot.max_point().row;
243 if cursor.point.row < max_row {
244 return Err(anyhow!(
245 "Cursor position {:?} is out of bounds (line length is {})",
246 cursor.point,
247 snapshot.line_len(cursor.point.row)
248 ));
249 } else {
250 return Err(anyhow!(
251 "Cursor position {:?} is out of bounds (max row is {})",
252 cursor.point,
253 max_row
254 ));
255 }
256 }
257
258 let events = match events {
259 Some(events) => events.read_to_string().await?,
260 None => String::new(),
261 };
262
263 if let Some(zeta2_args) = zeta2_args {
264 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
265 // the whole worktree.
266 worktree
267 .read_with(cx, |worktree, _cx| {
268 worktree.as_local().unwrap().scan_complete()
269 })?
270 .await;
271 let output = cx
272 .update(|cx| {
273 let zeta = cx.new(|cx| {
274 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
275 });
276 let indexing_done_task = zeta.update(cx, |zeta, cx| {
277 zeta.set_options(zeta2::ZetaOptions {
278 excerpt: EditPredictionExcerptOptions {
279 max_bytes: zeta2_args.max_excerpt_bytes,
280 min_bytes: zeta2_args.min_excerpt_bytes,
281 target_before_cursor_over_total_bytes: zeta2_args
282 .target_before_cursor_over_total_bytes,
283 },
284 max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
285 max_prompt_bytes: zeta2_args.max_prompt_bytes,
286 prompt_format: zeta2_args.prompt_format.into(),
287 file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
288 });
289 zeta.register_buffer(&buffer, &project, cx);
290 zeta.wait_for_initial_indexing(&project, cx)
291 });
292 cx.spawn(async move |cx| {
293 indexing_done_task.await?;
294 let request = zeta
295 .update(cx, |zeta, cx| {
296 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
297 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
298 })?
299 .await?;
300
301 let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
302 let (prompt_string, section_labels) = planned_prompt.to_prompt_string()?;
303
304 match zeta2_args.output_format {
305 OutputFormat::Prompt => anyhow::Ok(prompt_string),
306 OutputFormat::Request => {
307 anyhow::Ok(serde_json::to_string_pretty(&request)?)
308 }
309 OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
310 "request": request,
311 "prompt": prompt_string,
312 "section_labels": section_labels,
313 }))?),
314 }
315 })
316 })?
317 .await?;
318 Ok(GetContextOutput::Zeta2(output))
319 } else {
320 let prompt_for_events = move || (events, 0);
321 Ok(GetContextOutput::Zeta1(
322 cx.update(|cx| {
323 zeta::gather_context(
324 full_path_str,
325 &snapshot,
326 clipped_cursor,
327 prompt_for_events,
328 cx,
329 )
330 })?
331 .await?,
332 ))
333 }
334}
335
336pub async fn retrieval_stats(
337 worktree: PathBuf,
338 file_indexing_parallelism: usize,
339 app_state: Arc<ZetaCliAppState>,
340 cx: &mut AsyncApp,
341) -> Result<String> {
342 let worktree_path = worktree.canonicalize()?;
343
344 let project = cx.update(|cx| {
345 Project::local(
346 app_state.client.clone(),
347 app_state.node_runtime.clone(),
348 app_state.user_store.clone(),
349 app_state.languages.clone(),
350 app_state.fs.clone(),
351 None,
352 cx,
353 )
354 })?;
355
356 let worktree = project
357 .update(cx, |project, cx| {
358 project.create_worktree(&worktree_path, true, cx)
359 })?
360 .await?;
361 let worktree_id = worktree.read_with(cx, |worktree, _cx| worktree.id())?;
362
363 // wait for worktree scan so that wait_for_initial_file_indexing waits for the whole worktree.
364 worktree
365 .read_with(cx, |worktree, _cx| {
366 worktree.as_local().unwrap().scan_complete()
367 })?
368 .await;
369
370 let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx))?;
371 index
372 .read_with(cx, |index, cx| index.wait_for_initial_file_indexing(cx))?
373 .await?;
374 let files = index
375 .read_with(cx, |index, cx| index.indexed_file_paths(cx))?
376 .await;
377
378 let mut lsp_open_handles = Vec::new();
379 let mut output = std::fs::File::create("retrieval-stats.txt")?;
380 let mut results = Vec::new();
381 for (file_index, project_path) in files.iter().enumerate() {
382 println!(
383 "Processing file {} of {}: {}",
384 file_index + 1,
385 files.len(),
386 project_path.path.display(PathStyle::Posix)
387 );
388 let Some((lsp_open_handle, buffer)) =
389 open_buffer_with_language_server(&project, &worktree, &project_path.path, cx)
390 .await
391 .log_err()
392 else {
393 continue;
394 };
395 lsp_open_handles.push(lsp_open_handle);
396
397 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
398 let full_range = 0..snapshot.len();
399 let references = references_in_range(
400 full_range,
401 &snapshot.text(),
402 ReferenceRegion::Nearby,
403 &snapshot,
404 );
405
406 let index = index.read_with(cx, |index, _cx| index.state().clone())?;
407 let index = index.lock().await;
408 for reference in references {
409 let query_point = snapshot.offset_to_point(reference.range.start);
410 let mut single_reference_map = HashMap::default();
411 single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]);
412 let edit_prediction_context = EditPredictionContext::gather_context_with_references_fn(
413 query_point,
414 &snapshot,
415 &zeta2::DEFAULT_EXCERPT_OPTIONS,
416 Some(&index),
417 |_, _, _| single_reference_map,
418 );
419
420 let Some(edit_prediction_context) = edit_prediction_context else {
421 let result = RetrievalStatsResult {
422 identifier: reference.identifier,
423 point: query_point,
424 outcome: RetrievalStatsOutcome::NoExcerpt,
425 };
426 write!(output, "{:?}\n\n", result)?;
427 results.push(result);
428 continue;
429 };
430
431 let mut retrieved_definitions = Vec::new();
432 for scored_declaration in edit_prediction_context.declarations {
433 match &scored_declaration.declaration {
434 Declaration::File {
435 project_entry_id,
436 declaration,
437 } => {
438 let Some(path) = worktree.read_with(cx, |worktree, _cx| {
439 worktree
440 .entry_for_id(*project_entry_id)
441 .map(|entry| entry.path.clone())
442 })?
443 else {
444 log::error!("bug: file project entry not found");
445 continue;
446 };
447 let project_path = ProjectPath {
448 worktree_id,
449 path: path.clone(),
450 };
451 let buffer = project
452 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
453 .await?;
454 let rope = buffer.read_with(cx, |buffer, _cx| buffer.as_rope().clone())?;
455 retrieved_definitions.push((
456 path,
457 rope.offset_to_point(declaration.item_range.start)
458 ..rope.offset_to_point(declaration.item_range.end),
459 scored_declaration.scores.declaration,
460 scored_declaration.scores.retrieval,
461 ));
462 }
463 Declaration::Buffer {
464 project_entry_id,
465 rope,
466 declaration,
467 ..
468 } => {
469 let Some(path) = worktree.read_with(cx, |worktree, _cx| {
470 worktree
471 .entry_for_id(*project_entry_id)
472 .map(|entry| entry.path.clone())
473 })?
474 else {
475 log::error!("bug: buffer project entry not found");
476 continue;
477 };
478 retrieved_definitions.push((
479 path,
480 rope.offset_to_point(declaration.item_range.start)
481 ..rope.offset_to_point(declaration.item_range.end),
482 scored_declaration.scores.declaration,
483 scored_declaration.scores.retrieval,
484 ));
485 }
486 }
487 }
488 retrieved_definitions
489 .sort_by_key(|(_, _, _, retrieval_score)| Reverse(OrderedFloat(*retrieval_score)));
490
491 // TODO: Consider still checking language server in this case, or having a mode for
492 // this. For now assuming that the purpose of this is to refine the ranking rather than
493 // refining whether the definition is present at all.
494 if retrieved_definitions.is_empty() {
495 continue;
496 }
497
498 // TODO: Rename declaration to definition in edit_prediction_context?
499 let lsp_result = project
500 .update(cx, |project, cx| {
501 project.definitions(&buffer, reference.range.start, cx)
502 })?
503 .await;
504 match lsp_result {
505 Ok(lsp_definitions) => {
506 let lsp_definitions = lsp_definitions
507 .unwrap_or_default()
508 .into_iter()
509 .filter_map(|definition| {
510 definition
511 .target
512 .buffer
513 .read_with(cx, |buffer, _cx| {
514 Some((
515 buffer.file()?.path().clone(),
516 definition.target.range.to_point(&buffer),
517 ))
518 })
519 .ok()?
520 })
521 .collect::<Vec<_>>();
522
523 let result = RetrievalStatsResult {
524 identifier: reference.identifier,
525 point: query_point,
526 outcome: RetrievalStatsOutcome::Success {
527 matches: lsp_definitions
528 .iter()
529 .map(|(path, range)| {
530 retrieved_definitions.iter().position(
531 |(retrieved_path, retrieved_range, _, _)| {
532 path == retrieved_path
533 && retrieved_range.contains_inclusive(&range)
534 },
535 )
536 })
537 .collect(),
538 lsp_definitions,
539 retrieved_definitions,
540 },
541 };
542 write!(output, "{:?}\n\n", result)?;
543 results.push(result);
544 }
545 Err(err) => {
546 let result = RetrievalStatsResult {
547 identifier: reference.identifier,
548 point: query_point,
549 outcome: RetrievalStatsOutcome::LanguageServerError {
550 message: err.to_string(),
551 },
552 };
553 write!(output, "{:?}\n\n", result)?;
554 results.push(result);
555 }
556 }
557 }
558 }
559
560 let mut no_excerpt_count = 0;
561 let mut error_count = 0;
562 let mut definitions_count = 0;
563 let mut top_match_count = 0;
564 let mut non_top_match_count = 0;
565 let mut ranking_involved_count = 0;
566 let mut ranking_involved_top_match_count = 0;
567 let mut ranking_involved_non_top_match_count = 0;
568 for result in &results {
569 match &result.outcome {
570 RetrievalStatsOutcome::NoExcerpt => no_excerpt_count += 1,
571 RetrievalStatsOutcome::LanguageServerError { .. } => error_count += 1,
572 RetrievalStatsOutcome::Success {
573 matches,
574 retrieved_definitions,
575 ..
576 } => {
577 definitions_count += 1;
578 let top_matches = matches.contains(&Some(0));
579 if top_matches {
580 top_match_count += 1;
581 }
582 let non_top_matches = !top_matches && matches.iter().any(|index| *index != Some(0));
583 if non_top_matches {
584 non_top_match_count += 1;
585 }
586 if retrieved_definitions.len() > 1 {
587 ranking_involved_count += 1;
588 if top_matches {
589 ranking_involved_top_match_count += 1;
590 }
591 if non_top_matches {
592 ranking_involved_non_top_match_count += 1;
593 }
594 }
595 }
596 }
597 }
598
599 println!("\nStats:\n");
600 println!("No Excerpt: {}", no_excerpt_count);
601 println!("Language Server Error: {}", error_count);
602 println!("Definitions: {}", definitions_count);
603 println!("Top Match: {}", top_match_count);
604 println!("Non-Top Match: {}", non_top_match_count);
605 println!("Ranking Involved: {}", ranking_involved_count);
606 println!(
607 "Ranking Involved Top Match: {}",
608 ranking_involved_top_match_count
609 );
610 println!(
611 "Ranking Involved Non-Top Match: {}",
612 ranking_involved_non_top_match_count
613 );
614
615 Ok("".to_string())
616}
617
618#[derive(Debug)]
619struct RetrievalStatsResult {
620 #[allow(dead_code)]
621 identifier: Identifier,
622 #[allow(dead_code)]
623 point: Point,
624 outcome: RetrievalStatsOutcome,
625}
626
627#[derive(Debug)]
628enum RetrievalStatsOutcome {
629 NoExcerpt,
630 LanguageServerError {
631 #[allow(dead_code)]
632 message: String,
633 },
634 Success {
635 matches: Vec<Option<usize>>,
636 #[allow(dead_code)]
637 lsp_definitions: Vec<(Arc<RelPath>, Range<Point>)>,
638 retrieved_definitions: Vec<(Arc<RelPath>, Range<Point>, f32, f32)>,
639 },
640}
641
642pub async fn open_buffer(
643 project: &Entity<Project>,
644 worktree: &Entity<Worktree>,
645 path: &RelPath,
646 cx: &mut AsyncApp,
647) -> Result<Entity<Buffer>> {
648 let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
649 worktree_id: worktree.id(),
650 path: path.into(),
651 })?;
652
653 project
654 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
655 .await
656}
657
658pub async fn open_buffer_with_language_server(
659 project: &Entity<Project>,
660 worktree: &Entity<Worktree>,
661 path: &RelPath,
662 cx: &mut AsyncApp,
663) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
664 let buffer = open_buffer(project, worktree, path, cx).await?;
665
666 let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
667 (
668 project.register_buffer_with_language_servers(&buffer, cx),
669 project.path_style(cx),
670 )
671 })?;
672
673 let log_prefix = path.display(path_style);
674 wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
675
676 Ok((lsp_open_handle, buffer))
677}
678
679// TODO: Dedupe with similar function in crates/eval/src/instance.rs
680pub fn wait_for_lang_server(
681 project: &Entity<Project>,
682 buffer: &Entity<Buffer>,
683 log_prefix: String,
684 cx: &mut AsyncApp,
685) -> Task<Result<()>> {
686 println!("{}⏵ Waiting for language server", log_prefix);
687
688 let (mut tx, mut rx) = mpsc::channel(1);
689
690 let lsp_store = project
691 .read_with(cx, |project, _| project.lsp_store())
692 .unwrap();
693
694 let has_lang_server = buffer
695 .update(cx, |buffer, cx| {
696 lsp_store.update(cx, |lsp_store, cx| {
697 lsp_store
698 .language_servers_for_local_buffer(buffer, cx)
699 .next()
700 .is_some()
701 })
702 })
703 .unwrap_or(false);
704
705 if has_lang_server {
706 project
707 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
708 .unwrap()
709 .detach();
710 }
711 let (mut added_tx, mut added_rx) = mpsc::channel(1);
712
713 let subscriptions = [
714 cx.subscribe(&lsp_store, {
715 let log_prefix = log_prefix.clone();
716 move |_, event, _| {
717 if let project::LspStoreEvent::LanguageServerUpdate {
718 message:
719 client::proto::update_language_server::Variant::WorkProgress(
720 client::proto::LspWorkProgress {
721 message: Some(message),
722 ..
723 },
724 ),
725 ..
726 } = event
727 {
728 println!("{}⟲ {message}", log_prefix)
729 }
730 }
731 }),
732 cx.subscribe(project, {
733 let buffer = buffer.clone();
734 move |project, event, cx| match event {
735 project::Event::LanguageServerAdded(_, _, _) => {
736 let buffer = buffer.clone();
737 project
738 .update(cx, |project, cx| project.save_buffer(buffer, cx))
739 .detach();
740 added_tx.try_send(()).ok();
741 }
742 project::Event::DiskBasedDiagnosticsFinished { .. } => {
743 tx.try_send(()).ok();
744 }
745 _ => {}
746 }
747 }),
748 ];
749
750 cx.spawn(async move |cx| {
751 if !has_lang_server {
752 // some buffers never have a language server, so this aborts quickly in that case.
753 let timeout = cx.background_executor().timer(Duration::from_secs(1));
754 futures::select! {
755 _ = added_rx.next() => {},
756 _ = timeout.fuse() => {
757 anyhow::bail!("Waiting for language server add timed out after 1 second");
758 }
759 };
760 }
761 let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
762 let result = futures::select! {
763 _ = rx.next() => {
764 println!("{}⚑ Language server idle", log_prefix);
765 anyhow::Ok(())
766 },
767 _ = timeout.fuse() => {
768 anyhow::bail!("LSP wait timed out after 5 minutes");
769 }
770 };
771 drop(subscriptions);
772 result
773 })
774}
775
776fn main() {
777 zlog::init();
778 zlog::init_output_stderr();
779 let args = ZetaCliArgs::parse();
780 let http_client = Arc::new(ReqwestClient::new());
781 let app = Application::headless().with_http_client(http_client);
782
783 app.run(move |cx| {
784 let app_state = Arc::new(headless::init(cx));
785 cx.spawn(async move |cx| {
786 let result = match args.command {
787 Commands::Zeta2Context {
788 zeta2_args,
789 context_args,
790 } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
791 Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
792 Ok(GetContextOutput::Zeta2(output)) => Ok(output),
793 Err(err) => Err(err),
794 },
795 Commands::Context(context_args) => {
796 match get_context(None, context_args, &app_state, cx).await {
797 Ok(GetContextOutput::Zeta1(output)) => {
798 Ok(serde_json::to_string_pretty(&output.body).unwrap())
799 }
800 Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
801 Err(err) => Err(err),
802 }
803 }
804 Commands::Predict {
805 predict_edits_body,
806 context_args,
807 } => {
808 cx.spawn(async move |cx| {
809 let app_version = cx.update(|cx| AppVersion::global(cx))?;
810 app_state.client.sign_in(true, cx).await?;
811 let llm_token = LlmApiToken::default();
812 llm_token.refresh(&app_state.client).await?;
813
814 let predict_edits_body =
815 if let Some(predict_edits_body) = predict_edits_body {
816 serde_json::from_str(&predict_edits_body.read_to_string().await?)?
817 } else if let Some(context_args) = context_args {
818 match get_context(None, context_args, &app_state, cx).await? {
819 GetContextOutput::Zeta1(output) => output.body,
820 GetContextOutput::Zeta2 { .. } => unreachable!(),
821 }
822 } else {
823 return Err(anyhow!(
824 "Expected either --predict-edits-body-file \
825 or the required args of the `context` command."
826 ));
827 };
828
829 let (response, _usage) =
830 Zeta::perform_predict_edits(PerformPredictEditsParams {
831 client: app_state.client.clone(),
832 llm_token,
833 app_version,
834 body: predict_edits_body,
835 })
836 .await?;
837
838 Ok(response.output_excerpt)
839 })
840 .await
841 }
842 Commands::RetrievalStats {
843 worktree,
844 file_indexing_parallelism,
845 } => retrieval_stats(worktree, file_indexing_parallelism, app_state, cx).await,
846 };
847 match result {
848 Ok(output) => {
849 println!("{}", output);
850 let _ = cx.update(|cx| cx.quit());
851 }
852 Err(e) => {
853 eprintln!("Failed: {:?}", e);
854 exit(1);
855 }
856 }
857 })
858 .detach();
859 });
860}