1use crate::db::ExtensionVersionConstraints;
  2use crate::{AppState, Error, Result, db::NewExtensionVersion};
  3use anyhow::Context as _;
  4use aws_sdk_s3::presigning::PresigningConfig;
  5use axum::{
  6    Extension, Json, Router,
  7    extract::{Path, Query},
  8    http::StatusCode,
  9    response::Redirect,
 10    routing::get,
 11};
 12use collections::{BTreeSet, HashMap};
 13use rpc::{ExtensionApiManifest, ExtensionProvides, GetExtensionsResponse};
 14use semantic_version::SemanticVersion;
 15use serde::Deserialize;
 16use std::str::FromStr;
 17use std::{sync::Arc, time::Duration};
 18use time::PrimitiveDateTime;
 19use util::{ResultExt, maybe};
 20
 21pub fn router() -> Router {
 22    Router::new()
 23        .route("/extensions", get(get_extensions))
 24        .route("/extensions/updates", get(get_extension_updates))
 25        .route("/extensions/:extension_id", get(get_extension_versions))
 26        .route(
 27            "/extensions/:extension_id/download",
 28            get(download_latest_extension),
 29        )
 30        .route(
 31            "/extensions/:extension_id/:version/download",
 32            get(download_extension),
 33        )
 34}
 35
 36#[derive(Debug, Deserialize)]
 37struct GetExtensionsParams {
 38    filter: Option<String>,
 39    /// A comma-delimited list of features that the extension must provide.
 40    ///
 41    /// For example:
 42    /// - `themes`
 43    /// - `themes,icon-themes`
 44    /// - `languages,language-servers`
 45    #[serde(default)]
 46    provides: Option<String>,
 47    #[serde(default)]
 48    max_schema_version: i32,
 49}
 50
 51async fn get_extensions(
 52    Extension(app): Extension<Arc<AppState>>,
 53    Query(params): Query<GetExtensionsParams>,
 54) -> Result<Json<GetExtensionsResponse>> {
 55    let provides_filter = params.provides.map(|provides| {
 56        provides
 57            .split(',')
 58            .map(|value| value.trim())
 59            .filter_map(|value| ExtensionProvides::from_str(value).ok())
 60            .collect::<BTreeSet<_>>()
 61    });
 62
 63    let mut extensions = app
 64        .db
 65        .get_extensions(
 66            params.filter.as_deref(),
 67            provides_filter.as_ref(),
 68            params.max_schema_version,
 69            1_000,
 70        )
 71        .await?;
 72
 73    if let Some(filter) = params.filter.as_deref() {
 74        let extension_id = filter.to_lowercase();
 75        let mut exact_match = None;
 76        extensions.retain(|extension| {
 77            if extension.id.as_ref() == extension_id {
 78                exact_match = Some(extension.clone());
 79                false
 80            } else {
 81                true
 82            }
 83        });
 84        if exact_match.is_none() {
 85            exact_match = app
 86                .db
 87                .get_extensions_by_ids(&[&extension_id], None)
 88                .await?
 89                .first()
 90                .cloned();
 91        }
 92
 93        if let Some(exact_match) = exact_match {
 94            extensions.insert(0, exact_match);
 95        }
 96    };
 97
 98    if let Some(query) = params.filter.as_deref() {
 99        let count = extensions.len();
100        tracing::info!(query, count, "extension_search")
101    }
102
103    Ok(Json(GetExtensionsResponse { data: extensions }))
104}
105
106#[derive(Debug, Deserialize)]
107struct GetExtensionUpdatesParams {
108    ids: String,
109    min_schema_version: i32,
110    max_schema_version: i32,
111    min_wasm_api_version: SemanticVersion,
112    max_wasm_api_version: SemanticVersion,
113}
114
115async fn get_extension_updates(
116    Extension(app): Extension<Arc<AppState>>,
117    Query(params): Query<GetExtensionUpdatesParams>,
118) -> Result<Json<GetExtensionsResponse>> {
119    let constraints = ExtensionVersionConstraints {
120        schema_versions: params.min_schema_version..=params.max_schema_version,
121        wasm_api_versions: params.min_wasm_api_version..=params.max_wasm_api_version,
122    };
123
124    let extension_ids = params.ids.split(',').map(|s| s.trim()).collect::<Vec<_>>();
125
126    let extensions = app
127        .db
128        .get_extensions_by_ids(&extension_ids, Some(&constraints))
129        .await?;
130
131    Ok(Json(GetExtensionsResponse { data: extensions }))
132}
133
134#[derive(Debug, Deserialize)]
135struct GetExtensionVersionsParams {
136    extension_id: String,
137}
138
139async fn get_extension_versions(
140    Extension(app): Extension<Arc<AppState>>,
141    Path(params): Path<GetExtensionVersionsParams>,
142) -> Result<Json<GetExtensionsResponse>> {
143    let extension_versions = app.db.get_extension_versions(¶ms.extension_id).await?;
144
145    Ok(Json(GetExtensionsResponse {
146        data: extension_versions,
147    }))
148}
149
150#[derive(Debug, Deserialize)]
151struct DownloadLatestExtensionPathParams {
152    extension_id: String,
153}
154
155#[derive(Debug, Deserialize)]
156struct DownloadLatestExtensionQueryParams {
157    min_schema_version: Option<i32>,
158    max_schema_version: Option<i32>,
159    min_wasm_api_version: Option<SemanticVersion>,
160    max_wasm_api_version: Option<SemanticVersion>,
161}
162
163async fn download_latest_extension(
164    Extension(app): Extension<Arc<AppState>>,
165    Path(params): Path<DownloadLatestExtensionPathParams>,
166    Query(query): Query<DownloadLatestExtensionQueryParams>,
167) -> Result<Redirect> {
168    let constraints = maybe!({
169        let min_schema_version = query.min_schema_version?;
170        let max_schema_version = query.max_schema_version?;
171        let min_wasm_api_version = query.min_wasm_api_version?;
172        let max_wasm_api_version = query.max_wasm_api_version?;
173
174        Some(ExtensionVersionConstraints {
175            schema_versions: min_schema_version..=max_schema_version,
176            wasm_api_versions: min_wasm_api_version..=max_wasm_api_version,
177        })
178    });
179
180    let extension = app
181        .db
182        .get_extension(¶ms.extension_id, constraints.as_ref())
183        .await?
184        .context("unknown extension")?;
185    download_extension(
186        Extension(app),
187        Path(DownloadExtensionParams {
188            extension_id: params.extension_id,
189            version: extension.manifest.version.to_string(),
190        }),
191    )
192    .await
193}
194
195#[derive(Debug, Deserialize)]
196struct DownloadExtensionParams {
197    extension_id: String,
198    version: String,
199}
200
201async fn download_extension(
202    Extension(app): Extension<Arc<AppState>>,
203    Path(params): Path<DownloadExtensionParams>,
204) -> Result<Redirect> {
205    let Some((blob_store_client, bucket)) = app
206        .blob_store_client
207        .clone()
208        .zip(app.config.blob_store_bucket.clone())
209    else {
210        Err(Error::http(
211            StatusCode::NOT_IMPLEMENTED,
212            "not supported".into(),
213        ))?
214    };
215
216    let DownloadExtensionParams {
217        extension_id,
218        version,
219    } = params;
220
221    let version_exists = app
222        .db
223        .record_extension_download(&extension_id, &version)
224        .await?;
225
226    if !version_exists {
227        Err(Error::http(
228            StatusCode::NOT_FOUND,
229            "unknown extension version".into(),
230        ))?;
231    }
232
233    let url = blob_store_client
234        .get_object()
235        .bucket(bucket)
236        .key(format!(
237            "extensions/{extension_id}/{version}/archive.tar.gz"
238        ))
239        .presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
240        .await
241        .context("creating presigned extension download url")?;
242
243    Ok(Redirect::temporary(url.uri()))
244}
245
246const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
247const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
248
249pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
250    let Some(blob_store_client) = app_state.blob_store_client.clone() else {
251        log::info!("no blob store client");
252        return;
253    };
254    let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
255        log::info!("no blob store bucket");
256        return;
257    };
258
259    let executor = app_state.executor.clone();
260    executor.spawn_detached({
261        let executor = executor.clone();
262        async move {
263            loop {
264                fetch_extensions_from_blob_store(
265                    &blob_store_client,
266                    &blob_store_bucket,
267                    &app_state,
268                )
269                .await
270                .log_err();
271                executor.sleep(EXTENSION_FETCH_INTERVAL).await;
272            }
273        }
274    });
275}
276
277async fn fetch_extensions_from_blob_store(
278    blob_store_client: &aws_sdk_s3::Client,
279    blob_store_bucket: &String,
280    app_state: &Arc<AppState>,
281) -> anyhow::Result<()> {
282    log::info!("fetching extensions from blob store");
283
284    let mut next_marker = None;
285    let mut published_versions = HashMap::<String, Vec<String>>::default();
286
287    loop {
288        let list = blob_store_client
289            .list_objects()
290            .bucket(blob_store_bucket)
291            .prefix("extensions/")
292            .set_marker(next_marker.clone())
293            .send()
294            .await?;
295        let objects = list.contents.unwrap_or_default();
296        log::info!("fetched {} object(s) from blob store", objects.len());
297
298        for object in &objects {
299            let Some(key) = object.key.as_ref() else {
300                continue;
301            };
302            let mut parts = key.split('/');
303            let Some(_) = parts.next().filter(|part| *part == "extensions") else {
304                continue;
305            };
306            let Some(extension_id) = parts.next() else {
307                continue;
308            };
309            let Some(version) = parts.next() else {
310                continue;
311            };
312            if parts.next() == Some("manifest.json") {
313                published_versions
314                    .entry(extension_id.to_owned())
315                    .or_default()
316                    .push(version.to_owned());
317            }
318        }
319
320        if let (Some(true), Some(last_object)) = (list.is_truncated, objects.last()) {
321            next_marker.clone_from(&last_object.key);
322        } else {
323            break;
324        }
325    }
326
327    log::info!("found {} published extensions", published_versions.len());
328
329    let known_versions = app_state.db.get_known_extension_versions().await?;
330
331    let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
332    let empty = Vec::new();
333    for (extension_id, published_versions) in &published_versions {
334        let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
335
336        for published_version in published_versions {
337            if known_versions
338                .binary_search_by_key(&published_version, |known_version| known_version)
339                .is_err()
340                && let Some(extension) = fetch_extension_manifest(
341                    blob_store_client,
342                    blob_store_bucket,
343                    extension_id,
344                    published_version,
345                )
346                .await
347                .log_err()
348            {
349                new_versions
350                    .entry(extension_id)
351                    .or_default()
352                    .push(extension);
353            }
354        }
355    }
356
357    app_state
358        .db
359        .insert_extension_versions(&new_versions)
360        .await?;
361
362    log::info!(
363        "fetched {} new extensions from blob store",
364        new_versions.values().map(|v| v.len()).sum::<usize>()
365    );
366
367    Ok(())
368}
369
370async fn fetch_extension_manifest(
371    blob_store_client: &aws_sdk_s3::Client,
372    blob_store_bucket: &String,
373    extension_id: &str,
374    version: &str,
375) -> anyhow::Result<NewExtensionVersion> {
376    let object = blob_store_client
377        .get_object()
378        .bucket(blob_store_bucket)
379        .key(format!("extensions/{extension_id}/{version}/manifest.json"))
380        .send()
381        .await?;
382    let manifest_bytes = object
383        .body
384        .collect()
385        .await
386        .map(|data| data.into_bytes())
387        .with_context(|| {
388            format!("failed to download manifest for extension {extension_id} version {version}")
389        })?
390        .to_vec();
391    let manifest =
392        serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
393            format!(
394                "invalid manifest for extension {extension_id} version {version}: {}",
395                String::from_utf8_lossy(&manifest_bytes)
396            )
397        })?;
398    let published_at = object.last_modified.with_context(|| {
399        format!("missing last modified timestamp for extension {extension_id} version {version}")
400    })?;
401    let published_at = time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
402    let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
403    let version = semver::Version::parse(&manifest.version).with_context(|| {
404        format!("invalid version for extension {extension_id} version {version}")
405    })?;
406    Ok(NewExtensionVersion {
407        name: manifest.name,
408        version,
409        description: manifest.description.unwrap_or_default(),
410        authors: manifest.authors,
411        repository: manifest.repository,
412        schema_version: manifest.schema_version.unwrap_or(0),
413        wasm_api_version: manifest.wasm_api_version,
414        provides: manifest.provides,
415        published_at,
416    })
417}