headless_host.rs

  1use std::{path::PathBuf, sync::Arc};
  2
  3use anyhow::{Context as _, Result};
  4use client::{
  5    TypedEnvelope,
  6    proto::{self, FromProto},
  7};
  8use collections::{HashMap, HashSet};
  9use extension::{
 10    Extension, ExtensionDebugAdapterProviderProxy, ExtensionHostProxy, ExtensionLanguageProxy,
 11    ExtensionLanguageServerProxy, ExtensionManifest,
 12};
 13use fs::{Fs, RemoveOptions, RenameOptions};
 14use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
 15use http_client::HttpClient;
 16use language::{LanguageConfig, LanguageName, LanguageQueries, LoadedLanguage};
 17use lsp::LanguageServerName;
 18use node_runtime::NodeRuntime;
 19
 20use crate::wasm_host::{WasmExtension, WasmHost};
 21
 22#[derive(Clone, Debug)]
 23pub struct ExtensionVersion {
 24    pub id: String,
 25    pub version: String,
 26    pub dev: bool,
 27}
 28
 29pub struct HeadlessExtensionStore {
 30    pub fs: Arc<dyn Fs>,
 31    pub extension_dir: PathBuf,
 32    pub proxy: Arc<ExtensionHostProxy>,
 33    pub wasm_host: Arc<WasmHost>,
 34    pub loaded_extensions: HashMap<Arc<str>, Arc<str>>,
 35    pub loaded_languages: HashMap<Arc<str>, Vec<LanguageName>>,
 36    pub loaded_language_servers: HashMap<Arc<str>, Vec<(LanguageServerName, LanguageName)>>,
 37}
 38
 39impl HeadlessExtensionStore {
 40    pub fn new(
 41        fs: Arc<dyn Fs>,
 42        http_client: Arc<dyn HttpClient>,
 43        extension_dir: PathBuf,
 44        extension_host_proxy: Arc<ExtensionHostProxy>,
 45        node_runtime: NodeRuntime,
 46        cx: &mut App,
 47    ) -> Entity<Self> {
 48        cx.new(|cx| Self {
 49            fs: fs.clone(),
 50            wasm_host: WasmHost::new(
 51                fs.clone(),
 52                http_client.clone(),
 53                node_runtime,
 54                extension_host_proxy.clone(),
 55                extension_dir.join("work"),
 56                cx,
 57            ),
 58            extension_dir,
 59            proxy: extension_host_proxy,
 60            loaded_extensions: Default::default(),
 61            loaded_languages: Default::default(),
 62            loaded_language_servers: Default::default(),
 63        })
 64    }
 65
 66    pub fn sync_extensions(
 67        &mut self,
 68        extensions: Vec<ExtensionVersion>,
 69        cx: &Context<Self>,
 70    ) -> Task<Result<Vec<ExtensionVersion>>> {
 71        let on_client = HashSet::from_iter(extensions.iter().map(|e| e.id.as_str()));
 72        let to_remove: Vec<Arc<str>> = self
 73            .loaded_extensions
 74            .keys()
 75            .filter(|id| !on_client.contains(id.as_ref()))
 76            .cloned()
 77            .collect();
 78        let to_load: Vec<ExtensionVersion> = extensions
 79            .into_iter()
 80            .filter(|e| {
 81                if e.dev {
 82                    return true;
 83                }
 84                self.loaded_extensions
 85                    .get(e.id.as_str())
 86                    .is_none_or(|loaded| loaded.as_ref() != e.version.as_str())
 87            })
 88            .collect();
 89
 90        cx.spawn(async move |this, cx| {
 91            let mut missing = Vec::new();
 92
 93            for extension_id in to_remove {
 94                log::info!("removing extension: {}", extension_id);
 95                this.update(cx, |this, cx| this.uninstall_extension(&extension_id, cx))?
 96                    .await?;
 97            }
 98
 99            for extension in to_load {
100                if let Err(e) = Self::load_extension(this.clone(), extension.clone(), cx).await {
101                    log::info!("failed to load extension: {}, {:?}", extension.id, e);
102                    missing.push(extension)
103                } else if extension.dev {
104                    missing.push(extension)
105                }
106            }
107
108            Ok(missing)
109        })
110    }
111
112    pub async fn load_extension(
113        this: WeakEntity<Self>,
114        extension: ExtensionVersion,
115        cx: &mut AsyncApp,
116    ) -> Result<()> {
117        let (fs, wasm_host, extension_dir) = this.update(cx, |this, _cx| {
118            this.loaded_extensions.insert(
119                extension.id.clone().into(),
120                extension.version.clone().into(),
121            );
122            (
123                this.fs.clone(),
124                this.wasm_host.clone(),
125                this.extension_dir.join(&extension.id),
126            )
127        })?;
128
129        let manifest = Arc::new(ExtensionManifest::load(fs.clone(), &extension_dir).await?);
130
131        debug_assert!(!manifest.languages.is_empty() || manifest.allow_remote_load());
132
133        if manifest.version.as_ref() != extension.version.as_str() {
134            anyhow::bail!(
135                "mismatched versions: ({}) != ({})",
136                manifest.version,
137                extension.version
138            )
139        }
140
141        for language_path in &manifest.languages {
142            let language_path = extension_dir.join(language_path);
143            let config = fs.load(&language_path.join("config.toml")).await?;
144            let mut config = ::toml::from_str::<LanguageConfig>(&config)?;
145
146            this.update(cx, |this, _cx| {
147                this.loaded_languages
148                    .entry(manifest.id.clone())
149                    .or_default()
150                    .push(config.name.clone());
151
152                config.grammar = None;
153
154                this.proxy.register_language(
155                    config.name.clone(),
156                    None,
157                    config.matcher.clone(),
158                    config.hidden,
159                    Arc::new(move || {
160                        Ok(LoadedLanguage {
161                            config: config.clone(),
162                            queries: LanguageQueries::default(),
163                            context_provider: None,
164                            toolchain_provider: None,
165                        })
166                    }),
167                );
168            })?;
169        }
170
171        if !manifest.allow_remote_load() {
172            return Ok(());
173        }
174
175        let wasm_extension: Arc<dyn Extension> = Arc::new(
176            WasmExtension::load(extension_dir.clone(), &manifest, wasm_host.clone(), &cx).await?,
177        );
178
179        for (language_server_id, language_server_config) in &manifest.language_servers {
180            for language in language_server_config.languages() {
181                this.update(cx, |this, _cx| {
182                    this.loaded_language_servers
183                        .entry(manifest.id.clone())
184                        .or_default()
185                        .push((language_server_id.clone(), language.clone()));
186                    this.proxy.register_language_server(
187                        wasm_extension.clone(),
188                        language_server_id.clone(),
189                        language.clone(),
190                    );
191                })?;
192            }
193            log::info!("Loaded language server: {}", language_server_id);
194        }
195
196        for (debug_adapter, meta) in &manifest.debug_adapters {
197            let schema_path = extension::build_debug_adapter_schema_path(debug_adapter, meta);
198
199            this.update(cx, |this, _cx| {
200                this.proxy.register_debug_adapter(
201                    wasm_extension.clone(),
202                    debug_adapter.clone(),
203                    &extension_dir.join(schema_path),
204                );
205            })?;
206            log::info!("Loaded debug adapter: {}", debug_adapter);
207        }
208
209        for debug_locator in manifest.debug_locators.keys() {
210            this.update(cx, |this, _cx| {
211                this.proxy
212                    .register_debug_locator(wasm_extension.clone(), debug_locator.clone());
213            })?;
214            log::info!("Loaded debug locator: {}", debug_locator);
215        }
216
217        Ok(())
218    }
219
220    fn uninstall_extension(
221        &mut self,
222        extension_id: &Arc<str>,
223        cx: &mut Context<Self>,
224    ) -> Task<Result<()>> {
225        self.loaded_extensions.remove(extension_id);
226
227        let languages_to_remove = self
228            .loaded_languages
229            .remove(extension_id)
230            .unwrap_or_default();
231        self.proxy.remove_languages(&languages_to_remove, &[]);
232
233        for (language_server_name, language) in self
234            .loaded_language_servers
235            .remove(extension_id)
236            .unwrap_or_default()
237        {
238            self.proxy
239                .remove_language_server(&language, &language_server_name);
240        }
241
242        let path = self.extension_dir.join(&extension_id.to_string());
243        let fs = self.fs.clone();
244        cx.spawn(async move |_, _| {
245            fs.remove_dir(
246                &path,
247                RemoveOptions {
248                    recursive: true,
249                    ignore_if_not_exists: true,
250                },
251            )
252            .await
253        })
254    }
255
256    pub fn install_extension(
257        &mut self,
258        extension: ExtensionVersion,
259        tmp_path: PathBuf,
260        cx: &mut Context<Self>,
261    ) -> Task<Result<()>> {
262        let path = self.extension_dir.join(&extension.id);
263        let fs = self.fs.clone();
264
265        cx.spawn(async move |this, cx| {
266            if fs.is_dir(&path).await {
267                this.update(cx, |this, cx| {
268                    this.uninstall_extension(&extension.id.clone().into(), cx)
269                })?
270                .await?;
271            }
272
273            fs.rename(&tmp_path, &path, RenameOptions::default())
274                .await?;
275
276            Self::load_extension(this, extension, cx).await
277        })
278    }
279
280    pub async fn handle_sync_extensions(
281        extension_store: Entity<HeadlessExtensionStore>,
282        envelope: TypedEnvelope<proto::SyncExtensions>,
283        mut cx: AsyncApp,
284    ) -> Result<proto::SyncExtensionsResponse> {
285        let requested_extensions =
286            envelope
287                .payload
288                .extensions
289                .into_iter()
290                .map(|p| ExtensionVersion {
291                    id: p.id,
292                    version: p.version,
293                    dev: p.dev,
294                });
295        let missing_extensions = extension_store
296            .update(&mut cx, |extension_store, cx| {
297                extension_store.sync_extensions(requested_extensions.collect(), cx)
298            })?
299            .await?;
300
301        Ok(proto::SyncExtensionsResponse {
302            missing_extensions: missing_extensions
303                .into_iter()
304                .map(|e| proto::Extension {
305                    id: e.id,
306                    version: e.version,
307                    dev: e.dev,
308                })
309                .collect(),
310            tmp_dir: paths::remote_extensions_uploads_dir()
311                .to_string_lossy()
312                .to_string(),
313        })
314    }
315
316    pub async fn handle_install_extension(
317        extensions: Entity<HeadlessExtensionStore>,
318        envelope: TypedEnvelope<proto::InstallExtension>,
319        mut cx: AsyncApp,
320    ) -> Result<proto::Ack> {
321        let extension = envelope
322            .payload
323            .extension
324            .context("Invalid InstallExtension request")?;
325
326        extensions
327            .update(&mut cx, |extensions, cx| {
328                extensions.install_extension(
329                    ExtensionVersion {
330                        id: extension.id,
331                        version: extension.version,
332                        dev: extension.dev,
333                    },
334                    PathBuf::from_proto(envelope.payload.tmp_dir),
335                    cx,
336                )
337            })?
338            .await?;
339
340        Ok(proto::Ack {})
341    }
342}