From e1919b41215f2db23c1ff79edae435191f3fd510 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 5 Feb 2025 17:12:18 -0500 Subject: [PATCH] collab: Add the ability to filter extensions by what they provide (#24315) This PR adds the ability to filter extension results from the extension API by the features that they provide. For instance, to filter down just to extensions that provide icon themes: ``` https://api.zed.dev/extensions?provides=icon-themes ``` Release Notes: - N/A --- crates/collab/src/api/extensions.rs | 28 +++++++++-- crates/collab/src/db/queries/extensions.rs | 48 +++++++++++++++++++ crates/collab/src/db/tests/extension_tests.rs | 12 ++--- crates/rpc/src/extension.rs | 8 +++- 4 files changed, 85 insertions(+), 11 deletions(-) diff --git a/crates/collab/src/api/extensions.rs b/crates/collab/src/api/extensions.rs index e132acaf0b106c4c93012e968e9aa74f4f51b718..73aea4534067f4f6e557c98fac3b27d3260c5d3a 100644 --- a/crates/collab/src/api/extensions.rs +++ b/crates/collab/src/api/extensions.rs @@ -9,10 +9,11 @@ use axum::{ routing::get, Extension, Json, Router, }; -use collections::HashMap; -use rpc::{ExtensionApiManifest, GetExtensionsResponse}; +use collections::{BTreeSet, HashMap}; +use rpc::{ExtensionApiManifest, ExtensionProvides, GetExtensionsResponse}; use semantic_version::SemanticVersion; use serde::Deserialize; +use std::str::FromStr; use std::{sync::Arc, time::Duration}; use time::PrimitiveDateTime; use util::{maybe, ResultExt}; @@ -35,6 +36,14 @@ pub fn router() -> Router { #[derive(Debug, Deserialize)] struct GetExtensionsParams { filter: Option, + /// A comma-delimited list of features that the extension must provide. + /// + /// For example: + /// - `themes` + /// - `themes,icon-themes` + /// - `languages,language-servers` + #[serde(default)] + provides: Option, #[serde(default)] max_schema_version: i32, } @@ -43,9 +52,22 @@ async fn get_extensions( Extension(app): Extension>, Query(params): Query, ) -> Result> { + let provides_filter = params.provides.map(|provides| { + provides + .split(',') + .map(|value| value.trim()) + .filter_map(|value| ExtensionProvides::from_str(value).ok()) + .collect::>() + }); + let mut extensions = app .db - .get_extensions(params.filter.as_deref(), params.max_schema_version, 500) + .get_extensions( + params.filter.as_deref(), + provides_filter.as_ref(), + params.max_schema_version, + 500, + ) .await?; if let Some(filter) = params.filter.as_deref() { diff --git a/crates/collab/src/db/queries/extensions.rs b/crates/collab/src/db/queries/extensions.rs index 54f47ae45ee015e80cf916c0d06703f0faff2e01..2b76e12335108a6f57c24b6ba17dd19c2d998708 100644 --- a/crates/collab/src/db/queries/extensions.rs +++ b/crates/collab/src/db/queries/extensions.rs @@ -10,6 +10,7 @@ impl Database { pub async fn get_extensions( &self, filter: Option<&str>, + provides_filter: Option<&BTreeSet>, max_schema_version: i32, limit: usize, ) -> Result> { @@ -26,6 +27,10 @@ impl Database { condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter)); } + if let Some(provides_filter) = provides_filter { + condition = apply_provides_filter(condition, provides_filter); + } + self.get_extensions_where(condition, Some(limit as u64), &tx) .await }) @@ -385,6 +390,49 @@ impl Database { } } +fn apply_provides_filter( + mut condition: Condition, + provides_filter: &BTreeSet, +) -> Condition { + if provides_filter.contains(&ExtensionProvides::Themes) { + condition = condition.add(extension_version::Column::ProvidesThemes.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::IconThemes) { + condition = condition.add(extension_version::Column::ProvidesIconThemes.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::Languages) { + condition = condition.add(extension_version::Column::ProvidesLanguages.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::Grammars) { + condition = condition.add(extension_version::Column::ProvidesGrammars.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::LanguageServers) { + condition = condition.add(extension_version::Column::ProvidesLanguageServers.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::ContextServers) { + condition = condition.add(extension_version::Column::ProvidesContextServers.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::SlashCommands) { + condition = condition.add(extension_version::Column::ProvidesSlashCommands.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::IndexedDocsProviders) { + condition = condition.add(extension_version::Column::ProvidesIndexedDocsProviders.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::Snippets) { + condition = condition.add(extension_version::Column::ProvidesSnippets.eq(true)); + } + + condition +} + fn metadata_from_extension_and_version( extension: extension::Model, version: extension_version::Model, diff --git a/crates/collab/src/db/tests/extension_tests.rs b/crates/collab/src/db/tests/extension_tests.rs index f7a5398d3c1849c0bdb689186fcd50f914bc31bb..460d74ffc0588c8243962a1a2b5e9d4bf9006fe8 100644 --- a/crates/collab/src/db/tests/extension_tests.rs +++ b/crates/collab/src/db/tests/extension_tests.rs @@ -20,7 +20,7 @@ async fn test_extensions(db: &Arc) { let versions = db.get_known_extension_versions().await.unwrap(); assert!(versions.is_empty()); - let extensions = db.get_extensions(None, 1, 5).await.unwrap(); + let extensions = db.get_extensions(None, None, 1, 5).await.unwrap(); assert!(extensions.is_empty()); let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap(); @@ -90,7 +90,7 @@ async fn test_extensions(db: &Arc) { ); // The latest version of each extension is returned. - let extensions = db.get_extensions(None, 1, 5).await.unwrap(); + let extensions = db.get_extensions(None, None, 1, 5).await.unwrap(); assert_eq!( extensions, &[ @@ -128,7 +128,7 @@ async fn test_extensions(db: &Arc) { ); // Extensions with too new of a schema version are excluded. - let extensions = db.get_extensions(None, 0, 5).await.unwrap(); + let extensions = db.get_extensions(None, None, 0, 5).await.unwrap(); assert_eq!( extensions, &[ExtensionMetadata { @@ -168,7 +168,7 @@ async fn test_extensions(db: &Arc) { .unwrap()); // Extensions are returned in descending order of total downloads. - let extensions = db.get_extensions(None, 1, 5).await.unwrap(); + let extensions = db.get_extensions(None, None, 1, 5).await.unwrap(); assert_eq!( extensions, &[ @@ -258,7 +258,7 @@ async fn test_extensions(db: &Arc) { .collect() ); - let extensions = db.get_extensions(None, 1, 5).await.unwrap(); + let extensions = db.get_extensions(None, None, 1, 5).await.unwrap(); assert_eq!( extensions, &[ @@ -306,7 +306,7 @@ async fn test_extensions_by_id(db: &Arc) { let versions = db.get_known_extension_versions().await.unwrap(); assert!(versions.is_empty()); - let extensions = db.get_extensions(None, 1, 5).await.unwrap(); + let extensions = db.get_extensions(None, None, 1, 5).await.unwrap(); assert!(extensions.is_empty()); let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap(); diff --git a/crates/rpc/src/extension.rs b/crates/rpc/src/extension.rs index 67b9116b83b73cdac50c002f7b61de7c6682ca6c..f1dcdc28d669251e811a15d8fa7cafb35d7eebcf 100644 --- a/crates/rpc/src/extension.rs +++ b/crates/rpc/src/extension.rs @@ -1,8 +1,9 @@ use std::collections::BTreeSet; +use std::sync::Arc; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; -use std::sync::Arc; +use strum::EnumString; #[derive(Clone, Serialize, Deserialize, Debug, PartialEq)] pub struct ExtensionApiManifest { @@ -17,8 +18,11 @@ pub struct ExtensionApiManifest { pub provides: BTreeSet, } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize, EnumString, +)] #[serde(rename_all = "kebab-case")] +#[strum(serialize_all = "kebab-case")] pub enum ExtensionProvides { Themes, IconThemes,