extensions.rs

  1use crate::db::ExtensionVersionConstraints;
  2use crate::{db::NewExtensionVersion, AppState, Error, Result};
  3use anyhow::{anyhow, Context as _};
  4use aws_sdk_s3::presigning::PresigningConfig;
  5use axum::{
  6    extract::{Path, Query},
  7    http::StatusCode,
  8    response::Redirect,
  9    routing::get,
 10    Extension, Json, Router,
 11};
 12use collections::HashMap;
 13use rpc::{ExtensionApiManifest, GetExtensionsResponse};
 14use semantic_version::SemanticVersion;
 15use serde::Deserialize;
 16use std::{sync::Arc, time::Duration};
 17use time::PrimitiveDateTime;
 18use util::{maybe, ResultExt};
 19
 20pub fn router() -> Router {
 21    Router::new()
 22        .route("/extensions", get(get_extensions))
 23        .route("/extensions/updates", get(get_extension_updates))
 24        .route("/extensions/:extension_id", get(get_extension_versions))
 25        .route(
 26            "/extensions/:extension_id/download",
 27            get(download_latest_extension),
 28        )
 29        .route(
 30            "/extensions/:extension_id/:version/download",
 31            get(download_extension),
 32        )
 33}
 34
 35#[derive(Debug, Deserialize)]
 36struct GetExtensionsParams {
 37    filter: Option<String>,
 38    #[serde(default)]
 39    max_schema_version: i32,
 40}
 41
 42async fn get_extensions(
 43    Extension(app): Extension<Arc<AppState>>,
 44    Query(params): Query<GetExtensionsParams>,
 45) -> Result<Json<GetExtensionsResponse>> {
 46    let mut extensions = app
 47        .db
 48        .get_extensions(params.filter.as_deref(), params.max_schema_version, 500)
 49        .await?;
 50
 51    if let Some(filter) = params.filter.as_deref() {
 52        let extension_id = filter.to_lowercase();
 53        let mut exact_match = None;
 54        extensions.retain(|extension| {
 55            if extension.id.as_ref() == extension_id {
 56                exact_match = Some(extension.clone());
 57                false
 58            } else {
 59                true
 60            }
 61        });
 62        if exact_match.is_none() {
 63            exact_match = app
 64                .db
 65                .get_extensions_by_ids(&[&extension_id], None)
 66                .await?
 67                .first()
 68                .cloned();
 69        }
 70
 71        if let Some(exact_match) = exact_match {
 72            extensions.insert(0, exact_match);
 73        }
 74    };
 75
 76    if let Some(query) = params.filter.as_deref() {
 77        let count = extensions.len();
 78        tracing::info!(query, count, "extension_search")
 79    }
 80
 81    Ok(Json(GetExtensionsResponse { data: extensions }))
 82}
 83
 84#[derive(Debug, Deserialize)]
 85struct GetExtensionUpdatesParams {
 86    ids: String,
 87    min_schema_version: i32,
 88    max_schema_version: i32,
 89    min_wasm_api_version: SemanticVersion,
 90    max_wasm_api_version: SemanticVersion,
 91}
 92
 93async fn get_extension_updates(
 94    Extension(app): Extension<Arc<AppState>>,
 95    Query(params): Query<GetExtensionUpdatesParams>,
 96) -> Result<Json<GetExtensionsResponse>> {
 97    let constraints = ExtensionVersionConstraints {
 98        schema_versions: params.min_schema_version..=params.max_schema_version,
 99        wasm_api_versions: params.min_wasm_api_version..=params.max_wasm_api_version,
100    };
101
102    let extension_ids = params.ids.split(',').map(|s| s.trim()).collect::<Vec<_>>();
103
104    let extensions = app
105        .db
106        .get_extensions_by_ids(&extension_ids, Some(&constraints))
107        .await?;
108
109    Ok(Json(GetExtensionsResponse { data: extensions }))
110}
111
112#[derive(Debug, Deserialize)]
113struct GetExtensionVersionsParams {
114    extension_id: String,
115}
116
117async fn get_extension_versions(
118    Extension(app): Extension<Arc<AppState>>,
119    Path(params): Path<GetExtensionVersionsParams>,
120) -> Result<Json<GetExtensionsResponse>> {
121    let extension_versions = app.db.get_extension_versions(&params.extension_id).await?;
122
123    Ok(Json(GetExtensionsResponse {
124        data: extension_versions,
125    }))
126}
127
128#[derive(Debug, Deserialize)]
129struct DownloadLatestExtensionPathParams {
130    extension_id: String,
131}
132
133#[derive(Debug, Deserialize)]
134struct DownloadLatestExtensionQueryParams {
135    min_schema_version: Option<i32>,
136    max_schema_version: Option<i32>,
137    min_wasm_api_version: Option<SemanticVersion>,
138    max_wasm_api_version: Option<SemanticVersion>,
139}
140
141async fn download_latest_extension(
142    Extension(app): Extension<Arc<AppState>>,
143    Path(params): Path<DownloadLatestExtensionPathParams>,
144    Query(query): Query<DownloadLatestExtensionQueryParams>,
145) -> Result<Redirect> {
146    let constraints = maybe!({
147        let min_schema_version = query.min_schema_version?;
148        let max_schema_version = query.max_schema_version?;
149        let min_wasm_api_version = query.min_wasm_api_version?;
150        let max_wasm_api_version = query.max_wasm_api_version?;
151
152        Some(ExtensionVersionConstraints {
153            schema_versions: min_schema_version..=max_schema_version,
154            wasm_api_versions: min_wasm_api_version..=max_wasm_api_version,
155        })
156    });
157
158    let extension = app
159        .db
160        .get_extension(&params.extension_id, constraints.as_ref())
161        .await?
162        .ok_or_else(|| anyhow!("unknown extension"))?;
163    download_extension(
164        Extension(app),
165        Path(DownloadExtensionParams {
166            extension_id: params.extension_id,
167            version: extension.manifest.version.to_string(),
168        }),
169    )
170    .await
171}
172
173#[derive(Debug, Deserialize)]
174struct DownloadExtensionParams {
175    extension_id: String,
176    version: String,
177}
178
179async fn download_extension(
180    Extension(app): Extension<Arc<AppState>>,
181    Path(params): Path<DownloadExtensionParams>,
182) -> Result<Redirect> {
183    let Some((blob_store_client, bucket)) = app
184        .blob_store_client
185        .clone()
186        .zip(app.config.blob_store_bucket.clone())
187    else {
188        Err(Error::http(
189            StatusCode::NOT_IMPLEMENTED,
190            "not supported".into(),
191        ))?
192    };
193
194    let DownloadExtensionParams {
195        extension_id,
196        version,
197    } = params;
198
199    let version_exists = app
200        .db
201        .record_extension_download(&extension_id, &version)
202        .await?;
203
204    if !version_exists {
205        Err(Error::http(
206            StatusCode::NOT_FOUND,
207            "unknown extension version".into(),
208        ))?;
209    }
210
211    let url = blob_store_client
212        .get_object()
213        .bucket(bucket)
214        .key(format!(
215            "extensions/{extension_id}/{version}/archive.tar.gz"
216        ))
217        .presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
218        .await
219        .map_err(|e| anyhow!("failed to create presigned extension download url {e}"))?;
220
221    Ok(Redirect::temporary(url.uri()))
222}
223
224const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
225const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
226
227pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
228    let Some(blob_store_client) = app_state.blob_store_client.clone() else {
229        log::info!("no blob store client");
230        return;
231    };
232    let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
233        log::info!("no blob store bucket");
234        return;
235    };
236
237    let executor = app_state.executor.clone();
238    executor.spawn_detached({
239        let executor = executor.clone();
240        async move {
241            loop {
242                fetch_extensions_from_blob_store(
243                    &blob_store_client,
244                    &blob_store_bucket,
245                    &app_state,
246                )
247                .await
248                .log_err();
249                executor.sleep(EXTENSION_FETCH_INTERVAL).await;
250            }
251        }
252    });
253}
254
255async fn fetch_extensions_from_blob_store(
256    blob_store_client: &aws_sdk_s3::Client,
257    blob_store_bucket: &String,
258    app_state: &Arc<AppState>,
259) -> anyhow::Result<()> {
260    log::info!("fetching extensions from blob store");
261
262    let mut next_marker = None;
263    let mut published_versions = HashMap::<String, Vec<String>>::default();
264
265    loop {
266        let list = blob_store_client
267            .list_objects()
268            .bucket(blob_store_bucket)
269            .prefix("extensions/")
270            .set_marker(next_marker.clone())
271            .send()
272            .await?;
273        let objects = list.contents.unwrap_or_default();
274        log::info!("fetched {} object(s) from blob store", objects.len());
275
276        for object in &objects {
277            let Some(key) = object.key.as_ref() else {
278                continue;
279            };
280            let mut parts = key.split('/');
281            let Some(_) = parts.next().filter(|part| *part == "extensions") else {
282                continue;
283            };
284            let Some(extension_id) = parts.next() else {
285                continue;
286            };
287            let Some(version) = parts.next() else {
288                continue;
289            };
290            if parts.next() == Some("manifest.json") {
291                published_versions
292                    .entry(extension_id.to_owned())
293                    .or_default()
294                    .push(version.to_owned());
295            }
296        }
297
298        if let (Some(true), Some(last_object)) = (list.is_truncated, objects.last()) {
299            next_marker.clone_from(&last_object.key);
300        } else {
301            break;
302        }
303    }
304
305    log::info!("found {} published extensions", published_versions.len());
306
307    let known_versions = app_state.db.get_known_extension_versions().await?;
308
309    let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
310    let empty = Vec::new();
311    for (extension_id, published_versions) in &published_versions {
312        let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
313
314        for published_version in published_versions {
315            if known_versions
316                .binary_search_by_key(&published_version, |known_version| known_version)
317                .is_err()
318            {
319                if let Some(extension) = fetch_extension_manifest(
320                    blob_store_client,
321                    blob_store_bucket,
322                    extension_id,
323                    published_version,
324                )
325                .await
326                .log_err()
327                {
328                    new_versions
329                        .entry(extension_id)
330                        .or_default()
331                        .push(extension);
332                }
333            }
334        }
335    }
336
337    app_state
338        .db
339        .insert_extension_versions(&new_versions)
340        .await?;
341
342    log::info!(
343        "fetched {} new extensions from blob store",
344        new_versions.values().map(|v| v.len()).sum::<usize>()
345    );
346
347    Ok(())
348}
349
350async fn fetch_extension_manifest(
351    blob_store_client: &aws_sdk_s3::Client,
352    blob_store_bucket: &String,
353    extension_id: &str,
354    version: &str,
355) -> Result<NewExtensionVersion, anyhow::Error> {
356    let object = blob_store_client
357        .get_object()
358        .bucket(blob_store_bucket)
359        .key(format!("extensions/{extension_id}/{version}/manifest.json"))
360        .send()
361        .await?;
362    let manifest_bytes = object
363        .body
364        .collect()
365        .await
366        .map(|data| data.into_bytes())
367        .with_context(|| {
368            format!("failed to download manifest for extension {extension_id} version {version}")
369        })?
370        .to_vec();
371    let manifest =
372        serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
373            format!(
374                "invalid manifest for extension {extension_id} version {version}: {}",
375                String::from_utf8_lossy(&manifest_bytes)
376            )
377        })?;
378    let published_at = object.last_modified.ok_or_else(|| {
379        anyhow!("missing last modified timestamp for extension {extension_id} version {version}")
380    })?;
381    let published_at = time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
382    let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
383    let version = semver::Version::parse(&manifest.version).with_context(|| {
384        format!("invalid version for extension {extension_id} version {version}")
385    })?;
386    Ok(NewExtensionVersion {
387        name: manifest.name,
388        version,
389        description: manifest.description.unwrap_or_default(),
390        authors: manifest.authors,
391        repository: manifest.repository,
392        schema_version: manifest.schema_version.unwrap_or(0),
393        wasm_api_version: manifest.wasm_api_version,
394        published_at,
395    })
396}