acp.rs

  1use acp_thread::AgentConnection;
  2use acp_tools::AcpConnectionRegistry;
  3use action_log::ActionLog;
  4use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
  5use anyhow::anyhow;
  6use collections::HashMap;
  7use futures::AsyncBufReadExt as _;
  8use futures::io::BufReader;
  9use project::Project;
 10use project::agent_server_store::AgentServerCommand;
 11use serde::Deserialize;
 12use util::ResultExt;
 13
 14use std::path::PathBuf;
 15use std::{any::Any, cell::RefCell};
 16use std::{path::Path, rc::Rc};
 17use thiserror::Error;
 18
 19use anyhow::{Context as _, Result};
 20use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
 21
 22use acp_thread::{AcpThread, AuthRequired, LoadError};
 23
 24#[derive(Debug, Error)]
 25#[error("Unsupported version")]
 26pub struct UnsupportedVersion;
 27
 28pub struct AcpConnection {
 29    server_name: SharedString,
 30    connection: Rc<acp::ClientSideConnection>,
 31    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
 32    auth_methods: Vec<acp::AuthMethod>,
 33    agent_capabilities: acp::AgentCapabilities,
 34    default_mode: Option<acp::SessionModeId>,
 35    root_dir: PathBuf,
 36    _io_task: Task<Result<()>>,
 37    _wait_task: Task<Result<()>>,
 38    _stderr_task: Task<Result<()>>,
 39}
 40
 41pub struct AcpSession {
 42    thread: WeakEntity<AcpThread>,
 43    suppress_abort_err: bool,
 44    session_modes: Option<Rc<RefCell<acp::SessionModeState>>>,
 45}
 46
 47pub async fn connect(
 48    server_name: SharedString,
 49    command: AgentServerCommand,
 50    root_dir: &Path,
 51    default_mode: Option<acp::SessionModeId>,
 52    is_remote: bool,
 53    cx: &mut AsyncApp,
 54) -> Result<Rc<dyn AgentConnection>> {
 55    let conn = AcpConnection::stdio(
 56        server_name,
 57        command.clone(),
 58        root_dir,
 59        default_mode,
 60        is_remote,
 61        cx,
 62    )
 63    .await?;
 64    Ok(Rc::new(conn) as _)
 65}
 66
 67const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
 68
 69impl AcpConnection {
 70    pub async fn stdio(
 71        server_name: SharedString,
 72        command: AgentServerCommand,
 73        root_dir: &Path,
 74        default_mode: Option<acp::SessionModeId>,
 75        is_remote: bool,
 76        cx: &mut AsyncApp,
 77    ) -> Result<Self> {
 78        let mut child = util::command::new_smol_command(command.path);
 79        child
 80            .args(command.args.iter().map(|arg| arg.as_str()))
 81            .envs(command.env.iter().flatten())
 82            .stdin(std::process::Stdio::piped())
 83            .stdout(std::process::Stdio::piped())
 84            .stderr(std::process::Stdio::piped())
 85            .kill_on_drop(true);
 86        if !is_remote {
 87            child.current_dir(root_dir);
 88        }
 89        let mut child = child.spawn()?;
 90
 91        let stdout = child.stdout.take().context("Failed to take stdout")?;
 92        let stdin = child.stdin.take().context("Failed to take stdin")?;
 93        let stderr = child.stderr.take().context("Failed to take stderr")?;
 94        log::trace!("Spawned (pid: {})", child.id());
 95
 96        let sessions = Rc::new(RefCell::new(HashMap::default()));
 97
 98        let client = ClientDelegate {
 99            sessions: sessions.clone(),
100            cx: cx.clone(),
101        };
102        let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
103            let foreground_executor = cx.foreground_executor().clone();
104            move |fut| {
105                foreground_executor.spawn(fut).detach();
106            }
107        });
108
109        let io_task = cx.background_spawn(io_task);
110
111        let stderr_task = cx.background_spawn(async move {
112            let mut stderr = BufReader::new(stderr);
113            let mut line = String::new();
114            while let Ok(n) = stderr.read_line(&mut line).await
115                && n > 0
116            {
117                log::warn!("agent stderr: {}", &line);
118                line.clear();
119            }
120            Ok(())
121        });
122
123        let wait_task = cx.spawn({
124            let sessions = sessions.clone();
125            async move |cx| {
126                let status = child.status().await?;
127
128                for session in sessions.borrow().values() {
129                    session
130                        .thread
131                        .update(cx, |thread, cx| {
132                            thread.emit_load_error(LoadError::Exited { status }, cx)
133                        })
134                        .ok();
135                }
136
137                anyhow::Ok(())
138            }
139        });
140
141        let connection = Rc::new(connection);
142
143        cx.update(|cx| {
144            AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
145                registry.set_active_connection(server_name.clone(), &connection, cx)
146            });
147        })?;
148
149        let response = connection
150            .initialize(acp::InitializeRequest {
151                protocol_version: acp::VERSION,
152                client_capabilities: acp::ClientCapabilities {
153                    fs: acp::FileSystemCapability {
154                        read_text_file: true,
155                        write_text_file: true,
156                    },
157                    terminal: true,
158                },
159            })
160            .await?;
161
162        if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
163            return Err(UnsupportedVersion.into());
164        }
165
166        Ok(Self {
167            auth_methods: response.auth_methods,
168            root_dir: root_dir.to_owned(),
169            connection,
170            server_name,
171            sessions,
172            agent_capabilities: response.agent_capabilities,
173            default_mode,
174            _io_task: io_task,
175            _wait_task: wait_task,
176            _stderr_task: stderr_task,
177        })
178    }
179
180    pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
181        &self.agent_capabilities.prompt_capabilities
182    }
183
184    pub fn root_dir(&self) -> &Path {
185        &self.root_dir
186    }
187}
188
189impl AgentConnection for AcpConnection {
190    fn new_thread(
191        self: Rc<Self>,
192        project: Entity<Project>,
193        cwd: &Path,
194        cx: &mut App,
195    ) -> Task<Result<Entity<AcpThread>>> {
196        let name = self.server_name.clone();
197        let conn = self.connection.clone();
198        let sessions = self.sessions.clone();
199        let default_mode = self.default_mode.clone();
200        let cwd = cwd.to_path_buf();
201        let context_server_store = project.read(cx).context_server_store().read(cx);
202        let mcp_servers = if project.read(cx).is_local() {
203            context_server_store
204                .configured_server_ids()
205                .iter()
206                .filter_map(|id| {
207                    let configuration = context_server_store.configuration_for_server(id)?;
208                    let command = configuration.command();
209                    Some(acp::McpServer::Stdio {
210                        name: id.0.to_string(),
211                        command: command.path.clone(),
212                        args: command.args.clone(),
213                        env: if let Some(env) = command.env.as_ref() {
214                            env.iter()
215                                .map(|(name, value)| acp::EnvVariable {
216                                    name: name.clone(),
217                                    value: value.clone(),
218                                })
219                                .collect()
220                        } else {
221                            vec![]
222                        },
223                    })
224                })
225                .collect()
226        } else {
227            // In SSH projects, the external agent is running on the remote
228            // machine, and currently we only run MCP servers on the local
229            // machine. So don't pass any MCP servers to the agent in that case.
230            Vec::new()
231        };
232
233        cx.spawn(async move |cx| {
234            let response = conn
235                .new_session(acp::NewSessionRequest { mcp_servers, cwd })
236                .await
237                .map_err(|err| {
238                    if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
239                        let mut error = AuthRequired::new();
240
241                        if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
242                            error = error.with_description(err.message);
243                        }
244
245                        anyhow!(error)
246                    } else {
247                        anyhow!(err)
248                    }
249                })?;
250
251            let modes = response.modes.map(|modes| Rc::new(RefCell::new(modes)));
252
253            if let Some(default_mode) = default_mode {
254                if let Some(modes) = modes.as_ref() {
255                    let mut modes_ref = modes.borrow_mut();
256                    let has_mode = modes_ref.available_modes.iter().any(|mode| mode.id == default_mode);
257
258                    if has_mode {
259                        let initial_mode_id = modes_ref.current_mode_id.clone();
260
261                        cx.spawn({
262                            let default_mode = default_mode.clone();
263                            let session_id = response.session_id.clone();
264                            let modes = modes.clone();
265                            async move |_| {
266                                let result = conn.set_session_mode(acp::SetSessionModeRequest {
267                                    session_id,
268                                    mode_id: default_mode,
269                                })
270                                .await.log_err();
271
272                                if result.is_none() {
273                                    modes.borrow_mut().current_mode_id = initial_mode_id;
274                                }
275                            }
276                        }).detach();
277
278                        modes_ref.current_mode_id = default_mode;
279                    } else {
280                        let available_modes = modes_ref
281                            .available_modes
282                            .iter()
283                            .map(|mode| format!("- `{}`: {}", mode.id, mode.name))
284                            .collect::<Vec<_>>()
285                            .join("\n");
286
287                        log::warn!(
288                            "`{default_mode}` is not valid {name} mode. Available options:\n{available_modes}",
289                        );
290                    }
291                } else {
292                    log::warn!(
293                        "`{name}` does not support modes, but `default_mode` was set in settings.",
294                    );
295                }
296            }
297
298            let session_id = response.session_id;
299            let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
300            let thread = cx.new(|cx| {
301                AcpThread::new(
302                    self.server_name.clone(),
303                    self.clone(),
304                    project,
305                    action_log,
306                    session_id.clone(),
307                    // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
308                    watch::Receiver::constant(self.agent_capabilities.prompt_capabilities),
309                    cx,
310                )
311            })?;
312
313            let session = AcpSession {
314                thread: thread.downgrade(),
315                suppress_abort_err: false,
316                session_modes: modes
317            };
318            sessions.borrow_mut().insert(session_id, session);
319
320            Ok(thread)
321        })
322    }
323
324    fn auth_methods(&self) -> &[acp::AuthMethod] {
325        &self.auth_methods
326    }
327
328    fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
329        let conn = self.connection.clone();
330        cx.foreground_executor().spawn(async move {
331            let result = conn
332                .authenticate(acp::AuthenticateRequest {
333                    method_id: method_id.clone(),
334                })
335                .await?;
336
337            Ok(result)
338        })
339    }
340
341    fn prompt(
342        &self,
343        _id: Option<acp_thread::UserMessageId>,
344        params: acp::PromptRequest,
345        cx: &mut App,
346    ) -> Task<Result<acp::PromptResponse>> {
347        let conn = self.connection.clone();
348        let sessions = self.sessions.clone();
349        let session_id = params.session_id.clone();
350        cx.foreground_executor().spawn(async move {
351            let result = conn.prompt(params).await;
352
353            let mut suppress_abort_err = false;
354
355            if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
356                suppress_abort_err = session.suppress_abort_err;
357                session.suppress_abort_err = false;
358            }
359
360            match result {
361                Ok(response) => Ok(response),
362                Err(err) => {
363                    if err.code != ErrorCode::INTERNAL_ERROR.code {
364                        anyhow::bail!(err)
365                    }
366
367                    let Some(data) = &err.data else {
368                        anyhow::bail!(err)
369                    };
370
371                    // Temporary workaround until the following PR is generally available:
372                    // https://github.com/google-gemini/gemini-cli/pull/6656
373
374                    #[derive(Deserialize)]
375                    #[serde(deny_unknown_fields)]
376                    struct ErrorDetails {
377                        details: Box<str>,
378                    }
379
380                    match serde_json::from_value(data.clone()) {
381                        Ok(ErrorDetails { details }) => {
382                            if suppress_abort_err
383                                && (details.contains("This operation was aborted")
384                                    || details.contains("The user aborted a request"))
385                            {
386                                Ok(acp::PromptResponse {
387                                    stop_reason: acp::StopReason::Cancelled,
388                                })
389                            } else {
390                                Err(anyhow!(details))
391                            }
392                        }
393                        Err(_) => Err(anyhow!(err)),
394                    }
395                }
396            }
397        })
398    }
399
400    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
401        if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
402            session.suppress_abort_err = true;
403        }
404        let conn = self.connection.clone();
405        let params = acp::CancelNotification {
406            session_id: session_id.clone(),
407        };
408        cx.foreground_executor()
409            .spawn(async move { conn.cancel(params).await })
410            .detach();
411    }
412
413    fn session_modes(
414        &self,
415        session_id: &acp::SessionId,
416        _cx: &App,
417    ) -> Option<Rc<dyn acp_thread::AgentSessionModes>> {
418        let sessions = self.sessions.clone();
419        let sessions_ref = sessions.borrow();
420        let Some(session) = sessions_ref.get(session_id) else {
421            return None;
422        };
423
424        if let Some(modes) = session.session_modes.as_ref() {
425            Some(Rc::new(AcpSessionModes {
426                connection: self.connection.clone(),
427                session_id: session_id.clone(),
428                state: modes.clone(),
429            }) as _)
430        } else {
431            None
432        }
433    }
434
435    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
436        self
437    }
438}
439
440struct AcpSessionModes {
441    session_id: acp::SessionId,
442    connection: Rc<acp::ClientSideConnection>,
443    state: Rc<RefCell<acp::SessionModeState>>,
444}
445
446impl acp_thread::AgentSessionModes for AcpSessionModes {
447    fn current_mode(&self) -> acp::SessionModeId {
448        self.state.borrow().current_mode_id.clone()
449    }
450
451    fn all_modes(&self) -> Vec<acp::SessionMode> {
452        self.state.borrow().available_modes.clone()
453    }
454
455    fn set_mode(&self, mode_id: acp::SessionModeId, cx: &mut App) -> Task<Result<()>> {
456        let connection = self.connection.clone();
457        let session_id = self.session_id.clone();
458        let old_mode_id;
459        {
460            let mut state = self.state.borrow_mut();
461            old_mode_id = state.current_mode_id.clone();
462            state.current_mode_id = mode_id.clone();
463        };
464        let state = self.state.clone();
465        cx.foreground_executor().spawn(async move {
466            let result = connection
467                .set_session_mode(acp::SetSessionModeRequest {
468                    session_id,
469                    mode_id,
470                })
471                .await;
472
473            if result.is_err() {
474                state.borrow_mut().current_mode_id = old_mode_id;
475            }
476
477            result?;
478
479            Ok(())
480        })
481    }
482}
483
484struct ClientDelegate {
485    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
486    cx: AsyncApp,
487}
488
489impl acp::Client for ClientDelegate {
490    async fn request_permission(
491        &self,
492        arguments: acp::RequestPermissionRequest,
493    ) -> Result<acp::RequestPermissionResponse, acp::Error> {
494        let respect_always_allow_setting;
495        let thread;
496        {
497            let sessions_ref = self.sessions.borrow();
498            let session = sessions_ref
499                .get(&arguments.session_id)
500                .context("Failed to get session")?;
501            respect_always_allow_setting = session.session_modes.is_none();
502            thread = session.thread.clone();
503        }
504
505        let cx = &mut self.cx.clone();
506
507        let task = thread.update(cx, |thread, cx| {
508            thread.request_tool_call_authorization(
509                arguments.tool_call,
510                arguments.options,
511                respect_always_allow_setting,
512                cx,
513            )
514        })??;
515
516        let outcome = task.await;
517
518        Ok(acp::RequestPermissionResponse { outcome })
519    }
520
521    async fn write_text_file(
522        &self,
523        arguments: acp::WriteTextFileRequest,
524    ) -> Result<(), acp::Error> {
525        let cx = &mut self.cx.clone();
526        let task = self
527            .session_thread(&arguments.session_id)?
528            .update(cx, |thread, cx| {
529                thread.write_text_file(arguments.path, arguments.content, cx)
530            })?;
531
532        task.await?;
533
534        Ok(())
535    }
536
537    async fn read_text_file(
538        &self,
539        arguments: acp::ReadTextFileRequest,
540    ) -> Result<acp::ReadTextFileResponse, acp::Error> {
541        let task = self.session_thread(&arguments.session_id)?.update(
542            &mut self.cx.clone(),
543            |thread, cx| {
544                thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
545            },
546        )?;
547
548        let content = task.await?;
549
550        Ok(acp::ReadTextFileResponse { content })
551    }
552
553    async fn session_notification(
554        &self,
555        notification: acp::SessionNotification,
556    ) -> Result<(), acp::Error> {
557        let sessions = self.sessions.borrow();
558        let session = sessions
559            .get(&notification.session_id)
560            .context("Failed to get session")?;
561
562        if let acp::SessionUpdate::CurrentModeUpdate { current_mode_id } = &notification.update {
563            if let Some(session_modes) = &session.session_modes {
564                session_modes.borrow_mut().current_mode_id = current_mode_id.clone();
565            } else {
566                log::error!(
567                    "Got a `CurrentModeUpdate` notification, but they agent didn't specify `modes` during setting setup."
568                );
569            }
570        }
571
572        session.thread.update(&mut self.cx.clone(), |thread, cx| {
573            thread.handle_session_update(notification.update, cx)
574        })??;
575
576        Ok(())
577    }
578
579    async fn create_terminal(
580        &self,
581        args: acp::CreateTerminalRequest,
582    ) -> Result<acp::CreateTerminalResponse, acp::Error> {
583        let terminal = self
584            .session_thread(&args.session_id)?
585            .update(&mut self.cx.clone(), |thread, cx| {
586                thread.create_terminal(
587                    args.command,
588                    args.args,
589                    args.env,
590                    args.cwd,
591                    args.output_byte_limit,
592                    cx,
593                )
594            })?
595            .await?;
596        Ok(
597            terminal.read_with(&self.cx, |terminal, _| acp::CreateTerminalResponse {
598                terminal_id: terminal.id().clone(),
599            })?,
600        )
601    }
602
603    async fn kill_terminal(&self, args: acp::KillTerminalRequest) -> Result<(), acp::Error> {
604        self.session_thread(&args.session_id)?
605            .update(&mut self.cx.clone(), |thread, cx| {
606                thread.kill_terminal(args.terminal_id, cx)
607            })??;
608
609        Ok(())
610    }
611
612    async fn release_terminal(&self, args: acp::ReleaseTerminalRequest) -> Result<(), acp::Error> {
613        self.session_thread(&args.session_id)?
614            .update(&mut self.cx.clone(), |thread, cx| {
615                thread.release_terminal(args.terminal_id, cx)
616            })??;
617
618        Ok(())
619    }
620
621    async fn terminal_output(
622        &self,
623        args: acp::TerminalOutputRequest,
624    ) -> Result<acp::TerminalOutputResponse, acp::Error> {
625        self.session_thread(&args.session_id)?
626            .read_with(&mut self.cx.clone(), |thread, cx| {
627                let out = thread
628                    .terminal(args.terminal_id)?
629                    .read(cx)
630                    .current_output(cx);
631
632                Ok(out)
633            })?
634    }
635
636    async fn wait_for_terminal_exit(
637        &self,
638        args: acp::WaitForTerminalExitRequest,
639    ) -> Result<acp::WaitForTerminalExitResponse, acp::Error> {
640        let exit_status = self
641            .session_thread(&args.session_id)?
642            .update(&mut self.cx.clone(), |thread, cx| {
643                anyhow::Ok(thread.terminal(args.terminal_id)?.read(cx).wait_for_exit())
644            })??
645            .await;
646
647        Ok(acp::WaitForTerminalExitResponse { exit_status })
648    }
649}
650
651impl ClientDelegate {
652    fn session_thread(&self, session_id: &acp::SessionId) -> Result<WeakEntity<AcpThread>> {
653        let sessions = self.sessions.borrow();
654        sessions
655            .get(session_id)
656            .context("Failed to get session")
657            .map(|session| session.thread.clone())
658    }
659}