1use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
2use agent_client_protocol::ToolKind;
3use anyhow::{Result, anyhow, bail};
4use collections::{BTreeMap, HashMap};
5use context_server::ContextServerId;
6use gpui::{App, Context, Entity, SharedString, Task};
7use project::context_server_store::{ContextServerStatus, ContextServerStore};
8use std::sync::Arc;
9use util::ResultExt;
10
11pub struct ContextServerRegistry {
12 server_store: Entity<ContextServerStore>,
13 registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
14 _subscription: gpui::Subscription,
15}
16
17struct RegisteredContextServer {
18 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
19 load_tools: Task<Result<()>>,
20}
21
22impl ContextServerRegistry {
23 pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
24 let mut this = Self {
25 server_store: server_store.clone(),
26 registered_servers: HashMap::default(),
27 _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
28 };
29 for server in server_store.read(cx).running_servers() {
30 this.reload_tools_for_server(server.id(), cx);
31 }
32 this
33 }
34
35 pub fn servers(
36 &self,
37 ) -> impl Iterator<
38 Item = (
39 &ContextServerId,
40 &BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
41 ),
42 > {
43 self.registered_servers
44 .iter()
45 .map(|(id, server)| (id, &server.tools))
46 }
47
48 fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
49 let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
50 return;
51 };
52 let Some(client) = server.client() else {
53 return;
54 };
55 if !client.capable(context_server::protocol::ServerCapability::Tools) {
56 return;
57 }
58
59 let registered_server =
60 self.registered_servers
61 .entry(server_id.clone())
62 .or_insert(RegisteredContextServer {
63 tools: BTreeMap::default(),
64 load_tools: Task::ready(Ok(())),
65 });
66 registered_server.load_tools = cx.spawn(async move |this, cx| {
67 let response = client
68 .request::<context_server::types::requests::ListTools>(())
69 .await;
70
71 this.update(cx, |this, cx| {
72 let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
73 return;
74 };
75
76 registered_server.tools.clear();
77 if let Some(response) = response.log_err() {
78 for tool in response.tools {
79 let tool = Arc::new(ContextServerTool::new(
80 this.server_store.clone(),
81 server.id(),
82 tool,
83 ));
84 registered_server.tools.insert(tool.name(), tool);
85 }
86 cx.notify();
87 }
88 })
89 });
90 }
91
92 fn handle_context_server_store_event(
93 &mut self,
94 _: Entity<ContextServerStore>,
95 event: &project::context_server_store::Event,
96 cx: &mut Context<Self>,
97 ) {
98 match event {
99 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
100 match status {
101 ContextServerStatus::Starting => {}
102 ContextServerStatus::Running => {
103 self.reload_tools_for_server(server_id.clone(), cx);
104 }
105 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
106 self.registered_servers.remove(server_id);
107 cx.notify();
108 }
109 }
110 }
111 }
112 }
113}
114
115struct ContextServerTool {
116 store: Entity<ContextServerStore>,
117 server_id: ContextServerId,
118 tool: context_server::types::Tool,
119}
120
121impl ContextServerTool {
122 fn new(
123 store: Entity<ContextServerStore>,
124 server_id: ContextServerId,
125 tool: context_server::types::Tool,
126 ) -> Self {
127 Self {
128 store,
129 server_id,
130 tool,
131 }
132 }
133}
134
135impl AnyAgentTool for ContextServerTool {
136 fn name(&self) -> SharedString {
137 self.tool.name.clone().into()
138 }
139
140 fn description(&self) -> SharedString {
141 self.tool.description.clone().unwrap_or_default().into()
142 }
143
144 fn kind(&self) -> ToolKind {
145 ToolKind::Other
146 }
147
148 fn initial_title(&self, _input: serde_json::Value) -> SharedString {
149 format!("Run MCP tool `{}`", self.tool.name).into()
150 }
151
152 fn input_schema(
153 &self,
154 format: language_model::LanguageModelToolSchemaFormat,
155 ) -> Result<serde_json::Value> {
156 let mut schema = self.tool.input_schema.clone();
157 assistant_tool::adapt_schema_to_format(&mut schema, format)?;
158 Ok(match schema {
159 serde_json::Value::Null => {
160 serde_json::json!({ "type": "object", "properties": [] })
161 }
162 serde_json::Value::Object(map) if map.is_empty() => {
163 serde_json::json!({ "type": "object", "properties": [] })
164 }
165 _ => schema,
166 })
167 }
168
169 fn run(
170 self: Arc<Self>,
171 input: serde_json::Value,
172 _event_stream: ToolCallEventStream,
173 cx: &mut App,
174 ) -> Task<Result<AgentToolOutput>> {
175 let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
176 return Task::ready(Err(anyhow!("Context server not found")));
177 };
178 let tool_name = self.tool.name.clone();
179 let server_clone = server.clone();
180 let input_clone = input.clone();
181
182 cx.spawn(async move |_cx| {
183 let Some(protocol) = server_clone.client() else {
184 bail!("Context server not initialized");
185 };
186
187 let arguments = if let serde_json::Value::Object(map) = input_clone {
188 Some(map.into_iter().collect())
189 } else {
190 None
191 };
192
193 log::trace!(
194 "Running tool: {} with arguments: {:?}",
195 tool_name,
196 arguments
197 );
198 let response = protocol
199 .request::<context_server::types::requests::CallTool>(
200 context_server::types::CallToolParams {
201 name: tool_name,
202 arguments,
203 meta: None,
204 },
205 )
206 .await?;
207
208 let mut result = String::new();
209 for content in response.content {
210 match content {
211 context_server::types::ToolResponseContent::Text { text } => {
212 result.push_str(&text);
213 }
214 context_server::types::ToolResponseContent::Image { .. } => {
215 log::warn!("Ignoring image content from tool response");
216 }
217 context_server::types::ToolResponseContent::Audio { .. } => {
218 log::warn!("Ignoring audio content from tool response");
219 }
220 context_server::types::ToolResponseContent::Resource { .. } => {
221 log::warn!("Ignoring resource content from tool response");
222 }
223 }
224 }
225 Ok(AgentToolOutput {
226 raw_output: result.clone().into(),
227 llm_output: result.into(),
228 })
229 })
230 }
231}