extensions.rs

  1use crate::{
  2    db::{ExtensionMetadata, NewExtensionVersion},
  3    AppState, Error, Result,
  4};
  5use anyhow::{anyhow, Context as _};
  6use aws_sdk_s3::presigning::PresigningConfig;
  7use axum::{
  8    extract::{Path, Query},
  9    http::StatusCode,
 10    response::Redirect,
 11    routing::get,
 12    Extension, Json, Router,
 13};
 14use collections::HashMap;
 15use rpc::ExtensionApiManifest;
 16use serde::{Deserialize, Serialize};
 17use std::{sync::Arc, time::Duration};
 18use time::PrimitiveDateTime;
 19use util::ResultExt;
 20
 21pub fn router() -> Router {
 22    Router::new()
 23        .route("/extensions", get(get_extensions))
 24        .route(
 25            "/extensions/:extension_id/download",
 26            get(download_latest_extension),
 27        )
 28        .route(
 29            "/extensions/:extension_id/:version/download",
 30            get(download_extension),
 31        )
 32}
 33
 34#[derive(Debug, Deserialize)]
 35struct GetExtensionsParams {
 36    filter: Option<String>,
 37    #[serde(default)]
 38    max_schema_version: i32,
 39}
 40
 41#[derive(Debug, Deserialize)]
 42struct DownloadLatestExtensionParams {
 43    extension_id: String,
 44}
 45
 46#[derive(Debug, Deserialize)]
 47struct DownloadExtensionParams {
 48    extension_id: String,
 49    version: String,
 50}
 51
 52#[derive(Debug, Serialize)]
 53struct GetExtensionsResponse {
 54    pub data: Vec<ExtensionMetadata>,
 55}
 56
 57async fn get_extensions(
 58    Extension(app): Extension<Arc<AppState>>,
 59    Query(params): Query<GetExtensionsParams>,
 60) -> Result<Json<GetExtensionsResponse>> {
 61    let extensions = app
 62        .db
 63        .get_extensions(params.filter.as_deref(), params.max_schema_version, 500)
 64        .await?;
 65    Ok(Json(GetExtensionsResponse { data: extensions }))
 66}
 67
 68async fn download_latest_extension(
 69    Extension(app): Extension<Arc<AppState>>,
 70    Path(params): Path<DownloadLatestExtensionParams>,
 71) -> Result<Redirect> {
 72    let extension = app
 73        .db
 74        .get_extension(&params.extension_id)
 75        .await?
 76        .ok_or_else(|| anyhow!("unknown extension"))?;
 77    download_extension(
 78        Extension(app),
 79        Path(DownloadExtensionParams {
 80            extension_id: params.extension_id,
 81            version: extension.version,
 82        }),
 83    )
 84    .await
 85}
 86
 87async fn download_extension(
 88    Extension(app): Extension<Arc<AppState>>,
 89    Path(params): Path<DownloadExtensionParams>,
 90) -> Result<Redirect> {
 91    let Some((blob_store_client, bucket)) = app
 92        .blob_store_client
 93        .clone()
 94        .zip(app.config.blob_store_bucket.clone())
 95    else {
 96        Err(Error::Http(
 97            StatusCode::NOT_IMPLEMENTED,
 98            "not supported".into(),
 99        ))?
100    };
101
102    let DownloadExtensionParams {
103        extension_id,
104        version,
105    } = params;
106
107    let version_exists = app
108        .db
109        .record_extension_download(&extension_id, &version)
110        .await?;
111
112    if !version_exists {
113        Err(Error::Http(
114            StatusCode::NOT_FOUND,
115            "unknown extension version".into(),
116        ))?;
117    }
118
119    let url = blob_store_client
120        .get_object()
121        .bucket(bucket)
122        .key(format!(
123            "extensions/{extension_id}/{version}/archive.tar.gz"
124        ))
125        .presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
126        .await
127        .map_err(|e| anyhow!("failed to create presigned extension download url {e}"))?;
128
129    Ok(Redirect::temporary(url.uri()))
130}
131
132const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
133const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
134
135pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
136    let Some(blob_store_client) = app_state.blob_store_client.clone() else {
137        log::info!("no blob store client");
138        return;
139    };
140    let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
141        log::info!("no blob store bucket");
142        return;
143    };
144
145    let executor = app_state.executor.clone();
146    executor.spawn_detached({
147        let executor = executor.clone();
148        async move {
149            loop {
150                fetch_extensions_from_blob_store(
151                    &blob_store_client,
152                    &blob_store_bucket,
153                    &app_state,
154                )
155                .await
156                .log_err();
157                executor.sleep(EXTENSION_FETCH_INTERVAL).await;
158            }
159        }
160    });
161}
162
163async fn fetch_extensions_from_blob_store(
164    blob_store_client: &aws_sdk_s3::Client,
165    blob_store_bucket: &String,
166    app_state: &Arc<AppState>,
167) -> anyhow::Result<()> {
168    log::info!("fetching extensions from blob store");
169
170    let list = blob_store_client
171        .list_objects()
172        .bucket(blob_store_bucket)
173        .prefix("extensions/")
174        .send()
175        .await?;
176
177    let objects = list.contents.unwrap_or_default();
178
179    let mut published_versions = HashMap::<&str, Vec<&str>>::default();
180    for object in &objects {
181        let Some(key) = object.key.as_ref() else {
182            continue;
183        };
184        let mut parts = key.split('/');
185        let Some(_) = parts.next().filter(|part| *part == "extensions") else {
186            continue;
187        };
188        let Some(extension_id) = parts.next() else {
189            continue;
190        };
191        let Some(version) = parts.next() else {
192            continue;
193        };
194        if parts.next() == Some("manifest.json") {
195            published_versions
196                .entry(extension_id)
197                .or_default()
198                .push(version);
199        }
200    }
201
202    let known_versions = app_state.db.get_known_extension_versions().await?;
203
204    let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
205    let empty = Vec::new();
206    for (extension_id, published_versions) in published_versions {
207        let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
208
209        for published_version in published_versions {
210            if known_versions
211                .binary_search_by_key(&published_version, String::as_str)
212                .is_err()
213            {
214                if let Some(extension) = fetch_extension_manifest(
215                    blob_store_client,
216                    blob_store_bucket,
217                    extension_id,
218                    published_version,
219                )
220                .await
221                .log_err()
222                {
223                    new_versions
224                        .entry(extension_id)
225                        .or_default()
226                        .push(extension);
227                }
228            }
229        }
230    }
231
232    app_state
233        .db
234        .insert_extension_versions(&new_versions)
235        .await?;
236
237    log::info!(
238        "fetched {} new extensions from blob store",
239        new_versions.values().map(|v| v.len()).sum::<usize>()
240    );
241
242    Ok(())
243}
244
245async fn fetch_extension_manifest(
246    blob_store_client: &aws_sdk_s3::Client,
247    blob_store_bucket: &String,
248    extension_id: &str,
249    version: &str,
250) -> Result<NewExtensionVersion, anyhow::Error> {
251    let object = blob_store_client
252        .get_object()
253        .bucket(blob_store_bucket)
254        .key(format!("extensions/{extension_id}/{version}/manifest.json"))
255        .send()
256        .await?;
257    let manifest_bytes = object
258        .body
259        .collect()
260        .await
261        .map(|data| data.into_bytes())
262        .with_context(|| {
263            format!("failed to download manifest for extension {extension_id} version {version}")
264        })?
265        .to_vec();
266    let manifest =
267        serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
268            format!(
269                "invalid manifest for extension {extension_id} version {version}: {}",
270                String::from_utf8_lossy(&manifest_bytes)
271            )
272        })?;
273    let published_at = object.last_modified.ok_or_else(|| {
274        anyhow!("missing last modified timestamp for extension {extension_id} version {version}")
275    })?;
276    let published_at = time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
277    let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
278    let version = semver::Version::parse(&manifest.version).with_context(|| {
279        format!("invalid version for extension {extension_id} version {version}")
280    })?;
281    Ok(NewExtensionVersion {
282        name: manifest.name,
283        version,
284        description: manifest.description.unwrap_or_default(),
285        authors: manifest.authors,
286        repository: manifest.repository,
287        schema_version: manifest.schema_version.unwrap_or(0),
288        published_at,
289    })
290}