1use acp_thread::AgentConnection;
2use acp_tools::AcpConnectionRegistry;
3use action_log::ActionLog;
4use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
5use anyhow::anyhow;
6use collections::HashMap;
7use futures::AsyncBufReadExt as _;
8use futures::io::BufReader;
9use project::Project;
10use project::agent_server_store::AgentServerCommand;
11use serde::Deserialize;
12use util::ResultExt as _;
13
14use std::path::PathBuf;
15use std::{any::Any, cell::RefCell};
16use std::{path::Path, rc::Rc, sync::Arc};
17use thiserror::Error;
18
19use anyhow::{Context as _, Result};
20use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
21
22use acp_thread::{AcpThread, AuthRequired, LoadError};
23
24#[derive(Debug, Error)]
25#[error("Unsupported version")]
26pub struct UnsupportedVersion;
27
28pub struct AcpConnection {
29 server_name: SharedString,
30 connection: Rc<acp::ClientSideConnection>,
31 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
32 auth_methods: Vec<acp::AuthMethod>,
33 agent_capabilities: acp::AgentCapabilities,
34 default_mode: Option<acp::SessionModeId>,
35 root_dir: PathBuf,
36 // NB: Don't move this into the wait_task, since we need to ensure the process is
37 // killed on drop (setting kill_on_drop on the command seems to not always work).
38 child: smol::process::Child,
39 _io_task: Task<Result<()>>,
40 _wait_task: Task<Result<()>>,
41 _stderr_task: Task<Result<()>>,
42}
43
44pub struct AcpSession {
45 thread: WeakEntity<AcpThread>,
46 suppress_abort_err: bool,
47 session_modes: Option<Rc<RefCell<acp::SessionModeState>>>,
48}
49
50pub async fn connect(
51 server_name: SharedString,
52 command: AgentServerCommand,
53 root_dir: &Path,
54 default_mode: Option<acp::SessionModeId>,
55 is_remote: bool,
56 cx: &mut AsyncApp,
57) -> Result<Rc<dyn AgentConnection>> {
58 let conn = AcpConnection::stdio(
59 server_name,
60 command.clone(),
61 root_dir,
62 default_mode,
63 is_remote,
64 cx,
65 )
66 .await?;
67 Ok(Rc::new(conn) as _)
68}
69
70const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
71
72impl AcpConnection {
73 pub async fn stdio(
74 server_name: SharedString,
75 command: AgentServerCommand,
76 root_dir: &Path,
77 default_mode: Option<acp::SessionModeId>,
78 is_remote: bool,
79 cx: &mut AsyncApp,
80 ) -> Result<Self> {
81 let mut child = util::command::new_smol_command(command.path);
82 child
83 .args(command.args.iter().map(|arg| arg.as_str()))
84 .envs(command.env.iter().flatten())
85 .stdin(std::process::Stdio::piped())
86 .stdout(std::process::Stdio::piped())
87 .stderr(std::process::Stdio::piped());
88 if !is_remote {
89 child.current_dir(root_dir);
90 }
91 let mut child = child.spawn()?;
92
93 let stdout = child.stdout.take().context("Failed to take stdout")?;
94 let stdin = child.stdin.take().context("Failed to take stdin")?;
95 let stderr = child.stderr.take().context("Failed to take stderr")?;
96 log::trace!("Spawned (pid: {})", child.id());
97
98 let sessions = Rc::new(RefCell::new(HashMap::default()));
99
100 let client = ClientDelegate {
101 sessions: sessions.clone(),
102 cx: cx.clone(),
103 };
104 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
105 let foreground_executor = cx.foreground_executor().clone();
106 move |fut| {
107 foreground_executor.spawn(fut).detach();
108 }
109 });
110
111 let io_task = cx.background_spawn(io_task);
112
113 let stderr_task = cx.background_spawn(async move {
114 let mut stderr = BufReader::new(stderr);
115 let mut line = String::new();
116 while let Ok(n) = stderr.read_line(&mut line).await
117 && n > 0
118 {
119 log::warn!("agent stderr: {}", &line);
120 line.clear();
121 }
122 Ok(())
123 });
124
125 let wait_task = cx.spawn({
126 let sessions = sessions.clone();
127 let status_fut = child.status();
128 async move |cx| {
129 let status = status_fut.await?;
130
131 for session in sessions.borrow().values() {
132 session
133 .thread
134 .update(cx, |thread, cx| {
135 thread.emit_load_error(LoadError::Exited { status }, cx)
136 })
137 .ok();
138 }
139
140 anyhow::Ok(())
141 }
142 });
143
144 let connection = Rc::new(connection);
145
146 cx.update(|cx| {
147 AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
148 registry.set_active_connection(server_name.clone(), &connection, cx)
149 });
150 })?;
151
152 let response = connection
153 .initialize(acp::InitializeRequest {
154 protocol_version: acp::VERSION,
155 client_capabilities: acp::ClientCapabilities {
156 fs: acp::FileSystemCapability {
157 read_text_file: true,
158 write_text_file: true,
159 meta: None,
160 },
161 terminal: true,
162 meta: None,
163 },
164 meta: None,
165 })
166 .await?;
167
168 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
169 return Err(UnsupportedVersion.into());
170 }
171
172 Ok(Self {
173 auth_methods: response.auth_methods,
174 root_dir: root_dir.to_owned(),
175 connection,
176 server_name,
177 sessions,
178 agent_capabilities: response.agent_capabilities,
179 default_mode,
180 _io_task: io_task,
181 _wait_task: wait_task,
182 _stderr_task: stderr_task,
183 child,
184 })
185 }
186
187 pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
188 &self.agent_capabilities.prompt_capabilities
189 }
190
191 pub fn root_dir(&self) -> &Path {
192 &self.root_dir
193 }
194}
195
196impl Drop for AcpConnection {
197 fn drop(&mut self) {
198 // See the comment on the child field.
199 self.child.kill().log_err();
200 }
201}
202
203impl AgentConnection for AcpConnection {
204 fn new_thread(
205 self: Rc<Self>,
206 project: Entity<Project>,
207 cwd: &Path,
208 cx: &mut App,
209 ) -> Task<Result<Entity<AcpThread>>> {
210 let name = self.server_name.clone();
211 let conn = self.connection.clone();
212 let sessions = self.sessions.clone();
213 let default_mode = self.default_mode.clone();
214 let cwd = cwd.to_path_buf();
215 let context_server_store = project.read(cx).context_server_store().read(cx);
216 let mcp_servers = if project.read(cx).is_local() {
217 context_server_store
218 .configured_server_ids()
219 .iter()
220 .filter_map(|id| {
221 let configuration = context_server_store.configuration_for_server(id)?;
222 let command = configuration.command();
223 Some(acp::McpServer::Stdio {
224 name: id.0.to_string(),
225 command: command.path.clone(),
226 args: command.args.clone(),
227 env: if let Some(env) = command.env.as_ref() {
228 env.iter()
229 .map(|(name, value)| acp::EnvVariable {
230 name: name.clone(),
231 value: value.clone(),
232 meta: None,
233 })
234 .collect()
235 } else {
236 vec![]
237 },
238 })
239 })
240 .collect()
241 } else {
242 // In SSH projects, the external agent is running on the remote
243 // machine, and currently we only run MCP servers on the local
244 // machine. So don't pass any MCP servers to the agent in that case.
245 Vec::new()
246 };
247
248 cx.spawn(async move |cx| {
249 let response = conn
250 .new_session(acp::NewSessionRequest { mcp_servers, cwd, meta: None })
251 .await
252 .map_err(|err| {
253 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
254 let mut error = AuthRequired::new();
255
256 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
257 error = error.with_description(err.message);
258 }
259
260 anyhow!(error)
261 } else {
262 anyhow!(err)
263 }
264 })?;
265
266 let modes = response.modes.map(|modes| Rc::new(RefCell::new(modes)));
267
268 if let Some(default_mode) = default_mode {
269 if let Some(modes) = modes.as_ref() {
270 let mut modes_ref = modes.borrow_mut();
271 let has_mode = modes_ref.available_modes.iter().any(|mode| mode.id == default_mode);
272
273 if has_mode {
274 let initial_mode_id = modes_ref.current_mode_id.clone();
275
276 cx.spawn({
277 let default_mode = default_mode.clone();
278 let session_id = response.session_id.clone();
279 let modes = modes.clone();
280 async move |_| {
281 let result = conn.set_session_mode(acp::SetSessionModeRequest {
282 session_id,
283 mode_id: default_mode,
284 meta: None,
285 })
286 .await.log_err();
287
288 if result.is_none() {
289 modes.borrow_mut().current_mode_id = initial_mode_id;
290 }
291 }
292 }).detach();
293
294 modes_ref.current_mode_id = default_mode;
295 } else {
296 let available_modes = modes_ref
297 .available_modes
298 .iter()
299 .map(|mode| format!("- `{}`: {}", mode.id, mode.name))
300 .collect::<Vec<_>>()
301 .join("\n");
302
303 log::warn!(
304 "`{default_mode}` is not valid {name} mode. Available options:\n{available_modes}",
305 );
306 }
307 } else {
308 log::warn!(
309 "`{name}` does not support modes, but `default_mode` was set in settings.",
310 );
311 }
312 }
313
314 let session_id = response.session_id;
315 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
316 let thread = cx.new(|cx| {
317 AcpThread::new(
318 self.server_name.clone(),
319 self.clone(),
320 project,
321 action_log,
322 session_id.clone(),
323 // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
324 watch::Receiver::constant(self.agent_capabilities.prompt_capabilities.clone()),
325 cx,
326 )
327 })?;
328
329 let session = AcpSession {
330 thread: thread.downgrade(),
331 suppress_abort_err: false,
332 session_modes: modes
333 };
334 sessions.borrow_mut().insert(session_id, session);
335
336 Ok(thread)
337 })
338 }
339
340 fn auth_methods(&self) -> &[acp::AuthMethod] {
341 &self.auth_methods
342 }
343
344 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
345 let conn = self.connection.clone();
346 cx.foreground_executor().spawn(async move {
347 conn.authenticate(acp::AuthenticateRequest {
348 method_id: method_id.clone(),
349 meta: None,
350 })
351 .await?;
352
353 Ok(())
354 })
355 }
356
357 fn prompt(
358 &self,
359 _id: Option<acp_thread::UserMessageId>,
360 params: acp::PromptRequest,
361 cx: &mut App,
362 ) -> Task<Result<acp::PromptResponse>> {
363 let conn = self.connection.clone();
364 let sessions = self.sessions.clone();
365 let session_id = params.session_id.clone();
366 cx.foreground_executor().spawn(async move {
367 let result = conn.prompt(params).await;
368
369 let mut suppress_abort_err = false;
370
371 if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
372 suppress_abort_err = session.suppress_abort_err;
373 session.suppress_abort_err = false;
374 }
375
376 match result {
377 Ok(response) => Ok(response),
378 Err(err) => {
379 if err.code != ErrorCode::INTERNAL_ERROR.code {
380 anyhow::bail!(err)
381 }
382
383 let Some(data) = &err.data else {
384 anyhow::bail!(err)
385 };
386
387 // Temporary workaround until the following PR is generally available:
388 // https://github.com/google-gemini/gemini-cli/pull/6656
389
390 #[derive(Deserialize)]
391 #[serde(deny_unknown_fields)]
392 struct ErrorDetails {
393 details: Box<str>,
394 }
395
396 match serde_json::from_value(data.clone()) {
397 Ok(ErrorDetails { details }) => {
398 if suppress_abort_err
399 && (details.contains("This operation was aborted")
400 || details.contains("The user aborted a request"))
401 {
402 Ok(acp::PromptResponse {
403 stop_reason: acp::StopReason::Cancelled,
404 meta: None,
405 })
406 } else {
407 Err(anyhow!(details))
408 }
409 }
410 Err(_) => Err(anyhow!(err)),
411 }
412 }
413 }
414 })
415 }
416
417 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
418 if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
419 session.suppress_abort_err = true;
420 }
421 let conn = self.connection.clone();
422 let params = acp::CancelNotification {
423 session_id: session_id.clone(),
424 meta: None,
425 };
426 cx.foreground_executor()
427 .spawn(async move { conn.cancel(params).await })
428 .detach();
429 }
430
431 fn session_modes(
432 &self,
433 session_id: &acp::SessionId,
434 _cx: &App,
435 ) -> Option<Rc<dyn acp_thread::AgentSessionModes>> {
436 let sessions = self.sessions.clone();
437 let sessions_ref = sessions.borrow();
438 let Some(session) = sessions_ref.get(session_id) else {
439 return None;
440 };
441
442 if let Some(modes) = session.session_modes.as_ref() {
443 Some(Rc::new(AcpSessionModes {
444 connection: self.connection.clone(),
445 session_id: session_id.clone(),
446 state: modes.clone(),
447 }) as _)
448 } else {
449 None
450 }
451 }
452
453 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
454 self
455 }
456}
457
458struct AcpSessionModes {
459 session_id: acp::SessionId,
460 connection: Rc<acp::ClientSideConnection>,
461 state: Rc<RefCell<acp::SessionModeState>>,
462}
463
464impl acp_thread::AgentSessionModes for AcpSessionModes {
465 fn current_mode(&self) -> acp::SessionModeId {
466 self.state.borrow().current_mode_id.clone()
467 }
468
469 fn all_modes(&self) -> Vec<acp::SessionMode> {
470 self.state.borrow().available_modes.clone()
471 }
472
473 fn set_mode(&self, mode_id: acp::SessionModeId, cx: &mut App) -> Task<Result<()>> {
474 let connection = self.connection.clone();
475 let session_id = self.session_id.clone();
476 let old_mode_id;
477 {
478 let mut state = self.state.borrow_mut();
479 old_mode_id = state.current_mode_id.clone();
480 state.current_mode_id = mode_id.clone();
481 };
482 let state = self.state.clone();
483 cx.foreground_executor().spawn(async move {
484 let result = connection
485 .set_session_mode(acp::SetSessionModeRequest {
486 session_id,
487 mode_id,
488 meta: None,
489 })
490 .await;
491
492 if result.is_err() {
493 state.borrow_mut().current_mode_id = old_mode_id;
494 }
495
496 result?;
497
498 Ok(())
499 })
500 }
501}
502
503struct ClientDelegate {
504 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
505 cx: AsyncApp,
506}
507
508impl acp::Client for ClientDelegate {
509 async fn request_permission(
510 &self,
511 arguments: acp::RequestPermissionRequest,
512 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
513 let respect_always_allow_setting;
514 let thread;
515 {
516 let sessions_ref = self.sessions.borrow();
517 let session = sessions_ref
518 .get(&arguments.session_id)
519 .context("Failed to get session")?;
520 respect_always_allow_setting = session.session_modes.is_none();
521 thread = session.thread.clone();
522 }
523
524 let cx = &mut self.cx.clone();
525
526 let task = thread.update(cx, |thread, cx| {
527 thread.request_tool_call_authorization(
528 arguments.tool_call,
529 arguments.options,
530 respect_always_allow_setting,
531 cx,
532 )
533 })??;
534
535 let outcome = task.await;
536
537 Ok(acp::RequestPermissionResponse {
538 outcome,
539 meta: None,
540 })
541 }
542
543 async fn write_text_file(
544 &self,
545 arguments: acp::WriteTextFileRequest,
546 ) -> Result<acp::WriteTextFileResponse, acp::Error> {
547 let cx = &mut self.cx.clone();
548 let task = self
549 .session_thread(&arguments.session_id)?
550 .update(cx, |thread, cx| {
551 thread.write_text_file(arguments.path, arguments.content, cx)
552 })?;
553
554 task.await?;
555
556 Ok(Default::default())
557 }
558
559 async fn read_text_file(
560 &self,
561 arguments: acp::ReadTextFileRequest,
562 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
563 let task = self.session_thread(&arguments.session_id)?.update(
564 &mut self.cx.clone(),
565 |thread, cx| {
566 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
567 },
568 )?;
569
570 let content = task.await?;
571
572 Ok(acp::ReadTextFileResponse {
573 content,
574 meta: None,
575 })
576 }
577
578 async fn session_notification(
579 &self,
580 notification: acp::SessionNotification,
581 ) -> Result<(), acp::Error> {
582 let sessions = self.sessions.borrow();
583 let session = sessions
584 .get(¬ification.session_id)
585 .context("Failed to get session")?;
586
587 if let acp::SessionUpdate::CurrentModeUpdate { current_mode_id } = ¬ification.update {
588 if let Some(session_modes) = &session.session_modes {
589 session_modes.borrow_mut().current_mode_id = current_mode_id.clone();
590 } else {
591 log::error!(
592 "Got a `CurrentModeUpdate` notification, but they agent didn't specify `modes` during setting setup."
593 );
594 }
595 }
596
597 session.thread.update(&mut self.cx.clone(), |thread, cx| {
598 thread.handle_session_update(notification.update, cx)
599 })??;
600
601 Ok(())
602 }
603
604 async fn create_terminal(
605 &self,
606 args: acp::CreateTerminalRequest,
607 ) -> Result<acp::CreateTerminalResponse, acp::Error> {
608 let terminal = self
609 .session_thread(&args.session_id)?
610 .update(&mut self.cx.clone(), |thread, cx| {
611 thread.create_terminal(
612 args.command,
613 args.args,
614 args.env,
615 args.cwd,
616 args.output_byte_limit,
617 cx,
618 )
619 })?
620 .await?;
621 Ok(
622 terminal.read_with(&self.cx, |terminal, _| acp::CreateTerminalResponse {
623 terminal_id: terminal.id().clone(),
624 meta: None,
625 })?,
626 )
627 }
628
629 async fn kill_terminal_command(
630 &self,
631 args: acp::KillTerminalCommandRequest,
632 ) -> Result<acp::KillTerminalCommandResponse, acp::Error> {
633 self.session_thread(&args.session_id)?
634 .update(&mut self.cx.clone(), |thread, cx| {
635 thread.kill_terminal(args.terminal_id, cx)
636 })??;
637
638 Ok(Default::default())
639 }
640
641 async fn ext_method(
642 &self,
643 _name: Arc<str>,
644 _params: Arc<serde_json::value::RawValue>,
645 ) -> Result<Arc<serde_json::value::RawValue>, acp::Error> {
646 Err(acp::Error::method_not_found())
647 }
648
649 async fn ext_notification(
650 &self,
651 _name: Arc<str>,
652 _params: Arc<serde_json::value::RawValue>,
653 ) -> Result<(), acp::Error> {
654 Err(acp::Error::method_not_found())
655 }
656
657 async fn release_terminal(
658 &self,
659 args: acp::ReleaseTerminalRequest,
660 ) -> Result<acp::ReleaseTerminalResponse, acp::Error> {
661 self.session_thread(&args.session_id)?
662 .update(&mut self.cx.clone(), |thread, cx| {
663 thread.release_terminal(args.terminal_id, cx)
664 })??;
665
666 Ok(Default::default())
667 }
668
669 async fn terminal_output(
670 &self,
671 args: acp::TerminalOutputRequest,
672 ) -> Result<acp::TerminalOutputResponse, acp::Error> {
673 self.session_thread(&args.session_id)?
674 .read_with(&mut self.cx.clone(), |thread, cx| {
675 let out = thread
676 .terminal(args.terminal_id)?
677 .read(cx)
678 .current_output(cx);
679
680 Ok(out)
681 })?
682 }
683
684 async fn wait_for_terminal_exit(
685 &self,
686 args: acp::WaitForTerminalExitRequest,
687 ) -> Result<acp::WaitForTerminalExitResponse, acp::Error> {
688 let exit_status = self
689 .session_thread(&args.session_id)?
690 .update(&mut self.cx.clone(), |thread, cx| {
691 anyhow::Ok(thread.terminal(args.terminal_id)?.read(cx).wait_for_exit())
692 })??
693 .await;
694
695 Ok(acp::WaitForTerminalExitResponse {
696 exit_status,
697 meta: None,
698 })
699 }
700}
701
702impl ClientDelegate {
703 fn session_thread(&self, session_id: &acp::SessionId) -> Result<WeakEntity<AcpThread>> {
704 let sessions = self.sessions.borrow();
705 sessions
706 .get(session_id)
707 .context("Failed to get session")
708 .map(|session| session.thread.clone())
709 }
710}