acp.rs

  1use crate::AgentServerCommand;
  2use acp_thread::AgentConnection;
  3use acp_tools::AcpConnectionRegistry;
  4use action_log::ActionLog;
  5use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
  6use anyhow::anyhow;
  7use collections::HashMap;
  8use futures::AsyncBufReadExt as _;
  9use futures::io::BufReader;
 10use project::Project;
 11use serde::Deserialize;
 12use util::ResultExt as _;
 13
 14use std::{any::Any, cell::RefCell};
 15use std::{path::Path, rc::Rc};
 16use thiserror::Error;
 17
 18use anyhow::{Context as _, Result};
 19use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
 20
 21use acp_thread::{AcpThread, AuthRequired, LoadError};
 22
 23#[derive(Debug, Error)]
 24#[error("Unsupported version")]
 25pub struct UnsupportedVersion;
 26
 27pub struct AcpConnection {
 28    server_name: SharedString,
 29    connection: Rc<acp::ClientSideConnection>,
 30    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
 31    auth_methods: Vec<acp::AuthMethod>,
 32    agent_capabilities: acp::AgentCapabilities,
 33    // NB: Don't move this into the wait_task, since we need to ensure the process is
 34    // killed on drop (setting kill_on_drop on the command seems to not always work).
 35    child: smol::process::Child,
 36    _io_task: Task<Result<()>>,
 37    _wait_task: Task<Result<()>>,
 38    _stderr_task: Task<Result<()>>,
 39}
 40
 41pub struct AcpSession {
 42    thread: WeakEntity<AcpThread>,
 43    suppress_abort_err: bool,
 44}
 45
 46pub async fn connect(
 47    server_name: SharedString,
 48    command: AgentServerCommand,
 49    root_dir: &Path,
 50    cx: &mut AsyncApp,
 51) -> Result<Rc<dyn AgentConnection>> {
 52    let conn = AcpConnection::stdio(server_name, command.clone(), root_dir, cx).await?;
 53    Ok(Rc::new(conn) as _)
 54}
 55
 56const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
 57
 58impl AcpConnection {
 59    pub async fn stdio(
 60        server_name: SharedString,
 61        command: AgentServerCommand,
 62        root_dir: &Path,
 63        cx: &mut AsyncApp,
 64    ) -> Result<Self> {
 65        let mut child = util::command::new_smol_command(command.path)
 66            .args(command.args.iter().map(|arg| arg.as_str()))
 67            .envs(command.env.iter().flatten())
 68            .current_dir(root_dir)
 69            .stdin(std::process::Stdio::piped())
 70            .stdout(std::process::Stdio::piped())
 71            .stderr(std::process::Stdio::piped())
 72            .spawn()?;
 73
 74        let stdout = child.stdout.take().context("Failed to take stdout")?;
 75        let stdin = child.stdin.take().context("Failed to take stdin")?;
 76        let stderr = child.stderr.take().context("Failed to take stderr")?;
 77        log::trace!("Spawned (pid: {})", child.id());
 78
 79        let sessions = Rc::new(RefCell::new(HashMap::default()));
 80
 81        let client = ClientDelegate {
 82            sessions: sessions.clone(),
 83            cx: cx.clone(),
 84        };
 85        let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
 86            let foreground_executor = cx.foreground_executor().clone();
 87            move |fut| {
 88                foreground_executor.spawn(fut).detach();
 89            }
 90        });
 91
 92        let io_task = cx.background_spawn(io_task);
 93
 94        let stderr_task = cx.background_spawn(async move {
 95            let mut stderr = BufReader::new(stderr);
 96            let mut line = String::new();
 97            while let Ok(n) = stderr.read_line(&mut line).await
 98                && n > 0
 99            {
100                log::warn!("agent stderr: {}", &line);
101                line.clear();
102            }
103            Ok(())
104        });
105
106        let wait_task = cx.spawn({
107            let sessions = sessions.clone();
108            let status_fut = child.status();
109            async move |cx| {
110                let status = status_fut.await?;
111
112                for session in sessions.borrow().values() {
113                    session
114                        .thread
115                        .update(cx, |thread, cx| {
116                            thread.emit_load_error(LoadError::Exited { status }, cx)
117                        })
118                        .ok();
119                }
120
121                anyhow::Ok(())
122            }
123        });
124
125        let connection = Rc::new(connection);
126
127        cx.update(|cx| {
128            AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
129                registry.set_active_connection(server_name.clone(), &connection, cx)
130            });
131        })?;
132
133        let response = connection
134            .initialize(acp::InitializeRequest {
135                protocol_version: acp::VERSION,
136                client_capabilities: acp::ClientCapabilities {
137                    fs: acp::FileSystemCapability {
138                        read_text_file: true,
139                        write_text_file: true,
140                    },
141                    terminal: true,
142                },
143            })
144            .await?;
145
146        if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
147            return Err(UnsupportedVersion.into());
148        }
149
150        Ok(Self {
151            auth_methods: response.auth_methods,
152            connection,
153            server_name,
154            sessions,
155            agent_capabilities: response.agent_capabilities,
156            _io_task: io_task,
157            _wait_task: wait_task,
158            _stderr_task: stderr_task,
159            child,
160        })
161    }
162
163    pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
164        &self.agent_capabilities.prompt_capabilities
165    }
166}
167
168impl Drop for AcpConnection {
169    fn drop(&mut self) {
170        // See the comment on the child field.
171        self.child.kill().log_err();
172    }
173}
174
175impl AgentConnection for AcpConnection {
176    fn new_thread(
177        self: Rc<Self>,
178        project: Entity<Project>,
179        cwd: &Path,
180        cx: &mut App,
181    ) -> Task<Result<Entity<AcpThread>>> {
182        let conn = self.connection.clone();
183        let sessions = self.sessions.clone();
184        let cwd = cwd.to_path_buf();
185        let context_server_store = project.read(cx).context_server_store().read(cx);
186        let mcp_servers = context_server_store
187            .configured_server_ids()
188            .iter()
189            .filter_map(|id| {
190                let configuration = context_server_store.configuration_for_server(id)?;
191                let command = configuration.command();
192                Some(acp::McpServer {
193                    name: id.0.to_string(),
194                    command: command.path.clone(),
195                    args: command.args.clone(),
196                    env: if let Some(env) = command.env.as_ref() {
197                        env.iter()
198                            .map(|(name, value)| acp::EnvVariable {
199                                name: name.clone(),
200                                value: value.clone(),
201                            })
202                            .collect()
203                    } else {
204                        vec![]
205                    },
206                })
207            })
208            .collect();
209
210        cx.spawn(async move |cx| {
211            let response = conn
212                .new_session(acp::NewSessionRequest { mcp_servers, cwd })
213                .await
214                .map_err(|err| {
215                    if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
216                        let mut error = AuthRequired::new();
217
218                        if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
219                            error = error.with_description(err.message);
220                        }
221
222                        anyhow!(error)
223                    } else {
224                        anyhow!(err)
225                    }
226                })?;
227
228            let session_id = response.session_id;
229            let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
230            let thread = cx.new(|cx| {
231                AcpThread::new(
232                    self.server_name.clone(),
233                    self.clone(),
234                    project,
235                    action_log,
236                    session_id.clone(),
237                    // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
238                    watch::Receiver::constant(self.agent_capabilities.prompt_capabilities),
239                    cx,
240                )
241            })?;
242
243            let session = AcpSession {
244                thread: thread.downgrade(),
245                suppress_abort_err: false,
246            };
247            sessions.borrow_mut().insert(session_id, session);
248
249            Ok(thread)
250        })
251    }
252
253    fn auth_methods(&self) -> &[acp::AuthMethod] {
254        &self.auth_methods
255    }
256
257    fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
258        let conn = self.connection.clone();
259        cx.foreground_executor().spawn(async move {
260            let result = conn
261                .authenticate(acp::AuthenticateRequest {
262                    method_id: method_id.clone(),
263                })
264                .await?;
265
266            Ok(result)
267        })
268    }
269
270    fn prompt(
271        &self,
272        _id: Option<acp_thread::UserMessageId>,
273        params: acp::PromptRequest,
274        cx: &mut App,
275    ) -> Task<Result<acp::PromptResponse>> {
276        let conn = self.connection.clone();
277        let sessions = self.sessions.clone();
278        let session_id = params.session_id.clone();
279        cx.foreground_executor().spawn(async move {
280            let result = conn.prompt(params).await;
281
282            let mut suppress_abort_err = false;
283
284            if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
285                suppress_abort_err = session.suppress_abort_err;
286                session.suppress_abort_err = false;
287            }
288
289            match result {
290                Ok(response) => Ok(response),
291                Err(err) => {
292                    if err.code != ErrorCode::INTERNAL_ERROR.code {
293                        anyhow::bail!(err)
294                    }
295
296                    let Some(data) = &err.data else {
297                        anyhow::bail!(err)
298                    };
299
300                    // Temporary workaround until the following PR is generally available:
301                    // https://github.com/google-gemini/gemini-cli/pull/6656
302
303                    #[derive(Deserialize)]
304                    #[serde(deny_unknown_fields)]
305                    struct ErrorDetails {
306                        details: Box<str>,
307                    }
308
309                    match serde_json::from_value(data.clone()) {
310                        Ok(ErrorDetails { details }) => {
311                            if suppress_abort_err
312                                && (details.contains("This operation was aborted")
313                                    || details.contains("The user aborted a request"))
314                            {
315                                Ok(acp::PromptResponse {
316                                    stop_reason: acp::StopReason::Cancelled,
317                                })
318                            } else {
319                                Err(anyhow!(details))
320                            }
321                        }
322                        Err(_) => Err(anyhow!(err)),
323                    }
324                }
325            }
326        })
327    }
328
329    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
330        if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
331            session.suppress_abort_err = true;
332        }
333        let conn = self.connection.clone();
334        let params = acp::CancelNotification {
335            session_id: session_id.clone(),
336        };
337        cx.foreground_executor()
338            .spawn(async move { conn.cancel(params).await })
339            .detach();
340    }
341
342    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
343        self
344    }
345}
346
347struct ClientDelegate {
348    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
349    cx: AsyncApp,
350}
351
352impl acp::Client for ClientDelegate {
353    async fn request_permission(
354        &self,
355        arguments: acp::RequestPermissionRequest,
356    ) -> Result<acp::RequestPermissionResponse, acp::Error> {
357        let cx = &mut self.cx.clone();
358
359        let task = self
360            .session_thread(&arguments.session_id)?
361            .update(cx, |thread, cx| {
362                thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
363            })??;
364
365        let outcome = task.await;
366
367        Ok(acp::RequestPermissionResponse { outcome })
368    }
369
370    async fn write_text_file(
371        &self,
372        arguments: acp::WriteTextFileRequest,
373    ) -> Result<(), acp::Error> {
374        let cx = &mut self.cx.clone();
375        let task = self
376            .session_thread(&arguments.session_id)?
377            .update(cx, |thread, cx| {
378                thread.write_text_file(arguments.path, arguments.content, cx)
379            })?;
380
381        task.await?;
382
383        Ok(())
384    }
385
386    async fn read_text_file(
387        &self,
388        arguments: acp::ReadTextFileRequest,
389    ) -> Result<acp::ReadTextFileResponse, acp::Error> {
390        let task = self.session_thread(&arguments.session_id)?.update(
391            &mut self.cx.clone(),
392            |thread, cx| {
393                thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
394            },
395        )?;
396
397        let content = task.await?;
398
399        Ok(acp::ReadTextFileResponse { content })
400    }
401
402    async fn session_notification(
403        &self,
404        notification: acp::SessionNotification,
405    ) -> Result<(), acp::Error> {
406        self.session_thread(&notification.session_id)?
407            .update(&mut self.cx.clone(), |thread, cx| {
408                thread.handle_session_update(notification.update, cx)
409            })??;
410
411        Ok(())
412    }
413
414    async fn create_terminal(
415        &self,
416        args: acp::CreateTerminalRequest,
417    ) -> Result<acp::CreateTerminalResponse, acp::Error> {
418        let terminal = self
419            .session_thread(&args.session_id)?
420            .update(&mut self.cx.clone(), |thread, cx| {
421                thread.create_terminal(
422                    args.command,
423                    args.args,
424                    args.env,
425                    args.cwd,
426                    args.output_byte_limit,
427                    cx,
428                )
429            })?
430            .await?;
431        Ok(
432            terminal.read_with(&self.cx, |terminal, _| acp::CreateTerminalResponse {
433                terminal_id: terminal.id().clone(),
434            })?,
435        )
436    }
437
438    async fn kill_terminal(&self, args: acp::KillTerminalRequest) -> Result<(), acp::Error> {
439        self.session_thread(&args.session_id)?
440            .update(&mut self.cx.clone(), |thread, cx| {
441                thread.kill_terminal(args.terminal_id, cx)
442            })??;
443
444        Ok(())
445    }
446
447    async fn release_terminal(&self, args: acp::ReleaseTerminalRequest) -> Result<(), acp::Error> {
448        self.session_thread(&args.session_id)?
449            .update(&mut self.cx.clone(), |thread, cx| {
450                thread.release_terminal(args.terminal_id, cx)
451            })??;
452
453        Ok(())
454    }
455
456    async fn terminal_output(
457        &self,
458        args: acp::TerminalOutputRequest,
459    ) -> Result<acp::TerminalOutputResponse, acp::Error> {
460        self.session_thread(&args.session_id)?
461            .read_with(&mut self.cx.clone(), |thread, cx| {
462                let out = thread
463                    .terminal(args.terminal_id)?
464                    .read(cx)
465                    .current_output(cx);
466
467                Ok(out)
468            })?
469    }
470
471    async fn wait_for_terminal_exit(
472        &self,
473        args: acp::WaitForTerminalExitRequest,
474    ) -> Result<acp::WaitForTerminalExitResponse, acp::Error> {
475        let exit_status = self
476            .session_thread(&args.session_id)?
477            .update(&mut self.cx.clone(), |thread, cx| {
478                anyhow::Ok(thread.terminal(args.terminal_id)?.read(cx).wait_for_exit())
479            })??
480            .await;
481
482        Ok(acp::WaitForTerminalExitResponse { exit_status })
483    }
484}
485
486impl ClientDelegate {
487    fn session_thread(&self, session_id: &acp::SessionId) -> Result<WeakEntity<AcpThread>> {
488        let sessions = self.sessions.borrow();
489        sessions
490            .get(session_id)
491            .context("Failed to get session")
492            .map(|session| session.thread.clone())
493    }
494}