headless_host.rs

  1use std::{path::PathBuf, sync::Arc};
  2
  3use anyhow::{anyhow, Context as _, Result};
  4use client::{proto, TypedEnvelope};
  5use collections::{HashMap, HashSet};
  6use extension::{
  7    Extension, ExtensionHostProxy, ExtensionLanguageProxy, ExtensionLanguageServerProxy,
  8    ExtensionManifest,
  9};
 10use fs::{Fs, RemoveOptions, RenameOptions};
 11use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
 12use http_client::HttpClient;
 13use language::{LanguageConfig, LanguageName, LanguageQueries, LoadedLanguage};
 14use lsp::LanguageServerName;
 15use node_runtime::NodeRuntime;
 16
 17use crate::wasm_host::{WasmExtension, WasmHost};
 18
 19#[derive(Clone, Debug)]
 20pub struct ExtensionVersion {
 21    pub id: String,
 22    pub version: String,
 23    pub dev: bool,
 24}
 25
 26pub struct HeadlessExtensionStore {
 27    pub fs: Arc<dyn Fs>,
 28    pub extension_dir: PathBuf,
 29    pub proxy: Arc<ExtensionHostProxy>,
 30    pub wasm_host: Arc<WasmHost>,
 31    pub loaded_extensions: HashMap<Arc<str>, Arc<str>>,
 32    pub loaded_languages: HashMap<Arc<str>, Vec<LanguageName>>,
 33    pub loaded_language_servers: HashMap<Arc<str>, Vec<(LanguageServerName, LanguageName)>>,
 34}
 35
 36impl HeadlessExtensionStore {
 37    pub fn new(
 38        fs: Arc<dyn Fs>,
 39        http_client: Arc<dyn HttpClient>,
 40        extension_dir: PathBuf,
 41        extension_host_proxy: Arc<ExtensionHostProxy>,
 42        node_runtime: NodeRuntime,
 43        cx: &mut App,
 44    ) -> Entity<Self> {
 45        cx.new(|cx| Self {
 46            fs: fs.clone(),
 47            wasm_host: WasmHost::new(
 48                fs.clone(),
 49                http_client.clone(),
 50                node_runtime,
 51                extension_host_proxy.clone(),
 52                extension_dir.join("work"),
 53                cx,
 54            ),
 55            extension_dir,
 56            proxy: extension_host_proxy,
 57            loaded_extensions: Default::default(),
 58            loaded_languages: Default::default(),
 59            loaded_language_servers: Default::default(),
 60        })
 61    }
 62
 63    pub fn sync_extensions(
 64        &mut self,
 65        extensions: Vec<ExtensionVersion>,
 66        cx: &Context<Self>,
 67    ) -> Task<Result<Vec<ExtensionVersion>>> {
 68        let on_client = HashSet::from_iter(extensions.iter().map(|e| e.id.as_str()));
 69        let to_remove: Vec<Arc<str>> = self
 70            .loaded_extensions
 71            .keys()
 72            .filter(|id| !on_client.contains(id.as_ref()))
 73            .cloned()
 74            .collect();
 75        let to_load: Vec<ExtensionVersion> = extensions
 76            .into_iter()
 77            .filter(|e| {
 78                if e.dev {
 79                    return true;
 80                }
 81                !self
 82                    .loaded_extensions
 83                    .get(e.id.as_str())
 84                    .is_some_and(|loaded| loaded.as_ref() == e.version.as_str())
 85            })
 86            .collect();
 87
 88        cx.spawn(|this, mut cx| async move {
 89            let mut missing = Vec::new();
 90
 91            for extension_id in to_remove {
 92                log::info!("removing extension: {}", extension_id);
 93                this.update(&mut cx, |this, cx| {
 94                    this.uninstall_extension(&extension_id, cx)
 95                })?
 96                .await?;
 97            }
 98
 99            for extension in to_load {
100                if let Err(e) = Self::load_extension(this.clone(), extension.clone(), &mut cx).await
101                {
102                    log::info!("failed to load extension: {}, {:?}", extension.id, e);
103                    missing.push(extension)
104                } else if extension.dev {
105                    missing.push(extension)
106                }
107            }
108
109            Ok(missing)
110        })
111    }
112
113    pub async fn load_extension(
114        this: WeakEntity<Self>,
115        extension: ExtensionVersion,
116        cx: &mut AsyncApp,
117    ) -> Result<()> {
118        let (fs, wasm_host, extension_dir) = this.update(cx, |this, _cx| {
119            this.loaded_extensions.insert(
120                extension.id.clone().into(),
121                extension.version.clone().into(),
122            );
123            (
124                this.fs.clone(),
125                this.wasm_host.clone(),
126                this.extension_dir.join(&extension.id),
127            )
128        })?;
129
130        let manifest = Arc::new(ExtensionManifest::load(fs.clone(), &extension_dir).await?);
131
132        debug_assert!(!manifest.languages.is_empty() || !manifest.language_servers.is_empty());
133
134        if manifest.version.as_ref() != extension.version.as_str() {
135            anyhow::bail!(
136                "mismatched versions: ({}) != ({})",
137                manifest.version,
138                extension.version
139            )
140        }
141
142        for language_path in &manifest.languages {
143            let language_path = extension_dir.join(language_path);
144            let config = fs.load(&language_path.join("config.toml")).await?;
145            let mut config = ::toml::from_str::<LanguageConfig>(&config)?;
146
147            this.update(cx, |this, _cx| {
148                this.loaded_languages
149                    .entry(manifest.id.clone())
150                    .or_default()
151                    .push(config.name.clone());
152
153                config.grammar = None;
154
155                this.proxy.register_language(
156                    config.name.clone(),
157                    None,
158                    config.matcher.clone(),
159                    config.hidden,
160                    Arc::new(move || {
161                        Ok(LoadedLanguage {
162                            config: config.clone(),
163                            queries: LanguageQueries::default(),
164                            context_provider: None,
165                            toolchain_provider: None,
166                        })
167                    }),
168                );
169            })?;
170        }
171
172        if manifest.language_servers.is_empty() {
173            return Ok(());
174        }
175
176        let wasm_extension: Arc<dyn Extension> =
177            Arc::new(WasmExtension::load(extension_dir, &manifest, wasm_host.clone(), &cx).await?);
178
179        for (language_server_id, language_server_config) in &manifest.language_servers {
180            for language in language_server_config.languages() {
181                this.update(cx, |this, _cx| {
182                    this.loaded_language_servers
183                        .entry(manifest.id.clone())
184                        .or_default()
185                        .push((language_server_id.clone(), language.clone()));
186                    this.proxy.register_language_server(
187                        wasm_extension.clone(),
188                        language_server_id.clone(),
189                        language.clone(),
190                    );
191                })?;
192            }
193        }
194
195        Ok(())
196    }
197
198    fn uninstall_extension(
199        &mut self,
200        extension_id: &Arc<str>,
201        cx: &mut Context<Self>,
202    ) -> Task<Result<()>> {
203        self.loaded_extensions.remove(extension_id);
204
205        let languages_to_remove = self
206            .loaded_languages
207            .remove(extension_id)
208            .unwrap_or_default();
209        self.proxy.remove_languages(&languages_to_remove, &[]);
210
211        for (language_server_name, language) in self
212            .loaded_language_servers
213            .remove(extension_id)
214            .unwrap_or_default()
215        {
216            self.proxy
217                .remove_language_server(&language, &language_server_name);
218        }
219
220        let path = self.extension_dir.join(&extension_id.to_string());
221        let fs = self.fs.clone();
222        cx.spawn(|_, _| async move {
223            fs.remove_dir(
224                &path,
225                RemoveOptions {
226                    recursive: true,
227                    ignore_if_not_exists: true,
228                },
229            )
230            .await
231        })
232    }
233
234    pub fn install_extension(
235        &mut self,
236        extension: ExtensionVersion,
237        tmp_path: PathBuf,
238        cx: &mut Context<Self>,
239    ) -> Task<Result<()>> {
240        let path = self.extension_dir.join(&extension.id);
241        let fs = self.fs.clone();
242
243        cx.spawn(|this, mut cx| async move {
244            if fs.is_dir(&path).await {
245                this.update(&mut cx, |this, cx| {
246                    this.uninstall_extension(&extension.id.clone().into(), cx)
247                })?
248                .await?;
249            }
250
251            fs.rename(&tmp_path, &path, RenameOptions::default())
252                .await?;
253
254            Self::load_extension(this, extension, &mut cx).await
255        })
256    }
257
258    pub async fn handle_sync_extensions(
259        extension_store: Entity<HeadlessExtensionStore>,
260        envelope: TypedEnvelope<proto::SyncExtensions>,
261        mut cx: AsyncApp,
262    ) -> Result<proto::SyncExtensionsResponse> {
263        let requested_extensions =
264            envelope
265                .payload
266                .extensions
267                .into_iter()
268                .map(|p| ExtensionVersion {
269                    id: p.id,
270                    version: p.version,
271                    dev: p.dev,
272                });
273        let missing_extensions = extension_store
274            .update(&mut cx, |extension_store, cx| {
275                extension_store.sync_extensions(requested_extensions.collect(), cx)
276            })?
277            .await?;
278
279        Ok(proto::SyncExtensionsResponse {
280            missing_extensions: missing_extensions
281                .into_iter()
282                .map(|e| proto::Extension {
283                    id: e.id,
284                    version: e.version,
285                    dev: e.dev,
286                })
287                .collect(),
288            tmp_dir: paths::remote_extensions_uploads_dir()
289                .to_string_lossy()
290                .to_string(),
291        })
292    }
293
294    pub async fn handle_install_extension(
295        extensions: Entity<HeadlessExtensionStore>,
296        envelope: TypedEnvelope<proto::InstallExtension>,
297        mut cx: AsyncApp,
298    ) -> Result<proto::Ack> {
299        let extension = envelope
300            .payload
301            .extension
302            .with_context(|| anyhow!("Invalid InstallExtension request"))?;
303
304        extensions
305            .update(&mut cx, |extensions, cx| {
306                extensions.install_extension(
307                    ExtensionVersion {
308                        id: extension.id,
309                        version: extension.version,
310                        dev: extension.dev,
311                    },
312                    PathBuf::from(envelope.payload.tmp_dir),
313                    cx,
314                )
315            })?
316            .await?;
317
318        Ok(proto::Ack {})
319    }
320}