zeta_cli: Avoid unnecessary rechecks in `retrieval-stats` (#39267)

Michael Sloan and Agus created

Before this change, it would save every buffer and wait for diagnostics.
For rust analyzer this would cause a lot of rechecking and greatly slow
down the analysis

Release Notes:

- N/A

Co-authored-by: Agus <agus@zed.dev>

Change summary

crates/edit_prediction_context/src/syntax_index.rs |   4 
crates/zeta_cli/src/main.rs                        | 144 +++++++++++++--
2 files changed, 126 insertions(+), 22 deletions(-)

Detailed changes

crates/edit_prediction_context/src/syntax_index.rs 🔗

@@ -35,6 +35,10 @@ use crate::outline::declarations_in_buffer;
 //
 // * Send multiple selected excerpt ranges. Challenge is that excerpt ranges influence which
 // references are present and their scores.
+//
+// * Include single-file worktrees / non visible worktrees? E.g. go to definition that resolves to a
+// file in a build dependency. Should not be editable in that case - but how to distinguish the case
+// where it should be editable?
 
 // Potential future optimizations:
 //

crates/zeta_cli/src/main.rs 🔗

@@ -11,9 +11,9 @@ use futures::channel::mpsc;
 use futures::{FutureExt as _, StreamExt as _};
 use gpui::{AppContext, Application, AsyncApp};
 use gpui::{Entity, Task};
-use language::Bias;
-use language::Point;
+use language::{Bias, LanguageServerId};
 use language::{Buffer, OffsetRangeExt};
+use language::{LanguageId, Point};
 use language_model::LlmApiToken;
 use ordered_float::OrderedFloat;
 use project::{Project, ProjectPath, Worktree};
@@ -21,7 +21,7 @@ use release_channel::AppVersion;
 use reqwest_client::ReqwestClient;
 use serde_json::json;
 use std::cmp::Reverse;
-use std::collections::HashMap;
+use std::collections::{HashMap, HashSet};
 use std::io::Write as _;
 use std::ops::Range;
 use std::path::{Path, PathBuf};
@@ -222,9 +222,16 @@ async fn get_context(
         })?
         .await?;
 
+    let mut ready_languages = HashSet::default();
     let (_lsp_open_handle, buffer) = if use_language_server {
-        let (lsp_open_handle, buffer) =
-            open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
+        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
+            &project,
+            &worktree,
+            &cursor.path,
+            &mut ready_languages,
+            cx,
+        )
+        .await?;
         (Some(lsp_open_handle), buffer)
     } else {
         let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
@@ -373,23 +380,59 @@ pub async fn retrieval_stats(
         .await?;
     let files = index
         .read_with(cx, |index, cx| index.indexed_file_paths(cx))?
-        .await;
+        .await
+        .into_iter()
+        .filter(|project_path| {
+            project_path
+                .path
+                .extension()
+                .is_some_and(|extension| !["md", "json", "sh", "diff"].contains(&extension))
+        })
+        .collect::<Vec<_>>();
+
+    let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
+    cx.subscribe(&lsp_store, {
+        move |_, event, _| {
+            if let project::LspStoreEvent::LanguageServerUpdate {
+                message:
+                    client::proto::update_language_server::Variant::WorkProgress(
+                        client::proto::LspWorkProgress {
+                            message: Some(message),
+                            ..
+                        },
+                    ),
+                ..
+            } = event
+            {
+                println!("⟲ {message}")
+            }
+        }
+    })?
+    .detach();
 
     let mut lsp_open_handles = Vec::new();
     let mut output = std::fs::File::create("retrieval-stats.txt")?;
     let mut results = Vec::new();
+    let mut ready_languages = HashSet::default();
     for (file_index, project_path) in files.iter().enumerate() {
-        println!(
+        let processing_file_message = format!(
             "Processing file {} of {}: {}",
             file_index + 1,
             files.len(),
             project_path.path.display(PathStyle::Posix)
         );
-        let Some((lsp_open_handle, buffer)) =
-            open_buffer_with_language_server(&project, &worktree, &project_path.path, cx)
-                .await
-                .log_err()
-        else {
+        println!("{}", processing_file_message);
+        write!(output, "{processing_file_message}\n\n").ok();
+
+        let Some((lsp_open_handle, language_server_id, buffer)) = open_buffer_with_language_server(
+            &project,
+            &worktree,
+            &project_path.path,
+            &mut ready_languages,
+            cx,
+        )
+        .await
+        .log_err() else {
             continue;
         };
         lsp_open_handles.push(lsp_open_handle);
@@ -403,6 +446,23 @@ pub async fn retrieval_stats(
             &snapshot,
         );
 
+        loop {
+            let is_ready = lsp_store
+                .read_with(cx, |lsp_store, _cx| {
+                    lsp_store
+                        .language_server_statuses
+                        .get(&language_server_id)
+                        .is_some_and(|status| status.pending_work.is_empty())
+                })
+                .unwrap();
+            if is_ready {
+                break;
+            }
+            cx.background_executor()
+                .timer(Duration::from_millis(10))
+                .await;
+        }
+
         let index = index.read_with(cx, |index, _cx| index.state().clone())?;
         let index = index.lock().await;
         for reference in references {
@@ -472,7 +532,8 @@ pub async fn retrieval_stats(
                                 .map(|entry| entry.path.clone())
                         })?
                         else {
-                            log::error!("bug: buffer project entry not found");
+                            // This case happens when dependency buffers have been opened by
+                            // go-to-definition, resulting in single-file worktrees.
                             continue;
                         };
                         retrieved_definitions.push((
@@ -511,10 +572,16 @@ pub async fn retrieval_stats(
                                 .target
                                 .buffer
                                 .read_with(cx, |buffer, _cx| {
-                                    Some((
-                                        buffer.file()?.path().clone(),
-                                        definition.target.range.to_point(&buffer),
-                                    ))
+                                    let path = buffer.file()?.path();
+                                    // filter out definitions from single-file worktrees
+                                    if path.is_empty() {
+                                        None
+                                    } else {
+                                        Some((
+                                            path.clone(),
+                                            definition.target.range.to_point(&buffer),
+                                        ))
+                                    }
                                 })
                                 .ok()?
                         })
@@ -659,8 +726,9 @@ pub async fn open_buffer_with_language_server(
     project: &Entity<Project>,
     worktree: &Entity<Worktree>,
     path: &RelPath,
+    ready_languages: &mut HashSet<LanguageId>,
     cx: &mut AsyncApp,
-) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
+) -> Result<(Entity<Entity<Buffer>>, LanguageServerId, Entity<Buffer>)> {
     let buffer = open_buffer(project, worktree, path, cx).await?;
 
     let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
@@ -670,10 +738,42 @@ pub async fn open_buffer_with_language_server(
         )
     })?;
 
+    let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
+        buffer.language().map(|language| language.id())
+    })?
+    else {
+        return Err(anyhow!("No language for {}", path.display(path_style)));
+    };
+
     let log_prefix = path.display(path_style);
-    wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
+    if !ready_languages.contains(&language_id) {
+        wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
+        ready_languages.insert(language_id);
+    }
+
+    let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
+
+    // hacky wait for buffer to be registered with the language server
+    for _ in 0..100 {
+        let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| {
+            buffer.update(cx, |buffer, cx| {
+                lsp_store
+                    .language_servers_for_local_buffer(&buffer, cx)
+                    .next()
+                    .map(|(_, language_server)| language_server.server_id())
+            })
+        })?
+        else {
+            cx.background_executor()
+                .timer(Duration::from_millis(10))
+                .await;
+            continue;
+        };
+
+        return Ok((lsp_open_handle, language_server_id, buffer));
+    }
 
-    Ok((lsp_open_handle, buffer))
+    return Err(anyhow!("No language server found for buffer"));
 }
 
 // TODO: Dedupe with similar function in crates/eval/src/instance.rs
@@ -750,11 +850,11 @@ pub fn wait_for_lang_server(
     cx.spawn(async move |cx| {
         if !has_lang_server {
             // some buffers never have a language server, so this aborts quickly in that case.
-            let timeout = cx.background_executor().timer(Duration::from_secs(1));
+            let timeout = cx.background_executor().timer(Duration::from_secs(5));
             futures::select! {
                 _ = added_rx.next() => {},
                 _ = timeout.fuse() => {
-                    anyhow::bail!("Waiting for language server add timed out after 1 second");
+                    anyhow::bail!("Waiting for language server add timed out after 5 seconds");
                 }
             };
         }