acp.rs

  1use std::{io::Write as _, path::Path, sync::Arc};
  2
  3use crate::{
  4    Agent, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk, ResponseEvent, Role,
  5    Thread, ThreadEntryId, ThreadId,
  6};
  7use agentic_coding_protocol as acp;
  8use anyhow::{Context as _, Result, anyhow};
  9use async_trait::async_trait;
 10use collections::HashMap;
 11use futures::channel::mpsc::UnboundedReceiver;
 12use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
 13use parking_lot::Mutex;
 14use project::Project;
 15use smol::process::Child;
 16use util::ResultExt;
 17
 18pub struct AcpAgent {
 19    connection: Arc<acp::AgentConnection>,
 20    threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>>,
 21    project: Entity<Project>,
 22    _handler_task: Task<()>,
 23    _io_task: Task<()>,
 24}
 25
 26struct AcpClientDelegate {
 27    project: Entity<Project>,
 28    threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>>,
 29    cx: AsyncApp,
 30    // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
 31}
 32
 33impl AcpClientDelegate {
 34    fn new(project: Entity<Project>, cx: AsyncApp) -> Self {
 35        Self {
 36            project,
 37            threads: Default::default(),
 38            cx: cx,
 39        }
 40    }
 41
 42    fn update_thread<R>(
 43        &self,
 44        thread_id: &ThreadId,
 45        cx: &mut App,
 46        callback: impl FnMut(&mut Thread, &mut Context<Thread>) -> R,
 47    ) -> Option<R> {
 48        let thread = self.threads.lock().get(&thread_id)?.clone();
 49        let Some(thread) = thread.upgrade() else {
 50            self.threads.lock().remove(&thread_id);
 51            return None;
 52        };
 53        Some(thread.update(cx, callback))
 54    }
 55}
 56
 57#[async_trait(?Send)]
 58impl acp::Client for AcpClientDelegate {
 59    async fn stat(&self, params: acp::StatParams) -> Result<acp::StatResponse> {
 60        let cx = &mut self.cx.clone();
 61        self.project.update(cx, |project, cx| {
 62            let path = project
 63                .project_path_for_absolute_path(Path::new(&params.path), cx)
 64                .context("Failed to get project path")?;
 65
 66            match project.entry_for_path(&path, cx) {
 67                // todo! refresh entry?
 68                None => Ok(acp::StatResponse {
 69                    exists: false,
 70                    is_directory: false,
 71                }),
 72                Some(entry) => Ok(acp::StatResponse {
 73                    exists: entry.is_created(),
 74                    is_directory: entry.is_dir(),
 75                }),
 76            }
 77        })?
 78    }
 79
 80    async fn stream_message_chunk(
 81        &self,
 82        chunk: acp::StreamMessageChunkParams,
 83    ) -> Result<acp::StreamMessageChunkResponse> {
 84        Ok(acp::StreamMessageChunkResponse)
 85    }
 86
 87    async fn read_text_file(
 88        &self,
 89        request: acp::ReadTextFileParams,
 90    ) -> Result<acp::ReadTextFileResponse> {
 91        let cx = &mut self.cx.clone();
 92        let buffer = self
 93            .project
 94            .update(cx, |project, cx| {
 95                let path = project
 96                    .project_path_for_absolute_path(Path::new(&request.path), cx)
 97                    .context("Failed to get project path")?;
 98                anyhow::Ok(project.open_buffer(path, cx))
 99            })??
100            .await?;
101
102        buffer.update(cx, |buffer, cx| {
103            let start = language::Point::new(request.line_offset.unwrap_or(0), 0);
104            let end = match request.line_limit {
105                None => buffer.max_point(),
106                Some(limit) => start + language::Point::new(limit + 1, 0),
107            };
108
109            let content: String = buffer.text_for_range(start..end).collect();
110            self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
111                thread.push_entry(
112                    AgentThreadEntryContent::ReadFile {
113                        path: request.path.clone(),
114                        content: content.clone(),
115                    },
116                    cx,
117                );
118            });
119
120            acp::ReadTextFileResponse {
121                content,
122                version: acp::FileVersion(0),
123            }
124        })
125    }
126
127    async fn read_binary_file(
128        &self,
129        request: acp::ReadBinaryFileParams,
130    ) -> Result<acp::ReadBinaryFileResponse> {
131        let cx = &mut self.cx.clone();
132        let file = self
133            .project
134            .update(cx, |project, cx| {
135                let (worktree, path) = project
136                    .find_worktree(Path::new(&request.path), cx)
137                    .context("Failed to get project path")?;
138
139                let task = worktree.update(cx, |worktree, cx| worktree.load_binary_file(&path, cx));
140                anyhow::Ok(task)
141            })??
142            .await?;
143
144        // todo! test
145        let content = cx
146            .background_spawn(async move {
147                let start = request.byte_offset.unwrap_or(0) as usize;
148                let end = request
149                    .byte_limit
150                    .map(|limit| (start + limit as usize).min(file.content.len()))
151                    .unwrap_or(file.content.len());
152
153                let range_content = &file.content[start..end];
154
155                let mut base64_content = Vec::new();
156                let mut base64_encoder = base64::write::EncoderWriter::new(
157                    std::io::Cursor::new(&mut base64_content),
158                    &base64::engine::general_purpose::STANDARD,
159                );
160                base64_encoder.write_all(range_content)?;
161                drop(base64_encoder);
162
163                // SAFETY: The base64 encoder should not produce non-UTF8.
164                unsafe { anyhow::Ok(String::from_utf8_unchecked(base64_content)) }
165            })
166            .await?;
167
168        Ok(acp::ReadBinaryFileResponse {
169            content,
170            // todo!
171            version: acp::FileVersion(0),
172        })
173    }
174
175    async fn glob_search(&self, request: acp::GlobSearchParams) -> Result<acp::GlobSearchResponse> {
176        todo!()
177    }
178
179    async fn end_turn(&self, request: acp::EndTurnParams) -> Result<acp::EndTurnResponse> {
180        todo!()
181    }
182}
183
184impl AcpAgent {
185    pub fn stdio(mut process: Child, project: Entity<Project>, cx: AsyncApp) -> Self {
186        let stdin = process.stdin.take().expect("process didn't have stdin");
187        let stdout = process.stdout.take().expect("process didn't have stdout");
188
189        let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
190            AcpClientDelegate::new(project.clone(), cx.clone()),
191            stdin,
192            stdout,
193        );
194
195        let io_task = cx.background_spawn(async move {
196            io_fut.await.log_err();
197            process.status().await.log_err();
198        });
199
200        Self {
201            project,
202            connection: Arc::new(connection),
203            threads: Default::default(),
204            _handler_task: cx.foreground_executor().spawn(handler_fut),
205            _io_task: io_task,
206        }
207    }
208}
209
210#[async_trait(?Send)]
211impl Agent for AcpAgent {
212    async fn threads(&self, cx: &mut AsyncApp) -> Result<Vec<AgentThreadSummary>> {
213        let response = self.connection.request(acp::GetThreadsParams).await?;
214        response
215            .threads
216            .into_iter()
217            .map(|thread| {
218                Ok(AgentThreadSummary {
219                    id: thread.id.into(),
220                    title: thread.title,
221                    created_at: thread.modified_at,
222                })
223            })
224            .collect()
225    }
226
227    async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<Thread>> {
228        let response = self.connection.request(acp::CreateThreadParams).await?;
229        let thread_id: ThreadId = response.thread_id.into();
230        let agent = self.clone();
231        let thread = cx.new(|_| Thread {
232            id: thread_id.clone(),
233            next_entry_id: ThreadEntryId(0),
234            entries: Vec::default(),
235            project: self.project.clone(),
236            agent,
237        })?;
238        self.threads.lock().insert(thread_id, thread.downgrade());
239        Ok(thread)
240    }
241
242    async fn open_thread(&self, id: ThreadId, cx: &mut AsyncApp) -> Result<Entity<Thread>> {
243        todo!()
244    }
245
246    async fn thread_entries(
247        &self,
248        thread_id: ThreadId,
249        cx: &mut AsyncApp,
250    ) -> Result<Vec<AgentThreadEntryContent>> {
251        let response = self
252            .connection
253            .request(acp::GetThreadEntriesParams {
254                thread_id: thread_id.clone().into(),
255            })
256            .await?;
257
258        Ok(response
259            .entries
260            .into_iter()
261            .map(|entry| match entry {
262                acp::ThreadEntry::Message { message } => {
263                    AgentThreadEntryContent::Message(Message {
264                        role: match message.role {
265                            acp::Role::User => Role::User,
266                            acp::Role::Assistant => Role::Assistant,
267                        },
268                        chunks: message
269                            .chunks
270                            .into_iter()
271                            .map(|chunk| match chunk {
272                                acp::MessageChunk::Text { chunk } => MessageChunk::Text { chunk },
273                            })
274                            .collect(),
275                    })
276                }
277                acp::ThreadEntry::ReadFile { path, content } => {
278                    AgentThreadEntryContent::ReadFile { path, content }
279                }
280            })
281            .collect())
282    }
283
284    async fn send_thread_message(
285        &self,
286        thread_id: ThreadId,
287        message: crate::Message,
288        cx: &mut AsyncApp,
289    ) -> Result<UnboundedReceiver<Result<ResponseEvent>>> {
290        let thread = self
291            .threads
292            .lock()
293            .get(&thread_id)
294            .cloned()
295            .ok_or_else(|| anyhow!("no such thread"))?;
296        let response = self
297            .connection
298            .request(acp::SendMessageParams {
299                thread_id: thread_id.clone().into(),
300                message: acp::Message {
301                    role: match message.role {
302                        Role::User => acp::Role::User,
303                        Role::Assistant => acp::Role::Assistant,
304                    },
305                    chunks: message
306                        .chunks
307                        .into_iter()
308                        .map(|chunk| match chunk {
309                            MessageChunk::Text { chunk } => acp::MessageChunk::Text { chunk },
310                            MessageChunk::File { .. } => todo!(),
311                            MessageChunk::Directory { .. } => todo!(),
312                            MessageChunk::Symbol { .. } => todo!(),
313                            MessageChunk::Thread { .. } => todo!(),
314                            MessageChunk::Fetch { .. } => todo!(),
315                        })
316                        .collect(),
317                },
318            })
319            .await?;
320        todo!()
321    }
322}
323
324impl From<acp::ThreadId> for ThreadId {
325    fn from(thread_id: acp::ThreadId) -> Self {
326        Self(thread_id.0.into())
327    }
328}
329
330impl From<ThreadId> for acp::ThreadId {
331    fn from(thread_id: ThreadId) -> Self {
332        acp::ThreadId(thread_id.0.to_string())
333    }
334}