1use acp_thread::AgentConnection;
2use acp_tools::AcpConnectionRegistry;
3use action_log::ActionLog;
4use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
5use anyhow::anyhow;
6use collections::HashMap;
7use futures::AsyncBufReadExt as _;
8use futures::io::BufReader;
9use project::Project;
10use project::agent_server_store::AgentServerCommand;
11use serde::Deserialize;
12use util::ResultExt;
13
14use std::path::PathBuf;
15use std::{any::Any, cell::RefCell};
16use std::{path::Path, rc::Rc};
17use thiserror::Error;
18
19use anyhow::{Context as _, Result};
20use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
21
22use acp_thread::{AcpThread, AuthRequired, LoadError};
23
24#[derive(Debug, Error)]
25#[error("Unsupported version")]
26pub struct UnsupportedVersion;
27
28pub struct AcpConnection {
29 server_name: SharedString,
30 connection: Rc<acp::ClientSideConnection>,
31 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
32 auth_methods: Vec<acp::AuthMethod>,
33 agent_capabilities: acp::AgentCapabilities,
34 default_mode: Option<acp::SessionModeId>,
35 root_dir: PathBuf,
36 _io_task: Task<Result<()>>,
37 _wait_task: Task<Result<()>>,
38 _stderr_task: Task<Result<()>>,
39}
40
41pub struct AcpSession {
42 thread: WeakEntity<AcpThread>,
43 suppress_abort_err: bool,
44 session_modes: Option<Rc<RefCell<acp::SessionModeState>>>,
45}
46
47pub async fn connect(
48 server_name: SharedString,
49 command: AgentServerCommand,
50 root_dir: &Path,
51 default_mode: Option<acp::SessionModeId>,
52 is_remote: bool,
53 cx: &mut AsyncApp,
54) -> Result<Rc<dyn AgentConnection>> {
55 let conn = AcpConnection::stdio(
56 server_name,
57 command.clone(),
58 root_dir,
59 default_mode,
60 is_remote,
61 cx,
62 )
63 .await?;
64 Ok(Rc::new(conn) as _)
65}
66
67const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
68
69impl AcpConnection {
70 pub async fn stdio(
71 server_name: SharedString,
72 command: AgentServerCommand,
73 root_dir: &Path,
74 default_mode: Option<acp::SessionModeId>,
75 is_remote: bool,
76 cx: &mut AsyncApp,
77 ) -> Result<Self> {
78 let mut child = util::command::new_smol_command(command.path);
79 child
80 .args(command.args.iter().map(|arg| arg.as_str()))
81 .envs(command.env.iter().flatten())
82 .stdin(std::process::Stdio::piped())
83 .stdout(std::process::Stdio::piped())
84 .stderr(std::process::Stdio::piped())
85 .kill_on_drop(true);
86 if !is_remote {
87 child.current_dir(root_dir);
88 }
89 let mut child = child.spawn()?;
90
91 let stdout = child.stdout.take().context("Failed to take stdout")?;
92 let stdin = child.stdin.take().context("Failed to take stdin")?;
93 let stderr = child.stderr.take().context("Failed to take stderr")?;
94 log::trace!("Spawned (pid: {})", child.id());
95
96 let sessions = Rc::new(RefCell::new(HashMap::default()));
97
98 let client = ClientDelegate {
99 sessions: sessions.clone(),
100 cx: cx.clone(),
101 };
102 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
103 let foreground_executor = cx.foreground_executor().clone();
104 move |fut| {
105 foreground_executor.spawn(fut).detach();
106 }
107 });
108
109 let io_task = cx.background_spawn(io_task);
110
111 let stderr_task = cx.background_spawn(async move {
112 let mut stderr = BufReader::new(stderr);
113 let mut line = String::new();
114 while let Ok(n) = stderr.read_line(&mut line).await
115 && n > 0
116 {
117 log::warn!("agent stderr: {}", &line);
118 line.clear();
119 }
120 Ok(())
121 });
122
123 let wait_task = cx.spawn({
124 let sessions = sessions.clone();
125 async move |cx| {
126 let status = child.status().await?;
127
128 for session in sessions.borrow().values() {
129 session
130 .thread
131 .update(cx, |thread, cx| {
132 thread.emit_load_error(LoadError::Exited { status }, cx)
133 })
134 .ok();
135 }
136
137 anyhow::Ok(())
138 }
139 });
140
141 let connection = Rc::new(connection);
142
143 cx.update(|cx| {
144 AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
145 registry.set_active_connection(server_name.clone(), &connection, cx)
146 });
147 })?;
148
149 let response = connection
150 .initialize(acp::InitializeRequest {
151 protocol_version: acp::VERSION,
152 client_capabilities: acp::ClientCapabilities {
153 fs: acp::FileSystemCapability {
154 read_text_file: true,
155 write_text_file: true,
156 },
157 terminal: true,
158 },
159 })
160 .await?;
161
162 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
163 return Err(UnsupportedVersion.into());
164 }
165
166 Ok(Self {
167 auth_methods: response.auth_methods,
168 root_dir: root_dir.to_owned(),
169 connection,
170 server_name,
171 sessions,
172 agent_capabilities: response.agent_capabilities,
173 default_mode,
174 _io_task: io_task,
175 _wait_task: wait_task,
176 _stderr_task: stderr_task,
177 })
178 }
179
180 pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
181 &self.agent_capabilities.prompt_capabilities
182 }
183
184 pub fn root_dir(&self) -> &Path {
185 &self.root_dir
186 }
187}
188
189impl AgentConnection for AcpConnection {
190 fn new_thread(
191 self: Rc<Self>,
192 project: Entity<Project>,
193 cwd: &Path,
194 cx: &mut App,
195 ) -> Task<Result<Entity<AcpThread>>> {
196 let name = self.server_name.clone();
197 let conn = self.connection.clone();
198 let sessions = self.sessions.clone();
199 let default_mode = self.default_mode.clone();
200 let cwd = cwd.to_path_buf();
201 let context_server_store = project.read(cx).context_server_store().read(cx);
202 let mcp_servers = if project.read(cx).is_local() {
203 context_server_store
204 .configured_server_ids()
205 .iter()
206 .filter_map(|id| {
207 let configuration = context_server_store.configuration_for_server(id)?;
208 let command = configuration.command();
209 Some(acp::McpServer::Stdio {
210 name: id.0.to_string(),
211 command: command.path.clone(),
212 args: command.args.clone(),
213 env: if let Some(env) = command.env.as_ref() {
214 env.iter()
215 .map(|(name, value)| acp::EnvVariable {
216 name: name.clone(),
217 value: value.clone(),
218 })
219 .collect()
220 } else {
221 vec![]
222 },
223 })
224 })
225 .collect()
226 } else {
227 // In SSH projects, the external agent is running on the remote
228 // machine, and currently we only run MCP servers on the local
229 // machine. So don't pass any MCP servers to the agent in that case.
230 Vec::new()
231 };
232
233 cx.spawn(async move |cx| {
234 let response = conn
235 .new_session(acp::NewSessionRequest { mcp_servers, cwd })
236 .await
237 .map_err(|err| {
238 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
239 let mut error = AuthRequired::new();
240
241 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
242 error = error.with_description(err.message);
243 }
244
245 anyhow!(error)
246 } else {
247 anyhow!(err)
248 }
249 })?;
250
251 let modes = response.modes.map(|modes| Rc::new(RefCell::new(modes)));
252
253 if let Some(default_mode) = default_mode {
254 if let Some(modes) = modes.as_ref() {
255 let mut modes_ref = modes.borrow_mut();
256 let has_mode = modes_ref.available_modes.iter().any(|mode| mode.id == default_mode);
257
258 if has_mode {
259 let initial_mode_id = modes_ref.current_mode_id.clone();
260
261 cx.spawn({
262 let default_mode = default_mode.clone();
263 let session_id = response.session_id.clone();
264 let modes = modes.clone();
265 async move |_| {
266 let result = conn.set_session_mode(acp::SetSessionModeRequest {
267 session_id,
268 mode_id: default_mode,
269 })
270 .await.log_err();
271
272 if result.is_none() {
273 modes.borrow_mut().current_mode_id = initial_mode_id;
274 }
275 }
276 }).detach();
277
278 modes_ref.current_mode_id = default_mode;
279 } else {
280 let available_modes = modes_ref
281 .available_modes
282 .iter()
283 .map(|mode| format!("- `{}`: {}", mode.id, mode.name))
284 .collect::<Vec<_>>()
285 .join("\n");
286
287 log::warn!(
288 "`{default_mode}` is not valid {name} mode. Available options:\n{available_modes}",
289 );
290 }
291 } else {
292 log::warn!(
293 "`{name}` does not support modes, but `default_mode` was set in settings.",
294 );
295 }
296 }
297
298 let session_id = response.session_id;
299 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
300 let thread = cx.new(|cx| {
301 AcpThread::new(
302 self.server_name.clone(),
303 self.clone(),
304 project,
305 action_log,
306 session_id.clone(),
307 // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
308 watch::Receiver::constant(self.agent_capabilities.prompt_capabilities),
309 cx,
310 )
311 })?;
312
313 let session = AcpSession {
314 thread: thread.downgrade(),
315 suppress_abort_err: false,
316 session_modes: modes
317 };
318 sessions.borrow_mut().insert(session_id, session);
319
320 Ok(thread)
321 })
322 }
323
324 fn auth_methods(&self) -> &[acp::AuthMethod] {
325 &self.auth_methods
326 }
327
328 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
329 let conn = self.connection.clone();
330 cx.foreground_executor().spawn(async move {
331 let result = conn
332 .authenticate(acp::AuthenticateRequest {
333 method_id: method_id.clone(),
334 })
335 .await?;
336
337 Ok(result)
338 })
339 }
340
341 fn prompt(
342 &self,
343 _id: Option<acp_thread::UserMessageId>,
344 params: acp::PromptRequest,
345 cx: &mut App,
346 ) -> Task<Result<acp::PromptResponse>> {
347 let conn = self.connection.clone();
348 let sessions = self.sessions.clone();
349 let session_id = params.session_id.clone();
350 cx.foreground_executor().spawn(async move {
351 let result = conn.prompt(params).await;
352
353 let mut suppress_abort_err = false;
354
355 if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
356 suppress_abort_err = session.suppress_abort_err;
357 session.suppress_abort_err = false;
358 }
359
360 match result {
361 Ok(response) => Ok(response),
362 Err(err) => {
363 if err.code != ErrorCode::INTERNAL_ERROR.code {
364 anyhow::bail!(err)
365 }
366
367 let Some(data) = &err.data else {
368 anyhow::bail!(err)
369 };
370
371 // Temporary workaround until the following PR is generally available:
372 // https://github.com/google-gemini/gemini-cli/pull/6656
373
374 #[derive(Deserialize)]
375 #[serde(deny_unknown_fields)]
376 struct ErrorDetails {
377 details: Box<str>,
378 }
379
380 match serde_json::from_value(data.clone()) {
381 Ok(ErrorDetails { details }) => {
382 if suppress_abort_err
383 && (details.contains("This operation was aborted")
384 || details.contains("The user aborted a request"))
385 {
386 Ok(acp::PromptResponse {
387 stop_reason: acp::StopReason::Cancelled,
388 })
389 } else {
390 Err(anyhow!(details))
391 }
392 }
393 Err(_) => Err(anyhow!(err)),
394 }
395 }
396 }
397 })
398 }
399
400 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
401 if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
402 session.suppress_abort_err = true;
403 }
404 let conn = self.connection.clone();
405 let params = acp::CancelNotification {
406 session_id: session_id.clone(),
407 };
408 cx.foreground_executor()
409 .spawn(async move { conn.cancel(params).await })
410 .detach();
411 }
412
413 fn session_modes(
414 &self,
415 session_id: &acp::SessionId,
416 _cx: &App,
417 ) -> Option<Rc<dyn acp_thread::AgentSessionModes>> {
418 let sessions = self.sessions.clone();
419 let sessions_ref = sessions.borrow();
420 let Some(session) = sessions_ref.get(session_id) else {
421 return None;
422 };
423
424 if let Some(modes) = session.session_modes.as_ref() {
425 Some(Rc::new(AcpSessionModes {
426 connection: self.connection.clone(),
427 session_id: session_id.clone(),
428 state: modes.clone(),
429 }) as _)
430 } else {
431 None
432 }
433 }
434
435 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
436 self
437 }
438}
439
440struct AcpSessionModes {
441 session_id: acp::SessionId,
442 connection: Rc<acp::ClientSideConnection>,
443 state: Rc<RefCell<acp::SessionModeState>>,
444}
445
446impl acp_thread::AgentSessionModes for AcpSessionModes {
447 fn current_mode(&self) -> acp::SessionModeId {
448 self.state.borrow().current_mode_id.clone()
449 }
450
451 fn all_modes(&self) -> Vec<acp::SessionMode> {
452 self.state.borrow().available_modes.clone()
453 }
454
455 fn set_mode(&self, mode_id: acp::SessionModeId, cx: &mut App) -> Task<Result<()>> {
456 let connection = self.connection.clone();
457 let session_id = self.session_id.clone();
458 let old_mode_id;
459 {
460 let mut state = self.state.borrow_mut();
461 old_mode_id = state.current_mode_id.clone();
462 state.current_mode_id = mode_id.clone();
463 };
464 let state = self.state.clone();
465 cx.foreground_executor().spawn(async move {
466 let result = connection
467 .set_session_mode(acp::SetSessionModeRequest {
468 session_id,
469 mode_id,
470 })
471 .await;
472
473 if result.is_err() {
474 state.borrow_mut().current_mode_id = old_mode_id;
475 }
476
477 result?;
478
479 Ok(())
480 })
481 }
482}
483
484struct ClientDelegate {
485 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
486 cx: AsyncApp,
487}
488
489impl acp::Client for ClientDelegate {
490 async fn request_permission(
491 &self,
492 arguments: acp::RequestPermissionRequest,
493 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
494 let respect_always_allow_setting;
495 let thread;
496 {
497 let sessions_ref = self.sessions.borrow();
498 let session = sessions_ref
499 .get(&arguments.session_id)
500 .context("Failed to get session")?;
501 respect_always_allow_setting = session.session_modes.is_none();
502 thread = session.thread.clone();
503 }
504
505 let cx = &mut self.cx.clone();
506
507 let task = thread.update(cx, |thread, cx| {
508 thread.request_tool_call_authorization(
509 arguments.tool_call,
510 arguments.options,
511 respect_always_allow_setting,
512 cx,
513 )
514 })??;
515
516 let outcome = task.await;
517
518 Ok(acp::RequestPermissionResponse { outcome })
519 }
520
521 async fn write_text_file(
522 &self,
523 arguments: acp::WriteTextFileRequest,
524 ) -> Result<(), acp::Error> {
525 let cx = &mut self.cx.clone();
526 let task = self
527 .session_thread(&arguments.session_id)?
528 .update(cx, |thread, cx| {
529 thread.write_text_file(arguments.path, arguments.content, cx)
530 })?;
531
532 task.await?;
533
534 Ok(())
535 }
536
537 async fn read_text_file(
538 &self,
539 arguments: acp::ReadTextFileRequest,
540 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
541 let task = self.session_thread(&arguments.session_id)?.update(
542 &mut self.cx.clone(),
543 |thread, cx| {
544 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
545 },
546 )?;
547
548 let content = task.await?;
549
550 Ok(acp::ReadTextFileResponse { content })
551 }
552
553 async fn session_notification(
554 &self,
555 notification: acp::SessionNotification,
556 ) -> Result<(), acp::Error> {
557 let sessions = self.sessions.borrow();
558 let session = sessions
559 .get(¬ification.session_id)
560 .context("Failed to get session")?;
561
562 if let acp::SessionUpdate::CurrentModeUpdate { current_mode_id } = ¬ification.update {
563 if let Some(session_modes) = &session.session_modes {
564 session_modes.borrow_mut().current_mode_id = current_mode_id.clone();
565 } else {
566 log::error!(
567 "Got a `CurrentModeUpdate` notification, but they agent didn't specify `modes` during setting setup."
568 );
569 }
570 }
571
572 session.thread.update(&mut self.cx.clone(), |thread, cx| {
573 thread.handle_session_update(notification.update, cx)
574 })??;
575
576 Ok(())
577 }
578
579 async fn create_terminal(
580 &self,
581 args: acp::CreateTerminalRequest,
582 ) -> Result<acp::CreateTerminalResponse, acp::Error> {
583 let terminal = self
584 .session_thread(&args.session_id)?
585 .update(&mut self.cx.clone(), |thread, cx| {
586 thread.create_terminal(
587 args.command,
588 args.args,
589 args.env,
590 args.cwd,
591 args.output_byte_limit,
592 cx,
593 )
594 })?
595 .await?;
596 Ok(
597 terminal.read_with(&self.cx, |terminal, _| acp::CreateTerminalResponse {
598 terminal_id: terminal.id().clone(),
599 })?,
600 )
601 }
602
603 async fn kill_terminal(&self, args: acp::KillTerminalRequest) -> Result<(), acp::Error> {
604 self.session_thread(&args.session_id)?
605 .update(&mut self.cx.clone(), |thread, cx| {
606 thread.kill_terminal(args.terminal_id, cx)
607 })??;
608
609 Ok(())
610 }
611
612 async fn release_terminal(&self, args: acp::ReleaseTerminalRequest) -> Result<(), acp::Error> {
613 self.session_thread(&args.session_id)?
614 .update(&mut self.cx.clone(), |thread, cx| {
615 thread.release_terminal(args.terminal_id, cx)
616 })??;
617
618 Ok(())
619 }
620
621 async fn terminal_output(
622 &self,
623 args: acp::TerminalOutputRequest,
624 ) -> Result<acp::TerminalOutputResponse, acp::Error> {
625 self.session_thread(&args.session_id)?
626 .read_with(&mut self.cx.clone(), |thread, cx| {
627 let out = thread
628 .terminal(args.terminal_id)?
629 .read(cx)
630 .current_output(cx);
631
632 Ok(out)
633 })?
634 }
635
636 async fn wait_for_terminal_exit(
637 &self,
638 args: acp::WaitForTerminalExitRequest,
639 ) -> Result<acp::WaitForTerminalExitResponse, acp::Error> {
640 let exit_status = self
641 .session_thread(&args.session_id)?
642 .update(&mut self.cx.clone(), |thread, cx| {
643 anyhow::Ok(thread.terminal(args.terminal_id)?.read(cx).wait_for_exit())
644 })??
645 .await;
646
647 Ok(acp::WaitForTerminalExitResponse { exit_status })
648 }
649}
650
651impl ClientDelegate {
652 fn session_thread(&self, session_id: &acp::SessionId) -> Result<WeakEntity<AcpThread>> {
653 let sessions = self.sessions.borrow();
654 sessions
655 .get(session_id)
656 .context("Failed to get session")
657 .map(|session| session.thread.clone())
658 }
659}