context_server_registry.rs

  1use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
  2use agent_client_protocol::ToolKind;
  3use anyhow::{Result, anyhow, bail};
  4use collections::{BTreeMap, HashMap};
  5use context_server::ContextServerId;
  6use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task};
  7use project::context_server_store::{ContextServerStatus, ContextServerStore};
  8use std::sync::Arc;
  9use util::ResultExt;
 10
 11pub struct ContextServerPrompt {
 12    pub server_id: ContextServerId,
 13    pub prompt: context_server::types::Prompt,
 14}
 15
 16pub enum ContextServerRegistryEvent {
 17    ToolsChanged,
 18    PromptsChanged,
 19}
 20
 21impl EventEmitter<ContextServerRegistryEvent> for ContextServerRegistry {}
 22
 23pub struct ContextServerRegistry {
 24    server_store: Entity<ContextServerStore>,
 25    registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
 26    _subscription: gpui::Subscription,
 27}
 28
 29struct RegisteredContextServer {
 30    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 31    prompts: BTreeMap<SharedString, ContextServerPrompt>,
 32    load_tools: Task<Result<()>>,
 33    load_prompts: Task<Result<()>>,
 34}
 35
 36impl RegisteredContextServer {
 37    fn new() -> Self {
 38        Self {
 39            tools: BTreeMap::default(),
 40            prompts: BTreeMap::default(),
 41            load_tools: Task::ready(Ok(())),
 42            load_prompts: Task::ready(Ok(())),
 43        }
 44    }
 45}
 46
 47impl ContextServerRegistry {
 48    pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
 49        let mut this = Self {
 50            server_store: server_store.clone(),
 51            registered_servers: HashMap::default(),
 52            _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
 53        };
 54        for server in server_store.read(cx).running_servers() {
 55            this.reload_tools_for_server(server.id(), cx);
 56            this.reload_prompts_for_server(server.id(), cx);
 57        }
 58        this
 59    }
 60
 61    pub fn tools_for_server(
 62        &self,
 63        server_id: &ContextServerId,
 64    ) -> impl Iterator<Item = &Arc<dyn AnyAgentTool>> {
 65        self.registered_servers
 66            .get(server_id)
 67            .map(|server| server.tools.values())
 68            .into_iter()
 69            .flatten()
 70    }
 71
 72    pub fn servers(
 73        &self,
 74    ) -> impl Iterator<
 75        Item = (
 76            &ContextServerId,
 77            &BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 78        ),
 79    > {
 80        self.registered_servers
 81            .iter()
 82            .map(|(id, server)| (id, &server.tools))
 83    }
 84
 85    pub fn prompts(&self) -> impl Iterator<Item = &ContextServerPrompt> {
 86        self.registered_servers
 87            .values()
 88            .flat_map(|server| server.prompts.values())
 89    }
 90
 91    pub fn find_prompt(
 92        &self,
 93        server_id: Option<&ContextServerId>,
 94        name: &str,
 95    ) -> Option<&ContextServerPrompt> {
 96        if let Some(server_id) = server_id {
 97            self.registered_servers
 98                .get(server_id)
 99                .and_then(|server| server.prompts.get(name))
100        } else {
101            self.registered_servers
102                .values()
103                .find_map(|server| server.prompts.get(name))
104        }
105    }
106
107    pub fn server_store(&self) -> &Entity<ContextServerStore> {
108        &self.server_store
109    }
110
111    fn get_or_register_server(
112        &mut self,
113        server_id: &ContextServerId,
114    ) -> &mut RegisteredContextServer {
115        self.registered_servers
116            .entry(server_id.clone())
117            .or_insert_with(RegisteredContextServer::new)
118    }
119
120    fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
121        let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
122            return;
123        };
124        let Some(client) = server.client() else {
125            return;
126        };
127        if !client.capable(context_server::protocol::ServerCapability::Tools) {
128            return;
129        }
130
131        let registered_server = self.get_or_register_server(&server_id);
132        registered_server.load_tools = cx.spawn(async move |this, cx| {
133            let response = client
134                .request::<context_server::types::requests::ListTools>(())
135                .await;
136
137            this.update(cx, |this, cx| {
138                let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
139                    return;
140                };
141
142                registered_server.tools.clear();
143                if let Some(response) = response.log_err() {
144                    for tool in response.tools {
145                        let tool = Arc::new(ContextServerTool::new(
146                            this.server_store.clone(),
147                            server.id(),
148                            tool,
149                        ));
150                        registered_server.tools.insert(tool.name(), tool);
151                    }
152                    cx.emit(ContextServerRegistryEvent::ToolsChanged);
153                    cx.notify();
154                }
155            })
156        });
157    }
158
159    fn reload_prompts_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
160        let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
161            return;
162        };
163        let Some(client) = server.client() else {
164            return;
165        };
166        if !client.capable(context_server::protocol::ServerCapability::Prompts) {
167            return;
168        }
169
170        let registered_server = self.get_or_register_server(&server_id);
171
172        registered_server.load_prompts = cx.spawn(async move |this, cx| {
173            let response = client
174                .request::<context_server::types::requests::PromptsList>(())
175                .await;
176
177            this.update(cx, |this, cx| {
178                let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
179                    return;
180                };
181
182                registered_server.prompts.clear();
183                if let Some(response) = response.log_err() {
184                    for prompt in response.prompts {
185                        let name: SharedString = prompt.name.clone().into();
186                        registered_server.prompts.insert(
187                            name,
188                            ContextServerPrompt {
189                                server_id: server_id.clone(),
190                                prompt,
191                            },
192                        );
193                    }
194                    cx.emit(ContextServerRegistryEvent::PromptsChanged);
195                    cx.notify();
196                }
197            })
198        });
199    }
200
201    fn handle_context_server_store_event(
202        &mut self,
203        _: Entity<ContextServerStore>,
204        event: &project::context_server_store::Event,
205        cx: &mut Context<Self>,
206    ) {
207        match event {
208            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
209                match status {
210                    ContextServerStatus::Starting => {}
211                    ContextServerStatus::Running => {
212                        self.reload_tools_for_server(server_id.clone(), cx);
213                        self.reload_prompts_for_server(server_id.clone(), cx);
214                    }
215                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
216                        if let Some(registered_server) = self.registered_servers.remove(server_id) {
217                            if !registered_server.tools.is_empty() {
218                                cx.emit(ContextServerRegistryEvent::ToolsChanged);
219                            }
220                            if !registered_server.prompts.is_empty() {
221                                cx.emit(ContextServerRegistryEvent::PromptsChanged);
222                            }
223                        }
224                        cx.notify();
225                    }
226                }
227            }
228        }
229    }
230}
231
232struct ContextServerTool {
233    store: Entity<ContextServerStore>,
234    server_id: ContextServerId,
235    tool: context_server::types::Tool,
236}
237
238impl ContextServerTool {
239    fn new(
240        store: Entity<ContextServerStore>,
241        server_id: ContextServerId,
242        tool: context_server::types::Tool,
243    ) -> Self {
244        Self {
245            store,
246            server_id,
247            tool,
248        }
249    }
250}
251
252impl AnyAgentTool for ContextServerTool {
253    fn name(&self) -> SharedString {
254        self.tool.name.clone().into()
255    }
256
257    fn description(&self) -> SharedString {
258        self.tool.description.clone().unwrap_or_default().into()
259    }
260
261    fn kind(&self) -> ToolKind {
262        ToolKind::Other
263    }
264
265    fn initial_title(&self, _input: serde_json::Value, _cx: &mut App) -> SharedString {
266        format!("Run MCP tool `{}`", self.tool.name).into()
267    }
268
269    fn input_schema(
270        &self,
271        format: language_model::LanguageModelToolSchemaFormat,
272    ) -> Result<serde_json::Value> {
273        let mut schema = self.tool.input_schema.clone();
274        language_model::tool_schema::adapt_schema_to_format(&mut schema, format)?;
275        Ok(match schema {
276            serde_json::Value::Null => {
277                serde_json::json!({ "type": "object", "properties": [] })
278            }
279            serde_json::Value::Object(map) if map.is_empty() => {
280                serde_json::json!({ "type": "object", "properties": [] })
281            }
282            _ => schema,
283        })
284    }
285
286    fn run(
287        self: Arc<Self>,
288        input: serde_json::Value,
289        event_stream: ToolCallEventStream,
290        cx: &mut App,
291    ) -> Task<Result<AgentToolOutput>> {
292        let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
293            return Task::ready(Err(anyhow!("Context server not found")));
294        };
295        let tool_name = self.tool.name.clone();
296        let authorize = event_stream.authorize(self.initial_title(input.clone(), cx), cx);
297
298        cx.spawn(async move |_cx| {
299            authorize.await?;
300
301            let Some(protocol) = server.client() else {
302                bail!("Context server not initialized");
303            };
304
305            let arguments = if let serde_json::Value::Object(map) = input {
306                Some(map.into_iter().collect())
307            } else {
308                None
309            };
310
311            log::trace!(
312                "Running tool: {} with arguments: {:?}",
313                tool_name,
314                arguments
315            );
316            let response = protocol
317                .request::<context_server::types::requests::CallTool>(
318                    context_server::types::CallToolParams {
319                        name: tool_name,
320                        arguments,
321                        meta: None,
322                    },
323                )
324                .await?;
325
326            let mut result = String::new();
327            for content in response.content {
328                match content {
329                    context_server::types::ToolResponseContent::Text { text } => {
330                        result.push_str(&text);
331                    }
332                    context_server::types::ToolResponseContent::Image { .. } => {
333                        log::warn!("Ignoring image content from tool response");
334                    }
335                    context_server::types::ToolResponseContent::Audio { .. } => {
336                        log::warn!("Ignoring audio content from tool response");
337                    }
338                    context_server::types::ToolResponseContent::Resource { .. } => {
339                        log::warn!("Ignoring resource content from tool response");
340                    }
341                }
342            }
343            Ok(AgentToolOutput {
344                raw_output: result.clone().into(),
345                llm_output: result.into(),
346            })
347        })
348    }
349
350    fn replay(
351        &self,
352        _input: serde_json::Value,
353        _output: serde_json::Value,
354        _event_stream: ToolCallEventStream,
355        _cx: &mut App,
356    ) -> Result<()> {
357        Ok(())
358    }
359}
360
361pub fn get_prompt(
362    server_store: &Entity<ContextServerStore>,
363    server_id: &ContextServerId,
364    prompt_name: &str,
365    arguments: HashMap<String, String>,
366    cx: &mut AsyncApp,
367) -> Task<Result<context_server::types::PromptsGetResponse>> {
368    let server = match cx.update(|cx| server_store.read(cx).get_running_server(server_id)) {
369        Ok(server) => server,
370        Err(error) => return Task::ready(Err(error)),
371    };
372    let Some(server) = server else {
373        return Task::ready(Err(anyhow::anyhow!("Context server not found")));
374    };
375
376    let Some(protocol) = server.client() else {
377        return Task::ready(Err(anyhow::anyhow!("Context server not initialized")));
378    };
379
380    let prompt_name = prompt_name.to_string();
381
382    cx.background_spawn(async move {
383        let response = protocol
384            .request::<context_server::types::requests::PromptsGet>(
385                context_server::types::PromptsGetParams {
386                    name: prompt_name,
387                    arguments: (!arguments.is_empty()).then(|| arguments),
388                    meta: None,
389                },
390            )
391            .await?;
392
393        Ok(response)
394    })
395}