extensions.rs

  1use crate::db::ExtensionVersionConstraints;
  2use crate::{AppState, Error, Result, db::NewExtensionVersion};
  3use anyhow::Context as _;
  4use aws_sdk_s3::presigning::PresigningConfig;
  5use axum::{
  6    Extension, Json, Router,
  7    extract::{Path, Query, RawQuery},
  8    http::StatusCode,
  9    response::Redirect,
 10    routing::get,
 11};
 12use cloud_api_types::{ExtensionApiManifest, GetExtensionsResponse};
 13use collections::HashMap;
 14use semver::Version as SemanticVersion;
 15use serde::Deserialize;
 16use std::{sync::Arc, time::Duration};
 17use time::PrimitiveDateTime;
 18use util::{ResultExt, maybe};
 19
 20pub fn router() -> Router {
 21    Router::new()
 22        .route("/extensions", get(get_extensions))
 23        .route("/extensions/updates", get(get_extension_updates))
 24        .route("/extensions/:extension_id", get(get_extension_versions))
 25        .route(
 26            "/extensions/:extension_id/download",
 27            get(download_latest_extension),
 28        )
 29        .route(
 30            "/extensions/:extension_id/:version/download",
 31            get(download_extension),
 32        )
 33}
 34
 35const UPSTREAM_EXTENSIONS_URL: &str = "https://cloud.zed.dev/extensions";
 36
 37async fn get_extensions(RawQuery(query): RawQuery) -> Result<Json<GetExtensionsResponse>> {
 38    let upstream_url = match query {
 39        Some(query) => format!("{UPSTREAM_EXTENSIONS_URL}?{query}"),
 40        None => UPSTREAM_EXTENSIONS_URL.to_string(),
 41    };
 42
 43    let response = reqwest::get(&upstream_url).await.map_err(|error| {
 44        tracing::error!(
 45            ?error,
 46            "failed to proxy request to upstream extensions service"
 47        );
 48        Error::http(
 49            StatusCode::BAD_GATEWAY,
 50            "upstream extensions service unavailable".into(),
 51        )
 52    })?;
 53
 54    let status = response.status();
 55    if !status.is_success() {
 56        let upstream_status =
 57            StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
 58        let body = response.text().await.unwrap_or_default();
 59        tracing::error!(
 60            status = status.as_u16(),
 61            body,
 62            "upstream extensions service returned an error"
 63        );
 64        return Err(Error::http(upstream_status, body));
 65    }
 66
 67    let body: GetExtensionsResponse = response.json().await.map_err(|error| {
 68        tracing::error!(
 69            ?error,
 70            "failed to parse response from upstream extensions service"
 71        );
 72        Error::http(
 73            StatusCode::BAD_GATEWAY,
 74            "failed to parse upstream response".into(),
 75        )
 76    })?;
 77
 78    Ok(Json(body))
 79}
 80
 81#[derive(Debug, Deserialize)]
 82struct GetExtensionUpdatesParams {
 83    ids: String,
 84    min_schema_version: i32,
 85    max_schema_version: i32,
 86    min_wasm_api_version: semver::Version,
 87    max_wasm_api_version: semver::Version,
 88}
 89
 90async fn get_extension_updates(
 91    Extension(app): Extension<Arc<AppState>>,
 92    Query(params): Query<GetExtensionUpdatesParams>,
 93) -> Result<Json<GetExtensionsResponse>> {
 94    let constraints = ExtensionVersionConstraints {
 95        schema_versions: params.min_schema_version..=params.max_schema_version,
 96        wasm_api_versions: params.min_wasm_api_version..=params.max_wasm_api_version,
 97    };
 98
 99    let extension_ids = params.ids.split(',').map(|s| s.trim()).collect::<Vec<_>>();
100
101    let extensions = app
102        .db
103        .get_extensions_by_ids(&extension_ids, Some(&constraints))
104        .await?;
105
106    Ok(Json(GetExtensionsResponse { data: extensions }))
107}
108
109#[derive(Debug, Deserialize)]
110struct GetExtensionVersionsParams {
111    extension_id: String,
112}
113
114async fn get_extension_versions(
115    Extension(app): Extension<Arc<AppState>>,
116    Path(params): Path<GetExtensionVersionsParams>,
117) -> Result<Json<GetExtensionsResponse>> {
118    let extension_versions = app.db.get_extension_versions(&params.extension_id).await?;
119
120    Ok(Json(GetExtensionsResponse {
121        data: extension_versions,
122    }))
123}
124
125#[derive(Debug, Deserialize)]
126struct DownloadLatestExtensionPathParams {
127    extension_id: String,
128}
129
130#[derive(Debug, Deserialize)]
131struct DownloadLatestExtensionQueryParams {
132    min_schema_version: Option<i32>,
133    max_schema_version: Option<i32>,
134    min_wasm_api_version: Option<SemanticVersion>,
135    max_wasm_api_version: Option<SemanticVersion>,
136}
137
138async fn download_latest_extension(
139    Extension(app): Extension<Arc<AppState>>,
140    Path(params): Path<DownloadLatestExtensionPathParams>,
141    Query(query): Query<DownloadLatestExtensionQueryParams>,
142) -> Result<Redirect> {
143    let constraints = maybe!({
144        let min_schema_version = query.min_schema_version?;
145        let max_schema_version = query.max_schema_version?;
146        let min_wasm_api_version = query.min_wasm_api_version?;
147        let max_wasm_api_version = query.max_wasm_api_version?;
148
149        Some(ExtensionVersionConstraints {
150            schema_versions: min_schema_version..=max_schema_version,
151            wasm_api_versions: min_wasm_api_version..=max_wasm_api_version,
152        })
153    });
154
155    let extension = app
156        .db
157        .get_extension(&params.extension_id, constraints.as_ref())
158        .await?
159        .context("unknown extension")?;
160    download_extension(
161        Extension(app),
162        Path(DownloadExtensionParams {
163            extension_id: params.extension_id,
164            version: extension.manifest.version.to_string(),
165        }),
166    )
167    .await
168}
169
170#[derive(Debug, Deserialize)]
171struct DownloadExtensionParams {
172    extension_id: String,
173    version: String,
174}
175
176async fn download_extension(
177    Extension(app): Extension<Arc<AppState>>,
178    Path(params): Path<DownloadExtensionParams>,
179) -> Result<Redirect> {
180    let Some((blob_store_client, bucket)) = app
181        .blob_store_client
182        .clone()
183        .zip(app.config.blob_store_bucket.clone())
184    else {
185        Err(Error::http(
186            StatusCode::NOT_IMPLEMENTED,
187            "not supported".into(),
188        ))?
189    };
190
191    let DownloadExtensionParams {
192        extension_id,
193        version,
194    } = params;
195
196    let version_exists = app
197        .db
198        .record_extension_download(&extension_id, &version)
199        .await?;
200
201    if !version_exists {
202        Err(Error::http(
203            StatusCode::NOT_FOUND,
204            "unknown extension version".into(),
205        ))?;
206    }
207
208    let url = blob_store_client
209        .get_object()
210        .bucket(bucket)
211        .key(format!(
212            "extensions/{extension_id}/{version}/archive.tar.gz"
213        ))
214        .presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
215        .await
216        .context("creating presigned extension download url")?;
217
218    Ok(Redirect::temporary(url.uri()))
219}
220
221const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
222const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
223
224pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
225    let Some(blob_store_client) = app_state.blob_store_client.clone() else {
226        log::info!("no blob store client");
227        return;
228    };
229    let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
230        log::info!("no blob store bucket");
231        return;
232    };
233
234    let executor = app_state.executor.clone();
235    executor.spawn_detached({
236        let executor = executor.clone();
237        async move {
238            loop {
239                fetch_extensions_from_blob_store(
240                    &blob_store_client,
241                    &blob_store_bucket,
242                    &app_state,
243                )
244                .await
245                .log_err();
246                executor.sleep(EXTENSION_FETCH_INTERVAL).await;
247            }
248        }
249    });
250}
251
252async fn fetch_extensions_from_blob_store(
253    blob_store_client: &aws_sdk_s3::Client,
254    blob_store_bucket: &String,
255    app_state: &Arc<AppState>,
256) -> anyhow::Result<()> {
257    log::info!("fetching extensions from blob store");
258
259    let mut next_marker = None;
260    let mut published_versions = HashMap::<String, Vec<String>>::default();
261
262    loop {
263        let list = blob_store_client
264            .list_objects()
265            .bucket(blob_store_bucket)
266            .prefix("extensions/")
267            .set_marker(next_marker.clone())
268            .send()
269            .await?;
270        let objects = list.contents.unwrap_or_default();
271        log::info!("fetched {} object(s) from blob store", objects.len());
272
273        for object in &objects {
274            let Some(key) = object.key.as_ref() else {
275                continue;
276            };
277            let mut parts = key.split('/');
278            let Some(_) = parts.next().filter(|part| *part == "extensions") else {
279                continue;
280            };
281            let Some(extension_id) = parts.next() else {
282                continue;
283            };
284            let Some(version) = parts.next() else {
285                continue;
286            };
287            if parts.next() == Some("manifest.json") {
288                published_versions
289                    .entry(extension_id.to_owned())
290                    .or_default()
291                    .push(version.to_owned());
292            }
293        }
294
295        if let (Some(true), Some(last_object)) = (list.is_truncated, objects.last()) {
296            next_marker.clone_from(&last_object.key);
297        } else {
298            break;
299        }
300    }
301
302    log::info!("found {} published extensions", published_versions.len());
303
304    let known_versions = app_state.db.get_known_extension_versions().await?;
305
306    let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
307    let empty = Vec::new();
308    for (extension_id, published_versions) in &published_versions {
309        let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
310
311        for published_version in published_versions {
312            if known_versions
313                .binary_search_by_key(&published_version, |known_version| known_version)
314                .is_err()
315                && let Some(extension) = fetch_extension_manifest(
316                    blob_store_client,
317                    blob_store_bucket,
318                    extension_id,
319                    published_version,
320                )
321                .await
322                .log_err()
323            {
324                new_versions
325                    .entry(extension_id)
326                    .or_default()
327                    .push(extension);
328            }
329        }
330    }
331
332    app_state
333        .db
334        .insert_extension_versions(&new_versions)
335        .await?;
336
337    log::info!(
338        "fetched {} new extensions from blob store",
339        new_versions.values().map(|v| v.len()).sum::<usize>()
340    );
341
342    Ok(())
343}
344
345async fn fetch_extension_manifest(
346    blob_store_client: &aws_sdk_s3::Client,
347    blob_store_bucket: &String,
348    extension_id: &str,
349    version: &str,
350) -> anyhow::Result<NewExtensionVersion> {
351    let object = blob_store_client
352        .get_object()
353        .bucket(blob_store_bucket)
354        .key(format!("extensions/{extension_id}/{version}/manifest.json"))
355        .send()
356        .await?;
357    let manifest_bytes = object
358        .body
359        .collect()
360        .await
361        .map(|data| data.into_bytes())
362        .with_context(|| {
363            format!("failed to download manifest for extension {extension_id} version {version}")
364        })?
365        .to_vec();
366    let manifest =
367        serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
368            format!(
369                "invalid manifest for extension {extension_id} version {version}: {}",
370                String::from_utf8_lossy(&manifest_bytes)
371            )
372        })?;
373    let published_at = object.last_modified.with_context(|| {
374        format!("missing last modified timestamp for extension {extension_id} version {version}")
375    })?;
376    let published_at = time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
377    let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
378    let version = semver::Version::parse(&manifest.version).with_context(|| {
379        format!("invalid version for extension {extension_id} version {version}")
380    })?;
381    Ok(NewExtensionVersion {
382        name: manifest.name,
383        version,
384        description: manifest.description.unwrap_or_default(),
385        authors: manifest.authors,
386        repository: manifest.repository,
387        schema_version: manifest.schema_version.unwrap_or(0),
388        wasm_api_version: manifest.wasm_api_version,
389        provides: manifest.provides,
390        published_at,
391    })
392}