headless_host.rs

  1use std::{path::PathBuf, sync::Arc};
  2
  3use anyhow::{anyhow, Context as _, Result};
  4use client::{proto, TypedEnvelope};
  5use collections::{HashMap, HashSet};
  6use extension::{Extension, ExtensionManifest};
  7use fs::{Fs, RemoveOptions, RenameOptions};
  8use gpui::{AppContext, AsyncAppContext, Context, Model, ModelContext, Task, WeakModel};
  9use http_client::HttpClient;
 10use language::{LanguageConfig, LanguageName, LanguageQueries, LanguageRegistry, LoadedLanguage};
 11use lsp::LanguageServerName;
 12use node_runtime::NodeRuntime;
 13
 14use crate::{
 15    extension_lsp_adapter::ExtensionLspAdapter,
 16    wasm_host::{WasmExtension, WasmHost},
 17    ExtensionRegistrationHooks,
 18};
 19
 20pub struct HeadlessExtensionStore {
 21    pub registration_hooks: Arc<dyn ExtensionRegistrationHooks>,
 22    pub fs: Arc<dyn Fs>,
 23    pub extension_dir: PathBuf,
 24    pub wasm_host: Arc<WasmHost>,
 25    pub loaded_extensions: HashMap<Arc<str>, Arc<str>>,
 26    pub loaded_languages: HashMap<Arc<str>, Vec<LanguageName>>,
 27    pub loaded_language_servers: HashMap<Arc<str>, Vec<(LanguageServerName, LanguageName)>>,
 28}
 29
 30#[derive(Clone, Debug)]
 31pub struct ExtensionVersion {
 32    pub id: String,
 33    pub version: String,
 34    pub dev: bool,
 35}
 36
 37impl HeadlessExtensionStore {
 38    pub fn new(
 39        fs: Arc<dyn Fs>,
 40        http_client: Arc<dyn HttpClient>,
 41        languages: Arc<LanguageRegistry>,
 42        extension_dir: PathBuf,
 43        node_runtime: NodeRuntime,
 44        cx: &mut AppContext,
 45    ) -> Model<Self> {
 46        let registration_hooks = Arc::new(HeadlessRegistrationHooks::new(languages.clone()));
 47        cx.new_model(|cx| Self {
 48            registration_hooks: registration_hooks.clone(),
 49            fs: fs.clone(),
 50            wasm_host: WasmHost::new(
 51                fs.clone(),
 52                http_client.clone(),
 53                node_runtime,
 54                registration_hooks,
 55                extension_dir.join("work"),
 56                cx,
 57            ),
 58            extension_dir,
 59            loaded_extensions: Default::default(),
 60            loaded_languages: Default::default(),
 61            loaded_language_servers: Default::default(),
 62        })
 63    }
 64
 65    pub fn sync_extensions(
 66        &mut self,
 67        extensions: Vec<ExtensionVersion>,
 68        cx: &ModelContext<Self>,
 69    ) -> Task<Result<Vec<ExtensionVersion>>> {
 70        let on_client = HashSet::from_iter(extensions.iter().map(|e| e.id.as_str()));
 71        let to_remove: Vec<Arc<str>> = self
 72            .loaded_extensions
 73            .keys()
 74            .filter(|id| !on_client.contains(id.as_ref()))
 75            .cloned()
 76            .collect();
 77        let to_load: Vec<ExtensionVersion> = extensions
 78            .into_iter()
 79            .filter(|e| {
 80                if e.dev {
 81                    return true;
 82                }
 83                !self
 84                    .loaded_extensions
 85                    .get(e.id.as_str())
 86                    .is_some_and(|loaded| loaded.as_ref() == e.version.as_str())
 87            })
 88            .collect();
 89
 90        cx.spawn(|this, mut cx| async move {
 91            let mut missing = Vec::new();
 92
 93            for extension_id in to_remove {
 94                log::info!("removing extension: {}", extension_id);
 95                this.update(&mut cx, |this, cx| {
 96                    this.uninstall_extension(&extension_id, cx)
 97                })?
 98                .await?;
 99            }
100
101            for extension in to_load {
102                if let Err(e) = Self::load_extension(this.clone(), extension.clone(), &mut cx).await
103                {
104                    log::info!("failed to load extension: {}, {:?}", extension.id, e);
105                    missing.push(extension)
106                } else if extension.dev {
107                    missing.push(extension)
108                }
109            }
110
111            Ok(missing)
112        })
113    }
114
115    pub async fn load_extension(
116        this: WeakModel<Self>,
117        extension: ExtensionVersion,
118        cx: &mut AsyncAppContext,
119    ) -> Result<()> {
120        let (fs, wasm_host, extension_dir) = this.update(cx, |this, _cx| {
121            this.loaded_extensions.insert(
122                extension.id.clone().into(),
123                extension.version.clone().into(),
124            );
125            (
126                this.fs.clone(),
127                this.wasm_host.clone(),
128                this.extension_dir.join(&extension.id),
129            )
130        })?;
131
132        let manifest = Arc::new(ExtensionManifest::load(fs.clone(), &extension_dir).await?);
133
134        debug_assert!(!manifest.languages.is_empty() || !manifest.language_servers.is_empty());
135
136        if manifest.version.as_ref() != extension.version.as_str() {
137            anyhow::bail!(
138                "mismatched versions: ({}) != ({})",
139                manifest.version,
140                extension.version
141            )
142        }
143
144        for language_path in &manifest.languages {
145            let language_path = extension_dir.join(language_path);
146            let config = fs.load(&language_path.join("config.toml")).await?;
147            let mut config = ::toml::from_str::<LanguageConfig>(&config)?;
148
149            this.update(cx, |this, _cx| {
150                this.loaded_languages
151                    .entry(manifest.id.clone())
152                    .or_default()
153                    .push(config.name.clone());
154
155                config.grammar = None;
156
157                this.registration_hooks.register_language(
158                    config.name.clone(),
159                    None,
160                    config.matcher.clone(),
161                    Arc::new(move || {
162                        Ok(LoadedLanguage {
163                            config: config.clone(),
164                            queries: LanguageQueries::default(),
165                            context_provider: None,
166                            toolchain_provider: None,
167                        })
168                    }),
169                );
170            })?;
171        }
172
173        if manifest.language_servers.is_empty() {
174            return Ok(());
175        }
176
177        let wasm_extension: Arc<dyn Extension> =
178            Arc::new(WasmExtension::load(extension_dir, &manifest, wasm_host.clone(), &cx).await?);
179
180        for (language_server_name, language_server_config) in &manifest.language_servers {
181            for language in language_server_config.languages() {
182                this.update(cx, |this, _cx| {
183                    this.loaded_language_servers
184                        .entry(manifest.id.clone())
185                        .or_default()
186                        .push((language_server_name.clone(), language.clone()));
187                    this.registration_hooks.register_lsp_adapter(
188                        language.clone(),
189                        ExtensionLspAdapter {
190                            extension: wasm_extension.clone(),
191                            language_server_id: language_server_name.clone(),
192                            language_name: language,
193                        },
194                    );
195                })?;
196            }
197        }
198
199        Ok(())
200    }
201
202    fn uninstall_extension(
203        &mut self,
204        extension_id: &Arc<str>,
205        cx: &mut ModelContext<Self>,
206    ) -> Task<Result<()>> {
207        self.loaded_extensions.remove(extension_id);
208        let languages_to_remove = self
209            .loaded_languages
210            .remove(extension_id)
211            .unwrap_or_default();
212        self.registration_hooks
213            .remove_languages(&languages_to_remove, &[]);
214        for (language_server_name, language) in self
215            .loaded_language_servers
216            .remove(extension_id)
217            .unwrap_or_default()
218        {
219            self.registration_hooks
220                .remove_lsp_adapter(&language, &language_server_name);
221        }
222
223        let path = self.extension_dir.join(&extension_id.to_string());
224        let fs = self.fs.clone();
225        cx.spawn(|_, _| async move {
226            fs.remove_dir(
227                &path,
228                RemoveOptions {
229                    recursive: true,
230                    ignore_if_not_exists: true,
231                },
232            )
233            .await
234        })
235    }
236
237    pub fn install_extension(
238        &mut self,
239        extension: ExtensionVersion,
240        tmp_path: PathBuf,
241        cx: &mut ModelContext<Self>,
242    ) -> Task<Result<()>> {
243        let path = self.extension_dir.join(&extension.id);
244        let fs = self.fs.clone();
245
246        cx.spawn(|this, mut cx| async move {
247            if fs.is_dir(&path).await {
248                this.update(&mut cx, |this, cx| {
249                    this.uninstall_extension(&extension.id.clone().into(), cx)
250                })?
251                .await?;
252            }
253
254            fs.rename(&tmp_path, &path, RenameOptions::default())
255                .await?;
256
257            Self::load_extension(this, extension, &mut cx).await
258        })
259    }
260
261    pub async fn handle_sync_extensions(
262        extension_store: Model<HeadlessExtensionStore>,
263        envelope: TypedEnvelope<proto::SyncExtensions>,
264        mut cx: AsyncAppContext,
265    ) -> Result<proto::SyncExtensionsResponse> {
266        let requested_extensions =
267            envelope
268                .payload
269                .extensions
270                .into_iter()
271                .map(|p| ExtensionVersion {
272                    id: p.id,
273                    version: p.version,
274                    dev: p.dev,
275                });
276        let missing_extensions = extension_store
277            .update(&mut cx, |extension_store, cx| {
278                extension_store.sync_extensions(requested_extensions.collect(), cx)
279            })?
280            .await?;
281
282        Ok(proto::SyncExtensionsResponse {
283            missing_extensions: missing_extensions
284                .into_iter()
285                .map(|e| proto::Extension {
286                    id: e.id,
287                    version: e.version,
288                    dev: e.dev,
289                })
290                .collect(),
291            tmp_dir: paths::remote_extensions_uploads_dir()
292                .to_string_lossy()
293                .to_string(),
294        })
295    }
296
297    pub async fn handle_install_extension(
298        extensions: Model<HeadlessExtensionStore>,
299        envelope: TypedEnvelope<proto::InstallExtension>,
300        mut cx: AsyncAppContext,
301    ) -> Result<proto::Ack> {
302        let extension = envelope
303            .payload
304            .extension
305            .with_context(|| anyhow!("Invalid InstallExtension request"))?;
306
307        extensions
308            .update(&mut cx, |extensions, cx| {
309                extensions.install_extension(
310                    ExtensionVersion {
311                        id: extension.id,
312                        version: extension.version,
313                        dev: extension.dev,
314                    },
315                    PathBuf::from(envelope.payload.tmp_dir),
316                    cx,
317                )
318            })?
319            .await?;
320
321        Ok(proto::Ack {})
322    }
323}
324
325struct HeadlessRegistrationHooks {
326    language_registry: Arc<LanguageRegistry>,
327}
328
329impl HeadlessRegistrationHooks {
330    fn new(language_registry: Arc<LanguageRegistry>) -> Self {
331        Self { language_registry }
332    }
333}
334
335impl ExtensionRegistrationHooks for HeadlessRegistrationHooks {
336    fn register_language(
337        &self,
338        language: LanguageName,
339        _grammar: Option<Arc<str>>,
340        matcher: language::LanguageMatcher,
341        load: Arc<dyn Fn() -> Result<LoadedLanguage> + 'static + Send + Sync>,
342    ) {
343        log::info!("registering language: {:?}", language);
344        self.language_registry
345            .register_language(language, None, matcher, load)
346    }
347    fn register_lsp_adapter(&self, language: LanguageName, adapter: ExtensionLspAdapter) {
348        log::info!("registering lsp adapter {:?}", language);
349        self.language_registry
350            .register_lsp_adapter(language, Arc::new(adapter) as _);
351    }
352
353    fn register_wasm_grammars(&self, grammars: Vec<(Arc<str>, PathBuf)>) {
354        self.language_registry.register_wasm_grammars(grammars)
355    }
356
357    fn remove_lsp_adapter(&self, language: &LanguageName, server_name: &LanguageServerName) {
358        self.language_registry
359            .remove_lsp_adapter(language, server_name)
360    }
361
362    fn remove_languages(
363        &self,
364        languages_to_remove: &[LanguageName],
365        _grammars_to_remove: &[Arc<str>],
366    ) {
367        self.language_registry
368            .remove_languages(languages_to_remove, &[])
369    }
370
371    fn update_lsp_status(
372        &self,
373        server_name: LanguageServerName,
374        status: language::LanguageServerBinaryStatus,
375    ) {
376        self.language_registry
377            .update_lsp_status(server_name, status)
378    }
379}