1use crate::db::ExtensionVersionConstraints;
2use crate::{db::NewExtensionVersion, AppState, Error, Result};
3use anyhow::{anyhow, Context as _};
4use aws_sdk_s3::presigning::PresigningConfig;
5use axum::{
6 extract::{Path, Query},
7 http::StatusCode,
8 response::Redirect,
9 routing::get,
10 Extension, Json, Router,
11};
12use collections::HashMap;
13use rpc::{ExtensionApiManifest, ExtensionMetadata, GetExtensionsResponse};
14use semantic_version::SemanticVersion;
15use serde::Deserialize;
16use std::{sync::Arc, time::Duration};
17use time::PrimitiveDateTime;
18use util::{maybe, ResultExt};
19
20pub fn router() -> Router {
21 Router::new()
22 .route("/extensions", get(get_extensions))
23 .route("/extensions/updates", get(get_extension_updates))
24 .route("/extensions/:extension_id", get(get_extension_versions))
25 .route(
26 "/extensions/:extension_id/download",
27 get(download_latest_extension),
28 )
29 .route(
30 "/extensions/:extension_id/:version/download",
31 get(download_extension),
32 )
33}
34
35#[derive(Debug, Deserialize)]
36struct GetExtensionsParams {
37 filter: Option<String>,
38 #[serde(default)]
39 max_schema_version: i32,
40}
41
42async fn get_extensions(
43 Extension(app): Extension<Arc<AppState>>,
44 Query(params): Query<GetExtensionsParams>,
45) -> Result<Json<GetExtensionsResponse>> {
46 let mut extensions = app
47 .db
48 .get_extensions(params.filter.as_deref(), params.max_schema_version, 500)
49 .await?;
50
51 if let Some(query) = params.filter.as_deref() {
52 let count = extensions.len();
53 tracing::info!(query, count, "extension_search")
54 }
55
56 if let Some(filter) = params.filter.as_deref() {
57 let mut exact_match: Option<ExtensionMetadata> = None;
58 extensions.retain(|extension| {
59 exact_match = Some(extension.clone());
60 extension.id.as_ref() != &filter.to_lowercase()
61 });
62 if exact_match == None {
63 exact_match = app
64 .db
65 .get_extensions_by_ids(&[&filter.to_lowercase()], None)
66 .await?
67 .first()
68 .cloned();
69 }
70 extensions.splice(0..0, exact_match);
71 };
72
73 Ok(Json(GetExtensionsResponse { data: extensions }))
74}
75
76#[derive(Debug, Deserialize)]
77struct GetExtensionUpdatesParams {
78 ids: String,
79 min_schema_version: i32,
80 max_schema_version: i32,
81 min_wasm_api_version: SemanticVersion,
82 max_wasm_api_version: SemanticVersion,
83}
84
85async fn get_extension_updates(
86 Extension(app): Extension<Arc<AppState>>,
87 Query(params): Query<GetExtensionUpdatesParams>,
88) -> Result<Json<GetExtensionsResponse>> {
89 let constraints = ExtensionVersionConstraints {
90 schema_versions: params.min_schema_version..=params.max_schema_version,
91 wasm_api_versions: params.min_wasm_api_version..=params.max_wasm_api_version,
92 };
93
94 let extension_ids = params.ids.split(',').map(|s| s.trim()).collect::<Vec<_>>();
95
96 let extensions = app
97 .db
98 .get_extensions_by_ids(&extension_ids, Some(&constraints))
99 .await?;
100
101 Ok(Json(GetExtensionsResponse { data: extensions }))
102}
103
104#[derive(Debug, Deserialize)]
105struct GetExtensionVersionsParams {
106 extension_id: String,
107}
108
109async fn get_extension_versions(
110 Extension(app): Extension<Arc<AppState>>,
111 Path(params): Path<GetExtensionVersionsParams>,
112) -> Result<Json<GetExtensionsResponse>> {
113 let extension_versions = app.db.get_extension_versions(¶ms.extension_id).await?;
114
115 Ok(Json(GetExtensionsResponse {
116 data: extension_versions,
117 }))
118}
119
120#[derive(Debug, Deserialize)]
121struct DownloadLatestExtensionPathParams {
122 extension_id: String,
123}
124
125#[derive(Debug, Deserialize)]
126struct DownloadLatestExtensionQueryParams {
127 min_schema_version: Option<i32>,
128 max_schema_version: Option<i32>,
129 min_wasm_api_version: Option<SemanticVersion>,
130 max_wasm_api_version: Option<SemanticVersion>,
131}
132
133async fn download_latest_extension(
134 Extension(app): Extension<Arc<AppState>>,
135 Path(params): Path<DownloadLatestExtensionPathParams>,
136 Query(query): Query<DownloadLatestExtensionQueryParams>,
137) -> Result<Redirect> {
138 let constraints = maybe!({
139 let min_schema_version = query.min_schema_version?;
140 let max_schema_version = query.max_schema_version?;
141 let min_wasm_api_version = query.min_wasm_api_version?;
142 let max_wasm_api_version = query.max_wasm_api_version?;
143
144 Some(ExtensionVersionConstraints {
145 schema_versions: min_schema_version..=max_schema_version,
146 wasm_api_versions: min_wasm_api_version..=max_wasm_api_version,
147 })
148 });
149
150 let extension = app
151 .db
152 .get_extension(¶ms.extension_id, constraints.as_ref())
153 .await?
154 .ok_or_else(|| anyhow!("unknown extension"))?;
155 download_extension(
156 Extension(app),
157 Path(DownloadExtensionParams {
158 extension_id: params.extension_id,
159 version: extension.manifest.version.to_string(),
160 }),
161 )
162 .await
163}
164
165#[derive(Debug, Deserialize)]
166struct DownloadExtensionParams {
167 extension_id: String,
168 version: String,
169}
170
171async fn download_extension(
172 Extension(app): Extension<Arc<AppState>>,
173 Path(params): Path<DownloadExtensionParams>,
174) -> Result<Redirect> {
175 let Some((blob_store_client, bucket)) = app
176 .blob_store_client
177 .clone()
178 .zip(app.config.blob_store_bucket.clone())
179 else {
180 Err(Error::Http(
181 StatusCode::NOT_IMPLEMENTED,
182 "not supported".into(),
183 ))?
184 };
185
186 let DownloadExtensionParams {
187 extension_id,
188 version,
189 } = params;
190
191 let version_exists = app
192 .db
193 .record_extension_download(&extension_id, &version)
194 .await?;
195
196 if !version_exists {
197 Err(Error::Http(
198 StatusCode::NOT_FOUND,
199 "unknown extension version".into(),
200 ))?;
201 }
202
203 let url = blob_store_client
204 .get_object()
205 .bucket(bucket)
206 .key(format!(
207 "extensions/{extension_id}/{version}/archive.tar.gz"
208 ))
209 .presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
210 .await
211 .map_err(|e| anyhow!("failed to create presigned extension download url {e}"))?;
212
213 Ok(Redirect::temporary(url.uri()))
214}
215
216const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
217const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
218
219pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
220 let Some(blob_store_client) = app_state.blob_store_client.clone() else {
221 log::info!("no blob store client");
222 return;
223 };
224 let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
225 log::info!("no blob store bucket");
226 return;
227 };
228
229 let executor = app_state.executor.clone();
230 executor.spawn_detached({
231 let executor = executor.clone();
232 async move {
233 loop {
234 fetch_extensions_from_blob_store(
235 &blob_store_client,
236 &blob_store_bucket,
237 &app_state,
238 )
239 .await
240 .log_err();
241 executor.sleep(EXTENSION_FETCH_INTERVAL).await;
242 }
243 }
244 });
245}
246
247async fn fetch_extensions_from_blob_store(
248 blob_store_client: &aws_sdk_s3::Client,
249 blob_store_bucket: &String,
250 app_state: &Arc<AppState>,
251) -> anyhow::Result<()> {
252 log::info!("fetching extensions from blob store");
253
254 let mut next_marker = None;
255 let mut published_versions = HashMap::<String, Vec<String>>::default();
256
257 loop {
258 let list = blob_store_client
259 .list_objects()
260 .bucket(blob_store_bucket)
261 .prefix("extensions/")
262 .set_marker(next_marker.clone())
263 .send()
264 .await?;
265 let objects = list.contents.unwrap_or_default();
266 log::info!("fetched {} object(s) from blob store", objects.len());
267
268 for object in &objects {
269 let Some(key) = object.key.as_ref() else {
270 continue;
271 };
272 let mut parts = key.split('/');
273 let Some(_) = parts.next().filter(|part| *part == "extensions") else {
274 continue;
275 };
276 let Some(extension_id) = parts.next() else {
277 continue;
278 };
279 let Some(version) = parts.next() else {
280 continue;
281 };
282 if parts.next() == Some("manifest.json") {
283 published_versions
284 .entry(extension_id.to_owned())
285 .or_default()
286 .push(version.to_owned());
287 }
288 }
289
290 if let (Some(true), Some(last_object)) = (list.is_truncated, objects.last()) {
291 next_marker.clone_from(&last_object.key);
292 } else {
293 break;
294 }
295 }
296
297 log::info!("found {} published extensions", published_versions.len());
298
299 let known_versions = app_state.db.get_known_extension_versions().await?;
300
301 let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
302 let empty = Vec::new();
303 for (extension_id, published_versions) in &published_versions {
304 let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
305
306 for published_version in published_versions {
307 if known_versions
308 .binary_search_by_key(&published_version, |known_version| known_version)
309 .is_err()
310 {
311 if let Some(extension) = fetch_extension_manifest(
312 blob_store_client,
313 blob_store_bucket,
314 &extension_id,
315 &published_version,
316 )
317 .await
318 .log_err()
319 {
320 new_versions
321 .entry(&extension_id)
322 .or_default()
323 .push(extension);
324 }
325 }
326 }
327 }
328
329 app_state
330 .db
331 .insert_extension_versions(&new_versions)
332 .await?;
333
334 log::info!(
335 "fetched {} new extensions from blob store",
336 new_versions.values().map(|v| v.len()).sum::<usize>()
337 );
338
339 Ok(())
340}
341
342async fn fetch_extension_manifest(
343 blob_store_client: &aws_sdk_s3::Client,
344 blob_store_bucket: &String,
345 extension_id: &str,
346 version: &str,
347) -> Result<NewExtensionVersion, anyhow::Error> {
348 let object = blob_store_client
349 .get_object()
350 .bucket(blob_store_bucket)
351 .key(format!("extensions/{extension_id}/{version}/manifest.json"))
352 .send()
353 .await?;
354 let manifest_bytes = object
355 .body
356 .collect()
357 .await
358 .map(|data| data.into_bytes())
359 .with_context(|| {
360 format!("failed to download manifest for extension {extension_id} version {version}")
361 })?
362 .to_vec();
363 let manifest =
364 serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
365 format!(
366 "invalid manifest for extension {extension_id} version {version}: {}",
367 String::from_utf8_lossy(&manifest_bytes)
368 )
369 })?;
370 let published_at = object.last_modified.ok_or_else(|| {
371 anyhow!("missing last modified timestamp for extension {extension_id} version {version}")
372 })?;
373 let published_at = time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
374 let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
375 let version = semver::Version::parse(&manifest.version).with_context(|| {
376 format!("invalid version for extension {extension_id} version {version}")
377 })?;
378 Ok(NewExtensionVersion {
379 name: manifest.name,
380 version,
381 description: manifest.description.unwrap_or_default(),
382 authors: manifest.authors,
383 repository: manifest.repository,
384 schema_version: manifest.schema_version.unwrap_or(0),
385 wasm_api_version: manifest.wasm_api_version,
386 published_at,
387 })
388}