From 83ce783856190afdbd2b9201551ce3e91dbfdef9 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Mon, 1 Apr 2024 17:10:30 -0400 Subject: [PATCH] Respect version constraints when installing extensions (#10052) This PR modifies the extension installation and update process to respect version constraints (schema version and Wasm API version) to ensure only compatible versions of extensions are able to be installed. To achieve this there is a new `GET /extensions/updates` endpoint that will return extension versions based on the provided constraints. Release Notes: - N/A --------- Co-authored-by: Max --- crates/collab/src/api/extensions.rs | 55 ++++++++- crates/collab/src/db.rs | 25 ++-- crates/collab/src/db/queries/extensions.rs | 107 ++++++++++++++---- crates/collab/src/db/tests/extension_tests.rs | 106 +++++++++++++++++ crates/extension/Cargo.toml | 4 - .../extension/src/extension_json_schemas.rs | 17 --- crates/extension/src/extension_store.rs | 54 +++++++-- crates/extension/src/wasm_host/wit.rs | 9 +- crates/extensions_ui/src/extensions_ui.rs | 3 +- 9 files changed, 303 insertions(+), 77 deletions(-) delete mode 100644 crates/extension/src/extension_json_schemas.rs diff --git a/crates/collab/src/api/extensions.rs b/crates/collab/src/api/extensions.rs index 2e8acd6c7b5ea1beb87cb08dd0a8504f7482d834..fd0909a2f46792791ab05f81bf85aa42a1c6b347 100644 --- a/crates/collab/src/api/extensions.rs +++ b/crates/collab/src/api/extensions.rs @@ -1,3 +1,4 @@ +use crate::db::ExtensionVersionConstraints; use crate::{db::NewExtensionVersion, AppState, Error, Result}; use anyhow::{anyhow, Context as _}; use aws_sdk_s3::presigning::PresigningConfig; @@ -10,14 +11,16 @@ use axum::{ }; use collections::HashMap; use rpc::{ExtensionApiManifest, GetExtensionsResponse}; +use semantic_version::SemanticVersion; use serde::Deserialize; use std::{sync::Arc, time::Duration}; use time::PrimitiveDateTime; -use util::ResultExt; +use util::{maybe, ResultExt}; pub fn router() -> Router { Router::new() .route("/extensions", get(get_extensions)) + .route("/extensions/updates", get(get_extension_updates)) .route("/extensions/:extension_id", get(get_extension_versions)) .route( "/extensions/:extension_id/download", @@ -48,9 +51,7 @@ async fn get_extensions( .map(|s| s.split(',').map(|s| s.trim()).collect::>()); let extensions = if let Some(extension_ids) = extension_ids { - app.db - .get_extensions_by_ids(&extension_ids, params.max_schema_version) - .await? + app.db.get_extensions_by_ids(&extension_ids, None).await? } else { app.db .get_extensions(params.filter.as_deref(), params.max_schema_version, 500) @@ -60,6 +61,34 @@ async fn get_extensions( Ok(Json(GetExtensionsResponse { data: extensions })) } +#[derive(Debug, Deserialize)] +struct GetExtensionUpdatesParams { + ids: String, + min_schema_version: i32, + max_schema_version: i32, + min_wasm_api_version: SemanticVersion, + max_wasm_api_version: SemanticVersion, +} + +async fn get_extension_updates( + Extension(app): Extension>, + Query(params): Query, +) -> Result> { + let constraints = ExtensionVersionConstraints { + schema_versions: params.min_schema_version..=params.max_schema_version, + wasm_api_versions: params.min_wasm_api_version..=params.max_wasm_api_version, + }; + + let extension_ids = params.ids.split(',').map(|s| s.trim()).collect::>(); + + let extensions = app + .db + .get_extensions_by_ids(&extension_ids, Some(&constraints)) + .await?; + + Ok(Json(GetExtensionsResponse { data: extensions })) +} + #[derive(Debug, Deserialize)] struct GetExtensionVersionsParams { extension_id: String, @@ -79,15 +108,31 @@ async fn get_extension_versions( #[derive(Debug, Deserialize)] struct DownloadLatestExtensionParams { extension_id: String, + min_schema_version: Option, + max_schema_version: Option, + min_wasm_api_version: Option, + max_wasm_api_version: Option, } async fn download_latest_extension( Extension(app): Extension>, Path(params): Path, ) -> Result { + let constraints = maybe!({ + let min_schema_version = params.min_schema_version?; + let max_schema_version = params.max_schema_version?; + let min_wasm_api_version = params.min_wasm_api_version?; + let max_wasm_api_version = params.max_wasm_api_version?; + + Some(ExtensionVersionConstraints { + schema_versions: min_schema_version..=max_schema_version, + wasm_api_versions: min_wasm_api_version..=max_wasm_api_version, + }) + }); + let extension = app .db - .get_extension(¶ms.extension_id) + .get_extension(¶ms.extension_id, constraints.as_ref()) .await? .ok_or_else(|| anyhow!("unknown extension"))?; download_extension( diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 637a8c31f5622145e79e268be8df66a49f7ae50b..0527e070ea0977a9558f24b4558d52cc0d97f1d9 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -21,11 +21,13 @@ use sea_orm::{ FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement, TransactionTrait, }; -use serde::{ser::Error as _, Deserialize, Serialize, Serializer}; +use semantic_version::SemanticVersion; +use serde::{Deserialize, Serialize}; use sqlx::{ migrate::{Migrate, Migration, MigrationSource}, Connection, }; +use std::ops::RangeInclusive; use std::{ fmt::Write as _, future::Future, @@ -36,7 +38,7 @@ use std::{ sync::Arc, time::Duration, }; -use time::{format_description::well_known::iso8601, PrimitiveDateTime}; +use time::PrimitiveDateTime; use tokio::sync::{Mutex, OwnedMutexGuard}; #[cfg(test)] @@ -730,20 +732,7 @@ pub struct NewExtensionVersion { pub published_at: PrimitiveDateTime, } -pub fn serialize_iso8601( - datetime: &PrimitiveDateTime, - serializer: S, -) -> Result { - const SERDE_CONFIG: iso8601::EncodedConfig = iso8601::Config::DEFAULT - .set_year_is_six_digits(false) - .set_time_precision(iso8601::TimePrecision::Second { - decimal_digits: None, - }) - .encode(); - - datetime - .assume_utc() - .format(&time::format_description::well_known::Iso8601::) - .map_err(S::Error::custom)? - .serialize(serializer) +pub struct ExtensionVersionConstraints { + pub schema_versions: RangeInclusive, + pub wasm_api_versions: RangeInclusive, } diff --git a/crates/collab/src/db/queries/extensions.rs b/crates/collab/src/db/queries/extensions.rs index fc3def1d6d99e0ee445a7bf9296fd90cae85a018..d6938fd776e2f8c0610e5dea77194a6ba940cccc 100644 --- a/crates/collab/src/db/queries/extensions.rs +++ b/crates/collab/src/db/queries/extensions.rs @@ -1,5 +1,8 @@ +use std::str::FromStr; + use chrono::Utc; use sea_orm::sea_query::IntoCondition; +use util::ResultExt; use super::*; @@ -32,23 +35,83 @@ impl Database { pub async fn get_extensions_by_ids( &self, ids: &[&str], - max_schema_version: i32, + constraints: Option<&ExtensionVersionConstraints>, ) -> Result> { self.transaction(|tx| async move { - let condition = Condition::all() - .add( - extension::Column::LatestVersion - .into_expr() - .eq(extension_version::Column::Version.into_expr()), - ) - .add(extension::Column::ExternalId.is_in(ids.iter().copied())) - .add(extension_version::Column::SchemaVersion.lte(max_schema_version)); + let extensions = extension::Entity::find() + .filter(extension::Column::ExternalId.is_in(ids.iter().copied())) + .all(&*tx) + .await?; - self.get_extensions_where(condition, None, &tx).await + let mut max_versions = self + .get_latest_versions_for_extensions(&extensions, constraints, &tx) + .await?; + + Ok(extensions + .into_iter() + .filter_map(|extension| { + let (version, _) = max_versions.remove(&extension.id)?; + Some(metadata_from_extension_and_version(extension, version)) + }) + .collect()) }) .await } + async fn get_latest_versions_for_extensions( + &self, + extensions: &[extension::Model], + constraints: Option<&ExtensionVersionConstraints>, + tx: &DatabaseTransaction, + ) -> Result> { + let mut versions = extension_version::Entity::find() + .filter( + extension_version::Column::ExtensionId + .is_in(extensions.iter().map(|extension| extension.id)), + ) + .stream(tx) + .await?; + + let mut max_versions = + HashMap::::default(); + while let Some(version) = versions.next().await { + let version = version?; + let Some(extension_version) = SemanticVersion::from_str(&version.version).log_err() + else { + continue; + }; + + if let Some((_, max_extension_version)) = &max_versions.get(&version.extension_id) { + if max_extension_version > &extension_version { + continue; + } + } + + if let Some(constraints) = constraints { + if !constraints + .schema_versions + .contains(&version.schema_version) + { + continue; + } + + if let Some(wasm_api_version) = version.wasm_api_version.as_ref() { + if let Some(version) = SemanticVersion::from_str(wasm_api_version).log_err() { + if !constraints.wasm_api_versions.contains(&version) { + continue; + } + } else { + continue; + } + } + } + + max_versions.insert(version.extension_id, (version, extension_version)); + } + + Ok(max_versions) + } + /// Returns all of the versions for the extension with the given ID. pub async fn get_extension_versions( &self, @@ -88,22 +151,26 @@ impl Database { .collect()) } - pub async fn get_extension(&self, extension_id: &str) -> Result> { + pub async fn get_extension( + &self, + extension_id: &str, + constraints: Option<&ExtensionVersionConstraints>, + ) -> Result> { self.transaction(|tx| async move { let extension = extension::Entity::find() .filter(extension::Column::ExternalId.eq(extension_id)) - .filter( - extension::Column::LatestVersion - .into_expr() - .eq(extension_version::Column::Version.into_expr()), - ) - .inner_join(extension_version::Entity) - .select_also(extension_version::Entity) .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such extension: {extension_id}"))?; + + let extensions = [extension]; + let mut versions = self + .get_latest_versions_for_extensions(&extensions, constraints, &tx) .await?; + let [extension] = extensions; - Ok(extension.and_then(|(extension, version)| { - Some(metadata_from_extension_and_version(extension, version?)) + Ok(versions.remove(&extension.id).map(|(max_version, _)| { + metadata_from_extension_and_version(extension, max_version) })) }) .await diff --git a/crates/collab/src/db/tests/extension_tests.rs b/crates/collab/src/db/tests/extension_tests.rs index 49e94e24d5d43517b77eaace5f95d1123e5a6a6f..b91570c49405c16ef9308bd63f352b8177403327 100644 --- a/crates/collab/src/db/tests/extension_tests.rs +++ b/crates/collab/src/db/tests/extension_tests.rs @@ -1,4 +1,5 @@ use super::Database; +use crate::db::ExtensionVersionConstraints; use crate::{ db::{queries::extensions::convert_time_to_chrono, ExtensionMetadata, NewExtensionVersion}, test_both_dbs, @@ -278,3 +279,108 @@ async fn test_extensions(db: &Arc) { ] ); } + +test_both_dbs!( + test_extensions_by_id, + test_extensions_by_id_postgres, + test_extensions_by_id_sqlite +); + +async fn test_extensions_by_id(db: &Arc) { + let versions = db.get_known_extension_versions().await.unwrap(); + assert!(versions.is_empty()); + + let extensions = db.get_extensions(None, 1, 5).await.unwrap(); + assert!(extensions.is_empty()); + + let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap(); + let t0 = time::PrimitiveDateTime::new(t0.date(), t0.time()); + + let t0_chrono = convert_time_to_chrono(t0); + + db.insert_extension_versions( + &[ + ( + "ext1", + vec![ + NewExtensionVersion { + name: "Extension 1".into(), + version: semver::Version::parse("0.0.1").unwrap(), + description: "an extension".into(), + authors: vec!["max".into()], + repository: "ext1/repo".into(), + schema_version: 1, + wasm_api_version: Some("0.0.4".into()), + published_at: t0, + }, + NewExtensionVersion { + name: "Extension 1".into(), + version: semver::Version::parse("0.0.2").unwrap(), + description: "a good extension".into(), + authors: vec!["max".into()], + repository: "ext1/repo".into(), + schema_version: 1, + wasm_api_version: Some("0.0.4".into()), + published_at: t0, + }, + NewExtensionVersion { + name: "Extension 1".into(), + version: semver::Version::parse("0.0.3").unwrap(), + description: "a real good extension".into(), + authors: vec!["max".into(), "marshall".into()], + repository: "ext1/repo".into(), + schema_version: 1, + wasm_api_version: Some("0.0.5".into()), + published_at: t0, + }, + ], + ), + ( + "ext2", + vec![NewExtensionVersion { + name: "Extension 2".into(), + version: semver::Version::parse("0.2.0").unwrap(), + description: "a great extension".into(), + authors: vec!["marshall".into()], + repository: "ext2/repo".into(), + schema_version: 0, + wasm_api_version: None, + published_at: t0, + }], + ), + ] + .into_iter() + .collect(), + ) + .await + .unwrap(); + + let extensions = db + .get_extensions_by_ids( + &["ext1"], + Some(&ExtensionVersionConstraints { + schema_versions: 1..=1, + wasm_api_versions: "0.0.1".parse().unwrap()..="0.0.4".parse().unwrap(), + }), + ) + .await + .unwrap(); + + assert_eq!( + extensions, + &[ExtensionMetadata { + id: "ext1".into(), + manifest: rpc::ExtensionApiManifest { + name: "Extension 1".into(), + version: "0.0.2".into(), + authors: vec!["max".into()], + description: Some("a good extension".into()), + repository: "ext1/repo".into(), + schema_version: Some(1), + wasm_api_version: Some("0.0.4".into()), + }, + published_at: t0_chrono, + download_count: 0, + }] + ); +} diff --git a/crates/extension/Cargo.toml b/crates/extension/Cargo.toml index e9b02c72dd29a3721d2fa6506faf793f8d1e4117..df02174e1e410301df2efdd9a7ef3299d96036af 100644 --- a/crates/extension/Cargo.toml +++ b/crates/extension/Cargo.toml @@ -12,10 +12,6 @@ workspace = true path = "src/extension_store.rs" doctest = false -[[bin]] -name = "extension_json_schemas" -path = "src/extension_json_schemas.rs" - [dependencies] anyhow.workspace = true async-compression.workspace = true diff --git a/crates/extension/src/extension_json_schemas.rs b/crates/extension/src/extension_json_schemas.rs deleted file mode 100644 index b46e72fce65bef2485354107166e8b9466d333d4..0000000000000000000000000000000000000000 --- a/crates/extension/src/extension_json_schemas.rs +++ /dev/null @@ -1,17 +0,0 @@ -use language::LanguageConfig; -use schemars::schema_for; -use theme::ThemeFamilyContent; - -fn main() { - let theme_family_schema = schema_for!(ThemeFamilyContent); - let language_config_schema = schema_for!(LanguageConfig); - - println!( - "{}", - serde_json::to_string_pretty(&theme_family_schema).unwrap() - ); - println!( - "{}", - serde_json::to_string_pretty(&language_config_schema).unwrap() - ); -} diff --git a/crates/extension/src/extension_store.rs b/crates/extension/src/extension_store.rs index 904d3cd25d2d9f82c1cfdd8b8b997cb7a77f0544..72001a0f73eaef040dcb6ace3ccc2a4af8eeb918 100644 --- a/crates/extension/src/extension_store.rs +++ b/crates/extension/src/extension_store.rs @@ -36,6 +36,7 @@ use node_runtime::NodeRuntime; use semantic_version::SemanticVersion; use serde::{Deserialize, Serialize}; use settings::Settings; +use std::ops::RangeInclusive; use std::str::FromStr; use std::{ cmp::Ordering, @@ -51,7 +52,10 @@ use util::{ paths::EXTENSIONS_DIR, ResultExt, }; -use wasm_host::{wit::is_supported_wasm_api_version, WasmExtension, WasmHost}; +use wasm_host::{ + wit::{is_supported_wasm_api_version, wasm_api_version_range}, + WasmExtension, WasmHost, +}; pub use extension_manifest::{ ExtensionLibraryKind, ExtensionManifest, GrammarManifestEntry, OldExtensionManifest, @@ -64,6 +68,11 @@ const FS_WATCH_LATENCY: Duration = Duration::from_millis(100); /// The current extension [`SchemaVersion`] supported by Zed. const CURRENT_SCHEMA_VERSION: SchemaVersion = SchemaVersion(1); +/// Returns the [`SchemaVersion`] range that is compatible with this version of Zed. +pub fn schema_version_range() -> RangeInclusive { + SchemaVersion::ZERO..=CURRENT_SCHEMA_VERSION +} + /// Returns whether the given extension version is compatible with this version of Zed. pub fn is_version_compatible(extension_version: &ExtensionMetadata) -> bool { let schema_version = extension_version.manifest.schema_version.unwrap_or(0); @@ -412,15 +421,15 @@ impl ExtensionStore { query.push(("filter", search)); } - self.fetch_extensions_from_api("/extensions", query, cx) + self.fetch_extensions_from_api("/extensions", &query, cx) } pub fn fetch_extensions_with_update_available( &mut self, cx: &mut ModelContext, ) -> Task>> { - let version = CURRENT_SCHEMA_VERSION.to_string(); - let mut query = vec![("max_schema_version", version.as_str())]; + let schema_versions = schema_version_range(); + let wasm_api_versions = wasm_api_version_range(); let extension_settings = ExtensionSettings::get_global(cx); let extension_ids = self .extension_index @@ -430,9 +439,20 @@ impl ExtensionStore { .filter(|id| extension_settings.should_auto_update(id)) .collect::>() .join(","); - query.push(("ids", &extension_ids)); - - let task = self.fetch_extensions_from_api("/extensions", query, cx); + let task = self.fetch_extensions_from_api( + "/extensions/updates", + &[ + ("min_schema_version", &schema_versions.start().to_string()), + ("max_schema_version", &schema_versions.end().to_string()), + ( + "min_wasm_api_version", + &wasm_api_versions.start().to_string(), + ), + ("max_wasm_api_version", &wasm_api_versions.end().to_string()), + ("ids", &extension_ids), + ], + cx, + ); cx.spawn(move |this, mut cx| async move { let extensions = task.await?; this.update(&mut cx, |this, _cx| { @@ -456,7 +476,7 @@ impl ExtensionStore { extension_id: &str, cx: &mut ModelContext, ) -> Task>> { - self.fetch_extensions_from_api(&format!("/extensions/{extension_id}"), Vec::new(), cx) + self.fetch_extensions_from_api(&format!("/extensions/{extension_id}"), &[], cx) } pub fn check_for_updates(&mut self, cx: &mut ModelContext) { @@ -500,7 +520,7 @@ impl ExtensionStore { fn fetch_extensions_from_api( &self, path: &str, - query: Vec<(&str, &str)>, + query: &[(&str, &str)], cx: &mut ModelContext<'_, ExtensionStore>, ) -> Task>> { let url = self.http_client.build_zed_api_url(path, &query); @@ -614,9 +634,23 @@ impl ExtensionStore { ) { log::info!("installing extension {extension_id} latest version"); + let schema_versions = schema_version_range(); + let wasm_api_versions = wasm_api_version_range(); + let Some(url) = self .http_client - .build_zed_api_url(&format!("/extensions/{extension_id}/download"), &[]) + .build_zed_api_url( + &format!("/extensions/{extension_id}/download"), + &[ + ("min_schema_version", &schema_versions.start().to_string()), + ("max_schema_version", &schema_versions.end().to_string()), + ( + "min_wasm_api_version", + &wasm_api_versions.start().to_string(), + ), + ("max_wasm_api_version", &wasm_api_versions.end().to_string()), + ], + ) .log_err() else { return; diff --git a/crates/extension/src/wasm_host/wit.rs b/crates/extension/src/wasm_host/wit.rs index a4790deec11a43c0545da954733ccc5de4daae7c..da14f0466423cf7e5d02ace11f5141c9e992af1a 100644 --- a/crates/extension/src/wasm_host/wit.rs +++ b/crates/extension/src/wasm_host/wit.rs @@ -5,6 +5,7 @@ use super::{wasm_engine, WasmState}; use anyhow::{Context, Result}; use language::LspAdapterDelegate; use semantic_version::SemanticVersion; +use std::ops::RangeInclusive; use std::sync::Arc; use wasmtime::{ component::{Component, Instance, Linker, Resource}, @@ -30,7 +31,13 @@ fn wasi_view(state: &mut WasmState) -> &mut WasmState { /// Returns whether the given Wasm API version is supported by the Wasm host. pub fn is_supported_wasm_api_version(version: SemanticVersion) -> bool { - since_v0_0_1::MIN_VERSION <= version && version <= latest::MAX_VERSION + wasm_api_version_range().contains(&version) +} + +/// Returns the Wasm API version range that is supported by the Wasm host. +#[inline(always)] +pub fn wasm_api_version_range() -> RangeInclusive { + since_v0_0_1::MIN_VERSION..=latest::MAX_VERSION } pub enum Extension { diff --git a/crates/extensions_ui/src/extensions_ui.rs b/crates/extensions_ui/src/extensions_ui.rs index 1649230f5f28fbacd298cf6c771cb94a567bce34..5a1278b0f783d5f71c554311c4cdaac3c7147efa 100644 --- a/crates/extensions_ui/src/extensions_ui.rs +++ b/crates/extensions_ui/src/extensions_ui.rs @@ -587,12 +587,11 @@ impl ExtensionsPage { .disabled(disabled) .on_click(cx.listener({ let extension_id = extension.id.clone(); - let version = extension.manifest.version.clone(); move |this, _, cx| { this.telemetry .report_app_event("extensions: install extension".to_string()); ExtensionStore::global(cx).update(cx, |store, cx| { - store.install_extension(extension_id.clone(), version.clone(), cx) + store.install_latest_extension(extension_id.clone(), cx) }); } })),