1use acp_thread::AgentConnection;
2use acp_tools::AcpConnectionRegistry;
3use action_log::ActionLog;
4use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
5use anyhow::anyhow;
6use collections::HashMap;
7use futures::AsyncBufReadExt as _;
8use futures::io::BufReader;
9use project::Project;
10use project::agent_server_store::AgentServerCommand;
11use serde::Deserialize;
12use util::ResultExt as _;
13
14use std::path::PathBuf;
15use std::{any::Any, cell::RefCell};
16use std::{path::Path, rc::Rc};
17use thiserror::Error;
18
19use anyhow::{Context as _, Result};
20use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
21
22use acp_thread::{AcpThread, AuthRequired, LoadError};
23
24#[derive(Debug, Error)]
25#[error("Unsupported version")]
26pub struct UnsupportedVersion;
27
28pub struct AcpConnection {
29 server_name: SharedString,
30 connection: Rc<acp::ClientSideConnection>,
31 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
32 auth_methods: Vec<acp::AuthMethod>,
33 agent_capabilities: acp::AgentCapabilities,
34 default_mode: Option<acp::SessionModeId>,
35 root_dir: PathBuf,
36 // NB: Don't move this into the wait_task, since we need to ensure the process is
37 // killed on drop (setting kill_on_drop on the command seems to not always work).
38 child: smol::process::Child,
39 _io_task: Task<Result<()>>,
40 _wait_task: Task<Result<()>>,
41 _stderr_task: Task<Result<()>>,
42}
43
44pub struct AcpSession {
45 thread: WeakEntity<AcpThread>,
46 suppress_abort_err: bool,
47 session_modes: Option<Rc<RefCell<acp::SessionModeState>>>,
48}
49
50pub async fn connect(
51 server_name: SharedString,
52 command: AgentServerCommand,
53 root_dir: &Path,
54 default_mode: Option<acp::SessionModeId>,
55 is_remote: bool,
56 cx: &mut AsyncApp,
57) -> Result<Rc<dyn AgentConnection>> {
58 let conn = AcpConnection::stdio(
59 server_name,
60 command.clone(),
61 root_dir,
62 default_mode,
63 is_remote,
64 cx,
65 )
66 .await?;
67 Ok(Rc::new(conn) as _)
68}
69
70const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
71
72impl AcpConnection {
73 pub async fn stdio(
74 server_name: SharedString,
75 command: AgentServerCommand,
76 root_dir: &Path,
77 default_mode: Option<acp::SessionModeId>,
78 is_remote: bool,
79 cx: &mut AsyncApp,
80 ) -> Result<Self> {
81 let mut child = util::command::new_smol_command(command.path);
82 child
83 .args(command.args.iter().map(|arg| arg.as_str()))
84 .envs(command.env.iter().flatten())
85 .stdin(std::process::Stdio::piped())
86 .stdout(std::process::Stdio::piped())
87 .stderr(std::process::Stdio::piped());
88 if !is_remote {
89 child.current_dir(root_dir);
90 }
91 let mut child = child.spawn()?;
92
93 let stdout = child.stdout.take().context("Failed to take stdout")?;
94 let stdin = child.stdin.take().context("Failed to take stdin")?;
95 let stderr = child.stderr.take().context("Failed to take stderr")?;
96 log::trace!("Spawned (pid: {})", child.id());
97
98 let sessions = Rc::new(RefCell::new(HashMap::default()));
99
100 let client = ClientDelegate {
101 sessions: sessions.clone(),
102 cx: cx.clone(),
103 };
104 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
105 let foreground_executor = cx.foreground_executor().clone();
106 move |fut| {
107 foreground_executor.spawn(fut).detach();
108 }
109 });
110
111 let io_task = cx.background_spawn(io_task);
112
113 let stderr_task = cx.background_spawn(async move {
114 let mut stderr = BufReader::new(stderr);
115 let mut line = String::new();
116 while let Ok(n) = stderr.read_line(&mut line).await
117 && n > 0
118 {
119 log::warn!("agent stderr: {}", &line);
120 line.clear();
121 }
122 Ok(())
123 });
124
125 let wait_task = cx.spawn({
126 let sessions = sessions.clone();
127 let status_fut = child.status();
128 async move |cx| {
129 let status = status_fut.await?;
130
131 for session in sessions.borrow().values() {
132 session
133 .thread
134 .update(cx, |thread, cx| {
135 thread.emit_load_error(LoadError::Exited { status }, cx)
136 })
137 .ok();
138 }
139
140 anyhow::Ok(())
141 }
142 });
143
144 let connection = Rc::new(connection);
145
146 cx.update(|cx| {
147 AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
148 registry.set_active_connection(server_name.clone(), &connection, cx)
149 });
150 })?;
151
152 let response = connection
153 .initialize(acp::InitializeRequest {
154 protocol_version: acp::VERSION,
155 client_capabilities: acp::ClientCapabilities {
156 fs: acp::FileSystemCapability {
157 read_text_file: true,
158 write_text_file: true,
159 },
160 terminal: true,
161 },
162 })
163 .await?;
164
165 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
166 return Err(UnsupportedVersion.into());
167 }
168
169 Ok(Self {
170 auth_methods: response.auth_methods,
171 root_dir: root_dir.to_owned(),
172 connection,
173 server_name,
174 sessions,
175 agent_capabilities: response.agent_capabilities,
176 default_mode,
177 _io_task: io_task,
178 _wait_task: wait_task,
179 _stderr_task: stderr_task,
180 child,
181 })
182 }
183
184 pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
185 &self.agent_capabilities.prompt_capabilities
186 }
187
188 pub fn root_dir(&self) -> &Path {
189 &self.root_dir
190 }
191}
192
193impl Drop for AcpConnection {
194 fn drop(&mut self) {
195 // See the comment on the child field.
196 self.child.kill().log_err();
197 }
198}
199
200impl AgentConnection for AcpConnection {
201 fn new_thread(
202 self: Rc<Self>,
203 project: Entity<Project>,
204 cwd: &Path,
205 cx: &mut App,
206 ) -> Task<Result<Entity<AcpThread>>> {
207 let name = self.server_name.clone();
208 let conn = self.connection.clone();
209 let sessions = self.sessions.clone();
210 let default_mode = self.default_mode.clone();
211 let cwd = cwd.to_path_buf();
212 let context_server_store = project.read(cx).context_server_store().read(cx);
213 let mcp_servers = if project.read(cx).is_local() {
214 context_server_store
215 .configured_server_ids()
216 .iter()
217 .filter_map(|id| {
218 let configuration = context_server_store.configuration_for_server(id)?;
219 let command = configuration.command();
220 Some(acp::McpServer::Stdio {
221 name: id.0.to_string(),
222 command: command.path.clone(),
223 args: command.args.clone(),
224 env: if let Some(env) = command.env.as_ref() {
225 env.iter()
226 .map(|(name, value)| acp::EnvVariable {
227 name: name.clone(),
228 value: value.clone(),
229 })
230 .collect()
231 } else {
232 vec![]
233 },
234 })
235 })
236 .collect()
237 } else {
238 // In SSH projects, the external agent is running on the remote
239 // machine, and currently we only run MCP servers on the local
240 // machine. So don't pass any MCP servers to the agent in that case.
241 Vec::new()
242 };
243
244 cx.spawn(async move |cx| {
245 let response = conn
246 .new_session(acp::NewSessionRequest { mcp_servers, cwd })
247 .await
248 .map_err(|err| {
249 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
250 let mut error = AuthRequired::new();
251
252 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
253 error = error.with_description(err.message);
254 }
255
256 anyhow!(error)
257 } else {
258 anyhow!(err)
259 }
260 })?;
261
262 let modes = response.modes.map(|modes| Rc::new(RefCell::new(modes)));
263
264 if let Some(default_mode) = default_mode {
265 if let Some(modes) = modes.as_ref() {
266 let mut modes_ref = modes.borrow_mut();
267 let has_mode = modes_ref.available_modes.iter().any(|mode| mode.id == default_mode);
268
269 if has_mode {
270 let initial_mode_id = modes_ref.current_mode_id.clone();
271
272 cx.spawn({
273 let default_mode = default_mode.clone();
274 let session_id = response.session_id.clone();
275 let modes = modes.clone();
276 async move |_| {
277 let result = conn.set_session_mode(acp::SetSessionModeRequest {
278 session_id,
279 mode_id: default_mode,
280 })
281 .await.log_err();
282
283 if result.is_none() {
284 modes.borrow_mut().current_mode_id = initial_mode_id;
285 }
286 }
287 }).detach();
288
289 modes_ref.current_mode_id = default_mode;
290 } else {
291 let available_modes = modes_ref
292 .available_modes
293 .iter()
294 .map(|mode| format!("- `{}`: {}", mode.id, mode.name))
295 .collect::<Vec<_>>()
296 .join("\n");
297
298 log::warn!(
299 "`{default_mode}` is not valid {name} mode. Available options:\n{available_modes}",
300 );
301 }
302 } else {
303 log::warn!(
304 "`{name}` does not support modes, but `default_mode` was set in settings.",
305 );
306 }
307 }
308
309 let session_id = response.session_id;
310 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
311 let thread = cx.new(|cx| {
312 AcpThread::new(
313 self.server_name.clone(),
314 self.clone(),
315 project,
316 action_log,
317 session_id.clone(),
318 // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
319 watch::Receiver::constant(self.agent_capabilities.prompt_capabilities),
320 cx,
321 )
322 })?;
323
324 let session = AcpSession {
325 thread: thread.downgrade(),
326 suppress_abort_err: false,
327 session_modes: modes
328 };
329 sessions.borrow_mut().insert(session_id, session);
330
331 Ok(thread)
332 })
333 }
334
335 fn auth_methods(&self) -> &[acp::AuthMethod] {
336 &self.auth_methods
337 }
338
339 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
340 let conn = self.connection.clone();
341 cx.foreground_executor().spawn(async move {
342 let result = conn
343 .authenticate(acp::AuthenticateRequest {
344 method_id: method_id.clone(),
345 })
346 .await?;
347
348 Ok(result)
349 })
350 }
351
352 fn prompt(
353 &self,
354 _id: Option<acp_thread::UserMessageId>,
355 params: acp::PromptRequest,
356 cx: &mut App,
357 ) -> Task<Result<acp::PromptResponse>> {
358 let conn = self.connection.clone();
359 let sessions = self.sessions.clone();
360 let session_id = params.session_id.clone();
361 cx.foreground_executor().spawn(async move {
362 let result = conn.prompt(params).await;
363
364 let mut suppress_abort_err = false;
365
366 if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
367 suppress_abort_err = session.suppress_abort_err;
368 session.suppress_abort_err = false;
369 }
370
371 match result {
372 Ok(response) => Ok(response),
373 Err(err) => {
374 if err.code != ErrorCode::INTERNAL_ERROR.code {
375 anyhow::bail!(err)
376 }
377
378 let Some(data) = &err.data else {
379 anyhow::bail!(err)
380 };
381
382 // Temporary workaround until the following PR is generally available:
383 // https://github.com/google-gemini/gemini-cli/pull/6656
384
385 #[derive(Deserialize)]
386 #[serde(deny_unknown_fields)]
387 struct ErrorDetails {
388 details: Box<str>,
389 }
390
391 match serde_json::from_value(data.clone()) {
392 Ok(ErrorDetails { details }) => {
393 if suppress_abort_err
394 && (details.contains("This operation was aborted")
395 || details.contains("The user aborted a request"))
396 {
397 Ok(acp::PromptResponse {
398 stop_reason: acp::StopReason::Cancelled,
399 })
400 } else {
401 Err(anyhow!(details))
402 }
403 }
404 Err(_) => Err(anyhow!(err)),
405 }
406 }
407 }
408 })
409 }
410
411 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
412 if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
413 session.suppress_abort_err = true;
414 }
415 let conn = self.connection.clone();
416 let params = acp::CancelNotification {
417 session_id: session_id.clone(),
418 };
419 cx.foreground_executor()
420 .spawn(async move { conn.cancel(params).await })
421 .detach();
422 }
423
424 fn session_modes(
425 &self,
426 session_id: &acp::SessionId,
427 _cx: &App,
428 ) -> Option<Rc<dyn acp_thread::AgentSessionModes>> {
429 let sessions = self.sessions.clone();
430 let sessions_ref = sessions.borrow();
431 let Some(session) = sessions_ref.get(session_id) else {
432 return None;
433 };
434
435 if let Some(modes) = session.session_modes.as_ref() {
436 Some(Rc::new(AcpSessionModes {
437 connection: self.connection.clone(),
438 session_id: session_id.clone(),
439 state: modes.clone(),
440 }) as _)
441 } else {
442 None
443 }
444 }
445
446 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
447 self
448 }
449}
450
451struct AcpSessionModes {
452 session_id: acp::SessionId,
453 connection: Rc<acp::ClientSideConnection>,
454 state: Rc<RefCell<acp::SessionModeState>>,
455}
456
457impl acp_thread::AgentSessionModes for AcpSessionModes {
458 fn current_mode(&self) -> acp::SessionModeId {
459 self.state.borrow().current_mode_id.clone()
460 }
461
462 fn all_modes(&self) -> Vec<acp::SessionMode> {
463 self.state.borrow().available_modes.clone()
464 }
465
466 fn set_mode(&self, mode_id: acp::SessionModeId, cx: &mut App) -> Task<Result<()>> {
467 let connection = self.connection.clone();
468 let session_id = self.session_id.clone();
469 let old_mode_id;
470 {
471 let mut state = self.state.borrow_mut();
472 old_mode_id = state.current_mode_id.clone();
473 state.current_mode_id = mode_id.clone();
474 };
475 let state = self.state.clone();
476 cx.foreground_executor().spawn(async move {
477 let result = connection
478 .set_session_mode(acp::SetSessionModeRequest {
479 session_id,
480 mode_id,
481 })
482 .await;
483
484 if result.is_err() {
485 state.borrow_mut().current_mode_id = old_mode_id;
486 }
487
488 result?;
489
490 Ok(())
491 })
492 }
493}
494
495struct ClientDelegate {
496 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
497 cx: AsyncApp,
498}
499
500impl acp::Client for ClientDelegate {
501 async fn request_permission(
502 &self,
503 arguments: acp::RequestPermissionRequest,
504 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
505 let respect_always_allow_setting;
506 let thread;
507 {
508 let sessions_ref = self.sessions.borrow();
509 let session = sessions_ref
510 .get(&arguments.session_id)
511 .context("Failed to get session")?;
512 respect_always_allow_setting = session.session_modes.is_none();
513 thread = session.thread.clone();
514 }
515
516 let cx = &mut self.cx.clone();
517
518 let task = thread.update(cx, |thread, cx| {
519 thread.request_tool_call_authorization(
520 arguments.tool_call,
521 arguments.options,
522 respect_always_allow_setting,
523 cx,
524 )
525 })??;
526
527 let outcome = task.await;
528
529 Ok(acp::RequestPermissionResponse { outcome })
530 }
531
532 async fn write_text_file(
533 &self,
534 arguments: acp::WriteTextFileRequest,
535 ) -> Result<(), acp::Error> {
536 let cx = &mut self.cx.clone();
537 let task = self
538 .session_thread(&arguments.session_id)?
539 .update(cx, |thread, cx| {
540 thread.write_text_file(arguments.path, arguments.content, cx)
541 })?;
542
543 task.await?;
544
545 Ok(())
546 }
547
548 async fn read_text_file(
549 &self,
550 arguments: acp::ReadTextFileRequest,
551 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
552 let task = self.session_thread(&arguments.session_id)?.update(
553 &mut self.cx.clone(),
554 |thread, cx| {
555 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
556 },
557 )?;
558
559 let content = task.await?;
560
561 Ok(acp::ReadTextFileResponse { content })
562 }
563
564 async fn session_notification(
565 &self,
566 notification: acp::SessionNotification,
567 ) -> Result<(), acp::Error> {
568 let sessions = self.sessions.borrow();
569 let session = sessions
570 .get(¬ification.session_id)
571 .context("Failed to get session")?;
572
573 if let acp::SessionUpdate::CurrentModeUpdate { current_mode_id } = ¬ification.update {
574 if let Some(session_modes) = &session.session_modes {
575 session_modes.borrow_mut().current_mode_id = current_mode_id.clone();
576 } else {
577 log::error!(
578 "Got a `CurrentModeUpdate` notification, but they agent didn't specify `modes` during setting setup."
579 );
580 }
581 }
582
583 session.thread.update(&mut self.cx.clone(), |thread, cx| {
584 thread.handle_session_update(notification.update, cx)
585 })??;
586
587 Ok(())
588 }
589
590 async fn create_terminal(
591 &self,
592 args: acp::CreateTerminalRequest,
593 ) -> Result<acp::CreateTerminalResponse, acp::Error> {
594 let terminal = self
595 .session_thread(&args.session_id)?
596 .update(&mut self.cx.clone(), |thread, cx| {
597 thread.create_terminal(
598 args.command,
599 args.args,
600 args.env,
601 args.cwd,
602 args.output_byte_limit,
603 cx,
604 )
605 })?
606 .await?;
607 Ok(
608 terminal.read_with(&self.cx, |terminal, _| acp::CreateTerminalResponse {
609 terminal_id: terminal.id().clone(),
610 })?,
611 )
612 }
613
614 async fn kill_terminal(&self, args: acp::KillTerminalRequest) -> Result<(), acp::Error> {
615 self.session_thread(&args.session_id)?
616 .update(&mut self.cx.clone(), |thread, cx| {
617 thread.kill_terminal(args.terminal_id, cx)
618 })??;
619
620 Ok(())
621 }
622
623 async fn release_terminal(&self, args: acp::ReleaseTerminalRequest) -> Result<(), acp::Error> {
624 self.session_thread(&args.session_id)?
625 .update(&mut self.cx.clone(), |thread, cx| {
626 thread.release_terminal(args.terminal_id, cx)
627 })??;
628
629 Ok(())
630 }
631
632 async fn terminal_output(
633 &self,
634 args: acp::TerminalOutputRequest,
635 ) -> Result<acp::TerminalOutputResponse, acp::Error> {
636 self.session_thread(&args.session_id)?
637 .read_with(&mut self.cx.clone(), |thread, cx| {
638 let out = thread
639 .terminal(args.terminal_id)?
640 .read(cx)
641 .current_output(cx);
642
643 Ok(out)
644 })?
645 }
646
647 async fn wait_for_terminal_exit(
648 &self,
649 args: acp::WaitForTerminalExitRequest,
650 ) -> Result<acp::WaitForTerminalExitResponse, acp::Error> {
651 let exit_status = self
652 .session_thread(&args.session_id)?
653 .update(&mut self.cx.clone(), |thread, cx| {
654 anyhow::Ok(thread.terminal(args.terminal_id)?.read(cx).wait_for_exit())
655 })??
656 .await;
657
658 Ok(acp::WaitForTerminalExitResponse { exit_status })
659 }
660}
661
662impl ClientDelegate {
663 fn session_thread(&self, session_id: &acp::SessionId) -> Result<WeakEntity<AcpThread>> {
664 let sessions = self.sessions.borrow();
665 sessions
666 .get(session_id)
667 .context("Failed to get session")
668 .map(|session| session.thread.clone())
669 }
670}