acp_connection.rs

  1use agent_client_protocol::{self as acp, Agent as _};
  2use collections::HashMap;
  3use futures::channel::oneshot;
  4use project::Project;
  5use std::cell::RefCell;
  6use std::path::Path;
  7use std::rc::Rc;
  8
  9use anyhow::{Context as _, Result};
 10use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
 11
 12use crate::AgentServerCommand;
 13use acp_thread::{AcpThread, AgentConnection, AuthRequired};
 14
 15pub struct AcpConnection {
 16    server_name: &'static str,
 17    connection: Rc<acp::ClientSideConnection>,
 18    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
 19    auth_methods: Vec<acp::AuthMethod>,
 20    _io_task: Task<Result<()>>,
 21}
 22
 23pub struct AcpSession {
 24    thread: WeakEntity<AcpThread>,
 25}
 26
 27impl AcpConnection {
 28    pub async fn stdio(
 29        server_name: &'static str,
 30        command: AgentServerCommand,
 31        root_dir: &Path,
 32        cx: &mut AsyncApp,
 33    ) -> Result<Self> {
 34        let mut child = util::command::new_smol_command(&command.path)
 35            .args(command.args.iter().map(|arg| arg.as_str()))
 36            .envs(command.env.iter().flatten())
 37            .current_dir(root_dir)
 38            .stdin(std::process::Stdio::piped())
 39            .stdout(std::process::Stdio::piped())
 40            .stderr(std::process::Stdio::inherit())
 41            .kill_on_drop(true)
 42            .spawn()?;
 43
 44        let stdout = child.stdout.take().expect("Failed to take stdout");
 45        let stdin = child.stdin.take().expect("Failed to take stdin");
 46
 47        let sessions = Rc::new(RefCell::new(HashMap::default()));
 48
 49        let client = ClientDelegate {
 50            sessions: sessions.clone(),
 51            cx: cx.clone(),
 52        };
 53        let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
 54            let foreground_executor = cx.foreground_executor().clone();
 55            move |fut| {
 56                foreground_executor.spawn(fut).detach();
 57            }
 58        });
 59
 60        let io_task = cx.background_spawn(io_task);
 61
 62        let response = connection
 63            .initialize(acp::InitializeRequest {
 64                protocol_version: acp::VERSION,
 65                client_capabilities: acp::ClientCapabilities {
 66                    fs: acp::FileSystemCapability {
 67                        read_text_file: true,
 68                        write_text_file: true,
 69                    },
 70                },
 71            })
 72            .await?;
 73
 74        // todo! check version
 75
 76        Ok(Self {
 77            auth_methods: response.auth_methods,
 78            connection: connection.into(),
 79            server_name,
 80            sessions,
 81            _io_task: io_task,
 82        })
 83    }
 84}
 85
 86impl AgentConnection for AcpConnection {
 87    fn new_thread(
 88        self: Rc<Self>,
 89        project: Entity<Project>,
 90        cwd: &Path,
 91        cx: &mut AsyncApp,
 92    ) -> Task<Result<Entity<AcpThread>>> {
 93        let conn = self.connection.clone();
 94        let sessions = self.sessions.clone();
 95        let cwd = cwd.to_path_buf();
 96        cx.spawn(async move |cx| {
 97            let response = conn
 98                .new_session(acp::NewSessionRequest {
 99                    // todo! Zed MCP server?
100                    mcp_servers: vec![],
101                    cwd,
102                })
103                .await?;
104
105            let Some(session_id) = response.session_id else {
106                anyhow::bail!(AuthRequired);
107            };
108
109            let thread = cx.new(|cx| {
110                AcpThread::new(
111                    self.server_name,
112                    self.clone(),
113                    project,
114                    session_id.clone(),
115                    cx,
116                )
117            })?;
118
119            let session = AcpSession {
120                thread: thread.downgrade(),
121            };
122            sessions.borrow_mut().insert(session_id, session);
123
124            Ok(thread)
125        })
126    }
127
128    fn auth_methods(&self) -> &[acp::AuthMethod] {
129        &self.auth_methods
130    }
131
132    fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
133        let conn = self.connection.clone();
134        cx.foreground_executor().spawn(async move {
135            let result = conn
136                .authenticate(acp::AuthenticateRequest {
137                    method_id: method_id.clone(),
138                })
139                .await?;
140
141            Ok(result)
142        })
143    }
144
145    fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
146        let conn = self.connection.clone();
147        cx.foreground_executor()
148            .spawn(async move { Ok(conn.prompt(params).await?) })
149    }
150
151    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
152        let conn = self.connection.clone();
153        let params = acp::CancelledNotification {
154            session_id: session_id.clone(),
155        };
156        cx.foreground_executor()
157            .spawn(async move { conn.cancelled(params).await })
158            .detach();
159    }
160}
161
162struct ClientDelegate {
163    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
164    cx: AsyncApp,
165}
166
167impl acp::Client for ClientDelegate {
168    async fn request_permission(
169        &self,
170        arguments: acp::RequestPermissionRequest,
171    ) -> Result<acp::RequestPermissionResponse, acp::Error> {
172        let cx = &mut self.cx.clone();
173        let result = self
174            .sessions
175            .borrow()
176            .get(&arguments.session_id)
177            .context("Failed to get session")?
178            .thread
179            .update(cx, |thread, cx| {
180                thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx)
181            })?
182            .await;
183
184        let outcome = match result {
185            Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
186            Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
187        };
188
189        Ok(acp::RequestPermissionResponse { outcome })
190    }
191
192    async fn write_text_file(
193        &self,
194        arguments: acp::WriteTextFileRequest,
195    ) -> Result<(), acp::Error> {
196        let cx = &mut self.cx.clone();
197        self.sessions
198            .borrow()
199            .get(&arguments.session_id)
200            .context("Failed to get session")?
201            .thread
202            .update(cx, |thread, cx| {
203                thread.write_text_file(arguments.path, arguments.content, cx)
204            })?
205            .await?;
206
207        Ok(())
208    }
209
210    async fn read_text_file(
211        &self,
212        arguments: acp::ReadTextFileRequest,
213    ) -> Result<acp::ReadTextFileResponse, acp::Error> {
214        let cx = &mut self.cx.clone();
215        let content = self
216            .sessions
217            .borrow()
218            .get(&arguments.session_id)
219            .context("Failed to get session")?
220            .thread
221            .update(cx, |thread, cx| {
222                thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
223            })?
224            .await?;
225
226        Ok(acp::ReadTextFileResponse { content })
227    }
228
229    async fn session_notification(
230        &self,
231        notification: acp::SessionNotification,
232    ) -> Result<(), acp::Error> {
233        let cx = &mut self.cx.clone();
234        let sessions = self.sessions.borrow();
235        let session = sessions
236            .get(&notification.session_id)
237            .context("Failed to get session")?;
238
239        session.thread.update(cx, |thread, cx| {
240            thread.handle_session_update(notification.update, cx)
241        })??;
242
243        Ok(())
244    }
245}