@@ -4,16 +4,15 @@ use anyhow::Context as _;
use aws_sdk_s3::presigning::PresigningConfig;
use axum::{
Extension, Json, Router,
- extract::{Path, Query},
+ extract::{Path, Query, RawQuery},
http::StatusCode,
response::Redirect,
routing::get,
};
-use cloud_api_types::{ExtensionApiManifest, ExtensionProvides, GetExtensionsResponse};
-use collections::{BTreeSet, HashMap};
+use cloud_api_types::{ExtensionApiManifest, GetExtensionsResponse};
+use collections::HashMap;
use semver::Version as SemanticVersion;
use serde::Deserialize;
-use std::str::FromStr;
use std::{sync::Arc, time::Duration};
use time::PrimitiveDateTime;
use util::{ResultExt, maybe};
@@ -33,74 +32,50 @@ pub fn router() -> Router {
)
}
-#[derive(Debug, Deserialize)]
-struct GetExtensionsParams {
- filter: Option<String>,
- /// A comma-delimited list of features that the extension must provide.
- ///
- /// For example:
- /// - `themes`
- /// - `themes,icon-themes`
- /// - `languages,language-servers`
- #[serde(default)]
- provides: Option<String>,
- #[serde(default)]
- max_schema_version: i32,
-}
+const UPSTREAM_EXTENSIONS_URL: &str = "https://cloud.zed.dev/extensions";
-async fn get_extensions(
- Extension(app): Extension<Arc<AppState>>,
- Query(params): Query<GetExtensionsParams>,
-) -> Result<Json<GetExtensionsResponse>> {
- let provides_filter = params.provides.map(|provides| {
- provides
- .split(',')
- .map(|value| value.trim())
- .filter_map(|value| ExtensionProvides::from_str(value).ok())
- .collect::<BTreeSet<_>>()
- });
+async fn get_extensions(RawQuery(query): RawQuery) -> Result<Json<GetExtensionsResponse>> {
+ let upstream_url = match query {
+ Some(query) => format!("{UPSTREAM_EXTENSIONS_URL}?{query}"),
+ None => UPSTREAM_EXTENSIONS_URL.to_string(),
+ };
- let mut extensions = app
- .db
- .get_extensions(
- params.filter.as_deref(),
- provides_filter.as_ref(),
- params.max_schema_version,
- 1_000,
+ let response = reqwest::get(&upstream_url).await.map_err(|error| {
+ tracing::error!(
+ ?error,
+ "failed to proxy request to upstream extensions service"
+ );
+ Error::http(
+ StatusCode::BAD_GATEWAY,
+ "upstream extensions service unavailable".into(),
)
- .await?;
-
- if let Some(filter) = params.filter.as_deref() {
- let extension_id = filter.to_lowercase();
- let mut exact_match = None;
- extensions.retain(|extension| {
- if extension.id.as_ref() == extension_id {
- exact_match = Some(extension.clone());
- false
- } else {
- true
- }
- });
- if exact_match.is_none() {
- exact_match = app
- .db
- .get_extensions_by_ids(&[&extension_id], None)
- .await?
- .first()
- .cloned();
- }
-
- if let Some(exact_match) = exact_match {
- extensions.insert(0, exact_match);
- }
- };
+ })?;
- if let Some(query) = params.filter.as_deref() {
- let count = extensions.len();
- tracing::info!(query, count, "extension_search")
+ let status = response.status();
+ if !status.is_success() {
+ let upstream_status =
+ StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
+ let body = response.text().await.unwrap_or_default();
+ tracing::error!(
+ status = status.as_u16(),
+ body,
+ "upstream extensions service returned an error"
+ );
+ return Err(Error::http(upstream_status, body));
}
- Ok(Json(GetExtensionsResponse { data: extensions }))
+ let body: GetExtensionsResponse = response.json().await.map_err(|error| {
+ tracing::error!(
+ ?error,
+ "failed to parse response from upstream extensions service"
+ );
+ Error::http(
+ StatusCode::BAD_GATEWAY,
+ "failed to parse upstream response".into(),
+ )
+ })?;
+
+ Ok(Json(body))
}
#[derive(Debug, Deserialize)]
@@ -8,359 +8,6 @@ use collab::db::{NewExtensionVersion, queries::extensions::convert_time_to_chron
use crate::test_both_dbs;
-test_both_dbs!(
- test_extensions,
- test_extensions_postgres,
- test_extensions_sqlite
-);
-
-test_both_dbs!(
- test_agent_servers_filter,
- test_agent_servers_filter_postgres,
- test_agent_servers_filter_sqlite
-);
-
-async fn test_agent_servers_filter(db: &Arc<Database>) {
- // No extensions initially
- let versions = db.get_known_extension_versions().await.unwrap();
- assert!(versions.is_empty());
-
- // Shared timestamp
- let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap();
- let t0 = time::PrimitiveDateTime::new(t0.date(), t0.time());
-
- // Insert two extensions, only one provides AgentServers
- db.insert_extension_versions(
- &[
- (
- "ext_agent_servers",
- vec![NewExtensionVersion {
- name: "Agent Servers Provider".into(),
- version: semver::Version::parse("1.0.0").unwrap(),
- description: "has agent servers".into(),
- authors: vec!["author".into()],
- repository: "org/agent-servers".into(),
- schema_version: 1,
- wasm_api_version: None,
- provides: BTreeSet::from_iter([ExtensionProvides::AgentServers]),
- published_at: t0,
- }],
- ),
- (
- "ext_plain",
- vec![NewExtensionVersion {
- name: "Plain Extension".into(),
- version: semver::Version::parse("0.1.0").unwrap(),
- description: "no agent servers".into(),
- authors: vec!["author2".into()],
- repository: "org/plain".into(),
- schema_version: 1,
- wasm_api_version: None,
- provides: BTreeSet::default(),
- published_at: t0,
- }],
- ),
- ]
- .into_iter()
- .collect(),
- )
- .await
- .unwrap();
-
- // Filter by AgentServers provides
- let provides_filter = BTreeSet::from_iter([ExtensionProvides::AgentServers]);
-
- let filtered = db
- .get_extensions(None, Some(&provides_filter), 1, 10)
- .await
- .unwrap();
-
- // Expect only the extension that declared AgentServers
- assert_eq!(filtered.len(), 1);
- assert_eq!(filtered[0].id.as_ref(), "ext_agent_servers");
-}
-
-async fn test_extensions(db: &Arc<Database>) {
- let versions = db.get_known_extension_versions().await.unwrap();
- assert!(versions.is_empty());
-
- let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
- assert!(extensions.is_empty());
-
- let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap();
- let t0 = time::PrimitiveDateTime::new(t0.date(), t0.time());
-
- let t0_chrono = convert_time_to_chrono(t0);
-
- db.insert_extension_versions(
- &[
- (
- "ext1",
- vec![
- NewExtensionVersion {
- name: "Extension 1".into(),
- version: semver::Version::parse("0.0.1").unwrap(),
- description: "an extension".into(),
- authors: vec!["max".into()],
- repository: "ext1/repo".into(),
- schema_version: 1,
- wasm_api_version: None,
- provides: BTreeSet::default(),
- published_at: t0,
- },
- NewExtensionVersion {
- name: "Extension One".into(),
- version: semver::Version::parse("0.0.2").unwrap(),
- description: "a good extension".into(),
- authors: vec!["max".into(), "marshall".into()],
- repository: "ext1/repo".into(),
- schema_version: 1,
- wasm_api_version: None,
- provides: BTreeSet::default(),
- published_at: t0,
- },
- ],
- ),
- (
- "ext2",
- vec![NewExtensionVersion {
- name: "Extension Two".into(),
- version: semver::Version::parse("0.2.0").unwrap(),
- description: "a great extension".into(),
- authors: vec!["marshall".into()],
- repository: "ext2/repo".into(),
- schema_version: 0,
- wasm_api_version: None,
- provides: BTreeSet::default(),
- published_at: t0,
- }],
- ),
- ]
- .into_iter()
- .collect(),
- )
- .await
- .unwrap();
-
- let versions = db.get_known_extension_versions().await.unwrap();
- assert_eq!(
- versions,
- [
- ("ext1".into(), vec!["0.0.1".into(), "0.0.2".into()]),
- ("ext2".into(), vec!["0.2.0".into()])
- ]
- .into_iter()
- .collect()
- );
-
- // The latest version of each extension is returned.
- let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
- assert_eq!(
- extensions,
- &[
- ExtensionMetadata {
- id: "ext1".into(),
- manifest: cloud_api_types::ExtensionApiManifest {
- name: "Extension One".into(),
- version: "0.0.2".into(),
- authors: vec!["max".into(), "marshall".into()],
- description: Some("a good extension".into()),
- repository: "ext1/repo".into(),
- schema_version: Some(1),
- wasm_api_version: None,
- provides: BTreeSet::default(),
- },
- published_at: t0_chrono,
- download_count: 0,
- },
- ExtensionMetadata {
- id: "ext2".into(),
- manifest: cloud_api_types::ExtensionApiManifest {
- name: "Extension Two".into(),
- version: "0.2.0".into(),
- authors: vec!["marshall".into()],
- description: Some("a great extension".into()),
- repository: "ext2/repo".into(),
- schema_version: Some(0),
- wasm_api_version: None,
- provides: BTreeSet::default(),
- },
- published_at: t0_chrono,
- download_count: 0
- },
- ]
- );
-
- // Extensions with too new of a schema version are excluded.
- let extensions = db.get_extensions(None, None, 0, 5).await.unwrap();
- assert_eq!(
- extensions,
- &[ExtensionMetadata {
- id: "ext2".into(),
- manifest: cloud_api_types::ExtensionApiManifest {
- name: "Extension Two".into(),
- version: "0.2.0".into(),
- authors: vec!["marshall".into()],
- description: Some("a great extension".into()),
- repository: "ext2/repo".into(),
- schema_version: Some(0),
- wasm_api_version: None,
- provides: BTreeSet::default(),
- },
- published_at: t0_chrono,
- download_count: 0
- },]
- );
-
- // Record extensions being downloaded.
- for _ in 0..7 {
- assert!(db.record_extension_download("ext2", "0.0.2").await.unwrap());
- }
-
- for _ in 0..3 {
- assert!(db.record_extension_download("ext1", "0.0.1").await.unwrap());
- }
-
- for _ in 0..2 {
- assert!(db.record_extension_download("ext1", "0.0.2").await.unwrap());
- }
-
- // Record download returns false if the extension does not exist.
- assert!(
- !db.record_extension_download("no-such-extension", "0.0.2")
- .await
- .unwrap()
- );
-
- // Extensions are returned in descending order of total downloads.
- let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
- assert_eq!(
- extensions,
- &[
- ExtensionMetadata {
- id: "ext2".into(),
- manifest: cloud_api_types::ExtensionApiManifest {
- name: "Extension Two".into(),
- version: "0.2.0".into(),
- authors: vec!["marshall".into()],
- description: Some("a great extension".into()),
- repository: "ext2/repo".into(),
- schema_version: Some(0),
- wasm_api_version: None,
- provides: BTreeSet::default(),
- },
- published_at: t0_chrono,
- download_count: 7
- },
- ExtensionMetadata {
- id: "ext1".into(),
- manifest: cloud_api_types::ExtensionApiManifest {
- name: "Extension One".into(),
- version: "0.0.2".into(),
- authors: vec!["max".into(), "marshall".into()],
- description: Some("a good extension".into()),
- repository: "ext1/repo".into(),
- schema_version: Some(1),
- wasm_api_version: None,
- provides: BTreeSet::default(),
- },
- published_at: t0_chrono,
- download_count: 5,
- },
- ]
- );
-
- // Add more extensions, including a new version of `ext1`, and backfilling
- // an older version of `ext2`.
- db.insert_extension_versions(
- &[
- (
- "ext1",
- vec![NewExtensionVersion {
- name: "Extension One".into(),
- version: semver::Version::parse("0.0.3").unwrap(),
- description: "a real good extension".into(),
- authors: vec!["max".into(), "marshall".into()],
- repository: "ext1/repo".into(),
- schema_version: 1,
- wasm_api_version: None,
- provides: BTreeSet::default(),
- published_at: t0,
- }],
- ),
- (
- "ext2",
- vec![NewExtensionVersion {
- name: "Extension Two".into(),
- version: semver::Version::parse("0.1.0").unwrap(),
- description: "an old extension".into(),
- authors: vec!["marshall".into()],
- repository: "ext2/repo".into(),
- schema_version: 0,
- wasm_api_version: None,
- provides: BTreeSet::default(),
- published_at: t0,
- }],
- ),
- ]
- .into_iter()
- .collect(),
- )
- .await
- .unwrap();
-
- let versions = db.get_known_extension_versions().await.unwrap();
- assert_eq!(
- versions,
- [
- (
- "ext1".into(),
- vec!["0.0.1".into(), "0.0.2".into(), "0.0.3".into()]
- ),
- ("ext2".into(), vec!["0.1.0".into(), "0.2.0".into()])
- ]
- .into_iter()
- .collect()
- );
-
- let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
- assert_eq!(
- extensions,
- &[
- ExtensionMetadata {
- id: "ext2".into(),
- manifest: cloud_api_types::ExtensionApiManifest {
- name: "Extension Two".into(),
- version: "0.2.0".into(),
- authors: vec!["marshall".into()],
- description: Some("a great extension".into()),
- repository: "ext2/repo".into(),
- schema_version: Some(0),
- wasm_api_version: None,
- provides: BTreeSet::default(),
- },
- published_at: t0_chrono,
- download_count: 7
- },
- ExtensionMetadata {
- id: "ext1".into(),
- manifest: cloud_api_types::ExtensionApiManifest {
- name: "Extension One".into(),
- version: "0.0.3".into(),
- authors: vec!["max".into(), "marshall".into()],
- description: Some("a real good extension".into()),
- repository: "ext1/repo".into(),
- schema_version: Some(1),
- wasm_api_version: None,
- provides: BTreeSet::default(),
- },
- published_at: t0_chrono,
- download_count: 5,
- },
- ]
- );
-}
-
test_both_dbs!(
test_extensions_by_id,
test_extensions_by_id_postgres,
@@ -371,9 +18,6 @@ async fn test_extensions_by_id(db: &Arc<Database>) {
let versions = db.get_known_extension_versions().await.unwrap();
assert!(versions.is_empty());
- let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
- assert!(extensions.is_empty());
-
let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap();
let t0 = time::PrimitiveDateTime::new(t0.date(), t0.time());