1use std::{io::Write as _, path::Path, sync::Arc};
2
3use crate::{
4 Agent, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk, Role, Thread,
5 ThreadEntryId, ThreadId,
6};
7use agentic_coding_protocol as acp;
8use anyhow::{Context as _, Result};
9use async_trait::async_trait;
10use collections::HashMap;
11use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
12use parking_lot::Mutex;
13use project::Project;
14use smol::process::Child;
15use util::ResultExt;
16
17pub struct AcpAgent {
18 connection: Arc<acp::AgentConnection>,
19 threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>>,
20 project: Entity<Project>,
21 _handler_task: Task<()>,
22 _io_task: Task<()>,
23}
24
25struct AcpClientDelegate {
26 project: Entity<Project>,
27 threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>>,
28 cx: AsyncApp,
29 // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
30}
31
32impl AcpClientDelegate {
33 fn new(
34 project: Entity<Project>,
35 threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>>,
36 cx: AsyncApp,
37 ) -> Self {
38 Self {
39 project,
40 threads,
41 cx: cx,
42 }
43 }
44
45 fn update_thread<R>(
46 &self,
47 thread_id: &ThreadId,
48 cx: &mut App,
49 callback: impl FnMut(&mut Thread, &mut Context<Thread>) -> R,
50 ) -> Option<R> {
51 let thread = self.threads.lock().get(&thread_id)?.clone();
52 let Some(thread) = thread.upgrade() else {
53 self.threads.lock().remove(&thread_id);
54 return None;
55 };
56 Some(thread.update(cx, callback))
57 }
58}
59
60#[async_trait(?Send)]
61impl acp::Client for AcpClientDelegate {
62 async fn stat(&self, params: acp::StatParams) -> Result<acp::StatResponse> {
63 let cx = &mut self.cx.clone();
64 self.project.update(cx, |project, cx| {
65 let path = project
66 .project_path_for_absolute_path(Path::new(¶ms.path), cx)
67 .context("Failed to get project path")?;
68
69 match project.entry_for_path(&path, cx) {
70 // todo! refresh entry?
71 None => Ok(acp::StatResponse {
72 exists: false,
73 is_directory: false,
74 }),
75 Some(entry) => Ok(acp::StatResponse {
76 exists: entry.is_created(),
77 is_directory: entry.is_dir(),
78 }),
79 }
80 })?
81 }
82
83 async fn stream_message_chunk(
84 &self,
85 params: acp::StreamMessageChunkParams,
86 ) -> Result<acp::StreamMessageChunkResponse> {
87 let cx = &mut self.cx.clone();
88
89 cx.update(|cx| {
90 self.update_thread(¶ms.thread_id.into(), cx, |thread, cx| {
91 let acp::MessageChunk::Text { chunk } = ¶ms.chunk;
92 thread.push_assistant_chunk(
93 MessageChunk::Text {
94 chunk: chunk.into(),
95 },
96 cx,
97 )
98 });
99 })?;
100
101 Ok(acp::StreamMessageChunkResponse)
102 }
103
104 async fn read_text_file(
105 &self,
106 request: acp::ReadTextFileParams,
107 ) -> Result<acp::ReadTextFileResponse> {
108 let cx = &mut self.cx.clone();
109 let buffer = self
110 .project
111 .update(cx, |project, cx| {
112 let path = project
113 .project_path_for_absolute_path(Path::new(&request.path), cx)
114 .context("Failed to get project path")?;
115 anyhow::Ok(project.open_buffer(path, cx))
116 })??
117 .await?;
118
119 buffer.update(cx, |buffer, cx| {
120 let start = language::Point::new(request.line_offset.unwrap_or(0), 0);
121 let end = match request.line_limit {
122 None => buffer.max_point(),
123 Some(limit) => start + language::Point::new(limit + 1, 0),
124 };
125
126 let content: String = buffer.text_for_range(start..end).collect();
127 self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
128 thread.push_entry(
129 AgentThreadEntryContent::ReadFile {
130 path: request.path.clone(),
131 content: content.clone(),
132 },
133 cx,
134 );
135 });
136
137 acp::ReadTextFileResponse {
138 content,
139 version: acp::FileVersion(0),
140 }
141 })
142 }
143
144 async fn read_binary_file(
145 &self,
146 request: acp::ReadBinaryFileParams,
147 ) -> Result<acp::ReadBinaryFileResponse> {
148 let cx = &mut self.cx.clone();
149 let file = self
150 .project
151 .update(cx, |project, cx| {
152 let (worktree, path) = project
153 .find_worktree(Path::new(&request.path), cx)
154 .context("Failed to get project path")?;
155
156 let task = worktree.update(cx, |worktree, cx| worktree.load_binary_file(&path, cx));
157 anyhow::Ok(task)
158 })??
159 .await?;
160
161 // todo! test
162 let content = cx
163 .background_spawn(async move {
164 let start = request.byte_offset.unwrap_or(0) as usize;
165 let end = request
166 .byte_limit
167 .map(|limit| (start + limit as usize).min(file.content.len()))
168 .unwrap_or(file.content.len());
169
170 let range_content = &file.content[start..end];
171
172 let mut base64_content = Vec::new();
173 let mut base64_encoder = base64::write::EncoderWriter::new(
174 std::io::Cursor::new(&mut base64_content),
175 &base64::engine::general_purpose::STANDARD,
176 );
177 base64_encoder.write_all(range_content)?;
178 drop(base64_encoder);
179
180 // SAFETY: The base64 encoder should not produce non-UTF8.
181 unsafe { anyhow::Ok(String::from_utf8_unchecked(base64_content)) }
182 })
183 .await?;
184
185 Ok(acp::ReadBinaryFileResponse {
186 content,
187 // todo!
188 version: acp::FileVersion(0),
189 })
190 }
191
192 async fn glob_search(&self, request: acp::GlobSearchParams) -> Result<acp::GlobSearchResponse> {
193 todo!()
194 }
195}
196
197impl AcpAgent {
198 pub fn stdio(mut process: Child, project: Entity<Project>, cx: &mut AsyncApp) -> Arc<Self> {
199 let stdin = process.stdin.take().expect("process didn't have stdin");
200 let stdout = process.stdout.take().expect("process didn't have stdout");
201
202 let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>> = Default::default();
203 let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
204 AcpClientDelegate::new(project.clone(), threads.clone(), cx.clone()),
205 stdin,
206 stdout,
207 );
208
209 let io_task = cx.background_spawn(async move {
210 io_fut.await.log_err();
211 process.status().await.log_err();
212 });
213
214 Arc::new(Self {
215 project,
216 connection: Arc::new(connection),
217 threads,
218 _handler_task: cx.foreground_executor().spawn(handler_fut),
219 _io_task: io_task,
220 })
221 }
222}
223
224#[async_trait(?Send)]
225impl Agent for AcpAgent {
226 async fn threads(&self, cx: &mut AsyncApp) -> Result<Vec<AgentThreadSummary>> {
227 let response = self.connection.request(acp::GetThreadsParams).await?;
228 response
229 .threads
230 .into_iter()
231 .map(|thread| {
232 Ok(AgentThreadSummary {
233 id: thread.id.into(),
234 title: thread.title,
235 created_at: thread.modified_at,
236 })
237 })
238 .collect()
239 }
240
241 async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<Thread>> {
242 let response = self.connection.request(acp::CreateThreadParams).await?;
243 let thread_id: ThreadId = response.thread_id.into();
244 let agent = self.clone();
245 let thread = cx.new(|_| Thread {
246 title: "The agent2 thread".into(),
247 id: thread_id.clone(),
248 next_entry_id: ThreadEntryId(0),
249 entries: Vec::default(),
250 project: self.project.clone(),
251 agent,
252 })?;
253 self.threads.lock().insert(thread_id, thread.downgrade());
254 Ok(thread)
255 }
256
257 async fn open_thread(&self, id: ThreadId, cx: &mut AsyncApp) -> Result<Entity<Thread>> {
258 todo!()
259 }
260
261 async fn thread_entries(
262 &self,
263 thread_id: ThreadId,
264 cx: &mut AsyncApp,
265 ) -> Result<Vec<AgentThreadEntryContent>> {
266 let response = self
267 .connection
268 .request(acp::GetThreadEntriesParams {
269 thread_id: thread_id.clone().into(),
270 })
271 .await?;
272
273 Ok(response
274 .entries
275 .into_iter()
276 .map(|entry| match entry {
277 acp::ThreadEntry::Message { message } => {
278 AgentThreadEntryContent::Message(Message {
279 role: match message.role {
280 acp::Role::User => Role::User,
281 acp::Role::Assistant => Role::Assistant,
282 },
283 chunks: message
284 .chunks
285 .into_iter()
286 .map(|chunk| match chunk {
287 acp::MessageChunk::Text { chunk } => MessageChunk::Text {
288 chunk: chunk.into(),
289 },
290 })
291 .collect(),
292 })
293 }
294 acp::ThreadEntry::ReadFile { path, content } => {
295 AgentThreadEntryContent::ReadFile { path, content }
296 }
297 })
298 .collect())
299 }
300
301 async fn send_thread_message(
302 &self,
303 thread_id: ThreadId,
304 message: crate::Message,
305 cx: &mut AsyncApp,
306 ) -> Result<()> {
307 self.connection
308 .request(acp::SendMessageParams {
309 thread_id: thread_id.clone().into(),
310 message: acp::Message {
311 role: match message.role {
312 Role::User => acp::Role::User,
313 Role::Assistant => acp::Role::Assistant,
314 },
315 chunks: message
316 .chunks
317 .into_iter()
318 .map(|chunk| match chunk {
319 MessageChunk::Text { chunk } => acp::MessageChunk::Text {
320 chunk: chunk.into(),
321 },
322 MessageChunk::File { .. } => todo!(),
323 MessageChunk::Directory { .. } => todo!(),
324 MessageChunk::Symbol { .. } => todo!(),
325 MessageChunk::Fetch { .. } => todo!(),
326 })
327 .collect(),
328 },
329 })
330 .await?;
331 Ok(())
332 }
333}
334
335impl From<acp::ThreadId> for ThreadId {
336 fn from(thread_id: acp::ThreadId) -> Self {
337 Self(thread_id.0.into())
338 }
339}
340
341impl From<ThreadId> for acp::ThreadId {
342 fn from(thread_id: ThreadId) -> Self {
343 acp::ThreadId(thread_id.0.to_string())
344 }
345}