1use std::{io::Write as _, path::Path, sync::Arc};
2
3use crate::{
4 Agent, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk, Role, Thread,
5 ThreadEntryId, ThreadId,
6};
7use agentic_coding_protocol as acp;
8use anyhow::{Context as _, Result, anyhow};
9use async_trait::async_trait;
10use collections::HashMap;
11use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
12use parking_lot::Mutex;
13use project::Project;
14use smol::process::Child;
15use util::ResultExt;
16
17pub struct AcpAgent {
18 connection: Arc<acp::AgentConnection>,
19 threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>>,
20 project: Entity<Project>,
21 _handler_task: Task<()>,
22 _io_task: Task<()>,
23}
24
25struct AcpClientDelegate {
26 project: Entity<Project>,
27 threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>>,
28 cx: AsyncApp,
29 // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
30}
31
32impl AcpClientDelegate {
33 fn new(
34 project: Entity<Project>,
35 threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>>,
36 cx: AsyncApp,
37 ) -> Self {
38 Self {
39 project,
40 threads,
41 cx: cx,
42 }
43 }
44
45 fn update_thread<R>(
46 &self,
47 thread_id: &ThreadId,
48 cx: &mut App,
49 callback: impl FnMut(&mut Thread, &mut Context<Thread>) -> R,
50 ) -> Option<R> {
51 let thread = self.threads.lock().get(&thread_id)?.clone();
52 let Some(thread) = thread.upgrade() else {
53 self.threads.lock().remove(&thread_id);
54 return None;
55 };
56 Some(thread.update(cx, callback))
57 }
58}
59
60#[async_trait(?Send)]
61impl acp::Client for AcpClientDelegate {
62 async fn stat(&self, params: acp::StatParams) -> Result<acp::StatResponse> {
63 let cx = &mut self.cx.clone();
64 self.project.update(cx, |project, cx| {
65 let path = project
66 .project_path_for_absolute_path(Path::new(¶ms.path), cx)
67 .context("Failed to get project path")?;
68
69 match project.entry_for_path(&path, cx) {
70 // todo! refresh entry?
71 None => Ok(acp::StatResponse {
72 exists: false,
73 is_directory: false,
74 }),
75 Some(entry) => Ok(acp::StatResponse {
76 exists: entry.is_created(),
77 is_directory: entry.is_dir(),
78 }),
79 }
80 })?
81 }
82
83 async fn stream_message_chunk(
84 &self,
85 chunk: acp::StreamMessageChunkParams,
86 ) -> Result<acp::StreamMessageChunkResponse> {
87 Ok(acp::StreamMessageChunkResponse)
88 }
89
90 async fn read_text_file(
91 &self,
92 request: acp::ReadTextFileParams,
93 ) -> Result<acp::ReadTextFileResponse> {
94 let cx = &mut self.cx.clone();
95 let buffer = self
96 .project
97 .update(cx, |project, cx| {
98 let path = project
99 .project_path_for_absolute_path(Path::new(&request.path), cx)
100 .context("Failed to get project path")?;
101 anyhow::Ok(project.open_buffer(path, cx))
102 })??
103 .await?;
104
105 buffer.update(cx, |buffer, cx| {
106 let start = language::Point::new(request.line_offset.unwrap_or(0), 0);
107 let end = match request.line_limit {
108 None => buffer.max_point(),
109 Some(limit) => start + language::Point::new(limit + 1, 0),
110 };
111
112 let content: String = buffer.text_for_range(start..end).collect();
113 self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
114 thread.push_entry(
115 AgentThreadEntryContent::ReadFile {
116 path: request.path.clone(),
117 content: content.clone(),
118 },
119 cx,
120 );
121 });
122
123 acp::ReadTextFileResponse {
124 content,
125 version: acp::FileVersion(0),
126 }
127 })
128 }
129
130 async fn read_binary_file(
131 &self,
132 request: acp::ReadBinaryFileParams,
133 ) -> Result<acp::ReadBinaryFileResponse> {
134 let cx = &mut self.cx.clone();
135 let file = self
136 .project
137 .update(cx, |project, cx| {
138 let (worktree, path) = project
139 .find_worktree(Path::new(&request.path), cx)
140 .context("Failed to get project path")?;
141
142 let task = worktree.update(cx, |worktree, cx| worktree.load_binary_file(&path, cx));
143 anyhow::Ok(task)
144 })??
145 .await?;
146
147 // todo! test
148 let content = cx
149 .background_spawn(async move {
150 let start = request.byte_offset.unwrap_or(0) as usize;
151 let end = request
152 .byte_limit
153 .map(|limit| (start + limit as usize).min(file.content.len()))
154 .unwrap_or(file.content.len());
155
156 let range_content = &file.content[start..end];
157
158 let mut base64_content = Vec::new();
159 let mut base64_encoder = base64::write::EncoderWriter::new(
160 std::io::Cursor::new(&mut base64_content),
161 &base64::engine::general_purpose::STANDARD,
162 );
163 base64_encoder.write_all(range_content)?;
164 drop(base64_encoder);
165
166 // SAFETY: The base64 encoder should not produce non-UTF8.
167 unsafe { anyhow::Ok(String::from_utf8_unchecked(base64_content)) }
168 })
169 .await?;
170
171 Ok(acp::ReadBinaryFileResponse {
172 content,
173 // todo!
174 version: acp::FileVersion(0),
175 })
176 }
177
178 async fn glob_search(&self, request: acp::GlobSearchParams) -> Result<acp::GlobSearchResponse> {
179 todo!()
180 }
181
182 async fn end_turn(&self, request: acp::EndTurnParams) -> Result<acp::EndTurnResponse> {
183 todo!()
184 }
185}
186
187impl AcpAgent {
188 pub fn stdio(mut process: Child, project: Entity<Project>, cx: AsyncApp) -> Self {
189 let stdin = process.stdin.take().expect("process didn't have stdin");
190 let stdout = process.stdout.take().expect("process didn't have stdout");
191
192 let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>> = Default::default();
193 let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
194 AcpClientDelegate::new(project.clone(), threads.clone(), cx.clone()),
195 stdin,
196 stdout,
197 );
198
199 let io_task = cx.background_spawn(async move {
200 io_fut.await.log_err();
201 process.status().await.log_err();
202 });
203
204 Self {
205 project,
206 connection: Arc::new(connection),
207 threads,
208 _handler_task: cx.foreground_executor().spawn(handler_fut),
209 _io_task: io_task,
210 }
211 }
212}
213
214#[async_trait(?Send)]
215impl Agent for AcpAgent {
216 async fn threads(&self, cx: &mut AsyncApp) -> Result<Vec<AgentThreadSummary>> {
217 let response = self.connection.request(acp::GetThreadsParams).await?;
218 response
219 .threads
220 .into_iter()
221 .map(|thread| {
222 Ok(AgentThreadSummary {
223 id: thread.id.into(),
224 title: thread.title,
225 created_at: thread.modified_at,
226 })
227 })
228 .collect()
229 }
230
231 async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<Thread>> {
232 let response = self.connection.request(acp::CreateThreadParams).await?;
233 let thread_id: ThreadId = response.thread_id.into();
234 let agent = self.clone();
235 let thread = cx.new(|_| Thread {
236 id: thread_id.clone(),
237 next_entry_id: ThreadEntryId(0),
238 entries: Vec::default(),
239 project: self.project.clone(),
240 agent,
241 })?;
242 self.threads.lock().insert(thread_id, thread.downgrade());
243 Ok(thread)
244 }
245
246 async fn open_thread(&self, id: ThreadId, cx: &mut AsyncApp) -> Result<Entity<Thread>> {
247 todo!()
248 }
249
250 async fn thread_entries(
251 &self,
252 thread_id: ThreadId,
253 cx: &mut AsyncApp,
254 ) -> Result<Vec<AgentThreadEntryContent>> {
255 let response = self
256 .connection
257 .request(acp::GetThreadEntriesParams {
258 thread_id: thread_id.clone().into(),
259 })
260 .await?;
261
262 Ok(response
263 .entries
264 .into_iter()
265 .map(|entry| match entry {
266 acp::ThreadEntry::Message { message } => {
267 AgentThreadEntryContent::Message(Message {
268 role: match message.role {
269 acp::Role::User => Role::User,
270 acp::Role::Assistant => Role::Assistant,
271 },
272 chunks: message
273 .chunks
274 .into_iter()
275 .map(|chunk| match chunk {
276 acp::MessageChunk::Text { chunk } => MessageChunk::Text { chunk },
277 })
278 .collect(),
279 })
280 }
281 acp::ThreadEntry::ReadFile { path, content } => {
282 AgentThreadEntryContent::ReadFile { path, content }
283 }
284 })
285 .collect())
286 }
287
288 async fn send_thread_message(
289 &self,
290 thread_id: ThreadId,
291 message: crate::Message,
292 cx: &mut AsyncApp,
293 ) -> Result<()> {
294 let thread = self
295 .threads
296 .lock()
297 .get(&thread_id)
298 .cloned()
299 .ok_or_else(|| anyhow!("no such thread"))?;
300 self.connection
301 .request(acp::SendMessageParams {
302 thread_id: thread_id.clone().into(),
303 message: acp::Message {
304 role: match message.role {
305 Role::User => acp::Role::User,
306 Role::Assistant => acp::Role::Assistant,
307 },
308 chunks: message
309 .chunks
310 .into_iter()
311 .map(|chunk| match chunk {
312 MessageChunk::Text { chunk } => acp::MessageChunk::Text { chunk },
313 MessageChunk::File { .. } => todo!(),
314 MessageChunk::Directory { .. } => todo!(),
315 MessageChunk::Symbol { .. } => todo!(),
316 MessageChunk::Thread { .. } => todo!(),
317 MessageChunk::Fetch { .. } => todo!(),
318 })
319 .collect(),
320 },
321 })
322 .await?;
323 Ok(())
324 }
325}
326
327impl From<acp::ThreadId> for ThreadId {
328 fn from(thread_id: acp::ThreadId) -> Self {
329 Self(thread_id.0.into())
330 }
331}
332
333impl From<ThreadId> for acp::ThreadId {
334 fn from(thread_id: ThreadId) -> Self {
335 acp::ThreadId(thread_id.0.to_string())
336 }
337}