1mod headless;
2
3use anyhow::{Result, anyhow};
4use clap::{Args, Parser, Subcommand};
5use cloud_llm_client::predict_edits_v3;
6use edit_prediction_context::{
7 Declaration, EditPredictionContext, EditPredictionExcerptOptions, Identifier, ReferenceRegion,
8 SyntaxIndex, references_in_range,
9};
10use futures::channel::mpsc;
11use futures::{FutureExt as _, StreamExt as _};
12use gpui::{AppContext, Application, AsyncApp};
13use gpui::{Entity, Task};
14use language::{Bias, LanguageServerId};
15use language::{Buffer, OffsetRangeExt};
16use language::{LanguageId, Point};
17use language_model::LlmApiToken;
18use ordered_float::OrderedFloat;
19use project::{Project, ProjectPath, Worktree};
20use release_channel::AppVersion;
21use reqwest_client::ReqwestClient;
22use serde_json::json;
23use std::cmp::Reverse;
24use std::collections::{HashMap, HashSet};
25use std::io::Write as _;
26use std::ops::Range;
27use std::path::{Path, PathBuf};
28use std::process::exit;
29use std::str::FromStr;
30use std::sync::Arc;
31use std::time::Duration;
32use util::paths::PathStyle;
33use util::rel_path::RelPath;
34use util::{RangeExt, ResultExt as _};
35use zeta::{PerformPredictEditsParams, Zeta};
36
37use crate::headless::ZetaCliAppState;
38
39#[derive(Parser, Debug)]
40#[command(name = "zeta")]
41struct ZetaCliArgs {
42 #[command(subcommand)]
43 command: Commands,
44}
45
46#[derive(Subcommand, Debug)]
47enum Commands {
48 Context(ContextArgs),
49 Zeta2Context {
50 #[clap(flatten)]
51 zeta2_args: Zeta2Args,
52 #[clap(flatten)]
53 context_args: ContextArgs,
54 },
55 Predict {
56 #[arg(long)]
57 predict_edits_body: Option<FileOrStdin>,
58 #[clap(flatten)]
59 context_args: Option<ContextArgs>,
60 },
61 RetrievalStats {
62 #[arg(long)]
63 worktree: PathBuf,
64 #[arg(long, default_value_t = 42)]
65 file_indexing_parallelism: usize,
66 },
67}
68
69#[derive(Debug, Args)]
70#[group(requires = "worktree")]
71struct ContextArgs {
72 #[arg(long)]
73 worktree: PathBuf,
74 #[arg(long)]
75 cursor: CursorPosition,
76 #[arg(long)]
77 use_language_server: bool,
78 #[arg(long)]
79 events: Option<FileOrStdin>,
80}
81
82#[derive(Debug, Args)]
83struct Zeta2Args {
84 #[arg(long, default_value_t = 8192)]
85 max_prompt_bytes: usize,
86 #[arg(long, default_value_t = 2048)]
87 max_excerpt_bytes: usize,
88 #[arg(long, default_value_t = 1024)]
89 min_excerpt_bytes: usize,
90 #[arg(long, default_value_t = 0.66)]
91 target_before_cursor_over_total_bytes: f32,
92 #[arg(long, default_value_t = 1024)]
93 max_diagnostic_bytes: usize,
94 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
95 prompt_format: PromptFormat,
96 #[arg(long, value_enum, default_value_t = Default::default())]
97 output_format: OutputFormat,
98 #[arg(long, default_value_t = 42)]
99 file_indexing_parallelism: usize,
100}
101
102#[derive(clap::ValueEnum, Default, Debug, Clone)]
103enum PromptFormat {
104 #[default]
105 MarkedExcerpt,
106 LabeledSections,
107 OnlySnippets,
108}
109
110impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
111 fn into(self) -> predict_edits_v3::PromptFormat {
112 match self {
113 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
114 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
115 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
116 }
117 }
118}
119
120#[derive(clap::ValueEnum, Default, Debug, Clone)]
121enum OutputFormat {
122 #[default]
123 Prompt,
124 Request,
125 Full,
126}
127
128#[derive(Debug, Clone)]
129enum FileOrStdin {
130 File(PathBuf),
131 Stdin,
132}
133
134impl FileOrStdin {
135 async fn read_to_string(&self) -> Result<String, std::io::Error> {
136 match self {
137 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
138 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
139 }
140 }
141}
142
143impl FromStr for FileOrStdin {
144 type Err = <PathBuf as FromStr>::Err;
145
146 fn from_str(s: &str) -> Result<Self, Self::Err> {
147 match s {
148 "-" => Ok(Self::Stdin),
149 _ => Ok(Self::File(PathBuf::from_str(s)?)),
150 }
151 }
152}
153
154#[derive(Debug, Clone)]
155struct CursorPosition {
156 path: Arc<RelPath>,
157 point: Point,
158}
159
160impl FromStr for CursorPosition {
161 type Err = anyhow::Error;
162
163 fn from_str(s: &str) -> Result<Self> {
164 let parts: Vec<&str> = s.split(':').collect();
165 if parts.len() != 3 {
166 return Err(anyhow!(
167 "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
168 s
169 ));
170 }
171
172 let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
173 let line: u32 = parts[1]
174 .parse()
175 .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
176 let column: u32 = parts[2]
177 .parse()
178 .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
179
180 // Convert from 1-based to 0-based indexing
181 let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
182
183 Ok(CursorPosition { path, point })
184 }
185}
186
187enum GetContextOutput {
188 Zeta1(zeta::GatherContextOutput),
189 Zeta2(String),
190}
191
192async fn get_context(
193 zeta2_args: Option<Zeta2Args>,
194 args: ContextArgs,
195 app_state: &Arc<ZetaCliAppState>,
196 cx: &mut AsyncApp,
197) -> Result<GetContextOutput> {
198 let ContextArgs {
199 worktree: worktree_path,
200 cursor,
201 use_language_server,
202 events,
203 } = args;
204
205 let worktree_path = worktree_path.canonicalize()?;
206
207 let project = cx.update(|cx| {
208 Project::local(
209 app_state.client.clone(),
210 app_state.node_runtime.clone(),
211 app_state.user_store.clone(),
212 app_state.languages.clone(),
213 app_state.fs.clone(),
214 None,
215 cx,
216 )
217 })?;
218
219 let worktree = project
220 .update(cx, |project, cx| {
221 project.create_worktree(&worktree_path, true, cx)
222 })?
223 .await?;
224
225 let mut ready_languages = HashSet::default();
226 let (_lsp_open_handle, buffer) = if use_language_server {
227 let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
228 &project,
229 &worktree,
230 &cursor.path,
231 &mut ready_languages,
232 cx,
233 )
234 .await?;
235 (Some(lsp_open_handle), buffer)
236 } else {
237 let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
238 (None, buffer)
239 };
240
241 let full_path_str = worktree
242 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
243 .display(PathStyle::local())
244 .to_string();
245
246 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
247 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
248 if clipped_cursor != cursor.point {
249 let max_row = snapshot.max_point().row;
250 if cursor.point.row < max_row {
251 return Err(anyhow!(
252 "Cursor position {:?} is out of bounds (line length is {})",
253 cursor.point,
254 snapshot.line_len(cursor.point.row)
255 ));
256 } else {
257 return Err(anyhow!(
258 "Cursor position {:?} is out of bounds (max row is {})",
259 cursor.point,
260 max_row
261 ));
262 }
263 }
264
265 let events = match events {
266 Some(events) => events.read_to_string().await?,
267 None => String::new(),
268 };
269
270 if let Some(zeta2_args) = zeta2_args {
271 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
272 // the whole worktree.
273 worktree
274 .read_with(cx, |worktree, _cx| {
275 worktree.as_local().unwrap().scan_complete()
276 })?
277 .await;
278 let output = cx
279 .update(|cx| {
280 let zeta = cx.new(|cx| {
281 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
282 });
283 let indexing_done_task = zeta.update(cx, |zeta, cx| {
284 zeta.set_options(zeta2::ZetaOptions {
285 excerpt: EditPredictionExcerptOptions {
286 max_bytes: zeta2_args.max_excerpt_bytes,
287 min_bytes: zeta2_args.min_excerpt_bytes,
288 target_before_cursor_over_total_bytes: zeta2_args
289 .target_before_cursor_over_total_bytes,
290 },
291 max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
292 max_prompt_bytes: zeta2_args.max_prompt_bytes,
293 prompt_format: zeta2_args.prompt_format.into(),
294 file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
295 });
296 zeta.register_buffer(&buffer, &project, cx);
297 zeta.wait_for_initial_indexing(&project, cx)
298 });
299 cx.spawn(async move |cx| {
300 indexing_done_task.await?;
301 let request = zeta
302 .update(cx, |zeta, cx| {
303 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
304 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
305 })?
306 .await?;
307
308 let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
309 let (prompt_string, section_labels) = planned_prompt.to_prompt_string()?;
310
311 match zeta2_args.output_format {
312 OutputFormat::Prompt => anyhow::Ok(prompt_string),
313 OutputFormat::Request => {
314 anyhow::Ok(serde_json::to_string_pretty(&request)?)
315 }
316 OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
317 "request": request,
318 "prompt": prompt_string,
319 "section_labels": section_labels,
320 }))?),
321 }
322 })
323 })?
324 .await?;
325 Ok(GetContextOutput::Zeta2(output))
326 } else {
327 let prompt_for_events = move || (events, 0);
328 Ok(GetContextOutput::Zeta1(
329 cx.update(|cx| {
330 zeta::gather_context(
331 full_path_str,
332 &snapshot,
333 clipped_cursor,
334 prompt_for_events,
335 cx,
336 )
337 })?
338 .await?,
339 ))
340 }
341}
342
343pub async fn retrieval_stats(
344 worktree: PathBuf,
345 file_indexing_parallelism: usize,
346 app_state: Arc<ZetaCliAppState>,
347 cx: &mut AsyncApp,
348) -> Result<String> {
349 let worktree_path = worktree.canonicalize()?;
350
351 let project = cx.update(|cx| {
352 Project::local(
353 app_state.client.clone(),
354 app_state.node_runtime.clone(),
355 app_state.user_store.clone(),
356 app_state.languages.clone(),
357 app_state.fs.clone(),
358 None,
359 cx,
360 )
361 })?;
362
363 let worktree = project
364 .update(cx, |project, cx| {
365 project.create_worktree(&worktree_path, true, cx)
366 })?
367 .await?;
368 let worktree_id = worktree.read_with(cx, |worktree, _cx| worktree.id())?;
369
370 // wait for worktree scan so that wait_for_initial_file_indexing waits for the whole worktree.
371 worktree
372 .read_with(cx, |worktree, _cx| {
373 worktree.as_local().unwrap().scan_complete()
374 })?
375 .await;
376
377 let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx))?;
378 index
379 .read_with(cx, |index, cx| index.wait_for_initial_file_indexing(cx))?
380 .await?;
381 let files = index
382 .read_with(cx, |index, cx| index.indexed_file_paths(cx))?
383 .await
384 .into_iter()
385 .filter(|project_path| {
386 project_path
387 .path
388 .extension()
389 .is_some_and(|extension| !["md", "json", "sh", "diff"].contains(&extension))
390 })
391 .collect::<Vec<_>>();
392
393 let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
394 cx.subscribe(&lsp_store, {
395 move |_, event, _| {
396 if let project::LspStoreEvent::LanguageServerUpdate {
397 message:
398 client::proto::update_language_server::Variant::WorkProgress(
399 client::proto::LspWorkProgress {
400 message: Some(message),
401 ..
402 },
403 ),
404 ..
405 } = event
406 {
407 println!("⟲ {message}")
408 }
409 }
410 })?
411 .detach();
412
413 let mut lsp_open_handles = Vec::new();
414 let mut output = std::fs::File::create("retrieval-stats.txt")?;
415 let mut results = Vec::new();
416 let mut ready_languages = HashSet::default();
417 for (file_index, project_path) in files.iter().enumerate() {
418 let processing_file_message = format!(
419 "Processing file {} of {}: {}",
420 file_index + 1,
421 files.len(),
422 project_path.path.display(PathStyle::Posix)
423 );
424 println!("{}", processing_file_message);
425 write!(output, "{processing_file_message}\n\n").ok();
426
427 let Some((lsp_open_handle, language_server_id, buffer)) = open_buffer_with_language_server(
428 &project,
429 &worktree,
430 &project_path.path,
431 &mut ready_languages,
432 cx,
433 )
434 .await
435 .log_err() else {
436 continue;
437 };
438 lsp_open_handles.push(lsp_open_handle);
439
440 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
441 let full_range = 0..snapshot.len();
442 let references = references_in_range(
443 full_range,
444 &snapshot.text(),
445 ReferenceRegion::Nearby,
446 &snapshot,
447 );
448
449 loop {
450 let is_ready = lsp_store
451 .read_with(cx, |lsp_store, _cx| {
452 lsp_store
453 .language_server_statuses
454 .get(&language_server_id)
455 .is_some_and(|status| status.pending_work.is_empty())
456 })
457 .unwrap();
458 if is_ready {
459 break;
460 }
461 cx.background_executor()
462 .timer(Duration::from_millis(10))
463 .await;
464 }
465
466 let index = index.read_with(cx, |index, _cx| index.state().clone())?;
467 let index = index.lock().await;
468 for reference in references {
469 let query_point = snapshot.offset_to_point(reference.range.start);
470 let mut single_reference_map = HashMap::default();
471 single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]);
472 let edit_prediction_context = EditPredictionContext::gather_context_with_references_fn(
473 query_point,
474 &snapshot,
475 &zeta2::DEFAULT_EXCERPT_OPTIONS,
476 Some(&index),
477 |_, _, _| single_reference_map,
478 );
479
480 let Some(edit_prediction_context) = edit_prediction_context else {
481 let result = RetrievalStatsResult {
482 identifier: reference.identifier,
483 point: query_point,
484 outcome: RetrievalStatsOutcome::NoExcerpt,
485 };
486 write!(output, "{:?}\n\n", result)?;
487 results.push(result);
488 continue;
489 };
490
491 let mut retrieved_definitions = Vec::new();
492 for scored_declaration in edit_prediction_context.declarations {
493 match &scored_declaration.declaration {
494 Declaration::File {
495 project_entry_id,
496 declaration,
497 } => {
498 let Some(path) = worktree.read_with(cx, |worktree, _cx| {
499 worktree
500 .entry_for_id(*project_entry_id)
501 .map(|entry| entry.path.clone())
502 })?
503 else {
504 log::error!("bug: file project entry not found");
505 continue;
506 };
507 let project_path = ProjectPath {
508 worktree_id,
509 path: path.clone(),
510 };
511 let buffer = project
512 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
513 .await?;
514 let rope = buffer.read_with(cx, |buffer, _cx| buffer.as_rope().clone())?;
515 retrieved_definitions.push((
516 path,
517 rope.offset_to_point(declaration.item_range.start)
518 ..rope.offset_to_point(declaration.item_range.end),
519 scored_declaration.scores.declaration,
520 scored_declaration.scores.retrieval,
521 ));
522 }
523 Declaration::Buffer {
524 project_entry_id,
525 rope,
526 declaration,
527 ..
528 } => {
529 let Some(path) = worktree.read_with(cx, |worktree, _cx| {
530 worktree
531 .entry_for_id(*project_entry_id)
532 .map(|entry| entry.path.clone())
533 })?
534 else {
535 // This case happens when dependency buffers have been opened by
536 // go-to-definition, resulting in single-file worktrees.
537 continue;
538 };
539 retrieved_definitions.push((
540 path,
541 rope.offset_to_point(declaration.item_range.start)
542 ..rope.offset_to_point(declaration.item_range.end),
543 scored_declaration.scores.declaration,
544 scored_declaration.scores.retrieval,
545 ));
546 }
547 }
548 }
549 retrieved_definitions
550 .sort_by_key(|(_, _, _, retrieval_score)| Reverse(OrderedFloat(*retrieval_score)));
551
552 // TODO: Consider still checking language server in this case, or having a mode for
553 // this. For now assuming that the purpose of this is to refine the ranking rather than
554 // refining whether the definition is present at all.
555 if retrieved_definitions.is_empty() {
556 continue;
557 }
558
559 // TODO: Rename declaration to definition in edit_prediction_context?
560 let lsp_result = project
561 .update(cx, |project, cx| {
562 project.definitions(&buffer, reference.range.start, cx)
563 })?
564 .await;
565 match lsp_result {
566 Ok(lsp_definitions) => {
567 let lsp_definitions = lsp_definitions
568 .unwrap_or_default()
569 .into_iter()
570 .filter_map(|definition| {
571 definition
572 .target
573 .buffer
574 .read_with(cx, |buffer, _cx| {
575 let path = buffer.file()?.path();
576 // filter out definitions from single-file worktrees
577 if path.is_empty() {
578 None
579 } else {
580 Some((
581 path.clone(),
582 definition.target.range.to_point(&buffer),
583 ))
584 }
585 })
586 .ok()?
587 })
588 .collect::<Vec<_>>();
589
590 let result = RetrievalStatsResult {
591 identifier: reference.identifier,
592 point: query_point,
593 outcome: RetrievalStatsOutcome::Success {
594 matches: lsp_definitions
595 .iter()
596 .map(|(path, range)| {
597 retrieved_definitions.iter().position(
598 |(retrieved_path, retrieved_range, _, _)| {
599 path == retrieved_path
600 && retrieved_range.contains_inclusive(&range)
601 },
602 )
603 })
604 .collect(),
605 lsp_definitions,
606 retrieved_definitions,
607 },
608 };
609 write!(output, "{:?}\n\n", result)?;
610 results.push(result);
611 }
612 Err(err) => {
613 let result = RetrievalStatsResult {
614 identifier: reference.identifier,
615 point: query_point,
616 outcome: RetrievalStatsOutcome::LanguageServerError {
617 message: err.to_string(),
618 },
619 };
620 write!(output, "{:?}\n\n", result)?;
621 results.push(result);
622 }
623 }
624 }
625 }
626
627 let mut no_excerpt_count = 0;
628 let mut error_count = 0;
629 let mut definitions_count = 0;
630 let mut top_match_count = 0;
631 let mut non_top_match_count = 0;
632 let mut ranking_involved_count = 0;
633 let mut ranking_involved_top_match_count = 0;
634 let mut ranking_involved_non_top_match_count = 0;
635 for result in &results {
636 match &result.outcome {
637 RetrievalStatsOutcome::NoExcerpt => no_excerpt_count += 1,
638 RetrievalStatsOutcome::LanguageServerError { .. } => error_count += 1,
639 RetrievalStatsOutcome::Success {
640 matches,
641 retrieved_definitions,
642 ..
643 } => {
644 definitions_count += 1;
645 let top_matches = matches.contains(&Some(0));
646 if top_matches {
647 top_match_count += 1;
648 }
649 let non_top_matches = !top_matches && matches.iter().any(|index| *index != Some(0));
650 if non_top_matches {
651 non_top_match_count += 1;
652 }
653 if retrieved_definitions.len() > 1 {
654 ranking_involved_count += 1;
655 if top_matches {
656 ranking_involved_top_match_count += 1;
657 }
658 if non_top_matches {
659 ranking_involved_non_top_match_count += 1;
660 }
661 }
662 }
663 }
664 }
665
666 println!("\nStats:\n");
667 println!("No Excerpt: {}", no_excerpt_count);
668 println!("Language Server Error: {}", error_count);
669 println!("Definitions: {}", definitions_count);
670 println!("Top Match: {}", top_match_count);
671 println!("Non-Top Match: {}", non_top_match_count);
672 println!("Ranking Involved: {}", ranking_involved_count);
673 println!(
674 "Ranking Involved Top Match: {}",
675 ranking_involved_top_match_count
676 );
677 println!(
678 "Ranking Involved Non-Top Match: {}",
679 ranking_involved_non_top_match_count
680 );
681
682 Ok("".to_string())
683}
684
685#[derive(Debug)]
686struct RetrievalStatsResult {
687 #[allow(dead_code)]
688 identifier: Identifier,
689 #[allow(dead_code)]
690 point: Point,
691 outcome: RetrievalStatsOutcome,
692}
693
694#[derive(Debug)]
695enum RetrievalStatsOutcome {
696 NoExcerpt,
697 LanguageServerError {
698 #[allow(dead_code)]
699 message: String,
700 },
701 Success {
702 matches: Vec<Option<usize>>,
703 #[allow(dead_code)]
704 lsp_definitions: Vec<(Arc<RelPath>, Range<Point>)>,
705 retrieved_definitions: Vec<(Arc<RelPath>, Range<Point>, f32, f32)>,
706 },
707}
708
709pub async fn open_buffer(
710 project: &Entity<Project>,
711 worktree: &Entity<Worktree>,
712 path: &RelPath,
713 cx: &mut AsyncApp,
714) -> Result<Entity<Buffer>> {
715 let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
716 worktree_id: worktree.id(),
717 path: path.into(),
718 })?;
719
720 project
721 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
722 .await
723}
724
725pub async fn open_buffer_with_language_server(
726 project: &Entity<Project>,
727 worktree: &Entity<Worktree>,
728 path: &RelPath,
729 ready_languages: &mut HashSet<LanguageId>,
730 cx: &mut AsyncApp,
731) -> Result<(Entity<Entity<Buffer>>, LanguageServerId, Entity<Buffer>)> {
732 let buffer = open_buffer(project, worktree, path, cx).await?;
733
734 let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
735 (
736 project.register_buffer_with_language_servers(&buffer, cx),
737 project.path_style(cx),
738 )
739 })?;
740
741 let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
742 buffer.language().map(|language| language.id())
743 })?
744 else {
745 return Err(anyhow!("No language for {}", path.display(path_style)));
746 };
747
748 let log_prefix = path.display(path_style);
749 if !ready_languages.contains(&language_id) {
750 wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
751 ready_languages.insert(language_id);
752 }
753
754 let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
755
756 // hacky wait for buffer to be registered with the language server
757 for _ in 0..100 {
758 let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| {
759 buffer.update(cx, |buffer, cx| {
760 lsp_store
761 .language_servers_for_local_buffer(&buffer, cx)
762 .next()
763 .map(|(_, language_server)| language_server.server_id())
764 })
765 })?
766 else {
767 cx.background_executor()
768 .timer(Duration::from_millis(10))
769 .await;
770 continue;
771 };
772
773 return Ok((lsp_open_handle, language_server_id, buffer));
774 }
775
776 return Err(anyhow!("No language server found for buffer"));
777}
778
779// TODO: Dedupe with similar function in crates/eval/src/instance.rs
780pub fn wait_for_lang_server(
781 project: &Entity<Project>,
782 buffer: &Entity<Buffer>,
783 log_prefix: String,
784 cx: &mut AsyncApp,
785) -> Task<Result<()>> {
786 println!("{}⏵ Waiting for language server", log_prefix);
787
788 let (mut tx, mut rx) = mpsc::channel(1);
789
790 let lsp_store = project
791 .read_with(cx, |project, _| project.lsp_store())
792 .unwrap();
793
794 let has_lang_server = buffer
795 .update(cx, |buffer, cx| {
796 lsp_store.update(cx, |lsp_store, cx| {
797 lsp_store
798 .language_servers_for_local_buffer(buffer, cx)
799 .next()
800 .is_some()
801 })
802 })
803 .unwrap_or(false);
804
805 if has_lang_server {
806 project
807 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
808 .unwrap()
809 .detach();
810 }
811 let (mut added_tx, mut added_rx) = mpsc::channel(1);
812
813 let subscriptions = [
814 cx.subscribe(&lsp_store, {
815 let log_prefix = log_prefix.clone();
816 move |_, event, _| {
817 if let project::LspStoreEvent::LanguageServerUpdate {
818 message:
819 client::proto::update_language_server::Variant::WorkProgress(
820 client::proto::LspWorkProgress {
821 message: Some(message),
822 ..
823 },
824 ),
825 ..
826 } = event
827 {
828 println!("{}⟲ {message}", log_prefix)
829 }
830 }
831 }),
832 cx.subscribe(project, {
833 let buffer = buffer.clone();
834 move |project, event, cx| match event {
835 project::Event::LanguageServerAdded(_, _, _) => {
836 let buffer = buffer.clone();
837 project
838 .update(cx, |project, cx| project.save_buffer(buffer, cx))
839 .detach();
840 added_tx.try_send(()).ok();
841 }
842 project::Event::DiskBasedDiagnosticsFinished { .. } => {
843 tx.try_send(()).ok();
844 }
845 _ => {}
846 }
847 }),
848 ];
849
850 cx.spawn(async move |cx| {
851 if !has_lang_server {
852 // some buffers never have a language server, so this aborts quickly in that case.
853 let timeout = cx.background_executor().timer(Duration::from_secs(5));
854 futures::select! {
855 _ = added_rx.next() => {},
856 _ = timeout.fuse() => {
857 anyhow::bail!("Waiting for language server add timed out after 5 seconds");
858 }
859 };
860 }
861 let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
862 let result = futures::select! {
863 _ = rx.next() => {
864 println!("{}⚑ Language server idle", log_prefix);
865 anyhow::Ok(())
866 },
867 _ = timeout.fuse() => {
868 anyhow::bail!("LSP wait timed out after 5 minutes");
869 }
870 };
871 drop(subscriptions);
872 result
873 })
874}
875
876fn main() {
877 zlog::init();
878 zlog::init_output_stderr();
879 let args = ZetaCliArgs::parse();
880 let http_client = Arc::new(ReqwestClient::new());
881 let app = Application::headless().with_http_client(http_client);
882
883 app.run(move |cx| {
884 let app_state = Arc::new(headless::init(cx));
885 cx.spawn(async move |cx| {
886 let result = match args.command {
887 Commands::Zeta2Context {
888 zeta2_args,
889 context_args,
890 } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
891 Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
892 Ok(GetContextOutput::Zeta2(output)) => Ok(output),
893 Err(err) => Err(err),
894 },
895 Commands::Context(context_args) => {
896 match get_context(None, context_args, &app_state, cx).await {
897 Ok(GetContextOutput::Zeta1(output)) => {
898 Ok(serde_json::to_string_pretty(&output.body).unwrap())
899 }
900 Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
901 Err(err) => Err(err),
902 }
903 }
904 Commands::Predict {
905 predict_edits_body,
906 context_args,
907 } => {
908 cx.spawn(async move |cx| {
909 let app_version = cx.update(|cx| AppVersion::global(cx))?;
910 app_state.client.sign_in(true, cx).await?;
911 let llm_token = LlmApiToken::default();
912 llm_token.refresh(&app_state.client).await?;
913
914 let predict_edits_body =
915 if let Some(predict_edits_body) = predict_edits_body {
916 serde_json::from_str(&predict_edits_body.read_to_string().await?)?
917 } else if let Some(context_args) = context_args {
918 match get_context(None, context_args, &app_state, cx).await? {
919 GetContextOutput::Zeta1(output) => output.body,
920 GetContextOutput::Zeta2 { .. } => unreachable!(),
921 }
922 } else {
923 return Err(anyhow!(
924 "Expected either --predict-edits-body-file \
925 or the required args of the `context` command."
926 ));
927 };
928
929 let (response, _usage) =
930 Zeta::perform_predict_edits(PerformPredictEditsParams {
931 client: app_state.client.clone(),
932 llm_token,
933 app_version,
934 body: predict_edits_body,
935 })
936 .await?;
937
938 Ok(response.output_excerpt)
939 })
940 .await
941 }
942 Commands::RetrievalStats {
943 worktree,
944 file_indexing_parallelism,
945 } => retrieval_stats(worktree, file_indexing_parallelism, app_state, cx).await,
946 };
947 match result {
948 Ok(output) => {
949 println!("{}", output);
950 let _ = cx.update(|cx| cx.quit());
951 }
952 Err(e) => {
953 eprintln!("Failed: {:?}", e);
954 exit(1);
955 }
956 }
957 })
958 .detach();
959 });
960}