context_server_registry.rs

  1use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
  2use agent_client_protocol::ToolKind;
  3use anyhow::{Result, anyhow};
  4use collections::{BTreeMap, HashMap};
  5use context_server::{ContextServerId, client::NotificationSubscription};
  6use futures::FutureExt as _;
  7use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task};
  8use project::context_server_store::{ContextServerStatus, ContextServerStore};
  9use std::sync::Arc;
 10use util::ResultExt;
 11
 12pub struct ContextServerPrompt {
 13    pub server_id: ContextServerId,
 14    pub prompt: context_server::types::Prompt,
 15}
 16
 17pub enum ContextServerRegistryEvent {
 18    ToolsChanged,
 19    PromptsChanged,
 20}
 21
 22impl EventEmitter<ContextServerRegistryEvent> for ContextServerRegistry {}
 23
 24pub struct ContextServerRegistry {
 25    server_store: Entity<ContextServerStore>,
 26    registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
 27    _subscription: gpui::Subscription,
 28}
 29
 30struct RegisteredContextServer {
 31    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 32    prompts: BTreeMap<SharedString, ContextServerPrompt>,
 33    load_tools: Task<Result<()>>,
 34    load_prompts: Task<Result<()>>,
 35    _tools_updated_subscription: Option<NotificationSubscription>,
 36}
 37
 38impl ContextServerRegistry {
 39    pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
 40        let mut this = Self {
 41            server_store: server_store.clone(),
 42            registered_servers: HashMap::default(),
 43            _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
 44        };
 45        for server in server_store.read(cx).running_servers() {
 46            this.reload_tools_for_server(server.id(), cx);
 47            this.reload_prompts_for_server(server.id(), cx);
 48        }
 49        this
 50    }
 51
 52    pub fn tools_for_server(
 53        &self,
 54        server_id: &ContextServerId,
 55    ) -> impl Iterator<Item = &Arc<dyn AnyAgentTool>> {
 56        self.registered_servers
 57            .get(server_id)
 58            .map(|server| server.tools.values())
 59            .into_iter()
 60            .flatten()
 61    }
 62
 63    pub fn servers(
 64        &self,
 65    ) -> impl Iterator<
 66        Item = (
 67            &ContextServerId,
 68            &BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 69        ),
 70    > {
 71        self.registered_servers
 72            .iter()
 73            .map(|(id, server)| (id, &server.tools))
 74    }
 75
 76    pub fn prompts(&self) -> impl Iterator<Item = &ContextServerPrompt> {
 77        self.registered_servers
 78            .values()
 79            .flat_map(|server| server.prompts.values())
 80    }
 81
 82    pub fn find_prompt(
 83        &self,
 84        server_id: Option<&ContextServerId>,
 85        name: &str,
 86    ) -> Option<&ContextServerPrompt> {
 87        if let Some(server_id) = server_id {
 88            self.registered_servers
 89                .get(server_id)
 90                .and_then(|server| server.prompts.get(name))
 91        } else {
 92            self.registered_servers
 93                .values()
 94                .find_map(|server| server.prompts.get(name))
 95        }
 96    }
 97
 98    pub fn server_store(&self) -> &Entity<ContextServerStore> {
 99        &self.server_store
100    }
101
102    fn get_or_register_server(
103        &mut self,
104        server_id: &ContextServerId,
105        cx: &mut Context<Self>,
106    ) -> &mut RegisteredContextServer {
107        self.registered_servers
108            .entry(server_id.clone())
109            .or_insert_with(|| Self::init_registered_server(server_id, &self.server_store, cx))
110    }
111
112    fn init_registered_server(
113        server_id: &ContextServerId,
114        server_store: &Entity<ContextServerStore>,
115        cx: &mut Context<Self>,
116    ) -> RegisteredContextServer {
117        let tools_updated_subscription = server_store
118            .read(cx)
119            .get_running_server(server_id)
120            .and_then(|server| {
121                let client = server.client()?;
122
123                if !client.capable(context_server::protocol::ServerCapability::Tools) {
124                    return None;
125                }
126
127                let server_id = server.id();
128                let this = cx.entity().downgrade();
129
130                Some(client.on_notification(
131                    "notifications/tools/list_changed",
132                    Box::new(move |_params, cx: AsyncApp| {
133                        let server_id = server_id.clone();
134                        let this = this.clone();
135                        cx.spawn(async move |cx| {
136                            this.update(cx, |this, cx| {
137                                log::info!(
138                                    "Received tools/list_changed notification for server {}",
139                                    server_id
140                                );
141                                this.reload_tools_for_server(server_id, cx);
142                            })
143                        })
144                        .detach();
145                    }),
146                ))
147            });
148
149        RegisteredContextServer {
150            tools: BTreeMap::default(),
151            prompts: BTreeMap::default(),
152            load_tools: Task::ready(Ok(())),
153            load_prompts: Task::ready(Ok(())),
154            _tools_updated_subscription: tools_updated_subscription,
155        }
156    }
157
158    fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
159        let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
160            return;
161        };
162        let Some(client) = server.client() else {
163            return;
164        };
165
166        if !client.capable(context_server::protocol::ServerCapability::Tools) {
167            return;
168        }
169
170        let registered_server = self.get_or_register_server(&server_id, cx);
171        registered_server.load_tools = cx.spawn(async move |this, cx| {
172            let response = client
173                .request::<context_server::types::requests::ListTools>(())
174                .await;
175
176            this.update(cx, |this, cx| {
177                let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
178                    return;
179                };
180
181                registered_server.tools.clear();
182                if let Some(response) = response.log_err() {
183                    for tool in response.tools {
184                        let tool = Arc::new(ContextServerTool::new(
185                            this.server_store.clone(),
186                            server.id(),
187                            tool,
188                        ));
189                        registered_server.tools.insert(tool.name(), tool);
190                    }
191                    cx.emit(ContextServerRegistryEvent::ToolsChanged);
192                    cx.notify();
193                }
194            })
195        });
196    }
197
198    fn reload_prompts_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
199        let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
200            return;
201        };
202        let Some(client) = server.client() else {
203            return;
204        };
205        if !client.capable(context_server::protocol::ServerCapability::Prompts) {
206            return;
207        }
208
209        let registered_server = self.get_or_register_server(&server_id, cx);
210
211        registered_server.load_prompts = cx.spawn(async move |this, cx| {
212            let response = client
213                .request::<context_server::types::requests::PromptsList>(())
214                .await;
215
216            this.update(cx, |this, cx| {
217                let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
218                    return;
219                };
220
221                registered_server.prompts.clear();
222                if let Some(response) = response.log_err() {
223                    for prompt in response.prompts {
224                        let name: SharedString = prompt.name.clone().into();
225                        registered_server.prompts.insert(
226                            name,
227                            ContextServerPrompt {
228                                server_id: server_id.clone(),
229                                prompt,
230                            },
231                        );
232                    }
233                    cx.emit(ContextServerRegistryEvent::PromptsChanged);
234                    cx.notify();
235                }
236            })
237        });
238    }
239
240    fn handle_context_server_store_event(
241        &mut self,
242        _: Entity<ContextServerStore>,
243        event: &project::context_server_store::Event,
244        cx: &mut Context<Self>,
245    ) {
246        match event {
247            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
248                match status {
249                    ContextServerStatus::Starting => {}
250                    ContextServerStatus::Running => {
251                        self.reload_tools_for_server(server_id.clone(), cx);
252                        self.reload_prompts_for_server(server_id.clone(), cx);
253                    }
254                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
255                        if let Some(registered_server) = self.registered_servers.remove(server_id) {
256                            if !registered_server.tools.is_empty() {
257                                cx.emit(ContextServerRegistryEvent::ToolsChanged);
258                            }
259                            if !registered_server.prompts.is_empty() {
260                                cx.emit(ContextServerRegistryEvent::PromptsChanged);
261                            }
262                        }
263                        cx.notify();
264                    }
265                }
266            }
267        }
268    }
269}
270
271struct ContextServerTool {
272    store: Entity<ContextServerStore>,
273    server_id: ContextServerId,
274    tool: context_server::types::Tool,
275}
276
277impl ContextServerTool {
278    fn new(
279        store: Entity<ContextServerStore>,
280        server_id: ContextServerId,
281        tool: context_server::types::Tool,
282    ) -> Self {
283        Self {
284            store,
285            server_id,
286            tool,
287        }
288    }
289}
290
291impl AnyAgentTool for ContextServerTool {
292    fn name(&self) -> SharedString {
293        self.tool.name.clone().into()
294    }
295
296    fn description(&self) -> SharedString {
297        self.tool.description.clone().unwrap_or_default().into()
298    }
299
300    fn kind(&self) -> ToolKind {
301        ToolKind::Other
302    }
303
304    fn initial_title(&self, _input: serde_json::Value, _cx: &mut App) -> SharedString {
305        format!("Run MCP tool `{}`", self.tool.name).into()
306    }
307
308    fn input_schema(
309        &self,
310        format: language_model::LanguageModelToolSchemaFormat,
311    ) -> Result<serde_json::Value> {
312        let mut schema = self.tool.input_schema.clone();
313        language_model::tool_schema::adapt_schema_to_format(&mut schema, format)?;
314        Ok(match schema {
315            serde_json::Value::Null => {
316                serde_json::json!({ "type": "object", "properties": [] })
317            }
318            serde_json::Value::Object(map) if map.is_empty() => {
319                serde_json::json!({ "type": "object", "properties": [] })
320            }
321            _ => schema,
322        })
323    }
324
325    fn run(
326        self: Arc<Self>,
327        input: serde_json::Value,
328        event_stream: ToolCallEventStream,
329        cx: &mut App,
330    ) -> Task<Result<AgentToolOutput>> {
331        let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
332            return Task::ready(Err(anyhow!("Context server not found")));
333        };
334        let tool_name = self.tool.name.clone();
335        let authorize = event_stream.authorize(self.initial_title(input.clone(), cx), cx);
336
337        cx.spawn(async move |_cx| {
338            authorize.await?;
339
340            let Some(protocol) = server.client() else {
341                anyhow::bail!("Context server not initialized");
342            };
343
344            let arguments = if let serde_json::Value::Object(map) = input {
345                Some(map.into_iter().collect())
346            } else {
347                None
348            };
349
350            log::trace!(
351                "Running tool: {} with arguments: {:?}",
352                tool_name,
353                arguments
354            );
355
356            let request = protocol.request::<context_server::types::requests::CallTool>(
357                context_server::types::CallToolParams {
358                    name: tool_name,
359                    arguments,
360                    meta: None,
361                },
362            );
363
364            let response = futures::select! {
365                response = request.fuse() => response?,
366                _ = event_stream.cancelled_by_user().fuse() => {
367                    anyhow::bail!("MCP tool cancelled by user");
368                }
369            };
370
371            let mut result = String::new();
372            for content in response.content {
373                match content {
374                    context_server::types::ToolResponseContent::Text { text } => {
375                        result.push_str(&text);
376                    }
377                    context_server::types::ToolResponseContent::Image { .. } => {
378                        log::warn!("Ignoring image content from tool response");
379                    }
380                    context_server::types::ToolResponseContent::Audio { .. } => {
381                        log::warn!("Ignoring audio content from tool response");
382                    }
383                    context_server::types::ToolResponseContent::Resource { .. } => {
384                        log::warn!("Ignoring resource content from tool response");
385                    }
386                }
387            }
388            Ok(AgentToolOutput {
389                raw_output: result.clone().into(),
390                llm_output: result.into(),
391            })
392        })
393    }
394
395    fn replay(
396        &self,
397        _input: serde_json::Value,
398        _output: serde_json::Value,
399        _event_stream: ToolCallEventStream,
400        _cx: &mut App,
401    ) -> Result<()> {
402        Ok(())
403    }
404}
405
406pub fn get_prompt(
407    server_store: &Entity<ContextServerStore>,
408    server_id: &ContextServerId,
409    prompt_name: &str,
410    arguments: HashMap<String, String>,
411    cx: &mut AsyncApp,
412) -> Task<Result<context_server::types::PromptsGetResponse>> {
413    let server = cx.update(|cx| server_store.read(cx).get_running_server(server_id));
414    let Some(server) = server else {
415        return Task::ready(Err(anyhow::anyhow!("Context server not found")));
416    };
417
418    let Some(protocol) = server.client() else {
419        return Task::ready(Err(anyhow::anyhow!("Context server not initialized")));
420    };
421
422    let prompt_name = prompt_name.to_string();
423
424    cx.background_spawn(async move {
425        let response = protocol
426            .request::<context_server::types::requests::PromptsGet>(
427                context_server::types::PromptsGetParams {
428                    name: prompt_name,
429                    arguments: (!arguments.is_empty()).then(|| arguments),
430                    meta: None,
431                },
432            )
433            .await?;
434
435        Ok(response)
436    })
437}