1use agent_client_protocol::{self as acp, Agent as _};
2use anyhow::anyhow;
3use collections::HashMap;
4use futures::channel::oneshot;
5use project::Project;
6use std::cell::RefCell;
7use std::path::Path;
8use std::rc::Rc;
9
10use anyhow::{Context as _, Result};
11use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
12
13use crate::{AgentServerCommand, acp::UnsupportedVersion};
14use acp_thread::{AcpThread, AgentConnection, AuthRequired};
15
16pub struct AcpConnection {
17 server_name: &'static str,
18 connection: Rc<acp::ClientSideConnection>,
19 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
20 auth_methods: Vec<acp::AuthMethod>,
21 _io_task: Task<Result<()>>,
22}
23
24pub struct AcpSession {
25 thread: WeakEntity<AcpThread>,
26}
27
28const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
29
30impl AcpConnection {
31 pub async fn stdio(
32 server_name: &'static str,
33 command: AgentServerCommand,
34 root_dir: &Path,
35 cx: &mut AsyncApp,
36 ) -> Result<Self> {
37 let mut child = util::command::new_smol_command(&command.path)
38 .args(command.args.iter().map(|arg| arg.as_str()))
39 .envs(command.env.iter().flatten())
40 .current_dir(root_dir)
41 .stdin(std::process::Stdio::piped())
42 .stdout(std::process::Stdio::piped())
43 .stderr(std::process::Stdio::inherit())
44 .kill_on_drop(true)
45 .spawn()?;
46
47 let stdout = child.stdout.take().expect("Failed to take stdout");
48 let stdin = child.stdin.take().expect("Failed to take stdin");
49 log::trace!("Spawned (pid: {})", child.id());
50
51 let sessions = Rc::new(RefCell::new(HashMap::default()));
52
53 let client = ClientDelegate {
54 sessions: sessions.clone(),
55 cx: cx.clone(),
56 };
57 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
58 let foreground_executor = cx.foreground_executor().clone();
59 move |fut| {
60 foreground_executor.spawn(fut).detach();
61 }
62 });
63
64 let io_task = cx.background_spawn(io_task);
65
66 cx.spawn({
67 let sessions = sessions.clone();
68 async move |cx| {
69 let status = child.status().await?;
70
71 for session in sessions.borrow().values() {
72 session
73 .thread
74 .update(cx, |thread, cx| thread.emit_server_exited(status, cx))
75 .ok();
76 }
77
78 anyhow::Ok(())
79 }
80 })
81 .detach();
82
83 let response = connection
84 .initialize(acp::InitializeRequest {
85 protocol_version: acp::VERSION,
86 client_capabilities: acp::ClientCapabilities {
87 fs: acp::FileSystemCapability {
88 read_text_file: true,
89 write_text_file: true,
90 },
91 },
92 })
93 .await?;
94
95 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
96 return Err(UnsupportedVersion.into());
97 }
98
99 Ok(Self {
100 auth_methods: response.auth_methods,
101 connection: connection.into(),
102 server_name,
103 sessions,
104 _io_task: io_task,
105 })
106 }
107}
108
109impl AgentConnection for AcpConnection {
110 fn new_thread(
111 self: Rc<Self>,
112 project: Entity<Project>,
113 cwd: &Path,
114 cx: &mut AsyncApp,
115 ) -> Task<Result<Entity<AcpThread>>> {
116 let conn = self.connection.clone();
117 let sessions = self.sessions.clone();
118 let cwd = cwd.to_path_buf();
119 cx.spawn(async move |cx| {
120 let response = conn
121 .new_session(acp::NewSessionRequest {
122 mcp_servers: vec![],
123 cwd,
124 })
125 .await
126 .map_err(|err| {
127 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
128 anyhow!(AuthRequired)
129 } else {
130 anyhow!(err)
131 }
132 })?;
133
134 let session_id = response.session_id;
135
136 let thread = cx.new(|cx| {
137 AcpThread::new(
138 self.server_name,
139 self.clone(),
140 project,
141 session_id.clone(),
142 cx,
143 )
144 })?;
145
146 let session = AcpSession {
147 thread: thread.downgrade(),
148 };
149 sessions.borrow_mut().insert(session_id, session);
150
151 Ok(thread)
152 })
153 }
154
155 fn auth_methods(&self) -> &[acp::AuthMethod] {
156 &self.auth_methods
157 }
158
159 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
160 let conn = self.connection.clone();
161 cx.foreground_executor().spawn(async move {
162 let result = conn
163 .authenticate(acp::AuthenticateRequest {
164 method_id: method_id.clone(),
165 })
166 .await?;
167
168 Ok(result)
169 })
170 }
171
172 fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
173 let conn = self.connection.clone();
174 cx.foreground_executor().spawn(async move {
175 conn.prompt(params).await?;
176 Ok(())
177 })
178 }
179
180 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
181 let conn = self.connection.clone();
182 let params = acp::CancelNotification {
183 session_id: session_id.clone(),
184 };
185 cx.foreground_executor()
186 .spawn(async move { conn.cancel(params).await })
187 .detach();
188 }
189}
190
191struct ClientDelegate {
192 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
193 cx: AsyncApp,
194}
195
196impl acp::Client for ClientDelegate {
197 async fn request_permission(
198 &self,
199 arguments: acp::RequestPermissionRequest,
200 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
201 let cx = &mut self.cx.clone();
202 let rx = self
203 .sessions
204 .borrow()
205 .get(&arguments.session_id)
206 .context("Failed to get session")?
207 .thread
208 .update(cx, |thread, cx| {
209 thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx)
210 })?;
211
212 let result = rx.await;
213
214 let outcome = match result {
215 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
216 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
217 };
218
219 Ok(acp::RequestPermissionResponse { outcome })
220 }
221
222 async fn write_text_file(
223 &self,
224 arguments: acp::WriteTextFileRequest,
225 ) -> Result<(), acp::Error> {
226 let cx = &mut self.cx.clone();
227 let task = self
228 .sessions
229 .borrow()
230 .get(&arguments.session_id)
231 .context("Failed to get session")?
232 .thread
233 .update(cx, |thread, cx| {
234 thread.write_text_file(arguments.path, arguments.content, cx)
235 })?;
236
237 task.await?;
238
239 Ok(())
240 }
241
242 async fn read_text_file(
243 &self,
244 arguments: acp::ReadTextFileRequest,
245 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
246 let cx = &mut self.cx.clone();
247 let task = self
248 .sessions
249 .borrow()
250 .get(&arguments.session_id)
251 .context("Failed to get session")?
252 .thread
253 .update(cx, |thread, cx| {
254 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
255 })?;
256
257 let content = task.await?;
258
259 Ok(acp::ReadTextFileResponse { content })
260 }
261
262 async fn session_notification(
263 &self,
264 notification: acp::SessionNotification,
265 ) -> Result<(), acp::Error> {
266 let cx = &mut self.cx.clone();
267 let sessions = self.sessions.borrow();
268 let session = sessions
269 .get(¬ification.session_id)
270 .context("Failed to get session")?;
271
272 session.thread.update(cx, |thread, cx| {
273 thread.handle_session_update(notification.update, cx)
274 })??;
275
276 Ok(())
277 }
278}