headless_host.rs

  1use std::{path::PathBuf, sync::Arc};
  2
  3use anyhow::{Context as _, Result};
  4use client::{
  5    TypedEnvelope,
  6    proto::{self, FromProto},
  7};
  8use collections::{HashMap, HashSet};
  9use extension::{
 10    Extension, ExtensionDebugAdapterProviderProxy, ExtensionHostProxy, ExtensionLanguageProxy,
 11    ExtensionLanguageServerProxy, ExtensionManifest,
 12};
 13use fs::{Fs, RemoveOptions, RenameOptions};
 14use futures::future::join_all;
 15use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
 16use http_client::HttpClient;
 17use language::{LanguageConfig, LanguageName, LanguageQueries, LoadedLanguage};
 18use lsp::LanguageServerName;
 19use node_runtime::NodeRuntime;
 20
 21use crate::wasm_host::{WasmExtension, WasmHost};
 22
 23#[derive(Clone, Debug)]
 24pub struct ExtensionVersion {
 25    pub id: String,
 26    pub version: String,
 27    pub dev: bool,
 28}
 29
 30pub struct HeadlessExtensionStore {
 31    pub fs: Arc<dyn Fs>,
 32    pub extension_dir: PathBuf,
 33    pub proxy: Arc<ExtensionHostProxy>,
 34    pub wasm_host: Arc<WasmHost>,
 35    pub loaded_extensions: HashMap<Arc<str>, Arc<str>>,
 36    pub loaded_languages: HashMap<Arc<str>, Vec<LanguageName>>,
 37    pub loaded_language_servers: HashMap<Arc<str>, Vec<(LanguageServerName, LanguageName)>>,
 38}
 39
 40impl HeadlessExtensionStore {
 41    pub fn new(
 42        fs: Arc<dyn Fs>,
 43        http_client: Arc<dyn HttpClient>,
 44        extension_dir: PathBuf,
 45        extension_host_proxy: Arc<ExtensionHostProxy>,
 46        node_runtime: NodeRuntime,
 47        cx: &mut App,
 48    ) -> Entity<Self> {
 49        cx.new(|cx| Self {
 50            fs: fs.clone(),
 51            wasm_host: WasmHost::new(
 52                fs.clone(),
 53                http_client.clone(),
 54                node_runtime,
 55                extension_host_proxy.clone(),
 56                extension_dir.join("work"),
 57                cx,
 58            ),
 59            extension_dir,
 60            proxy: extension_host_proxy,
 61            loaded_extensions: Default::default(),
 62            loaded_languages: Default::default(),
 63            loaded_language_servers: Default::default(),
 64        })
 65    }
 66
 67    pub fn sync_extensions(
 68        &mut self,
 69        extensions: Vec<ExtensionVersion>,
 70        cx: &Context<Self>,
 71    ) -> Task<Result<Vec<ExtensionVersion>>> {
 72        let on_client = HashSet::from_iter(extensions.iter().map(|e| e.id.as_str()));
 73        let to_remove: Vec<Arc<str>> = self
 74            .loaded_extensions
 75            .keys()
 76            .filter(|id| !on_client.contains(id.as_ref()))
 77            .cloned()
 78            .collect();
 79        let to_load: Vec<ExtensionVersion> = extensions
 80            .into_iter()
 81            .filter(|e| {
 82                if e.dev {
 83                    return true;
 84                }
 85                self.loaded_extensions
 86                    .get(e.id.as_str())
 87                    .is_none_or(|loaded| loaded.as_ref() != e.version.as_str())
 88            })
 89            .collect();
 90
 91        cx.spawn(async move |this, cx| {
 92            let mut missing = Vec::new();
 93
 94            for extension_id in to_remove {
 95                log::info!("removing extension: {}", extension_id);
 96                this.update(cx, |this, cx| this.uninstall_extension(&extension_id, cx))?
 97                    .await?;
 98            }
 99
100            for extension in to_load {
101                if let Err(e) = Self::load_extension(this.clone(), extension.clone(), cx).await {
102                    log::info!("failed to load extension: {}, {:?}", extension.id, e);
103                    missing.push(extension)
104                } else if extension.dev {
105                    missing.push(extension)
106                }
107            }
108
109            Ok(missing)
110        })
111    }
112
113    pub async fn load_extension(
114        this: WeakEntity<Self>,
115        extension: ExtensionVersion,
116        cx: &mut AsyncApp,
117    ) -> Result<()> {
118        let (fs, wasm_host, extension_dir) = this.update(cx, |this, _cx| {
119            this.loaded_extensions.insert(
120                extension.id.clone().into(),
121                extension.version.clone().into(),
122            );
123            (
124                this.fs.clone(),
125                this.wasm_host.clone(),
126                this.extension_dir.join(&extension.id),
127            )
128        })?;
129
130        let manifest = Arc::new(ExtensionManifest::load(fs.clone(), &extension_dir).await?);
131
132        debug_assert!(!manifest.languages.is_empty() || manifest.allow_remote_load());
133
134        if manifest.version.as_ref() != extension.version.as_str() {
135            anyhow::bail!(
136                "mismatched versions: ({}) != ({})",
137                manifest.version,
138                extension.version
139            )
140        }
141
142        for language_path in &manifest.languages {
143            let language_path = extension_dir.join(language_path);
144            let config = fs.load(&language_path.join("config.toml")).await?;
145            let mut config = ::toml::from_str::<LanguageConfig>(&config)?;
146
147            this.update(cx, |this, _cx| {
148                this.loaded_languages
149                    .entry(manifest.id.clone())
150                    .or_default()
151                    .push(config.name.clone());
152
153                config.grammar = None;
154
155                this.proxy.register_language(
156                    config.name.clone(),
157                    None,
158                    config.matcher.clone(),
159                    config.hidden,
160                    Arc::new(move || {
161                        Ok(LoadedLanguage {
162                            config: config.clone(),
163                            queries: LanguageQueries::default(),
164                            context_provider: None,
165                            toolchain_provider: None,
166                            manifest_name: None,
167                        })
168                    }),
169                );
170            })?;
171        }
172
173        if !manifest.allow_remote_load() {
174            return Ok(());
175        }
176
177        let wasm_extension: Arc<dyn Extension> =
178            Arc::new(WasmExtension::load(&extension_dir, &manifest, wasm_host.clone(), cx).await?);
179
180        for (language_server_id, language_server_config) in &manifest.language_servers {
181            for language in language_server_config.languages() {
182                this.update(cx, |this, _cx| {
183                    this.loaded_language_servers
184                        .entry(manifest.id.clone())
185                        .or_default()
186                        .push((language_server_id.clone(), language.clone()));
187                    this.proxy.register_language_server(
188                        wasm_extension.clone(),
189                        language_server_id.clone(),
190                        language.clone(),
191                    );
192                })?;
193            }
194            log::info!("Loaded language server: {}", language_server_id);
195        }
196
197        for (debug_adapter, meta) in &manifest.debug_adapters {
198            let schema_path = extension::build_debug_adapter_schema_path(debug_adapter, meta);
199
200            this.update(cx, |this, _cx| {
201                this.proxy.register_debug_adapter(
202                    wasm_extension.clone(),
203                    debug_adapter.clone(),
204                    &extension_dir.join(schema_path),
205                );
206            })?;
207            log::info!("Loaded debug adapter: {}", debug_adapter);
208        }
209
210        for debug_locator in manifest.debug_locators.keys() {
211            this.update(cx, |this, _cx| {
212                this.proxy
213                    .register_debug_locator(wasm_extension.clone(), debug_locator.clone());
214            })?;
215            log::info!("Loaded debug locator: {}", debug_locator);
216        }
217
218        Ok(())
219    }
220
221    fn uninstall_extension(
222        &mut self,
223        extension_id: &Arc<str>,
224        cx: &mut Context<Self>,
225    ) -> Task<Result<()>> {
226        self.loaded_extensions.remove(extension_id);
227
228        let languages_to_remove = self
229            .loaded_languages
230            .remove(extension_id)
231            .unwrap_or_default();
232        self.proxy.remove_languages(&languages_to_remove, &[]);
233
234        let servers_to_remove = self
235            .loaded_language_servers
236            .remove(extension_id)
237            .unwrap_or_default();
238        let proxy = self.proxy.clone();
239        let path = self.extension_dir.join(&extension_id.to_string());
240        let fs = self.fs.clone();
241        cx.spawn(async move |_, cx| {
242            let mut removal_tasks = Vec::with_capacity(servers_to_remove.len());
243            cx.update(|cx| {
244                for (language_server_name, language) in servers_to_remove {
245                    removal_tasks.push(proxy.remove_language_server(
246                        &language,
247                        &language_server_name,
248                        cx,
249                    ));
250                }
251            })
252            .ok();
253            let _ = join_all(removal_tasks).await;
254
255            fs.remove_dir(
256                &path,
257                RemoveOptions {
258                    recursive: true,
259                    ignore_if_not_exists: true,
260                },
261            )
262            .await
263            .with_context(|| format!("Removing directory {path:?}"))
264        })
265    }
266
267    pub fn install_extension(
268        &mut self,
269        extension: ExtensionVersion,
270        tmp_path: PathBuf,
271        cx: &mut Context<Self>,
272    ) -> Task<Result<()>> {
273        let path = self.extension_dir.join(&extension.id);
274        let fs = self.fs.clone();
275
276        cx.spawn(async move |this, cx| {
277            if fs.is_dir(&path).await {
278                this.update(cx, |this, cx| {
279                    this.uninstall_extension(&extension.id.clone().into(), cx)
280                })?
281                .await?;
282            }
283
284            fs.rename(&tmp_path, &path, RenameOptions::default())
285                .await?;
286
287            Self::load_extension(this, extension, cx).await
288        })
289    }
290
291    pub async fn handle_sync_extensions(
292        extension_store: Entity<HeadlessExtensionStore>,
293        envelope: TypedEnvelope<proto::SyncExtensions>,
294        mut cx: AsyncApp,
295    ) -> Result<proto::SyncExtensionsResponse> {
296        let requested_extensions =
297            envelope
298                .payload
299                .extensions
300                .into_iter()
301                .map(|p| ExtensionVersion {
302                    id: p.id,
303                    version: p.version,
304                    dev: p.dev,
305                });
306        let missing_extensions = extension_store
307            .update(&mut cx, |extension_store, cx| {
308                extension_store.sync_extensions(requested_extensions.collect(), cx)
309            })?
310            .await?;
311
312        Ok(proto::SyncExtensionsResponse {
313            missing_extensions: missing_extensions
314                .into_iter()
315                .map(|e| proto::Extension {
316                    id: e.id,
317                    version: e.version,
318                    dev: e.dev,
319                })
320                .collect(),
321            tmp_dir: paths::remote_extensions_uploads_dir()
322                .to_string_lossy()
323                .to_string(),
324        })
325    }
326
327    pub async fn handle_install_extension(
328        extensions: Entity<HeadlessExtensionStore>,
329        envelope: TypedEnvelope<proto::InstallExtension>,
330        mut cx: AsyncApp,
331    ) -> Result<proto::Ack> {
332        let extension = envelope
333            .payload
334            .extension
335            .context("Invalid InstallExtension request")?;
336
337        extensions
338            .update(&mut cx, |extensions, cx| {
339                extensions.install_extension(
340                    ExtensionVersion {
341                        id: extension.id,
342                        version: extension.version,
343                        dev: extension.dev,
344                    },
345                    PathBuf::from_proto(envelope.payload.tmp_dir),
346                    cx,
347                )
348            })?
349            .await?;
350
351        Ok(proto::Ack {})
352    }
353}