headless_host.rs

  1use std::{path::PathBuf, sync::Arc};
  2
  3use anyhow::{Context as _, Result};
  4use client::{TypedEnvelope, proto};
  5use collections::{HashMap, HashSet};
  6use extension::{
  7    Extension, ExtensionDebugAdapterProviderProxy, ExtensionHostProxy, ExtensionLanguageProxy,
  8    ExtensionLanguageServerProxy, ExtensionManifest,
  9};
 10use fs::{Fs, RemoveOptions, RenameOptions};
 11use futures::future::join_all;
 12use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
 13use http_client::HttpClient;
 14use language::{LanguageConfig, LanguageName, LanguageQueries, LoadedLanguage};
 15use lsp::LanguageServerName;
 16use node_runtime::NodeRuntime;
 17
 18use crate::wasm_host::{WasmExtension, WasmHost};
 19
 20#[derive(Clone, Debug)]
 21pub struct ExtensionVersion {
 22    pub id: String,
 23    pub version: String,
 24    pub dev: bool,
 25}
 26
 27pub struct HeadlessExtensionStore {
 28    pub fs: Arc<dyn Fs>,
 29    pub extension_dir: PathBuf,
 30    pub proxy: Arc<ExtensionHostProxy>,
 31    pub wasm_host: Arc<WasmHost>,
 32    pub loaded_extensions: HashMap<Arc<str>, Arc<str>>,
 33    pub loaded_languages: HashMap<Arc<str>, Vec<LanguageName>>,
 34    pub loaded_language_servers: HashMap<Arc<str>, Vec<(LanguageServerName, LanguageName)>>,
 35}
 36
 37impl HeadlessExtensionStore {
 38    pub fn new(
 39        fs: Arc<dyn Fs>,
 40        http_client: Arc<dyn HttpClient>,
 41        extension_dir: PathBuf,
 42        extension_host_proxy: Arc<ExtensionHostProxy>,
 43        node_runtime: NodeRuntime,
 44        cx: &mut App,
 45    ) -> Entity<Self> {
 46        cx.new(|cx| Self {
 47            fs: fs.clone(),
 48            wasm_host: WasmHost::new(
 49                fs.clone(),
 50                http_client.clone(),
 51                node_runtime,
 52                extension_host_proxy.clone(),
 53                extension_dir.join("work"),
 54                cx,
 55            ),
 56            extension_dir,
 57            proxy: extension_host_proxy,
 58            loaded_extensions: Default::default(),
 59            loaded_languages: Default::default(),
 60            loaded_language_servers: Default::default(),
 61        })
 62    }
 63
 64    pub fn sync_extensions(
 65        &mut self,
 66        extensions: Vec<ExtensionVersion>,
 67        cx: &Context<Self>,
 68    ) -> Task<Result<Vec<ExtensionVersion>>> {
 69        let on_client = HashSet::from_iter(extensions.iter().map(|e| e.id.as_str()));
 70        let to_remove: Vec<Arc<str>> = self
 71            .loaded_extensions
 72            .keys()
 73            .filter(|id| !on_client.contains(id.as_ref()))
 74            .cloned()
 75            .collect();
 76        let to_load: Vec<ExtensionVersion> = extensions
 77            .into_iter()
 78            .filter(|e| {
 79                if e.dev {
 80                    return true;
 81                }
 82                self.loaded_extensions
 83                    .get(e.id.as_str())
 84                    .is_none_or(|loaded| loaded.as_ref() != e.version.as_str())
 85            })
 86            .collect();
 87
 88        cx.spawn(async move |this, cx| {
 89            let mut missing = Vec::new();
 90
 91            for extension_id in to_remove {
 92                log::info!("removing extension: {}", extension_id);
 93                this.update(cx, |this, cx| this.uninstall_extension(&extension_id, cx))?
 94                    .await?;
 95            }
 96
 97            for extension in to_load {
 98                if let Err(e) = Self::load_extension(this.clone(), extension.clone(), cx).await {
 99                    log::info!("failed to load extension: {}, {:?}", extension.id, e);
100                    missing.push(extension)
101                } else if extension.dev {
102                    missing.push(extension)
103                }
104            }
105
106            Ok(missing)
107        })
108    }
109
110    pub async fn load_extension(
111        this: WeakEntity<Self>,
112        extension: ExtensionVersion,
113        cx: &mut AsyncApp,
114    ) -> Result<()> {
115        let (fs, wasm_host, extension_dir) = this.update(cx, |this, _cx| {
116            this.loaded_extensions.insert(
117                extension.id.clone().into(),
118                extension.version.clone().into(),
119            );
120            (
121                this.fs.clone(),
122                this.wasm_host.clone(),
123                this.extension_dir.join(&extension.id),
124            )
125        })?;
126
127        let manifest = Arc::new(ExtensionManifest::load(fs.clone(), &extension_dir).await?);
128
129        debug_assert!(!manifest.languages.is_empty() || manifest.allow_remote_load());
130
131        if manifest.version.as_ref() != extension.version.as_str() {
132            anyhow::bail!(
133                "mismatched versions: ({}) != ({})",
134                manifest.version,
135                extension.version
136            )
137        }
138
139        for language_path in &manifest.languages {
140            let language_path = extension_dir.join(language_path);
141            let config = fs.load(&language_path.join("config.toml")).await?;
142            let mut config = ::toml::from_str::<LanguageConfig>(&config)?;
143
144            this.update(cx, |this, _cx| {
145                this.loaded_languages
146                    .entry(manifest.id.clone())
147                    .or_default()
148                    .push(config.name.clone());
149
150                config.grammar = None;
151
152                this.proxy.register_language(
153                    config.name.clone(),
154                    None,
155                    config.matcher.clone(),
156                    config.hidden,
157                    Arc::new(move || {
158                        Ok(LoadedLanguage {
159                            config: config.clone(),
160                            queries: LanguageQueries::default(),
161                            context_provider: None,
162                            toolchain_provider: None,
163                            manifest_name: None,
164                        })
165                    }),
166                );
167            })?;
168        }
169
170        if !manifest.allow_remote_load() {
171            return Ok(());
172        }
173
174        let wasm_extension: Arc<dyn Extension> =
175            Arc::new(WasmExtension::load(&extension_dir, &manifest, wasm_host.clone(), cx).await?);
176
177        for (language_server_id, language_server_config) in &manifest.language_servers {
178            for language in language_server_config.languages() {
179                this.update(cx, |this, _cx| {
180                    this.loaded_language_servers
181                        .entry(manifest.id.clone())
182                        .or_default()
183                        .push((language_server_id.clone(), language.clone()));
184                    this.proxy.register_language_server(
185                        wasm_extension.clone(),
186                        language_server_id.clone(),
187                        language.clone(),
188                    );
189                })?;
190            }
191            log::info!("Loaded language server: {}", language_server_id);
192        }
193
194        for (debug_adapter, meta) in &manifest.debug_adapters {
195            let schema_path = extension::build_debug_adapter_schema_path(debug_adapter, meta);
196
197            this.update(cx, |this, _cx| {
198                this.proxy.register_debug_adapter(
199                    wasm_extension.clone(),
200                    debug_adapter.clone(),
201                    &extension_dir.join(schema_path),
202                );
203            })?;
204            log::info!("Loaded debug adapter: {}", debug_adapter);
205        }
206
207        for debug_locator in manifest.debug_locators.keys() {
208            this.update(cx, |this, _cx| {
209                this.proxy
210                    .register_debug_locator(wasm_extension.clone(), debug_locator.clone());
211            })?;
212            log::info!("Loaded debug locator: {}", debug_locator);
213        }
214
215        Ok(())
216    }
217
218    fn uninstall_extension(
219        &mut self,
220        extension_id: &Arc<str>,
221        cx: &mut Context<Self>,
222    ) -> Task<Result<()>> {
223        self.loaded_extensions.remove(extension_id);
224
225        let languages_to_remove = self
226            .loaded_languages
227            .remove(extension_id)
228            .unwrap_or_default();
229        self.proxy.remove_languages(&languages_to_remove, &[]);
230
231        let servers_to_remove = self
232            .loaded_language_servers
233            .remove(extension_id)
234            .unwrap_or_default();
235        let proxy = self.proxy.clone();
236        let path = self.extension_dir.join(&extension_id.to_string());
237        let fs = self.fs.clone();
238        cx.spawn(async move |_, cx| {
239            let mut removal_tasks = Vec::with_capacity(servers_to_remove.len());
240            cx.update(|cx| {
241                for (language_server_name, language) in servers_to_remove {
242                    removal_tasks.push(proxy.remove_language_server(
243                        &language,
244                        &language_server_name,
245                        cx,
246                    ));
247                }
248            })
249            .ok();
250            let _ = join_all(removal_tasks).await;
251
252            fs.remove_dir(
253                &path,
254                RemoveOptions {
255                    recursive: true,
256                    ignore_if_not_exists: true,
257                },
258            )
259            .await
260            .with_context(|| format!("Removing directory {path:?}"))
261        })
262    }
263
264    pub fn install_extension(
265        &mut self,
266        extension: ExtensionVersion,
267        tmp_path: PathBuf,
268        cx: &mut Context<Self>,
269    ) -> Task<Result<()>> {
270        let path = self.extension_dir.join(&extension.id);
271        let fs = self.fs.clone();
272
273        cx.spawn(async move |this, cx| {
274            if fs.is_dir(&path).await {
275                this.update(cx, |this, cx| {
276                    this.uninstall_extension(&extension.id.clone().into(), cx)
277                })?
278                .await?;
279            }
280
281            fs.rename(&tmp_path, &path, RenameOptions::default())
282                .await?;
283
284            Self::load_extension(this, extension, cx).await
285        })
286    }
287
288    pub async fn handle_sync_extensions(
289        extension_store: Entity<HeadlessExtensionStore>,
290        envelope: TypedEnvelope<proto::SyncExtensions>,
291        mut cx: AsyncApp,
292    ) -> Result<proto::SyncExtensionsResponse> {
293        let requested_extensions =
294            envelope
295                .payload
296                .extensions
297                .into_iter()
298                .map(|p| ExtensionVersion {
299                    id: p.id,
300                    version: p.version,
301                    dev: p.dev,
302                });
303        let missing_extensions = extension_store
304            .update(&mut cx, |extension_store, cx| {
305                extension_store.sync_extensions(requested_extensions.collect(), cx)
306            })?
307            .await?;
308
309        Ok(proto::SyncExtensionsResponse {
310            missing_extensions: missing_extensions
311                .into_iter()
312                .map(|e| proto::Extension {
313                    id: e.id,
314                    version: e.version,
315                    dev: e.dev,
316                })
317                .collect(),
318            tmp_dir: paths::remote_extensions_uploads_dir()
319                .to_string_lossy()
320                .to_string(),
321        })
322    }
323
324    pub async fn handle_install_extension(
325        extensions: Entity<HeadlessExtensionStore>,
326        envelope: TypedEnvelope<proto::InstallExtension>,
327        mut cx: AsyncApp,
328    ) -> Result<proto::Ack> {
329        let extension = envelope
330            .payload
331            .extension
332            .context("Invalid InstallExtension request")?;
333
334        extensions
335            .update(&mut cx, |extensions, cx| {
336                extensions.install_extension(
337                    ExtensionVersion {
338                        id: extension.id,
339                        version: extension.version,
340                        dev: extension.dev,
341                    },
342                    PathBuf::from(envelope.payload.tmp_dir),
343                    cx,
344                )
345            })?
346            .await?;
347
348        Ok(proto::Ack {})
349    }
350}