1use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
2use agent_client_protocol::ToolKind;
3use anyhow::{Result, anyhow, bail};
4use collections::{BTreeMap, HashMap};
5use context_server::ContextServerId;
6use gpui::{App, Context, Entity, SharedString, Task};
7use project::context_server_store::{ContextServerStatus, ContextServerStore};
8use std::sync::Arc;
9use util::ResultExt;
10
11pub struct ContextServerRegistry {
12 server_store: Entity<ContextServerStore>,
13 registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
14 _subscription: gpui::Subscription,
15}
16
17struct RegisteredContextServer {
18 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
19 load_tools: Task<Result<()>>,
20}
21
22impl ContextServerRegistry {
23 pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
24 let mut this = Self {
25 server_store: server_store.clone(),
26 registered_servers: HashMap::default(),
27 _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
28 };
29 for server in server_store.read(cx).running_servers() {
30 this.reload_tools_for_server(server.id(), cx);
31 }
32 this
33 }
34
35 pub fn servers(
36 &self,
37 ) -> impl Iterator<
38 Item = (
39 &ContextServerId,
40 &BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
41 ),
42 > {
43 self.registered_servers
44 .iter()
45 .map(|(id, server)| (id, &server.tools))
46 }
47
48 fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
49 let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
50 return;
51 };
52 let Some(client) = server.client() else {
53 return;
54 };
55 if !client.capable(context_server::protocol::ServerCapability::Tools) {
56 return;
57 }
58
59 let registered_server =
60 self.registered_servers
61 .entry(server_id.clone())
62 .or_insert(RegisteredContextServer {
63 tools: BTreeMap::default(),
64 load_tools: Task::ready(Ok(())),
65 });
66 registered_server.load_tools = cx.spawn(async move |this, cx| {
67 let response = client
68 .request::<context_server::types::requests::ListTools>(())
69 .await;
70
71 this.update(cx, |this, cx| {
72 let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
73 return;
74 };
75
76 registered_server.tools.clear();
77 if let Some(response) = response.log_err() {
78 for tool in response.tools {
79 let tool = Arc::new(ContextServerTool::new(
80 this.server_store.clone(),
81 server.id(),
82 tool,
83 ));
84 registered_server.tools.insert(tool.name(), tool);
85 }
86 cx.notify();
87 }
88 })
89 });
90 }
91
92 fn handle_context_server_store_event(
93 &mut self,
94 _: Entity<ContextServerStore>,
95 event: &project::context_server_store::Event,
96 cx: &mut Context<Self>,
97 ) {
98 match event {
99 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
100 match status {
101 ContextServerStatus::Starting => {}
102 ContextServerStatus::Running => {
103 self.reload_tools_for_server(server_id.clone(), cx);
104 }
105 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
106 self.registered_servers.remove(server_id);
107 cx.notify();
108 }
109 }
110 }
111 }
112 }
113}
114
115struct ContextServerTool {
116 store: Entity<ContextServerStore>,
117 server_id: ContextServerId,
118 tool: context_server::types::Tool,
119}
120
121impl ContextServerTool {
122 fn new(
123 store: Entity<ContextServerStore>,
124 server_id: ContextServerId,
125 tool: context_server::types::Tool,
126 ) -> Self {
127 Self {
128 store,
129 server_id,
130 tool,
131 }
132 }
133}
134
135impl AnyAgentTool for ContextServerTool {
136 fn name(&self) -> SharedString {
137 self.tool.name.clone().into()
138 }
139
140 fn description(&self) -> SharedString {
141 self.tool.description.clone().unwrap_or_default().into()
142 }
143
144 fn kind(&self) -> ToolKind {
145 ToolKind::Other
146 }
147
148 fn initial_title(&self, _input: serde_json::Value, _cx: &mut App) -> SharedString {
149 format!("Run MCP tool `{}`", self.tool.name).into()
150 }
151
152 fn input_schema(
153 &self,
154 format: language_model::LanguageModelToolSchemaFormat,
155 ) -> Result<serde_json::Value> {
156 let mut schema = self.tool.input_schema.clone();
157 assistant_tool::adapt_schema_to_format(&mut schema, format)?;
158 Ok(match schema {
159 serde_json::Value::Null => {
160 serde_json::json!({ "type": "object", "properties": [] })
161 }
162 serde_json::Value::Object(map) if map.is_empty() => {
163 serde_json::json!({ "type": "object", "properties": [] })
164 }
165 _ => schema,
166 })
167 }
168
169 fn run(
170 self: Arc<Self>,
171 input: serde_json::Value,
172 event_stream: ToolCallEventStream,
173 cx: &mut App,
174 ) -> Task<Result<AgentToolOutput>> {
175 let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
176 return Task::ready(Err(anyhow!("Context server not found")));
177 };
178 let tool_name = self.tool.name.clone();
179 let authorize = event_stream.authorize(self.initial_title(input.clone(), cx), cx);
180
181 cx.spawn(async move |_cx| {
182 authorize.await?;
183
184 let Some(protocol) = server.client() else {
185 bail!("Context server not initialized");
186 };
187
188 let arguments = if let serde_json::Value::Object(map) = input {
189 Some(map.into_iter().collect())
190 } else {
191 None
192 };
193
194 log::trace!(
195 "Running tool: {} with arguments: {:?}",
196 tool_name,
197 arguments
198 );
199 let response = protocol
200 .request::<context_server::types::requests::CallTool>(
201 context_server::types::CallToolParams {
202 name: tool_name,
203 arguments,
204 meta: None,
205 },
206 )
207 .await?;
208
209 let mut result = String::new();
210 for content in response.content {
211 match content {
212 context_server::types::ToolResponseContent::Text { text } => {
213 result.push_str(&text);
214 }
215 context_server::types::ToolResponseContent::Image { .. } => {
216 log::warn!("Ignoring image content from tool response");
217 }
218 context_server::types::ToolResponseContent::Audio { .. } => {
219 log::warn!("Ignoring audio content from tool response");
220 }
221 context_server::types::ToolResponseContent::Resource { .. } => {
222 log::warn!("Ignoring resource content from tool response");
223 }
224 }
225 }
226 Ok(AgentToolOutput {
227 raw_output: result.clone().into(),
228 llm_output: result.into(),
229 })
230 })
231 }
232
233 fn replay(
234 &self,
235 _input: serde_json::Value,
236 _output: serde_json::Value,
237 _event_stream: ToolCallEventStream,
238 _cx: &mut App,
239 ) -> Result<()> {
240 Ok(())
241 }
242}