1use crate::{AcpThread, ThreadEntryId, ThreadId, ToolCallId, ToolCallRequest};
2use agentic_coding_protocol as acp;
3use anyhow::{Context as _, Result};
4use async_trait::async_trait;
5use collections::HashMap;
6use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
7use parking_lot::Mutex;
8use project::Project;
9use smol::process::Child;
10use std::{io::Write as _, path::Path, process::ExitStatus, sync::Arc};
11use util::ResultExt;
12
13pub struct AcpServer {
14 connection: Arc<acp::AgentConnection>,
15 threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
16 project: Entity<Project>,
17 exit_status: Arc<Mutex<Option<ExitStatus>>>,
18 _handler_task: Task<()>,
19 _io_task: Task<()>,
20}
21
22struct AcpClientDelegate {
23 project: Entity<Project>,
24 threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
25 cx: AsyncApp,
26 // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
27}
28
29impl AcpClientDelegate {
30 fn new(
31 project: Entity<Project>,
32 threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
33 cx: AsyncApp,
34 ) -> Self {
35 Self {
36 project,
37 threads,
38 cx: cx,
39 }
40 }
41
42 fn update_thread<R>(
43 &self,
44 thread_id: &ThreadId,
45 cx: &mut App,
46 callback: impl FnOnce(&mut AcpThread, &mut Context<AcpThread>) -> R,
47 ) -> Option<R> {
48 let thread = self.threads.lock().get(&thread_id)?.clone();
49 let Some(thread) = thread.upgrade() else {
50 self.threads.lock().remove(&thread_id);
51 return None;
52 };
53 Some(thread.update(cx, callback))
54 }
55}
56
57#[async_trait(?Send)]
58impl acp::Client for AcpClientDelegate {
59 async fn stat(&self, params: acp::StatParams) -> Result<acp::StatResponse> {
60 let cx = &mut self.cx.clone();
61 self.project.update(cx, |project, cx| {
62 let path = project
63 .project_path_for_absolute_path(Path::new(¶ms.path), cx)
64 .context("Failed to get project path")?;
65
66 match project.entry_for_path(&path, cx) {
67 // todo! refresh entry?
68 None => Ok(acp::StatResponse {
69 exists: false,
70 is_directory: false,
71 size: 0,
72 }),
73 Some(entry) => Ok(acp::StatResponse {
74 exists: entry.is_created(),
75 is_directory: entry.is_dir(),
76 size: entry.size,
77 }),
78 }
79 })?
80 }
81
82 async fn stream_message_chunk(
83 &self,
84 params: acp::StreamMessageChunkParams,
85 ) -> Result<acp::StreamMessageChunkResponse> {
86 let cx = &mut self.cx.clone();
87
88 cx.update(|cx| {
89 self.update_thread(¶ms.thread_id.into(), cx, |thread, cx| {
90 thread.push_assistant_chunk(params.chunk, cx)
91 });
92 })?;
93
94 Ok(acp::StreamMessageChunkResponse)
95 }
96
97 async fn read_text_file(
98 &self,
99 request: acp::ReadTextFileParams,
100 ) -> Result<acp::ReadTextFileResponse> {
101 let cx = &mut self.cx.clone();
102 let buffer = self
103 .project
104 .update(cx, |project, cx| {
105 let path = project
106 .project_path_for_absolute_path(Path::new(&request.path), cx)
107 .context("Failed to get project path")?;
108 anyhow::Ok(project.open_buffer(path, cx))
109 })??
110 .await?;
111
112 buffer.update(cx, |buffer, _cx| {
113 let start = language::Point::new(request.line_offset.unwrap_or(0), 0);
114 let end = match request.line_limit {
115 None => buffer.max_point(),
116 Some(limit) => start + language::Point::new(limit + 1, 0),
117 };
118
119 let content: String = buffer.text_for_range(start..end).collect();
120
121 acp::ReadTextFileResponse {
122 content,
123 version: acp::FileVersion(0),
124 }
125 })
126 }
127
128 async fn read_binary_file(
129 &self,
130 request: acp::ReadBinaryFileParams,
131 ) -> Result<acp::ReadBinaryFileResponse> {
132 let cx = &mut self.cx.clone();
133 let file = self
134 .project
135 .update(cx, |project, cx| {
136 let (worktree, path) = project
137 .find_worktree(Path::new(&request.path), cx)
138 .context("Failed to get project path")?;
139
140 let task = worktree.update(cx, |worktree, cx| worktree.load_binary_file(&path, cx));
141 anyhow::Ok(task)
142 })??
143 .await?;
144
145 // todo! test
146 let content = cx
147 .background_spawn(async move {
148 let start = request.byte_offset.unwrap_or(0) as usize;
149 let end = request
150 .byte_limit
151 .map(|limit| (start + limit as usize).min(file.content.len()))
152 .unwrap_or(file.content.len());
153
154 let range_content = &file.content[start..end];
155
156 let mut base64_content = Vec::new();
157 let mut base64_encoder = base64::write::EncoderWriter::new(
158 std::io::Cursor::new(&mut base64_content),
159 &base64::engine::general_purpose::STANDARD,
160 );
161 base64_encoder.write_all(range_content)?;
162 drop(base64_encoder);
163
164 // SAFETY: The base64 encoder should not produce non-UTF8.
165 unsafe { anyhow::Ok(String::from_utf8_unchecked(base64_content)) }
166 })
167 .await?;
168
169 Ok(acp::ReadBinaryFileResponse {
170 content,
171 // todo!
172 version: acp::FileVersion(0),
173 })
174 }
175
176 async fn glob_search(
177 &self,
178 _request: acp::GlobSearchParams,
179 ) -> Result<acp::GlobSearchResponse> {
180 todo!()
181 }
182
183 async fn request_tool_call_confirmation(
184 &self,
185 request: acp::RequestToolCallConfirmationParams,
186 ) -> Result<acp::RequestToolCallConfirmationResponse> {
187 let cx = &mut self.cx.clone();
188 let ToolCallRequest { id, outcome } = cx
189 .update(|cx| {
190 self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
191 thread.request_tool_call(request.label, request.icon, request.confirmation, cx)
192 })
193 })?
194 .context("Failed to update thread")?;
195
196 Ok(acp::RequestToolCallConfirmationResponse {
197 id: id.into(),
198 outcome: outcome.await?,
199 })
200 }
201
202 async fn push_tool_call(
203 &self,
204 request: acp::PushToolCallParams,
205 ) -> Result<acp::PushToolCallResponse> {
206 let cx = &mut self.cx.clone();
207 let entry_id = cx
208 .update(|cx| {
209 self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
210 thread.push_tool_call(request.label, request.icon, cx)
211 })
212 })?
213 .context("Failed to update thread")?;
214
215 Ok(acp::PushToolCallResponse {
216 id: entry_id.into(),
217 })
218 }
219
220 async fn update_tool_call(
221 &self,
222 request: acp::UpdateToolCallParams,
223 ) -> Result<acp::UpdateToolCallResponse> {
224 let cx = &mut self.cx.clone();
225
226 cx.update(|cx| {
227 self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
228 thread.update_tool_call(
229 request.tool_call_id.into(),
230 request.status,
231 request.content,
232 cx,
233 )
234 })
235 })?
236 .context("Failed to update thread")??;
237
238 Ok(acp::UpdateToolCallResponse)
239 }
240}
241
242impl AcpServer {
243 pub fn stdio(mut process: Child, project: Entity<Project>, cx: &mut App) -> Arc<Self> {
244 let stdin = process.stdin.take().expect("process didn't have stdin");
245 let stdout = process.stdout.take().expect("process didn't have stdout");
246
247 let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
248 let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
249 AcpClientDelegate::new(project.clone(), threads.clone(), cx.to_async()),
250 stdin,
251 stdout,
252 );
253
254 let exit_status: Arc<Mutex<Option<ExitStatus>>> = Default::default();
255 let io_task = cx.background_spawn({
256 let exit_status = exit_status.clone();
257 async move {
258 io_fut.await.log_err();
259 let result = process.status().await.log_err();
260 *exit_status.lock() = result;
261 }
262 });
263
264 Arc::new(Self {
265 project,
266 connection: Arc::new(connection),
267 threads,
268 exit_status,
269 _handler_task: cx.foreground_executor().spawn(handler_fut),
270 _io_task: io_task,
271 })
272 }
273
274 pub async fn initialize(&self) -> Result<acp::InitializeResponse> {
275 self.connection
276 .request(acp::InitializeParams)
277 .await
278 .map_err(to_anyhow)
279 }
280
281 pub async fn authenticate(&self) -> Result<()> {
282 self.connection
283 .request(acp::AuthenticateParams)
284 .await
285 .map_err(to_anyhow)?;
286
287 Ok(())
288 }
289
290 pub async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<AcpThread>> {
291 let response = self
292 .connection
293 .request(acp::CreateThreadParams)
294 .await
295 .map_err(to_anyhow)?;
296
297 let thread_id: ThreadId = response.thread_id.into();
298 let server = self.clone();
299 let thread = cx.new(|_| AcpThread {
300 // todo!
301 title: "ACP Thread".into(),
302 id: thread_id.clone(), // Either<ErrorState, Id>
303 next_entry_id: ThreadEntryId(0),
304 entries: Vec::default(),
305 project: self.project.clone(),
306 server,
307 })?;
308 self.threads.lock().insert(thread_id, thread.downgrade());
309 Ok(thread)
310 }
311
312 pub async fn send_message(
313 &self,
314 thread_id: ThreadId,
315 message: acp::Message,
316 _cx: &mut AsyncApp,
317 ) -> Result<()> {
318 self.connection
319 .request(acp::SendMessageParams {
320 thread_id: thread_id.clone().into(),
321 message,
322 })
323 .await
324 .map_err(to_anyhow)?;
325 Ok(())
326 }
327
328 pub fn exit_status(&self) -> Option<ExitStatus> {
329 self.exit_status.lock().clone()
330 }
331}
332
333#[track_caller]
334fn to_anyhow(e: acp::Error) -> anyhow::Error {
335 log::error!(
336 "failed to send message: {code}: {message}",
337 code = e.code,
338 message = e.message
339 );
340 anyhow::anyhow!(e.message)
341}
342
343impl From<acp::ThreadId> for ThreadId {
344 fn from(thread_id: acp::ThreadId) -> Self {
345 Self(thread_id.0.into())
346 }
347}
348
349impl From<ThreadId> for acp::ThreadId {
350 fn from(thread_id: ThreadId) -> Self {
351 acp::ThreadId(thread_id.0.to_string())
352 }
353}
354
355impl From<acp::ToolCallId> for ToolCallId {
356 fn from(tool_call_id: acp::ToolCallId) -> Self {
357 Self(ThreadEntryId(tool_call_id.0))
358 }
359}
360
361impl From<ToolCallId> for acp::ToolCallId {
362 fn from(tool_call_id: ToolCallId) -> Self {
363 acp::ToolCallId(tool_call_id.as_u64())
364 }
365}