extensions.rs

  1use crate::db::ExtensionVersionConstraints;
  2use crate::{db::NewExtensionVersion, AppState, Error, Result};
  3use anyhow::{anyhow, Context as _};
  4use aws_sdk_s3::presigning::PresigningConfig;
  5use axum::{
  6    extract::{Path, Query},
  7    http::StatusCode,
  8    response::Redirect,
  9    routing::get,
 10    Extension, Json, Router,
 11};
 12use collections::HashMap;
 13use rpc::{ExtensionApiManifest, GetExtensionsResponse};
 14use semantic_version::SemanticVersion;
 15use serde::Deserialize;
 16use std::{sync::Arc, time::Duration};
 17use time::PrimitiveDateTime;
 18use util::{maybe, ResultExt};
 19
 20pub fn router() -> Router {
 21    Router::new()
 22        .route("/extensions", get(get_extensions))
 23        .route("/extensions/updates", get(get_extension_updates))
 24        .route("/extensions/:extension_id", get(get_extension_versions))
 25        .route(
 26            "/extensions/:extension_id/download",
 27            get(download_latest_extension),
 28        )
 29        .route(
 30            "/extensions/:extension_id/:version/download",
 31            get(download_extension),
 32        )
 33}
 34
 35#[derive(Debug, Deserialize)]
 36struct GetExtensionsParams {
 37    filter: Option<String>,
 38    #[serde(default)]
 39    ids: Option<String>,
 40    #[serde(default)]
 41    max_schema_version: i32,
 42}
 43
 44async fn get_extensions(
 45    Extension(app): Extension<Arc<AppState>>,
 46    Query(params): Query<GetExtensionsParams>,
 47) -> Result<Json<GetExtensionsResponse>> {
 48    let extension_ids = params
 49        .ids
 50        .as_ref()
 51        .map(|s| s.split(',').map(|s| s.trim()).collect::<Vec<_>>());
 52
 53    let extensions = if let Some(extension_ids) = extension_ids {
 54        app.db.get_extensions_by_ids(&extension_ids, None).await?
 55    } else {
 56        let result = app
 57            .db
 58            .get_extensions(params.filter.as_deref(), params.max_schema_version, 500)
 59            .await?;
 60
 61        if let Some(query) = params.filter.as_deref() {
 62            let count = result.len();
 63            tracing::info!(query, count, "extension_search")
 64        }
 65
 66        result
 67    };
 68
 69    Ok(Json(GetExtensionsResponse { data: extensions }))
 70}
 71
 72#[derive(Debug, Deserialize)]
 73struct GetExtensionUpdatesParams {
 74    ids: String,
 75    min_schema_version: i32,
 76    max_schema_version: i32,
 77    min_wasm_api_version: SemanticVersion,
 78    max_wasm_api_version: SemanticVersion,
 79}
 80
 81async fn get_extension_updates(
 82    Extension(app): Extension<Arc<AppState>>,
 83    Query(params): Query<GetExtensionUpdatesParams>,
 84) -> Result<Json<GetExtensionsResponse>> {
 85    let constraints = ExtensionVersionConstraints {
 86        schema_versions: params.min_schema_version..=params.max_schema_version,
 87        wasm_api_versions: params.min_wasm_api_version..=params.max_wasm_api_version,
 88    };
 89
 90    let extension_ids = params.ids.split(',').map(|s| s.trim()).collect::<Vec<_>>();
 91
 92    let extensions = app
 93        .db
 94        .get_extensions_by_ids(&extension_ids, Some(&constraints))
 95        .await?;
 96
 97    Ok(Json(GetExtensionsResponse { data: extensions }))
 98}
 99
100#[derive(Debug, Deserialize)]
101struct GetExtensionVersionsParams {
102    extension_id: String,
103}
104
105async fn get_extension_versions(
106    Extension(app): Extension<Arc<AppState>>,
107    Path(params): Path<GetExtensionVersionsParams>,
108) -> Result<Json<GetExtensionsResponse>> {
109    let extension_versions = app.db.get_extension_versions(&params.extension_id).await?;
110
111    Ok(Json(GetExtensionsResponse {
112        data: extension_versions,
113    }))
114}
115
116#[derive(Debug, Deserialize)]
117struct DownloadLatestExtensionPathParams {
118    extension_id: String,
119}
120
121#[derive(Debug, Deserialize)]
122struct DownloadLatestExtensionQueryParams {
123    min_schema_version: Option<i32>,
124    max_schema_version: Option<i32>,
125    min_wasm_api_version: Option<SemanticVersion>,
126    max_wasm_api_version: Option<SemanticVersion>,
127}
128
129async fn download_latest_extension(
130    Extension(app): Extension<Arc<AppState>>,
131    Path(params): Path<DownloadLatestExtensionPathParams>,
132    Query(query): Query<DownloadLatestExtensionQueryParams>,
133) -> Result<Redirect> {
134    let constraints = maybe!({
135        let min_schema_version = query.min_schema_version?;
136        let max_schema_version = query.max_schema_version?;
137        let min_wasm_api_version = query.min_wasm_api_version?;
138        let max_wasm_api_version = query.max_wasm_api_version?;
139
140        Some(ExtensionVersionConstraints {
141            schema_versions: min_schema_version..=max_schema_version,
142            wasm_api_versions: min_wasm_api_version..=max_wasm_api_version,
143        })
144    });
145
146    let extension = app
147        .db
148        .get_extension(&params.extension_id, constraints.as_ref())
149        .await?
150        .ok_or_else(|| anyhow!("unknown extension"))?;
151    download_extension(
152        Extension(app),
153        Path(DownloadExtensionParams {
154            extension_id: params.extension_id,
155            version: extension.manifest.version.to_string(),
156        }),
157    )
158    .await
159}
160
161#[derive(Debug, Deserialize)]
162struct DownloadExtensionParams {
163    extension_id: String,
164    version: String,
165}
166
167async fn download_extension(
168    Extension(app): Extension<Arc<AppState>>,
169    Path(params): Path<DownloadExtensionParams>,
170) -> Result<Redirect> {
171    let Some((blob_store_client, bucket)) = app
172        .blob_store_client
173        .clone()
174        .zip(app.config.blob_store_bucket.clone())
175    else {
176        Err(Error::Http(
177            StatusCode::NOT_IMPLEMENTED,
178            "not supported".into(),
179        ))?
180    };
181
182    let DownloadExtensionParams {
183        extension_id,
184        version,
185    } = params;
186
187    let version_exists = app
188        .db
189        .record_extension_download(&extension_id, &version)
190        .await?;
191
192    if !version_exists {
193        Err(Error::Http(
194            StatusCode::NOT_FOUND,
195            "unknown extension version".into(),
196        ))?;
197    }
198
199    let url = blob_store_client
200        .get_object()
201        .bucket(bucket)
202        .key(format!(
203            "extensions/{extension_id}/{version}/archive.tar.gz"
204        ))
205        .presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
206        .await
207        .map_err(|e| anyhow!("failed to create presigned extension download url {e}"))?;
208
209    Ok(Redirect::temporary(url.uri()))
210}
211
212const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
213const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
214
215pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
216    let Some(blob_store_client) = app_state.blob_store_client.clone() else {
217        log::info!("no blob store client");
218        return;
219    };
220    let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
221        log::info!("no blob store bucket");
222        return;
223    };
224
225    let executor = app_state.executor.clone();
226    executor.spawn_detached({
227        let executor = executor.clone();
228        async move {
229            loop {
230                fetch_extensions_from_blob_store(
231                    &blob_store_client,
232                    &blob_store_bucket,
233                    &app_state,
234                )
235                .await
236                .log_err();
237                executor.sleep(EXTENSION_FETCH_INTERVAL).await;
238            }
239        }
240    });
241}
242
243async fn fetch_extensions_from_blob_store(
244    blob_store_client: &aws_sdk_s3::Client,
245    blob_store_bucket: &String,
246    app_state: &Arc<AppState>,
247) -> anyhow::Result<()> {
248    log::info!("fetching extensions from blob store");
249
250    let mut next_marker = None;
251    let mut published_versions = HashMap::<String, Vec<String>>::default();
252
253    loop {
254        let list = blob_store_client
255            .list_objects()
256            .bucket(blob_store_bucket)
257            .prefix("extensions/")
258            .set_marker(next_marker.clone())
259            .send()
260            .await?;
261        let objects = list.contents.unwrap_or_default();
262        log::info!("fetched {} object(s) from blob store", objects.len());
263
264        for object in &objects {
265            let Some(key) = object.key.as_ref() else {
266                continue;
267            };
268            let mut parts = key.split('/');
269            let Some(_) = parts.next().filter(|part| *part == "extensions") else {
270                continue;
271            };
272            let Some(extension_id) = parts.next() else {
273                continue;
274            };
275            let Some(version) = parts.next() else {
276                continue;
277            };
278            if parts.next() == Some("manifest.json") {
279                published_versions
280                    .entry(extension_id.to_owned())
281                    .or_default()
282                    .push(version.to_owned());
283            }
284        }
285
286        if let (Some(true), Some(last_object)) = (list.is_truncated, objects.last()) {
287            next_marker.clone_from(&last_object.key);
288        } else {
289            break;
290        }
291    }
292
293    log::info!("found {} published extensions", published_versions.len());
294
295    let known_versions = app_state.db.get_known_extension_versions().await?;
296
297    let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
298    let empty = Vec::new();
299    for (extension_id, published_versions) in &published_versions {
300        let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
301
302        for published_version in published_versions {
303            if known_versions
304                .binary_search_by_key(&published_version, |known_version| known_version)
305                .is_err()
306            {
307                if let Some(extension) = fetch_extension_manifest(
308                    blob_store_client,
309                    blob_store_bucket,
310                    &extension_id,
311                    &published_version,
312                )
313                .await
314                .log_err()
315                {
316                    new_versions
317                        .entry(&extension_id)
318                        .or_default()
319                        .push(extension);
320                }
321            }
322        }
323    }
324
325    app_state
326        .db
327        .insert_extension_versions(&new_versions)
328        .await?;
329
330    log::info!(
331        "fetched {} new extensions from blob store",
332        new_versions.values().map(|v| v.len()).sum::<usize>()
333    );
334
335    Ok(())
336}
337
338async fn fetch_extension_manifest(
339    blob_store_client: &aws_sdk_s3::Client,
340    blob_store_bucket: &String,
341    extension_id: &str,
342    version: &str,
343) -> Result<NewExtensionVersion, anyhow::Error> {
344    let object = blob_store_client
345        .get_object()
346        .bucket(blob_store_bucket)
347        .key(format!("extensions/{extension_id}/{version}/manifest.json"))
348        .send()
349        .await?;
350    let manifest_bytes = object
351        .body
352        .collect()
353        .await
354        .map(|data| data.into_bytes())
355        .with_context(|| {
356            format!("failed to download manifest for extension {extension_id} version {version}")
357        })?
358        .to_vec();
359    let manifest =
360        serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
361            format!(
362                "invalid manifest for extension {extension_id} version {version}: {}",
363                String::from_utf8_lossy(&manifest_bytes)
364            )
365        })?;
366    let published_at = object.last_modified.ok_or_else(|| {
367        anyhow!("missing last modified timestamp for extension {extension_id} version {version}")
368    })?;
369    let published_at = time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
370    let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
371    let version = semver::Version::parse(&manifest.version).with_context(|| {
372        format!("invalid version for extension {extension_id} version {version}")
373    })?;
374    Ok(NewExtensionVersion {
375        name: manifest.name,
376        version,
377        description: manifest.description.unwrap_or_default(),
378        authors: manifest.authors,
379        repository: manifest.repository,
380        schema_version: manifest.schema_version.unwrap_or(0),
381        wasm_api_version: manifest.wasm_api_version,
382        published_at,
383    })
384}