v1.rs

  1use agent_client_protocol::{self as acp, Agent as _};
  2use collections::HashMap;
  3use futures::channel::oneshot;
  4use project::Project;
  5use std::cell::RefCell;
  6use std::path::Path;
  7use std::rc::Rc;
  8
  9use anyhow::{Context as _, Result};
 10use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
 11
 12use crate::{AgentServerCommand, acp::UnsupportedVersion};
 13use acp_thread::{AcpThread, AgentConnection, AuthRequired};
 14
 15pub struct AcpConnection {
 16    server_name: &'static str,
 17    connection: Rc<acp::ClientSideConnection>,
 18    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
 19    auth_methods: Vec<acp::AuthMethod>,
 20    _io_task: Task<Result<()>>,
 21    _child: smol::process::Child,
 22}
 23
 24pub struct AcpSession {
 25    thread: WeakEntity<AcpThread>,
 26}
 27
 28const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
 29
 30impl AcpConnection {
 31    pub async fn stdio(
 32        server_name: &'static str,
 33        command: AgentServerCommand,
 34        root_dir: &Path,
 35        cx: &mut AsyncApp,
 36    ) -> Result<Self> {
 37        let mut child = util::command::new_smol_command(&command.path)
 38            .args(command.args.iter().map(|arg| arg.as_str()))
 39            .envs(command.env.iter().flatten())
 40            .current_dir(root_dir)
 41            .stdin(std::process::Stdio::piped())
 42            .stdout(std::process::Stdio::piped())
 43            .stderr(std::process::Stdio::inherit())
 44            .kill_on_drop(true)
 45            .spawn()?;
 46
 47        let stdout = child.stdout.take().expect("Failed to take stdout");
 48        let stdin = child.stdin.take().expect("Failed to take stdin");
 49
 50        let sessions = Rc::new(RefCell::new(HashMap::default()));
 51
 52        let client = ClientDelegate {
 53            sessions: sessions.clone(),
 54            cx: cx.clone(),
 55        };
 56        let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
 57            let foreground_executor = cx.foreground_executor().clone();
 58            move |fut| {
 59                foreground_executor.spawn(fut).detach();
 60            }
 61        });
 62
 63        let io_task = cx.background_spawn(io_task);
 64
 65        let response = connection
 66            .initialize(acp::InitializeRequest {
 67                protocol_version: acp::VERSION,
 68                client_capabilities: acp::ClientCapabilities {
 69                    fs: acp::FileSystemCapability {
 70                        read_text_file: true,
 71                        write_text_file: true,
 72                    },
 73                },
 74            })
 75            .await?;
 76
 77        if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
 78            return Err(UnsupportedVersion.into());
 79        }
 80
 81        Ok(Self {
 82            auth_methods: response.auth_methods,
 83            connection: connection.into(),
 84            server_name,
 85            sessions,
 86            _child: child,
 87            _io_task: io_task,
 88        })
 89    }
 90}
 91
 92impl AgentConnection for AcpConnection {
 93    fn new_thread(
 94        self: Rc<Self>,
 95        project: Entity<Project>,
 96        cwd: &Path,
 97        cx: &mut AsyncApp,
 98    ) -> Task<Result<Entity<AcpThread>>> {
 99        let conn = self.connection.clone();
100        let sessions = self.sessions.clone();
101        let cwd = cwd.to_path_buf();
102        cx.spawn(async move |cx| {
103            let response = conn
104                .new_session(acp::NewSessionRequest {
105                    mcp_servers: vec![],
106                    cwd,
107                })
108                .await?;
109
110            let Some(session_id) = response.session_id else {
111                anyhow::bail!(AuthRequired);
112            };
113
114            let thread = cx.new(|cx| {
115                AcpThread::new(
116                    self.server_name,
117                    self.clone(),
118                    project,
119                    session_id.clone(),
120                    cx,
121                )
122            })?;
123
124            let session = AcpSession {
125                thread: thread.downgrade(),
126            };
127            sessions.borrow_mut().insert(session_id, session);
128
129            Ok(thread)
130        })
131    }
132
133    fn auth_methods(&self) -> &[acp::AuthMethod] {
134        &self.auth_methods
135    }
136
137    fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
138        let conn = self.connection.clone();
139        cx.foreground_executor().spawn(async move {
140            let result = conn
141                .authenticate(acp::AuthenticateRequest {
142                    method_id: method_id.clone(),
143                })
144                .await?;
145
146            Ok(result)
147        })
148    }
149
150    fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
151        let conn = self.connection.clone();
152        cx.foreground_executor()
153            .spawn(async move { Ok(conn.prompt(params).await?) })
154    }
155
156    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
157        let conn = self.connection.clone();
158        let params = acp::CancelledNotification {
159            session_id: session_id.clone(),
160        };
161        cx.foreground_executor()
162            .spawn(async move { conn.cancelled(params).await })
163            .detach();
164    }
165}
166
167struct ClientDelegate {
168    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
169    cx: AsyncApp,
170}
171
172impl acp::Client for ClientDelegate {
173    async fn request_permission(
174        &self,
175        arguments: acp::RequestPermissionRequest,
176    ) -> Result<acp::RequestPermissionResponse, acp::Error> {
177        let cx = &mut self.cx.clone();
178        let rx = self
179            .sessions
180            .borrow()
181            .get(&arguments.session_id)
182            .context("Failed to get session")?
183            .thread
184            .update(cx, |thread, cx| {
185                thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx)
186            })?;
187
188        let result = rx.await;
189
190        let outcome = match result {
191            Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
192            Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
193        };
194
195        Ok(acp::RequestPermissionResponse { outcome })
196    }
197
198    async fn write_text_file(
199        &self,
200        arguments: acp::WriteTextFileRequest,
201    ) -> Result<(), acp::Error> {
202        let cx = &mut self.cx.clone();
203        let task = self
204            .sessions
205            .borrow()
206            .get(&arguments.session_id)
207            .context("Failed to get session")?
208            .thread
209            .update(cx, |thread, cx| {
210                thread.write_text_file(arguments.path, arguments.content, cx)
211            })?;
212
213        task.await?;
214
215        Ok(())
216    }
217
218    async fn read_text_file(
219        &self,
220        arguments: acp::ReadTextFileRequest,
221    ) -> Result<acp::ReadTextFileResponse, acp::Error> {
222        let cx = &mut self.cx.clone();
223        let task = self
224            .sessions
225            .borrow()
226            .get(&arguments.session_id)
227            .context("Failed to get session")?
228            .thread
229            .update(cx, |thread, cx| {
230                thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
231            })?;
232
233        let content = task.await?;
234
235        Ok(acp::ReadTextFileResponse { content })
236    }
237
238    async fn session_notification(
239        &self,
240        notification: acp::SessionNotification,
241    ) -> Result<(), acp::Error> {
242        let cx = &mut self.cx.clone();
243        let sessions = self.sessions.borrow();
244        let session = sessions
245            .get(&notification.session_id)
246            .context("Failed to get session")?;
247
248        session.thread.update(cx, |thread, cx| {
249            thread.handle_session_update(notification.update, cx)
250        })??;
251
252        Ok(())
253    }
254}