context_server_registry.rs

  1use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
  2use agent_client_protocol::ToolKind;
  3use anyhow::{Result, anyhow};
  4use collections::{BTreeMap, HashMap};
  5use context_server::{ContextServerId, client::NotificationSubscription};
  6use futures::FutureExt as _;
  7use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task};
  8use project::context_server_store::{ContextServerStatus, ContextServerStore};
  9use std::sync::Arc;
 10use util::ResultExt;
 11
 12/// Generates a tool ID for an MCP tool that can be used in settings.
 13///
 14/// The format is `mcp:<server_id>:<tool_name>` to avoid collisions with built-in tools.
 15pub fn mcp_tool_id(server_id: &str, tool_name: &str) -> String {
 16    format!("mcp:{}:{}", server_id, tool_name)
 17}
 18
 19pub struct ContextServerPrompt {
 20    pub server_id: ContextServerId,
 21    pub prompt: context_server::types::Prompt,
 22}
 23
 24pub enum ContextServerRegistryEvent {
 25    ToolsChanged,
 26    PromptsChanged,
 27}
 28
 29impl EventEmitter<ContextServerRegistryEvent> for ContextServerRegistry {}
 30
 31pub struct ContextServerRegistry {
 32    server_store: Entity<ContextServerStore>,
 33    registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
 34    _subscription: gpui::Subscription,
 35}
 36
 37struct RegisteredContextServer {
 38    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 39    prompts: BTreeMap<SharedString, ContextServerPrompt>,
 40    load_tools: Task<Result<()>>,
 41    load_prompts: Task<Result<()>>,
 42    _tools_updated_subscription: Option<NotificationSubscription>,
 43}
 44
 45impl ContextServerRegistry {
 46    pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
 47        let mut this = Self {
 48            server_store: server_store.clone(),
 49            registered_servers: HashMap::default(),
 50            _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
 51        };
 52        for server in server_store.read(cx).running_servers() {
 53            this.reload_tools_for_server(server.id(), cx);
 54            this.reload_prompts_for_server(server.id(), cx);
 55        }
 56        this
 57    }
 58
 59    pub fn tools_for_server(
 60        &self,
 61        server_id: &ContextServerId,
 62    ) -> impl Iterator<Item = &Arc<dyn AnyAgentTool>> {
 63        self.registered_servers
 64            .get(server_id)
 65            .map(|server| server.tools.values())
 66            .into_iter()
 67            .flatten()
 68    }
 69
 70    pub fn servers(
 71        &self,
 72    ) -> impl Iterator<
 73        Item = (
 74            &ContextServerId,
 75            &BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 76        ),
 77    > {
 78        self.registered_servers
 79            .iter()
 80            .map(|(id, server)| (id, &server.tools))
 81    }
 82
 83    pub fn prompts(&self) -> impl Iterator<Item = &ContextServerPrompt> {
 84        self.registered_servers
 85            .values()
 86            .flat_map(|server| server.prompts.values())
 87    }
 88
 89    pub fn find_prompt(
 90        &self,
 91        server_id: Option<&ContextServerId>,
 92        name: &str,
 93    ) -> Option<&ContextServerPrompt> {
 94        if let Some(server_id) = server_id {
 95            self.registered_servers
 96                .get(server_id)
 97                .and_then(|server| server.prompts.get(name))
 98        } else {
 99            self.registered_servers
100                .values()
101                .find_map(|server| server.prompts.get(name))
102        }
103    }
104
105    pub fn server_store(&self) -> &Entity<ContextServerStore> {
106        &self.server_store
107    }
108
109    fn get_or_register_server(
110        &mut self,
111        server_id: &ContextServerId,
112        cx: &mut Context<Self>,
113    ) -> &mut RegisteredContextServer {
114        self.registered_servers
115            .entry(server_id.clone())
116            .or_insert_with(|| Self::init_registered_server(server_id, &self.server_store, cx))
117    }
118
119    fn init_registered_server(
120        server_id: &ContextServerId,
121        server_store: &Entity<ContextServerStore>,
122        cx: &mut Context<Self>,
123    ) -> RegisteredContextServer {
124        let tools_updated_subscription = server_store
125            .read(cx)
126            .get_running_server(server_id)
127            .and_then(|server| {
128                let client = server.client()?;
129
130                if !client.capable(context_server::protocol::ServerCapability::Tools) {
131                    return None;
132                }
133
134                let server_id = server.id();
135                let this = cx.entity().downgrade();
136
137                Some(client.on_notification(
138                    "notifications/tools/list_changed",
139                    Box::new(move |_params, cx: AsyncApp| {
140                        let server_id = server_id.clone();
141                        let this = this.clone();
142                        cx.spawn(async move |cx| {
143                            this.update(cx, |this, cx| {
144                                log::info!(
145                                    "Received tools/list_changed notification for server {}",
146                                    server_id
147                                );
148                                this.reload_tools_for_server(server_id, cx);
149                            })
150                        })
151                        .detach();
152                    }),
153                ))
154            });
155
156        RegisteredContextServer {
157            tools: BTreeMap::default(),
158            prompts: BTreeMap::default(),
159            load_tools: Task::ready(Ok(())),
160            load_prompts: Task::ready(Ok(())),
161            _tools_updated_subscription: tools_updated_subscription,
162        }
163    }
164
165    fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
166        let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
167            return;
168        };
169        let Some(client) = server.client() else {
170            return;
171        };
172
173        if !client.capable(context_server::protocol::ServerCapability::Tools) {
174            return;
175        }
176
177        let registered_server = self.get_or_register_server(&server_id, cx);
178        registered_server.load_tools = cx.spawn(async move |this, cx| {
179            let response = client
180                .request::<context_server::types::requests::ListTools>(())
181                .await;
182
183            this.update(cx, |this, cx| {
184                let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
185                    return;
186                };
187
188                registered_server.tools.clear();
189                if let Some(response) = response.log_err() {
190                    for tool in response.tools {
191                        let tool = Arc::new(ContextServerTool::new(
192                            this.server_store.clone(),
193                            server.id(),
194                            tool,
195                        ));
196                        registered_server.tools.insert(tool.name(), tool);
197                    }
198                    cx.emit(ContextServerRegistryEvent::ToolsChanged);
199                    cx.notify();
200                }
201            })
202        });
203    }
204
205    fn reload_prompts_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
206        let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
207            return;
208        };
209        let Some(client) = server.client() else {
210            return;
211        };
212        if !client.capable(context_server::protocol::ServerCapability::Prompts) {
213            return;
214        }
215
216        let registered_server = self.get_or_register_server(&server_id, cx);
217
218        registered_server.load_prompts = cx.spawn(async move |this, cx| {
219            let response = client
220                .request::<context_server::types::requests::PromptsList>(())
221                .await;
222
223            this.update(cx, |this, cx| {
224                let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
225                    return;
226                };
227
228                registered_server.prompts.clear();
229                if let Some(response) = response.log_err() {
230                    for prompt in response.prompts {
231                        let name: SharedString = prompt.name.clone().into();
232                        registered_server.prompts.insert(
233                            name,
234                            ContextServerPrompt {
235                                server_id: server_id.clone(),
236                                prompt,
237                            },
238                        );
239                    }
240                    cx.emit(ContextServerRegistryEvent::PromptsChanged);
241                    cx.notify();
242                }
243            })
244        });
245    }
246
247    fn handle_context_server_store_event(
248        &mut self,
249        _: Entity<ContextServerStore>,
250        event: &project::context_server_store::Event,
251        cx: &mut Context<Self>,
252    ) {
253        match event {
254            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
255                match status {
256                    ContextServerStatus::Starting => {}
257                    ContextServerStatus::Running => {
258                        self.reload_tools_for_server(server_id.clone(), cx);
259                        self.reload_prompts_for_server(server_id.clone(), cx);
260                    }
261                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
262                        if let Some(registered_server) = self.registered_servers.remove(server_id) {
263                            if !registered_server.tools.is_empty() {
264                                cx.emit(ContextServerRegistryEvent::ToolsChanged);
265                            }
266                            if !registered_server.prompts.is_empty() {
267                                cx.emit(ContextServerRegistryEvent::PromptsChanged);
268                            }
269                        }
270                        cx.notify();
271                    }
272                }
273            }
274        }
275    }
276}
277
278struct ContextServerTool {
279    store: Entity<ContextServerStore>,
280    server_id: ContextServerId,
281    tool: context_server::types::Tool,
282}
283
284impl ContextServerTool {
285    fn new(
286        store: Entity<ContextServerStore>,
287        server_id: ContextServerId,
288        tool: context_server::types::Tool,
289    ) -> Self {
290        Self {
291            store,
292            server_id,
293            tool,
294        }
295    }
296}
297
298impl AnyAgentTool for ContextServerTool {
299    fn name(&self) -> SharedString {
300        self.tool.name.clone().into()
301    }
302
303    fn description(&self) -> SharedString {
304        self.tool.description.clone().unwrap_or_default().into()
305    }
306
307    fn kind(&self) -> ToolKind {
308        ToolKind::Other
309    }
310
311    fn initial_title(&self, _input: serde_json::Value, _cx: &mut App) -> SharedString {
312        format!("Run MCP tool `{}`", self.tool.name).into()
313    }
314
315    fn input_schema(
316        &self,
317        format: language_model::LanguageModelToolSchemaFormat,
318    ) -> Result<serde_json::Value> {
319        let mut schema = self.tool.input_schema.clone();
320        language_model::tool_schema::adapt_schema_to_format(&mut schema, format)?;
321        Ok(match schema {
322            serde_json::Value::Null => {
323                serde_json::json!({ "type": "object", "properties": [] })
324            }
325            serde_json::Value::Object(map) if map.is_empty() => {
326                serde_json::json!({ "type": "object", "properties": [] })
327            }
328            _ => schema,
329        })
330    }
331
332    fn run(
333        self: Arc<Self>,
334        input: serde_json::Value,
335        event_stream: ToolCallEventStream,
336        cx: &mut App,
337    ) -> Task<Result<AgentToolOutput>> {
338        let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
339            return Task::ready(Err(anyhow!("Context server not found")));
340        };
341        let tool_name = self.tool.name.clone();
342        let tool_id = mcp_tool_id(&self.server_id.0, &self.tool.name);
343        let display_name = self.tool.name.clone();
344        let authorize = event_stream.authorize_third_party_tool(
345            self.initial_title(input.clone(), cx),
346            tool_id,
347            display_name,
348            cx,
349        );
350
351        cx.spawn(async move |_cx| {
352            authorize.await?;
353
354            let Some(protocol) = server.client() else {
355                anyhow::bail!("Context server not initialized");
356            };
357
358            let arguments = if let serde_json::Value::Object(map) = input {
359                Some(map.into_iter().collect())
360            } else {
361                None
362            };
363
364            log::trace!(
365                "Running tool: {} with arguments: {:?}",
366                tool_name,
367                arguments
368            );
369
370            let request = protocol.request::<context_server::types::requests::CallTool>(
371                context_server::types::CallToolParams {
372                    name: tool_name,
373                    arguments,
374                    meta: None,
375                },
376            );
377
378            let response = futures::select! {
379                response = request.fuse() => response?,
380                _ = event_stream.cancelled_by_user().fuse() => {
381                    anyhow::bail!("MCP tool cancelled by user");
382                }
383            };
384
385            let mut result = String::new();
386            for content in response.content {
387                match content {
388                    context_server::types::ToolResponseContent::Text { text } => {
389                        result.push_str(&text);
390                    }
391                    context_server::types::ToolResponseContent::Image { .. } => {
392                        log::warn!("Ignoring image content from tool response");
393                    }
394                    context_server::types::ToolResponseContent::Audio { .. } => {
395                        log::warn!("Ignoring audio content from tool response");
396                    }
397                    context_server::types::ToolResponseContent::Resource { .. } => {
398                        log::warn!("Ignoring resource content from tool response");
399                    }
400                }
401            }
402            Ok(AgentToolOutput {
403                raw_output: result.clone().into(),
404                llm_output: result.into(),
405            })
406        })
407    }
408
409    fn replay(
410        &self,
411        _input: serde_json::Value,
412        _output: serde_json::Value,
413        _event_stream: ToolCallEventStream,
414        _cx: &mut App,
415    ) -> Result<()> {
416        Ok(())
417    }
418}
419
420pub fn get_prompt(
421    server_store: &Entity<ContextServerStore>,
422    server_id: &ContextServerId,
423    prompt_name: &str,
424    arguments: HashMap<String, String>,
425    cx: &mut AsyncApp,
426) -> Task<Result<context_server::types::PromptsGetResponse>> {
427    let server = cx.update(|cx| server_store.read(cx).get_running_server(server_id));
428    let Some(server) = server else {
429        return Task::ready(Err(anyhow::anyhow!("Context server not found")));
430    };
431
432    let Some(protocol) = server.client() else {
433        return Task::ready(Err(anyhow::anyhow!("Context server not initialized")));
434    };
435
436    let prompt_name = prompt_name.to_string();
437
438    cx.background_spawn(async move {
439        let response = protocol
440            .request::<context_server::types::requests::PromptsGet>(
441                context_server::types::PromptsGetParams {
442                    name: prompt_name,
443                    arguments: (!arguments.is_empty()).then(|| arguments),
444                    meta: None,
445                },
446            )
447            .await?;
448
449        Ok(response)
450    })
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456
457    #[test]
458    fn test_mcp_tool_id_format() {
459        assert_eq!(
460            mcp_tool_id("filesystem", "read_file"),
461            "mcp:filesystem:read_file"
462        );
463        assert_eq!(
464            mcp_tool_id("github", "create_issue"),
465            "mcp:github:create_issue"
466        );
467        assert_eq!(
468            mcp_tool_id("my-custom-server", "do_something"),
469            "mcp:my-custom-server:do_something"
470        );
471        // Underscores in names
472        assert_eq!(mcp_tool_id("my_server", "my_tool"), "mcp:my_server:my_tool");
473    }
474
475    // Note: Tests for MCP tool ID collision with built-in tools and permission
476    // decisions are in crates/agent/src/tool_permissions.rs to avoid duplication.
477}