Add `label_for_symbol` to extension API (#10179)

Marshall Bowers and Max created

This PR adds `label_for_symbol` to the extension API.

As a motivating example, we implemented `label_for_symbol` for the
Haskell extension.

Release Notes:

- N/A

Co-authored-by: Max <max@zed.dev>

Change summary

Cargo.lock                                          |   2 
crates/extension/src/extension_lsp_adapter.rs       | 120 +++++++++++---
crates/extension/src/wasm_host/wit.rs               |  22 ++
crates/extension_api/src/extension_api.rs           |  31 +++
crates/extension_api/wit/since_v0.0.6/extension.wit |   3 
crates/extension_api/wit/since_v0.0.6/lsp.wit       |  35 ++++
crates/language/src/language.rs                     |  13 +
crates/project/src/project.rs                       |   4 
extensions/haskell/Cargo.toml                       |   3 
extensions/haskell/src/haskell.rs                   |  42 +++++
10 files changed, 235 insertions(+), 40 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -12540,7 +12540,7 @@ dependencies = [
 name = "zed_haskell"
 version = "0.0.1"
 dependencies = [
- "zed_extension_api 0.0.4",
+ "zed_extension_api 0.0.6",
 ]
 
 [[package]]

crates/extension/src/extension_lsp_adapter.rs 🔗

@@ -11,6 +11,7 @@ use language::{
     CodeLabel, HighlightId, Language, LanguageServerName, LspAdapter, LspAdapterDelegate,
 };
 use lsp::LanguageServerBinary;
+use serde::Serialize;
 use std::ops::Range;
 use std::{
     any::Any,
@@ -210,21 +211,61 @@ impl LspAdapter for ExtensionLspAdapter {
             })
             .await?;
 
-        Ok(labels
+        Ok(labels_from_wit(labels, language))
+    }
+
+    async fn labels_for_symbols(
+        self: Arc<Self>,
+        symbols: &[(String, lsp::SymbolKind)],
+        language: &Arc<Language>,
+    ) -> Result<Vec<Option<CodeLabel>>> {
+        let symbols = symbols
             .into_iter()
-            .map(|label| {
-                label.map(|label| {
-                    build_code_label(
-                        &label,
-                        &language.highlight_text(&label.code.as_str().into(), 0..label.code.len()),
-                        &language,
-                    )
-                })
+            .cloned()
+            .map(|(name, kind)| wit::Symbol {
+                name,
+                kind: kind.into(),
             })
-            .collect())
+            .collect::<Vec<_>>();
+
+        let labels = self
+            .extension
+            .call({
+                let this = self.clone();
+                |extension, store| {
+                    async move {
+                        extension
+                            .call_labels_for_symbols(store, &this.language_server_id, symbols)
+                            .await?
+                            .map_err(|e| anyhow!("{}", e))
+                    }
+                    .boxed()
+                }
+            })
+            .await?;
+
+        Ok(labels_from_wit(labels, language))
     }
 }
 
+fn labels_from_wit(
+    labels: Vec<Option<wit::CodeLabel>>,
+    language: &Arc<Language>,
+) -> Vec<Option<CodeLabel>> {
+    labels
+        .into_iter()
+        .map(|label| {
+            label.map(|label| {
+                build_code_label(
+                    &label,
+                    &language.highlight_text(&label.code.as_str().into(), 0..label.code.len()),
+                    &language,
+                )
+            })
+        })
+        .collect()
+}
+
 fn build_code_label(
     label: &wit::CodeLabel,
     parsed_runs: &[(Range<usize>, HighlightId)],
@@ -332,14 +373,7 @@ impl From<lsp::CompletionItemKind> for wit::CompletionKind {
             lsp::CompletionItemKind::EVENT => Self::Event,
             lsp::CompletionItemKind::OPERATOR => Self::Operator,
             lsp::CompletionItemKind::TYPE_PARAMETER => Self::TypeParameter,
-            _ => {
-                let value = maybe!({
-                    let kind = serde_json::to_value(&value)?;
-                    serde_json::from_value(kind)
-                });
-
-                Self::Other(value.log_err().unwrap_or(-1))
-            }
+            _ => Self::Other(extract_int(value)),
         }
     }
 }
@@ -349,18 +383,54 @@ impl From<lsp::InsertTextFormat> for wit::InsertTextFormat {
         match value {
             lsp::InsertTextFormat::PLAIN_TEXT => Self::PlainText,
             lsp::InsertTextFormat::SNIPPET => Self::Snippet,
-            _ => {
-                let value = maybe!({
-                    let kind = serde_json::to_value(&value)?;
-                    serde_json::from_value(kind)
-                });
+            _ => Self::Other(extract_int(value)),
+        }
+    }
+}
 
-                Self::Other(value.log_err().unwrap_or(-1))
-            }
+impl From<lsp::SymbolKind> for wit::SymbolKind {
+    fn from(value: lsp::SymbolKind) -> Self {
+        match value {
+            lsp::SymbolKind::FILE => Self::File,
+            lsp::SymbolKind::MODULE => Self::Module,
+            lsp::SymbolKind::NAMESPACE => Self::Namespace,
+            lsp::SymbolKind::PACKAGE => Self::Package,
+            lsp::SymbolKind::CLASS => Self::Class,
+            lsp::SymbolKind::METHOD => Self::Method,
+            lsp::SymbolKind::PROPERTY => Self::Property,
+            lsp::SymbolKind::FIELD => Self::Field,
+            lsp::SymbolKind::CONSTRUCTOR => Self::Constructor,
+            lsp::SymbolKind::ENUM => Self::Enum,
+            lsp::SymbolKind::INTERFACE => Self::Interface,
+            lsp::SymbolKind::FUNCTION => Self::Function,
+            lsp::SymbolKind::VARIABLE => Self::Variable,
+            lsp::SymbolKind::CONSTANT => Self::Constant,
+            lsp::SymbolKind::STRING => Self::String,
+            lsp::SymbolKind::NUMBER => Self::Number,
+            lsp::SymbolKind::BOOLEAN => Self::Boolean,
+            lsp::SymbolKind::ARRAY => Self::Array,
+            lsp::SymbolKind::OBJECT => Self::Object,
+            lsp::SymbolKind::KEY => Self::Key,
+            lsp::SymbolKind::NULL => Self::Null,
+            lsp::SymbolKind::ENUM_MEMBER => Self::EnumMember,
+            lsp::SymbolKind::STRUCT => Self::Struct,
+            lsp::SymbolKind::EVENT => Self::Event,
+            lsp::SymbolKind::OPERATOR => Self::Operator,
+            lsp::SymbolKind::TYPE_PARAMETER => Self::TypeParameter,
+            _ => Self::Other(extract_int(value)),
         }
     }
 }
 
+fn extract_int<T: Serialize>(value: T) -> i32 {
+    maybe!({
+        let kind = serde_json::to_value(&value)?;
+        serde_json::from_value(kind)
+    })
+    .log_err()
+    .unwrap_or(-1)
+}
+
 #[test]
 fn test_build_code_label() {
     use util::test::marked_text_ranges;

crates/extension/src/wasm_host/wit.rs 🔗

@@ -5,7 +5,6 @@ mod since_v0_0_6;
 use std::ops::RangeInclusive;
 use std::sync::Arc;
 
-use anyhow::bail;
 use anyhow::{Context, Result};
 use language::{LanguageServerName, LspAdapterDelegate};
 use semantic_version::SemanticVersion;
@@ -19,7 +18,7 @@ use super::{wasm_engine, WasmState};
 use since_v0_0_6 as latest;
 
 pub use latest::{
-    zed::extension::lsp::{Completion, CompletionKind, InsertTextFormat},
+    zed::extension::lsp::{Completion, CompletionKind, InsertTextFormat, Symbol, SymbolKind},
     CodeLabel, CodeLabelSpan, Command, Range,
 };
 pub use since_v0_0_4::LanguageServerConfig;
@@ -156,15 +155,28 @@ impl Extension {
         completions: Vec<latest::Completion>,
     ) -> Result<Result<Vec<Option<CodeLabel>>, String>> {
         match self {
-            Extension::V001(_) | Extension::V004(_) => {
-                bail!("unsupported function: 'labels_for_completions'")
-            }
+            Extension::V001(_) | Extension::V004(_) => Ok(Ok(Vec::new())),
             Extension::V006(ext) => {
                 ext.call_labels_for_completions(store, &language_server_id.0, &completions)
                     .await
             }
         }
     }
+
+    pub async fn call_labels_for_symbols(
+        &self,
+        store: &mut Store<WasmState>,
+        language_server_id: &LanguageServerName,
+        symbols: Vec<latest::Symbol>,
+    ) -> Result<Result<Vec<Option<CodeLabel>>, String>> {
+        match self {
+            Extension::V001(_) | Extension::V004(_) => Ok(Ok(Vec::new())),
+            Extension::V006(ext) => {
+                ext.call_labels_for_symbols(store, &language_server_id.0, &symbols)
+                    .await
+            }
+        }
+    }
 }
 
 trait ToWasmtimeResult<T> {

crates/extension_api/src/extension_api.rs 🔗

@@ -64,6 +64,15 @@ pub trait Extension: Send + Sync {
     ) -> Option<CodeLabel> {
         None
     }
+
+    /// Returns the label for the given symbol.
+    fn label_for_symbol(
+        &self,
+        _language_server_id: &LanguageServerId,
+        _symbol: Symbol,
+    ) -> Option<CodeLabel> {
+        None
+    }
 }
 
 #[macro_export]
@@ -138,11 +147,33 @@ impl wit::Guest for Component {
         }
         Ok(labels)
     }
+
+    fn labels_for_symbols(
+        language_server_id: String,
+        symbols: Vec<Symbol>,
+    ) -> Result<Vec<Option<CodeLabel>>, String> {
+        let language_server_id = LanguageServerId(language_server_id);
+        let mut labels = Vec::new();
+        for (ix, symbol) in symbols.into_iter().enumerate() {
+            let label = extension().label_for_symbol(&language_server_id, symbol);
+            if let Some(label) = label {
+                labels.resize(ix + 1, None);
+                *labels.last_mut().unwrap() = Some(label);
+            }
+        }
+        Ok(labels)
+    }
 }
 
 #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
 pub struct LanguageServerId(String);
 
+impl AsRef<str> for LanguageServerId {
+    fn as_ref(&self) -> &str {
+        &self.0
+    }
+}
+
 impl fmt::Display for LanguageServerId {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         write!(f, "{}", self.0)

crates/extension_api/wit/since_v0.0.6/extension.wit 🔗

@@ -3,7 +3,7 @@ package zed:extension;
 world extension {
     import lsp;
 
-    use lsp.{completion};
+    use lsp.{completion, symbol};
 
     export init-extension: func();
 
@@ -117,4 +117,5 @@ world extension {
     }
 
     export labels-for-completions: func(language-server-id: string, completions: list<completion>) -> result<list<option<code-label>>, string>;
+    export labels-for-symbols: func(language-server-id: string, symbols: list<symbol>) -> result<list<option<code-label>>, string>;
 }

crates/extension_api/wit/since_v0.0.6/lsp.wit 🔗

@@ -41,4 +41,39 @@ interface lsp {
         snippet,
         other(s32),
     }
+
+    record symbol {
+        kind: symbol-kind,
+        name: string,
+    }
+
+    variant symbol-kind {
+        file,
+        module,
+        namespace,
+        %package,
+        class,
+        method,
+        property,
+        field,
+        %constructor,
+        %enum,
+        %interface,
+        function,
+        variable,
+        constant,
+        %string,
+        number,
+        boolean,
+        array,
+        object,
+        key,
+        null,
+        enum-member,
+        struct,
+        event,
+        operator,
+        type-parameter,
+        other(s32),
+    }
 }

crates/language/src/language.rs 🔗

@@ -224,8 +224,11 @@ impl CachedLspAdapter {
         &self,
         symbols: &[(String, lsp::SymbolKind)],
         language: &Arc<Language>,
-    ) -> Vec<Option<CodeLabel>> {
-        self.adapter.labels_for_symbols(symbols, language).await
+    ) -> Result<Vec<Option<CodeLabel>>> {
+        self.adapter
+            .clone()
+            .labels_for_symbols(symbols, language)
+            .await
     }
 
     #[cfg(any(test, feature = "test-support"))]
@@ -410,10 +413,10 @@ pub trait LspAdapter: 'static + Send + Sync {
     }
 
     async fn labels_for_symbols(
-        &self,
+        self: Arc<Self>,
         symbols: &[(String, lsp::SymbolKind)],
         language: &Arc<Language>,
-    ) -> Vec<Option<CodeLabel>> {
+    ) -> Result<Vec<Option<CodeLabel>>> {
         let mut labels = Vec::new();
         for (ix, (name, kind)) in symbols.into_iter().enumerate() {
             let label = self.label_for_symbol(name, *kind, language).await;
@@ -422,7 +425,7 @@ pub trait LspAdapter: 'static + Send + Sync {
                 *labels.last_mut().unwrap() = Some(label);
             }
         }
-        labels
+        Ok(labels)
     }
 
     async fn label_for_symbol(

crates/project/src/project.rs 🔗

@@ -9876,7 +9876,9 @@ async fn populate_labels_for_symbols(
             if let Some(lsp_adapter) = lsp_adapter {
                 labels = lsp_adapter
                     .labels_for_symbols(&label_params, &language)
-                    .await;
+                    .await
+                    .log_err()
+                    .unwrap_or_default();
             }
         }
 

extensions/haskell/Cargo.toml 🔗

@@ -13,4 +13,5 @@ path = "src/haskell.rs"
 crate-type = ["cdylib"]
 
 [dependencies]
-zed_extension_api = "0.0.4"
+# zed_extension_api = "0.0.4"
+zed_extension_api = { path = "../../crates/extension_api" }

extensions/haskell/src/haskell.rs 🔗

@@ -1,3 +1,5 @@
+use zed::lsp::{Symbol, SymbolKind};
+use zed::{CodeLabel, CodeLabelSpan};
 use zed_extension_api::{self as zed, Result};
 
 struct HaskellExtension;
@@ -9,7 +11,7 @@ impl zed::Extension for HaskellExtension {
 
     fn language_server_command(
         &mut self,
-        _config: zed::LanguageServerConfig,
+        _language_server_id: &zed::LanguageServerId,
         worktree: &zed::Worktree,
     ) -> Result<zed::Command> {
         let path = worktree
@@ -22,6 +24,44 @@ impl zed::Extension for HaskellExtension {
             env: Default::default(),
         })
     }
+
+    fn label_for_symbol(
+        &self,
+        _language_server_id: &zed::LanguageServerId,
+        symbol: Symbol,
+    ) -> Option<CodeLabel> {
+        let name = &symbol.name;
+
+        let (code, display_range, filter_range) = match symbol.kind {
+            SymbolKind::Struct => {
+                let data_decl = "data ";
+                let code = format!("{data_decl}{name} = A");
+                let display_range = 0..data_decl.len() + name.len();
+                let filter_range = data_decl.len()..display_range.end;
+                (code, display_range, filter_range)
+            }
+            SymbolKind::Constructor => {
+                let data_decl = "data A = ";
+                let code = format!("{data_decl}{name}");
+                let display_range = data_decl.len()..data_decl.len() + name.len();
+                let filter_range = 0..name.len();
+                (code, display_range, filter_range)
+            }
+            SymbolKind::Variable => {
+                let code = format!("{name} :: T");
+                let display_range = 0..name.len();
+                let filter_range = 0..name.len();
+                (code, display_range, filter_range)
+            }
+            _ => return None,
+        };
+
+        Some(CodeLabel {
+            spans: vec![CodeLabelSpan::code_range(display_range)],
+            filter_range: filter_range.into(),
+            code,
+        })
+    }
 }
 
 zed::register_extension!(HaskellExtension);