context_server_registry.rs

  1use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
  2use agent_client_protocol::ToolKind;
  3use anyhow::{Result, anyhow, bail};
  4use collections::{BTreeMap, HashMap};
  5use context_server::{ContextServerId, client::NotificationSubscription};
  6use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task};
  7use project::context_server_store::{ContextServerStatus, ContextServerStore};
  8use std::sync::Arc;
  9use util::ResultExt;
 10
 11pub struct ContextServerPrompt {
 12    pub server_id: ContextServerId,
 13    pub prompt: context_server::types::Prompt,
 14}
 15
 16pub enum ContextServerRegistryEvent {
 17    ToolsChanged,
 18    PromptsChanged,
 19}
 20
 21impl EventEmitter<ContextServerRegistryEvent> for ContextServerRegistry {}
 22
 23pub struct ContextServerRegistry {
 24    server_store: Entity<ContextServerStore>,
 25    registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
 26    _subscription: gpui::Subscription,
 27}
 28
 29struct RegisteredContextServer {
 30    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 31    prompts: BTreeMap<SharedString, ContextServerPrompt>,
 32    load_tools: Task<Result<()>>,
 33    load_prompts: Task<Result<()>>,
 34    _tools_updated_subscription: Option<NotificationSubscription>,
 35}
 36
 37impl ContextServerRegistry {
 38    pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
 39        let mut this = Self {
 40            server_store: server_store.clone(),
 41            registered_servers: HashMap::default(),
 42            _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
 43        };
 44        for server in server_store.read(cx).running_servers() {
 45            this.reload_tools_for_server(server.id(), cx);
 46            this.reload_prompts_for_server(server.id(), cx);
 47        }
 48        this
 49    }
 50
 51    pub fn tools_for_server(
 52        &self,
 53        server_id: &ContextServerId,
 54    ) -> impl Iterator<Item = &Arc<dyn AnyAgentTool>> {
 55        self.registered_servers
 56            .get(server_id)
 57            .map(|server| server.tools.values())
 58            .into_iter()
 59            .flatten()
 60    }
 61
 62    pub fn servers(
 63        &self,
 64    ) -> impl Iterator<
 65        Item = (
 66            &ContextServerId,
 67            &BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 68        ),
 69    > {
 70        self.registered_servers
 71            .iter()
 72            .map(|(id, server)| (id, &server.tools))
 73    }
 74
 75    pub fn prompts(&self) -> impl Iterator<Item = &ContextServerPrompt> {
 76        self.registered_servers
 77            .values()
 78            .flat_map(|server| server.prompts.values())
 79    }
 80
 81    pub fn find_prompt(
 82        &self,
 83        server_id: Option<&ContextServerId>,
 84        name: &str,
 85    ) -> Option<&ContextServerPrompt> {
 86        if let Some(server_id) = server_id {
 87            self.registered_servers
 88                .get(server_id)
 89                .and_then(|server| server.prompts.get(name))
 90        } else {
 91            self.registered_servers
 92                .values()
 93                .find_map(|server| server.prompts.get(name))
 94        }
 95    }
 96
 97    pub fn server_store(&self) -> &Entity<ContextServerStore> {
 98        &self.server_store
 99    }
100
101    fn get_or_register_server(
102        &mut self,
103        server_id: &ContextServerId,
104        cx: &mut Context<Self>,
105    ) -> &mut RegisteredContextServer {
106        self.registered_servers
107            .entry(server_id.clone())
108            .or_insert_with(|| Self::init_registered_server(server_id, &self.server_store, cx))
109    }
110
111    fn init_registered_server(
112        server_id: &ContextServerId,
113        server_store: &Entity<ContextServerStore>,
114        cx: &mut Context<Self>,
115    ) -> RegisteredContextServer {
116        let tools_updated_subscription = server_store
117            .read(cx)
118            .get_running_server(server_id)
119            .and_then(|server| {
120                let client = server.client()?;
121
122                if !client.capable(context_server::protocol::ServerCapability::Tools) {
123                    return None;
124                }
125
126                let server_id = server.id();
127                let this = cx.entity().downgrade();
128
129                Some(client.on_notification(
130                    "notifications/tools/list_changed",
131                    Box::new(move |_params, cx: AsyncApp| {
132                        let server_id = server_id.clone();
133                        let this = this.clone();
134                        cx.spawn(async move |cx| {
135                            this.update(cx, |this, cx| {
136                                log::info!(
137                                    "Received tools/list_changed notification for server {}",
138                                    server_id
139                                );
140                                this.reload_tools_for_server(server_id, cx);
141                            })
142                        })
143                        .detach();
144                    }),
145                ))
146            });
147
148        RegisteredContextServer {
149            tools: BTreeMap::default(),
150            prompts: BTreeMap::default(),
151            load_tools: Task::ready(Ok(())),
152            load_prompts: Task::ready(Ok(())),
153            _tools_updated_subscription: tools_updated_subscription,
154        }
155    }
156
157    fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
158        let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
159            return;
160        };
161        let Some(client) = server.client() else {
162            return;
163        };
164
165        if !client.capable(context_server::protocol::ServerCapability::Tools) {
166            return;
167        }
168
169        let registered_server = self.get_or_register_server(&server_id, cx);
170        registered_server.load_tools = cx.spawn(async move |this, cx| {
171            let response = client
172                .request::<context_server::types::requests::ListTools>(())
173                .await;
174
175            this.update(cx, |this, cx| {
176                let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
177                    return;
178                };
179
180                registered_server.tools.clear();
181                if let Some(response) = response.log_err() {
182                    for tool in response.tools {
183                        let tool = Arc::new(ContextServerTool::new(
184                            this.server_store.clone(),
185                            server.id(),
186                            tool,
187                        ));
188                        registered_server.tools.insert(tool.name(), tool);
189                    }
190                    cx.emit(ContextServerRegistryEvent::ToolsChanged);
191                    cx.notify();
192                }
193            })
194        });
195    }
196
197    fn reload_prompts_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
198        let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
199            return;
200        };
201        let Some(client) = server.client() else {
202            return;
203        };
204        if !client.capable(context_server::protocol::ServerCapability::Prompts) {
205            return;
206        }
207
208        let registered_server = self.get_or_register_server(&server_id, cx);
209
210        registered_server.load_prompts = cx.spawn(async move |this, cx| {
211            let response = client
212                .request::<context_server::types::requests::PromptsList>(())
213                .await;
214
215            this.update(cx, |this, cx| {
216                let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
217                    return;
218                };
219
220                registered_server.prompts.clear();
221                if let Some(response) = response.log_err() {
222                    for prompt in response.prompts {
223                        let name: SharedString = prompt.name.clone().into();
224                        registered_server.prompts.insert(
225                            name,
226                            ContextServerPrompt {
227                                server_id: server_id.clone(),
228                                prompt,
229                            },
230                        );
231                    }
232                    cx.emit(ContextServerRegistryEvent::PromptsChanged);
233                    cx.notify();
234                }
235            })
236        });
237    }
238
239    fn handle_context_server_store_event(
240        &mut self,
241        _: Entity<ContextServerStore>,
242        event: &project::context_server_store::Event,
243        cx: &mut Context<Self>,
244    ) {
245        match event {
246            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
247                match status {
248                    ContextServerStatus::Starting => {}
249                    ContextServerStatus::Running => {
250                        self.reload_tools_for_server(server_id.clone(), cx);
251                        self.reload_prompts_for_server(server_id.clone(), cx);
252                    }
253                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
254                        if let Some(registered_server) = self.registered_servers.remove(server_id) {
255                            if !registered_server.tools.is_empty() {
256                                cx.emit(ContextServerRegistryEvent::ToolsChanged);
257                            }
258                            if !registered_server.prompts.is_empty() {
259                                cx.emit(ContextServerRegistryEvent::PromptsChanged);
260                            }
261                        }
262                        cx.notify();
263                    }
264                }
265            }
266        }
267    }
268}
269
270struct ContextServerTool {
271    store: Entity<ContextServerStore>,
272    server_id: ContextServerId,
273    tool: context_server::types::Tool,
274}
275
276impl ContextServerTool {
277    fn new(
278        store: Entity<ContextServerStore>,
279        server_id: ContextServerId,
280        tool: context_server::types::Tool,
281    ) -> Self {
282        Self {
283            store,
284            server_id,
285            tool,
286        }
287    }
288}
289
290impl AnyAgentTool for ContextServerTool {
291    fn name(&self) -> SharedString {
292        self.tool.name.clone().into()
293    }
294
295    fn description(&self) -> SharedString {
296        self.tool.description.clone().unwrap_or_default().into()
297    }
298
299    fn kind(&self) -> ToolKind {
300        ToolKind::Other
301    }
302
303    fn initial_title(&self, _input: serde_json::Value, _cx: &mut App) -> SharedString {
304        format!("Run MCP tool `{}`", self.tool.name).into()
305    }
306
307    fn input_schema(
308        &self,
309        format: language_model::LanguageModelToolSchemaFormat,
310    ) -> Result<serde_json::Value> {
311        let mut schema = self.tool.input_schema.clone();
312        language_model::tool_schema::adapt_schema_to_format(&mut schema, format)?;
313        Ok(match schema {
314            serde_json::Value::Null => {
315                serde_json::json!({ "type": "object", "properties": [] })
316            }
317            serde_json::Value::Object(map) if map.is_empty() => {
318                serde_json::json!({ "type": "object", "properties": [] })
319            }
320            _ => schema,
321        })
322    }
323
324    fn run(
325        self: Arc<Self>,
326        input: serde_json::Value,
327        event_stream: ToolCallEventStream,
328        cx: &mut App,
329    ) -> Task<Result<AgentToolOutput>> {
330        let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
331            return Task::ready(Err(anyhow!("Context server not found")));
332        };
333        let tool_name = self.tool.name.clone();
334        let authorize = event_stream.authorize(self.initial_title(input.clone(), cx), cx);
335
336        cx.spawn(async move |_cx| {
337            authorize.await?;
338
339            let Some(protocol) = server.client() else {
340                bail!("Context server not initialized");
341            };
342
343            let arguments = if let serde_json::Value::Object(map) = input {
344                Some(map.into_iter().collect())
345            } else {
346                None
347            };
348
349            log::trace!(
350                "Running tool: {} with arguments: {:?}",
351                tool_name,
352                arguments
353            );
354            let response = protocol
355                .request::<context_server::types::requests::CallTool>(
356                    context_server::types::CallToolParams {
357                        name: tool_name,
358                        arguments,
359                        meta: None,
360                    },
361                )
362                .await?;
363
364            let mut result = String::new();
365            for content in response.content {
366                match content {
367                    context_server::types::ToolResponseContent::Text { text } => {
368                        result.push_str(&text);
369                    }
370                    context_server::types::ToolResponseContent::Image { .. } => {
371                        log::warn!("Ignoring image content from tool response");
372                    }
373                    context_server::types::ToolResponseContent::Audio { .. } => {
374                        log::warn!("Ignoring audio content from tool response");
375                    }
376                    context_server::types::ToolResponseContent::Resource { .. } => {
377                        log::warn!("Ignoring resource content from tool response");
378                    }
379                }
380            }
381            Ok(AgentToolOutput {
382                raw_output: result.clone().into(),
383                llm_output: result.into(),
384            })
385        })
386    }
387
388    fn replay(
389        &self,
390        _input: serde_json::Value,
391        _output: serde_json::Value,
392        _event_stream: ToolCallEventStream,
393        _cx: &mut App,
394    ) -> Result<()> {
395        Ok(())
396    }
397}
398
399pub fn get_prompt(
400    server_store: &Entity<ContextServerStore>,
401    server_id: &ContextServerId,
402    prompt_name: &str,
403    arguments: HashMap<String, String>,
404    cx: &mut AsyncApp,
405) -> Task<Result<context_server::types::PromptsGetResponse>> {
406    let server = match cx.update(|cx| server_store.read(cx).get_running_server(server_id)) {
407        Ok(server) => server,
408        Err(error) => return Task::ready(Err(error)),
409    };
410    let Some(server) = server else {
411        return Task::ready(Err(anyhow::anyhow!("Context server not found")));
412    };
413
414    let Some(protocol) = server.client() else {
415        return Task::ready(Err(anyhow::anyhow!("Context server not initialized")));
416    };
417
418    let prompt_name = prompt_name.to_string();
419
420    cx.background_spawn(async move {
421        let response = protocol
422            .request::<context_server::types::requests::PromptsGet>(
423                context_server::types::PromptsGetParams {
424                    name: prompt_name,
425                    arguments: (!arguments.is_empty()).then(|| arguments),
426                    meta: None,
427                },
428            )
429            .await?;
430
431        Ok(response)
432    })
433}