1use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
2use agent_client_protocol::ToolKind;
3use anyhow::{Result, anyhow, bail};
4use collections::{BTreeMap, HashMap};
5use context_server::{ContextServerId, client::NotificationSubscription};
6use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task};
7use project::context_server_store::{ContextServerStatus, ContextServerStore};
8use std::sync::Arc;
9use util::ResultExt;
10
11pub struct ContextServerPrompt {
12 pub server_id: ContextServerId,
13 pub prompt: context_server::types::Prompt,
14}
15
16pub enum ContextServerRegistryEvent {
17 ToolsChanged,
18 PromptsChanged,
19}
20
21impl EventEmitter<ContextServerRegistryEvent> for ContextServerRegistry {}
22
23pub struct ContextServerRegistry {
24 server_store: Entity<ContextServerStore>,
25 registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
26 _subscription: gpui::Subscription,
27}
28
29struct RegisteredContextServer {
30 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
31 prompts: BTreeMap<SharedString, ContextServerPrompt>,
32 load_tools: Task<Result<()>>,
33 load_prompts: Task<Result<()>>,
34 _tools_updated_subscription: Option<NotificationSubscription>,
35}
36
37impl ContextServerRegistry {
38 pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
39 let mut this = Self {
40 server_store: server_store.clone(),
41 registered_servers: HashMap::default(),
42 _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
43 };
44 for server in server_store.read(cx).running_servers() {
45 this.reload_tools_for_server(server.id(), cx);
46 this.reload_prompts_for_server(server.id(), cx);
47 }
48 this
49 }
50
51 pub fn tools_for_server(
52 &self,
53 server_id: &ContextServerId,
54 ) -> impl Iterator<Item = &Arc<dyn AnyAgentTool>> {
55 self.registered_servers
56 .get(server_id)
57 .map(|server| server.tools.values())
58 .into_iter()
59 .flatten()
60 }
61
62 pub fn servers(
63 &self,
64 ) -> impl Iterator<
65 Item = (
66 &ContextServerId,
67 &BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
68 ),
69 > {
70 self.registered_servers
71 .iter()
72 .map(|(id, server)| (id, &server.tools))
73 }
74
75 pub fn prompts(&self) -> impl Iterator<Item = &ContextServerPrompt> {
76 self.registered_servers
77 .values()
78 .flat_map(|server| server.prompts.values())
79 }
80
81 pub fn find_prompt(
82 &self,
83 server_id: Option<&ContextServerId>,
84 name: &str,
85 ) -> Option<&ContextServerPrompt> {
86 if let Some(server_id) = server_id {
87 self.registered_servers
88 .get(server_id)
89 .and_then(|server| server.prompts.get(name))
90 } else {
91 self.registered_servers
92 .values()
93 .find_map(|server| server.prompts.get(name))
94 }
95 }
96
97 pub fn server_store(&self) -> &Entity<ContextServerStore> {
98 &self.server_store
99 }
100
101 fn get_or_register_server(
102 &mut self,
103 server_id: &ContextServerId,
104 cx: &mut Context<Self>,
105 ) -> &mut RegisteredContextServer {
106 self.registered_servers
107 .entry(server_id.clone())
108 .or_insert_with(|| Self::init_registered_server(server_id, &self.server_store, cx))
109 }
110
111 fn init_registered_server(
112 server_id: &ContextServerId,
113 server_store: &Entity<ContextServerStore>,
114 cx: &mut Context<Self>,
115 ) -> RegisteredContextServer {
116 let tools_updated_subscription = server_store
117 .read(cx)
118 .get_running_server(server_id)
119 .and_then(|server| {
120 let client = server.client()?;
121
122 if !client.capable(context_server::protocol::ServerCapability::Tools) {
123 return None;
124 }
125
126 let server_id = server.id();
127 let this = cx.entity().downgrade();
128
129 Some(client.on_notification(
130 "notifications/tools/list_changed",
131 Box::new(move |_params, cx: AsyncApp| {
132 let server_id = server_id.clone();
133 let this = this.clone();
134 cx.spawn(async move |cx| {
135 this.update(cx, |this, cx| {
136 log::info!(
137 "Received tools/list_changed notification for server {}",
138 server_id
139 );
140 this.reload_tools_for_server(server_id, cx);
141 })
142 })
143 .detach();
144 }),
145 ))
146 });
147
148 RegisteredContextServer {
149 tools: BTreeMap::default(),
150 prompts: BTreeMap::default(),
151 load_tools: Task::ready(Ok(())),
152 load_prompts: Task::ready(Ok(())),
153 _tools_updated_subscription: tools_updated_subscription,
154 }
155 }
156
157 fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
158 let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
159 return;
160 };
161 let Some(client) = server.client() else {
162 return;
163 };
164
165 if !client.capable(context_server::protocol::ServerCapability::Tools) {
166 return;
167 }
168
169 let registered_server = self.get_or_register_server(&server_id, cx);
170 registered_server.load_tools = cx.spawn(async move |this, cx| {
171 let response = client
172 .request::<context_server::types::requests::ListTools>(())
173 .await;
174
175 this.update(cx, |this, cx| {
176 let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
177 return;
178 };
179
180 registered_server.tools.clear();
181 if let Some(response) = response.log_err() {
182 for tool in response.tools {
183 let tool = Arc::new(ContextServerTool::new(
184 this.server_store.clone(),
185 server.id(),
186 tool,
187 ));
188 registered_server.tools.insert(tool.name(), tool);
189 }
190 cx.emit(ContextServerRegistryEvent::ToolsChanged);
191 cx.notify();
192 }
193 })
194 });
195 }
196
197 fn reload_prompts_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
198 let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
199 return;
200 };
201 let Some(client) = server.client() else {
202 return;
203 };
204 if !client.capable(context_server::protocol::ServerCapability::Prompts) {
205 return;
206 }
207
208 let registered_server = self.get_or_register_server(&server_id, cx);
209
210 registered_server.load_prompts = cx.spawn(async move |this, cx| {
211 let response = client
212 .request::<context_server::types::requests::PromptsList>(())
213 .await;
214
215 this.update(cx, |this, cx| {
216 let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
217 return;
218 };
219
220 registered_server.prompts.clear();
221 if let Some(response) = response.log_err() {
222 for prompt in response.prompts {
223 let name: SharedString = prompt.name.clone().into();
224 registered_server.prompts.insert(
225 name,
226 ContextServerPrompt {
227 server_id: server_id.clone(),
228 prompt,
229 },
230 );
231 }
232 cx.emit(ContextServerRegistryEvent::PromptsChanged);
233 cx.notify();
234 }
235 })
236 });
237 }
238
239 fn handle_context_server_store_event(
240 &mut self,
241 _: Entity<ContextServerStore>,
242 event: &project::context_server_store::Event,
243 cx: &mut Context<Self>,
244 ) {
245 match event {
246 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
247 match status {
248 ContextServerStatus::Starting => {}
249 ContextServerStatus::Running => {
250 self.reload_tools_for_server(server_id.clone(), cx);
251 self.reload_prompts_for_server(server_id.clone(), cx);
252 }
253 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
254 if let Some(registered_server) = self.registered_servers.remove(server_id) {
255 if !registered_server.tools.is_empty() {
256 cx.emit(ContextServerRegistryEvent::ToolsChanged);
257 }
258 if !registered_server.prompts.is_empty() {
259 cx.emit(ContextServerRegistryEvent::PromptsChanged);
260 }
261 }
262 cx.notify();
263 }
264 }
265 }
266 }
267 }
268}
269
270struct ContextServerTool {
271 store: Entity<ContextServerStore>,
272 server_id: ContextServerId,
273 tool: context_server::types::Tool,
274}
275
276impl ContextServerTool {
277 fn new(
278 store: Entity<ContextServerStore>,
279 server_id: ContextServerId,
280 tool: context_server::types::Tool,
281 ) -> Self {
282 Self {
283 store,
284 server_id,
285 tool,
286 }
287 }
288}
289
290impl AnyAgentTool for ContextServerTool {
291 fn name(&self) -> SharedString {
292 self.tool.name.clone().into()
293 }
294
295 fn description(&self) -> SharedString {
296 self.tool.description.clone().unwrap_or_default().into()
297 }
298
299 fn kind(&self) -> ToolKind {
300 ToolKind::Other
301 }
302
303 fn initial_title(&self, _input: serde_json::Value, _cx: &mut App) -> SharedString {
304 format!("Run MCP tool `{}`", self.tool.name).into()
305 }
306
307 fn input_schema(
308 &self,
309 format: language_model::LanguageModelToolSchemaFormat,
310 ) -> Result<serde_json::Value> {
311 let mut schema = self.tool.input_schema.clone();
312 language_model::tool_schema::adapt_schema_to_format(&mut schema, format)?;
313 Ok(match schema {
314 serde_json::Value::Null => {
315 serde_json::json!({ "type": "object", "properties": [] })
316 }
317 serde_json::Value::Object(map) if map.is_empty() => {
318 serde_json::json!({ "type": "object", "properties": [] })
319 }
320 _ => schema,
321 })
322 }
323
324 fn run(
325 self: Arc<Self>,
326 input: serde_json::Value,
327 event_stream: ToolCallEventStream,
328 cx: &mut App,
329 ) -> Task<Result<AgentToolOutput>> {
330 let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
331 return Task::ready(Err(anyhow!("Context server not found")));
332 };
333 let tool_name = self.tool.name.clone();
334 let authorize = event_stream.authorize(self.initial_title(input.clone(), cx), cx);
335
336 cx.spawn(async move |_cx| {
337 authorize.await?;
338
339 let Some(protocol) = server.client() else {
340 bail!("Context server not initialized");
341 };
342
343 let arguments = if let serde_json::Value::Object(map) = input {
344 Some(map.into_iter().collect())
345 } else {
346 None
347 };
348
349 log::trace!(
350 "Running tool: {} with arguments: {:?}",
351 tool_name,
352 arguments
353 );
354 let response = protocol
355 .request::<context_server::types::requests::CallTool>(
356 context_server::types::CallToolParams {
357 name: tool_name,
358 arguments,
359 meta: None,
360 },
361 )
362 .await?;
363
364 let mut result = String::new();
365 for content in response.content {
366 match content {
367 context_server::types::ToolResponseContent::Text { text } => {
368 result.push_str(&text);
369 }
370 context_server::types::ToolResponseContent::Image { .. } => {
371 log::warn!("Ignoring image content from tool response");
372 }
373 context_server::types::ToolResponseContent::Audio { .. } => {
374 log::warn!("Ignoring audio content from tool response");
375 }
376 context_server::types::ToolResponseContent::Resource { .. } => {
377 log::warn!("Ignoring resource content from tool response");
378 }
379 }
380 }
381 Ok(AgentToolOutput {
382 raw_output: result.clone().into(),
383 llm_output: result.into(),
384 })
385 })
386 }
387
388 fn replay(
389 &self,
390 _input: serde_json::Value,
391 _output: serde_json::Value,
392 _event_stream: ToolCallEventStream,
393 _cx: &mut App,
394 ) -> Result<()> {
395 Ok(())
396 }
397}
398
399pub fn get_prompt(
400 server_store: &Entity<ContextServerStore>,
401 server_id: &ContextServerId,
402 prompt_name: &str,
403 arguments: HashMap<String, String>,
404 cx: &mut AsyncApp,
405) -> Task<Result<context_server::types::PromptsGetResponse>> {
406 let server = match cx.update(|cx| server_store.read(cx).get_running_server(server_id)) {
407 Ok(server) => server,
408 Err(error) => return Task::ready(Err(error)),
409 };
410 let Some(server) = server else {
411 return Task::ready(Err(anyhow::anyhow!("Context server not found")));
412 };
413
414 let Some(protocol) = server.client() else {
415 return Task::ready(Err(anyhow::anyhow!("Context server not initialized")));
416 };
417
418 let prompt_name = prompt_name.to_string();
419
420 cx.background_spawn(async move {
421 let response = protocol
422 .request::<context_server::types::requests::PromptsGet>(
423 context_server::types::PromptsGetParams {
424 name: prompt_name,
425 arguments: (!arguments.is_empty()).then(|| arguments),
426 meta: None,
427 },
428 )
429 .await?;
430
431 Ok(response)
432 })
433}