extension_host: Refactor capability checks (#35139)

Marshall Bowers created

This PR refactors the extension capability checks to be centralized in
the `CapabilityGranter`.

Release Notes:

- N/A

Change summary

crates/extension/src/extension_manifest.rs                       |  88 
crates/extension_host/benches/extension_compilation_benchmark.rs |  10 
crates/extension_host/src/capability_granter.rs                  | 115 ++
crates/extension_host/src/extension_host.rs                      |   1 
crates/extension_host/src/wasm_host.rs                           |  17 
crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs          |   3 
6 files changed, 193 insertions(+), 41 deletions(-)

Detailed changes

crates/extension/src/extension_manifest.rs 🔗

@@ -100,26 +100,9 @@ impl ExtensionManifest {
         desired_args: &[impl AsRef<str> + std::fmt::Debug],
     ) -> Result<()> {
         let is_allowed = self.capabilities.iter().any(|capability| match capability {
-            ExtensionCapability::ProcessExec { command, args } if command == desired_command => {
-                for (ix, arg) in args.iter().enumerate() {
-                    if arg == "**" {
-                        return true;
-                    }
-
-                    if ix >= desired_args.len() {
-                        return false;
-                    }
-
-                    if arg != "*" && arg != desired_args[ix].as_ref() {
-                        return false;
-                    }
-                }
-                if args.len() < desired_args.len() {
-                    return false;
-                }
-                true
+            ExtensionCapability::ProcessExec(capability) => {
+                capability.allows(desired_command, desired_args)
             }
-            _ => false,
         });
 
         if !is_allowed {
@@ -153,13 +136,50 @@ pub fn build_debug_adapter_schema_path(
 #[serde(tag = "kind")]
 pub enum ExtensionCapability {
     #[serde(rename = "process:exec")]
-    ProcessExec {
-        /// The command to execute.
-        command: String,
-        /// The arguments to pass to the command. Use `*` for a single wildcard argument.
-        /// If the last element is `**`, then any trailing arguments are allowed.
-        args: Vec<String>,
-    },
+    ProcessExec(ProcessExecCapability),
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+pub struct ProcessExecCapability {
+    /// The command to execute.
+    pub command: String,
+    /// The arguments to pass to the command. Use `*` for a single wildcard argument.
+    /// If the last element is `**`, then any trailing arguments are allowed.
+    pub args: Vec<String>,
+}
+
+impl ProcessExecCapability {
+    /// Returns whether the capability allows the given command and arguments.
+    pub fn allows(
+        &self,
+        desired_command: &str,
+        desired_args: &[impl AsRef<str> + std::fmt::Debug],
+    ) -> bool {
+        if self.command != desired_command && self.command != "*" {
+            return false;
+        }
+
+        for (ix, arg) in self.args.iter().enumerate() {
+            if arg == "**" {
+                return true;
+            }
+
+            if ix >= desired_args.len() {
+                return false;
+            }
+
+            if arg != "*" && arg != desired_args[ix].as_ref() {
+                return false;
+            }
+        }
+
+        if self.args.len() < desired_args.len() {
+            return false;
+        }
+
+        true
+    }
 }
 
 #[derive(Clone, Default, PartialEq, Eq, Debug, Deserialize, Serialize)]
@@ -362,10 +382,10 @@ mod tests {
     #[test]
     fn test_allow_exact_match() {
         let manifest = ExtensionManifest {
-            capabilities: vec![ExtensionCapability::ProcessExec {
+            capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability {
                 command: "ls".to_string(),
                 args: vec!["-la".to_string()],
-            }],
+            })],
             ..extension_manifest()
         };
 
@@ -377,10 +397,10 @@ mod tests {
     #[test]
     fn test_allow_wildcard_arg() {
         let manifest = ExtensionManifest {
-            capabilities: vec![ExtensionCapability::ProcessExec {
+            capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability {
                 command: "git".to_string(),
                 args: vec!["*".to_string()],
-            }],
+            })],
             ..extension_manifest()
         };
 
@@ -393,10 +413,10 @@ mod tests {
     #[test]
     fn test_allow_double_wildcard() {
         let manifest = ExtensionManifest {
-            capabilities: vec![ExtensionCapability::ProcessExec {
+            capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability {
                 command: "cargo".to_string(),
                 args: vec!["test".to_string(), "**".to_string()],
-            }],
+            })],
             ..extension_manifest()
         };
 
@@ -413,10 +433,10 @@ mod tests {
     #[test]
     fn test_allow_mixed_wildcards() {
         let manifest = ExtensionManifest {
-            capabilities: vec![ExtensionCapability::ProcessExec {
+            capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability {
                 command: "docker".to_string(),
                 args: vec!["run".to_string(), "*".to_string(), "**".to_string()],
-            }],
+            })],
             ..extension_manifest()
         };
 

crates/extension_host/benches/extension_compilation_benchmark.rs 🔗

@@ -134,10 +134,12 @@ fn manifest() -> ExtensionManifest {
         slash_commands: BTreeMap::default(),
         indexed_docs_providers: BTreeMap::default(),
         snippets: None,
-        capabilities: vec![ExtensionCapability::ProcessExec {
-            command: "echo".into(),
-            args: vec!["hello!".into()],
-        }],
+        capabilities: vec![ExtensionCapability::ProcessExec(
+            extension::ProcessExecCapability {
+                command: "echo".into(),
+                args: vec!["hello!".into()],
+            },
+        )],
         debug_adapters: Default::default(),
         debug_locators: Default::default(),
     }

crates/extension_host/src/capability_granter.rs 🔗

@@ -0,0 +1,115 @@
+use std::sync::Arc;
+
+use anyhow::{Result, bail};
+use extension::{ExtensionCapability, ExtensionManifest};
+
+pub struct CapabilityGranter {
+    granted_capabilities: Vec<ExtensionCapability>,
+    manifest: Arc<ExtensionManifest>,
+}
+
+impl CapabilityGranter {
+    pub fn new(
+        granted_capabilities: Vec<ExtensionCapability>,
+        manifest: Arc<ExtensionManifest>,
+    ) -> Self {
+        Self {
+            granted_capabilities,
+            manifest,
+        }
+    }
+
+    pub fn grant_exec(
+        &self,
+        desired_command: &str,
+        desired_args: &[impl AsRef<str> + std::fmt::Debug],
+    ) -> Result<()> {
+        self.manifest.allow_exec(desired_command, desired_args)?;
+
+        let is_allowed = self
+            .granted_capabilities
+            .iter()
+            .any(|capability| match capability {
+                ExtensionCapability::ProcessExec(capability) => {
+                    capability.allows(desired_command, desired_args)
+                }
+            });
+
+        if !is_allowed {
+            bail!(
+                "capability for process:exec {desired_command} {desired_args:?} is not granted by the extension host",
+            );
+        }
+
+        Ok(())
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use std::collections::BTreeMap;
+
+    use extension::{ProcessExecCapability, SchemaVersion};
+
+    use super::*;
+
+    fn extension_manifest() -> ExtensionManifest {
+        ExtensionManifest {
+            id: "test".into(),
+            name: "Test".to_string(),
+            version: "1.0.0".into(),
+            schema_version: SchemaVersion::ZERO,
+            description: None,
+            repository: None,
+            authors: vec![],
+            lib: Default::default(),
+            themes: vec![],
+            icon_themes: vec![],
+            languages: vec![],
+            grammars: BTreeMap::default(),
+            language_servers: BTreeMap::default(),
+            context_servers: BTreeMap::default(),
+            slash_commands: BTreeMap::default(),
+            indexed_docs_providers: BTreeMap::default(),
+            snippets: None,
+            capabilities: vec![],
+            debug_adapters: Default::default(),
+            debug_locators: Default::default(),
+        }
+    }
+
+    #[test]
+    fn test_grant_exec() {
+        let manifest = Arc::new(ExtensionManifest {
+            capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability {
+                command: "ls".to_string(),
+                args: vec!["-la".to_string()],
+            })],
+            ..extension_manifest()
+        });
+
+        // It returns an error when the extension host has no granted capabilities.
+        let granter = CapabilityGranter::new(Vec::new(), manifest.clone());
+        assert!(granter.grant_exec("ls", &["-la"]).is_err());
+
+        // It succeeds when the extension host has the exact capability.
+        let granter = CapabilityGranter::new(
+            vec![ExtensionCapability::ProcessExec(ProcessExecCapability {
+                command: "ls".to_string(),
+                args: vec!["-la".to_string()],
+            })],
+            manifest.clone(),
+        );
+        assert!(granter.grant_exec("ls", &["-la"]).is_ok());
+
+        // It succeeds when the extension host has a wildcard capability.
+        let granter = CapabilityGranter::new(
+            vec![ExtensionCapability::ProcessExec(ProcessExecCapability {
+                command: "*".to_string(),
+                args: vec!["**".to_string()],
+            })],
+            manifest.clone(),
+        );
+        assert!(granter.grant_exec("ls", &["-la"]).is_ok());
+    }
+}

crates/extension_host/src/wasm_host.rs 🔗

@@ -1,13 +1,15 @@
 pub mod wit;
 
 use crate::ExtensionManifest;
+use crate::capability_granter::CapabilityGranter;
 use anyhow::{Context as _, Result, anyhow, bail};
 use async_trait::async_trait;
 use dap::{DebugRequest, StartDebuggingRequestArgumentsRequest};
 use extension::{
     CodeLabel, Command, Completion, ContextServerConfiguration, DebugAdapterBinary,
-    DebugTaskDefinition, ExtensionHostProxy, KeyValueStoreDelegate, ProjectDelegate, SlashCommand,
-    SlashCommandArgumentCompletion, SlashCommandOutput, Symbol, WorktreeDelegate,
+    DebugTaskDefinition, ExtensionCapability, ExtensionHostProxy, KeyValueStoreDelegate,
+    ProcessExecCapability, ProjectDelegate, SlashCommand, SlashCommandArgumentCompletion,
+    SlashCommandOutput, Symbol, WorktreeDelegate,
 };
 use fs::{Fs, normalize_path};
 use futures::future::LocalBoxFuture;
@@ -50,6 +52,8 @@ pub struct WasmHost {
     pub(crate) proxy: Arc<ExtensionHostProxy>,
     fs: Arc<dyn Fs>,
     pub work_dir: PathBuf,
+    /// The capabilities granted to extensions running on the host.
+    pub(crate) granted_capabilities: Vec<ExtensionCapability>,
     _main_thread_message_task: Task<()>,
     main_thread_message_tx: mpsc::UnboundedSender<MainThreadCall>,
 }
@@ -486,6 +490,7 @@ pub struct WasmState {
     pub table: ResourceTable,
     ctx: wasi::WasiCtx,
     pub host: Arc<WasmHost>,
+    pub(crate) capability_granter: CapabilityGranter,
 }
 
 type MainThreadCall = Box<dyn Send + for<'a> FnOnce(&'a mut AsyncApp) -> LocalBoxFuture<'a, ()>>;
@@ -571,6 +576,10 @@ impl WasmHost {
             node_runtime,
             proxy,
             release_channel: ReleaseChannel::global(cx),
+            granted_capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability {
+                command: "*".to_string(),
+                args: vec!["**".to_string()],
+            })],
             _main_thread_message_task: task,
             main_thread_message_tx: tx,
         })
@@ -597,6 +606,10 @@ impl WasmHost {
                     manifest: manifest.clone(),
                     table: ResourceTable::new(),
                     host: this.clone(),
+                    capability_granter: CapabilityGranter::new(
+                        this.granted_capabilities.clone(),
+                        manifest.clone(),
+                    ),
                 },
             );
             // Store will yield after 1 tick, and get a new deadline of 1 tick after each yield.

crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs 🔗

@@ -847,7 +847,8 @@ impl process::Host for WasmState {
         command: process::Command,
     ) -> wasmtime::Result<Result<process::Output, String>> {
         maybe!(async {
-            self.manifest.allow_exec(&command.command, &command.args)?;
+            self.capability_granter
+                .grant_exec(&command.command, &command.args)?;
 
             let output = util::command::new_smol_command(command.command.as_str())
                 .args(&command.args)