1use agent_client_protocol as acp;
2use anyhow::anyhow;
3use collections::HashMap;
4use context_server::listener::McpServerTool;
5use context_server::types::requests;
6use context_server::{ContextServer, ContextServerCommand, ContextServerId};
7use futures::channel::{mpsc, oneshot};
8use project::Project;
9use smol::stream::StreamExt as _;
10use std::cell::RefCell;
11use std::rc::Rc;
12use std::{path::Path, sync::Arc};
13use util::ResultExt;
14
15use anyhow::{Context, Result};
16use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
17
18use crate::mcp_server::ZedMcpServer;
19use crate::{AgentServerCommand, mcp_server};
20use acp_thread::{AcpThread, AgentConnection};
21
22pub struct AcpConnection {
23 server_name: &'static str,
24 client: Arc<context_server::ContextServer>,
25 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
26 _notification_handler_task: Task<()>,
27}
28
29impl AcpConnection {
30 pub async fn stdio(
31 server_name: &'static str,
32 command: AgentServerCommand,
33 cx: &mut AsyncApp,
34 ) -> Result<Self> {
35 let client: Arc<ContextServer> = ContextServer::stdio(
36 ContextServerId(format!("{}-mcp-server", server_name).into()),
37 ContextServerCommand {
38 path: command.path,
39 args: command.args,
40 env: command.env,
41 },
42 )
43 .into();
44 ContextServer::start(client.clone(), cx).await?;
45
46 let (notification_tx, mut notification_rx) = mpsc::unbounded();
47 client
48 .client()
49 .context("Failed to subscribe")?
50 .on_notification(acp::AGENT_METHODS.session_update, {
51 move |notification, _cx| {
52 let notification_tx = notification_tx.clone();
53 log::trace!(
54 "ACP Notification: {}",
55 serde_json::to_string_pretty(¬ification).unwrap()
56 );
57
58 if let Some(notification) =
59 serde_json::from_value::<acp::SessionNotification>(notification).log_err()
60 {
61 notification_tx.unbounded_send(notification).ok();
62 }
63 }
64 });
65
66 let sessions = Rc::new(RefCell::new(HashMap::default()));
67
68 let notification_handler_task = cx.spawn({
69 let sessions = sessions.clone();
70 async move |cx| {
71 while let Some(notification) = notification_rx.next().await {
72 Self::handle_session_notification(notification, sessions.clone(), cx)
73 }
74 }
75 });
76
77 Ok(Self {
78 server_name,
79 client,
80 sessions,
81 _notification_handler_task: notification_handler_task,
82 })
83 }
84
85 pub fn handle_session_notification(
86 notification: acp::SessionNotification,
87 threads: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
88 cx: &mut AsyncApp,
89 ) {
90 let threads = threads.borrow();
91 let Some(thread) = threads
92 .get(¬ification.session_id)
93 .and_then(|session| session.thread.upgrade())
94 else {
95 log::error!(
96 "Thread not found for session ID: {}",
97 notification.session_id
98 );
99 return;
100 };
101
102 thread
103 .update(cx, |thread, cx| {
104 thread.handle_session_update(notification.update, cx)
105 })
106 .log_err();
107 }
108}
109
110pub struct AcpSession {
111 thread: WeakEntity<AcpThread>,
112 cancel_tx: Option<oneshot::Sender<()>>,
113 _mcp_server: ZedMcpServer,
114}
115
116impl AgentConnection for AcpConnection {
117 fn new_thread(
118 self: Rc<Self>,
119 project: Entity<Project>,
120 cwd: &Path,
121 cx: &mut AsyncApp,
122 ) -> Task<Result<Entity<AcpThread>>> {
123 let client = self.client.client();
124 let sessions = self.sessions.clone();
125 let cwd = cwd.to_path_buf();
126 cx.spawn(async move |cx| {
127 let client = client.context("MCP server is not initialized yet")?;
128 let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
129
130 let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
131
132 let response = client
133 .request::<requests::CallTool>(context_server::types::CallToolParams {
134 name: acp::AGENT_METHODS.new_session.into(),
135 arguments: Some(serde_json::to_value(acp::NewSessionArguments {
136 mcp_servers: vec![mcp_server.server_config()?],
137 client_tools: acp::ClientTools {
138 request_permission: Some(acp::McpToolId {
139 mcp_server: mcp_server::SERVER_NAME.into(),
140 tool_name: mcp_server::RequestPermissionTool::NAME.into(),
141 }),
142 read_text_file: Some(acp::McpToolId {
143 mcp_server: mcp_server::SERVER_NAME.into(),
144 tool_name: mcp_server::ReadTextFileTool::NAME.into(),
145 }),
146 write_text_file: Some(acp::McpToolId {
147 mcp_server: mcp_server::SERVER_NAME.into(),
148 tool_name: mcp_server::WriteTextFileTool::NAME.into(),
149 }),
150 },
151 cwd,
152 })?),
153 meta: None,
154 })
155 .await?;
156
157 if response.is_error.unwrap_or_default() {
158 return Err(anyhow!(response.text_contents()));
159 }
160
161 let result = serde_json::from_value::<acp::NewSessionOutput>(
162 response.structured_content.context("Empty response")?,
163 )?;
164
165 let thread = cx.new(|cx| {
166 AcpThread::new(
167 self.server_name,
168 self.clone(),
169 project,
170 result.session_id.clone(),
171 cx,
172 )
173 })?;
174
175 thread_tx.send(thread.downgrade())?;
176
177 let session = AcpSession {
178 thread: thread.downgrade(),
179 cancel_tx: None,
180 _mcp_server: mcp_server,
181 };
182 sessions.borrow_mut().insert(result.session_id, session);
183
184 Ok(thread)
185 })
186 }
187
188 fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
189 Task::ready(Err(anyhow!("Authentication not supported")))
190 }
191
192 fn prompt(
193 &self,
194 params: agent_client_protocol::PromptArguments,
195 cx: &mut App,
196 ) -> Task<Result<()>> {
197 let client = self.client.client();
198 let sessions = self.sessions.clone();
199
200 cx.foreground_executor().spawn(async move {
201 let client = client.context("MCP server is not initialized yet")?;
202
203 let (new_cancel_tx, cancel_rx) = oneshot::channel();
204 {
205 let mut sessions = sessions.borrow_mut();
206 let session = sessions
207 .get_mut(¶ms.session_id)
208 .context("Session not found")?;
209 session.cancel_tx.replace(new_cancel_tx);
210 }
211
212 let result = client
213 .request_with::<requests::CallTool>(
214 context_server::types::CallToolParams {
215 name: acp::AGENT_METHODS.prompt.into(),
216 arguments: Some(serde_json::to_value(params)?),
217 meta: None,
218 },
219 Some(cancel_rx),
220 None,
221 )
222 .await;
223
224 if let Err(err) = &result
225 && err.is::<context_server::client::RequestCanceled>()
226 {
227 return Ok(());
228 }
229
230 let response = result?;
231
232 if response.is_error.unwrap_or_default() {
233 return Err(anyhow!(response.text_contents()));
234 }
235
236 Ok(())
237 })
238 }
239
240 fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
241 let mut sessions = self.sessions.borrow_mut();
242
243 if let Some(cancel_tx) = sessions
244 .get_mut(session_id)
245 .and_then(|session| session.cancel_tx.take())
246 {
247 cancel_tx.send(()).ok();
248 }
249 }
250}
251
252impl Drop for AcpConnection {
253 fn drop(&mut self) {
254 self.client.stop().log_err();
255 }
256}