From 8a96ea25c465697ec74ce3447bcd2ce9cb25b4f0 Mon Sep 17 00:00:00 2001 From: David Soria Parra <167242713+dsp-ant@users.noreply.github.com> Date: Mon, 28 Oct 2024 14:37:58 +0000 Subject: [PATCH] context_servers: Support tools (#19548) This PR depends on #19547 This PR adds support for tools from context servers. Context servers are free to expose tools that Zed can pass to models. When called by the model, Zed forwards the request to context servers. This allows for some interesting techniques. Context servers can easily expose tools such as querying local databases, reading or writing local files, reading resources over authenticated APIs (e.g. kubernetes, asana, etc). This is currently experimental. Things to discuss * I want to still add a confirm dialog asking people if a server is allows to use the tool. Should do this or just use the tool and assume trustworthyness of context servers? * Can we add tool use behind a local setting flag? Release Notes: - N/A --------- Co-authored-by: Marshall Bowers --- crates/assistant/src/assistant.rs | 85 ++++++++++++++----- crates/assistant/src/tools.rs | 1 + .../src/tools/context_server_tool.rs | 82 ++++++++++++++++++ crates/context_servers/src/protocol.rs | 33 +++++++ crates/context_servers/src/registry.rs | 32 +++++-- crates/context_servers/src/types.rs | 18 ++++ .../language_model/src/provider/anthropic.rs | 12 ++- 7 files changed, 235 insertions(+), 28 deletions(-) create mode 100644 crates/assistant/src/tools/context_server_tool.rs diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index e1e574744fff61a05da0a7ccb6e1ddff9162cb11..a48f6d6c29424a6de87ec038e8b42ba2726f6f79 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -298,25 +298,64 @@ fn register_context_server_handlers(cx: &mut AppContext) { return; }; - if let Some(prompts) = protocol.list_prompts().await.log_err() { - for prompt in prompts - .into_iter() - .filter(context_server_command::acceptable_prompt) - { - log::info!( - "registering context server command: {:?}", - prompt.name - ); - context_server_registry.register_command( - server.id.clone(), - prompt.name.as_str(), - ); - slash_command_registry.register_command( - context_server_command::ContextServerSlashCommand::new( - &server, prompt, - ), - true, - ); + if protocol.capable(context_servers::protocol::ServerCapability::Prompts) { + if let Some(prompts) = protocol.list_prompts().await.log_err() { + for prompt in prompts + .into_iter() + .filter(context_server_command::acceptable_prompt) + { + log::info!( + "registering context server command: {:?}", + prompt.name + ); + context_server_registry.register_command( + server.id.clone(), + prompt.name.as_str(), + ); + slash_command_registry.register_command( + context_server_command::ContextServerSlashCommand::new( + &server, prompt, + ), + true, + ); + } + } + } + }) + .detach(); + } + }, + ); + + cx.update_model( + &manager, + |manager: &mut context_servers::manager::ContextServerManager, cx| { + let tool_registry = ToolRegistry::global(cx); + let context_server_registry = ContextServerRegistry::global(cx); + if let Some(server) = manager.get_server(server_id) { + cx.spawn(|_, _| async move { + let Some(protocol) = server.client.read().clone() else { + return; + }; + + if protocol.capable(context_servers::protocol::ServerCapability::Tools) { + if let Some(tools) = protocol.list_tools().await.log_err() { + for tool in tools.tools { + log::info!( + "registering context server tool: {:?}", + tool.name + ); + context_server_registry.register_tool( + server.id.clone(), + tool.name.as_str(), + ); + tool_registry.register_tool( + tools::context_server_tool::ContextServerTool::new( + server.id.clone(), + tool + ), + ); + } } } }) @@ -334,6 +373,14 @@ fn register_context_server_handlers(cx: &mut AppContext) { context_server_registry.unregister_command(&server_id, &command_name); } } + + if let Some(tools) = context_server_registry.get_tools(server_id) { + let tool_registry = ToolRegistry::global(cx); + for tool_name in tools { + tool_registry.unregister_tool_by_name(&tool_name); + context_server_registry.unregister_tool(&server_id, &tool_name); + } + } } }, ) diff --git a/crates/assistant/src/tools.rs b/crates/assistant/src/tools.rs index abde04e760e3ee92e8d6e05fb503637734beadcd..83a396c0203cb24fb6053c857a6065ed500c2542 100644 --- a/crates/assistant/src/tools.rs +++ b/crates/assistant/src/tools.rs @@ -1 +1,2 @@ +pub mod context_server_tool; pub mod now_tool; diff --git a/crates/assistant/src/tools/context_server_tool.rs b/crates/assistant/src/tools/context_server_tool.rs new file mode 100644 index 0000000000000000000000000000000000000000..93edb32b75b72586347b8794615868dd881d3881 --- /dev/null +++ b/crates/assistant/src/tools/context_server_tool.rs @@ -0,0 +1,82 @@ +use anyhow::{anyhow, bail}; +use assistant_tool::Tool; +use context_servers::manager::ContextServerManager; +use context_servers::types; +use gpui::Task; + +pub struct ContextServerTool { + server_id: String, + tool: types::Tool, +} + +impl ContextServerTool { + pub fn new(server_id: impl Into, tool: types::Tool) -> Self { + Self { + server_id: server_id.into(), + tool, + } + } +} + +impl Tool for ContextServerTool { + fn name(&self) -> String { + self.tool.name.clone() + } + + fn description(&self) -> String { + self.tool.description.clone().unwrap_or_default() + } + + fn input_schema(&self) -> serde_json::Value { + match &self.tool.input_schema { + serde_json::Value::Null => { + serde_json::json!({ "type": "object", "properties": [] }) + } + serde_json::Value::Object(map) if map.is_empty() => { + serde_json::json!({ "type": "object", "properties": [] }) + } + _ => self.tool.input_schema.clone(), + } + } + + fn run( + self: std::sync::Arc, + input: serde_json::Value, + _workspace: gpui::WeakView, + cx: &mut ui::WindowContext, + ) -> gpui::Task> { + let manager = ContextServerManager::global(cx); + let manager = manager.read(cx); + if let Some(server) = manager.get_server(&self.server_id) { + cx.foreground_executor().spawn({ + let tool_name = self.tool.name.clone(); + async move { + let Some(protocol) = server.client.read().clone() else { + bail!("Context server not initialized"); + }; + + let arguments = if let serde_json::Value::Object(map) = input { + Some(map.into_iter().collect()) + } else { + None + }; + + log::trace!( + "Running tool: {} with arguments: {:?}", + tool_name, + arguments + ); + let response = protocol.run_tool(tool_name, arguments).await?; + + let tool_result = match response.tool_result { + serde_json::Value::String(s) => s, + _ => serde_json::to_string(&response.tool_result)?, + }; + Ok(tool_result) + } + }) + } else { + Task::ready(Err(anyhow!("Context server not found"))) + } + } +} diff --git a/crates/context_servers/src/protocol.rs b/crates/context_servers/src/protocol.rs index 80a7a7f991a23f5fe963ae54e836b3240b8844c5..996fc34f462c5f7e5ab3cdfa59fec1990643aa22 100644 --- a/crates/context_servers/src/protocol.rs +++ b/crates/context_servers/src/protocol.rs @@ -180,6 +180,39 @@ impl InitializedContextServerProtocol { Ok(completion) } + + /// List MCP tools. + pub async fn list_tools(&self) -> Result { + self.check_capability(ServerCapability::Tools)?; + + let response = self + .inner + .request::(types::RequestType::ListTools.as_str(), ()) + .await?; + + Ok(response) + } + + /// Executes a tool with the given arguments + pub async fn run_tool>( + &self, + tool: P, + arguments: Option>, + ) -> Result { + self.check_capability(ServerCapability::Tools)?; + + let params = types::CallToolParams { + name: tool.as_ref().to_string(), + arguments, + }; + + let response: types::CallToolResponse = self + .inner + .request(types::RequestType::CallTool.as_str(), params) + .await?; + + Ok(response) + } } impl InitializedContextServerProtocol { diff --git a/crates/context_servers/src/registry.rs b/crates/context_servers/src/registry.rs index 625f308c15228fc5f69795f601e87c30433fdaa5..5490187034972448152c377d369854a43702d29f 100644 --- a/crates/context_servers/src/registry.rs +++ b/crates/context_servers/src/registry.rs @@ -9,7 +9,8 @@ struct GlobalContextServerRegistry(Arc); impl Global for GlobalContextServerRegistry {} pub struct ContextServerRegistry { - registry: RwLock>>>, + command_registry: RwLock>>>, + tool_registry: RwLock>>>, } impl ContextServerRegistry { @@ -20,13 +21,14 @@ impl ContextServerRegistry { pub fn register(cx: &mut AppContext) { cx.set_global(GlobalContextServerRegistry(Arc::new( ContextServerRegistry { - registry: RwLock::new(HashMap::default()), + command_registry: RwLock::new(HashMap::default()), + tool_registry: RwLock::new(HashMap::default()), }, ))) } pub fn register_command(&self, server_id: String, command_name: &str) { - let mut registry = self.registry.write(); + let mut registry = self.command_registry.write(); registry .entry(server_id) .or_default() @@ -34,14 +36,34 @@ impl ContextServerRegistry { } pub fn unregister_command(&self, server_id: &str, command_name: &str) { - let mut registry = self.registry.write(); + let mut registry = self.command_registry.write(); if let Some(commands) = registry.get_mut(server_id) { commands.retain(|name| name.as_ref() != command_name); } } pub fn get_commands(&self, server_id: &str) -> Option>> { - let registry = self.registry.read(); + let registry = self.command_registry.read(); + registry.get(server_id).cloned() + } + + pub fn register_tool(&self, server_id: String, tool_name: &str) { + let mut registry = self.tool_registry.write(); + registry + .entry(server_id) + .or_default() + .push(tool_name.into()); + } + + pub fn unregister_tool(&self, server_id: &str, tool_name: &str) { + let mut registry = self.tool_registry.write(); + if let Some(tools) = registry.get_mut(server_id) { + tools.retain(|name| name.as_ref() != tool_name); + } + } + + pub fn get_tools(&self, server_id: &str) -> Option>> { + let registry = self.tool_registry.read(); registry.get(server_id).cloned() } } diff --git a/crates/context_servers/src/types.rs b/crates/context_servers/src/types.rs index 2bca0a021a129029b55d6371f9db98332418a7a5..b6d8a958bb1264c1e323dc9f13450ddf7551ff2f 100644 --- a/crates/context_servers/src/types.rs +++ b/crates/context_servers/src/types.rs @@ -16,6 +16,8 @@ pub enum RequestType { PromptsList, CompletionComplete, Ping, + ListTools, + ListResourceTemplates, } impl RequestType { @@ -32,6 +34,8 @@ impl RequestType { RequestType::PromptsList => "prompts/list", RequestType::CompletionComplete => "completion/complete", RequestType::Ping => "ping", + RequestType::ListTools => "tools/list", + RequestType::ListResourceTemplates => "resources/templates/list", } } } @@ -402,3 +406,17 @@ pub struct Completion { pub values: Vec, pub total: CompletionTotal, } + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CallToolResponse { + pub tool_result: serde_json::Value, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ListToolsResponse { + pub tools: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, +} diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index fe88c73b90deb6ee8a7af07497b0b59cab1fd7a5..b7e65650b55a3075fcb598d06fd027189ea0df31 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -505,10 +505,14 @@ pub fn map_to_language_model_completion_events( LanguageModelToolUse { id: tool_use.id, name: tool_use.name, - input: serde_json::Value::from_str( - &tool_use.input_json, - ) - .map_err(|err| anyhow!(err))?, + input: if tool_use.input_json.is_empty() { + serde_json::Value::Null + } else { + serde_json::Value::from_str( + &tool_use.input_json, + ) + .map_err(|err| anyhow!(err))? + }, }, )) })),