assistant2: Add support for using tools provided by context servers (#21418)

Marshall Bowers and Cole created

This PR adds support to Assistant 2 for using tools provided by context
servers.

As part of this I introduced a new `ThreadStore`.

Release Notes:

- N/A

---------

Co-authored-by: Cole <cole@zed.dev>

Change summary

Cargo.lock                               |   3 
crates/assistant2/Cargo.toml             |   3 
crates/assistant2/src/assistant.rs       |   1 
crates/assistant2/src/assistant_panel.rs |  20 ++++
crates/assistant2/src/thread_store.rs    | 114 +++++++++++++++++++++++++
5 files changed, 139 insertions(+), 2 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -458,12 +458,15 @@ dependencies = [
  "assistant_tool",
  "collections",
  "command_palette_hooks",
+ "context_server",
  "editor",
  "feature_flags",
  "futures 0.3.31",
  "gpui",
  "language_model",
  "language_model_selector",
+ "log",
+ "project",
  "proto",
  "serde",
  "serde_json",

crates/assistant2/Cargo.toml 🔗

@@ -17,12 +17,15 @@ anyhow.workspace = true
 assistant_tool.workspace = true
 collections.workspace = true
 command_palette_hooks.workspace = true
+context_server.workspace = true
 editor.workspace = true
 feature_flags.workspace = true
 futures.workspace = true
 gpui.workspace = true
 language_model.workspace = true
 language_model_selector.workspace = true
+log.workspace = true
+project.workspace = true
 proto.workspace = true
 settings.workspace = true
 serde.workspace = true

crates/assistant2/src/assistant.rs 🔗

@@ -1,6 +1,7 @@
 mod assistant_panel;
 mod message_editor;
 mod thread;
+mod thread_store;
 
 use command_palette_hooks::CommandPaletteFilter;
 use feature_flags::{Assistant2FeatureFlag, FeatureFlagAppExt};

crates/assistant2/src/assistant_panel.rs 🔗

@@ -14,6 +14,7 @@ use workspace::Workspace;
 
 use crate::message_editor::MessageEditor;
 use crate::thread::{Message, Thread, ThreadEvent};
+use crate::thread_store::ThreadStore;
 use crate::{NewThread, ToggleFocus, ToggleModelSelector};
 
 pub fn init(cx: &mut AppContext) {
@@ -29,6 +30,8 @@ pub fn init(cx: &mut AppContext) {
 
 pub struct AssistantPanel {
     workspace: WeakView<Workspace>,
+    #[allow(unused)]
+    thread_store: Model<ThreadStore>,
     thread: Model<Thread>,
     message_editor: View<MessageEditor>,
     tools: Arc<ToolWorkingSet>,
@@ -42,13 +45,25 @@ impl AssistantPanel {
     ) -> Task<Result<View<Self>>> {
         cx.spawn(|mut cx| async move {
             let tools = Arc::new(ToolWorkingSet::default());
+            let thread_store = workspace
+                .update(&mut cx, |workspace, cx| {
+                    let project = workspace.project().clone();
+                    ThreadStore::new(project, tools.clone(), cx)
+                })?
+                .await?;
+
             workspace.update(&mut cx, |workspace, cx| {
-                cx.new_view(|cx| Self::new(workspace, tools, cx))
+                cx.new_view(|cx| Self::new(workspace, thread_store, tools, cx))
             })
         })
     }
 
-    fn new(workspace: &Workspace, tools: Arc<ToolWorkingSet>, cx: &mut ViewContext<Self>) -> Self {
+    fn new(
+        workspace: &Workspace,
+        thread_store: Model<ThreadStore>,
+        tools: Arc<ToolWorkingSet>,
+        cx: &mut ViewContext<Self>,
+    ) -> Self {
         let thread = cx.new_model(|cx| Thread::new(tools.clone(), cx));
         let subscriptions = vec![
             cx.observe(&thread, |_, _, cx| cx.notify()),
@@ -57,6 +72,7 @@ impl AssistantPanel {
 
         Self {
             workspace: workspace.weak_handle(),
+            thread_store,
             thread: thread.clone(),
             message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)),
             tools,

crates/assistant2/src/thread_store.rs 🔗

@@ -0,0 +1,114 @@
+use std::sync::Arc;
+
+use anyhow::Result;
+use assistant_tool::{ToolId, ToolWorkingSet};
+use collections::HashMap;
+use context_server::manager::ContextServerManager;
+use context_server::{ContextServerFactoryRegistry, ContextServerTool};
+use gpui::{prelude::*, AppContext, Model, ModelContext, Task};
+use project::Project;
+use util::ResultExt as _;
+
+pub struct ThreadStore {
+    #[allow(unused)]
+    project: Model<Project>,
+    tools: Arc<ToolWorkingSet>,
+    context_server_manager: Model<ContextServerManager>,
+    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
+}
+
+impl ThreadStore {
+    pub fn new(
+        project: Model<Project>,
+        tools: Arc<ToolWorkingSet>,
+        cx: &mut AppContext,
+    ) -> Task<Result<Model<Self>>> {
+        cx.spawn(|mut cx| async move {
+            let this = cx.new_model(|cx: &mut ModelContext<Self>| {
+                let context_server_factory_registry =
+                    ContextServerFactoryRegistry::default_global(cx);
+                let context_server_manager = cx.new_model(|cx| {
+                    ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
+                });
+
+                let this = Self {
+                    project,
+                    tools,
+                    context_server_manager,
+                    context_server_tool_ids: HashMap::default(),
+                };
+                this.register_context_server_handlers(cx);
+
+                this
+            })?;
+
+            Ok(this)
+        })
+    }
+
+    fn register_context_server_handlers(&self, cx: &mut ModelContext<Self>) {
+        cx.subscribe(
+            &self.context_server_manager.clone(),
+            Self::handle_context_server_event,
+        )
+        .detach();
+    }
+
+    fn handle_context_server_event(
+        &mut self,
+        context_server_manager: Model<ContextServerManager>,
+        event: &context_server::manager::Event,
+        cx: &mut ModelContext<Self>,
+    ) {
+        let tool_working_set = self.tools.clone();
+        match event {
+            context_server::manager::Event::ServerStarted { server_id } => {
+                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
+                    let context_server_manager = context_server_manager.clone();
+                    cx.spawn({
+                        let server = server.clone();
+                        let server_id = server_id.clone();
+                        |this, mut cx| async move {
+                            let Some(protocol) = server.client() else {
+                                return;
+                            };
+
+                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
+                                if let Some(tools) = protocol.list_tools().await.log_err() {
+                                    let tool_ids = tools
+                                        .tools
+                                        .into_iter()
+                                        .map(|tool| {
+                                            log::info!(
+                                                "registering context server tool: {:?}",
+                                                tool.name
+                                            );
+                                            tool_working_set.insert(Arc::new(
+                                                ContextServerTool::new(
+                                                    context_server_manager.clone(),
+                                                    server.id(),
+                                                    tool,
+                                                ),
+                                            ))
+                                        })
+                                        .collect::<Vec<_>>();
+
+                                    this.update(&mut cx, |this, _cx| {
+                                        this.context_server_tool_ids.insert(server_id, tool_ids);
+                                    })
+                                    .log_err();
+                                }
+                            }
+                        }
+                    })
+                    .detach();
+                }
+            }
+            context_server::manager::Event::ServerStopped { server_id } => {
+                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
+                    tool_working_set.remove(&tool_ids);
+                }
+            }
+        }
+    }
+}