1use std::{
2 io::{Cursor, Write as _},
3 path::Path,
4 sync::{Arc, Weak},
5};
6
7use crate::{
8 Agent, AgentThread, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk,
9 ResponseEvent, Role, Thread, ThreadEntry, ThreadId,
10};
11use agentic_coding_protocol::{self as acp, TurnId};
12use anyhow::{Context as _, Result};
13use async_trait::async_trait;
14use collections::HashMap;
15use futures::channel::mpsc::UnboundedReceiver;
16use gpui::{AppContext, AsyncApp, Entity, Task, WeakEntity};
17use parking_lot::Mutex;
18use project::Project;
19use smol::process::Child;
20use util::ResultExt;
21
22pub struct AcpAgent {
23 connection: Arc<acp::AgentConnection>,
24 threads: Arc<Mutex<HashMap<acp::ThreadId, WeakEntity<Thread>>>>,
25 _handler_task: Task<()>,
26 _io_task: Task<()>,
27}
28
29struct AcpClientDelegate {
30 project: Entity<Project>,
31 threads: Arc<Mutex<HashMap<acp::ThreadId, WeakEntity<Thread>>>>,
32 cx: AsyncApp,
33 // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
34}
35
36#[async_trait(?Send)]
37impl acp::Client for AcpClientDelegate {
38 async fn stat(&self, params: acp::StatParams) -> Result<acp::StatResponse> {
39 let cx = &mut self.cx.clone();
40 self.project.update(cx, |project, cx| {
41 let path = project
42 .project_path_for_absolute_path(Path::new(¶ms.path), cx)
43 .context("Failed to get project path")?;
44
45 match project.entry_for_path(&path, cx) {
46 // todo! refresh entry?
47 None => Ok(acp::StatResponse {
48 exists: false,
49 is_directory: false,
50 }),
51 Some(entry) => Ok(acp::StatResponse {
52 exists: entry.is_created(),
53 is_directory: entry.is_dir(),
54 }),
55 }
56 })?
57 }
58
59 async fn stream_message_chunk(
60 &self,
61 request: acp::StreamMessageChunkParams,
62 ) -> Result<acp::StreamMessageChunkResponse> {
63 Ok(acp::StreamMessageChunkResponse)
64 }
65
66 async fn read_text_file(
67 &self,
68 request: acp::ReadTextFileParams,
69 ) -> Result<acp::ReadTextFileResponse> {
70 let cx = &mut self.cx.clone();
71 let buffer = self
72 .project
73 .update(cx, |project, cx| {
74 let path = project
75 .project_path_for_absolute_path(Path::new(&request.path), cx)
76 .context("Failed to get project path")?;
77 anyhow::Ok(project.open_buffer(path, cx))
78 })??
79 .await?;
80
81 buffer.update(cx, |buffer, _| {
82 let start = language::Point::new(request.line_offset.unwrap_or(0), 0);
83 let end = match request.line_limit {
84 None => buffer.max_point(),
85 Some(limit) => start + language::Point::new(limit + 1, 0),
86 };
87
88 let content = buffer.text_for_range(start..end).collect();
89
90 if let Some(thread) = self.threads.lock().get(&request.thread_id) {
91 thread.update(cx, |thread, cx| {
92 thread.push_entry(ThreadEntry {
93 content: AgentThreadEntryContent::ReadFile {
94 path: request.path.clone(),
95 content: content.clone(),
96 },
97 });
98 })
99 }
100
101 acp::ReadTextFileResponse {
102 content,
103 version: acp::FileVersion(0),
104 }
105 })
106 }
107
108 async fn read_binary_file(
109 &self,
110 request: acp::ReadBinaryFileParams,
111 ) -> Result<acp::ReadBinaryFileResponse> {
112 let cx = &mut self.cx.clone();
113 let file = self
114 .project
115 .update(cx, |project, cx| {
116 let (worktree, path) = project
117 .find_worktree(Path::new(&request.path), cx)
118 .context("Failed to get project path")?;
119
120 let task = worktree.update(cx, |worktree, cx| worktree.load_binary_file(&path, cx));
121 anyhow::Ok(task)
122 })??
123 .await?;
124
125 // todo! test
126 let content = cx
127 .background_spawn(async move {
128 let start = request.byte_offset.unwrap_or(0) as usize;
129 let end = request
130 .byte_limit
131 .map(|limit| (start + limit as usize).min(file.content.len()))
132 .unwrap_or(file.content.len());
133
134 let range_content = &file.content[start..end];
135
136 let mut base64_content = Vec::new();
137 let mut base64_encoder = base64::write::EncoderWriter::new(
138 Cursor::new(&mut base64_content),
139 &base64::engine::general_purpose::STANDARD,
140 );
141 base64_encoder.write_all(range_content)?;
142 drop(base64_encoder);
143
144 // SAFETY: The base64 encoder should not produce non-UTF8.
145 unsafe { anyhow::Ok(String::from_utf8_unchecked(base64_content)) }
146 })
147 .await?;
148
149 Ok(acp::ReadBinaryFileResponse {
150 content,
151 // todo!
152 version: acp::FileVersion(0),
153 })
154 }
155
156 async fn glob_search(&self, request: acp::GlobSearchParams) -> Result<acp::GlobSearchResponse> {
157 todo!()
158 }
159
160 async fn end_turn(&self, request: acp::EndTurnParams) -> Result<acp::EndTurnResponse> {
161 todo!()
162 }
163}
164
165impl AcpAgent {
166 pub fn stdio(mut process: Child, project: Entity<Project>, cx: AsyncApp) -> Self {
167 let stdin = process.stdin.take().expect("process didn't have stdin");
168 let stdout = process.stdout.take().expect("process didn't have stdout");
169
170 let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
171 AcpClientDelegate {
172 project,
173 cx: cx.clone(),
174 },
175 stdin,
176 stdout,
177 );
178
179 let io_task = cx.background_spawn(async move {
180 io_fut.await.log_err();
181 process.status().await.log_err();
182 });
183
184 Self {
185 connection: Arc::new(connection),
186 threads: Mutex::default(),
187 _handler_task: cx.foreground_executor().spawn(handler_fut),
188 _io_task: io_task,
189 }
190 }
191}
192
193#[async_trait]
194impl Agent for AcpAgent {
195 async fn threads(&self) -> Result<Vec<AgentThreadSummary>> {
196 let response = self.connection.request(acp::GetThreadsParams).await?;
197 response
198 .threads
199 .into_iter()
200 .map(|thread| {
201 Ok(AgentThreadSummary {
202 id: thread.id.into(),
203 title: thread.title,
204 created_at: thread.modified_at,
205 })
206 })
207 .collect()
208 }
209
210 async fn create_thread(&self) -> Result<Arc<Self::Thread>> {
211 let response = self.connection.request(acp::CreateThreadParams).await?;
212 let thread = Arc::new(AcpAgentThread {
213 id: response.thread_id.clone(),
214 connection: self.connection.clone(),
215 state: Mutex::new(AcpAgentThreadState {
216 turn: None,
217 next_turn_id: TurnId::default(),
218 }),
219 });
220 self.threads
221 .lock()
222 .insert(response.thread_id, Arc::downgrade(&thread));
223 Ok(thread)
224 }
225
226 async fn open_thread(&self, id: ThreadId) -> Result<Thread> {
227 todo!()
228 }
229
230 async fn thread_entries(&self, thread_id: ThreadId) -> Result<Vec<AgentThreadEntryContent>> {
231 let response = self
232 .connection
233 .request(acp::GetThreadEntriesParams {
234 thread_id: self.id.clone(),
235 })
236 .await?;
237
238 Ok(response
239 .entries
240 .into_iter()
241 .map(|entry| match entry {
242 acp::ThreadEntry::Message { message } => {
243 AgentThreadEntryContent::Message(Message {
244 role: match message.role {
245 acp::Role::User => Role::User,
246 acp::Role::Assistant => Role::Assistant,
247 },
248 chunks: message
249 .chunks
250 .into_iter()
251 .map(|chunk| match chunk {
252 acp::MessageChunk::Text { chunk } => MessageChunk::Text { chunk },
253 })
254 .collect(),
255 })
256 }
257 acp::ThreadEntry::ReadFile { path, content } => {
258 AgentThreadEntryContent::ReadFile { path, content }
259 }
260 })
261 .collect())
262 }
263
264 async fn send_thread_message(
265 &self,
266 thread_id: ThreadId,
267 message: crate::Message,
268 ) -> Result<UnboundedReceiver<Result<ResponseEvent>>> {
269 let turn_id = {
270 let mut state = self.state.lock();
271 let turn_id = state.next_turn_id.post_inc();
272 state.turn = Some(AcpAgentThreadTurn { id: turn_id });
273 turn_id
274 };
275 let response = self
276 .connection
277 .request(acp::SendMessageParams {
278 thread_id: self.id.clone(),
279 turn_id,
280 message: acp::Message {
281 role: match message.role {
282 Role::User => acp::Role::User,
283 Role::Assistant => acp::Role::Assistant,
284 },
285 chunks: message
286 .chunks
287 .into_iter()
288 .map(|chunk| match chunk {
289 MessageChunk::Text { chunk } => acp::MessageChunk::Text { chunk },
290 MessageChunk::File { .. } => todo!(),
291 MessageChunk::Directory { .. } => todo!(),
292 MessageChunk::Symbol { .. } => todo!(),
293 MessageChunk::Thread { .. } => todo!(),
294 MessageChunk::Fetch { .. } => todo!(),
295 })
296 .collect(),
297 },
298 })
299 .await?;
300 todo!()
301 }
302}
303
304pub struct AcpAgentThread {
305 id: acp::ThreadId,
306 connection: Arc<acp::AgentConnection>,
307 state: Mutex<AcpAgentThreadState>,
308}
309
310struct AcpAgentThreadState {
311 next_turn_id: acp::TurnId,
312 turn: Option<AcpAgentThreadTurn>,
313}
314
315struct AcpAgentThreadTurn {
316 id: acp::TurnId,
317}
318
319impl From<acp::ThreadId> for ThreadId {
320 fn from(thread_id: acp::ThreadId) -> Self {
321 Self(thread_id.0)
322 }
323}
324
325impl From<ThreadId> for acp::ThreadId {
326 fn from(thread_id: ThreadId) -> Self {
327 acp::ThreadId(thread_id.0)
328 }
329}