server.rs

  1use crate::{AcpThread, AgentThreadEntryContent, MessageChunk, Role, ThreadEntryId, ThreadId};
  2use agentic_coding_protocol as acp;
  3use anyhow::{Context as _, Result};
  4use async_trait::async_trait;
  5use collections::HashMap;
  6use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
  7use parking_lot::Mutex;
  8use project::Project;
  9use smol::process::Child;
 10use std::{io::Write as _, path::Path, sync::Arc};
 11use util::ResultExt;
 12
 13pub struct AcpServer {
 14    connection: Arc<acp::AgentConnection>,
 15    threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
 16    project: Entity<Project>,
 17    _handler_task: Task<()>,
 18    _io_task: Task<()>,
 19}
 20
 21struct AcpClientDelegate {
 22    project: Entity<Project>,
 23    threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
 24    cx: AsyncApp,
 25    // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
 26}
 27
 28impl AcpClientDelegate {
 29    fn new(
 30        project: Entity<Project>,
 31        threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
 32        cx: AsyncApp,
 33    ) -> Self {
 34        Self {
 35            project,
 36            threads,
 37            cx: cx,
 38        }
 39    }
 40
 41    fn update_thread<R>(
 42        &self,
 43        thread_id: &ThreadId,
 44        cx: &mut App,
 45        callback: impl FnMut(&mut AcpThread, &mut Context<AcpThread>) -> R,
 46    ) -> Option<R> {
 47        let thread = self.threads.lock().get(&thread_id)?.clone();
 48        let Some(thread) = thread.upgrade() else {
 49            self.threads.lock().remove(&thread_id);
 50            return None;
 51        };
 52        Some(thread.update(cx, callback))
 53    }
 54}
 55
 56#[async_trait(?Send)]
 57impl acp::Client for AcpClientDelegate {
 58    async fn stat(&self, params: acp::StatParams) -> Result<acp::StatResponse> {
 59        let cx = &mut self.cx.clone();
 60        self.project.update(cx, |project, cx| {
 61            let path = project
 62                .project_path_for_absolute_path(Path::new(&params.path), cx)
 63                .context("Failed to get project path")?;
 64
 65            match project.entry_for_path(&path, cx) {
 66                // todo! refresh entry?
 67                None => Ok(acp::StatResponse {
 68                    exists: false,
 69                    is_directory: false,
 70                }),
 71                Some(entry) => Ok(acp::StatResponse {
 72                    exists: entry.is_created(),
 73                    is_directory: entry.is_dir(),
 74                }),
 75            }
 76        })?
 77    }
 78
 79    async fn stream_message_chunk(
 80        &self,
 81        params: acp::StreamMessageChunkParams,
 82    ) -> Result<acp::StreamMessageChunkResponse> {
 83        dbg!();
 84        let cx = &mut self.cx.clone();
 85
 86        cx.update(|cx| {
 87            self.update_thread(&params.thread_id.into(), cx, |thread, cx| {
 88                let acp::MessageChunk::Text { chunk } = &params.chunk;
 89                thread.push_assistant_chunk(
 90                    MessageChunk::Text {
 91                        chunk: chunk.into(),
 92                    },
 93                    cx,
 94                )
 95            });
 96        })?;
 97
 98        Ok(acp::StreamMessageChunkResponse)
 99    }
100
101    async fn read_text_file(
102        &self,
103        request: acp::ReadTextFileParams,
104    ) -> Result<acp::ReadTextFileResponse> {
105        let cx = &mut self.cx.clone();
106        let buffer = self
107            .project
108            .update(cx, |project, cx| {
109                let path = project
110                    .project_path_for_absolute_path(Path::new(&request.path), cx)
111                    .context("Failed to get project path")?;
112                anyhow::Ok(project.open_buffer(path, cx))
113            })??
114            .await?;
115
116        buffer.update(cx, |buffer, cx| {
117            let start = language::Point::new(request.line_offset.unwrap_or(0), 0);
118            let end = match request.line_limit {
119                None => buffer.max_point(),
120                Some(limit) => start + language::Point::new(limit + 1, 0),
121            };
122
123            let content: String = buffer.text_for_range(start..end).collect();
124            self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
125                thread.push_entry(
126                    AgentThreadEntryContent::ReadFile {
127                        path: request.path.clone(),
128                        content: content.clone(),
129                    },
130                    cx,
131                );
132            });
133
134            acp::ReadTextFileResponse {
135                content,
136                version: acp::FileVersion(0),
137            }
138        })
139    }
140
141    async fn read_binary_file(
142        &self,
143        request: acp::ReadBinaryFileParams,
144    ) -> Result<acp::ReadBinaryFileResponse> {
145        let cx = &mut self.cx.clone();
146        let file = self
147            .project
148            .update(cx, |project, cx| {
149                let (worktree, path) = project
150                    .find_worktree(Path::new(&request.path), cx)
151                    .context("Failed to get project path")?;
152
153                let task = worktree.update(cx, |worktree, cx| worktree.load_binary_file(&path, cx));
154                anyhow::Ok(task)
155            })??
156            .await?;
157
158        // todo! test
159        let content = cx
160            .background_spawn(async move {
161                let start = request.byte_offset.unwrap_or(0) as usize;
162                let end = request
163                    .byte_limit
164                    .map(|limit| (start + limit as usize).min(file.content.len()))
165                    .unwrap_or(file.content.len());
166
167                let range_content = &file.content[start..end];
168
169                let mut base64_content = Vec::new();
170                let mut base64_encoder = base64::write::EncoderWriter::new(
171                    std::io::Cursor::new(&mut base64_content),
172                    &base64::engine::general_purpose::STANDARD,
173                );
174                base64_encoder.write_all(range_content)?;
175                drop(base64_encoder);
176
177                // SAFETY: The base64 encoder should not produce non-UTF8.
178                unsafe { anyhow::Ok(String::from_utf8_unchecked(base64_content)) }
179            })
180            .await?;
181
182        Ok(acp::ReadBinaryFileResponse {
183            content,
184            // todo!
185            version: acp::FileVersion(0),
186        })
187    }
188
189    async fn glob_search(&self, request: acp::GlobSearchParams) -> Result<acp::GlobSearchResponse> {
190        todo!()
191    }
192}
193
194impl AcpServer {
195    pub fn stdio(mut process: Child, project: Entity<Project>, cx: &mut AsyncApp) -> Arc<Self> {
196        let stdin = process.stdin.take().expect("process didn't have stdin");
197        let stdout = process.stdout.take().expect("process didn't have stdout");
198
199        let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
200        let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
201            AcpClientDelegate::new(project.clone(), threads.clone(), cx.clone()),
202            stdin,
203            stdout,
204        );
205
206        let io_task = cx.background_spawn(async move {
207            io_fut.await.log_err();
208            process.status().await.log_err();
209        });
210
211        Arc::new(Self {
212            project,
213            connection: Arc::new(connection),
214            threads,
215            _handler_task: cx.foreground_executor().spawn(handler_fut),
216            _io_task: io_task,
217        })
218    }
219}
220
221impl AcpServer {
222    pub async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<AcpThread>> {
223        let response = self.connection.request(acp::CreateThreadParams).await?;
224        let thread_id: ThreadId = response.thread_id.into();
225        let server = self.clone();
226        let thread = cx.new(|_| AcpThread {
227            title: "The agent2 thread".into(),
228            id: thread_id.clone(),
229            next_entry_id: ThreadEntryId(0),
230            entries: Vec::default(),
231            project: self.project.clone(),
232            server,
233        })?;
234        self.threads.lock().insert(thread_id, thread.downgrade());
235        Ok(thread)
236    }
237
238    pub async fn send_message(
239        &self,
240        thread_id: ThreadId,
241        message: crate::Message,
242        cx: &mut AsyncApp,
243    ) -> Result<()> {
244        self.connection
245            .request(acp::SendMessageParams {
246                thread_id: thread_id.clone().into(),
247                message: acp::Message {
248                    role: match message.role {
249                        Role::User => acp::Role::User,
250                        Role::Assistant => acp::Role::Assistant,
251                    },
252                    chunks: message
253                        .chunks
254                        .into_iter()
255                        .map(|chunk| match chunk {
256                            MessageChunk::Text { chunk } => acp::MessageChunk::Text {
257                                chunk: chunk.into(),
258                            },
259                            MessageChunk::File { .. } => todo!(),
260                            MessageChunk::Directory { .. } => todo!(),
261                            MessageChunk::Symbol { .. } => todo!(),
262                            MessageChunk::Fetch { .. } => todo!(),
263                        })
264                        .collect(),
265                },
266            })
267            .await?;
268        Ok(())
269    }
270}
271
272impl From<acp::ThreadId> for ThreadId {
273    fn from(thread_id: acp::ThreadId) -> Self {
274        Self(thread_id.0.into())
275    }
276}
277
278impl From<ThreadId> for acp::ThreadId {
279    fn from(thread_id: ThreadId) -> Self {
280        acp::ThreadId(thread_id.0.to_string())
281    }
282}