Detailed changes
@@ -298,25 +298,64 @@ fn register_context_server_handlers(cx: &mut AppContext) {
return;
};
- if let Some(prompts) = protocol.list_prompts().await.log_err() {
- for prompt in prompts
- .into_iter()
- .filter(context_server_command::acceptable_prompt)
- {
- log::info!(
- "registering context server command: {:?}",
- prompt.name
- );
- context_server_registry.register_command(
- server.id.clone(),
- prompt.name.as_str(),
- );
- slash_command_registry.register_command(
- context_server_command::ContextServerSlashCommand::new(
- &server, prompt,
- ),
- true,
- );
+ if protocol.capable(context_servers::protocol::ServerCapability::Prompts) {
+ if let Some(prompts) = protocol.list_prompts().await.log_err() {
+ for prompt in prompts
+ .into_iter()
+ .filter(context_server_command::acceptable_prompt)
+ {
+ log::info!(
+ "registering context server command: {:?}",
+ prompt.name
+ );
+ context_server_registry.register_command(
+ server.id.clone(),
+ prompt.name.as_str(),
+ );
+ slash_command_registry.register_command(
+ context_server_command::ContextServerSlashCommand::new(
+ &server, prompt,
+ ),
+ true,
+ );
+ }
+ }
+ }
+ })
+ .detach();
+ }
+ },
+ );
+
+ cx.update_model(
+ &manager,
+ |manager: &mut context_servers::manager::ContextServerManager, cx| {
+ let tool_registry = ToolRegistry::global(cx);
+ let context_server_registry = ContextServerRegistry::global(cx);
+ if let Some(server) = manager.get_server(server_id) {
+ cx.spawn(|_, _| async move {
+ let Some(protocol) = server.client.read().clone() else {
+ return;
+ };
+
+ if protocol.capable(context_servers::protocol::ServerCapability::Tools) {
+ if let Some(tools) = protocol.list_tools().await.log_err() {
+ for tool in tools.tools {
+ log::info!(
+ "registering context server tool: {:?}",
+ tool.name
+ );
+ context_server_registry.register_tool(
+ server.id.clone(),
+ tool.name.as_str(),
+ );
+ tool_registry.register_tool(
+ tools::context_server_tool::ContextServerTool::new(
+ server.id.clone(),
+ tool
+ ),
+ );
+ }
}
}
})
@@ -334,6 +373,14 @@ fn register_context_server_handlers(cx: &mut AppContext) {
context_server_registry.unregister_command(&server_id, &command_name);
}
}
+
+ if let Some(tools) = context_server_registry.get_tools(server_id) {
+ let tool_registry = ToolRegistry::global(cx);
+ for tool_name in tools {
+ tool_registry.unregister_tool_by_name(&tool_name);
+ context_server_registry.unregister_tool(&server_id, &tool_name);
+ }
+ }
}
},
)
@@ -1 +1,2 @@
+pub mod context_server_tool;
pub mod now_tool;
@@ -0,0 +1,82 @@
+use anyhow::{anyhow, bail};
+use assistant_tool::Tool;
+use context_servers::manager::ContextServerManager;
+use context_servers::types;
+use gpui::Task;
+
+pub struct ContextServerTool {
+ server_id: String,
+ tool: types::Tool,
+}
+
+impl ContextServerTool {
+ pub fn new(server_id: impl Into<String>, tool: types::Tool) -> Self {
+ Self {
+ server_id: server_id.into(),
+ tool,
+ }
+ }
+}
+
+impl Tool for ContextServerTool {
+ fn name(&self) -> String {
+ self.tool.name.clone()
+ }
+
+ fn description(&self) -> String {
+ self.tool.description.clone().unwrap_or_default()
+ }
+
+ fn input_schema(&self) -> serde_json::Value {
+ match &self.tool.input_schema {
+ serde_json::Value::Null => {
+ serde_json::json!({ "type": "object", "properties": [] })
+ }
+ serde_json::Value::Object(map) if map.is_empty() => {
+ serde_json::json!({ "type": "object", "properties": [] })
+ }
+ _ => self.tool.input_schema.clone(),
+ }
+ }
+
+ fn run(
+ self: std::sync::Arc<Self>,
+ input: serde_json::Value,
+ _workspace: gpui::WeakView<workspace::Workspace>,
+ cx: &mut ui::WindowContext,
+ ) -> gpui::Task<gpui::Result<String>> {
+ let manager = ContextServerManager::global(cx);
+ let manager = manager.read(cx);
+ if let Some(server) = manager.get_server(&self.server_id) {
+ cx.foreground_executor().spawn({
+ let tool_name = self.tool.name.clone();
+ async move {
+ let Some(protocol) = server.client.read().clone() else {
+ bail!("Context server not initialized");
+ };
+
+ let arguments = if let serde_json::Value::Object(map) = input {
+ Some(map.into_iter().collect())
+ } else {
+ None
+ };
+
+ log::trace!(
+ "Running tool: {} with arguments: {:?}",
+ tool_name,
+ arguments
+ );
+ let response = protocol.run_tool(tool_name, arguments).await?;
+
+ let tool_result = match response.tool_result {
+ serde_json::Value::String(s) => s,
+ _ => serde_json::to_string(&response.tool_result)?,
+ };
+ Ok(tool_result)
+ }
+ })
+ } else {
+ Task::ready(Err(anyhow!("Context server not found")))
+ }
+ }
+}
@@ -180,6 +180,39 @@ impl InitializedContextServerProtocol {
Ok(completion)
}
+
+ /// List MCP tools.
+ pub async fn list_tools(&self) -> Result<types::ListToolsResponse> {
+ self.check_capability(ServerCapability::Tools)?;
+
+ let response = self
+ .inner
+ .request::<types::ListToolsResponse>(types::RequestType::ListTools.as_str(), ())
+ .await?;
+
+ Ok(response)
+ }
+
+ /// Executes a tool with the given arguments
+ pub async fn run_tool<P: AsRef<str>>(
+ &self,
+ tool: P,
+ arguments: Option<HashMap<String, serde_json::Value>>,
+ ) -> Result<types::CallToolResponse> {
+ self.check_capability(ServerCapability::Tools)?;
+
+ let params = types::CallToolParams {
+ name: tool.as_ref().to_string(),
+ arguments,
+ };
+
+ let response: types::CallToolResponse = self
+ .inner
+ .request(types::RequestType::CallTool.as_str(), params)
+ .await?;
+
+ Ok(response)
+ }
}
impl InitializedContextServerProtocol {
@@ -9,7 +9,8 @@ struct GlobalContextServerRegistry(Arc<ContextServerRegistry>);
impl Global for GlobalContextServerRegistry {}
pub struct ContextServerRegistry {
- registry: RwLock<HashMap<String, Vec<Arc<str>>>>,
+ command_registry: RwLock<HashMap<String, Vec<Arc<str>>>>,
+ tool_registry: RwLock<HashMap<String, Vec<Arc<str>>>>,
}
impl ContextServerRegistry {
@@ -20,13 +21,14 @@ impl ContextServerRegistry {
pub fn register(cx: &mut AppContext) {
cx.set_global(GlobalContextServerRegistry(Arc::new(
ContextServerRegistry {
- registry: RwLock::new(HashMap::default()),
+ command_registry: RwLock::new(HashMap::default()),
+ tool_registry: RwLock::new(HashMap::default()),
},
)))
}
pub fn register_command(&self, server_id: String, command_name: &str) {
- let mut registry = self.registry.write();
+ let mut registry = self.command_registry.write();
registry
.entry(server_id)
.or_default()
@@ -34,14 +36,34 @@ impl ContextServerRegistry {
}
pub fn unregister_command(&self, server_id: &str, command_name: &str) {
- let mut registry = self.registry.write();
+ let mut registry = self.command_registry.write();
if let Some(commands) = registry.get_mut(server_id) {
commands.retain(|name| name.as_ref() != command_name);
}
}
pub fn get_commands(&self, server_id: &str) -> Option<Vec<Arc<str>>> {
- let registry = self.registry.read();
+ let registry = self.command_registry.read();
+ registry.get(server_id).cloned()
+ }
+
+ pub fn register_tool(&self, server_id: String, tool_name: &str) {
+ let mut registry = self.tool_registry.write();
+ registry
+ .entry(server_id)
+ .or_default()
+ .push(tool_name.into());
+ }
+
+ pub fn unregister_tool(&self, server_id: &str, tool_name: &str) {
+ let mut registry = self.tool_registry.write();
+ if let Some(tools) = registry.get_mut(server_id) {
+ tools.retain(|name| name.as_ref() != tool_name);
+ }
+ }
+
+ pub fn get_tools(&self, server_id: &str) -> Option<Vec<Arc<str>>> {
+ let registry = self.tool_registry.read();
registry.get(server_id).cloned()
}
}
@@ -16,6 +16,8 @@ pub enum RequestType {
PromptsList,
CompletionComplete,
Ping,
+ ListTools,
+ ListResourceTemplates,
}
impl RequestType {
@@ -32,6 +34,8 @@ impl RequestType {
RequestType::PromptsList => "prompts/list",
RequestType::CompletionComplete => "completion/complete",
RequestType::Ping => "ping",
+ RequestType::ListTools => "tools/list",
+ RequestType::ListResourceTemplates => "resources/templates/list",
}
}
}
@@ -402,3 +406,17 @@ pub struct Completion {
pub values: Vec<String>,
pub total: CompletionTotal,
}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CallToolResponse {
+ pub tool_result: serde_json::Value,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ListToolsResponse {
+ pub tools: Vec<Tool>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub next_cursor: Option<String>,
+}
@@ -505,10 +505,14 @@ pub fn map_to_language_model_completion_events(
LanguageModelToolUse {
id: tool_use.id,
name: tool_use.name,
- input: serde_json::Value::from_str(
- &tool_use.input_json,
- )
- .map_err(|err| anyhow!(err))?,
+ input: if tool_use.input_json.is_empty() {
+ serde_json::Value::Null
+ } else {
+ serde_json::Value::from_str(
+ &tool_use.input_json,
+ )
+ .map_err(|err| anyhow!(err))?
+ },
},
))
})),