1use crate::{db::NewExtensionVersion, AppState, Error, Result};
2use anyhow::{anyhow, Context as _};
3use aws_sdk_s3::presigning::PresigningConfig;
4use axum::{
5 extract::{Path, Query},
6 http::StatusCode,
7 response::Redirect,
8 routing::get,
9 Extension, Json, Router,
10};
11use collections::HashMap;
12use rpc::{ExtensionApiManifest, GetExtensionsResponse};
13use serde::Deserialize;
14use std::{sync::Arc, time::Duration};
15use time::PrimitiveDateTime;
16use util::ResultExt;
17
18pub fn router() -> Router {
19 Router::new()
20 .route("/extensions", get(get_extensions))
21 .route("/extensions/:extension_id", get(get_extension_versions))
22 .route(
23 "/extensions/:extension_id/download",
24 get(download_latest_extension),
25 )
26 .route(
27 "/extensions/:extension_id/:version/download",
28 get(download_extension),
29 )
30}
31
32#[derive(Debug, Deserialize)]
33struct GetExtensionsParams {
34 filter: Option<String>,
35 #[serde(default)]
36 ids: Option<String>,
37 #[serde(default)]
38 max_schema_version: i32,
39}
40
41async fn get_extensions(
42 Extension(app): Extension<Arc<AppState>>,
43 Query(params): Query<GetExtensionsParams>,
44) -> Result<Json<GetExtensionsResponse>> {
45 let extension_ids = params
46 .ids
47 .as_ref()
48 .map(|s| s.split(',').map(|s| s.trim()).collect::<Vec<_>>());
49
50 let extensions = if let Some(extension_ids) = extension_ids {
51 app.db
52 .get_extensions_by_ids(&extension_ids, params.max_schema_version)
53 .await?
54 } else {
55 app.db
56 .get_extensions(params.filter.as_deref(), params.max_schema_version, 500)
57 .await?
58 };
59
60 Ok(Json(GetExtensionsResponse { data: extensions }))
61}
62
63#[derive(Debug, Deserialize)]
64struct GetExtensionVersionsParams {
65 extension_id: String,
66}
67
68async fn get_extension_versions(
69 Extension(app): Extension<Arc<AppState>>,
70 Path(params): Path<GetExtensionVersionsParams>,
71) -> Result<Json<GetExtensionsResponse>> {
72 let extension_versions = app.db.get_extension_versions(¶ms.extension_id).await?;
73
74 Ok(Json(GetExtensionsResponse {
75 data: extension_versions,
76 }))
77}
78
79#[derive(Debug, Deserialize)]
80struct DownloadLatestExtensionParams {
81 extension_id: String,
82}
83
84async fn download_latest_extension(
85 Extension(app): Extension<Arc<AppState>>,
86 Path(params): Path<DownloadLatestExtensionParams>,
87) -> Result<Redirect> {
88 let extension = app
89 .db
90 .get_extension(¶ms.extension_id)
91 .await?
92 .ok_or_else(|| anyhow!("unknown extension"))?;
93 download_extension(
94 Extension(app),
95 Path(DownloadExtensionParams {
96 extension_id: params.extension_id,
97 version: extension.manifest.version.to_string(),
98 }),
99 )
100 .await
101}
102
103#[derive(Debug, Deserialize)]
104struct DownloadExtensionParams {
105 extension_id: String,
106 version: String,
107}
108
109async fn download_extension(
110 Extension(app): Extension<Arc<AppState>>,
111 Path(params): Path<DownloadExtensionParams>,
112) -> Result<Redirect> {
113 let Some((blob_store_client, bucket)) = app
114 .blob_store_client
115 .clone()
116 .zip(app.config.blob_store_bucket.clone())
117 else {
118 Err(Error::Http(
119 StatusCode::NOT_IMPLEMENTED,
120 "not supported".into(),
121 ))?
122 };
123
124 let DownloadExtensionParams {
125 extension_id,
126 version,
127 } = params;
128
129 let version_exists = app
130 .db
131 .record_extension_download(&extension_id, &version)
132 .await?;
133
134 if !version_exists {
135 Err(Error::Http(
136 StatusCode::NOT_FOUND,
137 "unknown extension version".into(),
138 ))?;
139 }
140
141 let url = blob_store_client
142 .get_object()
143 .bucket(bucket)
144 .key(format!(
145 "extensions/{extension_id}/{version}/archive.tar.gz"
146 ))
147 .presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
148 .await
149 .map_err(|e| anyhow!("failed to create presigned extension download url {e}"))?;
150
151 Ok(Redirect::temporary(url.uri()))
152}
153
154const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
155const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
156
157pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
158 let Some(blob_store_client) = app_state.blob_store_client.clone() else {
159 log::info!("no blob store client");
160 return;
161 };
162 let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
163 log::info!("no blob store bucket");
164 return;
165 };
166
167 let executor = app_state.executor.clone();
168 executor.spawn_detached({
169 let executor = executor.clone();
170 async move {
171 loop {
172 fetch_extensions_from_blob_store(
173 &blob_store_client,
174 &blob_store_bucket,
175 &app_state,
176 )
177 .await
178 .log_err();
179 executor.sleep(EXTENSION_FETCH_INTERVAL).await;
180 }
181 }
182 });
183}
184
185async fn fetch_extensions_from_blob_store(
186 blob_store_client: &aws_sdk_s3::Client,
187 blob_store_bucket: &String,
188 app_state: &Arc<AppState>,
189) -> anyhow::Result<()> {
190 log::info!("fetching extensions from blob store");
191
192 let list = blob_store_client
193 .list_objects()
194 .bucket(blob_store_bucket)
195 .prefix("extensions/")
196 .send()
197 .await?;
198
199 let objects = list.contents.unwrap_or_default();
200
201 let mut published_versions = HashMap::<&str, Vec<&str>>::default();
202 for object in &objects {
203 let Some(key) = object.key.as_ref() else {
204 continue;
205 };
206 let mut parts = key.split('/');
207 let Some(_) = parts.next().filter(|part| *part == "extensions") else {
208 continue;
209 };
210 let Some(extension_id) = parts.next() else {
211 continue;
212 };
213 let Some(version) = parts.next() else {
214 continue;
215 };
216 if parts.next() == Some("manifest.json") {
217 published_versions
218 .entry(extension_id)
219 .or_default()
220 .push(version);
221 }
222 }
223
224 let known_versions = app_state.db.get_known_extension_versions().await?;
225
226 let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
227 let empty = Vec::new();
228 for (extension_id, published_versions) in published_versions {
229 let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
230
231 for published_version in published_versions {
232 if known_versions
233 .binary_search_by_key(&published_version, String::as_str)
234 .is_err()
235 {
236 if let Some(extension) = fetch_extension_manifest(
237 blob_store_client,
238 blob_store_bucket,
239 extension_id,
240 published_version,
241 )
242 .await
243 .log_err()
244 {
245 new_versions
246 .entry(extension_id)
247 .or_default()
248 .push(extension);
249 }
250 }
251 }
252 }
253
254 app_state
255 .db
256 .insert_extension_versions(&new_versions)
257 .await?;
258
259 log::info!(
260 "fetched {} new extensions from blob store",
261 new_versions.values().map(|v| v.len()).sum::<usize>()
262 );
263
264 Ok(())
265}
266
267async fn fetch_extension_manifest(
268 blob_store_client: &aws_sdk_s3::Client,
269 blob_store_bucket: &String,
270 extension_id: &str,
271 version: &str,
272) -> Result<NewExtensionVersion, anyhow::Error> {
273 let object = blob_store_client
274 .get_object()
275 .bucket(blob_store_bucket)
276 .key(format!("extensions/{extension_id}/{version}/manifest.json"))
277 .send()
278 .await?;
279 let manifest_bytes = object
280 .body
281 .collect()
282 .await
283 .map(|data| data.into_bytes())
284 .with_context(|| {
285 format!("failed to download manifest for extension {extension_id} version {version}")
286 })?
287 .to_vec();
288 let manifest =
289 serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
290 format!(
291 "invalid manifest for extension {extension_id} version {version}: {}",
292 String::from_utf8_lossy(&manifest_bytes)
293 )
294 })?;
295 let published_at = object.last_modified.ok_or_else(|| {
296 anyhow!("missing last modified timestamp for extension {extension_id} version {version}")
297 })?;
298 let published_at = time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
299 let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
300 let version = semver::Version::parse(&manifest.version).with_context(|| {
301 format!("invalid version for extension {extension_id} version {version}")
302 })?;
303 Ok(NewExtensionVersion {
304 name: manifest.name,
305 version,
306 description: manifest.description.unwrap_or_default(),
307 authors: manifest.authors,
308 repository: manifest.repository,
309 schema_version: manifest.schema_version.unwrap_or(0),
310 wasm_api_version: manifest.wasm_api_version,
311 published_at,
312 })
313}