1use acp_thread::AgentConnection;
2use acp_tools::AcpConnectionRegistry;
3use action_log::ActionLog;
4use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
5use anyhow::anyhow;
6use collections::HashMap;
7use futures::AsyncBufReadExt as _;
8use futures::io::BufReader;
9use project::Project;
10use project::agent_server_store::AgentServerCommand;
11use serde::Deserialize;
12use util::ResultExt as _;
13
14use std::path::PathBuf;
15use std::{any::Any, cell::RefCell};
16use std::{path::Path, rc::Rc};
17use thiserror::Error;
18
19use anyhow::{Context as _, Result};
20use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
21
22use acp_thread::{AcpThread, AuthRequired, LoadError};
23
24#[derive(Debug, Error)]
25#[error("Unsupported version")]
26pub struct UnsupportedVersion;
27
28pub struct AcpConnection {
29 server_name: SharedString,
30 connection: Rc<acp::ClientSideConnection>,
31 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
32 auth_methods: Vec<acp::AuthMethod>,
33 agent_capabilities: acp::AgentCapabilities,
34 default_mode: Option<acp::SessionModeId>,
35 root_dir: PathBuf,
36 // NB: Don't move this into the wait_task, since we need to ensure the process is
37 // killed on drop (setting kill_on_drop on the command seems to not always work).
38 child: smol::process::Child,
39 _io_task: Task<Result<()>>,
40 _wait_task: Task<Result<()>>,
41 _stderr_task: Task<Result<()>>,
42}
43
44pub struct AcpSession {
45 thread: WeakEntity<AcpThread>,
46 suppress_abort_err: bool,
47 models: Option<Rc<RefCell<acp::SessionModelState>>>,
48 session_modes: Option<Rc<RefCell<acp::SessionModeState>>>,
49}
50
51pub async fn connect(
52 server_name: SharedString,
53 command: AgentServerCommand,
54 root_dir: &Path,
55 default_mode: Option<acp::SessionModeId>,
56 is_remote: bool,
57 cx: &mut AsyncApp,
58) -> Result<Rc<dyn AgentConnection>> {
59 let conn = AcpConnection::stdio(
60 server_name,
61 command.clone(),
62 root_dir,
63 default_mode,
64 is_remote,
65 cx,
66 )
67 .await?;
68 Ok(Rc::new(conn) as _)
69}
70
71const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
72
73impl AcpConnection {
74 pub async fn stdio(
75 server_name: SharedString,
76 command: AgentServerCommand,
77 root_dir: &Path,
78 default_mode: Option<acp::SessionModeId>,
79 is_remote: bool,
80 cx: &mut AsyncApp,
81 ) -> Result<Self> {
82 let mut child = util::command::new_smol_command(command.path);
83 child
84 .args(command.args.iter().map(|arg| arg.as_str()))
85 .envs(command.env.iter().flatten())
86 .stdin(std::process::Stdio::piped())
87 .stdout(std::process::Stdio::piped())
88 .stderr(std::process::Stdio::piped());
89 if !is_remote {
90 child.current_dir(root_dir);
91 }
92 let mut child = child.spawn()?;
93
94 let stdout = child.stdout.take().context("Failed to take stdout")?;
95 let stdin = child.stdin.take().context("Failed to take stdin")?;
96 let stderr = child.stderr.take().context("Failed to take stderr")?;
97 log::trace!("Spawned (pid: {})", child.id());
98
99 let sessions = Rc::new(RefCell::new(HashMap::default()));
100
101 let client = ClientDelegate {
102 sessions: sessions.clone(),
103 cx: cx.clone(),
104 };
105 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
106 let foreground_executor = cx.foreground_executor().clone();
107 move |fut| {
108 foreground_executor.spawn(fut).detach();
109 }
110 });
111
112 let io_task = cx.background_spawn(io_task);
113
114 let stderr_task = cx.background_spawn(async move {
115 let mut stderr = BufReader::new(stderr);
116 let mut line = String::new();
117 while let Ok(n) = stderr.read_line(&mut line).await
118 && n > 0
119 {
120 log::warn!("agent stderr: {}", &line);
121 line.clear();
122 }
123 Ok(())
124 });
125
126 let wait_task = cx.spawn({
127 let sessions = sessions.clone();
128 let status_fut = child.status();
129 async move |cx| {
130 let status = status_fut.await?;
131
132 for session in sessions.borrow().values() {
133 session
134 .thread
135 .update(cx, |thread, cx| {
136 thread.emit_load_error(LoadError::Exited { status }, cx)
137 })
138 .ok();
139 }
140
141 anyhow::Ok(())
142 }
143 });
144
145 let connection = Rc::new(connection);
146
147 cx.update(|cx| {
148 AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
149 registry.set_active_connection(server_name.clone(), &connection, cx)
150 });
151 })?;
152
153 let response = connection
154 .initialize(acp::InitializeRequest {
155 protocol_version: acp::VERSION,
156 client_capabilities: acp::ClientCapabilities {
157 fs: acp::FileSystemCapability {
158 read_text_file: true,
159 write_text_file: true,
160 meta: None,
161 },
162 terminal: true,
163 meta: None,
164 },
165 meta: None,
166 })
167 .await?;
168
169 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
170 return Err(UnsupportedVersion.into());
171 }
172
173 Ok(Self {
174 auth_methods: response.auth_methods,
175 root_dir: root_dir.to_owned(),
176 connection,
177 server_name,
178 sessions,
179 agent_capabilities: response.agent_capabilities,
180 default_mode,
181 _io_task: io_task,
182 _wait_task: wait_task,
183 _stderr_task: stderr_task,
184 child,
185 })
186 }
187
188 pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
189 &self.agent_capabilities.prompt_capabilities
190 }
191
192 pub fn root_dir(&self) -> &Path {
193 &self.root_dir
194 }
195}
196
197impl Drop for AcpConnection {
198 fn drop(&mut self) {
199 // See the comment on the child field.
200 self.child.kill().log_err();
201 }
202}
203
204impl AgentConnection for AcpConnection {
205 fn new_thread(
206 self: Rc<Self>,
207 project: Entity<Project>,
208 cwd: &Path,
209 cx: &mut App,
210 ) -> Task<Result<Entity<AcpThread>>> {
211 let name = self.server_name.clone();
212 let conn = self.connection.clone();
213 let sessions = self.sessions.clone();
214 let default_mode = self.default_mode.clone();
215 let cwd = cwd.to_path_buf();
216 let context_server_store = project.read(cx).context_server_store().read(cx);
217 let mcp_servers = if project.read(cx).is_local() {
218 context_server_store
219 .configured_server_ids()
220 .iter()
221 .filter_map(|id| {
222 let configuration = context_server_store.configuration_for_server(id)?;
223 let command = configuration.command();
224 Some(acp::McpServer::Stdio {
225 name: id.0.to_string(),
226 command: command.path.clone(),
227 args: command.args.clone(),
228 env: if let Some(env) = command.env.as_ref() {
229 env.iter()
230 .map(|(name, value)| acp::EnvVariable {
231 name: name.clone(),
232 value: value.clone(),
233 meta: None,
234 })
235 .collect()
236 } else {
237 vec![]
238 },
239 })
240 })
241 .collect()
242 } else {
243 // In SSH projects, the external agent is running on the remote
244 // machine, and currently we only run MCP servers on the local
245 // machine. So don't pass any MCP servers to the agent in that case.
246 Vec::new()
247 };
248
249 cx.spawn(async move |cx| {
250 let response = conn
251 .new_session(acp::NewSessionRequest { mcp_servers, cwd, meta: None })
252 .await
253 .map_err(|err| {
254 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
255 let mut error = AuthRequired::new();
256
257 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
258 error = error.with_description(err.message);
259 }
260
261 anyhow!(error)
262 } else {
263 anyhow!(err)
264 }
265 })?;
266
267 let modes = response.modes.map(|modes| Rc::new(RefCell::new(modes)));
268 let models = response.models.map(|models| Rc::new(RefCell::new(models)));
269
270 if let Some(default_mode) = default_mode {
271 if let Some(modes) = modes.as_ref() {
272 let mut modes_ref = modes.borrow_mut();
273 let has_mode = modes_ref.available_modes.iter().any(|mode| mode.id == default_mode);
274
275 if has_mode {
276 let initial_mode_id = modes_ref.current_mode_id.clone();
277
278 cx.spawn({
279 let default_mode = default_mode.clone();
280 let session_id = response.session_id.clone();
281 let modes = modes.clone();
282 async move |_| {
283 let result = conn.set_session_mode(acp::SetSessionModeRequest {
284 session_id,
285 mode_id: default_mode,
286 meta: None,
287 })
288 .await.log_err();
289
290 if result.is_none() {
291 modes.borrow_mut().current_mode_id = initial_mode_id;
292 }
293 }
294 }).detach();
295
296 modes_ref.current_mode_id = default_mode;
297 } else {
298 let available_modes = modes_ref
299 .available_modes
300 .iter()
301 .map(|mode| format!("- `{}`: {}", mode.id, mode.name))
302 .collect::<Vec<_>>()
303 .join("\n");
304
305 log::warn!(
306 "`{default_mode}` is not valid {name} mode. Available options:\n{available_modes}",
307 );
308 }
309 } else {
310 log::warn!(
311 "`{name}` does not support modes, but `default_mode` was set in settings.",
312 );
313 }
314 }
315
316 let session_id = response.session_id;
317 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
318 let thread = cx.new(|cx| {
319 AcpThread::new(
320 self.server_name.clone(),
321 self.clone(),
322 project,
323 action_log,
324 session_id.clone(),
325 // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
326 watch::Receiver::constant(self.agent_capabilities.prompt_capabilities.clone()),
327 cx,
328 )
329 })?;
330
331
332 let session = AcpSession {
333 thread: thread.downgrade(),
334 suppress_abort_err: false,
335 session_modes: modes,
336 models,
337 };
338 sessions.borrow_mut().insert(session_id, session);
339
340 Ok(thread)
341 })
342 }
343
344 fn auth_methods(&self) -> &[acp::AuthMethod] {
345 &self.auth_methods
346 }
347
348 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
349 let conn = self.connection.clone();
350 cx.foreground_executor().spawn(async move {
351 conn.authenticate(acp::AuthenticateRequest {
352 method_id: method_id.clone(),
353 meta: None,
354 })
355 .await?;
356
357 Ok(())
358 })
359 }
360
361 fn prompt(
362 &self,
363 _id: Option<acp_thread::UserMessageId>,
364 params: acp::PromptRequest,
365 cx: &mut App,
366 ) -> Task<Result<acp::PromptResponse>> {
367 let conn = self.connection.clone();
368 let sessions = self.sessions.clone();
369 let session_id = params.session_id.clone();
370 cx.foreground_executor().spawn(async move {
371 let result = conn.prompt(params).await;
372
373 let mut suppress_abort_err = false;
374
375 if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
376 suppress_abort_err = session.suppress_abort_err;
377 session.suppress_abort_err = false;
378 }
379
380 match result {
381 Ok(response) => Ok(response),
382 Err(err) => {
383 if err.code != ErrorCode::INTERNAL_ERROR.code {
384 anyhow::bail!(err)
385 }
386
387 let Some(data) = &err.data else {
388 anyhow::bail!(err)
389 };
390
391 // Temporary workaround until the following PR is generally available:
392 // https://github.com/google-gemini/gemini-cli/pull/6656
393
394 #[derive(Deserialize)]
395 #[serde(deny_unknown_fields)]
396 struct ErrorDetails {
397 details: Box<str>,
398 }
399
400 match serde_json::from_value(data.clone()) {
401 Ok(ErrorDetails { details }) => {
402 if suppress_abort_err
403 && (details.contains("This operation was aborted")
404 || details.contains("The user aborted a request"))
405 {
406 Ok(acp::PromptResponse {
407 stop_reason: acp::StopReason::Cancelled,
408 meta: None,
409 })
410 } else {
411 Err(anyhow!(details))
412 }
413 }
414 Err(_) => Err(anyhow!(err)),
415 }
416 }
417 }
418 })
419 }
420
421 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
422 if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
423 session.suppress_abort_err = true;
424 }
425 let conn = self.connection.clone();
426 let params = acp::CancelNotification {
427 session_id: session_id.clone(),
428 meta: None,
429 };
430 cx.foreground_executor()
431 .spawn(async move { conn.cancel(params).await })
432 .detach();
433 }
434
435 fn session_modes(
436 &self,
437 session_id: &acp::SessionId,
438 _cx: &App,
439 ) -> Option<Rc<dyn acp_thread::AgentSessionModes>> {
440 let sessions = self.sessions.clone();
441 let sessions_ref = sessions.borrow();
442 let Some(session) = sessions_ref.get(session_id) else {
443 return None;
444 };
445
446 if let Some(modes) = session.session_modes.as_ref() {
447 Some(Rc::new(AcpSessionModes {
448 connection: self.connection.clone(),
449 session_id: session_id.clone(),
450 state: modes.clone(),
451 }) as _)
452 } else {
453 None
454 }
455 }
456
457 fn model_selector(
458 &self,
459 session_id: &acp::SessionId,
460 ) -> Option<Rc<dyn acp_thread::AgentModelSelector>> {
461 let sessions = self.sessions.clone();
462 let sessions_ref = sessions.borrow();
463 let Some(session) = sessions_ref.get(session_id) else {
464 return None;
465 };
466
467 if let Some(models) = session.models.as_ref() {
468 Some(Rc::new(AcpModelSelector::new(
469 session_id.clone(),
470 self.connection.clone(),
471 models.clone(),
472 )) as _)
473 } else {
474 None
475 }
476 }
477
478 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
479 self
480 }
481}
482
483struct AcpSessionModes {
484 session_id: acp::SessionId,
485 connection: Rc<acp::ClientSideConnection>,
486 state: Rc<RefCell<acp::SessionModeState>>,
487}
488
489impl acp_thread::AgentSessionModes for AcpSessionModes {
490 fn current_mode(&self) -> acp::SessionModeId {
491 self.state.borrow().current_mode_id.clone()
492 }
493
494 fn all_modes(&self) -> Vec<acp::SessionMode> {
495 self.state.borrow().available_modes.clone()
496 }
497
498 fn set_mode(&self, mode_id: acp::SessionModeId, cx: &mut App) -> Task<Result<()>> {
499 let connection = self.connection.clone();
500 let session_id = self.session_id.clone();
501 let old_mode_id;
502 {
503 let mut state = self.state.borrow_mut();
504 old_mode_id = state.current_mode_id.clone();
505 state.current_mode_id = mode_id.clone();
506 };
507 let state = self.state.clone();
508 cx.foreground_executor().spawn(async move {
509 let result = connection
510 .set_session_mode(acp::SetSessionModeRequest {
511 session_id,
512 mode_id,
513 meta: None,
514 })
515 .await;
516
517 if result.is_err() {
518 state.borrow_mut().current_mode_id = old_mode_id;
519 }
520
521 result?;
522
523 Ok(())
524 })
525 }
526}
527
528struct AcpModelSelector {
529 session_id: acp::SessionId,
530 connection: Rc<acp::ClientSideConnection>,
531 state: Rc<RefCell<acp::SessionModelState>>,
532}
533
534impl AcpModelSelector {
535 fn new(
536 session_id: acp::SessionId,
537 connection: Rc<acp::ClientSideConnection>,
538 state: Rc<RefCell<acp::SessionModelState>>,
539 ) -> Self {
540 Self {
541 session_id,
542 connection,
543 state,
544 }
545 }
546}
547
548impl acp_thread::AgentModelSelector for AcpModelSelector {
549 fn list_models(&self, _cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
550 Task::ready(Ok(acp_thread::AgentModelList::Flat(
551 self.state
552 .borrow()
553 .available_models
554 .clone()
555 .into_iter()
556 .map(acp_thread::AgentModelInfo::from)
557 .collect(),
558 )))
559 }
560
561 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
562 let connection = self.connection.clone();
563 let session_id = self.session_id.clone();
564 let old_model_id;
565 {
566 let mut state = self.state.borrow_mut();
567 old_model_id = state.current_model_id.clone();
568 state.current_model_id = model_id.clone();
569 };
570 let state = self.state.clone();
571 cx.foreground_executor().spawn(async move {
572 let result = connection
573 .set_session_model(acp::SetSessionModelRequest {
574 session_id,
575 model_id,
576 meta: None,
577 })
578 .await;
579
580 if result.is_err() {
581 state.borrow_mut().current_model_id = old_model_id;
582 }
583
584 result?;
585
586 Ok(())
587 })
588 }
589
590 fn selected_model(&self, _cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
591 let state = self.state.borrow();
592 Task::ready(
593 state
594 .available_models
595 .iter()
596 .find(|m| m.model_id == state.current_model_id)
597 .cloned()
598 .map(acp_thread::AgentModelInfo::from)
599 .ok_or_else(|| anyhow::anyhow!("Model not found")),
600 )
601 }
602}
603
604struct ClientDelegate {
605 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
606 cx: AsyncApp,
607}
608
609#[async_trait::async_trait(?Send)]
610impl acp::Client for ClientDelegate {
611 async fn request_permission(
612 &self,
613 arguments: acp::RequestPermissionRequest,
614 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
615 let respect_always_allow_setting;
616 let thread;
617 {
618 let sessions_ref = self.sessions.borrow();
619 let session = sessions_ref
620 .get(&arguments.session_id)
621 .context("Failed to get session")?;
622 respect_always_allow_setting = session.session_modes.is_none();
623 thread = session.thread.clone();
624 }
625
626 let cx = &mut self.cx.clone();
627
628 let task = thread.update(cx, |thread, cx| {
629 thread.request_tool_call_authorization(
630 arguments.tool_call,
631 arguments.options,
632 respect_always_allow_setting,
633 cx,
634 )
635 })??;
636
637 let outcome = task.await;
638
639 Ok(acp::RequestPermissionResponse {
640 outcome,
641 meta: None,
642 })
643 }
644
645 async fn write_text_file(
646 &self,
647 arguments: acp::WriteTextFileRequest,
648 ) -> Result<acp::WriteTextFileResponse, acp::Error> {
649 let cx = &mut self.cx.clone();
650 let task = self
651 .session_thread(&arguments.session_id)?
652 .update(cx, |thread, cx| {
653 thread.write_text_file(arguments.path, arguments.content, cx)
654 })?;
655
656 task.await?;
657
658 Ok(Default::default())
659 }
660
661 async fn read_text_file(
662 &self,
663 arguments: acp::ReadTextFileRequest,
664 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
665 let task = self.session_thread(&arguments.session_id)?.update(
666 &mut self.cx.clone(),
667 |thread, cx| {
668 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
669 },
670 )?;
671
672 let content = task.await?;
673
674 Ok(acp::ReadTextFileResponse {
675 content,
676 meta: None,
677 })
678 }
679
680 async fn session_notification(
681 &self,
682 notification: acp::SessionNotification,
683 ) -> Result<(), acp::Error> {
684 let sessions = self.sessions.borrow();
685 let session = sessions
686 .get(¬ification.session_id)
687 .context("Failed to get session")?;
688
689 if let acp::SessionUpdate::CurrentModeUpdate { current_mode_id } = ¬ification.update {
690 if let Some(session_modes) = &session.session_modes {
691 session_modes.borrow_mut().current_mode_id = current_mode_id.clone();
692 } else {
693 log::error!(
694 "Got a `CurrentModeUpdate` notification, but they agent didn't specify `modes` during setting setup."
695 );
696 }
697 }
698
699 session.thread.update(&mut self.cx.clone(), |thread, cx| {
700 thread.handle_session_update(notification.update, cx)
701 })??;
702
703 Ok(())
704 }
705
706 async fn create_terminal(
707 &self,
708 args: acp::CreateTerminalRequest,
709 ) -> Result<acp::CreateTerminalResponse, acp::Error> {
710 let terminal = self
711 .session_thread(&args.session_id)?
712 .update(&mut self.cx.clone(), |thread, cx| {
713 thread.create_terminal(
714 args.command,
715 args.args,
716 args.env,
717 args.cwd,
718 args.output_byte_limit,
719 cx,
720 )
721 })?
722 .await?;
723 Ok(
724 terminal.read_with(&self.cx, |terminal, _| acp::CreateTerminalResponse {
725 terminal_id: terminal.id().clone(),
726 meta: None,
727 })?,
728 )
729 }
730
731 async fn kill_terminal_command(
732 &self,
733 args: acp::KillTerminalCommandRequest,
734 ) -> Result<acp::KillTerminalCommandResponse, acp::Error> {
735 self.session_thread(&args.session_id)?
736 .update(&mut self.cx.clone(), |thread, cx| {
737 thread.kill_terminal(args.terminal_id, cx)
738 })??;
739
740 Ok(Default::default())
741 }
742
743 async fn ext_method(&self, _args: acp::ExtRequest) -> Result<acp::ExtResponse, acp::Error> {
744 Err(acp::Error::method_not_found())
745 }
746
747 async fn ext_notification(&self, _args: acp::ExtNotification) -> Result<(), acp::Error> {
748 Err(acp::Error::method_not_found())
749 }
750
751 async fn release_terminal(
752 &self,
753 args: acp::ReleaseTerminalRequest,
754 ) -> Result<acp::ReleaseTerminalResponse, acp::Error> {
755 self.session_thread(&args.session_id)?
756 .update(&mut self.cx.clone(), |thread, cx| {
757 thread.release_terminal(args.terminal_id, cx)
758 })??;
759
760 Ok(Default::default())
761 }
762
763 async fn terminal_output(
764 &self,
765 args: acp::TerminalOutputRequest,
766 ) -> Result<acp::TerminalOutputResponse, acp::Error> {
767 self.session_thread(&args.session_id)?
768 .read_with(&mut self.cx.clone(), |thread, cx| {
769 let out = thread
770 .terminal(args.terminal_id)?
771 .read(cx)
772 .current_output(cx);
773
774 Ok(out)
775 })?
776 }
777
778 async fn wait_for_terminal_exit(
779 &self,
780 args: acp::WaitForTerminalExitRequest,
781 ) -> Result<acp::WaitForTerminalExitResponse, acp::Error> {
782 let exit_status = self
783 .session_thread(&args.session_id)?
784 .update(&mut self.cx.clone(), |thread, cx| {
785 anyhow::Ok(thread.terminal(args.terminal_id)?.read(cx).wait_for_exit())
786 })??
787 .await;
788
789 Ok(acp::WaitForTerminalExitResponse {
790 exit_status,
791 meta: None,
792 })
793 }
794}
795
796impl ClientDelegate {
797 fn session_thread(&self, session_id: &acp::SessionId) -> Result<WeakEntity<AcpThread>> {
798 let sessions = self.sessions.borrow();
799 sessions
800 .get(session_id)
801 .context("Failed to get session")
802 .map(|session| session.thread.clone())
803 }
804}