acp.rs

  1mod server;
  2mod thread_view;
  3
  4use agentic_coding_protocol::{self as acp, Role};
  5use anyhow::Result;
  6use chrono::{DateTime, Utc};
  7use futures::channel::oneshot;
  8use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  9use language::LanguageRegistry;
 10use markdown::Markdown;
 11use project::Project;
 12use std::{mem, ops::Range, path::PathBuf, sync::Arc};
 13use ui::App;
 14use util::{ResultExt, debug_panic};
 15
 16pub use server::AcpServer;
 17pub use thread_view::AcpThreadView;
 18
 19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 20pub struct ThreadId(SharedString);
 21
 22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 23pub struct FileVersion(u64);
 24
 25#[derive(Debug)]
 26pub struct AgentThreadSummary {
 27    pub id: ThreadId,
 28    pub title: String,
 29    pub created_at: DateTime<Utc>,
 30}
 31
 32#[derive(Clone, Debug, PartialEq, Eq)]
 33pub struct FileContent {
 34    pub path: PathBuf,
 35    pub version: FileVersion,
 36    pub content: SharedString,
 37}
 38
 39#[derive(Clone, Debug, Eq, PartialEq)]
 40pub struct Message {
 41    pub role: acp::Role,
 42    pub chunks: Vec<MessageChunk>,
 43}
 44
 45impl Message {
 46    fn into_acp(self, cx: &App) -> acp::Message {
 47        acp::Message {
 48            role: self.role,
 49            chunks: self
 50                .chunks
 51                .into_iter()
 52                .map(|chunk| chunk.into_acp(cx))
 53                .collect(),
 54        }
 55    }
 56}
 57
 58#[derive(Clone, Debug, Eq, PartialEq)]
 59pub enum MessageChunk {
 60    Text {
 61        chunk: Entity<Markdown>,
 62    },
 63    File {
 64        content: FileContent,
 65    },
 66    Directory {
 67        path: PathBuf,
 68        contents: Vec<FileContent>,
 69    },
 70    Symbol {
 71        path: PathBuf,
 72        range: Range<u64>,
 73        version: FileVersion,
 74        name: SharedString,
 75        content: SharedString,
 76    },
 77    Fetch {
 78        url: SharedString,
 79        content: SharedString,
 80    },
 81}
 82
 83impl MessageChunk {
 84    pub fn from_acp(
 85        chunk: acp::MessageChunk,
 86        language_registry: Arc<LanguageRegistry>,
 87        cx: &mut App,
 88    ) -> Self {
 89        match chunk {
 90            acp::MessageChunk::Text { chunk } => MessageChunk::Text {
 91                chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
 92            },
 93        }
 94    }
 95
 96    pub fn into_acp(self, cx: &App) -> acp::MessageChunk {
 97        match self {
 98            MessageChunk::Text { chunk } => acp::MessageChunk::Text {
 99                chunk: chunk.read(cx).source().to_string(),
100            },
101            MessageChunk::File { .. } => todo!(),
102            MessageChunk::Directory { .. } => todo!(),
103            MessageChunk::Symbol { .. } => todo!(),
104            MessageChunk::Fetch { .. } => todo!(),
105        }
106    }
107
108    pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
109        MessageChunk::Text {
110            chunk: cx.new(|cx| {
111                Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
112            }),
113        }
114    }
115}
116
117#[derive(Debug)]
118pub enum AgentThreadEntryContent {
119    Message(Message),
120    ToolCall(ToolCall),
121}
122
123#[derive(Debug)]
124pub struct ToolCall {
125    id: ToolCallId,
126    tool_name: Entity<Markdown>,
127    status: ToolCallStatus,
128}
129
130#[derive(Debug)]
131pub enum ToolCallStatus {
132    WaitingForConfirmation {
133        description: Entity<Markdown>,
134        respond_tx: oneshot::Sender<bool>,
135    },
136    // todo! Running?
137    Allowed,
138    Rejected,
139}
140
141/// A `ThreadEntryId` that is known to be a ToolCall
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
143pub struct ToolCallId(ThreadEntryId);
144
145impl ToolCallId {
146    pub fn as_u64(&self) -> u64 {
147        self.0.0
148    }
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
152pub struct ThreadEntryId(pub u64);
153
154impl ThreadEntryId {
155    pub fn post_inc(&mut self) -> Self {
156        let id = *self;
157        self.0 += 1;
158        id
159    }
160}
161
162#[derive(Debug)]
163pub struct ThreadEntry {
164    pub id: ThreadEntryId,
165    pub content: AgentThreadEntryContent,
166}
167
168pub struct AcpThread {
169    id: ThreadId,
170    next_entry_id: ThreadEntryId,
171    entries: Vec<ThreadEntry>,
172    server: Arc<AcpServer>,
173    title: SharedString,
174    project: Entity<Project>,
175}
176
177enum AcpThreadEvent {
178    NewEntry,
179    EntryUpdated(usize),
180}
181
182impl EventEmitter<AcpThreadEvent> for AcpThread {}
183
184impl AcpThread {
185    pub fn new(
186        server: Arc<AcpServer>,
187        thread_id: ThreadId,
188        entries: Vec<AgentThreadEntryContent>,
189        project: Entity<Project>,
190        _: &mut Context<Self>,
191    ) -> Self {
192        let mut next_entry_id = ThreadEntryId(0);
193        Self {
194            title: "A new agent2 thread".into(),
195            entries: entries
196                .into_iter()
197                .map(|entry| ThreadEntry {
198                    id: next_entry_id.post_inc(),
199                    content: entry,
200                })
201                .collect(),
202            server,
203            id: thread_id,
204            next_entry_id,
205            project,
206        }
207    }
208
209    pub fn title(&self) -> SharedString {
210        self.title.clone()
211    }
212
213    pub fn entries(&self) -> &[ThreadEntry] {
214        &self.entries
215    }
216
217    pub fn push_entry(
218        &mut self,
219        entry: AgentThreadEntryContent,
220        cx: &mut Context<Self>,
221    ) -> ThreadEntryId {
222        let id = self.next_entry_id.post_inc();
223        self.entries.push(ThreadEntry { id, content: entry });
224        cx.emit(AcpThreadEvent::NewEntry);
225        id
226    }
227
228    pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
229        let entries_len = self.entries.len();
230        if let Some(last_entry) = self.entries.last_mut()
231            && let AgentThreadEntryContent::Message(Message {
232                ref mut chunks,
233                role: Role::Assistant,
234            }) = last_entry.content
235        {
236            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
237
238            if let (
239                Some(MessageChunk::Text { chunk: old_chunk }),
240                acp::MessageChunk::Text { chunk: new_chunk },
241            ) = (chunks.last_mut(), &chunk)
242            {
243                old_chunk.update(cx, |old_chunk, cx| {
244                    old_chunk.append(&new_chunk, cx);
245                });
246            } else {
247                chunks.push(MessageChunk::from_acp(
248                    chunk,
249                    self.project.read(cx).languages().clone(),
250                    cx,
251                ));
252            }
253
254            return;
255        }
256
257        let chunk = MessageChunk::from_acp(chunk, self.project.read(cx).languages().clone(), cx);
258
259        self.push_entry(
260            AgentThreadEntryContent::Message(Message {
261                role: Role::Assistant,
262                chunks: vec![chunk],
263            }),
264            cx,
265        );
266    }
267
268    pub fn push_tool_call(
269        &mut self,
270        title: String,
271        description: String,
272        respond_tx: oneshot::Sender<bool>,
273        cx: &mut Context<Self>,
274    ) -> ToolCallId {
275        let language_registry = self.project.read(cx).languages().clone();
276
277        let entry_id = self.push_entry(
278            AgentThreadEntryContent::ToolCall(ToolCall {
279                // todo! clean up id creation
280                id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
281                tool_name: cx.new(|cx| {
282                    Markdown::new(title.into(), Some(language_registry.clone()), None, cx)
283                }),
284                status: ToolCallStatus::WaitingForConfirmation {
285                    description: cx.new(|cx| {
286                        Markdown::new(
287                            description.into(),
288                            Some(language_registry.clone()),
289                            None,
290                            cx,
291                        )
292                    }),
293                    respond_tx,
294                },
295            }),
296            cx,
297        );
298
299        ToolCallId(entry_id)
300    }
301
302    pub fn authorize_tool_call(&mut self, id: ToolCallId, allowed: bool, cx: &mut Context<Self>) {
303        let Some(entry) = self.entry_mut(id.0) else {
304            return;
305        };
306
307        let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
308            debug_panic!("expected ToolCall");
309            return;
310        };
311
312        let new_status = if allowed {
313            ToolCallStatus::Allowed
314        } else {
315            ToolCallStatus::Rejected
316        };
317
318        let curr_status = mem::replace(&mut call.status, new_status);
319
320        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
321            respond_tx.send(allowed).log_err();
322        } else {
323            debug_panic!("tried to authorize an already authorized tool call");
324        }
325
326        cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
327    }
328
329    fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
330        let entry = self.entries.get_mut(id.0 as usize);
331        debug_assert!(
332            entry.is_some(),
333            "We shouldn't give out ids to entries that don't exist"
334        );
335        entry
336    }
337
338    /// Returns true if the last turn is awaiting tool authorization
339    pub fn waiting_for_tool_confirmation(&self) -> bool {
340        for entry in self.entries.iter().rev() {
341            match &entry.content {
342                AgentThreadEntryContent::ToolCall(call) => match call.status {
343                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
344                    ToolCallStatus::Allowed | ToolCallStatus::Rejected => continue,
345                },
346                AgentThreadEntryContent::Message(_) => {
347                    // Reached the beginning of the turn
348                    return false;
349                }
350            }
351        }
352        false
353    }
354
355    pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
356        let agent = self.server.clone();
357        let id = self.id.clone();
358        let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
359        let message = Message {
360            role: Role::User,
361            chunks: vec![chunk],
362        };
363        self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
364        let acp_message = message.into_acp(cx);
365        cx.spawn(async move |_, cx| {
366            agent.send_message(id, acp_message, cx).await?;
367            Ok(())
368        })
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375    use futures::{FutureExt as _, channel::mpsc, select};
376    use gpui::{AsyncApp, TestAppContext};
377    use project::FakeFs;
378    use serde_json::json;
379    use settings::SettingsStore;
380    use smol::stream::StreamExt;
381    use std::{env, path::Path, process::Stdio, time::Duration};
382    use util::path;
383
384    fn init_test(cx: &mut TestAppContext) {
385        env_logger::try_init().ok();
386        cx.update(|cx| {
387            let settings_store = SettingsStore::test(cx);
388            cx.set_global(settings_store);
389            Project::init_settings(cx);
390            language::init(cx);
391        });
392    }
393
394    #[gpui::test]
395    async fn test_gemini_basic(cx: &mut TestAppContext) {
396        init_test(cx);
397
398        cx.executor().allow_parking();
399
400        let fs = FakeFs::new(cx.executor());
401        let project = Project::test(fs, [], cx).await;
402        let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
403        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
404        thread
405            .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
406            .await
407            .unwrap();
408
409        thread.read_with(cx, |thread, _| {
410            assert_eq!(thread.entries.len(), 2);
411            assert!(matches!(
412                thread.entries[0].content,
413                AgentThreadEntryContent::Message(Message {
414                    role: Role::User,
415                    ..
416                })
417            ));
418            assert!(matches!(
419                thread.entries[1].content,
420                AgentThreadEntryContent::Message(Message {
421                    role: Role::Assistant,
422                    ..
423                })
424            ));
425        });
426    }
427
428    #[gpui::test]
429    async fn test_gemini_tool_call(cx: &mut TestAppContext) {
430        init_test(cx);
431
432        cx.executor().allow_parking();
433
434        let fs = FakeFs::new(cx.executor());
435        fs.insert_tree(
436            path!("/private/tmp"),
437            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
438        )
439        .await;
440        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
441        let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
442        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
443        let full_turn = thread.update(cx, |thread, cx| {
444            thread.send(
445                "Read the '/private/tmp/foo' file and tell me what you see.",
446                cx,
447            )
448        });
449
450        run_until_tool_call(&thread, cx).await;
451
452        let tool_call_id = thread.read_with(cx, |thread, cx| {
453            let AgentThreadEntryContent::ToolCall(ToolCall {
454                id,
455                tool_name,
456                status: ToolCallStatus::WaitingForConfirmation { description, .. },
457            }) = &thread.entries().last().unwrap().content
458            else {
459                panic!();
460            };
461
462            tool_name.read_with(cx, |md, _cx| {
463                assert_eq!(md.source(), "read_file");
464            });
465
466            description.read_with(cx, |md, _cx| {
467                assert!(
468                    md.source().contains("foo"),
469                    "Expected description to contain 'foo', but got {}",
470                    md.source()
471                );
472            });
473            *id
474        });
475
476        thread.update(cx, |thread, cx| {
477            thread.authorize_tool_call(tool_call_id, true, cx);
478            assert!(matches!(
479                thread.entries().last().unwrap().content,
480                AgentThreadEntryContent::ToolCall(ToolCall {
481                    status: ToolCallStatus::Allowed,
482                    ..
483                })
484            ));
485        });
486
487        full_turn.await.unwrap();
488
489        thread.read_with(cx, |thread, _| {
490            assert!(thread.entries.len() >= 3, "{:?}", &thread.entries);
491            assert!(matches!(
492                thread.entries[0].content,
493                AgentThreadEntryContent::Message(Message {
494                    role: Role::User,
495                    ..
496                })
497            ));
498            assert!(matches!(
499                thread.entries[1].content,
500                AgentThreadEntryContent::ToolCall(ToolCall {
501                    status: ToolCallStatus::Allowed,
502                    ..
503                })
504            ));
505            assert!(matches!(
506                thread.entries[2].content,
507                AgentThreadEntryContent::Message(Message {
508                    role: Role::Assistant,
509                    ..
510                })
511            ));
512        });
513    }
514
515    async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
516        let (mut tx, mut rx) = mpsc::channel(1);
517
518        let subscription = cx.update(|cx| {
519            cx.subscribe(thread, move |thread, _, cx| {
520                if thread
521                    .read(cx)
522                    .entries
523                    .iter()
524                    .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
525                {
526                    tx.try_send(()).unwrap();
527                }
528            })
529        });
530
531        select! {
532            _ = cx.executor().timer(Duration::from_secs(5)).fuse() => {
533                panic!("Timeout waiting for tool call")
534            }
535            _ = rx.next().fuse() => {
536                drop(subscription);
537            }
538        }
539    }
540
541    pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {
542        let cli_path =
543            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
544        let mut command = util::command::new_smol_command("node");
545        command
546            .arg(cli_path)
547            .arg("--acp")
548            .args(["--model", "gemini-2.5-flash"])
549            .current_dir("/private/tmp")
550            .stdin(Stdio::piped())
551            .stdout(Stdio::piped())
552            .stderr(Stdio::inherit())
553            .kill_on_drop(true);
554
555        if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
556            command.env("GEMINI_API_KEY", gemini_key);
557        }
558
559        let child = command.spawn().unwrap();
560
561        Ok(AcpServer::stdio(child, project, &mut cx))
562    }
563}