1use crate::db::ExtensionVersionConstraints;
2use crate::{db::NewExtensionVersion, AppState, Error, Result};
3use anyhow::{anyhow, Context as _};
4use aws_sdk_s3::presigning::PresigningConfig;
5use axum::{
6 extract::{Path, Query},
7 http::StatusCode,
8 response::Redirect,
9 routing::get,
10 Extension, Json, Router,
11};
12use collections::HashMap;
13use rpc::{ExtensionApiManifest, GetExtensionsResponse};
14use semantic_version::SemanticVersion;
15use serde::Deserialize;
16use std::{sync::Arc, time::Duration};
17use time::PrimitiveDateTime;
18use util::{maybe, ResultExt};
19
20pub fn router() -> Router {
21 Router::new()
22 .route("/extensions", get(get_extensions))
23 .route("/extensions/updates", get(get_extension_updates))
24 .route("/extensions/:extension_id", get(get_extension_versions))
25 .route(
26 "/extensions/:extension_id/download",
27 get(download_latest_extension),
28 )
29 .route(
30 "/extensions/:extension_id/:version/download",
31 get(download_extension),
32 )
33}
34
35#[derive(Debug, Deserialize)]
36struct GetExtensionsParams {
37 filter: Option<String>,
38 #[serde(default)]
39 ids: Option<String>,
40 #[serde(default)]
41 max_schema_version: i32,
42}
43
44async fn get_extensions(
45 Extension(app): Extension<Arc<AppState>>,
46 Query(params): Query<GetExtensionsParams>,
47) -> Result<Json<GetExtensionsResponse>> {
48 let extension_ids = params
49 .ids
50 .as_ref()
51 .map(|s| s.split(',').map(|s| s.trim()).collect::<Vec<_>>());
52
53 let extensions = if let Some(extension_ids) = extension_ids {
54 app.db.get_extensions_by_ids(&extension_ids, None).await?
55 } else {
56 app.db
57 .get_extensions(params.filter.as_deref(), params.max_schema_version, 500)
58 .await?
59 };
60
61 Ok(Json(GetExtensionsResponse { data: extensions }))
62}
63
64#[derive(Debug, Deserialize)]
65struct GetExtensionUpdatesParams {
66 ids: String,
67 min_schema_version: i32,
68 max_schema_version: i32,
69 min_wasm_api_version: SemanticVersion,
70 max_wasm_api_version: SemanticVersion,
71}
72
73async fn get_extension_updates(
74 Extension(app): Extension<Arc<AppState>>,
75 Query(params): Query<GetExtensionUpdatesParams>,
76) -> Result<Json<GetExtensionsResponse>> {
77 let constraints = ExtensionVersionConstraints {
78 schema_versions: params.min_schema_version..=params.max_schema_version,
79 wasm_api_versions: params.min_wasm_api_version..=params.max_wasm_api_version,
80 };
81
82 let extension_ids = params.ids.split(',').map(|s| s.trim()).collect::<Vec<_>>();
83
84 let extensions = app
85 .db
86 .get_extensions_by_ids(&extension_ids, Some(&constraints))
87 .await?;
88
89 Ok(Json(GetExtensionsResponse { data: extensions }))
90}
91
92#[derive(Debug, Deserialize)]
93struct GetExtensionVersionsParams {
94 extension_id: String,
95}
96
97async fn get_extension_versions(
98 Extension(app): Extension<Arc<AppState>>,
99 Path(params): Path<GetExtensionVersionsParams>,
100) -> Result<Json<GetExtensionsResponse>> {
101 let extension_versions = app.db.get_extension_versions(¶ms.extension_id).await?;
102
103 Ok(Json(GetExtensionsResponse {
104 data: extension_versions,
105 }))
106}
107
108#[derive(Debug, Deserialize)]
109struct DownloadLatestExtensionPathParams {
110 extension_id: String,
111}
112
113#[derive(Debug, Deserialize)]
114struct DownloadLatestExtensionQueryParams {
115 min_schema_version: Option<i32>,
116 max_schema_version: Option<i32>,
117 min_wasm_api_version: Option<SemanticVersion>,
118 max_wasm_api_version: Option<SemanticVersion>,
119}
120
121async fn download_latest_extension(
122 Extension(app): Extension<Arc<AppState>>,
123 Path(params): Path<DownloadLatestExtensionPathParams>,
124 Query(query): Query<DownloadLatestExtensionQueryParams>,
125) -> Result<Redirect> {
126 let constraints = maybe!({
127 let min_schema_version = query.min_schema_version?;
128 let max_schema_version = query.max_schema_version?;
129 let min_wasm_api_version = query.min_wasm_api_version?;
130 let max_wasm_api_version = query.max_wasm_api_version?;
131
132 Some(ExtensionVersionConstraints {
133 schema_versions: min_schema_version..=max_schema_version,
134 wasm_api_versions: min_wasm_api_version..=max_wasm_api_version,
135 })
136 });
137
138 let extension = app
139 .db
140 .get_extension(¶ms.extension_id, constraints.as_ref())
141 .await?
142 .ok_or_else(|| anyhow!("unknown extension"))?;
143 download_extension(
144 Extension(app),
145 Path(DownloadExtensionParams {
146 extension_id: params.extension_id,
147 version: extension.manifest.version.to_string(),
148 }),
149 )
150 .await
151}
152
153#[derive(Debug, Deserialize)]
154struct DownloadExtensionParams {
155 extension_id: String,
156 version: String,
157}
158
159async fn download_extension(
160 Extension(app): Extension<Arc<AppState>>,
161 Path(params): Path<DownloadExtensionParams>,
162) -> Result<Redirect> {
163 let Some((blob_store_client, bucket)) = app
164 .blob_store_client
165 .clone()
166 .zip(app.config.blob_store_bucket.clone())
167 else {
168 Err(Error::Http(
169 StatusCode::NOT_IMPLEMENTED,
170 "not supported".into(),
171 ))?
172 };
173
174 let DownloadExtensionParams {
175 extension_id,
176 version,
177 } = params;
178
179 let version_exists = app
180 .db
181 .record_extension_download(&extension_id, &version)
182 .await?;
183
184 if !version_exists {
185 Err(Error::Http(
186 StatusCode::NOT_FOUND,
187 "unknown extension version".into(),
188 ))?;
189 }
190
191 let url = blob_store_client
192 .get_object()
193 .bucket(bucket)
194 .key(format!(
195 "extensions/{extension_id}/{version}/archive.tar.gz"
196 ))
197 .presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
198 .await
199 .map_err(|e| anyhow!("failed to create presigned extension download url {e}"))?;
200
201 Ok(Redirect::temporary(url.uri()))
202}
203
204const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
205const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
206
207pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
208 let Some(blob_store_client) = app_state.blob_store_client.clone() else {
209 log::info!("no blob store client");
210 return;
211 };
212 let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
213 log::info!("no blob store bucket");
214 return;
215 };
216
217 let executor = app_state.executor.clone();
218 executor.spawn_detached({
219 let executor = executor.clone();
220 async move {
221 loop {
222 fetch_extensions_from_blob_store(
223 &blob_store_client,
224 &blob_store_bucket,
225 &app_state,
226 )
227 .await
228 .log_err();
229 executor.sleep(EXTENSION_FETCH_INTERVAL).await;
230 }
231 }
232 });
233}
234
235async fn fetch_extensions_from_blob_store(
236 blob_store_client: &aws_sdk_s3::Client,
237 blob_store_bucket: &String,
238 app_state: &Arc<AppState>,
239) -> anyhow::Result<()> {
240 log::info!("fetching extensions from blob store");
241
242 let mut next_marker = None;
243 let mut published_versions = HashMap::<String, Vec<String>>::default();
244
245 loop {
246 let list = blob_store_client
247 .list_objects()
248 .bucket(blob_store_bucket)
249 .prefix("extensions/")
250 .set_marker(next_marker.clone())
251 .send()
252 .await?;
253 let objects = list.contents.unwrap_or_default();
254 log::info!("fetched {} object(s) from blob store", objects.len());
255
256 for object in &objects {
257 let Some(key) = object.key.as_ref() else {
258 continue;
259 };
260 let mut parts = key.split('/');
261 let Some(_) = parts.next().filter(|part| *part == "extensions") else {
262 continue;
263 };
264 let Some(extension_id) = parts.next() else {
265 continue;
266 };
267 let Some(version) = parts.next() else {
268 continue;
269 };
270 if parts.next() == Some("manifest.json") {
271 published_versions
272 .entry(extension_id.to_owned())
273 .or_default()
274 .push(version.to_owned());
275 }
276 }
277
278 if let (Some(true), Some(last_object)) = (list.is_truncated, objects.last()) {
279 next_marker.clone_from(&last_object.key);
280 } else {
281 break;
282 }
283 }
284
285 log::info!("found {} published extensions", published_versions.len());
286
287 let known_versions = app_state.db.get_known_extension_versions().await?;
288
289 let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
290 let empty = Vec::new();
291 for (extension_id, published_versions) in &published_versions {
292 let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
293
294 for published_version in published_versions {
295 if known_versions
296 .binary_search_by_key(&published_version, |known_version| known_version)
297 .is_err()
298 {
299 if let Some(extension) = fetch_extension_manifest(
300 blob_store_client,
301 blob_store_bucket,
302 &extension_id,
303 &published_version,
304 )
305 .await
306 .log_err()
307 {
308 new_versions
309 .entry(&extension_id)
310 .or_default()
311 .push(extension);
312 }
313 }
314 }
315 }
316
317 app_state
318 .db
319 .insert_extension_versions(&new_versions)
320 .await?;
321
322 log::info!(
323 "fetched {} new extensions from blob store",
324 new_versions.values().map(|v| v.len()).sum::<usize>()
325 );
326
327 Ok(())
328}
329
330async fn fetch_extension_manifest(
331 blob_store_client: &aws_sdk_s3::Client,
332 blob_store_bucket: &String,
333 extension_id: &str,
334 version: &str,
335) -> Result<NewExtensionVersion, anyhow::Error> {
336 let object = blob_store_client
337 .get_object()
338 .bucket(blob_store_bucket)
339 .key(format!("extensions/{extension_id}/{version}/manifest.json"))
340 .send()
341 .await?;
342 let manifest_bytes = object
343 .body
344 .collect()
345 .await
346 .map(|data| data.into_bytes())
347 .with_context(|| {
348 format!("failed to download manifest for extension {extension_id} version {version}")
349 })?
350 .to_vec();
351 let manifest =
352 serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
353 format!(
354 "invalid manifest for extension {extension_id} version {version}: {}",
355 String::from_utf8_lossy(&manifest_bytes)
356 )
357 })?;
358 let published_at = object.last_modified.ok_or_else(|| {
359 anyhow!("missing last modified timestamp for extension {extension_id} version {version}")
360 })?;
361 let published_at = time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
362 let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
363 let version = semver::Version::parse(&manifest.version).with_context(|| {
364 format!("invalid version for extension {extension_id} version {version}")
365 })?;
366 Ok(NewExtensionVersion {
367 name: manifest.name,
368 version,
369 description: manifest.description.unwrap_or_default(),
370 authors: manifest.authors,
371 repository: manifest.repository,
372 schema_version: manifest.schema_version.unwrap_or(0),
373 wasm_api_version: manifest.wasm_api_version,
374 published_at,
375 })
376}