headless_host.rs

  1use std::{path::PathBuf, sync::Arc};
  2
  3use anyhow::{Context as _, Result};
  4use client::{TypedEnvelope, proto};
  5use collections::{HashMap, HashSet};
  6use extension::{
  7    Extension, ExtensionDebugAdapterProviderProxy, ExtensionHostProxy, ExtensionLanguageProxy,
  8    ExtensionLanguageServerProxy, ExtensionManifest,
  9};
 10use fs::{Fs, RemoveOptions, RenameOptions};
 11use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
 12use http_client::HttpClient;
 13use language::{LanguageConfig, LanguageName, LanguageQueries, LoadedLanguage};
 14use lsp::LanguageServerName;
 15use node_runtime::NodeRuntime;
 16
 17use crate::wasm_host::{WasmExtension, WasmHost};
 18
 19#[derive(Clone, Debug)]
 20pub struct ExtensionVersion {
 21    pub id: String,
 22    pub version: String,
 23    pub dev: bool,
 24}
 25
 26pub struct HeadlessExtensionStore {
 27    pub fs: Arc<dyn Fs>,
 28    pub extension_dir: PathBuf,
 29    pub proxy: Arc<ExtensionHostProxy>,
 30    pub wasm_host: Arc<WasmHost>,
 31    pub loaded_extensions: HashMap<Arc<str>, Arc<str>>,
 32    pub loaded_languages: HashMap<Arc<str>, Vec<LanguageName>>,
 33    pub loaded_language_servers: HashMap<Arc<str>, Vec<(LanguageServerName, LanguageName)>>,
 34}
 35
 36impl HeadlessExtensionStore {
 37    pub fn new(
 38        fs: Arc<dyn Fs>,
 39        http_client: Arc<dyn HttpClient>,
 40        extension_dir: PathBuf,
 41        extension_host_proxy: Arc<ExtensionHostProxy>,
 42        node_runtime: NodeRuntime,
 43        cx: &mut App,
 44    ) -> Entity<Self> {
 45        cx.new(|cx| Self {
 46            fs: fs.clone(),
 47            wasm_host: WasmHost::new(
 48                fs.clone(),
 49                http_client.clone(),
 50                node_runtime,
 51                extension_host_proxy.clone(),
 52                extension_dir.join("work"),
 53                cx,
 54            ),
 55            extension_dir,
 56            proxy: extension_host_proxy,
 57            loaded_extensions: Default::default(),
 58            loaded_languages: Default::default(),
 59            loaded_language_servers: Default::default(),
 60        })
 61    }
 62
 63    pub fn sync_extensions(
 64        &mut self,
 65        extensions: Vec<ExtensionVersion>,
 66        cx: &Context<Self>,
 67    ) -> Task<Result<Vec<ExtensionVersion>>> {
 68        let on_client = HashSet::from_iter(extensions.iter().map(|e| e.id.as_str()));
 69        let to_remove: Vec<Arc<str>> = self
 70            .loaded_extensions
 71            .keys()
 72            .filter(|id| !on_client.contains(id.as_ref()))
 73            .cloned()
 74            .collect();
 75        let to_load: Vec<ExtensionVersion> = extensions
 76            .into_iter()
 77            .filter(|e| {
 78                if e.dev {
 79                    return true;
 80                }
 81                self.loaded_extensions
 82                    .get(e.id.as_str())
 83                    .is_none_or(|loaded| loaded.as_ref() != e.version.as_str())
 84            })
 85            .collect();
 86
 87        cx.spawn(async move |this, cx| {
 88            let mut missing = Vec::new();
 89
 90            for extension_id in to_remove {
 91                log::info!("removing extension: {}", extension_id);
 92                this.update(cx, |this, cx| this.uninstall_extension(&extension_id, cx))?
 93                    .await?;
 94            }
 95
 96            for extension in to_load {
 97                if let Err(e) = Self::load_extension(this.clone(), extension.clone(), cx).await {
 98                    log::info!("failed to load extension: {}, {:?}", extension.id, e);
 99                    missing.push(extension)
100                } else if extension.dev {
101                    missing.push(extension)
102                }
103            }
104
105            Ok(missing)
106        })
107    }
108
109    pub async fn load_extension(
110        this: WeakEntity<Self>,
111        extension: ExtensionVersion,
112        cx: &mut AsyncApp,
113    ) -> Result<()> {
114        let (fs, wasm_host, extension_dir) = this.update(cx, |this, _cx| {
115            this.loaded_extensions.insert(
116                extension.id.clone().into(),
117                extension.version.clone().into(),
118            );
119            (
120                this.fs.clone(),
121                this.wasm_host.clone(),
122                this.extension_dir.join(&extension.id),
123            )
124        })?;
125
126        let manifest = Arc::new(ExtensionManifest::load(fs.clone(), &extension_dir).await?);
127
128        debug_assert!(!manifest.languages.is_empty() || manifest.allow_remote_load());
129
130        if manifest.version.as_ref() != extension.version.as_str() {
131            anyhow::bail!(
132                "mismatched versions: ({}) != ({})",
133                manifest.version,
134                extension.version
135            )
136        }
137
138        for language_path in &manifest.languages {
139            let language_path = extension_dir.join(language_path);
140            let config = fs.load(&language_path.join("config.toml")).await?;
141            let mut config = ::toml::from_str::<LanguageConfig>(&config)?;
142
143            this.update(cx, |this, _cx| {
144                this.loaded_languages
145                    .entry(manifest.id.clone())
146                    .or_default()
147                    .push(config.name.clone());
148
149                config.grammar = None;
150
151                this.proxy.register_language(
152                    config.name.clone(),
153                    None,
154                    config.matcher.clone(),
155                    config.hidden,
156                    Arc::new(move || {
157                        Ok(LoadedLanguage {
158                            config: config.clone(),
159                            queries: LanguageQueries::default(),
160                            context_provider: None,
161                            toolchain_provider: None,
162                        })
163                    }),
164                );
165            })?;
166        }
167
168        if !manifest.allow_remote_load() {
169            return Ok(());
170        }
171
172        let wasm_extension: Arc<dyn Extension> = Arc::new(
173            WasmExtension::load(extension_dir.clone(), &manifest, wasm_host.clone(), &cx).await?,
174        );
175
176        for (language_server_id, language_server_config) in &manifest.language_servers {
177            for language in language_server_config.languages() {
178                this.update(cx, |this, _cx| {
179                    this.loaded_language_servers
180                        .entry(manifest.id.clone())
181                        .or_default()
182                        .push((language_server_id.clone(), language.clone()));
183                    this.proxy.register_language_server(
184                        wasm_extension.clone(),
185                        language_server_id.clone(),
186                        language.clone(),
187                    );
188                })?;
189            }
190            log::info!("Loaded language server: {}", language_server_id);
191        }
192
193        for (debug_adapter, meta) in &manifest.debug_adapters {
194            let schema_path = extension::build_debug_adapter_schema_path(debug_adapter, meta);
195
196            this.update(cx, |this, _cx| {
197                this.proxy.register_debug_adapter(
198                    wasm_extension.clone(),
199                    debug_adapter.clone(),
200                    &extension_dir.join(schema_path),
201                );
202            })?;
203            log::info!("Loaded debug adapter: {}", debug_adapter);
204        }
205
206        for debug_locator in manifest.debug_locators.keys() {
207            this.update(cx, |this, _cx| {
208                this.proxy
209                    .register_debug_locator(wasm_extension.clone(), debug_locator.clone());
210            })?;
211            log::info!("Loaded debug locator: {}", debug_locator);
212        }
213
214        Ok(())
215    }
216
217    fn uninstall_extension(
218        &mut self,
219        extension_id: &Arc<str>,
220        cx: &mut Context<Self>,
221    ) -> Task<Result<()>> {
222        self.loaded_extensions.remove(extension_id);
223
224        let languages_to_remove = self
225            .loaded_languages
226            .remove(extension_id)
227            .unwrap_or_default();
228        self.proxy.remove_languages(&languages_to_remove, &[]);
229
230        for (language_server_name, language) in self
231            .loaded_language_servers
232            .remove(extension_id)
233            .unwrap_or_default()
234        {
235            self.proxy
236                .remove_language_server(&language, &language_server_name);
237        }
238
239        let path = self.extension_dir.join(&extension_id.to_string());
240        let fs = self.fs.clone();
241        cx.spawn(async move |_, _| {
242            fs.remove_dir(
243                &path,
244                RemoveOptions {
245                    recursive: true,
246                    ignore_if_not_exists: true,
247                },
248            )
249            .await
250        })
251    }
252
253    pub fn install_extension(
254        &mut self,
255        extension: ExtensionVersion,
256        tmp_path: PathBuf,
257        cx: &mut Context<Self>,
258    ) -> Task<Result<()>> {
259        let path = self.extension_dir.join(&extension.id);
260        let fs = self.fs.clone();
261
262        cx.spawn(async move |this, cx| {
263            if fs.is_dir(&path).await {
264                this.update(cx, |this, cx| {
265                    this.uninstall_extension(&extension.id.clone().into(), cx)
266                })?
267                .await?;
268            }
269
270            fs.rename(&tmp_path, &path, RenameOptions::default())
271                .await?;
272
273            Self::load_extension(this, extension, cx).await
274        })
275    }
276
277    pub async fn handle_sync_extensions(
278        extension_store: Entity<HeadlessExtensionStore>,
279        envelope: TypedEnvelope<proto::SyncExtensions>,
280        mut cx: AsyncApp,
281    ) -> Result<proto::SyncExtensionsResponse> {
282        let requested_extensions =
283            envelope
284                .payload
285                .extensions
286                .into_iter()
287                .map(|p| ExtensionVersion {
288                    id: p.id,
289                    version: p.version,
290                    dev: p.dev,
291                });
292        let missing_extensions = extension_store
293            .update(&mut cx, |extension_store, cx| {
294                extension_store.sync_extensions(requested_extensions.collect(), cx)
295            })?
296            .await?;
297
298        Ok(proto::SyncExtensionsResponse {
299            missing_extensions: missing_extensions
300                .into_iter()
301                .map(|e| proto::Extension {
302                    id: e.id,
303                    version: e.version,
304                    dev: e.dev,
305                })
306                .collect(),
307            tmp_dir: paths::remote_extensions_uploads_dir()
308                .to_string_lossy()
309                .to_string(),
310        })
311    }
312
313    pub async fn handle_install_extension(
314        extensions: Entity<HeadlessExtensionStore>,
315        envelope: TypedEnvelope<proto::InstallExtension>,
316        mut cx: AsyncApp,
317    ) -> Result<proto::Ack> {
318        let extension = envelope
319            .payload
320            .extension
321            .context("Invalid InstallExtension request")?;
322
323        extensions
324            .update(&mut cx, |extensions, cx| {
325                extensions.install_extension(
326                    ExtensionVersion {
327                        id: extension.id,
328                        version: extension.version,
329                        dev: extension.dev,
330                    },
331                    PathBuf::from(envelope.payload.tmp_dir),
332                    cx,
333                )
334            })?
335            .await?;
336
337        Ok(proto::Ack {})
338    }
339}