context_server_registry.rs

  1use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream, ToolInput};
  2use agent_client_protocol::ToolKind;
  3use anyhow::Result;
  4use collections::{BTreeMap, HashMap};
  5use context_server::{ContextServerId, client::NotificationSubscription};
  6use futures::FutureExt as _;
  7use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task};
  8use project::context_server_store::{ContextServerStatus, ContextServerStore};
  9use std::sync::Arc;
 10use util::ResultExt;
 11
 12/// Generates a tool ID for an MCP tool that can be used in settings.
 13///
 14/// The format is `mcp:<server_id>:<tool_name>` to avoid collisions with built-in tools.
 15pub fn mcp_tool_id(server_id: &str, tool_name: &str) -> String {
 16    format!("mcp:{}:{}", server_id, tool_name)
 17}
 18
 19pub struct ContextServerPrompt {
 20    pub server_id: ContextServerId,
 21    pub prompt: context_server::types::Prompt,
 22}
 23
 24pub enum ContextServerRegistryEvent {
 25    ToolsChanged,
 26    PromptsChanged,
 27}
 28
 29impl EventEmitter<ContextServerRegistryEvent> for ContextServerRegistry {}
 30
 31pub struct ContextServerRegistry {
 32    server_store: Entity<ContextServerStore>,
 33    registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
 34    _subscription: gpui::Subscription,
 35}
 36
 37struct RegisteredContextServer {
 38    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 39    prompts: BTreeMap<SharedString, ContextServerPrompt>,
 40    load_tools: Task<Result<()>>,
 41    load_prompts: Task<Result<()>>,
 42    _tools_updated_subscription: Option<NotificationSubscription>,
 43}
 44
 45impl ContextServerRegistry {
 46    pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
 47        let mut this = Self {
 48            server_store: server_store.clone(),
 49            registered_servers: HashMap::default(),
 50            _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
 51        };
 52        for server in server_store.read(cx).running_servers() {
 53            this.reload_tools_for_server(server.id(), cx);
 54            this.reload_prompts_for_server(server.id(), cx);
 55        }
 56        this
 57    }
 58
 59    pub fn tools_for_server(
 60        &self,
 61        server_id: &ContextServerId,
 62    ) -> impl Iterator<Item = &Arc<dyn AnyAgentTool>> {
 63        self.registered_servers
 64            .get(server_id)
 65            .map(|server| server.tools.values())
 66            .into_iter()
 67            .flatten()
 68    }
 69
 70    pub fn servers(
 71        &self,
 72    ) -> impl Iterator<
 73        Item = (
 74            &ContextServerId,
 75            &BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 76        ),
 77    > {
 78        self.registered_servers
 79            .iter()
 80            .map(|(id, server)| (id, &server.tools))
 81    }
 82
 83    pub fn prompts(&self) -> impl Iterator<Item = &ContextServerPrompt> {
 84        self.registered_servers
 85            .values()
 86            .flat_map(|server| server.prompts.values())
 87    }
 88
 89    pub fn find_prompt(
 90        &self,
 91        server_id: Option<&ContextServerId>,
 92        name: &str,
 93    ) -> Option<&ContextServerPrompt> {
 94        if let Some(server_id) = server_id {
 95            self.registered_servers
 96                .get(server_id)
 97                .and_then(|server| server.prompts.get(name))
 98        } else {
 99            self.registered_servers
100                .values()
101                .find_map(|server| server.prompts.get(name))
102        }
103    }
104
105    pub fn server_store(&self) -> &Entity<ContextServerStore> {
106        &self.server_store
107    }
108
109    fn get_or_register_server(
110        &mut self,
111        server_id: &ContextServerId,
112        cx: &mut Context<Self>,
113    ) -> &mut RegisteredContextServer {
114        self.registered_servers
115            .entry(server_id.clone())
116            .or_insert_with(|| Self::init_registered_server(server_id, &self.server_store, cx))
117    }
118
119    fn init_registered_server(
120        server_id: &ContextServerId,
121        server_store: &Entity<ContextServerStore>,
122        cx: &mut Context<Self>,
123    ) -> RegisteredContextServer {
124        let tools_updated_subscription = server_store
125            .read(cx)
126            .get_running_server(server_id)
127            .and_then(|server| {
128                let client = server.client()?;
129
130                if !client.capable(context_server::protocol::ServerCapability::Tools) {
131                    return None;
132                }
133
134                let server_id = server.id();
135                let this = cx.entity().downgrade();
136
137                Some(client.on_notification(
138                    "notifications/tools/list_changed",
139                    Box::new(move |_params, cx: AsyncApp| {
140                        let server_id = server_id.clone();
141                        let this = this.clone();
142                        cx.spawn(async move |cx| {
143                            this.update(cx, |this, cx| {
144                                log::info!(
145                                    "Received tools/list_changed notification for server {}",
146                                    server_id
147                                );
148                                this.reload_tools_for_server(server_id, cx);
149                            })
150                        })
151                        .detach();
152                    }),
153                ))
154            });
155
156        RegisteredContextServer {
157            tools: BTreeMap::default(),
158            prompts: BTreeMap::default(),
159            load_tools: Task::ready(Ok(())),
160            load_prompts: Task::ready(Ok(())),
161            _tools_updated_subscription: tools_updated_subscription,
162        }
163    }
164
165    fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
166        let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
167            return;
168        };
169        let Some(client) = server.client() else {
170            return;
171        };
172
173        if !client.capable(context_server::protocol::ServerCapability::Tools) {
174            return;
175        }
176
177        let registered_server = self.get_or_register_server(&server_id, cx);
178        registered_server.load_tools = cx.spawn(async move |this, cx| {
179            let response = client
180                .request::<context_server::types::requests::ListTools>(())
181                .await;
182
183            this.update(cx, |this, cx| {
184                let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
185                    return;
186                };
187
188                registered_server.tools.clear();
189                if let Some(response) = response.log_err() {
190                    for tool in response.tools {
191                        let tool = Arc::new(ContextServerTool::new(
192                            this.server_store.clone(),
193                            server.id(),
194                            tool,
195                        ));
196                        registered_server.tools.insert(tool.name(), tool);
197                    }
198                    cx.emit(ContextServerRegistryEvent::ToolsChanged);
199                    cx.notify();
200                }
201            })
202        });
203    }
204
205    fn reload_prompts_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
206        let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
207            return;
208        };
209        let Some(client) = server.client() else {
210            return;
211        };
212        if !client.capable(context_server::protocol::ServerCapability::Prompts) {
213            return;
214        }
215
216        let registered_server = self.get_or_register_server(&server_id, cx);
217
218        registered_server.load_prompts = cx.spawn(async move |this, cx| {
219            let response = client
220                .request::<context_server::types::requests::PromptsList>(())
221                .await;
222
223            this.update(cx, |this, cx| {
224                let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
225                    return;
226                };
227
228                registered_server.prompts.clear();
229                if let Some(response) = response.log_err() {
230                    for prompt in response.prompts {
231                        let name: SharedString = prompt.name.clone().into();
232                        registered_server.prompts.insert(
233                            name,
234                            ContextServerPrompt {
235                                server_id: server_id.clone(),
236                                prompt,
237                            },
238                        );
239                    }
240                    cx.emit(ContextServerRegistryEvent::PromptsChanged);
241                    cx.notify();
242                }
243            })
244        });
245    }
246
247    fn handle_context_server_store_event(
248        &mut self,
249        _: Entity<ContextServerStore>,
250        event: &project::context_server_store::ServerStatusChangedEvent,
251        cx: &mut Context<Self>,
252    ) {
253        let project::context_server_store::ServerStatusChangedEvent { server_id, status } = event;
254
255        match status {
256            ContextServerStatus::Starting => {}
257            ContextServerStatus::Running => {
258                self.reload_tools_for_server(server_id.clone(), cx);
259                self.reload_prompts_for_server(server_id.clone(), cx);
260            }
261            ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
262                if let Some(registered_server) = self.registered_servers.remove(server_id) {
263                    if !registered_server.tools.is_empty() {
264                        cx.emit(ContextServerRegistryEvent::ToolsChanged);
265                    }
266                    if !registered_server.prompts.is_empty() {
267                        cx.emit(ContextServerRegistryEvent::PromptsChanged);
268                    }
269                }
270                cx.notify();
271            }
272        };
273    }
274}
275
276struct ContextServerTool {
277    store: Entity<ContextServerStore>,
278    server_id: ContextServerId,
279    tool: context_server::types::Tool,
280}
281
282impl ContextServerTool {
283    fn new(
284        store: Entity<ContextServerStore>,
285        server_id: ContextServerId,
286        tool: context_server::types::Tool,
287    ) -> Self {
288        Self {
289            store,
290            server_id,
291            tool,
292        }
293    }
294}
295
296impl AnyAgentTool for ContextServerTool {
297    fn name(&self) -> SharedString {
298        self.tool.name.clone().into()
299    }
300
301    fn description(&self) -> SharedString {
302        self.tool.description.clone().unwrap_or_default().into()
303    }
304
305    fn kind(&self) -> ToolKind {
306        ToolKind::Other
307    }
308
309    fn initial_title(&self, _input: serde_json::Value, _cx: &mut App) -> SharedString {
310        format!("Run MCP tool `{}`", self.tool.name).into()
311    }
312
313    fn input_schema(
314        &self,
315        format: language_model::LanguageModelToolSchemaFormat,
316    ) -> Result<serde_json::Value> {
317        let mut schema = self.tool.input_schema.clone();
318        language_model::tool_schema::adapt_schema_to_format(&mut schema, format)?;
319        Ok(match schema {
320            serde_json::Value::Null => {
321                serde_json::json!({ "type": "object", "properties": [] })
322            }
323            serde_json::Value::Object(map) if map.is_empty() => {
324                serde_json::json!({ "type": "object", "properties": [] })
325            }
326            _ => schema,
327        })
328    }
329
330    fn run(
331        self: Arc<Self>,
332        input: ToolInput<serde_json::Value>,
333        event_stream: ToolCallEventStream,
334        cx: &mut App,
335    ) -> Task<Result<AgentToolOutput, AgentToolOutput>> {
336        let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
337            return Task::ready(Err(AgentToolOutput::from_error("Context server not found")));
338        };
339        let tool_name = self.tool.name.clone();
340        let tool_id = mcp_tool_id(&self.server_id.0, &self.tool.name);
341        let display_name = self.tool.name.clone();
342        let initial_title = self.initial_title(serde_json::Value::Null, cx);
343        let authorize =
344            event_stream.authorize_third_party_tool(initial_title, tool_id, display_name, cx);
345
346        cx.spawn(async move |_cx| {
347            let input = input.recv().await.map_err(|e| {
348                AgentToolOutput::from_error(format!("Failed to receive tool input: {e}"))
349            })?;
350
351            authorize.await.map_err(|e| AgentToolOutput::from_error(e.to_string()))?;
352
353            let Some(protocol) = server.client() else {
354                return Err(AgentToolOutput::from_error("Context server not initialized"));
355            };
356
357            let arguments = if let serde_json::Value::Object(map) = input {
358                Some(map.into_iter().collect())
359            } else {
360                None
361            };
362
363            log::trace!(
364                "Running tool: {} with arguments: {:?}",
365                tool_name,
366                arguments
367            );
368
369            let request = protocol.request::<context_server::types::requests::CallTool>(
370                context_server::types::CallToolParams {
371                    name: tool_name,
372                    arguments,
373                    meta: None,
374                },
375            );
376
377            let response = futures::select! {
378                response = request.fuse() => response.map_err(|e| AgentToolOutput::from_error(e.to_string()))?,
379                _ = event_stream.cancelled_by_user().fuse() => {
380                    return Err(AgentToolOutput::from_error("MCP tool cancelled by user"));
381                }
382            };
383
384            if response.is_error == Some(true) {
385                let error_message: String =
386                    response.content.iter().filter_map(|c| c.text()).collect();
387                return Err(AgentToolOutput::from_error(error_message));
388            }
389
390            let mut result = String::new();
391            for content in response.content {
392                match content {
393                    context_server::types::ToolResponseContent::Text { text } => {
394                        result.push_str(&text);
395                    }
396                    context_server::types::ToolResponseContent::Image { .. } => {
397                        log::warn!("Ignoring image content from tool response");
398                    }
399                    context_server::types::ToolResponseContent::Audio { .. } => {
400                        log::warn!("Ignoring audio content from tool response");
401                    }
402                    context_server::types::ToolResponseContent::Resource { .. } => {
403                        log::warn!("Ignoring resource content from tool response");
404                    }
405                }
406            }
407            Ok(AgentToolOutput {
408                raw_output: result.clone().into(),
409                llm_output: result.into(),
410            })
411        })
412    }
413
414    fn replay(
415        &self,
416        _input: serde_json::Value,
417        _output: serde_json::Value,
418        _event_stream: ToolCallEventStream,
419        _cx: &mut App,
420    ) -> Result<()> {
421        Ok(())
422    }
423}
424
425pub fn get_prompt(
426    server_store: &Entity<ContextServerStore>,
427    server_id: &ContextServerId,
428    prompt_name: &str,
429    arguments: HashMap<String, String>,
430    cx: &mut AsyncApp,
431) -> Task<Result<context_server::types::PromptsGetResponse>> {
432    let server = cx.update(|cx| server_store.read(cx).get_running_server(server_id));
433    let Some(server) = server else {
434        return Task::ready(Err(anyhow::anyhow!("Context server not found")));
435    };
436
437    let Some(protocol) = server.client() else {
438        return Task::ready(Err(anyhow::anyhow!("Context server not initialized")));
439    };
440
441    let prompt_name = prompt_name.to_string();
442
443    cx.background_spawn(async move {
444        let response = protocol
445            .request::<context_server::types::requests::PromptsGet>(
446                context_server::types::PromptsGetParams {
447                    name: prompt_name,
448                    arguments: (!arguments.is_empty()).then(|| arguments),
449                    meta: None,
450                },
451            )
452            .await?;
453
454        Ok(response)
455    })
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    #[test]
463    fn test_mcp_tool_id_format() {
464        assert_eq!(
465            mcp_tool_id("filesystem", "read_file"),
466            "mcp:filesystem:read_file"
467        );
468        assert_eq!(
469            mcp_tool_id("github", "create_issue"),
470            "mcp:github:create_issue"
471        );
472        assert_eq!(
473            mcp_tool_id("my-custom-server", "do_something"),
474            "mcp:my-custom-server:do_something"
475        );
476        // Underscores in names
477        assert_eq!(mcp_tool_id("my_server", "my_tool"), "mcp:my_server:my_tool");
478    }
479
480    // Note: Tests for MCP tool ID collision with built-in tools and permission
481    // decisions are in crates/agent/src/tool_permissions.rs to avoid duplication.
482}