extensions.rs

  1use crate::db::ExtensionVersionConstraints;
  2use crate::{db::NewExtensionVersion, AppState, Error, Result};
  3use anyhow::{anyhow, Context as _};
  4use aws_sdk_s3::presigning::PresigningConfig;
  5use axum::{
  6    extract::{Path, Query},
  7    http::StatusCode,
  8    response::Redirect,
  9    routing::get,
 10    Extension, Json, Router,
 11};
 12use collections::HashMap;
 13use rpc::{ExtensionApiManifest, GetExtensionsResponse};
 14use semantic_version::SemanticVersion;
 15use serde::Deserialize;
 16use std::{sync::Arc, time::Duration};
 17use time::PrimitiveDateTime;
 18use util::{maybe, ResultExt};
 19
 20pub fn router() -> Router {
 21    Router::new()
 22        .route("/extensions", get(get_extensions))
 23        .route("/extensions/updates", get(get_extension_updates))
 24        .route("/extensions/:extension_id", get(get_extension_versions))
 25        .route(
 26            "/extensions/:extension_id/download",
 27            get(download_latest_extension),
 28        )
 29        .route(
 30            "/extensions/:extension_id/:version/download",
 31            get(download_extension),
 32        )
 33}
 34
 35#[derive(Debug, Deserialize)]
 36struct GetExtensionsParams {
 37    filter: Option<String>,
 38    #[serde(default)]
 39    ids: Option<String>,
 40    #[serde(default)]
 41    max_schema_version: i32,
 42}
 43
 44async fn get_extensions(
 45    Extension(app): Extension<Arc<AppState>>,
 46    Query(params): Query<GetExtensionsParams>,
 47) -> Result<Json<GetExtensionsResponse>> {
 48    let extension_ids = params
 49        .ids
 50        .as_ref()
 51        .map(|s| s.split(',').map(|s| s.trim()).collect::<Vec<_>>());
 52
 53    let extensions = if let Some(extension_ids) = extension_ids {
 54        app.db.get_extensions_by_ids(&extension_ids, None).await?
 55    } else {
 56        app.db
 57            .get_extensions(params.filter.as_deref(), params.max_schema_version, 500)
 58            .await?
 59    };
 60
 61    Ok(Json(GetExtensionsResponse { data: extensions }))
 62}
 63
 64#[derive(Debug, Deserialize)]
 65struct GetExtensionUpdatesParams {
 66    ids: String,
 67    min_schema_version: i32,
 68    max_schema_version: i32,
 69    min_wasm_api_version: SemanticVersion,
 70    max_wasm_api_version: SemanticVersion,
 71}
 72
 73async fn get_extension_updates(
 74    Extension(app): Extension<Arc<AppState>>,
 75    Query(params): Query<GetExtensionUpdatesParams>,
 76) -> Result<Json<GetExtensionsResponse>> {
 77    let constraints = ExtensionVersionConstraints {
 78        schema_versions: params.min_schema_version..=params.max_schema_version,
 79        wasm_api_versions: params.min_wasm_api_version..=params.max_wasm_api_version,
 80    };
 81
 82    let extension_ids = params.ids.split(',').map(|s| s.trim()).collect::<Vec<_>>();
 83
 84    let extensions = app
 85        .db
 86        .get_extensions_by_ids(&extension_ids, Some(&constraints))
 87        .await?;
 88
 89    Ok(Json(GetExtensionsResponse { data: extensions }))
 90}
 91
 92#[derive(Debug, Deserialize)]
 93struct GetExtensionVersionsParams {
 94    extension_id: String,
 95}
 96
 97async fn get_extension_versions(
 98    Extension(app): Extension<Arc<AppState>>,
 99    Path(params): Path<GetExtensionVersionsParams>,
100) -> Result<Json<GetExtensionsResponse>> {
101    let extension_versions = app.db.get_extension_versions(&params.extension_id).await?;
102
103    Ok(Json(GetExtensionsResponse {
104        data: extension_versions,
105    }))
106}
107
108#[derive(Debug, Deserialize)]
109struct DownloadLatestExtensionPathParams {
110    extension_id: String,
111}
112
113#[derive(Debug, Deserialize)]
114struct DownloadLatestExtensionQueryParams {
115    min_schema_version: Option<i32>,
116    max_schema_version: Option<i32>,
117    min_wasm_api_version: Option<SemanticVersion>,
118    max_wasm_api_version: Option<SemanticVersion>,
119}
120
121async fn download_latest_extension(
122    Extension(app): Extension<Arc<AppState>>,
123    Path(params): Path<DownloadLatestExtensionPathParams>,
124    Query(query): Query<DownloadLatestExtensionQueryParams>,
125) -> Result<Redirect> {
126    let constraints = maybe!({
127        let min_schema_version = query.min_schema_version?;
128        let max_schema_version = query.max_schema_version?;
129        let min_wasm_api_version = query.min_wasm_api_version?;
130        let max_wasm_api_version = query.max_wasm_api_version?;
131
132        Some(ExtensionVersionConstraints {
133            schema_versions: min_schema_version..=max_schema_version,
134            wasm_api_versions: min_wasm_api_version..=max_wasm_api_version,
135        })
136    });
137
138    let extension = app
139        .db
140        .get_extension(&params.extension_id, constraints.as_ref())
141        .await?
142        .ok_or_else(|| anyhow!("unknown extension"))?;
143    download_extension(
144        Extension(app),
145        Path(DownloadExtensionParams {
146            extension_id: params.extension_id,
147            version: extension.manifest.version.to_string(),
148        }),
149    )
150    .await
151}
152
153#[derive(Debug, Deserialize)]
154struct DownloadExtensionParams {
155    extension_id: String,
156    version: String,
157}
158
159async fn download_extension(
160    Extension(app): Extension<Arc<AppState>>,
161    Path(params): Path<DownloadExtensionParams>,
162) -> Result<Redirect> {
163    let Some((blob_store_client, bucket)) = app
164        .blob_store_client
165        .clone()
166        .zip(app.config.blob_store_bucket.clone())
167    else {
168        Err(Error::Http(
169            StatusCode::NOT_IMPLEMENTED,
170            "not supported".into(),
171        ))?
172    };
173
174    let DownloadExtensionParams {
175        extension_id,
176        version,
177    } = params;
178
179    let version_exists = app
180        .db
181        .record_extension_download(&extension_id, &version)
182        .await?;
183
184    if !version_exists {
185        Err(Error::Http(
186            StatusCode::NOT_FOUND,
187            "unknown extension version".into(),
188        ))?;
189    }
190
191    let url = blob_store_client
192        .get_object()
193        .bucket(bucket)
194        .key(format!(
195            "extensions/{extension_id}/{version}/archive.tar.gz"
196        ))
197        .presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
198        .await
199        .map_err(|e| anyhow!("failed to create presigned extension download url {e}"))?;
200
201    Ok(Redirect::temporary(url.uri()))
202}
203
204const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
205const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
206
207pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
208    let Some(blob_store_client) = app_state.blob_store_client.clone() else {
209        log::info!("no blob store client");
210        return;
211    };
212    let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
213        log::info!("no blob store bucket");
214        return;
215    };
216
217    let executor = app_state.executor.clone();
218    executor.spawn_detached({
219        let executor = executor.clone();
220        async move {
221            loop {
222                fetch_extensions_from_blob_store(
223                    &blob_store_client,
224                    &blob_store_bucket,
225                    &app_state,
226                )
227                .await
228                .log_err();
229                executor.sleep(EXTENSION_FETCH_INTERVAL).await;
230            }
231        }
232    });
233}
234
235async fn fetch_extensions_from_blob_store(
236    blob_store_client: &aws_sdk_s3::Client,
237    blob_store_bucket: &String,
238    app_state: &Arc<AppState>,
239) -> anyhow::Result<()> {
240    log::info!("fetching extensions from blob store");
241
242    let mut next_marker = None;
243    let mut published_versions = HashMap::<String, Vec<String>>::default();
244
245    loop {
246        let list = blob_store_client
247            .list_objects()
248            .bucket(blob_store_bucket)
249            .prefix("extensions/")
250            .set_marker(next_marker.clone())
251            .send()
252            .await?;
253        let objects = list.contents.unwrap_or_default();
254        log::info!("fetched {} object(s) from blob store", objects.len());
255
256        for object in &objects {
257            let Some(key) = object.key.as_ref() else {
258                continue;
259            };
260            let mut parts = key.split('/');
261            let Some(_) = parts.next().filter(|part| *part == "extensions") else {
262                continue;
263            };
264            let Some(extension_id) = parts.next() else {
265                continue;
266            };
267            let Some(version) = parts.next() else {
268                continue;
269            };
270            if parts.next() == Some("manifest.json") {
271                published_versions
272                    .entry(extension_id.to_owned())
273                    .or_default()
274                    .push(version.to_owned());
275            }
276        }
277
278        if let (Some(true), Some(last_object)) = (list.is_truncated, objects.last()) {
279            next_marker.clone_from(&last_object.key);
280        } else {
281            break;
282        }
283    }
284
285    log::info!("found {} published extensions", published_versions.len());
286
287    let known_versions = app_state.db.get_known_extension_versions().await?;
288
289    let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
290    let empty = Vec::new();
291    for (extension_id, published_versions) in &published_versions {
292        let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
293
294        for published_version in published_versions {
295            if known_versions
296                .binary_search_by_key(&published_version, |known_version| known_version)
297                .is_err()
298            {
299                if let Some(extension) = fetch_extension_manifest(
300                    blob_store_client,
301                    blob_store_bucket,
302                    &extension_id,
303                    &published_version,
304                )
305                .await
306                .log_err()
307                {
308                    new_versions
309                        .entry(&extension_id)
310                        .or_default()
311                        .push(extension);
312                }
313            }
314        }
315    }
316
317    app_state
318        .db
319        .insert_extension_versions(&new_versions)
320        .await?;
321
322    log::info!(
323        "fetched {} new extensions from blob store",
324        new_versions.values().map(|v| v.len()).sum::<usize>()
325    );
326
327    Ok(())
328}
329
330async fn fetch_extension_manifest(
331    blob_store_client: &aws_sdk_s3::Client,
332    blob_store_bucket: &String,
333    extension_id: &str,
334    version: &str,
335) -> Result<NewExtensionVersion, anyhow::Error> {
336    let object = blob_store_client
337        .get_object()
338        .bucket(blob_store_bucket)
339        .key(format!("extensions/{extension_id}/{version}/manifest.json"))
340        .send()
341        .await?;
342    let manifest_bytes = object
343        .body
344        .collect()
345        .await
346        .map(|data| data.into_bytes())
347        .with_context(|| {
348            format!("failed to download manifest for extension {extension_id} version {version}")
349        })?
350        .to_vec();
351    let manifest =
352        serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
353            format!(
354                "invalid manifest for extension {extension_id} version {version}: {}",
355                String::from_utf8_lossy(&manifest_bytes)
356            )
357        })?;
358    let published_at = object.last_modified.ok_or_else(|| {
359        anyhow!("missing last modified timestamp for extension {extension_id} version {version}")
360    })?;
361    let published_at = time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
362    let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
363    let version = semver::Version::parse(&manifest.version).with_context(|| {
364        format!("invalid version for extension {extension_id} version {version}")
365    })?;
366    Ok(NewExtensionVersion {
367        name: manifest.name,
368        version,
369        description: manifest.description.unwrap_or_default(),
370        authors: manifest.authors,
371        repository: manifest.repository,
372        schema_version: manifest.schema_version.unwrap_or(0),
373        wasm_api_version: manifest.wasm_api_version,
374        published_at,
375    })
376}