collab: Proxy `GET /extensions` to Cloud (#48717)

Marshall Bowers created

This PR updates the `GET /extensions` endpoint in Collab to proxy to
Cloud.

Release Notes:

- N/A

Change summary

crates/collab/src/api/extensions.rs                         | 107 -
crates/collab/src/db/queries/extensions.rs                  |  81 -
crates/collab/tests/integration/db_tests/extension_tests.rs | 356 -------
3 files changed, 41 insertions(+), 503 deletions(-)

Detailed changes

crates/collab/src/api/extensions.rs 🔗

@@ -4,16 +4,15 @@ use anyhow::Context as _;
 use aws_sdk_s3::presigning::PresigningConfig;
 use axum::{
     Extension, Json, Router,
-    extract::{Path, Query},
+    extract::{Path, Query, RawQuery},
     http::StatusCode,
     response::Redirect,
     routing::get,
 };
-use cloud_api_types::{ExtensionApiManifest, ExtensionProvides, GetExtensionsResponse};
-use collections::{BTreeSet, HashMap};
+use cloud_api_types::{ExtensionApiManifest, GetExtensionsResponse};
+use collections::HashMap;
 use semver::Version as SemanticVersion;
 use serde::Deserialize;
-use std::str::FromStr;
 use std::{sync::Arc, time::Duration};
 use time::PrimitiveDateTime;
 use util::{ResultExt, maybe};
@@ -33,74 +32,50 @@ pub fn router() -> Router {
         )
 }
 
-#[derive(Debug, Deserialize)]
-struct GetExtensionsParams {
-    filter: Option<String>,
-    /// A comma-delimited list of features that the extension must provide.
-    ///
-    /// For example:
-    /// - `themes`
-    /// - `themes,icon-themes`
-    /// - `languages,language-servers`
-    #[serde(default)]
-    provides: Option<String>,
-    #[serde(default)]
-    max_schema_version: i32,
-}
+const UPSTREAM_EXTENSIONS_URL: &str = "https://cloud.zed.dev/extensions";
 
-async fn get_extensions(
-    Extension(app): Extension<Arc<AppState>>,
-    Query(params): Query<GetExtensionsParams>,
-) -> Result<Json<GetExtensionsResponse>> {
-    let provides_filter = params.provides.map(|provides| {
-        provides
-            .split(',')
-            .map(|value| value.trim())
-            .filter_map(|value| ExtensionProvides::from_str(value).ok())
-            .collect::<BTreeSet<_>>()
-    });
+async fn get_extensions(RawQuery(query): RawQuery) -> Result<Json<GetExtensionsResponse>> {
+    let upstream_url = match query {
+        Some(query) => format!("{UPSTREAM_EXTENSIONS_URL}?{query}"),
+        None => UPSTREAM_EXTENSIONS_URL.to_string(),
+    };
 
-    let mut extensions = app
-        .db
-        .get_extensions(
-            params.filter.as_deref(),
-            provides_filter.as_ref(),
-            params.max_schema_version,
-            1_000,
+    let response = reqwest::get(&upstream_url).await.map_err(|error| {
+        tracing::error!(
+            ?error,
+            "failed to proxy request to upstream extensions service"
+        );
+        Error::http(
+            StatusCode::BAD_GATEWAY,
+            "upstream extensions service unavailable".into(),
         )
-        .await?;
-
-    if let Some(filter) = params.filter.as_deref() {
-        let extension_id = filter.to_lowercase();
-        let mut exact_match = None;
-        extensions.retain(|extension| {
-            if extension.id.as_ref() == extension_id {
-                exact_match = Some(extension.clone());
-                false
-            } else {
-                true
-            }
-        });
-        if exact_match.is_none() {
-            exact_match = app
-                .db
-                .get_extensions_by_ids(&[&extension_id], None)
-                .await?
-                .first()
-                .cloned();
-        }
-
-        if let Some(exact_match) = exact_match {
-            extensions.insert(0, exact_match);
-        }
-    };
+    })?;
 
-    if let Some(query) = params.filter.as_deref() {
-        let count = extensions.len();
-        tracing::info!(query, count, "extension_search")
+    let status = response.status();
+    if !status.is_success() {
+        let upstream_status =
+            StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
+        let body = response.text().await.unwrap_or_default();
+        tracing::error!(
+            status = status.as_u16(),
+            body,
+            "upstream extensions service returned an error"
+        );
+        return Err(Error::http(upstream_status, body));
     }
 
-    Ok(Json(GetExtensionsResponse { data: extensions }))
+    let body: GetExtensionsResponse = response.json().await.map_err(|error| {
+        tracing::error!(
+            ?error,
+            "failed to parse response from upstream extensions service"
+        );
+        Error::http(
+            StatusCode::BAD_GATEWAY,
+            "failed to parse upstream response".into(),
+        )
+    })?;
+
+    Ok(Json(body))
 }
 
 #[derive(Debug, Deserialize)]

crates/collab/src/db/queries/extensions.rs 🔗

@@ -8,36 +8,6 @@ use util::ResultExt;
 use super::*;
 
 impl Database {
-    pub async fn get_extensions(
-        &self,
-        filter: Option<&str>,
-        provides_filter: Option<&BTreeSet<ExtensionProvides>>,
-        max_schema_version: i32,
-        limit: usize,
-    ) -> Result<Vec<ExtensionMetadata>> {
-        self.transaction(|tx| async move {
-            let mut condition = Condition::all()
-                .add(
-                    extension::Column::LatestVersion
-                        .into_expr()
-                        .eq(extension_version::Column::Version.into_expr()),
-                )
-                .add(extension_version::Column::SchemaVersion.lte(max_schema_version));
-            if let Some(filter) = filter {
-                let fuzzy_name_filter = Self::fuzzy_like_string(filter);
-                condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter));
-            }
-
-            if let Some(provides_filter) = provides_filter {
-                condition = apply_provides_filter(condition, provides_filter);
-            }
-
-            self.get_extensions_where(condition, Some(limit as u64), &tx)
-                .await
-        })
-        .await
-    }
-
     pub async fn get_extensions_by_ids(
         &self,
         ids: &[&str],
@@ -396,57 +366,6 @@ impl Database {
     }
 }
 
-fn apply_provides_filter(
-    mut condition: Condition,
-    provides_filter: &BTreeSet<ExtensionProvides>,
-) -> Condition {
-    if provides_filter.contains(&ExtensionProvides::Themes) {
-        condition = condition.add(extension_version::Column::ProvidesThemes.eq(true));
-    }
-
-    if provides_filter.contains(&ExtensionProvides::IconThemes) {
-        condition = condition.add(extension_version::Column::ProvidesIconThemes.eq(true));
-    }
-
-    if provides_filter.contains(&ExtensionProvides::Languages) {
-        condition = condition.add(extension_version::Column::ProvidesLanguages.eq(true));
-    }
-
-    if provides_filter.contains(&ExtensionProvides::Grammars) {
-        condition = condition.add(extension_version::Column::ProvidesGrammars.eq(true));
-    }
-
-    if provides_filter.contains(&ExtensionProvides::LanguageServers) {
-        condition = condition.add(extension_version::Column::ProvidesLanguageServers.eq(true));
-    }
-
-    if provides_filter.contains(&ExtensionProvides::ContextServers) {
-        condition = condition.add(extension_version::Column::ProvidesContextServers.eq(true));
-    }
-
-    if provides_filter.contains(&ExtensionProvides::AgentServers) {
-        condition = condition.add(extension_version::Column::ProvidesAgentServers.eq(true));
-    }
-
-    if provides_filter.contains(&ExtensionProvides::SlashCommands) {
-        condition = condition.add(extension_version::Column::ProvidesSlashCommands.eq(true));
-    }
-
-    if provides_filter.contains(&ExtensionProvides::IndexedDocsProviders) {
-        condition = condition.add(extension_version::Column::ProvidesIndexedDocsProviders.eq(true));
-    }
-
-    if provides_filter.contains(&ExtensionProvides::Snippets) {
-        condition = condition.add(extension_version::Column::ProvidesSnippets.eq(true));
-    }
-
-    if provides_filter.contains(&ExtensionProvides::DebugAdapters) {
-        condition = condition.add(extension_version::Column::ProvidesDebugAdapters.eq(true));
-    }
-
-    condition
-}
-
 fn metadata_from_extension_and_version(
     extension: extension::Model,
     version: extension_version::Model,

crates/collab/tests/integration/db_tests/extension_tests.rs 🔗

@@ -8,359 +8,6 @@ use collab::db::{NewExtensionVersion, queries::extensions::convert_time_to_chron
 
 use crate::test_both_dbs;
 
-test_both_dbs!(
-    test_extensions,
-    test_extensions_postgres,
-    test_extensions_sqlite
-);
-
-test_both_dbs!(
-    test_agent_servers_filter,
-    test_agent_servers_filter_postgres,
-    test_agent_servers_filter_sqlite
-);
-
-async fn test_agent_servers_filter(db: &Arc<Database>) {
-    // No extensions initially
-    let versions = db.get_known_extension_versions().await.unwrap();
-    assert!(versions.is_empty());
-
-    // Shared timestamp
-    let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap();
-    let t0 = time::PrimitiveDateTime::new(t0.date(), t0.time());
-
-    // Insert two extensions, only one provides AgentServers
-    db.insert_extension_versions(
-        &[
-            (
-                "ext_agent_servers",
-                vec![NewExtensionVersion {
-                    name: "Agent Servers Provider".into(),
-                    version: semver::Version::parse("1.0.0").unwrap(),
-                    description: "has agent servers".into(),
-                    authors: vec!["author".into()],
-                    repository: "org/agent-servers".into(),
-                    schema_version: 1,
-                    wasm_api_version: None,
-                    provides: BTreeSet::from_iter([ExtensionProvides::AgentServers]),
-                    published_at: t0,
-                }],
-            ),
-            (
-                "ext_plain",
-                vec![NewExtensionVersion {
-                    name: "Plain Extension".into(),
-                    version: semver::Version::parse("0.1.0").unwrap(),
-                    description: "no agent servers".into(),
-                    authors: vec!["author2".into()],
-                    repository: "org/plain".into(),
-                    schema_version: 1,
-                    wasm_api_version: None,
-                    provides: BTreeSet::default(),
-                    published_at: t0,
-                }],
-            ),
-        ]
-        .into_iter()
-        .collect(),
-    )
-    .await
-    .unwrap();
-
-    // Filter by AgentServers provides
-    let provides_filter = BTreeSet::from_iter([ExtensionProvides::AgentServers]);
-
-    let filtered = db
-        .get_extensions(None, Some(&provides_filter), 1, 10)
-        .await
-        .unwrap();
-
-    // Expect only the extension that declared AgentServers
-    assert_eq!(filtered.len(), 1);
-    assert_eq!(filtered[0].id.as_ref(), "ext_agent_servers");
-}
-
-async fn test_extensions(db: &Arc<Database>) {
-    let versions = db.get_known_extension_versions().await.unwrap();
-    assert!(versions.is_empty());
-
-    let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
-    assert!(extensions.is_empty());
-
-    let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap();
-    let t0 = time::PrimitiveDateTime::new(t0.date(), t0.time());
-
-    let t0_chrono = convert_time_to_chrono(t0);
-
-    db.insert_extension_versions(
-        &[
-            (
-                "ext1",
-                vec![
-                    NewExtensionVersion {
-                        name: "Extension 1".into(),
-                        version: semver::Version::parse("0.0.1").unwrap(),
-                        description: "an extension".into(),
-                        authors: vec!["max".into()],
-                        repository: "ext1/repo".into(),
-                        schema_version: 1,
-                        wasm_api_version: None,
-                        provides: BTreeSet::default(),
-                        published_at: t0,
-                    },
-                    NewExtensionVersion {
-                        name: "Extension One".into(),
-                        version: semver::Version::parse("0.0.2").unwrap(),
-                        description: "a good extension".into(),
-                        authors: vec!["max".into(), "marshall".into()],
-                        repository: "ext1/repo".into(),
-                        schema_version: 1,
-                        wasm_api_version: None,
-                        provides: BTreeSet::default(),
-                        published_at: t0,
-                    },
-                ],
-            ),
-            (
-                "ext2",
-                vec![NewExtensionVersion {
-                    name: "Extension Two".into(),
-                    version: semver::Version::parse("0.2.0").unwrap(),
-                    description: "a great extension".into(),
-                    authors: vec!["marshall".into()],
-                    repository: "ext2/repo".into(),
-                    schema_version: 0,
-                    wasm_api_version: None,
-                    provides: BTreeSet::default(),
-                    published_at: t0,
-                }],
-            ),
-        ]
-        .into_iter()
-        .collect(),
-    )
-    .await
-    .unwrap();
-
-    let versions = db.get_known_extension_versions().await.unwrap();
-    assert_eq!(
-        versions,
-        [
-            ("ext1".into(), vec!["0.0.1".into(), "0.0.2".into()]),
-            ("ext2".into(), vec!["0.2.0".into()])
-        ]
-        .into_iter()
-        .collect()
-    );
-
-    // The latest version of each extension is returned.
-    let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
-    assert_eq!(
-        extensions,
-        &[
-            ExtensionMetadata {
-                id: "ext1".into(),
-                manifest: cloud_api_types::ExtensionApiManifest {
-                    name: "Extension One".into(),
-                    version: "0.0.2".into(),
-                    authors: vec!["max".into(), "marshall".into()],
-                    description: Some("a good extension".into()),
-                    repository: "ext1/repo".into(),
-                    schema_version: Some(1),
-                    wasm_api_version: None,
-                    provides: BTreeSet::default(),
-                },
-                published_at: t0_chrono,
-                download_count: 0,
-            },
-            ExtensionMetadata {
-                id: "ext2".into(),
-                manifest: cloud_api_types::ExtensionApiManifest {
-                    name: "Extension Two".into(),
-                    version: "0.2.0".into(),
-                    authors: vec!["marshall".into()],
-                    description: Some("a great extension".into()),
-                    repository: "ext2/repo".into(),
-                    schema_version: Some(0),
-                    wasm_api_version: None,
-                    provides: BTreeSet::default(),
-                },
-                published_at: t0_chrono,
-                download_count: 0
-            },
-        ]
-    );
-
-    // Extensions with too new of a schema version are excluded.
-    let extensions = db.get_extensions(None, None, 0, 5).await.unwrap();
-    assert_eq!(
-        extensions,
-        &[ExtensionMetadata {
-            id: "ext2".into(),
-            manifest: cloud_api_types::ExtensionApiManifest {
-                name: "Extension Two".into(),
-                version: "0.2.0".into(),
-                authors: vec!["marshall".into()],
-                description: Some("a great extension".into()),
-                repository: "ext2/repo".into(),
-                schema_version: Some(0),
-                wasm_api_version: None,
-                provides: BTreeSet::default(),
-            },
-            published_at: t0_chrono,
-            download_count: 0
-        },]
-    );
-
-    // Record extensions being downloaded.
-    for _ in 0..7 {
-        assert!(db.record_extension_download("ext2", "0.0.2").await.unwrap());
-    }
-
-    for _ in 0..3 {
-        assert!(db.record_extension_download("ext1", "0.0.1").await.unwrap());
-    }
-
-    for _ in 0..2 {
-        assert!(db.record_extension_download("ext1", "0.0.2").await.unwrap());
-    }
-
-    // Record download returns false if the extension does not exist.
-    assert!(
-        !db.record_extension_download("no-such-extension", "0.0.2")
-            .await
-            .unwrap()
-    );
-
-    // Extensions are returned in descending order of total downloads.
-    let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
-    assert_eq!(
-        extensions,
-        &[
-            ExtensionMetadata {
-                id: "ext2".into(),
-                manifest: cloud_api_types::ExtensionApiManifest {
-                    name: "Extension Two".into(),
-                    version: "0.2.0".into(),
-                    authors: vec!["marshall".into()],
-                    description: Some("a great extension".into()),
-                    repository: "ext2/repo".into(),
-                    schema_version: Some(0),
-                    wasm_api_version: None,
-                    provides: BTreeSet::default(),
-                },
-                published_at: t0_chrono,
-                download_count: 7
-            },
-            ExtensionMetadata {
-                id: "ext1".into(),
-                manifest: cloud_api_types::ExtensionApiManifest {
-                    name: "Extension One".into(),
-                    version: "0.0.2".into(),
-                    authors: vec!["max".into(), "marshall".into()],
-                    description: Some("a good extension".into()),
-                    repository: "ext1/repo".into(),
-                    schema_version: Some(1),
-                    wasm_api_version: None,
-                    provides: BTreeSet::default(),
-                },
-                published_at: t0_chrono,
-                download_count: 5,
-            },
-        ]
-    );
-
-    // Add more extensions, including a new version of `ext1`, and backfilling
-    // an older version of `ext2`.
-    db.insert_extension_versions(
-        &[
-            (
-                "ext1",
-                vec![NewExtensionVersion {
-                    name: "Extension One".into(),
-                    version: semver::Version::parse("0.0.3").unwrap(),
-                    description: "a real good extension".into(),
-                    authors: vec!["max".into(), "marshall".into()],
-                    repository: "ext1/repo".into(),
-                    schema_version: 1,
-                    wasm_api_version: None,
-                    provides: BTreeSet::default(),
-                    published_at: t0,
-                }],
-            ),
-            (
-                "ext2",
-                vec![NewExtensionVersion {
-                    name: "Extension Two".into(),
-                    version: semver::Version::parse("0.1.0").unwrap(),
-                    description: "an old extension".into(),
-                    authors: vec!["marshall".into()],
-                    repository: "ext2/repo".into(),
-                    schema_version: 0,
-                    wasm_api_version: None,
-                    provides: BTreeSet::default(),
-                    published_at: t0,
-                }],
-            ),
-        ]
-        .into_iter()
-        .collect(),
-    )
-    .await
-    .unwrap();
-
-    let versions = db.get_known_extension_versions().await.unwrap();
-    assert_eq!(
-        versions,
-        [
-            (
-                "ext1".into(),
-                vec!["0.0.1".into(), "0.0.2".into(), "0.0.3".into()]
-            ),
-            ("ext2".into(), vec!["0.1.0".into(), "0.2.0".into()])
-        ]
-        .into_iter()
-        .collect()
-    );
-
-    let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
-    assert_eq!(
-        extensions,
-        &[
-            ExtensionMetadata {
-                id: "ext2".into(),
-                manifest: cloud_api_types::ExtensionApiManifest {
-                    name: "Extension Two".into(),
-                    version: "0.2.0".into(),
-                    authors: vec!["marshall".into()],
-                    description: Some("a great extension".into()),
-                    repository: "ext2/repo".into(),
-                    schema_version: Some(0),
-                    wasm_api_version: None,
-                    provides: BTreeSet::default(),
-                },
-                published_at: t0_chrono,
-                download_count: 7
-            },
-            ExtensionMetadata {
-                id: "ext1".into(),
-                manifest: cloud_api_types::ExtensionApiManifest {
-                    name: "Extension One".into(),
-                    version: "0.0.3".into(),
-                    authors: vec!["max".into(), "marshall".into()],
-                    description: Some("a real good extension".into()),
-                    repository: "ext1/repo".into(),
-                    schema_version: Some(1),
-                    wasm_api_version: None,
-                    provides: BTreeSet::default(),
-                },
-                published_at: t0_chrono,
-                download_count: 5,
-            },
-        ]
-    );
-}
-
 test_both_dbs!(
     test_extensions_by_id,
     test_extensions_by_id_postgres,
@@ -371,9 +18,6 @@ async fn test_extensions_by_id(db: &Arc<Database>) {
     let versions = db.get_known_extension_versions().await.unwrap();
     assert!(versions.is_empty());
 
-    let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
-    assert!(extensions.is_empty());
-
     let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap();
     let t0 = time::PrimitiveDateTime::new(t0.date(), t0.time());