acp.rs

  1use crate::AgentServerCommand;
  2use acp_thread::AgentConnection;
  3use acp_tools::AcpConnectionRegistry;
  4use action_log::ActionLog;
  5use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
  6use anyhow::anyhow;
  7use collections::HashMap;
  8use futures::AsyncBufReadExt as _;
  9use futures::io::BufReader;
 10use project::Project;
 11use serde::Deserialize;
 12use util::ResultExt as _;
 13
 14use std::{any::Any, cell::RefCell};
 15use std::{path::Path, rc::Rc};
 16use thiserror::Error;
 17
 18use anyhow::{Context as _, Result};
 19use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
 20
 21use acp_thread::{AcpThread, AuthRequired, LoadError};
 22
 23#[derive(Debug, Error)]
 24#[error("Unsupported version")]
 25pub struct UnsupportedVersion;
 26
 27pub struct AcpConnection {
 28    server_name: SharedString,
 29    connection: Rc<acp::ClientSideConnection>,
 30    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
 31    auth_methods: Vec<acp::AuthMethod>,
 32    agent_capabilities: acp::AgentCapabilities,
 33    default_mode: Option<acp::SessionModeId>,
 34    root_dir: PathBuf,
 35    // NB: Don't move this into the wait_task, since we need to ensure the process is
 36    // killed on drop (setting kill_on_drop on the command seems to not always work).
 37    child: smol::process::Child,
 38    _io_task: Task<Result<()>>,
 39    _wait_task: Task<Result<()>>,
 40    _stderr_task: Task<Result<()>>,
 41}
 42
 43pub struct AcpSession {
 44    thread: WeakEntity<AcpThread>,
 45    suppress_abort_err: bool,
 46}
 47
 48pub async fn connect(
 49    server_name: SharedString,
 50    command: AgentServerCommand,
 51    root_dir: &Path,
 52    cx: &mut AsyncApp,
 53) -> Result<Rc<dyn AgentConnection>> {
 54    let conn = AcpConnection::stdio(server_name, command.clone(), root_dir, cx).await?;
 55    Ok(Rc::new(conn) as _)
 56}
 57
 58const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
 59
 60impl AcpConnection {
 61    pub async fn stdio(
 62        server_name: SharedString,
 63        command: AgentServerCommand,
 64        root_dir: &Path,
 65        cx: &mut AsyncApp,
 66    ) -> Result<Self> {
 67        let mut child = util::command::new_smol_command(command.path)
 68            .args(command.args.iter().map(|arg| arg.as_str()))
 69            .envs(command.env.iter().flatten())
 70            .current_dir(root_dir)
 71            .stdin(std::process::Stdio::piped())
 72            .stdout(std::process::Stdio::piped())
 73            .stderr(std::process::Stdio::piped());
 74        if !is_remote {
 75            child.current_dir(root_dir);
 76        }
 77        let mut child = child.spawn()?;
 78
 79        let stdout = child.stdout.take().context("Failed to take stdout")?;
 80        let stdin = child.stdin.take().context("Failed to take stdin")?;
 81        let stderr = child.stderr.take().context("Failed to take stderr")?;
 82        log::trace!("Spawned (pid: {})", child.id());
 83
 84        let sessions = Rc::new(RefCell::new(HashMap::default()));
 85
 86        let client = ClientDelegate {
 87            sessions: sessions.clone(),
 88            cx: cx.clone(),
 89        };
 90        let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
 91            let foreground_executor = cx.foreground_executor().clone();
 92            move |fut| {
 93                foreground_executor.spawn(fut).detach();
 94            }
 95        });
 96
 97        let io_task = cx.background_spawn(io_task);
 98
 99        let stderr_task = cx.background_spawn(async move {
100            let mut stderr = BufReader::new(stderr);
101            let mut line = String::new();
102            while let Ok(n) = stderr.read_line(&mut line).await
103                && n > 0
104            {
105                log::warn!("agent stderr: {}", &line);
106                line.clear();
107            }
108            Ok(())
109        });
110
111        let wait_task = cx.spawn({
112            let sessions = sessions.clone();
113            let status_fut = child.status();
114            async move |cx| {
115                let status = status_fut.await?;
116
117                for session in sessions.borrow().values() {
118                    session
119                        .thread
120                        .update(cx, |thread, cx| {
121                            thread.emit_load_error(LoadError::Exited { status }, cx)
122                        })
123                        .ok();
124                }
125
126                anyhow::Ok(())
127            }
128        });
129
130        let connection = Rc::new(connection);
131
132        cx.update(|cx| {
133            AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
134                registry.set_active_connection(server_name.clone(), &connection, cx)
135            });
136        })?;
137
138        let response = connection
139            .initialize(acp::InitializeRequest {
140                protocol_version: acp::VERSION,
141                client_capabilities: acp::ClientCapabilities {
142                    fs: acp::FileSystemCapability {
143                        read_text_file: true,
144                        write_text_file: true,
145                    },
146                    terminal: true,
147                },
148            })
149            .await?;
150
151        if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
152            return Err(UnsupportedVersion.into());
153        }
154
155        Ok(Self {
156            auth_methods: response.auth_methods,
157            connection,
158            server_name,
159            sessions,
160            agent_capabilities: response.agent_capabilities,
161            _io_task: io_task,
162            _wait_task: wait_task,
163            _stderr_task: stderr_task,
164            child,
165        })
166    }
167
168    pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
169        &self.agent_capabilities.prompt_capabilities
170    }
171}
172
173impl Drop for AcpConnection {
174    fn drop(&mut self) {
175        // See the comment on the child field.
176        self.child.kill().log_err();
177    }
178}
179
180impl AgentConnection for AcpConnection {
181    fn new_thread(
182        self: Rc<Self>,
183        project: Entity<Project>,
184        cwd: &Path,
185        cx: &mut App,
186    ) -> Task<Result<Entity<AcpThread>>> {
187        let conn = self.connection.clone();
188        let sessions = self.sessions.clone();
189        let cwd = cwd.to_path_buf();
190        let context_server_store = project.read(cx).context_server_store().read(cx);
191        let mcp_servers = context_server_store
192            .configured_server_ids()
193            .iter()
194            .filter_map(|id| {
195                let configuration = context_server_store.configuration_for_server(id)?;
196                let command = configuration.command();
197                Some(acp::McpServer {
198                    name: id.0.to_string(),
199                    command: command.path.clone(),
200                    args: command.args.clone(),
201                    env: if let Some(env) = command.env.as_ref() {
202                        env.iter()
203                            .map(|(name, value)| acp::EnvVariable {
204                                name: name.clone(),
205                                value: value.clone(),
206                            })
207                            .collect()
208                    } else {
209                        vec![]
210                    },
211                })
212            })
213            .collect();
214
215        cx.spawn(async move |cx| {
216            let response = conn
217                .new_session(acp::NewSessionRequest { mcp_servers, cwd })
218                .await
219                .map_err(|err| {
220                    if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
221                        let mut error = AuthRequired::new();
222
223                        if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
224                            error = error.with_description(err.message);
225                        }
226
227                        anyhow!(error)
228                    } else {
229                        anyhow!(err)
230                    }
231                })?;
232
233            let session_id = response.session_id;
234            let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
235            let thread = cx.new(|cx| {
236                AcpThread::new(
237                    self.server_name.clone(),
238                    self.clone(),
239                    project,
240                    action_log,
241                    session_id.clone(),
242                    // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
243                    watch::Receiver::constant(self.agent_capabilities.prompt_capabilities),
244                    cx,
245                )
246            })?;
247
248            let session = AcpSession {
249                thread: thread.downgrade(),
250                suppress_abort_err: false,
251            };
252            sessions.borrow_mut().insert(session_id, session);
253
254            Ok(thread)
255        })
256    }
257
258    fn auth_methods(&self) -> &[acp::AuthMethod] {
259        &self.auth_methods
260    }
261
262    fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
263        let conn = self.connection.clone();
264        cx.foreground_executor().spawn(async move {
265            let result = conn
266                .authenticate(acp::AuthenticateRequest {
267                    method_id: method_id.clone(),
268                })
269                .await?;
270
271            Ok(result)
272        })
273    }
274
275    fn prompt(
276        &self,
277        _id: Option<acp_thread::UserMessageId>,
278        params: acp::PromptRequest,
279        cx: &mut App,
280    ) -> Task<Result<acp::PromptResponse>> {
281        let conn = self.connection.clone();
282        let sessions = self.sessions.clone();
283        let session_id = params.session_id.clone();
284        cx.foreground_executor().spawn(async move {
285            let result = conn.prompt(params).await;
286
287            let mut suppress_abort_err = false;
288
289            if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
290                suppress_abort_err = session.suppress_abort_err;
291                session.suppress_abort_err = false;
292            }
293
294            match result {
295                Ok(response) => Ok(response),
296                Err(err) => {
297                    if err.code != ErrorCode::INTERNAL_ERROR.code {
298                        anyhow::bail!(err)
299                    }
300
301                    let Some(data) = &err.data else {
302                        anyhow::bail!(err)
303                    };
304
305                    // Temporary workaround until the following PR is generally available:
306                    // https://github.com/google-gemini/gemini-cli/pull/6656
307
308                    #[derive(Deserialize)]
309                    #[serde(deny_unknown_fields)]
310                    struct ErrorDetails {
311                        details: Box<str>,
312                    }
313
314                    match serde_json::from_value(data.clone()) {
315                        Ok(ErrorDetails { details }) => {
316                            if suppress_abort_err
317                                && (details.contains("This operation was aborted")
318                                    || details.contains("The user aborted a request"))
319                            {
320                                Ok(acp::PromptResponse {
321                                    stop_reason: acp::StopReason::Cancelled,
322                                })
323                            } else {
324                                Err(anyhow!(details))
325                            }
326                        }
327                        Err(_) => Err(anyhow!(err)),
328                    }
329                }
330            }
331        })
332    }
333
334    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
335        if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
336            session.suppress_abort_err = true;
337        }
338        let conn = self.connection.clone();
339        let params = acp::CancelNotification {
340            session_id: session_id.clone(),
341        };
342        cx.foreground_executor()
343            .spawn(async move { conn.cancel(params).await })
344            .detach();
345    }
346
347    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
348        self
349    }
350}
351
352struct ClientDelegate {
353    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
354    cx: AsyncApp,
355}
356
357impl acp::Client for ClientDelegate {
358    async fn request_permission(
359        &self,
360        arguments: acp::RequestPermissionRequest,
361    ) -> Result<acp::RequestPermissionResponse, acp::Error> {
362        let cx = &mut self.cx.clone();
363
364        let task = self
365            .session_thread(&arguments.session_id)?
366            .update(cx, |thread, cx| {
367                thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
368            })??;
369
370        let outcome = task.await;
371
372        Ok(acp::RequestPermissionResponse { outcome })
373    }
374
375    async fn write_text_file(
376        &self,
377        arguments: acp::WriteTextFileRequest,
378    ) -> Result<(), acp::Error> {
379        let cx = &mut self.cx.clone();
380        let task = self
381            .session_thread(&arguments.session_id)?
382            .update(cx, |thread, cx| {
383                thread.write_text_file(arguments.path, arguments.content, cx)
384            })?;
385
386        task.await?;
387
388        Ok(())
389    }
390
391    async fn read_text_file(
392        &self,
393        arguments: acp::ReadTextFileRequest,
394    ) -> Result<acp::ReadTextFileResponse, acp::Error> {
395        let task = self.session_thread(&arguments.session_id)?.update(
396            &mut self.cx.clone(),
397            |thread, cx| {
398                thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
399            },
400        )?;
401
402        let content = task.await?;
403
404        Ok(acp::ReadTextFileResponse { content })
405    }
406
407    async fn session_notification(
408        &self,
409        notification: acp::SessionNotification,
410    ) -> Result<(), acp::Error> {
411        self.session_thread(&notification.session_id)?
412            .update(&mut self.cx.clone(), |thread, cx| {
413                thread.handle_session_update(notification.update, cx)
414            })??;
415
416        Ok(())
417    }
418
419    async fn create_terminal(
420        &self,
421        args: acp::CreateTerminalRequest,
422    ) -> Result<acp::CreateTerminalResponse, acp::Error> {
423        let terminal = self
424            .session_thread(&args.session_id)?
425            .update(&mut self.cx.clone(), |thread, cx| {
426                thread.create_terminal(
427                    args.command,
428                    args.args,
429                    args.env,
430                    args.cwd,
431                    args.output_byte_limit,
432                    cx,
433                )
434            })?
435            .await?;
436        Ok(
437            terminal.read_with(&self.cx, |terminal, _| acp::CreateTerminalResponse {
438                terminal_id: terminal.id().clone(),
439            })?,
440        )
441    }
442
443    async fn kill_terminal(&self, args: acp::KillTerminalRequest) -> Result<(), acp::Error> {
444        self.session_thread(&args.session_id)?
445            .update(&mut self.cx.clone(), |thread, cx| {
446                thread.kill_terminal(args.terminal_id, cx)
447            })??;
448
449        Ok(())
450    }
451
452    async fn release_terminal(&self, args: acp::ReleaseTerminalRequest) -> Result<(), acp::Error> {
453        self.session_thread(&args.session_id)?
454            .update(&mut self.cx.clone(), |thread, cx| {
455                thread.release_terminal(args.terminal_id, cx)
456            })??;
457
458        Ok(())
459    }
460
461    async fn terminal_output(
462        &self,
463        args: acp::TerminalOutputRequest,
464    ) -> Result<acp::TerminalOutputResponse, acp::Error> {
465        self.session_thread(&args.session_id)?
466            .read_with(&mut self.cx.clone(), |thread, cx| {
467                let out = thread
468                    .terminal(args.terminal_id)?
469                    .read(cx)
470                    .current_output(cx);
471
472                Ok(out)
473            })?
474    }
475
476    async fn wait_for_terminal_exit(
477        &self,
478        args: acp::WaitForTerminalExitRequest,
479    ) -> Result<acp::WaitForTerminalExitResponse, acp::Error> {
480        let exit_status = self
481            .session_thread(&args.session_id)?
482            .update(&mut self.cx.clone(), |thread, cx| {
483                anyhow::Ok(thread.terminal(args.terminal_id)?.read(cx).wait_for_exit())
484            })??
485            .await;
486
487        Ok(acp::WaitForTerminalExitResponse { exit_status })
488    }
489}
490
491impl ClientDelegate {
492    fn session_thread(&self, session_id: &acp::SessionId) -> Result<WeakEntity<AcpThread>> {
493        let sessions = self.sessions.borrow();
494        sessions
495            .get(session_id)
496            .context("Failed to get session")
497            .map(|session| session.thread.clone())
498    }
499}