acp.rs

  1use acp_thread::AgentConnection;
  2use acp_tools::AcpConnectionRegistry;
  3use action_log::ActionLog;
  4use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
  5use anyhow::anyhow;
  6use collections::HashMap;
  7use futures::AsyncBufReadExt as _;
  8use futures::io::BufReader;
  9use project::Project;
 10use project::agent_server_store::AgentServerCommand;
 11use serde::Deserialize;
 12use util::ResultExt as _;
 13
 14use std::path::PathBuf;
 15use std::{any::Any, cell::RefCell};
 16use std::{path::Path, rc::Rc, sync::Arc};
 17use thiserror::Error;
 18
 19use anyhow::{Context as _, Result};
 20use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
 21
 22use acp_thread::{AcpThread, AuthRequired, LoadError};
 23
 24#[derive(Debug, Error)]
 25#[error("Unsupported version")]
 26pub struct UnsupportedVersion;
 27
 28pub struct AcpConnection {
 29    server_name: SharedString,
 30    connection: Rc<acp::ClientSideConnection>,
 31    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
 32    auth_methods: Vec<acp::AuthMethod>,
 33    agent_capabilities: acp::AgentCapabilities,
 34    default_mode: Option<acp::SessionModeId>,
 35    root_dir: PathBuf,
 36    // NB: Don't move this into the wait_task, since we need to ensure the process is
 37    // killed on drop (setting kill_on_drop on the command seems to not always work).
 38    child: smol::process::Child,
 39    _io_task: Task<Result<()>>,
 40    _wait_task: Task<Result<()>>,
 41    _stderr_task: Task<Result<()>>,
 42}
 43
 44pub struct AcpSession {
 45    thread: WeakEntity<AcpThread>,
 46    suppress_abort_err: bool,
 47    session_modes: Option<Rc<RefCell<acp::SessionModeState>>>,
 48}
 49
 50pub async fn connect(
 51    server_name: SharedString,
 52    command: AgentServerCommand,
 53    root_dir: &Path,
 54    default_mode: Option<acp::SessionModeId>,
 55    is_remote: bool,
 56    cx: &mut AsyncApp,
 57) -> Result<Rc<dyn AgentConnection>> {
 58    let conn = AcpConnection::stdio(
 59        server_name,
 60        command.clone(),
 61        root_dir,
 62        default_mode,
 63        is_remote,
 64        cx,
 65    )
 66    .await?;
 67    Ok(Rc::new(conn) as _)
 68}
 69
 70const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
 71
 72impl AcpConnection {
 73    pub async fn stdio(
 74        server_name: SharedString,
 75        command: AgentServerCommand,
 76        root_dir: &Path,
 77        default_mode: Option<acp::SessionModeId>,
 78        is_remote: bool,
 79        cx: &mut AsyncApp,
 80    ) -> Result<Self> {
 81        let mut child = util::command::new_smol_command(command.path);
 82        child
 83            .args(command.args.iter().map(|arg| arg.as_str()))
 84            .envs(command.env.iter().flatten())
 85            .stdin(std::process::Stdio::piped())
 86            .stdout(std::process::Stdio::piped())
 87            .stderr(std::process::Stdio::piped());
 88        if !is_remote {
 89            child.current_dir(root_dir);
 90        }
 91        let mut child = child.spawn()?;
 92
 93        let stdout = child.stdout.take().context("Failed to take stdout")?;
 94        let stdin = child.stdin.take().context("Failed to take stdin")?;
 95        let stderr = child.stderr.take().context("Failed to take stderr")?;
 96        log::trace!("Spawned (pid: {})", child.id());
 97
 98        let sessions = Rc::new(RefCell::new(HashMap::default()));
 99
100        let client = ClientDelegate {
101            sessions: sessions.clone(),
102            cx: cx.clone(),
103        };
104        let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
105            let foreground_executor = cx.foreground_executor().clone();
106            move |fut| {
107                foreground_executor.spawn(fut).detach();
108            }
109        });
110
111        let io_task = cx.background_spawn(io_task);
112
113        let stderr_task = cx.background_spawn(async move {
114            let mut stderr = BufReader::new(stderr);
115            let mut line = String::new();
116            while let Ok(n) = stderr.read_line(&mut line).await
117                && n > 0
118            {
119                log::warn!("agent stderr: {}", &line);
120                line.clear();
121            }
122            Ok(())
123        });
124
125        let wait_task = cx.spawn({
126            let sessions = sessions.clone();
127            let status_fut = child.status();
128            async move |cx| {
129                let status = status_fut.await?;
130
131                for session in sessions.borrow().values() {
132                    session
133                        .thread
134                        .update(cx, |thread, cx| {
135                            thread.emit_load_error(LoadError::Exited { status }, cx)
136                        })
137                        .ok();
138                }
139
140                anyhow::Ok(())
141            }
142        });
143
144        let connection = Rc::new(connection);
145
146        cx.update(|cx| {
147            AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
148                registry.set_active_connection(server_name.clone(), &connection, cx)
149            });
150        })?;
151
152        let response = connection
153            .initialize(acp::InitializeRequest {
154                protocol_version: acp::VERSION,
155                client_capabilities: acp::ClientCapabilities {
156                    fs: acp::FileSystemCapability {
157                        read_text_file: true,
158                        write_text_file: true,
159                        meta: None,
160                    },
161                    terminal: true,
162                    meta: None,
163                },
164                meta: None,
165            })
166            .await?;
167
168        if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
169            return Err(UnsupportedVersion.into());
170        }
171
172        Ok(Self {
173            auth_methods: response.auth_methods,
174            root_dir: root_dir.to_owned(),
175            connection,
176            server_name,
177            sessions,
178            agent_capabilities: response.agent_capabilities,
179            default_mode,
180            _io_task: io_task,
181            _wait_task: wait_task,
182            _stderr_task: stderr_task,
183            child,
184        })
185    }
186
187    pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
188        &self.agent_capabilities.prompt_capabilities
189    }
190
191    pub fn root_dir(&self) -> &Path {
192        &self.root_dir
193    }
194}
195
196impl Drop for AcpConnection {
197    fn drop(&mut self) {
198        // See the comment on the child field.
199        self.child.kill().log_err();
200    }
201}
202
203impl AgentConnection for AcpConnection {
204    fn new_thread(
205        self: Rc<Self>,
206        project: Entity<Project>,
207        cwd: &Path,
208        cx: &mut App,
209    ) -> Task<Result<Entity<AcpThread>>> {
210        let name = self.server_name.clone();
211        let conn = self.connection.clone();
212        let sessions = self.sessions.clone();
213        let default_mode = self.default_mode.clone();
214        let cwd = cwd.to_path_buf();
215        let context_server_store = project.read(cx).context_server_store().read(cx);
216        let mcp_servers = if project.read(cx).is_local() {
217            context_server_store
218                .configured_server_ids()
219                .iter()
220                .filter_map(|id| {
221                    let configuration = context_server_store.configuration_for_server(id)?;
222                    let command = configuration.command();
223                    Some(acp::McpServer::Stdio {
224                        name: id.0.to_string(),
225                        command: command.path.clone(),
226                        args: command.args.clone(),
227                        env: if let Some(env) = command.env.as_ref() {
228                            env.iter()
229                                .map(|(name, value)| acp::EnvVariable {
230                                    name: name.clone(),
231                                    value: value.clone(),
232                                    meta: None,
233                                })
234                                .collect()
235                        } else {
236                            vec![]
237                        },
238                    })
239                })
240                .collect()
241        } else {
242            // In SSH projects, the external agent is running on the remote
243            // machine, and currently we only run MCP servers on the local
244            // machine. So don't pass any MCP servers to the agent in that case.
245            Vec::new()
246        };
247
248        cx.spawn(async move |cx| {
249            let response = conn
250                .new_session(acp::NewSessionRequest { mcp_servers, cwd, meta: None })
251                .await
252                .map_err(|err| {
253                    if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
254                        let mut error = AuthRequired::new();
255
256                        if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
257                            error = error.with_description(err.message);
258                        }
259
260                        anyhow!(error)
261                    } else {
262                        anyhow!(err)
263                    }
264                })?;
265
266            let modes = response.modes.map(|modes| Rc::new(RefCell::new(modes)));
267
268            if let Some(default_mode) = default_mode {
269                if let Some(modes) = modes.as_ref() {
270                    let mut modes_ref = modes.borrow_mut();
271                    let has_mode = modes_ref.available_modes.iter().any(|mode| mode.id == default_mode);
272
273                    if has_mode {
274                        let initial_mode_id = modes_ref.current_mode_id.clone();
275
276                        cx.spawn({
277                            let default_mode = default_mode.clone();
278                            let session_id = response.session_id.clone();
279                            let modes = modes.clone();
280                            async move |_| {
281                                let result = conn.set_session_mode(acp::SetSessionModeRequest {
282                                    session_id,
283                                    mode_id: default_mode,
284                                    meta: None,
285                                })
286                                .await.log_err();
287
288                                if result.is_none() {
289                                    modes.borrow_mut().current_mode_id = initial_mode_id;
290                                }
291                            }
292                        }).detach();
293
294                        modes_ref.current_mode_id = default_mode;
295                    } else {
296                        let available_modes = modes_ref
297                            .available_modes
298                            .iter()
299                            .map(|mode| format!("- `{}`: {}", mode.id, mode.name))
300                            .collect::<Vec<_>>()
301                            .join("\n");
302
303                        log::warn!(
304                            "`{default_mode}` is not valid {name} mode. Available options:\n{available_modes}",
305                        );
306                    }
307                } else {
308                    log::warn!(
309                        "`{name}` does not support modes, but `default_mode` was set in settings.",
310                    );
311                }
312            }
313
314            let session_id = response.session_id;
315            let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
316            let thread = cx.new(|cx| {
317                AcpThread::new(
318                    self.server_name.clone(),
319                    self.clone(),
320                    project,
321                    action_log,
322                    session_id.clone(),
323                    // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
324                    watch::Receiver::constant(self.agent_capabilities.prompt_capabilities.clone()),
325                    cx,
326                )
327            })?;
328
329            let session = AcpSession {
330                thread: thread.downgrade(),
331                suppress_abort_err: false,
332                session_modes: modes
333            };
334            sessions.borrow_mut().insert(session_id, session);
335
336            Ok(thread)
337        })
338    }
339
340    fn auth_methods(&self) -> &[acp::AuthMethod] {
341        &self.auth_methods
342    }
343
344    fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
345        let conn = self.connection.clone();
346        cx.foreground_executor().spawn(async move {
347            conn.authenticate(acp::AuthenticateRequest {
348                method_id: method_id.clone(),
349                meta: None,
350            })
351            .await?;
352
353            Ok(())
354        })
355    }
356
357    fn prompt(
358        &self,
359        _id: Option<acp_thread::UserMessageId>,
360        params: acp::PromptRequest,
361        cx: &mut App,
362    ) -> Task<Result<acp::PromptResponse>> {
363        let conn = self.connection.clone();
364        let sessions = self.sessions.clone();
365        let session_id = params.session_id.clone();
366        cx.foreground_executor().spawn(async move {
367            let result = conn.prompt(params).await;
368
369            let mut suppress_abort_err = false;
370
371            if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
372                suppress_abort_err = session.suppress_abort_err;
373                session.suppress_abort_err = false;
374            }
375
376            match result {
377                Ok(response) => Ok(response),
378                Err(err) => {
379                    if err.code != ErrorCode::INTERNAL_ERROR.code {
380                        anyhow::bail!(err)
381                    }
382
383                    let Some(data) = &err.data else {
384                        anyhow::bail!(err)
385                    };
386
387                    // Temporary workaround until the following PR is generally available:
388                    // https://github.com/google-gemini/gemini-cli/pull/6656
389
390                    #[derive(Deserialize)]
391                    #[serde(deny_unknown_fields)]
392                    struct ErrorDetails {
393                        details: Box<str>,
394                    }
395
396                    match serde_json::from_value(data.clone()) {
397                        Ok(ErrorDetails { details }) => {
398                            if suppress_abort_err
399                                && (details.contains("This operation was aborted")
400                                    || details.contains("The user aborted a request"))
401                            {
402                                Ok(acp::PromptResponse {
403                                    stop_reason: acp::StopReason::Cancelled,
404                                    meta: None,
405                                })
406                            } else {
407                                Err(anyhow!(details))
408                            }
409                        }
410                        Err(_) => Err(anyhow!(err)),
411                    }
412                }
413            }
414        })
415    }
416
417    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
418        if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
419            session.suppress_abort_err = true;
420        }
421        let conn = self.connection.clone();
422        let params = acp::CancelNotification {
423            session_id: session_id.clone(),
424            meta: None,
425        };
426        cx.foreground_executor()
427            .spawn(async move { conn.cancel(params).await })
428            .detach();
429    }
430
431    fn session_modes(
432        &self,
433        session_id: &acp::SessionId,
434        _cx: &App,
435    ) -> Option<Rc<dyn acp_thread::AgentSessionModes>> {
436        let sessions = self.sessions.clone();
437        let sessions_ref = sessions.borrow();
438        let Some(session) = sessions_ref.get(session_id) else {
439            return None;
440        };
441
442        if let Some(modes) = session.session_modes.as_ref() {
443            Some(Rc::new(AcpSessionModes {
444                connection: self.connection.clone(),
445                session_id: session_id.clone(),
446                state: modes.clone(),
447            }) as _)
448        } else {
449            None
450        }
451    }
452
453    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
454        self
455    }
456}
457
458struct AcpSessionModes {
459    session_id: acp::SessionId,
460    connection: Rc<acp::ClientSideConnection>,
461    state: Rc<RefCell<acp::SessionModeState>>,
462}
463
464impl acp_thread::AgentSessionModes for AcpSessionModes {
465    fn current_mode(&self) -> acp::SessionModeId {
466        self.state.borrow().current_mode_id.clone()
467    }
468
469    fn all_modes(&self) -> Vec<acp::SessionMode> {
470        self.state.borrow().available_modes.clone()
471    }
472
473    fn set_mode(&self, mode_id: acp::SessionModeId, cx: &mut App) -> Task<Result<()>> {
474        let connection = self.connection.clone();
475        let session_id = self.session_id.clone();
476        let old_mode_id;
477        {
478            let mut state = self.state.borrow_mut();
479            old_mode_id = state.current_mode_id.clone();
480            state.current_mode_id = mode_id.clone();
481        };
482        let state = self.state.clone();
483        cx.foreground_executor().spawn(async move {
484            let result = connection
485                .set_session_mode(acp::SetSessionModeRequest {
486                    session_id,
487                    mode_id,
488                    meta: None,
489                })
490                .await;
491
492            if result.is_err() {
493                state.borrow_mut().current_mode_id = old_mode_id;
494            }
495
496            result?;
497
498            Ok(())
499        })
500    }
501}
502
503struct ClientDelegate {
504    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
505    cx: AsyncApp,
506}
507
508impl acp::Client for ClientDelegate {
509    async fn request_permission(
510        &self,
511        arguments: acp::RequestPermissionRequest,
512    ) -> Result<acp::RequestPermissionResponse, acp::Error> {
513        let respect_always_allow_setting;
514        let thread;
515        {
516            let sessions_ref = self.sessions.borrow();
517            let session = sessions_ref
518                .get(&arguments.session_id)
519                .context("Failed to get session")?;
520            respect_always_allow_setting = session.session_modes.is_none();
521            thread = session.thread.clone();
522        }
523
524        let cx = &mut self.cx.clone();
525
526        let task = thread.update(cx, |thread, cx| {
527            thread.request_tool_call_authorization(
528                arguments.tool_call,
529                arguments.options,
530                respect_always_allow_setting,
531                cx,
532            )
533        })??;
534
535        let outcome = task.await;
536
537        Ok(acp::RequestPermissionResponse {
538            outcome,
539            meta: None,
540        })
541    }
542
543    async fn write_text_file(
544        &self,
545        arguments: acp::WriteTextFileRequest,
546    ) -> Result<acp::WriteTextFileResponse, acp::Error> {
547        let cx = &mut self.cx.clone();
548        let task = self
549            .session_thread(&arguments.session_id)?
550            .update(cx, |thread, cx| {
551                thread.write_text_file(arguments.path, arguments.content, cx)
552            })?;
553
554        task.await?;
555
556        Ok(Default::default())
557    }
558
559    async fn read_text_file(
560        &self,
561        arguments: acp::ReadTextFileRequest,
562    ) -> Result<acp::ReadTextFileResponse, acp::Error> {
563        let task = self.session_thread(&arguments.session_id)?.update(
564            &mut self.cx.clone(),
565            |thread, cx| {
566                thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
567            },
568        )?;
569
570        let content = task.await?;
571
572        Ok(acp::ReadTextFileResponse {
573            content,
574            meta: None,
575        })
576    }
577
578    async fn session_notification(
579        &self,
580        notification: acp::SessionNotification,
581    ) -> Result<(), acp::Error> {
582        let sessions = self.sessions.borrow();
583        let session = sessions
584            .get(&notification.session_id)
585            .context("Failed to get session")?;
586
587        if let acp::SessionUpdate::CurrentModeUpdate { current_mode_id } = &notification.update {
588            if let Some(session_modes) = &session.session_modes {
589                session_modes.borrow_mut().current_mode_id = current_mode_id.clone();
590            } else {
591                log::error!(
592                    "Got a `CurrentModeUpdate` notification, but they agent didn't specify `modes` during setting setup."
593                );
594            }
595        }
596
597        session.thread.update(&mut self.cx.clone(), |thread, cx| {
598            thread.handle_session_update(notification.update, cx)
599        })??;
600
601        Ok(())
602    }
603
604    async fn create_terminal(
605        &self,
606        args: acp::CreateTerminalRequest,
607    ) -> Result<acp::CreateTerminalResponse, acp::Error> {
608        let terminal = self
609            .session_thread(&args.session_id)?
610            .update(&mut self.cx.clone(), |thread, cx| {
611                thread.create_terminal(
612                    args.command,
613                    args.args,
614                    args.env,
615                    args.cwd,
616                    args.output_byte_limit,
617                    cx,
618                )
619            })?
620            .await?;
621        Ok(
622            terminal.read_with(&self.cx, |terminal, _| acp::CreateTerminalResponse {
623                terminal_id: terminal.id().clone(),
624                meta: None,
625            })?,
626        )
627    }
628
629    async fn kill_terminal_command(
630        &self,
631        args: acp::KillTerminalCommandRequest,
632    ) -> Result<acp::KillTerminalCommandResponse, acp::Error> {
633        self.session_thread(&args.session_id)?
634            .update(&mut self.cx.clone(), |thread, cx| {
635                thread.kill_terminal(args.terminal_id, cx)
636            })??;
637
638        Ok(Default::default())
639    }
640
641    async fn ext_method(
642        &self,
643        _name: Arc<str>,
644        _params: Arc<serde_json::value::RawValue>,
645    ) -> Result<Arc<serde_json::value::RawValue>, acp::Error> {
646        Err(acp::Error::method_not_found())
647    }
648
649    async fn ext_notification(
650        &self,
651        _name: Arc<str>,
652        _params: Arc<serde_json::value::RawValue>,
653    ) -> Result<(), acp::Error> {
654        Err(acp::Error::method_not_found())
655    }
656
657    async fn release_terminal(
658        &self,
659        args: acp::ReleaseTerminalRequest,
660    ) -> Result<acp::ReleaseTerminalResponse, acp::Error> {
661        self.session_thread(&args.session_id)?
662            .update(&mut self.cx.clone(), |thread, cx| {
663                thread.release_terminal(args.terminal_id, cx)
664            })??;
665
666        Ok(Default::default())
667    }
668
669    async fn terminal_output(
670        &self,
671        args: acp::TerminalOutputRequest,
672    ) -> Result<acp::TerminalOutputResponse, acp::Error> {
673        self.session_thread(&args.session_id)?
674            .read_with(&mut self.cx.clone(), |thread, cx| {
675                let out = thread
676                    .terminal(args.terminal_id)?
677                    .read(cx)
678                    .current_output(cx);
679
680                Ok(out)
681            })?
682    }
683
684    async fn wait_for_terminal_exit(
685        &self,
686        args: acp::WaitForTerminalExitRequest,
687    ) -> Result<acp::WaitForTerminalExitResponse, acp::Error> {
688        let exit_status = self
689            .session_thread(&args.session_id)?
690            .update(&mut self.cx.clone(), |thread, cx| {
691                anyhow::Ok(thread.terminal(args.terminal_id)?.read(cx).wait_for_exit())
692            })??
693            .await;
694
695        Ok(acp::WaitForTerminalExitResponse {
696            exit_status,
697            meta: None,
698        })
699    }
700}
701
702impl ClientDelegate {
703    fn session_thread(&self, session_id: &acp::SessionId) -> Result<WeakEntity<AcpThread>> {
704        let sessions = self.sessions.borrow();
705        sessions
706            .get(session_id)
707            .context("Failed to get session")
708            .map(|session| session.thread.clone())
709    }
710}