extensions.rs

  1use crate::{db::NewExtensionVersion, AppState, Error, Result};
  2use anyhow::{anyhow, Context as _};
  3use aws_sdk_s3::presigning::PresigningConfig;
  4use axum::{
  5    extract::{Path, Query},
  6    http::StatusCode,
  7    response::Redirect,
  8    routing::get,
  9    Extension, Json, Router,
 10};
 11use collections::HashMap;
 12use rpc::{ExtensionApiManifest, GetExtensionsResponse};
 13use serde::Deserialize;
 14use std::{sync::Arc, time::Duration};
 15use time::PrimitiveDateTime;
 16use util::ResultExt;
 17
 18pub fn router() -> Router {
 19    Router::new()
 20        .route("/extensions", get(get_extensions))
 21        .route(
 22            "/extensions/:extension_id/download",
 23            get(download_latest_extension),
 24        )
 25        .route(
 26            "/extensions/:extension_id/:version/download",
 27            get(download_extension),
 28        )
 29}
 30
 31#[derive(Debug, Deserialize)]
 32struct GetExtensionsParams {
 33    filter: Option<String>,
 34    #[serde(default)]
 35    max_schema_version: i32,
 36}
 37
 38#[derive(Debug, Deserialize)]
 39struct DownloadLatestExtensionParams {
 40    extension_id: String,
 41}
 42
 43#[derive(Debug, Deserialize)]
 44struct DownloadExtensionParams {
 45    extension_id: String,
 46    version: String,
 47}
 48
 49async fn get_extensions(
 50    Extension(app): Extension<Arc<AppState>>,
 51    Query(params): Query<GetExtensionsParams>,
 52) -> Result<Json<GetExtensionsResponse>> {
 53    let extensions = app
 54        .db
 55        .get_extensions(params.filter.as_deref(), params.max_schema_version, 500)
 56        .await?;
 57    Ok(Json(GetExtensionsResponse { data: extensions }))
 58}
 59
 60async fn download_latest_extension(
 61    Extension(app): Extension<Arc<AppState>>,
 62    Path(params): Path<DownloadLatestExtensionParams>,
 63) -> Result<Redirect> {
 64    let extension = app
 65        .db
 66        .get_extension(&params.extension_id)
 67        .await?
 68        .ok_or_else(|| anyhow!("unknown extension"))?;
 69    download_extension(
 70        Extension(app),
 71        Path(DownloadExtensionParams {
 72            extension_id: params.extension_id,
 73            version: extension.manifest.version.to_string(),
 74        }),
 75    )
 76    .await
 77}
 78
 79async fn download_extension(
 80    Extension(app): Extension<Arc<AppState>>,
 81    Path(params): Path<DownloadExtensionParams>,
 82) -> Result<Redirect> {
 83    let Some((blob_store_client, bucket)) = app
 84        .blob_store_client
 85        .clone()
 86        .zip(app.config.blob_store_bucket.clone())
 87    else {
 88        Err(Error::Http(
 89            StatusCode::NOT_IMPLEMENTED,
 90            "not supported".into(),
 91        ))?
 92    };
 93
 94    let DownloadExtensionParams {
 95        extension_id,
 96        version,
 97    } = params;
 98
 99    let version_exists = app
100        .db
101        .record_extension_download(&extension_id, &version)
102        .await?;
103
104    if !version_exists {
105        Err(Error::Http(
106            StatusCode::NOT_FOUND,
107            "unknown extension version".into(),
108        ))?;
109    }
110
111    let url = blob_store_client
112        .get_object()
113        .bucket(bucket)
114        .key(format!(
115            "extensions/{extension_id}/{version}/archive.tar.gz"
116        ))
117        .presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
118        .await
119        .map_err(|e| anyhow!("failed to create presigned extension download url {e}"))?;
120
121    Ok(Redirect::temporary(url.uri()))
122}
123
124const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
125const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
126
127pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
128    let Some(blob_store_client) = app_state.blob_store_client.clone() else {
129        log::info!("no blob store client");
130        return;
131    };
132    let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
133        log::info!("no blob store bucket");
134        return;
135    };
136
137    let executor = app_state.executor.clone();
138    executor.spawn_detached({
139        let executor = executor.clone();
140        async move {
141            loop {
142                fetch_extensions_from_blob_store(
143                    &blob_store_client,
144                    &blob_store_bucket,
145                    &app_state,
146                )
147                .await
148                .log_err();
149                executor.sleep(EXTENSION_FETCH_INTERVAL).await;
150            }
151        }
152    });
153}
154
155async fn fetch_extensions_from_blob_store(
156    blob_store_client: &aws_sdk_s3::Client,
157    blob_store_bucket: &String,
158    app_state: &Arc<AppState>,
159) -> anyhow::Result<()> {
160    log::info!("fetching extensions from blob store");
161
162    let list = blob_store_client
163        .list_objects()
164        .bucket(blob_store_bucket)
165        .prefix("extensions/")
166        .send()
167        .await?;
168
169    let objects = list.contents.unwrap_or_default();
170
171    let mut published_versions = HashMap::<&str, Vec<&str>>::default();
172    for object in &objects {
173        let Some(key) = object.key.as_ref() else {
174            continue;
175        };
176        let mut parts = key.split('/');
177        let Some(_) = parts.next().filter(|part| *part == "extensions") else {
178            continue;
179        };
180        let Some(extension_id) = parts.next() else {
181            continue;
182        };
183        let Some(version) = parts.next() else {
184            continue;
185        };
186        if parts.next() == Some("manifest.json") {
187            published_versions
188                .entry(extension_id)
189                .or_default()
190                .push(version);
191        }
192    }
193
194    let known_versions = app_state.db.get_known_extension_versions().await?;
195
196    let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
197    let empty = Vec::new();
198    for (extension_id, published_versions) in published_versions {
199        let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
200
201        for published_version in published_versions {
202            if known_versions
203                .binary_search_by_key(&published_version, String::as_str)
204                .is_err()
205            {
206                if let Some(extension) = fetch_extension_manifest(
207                    blob_store_client,
208                    blob_store_bucket,
209                    extension_id,
210                    published_version,
211                )
212                .await
213                .log_err()
214                {
215                    new_versions
216                        .entry(extension_id)
217                        .or_default()
218                        .push(extension);
219                }
220            }
221        }
222    }
223
224    app_state
225        .db
226        .insert_extension_versions(&new_versions)
227        .await?;
228
229    log::info!(
230        "fetched {} new extensions from blob store",
231        new_versions.values().map(|v| v.len()).sum::<usize>()
232    );
233
234    Ok(())
235}
236
237async fn fetch_extension_manifest(
238    blob_store_client: &aws_sdk_s3::Client,
239    blob_store_bucket: &String,
240    extension_id: &str,
241    version: &str,
242) -> Result<NewExtensionVersion, anyhow::Error> {
243    let object = blob_store_client
244        .get_object()
245        .bucket(blob_store_bucket)
246        .key(format!("extensions/{extension_id}/{version}/manifest.json"))
247        .send()
248        .await?;
249    let manifest_bytes = object
250        .body
251        .collect()
252        .await
253        .map(|data| data.into_bytes())
254        .with_context(|| {
255            format!("failed to download manifest for extension {extension_id} version {version}")
256        })?
257        .to_vec();
258    let manifest =
259        serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
260            format!(
261                "invalid manifest for extension {extension_id} version {version}: {}",
262                String::from_utf8_lossy(&manifest_bytes)
263            )
264        })?;
265    let published_at = object.last_modified.ok_or_else(|| {
266        anyhow!("missing last modified timestamp for extension {extension_id} version {version}")
267    })?;
268    let published_at = time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
269    let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
270    let version = semver::Version::parse(&manifest.version).with_context(|| {
271        format!("invalid version for extension {extension_id} version {version}")
272    })?;
273    Ok(NewExtensionVersion {
274        name: manifest.name,
275        version,
276        description: manifest.description.unwrap_or_default(),
277        authors: manifest.authors,
278        repository: manifest.repository,
279        schema_version: manifest.schema_version.unwrap_or(0),
280        wasm_api_version: manifest.wasm_api_version,
281        published_at,
282    })
283}