1use crate::db::ExtensionVersionConstraints;
2use crate::{db::NewExtensionVersion, AppState, Error, Result};
3use anyhow::{anyhow, Context as _};
4use aws_sdk_s3::presigning::PresigningConfig;
5use axum::{
6 extract::{Path, Query},
7 http::StatusCode,
8 response::Redirect,
9 routing::get,
10 Extension, Json, Router,
11};
12use collections::HashMap;
13use rpc::{ExtensionApiManifest, GetExtensionsResponse};
14use semantic_version::SemanticVersion;
15use serde::Deserialize;
16use std::{sync::Arc, time::Duration};
17use time::PrimitiveDateTime;
18use util::{maybe, ResultExt};
19
20pub fn router() -> Router {
21 Router::new()
22 .route("/extensions", get(get_extensions))
23 .route("/extensions/updates", get(get_extension_updates))
24 .route("/extensions/:extension_id", get(get_extension_versions))
25 .route(
26 "/extensions/:extension_id/download",
27 get(download_latest_extension),
28 )
29 .route(
30 "/extensions/:extension_id/:version/download",
31 get(download_extension),
32 )
33}
34
35#[derive(Debug, Deserialize)]
36struct GetExtensionsParams {
37 filter: Option<String>,
38 #[serde(default)]
39 max_schema_version: i32,
40}
41
42async fn get_extensions(
43 Extension(app): Extension<Arc<AppState>>,
44 Query(params): Query<GetExtensionsParams>,
45) -> Result<Json<GetExtensionsResponse>> {
46 let extensions = app
47 .db
48 .get_extensions(params.filter.as_deref(), params.max_schema_version, 500)
49 .await?;
50
51 if let Some(query) = params.filter.as_deref() {
52 let count = extensions.len();
53 tracing::info!(query, count, "extension_search")
54 }
55
56 Ok(Json(GetExtensionsResponse { data: extensions }))
57}
58
59#[derive(Debug, Deserialize)]
60struct GetExtensionUpdatesParams {
61 ids: String,
62 min_schema_version: i32,
63 max_schema_version: i32,
64 min_wasm_api_version: SemanticVersion,
65 max_wasm_api_version: SemanticVersion,
66}
67
68async fn get_extension_updates(
69 Extension(app): Extension<Arc<AppState>>,
70 Query(params): Query<GetExtensionUpdatesParams>,
71) -> Result<Json<GetExtensionsResponse>> {
72 let constraints = ExtensionVersionConstraints {
73 schema_versions: params.min_schema_version..=params.max_schema_version,
74 wasm_api_versions: params.min_wasm_api_version..=params.max_wasm_api_version,
75 };
76
77 let extension_ids = params.ids.split(',').map(|s| s.trim()).collect::<Vec<_>>();
78
79 let extensions = app
80 .db
81 .get_extensions_by_ids(&extension_ids, Some(&constraints))
82 .await?;
83
84 Ok(Json(GetExtensionsResponse { data: extensions }))
85}
86
87#[derive(Debug, Deserialize)]
88struct GetExtensionVersionsParams {
89 extension_id: String,
90}
91
92async fn get_extension_versions(
93 Extension(app): Extension<Arc<AppState>>,
94 Path(params): Path<GetExtensionVersionsParams>,
95) -> Result<Json<GetExtensionsResponse>> {
96 let extension_versions = app.db.get_extension_versions(¶ms.extension_id).await?;
97
98 Ok(Json(GetExtensionsResponse {
99 data: extension_versions,
100 }))
101}
102
103#[derive(Debug, Deserialize)]
104struct DownloadLatestExtensionPathParams {
105 extension_id: String,
106}
107
108#[derive(Debug, Deserialize)]
109struct DownloadLatestExtensionQueryParams {
110 min_schema_version: Option<i32>,
111 max_schema_version: Option<i32>,
112 min_wasm_api_version: Option<SemanticVersion>,
113 max_wasm_api_version: Option<SemanticVersion>,
114}
115
116async fn download_latest_extension(
117 Extension(app): Extension<Arc<AppState>>,
118 Path(params): Path<DownloadLatestExtensionPathParams>,
119 Query(query): Query<DownloadLatestExtensionQueryParams>,
120) -> Result<Redirect> {
121 let constraints = maybe!({
122 let min_schema_version = query.min_schema_version?;
123 let max_schema_version = query.max_schema_version?;
124 let min_wasm_api_version = query.min_wasm_api_version?;
125 let max_wasm_api_version = query.max_wasm_api_version?;
126
127 Some(ExtensionVersionConstraints {
128 schema_versions: min_schema_version..=max_schema_version,
129 wasm_api_versions: min_wasm_api_version..=max_wasm_api_version,
130 })
131 });
132
133 let extension = app
134 .db
135 .get_extension(¶ms.extension_id, constraints.as_ref())
136 .await?
137 .ok_or_else(|| anyhow!("unknown extension"))?;
138 download_extension(
139 Extension(app),
140 Path(DownloadExtensionParams {
141 extension_id: params.extension_id,
142 version: extension.manifest.version.to_string(),
143 }),
144 )
145 .await
146}
147
148#[derive(Debug, Deserialize)]
149struct DownloadExtensionParams {
150 extension_id: String,
151 version: String,
152}
153
154async fn download_extension(
155 Extension(app): Extension<Arc<AppState>>,
156 Path(params): Path<DownloadExtensionParams>,
157) -> Result<Redirect> {
158 let Some((blob_store_client, bucket)) = app
159 .blob_store_client
160 .clone()
161 .zip(app.config.blob_store_bucket.clone())
162 else {
163 Err(Error::Http(
164 StatusCode::NOT_IMPLEMENTED,
165 "not supported".into(),
166 ))?
167 };
168
169 let DownloadExtensionParams {
170 extension_id,
171 version,
172 } = params;
173
174 let version_exists = app
175 .db
176 .record_extension_download(&extension_id, &version)
177 .await?;
178
179 if !version_exists {
180 Err(Error::Http(
181 StatusCode::NOT_FOUND,
182 "unknown extension version".into(),
183 ))?;
184 }
185
186 let url = blob_store_client
187 .get_object()
188 .bucket(bucket)
189 .key(format!(
190 "extensions/{extension_id}/{version}/archive.tar.gz"
191 ))
192 .presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
193 .await
194 .map_err(|e| anyhow!("failed to create presigned extension download url {e}"))?;
195
196 Ok(Redirect::temporary(url.uri()))
197}
198
199const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
200const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
201
202pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
203 let Some(blob_store_client) = app_state.blob_store_client.clone() else {
204 log::info!("no blob store client");
205 return;
206 };
207 let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
208 log::info!("no blob store bucket");
209 return;
210 };
211
212 let executor = app_state.executor.clone();
213 executor.spawn_detached({
214 let executor = executor.clone();
215 async move {
216 loop {
217 fetch_extensions_from_blob_store(
218 &blob_store_client,
219 &blob_store_bucket,
220 &app_state,
221 )
222 .await
223 .log_err();
224 executor.sleep(EXTENSION_FETCH_INTERVAL).await;
225 }
226 }
227 });
228}
229
230async fn fetch_extensions_from_blob_store(
231 blob_store_client: &aws_sdk_s3::Client,
232 blob_store_bucket: &String,
233 app_state: &Arc<AppState>,
234) -> anyhow::Result<()> {
235 log::info!("fetching extensions from blob store");
236
237 let mut next_marker = None;
238 let mut published_versions = HashMap::<String, Vec<String>>::default();
239
240 loop {
241 let list = blob_store_client
242 .list_objects()
243 .bucket(blob_store_bucket)
244 .prefix("extensions/")
245 .set_marker(next_marker.clone())
246 .send()
247 .await?;
248 let objects = list.contents.unwrap_or_default();
249 log::info!("fetched {} object(s) from blob store", objects.len());
250
251 for object in &objects {
252 let Some(key) = object.key.as_ref() else {
253 continue;
254 };
255 let mut parts = key.split('/');
256 let Some(_) = parts.next().filter(|part| *part == "extensions") else {
257 continue;
258 };
259 let Some(extension_id) = parts.next() else {
260 continue;
261 };
262 let Some(version) = parts.next() else {
263 continue;
264 };
265 if parts.next() == Some("manifest.json") {
266 published_versions
267 .entry(extension_id.to_owned())
268 .or_default()
269 .push(version.to_owned());
270 }
271 }
272
273 if let (Some(true), Some(last_object)) = (list.is_truncated, objects.last()) {
274 next_marker.clone_from(&last_object.key);
275 } else {
276 break;
277 }
278 }
279
280 log::info!("found {} published extensions", published_versions.len());
281
282 let known_versions = app_state.db.get_known_extension_versions().await?;
283
284 let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
285 let empty = Vec::new();
286 for (extension_id, published_versions) in &published_versions {
287 let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
288
289 for published_version in published_versions {
290 if known_versions
291 .binary_search_by_key(&published_version, |known_version| known_version)
292 .is_err()
293 {
294 if let Some(extension) = fetch_extension_manifest(
295 blob_store_client,
296 blob_store_bucket,
297 &extension_id,
298 &published_version,
299 )
300 .await
301 .log_err()
302 {
303 new_versions
304 .entry(&extension_id)
305 .or_default()
306 .push(extension);
307 }
308 }
309 }
310 }
311
312 app_state
313 .db
314 .insert_extension_versions(&new_versions)
315 .await?;
316
317 log::info!(
318 "fetched {} new extensions from blob store",
319 new_versions.values().map(|v| v.len()).sum::<usize>()
320 );
321
322 Ok(())
323}
324
325async fn fetch_extension_manifest(
326 blob_store_client: &aws_sdk_s3::Client,
327 blob_store_bucket: &String,
328 extension_id: &str,
329 version: &str,
330) -> Result<NewExtensionVersion, anyhow::Error> {
331 let object = blob_store_client
332 .get_object()
333 .bucket(blob_store_bucket)
334 .key(format!("extensions/{extension_id}/{version}/manifest.json"))
335 .send()
336 .await?;
337 let manifest_bytes = object
338 .body
339 .collect()
340 .await
341 .map(|data| data.into_bytes())
342 .with_context(|| {
343 format!("failed to download manifest for extension {extension_id} version {version}")
344 })?
345 .to_vec();
346 let manifest =
347 serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
348 format!(
349 "invalid manifest for extension {extension_id} version {version}: {}",
350 String::from_utf8_lossy(&manifest_bytes)
351 )
352 })?;
353 let published_at = object.last_modified.ok_or_else(|| {
354 anyhow!("missing last modified timestamp for extension {extension_id} version {version}")
355 })?;
356 let published_at = time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
357 let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
358 let version = semver::Version::parse(&manifest.version).with_context(|| {
359 format!("invalid version for extension {extension_id} version {version}")
360 })?;
361 Ok(NewExtensionVersion {
362 name: manifest.name,
363 version,
364 description: manifest.description.unwrap_or_default(),
365 authors: manifest.authors,
366 repository: manifest.repository,
367 schema_version: manifest.schema_version.unwrap_or(0),
368 wasm_api_version: manifest.wasm_api_version,
369 published_at,
370 })
371}