1use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
2use agent_client_protocol::ToolKind;
3use anyhow::{Result, anyhow, bail};
4use collections::{BTreeMap, HashMap};
5use context_server::ContextServerId;
6use gpui::{App, Context, Entity, SharedString, Task};
7use project::context_server_store::{ContextServerStatus, ContextServerStore};
8use std::sync::Arc;
9use util::ResultExt;
10
11pub struct ContextServerRegistry {
12 server_store: Entity<ContextServerStore>,
13 registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
14 _subscription: gpui::Subscription,
15}
16
17struct RegisteredContextServer {
18 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
19 load_tools: Task<Result<()>>,
20}
21
22impl ContextServerRegistry {
23 pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
24 let mut this = Self {
25 server_store: server_store.clone(),
26 registered_servers: HashMap::default(),
27 _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
28 };
29 for server in server_store.read(cx).running_servers() {
30 this.reload_tools_for_server(server.id(), cx);
31 }
32 this
33 }
34
35 pub fn tools_for_server(
36 &self,
37 server_id: &ContextServerId,
38 ) -> impl Iterator<Item = &Arc<dyn AnyAgentTool>> {
39 self.registered_servers
40 .get(server_id)
41 .map(|server| server.tools.values())
42 .into_iter()
43 .flatten()
44 }
45
46 pub fn servers(
47 &self,
48 ) -> impl Iterator<
49 Item = (
50 &ContextServerId,
51 &BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
52 ),
53 > {
54 self.registered_servers
55 .iter()
56 .map(|(id, server)| (id, &server.tools))
57 }
58
59 fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
60 let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
61 return;
62 };
63 let Some(client) = server.client() else {
64 return;
65 };
66 if !client.capable(context_server::protocol::ServerCapability::Tools) {
67 return;
68 }
69
70 let registered_server =
71 self.registered_servers
72 .entry(server_id.clone())
73 .or_insert(RegisteredContextServer {
74 tools: BTreeMap::default(),
75 load_tools: Task::ready(Ok(())),
76 });
77 registered_server.load_tools = cx.spawn(async move |this, cx| {
78 let response = client
79 .request::<context_server::types::requests::ListTools>(())
80 .await;
81
82 this.update(cx, |this, cx| {
83 let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
84 return;
85 };
86
87 registered_server.tools.clear();
88 if let Some(response) = response.log_err() {
89 for tool in response.tools {
90 let tool = Arc::new(ContextServerTool::new(
91 this.server_store.clone(),
92 server.id(),
93 tool,
94 ));
95 registered_server.tools.insert(tool.name(), tool);
96 }
97 cx.notify();
98 }
99 })
100 });
101 }
102
103 fn handle_context_server_store_event(
104 &mut self,
105 _: Entity<ContextServerStore>,
106 event: &project::context_server_store::Event,
107 cx: &mut Context<Self>,
108 ) {
109 match event {
110 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
111 match status {
112 ContextServerStatus::Starting => {}
113 ContextServerStatus::Running => {
114 self.reload_tools_for_server(server_id.clone(), cx);
115 }
116 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
117 self.registered_servers.remove(server_id);
118 cx.notify();
119 }
120 }
121 }
122 }
123 }
124}
125
126struct ContextServerTool {
127 store: Entity<ContextServerStore>,
128 server_id: ContextServerId,
129 tool: context_server::types::Tool,
130}
131
132impl ContextServerTool {
133 fn new(
134 store: Entity<ContextServerStore>,
135 server_id: ContextServerId,
136 tool: context_server::types::Tool,
137 ) -> Self {
138 Self {
139 store,
140 server_id,
141 tool,
142 }
143 }
144}
145
146impl AnyAgentTool for ContextServerTool {
147 fn name(&self) -> SharedString {
148 self.tool.name.clone().into()
149 }
150
151 fn description(&self) -> SharedString {
152 self.tool.description.clone().unwrap_or_default().into()
153 }
154
155 fn kind(&self) -> ToolKind {
156 ToolKind::Other
157 }
158
159 fn initial_title(&self, _input: serde_json::Value, _cx: &mut App) -> SharedString {
160 format!("Run MCP tool `{}`", self.tool.name).into()
161 }
162
163 fn input_schema(
164 &self,
165 format: language_model::LanguageModelToolSchemaFormat,
166 ) -> Result<serde_json::Value> {
167 let mut schema = self.tool.input_schema.clone();
168 crate::tool_schema::adapt_schema_to_format(&mut schema, format)?;
169 Ok(match schema {
170 serde_json::Value::Null => {
171 serde_json::json!({ "type": "object", "properties": [] })
172 }
173 serde_json::Value::Object(map) if map.is_empty() => {
174 serde_json::json!({ "type": "object", "properties": [] })
175 }
176 _ => schema,
177 })
178 }
179
180 fn run(
181 self: Arc<Self>,
182 input: serde_json::Value,
183 event_stream: ToolCallEventStream,
184 cx: &mut App,
185 ) -> Task<Result<AgentToolOutput>> {
186 let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
187 return Task::ready(Err(anyhow!("Context server not found")));
188 };
189 let tool_name = self.tool.name.clone();
190 let authorize = event_stream.authorize(self.initial_title(input.clone(), cx), cx);
191
192 cx.spawn(async move |_cx| {
193 authorize.await?;
194
195 let Some(protocol) = server.client() else {
196 bail!("Context server not initialized");
197 };
198
199 let arguments = if let serde_json::Value::Object(map) = input {
200 Some(map.into_iter().collect())
201 } else {
202 None
203 };
204
205 log::trace!(
206 "Running tool: {} with arguments: {:?}",
207 tool_name,
208 arguments
209 );
210 let response = protocol
211 .request::<context_server::types::requests::CallTool>(
212 context_server::types::CallToolParams {
213 name: tool_name,
214 arguments,
215 meta: None,
216 },
217 )
218 .await?;
219
220 let mut result = String::new();
221 for content in response.content {
222 match content {
223 context_server::types::ToolResponseContent::Text { text } => {
224 result.push_str(&text);
225 }
226 context_server::types::ToolResponseContent::Image { .. } => {
227 log::warn!("Ignoring image content from tool response");
228 }
229 context_server::types::ToolResponseContent::Audio { .. } => {
230 log::warn!("Ignoring audio content from tool response");
231 }
232 context_server::types::ToolResponseContent::Resource { .. } => {
233 log::warn!("Ignoring resource content from tool response");
234 }
235 }
236 }
237 Ok(AgentToolOutput {
238 raw_output: result.clone().into(),
239 llm_output: result.into(),
240 })
241 })
242 }
243
244 fn replay(
245 &self,
246 _input: serde_json::Value,
247 _output: serde_json::Value,
248 _event_stream: ToolCallEventStream,
249 _cx: &mut App,
250 ) -> Result<()> {
251 Ok(())
252 }
253}