1use crate::AgentServerCommand;
2use acp_thread::AgentConnection;
3use acp_tools::AcpConnectionRegistry;
4use action_log::ActionLog;
5use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
6use anyhow::anyhow;
7use collections::HashMap;
8use futures::AsyncBufReadExt as _;
9use futures::channel::oneshot;
10use futures::io::BufReader;
11use project::Project;
12use serde::Deserialize;
13use std::{any::Any, cell::RefCell};
14use std::{path::Path, rc::Rc};
15use thiserror::Error;
16
17use anyhow::{Context as _, Result};
18use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
19
20use acp_thread::{AcpThread, AuthRequired, LoadError};
21
22#[derive(Debug, Error)]
23#[error("Unsupported version")]
24pub struct UnsupportedVersion;
25
26pub struct AcpConnection {
27 server_name: SharedString,
28 connection: Rc<acp::ClientSideConnection>,
29 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
30 auth_methods: Vec<acp::AuthMethod>,
31 prompt_capabilities: acp::PromptCapabilities,
32 supports_rewind: bool,
33 _io_task: Task<Result<()>>,
34}
35
36pub struct AcpSession {
37 thread: WeakEntity<AcpThread>,
38 suppress_abort_err: bool,
39}
40
41pub async fn connect(
42 server_name: SharedString,
43 command: AgentServerCommand,
44 root_dir: &Path,
45 cx: &mut AsyncApp,
46) -> Result<Rc<dyn AgentConnection>> {
47 let conn = AcpConnection::stdio(server_name, command.clone(), root_dir, cx).await?;
48 Ok(Rc::new(conn) as _)
49}
50
51const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
52
53impl AcpConnection {
54 pub async fn stdio(
55 server_name: SharedString,
56 command: AgentServerCommand,
57 root_dir: &Path,
58 cx: &mut AsyncApp,
59 ) -> Result<Self> {
60 let mut child = util::command::new_smol_command(&command.path)
61 .args(command.args.iter().map(|arg| arg.as_str()))
62 .envs(command.env.iter().flatten())
63 .current_dir(root_dir)
64 .stdin(std::process::Stdio::piped())
65 .stdout(std::process::Stdio::piped())
66 .stderr(std::process::Stdio::piped())
67 .kill_on_drop(true)
68 .spawn()?;
69
70 let stdout = child.stdout.take().context("Failed to take stdout")?;
71 let stdin = child.stdin.take().context("Failed to take stdin")?;
72 let stderr = child.stderr.take().context("Failed to take stderr")?;
73 log::trace!("Spawned (pid: {})", child.id());
74
75 let sessions = Rc::new(RefCell::new(HashMap::default()));
76
77 let client = ClientDelegate {
78 sessions: sessions.clone(),
79 cx: cx.clone(),
80 };
81 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
82 let foreground_executor = cx.foreground_executor().clone();
83 move |fut| {
84 foreground_executor.spawn(fut).detach();
85 }
86 });
87
88 let io_task = cx.background_spawn(io_task);
89
90 cx.background_spawn(async move {
91 let mut stderr = BufReader::new(stderr);
92 let mut line = String::new();
93 while let Ok(n) = stderr.read_line(&mut line).await
94 && n > 0
95 {
96 log::warn!("agent stderr: {}", &line);
97 line.clear();
98 }
99 })
100 .detach();
101
102 cx.spawn({
103 let sessions = sessions.clone();
104 async move |cx| {
105 let status = child.status().await?;
106
107 for session in sessions.borrow().values() {
108 session
109 .thread
110 .update(cx, |thread, cx| {
111 thread.emit_load_error(LoadError::Exited { status }, cx)
112 })
113 .ok();
114 }
115
116 anyhow::Ok(())
117 }
118 })
119 .detach();
120
121 let connection = Rc::new(connection);
122
123 cx.update(|cx| {
124 AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
125 registry.set_active_connection(server_name.clone(), &connection, cx)
126 });
127 })?;
128
129 let response = connection
130 .initialize(acp::InitializeRequest {
131 protocol_version: acp::VERSION,
132 client_capabilities: acp::ClientCapabilities {
133 fs: acp::FileSystemCapability {
134 read_text_file: true,
135 write_text_file: true,
136 },
137 },
138 })
139 .await?;
140
141 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
142 return Err(UnsupportedVersion.into());
143 }
144
145 Ok(Self {
146 auth_methods: response.auth_methods,
147 connection,
148 server_name,
149 sessions,
150 prompt_capabilities: response.agent_capabilities.prompt_capabilities,
151 supports_rewind: response.agent_capabilities.rewind_session,
152 _io_task: io_task,
153 })
154 }
155}
156
157impl AgentConnection for AcpConnection {
158 fn new_thread(
159 self: Rc<Self>,
160 project: Entity<Project>,
161 cwd: &Path,
162 cx: &mut App,
163 ) -> Task<Result<Entity<AcpThread>>> {
164 let conn = self.connection.clone();
165 let sessions = self.sessions.clone();
166 let cwd = cwd.to_path_buf();
167 cx.spawn(async move |cx| {
168 let response = conn
169 .new_session(acp::NewSessionRequest {
170 mcp_servers: vec![],
171 cwd,
172 })
173 .await
174 .map_err(|err| {
175 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
176 let mut error = AuthRequired::new();
177
178 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
179 error = error.with_description(err.message);
180 }
181
182 anyhow!(error)
183 } else {
184 anyhow!(err)
185 }
186 })?;
187
188 let session_id = response.session_id;
189 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
190 let thread = cx.new(|cx| {
191 AcpThread::new(
192 self.server_name.clone(),
193 self.clone(),
194 project,
195 action_log,
196 session_id.clone(),
197 // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
198 watch::Receiver::constant(self.prompt_capabilities),
199 cx,
200 )
201 })?;
202
203 let session = AcpSession {
204 thread: thread.downgrade(),
205 suppress_abort_err: false,
206 };
207 sessions.borrow_mut().insert(session_id, session);
208
209 Ok(thread)
210 })
211 }
212
213 fn auth_methods(&self) -> &[acp::AuthMethod] {
214 &self.auth_methods
215 }
216
217 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
218 let conn = self.connection.clone();
219 cx.foreground_executor().spawn(async move {
220 let result = conn
221 .authenticate(acp::AuthenticateRequest {
222 method_id: method_id.clone(),
223 })
224 .await?;
225
226 Ok(result)
227 })
228 }
229
230 fn rewind(
231 &self,
232 session_id: &agent_client_protocol::SessionId,
233 _cx: &App,
234 ) -> Option<Rc<dyn acp_thread::AgentSessionRewind>> {
235 if !self.supports_rewind {
236 return None;
237 }
238 Some(Rc::new(AcpRewinder {
239 connection: self.connection.clone(),
240 session_id: session_id.clone(),
241 }) as _)
242 }
243
244 fn prompt(
245 &self,
246 params: acp::PromptRequest,
247 cx: &mut App,
248 ) -> Task<Result<acp::PromptResponse>> {
249 let conn = self.connection.clone();
250 let sessions = self.sessions.clone();
251 let session_id = params.session_id.clone();
252 cx.foreground_executor().spawn(async move {
253 let result = conn.prompt(params).await;
254
255 let mut suppress_abort_err = false;
256
257 if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
258 suppress_abort_err = session.suppress_abort_err;
259 session.suppress_abort_err = false;
260 }
261
262 match result {
263 Ok(response) => Ok(response),
264 Err(err) => {
265 if err.code != ErrorCode::INTERNAL_ERROR.code {
266 anyhow::bail!(err)
267 }
268
269 let Some(data) = &err.data else {
270 anyhow::bail!(err)
271 };
272
273 // Temporary workaround until the following PR is generally available:
274 // https://github.com/google-gemini/gemini-cli/pull/6656
275
276 #[derive(Deserialize)]
277 #[serde(deny_unknown_fields)]
278 struct ErrorDetails {
279 details: Box<str>,
280 }
281
282 match serde_json::from_value(data.clone()) {
283 Ok(ErrorDetails { details }) => {
284 if suppress_abort_err
285 && (details.contains("This operation was aborted")
286 || details.contains("The user aborted a request"))
287 {
288 Ok(acp::PromptResponse {
289 stop_reason: acp::StopReason::Cancelled,
290 })
291 } else {
292 Err(anyhow!(details))
293 }
294 }
295 Err(_) => Err(anyhow!(err)),
296 }
297 }
298 }
299 })
300 }
301
302 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
303 if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
304 session.suppress_abort_err = true;
305 }
306 let conn = self.connection.clone();
307 let params = acp::CancelNotification {
308 session_id: session_id.clone(),
309 };
310 cx.foreground_executor()
311 .spawn(async move { conn.cancel(params).await })
312 .detach();
313 }
314
315 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
316 self
317 }
318}
319
320struct AcpRewinder {
321 connection: Rc<acp::ClientSideConnection>,
322 session_id: acp::SessionId,
323}
324
325impl acp_thread::AgentSessionRewind for AcpRewinder {
326 fn rewind(&self, prompt_id: acp::PromptId, cx: &mut App) -> Task<Result<()>> {
327 let conn = self.connection.clone();
328 let params = acp::RewindRequest {
329 session_id: self.session_id.clone(),
330 prompt_id,
331 };
332 cx.foreground_executor().spawn(async move {
333 conn.rewind(params).await?;
334 Ok(())
335 })
336 }
337}
338
339struct ClientDelegate {
340 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
341 cx: AsyncApp,
342}
343
344impl acp::Client for ClientDelegate {
345 async fn request_permission(
346 &self,
347 arguments: acp::RequestPermissionRequest,
348 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
349 let cx = &mut self.cx.clone();
350 let rx = self
351 .sessions
352 .borrow()
353 .get(&arguments.session_id)
354 .context("Failed to get session")?
355 .thread
356 .update(cx, |thread, cx| {
357 thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
358 })?;
359
360 let result = rx?.await;
361
362 let outcome = match result {
363 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
364 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
365 };
366
367 Ok(acp::RequestPermissionResponse { outcome })
368 }
369
370 async fn write_text_file(
371 &self,
372 arguments: acp::WriteTextFileRequest,
373 ) -> Result<(), acp::Error> {
374 let cx = &mut self.cx.clone();
375 let task = self
376 .sessions
377 .borrow()
378 .get(&arguments.session_id)
379 .context("Failed to get session")?
380 .thread
381 .update(cx, |thread, cx| {
382 thread.write_text_file(arguments.path, arguments.content, cx)
383 })?;
384
385 task.await?;
386
387 Ok(())
388 }
389
390 async fn read_text_file(
391 &self,
392 arguments: acp::ReadTextFileRequest,
393 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
394 let cx = &mut self.cx.clone();
395 let task = self
396 .sessions
397 .borrow()
398 .get(&arguments.session_id)
399 .context("Failed to get session")?
400 .thread
401 .update(cx, |thread, cx| {
402 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
403 })?;
404
405 let content = task.await?;
406
407 Ok(acp::ReadTextFileResponse { content })
408 }
409
410 async fn session_notification(
411 &self,
412 notification: acp::SessionNotification,
413 ) -> Result<(), acp::Error> {
414 let cx = &mut self.cx.clone();
415 let sessions = self.sessions.borrow();
416 let session = sessions
417 .get(¬ification.session_id)
418 .context("Failed to get session")?;
419
420 session.thread.update(cx, |thread, cx| {
421 thread.handle_session_update(notification.update, cx)
422 })??;
423
424 Ok(())
425 }
426}