1use action_log::ActionLog;
2use agent_client_protocol::{self as acp, Agent as _};
3use anyhow::anyhow;
4use collections::HashMap;
5use futures::AsyncBufReadExt as _;
6use futures::channel::oneshot;
7use futures::io::BufReader;
8use project::Project;
9use std::path::Path;
10use std::rc::Rc;
11use std::{any::Any, cell::RefCell};
12
13use anyhow::{Context as _, Result};
14use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
15
16use crate::{AgentServerCommand, acp::UnsupportedVersion};
17use acp_thread::{AcpThread, AgentConnection, AuthRequired, LoadError};
18
19pub struct AcpConnection {
20 server_name: &'static str,
21 connection: Rc<acp::ClientSideConnection>,
22 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
23 auth_methods: Vec<acp::AuthMethod>,
24 _io_task: Task<Result<()>>,
25}
26
27pub struct AcpSession {
28 thread: WeakEntity<AcpThread>,
29}
30
31const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
32
33impl AcpConnection {
34 pub async fn stdio(
35 server_name: &'static str,
36 command: AgentServerCommand,
37 root_dir: &Path,
38 cx: &mut AsyncApp,
39 ) -> Result<Self> {
40 let mut child = util::command::new_smol_command(&command.path)
41 .args(command.args.iter().map(|arg| arg.as_str()))
42 .envs(command.env.iter().flatten())
43 .current_dir(root_dir)
44 .stdin(std::process::Stdio::piped())
45 .stdout(std::process::Stdio::piped())
46 .stderr(std::process::Stdio::piped())
47 .kill_on_drop(true)
48 .spawn()?;
49
50 let stdout = child.stdout.take().context("Failed to take stdout")?;
51 let stdin = child.stdin.take().context("Failed to take stdin")?;
52 let stderr = child.stderr.take().context("Failed to take stderr")?;
53 log::trace!("Spawned (pid: {})", child.id());
54
55 let sessions = Rc::new(RefCell::new(HashMap::default()));
56
57 let client = ClientDelegate {
58 sessions: sessions.clone(),
59 cx: cx.clone(),
60 };
61 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
62 let foreground_executor = cx.foreground_executor().clone();
63 move |fut| {
64 foreground_executor.spawn(fut).detach();
65 }
66 });
67
68 let io_task = cx.background_spawn(io_task);
69
70 cx.background_spawn(async move {
71 let mut stderr = BufReader::new(stderr);
72 let mut line = String::new();
73 while let Ok(n) = stderr.read_line(&mut line).await
74 && n > 0
75 {
76 log::warn!("agent stderr: {}", &line);
77 line.clear();
78 }
79 })
80 .detach();
81
82 cx.spawn({
83 let sessions = sessions.clone();
84 async move |cx| {
85 let status = child.status().await?;
86
87 for session in sessions.borrow().values() {
88 session
89 .thread
90 .update(cx, |thread, cx| {
91 thread.emit_load_error(LoadError::Exited { status }, cx)
92 })
93 .ok();
94 }
95
96 anyhow::Ok(())
97 }
98 })
99 .detach();
100
101 let response = connection
102 .initialize(acp::InitializeRequest {
103 protocol_version: acp::VERSION,
104 client_capabilities: acp::ClientCapabilities {
105 fs: acp::FileSystemCapability {
106 read_text_file: true,
107 write_text_file: true,
108 },
109 },
110 })
111 .await?;
112
113 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
114 return Err(UnsupportedVersion.into());
115 }
116
117 Ok(Self {
118 auth_methods: response.auth_methods,
119 connection: connection.into(),
120 server_name,
121 sessions,
122 _io_task: io_task,
123 })
124 }
125}
126
127impl AgentConnection for AcpConnection {
128 fn new_thread(
129 self: Rc<Self>,
130 project: Entity<Project>,
131 cwd: &Path,
132 cx: &mut App,
133 ) -> Task<Result<Entity<AcpThread>>> {
134 let conn = self.connection.clone();
135 let sessions = self.sessions.clone();
136 let cwd = cwd.to_path_buf();
137 cx.spawn(async move |cx| {
138 let response = conn
139 .new_session(acp::NewSessionRequest {
140 mcp_servers: vec![],
141 cwd,
142 })
143 .await
144 .map_err(|err| {
145 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
146 let mut error = AuthRequired::new();
147
148 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
149 error = error.with_description(err.message);
150 }
151
152 anyhow!(error)
153 } else {
154 anyhow!(err)
155 }
156 })?;
157
158 let session_id = response.session_id;
159 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
160 let thread = cx.new(|_cx| {
161 AcpThread::new(
162 self.server_name,
163 self.clone(),
164 project,
165 action_log,
166 session_id.clone(),
167 )
168 })?;
169
170 let session = AcpSession {
171 thread: thread.downgrade(),
172 };
173 sessions.borrow_mut().insert(session_id, session);
174
175 Ok(thread)
176 })
177 }
178
179 fn auth_methods(&self) -> &[acp::AuthMethod] {
180 &self.auth_methods
181 }
182
183 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
184 let conn = self.connection.clone();
185 cx.foreground_executor().spawn(async move {
186 let result = conn
187 .authenticate(acp::AuthenticateRequest {
188 method_id: method_id.clone(),
189 })
190 .await?;
191
192 Ok(result)
193 })
194 }
195
196 fn prompt(
197 &self,
198 _id: Option<acp_thread::UserMessageId>,
199 params: acp::PromptRequest,
200 cx: &mut App,
201 ) -> Task<Result<acp::PromptResponse>> {
202 let conn = self.connection.clone();
203 cx.foreground_executor().spawn(async move {
204 let response = conn.prompt(params).await?;
205 Ok(response)
206 })
207 }
208
209 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
210 let conn = self.connection.clone();
211 let params = acp::CancelNotification {
212 session_id: session_id.clone(),
213 };
214 cx.foreground_executor()
215 .spawn(async move { conn.cancel(params).await })
216 .detach();
217 }
218
219 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
220 self
221 }
222}
223
224struct ClientDelegate {
225 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
226 cx: AsyncApp,
227}
228
229impl acp::Client for ClientDelegate {
230 async fn request_permission(
231 &self,
232 arguments: acp::RequestPermissionRequest,
233 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
234 let cx = &mut self.cx.clone();
235 let rx = self
236 .sessions
237 .borrow()
238 .get(&arguments.session_id)
239 .context("Failed to get session")?
240 .thread
241 .update(cx, |thread, cx| {
242 thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
243 })?;
244
245 let result = rx?.await;
246
247 let outcome = match result {
248 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
249 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled,
250 };
251
252 Ok(acp::RequestPermissionResponse { outcome })
253 }
254
255 async fn write_text_file(
256 &self,
257 arguments: acp::WriteTextFileRequest,
258 ) -> Result<(), acp::Error> {
259 let cx = &mut self.cx.clone();
260 let task = self
261 .sessions
262 .borrow()
263 .get(&arguments.session_id)
264 .context("Failed to get session")?
265 .thread
266 .update(cx, |thread, cx| {
267 thread.write_text_file(arguments.path, arguments.content, cx)
268 })?;
269
270 task.await?;
271
272 Ok(())
273 }
274
275 async fn read_text_file(
276 &self,
277 arguments: acp::ReadTextFileRequest,
278 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
279 let cx = &mut self.cx.clone();
280 let task = self
281 .sessions
282 .borrow()
283 .get(&arguments.session_id)
284 .context("Failed to get session")?
285 .thread
286 .update(cx, |thread, cx| {
287 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
288 })?;
289
290 let content = task.await?;
291
292 Ok(acp::ReadTextFileResponse { content })
293 }
294
295 async fn session_notification(
296 &self,
297 notification: acp::SessionNotification,
298 ) -> Result<(), acp::Error> {
299 let cx = &mut self.cx.clone();
300 let sessions = self.sessions.borrow();
301 let session = sessions
302 .get(¬ification.session_id)
303 .context("Failed to get session")?;
304
305 session.thread.update(cx, |thread, cx| {
306 thread.handle_session_update(notification.update, cx)
307 })??;
308
309 Ok(())
310 }
311}