v1.rs

  1use agent_client_protocol::{self as acp, Agent as _};
  2use anyhow::anyhow;
  3use collections::HashMap;
  4use futures::AsyncBufReadExt as _;
  5use futures::channel::oneshot;
  6use futures::io::BufReader;
  7use project::Project;
  8use std::path::Path;
  9use std::rc::Rc;
 10use std::{any::Any, cell::RefCell};
 11
 12use anyhow::{Context as _, Result};
 13use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
 14
 15use crate::{AgentServerCommand, acp::UnsupportedVersion};
 16use acp_thread::{AcpThread, AgentConnection, AuthRequired};
 17
 18pub struct AcpConnection {
 19    server_name: &'static str,
 20    connection: Rc<acp::ClientSideConnection>,
 21    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
 22    auth_methods: Vec<acp::AuthMethod>,
 23    _io_task: Task<Result<()>>,
 24}
 25
 26pub struct AcpSession {
 27    thread: WeakEntity<AcpThread>,
 28}
 29
 30const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
 31
 32impl AcpConnection {
 33    pub async fn stdio(
 34        server_name: &'static str,
 35        command: AgentServerCommand,
 36        root_dir: &Path,
 37        cx: &mut AsyncApp,
 38    ) -> Result<Self> {
 39        let mut child = util::command::new_smol_command(&command.path)
 40            .args(command.args.iter().map(|arg| arg.as_str()))
 41            .envs(command.env.iter().flatten())
 42            .current_dir(root_dir)
 43            .stdin(std::process::Stdio::piped())
 44            .stdout(std::process::Stdio::piped())
 45            .stderr(std::process::Stdio::piped())
 46            .kill_on_drop(true)
 47            .spawn()?;
 48
 49        let stdout = child.stdout.take().context("Failed to take stdout")?;
 50        let stdin = child.stdin.take().context("Failed to take stdin")?;
 51        let stderr = child.stderr.take().context("Failed to take stderr")?;
 52        log::trace!("Spawned (pid: {})", child.id());
 53
 54        let sessions = Rc::new(RefCell::new(HashMap::default()));
 55
 56        let client = ClientDelegate {
 57            sessions: sessions.clone(),
 58            cx: cx.clone(),
 59        };
 60        let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
 61            let foreground_executor = cx.foreground_executor().clone();
 62            move |fut| {
 63                foreground_executor.spawn(fut).detach();
 64            }
 65        });
 66
 67        let io_task = cx.background_spawn(io_task);
 68
 69        cx.background_spawn(async move {
 70            let mut stderr = BufReader::new(stderr);
 71            let mut line = String::new();
 72            while let Ok(n) = stderr.read_line(&mut line).await
 73                && n > 0
 74            {
 75                log::warn!("agent stderr: {}", &line);
 76                line.clear();
 77            }
 78        })
 79        .detach();
 80
 81        cx.spawn({
 82            let sessions = sessions.clone();
 83            async move |cx| {
 84                let status = child.status().await?;
 85
 86                for session in sessions.borrow().values() {
 87                    session
 88                        .thread
 89                        .update(cx, |thread, cx| thread.emit_server_exited(status, cx))
 90                        .ok();
 91                }
 92
 93                anyhow::Ok(())
 94            }
 95        })
 96        .detach();
 97
 98        let response = connection
 99            .initialize(acp::InitializeRequest {
100                protocol_version: acp::VERSION,
101                client_capabilities: acp::ClientCapabilities {
102                    fs: acp::FileSystemCapability {
103                        read_text_file: true,
104                        write_text_file: true,
105                    },
106                },
107            })
108            .await?;
109
110        if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
111            return Err(UnsupportedVersion.into());
112        }
113
114        Ok(Self {
115            auth_methods: response.auth_methods,
116            connection: connection.into(),
117            server_name,
118            sessions,
119            _io_task: io_task,
120        })
121    }
122}
123
124impl AgentConnection for AcpConnection {
125    fn new_thread(
126        self: Rc<Self>,
127        project: Entity<Project>,
128        cwd: &Path,
129        cx: &mut App,
130    ) -> Task<Result<Entity<AcpThread>>> {
131        let conn = self.connection.clone();
132        let sessions = self.sessions.clone();
133        let cwd = cwd.to_path_buf();
134        cx.spawn(async move |cx| {
135            let response = conn
136                .new_session(acp::NewSessionRequest {
137                    mcp_servers: vec![],
138                    cwd,
139                })
140                .await
141                .map_err(|err| {
142                    if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
143                        let mut error = AuthRequired::new();
144
145                        if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
146                            error = error.with_description(err.message);
147                        }
148
149                        anyhow!(error)
150                    } else {
151                        anyhow!(err)
152                    }
153                })?;
154
155            let session_id = response.session_id;
156
157            let thread = cx.new(|cx| {
158                AcpThread::new(
159                    self.server_name,
160                    self.clone(),
161                    project,
162                    session_id.clone(),
163                    cx,
164                )
165            })?;
166
167            let session = AcpSession {
168                thread: thread.downgrade(),
169            };
170            sessions.borrow_mut().insert(session_id, session);
171
172            Ok(thread)
173        })
174    }
175
176    fn auth_methods(&self) -> &[acp::AuthMethod] {
177        &self.auth_methods
178    }
179
180    fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
181        let conn = self.connection.clone();
182        cx.foreground_executor().spawn(async move {
183            let result = conn
184                .authenticate(acp::AuthenticateRequest {
185                    method_id: method_id.clone(),
186                })
187                .await?;
188
189            Ok(result)
190        })
191    }
192
193    fn prompt(
194        &self,
195        _id: Option<acp_thread::UserMessageId>,
196        params: acp::PromptRequest,
197        cx: &mut App,
198    ) -> Task<Result<acp::PromptResponse>> {
199        let conn = self.connection.clone();
200        cx.foreground_executor().spawn(async move {
201            let response = conn.prompt(params).await?;
202            Ok(response)
203        })
204    }
205
206    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
207        let conn = self.connection.clone();
208        let params = acp::CancelNotification {
209            session_id: session_id.clone(),
210        };
211        cx.foreground_executor()
212            .spawn(async move { conn.cancel(params).await })
213            .detach();
214    }
215
216    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
217        self
218    }
219}
220
221struct ClientDelegate {
222    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
223    cx: AsyncApp,
224}
225
226impl acp::Client for ClientDelegate {
227    async fn request_permission(
228        &self,
229        arguments: acp::RequestPermissionRequest,
230    ) -> Result<acp::RequestPermissionResponse, acp::Error> {
231        let cx = &mut self.cx.clone();
232        let rx = self
233            .sessions
234            .borrow()
235            .get(&arguments.session_id)
236            .context("Failed to get session")?
237            .thread
238            .update(cx, |thread, cx| {
239                thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
240            })?;
241
242        let result = rx?.await;
243
244        let outcome = match result {
245            Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
246            Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled,
247        };
248
249        Ok(acp::RequestPermissionResponse { outcome })
250    }
251
252    async fn write_text_file(
253        &self,
254        arguments: acp::WriteTextFileRequest,
255    ) -> Result<(), acp::Error> {
256        let cx = &mut self.cx.clone();
257        let task = self
258            .sessions
259            .borrow()
260            .get(&arguments.session_id)
261            .context("Failed to get session")?
262            .thread
263            .update(cx, |thread, cx| {
264                thread.write_text_file(arguments.path, arguments.content, cx)
265            })?;
266
267        task.await?;
268
269        Ok(())
270    }
271
272    async fn read_text_file(
273        &self,
274        arguments: acp::ReadTextFileRequest,
275    ) -> Result<acp::ReadTextFileResponse, acp::Error> {
276        let cx = &mut self.cx.clone();
277        let task = self
278            .sessions
279            .borrow()
280            .get(&arguments.session_id)
281            .context("Failed to get session")?
282            .thread
283            .update(cx, |thread, cx| {
284                thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
285            })?;
286
287        let content = task.await?;
288
289        Ok(acp::ReadTextFileResponse { content })
290    }
291
292    async fn session_notification(
293        &self,
294        notification: acp::SessionNotification,
295    ) -> Result<(), acp::Error> {
296        let cx = &mut self.cx.clone();
297        let sessions = self.sessions.borrow();
298        let session = sessions
299            .get(&notification.session_id)
300            .context("Failed to get session")?;
301
302        session.thread.update(cx, |thread, cx| {
303            thread.handle_session_update(notification.update, cx)
304        })??;
305
306        Ok(())
307    }
308}