Respect server capabilities on queries (#33538)

Kirill Bulatov created

Closes https://github.com/zed-industries/zed/issues/33522

Turns out a bunch of Zed requests were not checking their capabilities
correctly, due to odd copy-paste and due to default that assumed that
the capabilities are met.

Adjust the code, which includes the document colors, add the test on the
colors case.

Release Notes:

- Fixed excessive document colors requests for unrelated files

Change summary

crates/editor/src/editor_tests.rs                |  78 ++++++++--
crates/project/src/lsp_command.rs                |  68 ++++++++-
crates/project/src/lsp_store.rs                  | 121 +++++++++++------
crates/project/src/lsp_store/lsp_ext_command.rs  |  22 +++
crates/project/src/project.rs                    |  31 ++++
crates/remote_server/src/remote_editing_tests.rs |  13 +
6 files changed, 259 insertions(+), 74 deletions(-)

Detailed changes

crates/editor/src/editor_tests.rs 🔗

@@ -22631,6 +22631,18 @@ async fn test_mtime_and_document_colors(cx: &mut TestAppContext) {
                 color_provider: Some(lsp::ColorProviderCapability::Simple(true)),
                 ..lsp::ServerCapabilities::default()
             },
+            name: "rust-analyzer",
+            ..FakeLspAdapter::default()
+        },
+    );
+    let mut fake_servers_without_capabilities = language_registry.register_fake_lsp(
+        "Rust",
+        FakeLspAdapter {
+            capabilities: lsp::ServerCapabilities {
+                color_provider: Some(lsp::ColorProviderCapability::Simple(false)),
+                ..lsp::ServerCapabilities::default()
+            },
+            name: "not-rust-analyzer",
             ..FakeLspAdapter::default()
         },
     );
@@ -22650,6 +22662,8 @@ async fn test_mtime_and_document_colors(cx: &mut TestAppContext) {
         .downcast::<Editor>()
         .unwrap();
     let fake_language_server = fake_servers.next().await.unwrap();
+    let fake_language_server_without_capabilities =
+        fake_servers_without_capabilities.next().await.unwrap();
     let requests_made = Arc::new(AtomicUsize::new(0));
     let closure_requests_made = Arc::clone(&requests_made);
     let mut color_request_handle = fake_language_server
@@ -22661,34 +22675,59 @@ async fn test_mtime_and_document_colors(cx: &mut TestAppContext) {
                     lsp::Url::from_file_path(path!("/a/first.rs")).unwrap()
                 );
                 requests_made.fetch_add(1, atomic::Ordering::Release);
-                Ok(vec![lsp::ColorInformation {
-                    range: lsp::Range {
-                        start: lsp::Position {
-                            line: 0,
-                            character: 0,
+                Ok(vec![
+                    lsp::ColorInformation {
+                        range: lsp::Range {
+                            start: lsp::Position {
+                                line: 0,
+                                character: 0,
+                            },
+                            end: lsp::Position {
+                                line: 0,
+                                character: 1,
+                            },
                         },
-                        end: lsp::Position {
-                            line: 0,
-                            character: 1,
+                        color: lsp::Color {
+                            red: 0.33,
+                            green: 0.33,
+                            blue: 0.33,
+                            alpha: 0.33,
                         },
                     },
-                    color: lsp::Color {
-                        red: 0.33,
-                        green: 0.33,
-                        blue: 0.33,
-                        alpha: 0.33,
+                    lsp::ColorInformation {
+                        range: lsp::Range {
+                            start: lsp::Position {
+                                line: 0,
+                                character: 0,
+                            },
+                            end: lsp::Position {
+                                line: 0,
+                                character: 1,
+                            },
+                        },
+                        color: lsp::Color {
+                            red: 0.33,
+                            green: 0.33,
+                            blue: 0.33,
+                            alpha: 0.33,
+                        },
                     },
-                }])
+                ])
             }
         });
+
+    let _handle = fake_language_server_without_capabilities
+        .set_request_handler::<lsp::request::DocumentColor, _, _>(move |_, _| async move {
+            panic!("Should not be called");
+        });
     color_request_handle.next().await.unwrap();
     cx.run_until_parked();
     color_request_handle.next().await.unwrap();
     cx.run_until_parked();
     assert_eq!(
-        2,
+        3,
         requests_made.load(atomic::Ordering::Acquire),
-        "Should query for colors once per editor open and once after the language server startup"
+        "Should query for colors once per editor open (1) and once after the language server startup (2)"
     );
 
     cx.executor().advance_clock(Duration::from_millis(500));
@@ -22718,7 +22757,7 @@ async fn test_mtime_and_document_colors(cx: &mut TestAppContext) {
     color_request_handle.next().await.unwrap();
     cx.run_until_parked();
     assert_eq!(
-        4,
+        5,
         requests_made.load(atomic::Ordering::Acquire),
         "Should query for colors once per save and once per formatting after save"
     );
@@ -22733,7 +22772,7 @@ async fn test_mtime_and_document_colors(cx: &mut TestAppContext) {
         .unwrap();
     close.await.unwrap();
     assert_eq!(
-        4,
+        5,
         requests_made.load(atomic::Ordering::Acquire),
         "After saving and closing the editor, no extra requests should be made"
     );
@@ -22745,10 +22784,11 @@ async fn test_mtime_and_document_colors(cx: &mut TestAppContext) {
             })
         })
         .unwrap();
+    cx.executor().advance_clock(Duration::from_millis(100));
     color_request_handle.next().await.unwrap();
     cx.run_until_parked();
     assert_eq!(
-        5,
+        6,
         requests_made.load(atomic::Ordering::Acquire),
         "After navigating back to an editor and reopening it, another color request should be made"
     );

crates/project/src/lsp_command.rs 🔗

@@ -107,9 +107,7 @@ pub trait LspCommand: 'static + Sized + Send + std::fmt::Debug {
     }
 
     /// When false, `to_lsp_params_or_response` default implementation will return the default response.
-    fn check_capabilities(&self, _: AdapterServerCapabilities) -> bool {
-        true
-    }
+    fn check_capabilities(&self, _: AdapterServerCapabilities) -> bool;
 
     fn to_lsp(
         &self,
@@ -277,6 +275,16 @@ impl LspCommand for PrepareRename {
         "Prepare rename"
     }
 
+    fn check_capabilities(&self, capabilities: AdapterServerCapabilities) -> bool {
+        capabilities
+            .server_capabilities
+            .rename_provider
+            .is_some_and(|capability| match capability {
+                OneOf::Left(enabled) => enabled,
+                OneOf::Right(options) => options.prepare_provider.unwrap_or(false),
+            })
+    }
+
     fn to_lsp_params_or_response(
         &self,
         path: &Path,
@@ -459,6 +467,16 @@ impl LspCommand for PerformRename {
         "Rename"
     }
 
+    fn check_capabilities(&self, capabilities: AdapterServerCapabilities) -> bool {
+        capabilities
+            .server_capabilities
+            .rename_provider
+            .is_some_and(|capability| match capability {
+                OneOf::Left(enabled) => enabled,
+                OneOf::Right(_options) => true,
+            })
+    }
+
     fn to_lsp(
         &self,
         path: &Path,
@@ -583,7 +601,10 @@ impl LspCommand for GetDefinition {
         capabilities
             .server_capabilities
             .definition_provider
-            .is_some()
+            .is_some_and(|capability| match capability {
+                OneOf::Left(supported) => supported,
+                OneOf::Right(_options) => true,
+            })
     }
 
     fn to_lsp(
@@ -682,7 +703,11 @@ impl LspCommand for GetDeclaration {
         capabilities
             .server_capabilities
             .declaration_provider
-            .is_some()
+            .is_some_and(|capability| match capability {
+                lsp::DeclarationCapability::Simple(supported) => supported,
+                lsp::DeclarationCapability::RegistrationOptions(..) => true,
+                lsp::DeclarationCapability::Options(..) => true,
+            })
     }
 
     fn to_lsp(
@@ -777,6 +802,16 @@ impl LspCommand for GetImplementation {
         "Get implementation"
     }
 
+    fn check_capabilities(&self, capabilities: AdapterServerCapabilities) -> bool {
+        capabilities
+            .server_capabilities
+            .implementation_provider
+            .is_some_and(|capability| match capability {
+                lsp::ImplementationProviderCapability::Simple(enabled) => enabled,
+                lsp::ImplementationProviderCapability::Options(_options) => true,
+            })
+    }
+
     fn to_lsp(
         &self,
         path: &Path,
@@ -1437,7 +1472,10 @@ impl LspCommand for GetDocumentHighlights {
         capabilities
             .server_capabilities
             .document_highlight_provider
-            .is_some()
+            .is_some_and(|capability| match capability {
+                OneOf::Left(supported) => supported,
+                OneOf::Right(_options) => true,
+            })
     }
 
     fn to_lsp(
@@ -1590,7 +1628,10 @@ impl LspCommand for GetDocumentSymbols {
         capabilities
             .server_capabilities
             .document_symbol_provider
-            .is_some()
+            .is_some_and(|capability| match capability {
+                OneOf::Left(supported) => supported,
+                OneOf::Right(_options) => true,
+            })
     }
 
     fn to_lsp(
@@ -2116,6 +2157,13 @@ impl LspCommand for GetCompletions {
         "Get completion"
     }
 
+    fn check_capabilities(&self, capabilities: AdapterServerCapabilities) -> bool {
+        capabilities
+            .server_capabilities
+            .completion_provider
+            .is_some()
+    }
+
     fn to_lsp(
         &self,
         path: &Path,
@@ -4161,7 +4209,11 @@ impl LspCommand for GetDocumentColor {
         server_capabilities
             .server_capabilities
             .color_provider
-            .is_some()
+            .is_some_and(|capability| match capability {
+                lsp::ColorProviderCapability::Simple(supported) => supported,
+                lsp::ColorProviderCapability::ColorProvider(..) => true,
+                lsp::ColorProviderCapability::Options(..) => true,
+            })
     }
 
     fn to_lsp(

crates/project/src/lsp_store.rs 🔗

@@ -3545,7 +3545,8 @@ pub struct LspStore {
     lsp_data: Option<LspData>,
 }
 
-type DocumentColorTask = Shared<Task<std::result::Result<Vec<DocumentColor>, Arc<anyhow::Error>>>>;
+type DocumentColorTask =
+    Shared<Task<std::result::Result<HashSet<DocumentColor>, Arc<anyhow::Error>>>>;
 
 #[derive(Debug)]
 struct LspData {
@@ -3557,7 +3558,7 @@ struct LspData {
 
 #[derive(Debug, Default)]
 struct BufferLspData {
-    colors: Option<Vec<DocumentColor>>,
+    colors: Option<HashSet<DocumentColor>>,
 }
 
 #[derive(Debug)]
@@ -6237,13 +6238,13 @@ impl LspStore {
             .flat_map(|lsp_data| lsp_data.buffer_lsp_data.values())
             .filter_map(|buffer_data| buffer_data.get(&abs_path))
             .filter_map(|buffer_data| {
-                let colors = buffer_data.colors.as_deref()?;
+                let colors = buffer_data.colors.as_ref()?;
                 received_colors_data = true;
                 Some(colors)
             })
             .flatten()
             .cloned()
-            .collect::<Vec<_>>();
+            .collect::<HashSet<_>>();
 
         if buffer_lsp_data.is_empty() || for_server_id.is_some() {
             if received_colors_data && for_server_id.is_none() {
@@ -6297,42 +6298,25 @@ impl LspStore {
             let task_abs_path = abs_path.clone();
             let new_task = cx
                 .spawn(async move |lsp_store, cx| {
-                    cx.background_executor().timer(Duration::from_millis(50)).await;
-                    let fetched_colors = match lsp_store
-                        .update(cx, |lsp_store, cx| {
-                            lsp_store.fetch_document_colors(buffer, cx)
-                        }) {
-                            Ok(fetch_task) => fetch_task.await
-                            .with_context(|| {
-                                format!(
-                                    "Fetching document colors for buffer with path {task_abs_path:?}"
-                                )
-                            }),
-                            Err(e) => return Err(Arc::new(e)),
-                        };
-                    let fetched_colors = match fetched_colors {
-                        Ok(fetched_colors) => fetched_colors,
-                        Err(e) => return Err(Arc::new(e)),
-                    };
-
-                    let lsp_colors = lsp_store.update(cx, |lsp_store, _| {
-                        let lsp_data = lsp_store.lsp_data.as_mut().with_context(|| format!(
-                            "Document lsp data got updated between fetch and update for path {task_abs_path:?}"
-                        ))?;
-                        let mut lsp_colors = Vec::new();
-                        anyhow::ensure!(lsp_data.mtime == buffer_mtime, "Buffer lsp data got updated between fetch and update for path {task_abs_path:?}");
-                        for (server_id, colors) in fetched_colors {
-                            let colors_lsp_data = &mut lsp_data.buffer_lsp_data.entry(server_id).or_default().entry(task_abs_path.clone()).or_default().colors;
-                            *colors_lsp_data = Some(colors.clone());
-                            lsp_colors.extend(colors);
+                    match fetch_document_colors(
+                        lsp_store.clone(),
+                        buffer,
+                        task_abs_path.clone(),
+                        cx,
+                    )
+                    .await
+                    {
+                        Ok(colors) => Ok(colors),
+                        Err(e) => {
+                            lsp_store
+                                .update(cx, |lsp_store, _| {
+                                    if let Some(lsp_data) = lsp_store.lsp_data.as_mut() {
+                                        lsp_data.colors_update.remove(&task_abs_path);
+                                    }
+                                })
+                                .ok();
+                            Err(Arc::new(e))
                         }
-                        Ok(lsp_colors)
-                    });
-
-                    match lsp_colors {
-                        Ok(Ok(lsp_colors)) => Ok(lsp_colors),
-                        Ok(Err(e)) => Err(Arc::new(e)),
-                        Err(e) => Err(Arc::new(e)),
                     }
                 })
                 .shared();
@@ -6350,11 +6334,11 @@ impl LspStore {
         }
     }
 
-    fn fetch_document_colors(
+    fn fetch_document_colors_for_buffer(
         &mut self,
         buffer: Entity<Buffer>,
         cx: &mut Context<Self>,
-    ) -> Task<anyhow::Result<Vec<(LanguageServerId, Vec<DocumentColor>)>>> {
+    ) -> Task<anyhow::Result<Vec<(LanguageServerId, HashSet<DocumentColor>)>>> {
         if let Some((client, project_id)) = self.upstream_client() {
             let request_task = client.request(proto::MultiLspQuery {
                 project_id,
@@ -6403,7 +6387,9 @@ impl LspStore {
                 .await
                 .into_iter()
                 .fold(HashMap::default(), |mut acc, (server_id, colors)| {
-                    acc.entry(server_id).or_insert_with(Vec::new).extend(colors);
+                    acc.entry(server_id)
+                        .or_insert_with(HashSet::default)
+                        .extend(colors);
                     acc
                 })
                 .into_iter()
@@ -6418,7 +6404,9 @@ impl LspStore {
                     .await
                     .into_iter()
                     .fold(HashMap::default(), |mut acc, (server_id, colors)| {
-                        acc.entry(server_id).or_insert_with(Vec::new).extend(colors);
+                        acc.entry(server_id)
+                            .or_insert_with(HashSet::default)
+                            .extend(colors);
                         acc
                     })
                     .into_iter()
@@ -10691,6 +10679,53 @@ impl LspStore {
     }
 }
 
+async fn fetch_document_colors(
+    lsp_store: WeakEntity<LspStore>,
+    buffer: Entity<Buffer>,
+    task_abs_path: PathBuf,
+    cx: &mut AsyncApp,
+) -> anyhow::Result<HashSet<DocumentColor>> {
+    cx.background_executor()
+        .timer(Duration::from_millis(50))
+        .await;
+    let Some(buffer_mtime) = buffer.update(cx, |buffer, _| buffer.saved_mtime())? else {
+        return Ok(HashSet::default());
+    };
+    let fetched_colors = lsp_store
+        .update(cx, |lsp_store, cx| {
+            lsp_store.fetch_document_colors_for_buffer(buffer, cx)
+        })?
+        .await
+        .with_context(|| {
+            format!("Fetching document colors for buffer with path {task_abs_path:?}")
+        })?;
+
+    lsp_store.update(cx, |lsp_store, _| {
+        let lsp_data = lsp_store.lsp_data.as_mut().with_context(|| {
+            format!(
+                "Document lsp data got updated between fetch and update for path {task_abs_path:?}"
+            )
+        })?;
+        let mut lsp_colors = HashSet::default();
+        anyhow::ensure!(
+            lsp_data.mtime == buffer_mtime,
+            "Buffer lsp data got updated between fetch and update for path {task_abs_path:?}"
+        );
+        for (server_id, colors) in fetched_colors {
+            let colors_lsp_data = &mut lsp_data
+                .buffer_lsp_data
+                .entry(server_id)
+                .or_default()
+                .entry(task_abs_path.clone())
+                .or_default()
+                .colors;
+            *colors_lsp_data = Some(colors.clone());
+            lsp_colors.extend(colors);
+        }
+        Ok(lsp_colors)
+    })?
+}
+
 fn subscribe_to_binary_statuses(
     languages: &Arc<LanguageRegistry>,
     cx: &mut Context<'_, LspStore>,

crates/project/src/lsp_store/lsp_ext_command.rs 🔗

@@ -16,7 +16,7 @@ use language::{
     Buffer, point_to_lsp,
     proto::{deserialize_anchor, serialize_anchor},
 };
-use lsp::{LanguageServer, LanguageServerId};
+use lsp::{AdapterServerCapabilities, LanguageServer, LanguageServerId};
 use rpc::proto::{self, PeerId};
 use serde::{Deserialize, Serialize};
 use std::{
@@ -68,6 +68,10 @@ impl LspCommand for ExpandMacro {
         "Expand macro"
     }
 
+    fn check_capabilities(&self, _: AdapterServerCapabilities) -> bool {
+        true
+    }
+
     fn to_lsp(
         &self,
         path: &Path,
@@ -196,6 +200,10 @@ impl LspCommand for OpenDocs {
         "Open docs"
     }
 
+    fn check_capabilities(&self, _: AdapterServerCapabilities) -> bool {
+        true
+    }
+
     fn to_lsp(
         &self,
         path: &Path,
@@ -326,6 +334,10 @@ impl LspCommand for SwitchSourceHeader {
         "Switch source header"
     }
 
+    fn check_capabilities(&self, _: AdapterServerCapabilities) -> bool {
+        true
+    }
+
     fn to_lsp(
         &self,
         path: &Path,
@@ -404,6 +416,10 @@ impl LspCommand for GoToParentModule {
         "Go to parent module"
     }
 
+    fn check_capabilities(&self, _: AdapterServerCapabilities) -> bool {
+        true
+    }
+
     fn to_lsp(
         &self,
         path: &Path,
@@ -578,6 +594,10 @@ impl LspCommand for GetLspRunnables {
         "LSP Runnables"
     }
 
+    fn check_capabilities(&self, _: AdapterServerCapabilities) -> bool {
+        true
+    }
+
     fn to_lsp(
         &self,
         path: &Path,

crates/project/src/project.rs 🔗

@@ -779,13 +779,42 @@ pub struct DocumentColor {
     pub color_presentations: Vec<ColorPresentation>,
 }
 
-#[derive(Clone, Debug, PartialEq)]
+impl Eq for DocumentColor {}
+
+impl std::hash::Hash for DocumentColor {
+    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+        self.lsp_range.hash(state);
+        self.color.red.to_bits().hash(state);
+        self.color.green.to_bits().hash(state);
+        self.color.blue.to_bits().hash(state);
+        self.color.alpha.to_bits().hash(state);
+        self.resolved.hash(state);
+        self.color_presentations.hash(state);
+    }
+}
+
+#[derive(Clone, Debug, PartialEq, Eq)]
 pub struct ColorPresentation {
     pub label: String,
     pub text_edit: Option<lsp::TextEdit>,
     pub additional_text_edits: Vec<lsp::TextEdit>,
 }
 
+impl std::hash::Hash for ColorPresentation {
+    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+        self.label.hash(state);
+        if let Some(ref edit) = self.text_edit {
+            edit.range.hash(state);
+            edit.new_text.hash(state);
+        }
+        self.additional_text_edits.len().hash(state);
+        for edit in &self.additional_text_edits {
+            edit.range.hash(state);
+            edit.new_text.hash(state);
+        }
+    }
+}
+
 #[derive(Clone)]
 pub enum DirectoryLister {
     Project(Entity<Project>),

crates/remote_server/src/remote_editing_tests.rs 🔗

@@ -422,7 +422,12 @@ async fn test_remote_lsp(cx: &mut TestAppContext, server_cx: &mut TestAppContext
             "Rust",
             FakeLspAdapter {
                 name: "rust-analyzer",
-                ..Default::default()
+                capabilities: lsp::ServerCapabilities {
+                    completion_provider: Some(lsp::CompletionOptions::default()),
+                    rename_provider: Some(lsp::OneOf::Left(true)),
+                    ..lsp::ServerCapabilities::default()
+                },
+                ..FakeLspAdapter::default()
             },
         )
     });
@@ -430,7 +435,11 @@ async fn test_remote_lsp(cx: &mut TestAppContext, server_cx: &mut TestAppContext
     let mut fake_lsp = server_cx.update(|cx| {
         headless.read(cx).languages.register_fake_language_server(
             LanguageServerName("rust-analyzer".into()),
-            Default::default(),
+            lsp::ServerCapabilities {
+                completion_provider: Some(lsp::CompletionOptions::default()),
+                rename_provider: Some(lsp::OneOf::Left(true)),
+                ..lsp::ServerCapabilities::default()
+            },
             None,
         )
     });