1use acp_thread::AgentConnection;
2use acp_tools::AcpConnectionRegistry;
3use action_log::ActionLog;
4use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
5use anyhow::anyhow;
6use collections::HashMap;
7use futures::AsyncBufReadExt as _;
8use futures::io::BufReader;
9use project::Project;
10use project::agent_server_store::AgentServerCommand;
11use serde::Deserialize;
12use util::ResultExt as _;
13
14use std::path::PathBuf;
15use std::{any::Any, cell::RefCell};
16use std::{path::Path, rc::Rc};
17use thiserror::Error;
18
19use anyhow::{Context as _, Result};
20use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
21
22use acp_thread::{AcpThread, AuthRequired, LoadError};
23
24#[derive(Debug, Error)]
25#[error("Unsupported version")]
26pub struct UnsupportedVersion;
27
28pub struct AcpConnection {
29 server_name: SharedString,
30 connection: Rc<acp::ClientSideConnection>,
31 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
32 auth_methods: Vec<acp::AuthMethod>,
33 agent_capabilities: acp::AgentCapabilities,
34 default_mode: Option<acp::SessionModeId>,
35 root_dir: PathBuf,
36 // NB: Don't move this into the wait_task, since we need to ensure the process is
37 // killed on drop (setting kill_on_drop on the command seems to not always work).
38 child: smol::process::Child,
39 _io_task: Task<Result<()>>,
40 _wait_task: Task<Result<()>>,
41 _stderr_task: Task<Result<()>>,
42}
43
44pub struct AcpSession {
45 thread: WeakEntity<AcpThread>,
46 suppress_abort_err: bool,
47 session_modes: Option<Rc<RefCell<acp::SessionModeState>>>,
48}
49
50pub async fn connect(
51 server_name: SharedString,
52 command: AgentServerCommand,
53 root_dir: &Path,
54 default_mode: Option<acp::SessionModeId>,
55 is_remote: bool,
56 cx: &mut AsyncApp,
57) -> Result<Rc<dyn AgentConnection>> {
58 let conn = AcpConnection::stdio(
59 server_name,
60 command.clone(),
61 root_dir,
62 default_mode,
63 is_remote,
64 cx,
65 )
66 .await?;
67 Ok(Rc::new(conn) as _)
68}
69
70const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
71
72impl AcpConnection {
73 pub async fn stdio(
74 server_name: SharedString,
75 command: AgentServerCommand,
76 root_dir: &Path,
77 default_mode: Option<acp::SessionModeId>,
78 is_remote: bool,
79 cx: &mut AsyncApp,
80 ) -> Result<Self> {
81 let mut child = util::command::new_smol_command(command.path);
82 child
83 .args(command.args.iter().map(|arg| arg.as_str()))
84 .envs(command.env.iter().flatten())
85 .stdin(std::process::Stdio::piped())
86 .stdout(std::process::Stdio::piped())
87 .stderr(std::process::Stdio::piped());
88 if !is_remote {
89 child.current_dir(root_dir);
90 }
91 let mut child = child.spawn()?;
92
93 let stdout = child.stdout.take().context("Failed to take stdout")?;
94 let stdin = child.stdin.take().context("Failed to take stdin")?;
95 let stderr = child.stderr.take().context("Failed to take stderr")?;
96 log::trace!("Spawned (pid: {})", child.id());
97
98 let sessions = Rc::new(RefCell::new(HashMap::default()));
99
100 let client = ClientDelegate {
101 sessions: sessions.clone(),
102 cx: cx.clone(),
103 };
104 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
105 let foreground_executor = cx.foreground_executor().clone();
106 move |fut| {
107 foreground_executor.spawn(fut).detach();
108 }
109 });
110
111 let io_task = cx.background_spawn(io_task);
112
113 let stderr_task = cx.background_spawn(async move {
114 let mut stderr = BufReader::new(stderr);
115 let mut line = String::new();
116 while let Ok(n) = stderr.read_line(&mut line).await
117 && n > 0
118 {
119 log::warn!("agent stderr: {}", &line);
120 line.clear();
121 }
122 Ok(())
123 });
124
125 let wait_task = cx.spawn({
126 let sessions = sessions.clone();
127 let status_fut = child.status();
128 async move |cx| {
129 let status = status_fut.await?;
130
131 for session in sessions.borrow().values() {
132 session
133 .thread
134 .update(cx, |thread, cx| {
135 thread.emit_load_error(LoadError::Exited { status }, cx)
136 })
137 .ok();
138 }
139
140 anyhow::Ok(())
141 }
142 });
143
144 let connection = Rc::new(connection);
145
146 cx.update(|cx| {
147 AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
148 registry.set_active_connection(server_name.clone(), &connection, cx)
149 });
150 })?;
151
152 let response = connection
153 .initialize(acp::InitializeRequest {
154 protocol_version: acp::VERSION,
155 client_capabilities: acp::ClientCapabilities {
156 fs: acp::FileSystemCapability {
157 read_text_file: true,
158 write_text_file: true,
159 meta: None,
160 },
161 terminal: true,
162 meta: None,
163 },
164 meta: None,
165 })
166 .await?;
167
168 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
169 return Err(UnsupportedVersion.into());
170 }
171
172 Ok(Self {
173 auth_methods: response.auth_methods,
174 root_dir: root_dir.to_owned(),
175 connection,
176 server_name,
177 sessions,
178 agent_capabilities: response.agent_capabilities,
179 default_mode,
180 _io_task: io_task,
181 _wait_task: wait_task,
182 _stderr_task: stderr_task,
183 child,
184 })
185 }
186
187 pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
188 &self.agent_capabilities.prompt_capabilities
189 }
190
191 pub fn root_dir(&self) -> &Path {
192 &self.root_dir
193 }
194}
195
196impl Drop for AcpConnection {
197 fn drop(&mut self) {
198 // See the comment on the child field.
199 self.child.kill().log_err();
200 }
201}
202
203impl AgentConnection for AcpConnection {
204 fn new_thread(
205 self: Rc<Self>,
206 project: Entity<Project>,
207 cwd: &Path,
208 cx: &mut App,
209 ) -> Task<Result<Entity<AcpThread>>> {
210 let name = self.server_name.clone();
211 let conn = self.connection.clone();
212 let sessions = self.sessions.clone();
213 let default_mode = self.default_mode.clone();
214 let cwd = cwd.to_path_buf();
215 let context_server_store = project.read(cx).context_server_store().read(cx);
216 let mcp_servers = if project.read(cx).is_local() {
217 context_server_store
218 .configured_server_ids()
219 .iter()
220 .filter_map(|id| {
221 let configuration = context_server_store.configuration_for_server(id)?;
222 let command = configuration.command();
223 Some(acp::McpServer::Stdio {
224 name: id.0.to_string(),
225 command: command.path.clone(),
226 args: command.args.clone(),
227 env: if let Some(env) = command.env.as_ref() {
228 env.iter()
229 .map(|(name, value)| acp::EnvVariable {
230 name: name.clone(),
231 value: value.clone(),
232 meta: None,
233 })
234 .collect()
235 } else {
236 vec![]
237 },
238 })
239 })
240 .collect()
241 } else {
242 // In SSH projects, the external agent is running on the remote
243 // machine, and currently we only run MCP servers on the local
244 // machine. So don't pass any MCP servers to the agent in that case.
245 Vec::new()
246 };
247
248 cx.spawn(async move |cx| {
249 let response = conn
250 .new_session(acp::NewSessionRequest { mcp_servers, cwd, meta: None })
251 .await
252 .map_err(|err| {
253 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
254 let mut error = AuthRequired::new();
255
256 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
257 error = error.with_description(err.message);
258 }
259
260 anyhow!(error)
261 } else {
262 anyhow!(err)
263 }
264 })?;
265
266 let modes = response.modes.map(|modes| Rc::new(RefCell::new(modes)));
267
268 if let Some(default_mode) = default_mode {
269 if let Some(modes) = modes.as_ref() {
270 let mut modes_ref = modes.borrow_mut();
271 let has_mode = modes_ref.available_modes.iter().any(|mode| mode.id == default_mode);
272
273 if has_mode {
274 let initial_mode_id = modes_ref.current_mode_id.clone();
275
276 cx.spawn({
277 let default_mode = default_mode.clone();
278 let session_id = response.session_id.clone();
279 let modes = modes.clone();
280 async move |_| {
281 let result = conn.set_session_mode(acp::SetSessionModeRequest {
282 session_id,
283 mode_id: default_mode,
284 meta: None,
285 })
286 .await.log_err();
287
288 if result.is_none() {
289 modes.borrow_mut().current_mode_id = initial_mode_id;
290 }
291 }
292 }).detach();
293
294 modes_ref.current_mode_id = default_mode;
295 } else {
296 let available_modes = modes_ref
297 .available_modes
298 .iter()
299 .map(|mode| format!("- `{}`: {}", mode.id, mode.name))
300 .collect::<Vec<_>>()
301 .join("\n");
302
303 log::warn!(
304 "`{default_mode}` is not valid {name} mode. Available options:\n{available_modes}",
305 );
306 }
307 } else {
308 log::warn!(
309 "`{name}` does not support modes, but `default_mode` was set in settings.",
310 );
311 }
312 }
313
314 let session_id = response.session_id;
315 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
316 let thread = cx.new(|cx| {
317 AcpThread::new(
318 self.server_name.clone(),
319 self.clone(),
320 project,
321 action_log,
322 session_id.clone(),
323 // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
324 watch::Receiver::constant(self.agent_capabilities.prompt_capabilities.clone()),
325 cx,
326 )
327 })?;
328
329 let session = AcpSession {
330 thread: thread.downgrade(),
331 suppress_abort_err: false,
332 session_modes: modes
333 };
334 sessions.borrow_mut().insert(session_id, session);
335
336 Ok(thread)
337 })
338 }
339
340 fn auth_methods(&self) -> &[acp::AuthMethod] {
341 &self.auth_methods
342 }
343
344 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
345 let conn = self.connection.clone();
346 cx.foreground_executor().spawn(async move {
347 conn.authenticate(acp::AuthenticateRequest {
348 method_id: method_id.clone(),
349 meta: None,
350 })
351 .await?;
352
353 Ok(())
354 })
355 }
356
357 fn prompt(
358 &self,
359 _id: Option<acp_thread::UserMessageId>,
360 params: acp::PromptRequest,
361 cx: &mut App,
362 ) -> Task<Result<acp::PromptResponse>> {
363 let conn = self.connection.clone();
364 let sessions = self.sessions.clone();
365 let session_id = params.session_id.clone();
366 cx.foreground_executor().spawn(async move {
367 let result = conn.prompt(params).await;
368
369 let mut suppress_abort_err = false;
370
371 if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
372 suppress_abort_err = session.suppress_abort_err;
373 session.suppress_abort_err = false;
374 }
375
376 match result {
377 Ok(response) => Ok(response),
378 Err(err) => {
379 if err.code != ErrorCode::INTERNAL_ERROR.code {
380 anyhow::bail!(err)
381 }
382
383 let Some(data) = &err.data else {
384 anyhow::bail!(err)
385 };
386
387 // Temporary workaround until the following PR is generally available:
388 // https://github.com/google-gemini/gemini-cli/pull/6656
389
390 #[derive(Deserialize)]
391 #[serde(deny_unknown_fields)]
392 struct ErrorDetails {
393 details: Box<str>,
394 }
395
396 match serde_json::from_value(data.clone()) {
397 Ok(ErrorDetails { details }) => {
398 if suppress_abort_err
399 && (details.contains("This operation was aborted")
400 || details.contains("The user aborted a request"))
401 {
402 Ok(acp::PromptResponse {
403 stop_reason: acp::StopReason::Cancelled,
404 meta: None,
405 })
406 } else {
407 Err(anyhow!(details))
408 }
409 }
410 Err(_) => Err(anyhow!(err)),
411 }
412 }
413 }
414 })
415 }
416
417 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
418 if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
419 session.suppress_abort_err = true;
420 }
421 let conn = self.connection.clone();
422 let params = acp::CancelNotification {
423 session_id: session_id.clone(),
424 meta: None,
425 };
426 cx.foreground_executor()
427 .spawn(async move { conn.cancel(params).await })
428 .detach();
429 }
430
431 fn session_modes(
432 &self,
433 session_id: &acp::SessionId,
434 _cx: &App,
435 ) -> Option<Rc<dyn acp_thread::AgentSessionModes>> {
436 let sessions = self.sessions.clone();
437 let sessions_ref = sessions.borrow();
438 let Some(session) = sessions_ref.get(session_id) else {
439 return None;
440 };
441
442 if let Some(modes) = session.session_modes.as_ref() {
443 Some(Rc::new(AcpSessionModes {
444 connection: self.connection.clone(),
445 session_id: session_id.clone(),
446 state: modes.clone(),
447 }) as _)
448 } else {
449 None
450 }
451 }
452
453 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
454 self
455 }
456}
457
458struct AcpSessionModes {
459 session_id: acp::SessionId,
460 connection: Rc<acp::ClientSideConnection>,
461 state: Rc<RefCell<acp::SessionModeState>>,
462}
463
464impl acp_thread::AgentSessionModes for AcpSessionModes {
465 fn current_mode(&self) -> acp::SessionModeId {
466 self.state.borrow().current_mode_id.clone()
467 }
468
469 fn all_modes(&self) -> Vec<acp::SessionMode> {
470 self.state.borrow().available_modes.clone()
471 }
472
473 fn set_mode(&self, mode_id: acp::SessionModeId, cx: &mut App) -> Task<Result<()>> {
474 let connection = self.connection.clone();
475 let session_id = self.session_id.clone();
476 let old_mode_id;
477 {
478 let mut state = self.state.borrow_mut();
479 old_mode_id = state.current_mode_id.clone();
480 state.current_mode_id = mode_id.clone();
481 };
482 let state = self.state.clone();
483 cx.foreground_executor().spawn(async move {
484 let result = connection
485 .set_session_mode(acp::SetSessionModeRequest {
486 session_id,
487 mode_id,
488 meta: None,
489 })
490 .await;
491
492 if result.is_err() {
493 state.borrow_mut().current_mode_id = old_mode_id;
494 }
495
496 result?;
497
498 Ok(())
499 })
500 }
501}
502
503struct ClientDelegate {
504 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
505 cx: AsyncApp,
506}
507
508#[async_trait::async_trait(?Send)]
509impl acp::Client for ClientDelegate {
510 async fn request_permission(
511 &self,
512 arguments: acp::RequestPermissionRequest,
513 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
514 let respect_always_allow_setting;
515 let thread;
516 {
517 let sessions_ref = self.sessions.borrow();
518 let session = sessions_ref
519 .get(&arguments.session_id)
520 .context("Failed to get session")?;
521 respect_always_allow_setting = session.session_modes.is_none();
522 thread = session.thread.clone();
523 }
524
525 let cx = &mut self.cx.clone();
526
527 let task = thread.update(cx, |thread, cx| {
528 thread.request_tool_call_authorization(
529 arguments.tool_call,
530 arguments.options,
531 respect_always_allow_setting,
532 cx,
533 )
534 })??;
535
536 let outcome = task.await;
537
538 Ok(acp::RequestPermissionResponse {
539 outcome,
540 meta: None,
541 })
542 }
543
544 async fn write_text_file(
545 &self,
546 arguments: acp::WriteTextFileRequest,
547 ) -> Result<acp::WriteTextFileResponse, acp::Error> {
548 let cx = &mut self.cx.clone();
549 let task = self
550 .session_thread(&arguments.session_id)?
551 .update(cx, |thread, cx| {
552 thread.write_text_file(arguments.path, arguments.content, cx)
553 })?;
554
555 task.await?;
556
557 Ok(Default::default())
558 }
559
560 async fn read_text_file(
561 &self,
562 arguments: acp::ReadTextFileRequest,
563 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
564 let task = self.session_thread(&arguments.session_id)?.update(
565 &mut self.cx.clone(),
566 |thread, cx| {
567 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
568 },
569 )?;
570
571 let content = task.await?;
572
573 Ok(acp::ReadTextFileResponse {
574 content,
575 meta: None,
576 })
577 }
578
579 async fn session_notification(
580 &self,
581 notification: acp::SessionNotification,
582 ) -> Result<(), acp::Error> {
583 let sessions = self.sessions.borrow();
584 let session = sessions
585 .get(¬ification.session_id)
586 .context("Failed to get session")?;
587
588 if let acp::SessionUpdate::CurrentModeUpdate { current_mode_id } = ¬ification.update {
589 if let Some(session_modes) = &session.session_modes {
590 session_modes.borrow_mut().current_mode_id = current_mode_id.clone();
591 } else {
592 log::error!(
593 "Got a `CurrentModeUpdate` notification, but they agent didn't specify `modes` during setting setup."
594 );
595 }
596 }
597
598 session.thread.update(&mut self.cx.clone(), |thread, cx| {
599 thread.handle_session_update(notification.update, cx)
600 })??;
601
602 Ok(())
603 }
604
605 async fn create_terminal(
606 &self,
607 args: acp::CreateTerminalRequest,
608 ) -> Result<acp::CreateTerminalResponse, acp::Error> {
609 let terminal = self
610 .session_thread(&args.session_id)?
611 .update(&mut self.cx.clone(), |thread, cx| {
612 thread.create_terminal(
613 args.command,
614 args.args,
615 args.env,
616 args.cwd,
617 args.output_byte_limit,
618 cx,
619 )
620 })?
621 .await?;
622 Ok(
623 terminal.read_with(&self.cx, |terminal, _| acp::CreateTerminalResponse {
624 terminal_id: terminal.id().clone(),
625 meta: None,
626 })?,
627 )
628 }
629
630 async fn kill_terminal_command(
631 &self,
632 args: acp::KillTerminalCommandRequest,
633 ) -> Result<acp::KillTerminalCommandResponse, acp::Error> {
634 self.session_thread(&args.session_id)?
635 .update(&mut self.cx.clone(), |thread, cx| {
636 thread.kill_terminal(args.terminal_id, cx)
637 })??;
638
639 Ok(Default::default())
640 }
641
642 async fn ext_method(&self, _args: acp::ExtRequest) -> Result<acp::ExtResponse, acp::Error> {
643 Err(acp::Error::method_not_found())
644 }
645
646 async fn ext_notification(&self, _args: acp::ExtNotification) -> Result<(), acp::Error> {
647 Err(acp::Error::method_not_found())
648 }
649
650 async fn release_terminal(
651 &self,
652 args: acp::ReleaseTerminalRequest,
653 ) -> Result<acp::ReleaseTerminalResponse, acp::Error> {
654 self.session_thread(&args.session_id)?
655 .update(&mut self.cx.clone(), |thread, cx| {
656 thread.release_terminal(args.terminal_id, cx)
657 })??;
658
659 Ok(Default::default())
660 }
661
662 async fn terminal_output(
663 &self,
664 args: acp::TerminalOutputRequest,
665 ) -> Result<acp::TerminalOutputResponse, acp::Error> {
666 self.session_thread(&args.session_id)?
667 .read_with(&mut self.cx.clone(), |thread, cx| {
668 let out = thread
669 .terminal(args.terminal_id)?
670 .read(cx)
671 .current_output(cx);
672
673 Ok(out)
674 })?
675 }
676
677 async fn wait_for_terminal_exit(
678 &self,
679 args: acp::WaitForTerminalExitRequest,
680 ) -> Result<acp::WaitForTerminalExitResponse, acp::Error> {
681 let exit_status = self
682 .session_thread(&args.session_id)?
683 .update(&mut self.cx.clone(), |thread, cx| {
684 anyhow::Ok(thread.terminal(args.terminal_id)?.read(cx).wait_for_exit())
685 })??
686 .await;
687
688 Ok(acp::WaitForTerminalExitResponse {
689 exit_status,
690 meta: None,
691 })
692 }
693}
694
695impl ClientDelegate {
696 fn session_thread(&self, session_id: &acp::SessionId) -> Result<WeakEntity<AcpThread>> {
697 let sessions = self.sessions.borrow();
698 sessions
699 .get(session_id)
700 .context("Failed to get session")
701 .map(|session| session.thread.clone())
702 }
703}