1use crate::{AcpThread, ThreadEntryId, ThreadId, ToolCallId, ToolCallRequest};
2use agentic_coding_protocol as acp;
3use anyhow::{Context as _, Result};
4use async_trait::async_trait;
5use collections::HashMap;
6use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
7use parking_lot::Mutex;
8use project::Project;
9use smol::process::Child;
10use std::{io::Write as _, path::Path, sync::Arc};
11use util::ResultExt;
12
13pub struct AcpServer {
14 connection: Arc<acp::AgentConnection>,
15 threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
16 project: Entity<Project>,
17 _handler_task: Task<()>,
18 _io_task: Task<()>,
19}
20
21struct AcpClientDelegate {
22 project: Entity<Project>,
23 threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
24 cx: AsyncApp,
25 // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
26}
27
28impl AcpClientDelegate {
29 fn new(
30 project: Entity<Project>,
31 threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
32 cx: AsyncApp,
33 ) -> Self {
34 Self {
35 project,
36 threads,
37 cx: cx,
38 }
39 }
40
41 fn update_thread<R>(
42 &self,
43 thread_id: &ThreadId,
44 cx: &mut App,
45 callback: impl FnOnce(&mut AcpThread, &mut Context<AcpThread>) -> R,
46 ) -> Option<R> {
47 let thread = self.threads.lock().get(&thread_id)?.clone();
48 let Some(thread) = thread.upgrade() else {
49 self.threads.lock().remove(&thread_id);
50 return None;
51 };
52 Some(thread.update(cx, callback))
53 }
54}
55
56#[async_trait(?Send)]
57impl acp::Client for AcpClientDelegate {
58 async fn stat(&self, params: acp::StatParams) -> Result<acp::StatResponse> {
59 let cx = &mut self.cx.clone();
60 self.project.update(cx, |project, cx| {
61 let path = project
62 .project_path_for_absolute_path(Path::new(¶ms.path), cx)
63 .context("Failed to get project path")?;
64
65 match project.entry_for_path(&path, cx) {
66 // todo! refresh entry?
67 None => Ok(acp::StatResponse {
68 exists: false,
69 is_directory: false,
70 }),
71 Some(entry) => Ok(acp::StatResponse {
72 exists: entry.is_created(),
73 is_directory: entry.is_dir(),
74 }),
75 }
76 })?
77 }
78
79 async fn stream_message_chunk(
80 &self,
81 params: acp::StreamMessageChunkParams,
82 ) -> Result<acp::StreamMessageChunkResponse> {
83 let cx = &mut self.cx.clone();
84
85 cx.update(|cx| {
86 self.update_thread(¶ms.thread_id.into(), cx, |thread, cx| {
87 thread.push_assistant_chunk(params.chunk, cx)
88 });
89 })?;
90
91 Ok(acp::StreamMessageChunkResponse)
92 }
93
94 async fn read_text_file(
95 &self,
96 request: acp::ReadTextFileParams,
97 ) -> Result<acp::ReadTextFileResponse> {
98 let cx = &mut self.cx.clone();
99 let buffer = self
100 .project
101 .update(cx, |project, cx| {
102 let path = project
103 .project_path_for_absolute_path(Path::new(&request.path), cx)
104 .context("Failed to get project path")?;
105 anyhow::Ok(project.open_buffer(path, cx))
106 })??
107 .await?;
108
109 buffer.update(cx, |buffer, _cx| {
110 let start = language::Point::new(request.line_offset.unwrap_or(0), 0);
111 let end = match request.line_limit {
112 None => buffer.max_point(),
113 Some(limit) => start + language::Point::new(limit + 1, 0),
114 };
115
116 let content: String = buffer.text_for_range(start..end).collect();
117
118 acp::ReadTextFileResponse {
119 content,
120 version: acp::FileVersion(0),
121 }
122 })
123 }
124
125 async fn read_binary_file(
126 &self,
127 request: acp::ReadBinaryFileParams,
128 ) -> Result<acp::ReadBinaryFileResponse> {
129 let cx = &mut self.cx.clone();
130 let file = self
131 .project
132 .update(cx, |project, cx| {
133 let (worktree, path) = project
134 .find_worktree(Path::new(&request.path), cx)
135 .context("Failed to get project path")?;
136
137 let task = worktree.update(cx, |worktree, cx| worktree.load_binary_file(&path, cx));
138 anyhow::Ok(task)
139 })??
140 .await?;
141
142 // todo! test
143 let content = cx
144 .background_spawn(async move {
145 let start = request.byte_offset.unwrap_or(0) as usize;
146 let end = request
147 .byte_limit
148 .map(|limit| (start + limit as usize).min(file.content.len()))
149 .unwrap_or(file.content.len());
150
151 let range_content = &file.content[start..end];
152
153 let mut base64_content = Vec::new();
154 let mut base64_encoder = base64::write::EncoderWriter::new(
155 std::io::Cursor::new(&mut base64_content),
156 &base64::engine::general_purpose::STANDARD,
157 );
158 base64_encoder.write_all(range_content)?;
159 drop(base64_encoder);
160
161 // SAFETY: The base64 encoder should not produce non-UTF8.
162 unsafe { anyhow::Ok(String::from_utf8_unchecked(base64_content)) }
163 })
164 .await?;
165
166 Ok(acp::ReadBinaryFileResponse {
167 content,
168 // todo!
169 version: acp::FileVersion(0),
170 })
171 }
172
173 async fn glob_search(
174 &self,
175 _request: acp::GlobSearchParams,
176 ) -> Result<acp::GlobSearchResponse> {
177 todo!()
178 }
179
180 async fn request_tool_call_confirmation(
181 &self,
182 request: acp::RequestToolCallConfirmationParams,
183 ) -> Result<acp::RequestToolCallConfirmationResponse> {
184 let cx = &mut self.cx.clone();
185 let ToolCallRequest { id, outcome } = cx
186 .update(|cx| {
187 self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
188 thread.request_tool_call(request.display_name, request.confirmation, cx)
189 })
190 })?
191 .context("Failed to update thread")?;
192
193 Ok(acp::RequestToolCallConfirmationResponse {
194 id: id.into(),
195 outcome: outcome.await?,
196 })
197 }
198
199 async fn push_tool_call(
200 &self,
201 request: acp::PushToolCallParams,
202 ) -> Result<acp::PushToolCallResponse> {
203 let cx = &mut self.cx.clone();
204 let entry_id = cx
205 .update(|cx| {
206 self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
207 thread.push_tool_call(request.display_name, cx)
208 })
209 })?
210 .context("Failed to update thread")?;
211
212 Ok(acp::PushToolCallResponse {
213 id: entry_id.into(),
214 })
215 }
216
217 async fn update_tool_call(
218 &self,
219 request: acp::UpdateToolCallParams,
220 ) -> Result<acp::UpdateToolCallResponse> {
221 let cx = &mut self.cx.clone();
222
223 cx.update(|cx| {
224 self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
225 thread.update_tool_call(
226 request.tool_call_id.into(),
227 request.status,
228 request.content,
229 cx,
230 )
231 })
232 })?
233 .context("Failed to update thread")??;
234
235 Ok(acp::UpdateToolCallResponse)
236 }
237}
238
239impl AcpServer {
240 pub fn stdio(mut process: Child, project: Entity<Project>, cx: &mut AsyncApp) -> Arc<Self> {
241 let stdin = process.stdin.take().expect("process didn't have stdin");
242 let stdout = process.stdout.take().expect("process didn't have stdout");
243
244 let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
245 let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
246 AcpClientDelegate::new(project.clone(), threads.clone(), cx.clone()),
247 stdin,
248 stdout,
249 );
250
251 let io_task = cx.background_spawn(async move {
252 io_fut.await.log_err();
253 process.status().await.log_err();
254 });
255
256 Arc::new(Self {
257 project,
258 connection: Arc::new(connection),
259 threads,
260 _handler_task: cx.foreground_executor().spawn(handler_fut),
261 _io_task: io_task,
262 })
263 }
264}
265
266impl AcpServer {
267 pub async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<AcpThread>> {
268 let response = self.connection.request(acp::CreateThreadParams).await?;
269 let thread_id: ThreadId = response.thread_id.into();
270 let server = self.clone();
271 let thread = cx.new(|_| AcpThread {
272 // todo!
273 title: "ACP Thread".into(),
274 id: thread_id.clone(),
275 next_entry_id: ThreadEntryId(0),
276 entries: Vec::default(),
277 project: self.project.clone(),
278 server,
279 })?;
280 self.threads.lock().insert(thread_id, thread.downgrade());
281 Ok(thread)
282 }
283
284 pub async fn send_message(
285 &self,
286 thread_id: ThreadId,
287 message: acp::Message,
288 _cx: &mut AsyncApp,
289 ) -> Result<()> {
290 self.connection
291 .request(acp::SendMessageParams {
292 thread_id: thread_id.clone().into(),
293 message,
294 })
295 .await?;
296 Ok(())
297 }
298}
299
300impl From<acp::ThreadId> for ThreadId {
301 fn from(thread_id: acp::ThreadId) -> Self {
302 Self(thread_id.0.into())
303 }
304}
305
306impl From<ThreadId> for acp::ThreadId {
307 fn from(thread_id: ThreadId) -> Self {
308 acp::ThreadId(thread_id.0.to_string())
309 }
310}
311
312impl From<acp::ToolCallId> for ToolCallId {
313 fn from(tool_call_id: acp::ToolCallId) -> Self {
314 Self(ThreadEntryId(tool_call_id.0))
315 }
316}
317
318impl From<ToolCallId> for acp::ToolCallId {
319 fn from(tool_call_id: ToolCallId) -> Self {
320 acp::ToolCallId(tool_call_id.as_u64())
321 }
322}