1use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
2use agent_client_protocol::ToolKind;
3use anyhow::{Result, anyhow};
4use collections::{BTreeMap, HashMap};
5use context_server::{ContextServerId, client::NotificationSubscription};
6use futures::FutureExt as _;
7use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task};
8use project::context_server_store::{ContextServerStatus, ContextServerStore};
9use std::sync::Arc;
10use util::ResultExt;
11
12/// Generates a tool ID for an MCP tool that can be used in settings.
13///
14/// The format is `mcp:<server_id>:<tool_name>` to avoid collisions with built-in tools.
15pub fn mcp_tool_id(server_id: &str, tool_name: &str) -> String {
16 format!("mcp:{}:{}", server_id, tool_name)
17}
18
19pub struct ContextServerPrompt {
20 pub server_id: ContextServerId,
21 pub prompt: context_server::types::Prompt,
22}
23
24pub enum ContextServerRegistryEvent {
25 ToolsChanged,
26 PromptsChanged,
27}
28
29impl EventEmitter<ContextServerRegistryEvent> for ContextServerRegistry {}
30
31pub struct ContextServerRegistry {
32 server_store: Entity<ContextServerStore>,
33 registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
34 _subscription: gpui::Subscription,
35}
36
37struct RegisteredContextServer {
38 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
39 prompts: BTreeMap<SharedString, ContextServerPrompt>,
40 load_tools: Task<Result<()>>,
41 load_prompts: Task<Result<()>>,
42 _tools_updated_subscription: Option<NotificationSubscription>,
43}
44
45impl ContextServerRegistry {
46 pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
47 let mut this = Self {
48 server_store: server_store.clone(),
49 registered_servers: HashMap::default(),
50 _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
51 };
52 for server in server_store.read(cx).running_servers() {
53 this.reload_tools_for_server(server.id(), cx);
54 this.reload_prompts_for_server(server.id(), cx);
55 }
56 this
57 }
58
59 pub fn tools_for_server(
60 &self,
61 server_id: &ContextServerId,
62 ) -> impl Iterator<Item = &Arc<dyn AnyAgentTool>> {
63 self.registered_servers
64 .get(server_id)
65 .map(|server| server.tools.values())
66 .into_iter()
67 .flatten()
68 }
69
70 pub fn servers(
71 &self,
72 ) -> impl Iterator<
73 Item = (
74 &ContextServerId,
75 &BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
76 ),
77 > {
78 self.registered_servers
79 .iter()
80 .map(|(id, server)| (id, &server.tools))
81 }
82
83 pub fn prompts(&self) -> impl Iterator<Item = &ContextServerPrompt> {
84 self.registered_servers
85 .values()
86 .flat_map(|server| server.prompts.values())
87 }
88
89 pub fn find_prompt(
90 &self,
91 server_id: Option<&ContextServerId>,
92 name: &str,
93 ) -> Option<&ContextServerPrompt> {
94 if let Some(server_id) = server_id {
95 self.registered_servers
96 .get(server_id)
97 .and_then(|server| server.prompts.get(name))
98 } else {
99 self.registered_servers
100 .values()
101 .find_map(|server| server.prompts.get(name))
102 }
103 }
104
105 pub fn server_store(&self) -> &Entity<ContextServerStore> {
106 &self.server_store
107 }
108
109 fn get_or_register_server(
110 &mut self,
111 server_id: &ContextServerId,
112 cx: &mut Context<Self>,
113 ) -> &mut RegisteredContextServer {
114 self.registered_servers
115 .entry(server_id.clone())
116 .or_insert_with(|| Self::init_registered_server(server_id, &self.server_store, cx))
117 }
118
119 fn init_registered_server(
120 server_id: &ContextServerId,
121 server_store: &Entity<ContextServerStore>,
122 cx: &mut Context<Self>,
123 ) -> RegisteredContextServer {
124 let tools_updated_subscription = server_store
125 .read(cx)
126 .get_running_server(server_id)
127 .and_then(|server| {
128 let client = server.client()?;
129
130 if !client.capable(context_server::protocol::ServerCapability::Tools) {
131 return None;
132 }
133
134 let server_id = server.id();
135 let this = cx.entity().downgrade();
136
137 Some(client.on_notification(
138 "notifications/tools/list_changed",
139 Box::new(move |_params, cx: AsyncApp| {
140 let server_id = server_id.clone();
141 let this = this.clone();
142 cx.spawn(async move |cx| {
143 this.update(cx, |this, cx| {
144 log::info!(
145 "Received tools/list_changed notification for server {}",
146 server_id
147 );
148 this.reload_tools_for_server(server_id, cx);
149 })
150 })
151 .detach();
152 }),
153 ))
154 });
155
156 RegisteredContextServer {
157 tools: BTreeMap::default(),
158 prompts: BTreeMap::default(),
159 load_tools: Task::ready(Ok(())),
160 load_prompts: Task::ready(Ok(())),
161 _tools_updated_subscription: tools_updated_subscription,
162 }
163 }
164
165 fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
166 let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
167 return;
168 };
169 let Some(client) = server.client() else {
170 return;
171 };
172
173 if !client.capable(context_server::protocol::ServerCapability::Tools) {
174 return;
175 }
176
177 let registered_server = self.get_or_register_server(&server_id, cx);
178 registered_server.load_tools = cx.spawn(async move |this, cx| {
179 let response = client
180 .request::<context_server::types::requests::ListTools>(())
181 .await;
182
183 this.update(cx, |this, cx| {
184 let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
185 return;
186 };
187
188 registered_server.tools.clear();
189 if let Some(response) = response.log_err() {
190 for tool in response.tools {
191 let tool = Arc::new(ContextServerTool::new(
192 this.server_store.clone(),
193 server.id(),
194 tool,
195 ));
196 registered_server.tools.insert(tool.name(), tool);
197 }
198 cx.emit(ContextServerRegistryEvent::ToolsChanged);
199 cx.notify();
200 }
201 })
202 });
203 }
204
205 fn reload_prompts_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
206 let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
207 return;
208 };
209 let Some(client) = server.client() else {
210 return;
211 };
212 if !client.capable(context_server::protocol::ServerCapability::Prompts) {
213 return;
214 }
215
216 let registered_server = self.get_or_register_server(&server_id, cx);
217
218 registered_server.load_prompts = cx.spawn(async move |this, cx| {
219 let response = client
220 .request::<context_server::types::requests::PromptsList>(())
221 .await;
222
223 this.update(cx, |this, cx| {
224 let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
225 return;
226 };
227
228 registered_server.prompts.clear();
229 if let Some(response) = response.log_err() {
230 for prompt in response.prompts {
231 let name: SharedString = prompt.name.clone().into();
232 registered_server.prompts.insert(
233 name,
234 ContextServerPrompt {
235 server_id: server_id.clone(),
236 prompt,
237 },
238 );
239 }
240 cx.emit(ContextServerRegistryEvent::PromptsChanged);
241 cx.notify();
242 }
243 })
244 });
245 }
246
247 fn handle_context_server_store_event(
248 &mut self,
249 _: Entity<ContextServerStore>,
250 event: &project::context_server_store::Event,
251 cx: &mut Context<Self>,
252 ) {
253 match event {
254 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
255 match status {
256 ContextServerStatus::Starting => {}
257 ContextServerStatus::Running => {
258 self.reload_tools_for_server(server_id.clone(), cx);
259 self.reload_prompts_for_server(server_id.clone(), cx);
260 }
261 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
262 if let Some(registered_server) = self.registered_servers.remove(server_id) {
263 if !registered_server.tools.is_empty() {
264 cx.emit(ContextServerRegistryEvent::ToolsChanged);
265 }
266 if !registered_server.prompts.is_empty() {
267 cx.emit(ContextServerRegistryEvent::PromptsChanged);
268 }
269 }
270 cx.notify();
271 }
272 }
273 }
274 }
275 }
276}
277
278struct ContextServerTool {
279 store: Entity<ContextServerStore>,
280 server_id: ContextServerId,
281 tool: context_server::types::Tool,
282}
283
284impl ContextServerTool {
285 fn new(
286 store: Entity<ContextServerStore>,
287 server_id: ContextServerId,
288 tool: context_server::types::Tool,
289 ) -> Self {
290 Self {
291 store,
292 server_id,
293 tool,
294 }
295 }
296}
297
298impl AnyAgentTool for ContextServerTool {
299 fn name(&self) -> SharedString {
300 self.tool.name.clone().into()
301 }
302
303 fn description(&self) -> SharedString {
304 self.tool.description.clone().unwrap_or_default().into()
305 }
306
307 fn kind(&self) -> ToolKind {
308 ToolKind::Other
309 }
310
311 fn initial_title(&self, _input: serde_json::Value, _cx: &mut App) -> SharedString {
312 format!("Run MCP tool `{}`", self.tool.name).into()
313 }
314
315 fn input_schema(
316 &self,
317 format: language_model::LanguageModelToolSchemaFormat,
318 ) -> Result<serde_json::Value> {
319 let mut schema = self.tool.input_schema.clone();
320 language_model::tool_schema::adapt_schema_to_format(&mut schema, format)?;
321 Ok(match schema {
322 serde_json::Value::Null => {
323 serde_json::json!({ "type": "object", "properties": [] })
324 }
325 serde_json::Value::Object(map) if map.is_empty() => {
326 serde_json::json!({ "type": "object", "properties": [] })
327 }
328 _ => schema,
329 })
330 }
331
332 fn run(
333 self: Arc<Self>,
334 input: serde_json::Value,
335 event_stream: ToolCallEventStream,
336 cx: &mut App,
337 ) -> Task<Result<AgentToolOutput>> {
338 let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
339 return Task::ready(Err(anyhow!("Context server not found")));
340 };
341 let tool_name = self.tool.name.clone();
342 let tool_id = mcp_tool_id(&self.server_id.0, &self.tool.name);
343 let display_name = self.tool.name.clone();
344 let authorize = event_stream.authorize_third_party_tool(
345 self.initial_title(input.clone(), cx),
346 tool_id,
347 display_name,
348 cx,
349 );
350
351 cx.spawn(async move |_cx| {
352 authorize.await?;
353
354 let Some(protocol) = server.client() else {
355 anyhow::bail!("Context server not initialized");
356 };
357
358 let arguments = if let serde_json::Value::Object(map) = input {
359 Some(map.into_iter().collect())
360 } else {
361 None
362 };
363
364 log::trace!(
365 "Running tool: {} with arguments: {:?}",
366 tool_name,
367 arguments
368 );
369
370 let request = protocol.request::<context_server::types::requests::CallTool>(
371 context_server::types::CallToolParams {
372 name: tool_name,
373 arguments,
374 meta: None,
375 },
376 );
377
378 let response = futures::select! {
379 response = request.fuse() => response?,
380 _ = event_stream.cancelled_by_user().fuse() => {
381 anyhow::bail!("MCP tool cancelled by user");
382 }
383 };
384
385 let mut result = String::new();
386 for content in response.content {
387 match content {
388 context_server::types::ToolResponseContent::Text { text } => {
389 result.push_str(&text);
390 }
391 context_server::types::ToolResponseContent::Image { .. } => {
392 log::warn!("Ignoring image content from tool response");
393 }
394 context_server::types::ToolResponseContent::Audio { .. } => {
395 log::warn!("Ignoring audio content from tool response");
396 }
397 context_server::types::ToolResponseContent::Resource { .. } => {
398 log::warn!("Ignoring resource content from tool response");
399 }
400 }
401 }
402 Ok(AgentToolOutput {
403 raw_output: result.clone().into(),
404 llm_output: result.into(),
405 })
406 })
407 }
408
409 fn replay(
410 &self,
411 _input: serde_json::Value,
412 _output: serde_json::Value,
413 _event_stream: ToolCallEventStream,
414 _cx: &mut App,
415 ) -> Result<()> {
416 Ok(())
417 }
418}
419
420pub fn get_prompt(
421 server_store: &Entity<ContextServerStore>,
422 server_id: &ContextServerId,
423 prompt_name: &str,
424 arguments: HashMap<String, String>,
425 cx: &mut AsyncApp,
426) -> Task<Result<context_server::types::PromptsGetResponse>> {
427 let server = cx.update(|cx| server_store.read(cx).get_running_server(server_id));
428 let Some(server) = server else {
429 return Task::ready(Err(anyhow::anyhow!("Context server not found")));
430 };
431
432 let Some(protocol) = server.client() else {
433 return Task::ready(Err(anyhow::anyhow!("Context server not initialized")));
434 };
435
436 let prompt_name = prompt_name.to_string();
437
438 cx.background_spawn(async move {
439 let response = protocol
440 .request::<context_server::types::requests::PromptsGet>(
441 context_server::types::PromptsGetParams {
442 name: prompt_name,
443 arguments: (!arguments.is_empty()).then(|| arguments),
444 meta: None,
445 },
446 )
447 .await?;
448
449 Ok(response)
450 })
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456
457 #[test]
458 fn test_mcp_tool_id_format() {
459 assert_eq!(
460 mcp_tool_id("filesystem", "read_file"),
461 "mcp:filesystem:read_file"
462 );
463 assert_eq!(
464 mcp_tool_id("github", "create_issue"),
465 "mcp:github:create_issue"
466 );
467 assert_eq!(
468 mcp_tool_id("my-custom-server", "do_something"),
469 "mcp:my-custom-server:do_something"
470 );
471 // Underscores in names
472 assert_eq!(mcp_tool_id("my_server", "my_tool"), "mcp:my_server:my_tool");
473 }
474
475 // Note: Tests for MCP tool ID collision with built-in tools and permission
476 // decisions are in crates/agent/src/tool_permissions.rs to avoid duplication.
477}