Store `result_id`s per language server (#32631)

Kirill Bulatov created

Follow-up of https://github.com/zed-industries/zed/pull/32403


Release Notes:

- N/A

Change summary

crates/editor/src/editor_tests.rs |   6 
crates/project/src/lsp_command.rs |  85 +++++-----
crates/project/src/lsp_store.rs   | 246 ++++++++++++++++++--------------
3 files changed, 186 insertions(+), 151 deletions(-)

Detailed changes

crates/editor/src/editor_tests.rs 🔗

@@ -21942,6 +21942,7 @@ async fn test_pulling_diagnostics(cx: &mut TestAppContext) {
         .downcast::<Editor>()
         .unwrap();
     let fake_server = fake_servers.next().await.unwrap();
+    let server_id = fake_server.server.server_id();
     let mut first_request = fake_server
         .set_request_handler::<lsp::request::DocumentDiagnosticRequest, _, _>(move |params, _| {
             let new_result_id = counter.fetch_add(1, atomic::Ordering::Release) + 1;
@@ -21973,7 +21974,10 @@ async fn test_pulling_diagnostics(cx: &mut TestAppContext) {
                 .expect("created a singleton buffer")
                 .read(cx)
                 .remote_id();
-            let buffer_result_id = project.lsp_store().read(cx).result_id(buffer_id, cx);
+            let buffer_result_id = project
+                .lsp_store()
+                .read(cx)
+                .result_id(server_id, buffer_id, cx);
             assert_eq!(expected, buffer_result_id);
         });
     };

crates/project/src/lsp_command.rs 🔗

@@ -3668,6 +3668,39 @@ impl LspCommand for LinkedEditingRange {
 }
 
 impl GetDocumentDiagnostics {
+    pub fn diagnostics_from_proto(
+        response: proto::GetDocumentDiagnosticsResponse,
+    ) -> Vec<LspPullDiagnostics> {
+        response
+            .pulled_diagnostics
+            .into_iter()
+            .filter_map(|diagnostics| {
+                Some(LspPullDiagnostics::Response {
+                    server_id: LanguageServerId::from_proto(diagnostics.server_id),
+                    uri: lsp::Url::from_str(diagnostics.uri.as_str()).log_err()?,
+                    diagnostics: if diagnostics.changed {
+                        PulledDiagnostics::Unchanged {
+                            result_id: diagnostics.result_id?,
+                        }
+                    } else {
+                        PulledDiagnostics::Changed {
+                            result_id: diagnostics.result_id,
+                            diagnostics: diagnostics
+                                .diagnostics
+                                .into_iter()
+                                .filter_map(|diagnostic| {
+                                    GetDocumentDiagnostics::deserialize_lsp_diagnostic(diagnostic)
+                                        .context("deserializing diagnostics")
+                                        .log_err()
+                                })
+                                .collect(),
+                        }
+                    },
+                })
+            })
+            .collect()
+    }
+
     fn deserialize_lsp_diagnostic(diagnostic: proto::LspDiagnostic) -> Result<lsp::Diagnostic> {
         let start = diagnostic.start.context("invalid start range")?;
         let end = diagnostic.end.context("invalid end range")?;
@@ -4037,21 +4070,14 @@ impl LspCommand for GetDocumentDiagnostics {
     }
 
     async fn from_proto(
-        message: proto::GetDocumentDiagnostics,
-        lsp_store: Entity<LspStore>,
-        buffer: Entity<Buffer>,
-        mut cx: AsyncApp,
+        _: proto::GetDocumentDiagnostics,
+        _: Entity<LspStore>,
+        _: Entity<Buffer>,
+        _: AsyncApp,
     ) -> Result<Self> {
-        buffer
-            .update(&mut cx, |buffer, _| {
-                buffer.wait_for_version(deserialize_version(&message.version))
-            })?
-            .await?;
-        let buffer_id = buffer.update(&mut cx, |buffer, _| buffer.remote_id())?;
-        Ok(Self {
-            previous_result_id: lsp_store
-                .update(&mut cx, |lsp_store, cx| lsp_store.result_id(buffer_id, cx))?,
-        })
+        anyhow::bail!(
+            "proto::GetDocumentDiagnostics is not expected to be converted from proto directly, as it needs `previous_result_id` fetched first"
+        )
     }
 
     fn response_to_proto(
@@ -4109,36 +4135,7 @@ impl LspCommand for GetDocumentDiagnostics {
         _: Entity<Buffer>,
         _: AsyncApp,
     ) -> Result<Self::Response> {
-        let pulled_diagnostics = response
-            .pulled_diagnostics
-            .into_iter()
-            .filter_map(|diagnostics| {
-                Some(LspPullDiagnostics::Response {
-                    server_id: LanguageServerId::from_proto(diagnostics.server_id),
-                    uri: lsp::Url::from_str(diagnostics.uri.as_str()).log_err()?,
-                    diagnostics: if diagnostics.changed {
-                        PulledDiagnostics::Unchanged {
-                            result_id: diagnostics.result_id?,
-                        }
-                    } else {
-                        PulledDiagnostics::Changed {
-                            result_id: diagnostics.result_id,
-                            diagnostics: diagnostics
-                                .diagnostics
-                                .into_iter()
-                                .filter_map(|diagnostic| {
-                                    GetDocumentDiagnostics::deserialize_lsp_diagnostic(diagnostic)
-                                        .context("deserializing diagnostics")
-                                        .log_err()
-                                })
-                                .collect(),
-                        }
-                    },
-                })
-            })
-            .collect();
-
-        Ok(pulled_diagnostics)
+        Ok(Self::diagnostics_from_proto(response))
     }
 
     fn buffer_id_from_proto(message: &proto::GetDocumentDiagnostics) -> Result<BufferId> {

crates/project/src/lsp_store.rs 🔗

@@ -166,7 +166,7 @@ pub struct LocalLspStore {
     _subscription: gpui::Subscription,
     lsp_tree: Entity<LanguageServerTree>,
     registered_buffers: HashMap<BufferId, usize>,
-    buffer_pull_diagnostics_result_ids: HashMap<PathBuf, Option<String>>,
+    buffer_pull_diagnostics_result_ids: HashMap<PathBuf, HashMap<LanguageServerId, Option<String>>>,
 }
 
 impl LocalLspStore {
@@ -316,11 +316,13 @@ impl LocalLspStore {
                         })?
                         .await
                         .inspect_err(|_| {
-                            if let Some(this) = this.upgrade() {
-                                this.update(cx, |_, cx| {
-                                    cx.emit(LspStoreEvent::LanguageServerRemoved(server_id))
-                                })
-                                .ok();
+                            if let Some(lsp_store) = this.upgrade() {
+                                lsp_store
+                                    .update(cx, |lsp_store, cx| {
+                                        lsp_store.remove_result_ids(server_id);
+                                        cx.emit(LspStoreEvent::LanguageServerRemoved(server_id))
+                                    })
+                                    .ok();
                             }
                         })?;
 
@@ -2297,7 +2299,9 @@ impl LocalLspStore {
         buffer.update(cx, |buffer, cx| {
             if let Some(abs_path) = File::from_dyn(buffer.file()).map(|f| f.abs_path(cx)) {
                 self.buffer_pull_diagnostics_result_ids
-                    .insert(abs_path, result_id);
+                    .entry(abs_path)
+                    .or_default()
+                    .insert(server_id, result_id);
             }
 
             buffer.update_diagnostics(server_id, set, cx)
@@ -3134,12 +3138,15 @@ impl LocalLspStore {
                     server_ids.remove(server_id_to_remove);
                 });
             self.language_server_watched_paths
-                .remove(&server_id_to_remove);
+                .remove(server_id_to_remove);
             self.language_server_paths_watched_for_rename
-                .remove(&server_id_to_remove);
+                .remove(server_id_to_remove);
             self.last_workspace_edits_by_language_server
-                .remove(&server_id_to_remove);
-            self.language_servers.remove(&server_id_to_remove);
+                .remove(server_id_to_remove);
+            self.language_servers.remove(server_id_to_remove);
+            for values_per_server in self.buffer_pull_diagnostics_result_ids.values_mut() {
+                values_per_server.remove(server_id_to_remove);
+            }
             cx.emit(LspStoreEvent::LanguageServerRemoved(*server_id_to_remove));
         }
         servers_to_remove.into_keys().collect()
@@ -5757,72 +5764,68 @@ impl LspStore {
     ) -> Task<Result<Vec<LspPullDiagnostics>>> {
         let buffer = buffer_handle.read(cx);
         let buffer_id = buffer.remote_id();
-        let result_id = self.result_id(buffer_id, cx);
 
         if let Some((client, upstream_project_id)) = self.upstream_client() {
             let request_task = client.request(proto::MultiLspQuery {
-                buffer_id: buffer_id.into(),
+                buffer_id: buffer_id.to_proto(),
                 version: serialize_version(&buffer_handle.read(cx).version()),
                 project_id: upstream_project_id,
                 strategy: Some(proto::multi_lsp_query::Strategy::All(
                     proto::AllLanguageServers {},
                 )),
                 request: Some(proto::multi_lsp_query::Request::GetDocumentDiagnostics(
-                    GetDocumentDiagnostics {
-                        previous_result_id: result_id.clone(),
-                    }
-                    .to_proto(upstream_project_id, buffer_handle.read(cx)),
+                    proto::GetDocumentDiagnostics {
+                        project_id: upstream_project_id,
+                        buffer_id: buffer_id.to_proto(),
+                        version: serialize_version(&buffer_handle.read(cx).version()),
+                    },
                 )),
             });
-            let buffer = buffer_handle.clone();
-            cx.spawn(async move |weak_project, cx| {
-                let Some(project) = weak_project.upgrade() else {
-                    return Ok(Vec::new());
-                };
-                let responses = request_task.await?.responses;
-                let diagnostics = join_all(
-                    responses
-                        .into_iter()
-                        .filter_map(|lsp_response| match lsp_response.response? {
-                            proto::lsp_response::Response::GetDocumentDiagnosticsResponse(
-                                response,
-                            ) => Some(response),
-                            unexpected => {
-                                debug_panic!("Unexpected response: {unexpected:?}");
-                                None
-                            }
-                        })
-                        .map(|diagnostics_response| {
-                            GetDocumentDiagnostics {
-                                previous_result_id: result_id.clone(),
-                            }
-                            .response_from_proto(
-                                diagnostics_response,
-                                project.clone(),
-                                buffer.clone(),
-                                cx.clone(),
-                            )
-                        }),
-                )
-                .await;
-
-                Ok(diagnostics
-                    .into_iter()
-                    .collect::<Result<Vec<_>>>()?
+            cx.background_spawn(async move {
+                Ok(request_task
+                    .await?
+                    .responses
                     .into_iter()
-                    .flatten()
+                    .filter_map(|lsp_response| match lsp_response.response? {
+                        proto::lsp_response::Response::GetDocumentDiagnosticsResponse(response) => {
+                            Some(response)
+                        }
+                        unexpected => {
+                            debug_panic!("Unexpected response: {unexpected:?}");
+                            None
+                        }
+                    })
+                    .flat_map(GetDocumentDiagnostics::diagnostics_from_proto)
                     .collect())
             })
         } else {
-            let all_actions_task = self.request_multiple_lsp_locally(
-                &buffer_handle,
-                None::<PointUtf16>,
-                GetDocumentDiagnostics {
-                    previous_result_id: result_id,
-                },
-                cx,
-            );
-            cx.spawn(async move |_, _| Ok(all_actions_task.await.into_iter().flatten().collect()))
+            let server_ids = buffer_handle.update(cx, |buffer, cx| {
+                self.language_servers_for_local_buffer(buffer, cx)
+                    .map(|(_, server)| server.server_id())
+                    .collect::<Vec<_>>()
+            });
+            let pull_diagnostics = server_ids
+                .into_iter()
+                .map(|server_id| {
+                    let result_id = self.result_id(server_id, buffer_id, cx);
+                    self.request_lsp(
+                        buffer_handle.clone(),
+                        LanguageServerToQuery::Other(server_id),
+                        GetDocumentDiagnostics {
+                            previous_result_id: result_id,
+                        },
+                        cx,
+                    )
+                })
+                .collect::<Vec<_>>();
+
+            cx.background_spawn(async move {
+                let mut responses = Vec::new();
+                for diagnostics in join_all(pull_diagnostics).await {
+                    responses.extend(diagnostics?);
+                }
+                Ok(responses)
+            })
         }
     }
 
@@ -7055,11 +7058,11 @@ impl LspStore {
     }
 
     async fn handle_multi_lsp_query(
-        this: Entity<Self>,
+        lsp_store: Entity<Self>,
         envelope: TypedEnvelope<proto::MultiLspQuery>,
         mut cx: AsyncApp,
     ) -> Result<proto::MultiLspQueryResponse> {
-        let response_from_ssh = this.read_with(&mut cx, |this, _| {
+        let response_from_ssh = lsp_store.read_with(&mut cx, |this, _| {
             let (upstream_client, project_id) = this.upstream_client()?;
             let mut payload = envelope.payload.clone();
             payload.project_id = project_id;
@@ -7073,7 +7076,7 @@ impl LspStore {
         let sender_id = envelope.original_sender_id().unwrap_or_default();
         let buffer_id = BufferId::new(envelope.payload.buffer_id)?;
         let version = deserialize_version(&envelope.payload.version);
-        let buffer = this.update(&mut cx, |this, cx| {
+        let buffer = lsp_store.update(&mut cx, |this, cx| {
             this.buffer_store.read(cx).get_existing(buffer_id)
         })??;
         buffer
@@ -7095,9 +7098,9 @@ impl LspStore {
         match envelope.payload.request {
             Some(proto::multi_lsp_query::Request::GetHover(get_hover)) => {
                 let get_hover =
-                    GetHover::from_proto(get_hover, this.clone(), buffer.clone(), cx.clone())
+                    GetHover::from_proto(get_hover, lsp_store.clone(), buffer.clone(), cx.clone())
                         .await?;
-                let all_hovers = this
+                let all_hovers = lsp_store
                     .update(&mut cx, |this, cx| {
                         this.request_multiple_lsp_locally(
                             &buffer,
@@ -7109,7 +7112,7 @@ impl LspStore {
                     .await
                     .into_iter()
                     .filter_map(|hover| remove_empty_hover_blocks(hover?));
-                this.update(&mut cx, |project, cx| proto::MultiLspQueryResponse {
+                lsp_store.update(&mut cx, |project, cx| proto::MultiLspQueryResponse {
                     responses: all_hovers
                         .map(|hover| proto::LspResponse {
                             response: Some(proto::lsp_response::Response::GetHoverResponse(
@@ -7128,13 +7131,13 @@ impl LspStore {
             Some(proto::multi_lsp_query::Request::GetCodeActions(get_code_actions)) => {
                 let get_code_actions = GetCodeActions::from_proto(
                     get_code_actions,
-                    this.clone(),
+                    lsp_store.clone(),
                     buffer.clone(),
                     cx.clone(),
                 )
                 .await?;
 
-                let all_actions = this
+                let all_actions = lsp_store
                     .update(&mut cx, |project, cx| {
                         project.request_multiple_lsp_locally(
                             &buffer,
@@ -7146,7 +7149,7 @@ impl LspStore {
                     .await
                     .into_iter();
 
-                this.update(&mut cx, |project, cx| proto::MultiLspQueryResponse {
+                lsp_store.update(&mut cx, |project, cx| proto::MultiLspQueryResponse {
                     responses: all_actions
                         .map(|code_actions| proto::LspResponse {
                             response: Some(proto::lsp_response::Response::GetCodeActionsResponse(
@@ -7165,13 +7168,13 @@ impl LspStore {
             Some(proto::multi_lsp_query::Request::GetSignatureHelp(get_signature_help)) => {
                 let get_signature_help = GetSignatureHelp::from_proto(
                     get_signature_help,
-                    this.clone(),
+                    lsp_store.clone(),
                     buffer.clone(),
                     cx.clone(),
                 )
                 .await?;
 
-                let all_signatures = this
+                let all_signatures = lsp_store
                     .update(&mut cx, |project, cx| {
                         project.request_multiple_lsp_locally(
                             &buffer,
@@ -7183,7 +7186,7 @@ impl LspStore {
                     .await
                     .into_iter();
 
-                this.update(&mut cx, |project, cx| proto::MultiLspQueryResponse {
+                lsp_store.update(&mut cx, |project, cx| proto::MultiLspQueryResponse {
                     responses: all_signatures
                         .map(|signature_help| proto::LspResponse {
                             response: Some(
@@ -7204,13 +7207,13 @@ impl LspStore {
             Some(proto::multi_lsp_query::Request::GetCodeLens(get_code_lens)) => {
                 let get_code_lens = GetCodeLens::from_proto(
                     get_code_lens,
-                    this.clone(),
+                    lsp_store.clone(),
                     buffer.clone(),
                     cx.clone(),
                 )
                 .await?;
 
-                let code_lens_actions = this
+                let code_lens_actions = lsp_store
                     .update(&mut cx, |project, cx| {
                         project.request_multiple_lsp_locally(
                             &buffer,
@@ -7222,7 +7225,7 @@ impl LspStore {
                     .await
                     .into_iter();
 
-                this.update(&mut cx, |project, cx| proto::MultiLspQueryResponse {
+                lsp_store.update(&mut cx, |project, cx| proto::MultiLspQueryResponse {
                     responses: code_lens_actions
                         .map(|actions| proto::LspResponse {
                             response: Some(proto::lsp_response::Response::GetCodeLensResponse(
@@ -7238,31 +7241,46 @@ impl LspStore {
                         .collect(),
                 })
             }
-            Some(proto::multi_lsp_query::Request::GetDocumentDiagnostics(
-                get_document_diagnostics,
-            )) => {
-                let get_document_diagnostics = GetDocumentDiagnostics::from_proto(
-                    get_document_diagnostics,
-                    this.clone(),
-                    buffer.clone(),
-                    cx.clone(),
-                )
-                .await?;
-
-                let all_diagnostics = this
-                    .update(&mut cx, |project, cx| {
-                        project.request_multiple_lsp_locally(
-                            &buffer,
-                            None::<PointUtf16>,
-                            get_document_diagnostics,
-                            cx,
-                        )
+            Some(proto::multi_lsp_query::Request::GetDocumentDiagnostics(message)) => {
+                buffer
+                    .update(&mut cx, |buffer, _| {
+                        buffer.wait_for_version(deserialize_version(&message.version))
                     })?
-                    .await
-                    .into_iter();
+                    .await?;
+                let pull_diagnostics = lsp_store.update(&mut cx, |lsp_store, cx| {
+                    let server_ids = buffer.update(cx, |buffer, cx| {
+                        lsp_store
+                            .language_servers_for_local_buffer(buffer, cx)
+                            .map(|(_, server)| server.server_id())
+                            .collect::<Vec<_>>()
+                    });
+
+                    server_ids
+                        .into_iter()
+                        .map(|server_id| {
+                            let result_id = lsp_store.result_id(server_id, buffer_id, cx);
+                            lsp_store.request_lsp(
+                                buffer.clone(),
+                                LanguageServerToQuery::Other(server_id),
+                                GetDocumentDiagnostics {
+                                    previous_result_id: result_id,
+                                },
+                                cx,
+                            )
+                        })
+                        .collect::<Vec<_>>()
+                })?;
+
+                let all_diagnostics_responses = join_all(pull_diagnostics).await;
+                let mut all_diagnostics = Vec::new();
+                for response in all_diagnostics_responses {
+                    let response = response?;
+                    all_diagnostics.push(response);
+                }
 
-                this.update(&mut cx, |project, cx| proto::MultiLspQueryResponse {
+                lsp_store.update(&mut cx, |project, cx| proto::MultiLspQueryResponse {
                     responses: all_diagnostics
+                        .into_iter()
                         .map(|lsp_diagnostic| proto::LspResponse {
                             response: Some(
                                 proto::lsp_response::Response::GetDocumentDiagnosticsResponse(
@@ -8828,6 +8846,7 @@ impl LspStore {
         local.language_server_watched_paths.remove(&server_id);
         let server_state = local.language_servers.remove(&server_id);
         cx.notify();
+        self.remove_result_ids(server_id);
         cx.emit(LspStoreEvent::LanguageServerRemoved(server_id));
         cx.spawn(async move |_, cx| {
             Self::shutdown_language_server(server_state, name, cx).await;
@@ -9716,7 +9735,20 @@ impl LspStore {
         }
     }
 
-    pub fn result_id(&self, buffer_id: BufferId, cx: &App) -> Option<String> {
+    fn remove_result_ids(&mut self, for_server: LanguageServerId) {
+        if let Some(local) = self.as_local_mut() {
+            for values_per_server in local.buffer_pull_diagnostics_result_ids.values_mut() {
+                values_per_server.remove(&for_server);
+            }
+        }
+    }
+
+    pub fn result_id(
+        &self,
+        server_id: LanguageServerId,
+        buffer_id: BufferId,
+        cx: &App,
+    ) -> Option<String> {
         let abs_path = self
             .buffer_store
             .read(cx)
@@ -9725,19 +9757,21 @@ impl LspStore {
             .map(|f| f.abs_path(cx))?;
         self.as_local()?
             .buffer_pull_diagnostics_result_ids
-            .get(&abs_path)
-            .cloned()
-            .flatten()
+            .get(&abs_path)?
+            .get(&server_id)?
+            .clone()
     }
 
-    pub fn all_result_ids(&self) -> HashMap<PathBuf, String> {
+    pub fn all_result_ids(&self, server_id: LanguageServerId) -> HashMap<PathBuf, String> {
         let Some(local) = self.as_local() else {
             return HashMap::default();
         };
         local
             .buffer_pull_diagnostics_result_ids
             .iter()
-            .filter_map(|(file_path, result_id)| Some((file_path.clone(), result_id.clone()?)))
+            .filter_map(|(file_path, result_ids)| {
+                Some((file_path.clone(), result_ids.get(&server_id)?.clone()?))
+            })
             .collect()
     }
 
@@ -9822,7 +9856,7 @@ fn lsp_workspace_diagnostics_refresh(
 
                 let Ok(previous_result_ids) = lsp_store.update(cx, |lsp_store, _| {
                     lsp_store
-                        .all_result_ids()
+                        .all_result_ids(server.server_id())
                         .into_iter()
                         .filter_map(|(abs_path, result_id)| {
                             let uri = file_path_to_lsp_url(&abs_path).ok()?;