acp.rs

  1mod server;
  2mod thread_view;
  3
  4use agentic_coding_protocol::{self as acp, Role};
  5use anyhow::Result;
  6use chrono::{DateTime, Utc};
  7use futures::channel::oneshot;
  8use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  9use language::LanguageRegistry;
 10use markdown::Markdown;
 11use project::Project;
 12use std::{mem, ops::Range, path::PathBuf, sync::Arc};
 13use ui::App;
 14use util::{ResultExt, debug_panic};
 15
 16pub use server::AcpServer;
 17pub use thread_view::AcpThreadView;
 18
 19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 20pub struct ThreadId(SharedString);
 21
 22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 23pub struct FileVersion(u64);
 24
 25#[derive(Debug)]
 26pub struct AgentThreadSummary {
 27    pub id: ThreadId,
 28    pub title: String,
 29    pub created_at: DateTime<Utc>,
 30}
 31
 32#[derive(Clone, Debug, PartialEq, Eq)]
 33pub struct FileContent {
 34    pub path: PathBuf,
 35    pub version: FileVersion,
 36    pub content: SharedString,
 37}
 38
 39#[derive(Clone, Debug, Eq, PartialEq)]
 40pub struct Message {
 41    pub role: acp::Role,
 42    pub chunks: Vec<MessageChunk>,
 43}
 44
 45impl Message {
 46    fn into_acp(self, cx: &App) -> acp::Message {
 47        acp::Message {
 48            role: self.role,
 49            chunks: self
 50                .chunks
 51                .into_iter()
 52                .map(|chunk| chunk.into_acp(cx))
 53                .collect(),
 54        }
 55    }
 56}
 57
 58#[derive(Clone, Debug, Eq, PartialEq)]
 59pub enum MessageChunk {
 60    Text {
 61        chunk: Entity<Markdown>,
 62    },
 63    File {
 64        content: FileContent,
 65    },
 66    Directory {
 67        path: PathBuf,
 68        contents: Vec<FileContent>,
 69    },
 70    Symbol {
 71        path: PathBuf,
 72        range: Range<u64>,
 73        version: FileVersion,
 74        name: SharedString,
 75        content: SharedString,
 76    },
 77    Fetch {
 78        url: SharedString,
 79        content: SharedString,
 80    },
 81}
 82
 83impl MessageChunk {
 84    pub fn from_acp(
 85        chunk: acp::MessageChunk,
 86        language_registry: Arc<LanguageRegistry>,
 87        cx: &mut App,
 88    ) -> Self {
 89        match chunk {
 90            acp::MessageChunk::Text { chunk } => MessageChunk::Text {
 91                chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
 92            },
 93        }
 94    }
 95
 96    pub fn into_acp(self, cx: &App) -> acp::MessageChunk {
 97        match self {
 98            MessageChunk::Text { chunk } => acp::MessageChunk::Text {
 99                chunk: chunk.read(cx).source().to_string(),
100            },
101            MessageChunk::File { .. } => todo!(),
102            MessageChunk::Directory { .. } => todo!(),
103            MessageChunk::Symbol { .. } => todo!(),
104            MessageChunk::Fetch { .. } => todo!(),
105        }
106    }
107
108    pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
109        MessageChunk::Text {
110            chunk: cx.new(|cx| {
111                Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
112            }),
113        }
114    }
115}
116
117#[derive(Debug)]
118pub enum AgentThreadEntryContent {
119    Message(Message),
120    ReadFile { path: PathBuf, content: String },
121    ToolCall(ToolCall),
122}
123
124#[derive(Debug)]
125pub enum ToolCall {
126    WaitingForConfirmation {
127        id: ToolCallId,
128        tool_name: Entity<Markdown>,
129        description: Entity<Markdown>,
130        respond_tx: oneshot::Sender<bool>,
131    },
132    // todo! Running?
133    Allowed,
134    Rejected,
135}
136
137/// A `ThreadEntryId` that is known to be a ToolCall
138#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
139pub struct ToolCallId(ThreadEntryId);
140
141#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
142pub struct ThreadEntryId(pub u64);
143
144impl ThreadEntryId {
145    pub fn post_inc(&mut self) -> Self {
146        let id = *self;
147        self.0 += 1;
148        id
149    }
150}
151
152#[derive(Debug)]
153pub struct ThreadEntry {
154    pub id: ThreadEntryId,
155    pub content: AgentThreadEntryContent,
156}
157
158pub struct AcpThread {
159    id: ThreadId,
160    next_entry_id: ThreadEntryId,
161    entries: Vec<ThreadEntry>,
162    server: Arc<AcpServer>,
163    title: SharedString,
164    project: Entity<Project>,
165}
166
167enum AcpThreadEvent {
168    NewEntry,
169    EntryUpdated(usize),
170}
171
172impl EventEmitter<AcpThreadEvent> for AcpThread {}
173
174impl AcpThread {
175    pub fn new(
176        server: Arc<AcpServer>,
177        thread_id: ThreadId,
178        entries: Vec<AgentThreadEntryContent>,
179        project: Entity<Project>,
180        _: &mut Context<Self>,
181    ) -> Self {
182        let mut next_entry_id = ThreadEntryId(0);
183        Self {
184            title: "A new agent2 thread".into(),
185            entries: entries
186                .into_iter()
187                .map(|entry| ThreadEntry {
188                    id: next_entry_id.post_inc(),
189                    content: entry,
190                })
191                .collect(),
192            server,
193            id: thread_id,
194            next_entry_id,
195            project,
196        }
197    }
198
199    pub fn title(&self) -> SharedString {
200        self.title.clone()
201    }
202
203    pub fn entries(&self) -> &[ThreadEntry] {
204        &self.entries
205    }
206
207    pub fn push_entry(
208        &mut self,
209        entry: AgentThreadEntryContent,
210        cx: &mut Context<Self>,
211    ) -> ThreadEntryId {
212        let id = self.next_entry_id.post_inc();
213        self.entries.push(ThreadEntry { id, content: entry });
214        cx.emit(AcpThreadEvent::NewEntry);
215        id
216    }
217
218    pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
219        let entries_len = self.entries.len();
220        if let Some(last_entry) = self.entries.last_mut()
221            && let AgentThreadEntryContent::Message(Message {
222                ref mut chunks,
223                role: Role::Assistant,
224            }) = last_entry.content
225        {
226            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
227
228            if let (
229                Some(MessageChunk::Text { chunk: old_chunk }),
230                acp::MessageChunk::Text { chunk: new_chunk },
231            ) = (chunks.last_mut(), &chunk)
232            {
233                old_chunk.update(cx, |old_chunk, cx| {
234                    old_chunk.append(&new_chunk, cx);
235                });
236            } else {
237                chunks.push(MessageChunk::from_acp(
238                    chunk,
239                    self.project.read(cx).languages().clone(),
240                    cx,
241                ));
242            }
243
244            return;
245        }
246
247        let chunk = MessageChunk::from_acp(chunk, self.project.read(cx).languages().clone(), cx);
248
249        self.push_entry(
250            AgentThreadEntryContent::Message(Message {
251                role: Role::Assistant,
252                chunks: vec![chunk],
253            }),
254            cx,
255        );
256    }
257
258    pub fn push_tool_call(
259        &mut self,
260        title: String,
261        description: String,
262        respond_tx: oneshot::Sender<bool>,
263        cx: &mut Context<Self>,
264    ) -> ToolCallId {
265        let language_registry = self.project.read(cx).languages().clone();
266
267        let entry_id = self.push_entry(
268            AgentThreadEntryContent::ToolCall(ToolCall::WaitingForConfirmation {
269                // todo! clean up id creation
270                id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
271                tool_name: cx.new(|cx| {
272                    Markdown::new(title.into(), Some(language_registry.clone()), None, cx)
273                }),
274                description: cx.new(|cx| {
275                    Markdown::new(
276                        description.into(),
277                        Some(language_registry.clone()),
278                        None,
279                        cx,
280                    )
281                }),
282                respond_tx,
283            }),
284            cx,
285        );
286
287        ToolCallId(entry_id)
288    }
289
290    pub fn authorize_tool_call(&mut self, id: ToolCallId, allowed: bool, cx: &mut Context<Self>) {
291        let Some(entry) = self.entry_mut(id.0) else {
292            return;
293        };
294
295        let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
296            debug_panic!("expected ToolCall");
297            return;
298        };
299
300        let new_state = if allowed {
301            ToolCall::Allowed
302        } else {
303            ToolCall::Rejected
304        };
305
306        let call = mem::replace(call, new_state);
307
308        if let ToolCall::WaitingForConfirmation { respond_tx, .. } = call {
309            respond_tx.send(allowed).log_err();
310        } else {
311            debug_panic!("tried to authorize an already authorized tool call");
312        }
313
314        cx.emit(AcpThreadEvent::EntryUpdated(id.0.0 as usize));
315    }
316
317    fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
318        let entry = self.entries.get_mut(id.0 as usize);
319        debug_assert!(
320            entry.is_some(),
321            "We shouldn't give out ids to entries that don't exist"
322        );
323        entry
324    }
325
326    pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
327        let agent = self.server.clone();
328        let id = self.id.clone();
329        let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
330        let message = Message {
331            role: Role::User,
332            chunks: vec![chunk],
333        };
334        self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
335        let acp_message = message.into_acp(cx);
336        cx.spawn(async move |_, cx| {
337            agent.send_message(id, acp_message, cx).await?;
338            Ok(())
339        })
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use gpui::{AsyncApp, TestAppContext};
347    use project::FakeFs;
348    use serde_json::json;
349    use settings::SettingsStore;
350    use std::{env, path::Path, process::Stdio};
351    use util::path;
352
353    fn init_test(cx: &mut TestAppContext) {
354        env_logger::init();
355        cx.update(|cx| {
356            let settings_store = SettingsStore::test(cx);
357            cx.set_global(settings_store);
358            Project::init_settings(cx);
359            language::init(cx);
360        });
361    }
362
363    #[gpui::test]
364    async fn test_gemini(cx: &mut TestAppContext) {
365        init_test(cx);
366
367        cx.executor().allow_parking();
368
369        let fs = FakeFs::new(cx.executor());
370        fs.insert_tree(
371            path!("/private/tmp"),
372            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
373        )
374        .await;
375        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
376        let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
377        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
378        thread
379            .update(cx, |thread, cx| {
380                thread.send(
381                    "Read the '/private/tmp/foo' file and output all of its contents.",
382                    cx,
383                )
384            })
385            .await
386            .unwrap();
387
388        thread.read_with(cx, |thread, _| {
389            assert!(matches!(
390                thread.entries[0].content,
391                AgentThreadEntryContent::Message(Message {
392                    role: Role::User,
393                    ..
394                })
395            ));
396            assert!(
397                thread.entries().iter().any(|entry| {
398                    match &entry.content {
399                        AgentThreadEntryContent::ReadFile { path, content } => {
400                            path.to_string_lossy().to_string() == "/private/tmp/foo"
401                                && content == "Lorem ipsum dolor"
402                        }
403                        _ => false,
404                    }
405                }),
406                "Thread does not contain entry. Actual: {:?}",
407                thread.entries()
408            );
409        });
410    }
411
412    pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {
413        let cli_path =
414            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
415        let mut command = util::command::new_smol_command("node");
416        command
417            .arg(cli_path)
418            .arg("--acp")
419            .args(["--model", "gemini-2.5-flash"])
420            .current_dir("/private/tmp")
421            .stdin(Stdio::piped())
422            .stdout(Stdio::piped())
423            .stderr(Stdio::inherit())
424            .kill_on_drop(true);
425
426        if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
427            command.env("GEMINI_API_KEY", gemini_key);
428        }
429
430        let child = command.spawn().unwrap();
431
432        Ok(AcpServer::stdio(child, project, &mut cx))
433    }
434}