1mod headless;
2
3use anyhow::{Context as _, Result, anyhow};
4use clap::{Args, Parser, Subcommand};
5use cloud_llm_client::predict_edits_v3::{self, DeclarationScoreComponents};
6use edit_prediction_context::{
7 Declaration, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
8 EditPredictionExcerptOptions, EditPredictionScoreOptions, Identifier, Imports, Reference,
9 ReferenceRegion, SyntaxIndex, SyntaxIndexState, references_in_range,
10};
11use futures::channel::mpsc;
12use futures::{FutureExt as _, StreamExt as _};
13use gpui::{AppContext, Application, AsyncApp};
14use gpui::{Entity, Task};
15use language::{Bias, BufferSnapshot, LanguageServerId, Point};
16use language::{Buffer, OffsetRangeExt};
17use language::{LanguageId, ParseStatus};
18use language_model::LlmApiToken;
19use ordered_float::OrderedFloat;
20use project::{Project, ProjectEntryId, ProjectPath, Worktree};
21use release_channel::AppVersion;
22use reqwest_client::ReqwestClient;
23use serde::{Deserialize, Deserializer, Serialize, Serializer};
24use serde_json::json;
25use std::cmp::Reverse;
26use std::collections::{HashMap, HashSet};
27use std::fmt::{self, Display};
28use std::fs::File;
29use std::hash::Hash;
30use std::hash::Hasher;
31use std::io::{BufRead, BufReader, BufWriter, Write as _};
32use std::ops::Range;
33use std::path::{Path, PathBuf};
34use std::process::exit;
35use std::str::FromStr;
36use std::sync::atomic::AtomicUsize;
37use std::sync::{Arc, atomic};
38use std::time::Duration;
39use util::paths::PathStyle;
40use util::rel_path::RelPath;
41use util::{RangeExt, ResultExt as _};
42use zeta::{PerformPredictEditsParams, Zeta};
43
44use crate::headless::ZetaCliAppState;
45
46#[derive(Parser, Debug)]
47#[command(name = "zeta")]
48struct ZetaCliArgs {
49 #[command(subcommand)]
50 command: Commands,
51}
52
53#[derive(Subcommand, Debug)]
54enum Commands {
55 Context(ContextArgs),
56 Zeta2Context {
57 #[clap(flatten)]
58 zeta2_args: Zeta2Args,
59 #[clap(flatten)]
60 context_args: ContextArgs,
61 },
62 Predict {
63 #[arg(long)]
64 predict_edits_body: Option<FileOrStdin>,
65 #[clap(flatten)]
66 context_args: Option<ContextArgs>,
67 },
68 RetrievalStats {
69 #[clap(flatten)]
70 zeta2_args: Zeta2Args,
71 #[arg(long)]
72 worktree: PathBuf,
73 #[arg(long)]
74 extension: Option<String>,
75 #[arg(long)]
76 limit: Option<usize>,
77 #[arg(long)]
78 skip: Option<usize>,
79 },
80}
81
82#[derive(Debug, Args)]
83#[group(requires = "worktree")]
84struct ContextArgs {
85 #[arg(long)]
86 worktree: PathBuf,
87 #[arg(long)]
88 cursor: SourceLocation,
89 #[arg(long)]
90 use_language_server: bool,
91 #[arg(long)]
92 events: Option<FileOrStdin>,
93}
94
95#[derive(Debug, Args)]
96struct Zeta2Args {
97 #[arg(long, default_value_t = 8192)]
98 max_prompt_bytes: usize,
99 #[arg(long, default_value_t = 2048)]
100 max_excerpt_bytes: usize,
101 #[arg(long, default_value_t = 1024)]
102 min_excerpt_bytes: usize,
103 #[arg(long, default_value_t = 0.66)]
104 target_before_cursor_over_total_bytes: f32,
105 #[arg(long, default_value_t = 1024)]
106 max_diagnostic_bytes: usize,
107 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
108 prompt_format: PromptFormat,
109 #[arg(long, value_enum, default_value_t = Default::default())]
110 output_format: OutputFormat,
111 #[arg(long, default_value_t = 42)]
112 file_indexing_parallelism: usize,
113 #[arg(long, default_value_t = false)]
114 disable_imports_gathering: bool,
115}
116
117#[derive(clap::ValueEnum, Default, Debug, Clone)]
118enum PromptFormat {
119 #[default]
120 MarkedExcerpt,
121 LabeledSections,
122 OnlySnippets,
123}
124
125impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
126 fn into(self) -> predict_edits_v3::PromptFormat {
127 match self {
128 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
129 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
130 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
131 }
132 }
133}
134
135#[derive(clap::ValueEnum, Default, Debug, Clone)]
136enum OutputFormat {
137 #[default]
138 Prompt,
139 Request,
140 Full,
141}
142
143#[derive(Debug, Clone)]
144enum FileOrStdin {
145 File(PathBuf),
146 Stdin,
147}
148
149impl FileOrStdin {
150 async fn read_to_string(&self) -> Result<String, std::io::Error> {
151 match self {
152 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
153 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
154 }
155 }
156}
157
158impl FromStr for FileOrStdin {
159 type Err = <PathBuf as FromStr>::Err;
160
161 fn from_str(s: &str) -> Result<Self, Self::Err> {
162 match s {
163 "-" => Ok(Self::Stdin),
164 _ => Ok(Self::File(PathBuf::from_str(s)?)),
165 }
166 }
167}
168
169#[derive(Debug, Clone, Hash, Eq, PartialEq)]
170struct SourceLocation {
171 path: Arc<RelPath>,
172 point: Point,
173}
174
175impl Serialize for SourceLocation {
176 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
177 where
178 S: Serializer,
179 {
180 serializer.serialize_str(&self.to_string())
181 }
182}
183
184impl<'de> Deserialize<'de> for SourceLocation {
185 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
186 where
187 D: Deserializer<'de>,
188 {
189 let s = String::deserialize(deserializer)?;
190 s.parse().map_err(serde::de::Error::custom)
191 }
192}
193
194impl Display for SourceLocation {
195 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196 write!(
197 f,
198 "{}:{}:{}",
199 self.path.display(PathStyle::Posix),
200 self.point.row + 1,
201 self.point.column + 1
202 )
203 }
204}
205
206impl FromStr for SourceLocation {
207 type Err = anyhow::Error;
208
209 fn from_str(s: &str) -> Result<Self> {
210 let parts: Vec<&str> = s.split(':').collect();
211 if parts.len() != 3 {
212 return Err(anyhow!(
213 "Invalid source location. Expected 'file.rs:line:column', got '{}'",
214 s
215 ));
216 }
217
218 let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
219 let line: u32 = parts[1]
220 .parse()
221 .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
222 let column: u32 = parts[2]
223 .parse()
224 .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
225
226 // Convert from 1-based to 0-based indexing
227 let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
228
229 Ok(SourceLocation { path, point })
230 }
231}
232
233enum GetContextOutput {
234 Zeta1(zeta::GatherContextOutput),
235 Zeta2(String),
236}
237
238async fn get_context(
239 zeta2_args: Option<Zeta2Args>,
240 args: ContextArgs,
241 app_state: &Arc<ZetaCliAppState>,
242 cx: &mut AsyncApp,
243) -> Result<GetContextOutput> {
244 let ContextArgs {
245 worktree: worktree_path,
246 cursor,
247 use_language_server,
248 events,
249 } = args;
250
251 let worktree_path = worktree_path.canonicalize()?;
252
253 let project = cx.update(|cx| {
254 Project::local(
255 app_state.client.clone(),
256 app_state.node_runtime.clone(),
257 app_state.user_store.clone(),
258 app_state.languages.clone(),
259 app_state.fs.clone(),
260 None,
261 cx,
262 )
263 })?;
264
265 let worktree = project
266 .update(cx, |project, cx| {
267 project.create_worktree(&worktree_path, true, cx)
268 })?
269 .await?;
270
271 let mut ready_languages = HashSet::default();
272 let (_lsp_open_handle, buffer) = if use_language_server {
273 let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
274 project.clone(),
275 worktree.clone(),
276 cursor.path.clone(),
277 &mut ready_languages,
278 cx,
279 )
280 .await?;
281 (Some(lsp_open_handle), buffer)
282 } else {
283 let buffer =
284 open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
285 (None, buffer)
286 };
287
288 let full_path_str = worktree
289 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
290 .display(PathStyle::local())
291 .to_string();
292
293 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
294 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
295 if clipped_cursor != cursor.point {
296 let max_row = snapshot.max_point().row;
297 if cursor.point.row < max_row {
298 return Err(anyhow!(
299 "Cursor position {:?} is out of bounds (line length is {})",
300 cursor.point,
301 snapshot.line_len(cursor.point.row)
302 ));
303 } else {
304 return Err(anyhow!(
305 "Cursor position {:?} is out of bounds (max row is {})",
306 cursor.point,
307 max_row
308 ));
309 }
310 }
311
312 let events = match events {
313 Some(events) => events.read_to_string().await?,
314 None => String::new(),
315 };
316
317 if let Some(zeta2_args) = zeta2_args {
318 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
319 // the whole worktree.
320 worktree
321 .read_with(cx, |worktree, _cx| {
322 worktree.as_local().unwrap().scan_complete()
323 })?
324 .await;
325 let output = cx
326 .update(|cx| {
327 let zeta = cx.new(|cx| {
328 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
329 });
330 let indexing_done_task = zeta.update(cx, |zeta, cx| {
331 zeta.set_options(zeta2_args.to_options(true));
332 zeta.register_buffer(&buffer, &project, cx);
333 zeta.wait_for_initial_indexing(&project, cx)
334 });
335 cx.spawn(async move |cx| {
336 indexing_done_task.await?;
337 let request = zeta
338 .update(cx, |zeta, cx| {
339 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
340 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
341 })?
342 .await?;
343
344 let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
345 let (prompt_string, section_labels) = planned_prompt.to_prompt_string()?;
346
347 match zeta2_args.output_format {
348 OutputFormat::Prompt => anyhow::Ok(prompt_string),
349 OutputFormat::Request => {
350 anyhow::Ok(serde_json::to_string_pretty(&request)?)
351 }
352 OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
353 "request": request,
354 "prompt": prompt_string,
355 "section_labels": section_labels,
356 }))?),
357 }
358 })
359 })?
360 .await?;
361 Ok(GetContextOutput::Zeta2(output))
362 } else {
363 let prompt_for_events = move || (events, 0);
364 Ok(GetContextOutput::Zeta1(
365 cx.update(|cx| {
366 zeta::gather_context(
367 full_path_str,
368 &snapshot,
369 clipped_cursor,
370 prompt_for_events,
371 cx,
372 )
373 })?
374 .await?,
375 ))
376 }
377}
378
379impl Zeta2Args {
380 fn to_options(&self, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions {
381 zeta2::ZetaOptions {
382 context: EditPredictionContextOptions {
383 use_imports: !self.disable_imports_gathering,
384 excerpt: EditPredictionExcerptOptions {
385 max_bytes: self.max_excerpt_bytes,
386 min_bytes: self.min_excerpt_bytes,
387 target_before_cursor_over_total_bytes: self
388 .target_before_cursor_over_total_bytes,
389 },
390 score: EditPredictionScoreOptions {
391 omit_excerpt_overlaps,
392 },
393 },
394 max_diagnostic_bytes: self.max_diagnostic_bytes,
395 max_prompt_bytes: self.max_prompt_bytes,
396 prompt_format: self.prompt_format.clone().into(),
397 file_indexing_parallelism: self.file_indexing_parallelism,
398 }
399 }
400}
401
402pub async fn retrieval_stats(
403 worktree: PathBuf,
404 app_state: Arc<ZetaCliAppState>,
405 only_extension: Option<String>,
406 file_limit: Option<usize>,
407 skip_files: Option<usize>,
408 options: zeta2::ZetaOptions,
409 cx: &mut AsyncApp,
410) -> Result<String> {
411 let options = Arc::new(options);
412 let worktree_path = worktree.canonicalize()?;
413
414 let project = cx.update(|cx| {
415 Project::local(
416 app_state.client.clone(),
417 app_state.node_runtime.clone(),
418 app_state.user_store.clone(),
419 app_state.languages.clone(),
420 app_state.fs.clone(),
421 None,
422 cx,
423 )
424 })?;
425
426 let worktree = project
427 .update(cx, |project, cx| {
428 project.create_worktree(&worktree_path, true, cx)
429 })?
430 .await?;
431
432 // wait for worktree scan so that wait_for_initial_file_indexing waits for the whole worktree.
433 worktree
434 .read_with(cx, |worktree, _cx| {
435 worktree.as_local().unwrap().scan_complete()
436 })?
437 .await;
438
439 let index = cx.new(|cx| SyntaxIndex::new(&project, options.file_indexing_parallelism, cx))?;
440 index
441 .read_with(cx, |index, cx| index.wait_for_initial_file_indexing(cx))?
442 .await?;
443 let indexed_files = index
444 .read_with(cx, |index, cx| index.indexed_file_paths(cx))?
445 .await;
446 let mut filtered_files = indexed_files
447 .into_iter()
448 .filter(|project_path| {
449 let file_extension = project_path.path.extension();
450 if let Some(only_extension) = only_extension.as_ref() {
451 file_extension.is_some_and(|extension| extension == only_extension)
452 } else {
453 file_extension
454 .is_some_and(|extension| !["md", "json", "sh", "diff"].contains(&extension))
455 }
456 })
457 .collect::<Vec<_>>();
458 filtered_files.sort_by(|a, b| a.path.cmp(&b.path));
459
460 let index_state = index.read_with(cx, |index, _cx| index.state().clone())?;
461 cx.update(|_| {
462 drop(index);
463 })?;
464 let index_state = Arc::new(
465 Arc::into_inner(index_state)
466 .context("Index state had more than 1 reference")?
467 .into_inner(),
468 );
469
470 struct FileSnapshot {
471 project_entry_id: ProjectEntryId,
472 snapshot: BufferSnapshot,
473 hash: u64,
474 parent_abs_path: Arc<Path>,
475 }
476
477 let files: Vec<FileSnapshot> = futures::future::try_join_all({
478 filtered_files
479 .iter()
480 .map(|file| {
481 let buffer_task =
482 open_buffer(project.clone(), worktree.clone(), file.path.clone(), cx);
483 cx.spawn(async move |cx| {
484 let buffer = buffer_task.await?;
485 let (project_entry_id, parent_abs_path, snapshot) =
486 buffer.read_with(cx, |buffer, cx| {
487 let file = project::File::from_dyn(buffer.file()).unwrap();
488 let project_entry_id = file.project_entry_id().unwrap();
489 let mut parent_abs_path = file.worktree.read(cx).absolutize(&file.path);
490 if !parent_abs_path.pop() {
491 panic!("Invalid worktree path");
492 }
493
494 (project_entry_id, parent_abs_path, buffer.snapshot())
495 })?;
496
497 anyhow::Ok(
498 cx.background_spawn(async move {
499 let mut hasher = collections::FxHasher::default();
500 snapshot.text().hash(&mut hasher);
501 FileSnapshot {
502 project_entry_id,
503 snapshot,
504 hash: hasher.finish(),
505 parent_abs_path: parent_abs_path.into(),
506 }
507 })
508 .await,
509 )
510 })
511 })
512 .collect::<Vec<_>>()
513 })
514 .await?;
515
516 let mut file_snapshots = HashMap::default();
517 let mut hasher = collections::FxHasher::default();
518 for FileSnapshot {
519 project_entry_id,
520 snapshot,
521 hash,
522 ..
523 } in &files
524 {
525 file_snapshots.insert(*project_entry_id, snapshot.clone());
526 hash.hash(&mut hasher);
527 }
528 let files_hash = hasher.finish();
529 let file_snapshots = Arc::new(file_snapshots);
530
531 let lsp_definitions_path = std::env::current_dir()?.join(format!(
532 "target/zeta2-lsp-definitions-{:x}.jsonl",
533 files_hash
534 ));
535
536 let mut lsp_definitions = HashMap::default();
537 let mut lsp_files = 0;
538
539 if std::fs::exists(&lsp_definitions_path)? {
540 log::info!(
541 "Using cached LSP definitions from {}",
542 lsp_definitions_path.display()
543 );
544
545 let file = File::options()
546 .read(true)
547 .write(true)
548 .open(&lsp_definitions_path)?;
549 let lines = BufReader::new(&file).lines();
550 let mut valid_len: usize = 0;
551
552 for (line, expected_file) in lines.zip(files.iter()) {
553 let line = line?;
554 let FileLspDefinitions { path, references } = match serde_json::from_str(&line) {
555 Ok(ok) => ok,
556 Err(_) => {
557 log::error!("Found invalid cache line. Truncating to #{lsp_files}.",);
558 file.set_len(valid_len as u64)?;
559 break;
560 }
561 };
562 let expected_path = expected_file.snapshot.file().unwrap().path().as_unix_str();
563 if expected_path != path.as_ref() {
564 log::error!(
565 "Expected file #{} to be {expected_path}, but found {path}. Truncating to #{lsp_files}.",
566 lsp_files + 1
567 );
568 file.set_len(valid_len as u64)?;
569 break;
570 }
571 for (point, ranges) in references {
572 let Ok(path) = RelPath::new(Path::new(path.as_ref()), PathStyle::Posix) else {
573 log::warn!("Invalid path: {}", path);
574 continue;
575 };
576 lsp_definitions.insert(
577 SourceLocation {
578 path: path.into_arc(),
579 point: point.into(),
580 },
581 ranges,
582 );
583 }
584 lsp_files += 1;
585 valid_len += line.len() + 1
586 }
587 }
588
589 if lsp_files < files.len() {
590 if lsp_files == 0 {
591 log::warn!(
592 "No LSP definitions found, populating {}",
593 lsp_definitions_path.display()
594 );
595 } else {
596 log::warn!("{} files missing from LSP cache", files.len() - lsp_files);
597 }
598
599 gather_lsp_definitions(
600 &lsp_definitions_path,
601 lsp_files,
602 &filtered_files,
603 &worktree,
604 &project,
605 &mut lsp_definitions,
606 cx,
607 )
608 .await?;
609 }
610 let files_len = files.len().min(file_limit.unwrap_or(usize::MAX));
611 let done_count = Arc::new(AtomicUsize::new(0));
612
613 let (output_tx, mut output_rx) = mpsc::unbounded::<RetrievalStatsResult>();
614 let mut output = std::fs::File::create("target/zeta-retrieval-stats.txt")?;
615
616 let tasks = files
617 .into_iter()
618 .skip(skip_files.unwrap_or(0))
619 .take(file_limit.unwrap_or(usize::MAX))
620 .map(|project_file| {
621 let index_state = index_state.clone();
622 let lsp_definitions = lsp_definitions.clone();
623 let options = options.clone();
624 let output_tx = output_tx.clone();
625 let done_count = done_count.clone();
626 let file_snapshots = file_snapshots.clone();
627 cx.background_spawn(async move {
628 let snapshot = project_file.snapshot;
629
630 let full_range = 0..snapshot.len();
631 let references = references_in_range(
632 full_range,
633 &snapshot.text(),
634 ReferenceRegion::Nearby,
635 &snapshot,
636 );
637
638 println!("references: {}", references.len(),);
639
640 let imports = if options.context.use_imports {
641 Imports::gather(&snapshot, Some(&project_file.parent_abs_path))
642 } else {
643 Imports::default()
644 };
645
646 let path = snapshot.file().unwrap().path();
647
648 for reference in references {
649 let query_point = snapshot.offset_to_point(reference.range.start);
650 let source_location = SourceLocation {
651 path: path.clone(),
652 point: query_point,
653 };
654 let lsp_definitions = lsp_definitions
655 .get(&source_location)
656 .cloned()
657 .unwrap_or_else(|| {
658 log::warn!(
659 "No definitions found for source location: {:?}",
660 source_location
661 );
662 Vec::new()
663 });
664
665 let retrieve_result = retrieve_definitions(
666 &reference,
667 &imports,
668 query_point,
669 &snapshot,
670 &index_state,
671 &file_snapshots,
672 &options,
673 )
674 .await?;
675
676 // TODO: LSP returns things like locals, this filters out some of those, but potentially
677 // hides some retrieval issues.
678 if retrieve_result.definitions.is_empty() {
679 continue;
680 }
681
682 let mut best_match = None;
683 let mut has_external_definition = false;
684 let mut in_excerpt = false;
685 for (index, retrieved_definition) in
686 retrieve_result.definitions.iter().enumerate()
687 {
688 for lsp_definition in &lsp_definitions {
689 let SourceRange {
690 path,
691 point_range,
692 offset_range,
693 } = lsp_definition;
694 let lsp_point_range =
695 SerializablePoint::into_language_point_range(point_range.clone());
696 has_external_definition = has_external_definition
697 || path.is_absolute()
698 || path
699 .components()
700 .any(|component| component.as_os_str() == "node_modules");
701 let is_match = path.as_path()
702 == retrieved_definition.path.as_std_path()
703 && retrieved_definition
704 .range
705 .contains_inclusive(&lsp_point_range);
706 if is_match {
707 if best_match.is_none() {
708 best_match = Some(index);
709 }
710 }
711 in_excerpt = in_excerpt
712 || retrieve_result.excerpt_range.as_ref().is_some_and(
713 |excerpt_range| excerpt_range.contains_inclusive(&offset_range),
714 );
715 }
716 }
717
718 let outcome = if let Some(best_match) = best_match {
719 RetrievalOutcome::Match { best_match }
720 } else if has_external_definition {
721 RetrievalOutcome::NoMatchDueToExternalLspDefinitions
722 } else if in_excerpt {
723 RetrievalOutcome::ProbablyLocal
724 } else {
725 RetrievalOutcome::NoMatch
726 };
727
728 let result = RetrievalStatsResult {
729 outcome,
730 path: path.clone(),
731 identifier: reference.identifier,
732 point: query_point,
733 lsp_definitions,
734 retrieved_definitions: retrieve_result.definitions,
735 };
736
737 output_tx.unbounded_send(result).ok();
738 }
739
740 println!(
741 "{:02}/{:02} done",
742 done_count.fetch_add(1, atomic::Ordering::Relaxed) + 1,
743 files_len,
744 );
745
746 anyhow::Ok(())
747 })
748 })
749 .collect::<Vec<_>>();
750
751 drop(output_tx);
752
753 let results_task = cx.background_spawn(async move {
754 let mut results = Vec::new();
755 while let Some(result) = output_rx.next().await {
756 output
757 .write_all(format!("{:#?}\n", result).as_bytes())
758 .log_err();
759 results.push(result)
760 }
761 results
762 });
763
764 futures::future::try_join_all(tasks).await?;
765 println!("Tasks completed");
766 let results = results_task.await;
767 println!("Results received");
768
769 let mut references_count = 0;
770
771 let mut included_count = 0;
772 let mut both_absent_count = 0;
773
774 let mut retrieved_count = 0;
775 let mut top_match_count = 0;
776 let mut non_top_match_count = 0;
777 let mut ranking_involved_top_match_count = 0;
778
779 let mut no_match_count = 0;
780 let mut no_match_none_retrieved = 0;
781 let mut no_match_wrong_retrieval = 0;
782
783 let mut expected_no_match_count = 0;
784 let mut in_excerpt_count = 0;
785 let mut external_definition_count = 0;
786
787 for result in results {
788 references_count += 1;
789 match &result.outcome {
790 RetrievalOutcome::Match { best_match } => {
791 included_count += 1;
792 retrieved_count += 1;
793 let multiple = result.retrieved_definitions.len() > 1;
794 if *best_match == 0 {
795 top_match_count += 1;
796 if multiple {
797 ranking_involved_top_match_count += 1;
798 }
799 } else {
800 non_top_match_count += 1;
801 }
802 }
803 RetrievalOutcome::NoMatch => {
804 if result.lsp_definitions.is_empty() {
805 included_count += 1;
806 both_absent_count += 1;
807 } else {
808 no_match_count += 1;
809 if result.retrieved_definitions.is_empty() {
810 no_match_none_retrieved += 1;
811 } else {
812 no_match_wrong_retrieval += 1;
813 }
814 }
815 }
816 RetrievalOutcome::NoMatchDueToExternalLspDefinitions => {
817 expected_no_match_count += 1;
818 external_definition_count += 1;
819 }
820 RetrievalOutcome::ProbablyLocal => {
821 included_count += 1;
822 in_excerpt_count += 1;
823 }
824 }
825 }
826
827 fn count_and_percentage(part: usize, total: usize) -> String {
828 format!("{} ({:.2}%)", part, (part as f64 / total as f64) * 100.0)
829 }
830
831 println!("");
832 println!("╮ references: {}", references_count);
833 println!(
834 "├─╮ included: {}",
835 count_and_percentage(included_count, references_count),
836 );
837 println!(
838 "│ ├─╮ retrieved: {}",
839 count_and_percentage(retrieved_count, references_count)
840 );
841 println!(
842 "│ │ ├─╮ top match : {}",
843 count_and_percentage(top_match_count, retrieved_count)
844 );
845 println!(
846 "│ │ │ ╰─╴ involving ranking: {}",
847 count_and_percentage(ranking_involved_top_match_count, top_match_count)
848 );
849 println!(
850 "│ │ ╰─╴ non-top match: {}",
851 count_and_percentage(non_top_match_count, retrieved_count)
852 );
853 println!(
854 "│ ├─╴ both absent: {}",
855 count_and_percentage(both_absent_count, included_count)
856 );
857 println!(
858 "│ ╰─╴ in excerpt: {}",
859 count_and_percentage(in_excerpt_count, included_count)
860 );
861 println!(
862 "├─╮ no match: {}",
863 count_and_percentage(no_match_count, references_count)
864 );
865 println!(
866 "│ ├─╴ none retrieved: {}",
867 count_and_percentage(no_match_none_retrieved, no_match_count)
868 );
869 println!(
870 "│ ╰─╴ wrong retrieval: {}",
871 count_and_percentage(no_match_wrong_retrieval, no_match_count)
872 );
873 println!(
874 "╰─╮ expected no match: {}",
875 count_and_percentage(expected_no_match_count, references_count)
876 );
877 println!(
878 " ╰─╴ external definition: {}",
879 count_and_percentage(external_definition_count, expected_no_match_count)
880 );
881
882 println!("");
883 println!("LSP definition cache at {}", lsp_definitions_path.display());
884
885 Ok("".to_string())
886}
887
888struct RetrieveResult {
889 definitions: Vec<RetrievedDefinition>,
890 excerpt_range: Option<Range<usize>>,
891}
892
893async fn retrieve_definitions(
894 reference: &Reference,
895 imports: &Imports,
896 query_point: Point,
897 snapshot: &BufferSnapshot,
898 index: &Arc<SyntaxIndexState>,
899 file_snapshots: &Arc<HashMap<ProjectEntryId, BufferSnapshot>>,
900 options: &Arc<zeta2::ZetaOptions>,
901) -> Result<RetrieveResult> {
902 let mut single_reference_map = HashMap::default();
903 single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]);
904 let edit_prediction_context = EditPredictionContext::gather_context_with_references_fn(
905 query_point,
906 snapshot,
907 imports,
908 &options.context,
909 Some(&index),
910 |_, _, _| single_reference_map,
911 );
912
913 let Some(edit_prediction_context) = edit_prediction_context else {
914 return Ok(RetrieveResult {
915 definitions: Vec::new(),
916 excerpt_range: None,
917 });
918 };
919
920 let mut retrieved_definitions = Vec::new();
921 for scored_declaration in edit_prediction_context.declarations {
922 match &scored_declaration.declaration {
923 Declaration::File {
924 project_entry_id,
925 declaration,
926 ..
927 } => {
928 let Some(snapshot) = file_snapshots.get(&project_entry_id) else {
929 log::error!("bug: file project entry not found");
930 continue;
931 };
932 let path = snapshot.file().unwrap().path().clone();
933 retrieved_definitions.push(RetrievedDefinition {
934 path,
935 range: snapshot.offset_to_point(declaration.item_range.start)
936 ..snapshot.offset_to_point(declaration.item_range.end),
937 score: scored_declaration.score(DeclarationStyle::Declaration),
938 retrieval_score: scored_declaration.retrieval_score(),
939 components: scored_declaration.components,
940 });
941 }
942 Declaration::Buffer {
943 project_entry_id,
944 rope,
945 declaration,
946 ..
947 } => {
948 let Some(snapshot) = file_snapshots.get(&project_entry_id) else {
949 // This case happens when dependency buffers have been opened by
950 // go-to-definition, resulting in single-file worktrees.
951 continue;
952 };
953 let path = snapshot.file().unwrap().path().clone();
954 retrieved_definitions.push(RetrievedDefinition {
955 path,
956 range: rope.offset_to_point(declaration.item_range.start)
957 ..rope.offset_to_point(declaration.item_range.end),
958 score: scored_declaration.score(DeclarationStyle::Declaration),
959 retrieval_score: scored_declaration.retrieval_score(),
960 components: scored_declaration.components,
961 });
962 }
963 }
964 }
965 retrieved_definitions.sort_by_key(|definition| Reverse(OrderedFloat(definition.score)));
966
967 Ok(RetrieveResult {
968 definitions: retrieved_definitions,
969 excerpt_range: Some(edit_prediction_context.excerpt.range),
970 })
971}
972
973async fn gather_lsp_definitions(
974 lsp_definitions_path: &Path,
975 start_index: usize,
976 files: &[ProjectPath],
977 worktree: &Entity<Worktree>,
978 project: &Entity<Project>,
979 definitions: &mut HashMap<SourceLocation, Vec<SourceRange>>,
980 cx: &mut AsyncApp,
981) -> Result<()> {
982 let worktree_id = worktree.read_with(cx, |worktree, _cx| worktree.id())?;
983
984 let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
985 cx.subscribe(&lsp_store, {
986 move |_, event, _| {
987 if let project::LspStoreEvent::LanguageServerUpdate {
988 message:
989 client::proto::update_language_server::Variant::WorkProgress(
990 client::proto::LspWorkProgress {
991 message: Some(message),
992 ..
993 },
994 ),
995 ..
996 } = event
997 {
998 println!("⟲ {message}")
999 }
1000 }
1001 })?
1002 .detach();
1003
1004 let (cache_line_tx, mut cache_line_rx) = mpsc::unbounded::<FileLspDefinitions>();
1005
1006 let cache_file = File::options()
1007 .append(true)
1008 .create(true)
1009 .open(lsp_definitions_path)
1010 .unwrap();
1011
1012 let cache_task = cx.background_spawn(async move {
1013 let mut writer = BufWriter::new(cache_file);
1014 while let Some(line) = cache_line_rx.next().await {
1015 serde_json::to_writer(&mut writer, &line).unwrap();
1016 writer.write_all(&[b'\n']).unwrap();
1017 }
1018 writer.flush().unwrap();
1019 });
1020
1021 let mut error_count = 0;
1022 let mut lsp_open_handles = Vec::new();
1023 let mut ready_languages = HashSet::default();
1024 for (file_index, project_path) in files[start_index..].iter().enumerate() {
1025 println!(
1026 "Processing file {} of {}: {}",
1027 start_index + file_index + 1,
1028 files.len(),
1029 project_path.path.display(PathStyle::Posix)
1030 );
1031
1032 let Some((lsp_open_handle, language_server_id, buffer)) = open_buffer_with_language_server(
1033 project.clone(),
1034 worktree.clone(),
1035 project_path.path.clone(),
1036 &mut ready_languages,
1037 cx,
1038 )
1039 .await
1040 .log_err() else {
1041 continue;
1042 };
1043 lsp_open_handles.push(lsp_open_handle);
1044
1045 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1046 let full_range = 0..snapshot.len();
1047 let references = references_in_range(
1048 full_range,
1049 &snapshot.text(),
1050 ReferenceRegion::Nearby,
1051 &snapshot,
1052 );
1053
1054 loop {
1055 let is_ready = lsp_store
1056 .read_with(cx, |lsp_store, _cx| {
1057 lsp_store
1058 .language_server_statuses
1059 .get(&language_server_id)
1060 .is_some_and(|status| status.pending_work.is_empty())
1061 })
1062 .unwrap();
1063 if is_ready {
1064 break;
1065 }
1066 cx.background_executor()
1067 .timer(Duration::from_millis(10))
1068 .await;
1069 }
1070
1071 let mut cache_line_references = Vec::with_capacity(references.len());
1072
1073 for reference in references {
1074 // TODO: Rename declaration to definition in edit_prediction_context?
1075 let lsp_result = project
1076 .update(cx, |project, cx| {
1077 project.definitions(&buffer, reference.range.start, cx)
1078 })?
1079 .await;
1080
1081 match lsp_result {
1082 Ok(lsp_definitions) => {
1083 let mut targets = Vec::new();
1084 for target in lsp_definitions.unwrap_or_default() {
1085 let buffer = target.target.buffer;
1086 let anchor_range = target.target.range;
1087 buffer.read_with(cx, |buffer, cx| {
1088 let Some(file) = project::File::from_dyn(buffer.file()) else {
1089 return;
1090 };
1091 let file_worktree = file.worktree.read(cx);
1092 let file_worktree_id = file_worktree.id();
1093 // Relative paths for worktree files, absolute for all others
1094 let path = if worktree_id != file_worktree_id {
1095 file.worktree.read(cx).absolutize(&file.path)
1096 } else {
1097 file.path.as_std_path().to_path_buf()
1098 };
1099 let offset_range = anchor_range.to_offset(&buffer);
1100 let point_range = SerializablePoint::from_language_point_range(
1101 offset_range.to_point(&buffer),
1102 );
1103 targets.push(SourceRange {
1104 path,
1105 offset_range,
1106 point_range,
1107 });
1108 })?;
1109 }
1110
1111 let point = snapshot.offset_to_point(reference.range.start);
1112
1113 cache_line_references.push((point.into(), targets.clone()));
1114 definitions.insert(
1115 SourceLocation {
1116 path: project_path.path.clone(),
1117 point,
1118 },
1119 targets,
1120 );
1121 }
1122 Err(err) => {
1123 log::error!("Language server error: {err}");
1124 error_count += 1;
1125 }
1126 }
1127 }
1128
1129 cache_line_tx
1130 .unbounded_send(FileLspDefinitions {
1131 path: project_path.path.as_unix_str().into(),
1132 references: cache_line_references,
1133 })
1134 .log_err();
1135 }
1136
1137 drop(cache_line_tx);
1138
1139 if error_count > 0 {
1140 log::error!("Encountered {} language server errors", error_count);
1141 }
1142
1143 cache_task.await;
1144
1145 Ok(())
1146}
1147
1148#[derive(Serialize, Deserialize)]
1149struct FileLspDefinitions {
1150 path: Arc<str>,
1151 references: Vec<(SerializablePoint, Vec<SourceRange>)>,
1152}
1153
1154#[derive(Debug, Clone, Serialize, Deserialize)]
1155struct SourceRange {
1156 path: PathBuf,
1157 point_range: Range<SerializablePoint>,
1158 offset_range: Range<usize>,
1159}
1160
1161/// Serializes to 1-based row and column indices.
1162#[derive(Debug, Clone, Serialize, Deserialize)]
1163pub struct SerializablePoint {
1164 pub row: u32,
1165 pub column: u32,
1166}
1167
1168impl SerializablePoint {
1169 pub fn into_language_point_range(range: Range<Self>) -> Range<Point> {
1170 range.start.into()..range.end.into()
1171 }
1172
1173 pub fn from_language_point_range(range: Range<Point>) -> Range<Self> {
1174 range.start.into()..range.end.into()
1175 }
1176}
1177
1178impl From<Point> for SerializablePoint {
1179 fn from(point: Point) -> Self {
1180 SerializablePoint {
1181 row: point.row + 1,
1182 column: point.column + 1,
1183 }
1184 }
1185}
1186
1187impl From<SerializablePoint> for Point {
1188 fn from(serializable: SerializablePoint) -> Self {
1189 Point {
1190 row: serializable.row.saturating_sub(1),
1191 column: serializable.column.saturating_sub(1),
1192 }
1193 }
1194}
1195
1196#[derive(Debug)]
1197struct RetrievalStatsResult {
1198 outcome: RetrievalOutcome,
1199 #[allow(dead_code)]
1200 path: Arc<RelPath>,
1201 #[allow(dead_code)]
1202 identifier: Identifier,
1203 #[allow(dead_code)]
1204 point: Point,
1205 #[allow(dead_code)]
1206 lsp_definitions: Vec<SourceRange>,
1207 retrieved_definitions: Vec<RetrievedDefinition>,
1208}
1209
1210#[derive(Debug)]
1211enum RetrievalOutcome {
1212 Match {
1213 /// Lowest index within retrieved_definitions that matches an LSP definition.
1214 best_match: usize,
1215 },
1216 ProbablyLocal,
1217 NoMatch,
1218 NoMatchDueToExternalLspDefinitions,
1219}
1220
1221#[derive(Debug)]
1222struct RetrievedDefinition {
1223 path: Arc<RelPath>,
1224 range: Range<Point>,
1225 score: f32,
1226 #[allow(dead_code)]
1227 retrieval_score: f32,
1228 #[allow(dead_code)]
1229 components: DeclarationScoreComponents,
1230}
1231
1232pub fn open_buffer(
1233 project: Entity<Project>,
1234 worktree: Entity<Worktree>,
1235 path: Arc<RelPath>,
1236 cx: &AsyncApp,
1237) -> Task<Result<Entity<Buffer>>> {
1238 cx.spawn(async move |cx| {
1239 let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
1240 worktree_id: worktree.id(),
1241 path,
1242 })?;
1243
1244 let buffer = project
1245 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
1246 .await?;
1247
1248 let mut parse_status = buffer.read_with(cx, |buffer, _cx| buffer.parse_status())?;
1249 while *parse_status.borrow() != ParseStatus::Idle {
1250 parse_status.changed().await?;
1251 }
1252
1253 Ok(buffer)
1254 })
1255}
1256
1257pub async fn open_buffer_with_language_server(
1258 project: Entity<Project>,
1259 worktree: Entity<Worktree>,
1260 path: Arc<RelPath>,
1261 ready_languages: &mut HashSet<LanguageId>,
1262 cx: &mut AsyncApp,
1263) -> Result<(Entity<Entity<Buffer>>, LanguageServerId, Entity<Buffer>)> {
1264 let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?;
1265
1266 let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
1267 (
1268 project.register_buffer_with_language_servers(&buffer, cx),
1269 project.path_style(cx),
1270 )
1271 })?;
1272
1273 let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
1274 buffer.language().map(|language| language.id())
1275 })?
1276 else {
1277 return Err(anyhow!("No language for {}", path.display(path_style)));
1278 };
1279
1280 let log_prefix = path.display(path_style);
1281 if !ready_languages.contains(&language_id) {
1282 wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
1283 ready_languages.insert(language_id);
1284 }
1285
1286 let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
1287
1288 // hacky wait for buffer to be registered with the language server
1289 for _ in 0..100 {
1290 let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| {
1291 buffer.update(cx, |buffer, cx| {
1292 lsp_store
1293 .language_servers_for_local_buffer(&buffer, cx)
1294 .next()
1295 .map(|(_, language_server)| language_server.server_id())
1296 })
1297 })?
1298 else {
1299 cx.background_executor()
1300 .timer(Duration::from_millis(10))
1301 .await;
1302 continue;
1303 };
1304
1305 return Ok((lsp_open_handle, language_server_id, buffer));
1306 }
1307
1308 return Err(anyhow!("No language server found for buffer"));
1309}
1310
1311// TODO: Dedupe with similar function in crates/eval/src/instance.rs
1312pub fn wait_for_lang_server(
1313 project: &Entity<Project>,
1314 buffer: &Entity<Buffer>,
1315 log_prefix: String,
1316 cx: &mut AsyncApp,
1317) -> Task<Result<()>> {
1318 println!("{}⏵ Waiting for language server", log_prefix);
1319
1320 let (mut tx, mut rx) = mpsc::channel(1);
1321
1322 let lsp_store = project
1323 .read_with(cx, |project, _| project.lsp_store())
1324 .unwrap();
1325
1326 let has_lang_server = buffer
1327 .update(cx, |buffer, cx| {
1328 lsp_store.update(cx, |lsp_store, cx| {
1329 lsp_store
1330 .language_servers_for_local_buffer(buffer, cx)
1331 .next()
1332 .is_some()
1333 })
1334 })
1335 .unwrap_or(false);
1336
1337 if has_lang_server {
1338 project
1339 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
1340 .unwrap()
1341 .detach();
1342 }
1343 let (mut added_tx, mut added_rx) = mpsc::channel(1);
1344
1345 let subscriptions = [
1346 cx.subscribe(&lsp_store, {
1347 let log_prefix = log_prefix.clone();
1348 move |_, event, _| {
1349 if let project::LspStoreEvent::LanguageServerUpdate {
1350 message:
1351 client::proto::update_language_server::Variant::WorkProgress(
1352 client::proto::LspWorkProgress {
1353 message: Some(message),
1354 ..
1355 },
1356 ),
1357 ..
1358 } = event
1359 {
1360 println!("{}⟲ {message}", log_prefix)
1361 }
1362 }
1363 }),
1364 cx.subscribe(project, {
1365 let buffer = buffer.clone();
1366 move |project, event, cx| match event {
1367 project::Event::LanguageServerAdded(_, _, _) => {
1368 let buffer = buffer.clone();
1369 project
1370 .update(cx, |project, cx| project.save_buffer(buffer, cx))
1371 .detach();
1372 added_tx.try_send(()).ok();
1373 }
1374 project::Event::DiskBasedDiagnosticsFinished { .. } => {
1375 tx.try_send(()).ok();
1376 }
1377 _ => {}
1378 }
1379 }),
1380 ];
1381
1382 cx.spawn(async move |cx| {
1383 if !has_lang_server {
1384 // some buffers never have a language server, so this aborts quickly in that case.
1385 let timeout = cx.background_executor().timer(Duration::from_secs(5));
1386 futures::select! {
1387 _ = added_rx.next() => {},
1388 _ = timeout.fuse() => {
1389 anyhow::bail!("Waiting for language server add timed out after 5 seconds");
1390 }
1391 };
1392 }
1393 let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
1394 let result = futures::select! {
1395 _ = rx.next() => {
1396 println!("{}⚑ Language server idle", log_prefix);
1397 anyhow::Ok(())
1398 },
1399 _ = timeout.fuse() => {
1400 anyhow::bail!("LSP wait timed out after 5 minutes");
1401 }
1402 };
1403 drop(subscriptions);
1404 result
1405 })
1406}
1407
1408fn main() {
1409 zlog::init();
1410 zlog::init_output_stderr();
1411 let args = ZetaCliArgs::parse();
1412 let http_client = Arc::new(ReqwestClient::new());
1413 let app = Application::headless().with_http_client(http_client);
1414
1415 app.run(move |cx| {
1416 let app_state = Arc::new(headless::init(cx));
1417 cx.spawn(async move |cx| {
1418 let result = match args.command {
1419 Commands::Zeta2Context {
1420 zeta2_args,
1421 context_args,
1422 } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
1423 Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
1424 Ok(GetContextOutput::Zeta2(output)) => Ok(output),
1425 Err(err) => Err(err),
1426 },
1427 Commands::Context(context_args) => {
1428 match get_context(None, context_args, &app_state, cx).await {
1429 Ok(GetContextOutput::Zeta1(output)) => {
1430 Ok(serde_json::to_string_pretty(&output.body).unwrap())
1431 }
1432 Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
1433 Err(err) => Err(err),
1434 }
1435 }
1436 Commands::Predict {
1437 predict_edits_body,
1438 context_args,
1439 } => {
1440 cx.spawn(async move |cx| {
1441 let app_version = cx.update(|cx| AppVersion::global(cx))?;
1442 app_state.client.sign_in(true, cx).await?;
1443 let llm_token = LlmApiToken::default();
1444 llm_token.refresh(&app_state.client).await?;
1445
1446 let predict_edits_body =
1447 if let Some(predict_edits_body) = predict_edits_body {
1448 serde_json::from_str(&predict_edits_body.read_to_string().await?)?
1449 } else if let Some(context_args) = context_args {
1450 match get_context(None, context_args, &app_state, cx).await? {
1451 GetContextOutput::Zeta1(output) => output.body,
1452 GetContextOutput::Zeta2 { .. } => unreachable!(),
1453 }
1454 } else {
1455 return Err(anyhow!(
1456 "Expected either --predict-edits-body-file \
1457 or the required args of the `context` command."
1458 ));
1459 };
1460
1461 let (response, _usage) =
1462 Zeta::perform_predict_edits(PerformPredictEditsParams {
1463 client: app_state.client.clone(),
1464 llm_token,
1465 app_version,
1466 body: predict_edits_body,
1467 })
1468 .await?;
1469
1470 Ok(response.output_excerpt)
1471 })
1472 .await
1473 }
1474 Commands::RetrievalStats {
1475 zeta2_args,
1476 worktree,
1477 extension,
1478 limit,
1479 skip,
1480 } => {
1481 retrieval_stats(
1482 worktree,
1483 app_state,
1484 extension,
1485 limit,
1486 skip,
1487 (&zeta2_args).to_options(false),
1488 cx,
1489 )
1490 .await
1491 }
1492 };
1493 match result {
1494 Ok(output) => {
1495 println!("{}", output);
1496 let _ = cx.update(|cx| cx.quit());
1497 }
1498 Err(e) => {
1499 eprintln!("Failed: {:?}", e);
1500 exit(1);
1501 }
1502 }
1503 })
1504 .detach();
1505 });
1506}