1use acp_thread::AgentConnection;
2use acp_tools::AcpConnectionRegistry;
3use action_log::ActionLog;
4use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
5use anyhow::anyhow;
6use collections::HashMap;
7use futures::AsyncBufReadExt as _;
8use futures::io::BufReader;
9use project::Project;
10use project::agent_server_store::AgentServerCommand;
11use serde::Deserialize;
12use util::ResultExt as _;
13
14use std::path::PathBuf;
15use std::{any::Any, cell::RefCell};
16use std::{path::Path, rc::Rc};
17use thiserror::Error;
18
19use anyhow::{Context as _, Result};
20use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
21
22use acp_thread::{AcpThread, AuthRequired, LoadError};
23
24#[derive(Debug, Error)]
25#[error("Unsupported version")]
26pub struct UnsupportedVersion;
27
28pub struct AcpConnection {
29 server_name: SharedString,
30 connection: Rc<acp::ClientSideConnection>,
31 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
32 auth_methods: Vec<acp::AuthMethod>,
33 agent_capabilities: acp::AgentCapabilities,
34 default_mode: Option<acp::SessionModeId>,
35 root_dir: PathBuf,
36 // NB: Don't move this into the wait_task, since we need to ensure the process is
37 // killed on drop (setting kill_on_drop on the command seems to not always work).
38 child: smol::process::Child,
39 _io_task: Task<Result<()>>,
40 _wait_task: Task<Result<()>>,
41 _stderr_task: Task<Result<()>>,
42}
43
44pub struct AcpSession {
45 thread: WeakEntity<AcpThread>,
46 suppress_abort_err: bool,
47 models: Option<Rc<RefCell<acp::SessionModelState>>>,
48 session_modes: Option<Rc<RefCell<acp::SessionModeState>>>,
49}
50
51pub async fn connect(
52 server_name: SharedString,
53 command: AgentServerCommand,
54 root_dir: &Path,
55 default_mode: Option<acp::SessionModeId>,
56 is_remote: bool,
57 cx: &mut AsyncApp,
58) -> Result<Rc<dyn AgentConnection>> {
59 let conn = AcpConnection::stdio(
60 server_name,
61 command.clone(),
62 root_dir,
63 default_mode,
64 is_remote,
65 cx,
66 )
67 .await?;
68 Ok(Rc::new(conn) as _)
69}
70
71const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
72
73impl AcpConnection {
74 pub async fn stdio(
75 server_name: SharedString,
76 command: AgentServerCommand,
77 root_dir: &Path,
78 default_mode: Option<acp::SessionModeId>,
79 is_remote: bool,
80 cx: &mut AsyncApp,
81 ) -> Result<Self> {
82 let mut child = util::command::new_smol_command(command.path);
83 child
84 .args(command.args.iter().map(|arg| arg.as_str()))
85 .envs(command.env.iter().flatten())
86 .stdin(std::process::Stdio::piped())
87 .stdout(std::process::Stdio::piped())
88 .stderr(std::process::Stdio::piped());
89 if !is_remote {
90 child.current_dir(root_dir);
91 }
92 let mut child = child.spawn()?;
93
94 let stdout = child.stdout.take().context("Failed to take stdout")?;
95 let stdin = child.stdin.take().context("Failed to take stdin")?;
96 let stderr = child.stderr.take().context("Failed to take stderr")?;
97 log::trace!("Spawned (pid: {})", child.id());
98
99 let sessions = Rc::new(RefCell::new(HashMap::default()));
100
101 let client = ClientDelegate {
102 sessions: sessions.clone(),
103 cx: cx.clone(),
104 };
105 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
106 let foreground_executor = cx.foreground_executor().clone();
107 move |fut| {
108 foreground_executor.spawn(fut).detach();
109 }
110 });
111
112 let io_task = cx.background_spawn(io_task);
113
114 let stderr_task = cx.background_spawn(async move {
115 let mut stderr = BufReader::new(stderr);
116 let mut line = String::new();
117 while let Ok(n) = stderr.read_line(&mut line).await
118 && n > 0
119 {
120 log::warn!("agent stderr: {}", &line);
121 line.clear();
122 }
123 Ok(())
124 });
125
126 let wait_task = cx.spawn({
127 let sessions = sessions.clone();
128 let status_fut = child.status();
129 async move |cx| {
130 let status = status_fut.await?;
131
132 for session in sessions.borrow().values() {
133 session
134 .thread
135 .update(cx, |thread, cx| {
136 thread.emit_load_error(LoadError::Exited { status }, cx)
137 })
138 .ok();
139 }
140
141 anyhow::Ok(())
142 }
143 });
144
145 let connection = Rc::new(connection);
146
147 cx.update(|cx| {
148 AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
149 registry.set_active_connection(server_name.clone(), &connection, cx)
150 });
151 })?;
152
153 let response = connection
154 .initialize(acp::InitializeRequest {
155 protocol_version: acp::VERSION,
156 client_capabilities: acp::ClientCapabilities {
157 fs: acp::FileSystemCapability {
158 read_text_file: true,
159 write_text_file: true,
160 meta: None,
161 },
162 terminal: true,
163 meta: None,
164 },
165 meta: None,
166 })
167 .await?;
168
169 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
170 return Err(UnsupportedVersion.into());
171 }
172
173 Ok(Self {
174 auth_methods: response.auth_methods,
175 root_dir: root_dir.to_owned(),
176 connection,
177 server_name,
178 sessions,
179 agent_capabilities: response.agent_capabilities,
180 default_mode,
181 _io_task: io_task,
182 _wait_task: wait_task,
183 _stderr_task: stderr_task,
184 child,
185 })
186 }
187
188 pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
189 &self.agent_capabilities.prompt_capabilities
190 }
191
192 pub fn root_dir(&self) -> &Path {
193 &self.root_dir
194 }
195}
196
197impl Drop for AcpConnection {
198 fn drop(&mut self) {
199 // See the comment on the child field.
200 self.child.kill().log_err();
201 }
202}
203
204impl AgentConnection for AcpConnection {
205 fn new_thread(
206 self: Rc<Self>,
207 project: Entity<Project>,
208 cwd: &Path,
209 cx: &mut App,
210 ) -> Task<Result<Entity<AcpThread>>> {
211 let name = self.server_name.clone();
212 let conn = self.connection.clone();
213 let sessions = self.sessions.clone();
214 let default_mode = self.default_mode.clone();
215 let cwd = cwd.to_path_buf();
216 let context_server_store = project.read(cx).context_server_store().read(cx);
217 let mcp_servers = if project.read(cx).is_local() {
218 context_server_store
219 .configured_server_ids()
220 .iter()
221 .filter_map(|id| {
222 let configuration = context_server_store.configuration_for_server(id)?;
223 let command = configuration.command();
224 Some(acp::McpServer::Stdio {
225 name: id.0.to_string(),
226 command: command.path.clone(),
227 args: command.args.clone(),
228 env: if let Some(env) = command.env.as_ref() {
229 env.iter()
230 .map(|(name, value)| acp::EnvVariable {
231 name: name.clone(),
232 value: value.clone(),
233 meta: None,
234 })
235 .collect()
236 } else {
237 vec![]
238 },
239 })
240 })
241 .collect()
242 } else {
243 // In SSH projects, the external agent is running on the remote
244 // machine, and currently we only run MCP servers on the local
245 // machine. So don't pass any MCP servers to the agent in that case.
246 Vec::new()
247 };
248
249 cx.spawn(async move |cx| {
250 let response = conn
251 .new_session(acp::NewSessionRequest { mcp_servers, cwd, meta: None })
252 .await
253 .map_err(|err| {
254 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
255 let mut error = AuthRequired::new();
256
257 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
258 error = error.with_description(err.message);
259 }
260
261 anyhow!(error)
262 } else {
263 anyhow!(err)
264 }
265 })?;
266
267 let modes = response.modes.map(|modes| Rc::new(RefCell::new(modes)));
268 let models = response.models.map(|models| Rc::new(RefCell::new(models)));
269
270 if let Some(default_mode) = default_mode {
271 if let Some(modes) = modes.as_ref() {
272 let mut modes_ref = modes.borrow_mut();
273 let has_mode = modes_ref.available_modes.iter().any(|mode| mode.id == default_mode);
274
275 if has_mode {
276 let initial_mode_id = modes_ref.current_mode_id.clone();
277
278 cx.spawn({
279 let default_mode = default_mode.clone();
280 let session_id = response.session_id.clone();
281 let modes = modes.clone();
282 async move |_| {
283 let result = conn.set_session_mode(acp::SetSessionModeRequest {
284 session_id,
285 mode_id: default_mode,
286 meta: None,
287 })
288 .await.log_err();
289
290 if result.is_none() {
291 modes.borrow_mut().current_mode_id = initial_mode_id;
292 }
293 }
294 }).detach();
295
296 modes_ref.current_mode_id = default_mode;
297 } else {
298 let available_modes = modes_ref
299 .available_modes
300 .iter()
301 .map(|mode| format!("- `{}`: {}", mode.id, mode.name))
302 .collect::<Vec<_>>()
303 .join("\n");
304
305 log::warn!(
306 "`{default_mode}` is not valid {name} mode. Available options:\n{available_modes}",
307 );
308 }
309 } else {
310 log::warn!(
311 "`{name}` does not support modes, but `default_mode` was set in settings.",
312 );
313 }
314 }
315
316 let session_id = response.session_id;
317 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
318 let thread = cx.new(|cx| {
319 AcpThread::new(
320 self.server_name.clone(),
321 self.clone(),
322 project,
323 action_log,
324 session_id.clone(),
325 // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
326 watch::Receiver::constant(self.agent_capabilities.prompt_capabilities.clone()),
327 cx,
328 )
329 })?;
330
331
332 let session = AcpSession {
333 thread: thread.downgrade(),
334 suppress_abort_err: false,
335 session_modes: modes,
336 models,
337 };
338 sessions.borrow_mut().insert(session_id, session);
339
340 Ok(thread)
341 })
342 }
343
344 fn auth_methods(&self) -> &[acp::AuthMethod] {
345 &self.auth_methods
346 }
347
348 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
349 let conn = self.connection.clone();
350 cx.foreground_executor().spawn(async move {
351 conn.authenticate(acp::AuthenticateRequest {
352 method_id: method_id.clone(),
353 meta: None,
354 })
355 .await?;
356
357 Ok(())
358 })
359 }
360
361 fn prompt(
362 &self,
363 _id: Option<acp_thread::UserMessageId>,
364 params: acp::PromptRequest,
365 cx: &mut App,
366 ) -> Task<Result<acp::PromptResponse>> {
367 let conn = self.connection.clone();
368 let sessions = self.sessions.clone();
369 let session_id = params.session_id.clone();
370 cx.foreground_executor().spawn(async move {
371 let result = conn.prompt(params).await;
372
373 let mut suppress_abort_err = false;
374
375 if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
376 suppress_abort_err = session.suppress_abort_err;
377 session.suppress_abort_err = false;
378 }
379
380 match result {
381 Ok(response) => Ok(response),
382 Err(err) => {
383 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
384 return Err(anyhow!(acp::Error::auth_required()));
385 }
386
387 if err.code != ErrorCode::INTERNAL_ERROR.code {
388 anyhow::bail!(err)
389 }
390
391 let Some(data) = &err.data else {
392 anyhow::bail!(err)
393 };
394
395 // Temporary workaround until the following PR is generally available:
396 // https://github.com/google-gemini/gemini-cli/pull/6656
397
398 #[derive(Deserialize)]
399 #[serde(deny_unknown_fields)]
400 struct ErrorDetails {
401 details: Box<str>,
402 }
403
404 match serde_json::from_value(data.clone()) {
405 Ok(ErrorDetails { details }) => {
406 if suppress_abort_err
407 && (details.contains("This operation was aborted")
408 || details.contains("The user aborted a request"))
409 {
410 Ok(acp::PromptResponse {
411 stop_reason: acp::StopReason::Cancelled,
412 meta: None,
413 })
414 } else {
415 Err(anyhow!(details))
416 }
417 }
418 Err(_) => Err(anyhow!(err)),
419 }
420 }
421 }
422 })
423 }
424
425 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
426 if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
427 session.suppress_abort_err = true;
428 }
429 let conn = self.connection.clone();
430 let params = acp::CancelNotification {
431 session_id: session_id.clone(),
432 meta: None,
433 };
434 cx.foreground_executor()
435 .spawn(async move { conn.cancel(params).await })
436 .detach();
437 }
438
439 fn session_modes(
440 &self,
441 session_id: &acp::SessionId,
442 _cx: &App,
443 ) -> Option<Rc<dyn acp_thread::AgentSessionModes>> {
444 let sessions = self.sessions.clone();
445 let sessions_ref = sessions.borrow();
446 let Some(session) = sessions_ref.get(session_id) else {
447 return None;
448 };
449
450 if let Some(modes) = session.session_modes.as_ref() {
451 Some(Rc::new(AcpSessionModes {
452 connection: self.connection.clone(),
453 session_id: session_id.clone(),
454 state: modes.clone(),
455 }) as _)
456 } else {
457 None
458 }
459 }
460
461 fn model_selector(
462 &self,
463 session_id: &acp::SessionId,
464 ) -> Option<Rc<dyn acp_thread::AgentModelSelector>> {
465 let sessions = self.sessions.clone();
466 let sessions_ref = sessions.borrow();
467 let Some(session) = sessions_ref.get(session_id) else {
468 return None;
469 };
470
471 if let Some(models) = session.models.as_ref() {
472 Some(Rc::new(AcpModelSelector::new(
473 session_id.clone(),
474 self.connection.clone(),
475 models.clone(),
476 )) as _)
477 } else {
478 None
479 }
480 }
481
482 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
483 self
484 }
485}
486
487struct AcpSessionModes {
488 session_id: acp::SessionId,
489 connection: Rc<acp::ClientSideConnection>,
490 state: Rc<RefCell<acp::SessionModeState>>,
491}
492
493impl acp_thread::AgentSessionModes for AcpSessionModes {
494 fn current_mode(&self) -> acp::SessionModeId {
495 self.state.borrow().current_mode_id.clone()
496 }
497
498 fn all_modes(&self) -> Vec<acp::SessionMode> {
499 self.state.borrow().available_modes.clone()
500 }
501
502 fn set_mode(&self, mode_id: acp::SessionModeId, cx: &mut App) -> Task<Result<()>> {
503 let connection = self.connection.clone();
504 let session_id = self.session_id.clone();
505 let old_mode_id;
506 {
507 let mut state = self.state.borrow_mut();
508 old_mode_id = state.current_mode_id.clone();
509 state.current_mode_id = mode_id.clone();
510 };
511 let state = self.state.clone();
512 cx.foreground_executor().spawn(async move {
513 let result = connection
514 .set_session_mode(acp::SetSessionModeRequest {
515 session_id,
516 mode_id,
517 meta: None,
518 })
519 .await;
520
521 if result.is_err() {
522 state.borrow_mut().current_mode_id = old_mode_id;
523 }
524
525 result?;
526
527 Ok(())
528 })
529 }
530}
531
532struct AcpModelSelector {
533 session_id: acp::SessionId,
534 connection: Rc<acp::ClientSideConnection>,
535 state: Rc<RefCell<acp::SessionModelState>>,
536}
537
538impl AcpModelSelector {
539 fn new(
540 session_id: acp::SessionId,
541 connection: Rc<acp::ClientSideConnection>,
542 state: Rc<RefCell<acp::SessionModelState>>,
543 ) -> Self {
544 Self {
545 session_id,
546 connection,
547 state,
548 }
549 }
550}
551
552impl acp_thread::AgentModelSelector for AcpModelSelector {
553 fn list_models(&self, _cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
554 Task::ready(Ok(acp_thread::AgentModelList::Flat(
555 self.state
556 .borrow()
557 .available_models
558 .clone()
559 .into_iter()
560 .map(acp_thread::AgentModelInfo::from)
561 .collect(),
562 )))
563 }
564
565 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
566 let connection = self.connection.clone();
567 let session_id = self.session_id.clone();
568 let old_model_id;
569 {
570 let mut state = self.state.borrow_mut();
571 old_model_id = state.current_model_id.clone();
572 state.current_model_id = model_id.clone();
573 };
574 let state = self.state.clone();
575 cx.foreground_executor().spawn(async move {
576 let result = connection
577 .set_session_model(acp::SetSessionModelRequest {
578 session_id,
579 model_id,
580 meta: None,
581 })
582 .await;
583
584 if result.is_err() {
585 state.borrow_mut().current_model_id = old_model_id;
586 }
587
588 result?;
589
590 Ok(())
591 })
592 }
593
594 fn selected_model(&self, _cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
595 let state = self.state.borrow();
596 Task::ready(
597 state
598 .available_models
599 .iter()
600 .find(|m| m.model_id == state.current_model_id)
601 .cloned()
602 .map(acp_thread::AgentModelInfo::from)
603 .ok_or_else(|| anyhow::anyhow!("Model not found")),
604 )
605 }
606}
607
608struct ClientDelegate {
609 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
610 cx: AsyncApp,
611}
612
613#[async_trait::async_trait(?Send)]
614impl acp::Client for ClientDelegate {
615 async fn request_permission(
616 &self,
617 arguments: acp::RequestPermissionRequest,
618 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
619 let respect_always_allow_setting;
620 let thread;
621 {
622 let sessions_ref = self.sessions.borrow();
623 let session = sessions_ref
624 .get(&arguments.session_id)
625 .context("Failed to get session")?;
626 respect_always_allow_setting = session.session_modes.is_none();
627 thread = session.thread.clone();
628 }
629
630 let cx = &mut self.cx.clone();
631
632 let task = thread.update(cx, |thread, cx| {
633 thread.request_tool_call_authorization(
634 arguments.tool_call,
635 arguments.options,
636 respect_always_allow_setting,
637 cx,
638 )
639 })??;
640
641 let outcome = task.await;
642
643 Ok(acp::RequestPermissionResponse {
644 outcome,
645 meta: None,
646 })
647 }
648
649 async fn write_text_file(
650 &self,
651 arguments: acp::WriteTextFileRequest,
652 ) -> Result<acp::WriteTextFileResponse, acp::Error> {
653 let cx = &mut self.cx.clone();
654 let task = self
655 .session_thread(&arguments.session_id)?
656 .update(cx, |thread, cx| {
657 thread.write_text_file(arguments.path, arguments.content, cx)
658 })?;
659
660 task.await?;
661
662 Ok(Default::default())
663 }
664
665 async fn read_text_file(
666 &self,
667 arguments: acp::ReadTextFileRequest,
668 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
669 let task = self.session_thread(&arguments.session_id)?.update(
670 &mut self.cx.clone(),
671 |thread, cx| {
672 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
673 },
674 )?;
675
676 let content = task.await?;
677
678 Ok(acp::ReadTextFileResponse {
679 content,
680 meta: None,
681 })
682 }
683
684 async fn session_notification(
685 &self,
686 notification: acp::SessionNotification,
687 ) -> Result<(), acp::Error> {
688 let sessions = self.sessions.borrow();
689 let session = sessions
690 .get(¬ification.session_id)
691 .context("Failed to get session")?;
692
693 if let acp::SessionUpdate::CurrentModeUpdate { current_mode_id } = ¬ification.update {
694 if let Some(session_modes) = &session.session_modes {
695 session_modes.borrow_mut().current_mode_id = current_mode_id.clone();
696 } else {
697 log::error!(
698 "Got a `CurrentModeUpdate` notification, but they agent didn't specify `modes` during setting setup."
699 );
700 }
701 }
702
703 session.thread.update(&mut self.cx.clone(), |thread, cx| {
704 thread.handle_session_update(notification.update, cx)
705 })??;
706
707 Ok(())
708 }
709
710 async fn create_terminal(
711 &self,
712 args: acp::CreateTerminalRequest,
713 ) -> Result<acp::CreateTerminalResponse, acp::Error> {
714 let terminal = self
715 .session_thread(&args.session_id)?
716 .update(&mut self.cx.clone(), |thread, cx| {
717 thread.create_terminal(
718 args.command,
719 args.args,
720 args.env,
721 args.cwd,
722 args.output_byte_limit,
723 cx,
724 )
725 })?
726 .await?;
727 Ok(
728 terminal.read_with(&self.cx, |terminal, _| acp::CreateTerminalResponse {
729 terminal_id: terminal.id().clone(),
730 meta: None,
731 })?,
732 )
733 }
734
735 async fn kill_terminal_command(
736 &self,
737 args: acp::KillTerminalCommandRequest,
738 ) -> Result<acp::KillTerminalCommandResponse, acp::Error> {
739 self.session_thread(&args.session_id)?
740 .update(&mut self.cx.clone(), |thread, cx| {
741 thread.kill_terminal(args.terminal_id, cx)
742 })??;
743
744 Ok(Default::default())
745 }
746
747 async fn ext_method(&self, _args: acp::ExtRequest) -> Result<acp::ExtResponse, acp::Error> {
748 Err(acp::Error::method_not_found())
749 }
750
751 async fn ext_notification(&self, _args: acp::ExtNotification) -> Result<(), acp::Error> {
752 Err(acp::Error::method_not_found())
753 }
754
755 async fn release_terminal(
756 &self,
757 args: acp::ReleaseTerminalRequest,
758 ) -> Result<acp::ReleaseTerminalResponse, acp::Error> {
759 self.session_thread(&args.session_id)?
760 .update(&mut self.cx.clone(), |thread, cx| {
761 thread.release_terminal(args.terminal_id, cx)
762 })??;
763
764 Ok(Default::default())
765 }
766
767 async fn terminal_output(
768 &self,
769 args: acp::TerminalOutputRequest,
770 ) -> Result<acp::TerminalOutputResponse, acp::Error> {
771 self.session_thread(&args.session_id)?
772 .read_with(&mut self.cx.clone(), |thread, cx| {
773 let out = thread
774 .terminal(args.terminal_id)?
775 .read(cx)
776 .current_output(cx);
777
778 Ok(out)
779 })?
780 }
781
782 async fn wait_for_terminal_exit(
783 &self,
784 args: acp::WaitForTerminalExitRequest,
785 ) -> Result<acp::WaitForTerminalExitResponse, acp::Error> {
786 let exit_status = self
787 .session_thread(&args.session_id)?
788 .update(&mut self.cx.clone(), |thread, cx| {
789 anyhow::Ok(thread.terminal(args.terminal_id)?.read(cx).wait_for_exit())
790 })??
791 .await;
792
793 Ok(acp::WaitForTerminalExitResponse {
794 exit_status,
795 meta: None,
796 })
797 }
798}
799
800impl ClientDelegate {
801 fn session_thread(&self, session_id: &acp::SessionId) -> Result<WeakEntity<AcpThread>> {
802 let sessions = self.sessions.borrow();
803 sessions
804 .get(session_id)
805 .context("Failed to get session")
806 .map(|session| session.thread.clone())
807 }
808}