1use crate::db::ExtensionVersionConstraints;
2use crate::{db::NewExtensionVersion, AppState, Error, Result};
3use anyhow::{anyhow, Context as _};
4use aws_sdk_s3::presigning::PresigningConfig;
5use axum::{
6 extract::{Path, Query},
7 http::StatusCode,
8 response::Redirect,
9 routing::get,
10 Extension, Json, Router,
11};
12use collections::HashMap;
13use rpc::{ExtensionApiManifest, GetExtensionsResponse};
14use semantic_version::SemanticVersion;
15use serde::Deserialize;
16use std::{sync::Arc, time::Duration};
17use time::PrimitiveDateTime;
18use util::{maybe, ResultExt};
19
20pub fn router() -> Router {
21 Router::new()
22 .route("/extensions", get(get_extensions))
23 .route("/extensions/updates", get(get_extension_updates))
24 .route("/extensions/:extension_id", get(get_extension_versions))
25 .route(
26 "/extensions/:extension_id/download",
27 get(download_latest_extension),
28 )
29 .route(
30 "/extensions/:extension_id/:version/download",
31 get(download_extension),
32 )
33}
34
35#[derive(Debug, Deserialize)]
36struct GetExtensionsParams {
37 filter: Option<String>,
38 #[serde(default)]
39 ids: Option<String>,
40 #[serde(default)]
41 max_schema_version: i32,
42}
43
44async fn get_extensions(
45 Extension(app): Extension<Arc<AppState>>,
46 Query(params): Query<GetExtensionsParams>,
47) -> Result<Json<GetExtensionsResponse>> {
48 let extension_ids = params
49 .ids
50 .as_ref()
51 .map(|s| s.split(',').map(|s| s.trim()).collect::<Vec<_>>());
52
53 let extensions = if let Some(extension_ids) = extension_ids {
54 app.db.get_extensions_by_ids(&extension_ids, None).await?
55 } else {
56 let result = app
57 .db
58 .get_extensions(params.filter.as_deref(), params.max_schema_version, 500)
59 .await?;
60
61 if let Some(query) = params.filter.as_deref() {
62 let count = result.len();
63 tracing::info!(query, count, "extension_search")
64 }
65
66 result
67 };
68
69 Ok(Json(GetExtensionsResponse { data: extensions }))
70}
71
72#[derive(Debug, Deserialize)]
73struct GetExtensionUpdatesParams {
74 ids: String,
75 min_schema_version: i32,
76 max_schema_version: i32,
77 min_wasm_api_version: SemanticVersion,
78 max_wasm_api_version: SemanticVersion,
79}
80
81async fn get_extension_updates(
82 Extension(app): Extension<Arc<AppState>>,
83 Query(params): Query<GetExtensionUpdatesParams>,
84) -> Result<Json<GetExtensionsResponse>> {
85 let constraints = ExtensionVersionConstraints {
86 schema_versions: params.min_schema_version..=params.max_schema_version,
87 wasm_api_versions: params.min_wasm_api_version..=params.max_wasm_api_version,
88 };
89
90 let extension_ids = params.ids.split(',').map(|s| s.trim()).collect::<Vec<_>>();
91
92 let extensions = app
93 .db
94 .get_extensions_by_ids(&extension_ids, Some(&constraints))
95 .await?;
96
97 Ok(Json(GetExtensionsResponse { data: extensions }))
98}
99
100#[derive(Debug, Deserialize)]
101struct GetExtensionVersionsParams {
102 extension_id: String,
103}
104
105async fn get_extension_versions(
106 Extension(app): Extension<Arc<AppState>>,
107 Path(params): Path<GetExtensionVersionsParams>,
108) -> Result<Json<GetExtensionsResponse>> {
109 let extension_versions = app.db.get_extension_versions(¶ms.extension_id).await?;
110
111 Ok(Json(GetExtensionsResponse {
112 data: extension_versions,
113 }))
114}
115
116#[derive(Debug, Deserialize)]
117struct DownloadLatestExtensionPathParams {
118 extension_id: String,
119}
120
121#[derive(Debug, Deserialize)]
122struct DownloadLatestExtensionQueryParams {
123 min_schema_version: Option<i32>,
124 max_schema_version: Option<i32>,
125 min_wasm_api_version: Option<SemanticVersion>,
126 max_wasm_api_version: Option<SemanticVersion>,
127}
128
129async fn download_latest_extension(
130 Extension(app): Extension<Arc<AppState>>,
131 Path(params): Path<DownloadLatestExtensionPathParams>,
132 Query(query): Query<DownloadLatestExtensionQueryParams>,
133) -> Result<Redirect> {
134 let constraints = maybe!({
135 let min_schema_version = query.min_schema_version?;
136 let max_schema_version = query.max_schema_version?;
137 let min_wasm_api_version = query.min_wasm_api_version?;
138 let max_wasm_api_version = query.max_wasm_api_version?;
139
140 Some(ExtensionVersionConstraints {
141 schema_versions: min_schema_version..=max_schema_version,
142 wasm_api_versions: min_wasm_api_version..=max_wasm_api_version,
143 })
144 });
145
146 let extension = app
147 .db
148 .get_extension(¶ms.extension_id, constraints.as_ref())
149 .await?
150 .ok_or_else(|| anyhow!("unknown extension"))?;
151 download_extension(
152 Extension(app),
153 Path(DownloadExtensionParams {
154 extension_id: params.extension_id,
155 version: extension.manifest.version.to_string(),
156 }),
157 )
158 .await
159}
160
161#[derive(Debug, Deserialize)]
162struct DownloadExtensionParams {
163 extension_id: String,
164 version: String,
165}
166
167async fn download_extension(
168 Extension(app): Extension<Arc<AppState>>,
169 Path(params): Path<DownloadExtensionParams>,
170) -> Result<Redirect> {
171 let Some((blob_store_client, bucket)) = app
172 .blob_store_client
173 .clone()
174 .zip(app.config.blob_store_bucket.clone())
175 else {
176 Err(Error::Http(
177 StatusCode::NOT_IMPLEMENTED,
178 "not supported".into(),
179 ))?
180 };
181
182 let DownloadExtensionParams {
183 extension_id,
184 version,
185 } = params;
186
187 let version_exists = app
188 .db
189 .record_extension_download(&extension_id, &version)
190 .await?;
191
192 if !version_exists {
193 Err(Error::Http(
194 StatusCode::NOT_FOUND,
195 "unknown extension version".into(),
196 ))?;
197 }
198
199 let url = blob_store_client
200 .get_object()
201 .bucket(bucket)
202 .key(format!(
203 "extensions/{extension_id}/{version}/archive.tar.gz"
204 ))
205 .presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
206 .await
207 .map_err(|e| anyhow!("failed to create presigned extension download url {e}"))?;
208
209 Ok(Redirect::temporary(url.uri()))
210}
211
212const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
213const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
214
215pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
216 let Some(blob_store_client) = app_state.blob_store_client.clone() else {
217 log::info!("no blob store client");
218 return;
219 };
220 let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
221 log::info!("no blob store bucket");
222 return;
223 };
224
225 let executor = app_state.executor.clone();
226 executor.spawn_detached({
227 let executor = executor.clone();
228 async move {
229 loop {
230 fetch_extensions_from_blob_store(
231 &blob_store_client,
232 &blob_store_bucket,
233 &app_state,
234 )
235 .await
236 .log_err();
237 executor.sleep(EXTENSION_FETCH_INTERVAL).await;
238 }
239 }
240 });
241}
242
243async fn fetch_extensions_from_blob_store(
244 blob_store_client: &aws_sdk_s3::Client,
245 blob_store_bucket: &String,
246 app_state: &Arc<AppState>,
247) -> anyhow::Result<()> {
248 log::info!("fetching extensions from blob store");
249
250 let mut next_marker = None;
251 let mut published_versions = HashMap::<String, Vec<String>>::default();
252
253 loop {
254 let list = blob_store_client
255 .list_objects()
256 .bucket(blob_store_bucket)
257 .prefix("extensions/")
258 .set_marker(next_marker.clone())
259 .send()
260 .await?;
261 let objects = list.contents.unwrap_or_default();
262 log::info!("fetched {} object(s) from blob store", objects.len());
263
264 for object in &objects {
265 let Some(key) = object.key.as_ref() else {
266 continue;
267 };
268 let mut parts = key.split('/');
269 let Some(_) = parts.next().filter(|part| *part == "extensions") else {
270 continue;
271 };
272 let Some(extension_id) = parts.next() else {
273 continue;
274 };
275 let Some(version) = parts.next() else {
276 continue;
277 };
278 if parts.next() == Some("manifest.json") {
279 published_versions
280 .entry(extension_id.to_owned())
281 .or_default()
282 .push(version.to_owned());
283 }
284 }
285
286 if let (Some(true), Some(last_object)) = (list.is_truncated, objects.last()) {
287 next_marker.clone_from(&last_object.key);
288 } else {
289 break;
290 }
291 }
292
293 log::info!("found {} published extensions", published_versions.len());
294
295 let known_versions = app_state.db.get_known_extension_versions().await?;
296
297 let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
298 let empty = Vec::new();
299 for (extension_id, published_versions) in &published_versions {
300 let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
301
302 for published_version in published_versions {
303 if known_versions
304 .binary_search_by_key(&published_version, |known_version| known_version)
305 .is_err()
306 {
307 if let Some(extension) = fetch_extension_manifest(
308 blob_store_client,
309 blob_store_bucket,
310 &extension_id,
311 &published_version,
312 )
313 .await
314 .log_err()
315 {
316 new_versions
317 .entry(&extension_id)
318 .or_default()
319 .push(extension);
320 }
321 }
322 }
323 }
324
325 app_state
326 .db
327 .insert_extension_versions(&new_versions)
328 .await?;
329
330 log::info!(
331 "fetched {} new extensions from blob store",
332 new_versions.values().map(|v| v.len()).sum::<usize>()
333 );
334
335 Ok(())
336}
337
338async fn fetch_extension_manifest(
339 blob_store_client: &aws_sdk_s3::Client,
340 blob_store_bucket: &String,
341 extension_id: &str,
342 version: &str,
343) -> Result<NewExtensionVersion, anyhow::Error> {
344 let object = blob_store_client
345 .get_object()
346 .bucket(blob_store_bucket)
347 .key(format!("extensions/{extension_id}/{version}/manifest.json"))
348 .send()
349 .await?;
350 let manifest_bytes = object
351 .body
352 .collect()
353 .await
354 .map(|data| data.into_bytes())
355 .with_context(|| {
356 format!("failed to download manifest for extension {extension_id} version {version}")
357 })?
358 .to_vec();
359 let manifest =
360 serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
361 format!(
362 "invalid manifest for extension {extension_id} version {version}: {}",
363 String::from_utf8_lossy(&manifest_bytes)
364 )
365 })?;
366 let published_at = object.last_modified.ok_or_else(|| {
367 anyhow!("missing last modified timestamp for extension {extension_id} version {version}")
368 })?;
369 let published_at = time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
370 let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
371 let version = semver::Version::parse(&manifest.version).with_context(|| {
372 format!("invalid version for extension {extension_id} version {version}")
373 })?;
374 Ok(NewExtensionVersion {
375 name: manifest.name,
376 version,
377 description: manifest.description.unwrap_or_default(),
378 authors: manifest.authors,
379 repository: manifest.repository,
380 schema_version: manifest.schema_version.unwrap_or(0),
381 wasm_api_version: manifest.wasm_api_version,
382 published_at,
383 })
384}