1use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream, ToolInput};
2use agent_client_protocol::ToolKind;
3use anyhow::Result;
4use collections::{BTreeMap, HashMap};
5use context_server::{ContextServerId, client::NotificationSubscription};
6use futures::FutureExt as _;
7use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task};
8use project::context_server_store::{ContextServerStatus, ContextServerStore};
9use std::sync::Arc;
10use util::ResultExt;
11
12/// Generates a tool ID for an MCP tool that can be used in settings.
13///
14/// The format is `mcp:<server_id>:<tool_name>` to avoid collisions with built-in tools.
15pub fn mcp_tool_id(server_id: &str, tool_name: &str) -> String {
16 format!("mcp:{}:{}", server_id, tool_name)
17}
18
19pub struct ContextServerPrompt {
20 pub server_id: ContextServerId,
21 pub prompt: context_server::types::Prompt,
22}
23
24pub enum ContextServerRegistryEvent {
25 ToolsChanged,
26 PromptsChanged,
27}
28
29impl EventEmitter<ContextServerRegistryEvent> for ContextServerRegistry {}
30
31pub struct ContextServerRegistry {
32 server_store: Entity<ContextServerStore>,
33 registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
34 _subscription: gpui::Subscription,
35}
36
37struct RegisteredContextServer {
38 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
39 prompts: BTreeMap<SharedString, ContextServerPrompt>,
40 load_tools: Task<Result<()>>,
41 load_prompts: Task<Result<()>>,
42 _tools_updated_subscription: Option<NotificationSubscription>,
43}
44
45impl ContextServerRegistry {
46 pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
47 let mut this = Self {
48 server_store: server_store.clone(),
49 registered_servers: HashMap::default(),
50 _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
51 };
52 for server in server_store.read(cx).running_servers() {
53 this.reload_tools_for_server(server.id(), cx);
54 this.reload_prompts_for_server(server.id(), cx);
55 }
56 this
57 }
58
59 pub fn tools_for_server(
60 &self,
61 server_id: &ContextServerId,
62 ) -> impl Iterator<Item = &Arc<dyn AnyAgentTool>> {
63 self.registered_servers
64 .get(server_id)
65 .map(|server| server.tools.values())
66 .into_iter()
67 .flatten()
68 }
69
70 pub fn servers(
71 &self,
72 ) -> impl Iterator<
73 Item = (
74 &ContextServerId,
75 &BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
76 ),
77 > {
78 self.registered_servers
79 .iter()
80 .map(|(id, server)| (id, &server.tools))
81 }
82
83 pub fn prompts(&self) -> impl Iterator<Item = &ContextServerPrompt> {
84 self.registered_servers
85 .values()
86 .flat_map(|server| server.prompts.values())
87 }
88
89 pub fn find_prompt(
90 &self,
91 server_id: Option<&ContextServerId>,
92 name: &str,
93 ) -> Option<&ContextServerPrompt> {
94 if let Some(server_id) = server_id {
95 self.registered_servers
96 .get(server_id)
97 .and_then(|server| server.prompts.get(name))
98 } else {
99 self.registered_servers
100 .values()
101 .find_map(|server| server.prompts.get(name))
102 }
103 }
104
105 pub fn server_store(&self) -> &Entity<ContextServerStore> {
106 &self.server_store
107 }
108
109 fn get_or_register_server(
110 &mut self,
111 server_id: &ContextServerId,
112 cx: &mut Context<Self>,
113 ) -> &mut RegisteredContextServer {
114 self.registered_servers
115 .entry(server_id.clone())
116 .or_insert_with(|| Self::init_registered_server(server_id, &self.server_store, cx))
117 }
118
119 fn init_registered_server(
120 server_id: &ContextServerId,
121 server_store: &Entity<ContextServerStore>,
122 cx: &mut Context<Self>,
123 ) -> RegisteredContextServer {
124 let tools_updated_subscription = server_store
125 .read(cx)
126 .get_running_server(server_id)
127 .and_then(|server| {
128 let client = server.client()?;
129
130 if !client.capable(context_server::protocol::ServerCapability::Tools) {
131 return None;
132 }
133
134 let server_id = server.id();
135 let this = cx.entity().downgrade();
136
137 Some(client.on_notification(
138 "notifications/tools/list_changed",
139 Box::new(move |_params, cx: AsyncApp| {
140 let server_id = server_id.clone();
141 let this = this.clone();
142 cx.spawn(async move |cx| {
143 this.update(cx, |this, cx| {
144 log::info!(
145 "Received tools/list_changed notification for server {}",
146 server_id
147 );
148 this.reload_tools_for_server(server_id, cx);
149 })
150 })
151 .detach();
152 }),
153 ))
154 });
155
156 RegisteredContextServer {
157 tools: BTreeMap::default(),
158 prompts: BTreeMap::default(),
159 load_tools: Task::ready(Ok(())),
160 load_prompts: Task::ready(Ok(())),
161 _tools_updated_subscription: tools_updated_subscription,
162 }
163 }
164
165 fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
166 let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
167 return;
168 };
169 let Some(client) = server.client() else {
170 return;
171 };
172
173 if !client.capable(context_server::protocol::ServerCapability::Tools) {
174 return;
175 }
176
177 let registered_server = self.get_or_register_server(&server_id, cx);
178 registered_server.load_tools = cx.spawn(async move |this, cx| {
179 let response = client
180 .request::<context_server::types::requests::ListTools>(())
181 .await;
182
183 this.update(cx, |this, cx| {
184 let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
185 return;
186 };
187
188 registered_server.tools.clear();
189 if let Some(response) = response.log_err() {
190 for tool in response.tools {
191 let tool = Arc::new(ContextServerTool::new(
192 this.server_store.clone(),
193 server.id(),
194 tool,
195 ));
196 registered_server.tools.insert(tool.name(), tool);
197 }
198 cx.emit(ContextServerRegistryEvent::ToolsChanged);
199 cx.notify();
200 }
201 })
202 });
203 }
204
205 fn reload_prompts_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
206 let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
207 return;
208 };
209 let Some(client) = server.client() else {
210 return;
211 };
212 if !client.capable(context_server::protocol::ServerCapability::Prompts) {
213 return;
214 }
215
216 let registered_server = self.get_or_register_server(&server_id, cx);
217
218 registered_server.load_prompts = cx.spawn(async move |this, cx| {
219 let response = client
220 .request::<context_server::types::requests::PromptsList>(())
221 .await;
222
223 this.update(cx, |this, cx| {
224 let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
225 return;
226 };
227
228 registered_server.prompts.clear();
229 if let Some(response) = response.log_err() {
230 for prompt in response.prompts {
231 let name: SharedString = prompt.name.clone().into();
232 registered_server.prompts.insert(
233 name,
234 ContextServerPrompt {
235 server_id: server_id.clone(),
236 prompt,
237 },
238 );
239 }
240 cx.emit(ContextServerRegistryEvent::PromptsChanged);
241 cx.notify();
242 }
243 })
244 });
245 }
246
247 fn handle_context_server_store_event(
248 &mut self,
249 _: Entity<ContextServerStore>,
250 event: &project::context_server_store::ServerStatusChangedEvent,
251 cx: &mut Context<Self>,
252 ) {
253 let project::context_server_store::ServerStatusChangedEvent { server_id, status } = event;
254
255 match status {
256 ContextServerStatus::Starting | ContextServerStatus::Authenticating => {}
257 ContextServerStatus::Running => {
258 self.reload_tools_for_server(server_id.clone(), cx);
259 self.reload_prompts_for_server(server_id.clone(), cx);
260 }
261 ContextServerStatus::Stopped
262 | ContextServerStatus::Error(_)
263 | ContextServerStatus::AuthRequired => {
264 if let Some(registered_server) = self.registered_servers.remove(server_id) {
265 if !registered_server.tools.is_empty() {
266 cx.emit(ContextServerRegistryEvent::ToolsChanged);
267 }
268 if !registered_server.prompts.is_empty() {
269 cx.emit(ContextServerRegistryEvent::PromptsChanged);
270 }
271 }
272 cx.notify();
273 }
274 };
275 }
276}
277
278struct ContextServerTool {
279 store: Entity<ContextServerStore>,
280 server_id: ContextServerId,
281 tool: context_server::types::Tool,
282}
283
284impl ContextServerTool {
285 fn new(
286 store: Entity<ContextServerStore>,
287 server_id: ContextServerId,
288 tool: context_server::types::Tool,
289 ) -> Self {
290 Self {
291 store,
292 server_id,
293 tool,
294 }
295 }
296}
297
298impl AnyAgentTool for ContextServerTool {
299 fn name(&self) -> SharedString {
300 self.tool.name.clone().into()
301 }
302
303 fn description(&self) -> SharedString {
304 self.tool.description.clone().unwrap_or_default().into()
305 }
306
307 fn kind(&self) -> ToolKind {
308 ToolKind::Other
309 }
310
311 fn initial_title(&self, _input: serde_json::Value, _cx: &mut App) -> SharedString {
312 format!("Run MCP tool `{}`", self.tool.name).into()
313 }
314
315 fn input_schema(
316 &self,
317 format: language_model::LanguageModelToolSchemaFormat,
318 ) -> Result<serde_json::Value> {
319 let mut schema = self.tool.input_schema.clone();
320 language_model::tool_schema::adapt_schema_to_format(&mut schema, format)?;
321 Ok(match schema {
322 serde_json::Value::Null => {
323 serde_json::json!({ "type": "object", "properties": [] })
324 }
325 serde_json::Value::Object(map) if map.is_empty() => {
326 serde_json::json!({ "type": "object", "properties": [] })
327 }
328 _ => schema,
329 })
330 }
331
332 fn run(
333 self: Arc<Self>,
334 input: ToolInput<serde_json::Value>,
335 event_stream: ToolCallEventStream,
336 cx: &mut App,
337 ) -> Task<Result<AgentToolOutput, AgentToolOutput>> {
338 let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
339 return Task::ready(Err(AgentToolOutput::from_error("Context server not found")));
340 };
341 let tool_name = self.tool.name.clone();
342 let tool_id = mcp_tool_id(&self.server_id.0, &self.tool.name);
343 let display_name = self.tool.name.clone();
344 let initial_title = self.initial_title(serde_json::Value::Null, cx);
345 let authorize =
346 event_stream.authorize_third_party_tool(initial_title, tool_id, display_name, cx);
347
348 cx.spawn(async move |_cx| {
349 let input = input.recv().await.map_err(|e| {
350 AgentToolOutput::from_error(format!("Failed to receive tool input: {e}"))
351 })?;
352
353 authorize.await.map_err(|e| AgentToolOutput::from_error(e.to_string()))?;
354
355 let Some(protocol) = server.client() else {
356 return Err(AgentToolOutput::from_error("Context server not initialized"));
357 };
358
359 let arguments = if let serde_json::Value::Object(map) = input {
360 Some(map.into_iter().collect())
361 } else {
362 None
363 };
364
365 log::trace!(
366 "Running tool: {} with arguments: {:?}",
367 tool_name,
368 arguments
369 );
370
371 let request = protocol.request::<context_server::types::requests::CallTool>(
372 context_server::types::CallToolParams {
373 name: tool_name,
374 arguments,
375 meta: None,
376 },
377 );
378
379 let response = futures::select! {
380 response = request.fuse() => response.map_err(|e| AgentToolOutput::from_error(e.to_string()))?,
381 _ = event_stream.cancelled_by_user().fuse() => {
382 return Err(AgentToolOutput::from_error("MCP tool cancelled by user"));
383 }
384 };
385
386 if response.is_error == Some(true) {
387 let error_message: String =
388 response.content.iter().filter_map(|c| c.text()).collect();
389 return Err(AgentToolOutput::from_error(error_message));
390 }
391
392 let mut result = String::new();
393 for content in response.content {
394 match content {
395 context_server::types::ToolResponseContent::Text { text } => {
396 result.push_str(&text);
397 }
398 context_server::types::ToolResponseContent::Image { .. } => {
399 log::warn!("Ignoring image content from tool response");
400 }
401 context_server::types::ToolResponseContent::Audio { .. } => {
402 log::warn!("Ignoring audio content from tool response");
403 }
404 context_server::types::ToolResponseContent::Resource { .. } => {
405 log::warn!("Ignoring resource content from tool response");
406 }
407 }
408 }
409 Ok(AgentToolOutput {
410 raw_output: result.clone().into(),
411 llm_output: result.into(),
412 })
413 })
414 }
415
416 fn replay(
417 &self,
418 _input: serde_json::Value,
419 _output: serde_json::Value,
420 _event_stream: ToolCallEventStream,
421 _cx: &mut App,
422 ) -> Result<()> {
423 Ok(())
424 }
425}
426
427pub fn get_prompt(
428 server_store: &Entity<ContextServerStore>,
429 server_id: &ContextServerId,
430 prompt_name: &str,
431 arguments: HashMap<String, String>,
432 cx: &mut AsyncApp,
433) -> Task<Result<context_server::types::PromptsGetResponse>> {
434 let server = cx.update(|cx| server_store.read(cx).get_running_server(server_id));
435 let Some(server) = server else {
436 return Task::ready(Err(anyhow::anyhow!("Context server not found")));
437 };
438
439 let Some(protocol) = server.client() else {
440 return Task::ready(Err(anyhow::anyhow!("Context server not initialized")));
441 };
442
443 let prompt_name = prompt_name.to_string();
444
445 cx.background_spawn(async move {
446 let response = protocol
447 .request::<context_server::types::requests::PromptsGet>(
448 context_server::types::PromptsGetParams {
449 name: prompt_name,
450 arguments: (!arguments.is_empty()).then(|| arguments),
451 meta: None,
452 },
453 )
454 .await?;
455
456 Ok(response)
457 })
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463
464 #[test]
465 fn test_mcp_tool_id_format() {
466 assert_eq!(
467 mcp_tool_id("filesystem", "read_file"),
468 "mcp:filesystem:read_file"
469 );
470 assert_eq!(
471 mcp_tool_id("github", "create_issue"),
472 "mcp:github:create_issue"
473 );
474 assert_eq!(
475 mcp_tool_id("my-custom-server", "do_something"),
476 "mcp:my-custom-server:do_something"
477 );
478 // Underscores in names
479 assert_eq!(mcp_tool_id("my_server", "my_tool"), "mcp:my_server:my_tool");
480 }
481
482 // Note: Tests for MCP tool ID collision with built-in tools and permission
483 // decisions are in crates/agent/src/tool_permissions.rs to avoid duplication.
484}