diff --git a/crates/collab/src/api/extensions.rs b/crates/collab/src/api/extensions.rs index b7bff3d445e655b5a00150c6013e654849d42c67..e34de94692e66d23b4b21a74bf839674f0a683c3 100644 --- a/crates/collab/src/api/extensions.rs +++ b/crates/collab/src/api/extensions.rs @@ -4,16 +4,15 @@ use anyhow::Context as _; use aws_sdk_s3::presigning::PresigningConfig; use axum::{ Extension, Json, Router, - extract::{Path, Query}, + extract::{Path, Query, RawQuery}, http::StatusCode, response::Redirect, routing::get, }; -use cloud_api_types::{ExtensionApiManifest, ExtensionProvides, GetExtensionsResponse}; -use collections::{BTreeSet, HashMap}; +use cloud_api_types::{ExtensionApiManifest, GetExtensionsResponse}; +use collections::HashMap; use semver::Version as SemanticVersion; use serde::Deserialize; -use std::str::FromStr; use std::{sync::Arc, time::Duration}; use time::PrimitiveDateTime; use util::{ResultExt, maybe}; @@ -33,74 +32,50 @@ pub fn router() -> Router { ) } -#[derive(Debug, Deserialize)] -struct GetExtensionsParams { - filter: Option, - /// A comma-delimited list of features that the extension must provide. - /// - /// For example: - /// - `themes` - /// - `themes,icon-themes` - /// - `languages,language-servers` - #[serde(default)] - provides: Option, - #[serde(default)] - max_schema_version: i32, -} +const UPSTREAM_EXTENSIONS_URL: &str = "https://cloud.zed.dev/extensions"; -async fn get_extensions( - Extension(app): Extension>, - Query(params): Query, -) -> Result> { - let provides_filter = params.provides.map(|provides| { - provides - .split(',') - .map(|value| value.trim()) - .filter_map(|value| ExtensionProvides::from_str(value).ok()) - .collect::>() - }); +async fn get_extensions(RawQuery(query): RawQuery) -> Result> { + let upstream_url = match query { + Some(query) => format!("{UPSTREAM_EXTENSIONS_URL}?{query}"), + None => UPSTREAM_EXTENSIONS_URL.to_string(), + }; - let mut extensions = app - .db - .get_extensions( - params.filter.as_deref(), - provides_filter.as_ref(), - params.max_schema_version, - 1_000, + let response = reqwest::get(&upstream_url).await.map_err(|error| { + tracing::error!( + ?error, + "failed to proxy request to upstream extensions service" + ); + Error::http( + StatusCode::BAD_GATEWAY, + "upstream extensions service unavailable".into(), ) - .await?; - - if let Some(filter) = params.filter.as_deref() { - let extension_id = filter.to_lowercase(); - let mut exact_match = None; - extensions.retain(|extension| { - if extension.id.as_ref() == extension_id { - exact_match = Some(extension.clone()); - false - } else { - true - } - }); - if exact_match.is_none() { - exact_match = app - .db - .get_extensions_by_ids(&[&extension_id], None) - .await? - .first() - .cloned(); - } - - if let Some(exact_match) = exact_match { - extensions.insert(0, exact_match); - } - }; + })?; - if let Some(query) = params.filter.as_deref() { - let count = extensions.len(); - tracing::info!(query, count, "extension_search") + let status = response.status(); + if !status.is_success() { + let upstream_status = + StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); + let body = response.text().await.unwrap_or_default(); + tracing::error!( + status = status.as_u16(), + body, + "upstream extensions service returned an error" + ); + return Err(Error::http(upstream_status, body)); } - Ok(Json(GetExtensionsResponse { data: extensions })) + let body: GetExtensionsResponse = response.json().await.map_err(|error| { + tracing::error!( + ?error, + "failed to parse response from upstream extensions service" + ); + Error::http( + StatusCode::BAD_GATEWAY, + "failed to parse upstream response".into(), + ) + })?; + + Ok(Json(body)) } #[derive(Debug, Deserialize)] diff --git a/crates/collab/src/db/queries/extensions.rs b/crates/collab/src/db/queries/extensions.rs index be1196ecf9b3af1a580bc2dbda45a7adfb2ede38..dd2ea4aa6b718b42921011299a7892a811c73c49 100644 --- a/crates/collab/src/db/queries/extensions.rs +++ b/crates/collab/src/db/queries/extensions.rs @@ -8,36 +8,6 @@ use util::ResultExt; use super::*; impl Database { - pub async fn get_extensions( - &self, - filter: Option<&str>, - provides_filter: Option<&BTreeSet>, - max_schema_version: i32, - limit: usize, - ) -> Result> { - self.transaction(|tx| async move { - let mut condition = Condition::all() - .add( - extension::Column::LatestVersion - .into_expr() - .eq(extension_version::Column::Version.into_expr()), - ) - .add(extension_version::Column::SchemaVersion.lte(max_schema_version)); - if let Some(filter) = filter { - let fuzzy_name_filter = Self::fuzzy_like_string(filter); - condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter)); - } - - if let Some(provides_filter) = provides_filter { - condition = apply_provides_filter(condition, provides_filter); - } - - self.get_extensions_where(condition, Some(limit as u64), &tx) - .await - }) - .await - } - pub async fn get_extensions_by_ids( &self, ids: &[&str], @@ -396,57 +366,6 @@ impl Database { } } -fn apply_provides_filter( - mut condition: Condition, - provides_filter: &BTreeSet, -) -> Condition { - if provides_filter.contains(&ExtensionProvides::Themes) { - condition = condition.add(extension_version::Column::ProvidesThemes.eq(true)); - } - - if provides_filter.contains(&ExtensionProvides::IconThemes) { - condition = condition.add(extension_version::Column::ProvidesIconThemes.eq(true)); - } - - if provides_filter.contains(&ExtensionProvides::Languages) { - condition = condition.add(extension_version::Column::ProvidesLanguages.eq(true)); - } - - if provides_filter.contains(&ExtensionProvides::Grammars) { - condition = condition.add(extension_version::Column::ProvidesGrammars.eq(true)); - } - - if provides_filter.contains(&ExtensionProvides::LanguageServers) { - condition = condition.add(extension_version::Column::ProvidesLanguageServers.eq(true)); - } - - if provides_filter.contains(&ExtensionProvides::ContextServers) { - condition = condition.add(extension_version::Column::ProvidesContextServers.eq(true)); - } - - if provides_filter.contains(&ExtensionProvides::AgentServers) { - condition = condition.add(extension_version::Column::ProvidesAgentServers.eq(true)); - } - - if provides_filter.contains(&ExtensionProvides::SlashCommands) { - condition = condition.add(extension_version::Column::ProvidesSlashCommands.eq(true)); - } - - if provides_filter.contains(&ExtensionProvides::IndexedDocsProviders) { - condition = condition.add(extension_version::Column::ProvidesIndexedDocsProviders.eq(true)); - } - - if provides_filter.contains(&ExtensionProvides::Snippets) { - condition = condition.add(extension_version::Column::ProvidesSnippets.eq(true)); - } - - if provides_filter.contains(&ExtensionProvides::DebugAdapters) { - condition = condition.add(extension_version::Column::ProvidesDebugAdapters.eq(true)); - } - - condition -} - fn metadata_from_extension_and_version( extension: extension::Model, version: extension_version::Model, diff --git a/crates/collab/tests/integration/db_tests/extension_tests.rs b/crates/collab/tests/integration/db_tests/extension_tests.rs index 993188feb8d40aac35c916f58cddb8baf0a72d38..97bd8e8c65be14de9b50a609e9af7bea6b24e427 100644 --- a/crates/collab/tests/integration/db_tests/extension_tests.rs +++ b/crates/collab/tests/integration/db_tests/extension_tests.rs @@ -8,359 +8,6 @@ use collab::db::{NewExtensionVersion, queries::extensions::convert_time_to_chron use crate::test_both_dbs; -test_both_dbs!( - test_extensions, - test_extensions_postgres, - test_extensions_sqlite -); - -test_both_dbs!( - test_agent_servers_filter, - test_agent_servers_filter_postgres, - test_agent_servers_filter_sqlite -); - -async fn test_agent_servers_filter(db: &Arc) { - // No extensions initially - let versions = db.get_known_extension_versions().await.unwrap(); - assert!(versions.is_empty()); - - // Shared timestamp - let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap(); - let t0 = time::PrimitiveDateTime::new(t0.date(), t0.time()); - - // Insert two extensions, only one provides AgentServers - db.insert_extension_versions( - &[ - ( - "ext_agent_servers", - vec![NewExtensionVersion { - name: "Agent Servers Provider".into(), - version: semver::Version::parse("1.0.0").unwrap(), - description: "has agent servers".into(), - authors: vec!["author".into()], - repository: "org/agent-servers".into(), - schema_version: 1, - wasm_api_version: None, - provides: BTreeSet::from_iter([ExtensionProvides::AgentServers]), - published_at: t0, - }], - ), - ( - "ext_plain", - vec![NewExtensionVersion { - name: "Plain Extension".into(), - version: semver::Version::parse("0.1.0").unwrap(), - description: "no agent servers".into(), - authors: vec!["author2".into()], - repository: "org/plain".into(), - schema_version: 1, - wasm_api_version: None, - provides: BTreeSet::default(), - published_at: t0, - }], - ), - ] - .into_iter() - .collect(), - ) - .await - .unwrap(); - - // Filter by AgentServers provides - let provides_filter = BTreeSet::from_iter([ExtensionProvides::AgentServers]); - - let filtered = db - .get_extensions(None, Some(&provides_filter), 1, 10) - .await - .unwrap(); - - // Expect only the extension that declared AgentServers - assert_eq!(filtered.len(), 1); - assert_eq!(filtered[0].id.as_ref(), "ext_agent_servers"); -} - -async fn test_extensions(db: &Arc) { - let versions = db.get_known_extension_versions().await.unwrap(); - assert!(versions.is_empty()); - - let extensions = db.get_extensions(None, None, 1, 5).await.unwrap(); - assert!(extensions.is_empty()); - - let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap(); - let t0 = time::PrimitiveDateTime::new(t0.date(), t0.time()); - - let t0_chrono = convert_time_to_chrono(t0); - - db.insert_extension_versions( - &[ - ( - "ext1", - vec![ - NewExtensionVersion { - name: "Extension 1".into(), - version: semver::Version::parse("0.0.1").unwrap(), - description: "an extension".into(), - authors: vec!["max".into()], - repository: "ext1/repo".into(), - schema_version: 1, - wasm_api_version: None, - provides: BTreeSet::default(), - published_at: t0, - }, - NewExtensionVersion { - name: "Extension One".into(), - version: semver::Version::parse("0.0.2").unwrap(), - description: "a good extension".into(), - authors: vec!["max".into(), "marshall".into()], - repository: "ext1/repo".into(), - schema_version: 1, - wasm_api_version: None, - provides: BTreeSet::default(), - published_at: t0, - }, - ], - ), - ( - "ext2", - vec![NewExtensionVersion { - name: "Extension Two".into(), - version: semver::Version::parse("0.2.0").unwrap(), - description: "a great extension".into(), - authors: vec!["marshall".into()], - repository: "ext2/repo".into(), - schema_version: 0, - wasm_api_version: None, - provides: BTreeSet::default(), - published_at: t0, - }], - ), - ] - .into_iter() - .collect(), - ) - .await - .unwrap(); - - let versions = db.get_known_extension_versions().await.unwrap(); - assert_eq!( - versions, - [ - ("ext1".into(), vec!["0.0.1".into(), "0.0.2".into()]), - ("ext2".into(), vec!["0.2.0".into()]) - ] - .into_iter() - .collect() - ); - - // The latest version of each extension is returned. - let extensions = db.get_extensions(None, None, 1, 5).await.unwrap(); - assert_eq!( - extensions, - &[ - ExtensionMetadata { - id: "ext1".into(), - manifest: cloud_api_types::ExtensionApiManifest { - name: "Extension One".into(), - version: "0.0.2".into(), - authors: vec!["max".into(), "marshall".into()], - description: Some("a good extension".into()), - repository: "ext1/repo".into(), - schema_version: Some(1), - wasm_api_version: None, - provides: BTreeSet::default(), - }, - published_at: t0_chrono, - download_count: 0, - }, - ExtensionMetadata { - id: "ext2".into(), - manifest: cloud_api_types::ExtensionApiManifest { - name: "Extension Two".into(), - version: "0.2.0".into(), - authors: vec!["marshall".into()], - description: Some("a great extension".into()), - repository: "ext2/repo".into(), - schema_version: Some(0), - wasm_api_version: None, - provides: BTreeSet::default(), - }, - published_at: t0_chrono, - download_count: 0 - }, - ] - ); - - // Extensions with too new of a schema version are excluded. - let extensions = db.get_extensions(None, None, 0, 5).await.unwrap(); - assert_eq!( - extensions, - &[ExtensionMetadata { - id: "ext2".into(), - manifest: cloud_api_types::ExtensionApiManifest { - name: "Extension Two".into(), - version: "0.2.0".into(), - authors: vec!["marshall".into()], - description: Some("a great extension".into()), - repository: "ext2/repo".into(), - schema_version: Some(0), - wasm_api_version: None, - provides: BTreeSet::default(), - }, - published_at: t0_chrono, - download_count: 0 - },] - ); - - // Record extensions being downloaded. - for _ in 0..7 { - assert!(db.record_extension_download("ext2", "0.0.2").await.unwrap()); - } - - for _ in 0..3 { - assert!(db.record_extension_download("ext1", "0.0.1").await.unwrap()); - } - - for _ in 0..2 { - assert!(db.record_extension_download("ext1", "0.0.2").await.unwrap()); - } - - // Record download returns false if the extension does not exist. - assert!( - !db.record_extension_download("no-such-extension", "0.0.2") - .await - .unwrap() - ); - - // Extensions are returned in descending order of total downloads. - let extensions = db.get_extensions(None, None, 1, 5).await.unwrap(); - assert_eq!( - extensions, - &[ - ExtensionMetadata { - id: "ext2".into(), - manifest: cloud_api_types::ExtensionApiManifest { - name: "Extension Two".into(), - version: "0.2.0".into(), - authors: vec!["marshall".into()], - description: Some("a great extension".into()), - repository: "ext2/repo".into(), - schema_version: Some(0), - wasm_api_version: None, - provides: BTreeSet::default(), - }, - published_at: t0_chrono, - download_count: 7 - }, - ExtensionMetadata { - id: "ext1".into(), - manifest: cloud_api_types::ExtensionApiManifest { - name: "Extension One".into(), - version: "0.0.2".into(), - authors: vec!["max".into(), "marshall".into()], - description: Some("a good extension".into()), - repository: "ext1/repo".into(), - schema_version: Some(1), - wasm_api_version: None, - provides: BTreeSet::default(), - }, - published_at: t0_chrono, - download_count: 5, - }, - ] - ); - - // Add more extensions, including a new version of `ext1`, and backfilling - // an older version of `ext2`. - db.insert_extension_versions( - &[ - ( - "ext1", - vec![NewExtensionVersion { - name: "Extension One".into(), - version: semver::Version::parse("0.0.3").unwrap(), - description: "a real good extension".into(), - authors: vec!["max".into(), "marshall".into()], - repository: "ext1/repo".into(), - schema_version: 1, - wasm_api_version: None, - provides: BTreeSet::default(), - published_at: t0, - }], - ), - ( - "ext2", - vec![NewExtensionVersion { - name: "Extension Two".into(), - version: semver::Version::parse("0.1.0").unwrap(), - description: "an old extension".into(), - authors: vec!["marshall".into()], - repository: "ext2/repo".into(), - schema_version: 0, - wasm_api_version: None, - provides: BTreeSet::default(), - published_at: t0, - }], - ), - ] - .into_iter() - .collect(), - ) - .await - .unwrap(); - - let versions = db.get_known_extension_versions().await.unwrap(); - assert_eq!( - versions, - [ - ( - "ext1".into(), - vec!["0.0.1".into(), "0.0.2".into(), "0.0.3".into()] - ), - ("ext2".into(), vec!["0.1.0".into(), "0.2.0".into()]) - ] - .into_iter() - .collect() - ); - - let extensions = db.get_extensions(None, None, 1, 5).await.unwrap(); - assert_eq!( - extensions, - &[ - ExtensionMetadata { - id: "ext2".into(), - manifest: cloud_api_types::ExtensionApiManifest { - name: "Extension Two".into(), - version: "0.2.0".into(), - authors: vec!["marshall".into()], - description: Some("a great extension".into()), - repository: "ext2/repo".into(), - schema_version: Some(0), - wasm_api_version: None, - provides: BTreeSet::default(), - }, - published_at: t0_chrono, - download_count: 7 - }, - ExtensionMetadata { - id: "ext1".into(), - manifest: cloud_api_types::ExtensionApiManifest { - name: "Extension One".into(), - version: "0.0.3".into(), - authors: vec!["max".into(), "marshall".into()], - description: Some("a real good extension".into()), - repository: "ext1/repo".into(), - schema_version: Some(1), - wasm_api_version: None, - provides: BTreeSet::default(), - }, - published_at: t0_chrono, - download_count: 5, - }, - ] - ); -} - test_both_dbs!( test_extensions_by_id, test_extensions_by_id_postgres, @@ -371,9 +18,6 @@ async fn test_extensions_by_id(db: &Arc) { let versions = db.get_known_extension_versions().await.unwrap(); assert!(versions.is_empty()); - let extensions = db.get_extensions(None, None, 1, 5).await.unwrap(); - assert!(extensions.is_empty()); - let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap(); let t0 = time::PrimitiveDateTime::new(t0.date(), t0.time());