1use crate::db::ExtensionVersionConstraints;
2use crate::{AppState, Error, Result, db::NewExtensionVersion};
3use anyhow::Context as _;
4use aws_sdk_s3::presigning::PresigningConfig;
5use axum::{
6 Extension, Json, Router,
7 extract::{Path, Query},
8 http::StatusCode,
9 response::Redirect,
10 routing::get,
11};
12use collections::{BTreeSet, HashMap};
13use rpc::{ExtensionApiManifest, ExtensionProvides, GetExtensionsResponse};
14use semantic_version::SemanticVersion;
15use serde::Deserialize;
16use std::str::FromStr;
17use std::{sync::Arc, time::Duration};
18use time::PrimitiveDateTime;
19use util::{ResultExt, maybe};
20
21pub fn router() -> Router {
22 Router::new()
23 .route("/extensions", get(get_extensions))
24 .route("/extensions/updates", get(get_extension_updates))
25 .route("/extensions/:extension_id", get(get_extension_versions))
26 .route(
27 "/extensions/:extension_id/download",
28 get(download_latest_extension),
29 )
30 .route(
31 "/extensions/:extension_id/:version/download",
32 get(download_extension),
33 )
34}
35
36#[derive(Debug, Deserialize)]
37struct GetExtensionsParams {
38 filter: Option<String>,
39 /// A comma-delimited list of features that the extension must provide.
40 ///
41 /// For example:
42 /// - `themes`
43 /// - `themes,icon-themes`
44 /// - `languages,language-servers`
45 #[serde(default)]
46 provides: Option<String>,
47 #[serde(default)]
48 max_schema_version: i32,
49}
50
51async fn get_extensions(
52 Extension(app): Extension<Arc<AppState>>,
53 Query(params): Query<GetExtensionsParams>,
54) -> Result<Json<GetExtensionsResponse>> {
55 let provides_filter = params.provides.map(|provides| {
56 provides
57 .split(',')
58 .map(|value| value.trim())
59 .filter_map(|value| ExtensionProvides::from_str(value).ok())
60 .collect::<BTreeSet<_>>()
61 });
62
63 let mut extensions = app
64 .db
65 .get_extensions(
66 params.filter.as_deref(),
67 provides_filter.as_ref(),
68 params.max_schema_version,
69 1_000,
70 )
71 .await?;
72
73 if let Some(filter) = params.filter.as_deref() {
74 let extension_id = filter.to_lowercase();
75 let mut exact_match = None;
76 extensions.retain(|extension| {
77 if extension.id.as_ref() == extension_id {
78 exact_match = Some(extension.clone());
79 false
80 } else {
81 true
82 }
83 });
84 if exact_match.is_none() {
85 exact_match = app
86 .db
87 .get_extensions_by_ids(&[&extension_id], None)
88 .await?
89 .first()
90 .cloned();
91 }
92
93 if let Some(exact_match) = exact_match {
94 extensions.insert(0, exact_match);
95 }
96 };
97
98 if let Some(query) = params.filter.as_deref() {
99 let count = extensions.len();
100 tracing::info!(query, count, "extension_search")
101 }
102
103 Ok(Json(GetExtensionsResponse { data: extensions }))
104}
105
106#[derive(Debug, Deserialize)]
107struct GetExtensionUpdatesParams {
108 ids: String,
109 min_schema_version: i32,
110 max_schema_version: i32,
111 min_wasm_api_version: SemanticVersion,
112 max_wasm_api_version: SemanticVersion,
113}
114
115async fn get_extension_updates(
116 Extension(app): Extension<Arc<AppState>>,
117 Query(params): Query<GetExtensionUpdatesParams>,
118) -> Result<Json<GetExtensionsResponse>> {
119 let constraints = ExtensionVersionConstraints {
120 schema_versions: params.min_schema_version..=params.max_schema_version,
121 wasm_api_versions: params.min_wasm_api_version..=params.max_wasm_api_version,
122 };
123
124 let extension_ids = params.ids.split(',').map(|s| s.trim()).collect::<Vec<_>>();
125
126 let extensions = app
127 .db
128 .get_extensions_by_ids(&extension_ids, Some(&constraints))
129 .await?;
130
131 Ok(Json(GetExtensionsResponse { data: extensions }))
132}
133
134#[derive(Debug, Deserialize)]
135struct GetExtensionVersionsParams {
136 extension_id: String,
137}
138
139async fn get_extension_versions(
140 Extension(app): Extension<Arc<AppState>>,
141 Path(params): Path<GetExtensionVersionsParams>,
142) -> Result<Json<GetExtensionsResponse>> {
143 let extension_versions = app.db.get_extension_versions(¶ms.extension_id).await?;
144
145 Ok(Json(GetExtensionsResponse {
146 data: extension_versions,
147 }))
148}
149
150#[derive(Debug, Deserialize)]
151struct DownloadLatestExtensionPathParams {
152 extension_id: String,
153}
154
155#[derive(Debug, Deserialize)]
156struct DownloadLatestExtensionQueryParams {
157 min_schema_version: Option<i32>,
158 max_schema_version: Option<i32>,
159 min_wasm_api_version: Option<SemanticVersion>,
160 max_wasm_api_version: Option<SemanticVersion>,
161}
162
163async fn download_latest_extension(
164 Extension(app): Extension<Arc<AppState>>,
165 Path(params): Path<DownloadLatestExtensionPathParams>,
166 Query(query): Query<DownloadLatestExtensionQueryParams>,
167) -> Result<Redirect> {
168 let constraints = maybe!({
169 let min_schema_version = query.min_schema_version?;
170 let max_schema_version = query.max_schema_version?;
171 let min_wasm_api_version = query.min_wasm_api_version?;
172 let max_wasm_api_version = query.max_wasm_api_version?;
173
174 Some(ExtensionVersionConstraints {
175 schema_versions: min_schema_version..=max_schema_version,
176 wasm_api_versions: min_wasm_api_version..=max_wasm_api_version,
177 })
178 });
179
180 let extension = app
181 .db
182 .get_extension(¶ms.extension_id, constraints.as_ref())
183 .await?
184 .context("unknown extension")?;
185 download_extension(
186 Extension(app),
187 Path(DownloadExtensionParams {
188 extension_id: params.extension_id,
189 version: extension.manifest.version.to_string(),
190 }),
191 )
192 .await
193}
194
195#[derive(Debug, Deserialize)]
196struct DownloadExtensionParams {
197 extension_id: String,
198 version: String,
199}
200
201async fn download_extension(
202 Extension(app): Extension<Arc<AppState>>,
203 Path(params): Path<DownloadExtensionParams>,
204) -> Result<Redirect> {
205 let Some((blob_store_client, bucket)) = app
206 .blob_store_client
207 .clone()
208 .zip(app.config.blob_store_bucket.clone())
209 else {
210 Err(Error::http(
211 StatusCode::NOT_IMPLEMENTED,
212 "not supported".into(),
213 ))?
214 };
215
216 let DownloadExtensionParams {
217 extension_id,
218 version,
219 } = params;
220
221 let version_exists = app
222 .db
223 .record_extension_download(&extension_id, &version)
224 .await?;
225
226 if !version_exists {
227 Err(Error::http(
228 StatusCode::NOT_FOUND,
229 "unknown extension version".into(),
230 ))?;
231 }
232
233 let url = blob_store_client
234 .get_object()
235 .bucket(bucket)
236 .key(format!(
237 "extensions/{extension_id}/{version}/archive.tar.gz"
238 ))
239 .presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
240 .await
241 .context("creating presigned extension download url")?;
242
243 Ok(Redirect::temporary(url.uri()))
244}
245
246const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
247const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
248
249pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
250 let Some(blob_store_client) = app_state.blob_store_client.clone() else {
251 log::info!("no blob store client");
252 return;
253 };
254 let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
255 log::info!("no blob store bucket");
256 return;
257 };
258
259 let executor = app_state.executor.clone();
260 executor.spawn_detached({
261 let executor = executor.clone();
262 async move {
263 loop {
264 fetch_extensions_from_blob_store(
265 &blob_store_client,
266 &blob_store_bucket,
267 &app_state,
268 )
269 .await
270 .log_err();
271 executor.sleep(EXTENSION_FETCH_INTERVAL).await;
272 }
273 }
274 });
275}
276
277async fn fetch_extensions_from_blob_store(
278 blob_store_client: &aws_sdk_s3::Client,
279 blob_store_bucket: &String,
280 app_state: &Arc<AppState>,
281) -> anyhow::Result<()> {
282 log::info!("fetching extensions from blob store");
283
284 let mut next_marker = None;
285 let mut published_versions = HashMap::<String, Vec<String>>::default();
286
287 loop {
288 let list = blob_store_client
289 .list_objects()
290 .bucket(blob_store_bucket)
291 .prefix("extensions/")
292 .set_marker(next_marker.clone())
293 .send()
294 .await?;
295 let objects = list.contents.unwrap_or_default();
296 log::info!("fetched {} object(s) from blob store", objects.len());
297
298 for object in &objects {
299 let Some(key) = object.key.as_ref() else {
300 continue;
301 };
302 let mut parts = key.split('/');
303 let Some(_) = parts.next().filter(|part| *part == "extensions") else {
304 continue;
305 };
306 let Some(extension_id) = parts.next() else {
307 continue;
308 };
309 let Some(version) = parts.next() else {
310 continue;
311 };
312 if parts.next() == Some("manifest.json") {
313 published_versions
314 .entry(extension_id.to_owned())
315 .or_default()
316 .push(version.to_owned());
317 }
318 }
319
320 if let (Some(true), Some(last_object)) = (list.is_truncated, objects.last()) {
321 next_marker.clone_from(&last_object.key);
322 } else {
323 break;
324 }
325 }
326
327 log::info!("found {} published extensions", published_versions.len());
328
329 let known_versions = app_state.db.get_known_extension_versions().await?;
330
331 let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
332 let empty = Vec::new();
333 for (extension_id, published_versions) in &published_versions {
334 let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
335
336 for published_version in published_versions {
337 if known_versions
338 .binary_search_by_key(&published_version, |known_version| known_version)
339 .is_err()
340 && let Some(extension) = fetch_extension_manifest(
341 blob_store_client,
342 blob_store_bucket,
343 extension_id,
344 published_version,
345 )
346 .await
347 .log_err()
348 {
349 new_versions
350 .entry(extension_id)
351 .or_default()
352 .push(extension);
353 }
354 }
355 }
356
357 app_state
358 .db
359 .insert_extension_versions(&new_versions)
360 .await?;
361
362 log::info!(
363 "fetched {} new extensions from blob store",
364 new_versions.values().map(|v| v.len()).sum::<usize>()
365 );
366
367 Ok(())
368}
369
370async fn fetch_extension_manifest(
371 blob_store_client: &aws_sdk_s3::Client,
372 blob_store_bucket: &String,
373 extension_id: &str,
374 version: &str,
375) -> anyhow::Result<NewExtensionVersion> {
376 let object = blob_store_client
377 .get_object()
378 .bucket(blob_store_bucket)
379 .key(format!("extensions/{extension_id}/{version}/manifest.json"))
380 .send()
381 .await?;
382 let manifest_bytes = object
383 .body
384 .collect()
385 .await
386 .map(|data| data.into_bytes())
387 .with_context(|| {
388 format!("failed to download manifest for extension {extension_id} version {version}")
389 })?
390 .to_vec();
391 let manifest =
392 serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
393 format!(
394 "invalid manifest for extension {extension_id} version {version}: {}",
395 String::from_utf8_lossy(&manifest_bytes)
396 )
397 })?;
398 let published_at = object.last_modified.with_context(|| {
399 format!("missing last modified timestamp for extension {extension_id} version {version}")
400 })?;
401 let published_at = time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
402 let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
403 let version = semver::Version::parse(&manifest.version).with_context(|| {
404 format!("invalid version for extension {extension_id} version {version}")
405 })?;
406 Ok(NewExtensionVersion {
407 name: manifest.name,
408 version,
409 description: manifest.description.unwrap_or_default(),
410 authors: manifest.authors,
411 repository: manifest.repository,
412 schema_version: manifest.schema_version.unwrap_or(0),
413 wasm_api_version: manifest.wasm_api_version,
414 provides: manifest.provides,
415 published_at,
416 })
417}