Extract `SlashCommand` trait from `assistant` (#12252)

Marshall Bowers created

This PR extracts the `SlashCommand` trait (along with the
`SlashCommandRegistry`) from the `assistant` crate.

This will allow us to register slash commands from extensions without
having to make `extension` depend on `assistant`.

Release Notes:

- N/A

Change summary

Cargo.lock                                                    | 13 
Cargo.toml                                                    |  2 
crates/assistant/Cargo.toml                                   |  1 
crates/assistant/src/assistant.rs                             |  1 
crates/assistant/src/assistant_panel.rs                       | 28 +
crates/assistant/src/slash_command.rs                         | 94 ----
crates/assistant_slash_command/Cargo.toml                     | 20 +
crates/assistant_slash_command/LICENSE-GPL                    |  1 
crates/assistant_slash_command/src/assistant_slash_command.rs | 50 ++
crates/assistant_slash_command/src/slash_command_registry.rs  | 64 +++
10 files changed, 183 insertions(+), 91 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -337,6 +337,7 @@ version = "0.1.0"
 dependencies = [
  "anthropic",
  "anyhow",
+ "assistant_slash_command",
  "cargo_toml",
  "chrono",
  "client",
@@ -426,6 +427,18 @@ dependencies = [
  "workspace",
 ]
 
+[[package]]
+name = "assistant_slash_command"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "collections",
+ "derive_more",
+ "futures 0.3.28",
+ "gpui",
+ "parking_lot",
+]
+
 [[package]]
 name = "assistant_tooling"
 version = "0.1.0"

Cargo.toml 🔗

@@ -5,6 +5,7 @@ members = [
     "crates/assets",
     "crates/assistant",
     "crates/assistant2",
+    "crates/assistant_slash_command",
     "crates/assistant_tooling",
     "crates/audio",
     "crates/auto_update",
@@ -148,6 +149,7 @@ anthropic = { path = "crates/anthropic" }
 assets = { path = "crates/assets" }
 assistant = { path = "crates/assistant" }
 assistant2 = { path = "crates/assistant2" }
+assistant_slash_command = { path = "crates/assistant_slash_command" }
 assistant_tooling = { path = "crates/assistant_tooling" }
 audio = { path = "crates/audio" }
 auto_update = { path = "crates/auto_update" }

crates/assistant/Cargo.toml 🔗

@@ -12,6 +12,7 @@ doctest = false
 [dependencies]
 anyhow.workspace = true
 anthropic = { workspace = true, features = ["schemars"] }
+assistant_slash_command.workspace = true
 cargo_toml.workspace = true
 chrono.workspace = true
 client.workspace = true

crates/assistant/src/assistant.rs 🔗

@@ -17,7 +17,6 @@ use client::{proto, Client};
 use command_palette_hooks::CommandPaletteFilter;
 pub(crate) use completion_provider::*;
 use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
-pub(crate) use prompts::prompt_library::*;
 pub(crate) use saved_conversation::*;
 use serde::{Deserialize, Serialize};
 use settings::{Settings, SettingsStore};

crates/assistant/src/assistant_panel.rs 🔗

@@ -9,7 +9,8 @@ use crate::{
     prompts::prompt::generate_content_prompt,
     search::*,
     slash_command::{
-        SlashCommandCleanup, SlashCommandCompletionProvider, SlashCommandLine, SlashCommandRegistry,
+        current_file_command, file_command, prompt_command, SlashCommandCleanup,
+        SlashCommandCompletionProvider, SlashCommandLine, SlashCommandRegistry,
     },
     ApplyEdit, Assist, CompletionProvider, CycleMessageRole, InlineAssist, LanguageModel,
     LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata, MessageStatus,
@@ -204,11 +205,21 @@ impl AssistantPanel {
                     })
                     .detach();
 
-                    let slash_command_registry = SlashCommandRegistry::new(
+                    let slash_command_registry = SlashCommandRegistry::new();
+
+                    let window = cx.window_handle().downcast::<Workspace>();
+
+                    slash_command_registry.register_command(file_command::FileSlashCommand::new(
                         workspace.project().clone(),
-                        prompt_library.clone(),
-                        cx.window_handle().downcast::<Workspace>(),
+                    ));
+                    slash_command_registry.register_command(
+                        prompt_command::PromptSlashCommand::new(prompt_library.clone()),
                     );
+                    if let Some(window) = window {
+                        slash_command_registry.register_command(
+                            current_file_command::CurrentFileSlashCommand::new(window),
+                        );
+                    }
 
                     Self {
                         workspace: workspace_handle,
@@ -4273,8 +4284,13 @@ mod tests {
 
         let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await;
         let prompt_library = Arc::new(PromptLibrary::default());
-        let slash_command_registry =
-            SlashCommandRegistry::new(project.clone(), prompt_library, None);
+        let slash_command_registry = SlashCommandRegistry::new();
+
+        slash_command_registry
+            .register_command(file_command::FileSlashCommand::new(project.clone()));
+        slash_command_registry.register_command(prompt_command::PromptSlashCommand::new(
+            prompt_library.clone(),
+        ));
 
         let registry = Arc::new(LanguageRegistry::test(cx.executor()));
         let conversation = cx.new_model(|cx| {

crates/assistant/src/slash_command.rs 🔗

@@ -1,12 +1,9 @@
 use anyhow::Result;
-use collections::HashMap;
 use editor::{CompletionProvider, Editor};
-use futures::channel::oneshot;
 use fuzzy::{match_strings, StringMatchCandidate};
-use gpui::{AppContext, Model, Task, ViewContext, WindowHandle};
+use gpui::{AppContext, Model, Task, ViewContext};
 use language::{Anchor, Buffer, CodeLabel, Documentation, LanguageServerId, ToPoint};
 use parking_lot::{Mutex, RwLock};
-use project::Project;
 use rope::Point;
 use std::{
     ops::Range,
@@ -15,60 +12,20 @@ use std::{
         Arc,
     },
 };
-use workspace::Workspace;
 
-use crate::PromptLibrary;
+pub use assistant_slash_command::{
+    SlashCommand, SlashCommandCleanup, SlashCommandInvocation, SlashCommandRegistry,
+};
 
-mod current_file_command;
-mod file_command;
-mod prompt_command;
+pub mod current_file_command;
+pub mod file_command;
+pub mod prompt_command;
 
 pub(crate) struct SlashCommandCompletionProvider {
     commands: Arc<SlashCommandRegistry>,
     cancel_flag: Mutex<Arc<AtomicBool>>,
 }
 
-#[derive(Default)]
-pub(crate) struct SlashCommandRegistry {
-    commands: HashMap<String, Box<dyn SlashCommand>>,
-}
-
-pub(crate) trait SlashCommand: 'static + Send + Sync {
-    fn name(&self) -> String;
-    fn description(&self) -> String;
-    fn complete_argument(
-        &self,
-        query: String,
-        cancel: Arc<AtomicBool>,
-        cx: &mut AppContext,
-    ) -> Task<Result<Vec<String>>>;
-    fn requires_argument(&self) -> bool;
-    fn run(&self, argument: Option<&str>, cx: &mut AppContext) -> SlashCommandInvocation;
-}
-
-pub(crate) struct SlashCommandInvocation {
-    pub output: Task<Result<String>>,
-    pub invalidated: oneshot::Receiver<()>,
-    pub cleanup: SlashCommandCleanup,
-}
-
-#[derive(Default)]
-pub(crate) struct SlashCommandCleanup(Option<Box<dyn FnOnce()>>);
-
-impl SlashCommandCleanup {
-    pub fn new(cleanup: impl FnOnce() + 'static) -> Self {
-        Self(Some(Box::new(cleanup)))
-    }
-}
-
-impl Drop for SlashCommandCleanup {
-    fn drop(&mut self) {
-        if let Some(cleanup) = self.0.take() {
-            cleanup();
-        }
-    }
-}
-
 pub(crate) struct SlashCommandLine {
     /// The range within the line containing the command name.
     pub name: Range<usize>,
@@ -76,38 +33,6 @@ pub(crate) struct SlashCommandLine {
     pub argument: Option<Range<usize>>,
 }
 
-impl SlashCommandRegistry {
-    pub fn new(
-        project: Model<Project>,
-        prompt_library: Arc<PromptLibrary>,
-        window: Option<WindowHandle<Workspace>>,
-    ) -> Arc<Self> {
-        let mut this = Self {
-            commands: HashMap::default(),
-        };
-
-        this.register_command(file_command::FileSlashCommand::new(project));
-        this.register_command(prompt_command::PromptSlashCommand::new(prompt_library));
-        if let Some(window) = window {
-            this.register_command(current_file_command::CurrentFileSlashCommand::new(window));
-        }
-
-        Arc::new(this)
-    }
-
-    fn register_command(&mut self, command: impl SlashCommand) {
-        self.commands.insert(command.name(), Box::new(command));
-    }
-
-    fn command_names(&self) -> impl Iterator<Item = &String> {
-        self.commands.keys()
-    }
-
-    pub(crate) fn command(&self, name: &str) -> Option<&dyn SlashCommand> {
-        self.commands.get(name).map(|b| &**b)
-    }
-}
-
 impl SlashCommandCompletionProvider {
     pub fn new(commands: Arc<SlashCommandRegistry>) -> Self {
         Self {
@@ -125,11 +50,12 @@ impl SlashCommandCompletionProvider {
         let candidates = self
             .commands
             .command_names()
+            .into_iter()
             .enumerate()
             .map(|(ix, def)| StringMatchCandidate {
                 id: ix,
-                string: def.clone(),
-                char_bag: def.as_str().into(),
+                string: def.to_string(),
+                char_bag: def.as_ref().into(),
             })
             .collect::<Vec<_>>();
         let commands = self.commands.clone();

crates/assistant_slash_command/Cargo.toml 🔗

@@ -0,0 +1,20 @@
+[package]
+name = "assistant_slash_command"
+version = "0.1.0"
+edition = "2021"
+publish = false
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/assistant_slash_command.rs"
+
+[dependencies]
+anyhow.workspace = true
+collections.workspace = true
+derive_more.workspace = true
+futures.workspace = true
+gpui.workspace = true
+parking_lot.workspace = true

crates/assistant_slash_command/src/assistant_slash_command.rs 🔗

@@ -0,0 +1,50 @@
+mod slash_command_registry;
+
+use std::sync::atomic::AtomicBool;
+use std::sync::Arc;
+
+use anyhow::Result;
+use futures::channel::oneshot;
+use gpui::{AppContext, Task};
+
+pub use slash_command_registry::*;
+
+pub fn init(cx: &mut AppContext) {
+    SlashCommandRegistry::default_global(cx);
+}
+
+pub trait SlashCommand: 'static + Send + Sync {
+    fn name(&self) -> String;
+    fn description(&self) -> String;
+    fn complete_argument(
+        &self,
+        query: String,
+        cancel: Arc<AtomicBool>,
+        cx: &mut AppContext,
+    ) -> Task<Result<Vec<String>>>;
+    fn requires_argument(&self) -> bool;
+    fn run(&self, argument: Option<&str>, cx: &mut AppContext) -> SlashCommandInvocation;
+}
+
+pub struct SlashCommandInvocation {
+    pub output: Task<Result<String>>,
+    pub invalidated: oneshot::Receiver<()>,
+    pub cleanup: SlashCommandCleanup,
+}
+
+#[derive(Default)]
+pub struct SlashCommandCleanup(Option<Box<dyn FnOnce()>>);
+
+impl SlashCommandCleanup {
+    pub fn new(cleanup: impl FnOnce() + 'static) -> Self {
+        Self(Some(Box::new(cleanup)))
+    }
+}
+
+impl Drop for SlashCommandCleanup {
+    fn drop(&mut self) {
+        if let Some(cleanup) = self.0.take() {
+            cleanup();
+        }
+    }
+}

crates/assistant_slash_command/src/slash_command_registry.rs 🔗

@@ -0,0 +1,64 @@
+use std::sync::Arc;
+
+use collections::HashMap;
+use derive_more::{Deref, DerefMut};
+use gpui::Global;
+use gpui::{AppContext, ReadGlobal};
+use parking_lot::RwLock;
+
+use crate::SlashCommand;
+
+#[derive(Default, Deref, DerefMut)]
+struct GlobalSlashCommandRegistry(Arc<SlashCommandRegistry>);
+
+impl Global for GlobalSlashCommandRegistry {}
+
+#[derive(Default)]
+struct SlashCommandRegistryState {
+    commands: HashMap<Arc<str>, Arc<dyn SlashCommand>>,
+}
+
+#[derive(Default)]
+pub struct SlashCommandRegistry {
+    state: RwLock<SlashCommandRegistryState>,
+}
+
+impl SlashCommandRegistry {
+    /// Returns the global [`SlashCommandRegistry`].
+    pub fn global(cx: &AppContext) -> Arc<Self> {
+        GlobalSlashCommandRegistry::global(cx).0.clone()
+    }
+
+    /// Returns the global [`SlashCommandRegistry`].
+    ///
+    /// Inserts a default [`SlashCommandRegistry`] if one does not yet exist.
+    pub fn default_global(cx: &mut AppContext) -> Arc<Self> {
+        cx.default_global::<GlobalSlashCommandRegistry>().0.clone()
+    }
+
+    pub fn new() -> Arc<Self> {
+        Arc::new(Self {
+            state: RwLock::new(SlashCommandRegistryState {
+                commands: HashMap::default(),
+            }),
+        })
+    }
+
+    /// Registers the provided [`SlashCommand`].
+    pub fn register_command(&self, command: impl SlashCommand) {
+        self.state
+            .write()
+            .commands
+            .insert(command.name().into(), Arc::new(command));
+    }
+
+    /// Returns the names of registered [`SlashCommand`]s.
+    pub fn command_names(&self) -> Vec<Arc<str>> {
+        self.state.read().commands.keys().cloned().collect()
+    }
+
+    /// Returns the [`SlashCommand`] with the given name.
+    pub fn command(&self, name: &str) -> Option<Arc<dyn SlashCommand>> {
+        self.state.read().commands.get(name).cloned()
+    }
+}