acp.rs

  1use std::{io::Write as _, path::Path, sync::Arc};
  2
  3use crate::{
  4    Agent, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk, Role, Thread,
  5    ThreadEntryId, ThreadId,
  6};
  7use agentic_coding_protocol as acp;
  8use anyhow::{Context as _, Result};
  9use async_trait::async_trait;
 10use collections::HashMap;
 11use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
 12use parking_lot::Mutex;
 13use project::Project;
 14use smol::process::Child;
 15use util::ResultExt;
 16
 17pub struct AcpAgent {
 18    connection: Arc<acp::AgentConnection>,
 19    threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>>,
 20    project: Entity<Project>,
 21    _handler_task: Task<()>,
 22    _io_task: Task<()>,
 23}
 24
 25struct AcpClientDelegate {
 26    project: Entity<Project>,
 27    threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>>,
 28    cx: AsyncApp,
 29    // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
 30}
 31
 32impl AcpClientDelegate {
 33    fn new(
 34        project: Entity<Project>,
 35        threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>>,
 36        cx: AsyncApp,
 37    ) -> Self {
 38        Self {
 39            project,
 40            threads,
 41            cx: cx,
 42        }
 43    }
 44
 45    fn update_thread<R>(
 46        &self,
 47        thread_id: &ThreadId,
 48        cx: &mut App,
 49        callback: impl FnMut(&mut Thread, &mut Context<Thread>) -> R,
 50    ) -> Option<R> {
 51        let thread = self.threads.lock().get(&thread_id)?.clone();
 52        let Some(thread) = thread.upgrade() else {
 53            self.threads.lock().remove(&thread_id);
 54            return None;
 55        };
 56        Some(thread.update(cx, callback))
 57    }
 58}
 59
 60#[async_trait(?Send)]
 61impl acp::Client for AcpClientDelegate {
 62    async fn stat(&self, params: acp::StatParams) -> Result<acp::StatResponse> {
 63        let cx = &mut self.cx.clone();
 64        self.project.update(cx, |project, cx| {
 65            let path = project
 66                .project_path_for_absolute_path(Path::new(&params.path), cx)
 67                .context("Failed to get project path")?;
 68
 69            match project.entry_for_path(&path, cx) {
 70                // todo! refresh entry?
 71                None => Ok(acp::StatResponse {
 72                    exists: false,
 73                    is_directory: false,
 74                }),
 75                Some(entry) => Ok(acp::StatResponse {
 76                    exists: entry.is_created(),
 77                    is_directory: entry.is_dir(),
 78                }),
 79            }
 80        })?
 81    }
 82
 83    async fn stream_message_chunk(
 84        &self,
 85        params: acp::StreamMessageChunkParams,
 86    ) -> Result<acp::StreamMessageChunkResponse> {
 87        let cx = &mut self.cx.clone();
 88
 89        cx.update(|cx| {
 90            self.update_thread(&params.thread_id.into(), cx, |thread, cx| {
 91                let acp::MessageChunk::Text { chunk } = &params.chunk;
 92                thread.push_assistant_chunk(
 93                    MessageChunk::Text {
 94                        chunk: chunk.into(),
 95                    },
 96                    cx,
 97                )
 98            });
 99        })?;
100
101        Ok(acp::StreamMessageChunkResponse)
102    }
103
104    async fn read_text_file(
105        &self,
106        request: acp::ReadTextFileParams,
107    ) -> Result<acp::ReadTextFileResponse> {
108        let cx = &mut self.cx.clone();
109        let buffer = self
110            .project
111            .update(cx, |project, cx| {
112                let path = project
113                    .project_path_for_absolute_path(Path::new(&request.path), cx)
114                    .context("Failed to get project path")?;
115                anyhow::Ok(project.open_buffer(path, cx))
116            })??
117            .await?;
118
119        buffer.update(cx, |buffer, cx| {
120            let start = language::Point::new(request.line_offset.unwrap_or(0), 0);
121            let end = match request.line_limit {
122                None => buffer.max_point(),
123                Some(limit) => start + language::Point::new(limit + 1, 0),
124            };
125
126            let content: String = buffer.text_for_range(start..end).collect();
127            self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
128                thread.push_entry(
129                    AgentThreadEntryContent::ReadFile {
130                        path: request.path.clone(),
131                        content: content.clone(),
132                    },
133                    cx,
134                );
135            });
136
137            acp::ReadTextFileResponse {
138                content,
139                version: acp::FileVersion(0),
140            }
141        })
142    }
143
144    async fn read_binary_file(
145        &self,
146        request: acp::ReadBinaryFileParams,
147    ) -> Result<acp::ReadBinaryFileResponse> {
148        let cx = &mut self.cx.clone();
149        let file = self
150            .project
151            .update(cx, |project, cx| {
152                let (worktree, path) = project
153                    .find_worktree(Path::new(&request.path), cx)
154                    .context("Failed to get project path")?;
155
156                let task = worktree.update(cx, |worktree, cx| worktree.load_binary_file(&path, cx));
157                anyhow::Ok(task)
158            })??
159            .await?;
160
161        // todo! test
162        let content = cx
163            .background_spawn(async move {
164                let start = request.byte_offset.unwrap_or(0) as usize;
165                let end = request
166                    .byte_limit
167                    .map(|limit| (start + limit as usize).min(file.content.len()))
168                    .unwrap_or(file.content.len());
169
170                let range_content = &file.content[start..end];
171
172                let mut base64_content = Vec::new();
173                let mut base64_encoder = base64::write::EncoderWriter::new(
174                    std::io::Cursor::new(&mut base64_content),
175                    &base64::engine::general_purpose::STANDARD,
176                );
177                base64_encoder.write_all(range_content)?;
178                drop(base64_encoder);
179
180                // SAFETY: The base64 encoder should not produce non-UTF8.
181                unsafe { anyhow::Ok(String::from_utf8_unchecked(base64_content)) }
182            })
183            .await?;
184
185        Ok(acp::ReadBinaryFileResponse {
186            content,
187            // todo!
188            version: acp::FileVersion(0),
189        })
190    }
191
192    async fn glob_search(&self, request: acp::GlobSearchParams) -> Result<acp::GlobSearchResponse> {
193        todo!()
194    }
195}
196
197impl AcpAgent {
198    pub fn stdio(mut process: Child, project: Entity<Project>, cx: &mut AsyncApp) -> Arc<Self> {
199        let stdin = process.stdin.take().expect("process didn't have stdin");
200        let stdout = process.stdout.take().expect("process didn't have stdout");
201
202        let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>> = Default::default();
203        let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
204            AcpClientDelegate::new(project.clone(), threads.clone(), cx.clone()),
205            stdin,
206            stdout,
207        );
208
209        let io_task = cx.background_spawn(async move {
210            io_fut.await.log_err();
211            process.status().await.log_err();
212        });
213
214        Arc::new(Self {
215            project,
216            connection: Arc::new(connection),
217            threads,
218            _handler_task: cx.foreground_executor().spawn(handler_fut),
219            _io_task: io_task,
220        })
221    }
222}
223
224#[async_trait(?Send)]
225impl Agent for AcpAgent {
226    async fn threads(&self, cx: &mut AsyncApp) -> Result<Vec<AgentThreadSummary>> {
227        let response = self.connection.request(acp::GetThreadsParams).await?;
228        response
229            .threads
230            .into_iter()
231            .map(|thread| {
232                Ok(AgentThreadSummary {
233                    id: thread.id.into(),
234                    title: thread.title,
235                    created_at: thread.modified_at,
236                })
237            })
238            .collect()
239    }
240
241    async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<Thread>> {
242        let response = self.connection.request(acp::CreateThreadParams).await?;
243        let thread_id: ThreadId = response.thread_id.into();
244        let agent = self.clone();
245        let thread = cx.new(|_| Thread {
246            title: "The agent2 thread".into(),
247            id: thread_id.clone(),
248            next_entry_id: ThreadEntryId(0),
249            entries: Vec::default(),
250            project: self.project.clone(),
251            agent,
252        })?;
253        self.threads.lock().insert(thread_id, thread.downgrade());
254        Ok(thread)
255    }
256
257    async fn open_thread(&self, id: ThreadId, cx: &mut AsyncApp) -> Result<Entity<Thread>> {
258        todo!()
259    }
260
261    async fn thread_entries(
262        &self,
263        thread_id: ThreadId,
264        cx: &mut AsyncApp,
265    ) -> Result<Vec<AgentThreadEntryContent>> {
266        let response = self
267            .connection
268            .request(acp::GetThreadEntriesParams {
269                thread_id: thread_id.clone().into(),
270            })
271            .await?;
272
273        Ok(response
274            .entries
275            .into_iter()
276            .map(|entry| match entry {
277                acp::ThreadEntry::Message { message } => {
278                    AgentThreadEntryContent::Message(Message {
279                        role: match message.role {
280                            acp::Role::User => Role::User,
281                            acp::Role::Assistant => Role::Assistant,
282                        },
283                        chunks: message
284                            .chunks
285                            .into_iter()
286                            .map(|chunk| match chunk {
287                                acp::MessageChunk::Text { chunk } => MessageChunk::Text {
288                                    chunk: chunk.into(),
289                                },
290                            })
291                            .collect(),
292                    })
293                }
294                acp::ThreadEntry::ReadFile { path, content } => {
295                    AgentThreadEntryContent::ReadFile { path, content }
296                }
297            })
298            .collect())
299    }
300
301    async fn send_thread_message(
302        &self,
303        thread_id: ThreadId,
304        message: crate::Message,
305        cx: &mut AsyncApp,
306    ) -> Result<()> {
307        self.connection
308            .request(acp::SendMessageParams {
309                thread_id: thread_id.clone().into(),
310                message: acp::Message {
311                    role: match message.role {
312                        Role::User => acp::Role::User,
313                        Role::Assistant => acp::Role::Assistant,
314                    },
315                    chunks: message
316                        .chunks
317                        .into_iter()
318                        .map(|chunk| match chunk {
319                            MessageChunk::Text { chunk } => acp::MessageChunk::Text {
320                                chunk: chunk.into(),
321                            },
322                            MessageChunk::File { .. } => todo!(),
323                            MessageChunk::Directory { .. } => todo!(),
324                            MessageChunk::Symbol { .. } => todo!(),
325                            MessageChunk::Fetch { .. } => todo!(),
326                        })
327                        .collect(),
328                },
329            })
330            .await?;
331        Ok(())
332    }
333}
334
335impl From<acp::ThreadId> for ThreadId {
336    fn from(thread_id: acp::ThreadId) -> Self {
337        Self(thread_id.0.into())
338    }
339}
340
341impl From<ThreadId> for acp::ThreadId {
342    fn from(thread_id: ThreadId) -> Self {
343        acp::ThreadId(thread_id.0.to_string())
344    }
345}