acp.rs

  1use acp_thread::AgentConnection;
  2use acp_tools::AcpConnectionRegistry;
  3use action_log::ActionLog;
  4use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
  5use anyhow::anyhow;
  6use collections::HashMap;
  7use futures::AsyncBufReadExt as _;
  8use futures::io::BufReader;
  9use project::Project;
 10use project::agent_server_store::AgentServerCommand;
 11use serde::Deserialize;
 12use util::ResultExt as _;
 13
 14use std::path::PathBuf;
 15use std::{any::Any, cell::RefCell};
 16use std::{path::Path, rc::Rc};
 17use thiserror::Error;
 18
 19use anyhow::{Context as _, Result};
 20use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
 21
 22use acp_thread::{AcpThread, AuthRequired, LoadError};
 23
 24#[derive(Debug, Error)]
 25#[error("Unsupported version")]
 26pub struct UnsupportedVersion;
 27
 28pub struct AcpConnection {
 29    server_name: SharedString,
 30    connection: Rc<acp::ClientSideConnection>,
 31    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
 32    auth_methods: Vec<acp::AuthMethod>,
 33    agent_capabilities: acp::AgentCapabilities,
 34    default_mode: Option<acp::SessionModeId>,
 35    root_dir: PathBuf,
 36    // NB: Don't move this into the wait_task, since we need to ensure the process is
 37    // killed on drop (setting kill_on_drop on the command seems to not always work).
 38    child: smol::process::Child,
 39    _io_task: Task<Result<()>>,
 40    _wait_task: Task<Result<()>>,
 41    _stderr_task: Task<Result<()>>,
 42}
 43
 44pub struct AcpSession {
 45    thread: WeakEntity<AcpThread>,
 46    suppress_abort_err: bool,
 47    session_modes: Option<Rc<RefCell<acp::SessionModeState>>>,
 48}
 49
 50pub async fn connect(
 51    server_name: SharedString,
 52    command: AgentServerCommand,
 53    root_dir: &Path,
 54    default_mode: Option<acp::SessionModeId>,
 55    is_remote: bool,
 56    cx: &mut AsyncApp,
 57) -> Result<Rc<dyn AgentConnection>> {
 58    let conn = AcpConnection::stdio(
 59        server_name,
 60        command.clone(),
 61        root_dir,
 62        default_mode,
 63        is_remote,
 64        cx,
 65    )
 66    .await?;
 67    Ok(Rc::new(conn) as _)
 68}
 69
 70const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
 71
 72impl AcpConnection {
 73    pub async fn stdio(
 74        server_name: SharedString,
 75        command: AgentServerCommand,
 76        root_dir: &Path,
 77        default_mode: Option<acp::SessionModeId>,
 78        is_remote: bool,
 79        cx: &mut AsyncApp,
 80    ) -> Result<Self> {
 81        let mut child = util::command::new_smol_command(command.path);
 82        child
 83            .args(command.args.iter().map(|arg| arg.as_str()))
 84            .envs(command.env.iter().flatten())
 85            .stdin(std::process::Stdio::piped())
 86            .stdout(std::process::Stdio::piped())
 87            .stderr(std::process::Stdio::piped());
 88        if !is_remote {
 89            child.current_dir(root_dir);
 90        }
 91        let mut child = child.spawn()?;
 92
 93        let stdout = child.stdout.take().context("Failed to take stdout")?;
 94        let stdin = child.stdin.take().context("Failed to take stdin")?;
 95        let stderr = child.stderr.take().context("Failed to take stderr")?;
 96        log::trace!("Spawned (pid: {})", child.id());
 97
 98        let sessions = Rc::new(RefCell::new(HashMap::default()));
 99
100        let client = ClientDelegate {
101            sessions: sessions.clone(),
102            cx: cx.clone(),
103        };
104        let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
105            let foreground_executor = cx.foreground_executor().clone();
106            move |fut| {
107                foreground_executor.spawn(fut).detach();
108            }
109        });
110
111        let io_task = cx.background_spawn(io_task);
112
113        let stderr_task = cx.background_spawn(async move {
114            let mut stderr = BufReader::new(stderr);
115            let mut line = String::new();
116            while let Ok(n) = stderr.read_line(&mut line).await
117                && n > 0
118            {
119                log::warn!("agent stderr: {}", &line);
120                line.clear();
121            }
122            Ok(())
123        });
124
125        let wait_task = cx.spawn({
126            let sessions = sessions.clone();
127            let status_fut = child.status();
128            async move |cx| {
129                let status = status_fut.await?;
130
131                for session in sessions.borrow().values() {
132                    session
133                        .thread
134                        .update(cx, |thread, cx| {
135                            thread.emit_load_error(LoadError::Exited { status }, cx)
136                        })
137                        .ok();
138                }
139
140                anyhow::Ok(())
141            }
142        });
143
144        let connection = Rc::new(connection);
145
146        cx.update(|cx| {
147            AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
148                registry.set_active_connection(server_name.clone(), &connection, cx)
149            });
150        })?;
151
152        let response = connection
153            .initialize(acp::InitializeRequest {
154                protocol_version: acp::VERSION,
155                client_capabilities: acp::ClientCapabilities {
156                    fs: acp::FileSystemCapability {
157                        read_text_file: true,
158                        write_text_file: true,
159                    },
160                    terminal: true,
161                },
162            })
163            .await?;
164
165        if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
166            return Err(UnsupportedVersion.into());
167        }
168
169        Ok(Self {
170            auth_methods: response.auth_methods,
171            root_dir: root_dir.to_owned(),
172            connection,
173            server_name,
174            sessions,
175            agent_capabilities: response.agent_capabilities,
176            default_mode,
177            _io_task: io_task,
178            _wait_task: wait_task,
179            _stderr_task: stderr_task,
180            child,
181        })
182    }
183
184    pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
185        &self.agent_capabilities.prompt_capabilities
186    }
187
188    pub fn root_dir(&self) -> &Path {
189        &self.root_dir
190    }
191}
192
193impl Drop for AcpConnection {
194    fn drop(&mut self) {
195        // See the comment on the child field.
196        self.child.kill().log_err();
197    }
198}
199
200impl AgentConnection for AcpConnection {
201    fn new_thread(
202        self: Rc<Self>,
203        project: Entity<Project>,
204        cwd: &Path,
205        cx: &mut App,
206    ) -> Task<Result<Entity<AcpThread>>> {
207        let name = self.server_name.clone();
208        let conn = self.connection.clone();
209        let sessions = self.sessions.clone();
210        let default_mode = self.default_mode.clone();
211        let cwd = cwd.to_path_buf();
212        let context_server_store = project.read(cx).context_server_store().read(cx);
213        let mcp_servers = if project.read(cx).is_local() {
214            context_server_store
215                .configured_server_ids()
216                .iter()
217                .filter_map(|id| {
218                    let configuration = context_server_store.configuration_for_server(id)?;
219                    let command = configuration.command();
220                    Some(acp::McpServer::Stdio {
221                        name: id.0.to_string(),
222                        command: command.path.clone(),
223                        args: command.args.clone(),
224                        env: if let Some(env) = command.env.as_ref() {
225                            env.iter()
226                                .map(|(name, value)| acp::EnvVariable {
227                                    name: name.clone(),
228                                    value: value.clone(),
229                                })
230                                .collect()
231                        } else {
232                            vec![]
233                        },
234                    })
235                })
236                .collect()
237        } else {
238            // In SSH projects, the external agent is running on the remote
239            // machine, and currently we only run MCP servers on the local
240            // machine. So don't pass any MCP servers to the agent in that case.
241            Vec::new()
242        };
243
244        cx.spawn(async move |cx| {
245            let response = conn
246                .new_session(acp::NewSessionRequest { mcp_servers, cwd })
247                .await
248                .map_err(|err| {
249                    if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
250                        let mut error = AuthRequired::new();
251
252                        if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
253                            error = error.with_description(err.message);
254                        }
255
256                        anyhow!(error)
257                    } else {
258                        anyhow!(err)
259                    }
260                })?;
261
262            let modes = response.modes.map(|modes| Rc::new(RefCell::new(modes)));
263
264            if let Some(default_mode) = default_mode {
265                if let Some(modes) = modes.as_ref() {
266                    let mut modes_ref = modes.borrow_mut();
267                    let has_mode = modes_ref.available_modes.iter().any(|mode| mode.id == default_mode);
268
269                    if has_mode {
270                        let initial_mode_id = modes_ref.current_mode_id.clone();
271
272                        cx.spawn({
273                            let default_mode = default_mode.clone();
274                            let session_id = response.session_id.clone();
275                            let modes = modes.clone();
276                            async move |_| {
277                                let result = conn.set_session_mode(acp::SetSessionModeRequest {
278                                    session_id,
279                                    mode_id: default_mode,
280                                })
281                                .await.log_err();
282
283                                if result.is_none() {
284                                    modes.borrow_mut().current_mode_id = initial_mode_id;
285                                }
286                            }
287                        }).detach();
288
289                        modes_ref.current_mode_id = default_mode;
290                    } else {
291                        let available_modes = modes_ref
292                            .available_modes
293                            .iter()
294                            .map(|mode| format!("- `{}`: {}", mode.id, mode.name))
295                            .collect::<Vec<_>>()
296                            .join("\n");
297
298                        log::warn!(
299                            "`{default_mode}` is not valid {name} mode. Available options:\n{available_modes}",
300                        );
301                    }
302                } else {
303                    log::warn!(
304                        "`{name}` does not support modes, but `default_mode` was set in settings.",
305                    );
306                }
307            }
308
309            let session_id = response.session_id;
310            let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
311            let thread = cx.new(|cx| {
312                AcpThread::new(
313                    self.server_name.clone(),
314                    self.clone(),
315                    project,
316                    action_log,
317                    session_id.clone(),
318                    // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
319                    watch::Receiver::constant(self.agent_capabilities.prompt_capabilities),
320                    cx,
321                )
322            })?;
323
324            let session = AcpSession {
325                thread: thread.downgrade(),
326                suppress_abort_err: false,
327                session_modes: modes
328            };
329            sessions.borrow_mut().insert(session_id, session);
330
331            Ok(thread)
332        })
333    }
334
335    fn auth_methods(&self) -> &[acp::AuthMethod] {
336        &self.auth_methods
337    }
338
339    fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
340        let conn = self.connection.clone();
341        cx.foreground_executor().spawn(async move {
342            let result = conn
343                .authenticate(acp::AuthenticateRequest {
344                    method_id: method_id.clone(),
345                })
346                .await?;
347
348            Ok(result)
349        })
350    }
351
352    fn prompt(
353        &self,
354        _id: Option<acp_thread::UserMessageId>,
355        params: acp::PromptRequest,
356        cx: &mut App,
357    ) -> Task<Result<acp::PromptResponse>> {
358        let conn = self.connection.clone();
359        let sessions = self.sessions.clone();
360        let session_id = params.session_id.clone();
361        cx.foreground_executor().spawn(async move {
362            let result = conn.prompt(params).await;
363
364            let mut suppress_abort_err = false;
365
366            if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
367                suppress_abort_err = session.suppress_abort_err;
368                session.suppress_abort_err = false;
369            }
370
371            match result {
372                Ok(response) => Ok(response),
373                Err(err) => {
374                    if err.code != ErrorCode::INTERNAL_ERROR.code {
375                        anyhow::bail!(err)
376                    }
377
378                    let Some(data) = &err.data else {
379                        anyhow::bail!(err)
380                    };
381
382                    // Temporary workaround until the following PR is generally available:
383                    // https://github.com/google-gemini/gemini-cli/pull/6656
384
385                    #[derive(Deserialize)]
386                    #[serde(deny_unknown_fields)]
387                    struct ErrorDetails {
388                        details: Box<str>,
389                    }
390
391                    match serde_json::from_value(data.clone()) {
392                        Ok(ErrorDetails { details }) => {
393                            if suppress_abort_err
394                                && (details.contains("This operation was aborted")
395                                    || details.contains("The user aborted a request"))
396                            {
397                                Ok(acp::PromptResponse {
398                                    stop_reason: acp::StopReason::Cancelled,
399                                })
400                            } else {
401                                Err(anyhow!(details))
402                            }
403                        }
404                        Err(_) => Err(anyhow!(err)),
405                    }
406                }
407            }
408        })
409    }
410
411    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
412        if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
413            session.suppress_abort_err = true;
414        }
415        let conn = self.connection.clone();
416        let params = acp::CancelNotification {
417            session_id: session_id.clone(),
418        };
419        cx.foreground_executor()
420            .spawn(async move { conn.cancel(params).await })
421            .detach();
422    }
423
424    fn session_modes(
425        &self,
426        session_id: &acp::SessionId,
427        _cx: &App,
428    ) -> Option<Rc<dyn acp_thread::AgentSessionModes>> {
429        let sessions = self.sessions.clone();
430        let sessions_ref = sessions.borrow();
431        let Some(session) = sessions_ref.get(session_id) else {
432            return None;
433        };
434
435        if let Some(modes) = session.session_modes.as_ref() {
436            Some(Rc::new(AcpSessionModes {
437                connection: self.connection.clone(),
438                session_id: session_id.clone(),
439                state: modes.clone(),
440            }) as _)
441        } else {
442            None
443        }
444    }
445
446    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
447        self
448    }
449}
450
451struct AcpSessionModes {
452    session_id: acp::SessionId,
453    connection: Rc<acp::ClientSideConnection>,
454    state: Rc<RefCell<acp::SessionModeState>>,
455}
456
457impl acp_thread::AgentSessionModes for AcpSessionModes {
458    fn current_mode(&self) -> acp::SessionModeId {
459        self.state.borrow().current_mode_id.clone()
460    }
461
462    fn all_modes(&self) -> Vec<acp::SessionMode> {
463        self.state.borrow().available_modes.clone()
464    }
465
466    fn set_mode(&self, mode_id: acp::SessionModeId, cx: &mut App) -> Task<Result<()>> {
467        let connection = self.connection.clone();
468        let session_id = self.session_id.clone();
469        let old_mode_id;
470        {
471            let mut state = self.state.borrow_mut();
472            old_mode_id = state.current_mode_id.clone();
473            state.current_mode_id = mode_id.clone();
474        };
475        let state = self.state.clone();
476        cx.foreground_executor().spawn(async move {
477            let result = connection
478                .set_session_mode(acp::SetSessionModeRequest {
479                    session_id,
480                    mode_id,
481                })
482                .await;
483
484            if result.is_err() {
485                state.borrow_mut().current_mode_id = old_mode_id;
486            }
487
488            result?;
489
490            Ok(())
491        })
492    }
493}
494
495struct ClientDelegate {
496    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
497    cx: AsyncApp,
498}
499
500impl acp::Client for ClientDelegate {
501    async fn request_permission(
502        &self,
503        arguments: acp::RequestPermissionRequest,
504    ) -> Result<acp::RequestPermissionResponse, acp::Error> {
505        let respect_always_allow_setting;
506        let thread;
507        {
508            let sessions_ref = self.sessions.borrow();
509            let session = sessions_ref
510                .get(&arguments.session_id)
511                .context("Failed to get session")?;
512            respect_always_allow_setting = session.session_modes.is_none();
513            thread = session.thread.clone();
514        }
515
516        let cx = &mut self.cx.clone();
517
518        let task = thread.update(cx, |thread, cx| {
519            thread.request_tool_call_authorization(
520                arguments.tool_call,
521                arguments.options,
522                respect_always_allow_setting,
523                cx,
524            )
525        })??;
526
527        let outcome = task.await;
528
529        Ok(acp::RequestPermissionResponse { outcome })
530    }
531
532    async fn write_text_file(
533        &self,
534        arguments: acp::WriteTextFileRequest,
535    ) -> Result<(), acp::Error> {
536        let cx = &mut self.cx.clone();
537        let task = self
538            .session_thread(&arguments.session_id)?
539            .update(cx, |thread, cx| {
540                thread.write_text_file(arguments.path, arguments.content, cx)
541            })?;
542
543        task.await?;
544
545        Ok(())
546    }
547
548    async fn read_text_file(
549        &self,
550        arguments: acp::ReadTextFileRequest,
551    ) -> Result<acp::ReadTextFileResponse, acp::Error> {
552        let task = self.session_thread(&arguments.session_id)?.update(
553            &mut self.cx.clone(),
554            |thread, cx| {
555                thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
556            },
557        )?;
558
559        let content = task.await?;
560
561        Ok(acp::ReadTextFileResponse { content })
562    }
563
564    async fn session_notification(
565        &self,
566        notification: acp::SessionNotification,
567    ) -> Result<(), acp::Error> {
568        let sessions = self.sessions.borrow();
569        let session = sessions
570            .get(&notification.session_id)
571            .context("Failed to get session")?;
572
573        if let acp::SessionUpdate::CurrentModeUpdate { current_mode_id } = &notification.update {
574            if let Some(session_modes) = &session.session_modes {
575                session_modes.borrow_mut().current_mode_id = current_mode_id.clone();
576            } else {
577                log::error!(
578                    "Got a `CurrentModeUpdate` notification, but they agent didn't specify `modes` during setting setup."
579                );
580            }
581        }
582
583        session.thread.update(&mut self.cx.clone(), |thread, cx| {
584            thread.handle_session_update(notification.update, cx)
585        })??;
586
587        Ok(())
588    }
589
590    async fn create_terminal(
591        &self,
592        args: acp::CreateTerminalRequest,
593    ) -> Result<acp::CreateTerminalResponse, acp::Error> {
594        let terminal = self
595            .session_thread(&args.session_id)?
596            .update(&mut self.cx.clone(), |thread, cx| {
597                thread.create_terminal(
598                    args.command,
599                    args.args,
600                    args.env,
601                    args.cwd,
602                    args.output_byte_limit,
603                    cx,
604                )
605            })?
606            .await?;
607        Ok(
608            terminal.read_with(&self.cx, |terminal, _| acp::CreateTerminalResponse {
609                terminal_id: terminal.id().clone(),
610            })?,
611        )
612    }
613
614    async fn kill_terminal(&self, args: acp::KillTerminalRequest) -> Result<(), acp::Error> {
615        self.session_thread(&args.session_id)?
616            .update(&mut self.cx.clone(), |thread, cx| {
617                thread.kill_terminal(args.terminal_id, cx)
618            })??;
619
620        Ok(())
621    }
622
623    async fn release_terminal(&self, args: acp::ReleaseTerminalRequest) -> Result<(), acp::Error> {
624        self.session_thread(&args.session_id)?
625            .update(&mut self.cx.clone(), |thread, cx| {
626                thread.release_terminal(args.terminal_id, cx)
627            })??;
628
629        Ok(())
630    }
631
632    async fn terminal_output(
633        &self,
634        args: acp::TerminalOutputRequest,
635    ) -> Result<acp::TerminalOutputResponse, acp::Error> {
636        self.session_thread(&args.session_id)?
637            .read_with(&mut self.cx.clone(), |thread, cx| {
638                let out = thread
639                    .terminal(args.terminal_id)?
640                    .read(cx)
641                    .current_output(cx);
642
643                Ok(out)
644            })?
645    }
646
647    async fn wait_for_terminal_exit(
648        &self,
649        args: acp::WaitForTerminalExitRequest,
650    ) -> Result<acp::WaitForTerminalExitResponse, acp::Error> {
651        let exit_status = self
652            .session_thread(&args.session_id)?
653            .update(&mut self.cx.clone(), |thread, cx| {
654                anyhow::Ok(thread.terminal(args.terminal_id)?.read(cx).wait_for_exit())
655            })??
656            .await;
657
658        Ok(acp::WaitForTerminalExitResponse { exit_status })
659    }
660}
661
662impl ClientDelegate {
663    fn session_thread(&self, session_id: &acp::SessionId) -> Result<WeakEntity<AcpThread>> {
664        let sessions = self.sessions.borrow();
665        sessions
666            .get(session_id)
667            .context("Failed to get session")
668            .map(|session| session.thread.clone())
669    }
670}