extension_lsp_adapter.rs

  1use crate::wasm_host::{
  2    wit::{self, LanguageServerConfig},
  3    WasmExtension, WasmHost,
  4};
  5use anyhow::{anyhow, Context, Result};
  6use async_trait::async_trait;
  7use collections::HashMap;
  8use extension::WorktreeDelegate;
  9use futures::{Future, FutureExt};
 10use gpui::AsyncAppContext;
 11use language::{
 12    CodeLabel, HighlightId, Language, LanguageToolchainStore, LspAdapter, LspAdapterDelegate,
 13};
 14use lsp::{CodeActionKind, LanguageServerBinary, LanguageServerBinaryOptions, LanguageServerName};
 15use serde::Serialize;
 16use serde_json::Value;
 17use std::ops::Range;
 18use std::{any::Any, path::PathBuf, pin::Pin, sync::Arc};
 19use util::{maybe, ResultExt};
 20use wasmtime_wasi::WasiView as _;
 21
 22/// An adapter that allows an [`LspAdapterDelegate`] to be used as a [`WorktreeDelegate`].
 23pub struct WorktreeDelegateAdapter(pub Arc<dyn LspAdapterDelegate>);
 24
 25#[async_trait]
 26impl WorktreeDelegate for WorktreeDelegateAdapter {
 27    fn id(&self) -> u64 {
 28        self.0.worktree_id().to_proto()
 29    }
 30
 31    fn root_path(&self) -> String {
 32        self.0.worktree_root_path().to_string_lossy().to_string()
 33    }
 34
 35    async fn read_text_file(&self, path: PathBuf) -> Result<String> {
 36        self.0.read_text_file(path).await
 37    }
 38
 39    async fn which(&self, binary_name: String) -> Option<String> {
 40        self.0
 41            .which(binary_name.as_ref())
 42            .await
 43            .map(|path| path.to_string_lossy().to_string())
 44    }
 45
 46    async fn shell_env(&self) -> Vec<(String, String)> {
 47        self.0.shell_env().await.into_iter().collect()
 48    }
 49}
 50
 51pub struct ExtensionLspAdapter {
 52    pub(crate) extension: WasmExtension,
 53    pub(crate) language_server_id: LanguageServerName,
 54    pub(crate) config: LanguageServerConfig,
 55    pub(crate) host: Arc<WasmHost>,
 56}
 57
 58#[async_trait(?Send)]
 59impl LspAdapter for ExtensionLspAdapter {
 60    fn name(&self) -> LanguageServerName {
 61        LanguageServerName(self.config.name.clone().into())
 62    }
 63
 64    fn get_language_server_command<'a>(
 65        self: Arc<Self>,
 66        delegate: Arc<dyn LspAdapterDelegate>,
 67        _: LanguageServerBinaryOptions,
 68        _: futures::lock::MutexGuard<'a, Option<LanguageServerBinary>>,
 69        _: &'a mut AsyncAppContext,
 70    ) -> Pin<Box<dyn 'a + Future<Output = Result<LanguageServerBinary>>>> {
 71        async move {
 72            let command = self
 73                .extension
 74                .call({
 75                    let this = self.clone();
 76                    |extension, store| {
 77                        async move {
 78                            let delegate = Arc::new(WorktreeDelegateAdapter(delegate.clone())) as _;
 79                            let resource = store.data_mut().table().push(delegate)?;
 80                            let command = extension
 81                                .call_language_server_command(
 82                                    store,
 83                                    &this.language_server_id,
 84                                    &this.config,
 85                                    resource,
 86                                )
 87                                .await?
 88                                .map_err(|e| anyhow!("{}", e))?;
 89                            anyhow::Ok(command)
 90                        }
 91                        .boxed()
 92                    }
 93                })
 94                .await?;
 95
 96            let path = self
 97                .host
 98                .path_from_extension(&self.extension.manifest.id, command.command.as_ref());
 99
100            // TODO: This should now be done via the `zed::make_file_executable` function in
101            // Zed extension API, but we're leaving these existing usages in place temporarily
102            // to avoid any compatibility issues between Zed and the extension versions.
103            //
104            // We can remove once the following extension versions no longer see any use:
105            // - toml@0.0.2
106            // - zig@0.0.1
107            if ["toml", "zig"].contains(&self.extension.manifest.id.as_ref())
108                && path.starts_with(&self.host.work_dir)
109            {
110                #[cfg(not(windows))]
111                {
112                    use std::fs::{self, Permissions};
113                    use std::os::unix::fs::PermissionsExt;
114
115                    fs::set_permissions(&path, Permissions::from_mode(0o755))
116                        .context("failed to set file permissions")?;
117                }
118            }
119
120            Ok(LanguageServerBinary {
121                path,
122                arguments: command.args.into_iter().map(|arg| arg.into()).collect(),
123                env: Some(command.env.into_iter().collect()),
124            })
125        }
126        .boxed_local()
127    }
128
129    async fn fetch_latest_server_version(
130        &self,
131        _: &dyn LspAdapterDelegate,
132    ) -> Result<Box<dyn 'static + Send + Any>> {
133        unreachable!("get_language_server_command is overridden")
134    }
135
136    async fn fetch_server_binary(
137        &self,
138        _: Box<dyn 'static + Send + Any>,
139        _: PathBuf,
140        _: &dyn LspAdapterDelegate,
141    ) -> Result<LanguageServerBinary> {
142        unreachable!("get_language_server_command is overridden")
143    }
144
145    async fn cached_server_binary(
146        &self,
147        _: PathBuf,
148        _: &dyn LspAdapterDelegate,
149    ) -> Option<LanguageServerBinary> {
150        unreachable!("get_language_server_command is overridden")
151    }
152
153    fn code_action_kinds(&self) -> Option<Vec<CodeActionKind>> {
154        let code_action_kinds = self
155            .extension
156            .manifest
157            .language_servers
158            .get(&self.language_server_id)
159            .and_then(|server| server.code_action_kinds.clone());
160
161        code_action_kinds.or(Some(vec![
162            CodeActionKind::EMPTY,
163            CodeActionKind::QUICKFIX,
164            CodeActionKind::REFACTOR,
165            CodeActionKind::REFACTOR_EXTRACT,
166            CodeActionKind::SOURCE,
167        ]))
168    }
169
170    fn language_ids(&self) -> HashMap<String, String> {
171        // TODO: The language IDs can be provided via the language server options
172        // in `extension.toml now but we're leaving these existing usages in place temporarily
173        // to avoid any compatibility issues between Zed and the extension versions.
174        //
175        // We can remove once the following extension versions no longer see any use:
176        // - php@0.0.1
177        if self.extension.manifest.id.as_ref() == "php" {
178            return HashMap::from_iter([("PHP".into(), "php".into())]);
179        }
180
181        self.extension
182            .manifest
183            .language_servers
184            .get(&LanguageServerName(self.config.name.clone().into()))
185            .map(|server| server.language_ids.clone())
186            .unwrap_or_default()
187    }
188
189    async fn initialization_options(
190        self: Arc<Self>,
191        delegate: &Arc<dyn LspAdapterDelegate>,
192    ) -> Result<Option<serde_json::Value>> {
193        let delegate = delegate.clone();
194        let json_options = self
195            .extension
196            .call({
197                let this = self.clone();
198                |extension, store| {
199                    async move {
200                        let delegate = Arc::new(WorktreeDelegateAdapter(delegate.clone())) as _;
201                        let resource = store.data_mut().table().push(delegate)?;
202                        let options = extension
203                            .call_language_server_initialization_options(
204                                store,
205                                &this.language_server_id,
206                                &this.config,
207                                resource,
208                            )
209                            .await?
210                            .map_err(|e| anyhow!("{}", e))?;
211                        anyhow::Ok(options)
212                    }
213                    .boxed()
214                }
215            })
216            .await?;
217        Ok(if let Some(json_options) = json_options {
218            serde_json::from_str(&json_options).with_context(|| {
219                format!("failed to parse initialization_options from extension: {json_options}")
220            })?
221        } else {
222            None
223        })
224    }
225
226    async fn workspace_configuration(
227        self: Arc<Self>,
228        delegate: &Arc<dyn LspAdapterDelegate>,
229        _: Arc<dyn LanguageToolchainStore>,
230        _cx: &mut AsyncAppContext,
231    ) -> Result<Value> {
232        let delegate = delegate.clone();
233        let json_options: Option<String> = self
234            .extension
235            .call({
236                let this = self.clone();
237                |extension, store| {
238                    async move {
239                        let delegate = Arc::new(WorktreeDelegateAdapter(delegate.clone())) as _;
240                        let resource = store.data_mut().table().push(delegate)?;
241                        let options = extension
242                            .call_language_server_workspace_configuration(
243                                store,
244                                &this.language_server_id,
245                                resource,
246                            )
247                            .await?
248                            .map_err(|e| anyhow!("{}", e))?;
249                        anyhow::Ok(options)
250                    }
251                    .boxed()
252                }
253            })
254            .await?;
255        Ok(if let Some(json_options) = json_options {
256            serde_json::from_str(&json_options).with_context(|| {
257                format!("failed to parse initialization_options from extension: {json_options}")
258            })?
259        } else {
260            serde_json::json!({})
261        })
262    }
263
264    async fn labels_for_completions(
265        self: Arc<Self>,
266        completions: &[lsp::CompletionItem],
267        language: &Arc<Language>,
268    ) -> Result<Vec<Option<CodeLabel>>> {
269        let completions = completions
270            .iter()
271            .map(|completion| wit::Completion::from(completion.clone()))
272            .collect::<Vec<_>>();
273
274        let labels = self
275            .extension
276            .call({
277                let this = self.clone();
278                |extension, store| {
279                    async move {
280                        extension
281                            .call_labels_for_completions(
282                                store,
283                                &this.language_server_id,
284                                completions,
285                            )
286                            .await?
287                            .map_err(|e| anyhow!("{}", e))
288                    }
289                    .boxed()
290                }
291            })
292            .await?;
293
294        Ok(labels_from_wit(labels, language))
295    }
296
297    async fn labels_for_symbols(
298        self: Arc<Self>,
299        symbols: &[(String, lsp::SymbolKind)],
300        language: &Arc<Language>,
301    ) -> Result<Vec<Option<CodeLabel>>> {
302        let symbols = symbols
303            .iter()
304            .cloned()
305            .map(|(name, kind)| wit::Symbol {
306                name,
307                kind: kind.into(),
308            })
309            .collect::<Vec<_>>();
310
311        let labels = self
312            .extension
313            .call({
314                let this = self.clone();
315                |extension, store| {
316                    async move {
317                        extension
318                            .call_labels_for_symbols(store, &this.language_server_id, symbols)
319                            .await?
320                            .map_err(|e| anyhow!("{}", e))
321                    }
322                    .boxed()
323                }
324            })
325            .await?;
326
327        Ok(labels_from_wit(labels, language))
328    }
329}
330
331fn labels_from_wit(
332    labels: Vec<Option<wit::CodeLabel>>,
333    language: &Arc<Language>,
334) -> Vec<Option<CodeLabel>> {
335    labels
336        .into_iter()
337        .map(|label| {
338            let label = label?;
339            let runs = if label.code.is_empty() {
340                Vec::new()
341            } else {
342                language.highlight_text(&label.code.as_str().into(), 0..label.code.len())
343            };
344            build_code_label(&label, &runs, language)
345        })
346        .collect()
347}
348
349fn build_code_label(
350    label: &wit::CodeLabel,
351    parsed_runs: &[(Range<usize>, HighlightId)],
352    language: &Arc<Language>,
353) -> Option<CodeLabel> {
354    let mut text = String::new();
355    let mut runs = vec![];
356
357    for span in &label.spans {
358        match span {
359            wit::CodeLabelSpan::CodeRange(range) => {
360                let range = Range::from(*range);
361                let code_span = &label.code.get(range.clone())?;
362                let mut input_ix = range.start;
363                let mut output_ix = text.len();
364                for (run_range, id) in parsed_runs {
365                    if run_range.start >= range.end {
366                        break;
367                    }
368                    if run_range.end <= input_ix {
369                        continue;
370                    }
371
372                    if run_range.start > input_ix {
373                        let len = run_range.start - input_ix;
374                        output_ix += len;
375                        input_ix += len;
376                    }
377
378                    let len = range.end.min(run_range.end) - input_ix;
379                    runs.push((output_ix..output_ix + len, *id));
380                    output_ix += len;
381                    input_ix += len;
382                }
383
384                text.push_str(code_span);
385            }
386            wit::CodeLabelSpan::Literal(span) => {
387                let highlight_id = language
388                    .grammar()
389                    .zip(span.highlight_name.as_ref())
390                    .and_then(|(grammar, highlight_name)| {
391                        grammar.highlight_id_for_name(highlight_name)
392                    })
393                    .unwrap_or_default();
394                let ix = text.len();
395                runs.push((ix..ix + span.text.len(), highlight_id));
396                text.push_str(&span.text);
397            }
398        }
399    }
400
401    let filter_range = Range::from(label.filter_range);
402    text.get(filter_range.clone())?;
403    Some(CodeLabel {
404        text,
405        runs,
406        filter_range,
407    })
408}
409
410impl From<wit::Range> for Range<usize> {
411    fn from(range: wit::Range) -> Self {
412        let start = range.start as usize;
413        let end = range.end as usize;
414        start..end
415    }
416}
417
418impl From<lsp::CompletionItem> for wit::Completion {
419    fn from(value: lsp::CompletionItem) -> Self {
420        Self {
421            label: value.label,
422            label_details: value.label_details.map(Into::into),
423            detail: value.detail,
424            kind: value.kind.map(Into::into),
425            insert_text_format: value.insert_text_format.map(Into::into),
426        }
427    }
428}
429
430impl From<lsp::CompletionItemLabelDetails> for wit::CompletionLabelDetails {
431    fn from(value: lsp::CompletionItemLabelDetails) -> Self {
432        Self {
433            detail: value.detail,
434            description: value.description,
435        }
436    }
437}
438
439impl From<lsp::CompletionItemKind> for wit::CompletionKind {
440    fn from(value: lsp::CompletionItemKind) -> Self {
441        match value {
442            lsp::CompletionItemKind::TEXT => Self::Text,
443            lsp::CompletionItemKind::METHOD => Self::Method,
444            lsp::CompletionItemKind::FUNCTION => Self::Function,
445            lsp::CompletionItemKind::CONSTRUCTOR => Self::Constructor,
446            lsp::CompletionItemKind::FIELD => Self::Field,
447            lsp::CompletionItemKind::VARIABLE => Self::Variable,
448            lsp::CompletionItemKind::CLASS => Self::Class,
449            lsp::CompletionItemKind::INTERFACE => Self::Interface,
450            lsp::CompletionItemKind::MODULE => Self::Module,
451            lsp::CompletionItemKind::PROPERTY => Self::Property,
452            lsp::CompletionItemKind::UNIT => Self::Unit,
453            lsp::CompletionItemKind::VALUE => Self::Value,
454            lsp::CompletionItemKind::ENUM => Self::Enum,
455            lsp::CompletionItemKind::KEYWORD => Self::Keyword,
456            lsp::CompletionItemKind::SNIPPET => Self::Snippet,
457            lsp::CompletionItemKind::COLOR => Self::Color,
458            lsp::CompletionItemKind::FILE => Self::File,
459            lsp::CompletionItemKind::REFERENCE => Self::Reference,
460            lsp::CompletionItemKind::FOLDER => Self::Folder,
461            lsp::CompletionItemKind::ENUM_MEMBER => Self::EnumMember,
462            lsp::CompletionItemKind::CONSTANT => Self::Constant,
463            lsp::CompletionItemKind::STRUCT => Self::Struct,
464            lsp::CompletionItemKind::EVENT => Self::Event,
465            lsp::CompletionItemKind::OPERATOR => Self::Operator,
466            lsp::CompletionItemKind::TYPE_PARAMETER => Self::TypeParameter,
467            _ => Self::Other(extract_int(value)),
468        }
469    }
470}
471
472impl From<lsp::InsertTextFormat> for wit::InsertTextFormat {
473    fn from(value: lsp::InsertTextFormat) -> Self {
474        match value {
475            lsp::InsertTextFormat::PLAIN_TEXT => Self::PlainText,
476            lsp::InsertTextFormat::SNIPPET => Self::Snippet,
477            _ => Self::Other(extract_int(value)),
478        }
479    }
480}
481
482impl From<lsp::SymbolKind> for wit::SymbolKind {
483    fn from(value: lsp::SymbolKind) -> Self {
484        match value {
485            lsp::SymbolKind::FILE => Self::File,
486            lsp::SymbolKind::MODULE => Self::Module,
487            lsp::SymbolKind::NAMESPACE => Self::Namespace,
488            lsp::SymbolKind::PACKAGE => Self::Package,
489            lsp::SymbolKind::CLASS => Self::Class,
490            lsp::SymbolKind::METHOD => Self::Method,
491            lsp::SymbolKind::PROPERTY => Self::Property,
492            lsp::SymbolKind::FIELD => Self::Field,
493            lsp::SymbolKind::CONSTRUCTOR => Self::Constructor,
494            lsp::SymbolKind::ENUM => Self::Enum,
495            lsp::SymbolKind::INTERFACE => Self::Interface,
496            lsp::SymbolKind::FUNCTION => Self::Function,
497            lsp::SymbolKind::VARIABLE => Self::Variable,
498            lsp::SymbolKind::CONSTANT => Self::Constant,
499            lsp::SymbolKind::STRING => Self::String,
500            lsp::SymbolKind::NUMBER => Self::Number,
501            lsp::SymbolKind::BOOLEAN => Self::Boolean,
502            lsp::SymbolKind::ARRAY => Self::Array,
503            lsp::SymbolKind::OBJECT => Self::Object,
504            lsp::SymbolKind::KEY => Self::Key,
505            lsp::SymbolKind::NULL => Self::Null,
506            lsp::SymbolKind::ENUM_MEMBER => Self::EnumMember,
507            lsp::SymbolKind::STRUCT => Self::Struct,
508            lsp::SymbolKind::EVENT => Self::Event,
509            lsp::SymbolKind::OPERATOR => Self::Operator,
510            lsp::SymbolKind::TYPE_PARAMETER => Self::TypeParameter,
511            _ => Self::Other(extract_int(value)),
512        }
513    }
514}
515
516fn extract_int<T: Serialize>(value: T) -> i32 {
517    maybe!({
518        let kind = serde_json::to_value(&value)?;
519        serde_json::from_value(kind)
520    })
521    .log_err()
522    .unwrap_or(-1)
523}
524
525#[test]
526fn test_build_code_label() {
527    use util::test::marked_text_ranges;
528
529    let (code, code_ranges) = marked_text_ranges(
530        "«const» «a»: «fn»(«Bcd»(«Efgh»)) -> «Ijklm» = pqrs.tuv",
531        false,
532    );
533    let code_runs = code_ranges
534        .into_iter()
535        .map(|range| (range, HighlightId(0)))
536        .collect::<Vec<_>>();
537
538    let label = build_code_label(
539        &wit::CodeLabel {
540            spans: vec![
541                wit::CodeLabelSpan::CodeRange(wit::Range {
542                    start: code.find("pqrs").unwrap() as u32,
543                    end: code.len() as u32,
544                }),
545                wit::CodeLabelSpan::CodeRange(wit::Range {
546                    start: code.find(": fn").unwrap() as u32,
547                    end: code.find(" = ").unwrap() as u32,
548                }),
549            ],
550            filter_range: wit::Range {
551                start: 0,
552                end: "pqrs.tuv".len() as u32,
553            },
554            code,
555        },
556        &code_runs,
557        &language::PLAIN_TEXT,
558    )
559    .unwrap();
560
561    let (label_text, label_ranges) =
562        marked_text_ranges("pqrs.tuv: «fn»(«Bcd»(«Efgh»)) -> «Ijklm»", false);
563    let label_runs = label_ranges
564        .into_iter()
565        .map(|range| (range, HighlightId(0)))
566        .collect::<Vec<_>>();
567
568    assert_eq!(
569        label,
570        CodeLabel {
571            text: label_text,
572            runs: label_runs,
573            filter_range: label.filter_range.clone()
574        }
575    )
576}
577
578#[test]
579fn test_build_code_label_with_invalid_ranges() {
580    use util::test::marked_text_ranges;
581
582    let (code, code_ranges) = marked_text_ranges("const «a»: «B» = '🏀'", false);
583    let code_runs = code_ranges
584        .into_iter()
585        .map(|range| (range, HighlightId(0)))
586        .collect::<Vec<_>>();
587
588    // A span uses a code range that is invalid because it starts inside of
589    // a multi-byte character.
590    let label = build_code_label(
591        &wit::CodeLabel {
592            spans: vec![
593                wit::CodeLabelSpan::CodeRange(wit::Range {
594                    start: code.find('B').unwrap() as u32,
595                    end: code.find(" = ").unwrap() as u32,
596                }),
597                wit::CodeLabelSpan::CodeRange(wit::Range {
598                    start: code.find('🏀').unwrap() as u32 + 1,
599                    end: code.len() as u32,
600                }),
601            ],
602            filter_range: wit::Range {
603                start: 0,
604                end: "B".len() as u32,
605            },
606            code,
607        },
608        &code_runs,
609        &language::PLAIN_TEXT,
610    );
611    assert!(label.is_none());
612
613    // Filter range extends beyond actual text
614    let label = build_code_label(
615        &wit::CodeLabel {
616            spans: vec![wit::CodeLabelSpan::Literal(wit::CodeLabelSpanLiteral {
617                text: "abc".into(),
618                highlight_name: Some("type".into()),
619            })],
620            filter_range: wit::Range { start: 0, end: 5 },
621            code: String::new(),
622        },
623        &code_runs,
624        &language::PLAIN_TEXT,
625    );
626    assert!(label.is_none());
627}