v1.rs

  1use agent_client_protocol::{self as acp, Agent as _};
  2use anyhow::anyhow;
  3use collections::HashMap;
  4use futures::AsyncBufReadExt as _;
  5use futures::channel::oneshot;
  6use futures::io::BufReader;
  7use project::Project;
  8use std::path::Path;
  9use std::rc::Rc;
 10use std::{any::Any, cell::RefCell};
 11
 12use anyhow::{Context as _, Result};
 13use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
 14
 15use crate::{AgentServerCommand, acp::UnsupportedVersion};
 16use acp_thread::{AcpThread, AgentConnection, AuthRequired};
 17
 18pub struct AcpConnection {
 19    server_name: &'static str,
 20    connection: Rc<acp::ClientSideConnection>,
 21    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
 22    auth_methods: Vec<acp::AuthMethod>,
 23    _io_task: Task<Result<()>>,
 24}
 25
 26pub struct AcpSession {
 27    thread: WeakEntity<AcpThread>,
 28}
 29
 30const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
 31
 32impl AcpConnection {
 33    pub async fn stdio(
 34        server_name: &'static str,
 35        command: AgentServerCommand,
 36        root_dir: &Path,
 37        cx: &mut AsyncApp,
 38    ) -> Result<Self> {
 39        let mut child = util::command::new_smol_command(&command.path)
 40            .args(command.args.iter().map(|arg| arg.as_str()))
 41            .envs(command.env.iter().flatten())
 42            .current_dir(root_dir)
 43            .stdin(std::process::Stdio::piped())
 44            .stdout(std::process::Stdio::piped())
 45            .stderr(std::process::Stdio::piped())
 46            .kill_on_drop(true)
 47            .spawn()?;
 48
 49        let stdout = child.stdout.take().context("Failed to take stdout")?;
 50        let stdin = child.stdin.take().context("Failed to take stdin")?;
 51        let stderr = child.stderr.take().context("Failed to take stderr")?;
 52        log::trace!("Spawned (pid: {})", child.id());
 53
 54        let sessions = Rc::new(RefCell::new(HashMap::default()));
 55
 56        let client = ClientDelegate {
 57            sessions: sessions.clone(),
 58            cx: cx.clone(),
 59        };
 60        let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
 61            let foreground_executor = cx.foreground_executor().clone();
 62            move |fut| {
 63                foreground_executor.spawn(fut).detach();
 64            }
 65        });
 66
 67        let io_task = cx.background_spawn(io_task);
 68
 69        cx.background_spawn(async move {
 70            let mut stderr = BufReader::new(stderr);
 71            let mut line = String::new();
 72            while let Ok(n) = stderr.read_line(&mut line).await
 73                && n > 0
 74            {
 75                log::warn!("agent stderr: {}", &line);
 76                line.clear();
 77            }
 78        })
 79        .detach();
 80
 81        cx.spawn({
 82            let sessions = sessions.clone();
 83            async move |cx| {
 84                let status = child.status().await?;
 85
 86                for session in sessions.borrow().values() {
 87                    session
 88                        .thread
 89                        .update(cx, |thread, cx| thread.emit_server_exited(status, cx))
 90                        .ok();
 91                }
 92
 93                anyhow::Ok(())
 94            }
 95        })
 96        .detach();
 97
 98        let response = connection
 99            .initialize(acp::InitializeRequest {
100                protocol_version: acp::VERSION,
101                client_capabilities: acp::ClientCapabilities {
102                    fs: acp::FileSystemCapability {
103                        read_text_file: true,
104                        write_text_file: true,
105                    },
106                },
107            })
108            .await?;
109
110        if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
111            return Err(UnsupportedVersion.into());
112        }
113
114        Ok(Self {
115            auth_methods: response.auth_methods,
116            connection: connection.into(),
117            server_name,
118            sessions,
119            _io_task: io_task,
120        })
121    }
122}
123
124impl AgentConnection for AcpConnection {
125    fn new_thread(
126        self: Rc<Self>,
127        project: Entity<Project>,
128        cwd: &Path,
129        cx: &mut App,
130    ) -> Task<Result<Entity<AcpThread>>> {
131        let conn = self.connection.clone();
132        let sessions = self.sessions.clone();
133        let cwd = cwd.to_path_buf();
134        cx.spawn(async move |cx| {
135            let response = conn
136                .new_session(acp::NewSessionRequest {
137                    mcp_servers: vec![],
138                    cwd,
139                })
140                .await
141                .map_err(|err| {
142                    if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
143                        anyhow!(AuthRequired)
144                    } else {
145                        anyhow!(err)
146                    }
147                })?;
148
149            let session_id = response.session_id;
150
151            let thread = cx.new(|cx| {
152                AcpThread::new(
153                    self.server_name,
154                    self.clone(),
155                    project,
156                    session_id.clone(),
157                    cx,
158                )
159            })?;
160
161            let session = AcpSession {
162                thread: thread.downgrade(),
163            };
164            sessions.borrow_mut().insert(session_id, session);
165
166            Ok(thread)
167        })
168    }
169
170    fn auth_methods(&self) -> &[acp::AuthMethod] {
171        &self.auth_methods
172    }
173
174    fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
175        let conn = self.connection.clone();
176        cx.foreground_executor().spawn(async move {
177            let result = conn
178                .authenticate(acp::AuthenticateRequest {
179                    method_id: method_id.clone(),
180                })
181                .await?;
182
183            Ok(result)
184        })
185    }
186
187    fn prompt(
188        &self,
189        _id: Option<acp_thread::UserMessageId>,
190        params: acp::PromptRequest,
191        cx: &mut App,
192    ) -> Task<Result<acp::PromptResponse>> {
193        let conn = self.connection.clone();
194        cx.foreground_executor().spawn(async move {
195            let response = conn.prompt(params).await?;
196            Ok(response)
197        })
198    }
199
200    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
201        let conn = self.connection.clone();
202        let params = acp::CancelNotification {
203            session_id: session_id.clone(),
204        };
205        cx.foreground_executor()
206            .spawn(async move { conn.cancel(params).await })
207            .detach();
208    }
209
210    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
211        self
212    }
213}
214
215struct ClientDelegate {
216    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
217    cx: AsyncApp,
218}
219
220impl acp::Client for ClientDelegate {
221    async fn request_permission(
222        &self,
223        arguments: acp::RequestPermissionRequest,
224    ) -> Result<acp::RequestPermissionResponse, acp::Error> {
225        let cx = &mut self.cx.clone();
226        let rx = self
227            .sessions
228            .borrow()
229            .get(&arguments.session_id)
230            .context("Failed to get session")?
231            .thread
232            .update(cx, |thread, cx| {
233                thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
234            })?;
235
236        let result = rx?.await;
237
238        let outcome = match result {
239            Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
240            Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled,
241        };
242
243        Ok(acp::RequestPermissionResponse { outcome })
244    }
245
246    async fn write_text_file(
247        &self,
248        arguments: acp::WriteTextFileRequest,
249    ) -> Result<(), acp::Error> {
250        let cx = &mut self.cx.clone();
251        let task = self
252            .sessions
253            .borrow()
254            .get(&arguments.session_id)
255            .context("Failed to get session")?
256            .thread
257            .update(cx, |thread, cx| {
258                thread.write_text_file(arguments.path, arguments.content, cx)
259            })?;
260
261        task.await?;
262
263        Ok(())
264    }
265
266    async fn read_text_file(
267        &self,
268        arguments: acp::ReadTextFileRequest,
269    ) -> Result<acp::ReadTextFileResponse, acp::Error> {
270        let cx = &mut self.cx.clone();
271        let task = self
272            .sessions
273            .borrow()
274            .get(&arguments.session_id)
275            .context("Failed to get session")?
276            .thread
277            .update(cx, |thread, cx| {
278                thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
279            })?;
280
281        let content = task.await?;
282
283        Ok(acp::ReadTextFileResponse { content })
284    }
285
286    async fn session_notification(
287        &self,
288        notification: acp::SessionNotification,
289    ) -> Result<(), acp::Error> {
290        let cx = &mut self.cx.clone();
291        let sessions = self.sessions.borrow();
292        let session = sessions
293            .get(&notification.session_id)
294            .context("Failed to get session")?;
295
296        session.thread.update(cx, |thread, cx| {
297            thread.handle_session_update(notification.update, cx)
298        })??;
299
300        Ok(())
301    }
302}