extensions.rs

  1use crate::db::ExtensionVersionConstraints;
  2use crate::{db::NewExtensionVersion, AppState, Error, Result};
  3use anyhow::{anyhow, Context as _};
  4use aws_sdk_s3::presigning::PresigningConfig;
  5use axum::{
  6    extract::{Path, Query},
  7    http::StatusCode,
  8    response::Redirect,
  9    routing::get,
 10    Extension, Json, Router,
 11};
 12use collections::HashMap;
 13use rpc::{ExtensionApiManifest, ExtensionMetadata, GetExtensionsResponse};
 14use semantic_version::SemanticVersion;
 15use serde::Deserialize;
 16use std::{sync::Arc, time::Duration};
 17use time::PrimitiveDateTime;
 18use util::{maybe, ResultExt};
 19
 20pub fn router() -> Router {
 21    Router::new()
 22        .route("/extensions", get(get_extensions))
 23        .route("/extensions/updates", get(get_extension_updates))
 24        .route("/extensions/:extension_id", get(get_extension_versions))
 25        .route(
 26            "/extensions/:extension_id/download",
 27            get(download_latest_extension),
 28        )
 29        .route(
 30            "/extensions/:extension_id/:version/download",
 31            get(download_extension),
 32        )
 33}
 34
 35#[derive(Debug, Deserialize)]
 36struct GetExtensionsParams {
 37    filter: Option<String>,
 38    #[serde(default)]
 39    max_schema_version: i32,
 40}
 41
 42async fn get_extensions(
 43    Extension(app): Extension<Arc<AppState>>,
 44    Query(params): Query<GetExtensionsParams>,
 45) -> Result<Json<GetExtensionsResponse>> {
 46    let mut extensions = app
 47        .db
 48        .get_extensions(params.filter.as_deref(), params.max_schema_version, 500)
 49        .await?;
 50
 51    if let Some(query) = params.filter.as_deref() {
 52        let count = extensions.len();
 53        tracing::info!(query, count, "extension_search")
 54    }
 55
 56    if let Some(filter) = params.filter.as_deref() {
 57        let mut exact_match: Option<ExtensionMetadata> = None;
 58        extensions.retain(|extension| {
 59            exact_match = Some(extension.clone());
 60            extension.id.as_ref() != &filter.to_lowercase()
 61        });
 62        if exact_match == None {
 63            exact_match = app
 64                .db
 65                .get_extensions_by_ids(&[&filter.to_lowercase()], None)
 66                .await?
 67                .first()
 68                .cloned();
 69        }
 70        extensions.splice(0..0, exact_match);
 71    };
 72
 73    Ok(Json(GetExtensionsResponse { data: extensions }))
 74}
 75
 76#[derive(Debug, Deserialize)]
 77struct GetExtensionUpdatesParams {
 78    ids: String,
 79    min_schema_version: i32,
 80    max_schema_version: i32,
 81    min_wasm_api_version: SemanticVersion,
 82    max_wasm_api_version: SemanticVersion,
 83}
 84
 85async fn get_extension_updates(
 86    Extension(app): Extension<Arc<AppState>>,
 87    Query(params): Query<GetExtensionUpdatesParams>,
 88) -> Result<Json<GetExtensionsResponse>> {
 89    let constraints = ExtensionVersionConstraints {
 90        schema_versions: params.min_schema_version..=params.max_schema_version,
 91        wasm_api_versions: params.min_wasm_api_version..=params.max_wasm_api_version,
 92    };
 93
 94    let extension_ids = params.ids.split(',').map(|s| s.trim()).collect::<Vec<_>>();
 95
 96    let extensions = app
 97        .db
 98        .get_extensions_by_ids(&extension_ids, Some(&constraints))
 99        .await?;
100
101    Ok(Json(GetExtensionsResponse { data: extensions }))
102}
103
104#[derive(Debug, Deserialize)]
105struct GetExtensionVersionsParams {
106    extension_id: String,
107}
108
109async fn get_extension_versions(
110    Extension(app): Extension<Arc<AppState>>,
111    Path(params): Path<GetExtensionVersionsParams>,
112) -> Result<Json<GetExtensionsResponse>> {
113    let extension_versions = app.db.get_extension_versions(&params.extension_id).await?;
114
115    Ok(Json(GetExtensionsResponse {
116        data: extension_versions,
117    }))
118}
119
120#[derive(Debug, Deserialize)]
121struct DownloadLatestExtensionPathParams {
122    extension_id: String,
123}
124
125#[derive(Debug, Deserialize)]
126struct DownloadLatestExtensionQueryParams {
127    min_schema_version: Option<i32>,
128    max_schema_version: Option<i32>,
129    min_wasm_api_version: Option<SemanticVersion>,
130    max_wasm_api_version: Option<SemanticVersion>,
131}
132
133async fn download_latest_extension(
134    Extension(app): Extension<Arc<AppState>>,
135    Path(params): Path<DownloadLatestExtensionPathParams>,
136    Query(query): Query<DownloadLatestExtensionQueryParams>,
137) -> Result<Redirect> {
138    let constraints = maybe!({
139        let min_schema_version = query.min_schema_version?;
140        let max_schema_version = query.max_schema_version?;
141        let min_wasm_api_version = query.min_wasm_api_version?;
142        let max_wasm_api_version = query.max_wasm_api_version?;
143
144        Some(ExtensionVersionConstraints {
145            schema_versions: min_schema_version..=max_schema_version,
146            wasm_api_versions: min_wasm_api_version..=max_wasm_api_version,
147        })
148    });
149
150    let extension = app
151        .db
152        .get_extension(&params.extension_id, constraints.as_ref())
153        .await?
154        .ok_or_else(|| anyhow!("unknown extension"))?;
155    download_extension(
156        Extension(app),
157        Path(DownloadExtensionParams {
158            extension_id: params.extension_id,
159            version: extension.manifest.version.to_string(),
160        }),
161    )
162    .await
163}
164
165#[derive(Debug, Deserialize)]
166struct DownloadExtensionParams {
167    extension_id: String,
168    version: String,
169}
170
171async fn download_extension(
172    Extension(app): Extension<Arc<AppState>>,
173    Path(params): Path<DownloadExtensionParams>,
174) -> Result<Redirect> {
175    let Some((blob_store_client, bucket)) = app
176        .blob_store_client
177        .clone()
178        .zip(app.config.blob_store_bucket.clone())
179    else {
180        Err(Error::Http(
181            StatusCode::NOT_IMPLEMENTED,
182            "not supported".into(),
183        ))?
184    };
185
186    let DownloadExtensionParams {
187        extension_id,
188        version,
189    } = params;
190
191    let version_exists = app
192        .db
193        .record_extension_download(&extension_id, &version)
194        .await?;
195
196    if !version_exists {
197        Err(Error::Http(
198            StatusCode::NOT_FOUND,
199            "unknown extension version".into(),
200        ))?;
201    }
202
203    let url = blob_store_client
204        .get_object()
205        .bucket(bucket)
206        .key(format!(
207            "extensions/{extension_id}/{version}/archive.tar.gz"
208        ))
209        .presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
210        .await
211        .map_err(|e| anyhow!("failed to create presigned extension download url {e}"))?;
212
213    Ok(Redirect::temporary(url.uri()))
214}
215
216const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
217const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
218
219pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
220    let Some(blob_store_client) = app_state.blob_store_client.clone() else {
221        log::info!("no blob store client");
222        return;
223    };
224    let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
225        log::info!("no blob store bucket");
226        return;
227    };
228
229    let executor = app_state.executor.clone();
230    executor.spawn_detached({
231        let executor = executor.clone();
232        async move {
233            loop {
234                fetch_extensions_from_blob_store(
235                    &blob_store_client,
236                    &blob_store_bucket,
237                    &app_state,
238                )
239                .await
240                .log_err();
241                executor.sleep(EXTENSION_FETCH_INTERVAL).await;
242            }
243        }
244    });
245}
246
247async fn fetch_extensions_from_blob_store(
248    blob_store_client: &aws_sdk_s3::Client,
249    blob_store_bucket: &String,
250    app_state: &Arc<AppState>,
251) -> anyhow::Result<()> {
252    log::info!("fetching extensions from blob store");
253
254    let mut next_marker = None;
255    let mut published_versions = HashMap::<String, Vec<String>>::default();
256
257    loop {
258        let list = blob_store_client
259            .list_objects()
260            .bucket(blob_store_bucket)
261            .prefix("extensions/")
262            .set_marker(next_marker.clone())
263            .send()
264            .await?;
265        let objects = list.contents.unwrap_or_default();
266        log::info!("fetched {} object(s) from blob store", objects.len());
267
268        for object in &objects {
269            let Some(key) = object.key.as_ref() else {
270                continue;
271            };
272            let mut parts = key.split('/');
273            let Some(_) = parts.next().filter(|part| *part == "extensions") else {
274                continue;
275            };
276            let Some(extension_id) = parts.next() else {
277                continue;
278            };
279            let Some(version) = parts.next() else {
280                continue;
281            };
282            if parts.next() == Some("manifest.json") {
283                published_versions
284                    .entry(extension_id.to_owned())
285                    .or_default()
286                    .push(version.to_owned());
287            }
288        }
289
290        if let (Some(true), Some(last_object)) = (list.is_truncated, objects.last()) {
291            next_marker.clone_from(&last_object.key);
292        } else {
293            break;
294        }
295    }
296
297    log::info!("found {} published extensions", published_versions.len());
298
299    let known_versions = app_state.db.get_known_extension_versions().await?;
300
301    let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
302    let empty = Vec::new();
303    for (extension_id, published_versions) in &published_versions {
304        let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
305
306        for published_version in published_versions {
307            if known_versions
308                .binary_search_by_key(&published_version, |known_version| known_version)
309                .is_err()
310            {
311                if let Some(extension) = fetch_extension_manifest(
312                    blob_store_client,
313                    blob_store_bucket,
314                    &extension_id,
315                    &published_version,
316                )
317                .await
318                .log_err()
319                {
320                    new_versions
321                        .entry(&extension_id)
322                        .or_default()
323                        .push(extension);
324                }
325            }
326        }
327    }
328
329    app_state
330        .db
331        .insert_extension_versions(&new_versions)
332        .await?;
333
334    log::info!(
335        "fetched {} new extensions from blob store",
336        new_versions.values().map(|v| v.len()).sum::<usize>()
337    );
338
339    Ok(())
340}
341
342async fn fetch_extension_manifest(
343    blob_store_client: &aws_sdk_s3::Client,
344    blob_store_bucket: &String,
345    extension_id: &str,
346    version: &str,
347) -> Result<NewExtensionVersion, anyhow::Error> {
348    let object = blob_store_client
349        .get_object()
350        .bucket(blob_store_bucket)
351        .key(format!("extensions/{extension_id}/{version}/manifest.json"))
352        .send()
353        .await?;
354    let manifest_bytes = object
355        .body
356        .collect()
357        .await
358        .map(|data| data.into_bytes())
359        .with_context(|| {
360            format!("failed to download manifest for extension {extension_id} version {version}")
361        })?
362        .to_vec();
363    let manifest =
364        serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
365            format!(
366                "invalid manifest for extension {extension_id} version {version}: {}",
367                String::from_utf8_lossy(&manifest_bytes)
368            )
369        })?;
370    let published_at = object.last_modified.ok_or_else(|| {
371        anyhow!("missing last modified timestamp for extension {extension_id} version {version}")
372    })?;
373    let published_at = time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
374    let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
375    let version = semver::Version::parse(&manifest.version).with_context(|| {
376        format!("invalid version for extension {extension_id} version {version}")
377    })?;
378    Ok(NewExtensionVersion {
379        name: manifest.name,
380        version,
381        description: manifest.description.unwrap_or_default(),
382        authors: manifest.authors,
383        repository: manifest.repository,
384        schema_version: manifest.schema_version.unwrap_or(0),
385        wasm_api_version: manifest.wasm_api_version,
386        published_at,
387    })
388}