extensions.rs

  1use crate::{db::NewExtensionVersion, AppState, Error, Result};
  2use anyhow::{anyhow, Context as _};
  3use aws_sdk_s3::presigning::PresigningConfig;
  4use axum::{
  5    extract::{Path, Query},
  6    http::StatusCode,
  7    response::Redirect,
  8    routing::get,
  9    Extension, Json, Router,
 10};
 11use collections::HashMap;
 12use rpc::{ExtensionApiManifest, ExtensionMetadata};
 13use serde::{Deserialize, Serialize};
 14use std::{sync::Arc, time::Duration};
 15use time::PrimitiveDateTime;
 16use util::ResultExt;
 17
 18pub fn router() -> Router {
 19    Router::new()
 20        .route("/extensions", get(get_extensions))
 21        .route(
 22            "/extensions/:extension_id/download",
 23            get(download_latest_extension),
 24        )
 25        .route(
 26            "/extensions/:extension_id/:version/download",
 27            get(download_extension),
 28        )
 29}
 30
 31#[derive(Debug, Deserialize)]
 32struct GetExtensionsParams {
 33    filter: Option<String>,
 34    #[serde(default)]
 35    max_schema_version: i32,
 36}
 37
 38#[derive(Debug, Deserialize)]
 39struct DownloadLatestExtensionParams {
 40    extension_id: String,
 41}
 42
 43#[derive(Debug, Deserialize)]
 44struct DownloadExtensionParams {
 45    extension_id: String,
 46    version: String,
 47}
 48
 49#[derive(Debug, Serialize)]
 50struct GetExtensionsResponse {
 51    pub data: Vec<ExtensionMetadata>,
 52}
 53
 54async fn get_extensions(
 55    Extension(app): Extension<Arc<AppState>>,
 56    Query(params): Query<GetExtensionsParams>,
 57) -> Result<Json<GetExtensionsResponse>> {
 58    let extensions = app
 59        .db
 60        .get_extensions(params.filter.as_deref(), params.max_schema_version, 500)
 61        .await?;
 62    Ok(Json(GetExtensionsResponse { data: extensions }))
 63}
 64
 65async fn download_latest_extension(
 66    Extension(app): Extension<Arc<AppState>>,
 67    Path(params): Path<DownloadLatestExtensionParams>,
 68) -> Result<Redirect> {
 69    let extension = app
 70        .db
 71        .get_extension(&params.extension_id)
 72        .await?
 73        .ok_or_else(|| anyhow!("unknown extension"))?;
 74    download_extension(
 75        Extension(app),
 76        Path(DownloadExtensionParams {
 77            extension_id: params.extension_id,
 78            version: extension.manifest.version,
 79        }),
 80    )
 81    .await
 82}
 83
 84async fn download_extension(
 85    Extension(app): Extension<Arc<AppState>>,
 86    Path(params): Path<DownloadExtensionParams>,
 87) -> Result<Redirect> {
 88    let Some((blob_store_client, bucket)) = app
 89        .blob_store_client
 90        .clone()
 91        .zip(app.config.blob_store_bucket.clone())
 92    else {
 93        Err(Error::Http(
 94            StatusCode::NOT_IMPLEMENTED,
 95            "not supported".into(),
 96        ))?
 97    };
 98
 99    let DownloadExtensionParams {
100        extension_id,
101        version,
102    } = params;
103
104    let version_exists = app
105        .db
106        .record_extension_download(&extension_id, &version)
107        .await?;
108
109    if !version_exists {
110        Err(Error::Http(
111            StatusCode::NOT_FOUND,
112            "unknown extension version".into(),
113        ))?;
114    }
115
116    let url = blob_store_client
117        .get_object()
118        .bucket(bucket)
119        .key(format!(
120            "extensions/{extension_id}/{version}/archive.tar.gz"
121        ))
122        .presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
123        .await
124        .map_err(|e| anyhow!("failed to create presigned extension download url {e}"))?;
125
126    Ok(Redirect::temporary(url.uri()))
127}
128
129const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
130const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
131
132pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
133    let Some(blob_store_client) = app_state.blob_store_client.clone() else {
134        log::info!("no blob store client");
135        return;
136    };
137    let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
138        log::info!("no blob store bucket");
139        return;
140    };
141
142    let executor = app_state.executor.clone();
143    executor.spawn_detached({
144        let executor = executor.clone();
145        async move {
146            loop {
147                fetch_extensions_from_blob_store(
148                    &blob_store_client,
149                    &blob_store_bucket,
150                    &app_state,
151                )
152                .await
153                .log_err();
154                executor.sleep(EXTENSION_FETCH_INTERVAL).await;
155            }
156        }
157    });
158}
159
160async fn fetch_extensions_from_blob_store(
161    blob_store_client: &aws_sdk_s3::Client,
162    blob_store_bucket: &String,
163    app_state: &Arc<AppState>,
164) -> anyhow::Result<()> {
165    log::info!("fetching extensions from blob store");
166
167    let list = blob_store_client
168        .list_objects()
169        .bucket(blob_store_bucket)
170        .prefix("extensions/")
171        .send()
172        .await?;
173
174    let objects = list.contents.unwrap_or_default();
175
176    let mut published_versions = HashMap::<&str, Vec<&str>>::default();
177    for object in &objects {
178        let Some(key) = object.key.as_ref() else {
179            continue;
180        };
181        let mut parts = key.split('/');
182        let Some(_) = parts.next().filter(|part| *part == "extensions") else {
183            continue;
184        };
185        let Some(extension_id) = parts.next() else {
186            continue;
187        };
188        let Some(version) = parts.next() else {
189            continue;
190        };
191        if parts.next() == Some("manifest.json") {
192            published_versions
193                .entry(extension_id)
194                .or_default()
195                .push(version);
196        }
197    }
198
199    let known_versions = app_state.db.get_known_extension_versions().await?;
200
201    let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
202    let empty = Vec::new();
203    for (extension_id, published_versions) in published_versions {
204        let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
205
206        for published_version in published_versions {
207            if known_versions
208                .binary_search_by_key(&published_version, String::as_str)
209                .is_err()
210            {
211                if let Some(extension) = fetch_extension_manifest(
212                    blob_store_client,
213                    blob_store_bucket,
214                    extension_id,
215                    published_version,
216                )
217                .await
218                .log_err()
219                {
220                    new_versions
221                        .entry(extension_id)
222                        .or_default()
223                        .push(extension);
224                }
225            }
226        }
227    }
228
229    app_state
230        .db
231        .insert_extension_versions(&new_versions)
232        .await?;
233
234    log::info!(
235        "fetched {} new extensions from blob store",
236        new_versions.values().map(|v| v.len()).sum::<usize>()
237    );
238
239    Ok(())
240}
241
242async fn fetch_extension_manifest(
243    blob_store_client: &aws_sdk_s3::Client,
244    blob_store_bucket: &String,
245    extension_id: &str,
246    version: &str,
247) -> Result<NewExtensionVersion, anyhow::Error> {
248    let object = blob_store_client
249        .get_object()
250        .bucket(blob_store_bucket)
251        .key(format!("extensions/{extension_id}/{version}/manifest.json"))
252        .send()
253        .await?;
254    let manifest_bytes = object
255        .body
256        .collect()
257        .await
258        .map(|data| data.into_bytes())
259        .with_context(|| {
260            format!("failed to download manifest for extension {extension_id} version {version}")
261        })?
262        .to_vec();
263    let manifest =
264        serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
265            format!(
266                "invalid manifest for extension {extension_id} version {version}: {}",
267                String::from_utf8_lossy(&manifest_bytes)
268            )
269        })?;
270    let published_at = object.last_modified.ok_or_else(|| {
271        anyhow!("missing last modified timestamp for extension {extension_id} version {version}")
272    })?;
273    let published_at = time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
274    let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
275    let version = semver::Version::parse(&manifest.version).with_context(|| {
276        format!("invalid version for extension {extension_id} version {version}")
277    })?;
278    Ok(NewExtensionVersion {
279        name: manifest.name,
280        version,
281        description: manifest.description.unwrap_or_default(),
282        authors: manifest.authors,
283        repository: manifest.repository,
284        schema_version: manifest.schema_version.unwrap_or(0),
285        wasm_api_version: manifest.wasm_api_version,
286        published_at,
287    })
288}