Detailed changes
@@ -9,10 +9,11 @@ use axum::{
routing::get,
Extension, Json, Router,
};
-use collections::HashMap;
-use rpc::{ExtensionApiManifest, GetExtensionsResponse};
+use collections::{BTreeSet, HashMap};
+use rpc::{ExtensionApiManifest, ExtensionProvides, GetExtensionsResponse};
use semantic_version::SemanticVersion;
use serde::Deserialize;
+use std::str::FromStr;
use std::{sync::Arc, time::Duration};
use time::PrimitiveDateTime;
use util::{maybe, ResultExt};
@@ -35,6 +36,14 @@ pub fn router() -> Router {
#[derive(Debug, Deserialize)]
struct GetExtensionsParams {
filter: Option<String>,
+ /// A comma-delimited list of features that the extension must provide.
+ ///
+ /// For example:
+ /// - `themes`
+ /// - `themes,icon-themes`
+ /// - `languages,language-servers`
+ #[serde(default)]
+ provides: Option<String>,
#[serde(default)]
max_schema_version: i32,
}
@@ -43,9 +52,22 @@ async fn get_extensions(
Extension(app): Extension<Arc<AppState>>,
Query(params): Query<GetExtensionsParams>,
) -> Result<Json<GetExtensionsResponse>> {
+ let provides_filter = params.provides.map(|provides| {
+ provides
+ .split(',')
+ .map(|value| value.trim())
+ .filter_map(|value| ExtensionProvides::from_str(value).ok())
+ .collect::<BTreeSet<_>>()
+ });
+
let mut extensions = app
.db
- .get_extensions(params.filter.as_deref(), params.max_schema_version, 500)
+ .get_extensions(
+ params.filter.as_deref(),
+ provides_filter.as_ref(),
+ params.max_schema_version,
+ 500,
+ )
.await?;
if let Some(filter) = params.filter.as_deref() {
@@ -10,6 +10,7 @@ impl Database {
pub async fn get_extensions(
&self,
filter: Option<&str>,
+ provides_filter: Option<&BTreeSet<ExtensionProvides>>,
max_schema_version: i32,
limit: usize,
) -> Result<Vec<ExtensionMetadata>> {
@@ -26,6 +27,10 @@ impl Database {
condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter));
}
+ if let Some(provides_filter) = provides_filter {
+ condition = apply_provides_filter(condition, provides_filter);
+ }
+
self.get_extensions_where(condition, Some(limit as u64), &tx)
.await
})
@@ -385,6 +390,49 @@ impl Database {
}
}
+fn apply_provides_filter(
+ mut condition: Condition,
+ provides_filter: &BTreeSet<ExtensionProvides>,
+) -> Condition {
+ if provides_filter.contains(&ExtensionProvides::Themes) {
+ condition = condition.add(extension_version::Column::ProvidesThemes.eq(true));
+ }
+
+ if provides_filter.contains(&ExtensionProvides::IconThemes) {
+ condition = condition.add(extension_version::Column::ProvidesIconThemes.eq(true));
+ }
+
+ if provides_filter.contains(&ExtensionProvides::Languages) {
+ condition = condition.add(extension_version::Column::ProvidesLanguages.eq(true));
+ }
+
+ if provides_filter.contains(&ExtensionProvides::Grammars) {
+ condition = condition.add(extension_version::Column::ProvidesGrammars.eq(true));
+ }
+
+ if provides_filter.contains(&ExtensionProvides::LanguageServers) {
+ condition = condition.add(extension_version::Column::ProvidesLanguageServers.eq(true));
+ }
+
+ if provides_filter.contains(&ExtensionProvides::ContextServers) {
+ condition = condition.add(extension_version::Column::ProvidesContextServers.eq(true));
+ }
+
+ if provides_filter.contains(&ExtensionProvides::SlashCommands) {
+ condition = condition.add(extension_version::Column::ProvidesSlashCommands.eq(true));
+ }
+
+ if provides_filter.contains(&ExtensionProvides::IndexedDocsProviders) {
+ condition = condition.add(extension_version::Column::ProvidesIndexedDocsProviders.eq(true));
+ }
+
+ if provides_filter.contains(&ExtensionProvides::Snippets) {
+ condition = condition.add(extension_version::Column::ProvidesSnippets.eq(true));
+ }
+
+ condition
+}
+
fn metadata_from_extension_and_version(
extension: extension::Model,
version: extension_version::Model,
@@ -20,7 +20,7 @@ async fn test_extensions(db: &Arc<Database>) {
let versions = db.get_known_extension_versions().await.unwrap();
assert!(versions.is_empty());
- let extensions = db.get_extensions(None, 1, 5).await.unwrap();
+ let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
assert!(extensions.is_empty());
let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap();
@@ -90,7 +90,7 @@ async fn test_extensions(db: &Arc<Database>) {
);
// The latest version of each extension is returned.
- let extensions = db.get_extensions(None, 1, 5).await.unwrap();
+ let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
assert_eq!(
extensions,
&[
@@ -128,7 +128,7 @@ async fn test_extensions(db: &Arc<Database>) {
);
// Extensions with too new of a schema version are excluded.
- let extensions = db.get_extensions(None, 0, 5).await.unwrap();
+ let extensions = db.get_extensions(None, None, 0, 5).await.unwrap();
assert_eq!(
extensions,
&[ExtensionMetadata {
@@ -168,7 +168,7 @@ async fn test_extensions(db: &Arc<Database>) {
.unwrap());
// Extensions are returned in descending order of total downloads.
- let extensions = db.get_extensions(None, 1, 5).await.unwrap();
+ let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
assert_eq!(
extensions,
&[
@@ -258,7 +258,7 @@ async fn test_extensions(db: &Arc<Database>) {
.collect()
);
- let extensions = db.get_extensions(None, 1, 5).await.unwrap();
+ let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
assert_eq!(
extensions,
&[
@@ -306,7 +306,7 @@ async fn test_extensions_by_id(db: &Arc<Database>) {
let versions = db.get_known_extension_versions().await.unwrap();
assert!(versions.is_empty());
- let extensions = db.get_extensions(None, 1, 5).await.unwrap();
+ let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
assert!(extensions.is_empty());
let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap();
@@ -1,8 +1,9 @@
use std::collections::BTreeSet;
+use std::sync::Arc;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
-use std::sync::Arc;
+use strum::EnumString;
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
pub struct ExtensionApiManifest {
@@ -17,8 +18,11 @@ pub struct ExtensionApiManifest {
pub provides: BTreeSet<ExtensionProvides>,
}
-#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
+#[derive(
+ Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize, EnumString,
+)]
#[serde(rename_all = "kebab-case")]
+#[strum(serialize_all = "kebab-case")]
pub enum ExtensionProvides {
Themes,
IconThemes,