edit prediction cli: Improve language server reliability (#44666)

Agus Zubiaga created

We weren't waiting for ALL language servers of a buffer to start, only
the first one.

Release Notes:

- N/A

Change summary

crates/edit_prediction_cli/src/load_project.rs     |   5 
crates/edit_prediction_cli/src/retrieve_context.rs | 175 ++++++---------
crates/eval/src/instance.rs                        |   2 
crates/project/src/lsp_store.rs                    |  15 +
crates/project/src/project.rs                      |   2 
crates/project/src/project_tests.rs                |   6 
6 files changed, 93 insertions(+), 112 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/load_project.rs 🔗

@@ -223,11 +223,6 @@ pub async fn setup_worktree(example: &Example) {
         let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"])
             .await
             .unwrap();
-        if revision != example.revision {
-            run_git(&repo_dir, &["tag", &example.revision, &revision])
-                .await
-                .unwrap();
-        }
         revision
     };
 

crates/edit_prediction_cli/src/retrieve_context.rs 🔗

@@ -3,11 +3,10 @@ use crate::{
     headless::EpAppState,
     load_project::run_load_project,
 };
-use anyhow::Result;
 use collections::HashSet;
 use edit_prediction::{DebugEvent, EditPredictionStore};
 use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
-use gpui::{AsyncApp, Entity, Task};
+use gpui::{AsyncApp, Entity};
 use language::Buffer;
 use project::Project;
 use std::{sync::Arc, time::Duration};
@@ -31,8 +30,7 @@ pub async fn run_context_retrieval(
             project.register_buffer_with_language_servers(&state.buffer, cx)
         })
         .unwrap();
-
-    wait_for_language_server_to_start(example, &project, &state.buffer, &mut cx).await;
+    wait_for_language_servers_to_start(example, &project, &state.buffer, &mut cx).await;
 
     let ep_store = cx
         .update(|cx| EditPredictionStore::try_global(cx).unwrap())
@@ -65,92 +63,73 @@ pub async fn run_context_retrieval(
     });
 }
 
-async fn wait_for_language_server_to_start(
+async fn wait_for_language_servers_to_start(
     example: &Example,
     project: &Entity<Project>,
     buffer: &Entity<Buffer>,
     cx: &mut AsyncApp,
 ) {
-    let Some(language_id) = buffer
-        .read_with(cx, |buffer, _cx| {
-            buffer.language().map(|language| language.id())
-        })
-        .unwrap()
-    else {
-        panic!("No language for {:?}", example.cursor_path);
-    };
-
-    let mut ready_languages = HashSet::default();
     let log_prefix = format!("{} | ", example.name);
-    if !ready_languages.contains(&language_id) {
-        wait_for_lang_server(&project, &buffer, log_prefix, cx)
-            .await
-            .unwrap();
-        ready_languages.insert(language_id);
-    }
-
-    let lsp_store = project
-        .read_with(cx, |project, _cx| project.lsp_store())
-        .unwrap();
-
-    // hacky wait for buffer to be registered with the language server
-    for _ in 0..100 {
-        if lsp_store
-            .update(cx, |lsp_store, cx| {
-                buffer.update(cx, |buffer, cx| {
-                    lsp_store
-                        .language_servers_for_local_buffer(&buffer, cx)
-                        .next()
-                        .map(|(_, language_server)| language_server.server_id())
-                })
-            })
-            .unwrap()
-            .is_some()
-        {
-            return;
-        } else {
-            cx.background_executor()
-                .timer(Duration::from_millis(10))
-                .await;
-        }
-    }
-
-    panic!("No language server found for buffer");
-}
-
-pub fn wait_for_lang_server(
-    project: &Entity<Project>,
-    buffer: &Entity<Buffer>,
-    log_prefix: String,
-    cx: &mut AsyncApp,
-) -> Task<Result<()>> {
-    eprintln!("{}⏵ Waiting for language server", log_prefix);
-
-    let (mut tx, mut rx) = mpsc::channel(1);
 
     let lsp_store = project
         .read_with(cx, |project, _| project.lsp_store())
         .unwrap();
 
-    let has_lang_server = buffer
+    let lang_server_ids = buffer
         .update(cx, |buffer, cx| {
             lsp_store.update(cx, |lsp_store, cx| {
-                lsp_store
-                    .language_servers_for_local_buffer(buffer, cx)
-                    .next()
-                    .is_some()
+                lsp_store.language_servers_for_local_buffer(buffer, cx)
             })
         })
-        .unwrap_or(false);
+        .unwrap_or_default();
 
-    if has_lang_server {
+    if !lang_server_ids.is_empty() {
         project
             .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
             .unwrap()
             .detach();
     }
-    let (mut added_tx, mut added_rx) = mpsc::channel(1);
 
+    eprintln!(
+        "{}⏵ Waiting for {} language servers",
+        log_prefix,
+        lang_server_ids.len()
+    );
+
+    let timeout = cx
+        .background_executor()
+        .timer(Duration::from_secs(60 * 5))
+        .shared();
+
+    let (mut tx, mut rx) = mpsc::channel(lang_server_ids.len());
+    let added_subscription = cx.subscribe(project, {
+        let log_prefix = log_prefix.clone();
+        move |_, event, _| match event {
+            project::Event::LanguageServerAdded(language_server_id, name, _) => {
+                eprintln!("{}+ Language server started: {}", log_prefix, name);
+                tx.try_send(*language_server_id).ok();
+            }
+            _ => {}
+        }
+    });
+
+    let mut pending_language_server_ids = HashSet::from_iter(lang_server_ids.iter());
+    while !pending_language_server_ids.is_empty() {
+        futures::select! {
+            language_server_id = rx.next() => {
+                if let Some(id) = language_server_id {
+                    pending_language_server_ids.remove(&id);
+                }
+            },
+            _ = timeout.clone().fuse() => {
+                panic!("LSP wait timed out after 5 minutes");
+            }
+        }
+    }
+
+    drop(added_subscription);
+
+    let (mut tx, mut rx) = mpsc::channel(lang_server_ids.len());
     let subscriptions = [
         cx.subscribe(&lsp_store, {
             let log_prefix = log_prefix.clone();
@@ -171,45 +150,41 @@ pub fn wait_for_lang_server(
             }
         }),
         cx.subscribe(project, {
-            let buffer = buffer.clone();
-            move |project, event, cx| match event {
-                project::Event::LanguageServerAdded(_, _, _) => {
-                    let buffer = buffer.clone();
-                    project
-                        .update(cx, |project, cx| project.save_buffer(buffer, cx))
-                        .detach();
-                    added_tx.try_send(()).ok();
-                }
-                project::Event::DiskBasedDiagnosticsFinished { .. } => {
-                    tx.try_send(()).ok();
+            let log_prefix = log_prefix.clone();
+            move |_, event, cx| match event {
+                project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
+                    let lsp_store = lsp_store.read(cx);
+                    let name = lsp_store
+                        .language_server_adapter_for_id(*language_server_id)
+                        .unwrap()
+                        .name();
+                    eprintln!("{}⚑ Language server idle: {}", log_prefix, name);
+                    tx.try_send(*language_server_id).ok();
                 }
                 _ => {}
             }
         }),
     ];
 
-    cx.spawn(async move |cx| {
-        if !has_lang_server {
-            // some buffers never have a language server, so this aborts quickly in that case.
-            let timeout = cx.background_executor().timer(Duration::from_secs(500));
-            futures::select! {
-                _ = added_rx.next() => {},
-                _ = timeout.fuse() => {
-                    anyhow::bail!("Waiting for language server add timed out after 5 seconds");
+    project
+        .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
+        .unwrap()
+        .await
+        .unwrap();
+
+    let mut pending_language_server_ids = HashSet::from_iter(lang_server_ids.into_iter());
+    while !pending_language_server_ids.is_empty() {
+        futures::select! {
+            language_server_id = rx.next() => {
+                if let Some(id) = language_server_id {
+                    pending_language_server_ids.remove(&id);
                 }
-            };
-        }
-        let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
-        let result = futures::select! {
-            _ = rx.next() => {
-                eprintln!("{}⚑ Language server idle", log_prefix);
-                anyhow::Ok(())
             },
-            _ = timeout.fuse() => {
-                anyhow::bail!("LSP wait timed out after 5 minutes");
+            _ = timeout.clone().fuse() => {
+                panic!("LSP wait timed out after 5 minutes");
             }
-        };
-        drop(subscriptions);
-        result
-    })
+        }
+    }
+
+    drop(subscriptions);
 }

crates/eval/src/instance.rs 🔗

@@ -892,7 +892,7 @@ pub fn wait_for_lang_server(
         .update(cx, |buffer, cx| {
             lsp_store.update(cx, |lsp_store, cx| {
                 lsp_store
-                    .language_servers_for_local_buffer(buffer, cx)
+                    .running_language_servers_for_local_buffer(buffer, cx)
                     .next()
                     .is_some()
             })

crates/project/src/lsp_store.rs 🔗

@@ -6783,7 +6783,7 @@ impl LspStore {
             })
         } else {
             let servers = buffer.update(cx, |buffer, cx| {
-                self.language_servers_for_local_buffer(buffer, cx)
+                self.running_language_servers_for_local_buffer(buffer, cx)
                     .map(|(_, server)| server.clone())
                     .collect::<Vec<_>>()
             });
@@ -8123,7 +8123,7 @@ impl LspStore {
         })
     }
 
-    pub fn language_servers_for_local_buffer<'a>(
+    pub fn running_language_servers_for_local_buffer<'a>(
         &'a self,
         buffer: &Buffer,
         cx: &mut App,
@@ -8145,6 +8145,17 @@ impl LspStore {
             )
     }
 
+    pub fn language_servers_for_local_buffer(
+        &self,
+        buffer: &Buffer,
+        cx: &mut App,
+    ) -> Vec<LanguageServerId> {
+        let local = self.as_local();
+        local
+            .map(|local| local.language_server_ids_for_buffer(buffer, cx))
+            .unwrap_or_default()
+    }
+
     pub fn language_server_for_local_buffer<'a>(
         &'a self,
         buffer: &'a Buffer,

crates/project/src/project.rs 🔗

@@ -5190,7 +5190,7 @@ impl Project {
     #[cfg(any(test, feature = "test-support"))]
     pub fn has_language_servers_for(&self, buffer: &Buffer, cx: &mut App) -> bool {
         self.lsp_store.update(cx, |this, cx| {
-            this.language_servers_for_local_buffer(buffer, cx)
+            this.running_language_servers_for_local_buffer(buffer, cx)
                 .next()
                 .is_some()
         })

crates/project/src/project_tests.rs 🔗

@@ -691,7 +691,7 @@ async fn test_running_multiple_instances_of_a_single_server_in_one_worktree(
     let servers = project.update(cx, |project, cx| {
         project.lsp_store.update(cx, |this, cx| {
             first_buffer.update(cx, |buffer, cx| {
-                this.language_servers_for_local_buffer(buffer, cx)
+                this.running_language_servers_for_local_buffer(buffer, cx)
                     .map(|(adapter, server)| (adapter.clone(), server.clone()))
                     .collect::<Vec<_>>()
             })
@@ -720,7 +720,7 @@ async fn test_running_multiple_instances_of_a_single_server_in_one_worktree(
     let servers = project.update(cx, |project, cx| {
         project.lsp_store.update(cx, |this, cx| {
             second_project_buffer.update(cx, |buffer, cx| {
-                this.language_servers_for_local_buffer(buffer, cx)
+                this.running_language_servers_for_local_buffer(buffer, cx)
                     .map(|(adapter, server)| (adapter.clone(), server.clone()))
                     .collect::<Vec<_>>()
             })
@@ -791,7 +791,7 @@ async fn test_running_multiple_instances_of_a_single_server_in_one_worktree(
     let servers = project.update(cx, |project, cx| {
         project.lsp_store.update(cx, |this, cx| {
             second_project_buffer.update(cx, |buffer, cx| {
-                this.language_servers_for_local_buffer(buffer, cx)
+                this.running_language_servers_for_local_buffer(buffer, cx)
                     .map(|(adapter, server)| (adapter.clone(), server.clone()))
                     .collect::<Vec<_>>()
             })