@@ -1,47 +1,29 @@
mod headless;
+mod retrieval_stats;
+mod source_location;
+mod util;
-use anyhow::{Context as _, Result, anyhow};
+use crate::retrieval_stats::retrieval_stats;
+use ::util::paths::PathStyle;
+use anyhow::{Result, anyhow};
use clap::{Args, Parser, Subcommand};
-use cloud_llm_client::predict_edits_v3::{self, DeclarationScoreComponents};
+use cloud_llm_client::predict_edits_v3::{self};
use edit_prediction_context::{
- Declaration, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
- EditPredictionExcerptOptions, EditPredictionScoreOptions, Identifier, Imports, Reference,
- ReferenceRegion, SyntaxIndex, SyntaxIndexState, references_in_range,
+ EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
};
-use futures::channel::mpsc;
-use futures::{FutureExt as _, StreamExt as _};
-use gpui::{AppContext, Application, AsyncApp};
-use gpui::{Entity, Task};
-use language::{Bias, BufferSnapshot, LanguageServerId, Point};
-use language::{Buffer, OffsetRangeExt};
-use language::{LanguageId, ParseStatus};
+use gpui::{Application, AsyncApp, prelude::*};
+use language::Bias;
use language_model::LlmApiToken;
-use ordered_float::OrderedFloat;
-use project::{Project, ProjectEntryId, ProjectPath, Worktree};
+use project::Project;
use release_channel::AppVersion;
use reqwest_client::ReqwestClient;
-use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::json;
-use std::cmp::Reverse;
-use std::collections::{HashMap, HashSet};
-use std::fmt::{self, Display};
-use std::fs::File;
-use std::hash::Hash;
-use std::hash::Hasher;
-use std::io::{BufRead, BufReader, BufWriter, Write as _};
-use std::ops::Range;
-use std::path::{Path, PathBuf};
-use std::process::exit;
-use std::str::FromStr;
-use std::sync::atomic::AtomicUsize;
-use std::sync::{Arc, atomic};
-use std::time::Duration;
-use util::paths::PathStyle;
-use util::rel_path::RelPath;
-use util::{RangeExt, ResultExt as _};
+use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc};
use zeta::{PerformPredictEditsParams, Zeta};
use crate::headless::ZetaCliAppState;
+use crate::source_location::SourceLocation;
+use crate::util::{open_buffer, open_buffer_with_language_server};
#[derive(Parser, Debug)]
#[command(name = "zeta")]
@@ -166,70 +148,6 @@ impl FromStr for FileOrStdin {
}
}
-#[derive(Debug, Clone, Hash, Eq, PartialEq)]
-struct SourceLocation {
- path: Arc<RelPath>,
- point: Point,
-}
-
-impl Serialize for SourceLocation {
- fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
- where
- S: Serializer,
- {
- serializer.serialize_str(&self.to_string())
- }
-}
-
-impl<'de> Deserialize<'de> for SourceLocation {
- fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
- where
- D: Deserializer<'de>,
- {
- let s = String::deserialize(deserializer)?;
- s.parse().map_err(serde::de::Error::custom)
- }
-}
-
-impl Display for SourceLocation {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(
- f,
- "{}:{}:{}",
- self.path.display(PathStyle::Posix),
- self.point.row + 1,
- self.point.column + 1
- )
- }
-}
-
-impl FromStr for SourceLocation {
- type Err = anyhow::Error;
-
- fn from_str(s: &str) -> Result<Self> {
- let parts: Vec<&str> = s.split(':').collect();
- if parts.len() != 3 {
- return Err(anyhow!(
- "Invalid source location. Expected 'file.rs:line:column', got '{}'",
- s
- ));
- }
-
- let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
- let line: u32 = parts[1]
- .parse()
- .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
- let column: u32 = parts[2]
- .parse()
- .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
-
- // Convert from 1-based to 0-based indexing
- let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
-
- Ok(SourceLocation { path, point })
- }
-}
-
enum GetContextOutput {
Zeta1(zeta::GatherContextOutput),
Zeta2(String),
@@ -399,1012 +317,6 @@ impl Zeta2Args {
}
}
-pub async fn retrieval_stats(
- worktree: PathBuf,
- app_state: Arc<ZetaCliAppState>,
- only_extension: Option<String>,
- file_limit: Option<usize>,
- skip_files: Option<usize>,
- options: zeta2::ZetaOptions,
- cx: &mut AsyncApp,
-) -> Result<String> {
- let options = Arc::new(options);
- let worktree_path = worktree.canonicalize()?;
-
- let project = cx.update(|cx| {
- Project::local(
- app_state.client.clone(),
- app_state.node_runtime.clone(),
- app_state.user_store.clone(),
- app_state.languages.clone(),
- app_state.fs.clone(),
- None,
- cx,
- )
- })?;
-
- let worktree = project
- .update(cx, |project, cx| {
- project.create_worktree(&worktree_path, true, cx)
- })?
- .await?;
-
- // wait for worktree scan so that wait_for_initial_file_indexing waits for the whole worktree.
- worktree
- .read_with(cx, |worktree, _cx| {
- worktree.as_local().unwrap().scan_complete()
- })?
- .await;
-
- let index = cx.new(|cx| SyntaxIndex::new(&project, options.file_indexing_parallelism, cx))?;
- index
- .read_with(cx, |index, cx| index.wait_for_initial_file_indexing(cx))?
- .await?;
- let indexed_files = index
- .read_with(cx, |index, cx| index.indexed_file_paths(cx))?
- .await;
- let mut filtered_files = indexed_files
- .into_iter()
- .filter(|project_path| {
- let file_extension = project_path.path.extension();
- if let Some(only_extension) = only_extension.as_ref() {
- file_extension.is_some_and(|extension| extension == only_extension)
- } else {
- file_extension
- .is_some_and(|extension| !["md", "json", "sh", "diff"].contains(&extension))
- }
- })
- .collect::<Vec<_>>();
- filtered_files.sort_by(|a, b| a.path.cmp(&b.path));
-
- let index_state = index.read_with(cx, |index, _cx| index.state().clone())?;
- cx.update(|_| {
- drop(index);
- })?;
- let index_state = Arc::new(
- Arc::into_inner(index_state)
- .context("Index state had more than 1 reference")?
- .into_inner(),
- );
-
- struct FileSnapshot {
- project_entry_id: ProjectEntryId,
- snapshot: BufferSnapshot,
- hash: u64,
- parent_abs_path: Arc<Path>,
- }
-
- let files: Vec<FileSnapshot> = futures::future::try_join_all({
- filtered_files
- .iter()
- .map(|file| {
- let buffer_task =
- open_buffer(project.clone(), worktree.clone(), file.path.clone(), cx);
- cx.spawn(async move |cx| {
- let buffer = buffer_task.await?;
- let (project_entry_id, parent_abs_path, snapshot) =
- buffer.read_with(cx, |buffer, cx| {
- let file = project::File::from_dyn(buffer.file()).unwrap();
- let project_entry_id = file.project_entry_id().unwrap();
- let mut parent_abs_path = file.worktree.read(cx).absolutize(&file.path);
- if !parent_abs_path.pop() {
- panic!("Invalid worktree path");
- }
-
- (project_entry_id, parent_abs_path, buffer.snapshot())
- })?;
-
- anyhow::Ok(
- cx.background_spawn(async move {
- let mut hasher = collections::FxHasher::default();
- snapshot.text().hash(&mut hasher);
- FileSnapshot {
- project_entry_id,
- snapshot,
- hash: hasher.finish(),
- parent_abs_path: parent_abs_path.into(),
- }
- })
- .await,
- )
- })
- })
- .collect::<Vec<_>>()
- })
- .await?;
-
- let mut file_snapshots = HashMap::default();
- let mut hasher = collections::FxHasher::default();
- for FileSnapshot {
- project_entry_id,
- snapshot,
- hash,
- ..
- } in &files
- {
- file_snapshots.insert(*project_entry_id, snapshot.clone());
- hash.hash(&mut hasher);
- }
- let files_hash = hasher.finish();
- let file_snapshots = Arc::new(file_snapshots);
-
- let lsp_definitions_path = std::env::current_dir()?.join(format!(
- "target/zeta2-lsp-definitions-{:x}.jsonl",
- files_hash
- ));
-
- let mut lsp_definitions = HashMap::default();
- let mut lsp_files = 0;
-
- if std::fs::exists(&lsp_definitions_path)? {
- log::info!(
- "Using cached LSP definitions from {}",
- lsp_definitions_path.display()
- );
-
- let file = File::options()
- .read(true)
- .write(true)
- .open(&lsp_definitions_path)?;
- let lines = BufReader::new(&file).lines();
- let mut valid_len: usize = 0;
-
- for (line, expected_file) in lines.zip(files.iter()) {
- let line = line?;
- let FileLspDefinitions { path, references } = match serde_json::from_str(&line) {
- Ok(ok) => ok,
- Err(_) => {
- log::error!("Found invalid cache line. Truncating to #{lsp_files}.",);
- file.set_len(valid_len as u64)?;
- break;
- }
- };
- let expected_path = expected_file.snapshot.file().unwrap().path().as_unix_str();
- if expected_path != path.as_ref() {
- log::error!(
- "Expected file #{} to be {expected_path}, but found {path}. Truncating to #{lsp_files}.",
- lsp_files + 1
- );
- file.set_len(valid_len as u64)?;
- break;
- }
- for (point, ranges) in references {
- let Ok(path) = RelPath::new(Path::new(path.as_ref()), PathStyle::Posix) else {
- log::warn!("Invalid path: {}", path);
- continue;
- };
- lsp_definitions.insert(
- SourceLocation {
- path: path.into_arc(),
- point: point.into(),
- },
- ranges,
- );
- }
- lsp_files += 1;
- valid_len += line.len() + 1
- }
- }
-
- if lsp_files < files.len() {
- if lsp_files == 0 {
- log::warn!(
- "No LSP definitions found, populating {}",
- lsp_definitions_path.display()
- );
- } else {
- log::warn!("{} files missing from LSP cache", files.len() - lsp_files);
- }
-
- gather_lsp_definitions(
- &lsp_definitions_path,
- lsp_files,
- &filtered_files,
- &worktree,
- &project,
- &mut lsp_definitions,
- cx,
- )
- .await?;
- }
- let files_len = files.len().min(file_limit.unwrap_or(usize::MAX));
- let done_count = Arc::new(AtomicUsize::new(0));
-
- let (output_tx, mut output_rx) = mpsc::unbounded::<RetrievalStatsResult>();
- let mut output = std::fs::File::create("target/zeta-retrieval-stats.txt")?;
-
- let tasks = files
- .into_iter()
- .skip(skip_files.unwrap_or(0))
- .take(file_limit.unwrap_or(usize::MAX))
- .map(|project_file| {
- let index_state = index_state.clone();
- let lsp_definitions = lsp_definitions.clone();
- let options = options.clone();
- let output_tx = output_tx.clone();
- let done_count = done_count.clone();
- let file_snapshots = file_snapshots.clone();
- cx.background_spawn(async move {
- let snapshot = project_file.snapshot;
-
- let full_range = 0..snapshot.len();
- let references = references_in_range(
- full_range,
- &snapshot.text(),
- ReferenceRegion::Nearby,
- &snapshot,
- );
-
- println!("references: {}", references.len(),);
-
- let imports = if options.context.use_imports {
- Imports::gather(&snapshot, Some(&project_file.parent_abs_path))
- } else {
- Imports::default()
- };
-
- let path = snapshot.file().unwrap().path();
-
- for reference in references {
- let query_point = snapshot.offset_to_point(reference.range.start);
- let source_location = SourceLocation {
- path: path.clone(),
- point: query_point,
- };
- let lsp_definitions = lsp_definitions
- .get(&source_location)
- .cloned()
- .unwrap_or_else(|| {
- log::warn!(
- "No definitions found for source location: {:?}",
- source_location
- );
- Vec::new()
- });
-
- let retrieve_result = retrieve_definitions(
- &reference,
- &imports,
- query_point,
- &snapshot,
- &index_state,
- &file_snapshots,
- &options,
- )
- .await?;
-
- // TODO: LSP returns things like locals, this filters out some of those, but potentially
- // hides some retrieval issues.
- if retrieve_result.definitions.is_empty() {
- continue;
- }
-
- let mut best_match = None;
- let mut has_external_definition = false;
- let mut in_excerpt = false;
- for (index, retrieved_definition) in
- retrieve_result.definitions.iter().enumerate()
- {
- for lsp_definition in &lsp_definitions {
- let SourceRange {
- path,
- point_range,
- offset_range,
- } = lsp_definition;
- let lsp_point_range =
- SerializablePoint::into_language_point_range(point_range.clone());
- has_external_definition = has_external_definition
- || path.is_absolute()
- || path
- .components()
- .any(|component| component.as_os_str() == "node_modules");
- let is_match = path.as_path()
- == retrieved_definition.path.as_std_path()
- && retrieved_definition
- .range
- .contains_inclusive(&lsp_point_range);
- if is_match {
- if best_match.is_none() {
- best_match = Some(index);
- }
- }
- in_excerpt = in_excerpt
- || retrieve_result.excerpt_range.as_ref().is_some_and(
- |excerpt_range| excerpt_range.contains_inclusive(&offset_range),
- );
- }
- }
-
- let outcome = if let Some(best_match) = best_match {
- RetrievalOutcome::Match { best_match }
- } else if has_external_definition {
- RetrievalOutcome::NoMatchDueToExternalLspDefinitions
- } else if in_excerpt {
- RetrievalOutcome::ProbablyLocal
- } else {
- RetrievalOutcome::NoMatch
- };
-
- let result = RetrievalStatsResult {
- outcome,
- path: path.clone(),
- identifier: reference.identifier,
- point: query_point,
- lsp_definitions,
- retrieved_definitions: retrieve_result.definitions,
- };
-
- output_tx.unbounded_send(result).ok();
- }
-
- println!(
- "{:02}/{:02} done",
- done_count.fetch_add(1, atomic::Ordering::Relaxed) + 1,
- files_len,
- );
-
- anyhow::Ok(())
- })
- })
- .collect::<Vec<_>>();
-
- drop(output_tx);
-
- let results_task = cx.background_spawn(async move {
- let mut results = Vec::new();
- while let Some(result) = output_rx.next().await {
- output
- .write_all(format!("{:#?}\n", result).as_bytes())
- .log_err();
- results.push(result)
- }
- results
- });
-
- futures::future::try_join_all(tasks).await?;
- println!("Tasks completed");
- let results = results_task.await;
- println!("Results received");
-
- let mut references_count = 0;
-
- let mut included_count = 0;
- let mut both_absent_count = 0;
-
- let mut retrieved_count = 0;
- let mut top_match_count = 0;
- let mut non_top_match_count = 0;
- let mut ranking_involved_top_match_count = 0;
-
- let mut no_match_count = 0;
- let mut no_match_none_retrieved = 0;
- let mut no_match_wrong_retrieval = 0;
-
- let mut expected_no_match_count = 0;
- let mut in_excerpt_count = 0;
- let mut external_definition_count = 0;
-
- for result in results {
- references_count += 1;
- match &result.outcome {
- RetrievalOutcome::Match { best_match } => {
- included_count += 1;
- retrieved_count += 1;
- let multiple = result.retrieved_definitions.len() > 1;
- if *best_match == 0 {
- top_match_count += 1;
- if multiple {
- ranking_involved_top_match_count += 1;
- }
- } else {
- non_top_match_count += 1;
- }
- }
- RetrievalOutcome::NoMatch => {
- if result.lsp_definitions.is_empty() {
- included_count += 1;
- both_absent_count += 1;
- } else {
- no_match_count += 1;
- if result.retrieved_definitions.is_empty() {
- no_match_none_retrieved += 1;
- } else {
- no_match_wrong_retrieval += 1;
- }
- }
- }
- RetrievalOutcome::NoMatchDueToExternalLspDefinitions => {
- expected_no_match_count += 1;
- external_definition_count += 1;
- }
- RetrievalOutcome::ProbablyLocal => {
- included_count += 1;
- in_excerpt_count += 1;
- }
- }
- }
-
- fn count_and_percentage(part: usize, total: usize) -> String {
- format!("{} ({:.2}%)", part, (part as f64 / total as f64) * 100.0)
- }
-
- println!("");
- println!("╮ references: {}", references_count);
- println!(
- "├─╮ included: {}",
- count_and_percentage(included_count, references_count),
- );
- println!(
- "│ ├─╮ retrieved: {}",
- count_and_percentage(retrieved_count, references_count)
- );
- println!(
- "│ │ ├─╮ top match : {}",
- count_and_percentage(top_match_count, retrieved_count)
- );
- println!(
- "│ │ │ ╰─╴ involving ranking: {}",
- count_and_percentage(ranking_involved_top_match_count, top_match_count)
- );
- println!(
- "│ │ ╰─╴ non-top match: {}",
- count_and_percentage(non_top_match_count, retrieved_count)
- );
- println!(
- "│ ├─╴ both absent: {}",
- count_and_percentage(both_absent_count, included_count)
- );
- println!(
- "│ ╰─╴ in excerpt: {}",
- count_and_percentage(in_excerpt_count, included_count)
- );
- println!(
- "├─╮ no match: {}",
- count_and_percentage(no_match_count, references_count)
- );
- println!(
- "│ ├─╴ none retrieved: {}",
- count_and_percentage(no_match_none_retrieved, no_match_count)
- );
- println!(
- "│ ╰─╴ wrong retrieval: {}",
- count_and_percentage(no_match_wrong_retrieval, no_match_count)
- );
- println!(
- "╰─╮ expected no match: {}",
- count_and_percentage(expected_no_match_count, references_count)
- );
- println!(
- " ╰─╴ external definition: {}",
- count_and_percentage(external_definition_count, expected_no_match_count)
- );
-
- println!("");
- println!("LSP definition cache at {}", lsp_definitions_path.display());
-
- Ok("".to_string())
-}
-
-struct RetrieveResult {
- definitions: Vec<RetrievedDefinition>,
- excerpt_range: Option<Range<usize>>,
-}
-
-async fn retrieve_definitions(
- reference: &Reference,
- imports: &Imports,
- query_point: Point,
- snapshot: &BufferSnapshot,
- index: &Arc<SyntaxIndexState>,
- file_snapshots: &Arc<HashMap<ProjectEntryId, BufferSnapshot>>,
- options: &Arc<zeta2::ZetaOptions>,
-) -> Result<RetrieveResult> {
- let mut single_reference_map = HashMap::default();
- single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]);
- let edit_prediction_context = EditPredictionContext::gather_context_with_references_fn(
- query_point,
- snapshot,
- imports,
- &options.context,
- Some(&index),
- |_, _, _| single_reference_map,
- );
-
- let Some(edit_prediction_context) = edit_prediction_context else {
- return Ok(RetrieveResult {
- definitions: Vec::new(),
- excerpt_range: None,
- });
- };
-
- let mut retrieved_definitions = Vec::new();
- for scored_declaration in edit_prediction_context.declarations {
- match &scored_declaration.declaration {
- Declaration::File {
- project_entry_id,
- declaration,
- ..
- } => {
- let Some(snapshot) = file_snapshots.get(&project_entry_id) else {
- log::error!("bug: file project entry not found");
- continue;
- };
- let path = snapshot.file().unwrap().path().clone();
- retrieved_definitions.push(RetrievedDefinition {
- path,
- range: snapshot.offset_to_point(declaration.item_range.start)
- ..snapshot.offset_to_point(declaration.item_range.end),
- score: scored_declaration.score(DeclarationStyle::Declaration),
- retrieval_score: scored_declaration.retrieval_score(),
- components: scored_declaration.components,
- });
- }
- Declaration::Buffer {
- project_entry_id,
- rope,
- declaration,
- ..
- } => {
- let Some(snapshot) = file_snapshots.get(&project_entry_id) else {
- // This case happens when dependency buffers have been opened by
- // go-to-definition, resulting in single-file worktrees.
- continue;
- };
- let path = snapshot.file().unwrap().path().clone();
- retrieved_definitions.push(RetrievedDefinition {
- path,
- range: rope.offset_to_point(declaration.item_range.start)
- ..rope.offset_to_point(declaration.item_range.end),
- score: scored_declaration.score(DeclarationStyle::Declaration),
- retrieval_score: scored_declaration.retrieval_score(),
- components: scored_declaration.components,
- });
- }
- }
- }
- retrieved_definitions.sort_by_key(|definition| Reverse(OrderedFloat(definition.score)));
-
- Ok(RetrieveResult {
- definitions: retrieved_definitions,
- excerpt_range: Some(edit_prediction_context.excerpt.range),
- })
-}
-
-async fn gather_lsp_definitions(
- lsp_definitions_path: &Path,
- start_index: usize,
- files: &[ProjectPath],
- worktree: &Entity<Worktree>,
- project: &Entity<Project>,
- definitions: &mut HashMap<SourceLocation, Vec<SourceRange>>,
- cx: &mut AsyncApp,
-) -> Result<()> {
- let worktree_id = worktree.read_with(cx, |worktree, _cx| worktree.id())?;
-
- let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
- cx.subscribe(&lsp_store, {
- move |_, event, _| {
- if let project::LspStoreEvent::LanguageServerUpdate {
- message:
- client::proto::update_language_server::Variant::WorkProgress(
- client::proto::LspWorkProgress {
- message: Some(message),
- ..
- },
- ),
- ..
- } = event
- {
- println!("⟲ {message}")
- }
- }
- })?
- .detach();
-
- let (cache_line_tx, mut cache_line_rx) = mpsc::unbounded::<FileLspDefinitions>();
-
- let cache_file = File::options()
- .append(true)
- .create(true)
- .open(lsp_definitions_path)
- .unwrap();
-
- let cache_task = cx.background_spawn(async move {
- let mut writer = BufWriter::new(cache_file);
- while let Some(line) = cache_line_rx.next().await {
- serde_json::to_writer(&mut writer, &line).unwrap();
- writer.write_all(&[b'\n']).unwrap();
- }
- writer.flush().unwrap();
- });
-
- let mut error_count = 0;
- let mut lsp_open_handles = Vec::new();
- let mut ready_languages = HashSet::default();
- for (file_index, project_path) in files[start_index..].iter().enumerate() {
- println!(
- "Processing file {} of {}: {}",
- start_index + file_index + 1,
- files.len(),
- project_path.path.display(PathStyle::Posix)
- );
-
- let Some((lsp_open_handle, language_server_id, buffer)) = open_buffer_with_language_server(
- project.clone(),
- worktree.clone(),
- project_path.path.clone(),
- &mut ready_languages,
- cx,
- )
- .await
- .log_err() else {
- continue;
- };
- lsp_open_handles.push(lsp_open_handle);
-
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
- let full_range = 0..snapshot.len();
- let references = references_in_range(
- full_range,
- &snapshot.text(),
- ReferenceRegion::Nearby,
- &snapshot,
- );
-
- loop {
- let is_ready = lsp_store
- .read_with(cx, |lsp_store, _cx| {
- lsp_store
- .language_server_statuses
- .get(&language_server_id)
- .is_some_and(|status| status.pending_work.is_empty())
- })
- .unwrap();
- if is_ready {
- break;
- }
- cx.background_executor()
- .timer(Duration::from_millis(10))
- .await;
- }
-
- let mut cache_line_references = Vec::with_capacity(references.len());
-
- for reference in references {
- // TODO: Rename declaration to definition in edit_prediction_context?
- let lsp_result = project
- .update(cx, |project, cx| {
- project.definitions(&buffer, reference.range.start, cx)
- })?
- .await;
-
- match lsp_result {
- Ok(lsp_definitions) => {
- let mut targets = Vec::new();
- for target in lsp_definitions.unwrap_or_default() {
- let buffer = target.target.buffer;
- let anchor_range = target.target.range;
- buffer.read_with(cx, |buffer, cx| {
- let Some(file) = project::File::from_dyn(buffer.file()) else {
- return;
- };
- let file_worktree = file.worktree.read(cx);
- let file_worktree_id = file_worktree.id();
- // Relative paths for worktree files, absolute for all others
- let path = if worktree_id != file_worktree_id {
- file.worktree.read(cx).absolutize(&file.path)
- } else {
- file.path.as_std_path().to_path_buf()
- };
- let offset_range = anchor_range.to_offset(&buffer);
- let point_range = SerializablePoint::from_language_point_range(
- offset_range.to_point(&buffer),
- );
- targets.push(SourceRange {
- path,
- offset_range,
- point_range,
- });
- })?;
- }
-
- let point = snapshot.offset_to_point(reference.range.start);
-
- cache_line_references.push((point.into(), targets.clone()));
- definitions.insert(
- SourceLocation {
- path: project_path.path.clone(),
- point,
- },
- targets,
- );
- }
- Err(err) => {
- log::error!("Language server error: {err}");
- error_count += 1;
- }
- }
- }
-
- cache_line_tx
- .unbounded_send(FileLspDefinitions {
- path: project_path.path.as_unix_str().into(),
- references: cache_line_references,
- })
- .log_err();
- }
-
- drop(cache_line_tx);
-
- if error_count > 0 {
- log::error!("Encountered {} language server errors", error_count);
- }
-
- cache_task.await;
-
- Ok(())
-}
-
-#[derive(Serialize, Deserialize)]
-struct FileLspDefinitions {
- path: Arc<str>,
- references: Vec<(SerializablePoint, Vec<SourceRange>)>,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-struct SourceRange {
- path: PathBuf,
- point_range: Range<SerializablePoint>,
- offset_range: Range<usize>,
-}
-
-/// Serializes to 1-based row and column indices.
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct SerializablePoint {
- pub row: u32,
- pub column: u32,
-}
-
-impl SerializablePoint {
- pub fn into_language_point_range(range: Range<Self>) -> Range<Point> {
- range.start.into()..range.end.into()
- }
-
- pub fn from_language_point_range(range: Range<Point>) -> Range<Self> {
- range.start.into()..range.end.into()
- }
-}
-
-impl From<Point> for SerializablePoint {
- fn from(point: Point) -> Self {
- SerializablePoint {
- row: point.row + 1,
- column: point.column + 1,
- }
- }
-}
-
-impl From<SerializablePoint> for Point {
- fn from(serializable: SerializablePoint) -> Self {
- Point {
- row: serializable.row.saturating_sub(1),
- column: serializable.column.saturating_sub(1),
- }
- }
-}
-
-#[derive(Debug)]
-struct RetrievalStatsResult {
- outcome: RetrievalOutcome,
- #[allow(dead_code)]
- path: Arc<RelPath>,
- #[allow(dead_code)]
- identifier: Identifier,
- #[allow(dead_code)]
- point: Point,
- #[allow(dead_code)]
- lsp_definitions: Vec<SourceRange>,
- retrieved_definitions: Vec<RetrievedDefinition>,
-}
-
-#[derive(Debug)]
-enum RetrievalOutcome {
- Match {
- /// Lowest index within retrieved_definitions that matches an LSP definition.
- best_match: usize,
- },
- ProbablyLocal,
- NoMatch,
- NoMatchDueToExternalLspDefinitions,
-}
-
-#[derive(Debug)]
-struct RetrievedDefinition {
- path: Arc<RelPath>,
- range: Range<Point>,
- score: f32,
- #[allow(dead_code)]
- retrieval_score: f32,
- #[allow(dead_code)]
- components: DeclarationScoreComponents,
-}
-
-pub fn open_buffer(
- project: Entity<Project>,
- worktree: Entity<Worktree>,
- path: Arc<RelPath>,
- cx: &AsyncApp,
-) -> Task<Result<Entity<Buffer>>> {
- cx.spawn(async move |cx| {
- let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
- worktree_id: worktree.id(),
- path,
- })?;
-
- let buffer = project
- .update(cx, |project, cx| project.open_buffer(project_path, cx))?
- .await?;
-
- let mut parse_status = buffer.read_with(cx, |buffer, _cx| buffer.parse_status())?;
- while *parse_status.borrow() != ParseStatus::Idle {
- parse_status.changed().await?;
- }
-
- Ok(buffer)
- })
-}
-
-pub async fn open_buffer_with_language_server(
- project: Entity<Project>,
- worktree: Entity<Worktree>,
- path: Arc<RelPath>,
- ready_languages: &mut HashSet<LanguageId>,
- cx: &mut AsyncApp,
-) -> Result<(Entity<Entity<Buffer>>, LanguageServerId, Entity<Buffer>)> {
- let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?;
-
- let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
- (
- project.register_buffer_with_language_servers(&buffer, cx),
- project.path_style(cx),
- )
- })?;
-
- let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
- buffer.language().map(|language| language.id())
- })?
- else {
- return Err(anyhow!("No language for {}", path.display(path_style)));
- };
-
- let log_prefix = path.display(path_style);
- if !ready_languages.contains(&language_id) {
- wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
- ready_languages.insert(language_id);
- }
-
- let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
-
- // hacky wait for buffer to be registered with the language server
- for _ in 0..100 {
- let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| {
- buffer.update(cx, |buffer, cx| {
- lsp_store
- .language_servers_for_local_buffer(&buffer, cx)
- .next()
- .map(|(_, language_server)| language_server.server_id())
- })
- })?
- else {
- cx.background_executor()
- .timer(Duration::from_millis(10))
- .await;
- continue;
- };
-
- return Ok((lsp_open_handle, language_server_id, buffer));
- }
-
- return Err(anyhow!("No language server found for buffer"));
-}
-
-// TODO: Dedupe with similar function in crates/eval/src/instance.rs
-pub fn wait_for_lang_server(
- project: &Entity<Project>,
- buffer: &Entity<Buffer>,
- log_prefix: String,
- cx: &mut AsyncApp,
-) -> Task<Result<()>> {
- println!("{}⏵ Waiting for language server", log_prefix);
-
- let (mut tx, mut rx) = mpsc::channel(1);
-
- let lsp_store = project
- .read_with(cx, |project, _| project.lsp_store())
- .unwrap();
-
- let has_lang_server = buffer
- .update(cx, |buffer, cx| {
- lsp_store.update(cx, |lsp_store, cx| {
- lsp_store
- .language_servers_for_local_buffer(buffer, cx)
- .next()
- .is_some()
- })
- })
- .unwrap_or(false);
-
- if has_lang_server {
- project
- .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
- .unwrap()
- .detach();
- }
- let (mut added_tx, mut added_rx) = mpsc::channel(1);
-
- let subscriptions = [
- cx.subscribe(&lsp_store, {
- let log_prefix = log_prefix.clone();
- move |_, event, _| {
- if let project::LspStoreEvent::LanguageServerUpdate {
- message:
- client::proto::update_language_server::Variant::WorkProgress(
- client::proto::LspWorkProgress {
- message: Some(message),
- ..
- },
- ),
- ..
- } = event
- {
- println!("{}⟲ {message}", log_prefix)
- }
- }
- }),
- cx.subscribe(project, {
- let buffer = buffer.clone();
- move |project, event, cx| match event {
- project::Event::LanguageServerAdded(_, _, _) => {
- let buffer = buffer.clone();
- project
- .update(cx, |project, cx| project.save_buffer(buffer, cx))
- .detach();
- added_tx.try_send(()).ok();
- }
- project::Event::DiskBasedDiagnosticsFinished { .. } => {
- tx.try_send(()).ok();
- }
- _ => {}
- }
- }),
- ];
-
- cx.spawn(async move |cx| {
- if !has_lang_server {
- // some buffers never have a language server, so this aborts quickly in that case.
- let timeout = cx.background_executor().timer(Duration::from_secs(5));
- futures::select! {
- _ = added_rx.next() => {},
- _ = timeout.fuse() => {
- anyhow::bail!("Waiting for language server add timed out after 5 seconds");
- }
- };
- }
- let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
- let result = futures::select! {
- _ = rx.next() => {
- println!("{}⚑ Language server idle", log_prefix);
- anyhow::Ok(())
- },
- _ = timeout.fuse() => {
- anyhow::bail!("LSP wait timed out after 5 minutes");
- }
- };
- drop(subscriptions);
- result
- })
-}
-
fn main() {
zlog::init();
zlog::init_output_stderr();
@@ -0,0 +1,866 @@
+use ::util::rel_path::RelPath;
+use ::util::{RangeExt, ResultExt as _};
+use anyhow::{Context as _, Result};
+use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
+use edit_prediction_context::{
+ Declaration, DeclarationStyle, EditPredictionContext, Identifier, Imports, Reference,
+ ReferenceRegion, SyntaxIndex, SyntaxIndexState, references_in_range,
+};
+use futures::StreamExt as _;
+use futures::channel::mpsc;
+use gpui::Entity;
+use gpui::{AppContext, AsyncApp};
+use language::OffsetRangeExt;
+use language::{BufferSnapshot, Point};
+use ordered_float::OrderedFloat;
+use project::{Project, ProjectEntryId, ProjectPath, Worktree};
+use serde::{Deserialize, Serialize};
+use std::{
+ cmp::Reverse,
+ collections::{HashMap, HashSet},
+ fs::File,
+ hash::{Hash, Hasher},
+ io::{BufRead, BufReader, BufWriter, Write as _},
+ ops::Range,
+ path::{Path, PathBuf},
+ sync::{
+ Arc,
+ atomic::{self, AtomicUsize},
+ },
+ time::Duration,
+};
+use util::paths::PathStyle;
+
+use crate::headless::ZetaCliAppState;
+use crate::source_location::SourceLocation;
+use crate::util::{open_buffer, open_buffer_with_language_server};
+
+pub async fn retrieval_stats(
+ worktree: PathBuf,
+ app_state: Arc<ZetaCliAppState>,
+ only_extension: Option<String>,
+ file_limit: Option<usize>,
+ skip_files: Option<usize>,
+ options: zeta2::ZetaOptions,
+ cx: &mut AsyncApp,
+) -> Result<String> {
+ let options = Arc::new(options);
+ let worktree_path = worktree.canonicalize()?;
+
+ let project = cx.update(|cx| {
+ Project::local(
+ app_state.client.clone(),
+ app_state.node_runtime.clone(),
+ app_state.user_store.clone(),
+ app_state.languages.clone(),
+ app_state.fs.clone(),
+ None,
+ cx,
+ )
+ })?;
+
+ let worktree = project
+ .update(cx, |project, cx| {
+ project.create_worktree(&worktree_path, true, cx)
+ })?
+ .await?;
+
+ // wait for worktree scan so that wait_for_initial_file_indexing waits for the whole worktree.
+ worktree
+ .read_with(cx, |worktree, _cx| {
+ worktree.as_local().unwrap().scan_complete()
+ })?
+ .await;
+
+ let index = cx.new(|cx| SyntaxIndex::new(&project, options.file_indexing_parallelism, cx))?;
+ index
+ .read_with(cx, |index, cx| index.wait_for_initial_file_indexing(cx))?
+ .await?;
+ let indexed_files = index
+ .read_with(cx, |index, cx| index.indexed_file_paths(cx))?
+ .await;
+ let mut filtered_files = indexed_files
+ .into_iter()
+ .filter(|project_path| {
+ let file_extension = project_path.path.extension();
+ if let Some(only_extension) = only_extension.as_ref() {
+ file_extension.is_some_and(|extension| extension == only_extension)
+ } else {
+ file_extension
+ .is_some_and(|extension| !["md", "json", "sh", "diff"].contains(&extension))
+ }
+ })
+ .collect::<Vec<_>>();
+ filtered_files.sort_by(|a, b| a.path.cmp(&b.path));
+
+ let index_state = index.read_with(cx, |index, _cx| index.state().clone())?;
+ cx.update(|_| {
+ drop(index);
+ })?;
+ let index_state = Arc::new(
+ Arc::into_inner(index_state)
+ .context("Index state had more than 1 reference")?
+ .into_inner(),
+ );
+
+ struct FileSnapshot {
+ project_entry_id: ProjectEntryId,
+ snapshot: BufferSnapshot,
+ hash: u64,
+ parent_abs_path: Arc<Path>,
+ }
+
+ let files: Vec<FileSnapshot> = futures::future::try_join_all({
+ filtered_files
+ .iter()
+ .map(|file| {
+ let buffer_task =
+ open_buffer(project.clone(), worktree.clone(), file.path.clone(), cx);
+ cx.spawn(async move |cx| {
+ let buffer = buffer_task.await?;
+ let (project_entry_id, parent_abs_path, snapshot) =
+ buffer.read_with(cx, |buffer, cx| {
+ let file = project::File::from_dyn(buffer.file()).unwrap();
+ let project_entry_id = file.project_entry_id().unwrap();
+ let mut parent_abs_path = file.worktree.read(cx).absolutize(&file.path);
+ if !parent_abs_path.pop() {
+ panic!("Invalid worktree path");
+ }
+
+ (project_entry_id, parent_abs_path, buffer.snapshot())
+ })?;
+
+ anyhow::Ok(
+ cx.background_spawn(async move {
+ let mut hasher = collections::FxHasher::default();
+ snapshot.text().hash(&mut hasher);
+ FileSnapshot {
+ project_entry_id,
+ snapshot,
+ hash: hasher.finish(),
+ parent_abs_path: parent_abs_path.into(),
+ }
+ })
+ .await,
+ )
+ })
+ })
+ .collect::<Vec<_>>()
+ })
+ .await?;
+
+ let mut file_snapshots = HashMap::default();
+ let mut hasher = collections::FxHasher::default();
+ for FileSnapshot {
+ project_entry_id,
+ snapshot,
+ hash,
+ ..
+ } in &files
+ {
+ file_snapshots.insert(*project_entry_id, snapshot.clone());
+ hash.hash(&mut hasher);
+ }
+ let files_hash = hasher.finish();
+ let file_snapshots = Arc::new(file_snapshots);
+
+ let lsp_definitions_path = std::env::current_dir()?.join(format!(
+ "target/zeta2-lsp-definitions-{:x}.jsonl",
+ files_hash
+ ));
+
+ let mut lsp_definitions = HashMap::default();
+ let mut lsp_files = 0;
+
+ if std::fs::exists(&lsp_definitions_path)? {
+ log::info!(
+ "Using cached LSP definitions from {}",
+ lsp_definitions_path.display()
+ );
+
+ let file = File::options()
+ .read(true)
+ .write(true)
+ .open(&lsp_definitions_path)?;
+ let lines = BufReader::new(&file).lines();
+ let mut valid_len: usize = 0;
+
+ for (line, expected_file) in lines.zip(files.iter()) {
+ let line = line?;
+ let FileLspDefinitions { path, references } = match serde_json::from_str(&line) {
+ Ok(ok) => ok,
+ Err(_) => {
+ log::error!("Found invalid cache line. Truncating to #{lsp_files}.",);
+ file.set_len(valid_len as u64)?;
+ break;
+ }
+ };
+ let expected_path = expected_file.snapshot.file().unwrap().path().as_unix_str();
+ if expected_path != path.as_ref() {
+ log::error!(
+ "Expected file #{} to be {expected_path}, but found {path}. Truncating to #{lsp_files}.",
+ lsp_files + 1
+ );
+ file.set_len(valid_len as u64)?;
+ break;
+ }
+ for (point, ranges) in references {
+ let Ok(path) = RelPath::new(Path::new(path.as_ref()), PathStyle::Posix) else {
+ log::warn!("Invalid path: {}", path);
+ continue;
+ };
+ lsp_definitions.insert(
+ SourceLocation {
+ path: path.into_arc(),
+ point: point.into(),
+ },
+ ranges,
+ );
+ }
+ lsp_files += 1;
+ valid_len += line.len() + 1
+ }
+ }
+
+ if lsp_files < files.len() {
+ if lsp_files == 0 {
+ log::warn!(
+ "No LSP definitions found, populating {}",
+ lsp_definitions_path.display()
+ );
+ } else {
+ log::warn!("{} files missing from LSP cache", files.len() - lsp_files);
+ }
+
+ gather_lsp_definitions(
+ &lsp_definitions_path,
+ lsp_files,
+ &filtered_files,
+ &worktree,
+ &project,
+ &mut lsp_definitions,
+ cx,
+ )
+ .await?;
+ }
+ let files_len = files.len().min(file_limit.unwrap_or(usize::MAX));
+ let done_count = Arc::new(AtomicUsize::new(0));
+
+ let (output_tx, mut output_rx) = mpsc::unbounded::<RetrievalStatsResult>();
+ let mut output = std::fs::File::create("target/zeta-retrieval-stats.txt")?;
+
+ let tasks = files
+ .into_iter()
+ .skip(skip_files.unwrap_or(0))
+ .take(file_limit.unwrap_or(usize::MAX))
+ .map(|project_file| {
+ let index_state = index_state.clone();
+ let lsp_definitions = lsp_definitions.clone();
+ let options = options.clone();
+ let output_tx = output_tx.clone();
+ let done_count = done_count.clone();
+ let file_snapshots = file_snapshots.clone();
+ cx.background_spawn(async move {
+ let snapshot = project_file.snapshot;
+
+ let full_range = 0..snapshot.len();
+ let references = references_in_range(
+ full_range,
+ &snapshot.text(),
+ ReferenceRegion::Nearby,
+ &snapshot,
+ );
+
+ println!("references: {}", references.len(),);
+
+ let imports = if options.context.use_imports {
+ Imports::gather(&snapshot, Some(&project_file.parent_abs_path))
+ } else {
+ Imports::default()
+ };
+
+ let path = snapshot.file().unwrap().path();
+
+ for reference in references {
+ let query_point = snapshot.offset_to_point(reference.range.start);
+ let source_location = SourceLocation {
+ path: path.clone(),
+ point: query_point,
+ };
+ let lsp_definitions = lsp_definitions
+ .get(&source_location)
+ .cloned()
+ .unwrap_or_else(|| {
+ log::warn!(
+ "No definitions found for source location: {:?}",
+ source_location
+ );
+ Vec::new()
+ });
+
+ let retrieve_result = retrieve_definitions(
+ &reference,
+ &imports,
+ query_point,
+ &snapshot,
+ &index_state,
+ &file_snapshots,
+ &options,
+ )
+ .await?;
+
+ // TODO: LSP returns things like locals, this filters out some of those, but potentially
+ // hides some retrieval issues.
+ if retrieve_result.definitions.is_empty() {
+ continue;
+ }
+
+ let mut best_match = None;
+ let mut has_external_definition = false;
+ let mut in_excerpt = false;
+ for (index, retrieved_definition) in
+ retrieve_result.definitions.iter().enumerate()
+ {
+ for lsp_definition in &lsp_definitions {
+ let SourceRange {
+ path,
+ point_range,
+ offset_range,
+ } = lsp_definition;
+ let lsp_point_range =
+ SerializablePoint::into_language_point_range(point_range.clone());
+ has_external_definition = has_external_definition
+ || path.is_absolute()
+ || path
+ .components()
+ .any(|component| component.as_os_str() == "node_modules");
+ let is_match = path.as_path()
+ == retrieved_definition.path.as_std_path()
+ && retrieved_definition
+ .range
+ .contains_inclusive(&lsp_point_range);
+ if is_match {
+ if best_match.is_none() {
+ best_match = Some(index);
+ }
+ }
+ in_excerpt = in_excerpt
+ || retrieve_result.excerpt_range.as_ref().is_some_and(
+ |excerpt_range| excerpt_range.contains_inclusive(&offset_range),
+ );
+ }
+ }
+
+ let outcome = if let Some(best_match) = best_match {
+ RetrievalOutcome::Match { best_match }
+ } else if has_external_definition {
+ RetrievalOutcome::NoMatchDueToExternalLspDefinitions
+ } else if in_excerpt {
+ RetrievalOutcome::ProbablyLocal
+ } else {
+ RetrievalOutcome::NoMatch
+ };
+
+ let result = RetrievalStatsResult {
+ outcome,
+ path: path.clone(),
+ identifier: reference.identifier,
+ point: query_point,
+ lsp_definitions,
+ retrieved_definitions: retrieve_result.definitions,
+ };
+
+ output_tx.unbounded_send(result).ok();
+ }
+
+ println!(
+ "{:02}/{:02} done",
+ done_count.fetch_add(1, atomic::Ordering::Relaxed) + 1,
+ files_len,
+ );
+
+ anyhow::Ok(())
+ })
+ })
+ .collect::<Vec<_>>();
+
+ drop(output_tx);
+
+ let results_task = cx.background_spawn(async move {
+ let mut results = Vec::new();
+ while let Some(result) = output_rx.next().await {
+ output
+ .write_all(format!("{:#?}\n", result).as_bytes())
+ .log_err();
+ results.push(result)
+ }
+ results
+ });
+
+ futures::future::try_join_all(tasks).await?;
+ println!("Tasks completed");
+ let results = results_task.await;
+ println!("Results received");
+
+ let mut references_count = 0;
+
+ let mut included_count = 0;
+ let mut both_absent_count = 0;
+
+ let mut retrieved_count = 0;
+ let mut top_match_count = 0;
+ let mut non_top_match_count = 0;
+ let mut ranking_involved_top_match_count = 0;
+
+ let mut no_match_count = 0;
+ let mut no_match_none_retrieved = 0;
+ let mut no_match_wrong_retrieval = 0;
+
+ let mut expected_no_match_count = 0;
+ let mut in_excerpt_count = 0;
+ let mut external_definition_count = 0;
+
+ for result in results {
+ references_count += 1;
+ match &result.outcome {
+ RetrievalOutcome::Match { best_match } => {
+ included_count += 1;
+ retrieved_count += 1;
+ let multiple = result.retrieved_definitions.len() > 1;
+ if *best_match == 0 {
+ top_match_count += 1;
+ if multiple {
+ ranking_involved_top_match_count += 1;
+ }
+ } else {
+ non_top_match_count += 1;
+ }
+ }
+ RetrievalOutcome::NoMatch => {
+ if result.lsp_definitions.is_empty() {
+ included_count += 1;
+ both_absent_count += 1;
+ } else {
+ no_match_count += 1;
+ if result.retrieved_definitions.is_empty() {
+ no_match_none_retrieved += 1;
+ } else {
+ no_match_wrong_retrieval += 1;
+ }
+ }
+ }
+ RetrievalOutcome::NoMatchDueToExternalLspDefinitions => {
+ expected_no_match_count += 1;
+ external_definition_count += 1;
+ }
+ RetrievalOutcome::ProbablyLocal => {
+ included_count += 1;
+ in_excerpt_count += 1;
+ }
+ }
+ }
+
+ fn count_and_percentage(part: usize, total: usize) -> String {
+ format!("{} ({:.2}%)", part, (part as f64 / total as f64) * 100.0)
+ }
+
+ println!("");
+ println!("╮ references: {}", references_count);
+ println!(
+ "├─╮ included: {}",
+ count_and_percentage(included_count, references_count),
+ );
+ println!(
+ "│ ├─╮ retrieved: {}",
+ count_and_percentage(retrieved_count, references_count)
+ );
+ println!(
+ "│ │ ├─╮ top match : {}",
+ count_and_percentage(top_match_count, retrieved_count)
+ );
+ println!(
+ "│ │ │ ╰─╴ involving ranking: {}",
+ count_and_percentage(ranking_involved_top_match_count, top_match_count)
+ );
+ println!(
+ "│ │ ╰─╴ non-top match: {}",
+ count_and_percentage(non_top_match_count, retrieved_count)
+ );
+ println!(
+ "│ ├─╴ both absent: {}",
+ count_and_percentage(both_absent_count, included_count)
+ );
+ println!(
+ "│ ╰─╴ in excerpt: {}",
+ count_and_percentage(in_excerpt_count, included_count)
+ );
+ println!(
+ "├─╮ no match: {}",
+ count_and_percentage(no_match_count, references_count)
+ );
+ println!(
+ "│ ├─╴ none retrieved: {}",
+ count_and_percentage(no_match_none_retrieved, no_match_count)
+ );
+ println!(
+ "│ ╰─╴ wrong retrieval: {}",
+ count_and_percentage(no_match_wrong_retrieval, no_match_count)
+ );
+ println!(
+ "╰─╮ expected no match: {}",
+ count_and_percentage(expected_no_match_count, references_count)
+ );
+ println!(
+ " ╰─╴ external definition: {}",
+ count_and_percentage(external_definition_count, expected_no_match_count)
+ );
+
+ println!("");
+ println!("LSP definition cache at {}", lsp_definitions_path.display());
+
+ Ok("".to_string())
+}
+
+struct RetrieveResult {
+ definitions: Vec<RetrievedDefinition>,
+ excerpt_range: Option<Range<usize>>,
+}
+
+async fn retrieve_definitions(
+ reference: &Reference,
+ imports: &Imports,
+ query_point: Point,
+ snapshot: &BufferSnapshot,
+ index: &Arc<SyntaxIndexState>,
+ file_snapshots: &Arc<HashMap<ProjectEntryId, BufferSnapshot>>,
+ options: &Arc<zeta2::ZetaOptions>,
+) -> Result<RetrieveResult> {
+ let mut single_reference_map = HashMap::default();
+ single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]);
+ let edit_prediction_context = EditPredictionContext::gather_context_with_references_fn(
+ query_point,
+ snapshot,
+ imports,
+ &options.context,
+ Some(&index),
+ |_, _, _| single_reference_map,
+ );
+
+ let Some(edit_prediction_context) = edit_prediction_context else {
+ return Ok(RetrieveResult {
+ definitions: Vec::new(),
+ excerpt_range: None,
+ });
+ };
+
+ let mut retrieved_definitions = Vec::new();
+ for scored_declaration in edit_prediction_context.declarations {
+ match &scored_declaration.declaration {
+ Declaration::File {
+ project_entry_id,
+ declaration,
+ ..
+ } => {
+ let Some(snapshot) = file_snapshots.get(&project_entry_id) else {
+ log::error!("bug: file project entry not found");
+ continue;
+ };
+ let path = snapshot.file().unwrap().path().clone();
+ retrieved_definitions.push(RetrievedDefinition {
+ path,
+ range: snapshot.offset_to_point(declaration.item_range.start)
+ ..snapshot.offset_to_point(declaration.item_range.end),
+ score: scored_declaration.score(DeclarationStyle::Declaration),
+ retrieval_score: scored_declaration.retrieval_score(),
+ components: scored_declaration.components,
+ });
+ }
+ Declaration::Buffer {
+ project_entry_id,
+ rope,
+ declaration,
+ ..
+ } => {
+ let Some(snapshot) = file_snapshots.get(&project_entry_id) else {
+ // This case happens when dependency buffers have been opened by
+ // go-to-definition, resulting in single-file worktrees.
+ continue;
+ };
+ let path = snapshot.file().unwrap().path().clone();
+ retrieved_definitions.push(RetrievedDefinition {
+ path,
+ range: rope.offset_to_point(declaration.item_range.start)
+ ..rope.offset_to_point(declaration.item_range.end),
+ score: scored_declaration.score(DeclarationStyle::Declaration),
+ retrieval_score: scored_declaration.retrieval_score(),
+ components: scored_declaration.components,
+ });
+ }
+ }
+ }
+ retrieved_definitions.sort_by_key(|definition| Reverse(OrderedFloat(definition.score)));
+
+ Ok(RetrieveResult {
+ definitions: retrieved_definitions,
+ excerpt_range: Some(edit_prediction_context.excerpt.range),
+ })
+}
+
+async fn gather_lsp_definitions(
+ lsp_definitions_path: &Path,
+ start_index: usize,
+ files: &[ProjectPath],
+ worktree: &Entity<Worktree>,
+ project: &Entity<Project>,
+ definitions: &mut HashMap<SourceLocation, Vec<SourceRange>>,
+ cx: &mut AsyncApp,
+) -> Result<()> {
+ let worktree_id = worktree.read_with(cx, |worktree, _cx| worktree.id())?;
+
+ let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
+ cx.subscribe(&lsp_store, {
+ move |_, event, _| {
+ if let project::LspStoreEvent::LanguageServerUpdate {
+ message:
+ client::proto::update_language_server::Variant::WorkProgress(
+ client::proto::LspWorkProgress {
+ message: Some(message),
+ ..
+ },
+ ),
+ ..
+ } = event
+ {
+ println!("⟲ {message}")
+ }
+ }
+ })?
+ .detach();
+
+ let (cache_line_tx, mut cache_line_rx) = mpsc::unbounded::<FileLspDefinitions>();
+
+ let cache_file = File::options()
+ .append(true)
+ .create(true)
+ .open(lsp_definitions_path)
+ .unwrap();
+
+ let cache_task = cx.background_spawn(async move {
+ let mut writer = BufWriter::new(cache_file);
+ while let Some(line) = cache_line_rx.next().await {
+ serde_json::to_writer(&mut writer, &line).unwrap();
+ writer.write_all(&[b'\n']).unwrap();
+ }
+ writer.flush().unwrap();
+ });
+
+ let mut error_count = 0;
+ let mut lsp_open_handles = Vec::new();
+ let mut ready_languages = HashSet::default();
+ for (file_index, project_path) in files[start_index..].iter().enumerate() {
+ println!(
+ "Processing file {} of {}: {}",
+ start_index + file_index + 1,
+ files.len(),
+ project_path.path.display(PathStyle::Posix)
+ );
+
+ let Some((lsp_open_handle, language_server_id, buffer)) = open_buffer_with_language_server(
+ project.clone(),
+ worktree.clone(),
+ project_path.path.clone(),
+ &mut ready_languages,
+ cx,
+ )
+ .await
+ .log_err() else {
+ continue;
+ };
+ lsp_open_handles.push(lsp_open_handle);
+
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+ let full_range = 0..snapshot.len();
+ let references = references_in_range(
+ full_range,
+ &snapshot.text(),
+ ReferenceRegion::Nearby,
+ &snapshot,
+ );
+
+ loop {
+ let is_ready = lsp_store
+ .read_with(cx, |lsp_store, _cx| {
+ lsp_store
+ .language_server_statuses
+ .get(&language_server_id)
+ .is_some_and(|status| status.pending_work.is_empty())
+ })
+ .unwrap();
+ if is_ready {
+ break;
+ }
+ cx.background_executor()
+ .timer(Duration::from_millis(10))
+ .await;
+ }
+
+ let mut cache_line_references = Vec::with_capacity(references.len());
+
+ for reference in references {
+ // TODO: Rename declaration to definition in edit_prediction_context?
+ let lsp_result = project
+ .update(cx, |project, cx| {
+ project.definitions(&buffer, reference.range.start, cx)
+ })?
+ .await;
+
+ match lsp_result {
+ Ok(lsp_definitions) => {
+ let mut targets = Vec::new();
+ for target in lsp_definitions.unwrap_or_default() {
+ let buffer = target.target.buffer;
+ let anchor_range = target.target.range;
+ buffer.read_with(cx, |buffer, cx| {
+ let Some(file) = project::File::from_dyn(buffer.file()) else {
+ return;
+ };
+ let file_worktree = file.worktree.read(cx);
+ let file_worktree_id = file_worktree.id();
+ // Relative paths for worktree files, absolute for all others
+ let path = if worktree_id != file_worktree_id {
+ file.worktree.read(cx).absolutize(&file.path)
+ } else {
+ file.path.as_std_path().to_path_buf()
+ };
+ let offset_range = anchor_range.to_offset(&buffer);
+ let point_range = SerializablePoint::from_language_point_range(
+ offset_range.to_point(&buffer),
+ );
+ targets.push(SourceRange {
+ path,
+ offset_range,
+ point_range,
+ });
+ })?;
+ }
+
+ let point = snapshot.offset_to_point(reference.range.start);
+
+ cache_line_references.push((point.into(), targets.clone()));
+ definitions.insert(
+ SourceLocation {
+ path: project_path.path.clone(),
+ point,
+ },
+ targets,
+ );
+ }
+ Err(err) => {
+ log::error!("Language server error: {err}");
+ error_count += 1;
+ }
+ }
+ }
+
+ cache_line_tx
+ .unbounded_send(FileLspDefinitions {
+ path: project_path.path.as_unix_str().into(),
+ references: cache_line_references,
+ })
+ .log_err();
+ }
+
+ drop(cache_line_tx);
+
+ if error_count > 0 {
+ log::error!("Encountered {} language server errors", error_count);
+ }
+
+ cache_task.await;
+
+ Ok(())
+}
+
+#[derive(Serialize, Deserialize)]
+struct FileLspDefinitions {
+ path: Arc<str>,
+ references: Vec<(SerializablePoint, Vec<SourceRange>)>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+struct SourceRange {
+ path: PathBuf,
+ point_range: Range<SerializablePoint>,
+ offset_range: Range<usize>,
+}
+
+/// Serializes to 1-based row and column indices.
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct SerializablePoint {
+ pub row: u32,
+ pub column: u32,
+}
+
+impl SerializablePoint {
+ pub fn into_language_point_range(range: Range<Self>) -> Range<Point> {
+ range.start.into()..range.end.into()
+ }
+
+ pub fn from_language_point_range(range: Range<Point>) -> Range<Self> {
+ range.start.into()..range.end.into()
+ }
+}
+
+impl From<Point> for SerializablePoint {
+ fn from(point: Point) -> Self {
+ SerializablePoint {
+ row: point.row + 1,
+ column: point.column + 1,
+ }
+ }
+}
+
+impl From<SerializablePoint> for Point {
+ fn from(serializable: SerializablePoint) -> Self {
+ Point {
+ row: serializable.row.saturating_sub(1),
+ column: serializable.column.saturating_sub(1),
+ }
+ }
+}
+
+#[derive(Debug)]
+struct RetrievalStatsResult {
+ outcome: RetrievalOutcome,
+ #[allow(dead_code)]
+ path: Arc<RelPath>,
+ #[allow(dead_code)]
+ identifier: Identifier,
+ #[allow(dead_code)]
+ point: Point,
+ #[allow(dead_code)]
+ lsp_definitions: Vec<SourceRange>,
+ retrieved_definitions: Vec<RetrievedDefinition>,
+}
+
+#[derive(Debug)]
+enum RetrievalOutcome {
+ Match {
+ /// Lowest index within retrieved_definitions that matches an LSP definition.
+ best_match: usize,
+ },
+ ProbablyLocal,
+ NoMatch,
+ NoMatchDueToExternalLspDefinitions,
+}
+
+#[derive(Debug)]
+struct RetrievedDefinition {
+ path: Arc<RelPath>,
+ range: Range<Point>,
+ score: f32,
+ #[allow(dead_code)]
+ retrieval_score: f32,
+ #[allow(dead_code)]
+ components: DeclarationScoreComponents,
+}
@@ -0,0 +1,186 @@
+use anyhow::{Result, anyhow};
+use futures::channel::mpsc;
+use futures::{FutureExt as _, StreamExt as _};
+use gpui::{AsyncApp, Entity, Task};
+use language::{Buffer, LanguageId, LanguageServerId, ParseStatus};
+use project::{Project, ProjectPath, Worktree};
+use std::collections::HashSet;
+use std::sync::Arc;
+use std::time::Duration;
+use util::rel_path::RelPath;
+
+pub fn open_buffer(
+ project: Entity<Project>,
+ worktree: Entity<Worktree>,
+ path: Arc<RelPath>,
+ cx: &AsyncApp,
+) -> Task<Result<Entity<Buffer>>> {
+ cx.spawn(async move |cx| {
+ let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
+ worktree_id: worktree.id(),
+ path,
+ })?;
+
+ let buffer = project
+ .update(cx, |project, cx| project.open_buffer(project_path, cx))?
+ .await?;
+
+ let mut parse_status = buffer.read_with(cx, |buffer, _cx| buffer.parse_status())?;
+ while *parse_status.borrow() != ParseStatus::Idle {
+ parse_status.changed().await?;
+ }
+
+ Ok(buffer)
+ })
+}
+
+pub async fn open_buffer_with_language_server(
+ project: Entity<Project>,
+ worktree: Entity<Worktree>,
+ path: Arc<RelPath>,
+ ready_languages: &mut HashSet<LanguageId>,
+ cx: &mut AsyncApp,
+) -> Result<(Entity<Entity<Buffer>>, LanguageServerId, Entity<Buffer>)> {
+ let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?;
+
+ let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
+ (
+ project.register_buffer_with_language_servers(&buffer, cx),
+ project.path_style(cx),
+ )
+ })?;
+
+ let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
+ buffer.language().map(|language| language.id())
+ })?
+ else {
+ return Err(anyhow!("No language for {}", path.display(path_style)));
+ };
+
+ let log_prefix = path.display(path_style);
+ if !ready_languages.contains(&language_id) {
+ wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
+ ready_languages.insert(language_id);
+ }
+
+ let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
+
+ // hacky wait for buffer to be registered with the language server
+ for _ in 0..100 {
+ let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| {
+ buffer.update(cx, |buffer, cx| {
+ lsp_store
+ .language_servers_for_local_buffer(&buffer, cx)
+ .next()
+ .map(|(_, language_server)| language_server.server_id())
+ })
+ })?
+ else {
+ cx.background_executor()
+ .timer(Duration::from_millis(10))
+ .await;
+ continue;
+ };
+
+ return Ok((lsp_open_handle, language_server_id, buffer));
+ }
+
+ return Err(anyhow!("No language server found for buffer"));
+}
+
+// TODO: Dedupe with similar function in crates/eval/src/instance.rs
+pub fn wait_for_lang_server(
+ project: &Entity<Project>,
+ buffer: &Entity<Buffer>,
+ log_prefix: String,
+ cx: &mut AsyncApp,
+) -> Task<Result<()>> {
+ println!("{}⏵ Waiting for language server", log_prefix);
+
+ let (mut tx, mut rx) = mpsc::channel(1);
+
+ let lsp_store = project
+ .read_with(cx, |project, _| project.lsp_store())
+ .unwrap();
+
+ let has_lang_server = buffer
+ .update(cx, |buffer, cx| {
+ lsp_store.update(cx, |lsp_store, cx| {
+ lsp_store
+ .language_servers_for_local_buffer(buffer, cx)
+ .next()
+ .is_some()
+ })
+ })
+ .unwrap_or(false);
+
+ if has_lang_server {
+ project
+ .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
+ .unwrap()
+ .detach();
+ }
+ let (mut added_tx, mut added_rx) = mpsc::channel(1);
+
+ let subscriptions = [
+ cx.subscribe(&lsp_store, {
+ let log_prefix = log_prefix.clone();
+ move |_, event, _| {
+ if let project::LspStoreEvent::LanguageServerUpdate {
+ message:
+ client::proto::update_language_server::Variant::WorkProgress(
+ client::proto::LspWorkProgress {
+ message: Some(message),
+ ..
+ },
+ ),
+ ..
+ } = event
+ {
+ println!("{}⟲ {message}", log_prefix)
+ }
+ }
+ }),
+ cx.subscribe(project, {
+ let buffer = buffer.clone();
+ move |project, event, cx| match event {
+ project::Event::LanguageServerAdded(_, _, _) => {
+ let buffer = buffer.clone();
+ project
+ .update(cx, |project, cx| project.save_buffer(buffer, cx))
+ .detach();
+ added_tx.try_send(()).ok();
+ }
+ project::Event::DiskBasedDiagnosticsFinished { .. } => {
+ tx.try_send(()).ok();
+ }
+ _ => {}
+ }
+ }),
+ ];
+
+ cx.spawn(async move |cx| {
+ if !has_lang_server {
+ // some buffers never have a language server, so this aborts quickly in that case.
+ let timeout = cx.background_executor().timer(Duration::from_secs(5));
+ futures::select! {
+ _ = added_rx.next() => {},
+ _ = timeout.fuse() => {
+ anyhow::bail!("Waiting for language server add timed out after 5 seconds");
+ }
+ };
+ }
+ let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
+ let result = futures::select! {
+ _ = rx.next() => {
+ println!("{}⚑ Language server idle", log_prefix);
+ anyhow::Ok(())
+ },
+ _ = timeout.fuse() => {
+ anyhow::bail!("LSP wait timed out after 5 minutes");
+ }
+ };
+ drop(subscriptions);
+ result
+ })
+}