1use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream, ToolInput};
2use agent_client_protocol::ToolKind;
3use anyhow::Result;
4use collections::{BTreeMap, HashMap};
5use context_server::{ContextServerId, client::NotificationSubscription};
6use futures::FutureExt as _;
7use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task};
8use project::context_server_store::{ContextServerStatus, ContextServerStore};
9use std::sync::Arc;
10use util::ResultExt;
11
12/// Generates a tool ID for an MCP tool that can be used in settings.
13///
14/// The format is `mcp:<server_id>:<tool_name>` to avoid collisions with built-in tools.
15pub fn mcp_tool_id(server_id: &str, tool_name: &str) -> String {
16 format!("mcp:{}:{}", server_id, tool_name)
17}
18
19pub struct ContextServerPrompt {
20 pub server_id: ContextServerId,
21 pub prompt: context_server::types::Prompt,
22}
23
24pub enum ContextServerRegistryEvent {
25 ToolsChanged,
26 PromptsChanged,
27}
28
29impl EventEmitter<ContextServerRegistryEvent> for ContextServerRegistry {}
30
31pub struct ContextServerRegistry {
32 server_store: Entity<ContextServerStore>,
33 registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
34 _subscription: gpui::Subscription,
35}
36
37struct RegisteredContextServer {
38 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
39 prompts: BTreeMap<SharedString, ContextServerPrompt>,
40 load_tools: Task<Result<()>>,
41 load_prompts: Task<Result<()>>,
42 _tools_updated_subscription: Option<NotificationSubscription>,
43}
44
45impl ContextServerRegistry {
46 pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
47 let mut this = Self {
48 server_store: server_store.clone(),
49 registered_servers: HashMap::default(),
50 _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
51 };
52 for server in server_store.read(cx).running_servers() {
53 this.reload_tools_for_server(server.id(), cx);
54 this.reload_prompts_for_server(server.id(), cx);
55 }
56 this
57 }
58
59 pub fn tools_for_server(
60 &self,
61 server_id: &ContextServerId,
62 ) -> impl Iterator<Item = &Arc<dyn AnyAgentTool>> {
63 self.registered_servers
64 .get(server_id)
65 .map(|server| server.tools.values())
66 .into_iter()
67 .flatten()
68 }
69
70 pub fn servers(
71 &self,
72 ) -> impl Iterator<
73 Item = (
74 &ContextServerId,
75 &BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
76 ),
77 > {
78 self.registered_servers
79 .iter()
80 .map(|(id, server)| (id, &server.tools))
81 }
82
83 pub fn prompts(&self) -> impl Iterator<Item = &ContextServerPrompt> {
84 self.registered_servers
85 .values()
86 .flat_map(|server| server.prompts.values())
87 }
88
89 pub fn find_prompt(
90 &self,
91 server_id: Option<&ContextServerId>,
92 name: &str,
93 ) -> Option<&ContextServerPrompt> {
94 if let Some(server_id) = server_id {
95 self.registered_servers
96 .get(server_id)
97 .and_then(|server| server.prompts.get(name))
98 } else {
99 self.registered_servers
100 .values()
101 .find_map(|server| server.prompts.get(name))
102 }
103 }
104
105 pub fn server_store(&self) -> &Entity<ContextServerStore> {
106 &self.server_store
107 }
108
109 fn get_or_register_server(
110 &mut self,
111 server_id: &ContextServerId,
112 cx: &mut Context<Self>,
113 ) -> &mut RegisteredContextServer {
114 self.registered_servers
115 .entry(server_id.clone())
116 .or_insert_with(|| Self::init_registered_server(server_id, &self.server_store, cx))
117 }
118
119 fn init_registered_server(
120 server_id: &ContextServerId,
121 server_store: &Entity<ContextServerStore>,
122 cx: &mut Context<Self>,
123 ) -> RegisteredContextServer {
124 let tools_updated_subscription = server_store
125 .read(cx)
126 .get_running_server(server_id)
127 .and_then(|server| {
128 let client = server.client()?;
129
130 if !client.capable(context_server::protocol::ServerCapability::Tools) {
131 return None;
132 }
133
134 let server_id = server.id();
135 let this = cx.entity().downgrade();
136
137 Some(client.on_notification(
138 "notifications/tools/list_changed",
139 Box::new(move |_params, cx: AsyncApp| {
140 let server_id = server_id.clone();
141 let this = this.clone();
142 cx.spawn(async move |cx| {
143 this.update(cx, |this, cx| {
144 log::info!(
145 "Received tools/list_changed notification for server {}",
146 server_id
147 );
148 this.reload_tools_for_server(server_id, cx);
149 })
150 })
151 .detach();
152 }),
153 ))
154 });
155
156 RegisteredContextServer {
157 tools: BTreeMap::default(),
158 prompts: BTreeMap::default(),
159 load_tools: Task::ready(Ok(())),
160 load_prompts: Task::ready(Ok(())),
161 _tools_updated_subscription: tools_updated_subscription,
162 }
163 }
164
165 fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
166 let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
167 return;
168 };
169 let Some(client) = server.client() else {
170 return;
171 };
172
173 if !client.capable(context_server::protocol::ServerCapability::Tools) {
174 return;
175 }
176
177 let registered_server = self.get_or_register_server(&server_id, cx);
178 registered_server.load_tools = cx.spawn(async move |this, cx| {
179 let response = client
180 .request::<context_server::types::requests::ListTools>(())
181 .await;
182
183 this.update(cx, |this, cx| {
184 let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
185 return;
186 };
187
188 registered_server.tools.clear();
189 if let Some(response) = response.log_err() {
190 for tool in response.tools {
191 let tool = Arc::new(ContextServerTool::new(
192 this.server_store.clone(),
193 server.id(),
194 tool,
195 ));
196 registered_server.tools.insert(tool.name(), tool);
197 }
198 cx.emit(ContextServerRegistryEvent::ToolsChanged);
199 cx.notify();
200 }
201 })
202 });
203 }
204
205 fn reload_prompts_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
206 let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
207 return;
208 };
209 let Some(client) = server.client() else {
210 return;
211 };
212 if !client.capable(context_server::protocol::ServerCapability::Prompts) {
213 return;
214 }
215
216 let registered_server = self.get_or_register_server(&server_id, cx);
217
218 registered_server.load_prompts = cx.spawn(async move |this, cx| {
219 let response = client
220 .request::<context_server::types::requests::PromptsList>(())
221 .await;
222
223 this.update(cx, |this, cx| {
224 let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
225 return;
226 };
227
228 registered_server.prompts.clear();
229 if let Some(response) = response.log_err() {
230 for prompt in response.prompts {
231 let name: SharedString = prompt.name.clone().into();
232 registered_server.prompts.insert(
233 name,
234 ContextServerPrompt {
235 server_id: server_id.clone(),
236 prompt,
237 },
238 );
239 }
240 cx.emit(ContextServerRegistryEvent::PromptsChanged);
241 cx.notify();
242 }
243 })
244 });
245 }
246
247 fn handle_context_server_store_event(
248 &mut self,
249 _: Entity<ContextServerStore>,
250 event: &project::context_server_store::ServerStatusChangedEvent,
251 cx: &mut Context<Self>,
252 ) {
253 let project::context_server_store::ServerStatusChangedEvent { server_id, status } = event;
254
255 match status {
256 ContextServerStatus::Starting => {}
257 ContextServerStatus::Running => {
258 self.reload_tools_for_server(server_id.clone(), cx);
259 self.reload_prompts_for_server(server_id.clone(), cx);
260 }
261 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
262 if let Some(registered_server) = self.registered_servers.remove(server_id) {
263 if !registered_server.tools.is_empty() {
264 cx.emit(ContextServerRegistryEvent::ToolsChanged);
265 }
266 if !registered_server.prompts.is_empty() {
267 cx.emit(ContextServerRegistryEvent::PromptsChanged);
268 }
269 }
270 cx.notify();
271 }
272 };
273 }
274}
275
276struct ContextServerTool {
277 store: Entity<ContextServerStore>,
278 server_id: ContextServerId,
279 tool: context_server::types::Tool,
280}
281
282impl ContextServerTool {
283 fn new(
284 store: Entity<ContextServerStore>,
285 server_id: ContextServerId,
286 tool: context_server::types::Tool,
287 ) -> Self {
288 Self {
289 store,
290 server_id,
291 tool,
292 }
293 }
294}
295
296impl AnyAgentTool for ContextServerTool {
297 fn name(&self) -> SharedString {
298 self.tool.name.clone().into()
299 }
300
301 fn description(&self) -> SharedString {
302 self.tool.description.clone().unwrap_or_default().into()
303 }
304
305 fn kind(&self) -> ToolKind {
306 ToolKind::Other
307 }
308
309 fn initial_title(&self, _input: serde_json::Value, _cx: &mut App) -> SharedString {
310 format!("Run MCP tool `{}`", self.tool.name).into()
311 }
312
313 fn input_schema(
314 &self,
315 format: language_model::LanguageModelToolSchemaFormat,
316 ) -> Result<serde_json::Value> {
317 let mut schema = self.tool.input_schema.clone();
318 language_model::tool_schema::adapt_schema_to_format(&mut schema, format)?;
319 Ok(match schema {
320 serde_json::Value::Null => {
321 serde_json::json!({ "type": "object", "properties": [] })
322 }
323 serde_json::Value::Object(map) if map.is_empty() => {
324 serde_json::json!({ "type": "object", "properties": [] })
325 }
326 _ => schema,
327 })
328 }
329
330 fn run(
331 self: Arc<Self>,
332 input: ToolInput<serde_json::Value>,
333 event_stream: ToolCallEventStream,
334 cx: &mut App,
335 ) -> Task<Result<AgentToolOutput, AgentToolOutput>> {
336 let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
337 return Task::ready(Err(AgentToolOutput::from_error("Context server not found")));
338 };
339 let tool_name = self.tool.name.clone();
340 let tool_id = mcp_tool_id(&self.server_id.0, &self.tool.name);
341 let display_name = self.tool.name.clone();
342 let initial_title = self.initial_title(serde_json::Value::Null, cx);
343 let authorize =
344 event_stream.authorize_third_party_tool(initial_title, tool_id, display_name, cx);
345
346 cx.spawn(async move |_cx| {
347 let input = input.recv().await.map_err(|e| {
348 AgentToolOutput::from_error(format!("Failed to receive tool input: {e}"))
349 })?;
350
351 authorize.await.map_err(|e| AgentToolOutput::from_error(e.to_string()))?;
352
353 let Some(protocol) = server.client() else {
354 return Err(AgentToolOutput::from_error("Context server not initialized"));
355 };
356
357 let arguments = if let serde_json::Value::Object(map) = input {
358 Some(map.into_iter().collect())
359 } else {
360 None
361 };
362
363 log::trace!(
364 "Running tool: {} with arguments: {:?}",
365 tool_name,
366 arguments
367 );
368
369 let request = protocol.request::<context_server::types::requests::CallTool>(
370 context_server::types::CallToolParams {
371 name: tool_name,
372 arguments,
373 meta: None,
374 },
375 );
376
377 let response = futures::select! {
378 response = request.fuse() => response.map_err(|e| AgentToolOutput::from_error(e.to_string()))?,
379 _ = event_stream.cancelled_by_user().fuse() => {
380 return Err(AgentToolOutput::from_error("MCP tool cancelled by user"));
381 }
382 };
383
384 if response.is_error == Some(true) {
385 let error_message: String =
386 response.content.iter().filter_map(|c| c.text()).collect();
387 return Err(AgentToolOutput::from_error(error_message));
388 }
389
390 let mut result = String::new();
391 for content in response.content {
392 match content {
393 context_server::types::ToolResponseContent::Text { text } => {
394 result.push_str(&text);
395 }
396 context_server::types::ToolResponseContent::Image { .. } => {
397 log::warn!("Ignoring image content from tool response");
398 }
399 context_server::types::ToolResponseContent::Audio { .. } => {
400 log::warn!("Ignoring audio content from tool response");
401 }
402 context_server::types::ToolResponseContent::Resource { .. } => {
403 log::warn!("Ignoring resource content from tool response");
404 }
405 }
406 }
407 Ok(AgentToolOutput {
408 raw_output: result.clone().into(),
409 llm_output: result.into(),
410 })
411 })
412 }
413
414 fn replay(
415 &self,
416 _input: serde_json::Value,
417 _output: serde_json::Value,
418 _event_stream: ToolCallEventStream,
419 _cx: &mut App,
420 ) -> Result<()> {
421 Ok(())
422 }
423}
424
425pub fn get_prompt(
426 server_store: &Entity<ContextServerStore>,
427 server_id: &ContextServerId,
428 prompt_name: &str,
429 arguments: HashMap<String, String>,
430 cx: &mut AsyncApp,
431) -> Task<Result<context_server::types::PromptsGetResponse>> {
432 let server = cx.update(|cx| server_store.read(cx).get_running_server(server_id));
433 let Some(server) = server else {
434 return Task::ready(Err(anyhow::anyhow!("Context server not found")));
435 };
436
437 let Some(protocol) = server.client() else {
438 return Task::ready(Err(anyhow::anyhow!("Context server not initialized")));
439 };
440
441 let prompt_name = prompt_name.to_string();
442
443 cx.background_spawn(async move {
444 let response = protocol
445 .request::<context_server::types::requests::PromptsGet>(
446 context_server::types::PromptsGetParams {
447 name: prompt_name,
448 arguments: (!arguments.is_empty()).then(|| arguments),
449 meta: None,
450 },
451 )
452 .await?;
453
454 Ok(response)
455 })
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 #[test]
463 fn test_mcp_tool_id_format() {
464 assert_eq!(
465 mcp_tool_id("filesystem", "read_file"),
466 "mcp:filesystem:read_file"
467 );
468 assert_eq!(
469 mcp_tool_id("github", "create_issue"),
470 "mcp:github:create_issue"
471 );
472 assert_eq!(
473 mcp_tool_id("my-custom-server", "do_something"),
474 "mcp:my-custom-server:do_something"
475 );
476 // Underscores in names
477 assert_eq!(mcp_tool_id("my_server", "my_tool"), "mcp:my_server:my_tool");
478 }
479
480 // Note: Tests for MCP tool ID collision with built-in tools and permission
481 // decisions are in crates/agent/src/tool_permissions.rs to avoid duplication.
482}