extension_api.rs

  1//! The Zed Rust Extension API allows you write extensions for [Zed](https://zed.dev/) in Rust.
  2
  3pub mod http_client;
  4pub mod process;
  5pub mod settings;
  6
  7use core::fmt;
  8
  9use wit::*;
 10
 11pub use serde_json;
 12
 13// WIT re-exports.
 14//
 15// We explicitly enumerate the symbols we want to re-export, as there are some
 16// that we may want to shadow to provide a cleaner Rust API.
 17pub use wit::{
 18    download_file, make_file_executable,
 19    zed::extension::github::{
 20        github_release_by_tag_name, latest_github_release, GithubRelease, GithubReleaseAsset,
 21        GithubReleaseOptions,
 22    },
 23    zed::extension::nodejs::{
 24        node_binary_path, npm_install_package, npm_package_installed_version,
 25        npm_package_latest_version,
 26    },
 27    zed::extension::platform::{current_platform, Architecture, Os},
 28    zed::extension::slash_command::{
 29        SlashCommand, SlashCommandArgumentCompletion, SlashCommandOutput, SlashCommandOutputSection,
 30    },
 31    CodeLabel, CodeLabelSpan, CodeLabelSpanLiteral, Command, DownloadedFileType, EnvVars,
 32    KeyValueStore, LanguageServerInstallationStatus, Project, Range, Worktree,
 33};
 34
 35// Undocumented WIT re-exports.
 36//
 37// These are symbols that need to be public for the purposes of implementing
 38// the extension host, but aren't relevant to extension authors.
 39#[doc(hidden)]
 40pub use wit::Guest;
 41
 42/// Constructs for interacting with language servers over the
 43/// Language Server Protocol (LSP).
 44pub mod lsp {
 45    pub use crate::wit::zed::extension::lsp::{
 46        Completion, CompletionKind, InsertTextFormat, Symbol, SymbolKind,
 47    };
 48}
 49
 50/// A result returned from a Zed extension.
 51pub type Result<T, E = String> = core::result::Result<T, E>;
 52
 53/// Updates the installation status for the given language server.
 54pub fn set_language_server_installation_status(
 55    language_server_id: &LanguageServerId,
 56    status: &LanguageServerInstallationStatus,
 57) {
 58    wit::set_language_server_installation_status(&language_server_id.0, status)
 59}
 60
 61/// A Zed extension.
 62pub trait Extension: Send + Sync {
 63    /// Returns a new instance of the extension.
 64    fn new() -> Self
 65    where
 66        Self: Sized;
 67
 68    /// Returns the command used to start the language server for the specified
 69    /// language.
 70    fn language_server_command(
 71        &mut self,
 72        _language_server_id: &LanguageServerId,
 73        _worktree: &Worktree,
 74    ) -> Result<Command> {
 75        Err("`language_server_command` not implemented".to_string())
 76    }
 77
 78    /// Returns the initialization options to pass to the specified language server.
 79    fn language_server_initialization_options(
 80        &mut self,
 81        _language_server_id: &LanguageServerId,
 82        _worktree: &Worktree,
 83    ) -> Result<Option<serde_json::Value>> {
 84        Ok(None)
 85    }
 86
 87    /// Returns the workspace configuration options to pass to the language server.
 88    fn language_server_workspace_configuration(
 89        &mut self,
 90        _language_server_id: &LanguageServerId,
 91        _worktree: &Worktree,
 92    ) -> Result<Option<serde_json::Value>> {
 93        Ok(None)
 94    }
 95
 96    /// Returns the label for the given completion.
 97    fn label_for_completion(
 98        &self,
 99        _language_server_id: &LanguageServerId,
100        _completion: Completion,
101    ) -> Option<CodeLabel> {
102        None
103    }
104
105    /// Returns the label for the given symbol.
106    fn label_for_symbol(
107        &self,
108        _language_server_id: &LanguageServerId,
109        _symbol: Symbol,
110    ) -> Option<CodeLabel> {
111        None
112    }
113
114    /// Returns the completions that should be shown when completing the provided slash command with the given query.
115    fn complete_slash_command_argument(
116        &self,
117        _command: SlashCommand,
118        _args: Vec<String>,
119    ) -> Result<Vec<SlashCommandArgumentCompletion>, String> {
120        Ok(Vec::new())
121    }
122
123    /// Returns the output from running the provided slash command.
124    fn run_slash_command(
125        &self,
126        _command: SlashCommand,
127        _args: Vec<String>,
128        _worktree: Option<&Worktree>,
129    ) -> Result<SlashCommandOutput, String> {
130        Err("`run_slash_command` not implemented".to_string())
131    }
132
133    /// Returns the command used to start a context server.
134    fn context_server_command(
135        &mut self,
136        _context_server_id: &ContextServerId,
137        _project: &Project,
138    ) -> Result<Command> {
139        Err("`context_server_command` not implemented".to_string())
140    }
141
142    /// Returns a list of package names as suggestions to be included in the
143    /// search results of the `/docs` slash command.
144    ///
145    /// This can be used to provide completions for known packages (e.g., from the
146    /// local project or a registry) before a package has been indexed.
147    fn suggest_docs_packages(&self, _provider: String) -> Result<Vec<String>, String> {
148        Ok(Vec::new())
149    }
150
151    /// Indexes the docs for the specified package.
152    fn index_docs(
153        &self,
154        _provider: String,
155        _package: String,
156        _database: &KeyValueStore,
157    ) -> Result<(), String> {
158        Err("`index_docs` not implemented".to_string())
159    }
160}
161
162/// Registers the provided type as a Zed extension.
163///
164/// The type must implement the [`Extension`] trait.
165#[macro_export]
166macro_rules! register_extension {
167    ($extension_type:ty) => {
168        #[export_name = "init-extension"]
169        pub extern "C" fn __init_extension() {
170            std::env::set_current_dir(std::env::var("PWD").unwrap()).unwrap();
171            zed_extension_api::register_extension(|| {
172                Box::new(<$extension_type as zed_extension_api::Extension>::new())
173            });
174        }
175    };
176}
177
178#[doc(hidden)]
179pub fn register_extension(build_extension: fn() -> Box<dyn Extension>) {
180    unsafe { EXTENSION = Some((build_extension)()) }
181}
182
183fn extension() -> &'static mut dyn Extension {
184    unsafe { EXTENSION.as_deref_mut().unwrap() }
185}
186
187static mut EXTENSION: Option<Box<dyn Extension>> = None;
188
189#[cfg(target_arch = "wasm32")]
190#[link_section = "zed:api-version"]
191#[doc(hidden)]
192pub static ZED_API_VERSION: [u8; 6] = *include_bytes!(concat!(env!("OUT_DIR"), "/version_bytes"));
193
194mod wit {
195    #![allow(clippy::too_many_arguments, clippy::missing_safety_doc)]
196
197    wit_bindgen::generate!({
198        skip: ["init-extension"],
199        path: "./wit/since_v0.3.0",
200    });
201}
202
203wit::export!(Component);
204
205struct Component;
206
207impl wit::Guest for Component {
208    fn language_server_command(
209        language_server_id: String,
210        worktree: &wit::Worktree,
211    ) -> Result<wit::Command> {
212        let language_server_id = LanguageServerId(language_server_id);
213        extension().language_server_command(&language_server_id, worktree)
214    }
215
216    fn language_server_initialization_options(
217        language_server_id: String,
218        worktree: &Worktree,
219    ) -> Result<Option<String>, String> {
220        let language_server_id = LanguageServerId(language_server_id);
221        Ok(extension()
222            .language_server_initialization_options(&language_server_id, worktree)?
223            .and_then(|value| serde_json::to_string(&value).ok()))
224    }
225
226    fn language_server_workspace_configuration(
227        language_server_id: String,
228        worktree: &Worktree,
229    ) -> Result<Option<String>, String> {
230        let language_server_id = LanguageServerId(language_server_id);
231        Ok(extension()
232            .language_server_workspace_configuration(&language_server_id, worktree)?
233            .and_then(|value| serde_json::to_string(&value).ok()))
234    }
235
236    fn labels_for_completions(
237        language_server_id: String,
238        completions: Vec<Completion>,
239    ) -> Result<Vec<Option<CodeLabel>>, String> {
240        let language_server_id = LanguageServerId(language_server_id);
241        let mut labels = Vec::new();
242        for (ix, completion) in completions.into_iter().enumerate() {
243            let label = extension().label_for_completion(&language_server_id, completion);
244            if let Some(label) = label {
245                labels.resize(ix + 1, None);
246                *labels.last_mut().unwrap() = Some(label);
247            }
248        }
249        Ok(labels)
250    }
251
252    fn labels_for_symbols(
253        language_server_id: String,
254        symbols: Vec<Symbol>,
255    ) -> Result<Vec<Option<CodeLabel>>, String> {
256        let language_server_id = LanguageServerId(language_server_id);
257        let mut labels = Vec::new();
258        for (ix, symbol) in symbols.into_iter().enumerate() {
259            let label = extension().label_for_symbol(&language_server_id, symbol);
260            if let Some(label) = label {
261                labels.resize(ix + 1, None);
262                *labels.last_mut().unwrap() = Some(label);
263            }
264        }
265        Ok(labels)
266    }
267
268    fn complete_slash_command_argument(
269        command: SlashCommand,
270        args: Vec<String>,
271    ) -> Result<Vec<SlashCommandArgumentCompletion>, String> {
272        extension().complete_slash_command_argument(command, args)
273    }
274
275    fn run_slash_command(
276        command: SlashCommand,
277        args: Vec<String>,
278        worktree: Option<&Worktree>,
279    ) -> Result<SlashCommandOutput, String> {
280        extension().run_slash_command(command, args, worktree)
281    }
282
283    fn context_server_command(
284        context_server_id: String,
285        project: &Project,
286    ) -> Result<wit::Command> {
287        let context_server_id = ContextServerId(context_server_id);
288        extension().context_server_command(&context_server_id, project)
289    }
290
291    fn suggest_docs_packages(provider: String) -> Result<Vec<String>, String> {
292        extension().suggest_docs_packages(provider)
293    }
294
295    fn index_docs(
296        provider: String,
297        package: String,
298        database: &KeyValueStore,
299    ) -> Result<(), String> {
300        extension().index_docs(provider, package, database)
301    }
302}
303
304/// The ID of a language server.
305#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
306pub struct LanguageServerId(String);
307
308impl AsRef<str> for LanguageServerId {
309    fn as_ref(&self) -> &str {
310        &self.0
311    }
312}
313
314impl fmt::Display for LanguageServerId {
315    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316        write!(f, "{}", self.0)
317    }
318}
319
320/// The ID of a context server.
321#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
322pub struct ContextServerId(String);
323
324impl AsRef<str> for ContextServerId {
325    fn as_ref(&self) -> &str {
326        &self.0
327    }
328}
329
330impl fmt::Display for ContextServerId {
331    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
332        write!(f, "{}", self.0)
333    }
334}
335
336impl CodeLabelSpan {
337    /// Returns a [`CodeLabelSpan::CodeRange`].
338    pub fn code_range(range: impl Into<wit::Range>) -> Self {
339        Self::CodeRange(range.into())
340    }
341
342    /// Returns a [`CodeLabelSpan::Literal`].
343    pub fn literal(text: impl Into<String>, highlight_name: Option<String>) -> Self {
344        Self::Literal(CodeLabelSpanLiteral {
345            text: text.into(),
346            highlight_name,
347        })
348    }
349}
350
351impl From<std::ops::Range<u32>> for wit::Range {
352    fn from(value: std::ops::Range<u32>) -> Self {
353        Self {
354            start: value.start,
355            end: value.end,
356        }
357    }
358}
359
360impl From<std::ops::Range<usize>> for wit::Range {
361    fn from(value: std::ops::Range<usize>) -> Self {
362        Self {
363            start: value.start as u32,
364            end: value.end as u32,
365        }
366    }
367}