1use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
2use agent_client_protocol::ToolKind;
3use anyhow::{Result, anyhow, bail};
4use collections::{BTreeMap, HashMap};
5use context_server::ContextServerId;
6use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task};
7use project::context_server_store::{ContextServerStatus, ContextServerStore};
8use std::sync::Arc;
9use util::ResultExt;
10
11pub struct ContextServerPrompt {
12 pub server_id: ContextServerId,
13 pub prompt: context_server::types::Prompt,
14}
15
16pub enum ContextServerRegistryEvent {
17 ToolsChanged,
18 PromptsChanged,
19}
20
21impl EventEmitter<ContextServerRegistryEvent> for ContextServerRegistry {}
22
23pub struct ContextServerRegistry {
24 server_store: Entity<ContextServerStore>,
25 registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
26 _subscription: gpui::Subscription,
27}
28
29struct RegisteredContextServer {
30 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
31 prompts: BTreeMap<SharedString, ContextServerPrompt>,
32 load_tools: Task<Result<()>>,
33 load_prompts: Task<Result<()>>,
34}
35
36impl RegisteredContextServer {
37 fn new() -> Self {
38 Self {
39 tools: BTreeMap::default(),
40 prompts: BTreeMap::default(),
41 load_tools: Task::ready(Ok(())),
42 load_prompts: Task::ready(Ok(())),
43 }
44 }
45}
46
47impl ContextServerRegistry {
48 pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
49 let mut this = Self {
50 server_store: server_store.clone(),
51 registered_servers: HashMap::default(),
52 _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
53 };
54 for server in server_store.read(cx).running_servers() {
55 this.reload_tools_for_server(server.id(), cx);
56 this.reload_prompts_for_server(server.id(), cx);
57 }
58 this
59 }
60
61 pub fn tools_for_server(
62 &self,
63 server_id: &ContextServerId,
64 ) -> impl Iterator<Item = &Arc<dyn AnyAgentTool>> {
65 self.registered_servers
66 .get(server_id)
67 .map(|server| server.tools.values())
68 .into_iter()
69 .flatten()
70 }
71
72 pub fn servers(
73 &self,
74 ) -> impl Iterator<
75 Item = (
76 &ContextServerId,
77 &BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
78 ),
79 > {
80 self.registered_servers
81 .iter()
82 .map(|(id, server)| (id, &server.tools))
83 }
84
85 pub fn prompts(&self) -> impl Iterator<Item = &ContextServerPrompt> {
86 self.registered_servers
87 .values()
88 .flat_map(|server| server.prompts.values())
89 }
90
91 pub fn find_prompt(
92 &self,
93 server_id: Option<&ContextServerId>,
94 name: &str,
95 ) -> Option<&ContextServerPrompt> {
96 if let Some(server_id) = server_id {
97 self.registered_servers
98 .get(server_id)
99 .and_then(|server| server.prompts.get(name))
100 } else {
101 self.registered_servers
102 .values()
103 .find_map(|server| server.prompts.get(name))
104 }
105 }
106
107 pub fn server_store(&self) -> &Entity<ContextServerStore> {
108 &self.server_store
109 }
110
111 fn get_or_register_server(
112 &mut self,
113 server_id: &ContextServerId,
114 ) -> &mut RegisteredContextServer {
115 self.registered_servers
116 .entry(server_id.clone())
117 .or_insert_with(RegisteredContextServer::new)
118 }
119
120 fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
121 let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
122 return;
123 };
124 let Some(client) = server.client() else {
125 return;
126 };
127 if !client.capable(context_server::protocol::ServerCapability::Tools) {
128 return;
129 }
130
131 let registered_server = self.get_or_register_server(&server_id);
132 registered_server.load_tools = cx.spawn(async move |this, cx| {
133 let response = client
134 .request::<context_server::types::requests::ListTools>(())
135 .await;
136
137 this.update(cx, |this, cx| {
138 let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
139 return;
140 };
141
142 registered_server.tools.clear();
143 if let Some(response) = response.log_err() {
144 for tool in response.tools {
145 let tool = Arc::new(ContextServerTool::new(
146 this.server_store.clone(),
147 server.id(),
148 tool,
149 ));
150 registered_server.tools.insert(tool.name(), tool);
151 }
152 cx.emit(ContextServerRegistryEvent::ToolsChanged);
153 cx.notify();
154 }
155 })
156 });
157 }
158
159 fn reload_prompts_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
160 let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
161 return;
162 };
163 let Some(client) = server.client() else {
164 return;
165 };
166 if !client.capable(context_server::protocol::ServerCapability::Prompts) {
167 return;
168 }
169
170 let registered_server = self.get_or_register_server(&server_id);
171
172 registered_server.load_prompts = cx.spawn(async move |this, cx| {
173 let response = client
174 .request::<context_server::types::requests::PromptsList>(())
175 .await;
176
177 this.update(cx, |this, cx| {
178 let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
179 return;
180 };
181
182 registered_server.prompts.clear();
183 if let Some(response) = response.log_err() {
184 for prompt in response.prompts {
185 let name: SharedString = prompt.name.clone().into();
186 registered_server.prompts.insert(
187 name,
188 ContextServerPrompt {
189 server_id: server_id.clone(),
190 prompt,
191 },
192 );
193 }
194 cx.emit(ContextServerRegistryEvent::PromptsChanged);
195 cx.notify();
196 }
197 })
198 });
199 }
200
201 fn handle_context_server_store_event(
202 &mut self,
203 _: Entity<ContextServerStore>,
204 event: &project::context_server_store::Event,
205 cx: &mut Context<Self>,
206 ) {
207 match event {
208 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
209 match status {
210 ContextServerStatus::Starting => {}
211 ContextServerStatus::Running => {
212 self.reload_tools_for_server(server_id.clone(), cx);
213 self.reload_prompts_for_server(server_id.clone(), cx);
214 }
215 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
216 if let Some(registered_server) = self.registered_servers.remove(server_id) {
217 if !registered_server.tools.is_empty() {
218 cx.emit(ContextServerRegistryEvent::ToolsChanged);
219 }
220 if !registered_server.prompts.is_empty() {
221 cx.emit(ContextServerRegistryEvent::PromptsChanged);
222 }
223 }
224 cx.notify();
225 }
226 }
227 }
228 }
229 }
230}
231
232struct ContextServerTool {
233 store: Entity<ContextServerStore>,
234 server_id: ContextServerId,
235 tool: context_server::types::Tool,
236}
237
238impl ContextServerTool {
239 fn new(
240 store: Entity<ContextServerStore>,
241 server_id: ContextServerId,
242 tool: context_server::types::Tool,
243 ) -> Self {
244 Self {
245 store,
246 server_id,
247 tool,
248 }
249 }
250}
251
252impl AnyAgentTool for ContextServerTool {
253 fn name(&self) -> SharedString {
254 self.tool.name.clone().into()
255 }
256
257 fn description(&self) -> SharedString {
258 self.tool.description.clone().unwrap_or_default().into()
259 }
260
261 fn kind(&self) -> ToolKind {
262 ToolKind::Other
263 }
264
265 fn initial_title(&self, _input: serde_json::Value, _cx: &mut App) -> SharedString {
266 format!("Run MCP tool `{}`", self.tool.name).into()
267 }
268
269 fn input_schema(
270 &self,
271 format: language_model::LanguageModelToolSchemaFormat,
272 ) -> Result<serde_json::Value> {
273 let mut schema = self.tool.input_schema.clone();
274 language_model::tool_schema::adapt_schema_to_format(&mut schema, format)?;
275 Ok(match schema {
276 serde_json::Value::Null => {
277 serde_json::json!({ "type": "object", "properties": [] })
278 }
279 serde_json::Value::Object(map) if map.is_empty() => {
280 serde_json::json!({ "type": "object", "properties": [] })
281 }
282 _ => schema,
283 })
284 }
285
286 fn run(
287 self: Arc<Self>,
288 input: serde_json::Value,
289 event_stream: ToolCallEventStream,
290 cx: &mut App,
291 ) -> Task<Result<AgentToolOutput>> {
292 let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
293 return Task::ready(Err(anyhow!("Context server not found")));
294 };
295 let tool_name = self.tool.name.clone();
296 let authorize = event_stream.authorize(self.initial_title(input.clone(), cx), cx);
297
298 cx.spawn(async move |_cx| {
299 authorize.await?;
300
301 let Some(protocol) = server.client() else {
302 bail!("Context server not initialized");
303 };
304
305 let arguments = if let serde_json::Value::Object(map) = input {
306 Some(map.into_iter().collect())
307 } else {
308 None
309 };
310
311 log::trace!(
312 "Running tool: {} with arguments: {:?}",
313 tool_name,
314 arguments
315 );
316 let response = protocol
317 .request::<context_server::types::requests::CallTool>(
318 context_server::types::CallToolParams {
319 name: tool_name,
320 arguments,
321 meta: None,
322 },
323 )
324 .await?;
325
326 let mut result = String::new();
327 for content in response.content {
328 match content {
329 context_server::types::ToolResponseContent::Text { text } => {
330 result.push_str(&text);
331 }
332 context_server::types::ToolResponseContent::Image { .. } => {
333 log::warn!("Ignoring image content from tool response");
334 }
335 context_server::types::ToolResponseContent::Audio { .. } => {
336 log::warn!("Ignoring audio content from tool response");
337 }
338 context_server::types::ToolResponseContent::Resource { .. } => {
339 log::warn!("Ignoring resource content from tool response");
340 }
341 }
342 }
343 Ok(AgentToolOutput {
344 raw_output: result.clone().into(),
345 llm_output: result.into(),
346 })
347 })
348 }
349
350 fn replay(
351 &self,
352 _input: serde_json::Value,
353 _output: serde_json::Value,
354 _event_stream: ToolCallEventStream,
355 _cx: &mut App,
356 ) -> Result<()> {
357 Ok(())
358 }
359}
360
361pub fn get_prompt(
362 server_store: &Entity<ContextServerStore>,
363 server_id: &ContextServerId,
364 prompt_name: &str,
365 arguments: HashMap<String, String>,
366 cx: &mut AsyncApp,
367) -> Task<Result<context_server::types::PromptsGetResponse>> {
368 let server = match cx.update(|cx| server_store.read(cx).get_running_server(server_id)) {
369 Ok(server) => server,
370 Err(error) => return Task::ready(Err(error)),
371 };
372 let Some(server) = server else {
373 return Task::ready(Err(anyhow::anyhow!("Context server not found")));
374 };
375
376 let Some(protocol) = server.client() else {
377 return Task::ready(Err(anyhow::anyhow!("Context server not initialized")));
378 };
379
380 let prompt_name = prompt_name.to_string();
381
382 cx.background_spawn(async move {
383 let response = protocol
384 .request::<context_server::types::requests::PromptsGet>(
385 context_server::types::PromptsGetParams {
386 name: prompt_name,
387 arguments: (!arguments.is_empty()).then(|| arguments),
388 meta: None,
389 },
390 )
391 .await?;
392
393 Ok(response)
394 })
395}