1mod headless;
2
3use anyhow::{Context as _, Result, anyhow};
4use clap::{Args, Parser, Subcommand};
5use cloud_llm_client::predict_edits_v3::{self, DeclarationScoreComponents};
6use edit_prediction_context::{
7 Declaration, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
8 EditPredictionExcerptOptions, EditPredictionScoreOptions, Identifier, Imports, Reference,
9 ReferenceRegion, SyntaxIndex, SyntaxIndexState, references_in_range,
10};
11use futures::channel::mpsc;
12use futures::{FutureExt as _, StreamExt as _};
13use gpui::{AppContext, Application, AsyncApp};
14use gpui::{Entity, Task};
15use language::{Bias, BufferSnapshot, LanguageServerId, Point};
16use language::{Buffer, OffsetRangeExt};
17use language::{LanguageId, ParseStatus};
18use language_model::LlmApiToken;
19use ordered_float::OrderedFloat;
20use project::{Project, ProjectEntryId, ProjectPath, Worktree};
21use release_channel::AppVersion;
22use reqwest_client::ReqwestClient;
23use serde::{Deserialize, Deserializer, Serialize, Serializer};
24use serde_json::json;
25use std::cmp::Reverse;
26use std::collections::{HashMap, HashSet};
27use std::fmt::{self, Display};
28use std::fs::File;
29use std::hash::Hash;
30use std::hash::Hasher;
31use std::io::Write as _;
32use std::ops::Range;
33use std::path::{Path, PathBuf};
34use std::process::exit;
35use std::str::FromStr;
36use std::sync::atomic::AtomicUsize;
37use std::sync::{Arc, atomic};
38use std::time::Duration;
39use util::paths::PathStyle;
40use util::rel_path::RelPath;
41use util::{RangeExt, ResultExt as _};
42use zeta::{PerformPredictEditsParams, Zeta};
43
44use crate::headless::ZetaCliAppState;
45
46#[derive(Parser, Debug)]
47#[command(name = "zeta")]
48struct ZetaCliArgs {
49 #[command(subcommand)]
50 command: Commands,
51}
52
53#[derive(Subcommand, Debug)]
54enum Commands {
55 Context(ContextArgs),
56 Zeta2Context {
57 #[clap(flatten)]
58 zeta2_args: Zeta2Args,
59 #[clap(flatten)]
60 context_args: ContextArgs,
61 },
62 Predict {
63 #[arg(long)]
64 predict_edits_body: Option<FileOrStdin>,
65 #[clap(flatten)]
66 context_args: Option<ContextArgs>,
67 },
68 RetrievalStats {
69 #[clap(flatten)]
70 zeta2_args: Zeta2Args,
71 #[arg(long)]
72 worktree: PathBuf,
73 #[arg(long)]
74 extension: Option<String>,
75 #[arg(long)]
76 limit: Option<usize>,
77 #[arg(long)]
78 skip: Option<usize>,
79 },
80}
81
82#[derive(Debug, Args)]
83#[group(requires = "worktree")]
84struct ContextArgs {
85 #[arg(long)]
86 worktree: PathBuf,
87 #[arg(long)]
88 cursor: SourceLocation,
89 #[arg(long)]
90 use_language_server: bool,
91 #[arg(long)]
92 events: Option<FileOrStdin>,
93}
94
95#[derive(Debug, Args)]
96struct Zeta2Args {
97 #[arg(long, default_value_t = 8192)]
98 max_prompt_bytes: usize,
99 #[arg(long, default_value_t = 2048)]
100 max_excerpt_bytes: usize,
101 #[arg(long, default_value_t = 1024)]
102 min_excerpt_bytes: usize,
103 #[arg(long, default_value_t = 0.66)]
104 target_before_cursor_over_total_bytes: f32,
105 #[arg(long, default_value_t = 1024)]
106 max_diagnostic_bytes: usize,
107 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
108 prompt_format: PromptFormat,
109 #[arg(long, value_enum, default_value_t = Default::default())]
110 output_format: OutputFormat,
111 #[arg(long, default_value_t = 42)]
112 file_indexing_parallelism: usize,
113 #[arg(long, default_value_t = false)]
114 disable_imports_gathering: bool,
115 #[arg(long, default_value_t = 0.5)]
116 prefilter_score_ratio: f32,
117}
118
119#[derive(clap::ValueEnum, Default, Debug, Clone)]
120enum PromptFormat {
121 #[default]
122 MarkedExcerpt,
123 LabeledSections,
124 OnlySnippets,
125}
126
127impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
128 fn into(self) -> predict_edits_v3::PromptFormat {
129 match self {
130 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
131 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
132 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
133 }
134 }
135}
136
137#[derive(clap::ValueEnum, Default, Debug, Clone)]
138enum OutputFormat {
139 #[default]
140 Prompt,
141 Request,
142 Full,
143}
144
145#[derive(Debug, Clone)]
146enum FileOrStdin {
147 File(PathBuf),
148 Stdin,
149}
150
151impl FileOrStdin {
152 async fn read_to_string(&self) -> Result<String, std::io::Error> {
153 match self {
154 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
155 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
156 }
157 }
158}
159
160impl FromStr for FileOrStdin {
161 type Err = <PathBuf as FromStr>::Err;
162
163 fn from_str(s: &str) -> Result<Self, Self::Err> {
164 match s {
165 "-" => Ok(Self::Stdin),
166 _ => Ok(Self::File(PathBuf::from_str(s)?)),
167 }
168 }
169}
170
171#[derive(Debug, Clone, Hash, Eq, PartialEq)]
172struct SourceLocation {
173 path: Arc<RelPath>,
174 point: Point,
175}
176
177impl Serialize for SourceLocation {
178 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
179 where
180 S: Serializer,
181 {
182 serializer.serialize_str(&self.to_string())
183 }
184}
185
186impl<'de> Deserialize<'de> for SourceLocation {
187 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
188 where
189 D: Deserializer<'de>,
190 {
191 let s = String::deserialize(deserializer)?;
192 s.parse().map_err(serde::de::Error::custom)
193 }
194}
195
196impl Display for SourceLocation {
197 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198 write!(
199 f,
200 "{}:{}:{}",
201 self.path.display(PathStyle::Posix),
202 self.point.row + 1,
203 self.point.column + 1
204 )
205 }
206}
207
208impl FromStr for SourceLocation {
209 type Err = anyhow::Error;
210
211 fn from_str(s: &str) -> Result<Self> {
212 let parts: Vec<&str> = s.split(':').collect();
213 if parts.len() != 3 {
214 return Err(anyhow!(
215 "Invalid source location. Expected 'file.rs:line:column', got '{}'",
216 s
217 ));
218 }
219
220 let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
221 let line: u32 = parts[1]
222 .parse()
223 .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
224 let column: u32 = parts[2]
225 .parse()
226 .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
227
228 // Convert from 1-based to 0-based indexing
229 let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
230
231 Ok(SourceLocation { path, point })
232 }
233}
234
235enum GetContextOutput {
236 Zeta1(zeta::GatherContextOutput),
237 Zeta2(String),
238}
239
240async fn get_context(
241 zeta2_args: Option<Zeta2Args>,
242 args: ContextArgs,
243 app_state: &Arc<ZetaCliAppState>,
244 cx: &mut AsyncApp,
245) -> Result<GetContextOutput> {
246 let ContextArgs {
247 worktree: worktree_path,
248 cursor,
249 use_language_server,
250 events,
251 } = args;
252
253 let worktree_path = worktree_path.canonicalize()?;
254
255 let project = cx.update(|cx| {
256 Project::local(
257 app_state.client.clone(),
258 app_state.node_runtime.clone(),
259 app_state.user_store.clone(),
260 app_state.languages.clone(),
261 app_state.fs.clone(),
262 None,
263 cx,
264 )
265 })?;
266
267 let worktree = project
268 .update(cx, |project, cx| {
269 project.create_worktree(&worktree_path, true, cx)
270 })?
271 .await?;
272
273 let mut ready_languages = HashSet::default();
274 let (_lsp_open_handle, buffer) = if use_language_server {
275 let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
276 project.clone(),
277 worktree.clone(),
278 cursor.path.clone(),
279 &mut ready_languages,
280 cx,
281 )
282 .await?;
283 (Some(lsp_open_handle), buffer)
284 } else {
285 let buffer =
286 open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
287 (None, buffer)
288 };
289
290 let full_path_str = worktree
291 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
292 .display(PathStyle::local())
293 .to_string();
294
295 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
296 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
297 if clipped_cursor != cursor.point {
298 let max_row = snapshot.max_point().row;
299 if cursor.point.row < max_row {
300 return Err(anyhow!(
301 "Cursor position {:?} is out of bounds (line length is {})",
302 cursor.point,
303 snapshot.line_len(cursor.point.row)
304 ));
305 } else {
306 return Err(anyhow!(
307 "Cursor position {:?} is out of bounds (max row is {})",
308 cursor.point,
309 max_row
310 ));
311 }
312 }
313
314 let events = match events {
315 Some(events) => events.read_to_string().await?,
316 None => String::new(),
317 };
318
319 if let Some(zeta2_args) = zeta2_args {
320 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
321 // the whole worktree.
322 worktree
323 .read_with(cx, |worktree, _cx| {
324 worktree.as_local().unwrap().scan_complete()
325 })?
326 .await;
327 let output = cx
328 .update(|cx| {
329 let zeta = cx.new(|cx| {
330 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
331 });
332 let indexing_done_task = zeta.update(cx, |zeta, cx| {
333 zeta.set_options(zeta2_args.to_options(true));
334 zeta.register_buffer(&buffer, &project, cx);
335 zeta.wait_for_initial_indexing(&project, cx)
336 });
337 cx.spawn(async move |cx| {
338 indexing_done_task.await?;
339 let request = zeta
340 .update(cx, |zeta, cx| {
341 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
342 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
343 })?
344 .await?;
345
346 let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
347 let (prompt_string, section_labels) = planned_prompt.to_prompt_string()?;
348
349 match zeta2_args.output_format {
350 OutputFormat::Prompt => anyhow::Ok(prompt_string),
351 OutputFormat::Request => {
352 anyhow::Ok(serde_json::to_string_pretty(&request)?)
353 }
354 OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
355 "request": request,
356 "prompt": prompt_string,
357 "section_labels": section_labels,
358 }))?),
359 }
360 })
361 })?
362 .await?;
363 Ok(GetContextOutput::Zeta2(output))
364 } else {
365 let prompt_for_events = move || (events, 0);
366 Ok(GetContextOutput::Zeta1(
367 cx.update(|cx| {
368 zeta::gather_context(
369 full_path_str,
370 &snapshot,
371 clipped_cursor,
372 prompt_for_events,
373 cx,
374 )
375 })?
376 .await?,
377 ))
378 }
379}
380
381impl Zeta2Args {
382 fn to_options(&self, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions {
383 zeta2::ZetaOptions {
384 context: EditPredictionContextOptions {
385 use_imports: !self.disable_imports_gathering,
386 excerpt: EditPredictionExcerptOptions {
387 max_bytes: self.max_excerpt_bytes,
388 min_bytes: self.min_excerpt_bytes,
389 target_before_cursor_over_total_bytes: self
390 .target_before_cursor_over_total_bytes,
391 },
392 score: EditPredictionScoreOptions {
393 omit_excerpt_overlaps,
394 prefilter_score_ratio: self.prefilter_score_ratio,
395 },
396 },
397 max_diagnostic_bytes: self.max_diagnostic_bytes,
398 max_prompt_bytes: self.max_prompt_bytes,
399 prompt_format: self.prompt_format.clone().into(),
400 file_indexing_parallelism: self.file_indexing_parallelism,
401 }
402 }
403}
404
405pub async fn retrieval_stats(
406 worktree: PathBuf,
407 app_state: Arc<ZetaCliAppState>,
408 only_extension: Option<String>,
409 file_limit: Option<usize>,
410 skip_files: Option<usize>,
411 options: zeta2::ZetaOptions,
412 cx: &mut AsyncApp,
413) -> Result<String> {
414 let options = Arc::new(options);
415 let worktree_path = worktree.canonicalize()?;
416
417 let project = cx.update(|cx| {
418 Project::local(
419 app_state.client.clone(),
420 app_state.node_runtime.clone(),
421 app_state.user_store.clone(),
422 app_state.languages.clone(),
423 app_state.fs.clone(),
424 None,
425 cx,
426 )
427 })?;
428
429 let worktree = project
430 .update(cx, |project, cx| {
431 project.create_worktree(&worktree_path, true, cx)
432 })?
433 .await?;
434
435 // wait for worktree scan so that wait_for_initial_file_indexing waits for the whole worktree.
436 worktree
437 .read_with(cx, |worktree, _cx| {
438 worktree.as_local().unwrap().scan_complete()
439 })?
440 .await;
441
442 let index = cx.new(|cx| SyntaxIndex::new(&project, options.file_indexing_parallelism, cx))?;
443 index
444 .read_with(cx, |index, cx| index.wait_for_initial_file_indexing(cx))?
445 .await?;
446 let indexed_files = index
447 .read_with(cx, |index, cx| index.indexed_file_paths(cx))?
448 .await;
449 let mut filtered_files = indexed_files
450 .into_iter()
451 .filter(|project_path| {
452 let file_extension = project_path.path.extension();
453 if let Some(only_extension) = only_extension.as_ref() {
454 file_extension.is_some_and(|extension| extension == only_extension)
455 } else {
456 file_extension
457 .is_some_and(|extension| !["md", "json", "sh", "diff"].contains(&extension))
458 }
459 })
460 .collect::<Vec<_>>();
461 filtered_files.sort_by(|a, b| a.path.cmp(&b.path));
462
463 let index_state = index.read_with(cx, |index, _cx| index.state().clone())?;
464 cx.update(|_| {
465 drop(index);
466 })?;
467 let index_state = Arc::new(
468 Arc::into_inner(index_state)
469 .context("Index state had more than 1 reference")?
470 .into_inner(),
471 );
472
473 struct FileSnapshot {
474 project_entry_id: ProjectEntryId,
475 snapshot: BufferSnapshot,
476 hash: u64,
477 parent_abs_path: Arc<Path>,
478 }
479
480 let files: Vec<FileSnapshot> = futures::future::try_join_all({
481 filtered_files
482 .iter()
483 .map(|file| {
484 let buffer_task =
485 open_buffer(project.clone(), worktree.clone(), file.path.clone(), cx);
486 cx.spawn(async move |cx| {
487 let buffer = buffer_task.await?;
488 let (project_entry_id, parent_abs_path, snapshot) =
489 buffer.read_with(cx, |buffer, cx| {
490 let file = project::File::from_dyn(buffer.file()).unwrap();
491 let project_entry_id = file.project_entry_id().unwrap();
492 let mut parent_abs_path = file.worktree.read(cx).absolutize(&file.path);
493 if !parent_abs_path.pop() {
494 panic!("Invalid worktree path");
495 }
496
497 (project_entry_id, parent_abs_path, buffer.snapshot())
498 })?;
499
500 anyhow::Ok(
501 cx.background_spawn(async move {
502 let mut hasher = collections::FxHasher::default();
503 snapshot.text().hash(&mut hasher);
504 FileSnapshot {
505 project_entry_id,
506 snapshot,
507 hash: hasher.finish(),
508 parent_abs_path: parent_abs_path.into(),
509 }
510 })
511 .await,
512 )
513 })
514 })
515 .collect::<Vec<_>>()
516 })
517 .await?;
518
519 let mut file_snapshots = HashMap::default();
520 let mut hasher = collections::FxHasher::default();
521 for FileSnapshot {
522 project_entry_id,
523 snapshot,
524 hash,
525 ..
526 } in &files
527 {
528 file_snapshots.insert(*project_entry_id, snapshot.clone());
529 hash.hash(&mut hasher);
530 }
531 let files_hash = hasher.finish();
532 let file_snapshots = Arc::new(file_snapshots);
533
534 let lsp_definitions_path = std::env::current_dir()?.join(format!(
535 "target/zeta2-lsp-definitions-{:x}.json",
536 files_hash
537 ));
538
539 let lsp_definitions: Arc<_> = if std::fs::exists(&lsp_definitions_path)? {
540 log::info!(
541 "Using cached LSP definitions from {}",
542 lsp_definitions_path.display()
543 );
544 serde_json::from_reader(File::open(&lsp_definitions_path)?)?
545 } else {
546 log::warn!(
547 "No LSP definitions found populating {}",
548 lsp_definitions_path.display()
549 );
550 let lsp_definitions =
551 gather_lsp_definitions(&filtered_files, &worktree, &project, cx).await?;
552 serde_json::to_writer_pretty(File::create(&lsp_definitions_path)?, &lsp_definitions)?;
553 lsp_definitions
554 }
555 .into();
556
557 let files_len = files.len().min(file_limit.unwrap_or(usize::MAX));
558 let done_count = Arc::new(AtomicUsize::new(0));
559
560 let (output_tx, mut output_rx) = mpsc::unbounded::<RetrievalStatsResult>();
561 let mut output = std::fs::File::create("target/zeta-retrieval-stats.txt")?;
562
563 let tasks = files
564 .into_iter()
565 .skip(skip_files.unwrap_or(0))
566 .take(file_limit.unwrap_or(usize::MAX))
567 .map(|project_file| {
568 let index_state = index_state.clone();
569 let lsp_definitions = lsp_definitions.clone();
570 let options = options.clone();
571 let output_tx = output_tx.clone();
572 let done_count = done_count.clone();
573 let file_snapshots = file_snapshots.clone();
574 cx.background_spawn(async move {
575 let snapshot = project_file.snapshot;
576
577 let full_range = 0..snapshot.len();
578 let references = references_in_range(
579 full_range,
580 &snapshot.text(),
581 ReferenceRegion::Nearby,
582 &snapshot,
583 );
584
585 println!("references: {}", references.len(),);
586
587 let imports = if options.context.use_imports {
588 Imports::gather(&snapshot, Some(&project_file.parent_abs_path))
589 } else {
590 Imports::default()
591 };
592
593 let path = snapshot.file().unwrap().path();
594
595 for reference in references {
596 let query_point = snapshot.offset_to_point(reference.range.start);
597 let source_location = SourceLocation {
598 path: path.clone(),
599 point: query_point,
600 };
601 let lsp_definitions = lsp_definitions
602 .definitions
603 .get(&source_location)
604 .cloned()
605 .unwrap_or_else(|| {
606 log::warn!(
607 "No definitions found for source location: {:?}",
608 source_location
609 );
610 Vec::new()
611 });
612
613 let retrieve_result = retrieve_definitions(
614 &reference,
615 &imports,
616 query_point,
617 &snapshot,
618 &index_state,
619 &file_snapshots,
620 &options,
621 )
622 .await?;
623
624 // TODO: LSP returns things like locals, this filters out some of those, but potentially
625 // hides some retrieval issues.
626 if retrieve_result.definitions.is_empty() {
627 continue;
628 }
629
630 let mut best_match = None;
631 let mut has_external_definition = false;
632 let mut in_excerpt = false;
633 for (index, retrieved_definition) in
634 retrieve_result.definitions.iter().enumerate()
635 {
636 for lsp_definition in &lsp_definitions {
637 let SourceRange {
638 path,
639 point_range,
640 offset_range,
641 } = lsp_definition;
642 let lsp_point_range =
643 SerializablePoint::into_language_point_range(point_range.clone());
644 has_external_definition = has_external_definition
645 || path.is_absolute()
646 || path
647 .components()
648 .any(|component| component.as_os_str() == "node_modules");
649 let is_match = path.as_path()
650 == retrieved_definition.path.as_std_path()
651 && retrieved_definition
652 .range
653 .contains_inclusive(&lsp_point_range);
654 if is_match {
655 if best_match.is_none() {
656 best_match = Some(index);
657 }
658 }
659 in_excerpt = in_excerpt
660 || retrieve_result.excerpt_range.as_ref().is_some_and(
661 |excerpt_range| excerpt_range.contains_inclusive(&offset_range),
662 );
663 }
664 }
665
666 let outcome = if let Some(best_match) = best_match {
667 RetrievalOutcome::Match { best_match }
668 } else if has_external_definition {
669 RetrievalOutcome::NoMatchDueToExternalLspDefinitions
670 } else if in_excerpt {
671 RetrievalOutcome::ProbablyLocal
672 } else {
673 RetrievalOutcome::NoMatch
674 };
675
676 let result = RetrievalStatsResult {
677 outcome,
678 path: path.clone(),
679 identifier: reference.identifier,
680 point: query_point,
681 lsp_definitions,
682 retrieved_definitions: retrieve_result.definitions,
683 };
684
685 output_tx.unbounded_send(result).ok();
686 }
687
688 println!(
689 "{:02}/{:02} done",
690 done_count.fetch_add(1, atomic::Ordering::Relaxed) + 1,
691 files_len,
692 );
693
694 anyhow::Ok(())
695 })
696 })
697 .collect::<Vec<_>>();
698
699 drop(output_tx);
700
701 let results_task = cx.background_spawn(async move {
702 let mut results = Vec::new();
703 while let Some(result) = output_rx.next().await {
704 output
705 .write_all(format!("{:#?}\n", result).as_bytes())
706 .log_err();
707 results.push(result)
708 }
709 results
710 });
711
712 futures::future::try_join_all(tasks).await?;
713 println!("Tasks completed");
714 let results = results_task.await;
715 println!("Results received");
716
717 let mut references_count = 0;
718
719 let mut included_count = 0;
720 let mut both_absent_count = 0;
721
722 let mut retrieved_count = 0;
723 let mut top_match_count = 0;
724 let mut non_top_match_count = 0;
725 let mut ranking_involved_top_match_count = 0;
726
727 let mut no_match_count = 0;
728 let mut no_match_none_retrieved = 0;
729 let mut no_match_wrong_retrieval = 0;
730
731 let mut expected_no_match_count = 0;
732 let mut in_excerpt_count = 0;
733 let mut external_definition_count = 0;
734
735 for result in results {
736 references_count += 1;
737 match &result.outcome {
738 RetrievalOutcome::Match { best_match } => {
739 included_count += 1;
740 retrieved_count += 1;
741 let multiple = result.retrieved_definitions.len() > 1;
742 if *best_match == 0 {
743 top_match_count += 1;
744 if multiple {
745 ranking_involved_top_match_count += 1;
746 }
747 } else {
748 non_top_match_count += 1;
749 }
750 }
751 RetrievalOutcome::NoMatch => {
752 if result.lsp_definitions.is_empty() {
753 included_count += 1;
754 both_absent_count += 1;
755 } else {
756 no_match_count += 1;
757 if result.retrieved_definitions.is_empty() {
758 no_match_none_retrieved += 1;
759 } else {
760 no_match_wrong_retrieval += 1;
761 }
762 }
763 }
764 RetrievalOutcome::NoMatchDueToExternalLspDefinitions => {
765 expected_no_match_count += 1;
766 external_definition_count += 1;
767 }
768 RetrievalOutcome::ProbablyLocal => {
769 included_count += 1;
770 in_excerpt_count += 1;
771 }
772 }
773 }
774
775 fn count_and_percentage(part: usize, total: usize) -> String {
776 format!("{} ({:.2}%)", part, (part as f64 / total as f64) * 100.0)
777 }
778
779 println!("");
780 println!("╮ references: {}", references_count);
781 println!(
782 "├─╮ included: {}",
783 count_and_percentage(included_count, references_count),
784 );
785 println!(
786 "│ ├─╮ retrieved: {}",
787 count_and_percentage(retrieved_count, references_count)
788 );
789 println!(
790 "│ │ ├─╮ top match : {}",
791 count_and_percentage(top_match_count, retrieved_count)
792 );
793 println!(
794 "│ │ │ ╰─╴ involving ranking: {}",
795 count_and_percentage(ranking_involved_top_match_count, top_match_count)
796 );
797 println!(
798 "│ │ ╰─╴ non-top match: {}",
799 count_and_percentage(non_top_match_count, retrieved_count)
800 );
801 println!(
802 "│ ├─╴ both absent: {}",
803 count_and_percentage(both_absent_count, included_count)
804 );
805 println!(
806 "│ ╰─╴ in excerpt: {}",
807 count_and_percentage(in_excerpt_count, included_count)
808 );
809 println!(
810 "├─╮ no match: {}",
811 count_and_percentage(no_match_count, references_count)
812 );
813 println!(
814 "│ ├─╴ none retrieved: {}",
815 count_and_percentage(no_match_none_retrieved, no_match_count)
816 );
817 println!(
818 "│ ╰─╴ wrong retrieval: {}",
819 count_and_percentage(no_match_wrong_retrieval, no_match_count)
820 );
821 println!(
822 "╰─╮ expected no match: {}",
823 count_and_percentage(expected_no_match_count, references_count)
824 );
825 println!(
826 " ╰─╴ external definition: {}",
827 count_and_percentage(external_definition_count, expected_no_match_count)
828 );
829
830 println!("");
831 println!("LSP definition cache at {}", lsp_definitions_path.display());
832
833 Ok("".to_string())
834}
835
836struct RetrieveResult {
837 definitions: Vec<RetrievedDefinition>,
838 excerpt_range: Option<Range<usize>>,
839}
840
841async fn retrieve_definitions(
842 reference: &Reference,
843 imports: &Imports,
844 query_point: Point,
845 snapshot: &BufferSnapshot,
846 index: &Arc<SyntaxIndexState>,
847 file_snapshots: &Arc<HashMap<ProjectEntryId, BufferSnapshot>>,
848 options: &Arc<zeta2::ZetaOptions>,
849) -> Result<RetrieveResult> {
850 let mut single_reference_map = HashMap::default();
851 single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]);
852 let edit_prediction_context = EditPredictionContext::gather_context_with_references_fn(
853 query_point,
854 snapshot,
855 imports,
856 &options.context,
857 Some(&index),
858 |_, _, _| single_reference_map,
859 );
860
861 let Some(edit_prediction_context) = edit_prediction_context else {
862 return Ok(RetrieveResult {
863 definitions: Vec::new(),
864 excerpt_range: None,
865 });
866 };
867
868 let mut retrieved_definitions = Vec::new();
869 for scored_declaration in edit_prediction_context.declarations {
870 match &scored_declaration.declaration {
871 Declaration::File {
872 project_entry_id,
873 declaration,
874 ..
875 } => {
876 let Some(snapshot) = file_snapshots.get(&project_entry_id) else {
877 log::error!("bug: file project entry not found");
878 continue;
879 };
880 let path = snapshot.file().unwrap().path().clone();
881 retrieved_definitions.push(RetrievedDefinition {
882 path,
883 range: snapshot.offset_to_point(declaration.item_range.start)
884 ..snapshot.offset_to_point(declaration.item_range.end),
885 score: scored_declaration.score(DeclarationStyle::Declaration),
886 retrieval_score: scored_declaration.retrieval_score(),
887 components: scored_declaration.components,
888 });
889 }
890 Declaration::Buffer {
891 project_entry_id,
892 rope,
893 declaration,
894 ..
895 } => {
896 let Some(snapshot) = file_snapshots.get(&project_entry_id) else {
897 // This case happens when dependency buffers have been opened by
898 // go-to-definition, resulting in single-file worktrees.
899 continue;
900 };
901 let path = snapshot.file().unwrap().path().clone();
902 retrieved_definitions.push(RetrievedDefinition {
903 path,
904 range: rope.offset_to_point(declaration.item_range.start)
905 ..rope.offset_to_point(declaration.item_range.end),
906 score: scored_declaration.score(DeclarationStyle::Declaration),
907 retrieval_score: scored_declaration.retrieval_score(),
908 components: scored_declaration.components,
909 });
910 }
911 }
912 }
913 retrieved_definitions.sort_by_key(|definition| Reverse(OrderedFloat(definition.score)));
914
915 Ok(RetrieveResult {
916 definitions: retrieved_definitions,
917 excerpt_range: Some(edit_prediction_context.excerpt.range),
918 })
919}
920
921async fn gather_lsp_definitions(
922 files: &[ProjectPath],
923 worktree: &Entity<Worktree>,
924 project: &Entity<Project>,
925 cx: &mut AsyncApp,
926) -> Result<LspResults> {
927 let worktree_id = worktree.read_with(cx, |worktree, _cx| worktree.id())?;
928
929 let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
930 cx.subscribe(&lsp_store, {
931 move |_, event, _| {
932 if let project::LspStoreEvent::LanguageServerUpdate {
933 message:
934 client::proto::update_language_server::Variant::WorkProgress(
935 client::proto::LspWorkProgress {
936 message: Some(message),
937 ..
938 },
939 ),
940 ..
941 } = event
942 {
943 println!("⟲ {message}")
944 }
945 }
946 })?
947 .detach();
948
949 let mut definitions = HashMap::default();
950 let mut error_count = 0;
951 let mut lsp_open_handles = Vec::new();
952 let mut ready_languages = HashSet::default();
953 for (file_index, project_path) in files.iter().enumerate() {
954 println!(
955 "Processing file {} of {}: {}",
956 file_index + 1,
957 files.len(),
958 project_path.path.display(PathStyle::Posix)
959 );
960
961 let Some((lsp_open_handle, language_server_id, buffer)) = open_buffer_with_language_server(
962 project.clone(),
963 worktree.clone(),
964 project_path.path.clone(),
965 &mut ready_languages,
966 cx,
967 )
968 .await
969 .log_err() else {
970 continue;
971 };
972 lsp_open_handles.push(lsp_open_handle);
973
974 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
975 let full_range = 0..snapshot.len();
976 let references = references_in_range(
977 full_range,
978 &snapshot.text(),
979 ReferenceRegion::Nearby,
980 &snapshot,
981 );
982
983 loop {
984 let is_ready = lsp_store
985 .read_with(cx, |lsp_store, _cx| {
986 lsp_store
987 .language_server_statuses
988 .get(&language_server_id)
989 .is_some_and(|status| status.pending_work.is_empty())
990 })
991 .unwrap();
992 if is_ready {
993 break;
994 }
995 cx.background_executor()
996 .timer(Duration::from_millis(10))
997 .await;
998 }
999
1000 for reference in references {
1001 // TODO: Rename declaration to definition in edit_prediction_context?
1002 let lsp_result = project
1003 .update(cx, |project, cx| {
1004 project.definitions(&buffer, reference.range.start, cx)
1005 })?
1006 .await;
1007
1008 match lsp_result {
1009 Ok(lsp_definitions) => {
1010 let mut targets = Vec::new();
1011 for target in lsp_definitions.unwrap_or_default() {
1012 let buffer = target.target.buffer;
1013 let anchor_range = target.target.range;
1014 buffer.read_with(cx, |buffer, cx| {
1015 let Some(file) = project::File::from_dyn(buffer.file()) else {
1016 return;
1017 };
1018 let file_worktree = file.worktree.read(cx);
1019 let file_worktree_id = file_worktree.id();
1020 // Relative paths for worktree files, absolute for all others
1021 let path = if worktree_id != file_worktree_id {
1022 file.worktree.read(cx).absolutize(&file.path)
1023 } else {
1024 file.path.as_std_path().to_path_buf()
1025 };
1026 let offset_range = anchor_range.to_offset(&buffer);
1027 let point_range = SerializablePoint::from_language_point_range(
1028 offset_range.to_point(&buffer),
1029 );
1030 targets.push(SourceRange {
1031 path,
1032 offset_range,
1033 point_range,
1034 });
1035 })?;
1036 }
1037
1038 definitions.insert(
1039 SourceLocation {
1040 path: project_path.path.clone(),
1041 point: snapshot.offset_to_point(reference.range.start),
1042 },
1043 targets,
1044 );
1045 }
1046 Err(err) => {
1047 log::error!("Language server error: {err}");
1048 error_count += 1;
1049 }
1050 }
1051 }
1052 }
1053
1054 log::error!("Encountered {} language server errors", error_count);
1055
1056 Ok(LspResults { definitions })
1057}
1058
1059#[derive(Debug, Clone, Serialize, Deserialize)]
1060#[serde(transparent)]
1061struct LspResults {
1062 definitions: HashMap<SourceLocation, Vec<SourceRange>>,
1063}
1064
1065#[derive(Debug, Clone, Serialize, Deserialize)]
1066struct SourceRange {
1067 path: PathBuf,
1068 point_range: Range<SerializablePoint>,
1069 offset_range: Range<usize>,
1070}
1071
1072/// Serializes to 1-based row and column indices.
1073#[derive(Debug, Clone, Serialize, Deserialize)]
1074pub struct SerializablePoint {
1075 pub row: u32,
1076 pub column: u32,
1077}
1078
1079impl SerializablePoint {
1080 pub fn into_language_point_range(range: Range<Self>) -> Range<Point> {
1081 range.start.into()..range.end.into()
1082 }
1083
1084 pub fn from_language_point_range(range: Range<Point>) -> Range<Self> {
1085 range.start.into()..range.end.into()
1086 }
1087}
1088
1089impl From<Point> for SerializablePoint {
1090 fn from(point: Point) -> Self {
1091 SerializablePoint {
1092 row: point.row + 1,
1093 column: point.column + 1,
1094 }
1095 }
1096}
1097
1098impl From<SerializablePoint> for Point {
1099 fn from(serializable: SerializablePoint) -> Self {
1100 Point {
1101 row: serializable.row.saturating_sub(1),
1102 column: serializable.column.saturating_sub(1),
1103 }
1104 }
1105}
1106
1107#[derive(Debug)]
1108struct RetrievalStatsResult {
1109 outcome: RetrievalOutcome,
1110 #[allow(dead_code)]
1111 path: Arc<RelPath>,
1112 #[allow(dead_code)]
1113 identifier: Identifier,
1114 #[allow(dead_code)]
1115 point: Point,
1116 #[allow(dead_code)]
1117 lsp_definitions: Vec<SourceRange>,
1118 retrieved_definitions: Vec<RetrievedDefinition>,
1119}
1120
1121#[derive(Debug)]
1122enum RetrievalOutcome {
1123 Match {
1124 /// Lowest index within retrieved_definitions that matches an LSP definition.
1125 best_match: usize,
1126 },
1127 ProbablyLocal,
1128 NoMatch,
1129 NoMatchDueToExternalLspDefinitions,
1130}
1131
1132#[derive(Debug)]
1133struct RetrievedDefinition {
1134 path: Arc<RelPath>,
1135 range: Range<Point>,
1136 score: f32,
1137 #[allow(dead_code)]
1138 retrieval_score: f32,
1139 #[allow(dead_code)]
1140 components: DeclarationScoreComponents,
1141}
1142
1143pub fn open_buffer(
1144 project: Entity<Project>,
1145 worktree: Entity<Worktree>,
1146 path: Arc<RelPath>,
1147 cx: &AsyncApp,
1148) -> Task<Result<Entity<Buffer>>> {
1149 cx.spawn(async move |cx| {
1150 let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
1151 worktree_id: worktree.id(),
1152 path,
1153 })?;
1154
1155 let buffer = project
1156 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
1157 .await?;
1158
1159 let mut parse_status = buffer.read_with(cx, |buffer, _cx| buffer.parse_status())?;
1160 while *parse_status.borrow() != ParseStatus::Idle {
1161 parse_status.changed().await?;
1162 }
1163
1164 Ok(buffer)
1165 })
1166}
1167
1168pub async fn open_buffer_with_language_server(
1169 project: Entity<Project>,
1170 worktree: Entity<Worktree>,
1171 path: Arc<RelPath>,
1172 ready_languages: &mut HashSet<LanguageId>,
1173 cx: &mut AsyncApp,
1174) -> Result<(Entity<Entity<Buffer>>, LanguageServerId, Entity<Buffer>)> {
1175 let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?;
1176
1177 let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
1178 (
1179 project.register_buffer_with_language_servers(&buffer, cx),
1180 project.path_style(cx),
1181 )
1182 })?;
1183
1184 let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
1185 buffer.language().map(|language| language.id())
1186 })?
1187 else {
1188 return Err(anyhow!("No language for {}", path.display(path_style)));
1189 };
1190
1191 let log_prefix = path.display(path_style);
1192 if !ready_languages.contains(&language_id) {
1193 wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
1194 ready_languages.insert(language_id);
1195 }
1196
1197 let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
1198
1199 // hacky wait for buffer to be registered with the language server
1200 for _ in 0..100 {
1201 let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| {
1202 buffer.update(cx, |buffer, cx| {
1203 lsp_store
1204 .language_servers_for_local_buffer(&buffer, cx)
1205 .next()
1206 .map(|(_, language_server)| language_server.server_id())
1207 })
1208 })?
1209 else {
1210 cx.background_executor()
1211 .timer(Duration::from_millis(10))
1212 .await;
1213 continue;
1214 };
1215
1216 return Ok((lsp_open_handle, language_server_id, buffer));
1217 }
1218
1219 return Err(anyhow!("No language server found for buffer"));
1220}
1221
1222// TODO: Dedupe with similar function in crates/eval/src/instance.rs
1223pub fn wait_for_lang_server(
1224 project: &Entity<Project>,
1225 buffer: &Entity<Buffer>,
1226 log_prefix: String,
1227 cx: &mut AsyncApp,
1228) -> Task<Result<()>> {
1229 println!("{}⏵ Waiting for language server", log_prefix);
1230
1231 let (mut tx, mut rx) = mpsc::channel(1);
1232
1233 let lsp_store = project
1234 .read_with(cx, |project, _| project.lsp_store())
1235 .unwrap();
1236
1237 let has_lang_server = buffer
1238 .update(cx, |buffer, cx| {
1239 lsp_store.update(cx, |lsp_store, cx| {
1240 lsp_store
1241 .language_servers_for_local_buffer(buffer, cx)
1242 .next()
1243 .is_some()
1244 })
1245 })
1246 .unwrap_or(false);
1247
1248 if has_lang_server {
1249 project
1250 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
1251 .unwrap()
1252 .detach();
1253 }
1254 let (mut added_tx, mut added_rx) = mpsc::channel(1);
1255
1256 let subscriptions = [
1257 cx.subscribe(&lsp_store, {
1258 let log_prefix = log_prefix.clone();
1259 move |_, event, _| {
1260 if let project::LspStoreEvent::LanguageServerUpdate {
1261 message:
1262 client::proto::update_language_server::Variant::WorkProgress(
1263 client::proto::LspWorkProgress {
1264 message: Some(message),
1265 ..
1266 },
1267 ),
1268 ..
1269 } = event
1270 {
1271 println!("{}⟲ {message}", log_prefix)
1272 }
1273 }
1274 }),
1275 cx.subscribe(project, {
1276 let buffer = buffer.clone();
1277 move |project, event, cx| match event {
1278 project::Event::LanguageServerAdded(_, _, _) => {
1279 let buffer = buffer.clone();
1280 project
1281 .update(cx, |project, cx| project.save_buffer(buffer, cx))
1282 .detach();
1283 added_tx.try_send(()).ok();
1284 }
1285 project::Event::DiskBasedDiagnosticsFinished { .. } => {
1286 tx.try_send(()).ok();
1287 }
1288 _ => {}
1289 }
1290 }),
1291 ];
1292
1293 cx.spawn(async move |cx| {
1294 if !has_lang_server {
1295 // some buffers never have a language server, so this aborts quickly in that case.
1296 let timeout = cx.background_executor().timer(Duration::from_secs(5));
1297 futures::select! {
1298 _ = added_rx.next() => {},
1299 _ = timeout.fuse() => {
1300 anyhow::bail!("Waiting for language server add timed out after 5 seconds");
1301 }
1302 };
1303 }
1304 let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
1305 let result = futures::select! {
1306 _ = rx.next() => {
1307 println!("{}⚑ Language server idle", log_prefix);
1308 anyhow::Ok(())
1309 },
1310 _ = timeout.fuse() => {
1311 anyhow::bail!("LSP wait timed out after 5 minutes");
1312 }
1313 };
1314 drop(subscriptions);
1315 result
1316 })
1317}
1318
1319fn main() {
1320 zlog::init();
1321 zlog::init_output_stderr();
1322 let args = ZetaCliArgs::parse();
1323 let http_client = Arc::new(ReqwestClient::new());
1324 let app = Application::headless().with_http_client(http_client);
1325
1326 app.run(move |cx| {
1327 let app_state = Arc::new(headless::init(cx));
1328 cx.spawn(async move |cx| {
1329 let result = match args.command {
1330 Commands::Zeta2Context {
1331 zeta2_args,
1332 context_args,
1333 } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
1334 Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
1335 Ok(GetContextOutput::Zeta2(output)) => Ok(output),
1336 Err(err) => Err(err),
1337 },
1338 Commands::Context(context_args) => {
1339 match get_context(None, context_args, &app_state, cx).await {
1340 Ok(GetContextOutput::Zeta1(output)) => {
1341 Ok(serde_json::to_string_pretty(&output.body).unwrap())
1342 }
1343 Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
1344 Err(err) => Err(err),
1345 }
1346 }
1347 Commands::Predict {
1348 predict_edits_body,
1349 context_args,
1350 } => {
1351 cx.spawn(async move |cx| {
1352 let app_version = cx.update(|cx| AppVersion::global(cx))?;
1353 app_state.client.sign_in(true, cx).await?;
1354 let llm_token = LlmApiToken::default();
1355 llm_token.refresh(&app_state.client).await?;
1356
1357 let predict_edits_body =
1358 if let Some(predict_edits_body) = predict_edits_body {
1359 serde_json::from_str(&predict_edits_body.read_to_string().await?)?
1360 } else if let Some(context_args) = context_args {
1361 match get_context(None, context_args, &app_state, cx).await? {
1362 GetContextOutput::Zeta1(output) => output.body,
1363 GetContextOutput::Zeta2 { .. } => unreachable!(),
1364 }
1365 } else {
1366 return Err(anyhow!(
1367 "Expected either --predict-edits-body-file \
1368 or the required args of the `context` command."
1369 ));
1370 };
1371
1372 let (response, _usage) =
1373 Zeta::perform_predict_edits(PerformPredictEditsParams {
1374 client: app_state.client.clone(),
1375 llm_token,
1376 app_version,
1377 body: predict_edits_body,
1378 })
1379 .await?;
1380
1381 Ok(response.output_excerpt)
1382 })
1383 .await
1384 }
1385 Commands::RetrievalStats {
1386 zeta2_args,
1387 worktree,
1388 extension,
1389 limit,
1390 skip,
1391 } => {
1392 retrieval_stats(
1393 worktree,
1394 app_state,
1395 extension,
1396 limit,
1397 skip,
1398 (&zeta2_args).to_options(false),
1399 cx,
1400 )
1401 .await
1402 }
1403 };
1404 match result {
1405 Ok(output) => {
1406 println!("{}", output);
1407 let _ = cx.update(|cx| cx.quit());
1408 }
1409 Err(e) => {
1410 eprintln!("Failed: {:?}", e);
1411 exit(1);
1412 }
1413 }
1414 })
1415 .detach();
1416 });
1417}