extensions.rs

  1use crate::db::ExtensionVersionConstraints;
  2use crate::{db::NewExtensionVersion, AppState, Error, Result};
  3use anyhow::{anyhow, Context as _};
  4use aws_sdk_s3::presigning::PresigningConfig;
  5use axum::{
  6    extract::{Path, Query},
  7    http::StatusCode,
  8    response::Redirect,
  9    routing::get,
 10    Extension, Json, Router,
 11};
 12use collections::HashMap;
 13use rpc::{ExtensionApiManifest, GetExtensionsResponse};
 14use semantic_version::SemanticVersion;
 15use serde::Deserialize;
 16use std::{sync::Arc, time::Duration};
 17use time::PrimitiveDateTime;
 18use util::{maybe, ResultExt};
 19
 20pub fn router() -> Router {
 21    Router::new()
 22        .route("/extensions", get(get_extensions))
 23        .route("/extensions/updates", get(get_extension_updates))
 24        .route("/extensions/:extension_id", get(get_extension_versions))
 25        .route(
 26            "/extensions/:extension_id/download",
 27            get(download_latest_extension),
 28        )
 29        .route(
 30            "/extensions/:extension_id/:version/download",
 31            get(download_extension),
 32        )
 33}
 34
 35#[derive(Debug, Deserialize)]
 36struct GetExtensionsParams {
 37    filter: Option<String>,
 38    #[serde(default)]
 39    max_schema_version: i32,
 40}
 41
 42async fn get_extensions(
 43    Extension(app): Extension<Arc<AppState>>,
 44    Query(params): Query<GetExtensionsParams>,
 45) -> Result<Json<GetExtensionsResponse>> {
 46    let extensions = app
 47        .db
 48        .get_extensions(params.filter.as_deref(), params.max_schema_version, 500)
 49        .await?;
 50
 51    if let Some(query) = params.filter.as_deref() {
 52        let count = extensions.len();
 53        tracing::info!(query, count, "extension_search")
 54    }
 55
 56    Ok(Json(GetExtensionsResponse { data: extensions }))
 57}
 58
 59#[derive(Debug, Deserialize)]
 60struct GetExtensionUpdatesParams {
 61    ids: String,
 62    min_schema_version: i32,
 63    max_schema_version: i32,
 64    min_wasm_api_version: SemanticVersion,
 65    max_wasm_api_version: SemanticVersion,
 66}
 67
 68async fn get_extension_updates(
 69    Extension(app): Extension<Arc<AppState>>,
 70    Query(params): Query<GetExtensionUpdatesParams>,
 71) -> Result<Json<GetExtensionsResponse>> {
 72    let constraints = ExtensionVersionConstraints {
 73        schema_versions: params.min_schema_version..=params.max_schema_version,
 74        wasm_api_versions: params.min_wasm_api_version..=params.max_wasm_api_version,
 75    };
 76
 77    let extension_ids = params.ids.split(',').map(|s| s.trim()).collect::<Vec<_>>();
 78
 79    let extensions = app
 80        .db
 81        .get_extensions_by_ids(&extension_ids, Some(&constraints))
 82        .await?;
 83
 84    Ok(Json(GetExtensionsResponse { data: extensions }))
 85}
 86
 87#[derive(Debug, Deserialize)]
 88struct GetExtensionVersionsParams {
 89    extension_id: String,
 90}
 91
 92async fn get_extension_versions(
 93    Extension(app): Extension<Arc<AppState>>,
 94    Path(params): Path<GetExtensionVersionsParams>,
 95) -> Result<Json<GetExtensionsResponse>> {
 96    let extension_versions = app.db.get_extension_versions(&params.extension_id).await?;
 97
 98    Ok(Json(GetExtensionsResponse {
 99        data: extension_versions,
100    }))
101}
102
103#[derive(Debug, Deserialize)]
104struct DownloadLatestExtensionPathParams {
105    extension_id: String,
106}
107
108#[derive(Debug, Deserialize)]
109struct DownloadLatestExtensionQueryParams {
110    min_schema_version: Option<i32>,
111    max_schema_version: Option<i32>,
112    min_wasm_api_version: Option<SemanticVersion>,
113    max_wasm_api_version: Option<SemanticVersion>,
114}
115
116async fn download_latest_extension(
117    Extension(app): Extension<Arc<AppState>>,
118    Path(params): Path<DownloadLatestExtensionPathParams>,
119    Query(query): Query<DownloadLatestExtensionQueryParams>,
120) -> Result<Redirect> {
121    let constraints = maybe!({
122        let min_schema_version = query.min_schema_version?;
123        let max_schema_version = query.max_schema_version?;
124        let min_wasm_api_version = query.min_wasm_api_version?;
125        let max_wasm_api_version = query.max_wasm_api_version?;
126
127        Some(ExtensionVersionConstraints {
128            schema_versions: min_schema_version..=max_schema_version,
129            wasm_api_versions: min_wasm_api_version..=max_wasm_api_version,
130        })
131    });
132
133    let extension = app
134        .db
135        .get_extension(&params.extension_id, constraints.as_ref())
136        .await?
137        .ok_or_else(|| anyhow!("unknown extension"))?;
138    download_extension(
139        Extension(app),
140        Path(DownloadExtensionParams {
141            extension_id: params.extension_id,
142            version: extension.manifest.version.to_string(),
143        }),
144    )
145    .await
146}
147
148#[derive(Debug, Deserialize)]
149struct DownloadExtensionParams {
150    extension_id: String,
151    version: String,
152}
153
154async fn download_extension(
155    Extension(app): Extension<Arc<AppState>>,
156    Path(params): Path<DownloadExtensionParams>,
157) -> Result<Redirect> {
158    let Some((blob_store_client, bucket)) = app
159        .blob_store_client
160        .clone()
161        .zip(app.config.blob_store_bucket.clone())
162    else {
163        Err(Error::Http(
164            StatusCode::NOT_IMPLEMENTED,
165            "not supported".into(),
166        ))?
167    };
168
169    let DownloadExtensionParams {
170        extension_id,
171        version,
172    } = params;
173
174    let version_exists = app
175        .db
176        .record_extension_download(&extension_id, &version)
177        .await?;
178
179    if !version_exists {
180        Err(Error::Http(
181            StatusCode::NOT_FOUND,
182            "unknown extension version".into(),
183        ))?;
184    }
185
186    let url = blob_store_client
187        .get_object()
188        .bucket(bucket)
189        .key(format!(
190            "extensions/{extension_id}/{version}/archive.tar.gz"
191        ))
192        .presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
193        .await
194        .map_err(|e| anyhow!("failed to create presigned extension download url {e}"))?;
195
196    Ok(Redirect::temporary(url.uri()))
197}
198
199const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
200const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
201
202pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
203    let Some(blob_store_client) = app_state.blob_store_client.clone() else {
204        log::info!("no blob store client");
205        return;
206    };
207    let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
208        log::info!("no blob store bucket");
209        return;
210    };
211
212    let executor = app_state.executor.clone();
213    executor.spawn_detached({
214        let executor = executor.clone();
215        async move {
216            loop {
217                fetch_extensions_from_blob_store(
218                    &blob_store_client,
219                    &blob_store_bucket,
220                    &app_state,
221                )
222                .await
223                .log_err();
224                executor.sleep(EXTENSION_FETCH_INTERVAL).await;
225            }
226        }
227    });
228}
229
230async fn fetch_extensions_from_blob_store(
231    blob_store_client: &aws_sdk_s3::Client,
232    blob_store_bucket: &String,
233    app_state: &Arc<AppState>,
234) -> anyhow::Result<()> {
235    log::info!("fetching extensions from blob store");
236
237    let mut next_marker = None;
238    let mut published_versions = HashMap::<String, Vec<String>>::default();
239
240    loop {
241        let list = blob_store_client
242            .list_objects()
243            .bucket(blob_store_bucket)
244            .prefix("extensions/")
245            .set_marker(next_marker.clone())
246            .send()
247            .await?;
248        let objects = list.contents.unwrap_or_default();
249        log::info!("fetched {} object(s) from blob store", objects.len());
250
251        for object in &objects {
252            let Some(key) = object.key.as_ref() else {
253                continue;
254            };
255            let mut parts = key.split('/');
256            let Some(_) = parts.next().filter(|part| *part == "extensions") else {
257                continue;
258            };
259            let Some(extension_id) = parts.next() else {
260                continue;
261            };
262            let Some(version) = parts.next() else {
263                continue;
264            };
265            if parts.next() == Some("manifest.json") {
266                published_versions
267                    .entry(extension_id.to_owned())
268                    .or_default()
269                    .push(version.to_owned());
270            }
271        }
272
273        if let (Some(true), Some(last_object)) = (list.is_truncated, objects.last()) {
274            next_marker.clone_from(&last_object.key);
275        } else {
276            break;
277        }
278    }
279
280    log::info!("found {} published extensions", published_versions.len());
281
282    let known_versions = app_state.db.get_known_extension_versions().await?;
283
284    let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
285    let empty = Vec::new();
286    for (extension_id, published_versions) in &published_versions {
287        let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
288
289        for published_version in published_versions {
290            if known_versions
291                .binary_search_by_key(&published_version, |known_version| known_version)
292                .is_err()
293            {
294                if let Some(extension) = fetch_extension_manifest(
295                    blob_store_client,
296                    blob_store_bucket,
297                    &extension_id,
298                    &published_version,
299                )
300                .await
301                .log_err()
302                {
303                    new_versions
304                        .entry(&extension_id)
305                        .or_default()
306                        .push(extension);
307                }
308            }
309        }
310    }
311
312    app_state
313        .db
314        .insert_extension_versions(&new_versions)
315        .await?;
316
317    log::info!(
318        "fetched {} new extensions from blob store",
319        new_versions.values().map(|v| v.len()).sum::<usize>()
320    );
321
322    Ok(())
323}
324
325async fn fetch_extension_manifest(
326    blob_store_client: &aws_sdk_s3::Client,
327    blob_store_bucket: &String,
328    extension_id: &str,
329    version: &str,
330) -> Result<NewExtensionVersion, anyhow::Error> {
331    let object = blob_store_client
332        .get_object()
333        .bucket(blob_store_bucket)
334        .key(format!("extensions/{extension_id}/{version}/manifest.json"))
335        .send()
336        .await?;
337    let manifest_bytes = object
338        .body
339        .collect()
340        .await
341        .map(|data| data.into_bytes())
342        .with_context(|| {
343            format!("failed to download manifest for extension {extension_id} version {version}")
344        })?
345        .to_vec();
346    let manifest =
347        serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
348            format!(
349                "invalid manifest for extension {extension_id} version {version}: {}",
350                String::from_utf8_lossy(&manifest_bytes)
351            )
352        })?;
353    let published_at = object.last_modified.ok_or_else(|| {
354        anyhow!("missing last modified timestamp for extension {extension_id} version {version}")
355    })?;
356    let published_at = time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
357    let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
358    let version = semver::Version::parse(&manifest.version).with_context(|| {
359        format!("invalid version for extension {extension_id} version {version}")
360    })?;
361    Ok(NewExtensionVersion {
362        name: manifest.name,
363        version,
364        description: manifest.description.unwrap_or_default(),
365        authors: manifest.authors,
366        repository: manifest.repository,
367        schema_version: manifest.schema_version.unwrap_or(0),
368        wasm_api_version: manifest.wasm_api_version,
369        published_at,
370    })
371}