1use crate::db::ExtensionVersionConstraints;
2use crate::{db::NewExtensionVersion, AppState, Error, Result};
3use anyhow::{anyhow, Context as _};
4use aws_sdk_s3::presigning::PresigningConfig;
5use axum::{
6 extract::{Path, Query},
7 http::StatusCode,
8 response::Redirect,
9 routing::get,
10 Extension, Json, Router,
11};
12use collections::HashMap;
13use rpc::{ExtensionApiManifest, GetExtensionsResponse};
14use semantic_version::SemanticVersion;
15use serde::Deserialize;
16use std::{sync::Arc, time::Duration};
17use time::PrimitiveDateTime;
18use util::{maybe, ResultExt};
19
20pub fn router() -> Router {
21 Router::new()
22 .route("/extensions", get(get_extensions))
23 .route("/extensions/updates", get(get_extension_updates))
24 .route("/extensions/:extension_id", get(get_extension_versions))
25 .route(
26 "/extensions/:extension_id/download",
27 get(download_latest_extension),
28 )
29 .route(
30 "/extensions/:extension_id/:version/download",
31 get(download_extension),
32 )
33}
34
35#[derive(Debug, Deserialize)]
36struct GetExtensionsParams {
37 filter: Option<String>,
38 #[serde(default)]
39 max_schema_version: i32,
40}
41
42async fn get_extensions(
43 Extension(app): Extension<Arc<AppState>>,
44 Query(params): Query<GetExtensionsParams>,
45) -> Result<Json<GetExtensionsResponse>> {
46 let mut extensions = app
47 .db
48 .get_extensions(params.filter.as_deref(), params.max_schema_version, 500)
49 .await?;
50
51 if let Some(filter) = params.filter.as_deref() {
52 let extension_id = filter.to_lowercase();
53 let mut exact_match = None;
54 extensions.retain(|extension| {
55 if extension.id.as_ref() == extension_id {
56 exact_match = Some(extension.clone());
57 false
58 } else {
59 true
60 }
61 });
62 if exact_match.is_none() {
63 exact_match = app
64 .db
65 .get_extensions_by_ids(&[&extension_id], None)
66 .await?
67 .first()
68 .cloned();
69 }
70
71 if let Some(exact_match) = exact_match {
72 extensions.insert(0, exact_match);
73 }
74 };
75
76 if let Some(query) = params.filter.as_deref() {
77 let count = extensions.len();
78 tracing::info!(query, count, "extension_search")
79 }
80
81 Ok(Json(GetExtensionsResponse { data: extensions }))
82}
83
84#[derive(Debug, Deserialize)]
85struct GetExtensionUpdatesParams {
86 ids: String,
87 min_schema_version: i32,
88 max_schema_version: i32,
89 min_wasm_api_version: SemanticVersion,
90 max_wasm_api_version: SemanticVersion,
91}
92
93async fn get_extension_updates(
94 Extension(app): Extension<Arc<AppState>>,
95 Query(params): Query<GetExtensionUpdatesParams>,
96) -> Result<Json<GetExtensionsResponse>> {
97 let constraints = ExtensionVersionConstraints {
98 schema_versions: params.min_schema_version..=params.max_schema_version,
99 wasm_api_versions: params.min_wasm_api_version..=params.max_wasm_api_version,
100 };
101
102 let extension_ids = params.ids.split(',').map(|s| s.trim()).collect::<Vec<_>>();
103
104 let extensions = app
105 .db
106 .get_extensions_by_ids(&extension_ids, Some(&constraints))
107 .await?;
108
109 Ok(Json(GetExtensionsResponse { data: extensions }))
110}
111
112#[derive(Debug, Deserialize)]
113struct GetExtensionVersionsParams {
114 extension_id: String,
115}
116
117async fn get_extension_versions(
118 Extension(app): Extension<Arc<AppState>>,
119 Path(params): Path<GetExtensionVersionsParams>,
120) -> Result<Json<GetExtensionsResponse>> {
121 let extension_versions = app.db.get_extension_versions(¶ms.extension_id).await?;
122
123 Ok(Json(GetExtensionsResponse {
124 data: extension_versions,
125 }))
126}
127
128#[derive(Debug, Deserialize)]
129struct DownloadLatestExtensionPathParams {
130 extension_id: String,
131}
132
133#[derive(Debug, Deserialize)]
134struct DownloadLatestExtensionQueryParams {
135 min_schema_version: Option<i32>,
136 max_schema_version: Option<i32>,
137 min_wasm_api_version: Option<SemanticVersion>,
138 max_wasm_api_version: Option<SemanticVersion>,
139}
140
141async fn download_latest_extension(
142 Extension(app): Extension<Arc<AppState>>,
143 Path(params): Path<DownloadLatestExtensionPathParams>,
144 Query(query): Query<DownloadLatestExtensionQueryParams>,
145) -> Result<Redirect> {
146 let constraints = maybe!({
147 let min_schema_version = query.min_schema_version?;
148 let max_schema_version = query.max_schema_version?;
149 let min_wasm_api_version = query.min_wasm_api_version?;
150 let max_wasm_api_version = query.max_wasm_api_version?;
151
152 Some(ExtensionVersionConstraints {
153 schema_versions: min_schema_version..=max_schema_version,
154 wasm_api_versions: min_wasm_api_version..=max_wasm_api_version,
155 })
156 });
157
158 let extension = app
159 .db
160 .get_extension(¶ms.extension_id, constraints.as_ref())
161 .await?
162 .ok_or_else(|| anyhow!("unknown extension"))?;
163 download_extension(
164 Extension(app),
165 Path(DownloadExtensionParams {
166 extension_id: params.extension_id,
167 version: extension.manifest.version.to_string(),
168 }),
169 )
170 .await
171}
172
173#[derive(Debug, Deserialize)]
174struct DownloadExtensionParams {
175 extension_id: String,
176 version: String,
177}
178
179async fn download_extension(
180 Extension(app): Extension<Arc<AppState>>,
181 Path(params): Path<DownloadExtensionParams>,
182) -> Result<Redirect> {
183 let Some((blob_store_client, bucket)) = app
184 .blob_store_client
185 .clone()
186 .zip(app.config.blob_store_bucket.clone())
187 else {
188 Err(Error::http(
189 StatusCode::NOT_IMPLEMENTED,
190 "not supported".into(),
191 ))?
192 };
193
194 let DownloadExtensionParams {
195 extension_id,
196 version,
197 } = params;
198
199 let version_exists = app
200 .db
201 .record_extension_download(&extension_id, &version)
202 .await?;
203
204 if !version_exists {
205 Err(Error::http(
206 StatusCode::NOT_FOUND,
207 "unknown extension version".into(),
208 ))?;
209 }
210
211 let url = blob_store_client
212 .get_object()
213 .bucket(bucket)
214 .key(format!(
215 "extensions/{extension_id}/{version}/archive.tar.gz"
216 ))
217 .presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
218 .await
219 .map_err(|e| anyhow!("failed to create presigned extension download url {e}"))?;
220
221 Ok(Redirect::temporary(url.uri()))
222}
223
224const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
225const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
226
227pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
228 let Some(blob_store_client) = app_state.blob_store_client.clone() else {
229 log::info!("no blob store client");
230 return;
231 };
232 let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
233 log::info!("no blob store bucket");
234 return;
235 };
236
237 let executor = app_state.executor.clone();
238 executor.spawn_detached({
239 let executor = executor.clone();
240 async move {
241 loop {
242 fetch_extensions_from_blob_store(
243 &blob_store_client,
244 &blob_store_bucket,
245 &app_state,
246 )
247 .await
248 .log_err();
249 executor.sleep(EXTENSION_FETCH_INTERVAL).await;
250 }
251 }
252 });
253}
254
255async fn fetch_extensions_from_blob_store(
256 blob_store_client: &aws_sdk_s3::Client,
257 blob_store_bucket: &String,
258 app_state: &Arc<AppState>,
259) -> anyhow::Result<()> {
260 log::info!("fetching extensions from blob store");
261
262 let mut next_marker = None;
263 let mut published_versions = HashMap::<String, Vec<String>>::default();
264
265 loop {
266 let list = blob_store_client
267 .list_objects()
268 .bucket(blob_store_bucket)
269 .prefix("extensions/")
270 .set_marker(next_marker.clone())
271 .send()
272 .await?;
273 let objects = list.contents.unwrap_or_default();
274 log::info!("fetched {} object(s) from blob store", objects.len());
275
276 for object in &objects {
277 let Some(key) = object.key.as_ref() else {
278 continue;
279 };
280 let mut parts = key.split('/');
281 let Some(_) = parts.next().filter(|part| *part == "extensions") else {
282 continue;
283 };
284 let Some(extension_id) = parts.next() else {
285 continue;
286 };
287 let Some(version) = parts.next() else {
288 continue;
289 };
290 if parts.next() == Some("manifest.json") {
291 published_versions
292 .entry(extension_id.to_owned())
293 .or_default()
294 .push(version.to_owned());
295 }
296 }
297
298 if let (Some(true), Some(last_object)) = (list.is_truncated, objects.last()) {
299 next_marker.clone_from(&last_object.key);
300 } else {
301 break;
302 }
303 }
304
305 log::info!("found {} published extensions", published_versions.len());
306
307 let known_versions = app_state.db.get_known_extension_versions().await?;
308
309 let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
310 let empty = Vec::new();
311 for (extension_id, published_versions) in &published_versions {
312 let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
313
314 for published_version in published_versions {
315 if known_versions
316 .binary_search_by_key(&published_version, |known_version| known_version)
317 .is_err()
318 {
319 if let Some(extension) = fetch_extension_manifest(
320 blob_store_client,
321 blob_store_bucket,
322 extension_id,
323 published_version,
324 )
325 .await
326 .log_err()
327 {
328 new_versions
329 .entry(extension_id)
330 .or_default()
331 .push(extension);
332 }
333 }
334 }
335 }
336
337 app_state
338 .db
339 .insert_extension_versions(&new_versions)
340 .await?;
341
342 log::info!(
343 "fetched {} new extensions from blob store",
344 new_versions.values().map(|v| v.len()).sum::<usize>()
345 );
346
347 Ok(())
348}
349
350async fn fetch_extension_manifest(
351 blob_store_client: &aws_sdk_s3::Client,
352 blob_store_bucket: &String,
353 extension_id: &str,
354 version: &str,
355) -> Result<NewExtensionVersion, anyhow::Error> {
356 let object = blob_store_client
357 .get_object()
358 .bucket(blob_store_bucket)
359 .key(format!("extensions/{extension_id}/{version}/manifest.json"))
360 .send()
361 .await?;
362 let manifest_bytes = object
363 .body
364 .collect()
365 .await
366 .map(|data| data.into_bytes())
367 .with_context(|| {
368 format!("failed to download manifest for extension {extension_id} version {version}")
369 })?
370 .to_vec();
371 let manifest =
372 serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
373 format!(
374 "invalid manifest for extension {extension_id} version {version}: {}",
375 String::from_utf8_lossy(&manifest_bytes)
376 )
377 })?;
378 let published_at = object.last_modified.ok_or_else(|| {
379 anyhow!("missing last modified timestamp for extension {extension_id} version {version}")
380 })?;
381 let published_at = time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
382 let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
383 let version = semver::Version::parse(&manifest.version).with_context(|| {
384 format!("invalid version for extension {extension_id} version {version}")
385 })?;
386 Ok(NewExtensionVersion {
387 name: manifest.name,
388 version,
389 description: manifest.description.unwrap_or_default(),
390 authors: manifest.authors,
391 repository: manifest.repository,
392 schema_version: manifest.schema_version.unwrap_or(0),
393 wasm_api_version: manifest.wasm_api_version,
394 provides: manifest.provides,
395 published_at,
396 })
397}