1use crate::AgentServerCommand;
2use acp_thread::AgentConnection;
3use acp_tools::AcpConnectionRegistry;
4use action_log::ActionLog;
5use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
6use anyhow::anyhow;
7use collections::HashMap;
8use futures::AsyncBufReadExt as _;
9use futures::io::BufReader;
10use project::Project;
11use serde::Deserialize;
12use util::ResultExt as _;
13
14use std::{any::Any, cell::RefCell};
15use std::{path::Path, rc::Rc};
16use thiserror::Error;
17
18use anyhow::{Context as _, Result};
19use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
20
21use acp_thread::{AcpThread, AuthRequired, LoadError};
22
23#[derive(Debug, Error)]
24#[error("Unsupported version")]
25pub struct UnsupportedVersion;
26
27pub struct AcpConnection {
28 server_name: SharedString,
29 connection: Rc<acp::ClientSideConnection>,
30 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
31 auth_methods: Vec<acp::AuthMethod>,
32 agent_capabilities: acp::AgentCapabilities,
33 // NB: Don't move this into the wait_task, since we need to ensure the process is
34 // killed on drop (setting kill_on_drop on the command seems to not always work).
35 child: smol::process::Child,
36 _io_task: Task<Result<()>>,
37 _wait_task: Task<Result<()>>,
38 _stderr_task: Task<Result<()>>,
39}
40
41pub struct AcpSession {
42 thread: WeakEntity<AcpThread>,
43 suppress_abort_err: bool,
44}
45
46pub async fn connect(
47 server_name: SharedString,
48 command: AgentServerCommand,
49 root_dir: &Path,
50 cx: &mut AsyncApp,
51) -> Result<Rc<dyn AgentConnection>> {
52 let conn = AcpConnection::stdio(server_name, command.clone(), root_dir, cx).await?;
53 Ok(Rc::new(conn) as _)
54}
55
56const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
57
58impl AcpConnection {
59 pub async fn stdio(
60 server_name: SharedString,
61 command: AgentServerCommand,
62 root_dir: &Path,
63 cx: &mut AsyncApp,
64 ) -> Result<Self> {
65 let mut child = util::command::new_smol_command(command.path)
66 .args(command.args.iter().map(|arg| arg.as_str()))
67 .envs(command.env.iter().flatten())
68 .current_dir(root_dir)
69 .stdin(std::process::Stdio::piped())
70 .stdout(std::process::Stdio::piped())
71 .stderr(std::process::Stdio::piped())
72 .spawn()?;
73
74 let stdout = child.stdout.take().context("Failed to take stdout")?;
75 let stdin = child.stdin.take().context("Failed to take stdin")?;
76 let stderr = child.stderr.take().context("Failed to take stderr")?;
77 log::trace!("Spawned (pid: {})", child.id());
78
79 let sessions = Rc::new(RefCell::new(HashMap::default()));
80
81 let client = ClientDelegate {
82 sessions: sessions.clone(),
83 cx: cx.clone(),
84 };
85 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
86 let foreground_executor = cx.foreground_executor().clone();
87 move |fut| {
88 foreground_executor.spawn(fut).detach();
89 }
90 });
91
92 let io_task = cx.background_spawn(io_task);
93
94 let stderr_task = cx.background_spawn(async move {
95 let mut stderr = BufReader::new(stderr);
96 let mut line = String::new();
97 while let Ok(n) = stderr.read_line(&mut line).await
98 && n > 0
99 {
100 log::warn!("agent stderr: {}", &line);
101 line.clear();
102 }
103 Ok(())
104 });
105
106 let wait_task = cx.spawn({
107 let sessions = sessions.clone();
108 let status_fut = child.status();
109 async move |cx| {
110 let status = status_fut.await?;
111
112 for session in sessions.borrow().values() {
113 session
114 .thread
115 .update(cx, |thread, cx| {
116 thread.emit_load_error(LoadError::Exited { status }, cx)
117 })
118 .ok();
119 }
120
121 anyhow::Ok(())
122 }
123 });
124
125 let connection = Rc::new(connection);
126
127 cx.update(|cx| {
128 AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
129 registry.set_active_connection(server_name.clone(), &connection, cx)
130 });
131 })?;
132
133 let response = connection
134 .initialize(acp::InitializeRequest {
135 protocol_version: acp::VERSION,
136 client_capabilities: acp::ClientCapabilities {
137 fs: acp::FileSystemCapability {
138 read_text_file: true,
139 write_text_file: true,
140 },
141 terminal: true,
142 },
143 })
144 .await?;
145
146 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
147 return Err(UnsupportedVersion.into());
148 }
149
150 Ok(Self {
151 auth_methods: response.auth_methods,
152 connection,
153 server_name,
154 sessions,
155 agent_capabilities: response.agent_capabilities,
156 _io_task: io_task,
157 _wait_task: wait_task,
158 _stderr_task: stderr_task,
159 child,
160 })
161 }
162
163 pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
164 &self.agent_capabilities.prompt_capabilities
165 }
166}
167
168impl Drop for AcpConnection {
169 fn drop(&mut self) {
170 // See the comment on the child field.
171 self.child.kill().log_err();
172 }
173}
174
175impl AgentConnection for AcpConnection {
176 fn new_thread(
177 self: Rc<Self>,
178 project: Entity<Project>,
179 cwd: &Path,
180 cx: &mut App,
181 ) -> Task<Result<Entity<AcpThread>>> {
182 let conn = self.connection.clone();
183 let sessions = self.sessions.clone();
184 let cwd = cwd.to_path_buf();
185 let context_server_store = project.read(cx).context_server_store().read(cx);
186 let mcp_servers = context_server_store
187 .configured_server_ids()
188 .iter()
189 .filter_map(|id| {
190 let configuration = context_server_store.configuration_for_server(id)?;
191 let command = configuration.command();
192 Some(acp::McpServer {
193 name: id.0.to_string(),
194 command: command.path.clone(),
195 args: command.args.clone(),
196 env: if let Some(env) = command.env.as_ref() {
197 env.iter()
198 .map(|(name, value)| acp::EnvVariable {
199 name: name.clone(),
200 value: value.clone(),
201 })
202 .collect()
203 } else {
204 vec![]
205 },
206 })
207 })
208 .collect();
209
210 cx.spawn(async move |cx| {
211 let response = conn
212 .new_session(acp::NewSessionRequest { mcp_servers, cwd })
213 .await
214 .map_err(|err| {
215 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
216 let mut error = AuthRequired::new();
217
218 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
219 error = error.with_description(err.message);
220 }
221
222 anyhow!(error)
223 } else {
224 anyhow!(err)
225 }
226 })?;
227
228 let session_id = response.session_id;
229 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
230 let thread = cx.new(|cx| {
231 AcpThread::new(
232 self.server_name.clone(),
233 self.clone(),
234 project,
235 action_log,
236 session_id.clone(),
237 // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
238 watch::Receiver::constant(self.agent_capabilities.prompt_capabilities),
239 cx,
240 )
241 })?;
242
243 let session = AcpSession {
244 thread: thread.downgrade(),
245 suppress_abort_err: false,
246 };
247 sessions.borrow_mut().insert(session_id, session);
248
249 Ok(thread)
250 })
251 }
252
253 fn auth_methods(&self) -> &[acp::AuthMethod] {
254 &self.auth_methods
255 }
256
257 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
258 let conn = self.connection.clone();
259 cx.foreground_executor().spawn(async move {
260 let result = conn
261 .authenticate(acp::AuthenticateRequest {
262 method_id: method_id.clone(),
263 })
264 .await?;
265
266 Ok(result)
267 })
268 }
269
270 fn prompt(
271 &self,
272 _id: Option<acp_thread::UserMessageId>,
273 params: acp::PromptRequest,
274 cx: &mut App,
275 ) -> Task<Result<acp::PromptResponse>> {
276 let conn = self.connection.clone();
277 let sessions = self.sessions.clone();
278 let session_id = params.session_id.clone();
279 cx.foreground_executor().spawn(async move {
280 let result = conn.prompt(params).await;
281
282 let mut suppress_abort_err = false;
283
284 if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
285 suppress_abort_err = session.suppress_abort_err;
286 session.suppress_abort_err = false;
287 }
288
289 match result {
290 Ok(response) => Ok(response),
291 Err(err) => {
292 if err.code != ErrorCode::INTERNAL_ERROR.code {
293 anyhow::bail!(err)
294 }
295
296 let Some(data) = &err.data else {
297 anyhow::bail!(err)
298 };
299
300 // Temporary workaround until the following PR is generally available:
301 // https://github.com/google-gemini/gemini-cli/pull/6656
302
303 #[derive(Deserialize)]
304 #[serde(deny_unknown_fields)]
305 struct ErrorDetails {
306 details: Box<str>,
307 }
308
309 match serde_json::from_value(data.clone()) {
310 Ok(ErrorDetails { details }) => {
311 if suppress_abort_err
312 && (details.contains("This operation was aborted")
313 || details.contains("The user aborted a request"))
314 {
315 Ok(acp::PromptResponse {
316 stop_reason: acp::StopReason::Cancelled,
317 })
318 } else {
319 Err(anyhow!(details))
320 }
321 }
322 Err(_) => Err(anyhow!(err)),
323 }
324 }
325 }
326 })
327 }
328
329 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
330 if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
331 session.suppress_abort_err = true;
332 }
333 let conn = self.connection.clone();
334 let params = acp::CancelNotification {
335 session_id: session_id.clone(),
336 };
337 cx.foreground_executor()
338 .spawn(async move { conn.cancel(params).await })
339 .detach();
340 }
341
342 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
343 self
344 }
345}
346
347struct ClientDelegate {
348 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
349 cx: AsyncApp,
350}
351
352impl acp::Client for ClientDelegate {
353 async fn request_permission(
354 &self,
355 arguments: acp::RequestPermissionRequest,
356 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
357 let cx = &mut self.cx.clone();
358
359 let task = self
360 .session_thread(&arguments.session_id)?
361 .update(cx, |thread, cx| {
362 thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
363 })??;
364
365 let outcome = task.await;
366
367 Ok(acp::RequestPermissionResponse { outcome })
368 }
369
370 async fn write_text_file(
371 &self,
372 arguments: acp::WriteTextFileRequest,
373 ) -> Result<(), acp::Error> {
374 let cx = &mut self.cx.clone();
375 let task = self
376 .session_thread(&arguments.session_id)?
377 .update(cx, |thread, cx| {
378 thread.write_text_file(arguments.path, arguments.content, cx)
379 })?;
380
381 task.await?;
382
383 Ok(())
384 }
385
386 async fn read_text_file(
387 &self,
388 arguments: acp::ReadTextFileRequest,
389 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
390 let task = self.session_thread(&arguments.session_id)?.update(
391 &mut self.cx.clone(),
392 |thread, cx| {
393 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
394 },
395 )?;
396
397 let content = task.await?;
398
399 Ok(acp::ReadTextFileResponse { content })
400 }
401
402 async fn session_notification(
403 &self,
404 notification: acp::SessionNotification,
405 ) -> Result<(), acp::Error> {
406 self.session_thread(¬ification.session_id)?
407 .update(&mut self.cx.clone(), |thread, cx| {
408 thread.handle_session_update(notification.update, cx)
409 })??;
410
411 Ok(())
412 }
413
414 async fn create_terminal(
415 &self,
416 args: acp::CreateTerminalRequest,
417 ) -> Result<acp::CreateTerminalResponse, acp::Error> {
418 let terminal = self
419 .session_thread(&args.session_id)?
420 .update(&mut self.cx.clone(), |thread, cx| {
421 thread.create_terminal(
422 args.command,
423 args.args,
424 args.env,
425 args.cwd,
426 args.output_byte_limit,
427 cx,
428 )
429 })?
430 .await?;
431 Ok(
432 terminal.read_with(&self.cx, |terminal, _| acp::CreateTerminalResponse {
433 terminal_id: terminal.id().clone(),
434 })?,
435 )
436 }
437
438 async fn kill_terminal(&self, args: acp::KillTerminalRequest) -> Result<(), acp::Error> {
439 self.session_thread(&args.session_id)?
440 .update(&mut self.cx.clone(), |thread, cx| {
441 thread.kill_terminal(args.terminal_id, cx)
442 })??;
443
444 Ok(())
445 }
446
447 async fn release_terminal(&self, args: acp::ReleaseTerminalRequest) -> Result<(), acp::Error> {
448 self.session_thread(&args.session_id)?
449 .update(&mut self.cx.clone(), |thread, cx| {
450 thread.release_terminal(args.terminal_id, cx)
451 })??;
452
453 Ok(())
454 }
455
456 async fn terminal_output(
457 &self,
458 args: acp::TerminalOutputRequest,
459 ) -> Result<acp::TerminalOutputResponse, acp::Error> {
460 self.session_thread(&args.session_id)?
461 .read_with(&mut self.cx.clone(), |thread, cx| {
462 let out = thread
463 .terminal(args.terminal_id)?
464 .read(cx)
465 .current_output(cx);
466
467 Ok(out)
468 })?
469 }
470
471 async fn wait_for_terminal_exit(
472 &self,
473 args: acp::WaitForTerminalExitRequest,
474 ) -> Result<acp::WaitForTerminalExitResponse, acp::Error> {
475 let exit_status = self
476 .session_thread(&args.session_id)?
477 .update(&mut self.cx.clone(), |thread, cx| {
478 anyhow::Ok(thread.terminal(args.terminal_id)?.read(cx).wait_for_exit())
479 })??
480 .await;
481
482 Ok(acp::WaitForTerminalExitResponse { exit_status })
483 }
484}
485
486impl ClientDelegate {
487 fn session_thread(&self, session_id: &acp::SessionId) -> Result<WeakEntity<AcpThread>> {
488 let sessions = self.sessions.borrow();
489 sessions
490 .get(session_id)
491 .context("Failed to get session")
492 .map(|session| session.thread.clone())
493 }
494}