acp.rs

  1use std::{
  2    io::{Cursor, Write as _},
  3    path::Path,
  4    sync::{Arc, Weak},
  5};
  6
  7use crate::{
  8    Agent, AgentThread, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk,
  9    ResponseEvent, Role, Thread, ThreadEntry, ThreadId,
 10};
 11use agentic_coding_protocol::{self as acp, TurnId};
 12use anyhow::{Context as _, Result};
 13use async_trait::async_trait;
 14use collections::HashMap;
 15use futures::channel::mpsc::UnboundedReceiver;
 16use gpui::{AppContext, AsyncApp, Entity, Task, WeakEntity};
 17use parking_lot::Mutex;
 18use project::Project;
 19use smol::process::Child;
 20use util::ResultExt;
 21
 22pub struct AcpAgent {
 23    connection: Arc<acp::AgentConnection>,
 24    threads: Arc<Mutex<HashMap<acp::ThreadId, WeakEntity<Thread>>>>,
 25    _handler_task: Task<()>,
 26    _io_task: Task<()>,
 27}
 28
 29struct AcpClientDelegate {
 30    project: Entity<Project>,
 31    threads: Arc<Mutex<HashMap<acp::ThreadId, WeakEntity<Thread>>>>,
 32    cx: AsyncApp,
 33    // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
 34}
 35
 36#[async_trait(?Send)]
 37impl acp::Client for AcpClientDelegate {
 38    async fn stat(&self, params: acp::StatParams) -> Result<acp::StatResponse> {
 39        let cx = &mut self.cx.clone();
 40        self.project.update(cx, |project, cx| {
 41            let path = project
 42                .project_path_for_absolute_path(Path::new(&params.path), cx)
 43                .context("Failed to get project path")?;
 44
 45            match project.entry_for_path(&path, cx) {
 46                // todo! refresh entry?
 47                None => Ok(acp::StatResponse {
 48                    exists: false,
 49                    is_directory: false,
 50                }),
 51                Some(entry) => Ok(acp::StatResponse {
 52                    exists: entry.is_created(),
 53                    is_directory: entry.is_dir(),
 54                }),
 55            }
 56        })?
 57    }
 58
 59    async fn stream_message_chunk(
 60        &self,
 61        request: acp::StreamMessageChunkParams,
 62    ) -> Result<acp::StreamMessageChunkResponse> {
 63        Ok(acp::StreamMessageChunkResponse)
 64    }
 65
 66    async fn read_text_file(
 67        &self,
 68        request: acp::ReadTextFileParams,
 69    ) -> Result<acp::ReadTextFileResponse> {
 70        let cx = &mut self.cx.clone();
 71        let buffer = self
 72            .project
 73            .update(cx, |project, cx| {
 74                let path = project
 75                    .project_path_for_absolute_path(Path::new(&request.path), cx)
 76                    .context("Failed to get project path")?;
 77                anyhow::Ok(project.open_buffer(path, cx))
 78            })??
 79            .await?;
 80
 81        buffer.update(cx, |buffer, _| {
 82            let start = language::Point::new(request.line_offset.unwrap_or(0), 0);
 83            let end = match request.line_limit {
 84                None => buffer.max_point(),
 85                Some(limit) => start + language::Point::new(limit + 1, 0),
 86            };
 87
 88            let content = buffer.text_for_range(start..end).collect();
 89
 90            if let Some(thread) = self.threads.lock().get(&request.thread_id) {
 91                thread.update(cx, |thread, cx| {
 92                    thread.push_entry(ThreadEntry {
 93                        content: AgentThreadEntryContent::ReadFile {
 94                            path: request.path.clone(),
 95                            content: content.clone(),
 96                        },
 97                    });
 98                })
 99            }
100
101            acp::ReadTextFileResponse {
102                content,
103                version: acp::FileVersion(0),
104            }
105        })
106    }
107
108    async fn read_binary_file(
109        &self,
110        request: acp::ReadBinaryFileParams,
111    ) -> Result<acp::ReadBinaryFileResponse> {
112        let cx = &mut self.cx.clone();
113        let file = self
114            .project
115            .update(cx, |project, cx| {
116                let (worktree, path) = project
117                    .find_worktree(Path::new(&request.path), cx)
118                    .context("Failed to get project path")?;
119
120                let task = worktree.update(cx, |worktree, cx| worktree.load_binary_file(&path, cx));
121                anyhow::Ok(task)
122            })??
123            .await?;
124
125        // todo! test
126        let content = cx
127            .background_spawn(async move {
128                let start = request.byte_offset.unwrap_or(0) as usize;
129                let end = request
130                    .byte_limit
131                    .map(|limit| (start + limit as usize).min(file.content.len()))
132                    .unwrap_or(file.content.len());
133
134                let range_content = &file.content[start..end];
135
136                let mut base64_content = Vec::new();
137                let mut base64_encoder = base64::write::EncoderWriter::new(
138                    Cursor::new(&mut base64_content),
139                    &base64::engine::general_purpose::STANDARD,
140                );
141                base64_encoder.write_all(range_content)?;
142                drop(base64_encoder);
143
144                // SAFETY: The base64 encoder should not produce non-UTF8.
145                unsafe { anyhow::Ok(String::from_utf8_unchecked(base64_content)) }
146            })
147            .await?;
148
149        Ok(acp::ReadBinaryFileResponse {
150            content,
151            // todo!
152            version: acp::FileVersion(0),
153        })
154    }
155
156    async fn glob_search(&self, request: acp::GlobSearchParams) -> Result<acp::GlobSearchResponse> {
157        todo!()
158    }
159
160    async fn end_turn(&self, request: acp::EndTurnParams) -> Result<acp::EndTurnResponse> {
161        todo!()
162    }
163}
164
165impl AcpAgent {
166    pub fn stdio(mut process: Child, project: Entity<Project>, cx: AsyncApp) -> Self {
167        let stdin = process.stdin.take().expect("process didn't have stdin");
168        let stdout = process.stdout.take().expect("process didn't have stdout");
169
170        let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
171            AcpClientDelegate {
172                project,
173                cx: cx.clone(),
174            },
175            stdin,
176            stdout,
177        );
178
179        let io_task = cx.background_spawn(async move {
180            io_fut.await.log_err();
181            process.status().await.log_err();
182        });
183
184        Self {
185            connection: Arc::new(connection),
186            threads: Mutex::default(),
187            _handler_task: cx.foreground_executor().spawn(handler_fut),
188            _io_task: io_task,
189        }
190    }
191}
192
193#[async_trait]
194impl Agent for AcpAgent {
195    async fn threads(&self) -> Result<Vec<AgentThreadSummary>> {
196        let response = self.connection.request(acp::GetThreadsParams).await?;
197        response
198            .threads
199            .into_iter()
200            .map(|thread| {
201                Ok(AgentThreadSummary {
202                    id: thread.id.into(),
203                    title: thread.title,
204                    created_at: thread.modified_at,
205                })
206            })
207            .collect()
208    }
209
210    async fn create_thread(&self) -> Result<Arc<Self::Thread>> {
211        let response = self.connection.request(acp::CreateThreadParams).await?;
212        let thread = Arc::new(AcpAgentThread {
213            id: response.thread_id.clone(),
214            connection: self.connection.clone(),
215            state: Mutex::new(AcpAgentThreadState {
216                turn: None,
217                next_turn_id: TurnId::default(),
218            }),
219        });
220        self.threads
221            .lock()
222            .insert(response.thread_id, Arc::downgrade(&thread));
223        Ok(thread)
224    }
225
226    async fn open_thread(&self, id: ThreadId) -> Result<Thread> {
227        todo!()
228    }
229
230    async fn thread_entries(&self, thread_id: ThreadId) -> Result<Vec<AgentThreadEntryContent>> {
231        let response = self
232            .connection
233            .request(acp::GetThreadEntriesParams {
234                thread_id: self.id.clone(),
235            })
236            .await?;
237
238        Ok(response
239            .entries
240            .into_iter()
241            .map(|entry| match entry {
242                acp::ThreadEntry::Message { message } => {
243                    AgentThreadEntryContent::Message(Message {
244                        role: match message.role {
245                            acp::Role::User => Role::User,
246                            acp::Role::Assistant => Role::Assistant,
247                        },
248                        chunks: message
249                            .chunks
250                            .into_iter()
251                            .map(|chunk| match chunk {
252                                acp::MessageChunk::Text { chunk } => MessageChunk::Text { chunk },
253                            })
254                            .collect(),
255                    })
256                }
257                acp::ThreadEntry::ReadFile { path, content } => {
258                    AgentThreadEntryContent::ReadFile { path, content }
259                }
260            })
261            .collect())
262    }
263
264    async fn send_thread_message(
265        &self,
266        thread_id: ThreadId,
267        message: crate::Message,
268    ) -> Result<UnboundedReceiver<Result<ResponseEvent>>> {
269        let turn_id = {
270            let mut state = self.state.lock();
271            let turn_id = state.next_turn_id.post_inc();
272            state.turn = Some(AcpAgentThreadTurn { id: turn_id });
273            turn_id
274        };
275        let response = self
276            .connection
277            .request(acp::SendMessageParams {
278                thread_id: self.id.clone(),
279                turn_id,
280                message: acp::Message {
281                    role: match message.role {
282                        Role::User => acp::Role::User,
283                        Role::Assistant => acp::Role::Assistant,
284                    },
285                    chunks: message
286                        .chunks
287                        .into_iter()
288                        .map(|chunk| match chunk {
289                            MessageChunk::Text { chunk } => acp::MessageChunk::Text { chunk },
290                            MessageChunk::File { .. } => todo!(),
291                            MessageChunk::Directory { .. } => todo!(),
292                            MessageChunk::Symbol { .. } => todo!(),
293                            MessageChunk::Thread { .. } => todo!(),
294                            MessageChunk::Fetch { .. } => todo!(),
295                        })
296                        .collect(),
297                },
298            })
299            .await?;
300        todo!()
301    }
302}
303
304pub struct AcpAgentThread {
305    id: acp::ThreadId,
306    connection: Arc<acp::AgentConnection>,
307    state: Mutex<AcpAgentThreadState>,
308}
309
310struct AcpAgentThreadState {
311    next_turn_id: acp::TurnId,
312    turn: Option<AcpAgentThreadTurn>,
313}
314
315struct AcpAgentThreadTurn {
316    id: acp::TurnId,
317}
318
319impl From<acp::ThreadId> for ThreadId {
320    fn from(thread_id: acp::ThreadId) -> Self {
321        Self(thread_id.0)
322    }
323}
324
325impl From<ThreadId> for acp::ThreadId {
326    fn from(thread_id: ThreadId) -> Self {
327        acp::ThreadId(thread_id.0)
328    }
329}