1use agent_client_protocol as acp;
2use anyhow::anyhow;
3use collections::HashMap;
4use context_server::listener::McpServerTool;
5use context_server::types::requests;
6use context_server::{ContextServer, ContextServerCommand, ContextServerId};
7use futures::channel::{mpsc, oneshot};
8use project::Project;
9use smol::stream::StreamExt as _;
10use std::cell::RefCell;
11use std::rc::Rc;
12use std::{path::Path, sync::Arc};
13use util::ResultExt;
14
15use anyhow::{Context, Result};
16use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
17
18use crate::mcp_server::ZedMcpServer;
19use crate::{AgentServerCommand, mcp_server};
20use acp_thread::{AcpThread, AgentConnection, AuthRequired};
21
22pub struct AcpConnection {
23 auth_methods: Rc<RefCell<Vec<acp::AuthMethod>>>,
24 server_name: &'static str,
25 context_server: Arc<context_server::ContextServer>,
26 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
27 _session_update_task: Task<()>,
28}
29
30impl AcpConnection {
31 pub async fn stdio(
32 server_name: &'static str,
33 command: AgentServerCommand,
34 working_directory: Option<Arc<Path>>,
35 cx: &mut AsyncApp,
36 ) -> Result<Self> {
37 let context_server: Arc<ContextServer> = ContextServer::stdio(
38 ContextServerId(format!("{}-mcp-server", server_name).into()),
39 ContextServerCommand {
40 path: command.path,
41 args: command.args,
42 env: command.env,
43 },
44 working_directory,
45 )
46 .into();
47
48 let (notification_tx, mut notification_rx) = mpsc::unbounded();
49
50 let sessions = Rc::new(RefCell::new(HashMap::default()));
51
52 let session_update_handler_task = cx.spawn({
53 let sessions = sessions.clone();
54 async move |cx| {
55 while let Some(notification) = notification_rx.next().await {
56 Self::handle_session_notification(notification, sessions.clone(), cx)
57 }
58 }
59 });
60
61 context_server
62 .start_with_handlers(
63 vec![(acp::AGENT_METHODS.session_update, {
64 Box::new(move |notification, _cx| {
65 let notification_tx = notification_tx.clone();
66 log::trace!(
67 "ACP Notification: {}",
68 serde_json::to_string_pretty(¬ification).unwrap()
69 );
70
71 if let Some(notification) =
72 serde_json::from_value::<acp::SessionNotification>(notification)
73 .log_err()
74 {
75 notification_tx.unbounded_send(notification).ok();
76 }
77 })
78 })],
79 cx,
80 )
81 .await?;
82
83 Ok(Self {
84 auth_methods: Default::default(),
85 server_name,
86 context_server,
87 sessions,
88 _session_update_task: session_update_handler_task,
89 })
90 }
91
92 pub fn handle_session_notification(
93 notification: acp::SessionNotification,
94 threads: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
95 cx: &mut AsyncApp,
96 ) {
97 let threads = threads.borrow();
98 let Some(thread) = threads
99 .get(¬ification.session_id)
100 .and_then(|session| session.thread.upgrade())
101 else {
102 log::error!(
103 "Thread not found for session ID: {}",
104 notification.session_id
105 );
106 return;
107 };
108
109 thread
110 .update(cx, |thread, cx| {
111 thread.handle_session_update(notification.update, cx)
112 })
113 .log_err();
114 }
115}
116
117pub struct AcpSession {
118 thread: WeakEntity<AcpThread>,
119 cancel_tx: Option<oneshot::Sender<()>>,
120 _mcp_server: ZedMcpServer,
121}
122
123impl AgentConnection for AcpConnection {
124 fn new_thread(
125 self: Rc<Self>,
126 project: Entity<Project>,
127 cwd: &Path,
128 cx: &mut AsyncApp,
129 ) -> Task<Result<Entity<AcpThread>>> {
130 let client = self.context_server.client();
131 let sessions = self.sessions.clone();
132 let auth_methods = self.auth_methods.clone();
133 let cwd = cwd.to_path_buf();
134 cx.spawn(async move |cx| {
135 let client = client.context("MCP server is not initialized yet")?;
136 let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
137
138 let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
139
140 let response = client
141 .request::<requests::CallTool>(context_server::types::CallToolParams {
142 name: acp::AGENT_METHODS.new_session.into(),
143 arguments: Some(serde_json::to_value(acp::NewSessionArguments {
144 mcp_servers: vec![mcp_server.server_config()?],
145 client_tools: acp::ClientTools {
146 request_permission: Some(acp::McpToolId {
147 mcp_server: mcp_server::SERVER_NAME.into(),
148 tool_name: mcp_server::RequestPermissionTool::NAME.into(),
149 }),
150 read_text_file: Some(acp::McpToolId {
151 mcp_server: mcp_server::SERVER_NAME.into(),
152 tool_name: mcp_server::ReadTextFileTool::NAME.into(),
153 }),
154 write_text_file: Some(acp::McpToolId {
155 mcp_server: mcp_server::SERVER_NAME.into(),
156 tool_name: mcp_server::WriteTextFileTool::NAME.into(),
157 }),
158 },
159 cwd,
160 })?),
161 meta: None,
162 })
163 .await?;
164
165 if response.is_error.unwrap_or_default() {
166 return Err(anyhow!(response.text_contents()));
167 }
168
169 let result = serde_json::from_value::<acp::NewSessionOutput>(
170 response.structured_content.context("Empty response")?,
171 )?;
172
173 auth_methods.replace(result.auth_methods);
174
175 let Some(session_id) = result.session_id else {
176 anyhow::bail!(AuthRequired);
177 };
178
179 let thread = cx.new(|cx| {
180 AcpThread::new(
181 self.server_name,
182 self.clone(),
183 project,
184 session_id.clone(),
185 cx,
186 )
187 })?;
188
189 thread_tx.send(thread.downgrade())?;
190
191 let session = AcpSession {
192 thread: thread.downgrade(),
193 cancel_tx: None,
194 _mcp_server: mcp_server,
195 };
196 sessions.borrow_mut().insert(session_id, session);
197
198 Ok(thread)
199 })
200 }
201
202 fn auth_methods(&self) -> Vec<agent_client_protocol::AuthMethod> {
203 self.auth_methods.borrow().clone()
204 }
205
206 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
207 let client = self.context_server.client();
208 cx.foreground_executor().spawn(async move {
209 let params = acp::AuthenticateArguments { method_id };
210
211 let response = client
212 .context("MCP server is not initialized yet")?
213 .request::<requests::CallTool>(context_server::types::CallToolParams {
214 name: acp::AGENT_METHODS.authenticate.into(),
215 arguments: Some(serde_json::to_value(params)?),
216 meta: None,
217 })
218 .await?;
219
220 if response.is_error.unwrap_or_default() {
221 Err(anyhow!(response.text_contents()))
222 } else {
223 Ok(())
224 }
225 })
226 }
227
228 fn prompt(
229 &self,
230 params: agent_client_protocol::PromptArguments,
231 cx: &mut App,
232 ) -> Task<Result<()>> {
233 let client = self.context_server.client();
234 let sessions = self.sessions.clone();
235
236 cx.foreground_executor().spawn(async move {
237 let client = client.context("MCP server is not initialized yet")?;
238
239 let (new_cancel_tx, cancel_rx) = oneshot::channel();
240 {
241 let mut sessions = sessions.borrow_mut();
242 let session = sessions
243 .get_mut(¶ms.session_id)
244 .context("Session not found")?;
245 session.cancel_tx.replace(new_cancel_tx);
246 }
247
248 let result = client
249 .request_with::<requests::CallTool>(
250 context_server::types::CallToolParams {
251 name: acp::AGENT_METHODS.prompt.into(),
252 arguments: Some(serde_json::to_value(params)?),
253 meta: None,
254 },
255 Some(cancel_rx),
256 None,
257 )
258 .await;
259
260 if let Err(err) = &result
261 && err.is::<context_server::client::RequestCanceled>()
262 {
263 return Ok(());
264 }
265
266 let response = result?;
267
268 if response.is_error.unwrap_or_default() {
269 return Err(anyhow!(response.text_contents()));
270 }
271
272 Ok(())
273 })
274 }
275
276 fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
277 let mut sessions = self.sessions.borrow_mut();
278
279 if let Some(cancel_tx) = sessions
280 .get_mut(session_id)
281 .and_then(|session| session.cancel_tx.take())
282 {
283 cancel_tx.send(()).ok();
284 }
285 }
286}
287
288impl Drop for AcpConnection {
289 fn drop(&mut self) {
290 self.context_server.stop().log_err();
291 }
292}