extensions.rs

  1use crate::{db::NewExtensionVersion, AppState, Error, Result};
  2use anyhow::{anyhow, Context as _};
  3use aws_sdk_s3::presigning::PresigningConfig;
  4use axum::{
  5    extract::{Path, Query},
  6    http::StatusCode,
  7    response::Redirect,
  8    routing::get,
  9    Extension, Json, Router,
 10};
 11use collections::HashMap;
 12use rpc::{ExtensionApiManifest, GetExtensionsResponse};
 13use serde::Deserialize;
 14use std::{sync::Arc, time::Duration};
 15use time::PrimitiveDateTime;
 16use util::ResultExt;
 17
 18pub fn router() -> Router {
 19    Router::new()
 20        .route("/extensions", get(get_extensions))
 21        .route("/extensions/:extension_id", get(get_extension_versions))
 22        .route(
 23            "/extensions/:extension_id/download",
 24            get(download_latest_extension),
 25        )
 26        .route(
 27            "/extensions/:extension_id/:version/download",
 28            get(download_extension),
 29        )
 30}
 31
 32#[derive(Debug, Deserialize)]
 33struct GetExtensionsParams {
 34    filter: Option<String>,
 35    #[serde(default)]
 36    ids: Option<String>,
 37    #[serde(default)]
 38    max_schema_version: i32,
 39}
 40
 41async fn get_extensions(
 42    Extension(app): Extension<Arc<AppState>>,
 43    Query(params): Query<GetExtensionsParams>,
 44) -> Result<Json<GetExtensionsResponse>> {
 45    let extension_ids = params
 46        .ids
 47        .as_ref()
 48        .map(|s| s.split(',').map(|s| s.trim()).collect::<Vec<_>>());
 49
 50    let extensions = if let Some(extension_ids) = extension_ids {
 51        app.db
 52            .get_extensions_by_ids(&extension_ids, params.max_schema_version)
 53            .await?
 54    } else {
 55        app.db
 56            .get_extensions(params.filter.as_deref(), params.max_schema_version, 500)
 57            .await?
 58    };
 59
 60    Ok(Json(GetExtensionsResponse { data: extensions }))
 61}
 62
 63#[derive(Debug, Deserialize)]
 64struct GetExtensionVersionsParams {
 65    extension_id: String,
 66}
 67
 68async fn get_extension_versions(
 69    Extension(app): Extension<Arc<AppState>>,
 70    Path(params): Path<GetExtensionVersionsParams>,
 71) -> Result<Json<GetExtensionsResponse>> {
 72    let extension_versions = app.db.get_extension_versions(&params.extension_id).await?;
 73
 74    Ok(Json(GetExtensionsResponse {
 75        data: extension_versions,
 76    }))
 77}
 78
 79#[derive(Debug, Deserialize)]
 80struct DownloadLatestExtensionParams {
 81    extension_id: String,
 82}
 83
 84async fn download_latest_extension(
 85    Extension(app): Extension<Arc<AppState>>,
 86    Path(params): Path<DownloadLatestExtensionParams>,
 87) -> Result<Redirect> {
 88    let extension = app
 89        .db
 90        .get_extension(&params.extension_id)
 91        .await?
 92        .ok_or_else(|| anyhow!("unknown extension"))?;
 93    download_extension(
 94        Extension(app),
 95        Path(DownloadExtensionParams {
 96            extension_id: params.extension_id,
 97            version: extension.manifest.version.to_string(),
 98        }),
 99    )
100    .await
101}
102
103#[derive(Debug, Deserialize)]
104struct DownloadExtensionParams {
105    extension_id: String,
106    version: String,
107}
108
109async fn download_extension(
110    Extension(app): Extension<Arc<AppState>>,
111    Path(params): Path<DownloadExtensionParams>,
112) -> Result<Redirect> {
113    let Some((blob_store_client, bucket)) = app
114        .blob_store_client
115        .clone()
116        .zip(app.config.blob_store_bucket.clone())
117    else {
118        Err(Error::Http(
119            StatusCode::NOT_IMPLEMENTED,
120            "not supported".into(),
121        ))?
122    };
123
124    let DownloadExtensionParams {
125        extension_id,
126        version,
127    } = params;
128
129    let version_exists = app
130        .db
131        .record_extension_download(&extension_id, &version)
132        .await?;
133
134    if !version_exists {
135        Err(Error::Http(
136            StatusCode::NOT_FOUND,
137            "unknown extension version".into(),
138        ))?;
139    }
140
141    let url = blob_store_client
142        .get_object()
143        .bucket(bucket)
144        .key(format!(
145            "extensions/{extension_id}/{version}/archive.tar.gz"
146        ))
147        .presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
148        .await
149        .map_err(|e| anyhow!("failed to create presigned extension download url {e}"))?;
150
151    Ok(Redirect::temporary(url.uri()))
152}
153
154const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
155const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
156
157pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
158    let Some(blob_store_client) = app_state.blob_store_client.clone() else {
159        log::info!("no blob store client");
160        return;
161    };
162    let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
163        log::info!("no blob store bucket");
164        return;
165    };
166
167    let executor = app_state.executor.clone();
168    executor.spawn_detached({
169        let executor = executor.clone();
170        async move {
171            loop {
172                fetch_extensions_from_blob_store(
173                    &blob_store_client,
174                    &blob_store_bucket,
175                    &app_state,
176                )
177                .await
178                .log_err();
179                executor.sleep(EXTENSION_FETCH_INTERVAL).await;
180            }
181        }
182    });
183}
184
185async fn fetch_extensions_from_blob_store(
186    blob_store_client: &aws_sdk_s3::Client,
187    blob_store_bucket: &String,
188    app_state: &Arc<AppState>,
189) -> anyhow::Result<()> {
190    log::info!("fetching extensions from blob store");
191
192    let list = blob_store_client
193        .list_objects()
194        .bucket(blob_store_bucket)
195        .prefix("extensions/")
196        .send()
197        .await?;
198
199    let objects = list.contents.unwrap_or_default();
200
201    let mut published_versions = HashMap::<&str, Vec<&str>>::default();
202    for object in &objects {
203        let Some(key) = object.key.as_ref() else {
204            continue;
205        };
206        let mut parts = key.split('/');
207        let Some(_) = parts.next().filter(|part| *part == "extensions") else {
208            continue;
209        };
210        let Some(extension_id) = parts.next() else {
211            continue;
212        };
213        let Some(version) = parts.next() else {
214            continue;
215        };
216        if parts.next() == Some("manifest.json") {
217            published_versions
218                .entry(extension_id)
219                .or_default()
220                .push(version);
221        }
222    }
223
224    let known_versions = app_state.db.get_known_extension_versions().await?;
225
226    let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
227    let empty = Vec::new();
228    for (extension_id, published_versions) in published_versions {
229        let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
230
231        for published_version in published_versions {
232            if known_versions
233                .binary_search_by_key(&published_version, String::as_str)
234                .is_err()
235            {
236                if let Some(extension) = fetch_extension_manifest(
237                    blob_store_client,
238                    blob_store_bucket,
239                    extension_id,
240                    published_version,
241                )
242                .await
243                .log_err()
244                {
245                    new_versions
246                        .entry(extension_id)
247                        .or_default()
248                        .push(extension);
249                }
250            }
251        }
252    }
253
254    app_state
255        .db
256        .insert_extension_versions(&new_versions)
257        .await?;
258
259    log::info!(
260        "fetched {} new extensions from blob store",
261        new_versions.values().map(|v| v.len()).sum::<usize>()
262    );
263
264    Ok(())
265}
266
267async fn fetch_extension_manifest(
268    blob_store_client: &aws_sdk_s3::Client,
269    blob_store_bucket: &String,
270    extension_id: &str,
271    version: &str,
272) -> Result<NewExtensionVersion, anyhow::Error> {
273    let object = blob_store_client
274        .get_object()
275        .bucket(blob_store_bucket)
276        .key(format!("extensions/{extension_id}/{version}/manifest.json"))
277        .send()
278        .await?;
279    let manifest_bytes = object
280        .body
281        .collect()
282        .await
283        .map(|data| data.into_bytes())
284        .with_context(|| {
285            format!("failed to download manifest for extension {extension_id} version {version}")
286        })?
287        .to_vec();
288    let manifest =
289        serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
290            format!(
291                "invalid manifest for extension {extension_id} version {version}: {}",
292                String::from_utf8_lossy(&manifest_bytes)
293            )
294        })?;
295    let published_at = object.last_modified.ok_or_else(|| {
296        anyhow!("missing last modified timestamp for extension {extension_id} version {version}")
297    })?;
298    let published_at = time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
299    let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
300    let version = semver::Version::parse(&manifest.version).with_context(|| {
301        format!("invalid version for extension {extension_id} version {version}")
302    })?;
303    Ok(NewExtensionVersion {
304        name: manifest.name,
305        version,
306        description: manifest.description.unwrap_or_default(),
307        authors: manifest.authors,
308        repository: manifest.repository,
309        schema_version: manifest.schema_version.unwrap_or(0),
310        wasm_api_version: manifest.wasm_api_version,
311        published_at,
312    })
313}