1mod headless;
2
3use anyhow::{Context as _, Result, anyhow};
4use clap::{Args, Parser, Subcommand};
5use cloud_llm_client::predict_edits_v3::{self, DeclarationScoreComponents};
6use edit_prediction_context::{
7 Declaration, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
8 EditPredictionExcerptOptions, EditPredictionScoreOptions, Identifier, Imports, Reference,
9 ReferenceRegion, SyntaxIndex, SyntaxIndexState, references_in_range,
10};
11use futures::channel::mpsc;
12use futures::{FutureExt as _, StreamExt as _};
13use gpui::{AppContext, Application, AsyncApp};
14use gpui::{Entity, Task};
15use language::{Bias, BufferSnapshot, LanguageServerId, Point};
16use language::{Buffer, OffsetRangeExt};
17use language::{LanguageId, ParseStatus};
18use language_model::LlmApiToken;
19use ordered_float::OrderedFloat;
20use project::{Project, ProjectEntryId, ProjectPath, Worktree};
21use release_channel::AppVersion;
22use reqwest_client::ReqwestClient;
23use serde::{Deserialize, Deserializer, Serialize, Serializer};
24use serde_json::json;
25use std::cmp::Reverse;
26use std::collections::{HashMap, HashSet};
27use std::fmt::{self, Display};
28use std::fs::File;
29use std::hash::Hash;
30use std::hash::Hasher;
31use std::io::Write as _;
32use std::ops::Range;
33use std::path::{Path, PathBuf};
34use std::process::exit;
35use std::str::FromStr;
36use std::sync::atomic::AtomicUsize;
37use std::sync::{Arc, atomic};
38use std::time::Duration;
39use util::paths::PathStyle;
40use util::rel_path::RelPath;
41use util::{RangeExt, ResultExt as _};
42use zeta::{PerformPredictEditsParams, Zeta};
43
44use crate::headless::ZetaCliAppState;
45
46#[derive(Parser, Debug)]
47#[command(name = "zeta")]
48struct ZetaCliArgs {
49 #[command(subcommand)]
50 command: Commands,
51}
52
53#[derive(Subcommand, Debug)]
54enum Commands {
55 Context(ContextArgs),
56 Zeta2Context {
57 #[clap(flatten)]
58 zeta2_args: Zeta2Args,
59 #[clap(flatten)]
60 context_args: ContextArgs,
61 },
62 Predict {
63 #[arg(long)]
64 predict_edits_body: Option<FileOrStdin>,
65 #[clap(flatten)]
66 context_args: Option<ContextArgs>,
67 },
68 RetrievalStats {
69 #[clap(flatten)]
70 zeta2_args: Zeta2Args,
71 #[arg(long)]
72 worktree: PathBuf,
73 #[arg(long)]
74 extension: Option<String>,
75 #[arg(long)]
76 limit: Option<usize>,
77 #[arg(long)]
78 skip: Option<usize>,
79 },
80}
81
82#[derive(Debug, Args)]
83#[group(requires = "worktree")]
84struct ContextArgs {
85 #[arg(long)]
86 worktree: PathBuf,
87 #[arg(long)]
88 cursor: SourceLocation,
89 #[arg(long)]
90 use_language_server: bool,
91 #[arg(long)]
92 events: Option<FileOrStdin>,
93}
94
95#[derive(Debug, Args)]
96struct Zeta2Args {
97 #[arg(long, default_value_t = 8192)]
98 max_prompt_bytes: usize,
99 #[arg(long, default_value_t = 2048)]
100 max_excerpt_bytes: usize,
101 #[arg(long, default_value_t = 1024)]
102 min_excerpt_bytes: usize,
103 #[arg(long, default_value_t = 0.66)]
104 target_before_cursor_over_total_bytes: f32,
105 #[arg(long, default_value_t = 1024)]
106 max_diagnostic_bytes: usize,
107 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
108 prompt_format: PromptFormat,
109 #[arg(long, value_enum, default_value_t = Default::default())]
110 output_format: OutputFormat,
111 #[arg(long, default_value_t = 42)]
112 file_indexing_parallelism: usize,
113 #[arg(long, default_value_t = false)]
114 disable_imports_gathering: bool,
115}
116
117#[derive(clap::ValueEnum, Default, Debug, Clone)]
118enum PromptFormat {
119 #[default]
120 MarkedExcerpt,
121 LabeledSections,
122 OnlySnippets,
123}
124
125impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
126 fn into(self) -> predict_edits_v3::PromptFormat {
127 match self {
128 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
129 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
130 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
131 }
132 }
133}
134
135#[derive(clap::ValueEnum, Default, Debug, Clone)]
136enum OutputFormat {
137 #[default]
138 Prompt,
139 Request,
140 Full,
141}
142
143#[derive(Debug, Clone)]
144enum FileOrStdin {
145 File(PathBuf),
146 Stdin,
147}
148
149impl FileOrStdin {
150 async fn read_to_string(&self) -> Result<String, std::io::Error> {
151 match self {
152 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
153 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
154 }
155 }
156}
157
158impl FromStr for FileOrStdin {
159 type Err = <PathBuf as FromStr>::Err;
160
161 fn from_str(s: &str) -> Result<Self, Self::Err> {
162 match s {
163 "-" => Ok(Self::Stdin),
164 _ => Ok(Self::File(PathBuf::from_str(s)?)),
165 }
166 }
167}
168
169#[derive(Debug, Clone, Hash, Eq, PartialEq)]
170struct SourceLocation {
171 path: Arc<RelPath>,
172 point: Point,
173}
174
175impl Serialize for SourceLocation {
176 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
177 where
178 S: Serializer,
179 {
180 serializer.serialize_str(&self.to_string())
181 }
182}
183
184impl<'de> Deserialize<'de> for SourceLocation {
185 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
186 where
187 D: Deserializer<'de>,
188 {
189 let s = String::deserialize(deserializer)?;
190 s.parse().map_err(serde::de::Error::custom)
191 }
192}
193
194impl Display for SourceLocation {
195 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196 write!(
197 f,
198 "{}:{}:{}",
199 self.path.display(PathStyle::Posix),
200 self.point.row + 1,
201 self.point.column + 1
202 )
203 }
204}
205
206impl FromStr for SourceLocation {
207 type Err = anyhow::Error;
208
209 fn from_str(s: &str) -> Result<Self> {
210 let parts: Vec<&str> = s.split(':').collect();
211 if parts.len() != 3 {
212 return Err(anyhow!(
213 "Invalid source location. Expected 'file.rs:line:column', got '{}'",
214 s
215 ));
216 }
217
218 let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
219 let line: u32 = parts[1]
220 .parse()
221 .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
222 let column: u32 = parts[2]
223 .parse()
224 .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
225
226 // Convert from 1-based to 0-based indexing
227 let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
228
229 Ok(SourceLocation { path, point })
230 }
231}
232
233enum GetContextOutput {
234 Zeta1(zeta::GatherContextOutput),
235 Zeta2(String),
236}
237
238async fn get_context(
239 zeta2_args: Option<Zeta2Args>,
240 args: ContextArgs,
241 app_state: &Arc<ZetaCliAppState>,
242 cx: &mut AsyncApp,
243) -> Result<GetContextOutput> {
244 let ContextArgs {
245 worktree: worktree_path,
246 cursor,
247 use_language_server,
248 events,
249 } = args;
250
251 let worktree_path = worktree_path.canonicalize()?;
252
253 let project = cx.update(|cx| {
254 Project::local(
255 app_state.client.clone(),
256 app_state.node_runtime.clone(),
257 app_state.user_store.clone(),
258 app_state.languages.clone(),
259 app_state.fs.clone(),
260 None,
261 cx,
262 )
263 })?;
264
265 let worktree = project
266 .update(cx, |project, cx| {
267 project.create_worktree(&worktree_path, true, cx)
268 })?
269 .await?;
270
271 let mut ready_languages = HashSet::default();
272 let (_lsp_open_handle, buffer) = if use_language_server {
273 let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
274 project.clone(),
275 worktree.clone(),
276 cursor.path.clone(),
277 &mut ready_languages,
278 cx,
279 )
280 .await?;
281 (Some(lsp_open_handle), buffer)
282 } else {
283 let buffer =
284 open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
285 (None, buffer)
286 };
287
288 let full_path_str = worktree
289 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
290 .display(PathStyle::local())
291 .to_string();
292
293 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
294 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
295 if clipped_cursor != cursor.point {
296 let max_row = snapshot.max_point().row;
297 if cursor.point.row < max_row {
298 return Err(anyhow!(
299 "Cursor position {:?} is out of bounds (line length is {})",
300 cursor.point,
301 snapshot.line_len(cursor.point.row)
302 ));
303 } else {
304 return Err(anyhow!(
305 "Cursor position {:?} is out of bounds (max row is {})",
306 cursor.point,
307 max_row
308 ));
309 }
310 }
311
312 let events = match events {
313 Some(events) => events.read_to_string().await?,
314 None => String::new(),
315 };
316
317 if let Some(zeta2_args) = zeta2_args {
318 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
319 // the whole worktree.
320 worktree
321 .read_with(cx, |worktree, _cx| {
322 worktree.as_local().unwrap().scan_complete()
323 })?
324 .await;
325 let output = cx
326 .update(|cx| {
327 let zeta = cx.new(|cx| {
328 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
329 });
330 let indexing_done_task = zeta.update(cx, |zeta, cx| {
331 zeta.set_options(zeta2_args.to_options(true));
332 zeta.register_buffer(&buffer, &project, cx);
333 zeta.wait_for_initial_indexing(&project, cx)
334 });
335 cx.spawn(async move |cx| {
336 indexing_done_task.await?;
337 let request = zeta
338 .update(cx, |zeta, cx| {
339 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
340 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
341 })?
342 .await?;
343
344 let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
345 let (prompt_string, section_labels) = planned_prompt.to_prompt_string()?;
346
347 match zeta2_args.output_format {
348 OutputFormat::Prompt => anyhow::Ok(prompt_string),
349 OutputFormat::Request => {
350 anyhow::Ok(serde_json::to_string_pretty(&request)?)
351 }
352 OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
353 "request": request,
354 "prompt": prompt_string,
355 "section_labels": section_labels,
356 }))?),
357 }
358 })
359 })?
360 .await?;
361 Ok(GetContextOutput::Zeta2(output))
362 } else {
363 let prompt_for_events = move || (events, 0);
364 Ok(GetContextOutput::Zeta1(
365 cx.update(|cx| {
366 zeta::gather_context(
367 full_path_str,
368 &snapshot,
369 clipped_cursor,
370 prompt_for_events,
371 cx,
372 )
373 })?
374 .await?,
375 ))
376 }
377}
378
379impl Zeta2Args {
380 fn to_options(&self, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions {
381 zeta2::ZetaOptions {
382 context: EditPredictionContextOptions {
383 use_imports: !self.disable_imports_gathering,
384 excerpt: EditPredictionExcerptOptions {
385 max_bytes: self.max_excerpt_bytes,
386 min_bytes: self.min_excerpt_bytes,
387 target_before_cursor_over_total_bytes: self
388 .target_before_cursor_over_total_bytes,
389 },
390 score: EditPredictionScoreOptions {
391 omit_excerpt_overlaps,
392 },
393 },
394 max_diagnostic_bytes: self.max_diagnostic_bytes,
395 max_prompt_bytes: self.max_prompt_bytes,
396 prompt_format: self.prompt_format.clone().into(),
397 file_indexing_parallelism: self.file_indexing_parallelism,
398 }
399 }
400}
401
402pub async fn retrieval_stats(
403 worktree: PathBuf,
404 app_state: Arc<ZetaCliAppState>,
405 only_extension: Option<String>,
406 file_limit: Option<usize>,
407 skip_files: Option<usize>,
408 options: zeta2::ZetaOptions,
409 cx: &mut AsyncApp,
410) -> Result<String> {
411 let options = Arc::new(options);
412 let worktree_path = worktree.canonicalize()?;
413
414 let project = cx.update(|cx| {
415 Project::local(
416 app_state.client.clone(),
417 app_state.node_runtime.clone(),
418 app_state.user_store.clone(),
419 app_state.languages.clone(),
420 app_state.fs.clone(),
421 None,
422 cx,
423 )
424 })?;
425
426 let worktree = project
427 .update(cx, |project, cx| {
428 project.create_worktree(&worktree_path, true, cx)
429 })?
430 .await?;
431
432 // wait for worktree scan so that wait_for_initial_file_indexing waits for the whole worktree.
433 worktree
434 .read_with(cx, |worktree, _cx| {
435 worktree.as_local().unwrap().scan_complete()
436 })?
437 .await;
438
439 let index = cx.new(|cx| SyntaxIndex::new(&project, options.file_indexing_parallelism, cx))?;
440 index
441 .read_with(cx, |index, cx| index.wait_for_initial_file_indexing(cx))?
442 .await?;
443 let indexed_files = index
444 .read_with(cx, |index, cx| index.indexed_file_paths(cx))?
445 .await;
446 let mut filtered_files = indexed_files
447 .into_iter()
448 .filter(|project_path| {
449 let file_extension = project_path.path.extension();
450 if let Some(only_extension) = only_extension.as_ref() {
451 file_extension.is_some_and(|extension| extension == only_extension)
452 } else {
453 file_extension
454 .is_some_and(|extension| !["md", "json", "sh", "diff"].contains(&extension))
455 }
456 })
457 .collect::<Vec<_>>();
458 filtered_files.sort_by(|a, b| a.path.cmp(&b.path));
459
460 let index_state = index.read_with(cx, |index, _cx| index.state().clone())?;
461 cx.update(|_| {
462 drop(index);
463 })?;
464 let index_state = Arc::new(
465 Arc::into_inner(index_state)
466 .context("Index state had more than 1 reference")?
467 .into_inner(),
468 );
469
470 struct FileSnapshot {
471 project_entry_id: ProjectEntryId,
472 snapshot: BufferSnapshot,
473 hash: u64,
474 parent_abs_path: Arc<Path>,
475 }
476
477 let files: Vec<FileSnapshot> = futures::future::try_join_all({
478 filtered_files
479 .iter()
480 .map(|file| {
481 let buffer_task =
482 open_buffer(project.clone(), worktree.clone(), file.path.clone(), cx);
483 cx.spawn(async move |cx| {
484 let buffer = buffer_task.await?;
485 let (project_entry_id, parent_abs_path, snapshot) =
486 buffer.read_with(cx, |buffer, cx| {
487 let file = project::File::from_dyn(buffer.file()).unwrap();
488 let project_entry_id = file.project_entry_id().unwrap();
489 let mut parent_abs_path = file.worktree.read(cx).absolutize(&file.path);
490 if !parent_abs_path.pop() {
491 panic!("Invalid worktree path");
492 }
493
494 (project_entry_id, parent_abs_path, buffer.snapshot())
495 })?;
496
497 anyhow::Ok(
498 cx.background_spawn(async move {
499 let mut hasher = collections::FxHasher::default();
500 snapshot.text().hash(&mut hasher);
501 FileSnapshot {
502 project_entry_id,
503 snapshot,
504 hash: hasher.finish(),
505 parent_abs_path: parent_abs_path.into(),
506 }
507 })
508 .await,
509 )
510 })
511 })
512 .collect::<Vec<_>>()
513 })
514 .await?;
515
516 let mut file_snapshots = HashMap::default();
517 let mut hasher = collections::FxHasher::default();
518 for FileSnapshot {
519 project_entry_id,
520 snapshot,
521 hash,
522 ..
523 } in &files
524 {
525 file_snapshots.insert(*project_entry_id, snapshot.clone());
526 hash.hash(&mut hasher);
527 }
528 let files_hash = hasher.finish();
529 let file_snapshots = Arc::new(file_snapshots);
530
531 let lsp_definitions_path = std::env::current_dir()?.join(format!(
532 "target/zeta2-lsp-definitions-{:x}.json",
533 files_hash
534 ));
535
536 let lsp_definitions: Arc<_> = if std::fs::exists(&lsp_definitions_path)? {
537 log::info!(
538 "Using cached LSP definitions from {}",
539 lsp_definitions_path.display()
540 );
541 serde_json::from_reader(File::open(&lsp_definitions_path)?)?
542 } else {
543 log::warn!(
544 "No LSP definitions found populating {}",
545 lsp_definitions_path.display()
546 );
547 let lsp_definitions =
548 gather_lsp_definitions(&filtered_files, &worktree, &project, cx).await?;
549 serde_json::to_writer_pretty(File::create(&lsp_definitions_path)?, &lsp_definitions)?;
550 lsp_definitions
551 }
552 .into();
553
554 let files_len = files.len().min(file_limit.unwrap_or(usize::MAX));
555 let done_count = Arc::new(AtomicUsize::new(0));
556
557 let (output_tx, mut output_rx) = mpsc::unbounded::<RetrievalStatsResult>();
558 let mut output = std::fs::File::create("target/zeta-retrieval-stats.txt")?;
559
560 let tasks = files
561 .into_iter()
562 .skip(skip_files.unwrap_or(0))
563 .take(file_limit.unwrap_or(usize::MAX))
564 .map(|project_file| {
565 let index_state = index_state.clone();
566 let lsp_definitions = lsp_definitions.clone();
567 let options = options.clone();
568 let output_tx = output_tx.clone();
569 let done_count = done_count.clone();
570 let file_snapshots = file_snapshots.clone();
571 cx.background_spawn(async move {
572 let snapshot = project_file.snapshot;
573
574 let full_range = 0..snapshot.len();
575 let references = references_in_range(
576 full_range,
577 &snapshot.text(),
578 ReferenceRegion::Nearby,
579 &snapshot,
580 );
581
582 println!("references: {}", references.len(),);
583
584 let imports = if options.context.use_imports {
585 Imports::gather(&snapshot, Some(&project_file.parent_abs_path))
586 } else {
587 Imports::default()
588 };
589
590 let path = snapshot.file().unwrap().path();
591
592 for reference in references {
593 let query_point = snapshot.offset_to_point(reference.range.start);
594 let source_location = SourceLocation {
595 path: path.clone(),
596 point: query_point,
597 };
598 let lsp_definitions = lsp_definitions
599 .definitions
600 .get(&source_location)
601 .cloned()
602 .unwrap_or_else(|| {
603 log::warn!(
604 "No definitions found for source location: {:?}",
605 source_location
606 );
607 Vec::new()
608 });
609
610 let retrieve_result = retrieve_definitions(
611 &reference,
612 &imports,
613 query_point,
614 &snapshot,
615 &index_state,
616 &file_snapshots,
617 &options,
618 )
619 .await?;
620
621 // TODO: LSP returns things like locals, this filters out some of those, but potentially
622 // hides some retrieval issues.
623 if retrieve_result.definitions.is_empty() {
624 continue;
625 }
626
627 let mut best_match = None;
628 let mut has_external_definition = false;
629 let mut in_excerpt = false;
630 for (index, retrieved_definition) in
631 retrieve_result.definitions.iter().enumerate()
632 {
633 for lsp_definition in &lsp_definitions {
634 let SourceRange {
635 path,
636 point_range,
637 offset_range,
638 } = lsp_definition;
639 let lsp_point_range =
640 SerializablePoint::into_language_point_range(point_range.clone());
641 has_external_definition = has_external_definition
642 || path.is_absolute()
643 || path
644 .components()
645 .any(|component| component.as_os_str() == "node_modules");
646 let is_match = path.as_path()
647 == retrieved_definition.path.as_std_path()
648 && retrieved_definition
649 .range
650 .contains_inclusive(&lsp_point_range);
651 if is_match {
652 if best_match.is_none() {
653 best_match = Some(index);
654 }
655 }
656 in_excerpt = in_excerpt
657 || retrieve_result.excerpt_range.as_ref().is_some_and(
658 |excerpt_range| excerpt_range.contains_inclusive(&offset_range),
659 );
660 }
661 }
662
663 let outcome = if let Some(best_match) = best_match {
664 RetrievalOutcome::Match { best_match }
665 } else if has_external_definition {
666 RetrievalOutcome::NoMatchDueToExternalLspDefinitions
667 } else if in_excerpt {
668 RetrievalOutcome::ProbablyLocal
669 } else {
670 RetrievalOutcome::NoMatch
671 };
672
673 let result = RetrievalStatsResult {
674 outcome,
675 path: path.clone(),
676 identifier: reference.identifier,
677 point: query_point,
678 lsp_definitions,
679 retrieved_definitions: retrieve_result.definitions,
680 };
681
682 output_tx.unbounded_send(result).ok();
683 }
684
685 println!(
686 "{:02}/{:02} done",
687 done_count.fetch_add(1, atomic::Ordering::Relaxed) + 1,
688 files_len,
689 );
690
691 anyhow::Ok(())
692 })
693 })
694 .collect::<Vec<_>>();
695
696 drop(output_tx);
697
698 let results_task = cx.background_spawn(async move {
699 let mut results = Vec::new();
700 while let Some(result) = output_rx.next().await {
701 output
702 .write_all(format!("{:#?}\n", result).as_bytes())
703 .log_err();
704 results.push(result)
705 }
706 results
707 });
708
709 futures::future::try_join_all(tasks).await?;
710 println!("Tasks completed");
711 let results = results_task.await;
712 println!("Results received");
713
714 let mut references_count = 0;
715
716 let mut included_count = 0;
717 let mut both_absent_count = 0;
718
719 let mut retrieved_count = 0;
720 let mut top_match_count = 0;
721 let mut non_top_match_count = 0;
722 let mut ranking_involved_top_match_count = 0;
723
724 let mut no_match_count = 0;
725 let mut no_match_none_retrieved = 0;
726 let mut no_match_wrong_retrieval = 0;
727
728 let mut expected_no_match_count = 0;
729 let mut in_excerpt_count = 0;
730 let mut external_definition_count = 0;
731
732 for result in results {
733 references_count += 1;
734 match &result.outcome {
735 RetrievalOutcome::Match { best_match } => {
736 included_count += 1;
737 retrieved_count += 1;
738 let multiple = result.retrieved_definitions.len() > 1;
739 if *best_match == 0 {
740 top_match_count += 1;
741 if multiple {
742 ranking_involved_top_match_count += 1;
743 }
744 } else {
745 non_top_match_count += 1;
746 }
747 }
748 RetrievalOutcome::NoMatch => {
749 if result.lsp_definitions.is_empty() {
750 included_count += 1;
751 both_absent_count += 1;
752 } else {
753 no_match_count += 1;
754 if result.retrieved_definitions.is_empty() {
755 no_match_none_retrieved += 1;
756 } else {
757 no_match_wrong_retrieval += 1;
758 }
759 }
760 }
761 RetrievalOutcome::NoMatchDueToExternalLspDefinitions => {
762 expected_no_match_count += 1;
763 external_definition_count += 1;
764 }
765 RetrievalOutcome::ProbablyLocal => {
766 included_count += 1;
767 in_excerpt_count += 1;
768 }
769 }
770 }
771
772 fn count_and_percentage(part: usize, total: usize) -> String {
773 format!("{} ({:.2}%)", part, (part as f64 / total as f64) * 100.0)
774 }
775
776 println!("");
777 println!("╮ references: {}", references_count);
778 println!(
779 "├─╮ included: {}",
780 count_and_percentage(included_count, references_count),
781 );
782 println!(
783 "│ ├─╮ retrieved: {}",
784 count_and_percentage(retrieved_count, references_count)
785 );
786 println!(
787 "│ │ ├─╮ top match : {}",
788 count_and_percentage(top_match_count, retrieved_count)
789 );
790 println!(
791 "│ │ │ ╰─╴ involving ranking: {}",
792 count_and_percentage(ranking_involved_top_match_count, top_match_count)
793 );
794 println!(
795 "│ │ ╰─╴ non-top match: {}",
796 count_and_percentage(non_top_match_count, retrieved_count)
797 );
798 println!(
799 "│ ├─╴ both absent: {}",
800 count_and_percentage(both_absent_count, included_count)
801 );
802 println!(
803 "│ ╰─╴ in excerpt: {}",
804 count_and_percentage(in_excerpt_count, included_count)
805 );
806 println!(
807 "├─╮ no match: {}",
808 count_and_percentage(no_match_count, references_count)
809 );
810 println!(
811 "│ ├─╴ none retrieved: {}",
812 count_and_percentage(no_match_none_retrieved, no_match_count)
813 );
814 println!(
815 "│ ╰─╴ wrong retrieval: {}",
816 count_and_percentage(no_match_wrong_retrieval, no_match_count)
817 );
818 println!(
819 "╰─╮ expected no match: {}",
820 count_and_percentage(expected_no_match_count, references_count)
821 );
822 println!(
823 " ╰─╴ external definition: {}",
824 count_and_percentage(external_definition_count, expected_no_match_count)
825 );
826
827 println!("");
828 println!("LSP definition cache at {}", lsp_definitions_path.display());
829
830 Ok("".to_string())
831}
832
833struct RetrieveResult {
834 definitions: Vec<RetrievedDefinition>,
835 excerpt_range: Option<Range<usize>>,
836}
837
838async fn retrieve_definitions(
839 reference: &Reference,
840 imports: &Imports,
841 query_point: Point,
842 snapshot: &BufferSnapshot,
843 index: &Arc<SyntaxIndexState>,
844 file_snapshots: &Arc<HashMap<ProjectEntryId, BufferSnapshot>>,
845 options: &Arc<zeta2::ZetaOptions>,
846) -> Result<RetrieveResult> {
847 let mut single_reference_map = HashMap::default();
848 single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]);
849 let edit_prediction_context = EditPredictionContext::gather_context_with_references_fn(
850 query_point,
851 snapshot,
852 imports,
853 &options.context,
854 Some(&index),
855 |_, _, _| single_reference_map,
856 );
857
858 let Some(edit_prediction_context) = edit_prediction_context else {
859 return Ok(RetrieveResult {
860 definitions: Vec::new(),
861 excerpt_range: None,
862 });
863 };
864
865 let mut retrieved_definitions = Vec::new();
866 for scored_declaration in edit_prediction_context.declarations {
867 match &scored_declaration.declaration {
868 Declaration::File {
869 project_entry_id,
870 declaration,
871 ..
872 } => {
873 let Some(snapshot) = file_snapshots.get(&project_entry_id) else {
874 log::error!("bug: file project entry not found");
875 continue;
876 };
877 let path = snapshot.file().unwrap().path().clone();
878 retrieved_definitions.push(RetrievedDefinition {
879 path,
880 range: snapshot.offset_to_point(declaration.item_range.start)
881 ..snapshot.offset_to_point(declaration.item_range.end),
882 score: scored_declaration.score(DeclarationStyle::Declaration),
883 retrieval_score: scored_declaration.retrieval_score(),
884 components: scored_declaration.components,
885 });
886 }
887 Declaration::Buffer {
888 project_entry_id,
889 rope,
890 declaration,
891 ..
892 } => {
893 let Some(snapshot) = file_snapshots.get(&project_entry_id) else {
894 // This case happens when dependency buffers have been opened by
895 // go-to-definition, resulting in single-file worktrees.
896 continue;
897 };
898 let path = snapshot.file().unwrap().path().clone();
899 retrieved_definitions.push(RetrievedDefinition {
900 path,
901 range: rope.offset_to_point(declaration.item_range.start)
902 ..rope.offset_to_point(declaration.item_range.end),
903 score: scored_declaration.score(DeclarationStyle::Declaration),
904 retrieval_score: scored_declaration.retrieval_score(),
905 components: scored_declaration.components,
906 });
907 }
908 }
909 }
910 retrieved_definitions.sort_by_key(|definition| Reverse(OrderedFloat(definition.score)));
911
912 Ok(RetrieveResult {
913 definitions: retrieved_definitions,
914 excerpt_range: Some(edit_prediction_context.excerpt.range),
915 })
916}
917
918async fn gather_lsp_definitions(
919 files: &[ProjectPath],
920 worktree: &Entity<Worktree>,
921 project: &Entity<Project>,
922 cx: &mut AsyncApp,
923) -> Result<LspResults> {
924 let worktree_id = worktree.read_with(cx, |worktree, _cx| worktree.id())?;
925
926 let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
927 cx.subscribe(&lsp_store, {
928 move |_, event, _| {
929 if let project::LspStoreEvent::LanguageServerUpdate {
930 message:
931 client::proto::update_language_server::Variant::WorkProgress(
932 client::proto::LspWorkProgress {
933 message: Some(message),
934 ..
935 },
936 ),
937 ..
938 } = event
939 {
940 println!("⟲ {message}")
941 }
942 }
943 })?
944 .detach();
945
946 let mut definitions = HashMap::default();
947 let mut error_count = 0;
948 let mut lsp_open_handles = Vec::new();
949 let mut ready_languages = HashSet::default();
950 for (file_index, project_path) in files.iter().enumerate() {
951 println!(
952 "Processing file {} of {}: {}",
953 file_index + 1,
954 files.len(),
955 project_path.path.display(PathStyle::Posix)
956 );
957
958 let Some((lsp_open_handle, language_server_id, buffer)) = open_buffer_with_language_server(
959 project.clone(),
960 worktree.clone(),
961 project_path.path.clone(),
962 &mut ready_languages,
963 cx,
964 )
965 .await
966 .log_err() else {
967 continue;
968 };
969 lsp_open_handles.push(lsp_open_handle);
970
971 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
972 let full_range = 0..snapshot.len();
973 let references = references_in_range(
974 full_range,
975 &snapshot.text(),
976 ReferenceRegion::Nearby,
977 &snapshot,
978 );
979
980 loop {
981 let is_ready = lsp_store
982 .read_with(cx, |lsp_store, _cx| {
983 lsp_store
984 .language_server_statuses
985 .get(&language_server_id)
986 .is_some_and(|status| status.pending_work.is_empty())
987 })
988 .unwrap();
989 if is_ready {
990 break;
991 }
992 cx.background_executor()
993 .timer(Duration::from_millis(10))
994 .await;
995 }
996
997 for reference in references {
998 // TODO: Rename declaration to definition in edit_prediction_context?
999 let lsp_result = project
1000 .update(cx, |project, cx| {
1001 project.definitions(&buffer, reference.range.start, cx)
1002 })?
1003 .await;
1004
1005 match lsp_result {
1006 Ok(lsp_definitions) => {
1007 let mut targets = Vec::new();
1008 for target in lsp_definitions.unwrap_or_default() {
1009 let buffer = target.target.buffer;
1010 let anchor_range = target.target.range;
1011 buffer.read_with(cx, |buffer, cx| {
1012 let Some(file) = project::File::from_dyn(buffer.file()) else {
1013 return;
1014 };
1015 let file_worktree = file.worktree.read(cx);
1016 let file_worktree_id = file_worktree.id();
1017 // Relative paths for worktree files, absolute for all others
1018 let path = if worktree_id != file_worktree_id {
1019 file.worktree.read(cx).absolutize(&file.path)
1020 } else {
1021 file.path.as_std_path().to_path_buf()
1022 };
1023 let offset_range = anchor_range.to_offset(&buffer);
1024 let point_range = SerializablePoint::from_language_point_range(
1025 offset_range.to_point(&buffer),
1026 );
1027 targets.push(SourceRange {
1028 path,
1029 offset_range,
1030 point_range,
1031 });
1032 })?;
1033 }
1034
1035 definitions.insert(
1036 SourceLocation {
1037 path: project_path.path.clone(),
1038 point: snapshot.offset_to_point(reference.range.start),
1039 },
1040 targets,
1041 );
1042 }
1043 Err(err) => {
1044 log::error!("Language server error: {err}");
1045 error_count += 1;
1046 }
1047 }
1048 }
1049 }
1050
1051 log::error!("Encountered {} language server errors", error_count);
1052
1053 Ok(LspResults { definitions })
1054}
1055
1056#[derive(Debug, Clone, Serialize, Deserialize)]
1057#[serde(transparent)]
1058struct LspResults {
1059 definitions: HashMap<SourceLocation, Vec<SourceRange>>,
1060}
1061
1062#[derive(Debug, Clone, Serialize, Deserialize)]
1063struct SourceRange {
1064 path: PathBuf,
1065 point_range: Range<SerializablePoint>,
1066 offset_range: Range<usize>,
1067}
1068
1069/// Serializes to 1-based row and column indices.
1070#[derive(Debug, Clone, Serialize, Deserialize)]
1071pub struct SerializablePoint {
1072 pub row: u32,
1073 pub column: u32,
1074}
1075
1076impl SerializablePoint {
1077 pub fn into_language_point_range(range: Range<Self>) -> Range<Point> {
1078 range.start.into()..range.end.into()
1079 }
1080
1081 pub fn from_language_point_range(range: Range<Point>) -> Range<Self> {
1082 range.start.into()..range.end.into()
1083 }
1084}
1085
1086impl From<Point> for SerializablePoint {
1087 fn from(point: Point) -> Self {
1088 SerializablePoint {
1089 row: point.row + 1,
1090 column: point.column + 1,
1091 }
1092 }
1093}
1094
1095impl From<SerializablePoint> for Point {
1096 fn from(serializable: SerializablePoint) -> Self {
1097 Point {
1098 row: serializable.row.saturating_sub(1),
1099 column: serializable.column.saturating_sub(1),
1100 }
1101 }
1102}
1103
1104#[derive(Debug)]
1105struct RetrievalStatsResult {
1106 outcome: RetrievalOutcome,
1107 #[allow(dead_code)]
1108 path: Arc<RelPath>,
1109 #[allow(dead_code)]
1110 identifier: Identifier,
1111 #[allow(dead_code)]
1112 point: Point,
1113 #[allow(dead_code)]
1114 lsp_definitions: Vec<SourceRange>,
1115 retrieved_definitions: Vec<RetrievedDefinition>,
1116}
1117
1118#[derive(Debug)]
1119enum RetrievalOutcome {
1120 Match {
1121 /// Lowest index within retrieved_definitions that matches an LSP definition.
1122 best_match: usize,
1123 },
1124 ProbablyLocal,
1125 NoMatch,
1126 NoMatchDueToExternalLspDefinitions,
1127}
1128
1129#[derive(Debug)]
1130struct RetrievedDefinition {
1131 path: Arc<RelPath>,
1132 range: Range<Point>,
1133 score: f32,
1134 #[allow(dead_code)]
1135 retrieval_score: f32,
1136 #[allow(dead_code)]
1137 components: DeclarationScoreComponents,
1138}
1139
1140pub fn open_buffer(
1141 project: Entity<Project>,
1142 worktree: Entity<Worktree>,
1143 path: Arc<RelPath>,
1144 cx: &AsyncApp,
1145) -> Task<Result<Entity<Buffer>>> {
1146 cx.spawn(async move |cx| {
1147 let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
1148 worktree_id: worktree.id(),
1149 path,
1150 })?;
1151
1152 let buffer = project
1153 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
1154 .await?;
1155
1156 let mut parse_status = buffer.read_with(cx, |buffer, _cx| buffer.parse_status())?;
1157 while *parse_status.borrow() != ParseStatus::Idle {
1158 parse_status.changed().await?;
1159 }
1160
1161 Ok(buffer)
1162 })
1163}
1164
1165pub async fn open_buffer_with_language_server(
1166 project: Entity<Project>,
1167 worktree: Entity<Worktree>,
1168 path: Arc<RelPath>,
1169 ready_languages: &mut HashSet<LanguageId>,
1170 cx: &mut AsyncApp,
1171) -> Result<(Entity<Entity<Buffer>>, LanguageServerId, Entity<Buffer>)> {
1172 let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?;
1173
1174 let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
1175 (
1176 project.register_buffer_with_language_servers(&buffer, cx),
1177 project.path_style(cx),
1178 )
1179 })?;
1180
1181 let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
1182 buffer.language().map(|language| language.id())
1183 })?
1184 else {
1185 return Err(anyhow!("No language for {}", path.display(path_style)));
1186 };
1187
1188 let log_prefix = path.display(path_style);
1189 if !ready_languages.contains(&language_id) {
1190 wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
1191 ready_languages.insert(language_id);
1192 }
1193
1194 let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
1195
1196 // hacky wait for buffer to be registered with the language server
1197 for _ in 0..100 {
1198 let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| {
1199 buffer.update(cx, |buffer, cx| {
1200 lsp_store
1201 .language_servers_for_local_buffer(&buffer, cx)
1202 .next()
1203 .map(|(_, language_server)| language_server.server_id())
1204 })
1205 })?
1206 else {
1207 cx.background_executor()
1208 .timer(Duration::from_millis(10))
1209 .await;
1210 continue;
1211 };
1212
1213 return Ok((lsp_open_handle, language_server_id, buffer));
1214 }
1215
1216 return Err(anyhow!("No language server found for buffer"));
1217}
1218
1219// TODO: Dedupe with similar function in crates/eval/src/instance.rs
1220pub fn wait_for_lang_server(
1221 project: &Entity<Project>,
1222 buffer: &Entity<Buffer>,
1223 log_prefix: String,
1224 cx: &mut AsyncApp,
1225) -> Task<Result<()>> {
1226 println!("{}⏵ Waiting for language server", log_prefix);
1227
1228 let (mut tx, mut rx) = mpsc::channel(1);
1229
1230 let lsp_store = project
1231 .read_with(cx, |project, _| project.lsp_store())
1232 .unwrap();
1233
1234 let has_lang_server = buffer
1235 .update(cx, |buffer, cx| {
1236 lsp_store.update(cx, |lsp_store, cx| {
1237 lsp_store
1238 .language_servers_for_local_buffer(buffer, cx)
1239 .next()
1240 .is_some()
1241 })
1242 })
1243 .unwrap_or(false);
1244
1245 if has_lang_server {
1246 project
1247 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
1248 .unwrap()
1249 .detach();
1250 }
1251 let (mut added_tx, mut added_rx) = mpsc::channel(1);
1252
1253 let subscriptions = [
1254 cx.subscribe(&lsp_store, {
1255 let log_prefix = log_prefix.clone();
1256 move |_, event, _| {
1257 if let project::LspStoreEvent::LanguageServerUpdate {
1258 message:
1259 client::proto::update_language_server::Variant::WorkProgress(
1260 client::proto::LspWorkProgress {
1261 message: Some(message),
1262 ..
1263 },
1264 ),
1265 ..
1266 } = event
1267 {
1268 println!("{}⟲ {message}", log_prefix)
1269 }
1270 }
1271 }),
1272 cx.subscribe(project, {
1273 let buffer = buffer.clone();
1274 move |project, event, cx| match event {
1275 project::Event::LanguageServerAdded(_, _, _) => {
1276 let buffer = buffer.clone();
1277 project
1278 .update(cx, |project, cx| project.save_buffer(buffer, cx))
1279 .detach();
1280 added_tx.try_send(()).ok();
1281 }
1282 project::Event::DiskBasedDiagnosticsFinished { .. } => {
1283 tx.try_send(()).ok();
1284 }
1285 _ => {}
1286 }
1287 }),
1288 ];
1289
1290 cx.spawn(async move |cx| {
1291 if !has_lang_server {
1292 // some buffers never have a language server, so this aborts quickly in that case.
1293 let timeout = cx.background_executor().timer(Duration::from_secs(5));
1294 futures::select! {
1295 _ = added_rx.next() => {},
1296 _ = timeout.fuse() => {
1297 anyhow::bail!("Waiting for language server add timed out after 5 seconds");
1298 }
1299 };
1300 }
1301 let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
1302 let result = futures::select! {
1303 _ = rx.next() => {
1304 println!("{}⚑ Language server idle", log_prefix);
1305 anyhow::Ok(())
1306 },
1307 _ = timeout.fuse() => {
1308 anyhow::bail!("LSP wait timed out after 5 minutes");
1309 }
1310 };
1311 drop(subscriptions);
1312 result
1313 })
1314}
1315
1316fn main() {
1317 zlog::init();
1318 zlog::init_output_stderr();
1319 let args = ZetaCliArgs::parse();
1320 let http_client = Arc::new(ReqwestClient::new());
1321 let app = Application::headless().with_http_client(http_client);
1322
1323 app.run(move |cx| {
1324 let app_state = Arc::new(headless::init(cx));
1325 cx.spawn(async move |cx| {
1326 let result = match args.command {
1327 Commands::Zeta2Context {
1328 zeta2_args,
1329 context_args,
1330 } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
1331 Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
1332 Ok(GetContextOutput::Zeta2(output)) => Ok(output),
1333 Err(err) => Err(err),
1334 },
1335 Commands::Context(context_args) => {
1336 match get_context(None, context_args, &app_state, cx).await {
1337 Ok(GetContextOutput::Zeta1(output)) => {
1338 Ok(serde_json::to_string_pretty(&output.body).unwrap())
1339 }
1340 Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
1341 Err(err) => Err(err),
1342 }
1343 }
1344 Commands::Predict {
1345 predict_edits_body,
1346 context_args,
1347 } => {
1348 cx.spawn(async move |cx| {
1349 let app_version = cx.update(|cx| AppVersion::global(cx))?;
1350 app_state.client.sign_in(true, cx).await?;
1351 let llm_token = LlmApiToken::default();
1352 llm_token.refresh(&app_state.client).await?;
1353
1354 let predict_edits_body =
1355 if let Some(predict_edits_body) = predict_edits_body {
1356 serde_json::from_str(&predict_edits_body.read_to_string().await?)?
1357 } else if let Some(context_args) = context_args {
1358 match get_context(None, context_args, &app_state, cx).await? {
1359 GetContextOutput::Zeta1(output) => output.body,
1360 GetContextOutput::Zeta2 { .. } => unreachable!(),
1361 }
1362 } else {
1363 return Err(anyhow!(
1364 "Expected either --predict-edits-body-file \
1365 or the required args of the `context` command."
1366 ));
1367 };
1368
1369 let (response, _usage) =
1370 Zeta::perform_predict_edits(PerformPredictEditsParams {
1371 client: app_state.client.clone(),
1372 llm_token,
1373 app_version,
1374 body: predict_edits_body,
1375 })
1376 .await?;
1377
1378 Ok(response.output_excerpt)
1379 })
1380 .await
1381 }
1382 Commands::RetrievalStats {
1383 zeta2_args,
1384 worktree,
1385 extension,
1386 limit,
1387 skip,
1388 } => {
1389 retrieval_stats(
1390 worktree,
1391 app_state,
1392 extension,
1393 limit,
1394 skip,
1395 (&zeta2_args).to_options(false),
1396 cx,
1397 )
1398 .await
1399 }
1400 };
1401 match result {
1402 Ok(output) => {
1403 println!("{}", output);
1404 let _ = cx.update(|cx| cx.quit());
1405 }
1406 Err(e) => {
1407 eprintln!("Failed: {:?}", e);
1408 exit(1);
1409 }
1410 }
1411 })
1412 .detach();
1413 });
1414}