extensions.rs

  1use crate::db::ExtensionVersionConstraints;
  2use crate::{AppState, Error, Result, db::NewExtensionVersion};
  3use anyhow::Context as _;
  4use aws_sdk_s3::presigning::PresigningConfig;
  5use axum::{
  6    Extension, Json, Router,
  7    extract::{Path, Query},
  8    http::StatusCode,
  9    response::Redirect,
 10    routing::get,
 11};
 12use cloud_api_types::{ExtensionApiManifest, GetExtensionsResponse};
 13use collections::HashMap;
 14use semver::Version as SemanticVersion;
 15use serde::Deserialize;
 16use std::{sync::Arc, time::Duration};
 17use time::PrimitiveDateTime;
 18use util::{ResultExt, maybe};
 19
 20pub fn router() -> Router {
 21    Router::new()
 22        .route("/extensions/updates", get(get_extension_updates))
 23        .route("/extensions/:extension_id", get(get_extension_versions))
 24        .route(
 25            "/extensions/:extension_id/download",
 26            get(download_latest_extension),
 27        )
 28        .route(
 29            "/extensions/:extension_id/:version/download",
 30            get(download_extension),
 31        )
 32}
 33
 34#[derive(Debug, Deserialize)]
 35struct GetExtensionUpdatesParams {
 36    ids: String,
 37    min_schema_version: i32,
 38    max_schema_version: i32,
 39    min_wasm_api_version: semver::Version,
 40    max_wasm_api_version: semver::Version,
 41}
 42
 43async fn get_extension_updates(
 44    Extension(app): Extension<Arc<AppState>>,
 45    Query(params): Query<GetExtensionUpdatesParams>,
 46) -> Result<Json<GetExtensionsResponse>> {
 47    let constraints = ExtensionVersionConstraints {
 48        schema_versions: params.min_schema_version..=params.max_schema_version,
 49        wasm_api_versions: params.min_wasm_api_version..=params.max_wasm_api_version,
 50    };
 51
 52    let extension_ids = params.ids.split(',').map(|s| s.trim()).collect::<Vec<_>>();
 53
 54    let extensions = app
 55        .db
 56        .get_extensions_by_ids(&extension_ids, Some(&constraints))
 57        .await?;
 58
 59    Ok(Json(GetExtensionsResponse { data: extensions }))
 60}
 61
 62#[derive(Debug, Deserialize)]
 63struct GetExtensionVersionsParams {
 64    extension_id: String,
 65}
 66
 67async fn get_extension_versions(
 68    Extension(app): Extension<Arc<AppState>>,
 69    Path(params): Path<GetExtensionVersionsParams>,
 70) -> Result<Json<GetExtensionsResponse>> {
 71    let extension_versions = app.db.get_extension_versions(&params.extension_id).await?;
 72
 73    Ok(Json(GetExtensionsResponse {
 74        data: extension_versions,
 75    }))
 76}
 77
 78#[derive(Debug, Deserialize)]
 79struct DownloadLatestExtensionPathParams {
 80    extension_id: String,
 81}
 82
 83#[derive(Debug, Deserialize)]
 84struct DownloadLatestExtensionQueryParams {
 85    min_schema_version: Option<i32>,
 86    max_schema_version: Option<i32>,
 87    min_wasm_api_version: Option<SemanticVersion>,
 88    max_wasm_api_version: Option<SemanticVersion>,
 89}
 90
 91async fn download_latest_extension(
 92    Extension(app): Extension<Arc<AppState>>,
 93    Path(params): Path<DownloadLatestExtensionPathParams>,
 94    Query(query): Query<DownloadLatestExtensionQueryParams>,
 95) -> Result<Redirect> {
 96    let constraints = maybe!({
 97        let min_schema_version = query.min_schema_version?;
 98        let max_schema_version = query.max_schema_version?;
 99        let min_wasm_api_version = query.min_wasm_api_version?;
100        let max_wasm_api_version = query.max_wasm_api_version?;
101
102        Some(ExtensionVersionConstraints {
103            schema_versions: min_schema_version..=max_schema_version,
104            wasm_api_versions: min_wasm_api_version..=max_wasm_api_version,
105        })
106    });
107
108    let extension = app
109        .db
110        .get_extension(&params.extension_id, constraints.as_ref())
111        .await?
112        .context("unknown extension")?;
113    download_extension(
114        Extension(app),
115        Path(DownloadExtensionParams {
116            extension_id: params.extension_id,
117            version: extension.manifest.version.to_string(),
118        }),
119    )
120    .await
121}
122
123#[derive(Debug, Deserialize)]
124struct DownloadExtensionParams {
125    extension_id: String,
126    version: String,
127}
128
129async fn download_extension(
130    Extension(app): Extension<Arc<AppState>>,
131    Path(params): Path<DownloadExtensionParams>,
132) -> Result<Redirect> {
133    let Some((blob_store_client, bucket)) = app
134        .blob_store_client
135        .clone()
136        .zip(app.config.blob_store_bucket.clone())
137    else {
138        Err(Error::http(
139            StatusCode::NOT_IMPLEMENTED,
140            "not supported".into(),
141        ))?
142    };
143
144    let DownloadExtensionParams {
145        extension_id,
146        version,
147    } = params;
148
149    let version_exists = app
150        .db
151        .record_extension_download(&extension_id, &version)
152        .await?;
153
154    if !version_exists {
155        Err(Error::http(
156            StatusCode::NOT_FOUND,
157            "unknown extension version".into(),
158        ))?;
159    }
160
161    let url = blob_store_client
162        .get_object()
163        .bucket(bucket)
164        .key(format!(
165            "extensions/{extension_id}/{version}/archive.tar.gz"
166        ))
167        .presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
168        .await
169        .context("creating presigned extension download url")?;
170
171    Ok(Redirect::temporary(url.uri()))
172}
173
174const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
175const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
176
177pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
178    let Some(blob_store_client) = app_state.blob_store_client.clone() else {
179        log::info!("no blob store client");
180        return;
181    };
182    let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
183        log::info!("no blob store bucket");
184        return;
185    };
186
187    let executor = app_state.executor.clone();
188    executor.spawn_detached({
189        let executor = executor.clone();
190        async move {
191            loop {
192                fetch_extensions_from_blob_store(
193                    &blob_store_client,
194                    &blob_store_bucket,
195                    &app_state,
196                )
197                .await
198                .log_err();
199                executor.sleep(EXTENSION_FETCH_INTERVAL).await;
200            }
201        }
202    });
203}
204
205async fn fetch_extensions_from_blob_store(
206    blob_store_client: &aws_sdk_s3::Client,
207    blob_store_bucket: &String,
208    app_state: &Arc<AppState>,
209) -> anyhow::Result<()> {
210    log::info!("fetching extensions from blob store");
211
212    let mut next_marker = None;
213    let mut published_versions = HashMap::<String, Vec<String>>::default();
214
215    loop {
216        let list = blob_store_client
217            .list_objects()
218            .bucket(blob_store_bucket)
219            .prefix("extensions/")
220            .set_marker(next_marker.clone())
221            .send()
222            .await?;
223        let objects = list.contents.unwrap_or_default();
224        log::info!("fetched {} object(s) from blob store", objects.len());
225
226        for object in &objects {
227            let Some(key) = object.key.as_ref() else {
228                continue;
229            };
230            let mut parts = key.split('/');
231            let Some(_) = parts.next().filter(|part| *part == "extensions") else {
232                continue;
233            };
234            let Some(extension_id) = parts.next() else {
235                continue;
236            };
237            let Some(version) = parts.next() else {
238                continue;
239            };
240            if parts.next() == Some("manifest.json") {
241                published_versions
242                    .entry(extension_id.to_owned())
243                    .or_default()
244                    .push(version.to_owned());
245            }
246        }
247
248        if let (Some(true), Some(last_object)) = (list.is_truncated, objects.last()) {
249            next_marker.clone_from(&last_object.key);
250        } else {
251            break;
252        }
253    }
254
255    log::info!("found {} published extensions", published_versions.len());
256
257    let known_versions = app_state.db.get_known_extension_versions().await?;
258
259    let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
260    let empty = Vec::new();
261    for (extension_id, published_versions) in &published_versions {
262        let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
263
264        for published_version in published_versions {
265            if known_versions
266                .binary_search_by_key(&published_version, |known_version| known_version)
267                .is_err()
268                && let Some(extension) = fetch_extension_manifest(
269                    blob_store_client,
270                    blob_store_bucket,
271                    extension_id,
272                    published_version,
273                )
274                .await
275                .log_err()
276            {
277                new_versions
278                    .entry(extension_id)
279                    .or_default()
280                    .push(extension);
281            }
282        }
283    }
284
285    app_state
286        .db
287        .insert_extension_versions(&new_versions)
288        .await?;
289
290    log::info!(
291        "fetched {} new extensions from blob store",
292        new_versions.values().map(|v| v.len()).sum::<usize>()
293    );
294
295    Ok(())
296}
297
298async fn fetch_extension_manifest(
299    blob_store_client: &aws_sdk_s3::Client,
300    blob_store_bucket: &String,
301    extension_id: &str,
302    version: &str,
303) -> anyhow::Result<NewExtensionVersion> {
304    let object = blob_store_client
305        .get_object()
306        .bucket(blob_store_bucket)
307        .key(format!("extensions/{extension_id}/{version}/manifest.json"))
308        .send()
309        .await?;
310    let manifest_bytes = object
311        .body
312        .collect()
313        .await
314        .map(|data| data.into_bytes())
315        .with_context(|| {
316            format!("failed to download manifest for extension {extension_id} version {version}")
317        })?
318        .to_vec();
319    let manifest =
320        serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
321            format!(
322                "invalid manifest for extension {extension_id} version {version}: {}",
323                String::from_utf8_lossy(&manifest_bytes)
324            )
325        })?;
326    let published_at = object.last_modified.with_context(|| {
327        format!("missing last modified timestamp for extension {extension_id} version {version}")
328    })?;
329    let published_at = time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
330    let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
331    let version = semver::Version::parse(&manifest.version).with_context(|| {
332        format!("invalid version for extension {extension_id} version {version}")
333    })?;
334    Ok(NewExtensionVersion {
335        name: manifest.name,
336        version,
337        description: manifest.description.unwrap_or_default(),
338        authors: manifest.authors,
339        repository: manifest.repository,
340        schema_version: manifest.schema_version.unwrap_or(0),
341        wasm_api_version: manifest.wasm_api_version,
342        provides: manifest.provides,
343        published_at,
344    })
345}