1use action_log::ActionLog;
2use agent_client_protocol::{self as acp, Agent as _};
3use anyhow::anyhow;
4use collections::HashMap;
5use futures::AsyncBufReadExt as _;
6use futures::channel::oneshot;
7use futures::io::BufReader;
8use project::Project;
9use std::path::Path;
10use std::rc::Rc;
11use std::{any::Any, cell::RefCell};
12
13use anyhow::{Context as _, Result};
14use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
15
16use crate::{AgentServerCommand, acp::UnsupportedVersion};
17use acp_thread::{AcpThread, AgentConnection, AuthRequired, LoadError};
18
19pub struct AcpConnection {
20 server_name: &'static str,
21 connection: Rc<acp::ClientSideConnection>,
22 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
23 auth_methods: Vec<acp::AuthMethod>,
24 prompt_capabilities: acp::PromptCapabilities,
25 _io_task: Task<Result<()>>,
26}
27
28pub struct AcpSession {
29 thread: WeakEntity<AcpThread>,
30}
31
32const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
33
34impl AcpConnection {
35 pub async fn stdio(
36 server_name: &'static str,
37 command: AgentServerCommand,
38 root_dir: &Path,
39 cx: &mut AsyncApp,
40 ) -> Result<Self> {
41 let mut child = util::command::new_smol_command(&command.path)
42 .args(command.args.iter().map(|arg| arg.as_str()))
43 .envs(command.env.iter().flatten())
44 .current_dir(root_dir)
45 .stdin(std::process::Stdio::piped())
46 .stdout(std::process::Stdio::piped())
47 .stderr(std::process::Stdio::piped())
48 .kill_on_drop(true)
49 .spawn()?;
50
51 let stdout = child.stdout.take().context("Failed to take stdout")?;
52 let stdin = child.stdin.take().context("Failed to take stdin")?;
53 let stderr = child.stderr.take().context("Failed to take stderr")?;
54 log::trace!("Spawned (pid: {})", child.id());
55
56 let sessions = Rc::new(RefCell::new(HashMap::default()));
57
58 let client = ClientDelegate {
59 sessions: sessions.clone(),
60 cx: cx.clone(),
61 };
62 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
63 let foreground_executor = cx.foreground_executor().clone();
64 move |fut| {
65 foreground_executor.spawn(fut).detach();
66 }
67 });
68
69 let io_task = cx.background_spawn(io_task);
70
71 cx.background_spawn(async move {
72 let mut stderr = BufReader::new(stderr);
73 let mut line = String::new();
74 while let Ok(n) = stderr.read_line(&mut line).await
75 && n > 0
76 {
77 log::warn!("agent stderr: {}", &line);
78 line.clear();
79 }
80 })
81 .detach();
82
83 cx.spawn({
84 let sessions = sessions.clone();
85 async move |cx| {
86 let status = child.status().await?;
87
88 for session in sessions.borrow().values() {
89 session
90 .thread
91 .update(cx, |thread, cx| {
92 thread.emit_load_error(LoadError::Exited { status }, cx)
93 })
94 .ok();
95 }
96
97 anyhow::Ok(())
98 }
99 })
100 .detach();
101
102 let response = connection
103 .initialize(acp::InitializeRequest {
104 protocol_version: acp::VERSION,
105 client_capabilities: acp::ClientCapabilities {
106 fs: acp::FileSystemCapability {
107 read_text_file: true,
108 write_text_file: true,
109 },
110 },
111 })
112 .await?;
113
114 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
115 return Err(UnsupportedVersion.into());
116 }
117
118 Ok(Self {
119 auth_methods: response.auth_methods,
120 connection: connection.into(),
121 server_name,
122 sessions,
123 prompt_capabilities: response.agent_capabilities.prompt_capabilities,
124 _io_task: io_task,
125 })
126 }
127}
128
129impl AgentConnection for AcpConnection {
130 fn new_thread(
131 self: Rc<Self>,
132 project: Entity<Project>,
133 cwd: &Path,
134 cx: &mut App,
135 ) -> Task<Result<Entity<AcpThread>>> {
136 let conn = self.connection.clone();
137 let sessions = self.sessions.clone();
138 let cwd = cwd.to_path_buf();
139 cx.spawn(async move |cx| {
140 let response = conn
141 .new_session(acp::NewSessionRequest {
142 mcp_servers: vec![],
143 cwd,
144 })
145 .await
146 .map_err(|err| {
147 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
148 let mut error = AuthRequired::new();
149
150 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
151 error = error.with_description(err.message);
152 }
153
154 anyhow!(error)
155 } else {
156 anyhow!(err)
157 }
158 })?;
159
160 let session_id = response.session_id;
161 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
162 let thread = cx.new(|_cx| {
163 AcpThread::new(
164 self.server_name,
165 self.clone(),
166 project,
167 action_log,
168 session_id.clone(),
169 )
170 })?;
171
172 let session = AcpSession {
173 thread: thread.downgrade(),
174 };
175 sessions.borrow_mut().insert(session_id, session);
176
177 Ok(thread)
178 })
179 }
180
181 fn auth_methods(&self) -> &[acp::AuthMethod] {
182 &self.auth_methods
183 }
184
185 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
186 let conn = self.connection.clone();
187 cx.foreground_executor().spawn(async move {
188 let result = conn
189 .authenticate(acp::AuthenticateRequest {
190 method_id: method_id.clone(),
191 })
192 .await?;
193
194 Ok(result)
195 })
196 }
197
198 fn prompt(
199 &self,
200 _id: Option<acp_thread::UserMessageId>,
201 params: acp::PromptRequest,
202 cx: &mut App,
203 ) -> Task<Result<acp::PromptResponse>> {
204 let conn = self.connection.clone();
205 cx.foreground_executor().spawn(async move {
206 let response = conn.prompt(params).await?;
207 Ok(response)
208 })
209 }
210
211 fn prompt_capabilities(&self) -> acp::PromptCapabilities {
212 self.prompt_capabilities
213 }
214
215 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
216 let conn = self.connection.clone();
217 let params = acp::CancelNotification {
218 session_id: session_id.clone(),
219 };
220 cx.foreground_executor()
221 .spawn(async move { conn.cancel(params).await })
222 .detach();
223 }
224
225 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
226 self
227 }
228}
229
230struct ClientDelegate {
231 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
232 cx: AsyncApp,
233}
234
235impl acp::Client for ClientDelegate {
236 async fn request_permission(
237 &self,
238 arguments: acp::RequestPermissionRequest,
239 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
240 let cx = &mut self.cx.clone();
241 let rx = self
242 .sessions
243 .borrow()
244 .get(&arguments.session_id)
245 .context("Failed to get session")?
246 .thread
247 .update(cx, |thread, cx| {
248 thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
249 })?;
250
251 let result = rx?.await;
252
253 let outcome = match result {
254 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
255 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled,
256 };
257
258 Ok(acp::RequestPermissionResponse { outcome })
259 }
260
261 async fn write_text_file(
262 &self,
263 arguments: acp::WriteTextFileRequest,
264 ) -> Result<(), acp::Error> {
265 let cx = &mut self.cx.clone();
266 let task = self
267 .sessions
268 .borrow()
269 .get(&arguments.session_id)
270 .context("Failed to get session")?
271 .thread
272 .update(cx, |thread, cx| {
273 thread.write_text_file(arguments.path, arguments.content, cx)
274 })?;
275
276 task.await?;
277
278 Ok(())
279 }
280
281 async fn read_text_file(
282 &self,
283 arguments: acp::ReadTextFileRequest,
284 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
285 let cx = &mut self.cx.clone();
286 let task = self
287 .sessions
288 .borrow()
289 .get(&arguments.session_id)
290 .context("Failed to get session")?
291 .thread
292 .update(cx, |thread, cx| {
293 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
294 })?;
295
296 let content = task.await?;
297
298 Ok(acp::ReadTextFileResponse { content })
299 }
300
301 async fn session_notification(
302 &self,
303 notification: acp::SessionNotification,
304 ) -> Result<(), acp::Error> {
305 let cx = &mut self.cx.clone();
306 let sessions = self.sessions.borrow();
307 let session = sessions
308 .get(¬ification.session_id)
309 .context("Failed to get session")?;
310
311 session.thread.update(cx, |thread, cx| {
312 thread.handle_session_update(notification.update, cx)
313 })??;
314
315 Ok(())
316 }
317}