server.rs

  1use crate::{AcpThread, ThreadEntryId, ThreadId, ToolCallId, ToolCallRequest};
  2use agentic_coding_protocol as acp;
  3use anyhow::{Context as _, Result};
  4use async_trait::async_trait;
  5use collections::HashMap;
  6use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
  7use parking_lot::Mutex;
  8use project::Project;
  9use smol::process::Child;
 10use std::{process::ExitStatus, sync::Arc};
 11use util::ResultExt;
 12
 13pub struct AcpServer {
 14    connection: Arc<acp::AgentConnection>,
 15    threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
 16    project: Entity<Project>,
 17    exit_status: Arc<Mutex<Option<ExitStatus>>>,
 18    _handler_task: Task<()>,
 19    _io_task: Task<()>,
 20}
 21
 22struct AcpClientDelegate {
 23    project: Entity<Project>,
 24    threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
 25    cx: AsyncApp,
 26    // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
 27}
 28
 29impl AcpClientDelegate {
 30    fn new(
 31        project: Entity<Project>,
 32        threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
 33        cx: AsyncApp,
 34    ) -> Self {
 35        Self {
 36            project,
 37            threads,
 38            cx: cx,
 39        }
 40    }
 41
 42    fn update_thread<R>(
 43        &self,
 44        thread_id: &ThreadId,
 45        cx: &mut App,
 46        callback: impl FnOnce(&mut AcpThread, &mut Context<AcpThread>) -> R,
 47    ) -> Option<R> {
 48        let thread = self.threads.lock().get(&thread_id)?.clone();
 49        let Some(thread) = thread.upgrade() else {
 50            self.threads.lock().remove(&thread_id);
 51            return None;
 52        };
 53        Some(thread.update(cx, callback))
 54    }
 55}
 56
 57#[async_trait(?Send)]
 58impl acp::Client for AcpClientDelegate {
 59    async fn stream_message_chunk(
 60        &self,
 61        params: acp::StreamMessageChunkParams,
 62    ) -> Result<acp::StreamMessageChunkResponse> {
 63        let cx = &mut self.cx.clone();
 64
 65        cx.update(|cx| {
 66            self.update_thread(&params.thread_id.into(), cx, |thread, cx| {
 67                thread.push_assistant_chunk(params.chunk, cx)
 68            });
 69        })?;
 70
 71        Ok(acp::StreamMessageChunkResponse)
 72    }
 73
 74    async fn request_tool_call_confirmation(
 75        &self,
 76        request: acp::RequestToolCallConfirmationParams,
 77    ) -> Result<acp::RequestToolCallConfirmationResponse> {
 78        let cx = &mut self.cx.clone();
 79        let ToolCallRequest { id, outcome } = cx
 80            .update(|cx| {
 81                self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
 82                    thread.request_tool_call(
 83                        request.label,
 84                        request.icon,
 85                        request.content,
 86                        request.confirmation,
 87                        cx,
 88                    )
 89                })
 90            })?
 91            .context("Failed to update thread")?;
 92
 93        Ok(acp::RequestToolCallConfirmationResponse {
 94            id: id.into(),
 95            outcome: outcome.await?,
 96        })
 97    }
 98
 99    async fn push_tool_call(
100        &self,
101        request: acp::PushToolCallParams,
102    ) -> Result<acp::PushToolCallResponse> {
103        let cx = &mut self.cx.clone();
104        let entry_id = cx
105            .update(|cx| {
106                self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
107                    thread.push_tool_call(request.label, request.icon, request.content, cx)
108                })
109            })?
110            .context("Failed to update thread")?;
111
112        Ok(acp::PushToolCallResponse {
113            id: entry_id.into(),
114        })
115    }
116
117    async fn update_tool_call(
118        &self,
119        request: acp::UpdateToolCallParams,
120    ) -> Result<acp::UpdateToolCallResponse> {
121        let cx = &mut self.cx.clone();
122
123        cx.update(|cx| {
124            self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
125                thread.update_tool_call(
126                    request.tool_call_id.into(),
127                    request.status,
128                    request.content,
129                    cx,
130                )
131            })
132        })?
133        .context("Failed to update thread")??;
134
135        Ok(acp::UpdateToolCallResponse)
136    }
137}
138
139impl AcpServer {
140    pub fn stdio(mut process: Child, project: Entity<Project>, cx: &mut App) -> Arc<Self> {
141        let stdin = process.stdin.take().expect("process didn't have stdin");
142        let stdout = process.stdout.take().expect("process didn't have stdout");
143
144        let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
145        let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
146            AcpClientDelegate::new(project.clone(), threads.clone(), cx.to_async()),
147            stdin,
148            stdout,
149        );
150
151        let exit_status: Arc<Mutex<Option<ExitStatus>>> = Default::default();
152        let io_task = cx.background_spawn({
153            let exit_status = exit_status.clone();
154            async move {
155                io_fut.await.log_err();
156                let result = process.status().await.log_err();
157                *exit_status.lock() = result;
158            }
159        });
160
161        Arc::new(Self {
162            project,
163            connection: Arc::new(connection),
164            threads,
165            exit_status,
166            _handler_task: cx.foreground_executor().spawn(handler_fut),
167            _io_task: io_task,
168        })
169    }
170
171    pub async fn initialize(&self) -> Result<acp::InitializeResponse> {
172        self.connection
173            .request(acp::InitializeParams)
174            .await
175            .map_err(to_anyhow)
176    }
177
178    pub async fn authenticate(&self) -> Result<()> {
179        self.connection
180            .request(acp::AuthenticateParams)
181            .await
182            .map_err(to_anyhow)?;
183
184        Ok(())
185    }
186
187    pub async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<AcpThread>> {
188        let response = self
189            .connection
190            .request(acp::CreateThreadParams)
191            .await
192            .map_err(to_anyhow)?;
193
194        let thread_id: ThreadId = response.thread_id.into();
195        let server = self.clone();
196        let thread = cx.new(|_| AcpThread {
197            // todo!
198            title: "ACP Thread".into(),
199            id: thread_id.clone(), // Either<ErrorState, Id>
200            next_entry_id: ThreadEntryId(0),
201            entries: Vec::default(),
202            project: self.project.clone(),
203            server,
204        })?;
205        self.threads.lock().insert(thread_id, thread.downgrade());
206        Ok(thread)
207    }
208
209    pub async fn send_message(
210        &self,
211        thread_id: ThreadId,
212        message: acp::Message,
213        _cx: &mut AsyncApp,
214    ) -> Result<()> {
215        self.connection
216            .request(acp::SendMessageParams {
217                thread_id: thread_id.clone().into(),
218                message,
219            })
220            .await
221            .map_err(to_anyhow)?;
222        Ok(())
223    }
224
225    pub fn exit_status(&self) -> Option<ExitStatus> {
226        *self.exit_status.lock()
227    }
228}
229
230#[track_caller]
231fn to_anyhow(e: acp::Error) -> anyhow::Error {
232    log::error!(
233        "failed to send message: {code}: {message}",
234        code = e.code,
235        message = e.message
236    );
237    anyhow::anyhow!(e.message)
238}
239
240impl From<acp::ThreadId> for ThreadId {
241    fn from(thread_id: acp::ThreadId) -> Self {
242        Self(thread_id.0.into())
243    }
244}
245
246impl From<ThreadId> for acp::ThreadId {
247    fn from(thread_id: ThreadId) -> Self {
248        acp::ThreadId(thread_id.0.to_string())
249    }
250}
251
252impl From<acp::ToolCallId> for ToolCallId {
253    fn from(tool_call_id: acp::ToolCallId) -> Self {
254        Self(ThreadEntryId(tool_call_id.0))
255    }
256}
257
258impl From<ToolCallId> for acp::ToolCallId {
259    fn from(tool_call_id: ToolCallId) -> Self {
260        acp::ToolCallId(tool_call_id.as_u64())
261    }
262}