1use std::{path::PathBuf, sync::Arc};
2
3use anyhow::{anyhow, Context as _, Result};
4use client::{proto, TypedEnvelope};
5use collections::{HashMap, HashSet};
6use extension::{
7 Extension, ExtensionHostProxy, ExtensionLanguageProxy, ExtensionLanguageServerProxy,
8 ExtensionManifest,
9};
10use fs::{Fs, RemoveOptions, RenameOptions};
11use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
12use http_client::HttpClient;
13use language::{LanguageConfig, LanguageName, LanguageQueries, LoadedLanguage};
14use lsp::LanguageServerName;
15use node_runtime::NodeRuntime;
16
17use crate::wasm_host::{WasmExtension, WasmHost};
18
19#[derive(Clone, Debug)]
20pub struct ExtensionVersion {
21 pub id: String,
22 pub version: String,
23 pub dev: bool,
24}
25
26pub struct HeadlessExtensionStore {
27 pub fs: Arc<dyn Fs>,
28 pub extension_dir: PathBuf,
29 pub proxy: Arc<ExtensionHostProxy>,
30 pub wasm_host: Arc<WasmHost>,
31 pub loaded_extensions: HashMap<Arc<str>, Arc<str>>,
32 pub loaded_languages: HashMap<Arc<str>, Vec<LanguageName>>,
33 pub loaded_language_servers: HashMap<Arc<str>, Vec<(LanguageServerName, LanguageName)>>,
34}
35
36impl HeadlessExtensionStore {
37 pub fn new(
38 fs: Arc<dyn Fs>,
39 http_client: Arc<dyn HttpClient>,
40 extension_dir: PathBuf,
41 extension_host_proxy: Arc<ExtensionHostProxy>,
42 node_runtime: NodeRuntime,
43 cx: &mut App,
44 ) -> Entity<Self> {
45 cx.new(|cx| Self {
46 fs: fs.clone(),
47 wasm_host: WasmHost::new(
48 fs.clone(),
49 http_client.clone(),
50 node_runtime,
51 extension_host_proxy.clone(),
52 extension_dir.join("work"),
53 cx,
54 ),
55 extension_dir,
56 proxy: extension_host_proxy,
57 loaded_extensions: Default::default(),
58 loaded_languages: Default::default(),
59 loaded_language_servers: Default::default(),
60 })
61 }
62
63 pub fn sync_extensions(
64 &mut self,
65 extensions: Vec<ExtensionVersion>,
66 cx: &Context<Self>,
67 ) -> Task<Result<Vec<ExtensionVersion>>> {
68 let on_client = HashSet::from_iter(extensions.iter().map(|e| e.id.as_str()));
69 let to_remove: Vec<Arc<str>> = self
70 .loaded_extensions
71 .keys()
72 .filter(|id| !on_client.contains(id.as_ref()))
73 .cloned()
74 .collect();
75 let to_load: Vec<ExtensionVersion> = extensions
76 .into_iter()
77 .filter(|e| {
78 if e.dev {
79 return true;
80 }
81 !self
82 .loaded_extensions
83 .get(e.id.as_str())
84 .is_some_and(|loaded| loaded.as_ref() == e.version.as_str())
85 })
86 .collect();
87
88 cx.spawn(|this, mut cx| async move {
89 let mut missing = Vec::new();
90
91 for extension_id in to_remove {
92 log::info!("removing extension: {}", extension_id);
93 this.update(&mut cx, |this, cx| {
94 this.uninstall_extension(&extension_id, cx)
95 })?
96 .await?;
97 }
98
99 for extension in to_load {
100 if let Err(e) = Self::load_extension(this.clone(), extension.clone(), &mut cx).await
101 {
102 log::info!("failed to load extension: {}, {:?}", extension.id, e);
103 missing.push(extension)
104 } else if extension.dev {
105 missing.push(extension)
106 }
107 }
108
109 Ok(missing)
110 })
111 }
112
113 pub async fn load_extension(
114 this: WeakEntity<Self>,
115 extension: ExtensionVersion,
116 cx: &mut AsyncApp,
117 ) -> Result<()> {
118 let (fs, wasm_host, extension_dir) = this.update(cx, |this, _cx| {
119 this.loaded_extensions.insert(
120 extension.id.clone().into(),
121 extension.version.clone().into(),
122 );
123 (
124 this.fs.clone(),
125 this.wasm_host.clone(),
126 this.extension_dir.join(&extension.id),
127 )
128 })?;
129
130 let manifest = Arc::new(ExtensionManifest::load(fs.clone(), &extension_dir).await?);
131
132 debug_assert!(!manifest.languages.is_empty() || !manifest.language_servers.is_empty());
133
134 if manifest.version.as_ref() != extension.version.as_str() {
135 anyhow::bail!(
136 "mismatched versions: ({}) != ({})",
137 manifest.version,
138 extension.version
139 )
140 }
141
142 for language_path in &manifest.languages {
143 let language_path = extension_dir.join(language_path);
144 let config = fs.load(&language_path.join("config.toml")).await?;
145 let mut config = ::toml::from_str::<LanguageConfig>(&config)?;
146
147 this.update(cx, |this, _cx| {
148 this.loaded_languages
149 .entry(manifest.id.clone())
150 .or_default()
151 .push(config.name.clone());
152
153 config.grammar = None;
154
155 this.proxy.register_language(
156 config.name.clone(),
157 None,
158 config.matcher.clone(),
159 config.hidden,
160 Arc::new(move || {
161 Ok(LoadedLanguage {
162 config: config.clone(),
163 queries: LanguageQueries::default(),
164 context_provider: None,
165 toolchain_provider: None,
166 })
167 }),
168 );
169 })?;
170 }
171
172 if manifest.language_servers.is_empty() {
173 return Ok(());
174 }
175
176 let wasm_extension: Arc<dyn Extension> =
177 Arc::new(WasmExtension::load(extension_dir, &manifest, wasm_host.clone(), &cx).await?);
178
179 for (language_server_id, language_server_config) in &manifest.language_servers {
180 for language in language_server_config.languages() {
181 this.update(cx, |this, _cx| {
182 this.loaded_language_servers
183 .entry(manifest.id.clone())
184 .or_default()
185 .push((language_server_id.clone(), language.clone()));
186 this.proxy.register_language_server(
187 wasm_extension.clone(),
188 language_server_id.clone(),
189 language.clone(),
190 );
191 })?;
192 }
193 }
194
195 Ok(())
196 }
197
198 fn uninstall_extension(
199 &mut self,
200 extension_id: &Arc<str>,
201 cx: &mut Context<Self>,
202 ) -> Task<Result<()>> {
203 self.loaded_extensions.remove(extension_id);
204
205 let languages_to_remove = self
206 .loaded_languages
207 .remove(extension_id)
208 .unwrap_or_default();
209 self.proxy.remove_languages(&languages_to_remove, &[]);
210
211 for (language_server_name, language) in self
212 .loaded_language_servers
213 .remove(extension_id)
214 .unwrap_or_default()
215 {
216 self.proxy
217 .remove_language_server(&language, &language_server_name);
218 }
219
220 let path = self.extension_dir.join(&extension_id.to_string());
221 let fs = self.fs.clone();
222 cx.spawn(|_, _| async move {
223 fs.remove_dir(
224 &path,
225 RemoveOptions {
226 recursive: true,
227 ignore_if_not_exists: true,
228 },
229 )
230 .await
231 })
232 }
233
234 pub fn install_extension(
235 &mut self,
236 extension: ExtensionVersion,
237 tmp_path: PathBuf,
238 cx: &mut Context<Self>,
239 ) -> Task<Result<()>> {
240 let path = self.extension_dir.join(&extension.id);
241 let fs = self.fs.clone();
242
243 cx.spawn(|this, mut cx| async move {
244 if fs.is_dir(&path).await {
245 this.update(&mut cx, |this, cx| {
246 this.uninstall_extension(&extension.id.clone().into(), cx)
247 })?
248 .await?;
249 }
250
251 fs.rename(&tmp_path, &path, RenameOptions::default())
252 .await?;
253
254 Self::load_extension(this, extension, &mut cx).await
255 })
256 }
257
258 pub async fn handle_sync_extensions(
259 extension_store: Entity<HeadlessExtensionStore>,
260 envelope: TypedEnvelope<proto::SyncExtensions>,
261 mut cx: AsyncApp,
262 ) -> Result<proto::SyncExtensionsResponse> {
263 let requested_extensions =
264 envelope
265 .payload
266 .extensions
267 .into_iter()
268 .map(|p| ExtensionVersion {
269 id: p.id,
270 version: p.version,
271 dev: p.dev,
272 });
273 let missing_extensions = extension_store
274 .update(&mut cx, |extension_store, cx| {
275 extension_store.sync_extensions(requested_extensions.collect(), cx)
276 })?
277 .await?;
278
279 Ok(proto::SyncExtensionsResponse {
280 missing_extensions: missing_extensions
281 .into_iter()
282 .map(|e| proto::Extension {
283 id: e.id,
284 version: e.version,
285 dev: e.dev,
286 })
287 .collect(),
288 tmp_dir: paths::remote_extensions_uploads_dir()
289 .to_string_lossy()
290 .to_string(),
291 })
292 }
293
294 pub async fn handle_install_extension(
295 extensions: Entity<HeadlessExtensionStore>,
296 envelope: TypedEnvelope<proto::InstallExtension>,
297 mut cx: AsyncApp,
298 ) -> Result<proto::Ack> {
299 let extension = envelope
300 .payload
301 .extension
302 .with_context(|| anyhow!("Invalid InstallExtension request"))?;
303
304 extensions
305 .update(&mut cx, |extensions, cx| {
306 extensions.install_extension(
307 ExtensionVersion {
308 id: extension.id,
309 version: extension.version,
310 dev: extension.dev,
311 },
312 PathBuf::from(envelope.payload.tmp_dir),
313 cx,
314 )
315 })?
316 .await?;
317
318 Ok(proto::Ack {})
319 }
320}