1use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
2use agent_client_protocol::ToolKind;
3use anyhow::{Result, anyhow};
4use collections::{BTreeMap, HashMap};
5use context_server::{ContextServerId, client::NotificationSubscription};
6use futures::FutureExt as _;
7use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task};
8use project::context_server_store::{ContextServerStatus, ContextServerStore};
9use std::sync::Arc;
10use util::ResultExt;
11
12pub struct ContextServerPrompt {
13 pub server_id: ContextServerId,
14 pub prompt: context_server::types::Prompt,
15}
16
17pub enum ContextServerRegistryEvent {
18 ToolsChanged,
19 PromptsChanged,
20}
21
22impl EventEmitter<ContextServerRegistryEvent> for ContextServerRegistry {}
23
24pub struct ContextServerRegistry {
25 server_store: Entity<ContextServerStore>,
26 registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
27 _subscription: gpui::Subscription,
28}
29
30struct RegisteredContextServer {
31 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
32 prompts: BTreeMap<SharedString, ContextServerPrompt>,
33 load_tools: Task<Result<()>>,
34 load_prompts: Task<Result<()>>,
35 _tools_updated_subscription: Option<NotificationSubscription>,
36}
37
38impl ContextServerRegistry {
39 pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
40 let mut this = Self {
41 server_store: server_store.clone(),
42 registered_servers: HashMap::default(),
43 _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
44 };
45 for server in server_store.read(cx).running_servers() {
46 this.reload_tools_for_server(server.id(), cx);
47 this.reload_prompts_for_server(server.id(), cx);
48 }
49 this
50 }
51
52 pub fn tools_for_server(
53 &self,
54 server_id: &ContextServerId,
55 ) -> impl Iterator<Item = &Arc<dyn AnyAgentTool>> {
56 self.registered_servers
57 .get(server_id)
58 .map(|server| server.tools.values())
59 .into_iter()
60 .flatten()
61 }
62
63 pub fn servers(
64 &self,
65 ) -> impl Iterator<
66 Item = (
67 &ContextServerId,
68 &BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
69 ),
70 > {
71 self.registered_servers
72 .iter()
73 .map(|(id, server)| (id, &server.tools))
74 }
75
76 pub fn prompts(&self) -> impl Iterator<Item = &ContextServerPrompt> {
77 self.registered_servers
78 .values()
79 .flat_map(|server| server.prompts.values())
80 }
81
82 pub fn find_prompt(
83 &self,
84 server_id: Option<&ContextServerId>,
85 name: &str,
86 ) -> Option<&ContextServerPrompt> {
87 if let Some(server_id) = server_id {
88 self.registered_servers
89 .get(server_id)
90 .and_then(|server| server.prompts.get(name))
91 } else {
92 self.registered_servers
93 .values()
94 .find_map(|server| server.prompts.get(name))
95 }
96 }
97
98 pub fn server_store(&self) -> &Entity<ContextServerStore> {
99 &self.server_store
100 }
101
102 fn get_or_register_server(
103 &mut self,
104 server_id: &ContextServerId,
105 cx: &mut Context<Self>,
106 ) -> &mut RegisteredContextServer {
107 self.registered_servers
108 .entry(server_id.clone())
109 .or_insert_with(|| Self::init_registered_server(server_id, &self.server_store, cx))
110 }
111
112 fn init_registered_server(
113 server_id: &ContextServerId,
114 server_store: &Entity<ContextServerStore>,
115 cx: &mut Context<Self>,
116 ) -> RegisteredContextServer {
117 let tools_updated_subscription = server_store
118 .read(cx)
119 .get_running_server(server_id)
120 .and_then(|server| {
121 let client = server.client()?;
122
123 if !client.capable(context_server::protocol::ServerCapability::Tools) {
124 return None;
125 }
126
127 let server_id = server.id();
128 let this = cx.entity().downgrade();
129
130 Some(client.on_notification(
131 "notifications/tools/list_changed",
132 Box::new(move |_params, cx: AsyncApp| {
133 let server_id = server_id.clone();
134 let this = this.clone();
135 cx.spawn(async move |cx| {
136 this.update(cx, |this, cx| {
137 log::info!(
138 "Received tools/list_changed notification for server {}",
139 server_id
140 );
141 this.reload_tools_for_server(server_id, cx);
142 })
143 })
144 .detach();
145 }),
146 ))
147 });
148
149 RegisteredContextServer {
150 tools: BTreeMap::default(),
151 prompts: BTreeMap::default(),
152 load_tools: Task::ready(Ok(())),
153 load_prompts: Task::ready(Ok(())),
154 _tools_updated_subscription: tools_updated_subscription,
155 }
156 }
157
158 fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
159 let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
160 return;
161 };
162 let Some(client) = server.client() else {
163 return;
164 };
165
166 if !client.capable(context_server::protocol::ServerCapability::Tools) {
167 return;
168 }
169
170 let registered_server = self.get_or_register_server(&server_id, cx);
171 registered_server.load_tools = cx.spawn(async move |this, cx| {
172 let response = client
173 .request::<context_server::types::requests::ListTools>(())
174 .await;
175
176 this.update(cx, |this, cx| {
177 let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
178 return;
179 };
180
181 registered_server.tools.clear();
182 if let Some(response) = response.log_err() {
183 for tool in response.tools {
184 let tool = Arc::new(ContextServerTool::new(
185 this.server_store.clone(),
186 server.id(),
187 tool,
188 ));
189 registered_server.tools.insert(tool.name(), tool);
190 }
191 cx.emit(ContextServerRegistryEvent::ToolsChanged);
192 cx.notify();
193 }
194 })
195 });
196 }
197
198 fn reload_prompts_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
199 let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
200 return;
201 };
202 let Some(client) = server.client() else {
203 return;
204 };
205 if !client.capable(context_server::protocol::ServerCapability::Prompts) {
206 return;
207 }
208
209 let registered_server = self.get_or_register_server(&server_id, cx);
210
211 registered_server.load_prompts = cx.spawn(async move |this, cx| {
212 let response = client
213 .request::<context_server::types::requests::PromptsList>(())
214 .await;
215
216 this.update(cx, |this, cx| {
217 let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
218 return;
219 };
220
221 registered_server.prompts.clear();
222 if let Some(response) = response.log_err() {
223 for prompt in response.prompts {
224 let name: SharedString = prompt.name.clone().into();
225 registered_server.prompts.insert(
226 name,
227 ContextServerPrompt {
228 server_id: server_id.clone(),
229 prompt,
230 },
231 );
232 }
233 cx.emit(ContextServerRegistryEvent::PromptsChanged);
234 cx.notify();
235 }
236 })
237 });
238 }
239
240 fn handle_context_server_store_event(
241 &mut self,
242 _: Entity<ContextServerStore>,
243 event: &project::context_server_store::Event,
244 cx: &mut Context<Self>,
245 ) {
246 match event {
247 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
248 match status {
249 ContextServerStatus::Starting => {}
250 ContextServerStatus::Running => {
251 self.reload_tools_for_server(server_id.clone(), cx);
252 self.reload_prompts_for_server(server_id.clone(), cx);
253 }
254 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
255 if let Some(registered_server) = self.registered_servers.remove(server_id) {
256 if !registered_server.tools.is_empty() {
257 cx.emit(ContextServerRegistryEvent::ToolsChanged);
258 }
259 if !registered_server.prompts.is_empty() {
260 cx.emit(ContextServerRegistryEvent::PromptsChanged);
261 }
262 }
263 cx.notify();
264 }
265 }
266 }
267 }
268 }
269}
270
271struct ContextServerTool {
272 store: Entity<ContextServerStore>,
273 server_id: ContextServerId,
274 tool: context_server::types::Tool,
275}
276
277impl ContextServerTool {
278 fn new(
279 store: Entity<ContextServerStore>,
280 server_id: ContextServerId,
281 tool: context_server::types::Tool,
282 ) -> Self {
283 Self {
284 store,
285 server_id,
286 tool,
287 }
288 }
289}
290
291impl AnyAgentTool for ContextServerTool {
292 fn name(&self) -> SharedString {
293 self.tool.name.clone().into()
294 }
295
296 fn description(&self) -> SharedString {
297 self.tool.description.clone().unwrap_or_default().into()
298 }
299
300 fn kind(&self) -> ToolKind {
301 ToolKind::Other
302 }
303
304 fn initial_title(&self, _input: serde_json::Value, _cx: &mut App) -> SharedString {
305 format!("Run MCP tool `{}`", self.tool.name).into()
306 }
307
308 fn input_schema(
309 &self,
310 format: language_model::LanguageModelToolSchemaFormat,
311 ) -> Result<serde_json::Value> {
312 let mut schema = self.tool.input_schema.clone();
313 language_model::tool_schema::adapt_schema_to_format(&mut schema, format)?;
314 Ok(match schema {
315 serde_json::Value::Null => {
316 serde_json::json!({ "type": "object", "properties": [] })
317 }
318 serde_json::Value::Object(map) if map.is_empty() => {
319 serde_json::json!({ "type": "object", "properties": [] })
320 }
321 _ => schema,
322 })
323 }
324
325 fn run(
326 self: Arc<Self>,
327 input: serde_json::Value,
328 event_stream: ToolCallEventStream,
329 cx: &mut App,
330 ) -> Task<Result<AgentToolOutput>> {
331 let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
332 return Task::ready(Err(anyhow!("Context server not found")));
333 };
334 let tool_name = self.tool.name.clone();
335 let authorize = event_stream.authorize(self.initial_title(input.clone(), cx), cx);
336
337 cx.spawn(async move |_cx| {
338 authorize.await?;
339
340 let Some(protocol) = server.client() else {
341 anyhow::bail!("Context server not initialized");
342 };
343
344 let arguments = if let serde_json::Value::Object(map) = input {
345 Some(map.into_iter().collect())
346 } else {
347 None
348 };
349
350 log::trace!(
351 "Running tool: {} with arguments: {:?}",
352 tool_name,
353 arguments
354 );
355
356 let request = protocol.request::<context_server::types::requests::CallTool>(
357 context_server::types::CallToolParams {
358 name: tool_name,
359 arguments,
360 meta: None,
361 },
362 );
363
364 let response = futures::select! {
365 response = request.fuse() => response?,
366 _ = event_stream.cancelled_by_user().fuse() => {
367 anyhow::bail!("MCP tool cancelled by user");
368 }
369 };
370
371 let mut result = String::new();
372 for content in response.content {
373 match content {
374 context_server::types::ToolResponseContent::Text { text } => {
375 result.push_str(&text);
376 }
377 context_server::types::ToolResponseContent::Image { .. } => {
378 log::warn!("Ignoring image content from tool response");
379 }
380 context_server::types::ToolResponseContent::Audio { .. } => {
381 log::warn!("Ignoring audio content from tool response");
382 }
383 context_server::types::ToolResponseContent::Resource { .. } => {
384 log::warn!("Ignoring resource content from tool response");
385 }
386 }
387 }
388 Ok(AgentToolOutput {
389 raw_output: result.clone().into(),
390 llm_output: result.into(),
391 })
392 })
393 }
394
395 fn replay(
396 &self,
397 _input: serde_json::Value,
398 _output: serde_json::Value,
399 _event_stream: ToolCallEventStream,
400 _cx: &mut App,
401 ) -> Result<()> {
402 Ok(())
403 }
404}
405
406pub fn get_prompt(
407 server_store: &Entity<ContextServerStore>,
408 server_id: &ContextServerId,
409 prompt_name: &str,
410 arguments: HashMap<String, String>,
411 cx: &mut AsyncApp,
412) -> Task<Result<context_server::types::PromptsGetResponse>> {
413 let server = cx.update(|cx| server_store.read(cx).get_running_server(server_id));
414 let Some(server) = server else {
415 return Task::ready(Err(anyhow::anyhow!("Context server not found")));
416 };
417
418 let Some(protocol) = server.client() else {
419 return Task::ready(Err(anyhow::anyhow!("Context server not initialized")));
420 };
421
422 let prompt_name = prompt_name.to_string();
423
424 cx.background_spawn(async move {
425 let response = protocol
426 .request::<context_server::types::requests::PromptsGet>(
427 context_server::types::PromptsGetParams {
428 name: prompt_name,
429 arguments: (!arguments.is_empty()).then(|| arguments),
430 meta: None,
431 },
432 )
433 .await?;
434
435 Ok(response)
436 })
437}