1use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
2use agent_client_protocol::ToolKind;
3use anyhow::{Result, anyhow};
4use collections::{BTreeMap, HashMap};
5use context_server::{ContextServerId, client::NotificationSubscription};
6use futures::FutureExt as _;
7use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task};
8use project::context_server_store::{ContextServerStatus, ContextServerStore};
9use std::sync::Arc;
10use util::ResultExt;
11
12/// Generates a tool ID for an MCP tool that can be used in settings.
13///
14/// The format is `mcp:<server_id>:<tool_name>` to avoid collisions with built-in tools.
15pub fn mcp_tool_id(server_id: &str, tool_name: &str) -> String {
16 format!("mcp:{}:{}", server_id, tool_name)
17}
18
19pub struct ContextServerPrompt {
20 pub server_id: ContextServerId,
21 pub prompt: context_server::types::Prompt,
22}
23
24pub enum ContextServerRegistryEvent {
25 ToolsChanged,
26 PromptsChanged,
27}
28
29impl EventEmitter<ContextServerRegistryEvent> for ContextServerRegistry {}
30
31pub struct ContextServerRegistry {
32 server_store: Entity<ContextServerStore>,
33 registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
34 _subscription: gpui::Subscription,
35}
36
37struct RegisteredContextServer {
38 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
39 prompts: BTreeMap<SharedString, ContextServerPrompt>,
40 load_tools: Task<Result<()>>,
41 load_prompts: Task<Result<()>>,
42 _tools_updated_subscription: Option<NotificationSubscription>,
43}
44
45impl ContextServerRegistry {
46 pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
47 let mut this = Self {
48 server_store: server_store.clone(),
49 registered_servers: HashMap::default(),
50 _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
51 };
52 for server in server_store.read(cx).running_servers() {
53 this.reload_tools_for_server(server.id(), cx);
54 this.reload_prompts_for_server(server.id(), cx);
55 }
56 this
57 }
58
59 pub fn tools_for_server(
60 &self,
61 server_id: &ContextServerId,
62 ) -> impl Iterator<Item = &Arc<dyn AnyAgentTool>> {
63 self.registered_servers
64 .get(server_id)
65 .map(|server| server.tools.values())
66 .into_iter()
67 .flatten()
68 }
69
70 pub fn servers(
71 &self,
72 ) -> impl Iterator<
73 Item = (
74 &ContextServerId,
75 &BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
76 ),
77 > {
78 self.registered_servers
79 .iter()
80 .map(|(id, server)| (id, &server.tools))
81 }
82
83 pub fn prompts(&self) -> impl Iterator<Item = &ContextServerPrompt> {
84 self.registered_servers
85 .values()
86 .flat_map(|server| server.prompts.values())
87 }
88
89 pub fn find_prompt(
90 &self,
91 server_id: Option<&ContextServerId>,
92 name: &str,
93 ) -> Option<&ContextServerPrompt> {
94 if let Some(server_id) = server_id {
95 self.registered_servers
96 .get(server_id)
97 .and_then(|server| server.prompts.get(name))
98 } else {
99 self.registered_servers
100 .values()
101 .find_map(|server| server.prompts.get(name))
102 }
103 }
104
105 pub fn server_store(&self) -> &Entity<ContextServerStore> {
106 &self.server_store
107 }
108
109 fn get_or_register_server(
110 &mut self,
111 server_id: &ContextServerId,
112 cx: &mut Context<Self>,
113 ) -> &mut RegisteredContextServer {
114 self.registered_servers
115 .entry(server_id.clone())
116 .or_insert_with(|| Self::init_registered_server(server_id, &self.server_store, cx))
117 }
118
119 fn init_registered_server(
120 server_id: &ContextServerId,
121 server_store: &Entity<ContextServerStore>,
122 cx: &mut Context<Self>,
123 ) -> RegisteredContextServer {
124 let tools_updated_subscription = server_store
125 .read(cx)
126 .get_running_server(server_id)
127 .and_then(|server| {
128 let client = server.client()?;
129
130 if !client.capable(context_server::protocol::ServerCapability::Tools) {
131 return None;
132 }
133
134 let server_id = server.id();
135 let this = cx.entity().downgrade();
136
137 Some(client.on_notification(
138 "notifications/tools/list_changed",
139 Box::new(move |_params, cx: AsyncApp| {
140 let server_id = server_id.clone();
141 let this = this.clone();
142 cx.spawn(async move |cx| {
143 this.update(cx, |this, cx| {
144 log::info!(
145 "Received tools/list_changed notification for server {}",
146 server_id
147 );
148 this.reload_tools_for_server(server_id, cx);
149 })
150 })
151 .detach();
152 }),
153 ))
154 });
155
156 RegisteredContextServer {
157 tools: BTreeMap::default(),
158 prompts: BTreeMap::default(),
159 load_tools: Task::ready(Ok(())),
160 load_prompts: Task::ready(Ok(())),
161 _tools_updated_subscription: tools_updated_subscription,
162 }
163 }
164
165 fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
166 let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
167 return;
168 };
169 let Some(client) = server.client() else {
170 return;
171 };
172
173 if !client.capable(context_server::protocol::ServerCapability::Tools) {
174 return;
175 }
176
177 let registered_server = self.get_or_register_server(&server_id, cx);
178 registered_server.load_tools = cx.spawn(async move |this, cx| {
179 let response = client
180 .request::<context_server::types::requests::ListTools>(())
181 .await;
182
183 this.update(cx, |this, cx| {
184 let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
185 return;
186 };
187
188 registered_server.tools.clear();
189 if let Some(response) = response.log_err() {
190 for tool in response.tools {
191 let tool = Arc::new(ContextServerTool::new(
192 this.server_store.clone(),
193 server.id(),
194 tool,
195 ));
196 registered_server.tools.insert(tool.name(), tool);
197 }
198 cx.emit(ContextServerRegistryEvent::ToolsChanged);
199 cx.notify();
200 }
201 })
202 });
203 }
204
205 fn reload_prompts_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
206 let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
207 return;
208 };
209 let Some(client) = server.client() else {
210 return;
211 };
212 if !client.capable(context_server::protocol::ServerCapability::Prompts) {
213 return;
214 }
215
216 let registered_server = self.get_or_register_server(&server_id, cx);
217
218 registered_server.load_prompts = cx.spawn(async move |this, cx| {
219 let response = client
220 .request::<context_server::types::requests::PromptsList>(())
221 .await;
222
223 this.update(cx, |this, cx| {
224 let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
225 return;
226 };
227
228 registered_server.prompts.clear();
229 if let Some(response) = response.log_err() {
230 for prompt in response.prompts {
231 let name: SharedString = prompt.name.clone().into();
232 registered_server.prompts.insert(
233 name,
234 ContextServerPrompt {
235 server_id: server_id.clone(),
236 prompt,
237 },
238 );
239 }
240 cx.emit(ContextServerRegistryEvent::PromptsChanged);
241 cx.notify();
242 }
243 })
244 });
245 }
246
247 fn handle_context_server_store_event(
248 &mut self,
249 _: Entity<ContextServerStore>,
250 event: &project::context_server_store::ServerStatusChangedEvent,
251 cx: &mut Context<Self>,
252 ) {
253 let project::context_server_store::ServerStatusChangedEvent { server_id, status } = event;
254
255 match status {
256 ContextServerStatus::Starting => {}
257 ContextServerStatus::Running => {
258 self.reload_tools_for_server(server_id.clone(), cx);
259 self.reload_prompts_for_server(server_id.clone(), cx);
260 }
261 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
262 if let Some(registered_server) = self.registered_servers.remove(server_id) {
263 if !registered_server.tools.is_empty() {
264 cx.emit(ContextServerRegistryEvent::ToolsChanged);
265 }
266 if !registered_server.prompts.is_empty() {
267 cx.emit(ContextServerRegistryEvent::PromptsChanged);
268 }
269 }
270 cx.notify();
271 }
272 };
273 }
274}
275
276struct ContextServerTool {
277 store: Entity<ContextServerStore>,
278 server_id: ContextServerId,
279 tool: context_server::types::Tool,
280}
281
282impl ContextServerTool {
283 fn new(
284 store: Entity<ContextServerStore>,
285 server_id: ContextServerId,
286 tool: context_server::types::Tool,
287 ) -> Self {
288 Self {
289 store,
290 server_id,
291 tool,
292 }
293 }
294}
295
296impl AnyAgentTool for ContextServerTool {
297 fn name(&self) -> SharedString {
298 self.tool.name.clone().into()
299 }
300
301 fn description(&self) -> SharedString {
302 self.tool.description.clone().unwrap_or_default().into()
303 }
304
305 fn kind(&self) -> ToolKind {
306 ToolKind::Other
307 }
308
309 fn initial_title(&self, _input: serde_json::Value, _cx: &mut App) -> SharedString {
310 format!("Run MCP tool `{}`", self.tool.name).into()
311 }
312
313 fn input_schema(
314 &self,
315 format: language_model::LanguageModelToolSchemaFormat,
316 ) -> Result<serde_json::Value> {
317 let mut schema = self.tool.input_schema.clone();
318 language_model::tool_schema::adapt_schema_to_format(&mut schema, format)?;
319 Ok(match schema {
320 serde_json::Value::Null => {
321 serde_json::json!({ "type": "object", "properties": [] })
322 }
323 serde_json::Value::Object(map) if map.is_empty() => {
324 serde_json::json!({ "type": "object", "properties": [] })
325 }
326 _ => schema,
327 })
328 }
329
330 fn run(
331 self: Arc<Self>,
332 input: serde_json::Value,
333 event_stream: ToolCallEventStream,
334 cx: &mut App,
335 ) -> Task<Result<AgentToolOutput>> {
336 let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
337 return Task::ready(Err(anyhow!("Context server not found")));
338 };
339 let tool_name = self.tool.name.clone();
340 let tool_id = mcp_tool_id(&self.server_id.0, &self.tool.name);
341 let display_name = self.tool.name.clone();
342 let authorize = event_stream.authorize_third_party_tool(
343 self.initial_title(input.clone(), cx),
344 tool_id,
345 display_name,
346 cx,
347 );
348
349 cx.spawn(async move |_cx| {
350 authorize.await?;
351
352 let Some(protocol) = server.client() else {
353 anyhow::bail!("Context server not initialized");
354 };
355
356 let arguments = if let serde_json::Value::Object(map) = input {
357 Some(map.into_iter().collect())
358 } else {
359 None
360 };
361
362 log::trace!(
363 "Running tool: {} with arguments: {:?}",
364 tool_name,
365 arguments
366 );
367
368 let request = protocol.request::<context_server::types::requests::CallTool>(
369 context_server::types::CallToolParams {
370 name: tool_name,
371 arguments,
372 meta: None,
373 },
374 );
375
376 let response = futures::select! {
377 response = request.fuse() => response?,
378 _ = event_stream.cancelled_by_user().fuse() => {
379 anyhow::bail!("MCP tool cancelled by user");
380 }
381 };
382
383 if response.is_error == Some(true) {
384 let error_message: String =
385 response.content.iter().filter_map(|c| c.text()).collect();
386 anyhow::bail!(error_message);
387 }
388
389 let mut result = String::new();
390 for content in response.content {
391 match content {
392 context_server::types::ToolResponseContent::Text { text } => {
393 result.push_str(&text);
394 }
395 context_server::types::ToolResponseContent::Image { .. } => {
396 log::warn!("Ignoring image content from tool response");
397 }
398 context_server::types::ToolResponseContent::Audio { .. } => {
399 log::warn!("Ignoring audio content from tool response");
400 }
401 context_server::types::ToolResponseContent::Resource { .. } => {
402 log::warn!("Ignoring resource content from tool response");
403 }
404 }
405 }
406 Ok(AgentToolOutput {
407 raw_output: result.clone().into(),
408 llm_output: result.into(),
409 })
410 })
411 }
412
413 fn replay(
414 &self,
415 _input: serde_json::Value,
416 _output: serde_json::Value,
417 _event_stream: ToolCallEventStream,
418 _cx: &mut App,
419 ) -> Result<()> {
420 Ok(())
421 }
422}
423
424pub fn get_prompt(
425 server_store: &Entity<ContextServerStore>,
426 server_id: &ContextServerId,
427 prompt_name: &str,
428 arguments: HashMap<String, String>,
429 cx: &mut AsyncApp,
430) -> Task<Result<context_server::types::PromptsGetResponse>> {
431 let server = cx.update(|cx| server_store.read(cx).get_running_server(server_id));
432 let Some(server) = server else {
433 return Task::ready(Err(anyhow::anyhow!("Context server not found")));
434 };
435
436 let Some(protocol) = server.client() else {
437 return Task::ready(Err(anyhow::anyhow!("Context server not initialized")));
438 };
439
440 let prompt_name = prompt_name.to_string();
441
442 cx.background_spawn(async move {
443 let response = protocol
444 .request::<context_server::types::requests::PromptsGet>(
445 context_server::types::PromptsGetParams {
446 name: prompt_name,
447 arguments: (!arguments.is_empty()).then(|| arguments),
448 meta: None,
449 },
450 )
451 .await?;
452
453 Ok(response)
454 })
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460
461 #[test]
462 fn test_mcp_tool_id_format() {
463 assert_eq!(
464 mcp_tool_id("filesystem", "read_file"),
465 "mcp:filesystem:read_file"
466 );
467 assert_eq!(
468 mcp_tool_id("github", "create_issue"),
469 "mcp:github:create_issue"
470 );
471 assert_eq!(
472 mcp_tool_id("my-custom-server", "do_something"),
473 "mcp:my-custom-server:do_something"
474 );
475 // Underscores in names
476 assert_eq!(mcp_tool_id("my_server", "my_tool"), "mcp:my_server:my_tool");
477 }
478
479 // Note: Tests for MCP tool ID collision with built-in tools and permission
480 // decisions are in crates/agent/src/tool_permissions.rs to avoid duplication.
481}