acp.rs

  1mod server;
  2mod thread_view;
  3
  4use agentic_coding_protocol::{self as acp, Role};
  5use anyhow::{Context as _, Result};
  6use buffer_diff::BufferDiff;
  7use chrono::{DateTime, Utc};
  8use editor::MultiBuffer;
  9use futures::channel::oneshot;
 10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
 11use language::{Buffer, LanguageRegistry};
 12use markdown::Markdown;
 13use project::Project;
 14use std::{mem, ops::Range, path::PathBuf, sync::Arc};
 15use ui::{App, IconName};
 16use util::{ResultExt, debug_panic};
 17
 18pub use server::AcpServer;
 19pub use thread_view::AcpThreadView;
 20
 21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 22pub struct ThreadId(SharedString);
 23
 24#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 25pub struct FileVersion(u64);
 26
 27#[derive(Debug)]
 28pub struct AgentThreadSummary {
 29    pub id: ThreadId,
 30    pub title: String,
 31    pub created_at: DateTime<Utc>,
 32}
 33
 34#[derive(Clone, Debug, PartialEq, Eq)]
 35pub struct FileContent {
 36    pub path: PathBuf,
 37    pub version: FileVersion,
 38    pub content: SharedString,
 39}
 40
 41#[derive(Clone, Debug, Eq, PartialEq)]
 42pub struct Message {
 43    pub role: acp::Role,
 44    pub chunks: Vec<MessageChunk>,
 45}
 46
 47impl Message {
 48    fn into_acp(self, cx: &App) -> acp::Message {
 49        acp::Message {
 50            role: self.role,
 51            chunks: self
 52                .chunks
 53                .into_iter()
 54                .map(|chunk| chunk.into_acp(cx))
 55                .collect(),
 56        }
 57    }
 58}
 59
 60#[derive(Clone, Debug, Eq, PartialEq)]
 61pub enum MessageChunk {
 62    Text {
 63        chunk: Entity<Markdown>,
 64    },
 65    File {
 66        content: FileContent,
 67    },
 68    Directory {
 69        path: PathBuf,
 70        contents: Vec<FileContent>,
 71    },
 72    Symbol {
 73        path: PathBuf,
 74        range: Range<u64>,
 75        version: FileVersion,
 76        name: SharedString,
 77        content: SharedString,
 78    },
 79    Fetch {
 80        url: SharedString,
 81        content: SharedString,
 82    },
 83}
 84
 85impl MessageChunk {
 86    pub fn from_acp(
 87        chunk: acp::MessageChunk,
 88        language_registry: Arc<LanguageRegistry>,
 89        cx: &mut App,
 90    ) -> Self {
 91        match chunk {
 92            acp::MessageChunk::Text { chunk } => MessageChunk::Text {
 93                chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
 94            },
 95        }
 96    }
 97
 98    pub fn into_acp(self, cx: &App) -> acp::MessageChunk {
 99        match self {
100            MessageChunk::Text { chunk } => acp::MessageChunk::Text {
101                chunk: chunk.read(cx).source().to_string(),
102            },
103            MessageChunk::File { .. } => todo!(),
104            MessageChunk::Directory { .. } => todo!(),
105            MessageChunk::Symbol { .. } => todo!(),
106            MessageChunk::Fetch { .. } => todo!(),
107        }
108    }
109
110    pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
111        MessageChunk::Text {
112            chunk: cx.new(|cx| {
113                Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
114            }),
115        }
116    }
117}
118
119#[derive(Debug)]
120pub enum AgentThreadEntryContent {
121    Message(Message),
122    ToolCall(ToolCall),
123}
124
125#[derive(Debug)]
126pub struct ToolCall {
127    id: ToolCallId,
128    label: Entity<Markdown>,
129    icon: IconName,
130    status: ToolCallStatus,
131}
132
133#[derive(Debug)]
134pub enum ToolCallStatus {
135    WaitingForConfirmation {
136        confirmation: acp::ToolCallConfirmation,
137        respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
138    },
139    Allowed {
140        status: acp::ToolCallStatus,
141        content: Option<ToolCallContent>,
142    },
143    Rejected,
144}
145
146#[derive(Debug)]
147pub enum ToolCallContent {
148    Markdown {
149        markdown: Entity<Markdown>,
150    },
151    Diff {
152        path: PathBuf,
153        diff: Entity<BufferDiff>,
154        buffer: Entity<MultiBuffer>,
155        _task: Task<Result<()>>,
156    },
157}
158
159/// A `ThreadEntryId` that is known to be a ToolCall
160#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
161pub struct ToolCallId(ThreadEntryId);
162
163impl ToolCallId {
164    pub fn as_u64(&self) -> u64 {
165        self.0.0
166    }
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
170pub struct ThreadEntryId(pub u64);
171
172impl ThreadEntryId {
173    pub fn post_inc(&mut self) -> Self {
174        let id = *self;
175        self.0 += 1;
176        id
177    }
178}
179
180#[derive(Debug)]
181pub struct ThreadEntry {
182    pub id: ThreadEntryId,
183    pub content: AgentThreadEntryContent,
184}
185
186pub struct AcpThread {
187    id: ThreadId,
188    next_entry_id: ThreadEntryId,
189    entries: Vec<ThreadEntry>,
190    server: Arc<AcpServer>,
191    title: SharedString,
192    project: Entity<Project>,
193}
194
195enum AcpThreadEvent {
196    NewEntry,
197    EntryUpdated(usize),
198}
199
200impl EventEmitter<AcpThreadEvent> for AcpThread {}
201
202impl AcpThread {
203    pub fn new(
204        server: Arc<AcpServer>,
205        thread_id: ThreadId,
206        entries: Vec<AgentThreadEntryContent>,
207        project: Entity<Project>,
208        _: &mut Context<Self>,
209    ) -> Self {
210        let mut next_entry_id = ThreadEntryId(0);
211        Self {
212            title: "A new agent2 thread".into(),
213            entries: entries
214                .into_iter()
215                .map(|entry| ThreadEntry {
216                    id: next_entry_id.post_inc(),
217                    content: entry,
218                })
219                .collect(),
220            server,
221            id: thread_id,
222            next_entry_id,
223            project,
224        }
225    }
226
227    pub fn title(&self) -> SharedString {
228        self.title.clone()
229    }
230
231    pub fn entries(&self) -> &[ThreadEntry] {
232        &self.entries
233    }
234
235    pub fn push_entry(
236        &mut self,
237        entry: AgentThreadEntryContent,
238        cx: &mut Context<Self>,
239    ) -> ThreadEntryId {
240        let id = self.next_entry_id.post_inc();
241        self.entries.push(ThreadEntry { id, content: entry });
242        cx.emit(AcpThreadEvent::NewEntry);
243        id
244    }
245
246    pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
247        let entries_len = self.entries.len();
248        if let Some(last_entry) = self.entries.last_mut()
249            && let AgentThreadEntryContent::Message(Message {
250                ref mut chunks,
251                role: Role::Assistant,
252            }) = last_entry.content
253        {
254            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
255
256            if let (
257                Some(MessageChunk::Text { chunk: old_chunk }),
258                acp::MessageChunk::Text { chunk: new_chunk },
259            ) = (chunks.last_mut(), &chunk)
260            {
261                old_chunk.update(cx, |old_chunk, cx| {
262                    old_chunk.append(&new_chunk, cx);
263                });
264            } else {
265                chunks.push(MessageChunk::from_acp(
266                    chunk,
267                    self.project.read(cx).languages().clone(),
268                    cx,
269                ));
270            }
271
272            return;
273        }
274
275        let chunk = MessageChunk::from_acp(chunk, self.project.read(cx).languages().clone(), cx);
276
277        self.push_entry(
278            AgentThreadEntryContent::Message(Message {
279                role: Role::Assistant,
280                chunks: vec![chunk],
281            }),
282            cx,
283        );
284    }
285
286    pub fn request_tool_call(
287        &mut self,
288        label: String,
289        icon: acp::Icon,
290        confirmation: acp::ToolCallConfirmation,
291        cx: &mut Context<Self>,
292    ) -> ToolCallRequest {
293        let (tx, rx) = oneshot::channel();
294
295        let status = ToolCallStatus::WaitingForConfirmation {
296            confirmation,
297            respond_tx: tx,
298        };
299
300        let id = self.insert_tool_call(label, status, icon, cx);
301        ToolCallRequest { id, outcome: rx }
302    }
303
304    pub fn push_tool_call(
305        &mut self,
306        label: String,
307        icon: acp::Icon,
308        cx: &mut Context<Self>,
309    ) -> ToolCallId {
310        let status = ToolCallStatus::Allowed {
311            status: acp::ToolCallStatus::Running,
312            content: None,
313        };
314
315        self.insert_tool_call(label, status, icon, cx)
316    }
317
318    fn insert_tool_call(
319        &mut self,
320        label: String,
321        status: ToolCallStatus,
322        icon: acp::Icon,
323        cx: &mut Context<Self>,
324    ) -> ToolCallId {
325        let language_registry = self.project.read(cx).languages().clone();
326
327        let entry_id = self.push_entry(
328            AgentThreadEntryContent::ToolCall(ToolCall {
329                // todo! clean up id creation
330                id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
331                label: cx.new(|cx| {
332                    Markdown::new(label.into(), Some(language_registry.clone()), None, cx)
333                }),
334                icon: acp_icon_to_ui_icon(icon),
335                status,
336            }),
337            cx,
338        );
339
340        ToolCallId(entry_id)
341    }
342
343    pub fn authorize_tool_call(
344        &mut self,
345        id: ToolCallId,
346        outcome: acp::ToolCallConfirmationOutcome,
347        cx: &mut Context<Self>,
348    ) {
349        let Some(entry) = self.entry_mut(id.0) else {
350            return;
351        };
352
353        let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
354            debug_panic!("expected ToolCall");
355            return;
356        };
357
358        let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
359            ToolCallStatus::Rejected
360        } else {
361            ToolCallStatus::Allowed {
362                status: acp::ToolCallStatus::Running,
363                content: None,
364            }
365        };
366
367        let curr_status = mem::replace(&mut call.status, new_status);
368
369        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
370            respond_tx.send(outcome).log_err();
371        } else {
372            debug_panic!("tried to authorize an already authorized tool call");
373        }
374
375        cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
376    }
377
378    pub fn update_tool_call(
379        &mut self,
380        id: ToolCallId,
381        new_status: acp::ToolCallStatus,
382        new_content: Option<acp::ToolCallContent>,
383        cx: &mut Context<Self>,
384    ) -> Result<()> {
385        let language_registry = self.project.read(cx).languages().clone();
386        let entry = self.entry_mut(id.0).context("Entry not found")?;
387
388        match &mut entry.content {
389            AgentThreadEntryContent::ToolCall(call) => match &mut call.status {
390                ToolCallStatus::Allowed { content, status } => {
391                    *content = new_content.map(|new_content| match new_content {
392                        acp::ToolCallContent::Markdown { markdown } => ToolCallContent::Markdown {
393                            markdown: cx.new(|cx| {
394                                Markdown::new(
395                                    markdown.into(),
396                                    Some(language_registry.clone()),
397                                    None,
398                                    cx,
399                                )
400                            }),
401                        },
402                        acp::ToolCallContent::Diff {
403                            path,
404                            old_text,
405                            new_text,
406                        } => {
407                            let buffer = cx.new(|cx| Buffer::local(new_text, cx));
408                            let text_snapshot = buffer.read(cx).text_snapshot();
409                            let buffer_diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx));
410
411                            let multibuffer = cx.new(|cx| {
412                                let mut multibuffer = MultiBuffer::singleton(buffer.clone(), cx);
413                                multibuffer.add_diff(buffer_diff.clone(), cx);
414                                multibuffer
415                            });
416
417                            ToolCallContent::Diff {
418                                path: path.clone(),
419                                diff: buffer_diff.clone(),
420                                buffer: multibuffer,
421                                _task: cx.spawn(async move |_this, cx| {
422                                    let diff_snapshot = BufferDiff::update_diff(
423                                        buffer_diff.clone(),
424                                        text_snapshot.clone(),
425                                        old_text.map(|o| o.into()),
426                                        true,
427                                        true,
428                                        None,
429                                        Some(language_registry.clone()),
430                                        cx,
431                                    )
432                                    .await?;
433
434                                    buffer_diff.update(cx, |diff, cx| {
435                                        diff.set_snapshot(diff_snapshot, &text_snapshot, cx)
436                                    })?;
437
438                                    if let Some(language) = language_registry
439                                        .language_for_file_path(&path)
440                                        .await
441                                        .log_err()
442                                    {
443                                        buffer.update(cx, |buffer, cx| {
444                                            buffer.set_language(Some(language), cx)
445                                        })?;
446                                    }
447
448                                    anyhow::Ok(())
449                                }),
450                            }
451                        }
452                    });
453                    *status = new_status;
454                }
455                ToolCallStatus::WaitingForConfirmation { .. } => {
456                    anyhow::bail!("Tool call hasn't been authorized yet")
457                }
458                ToolCallStatus::Rejected => {
459                    anyhow::bail!("Tool call was rejected and therefore can't be updated")
460                }
461            },
462            _ => anyhow::bail!("Entry is not a tool call"),
463        }
464
465        cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
466        Ok(())
467    }
468
469    fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
470        let entry = self.entries.get_mut(id.0 as usize);
471        debug_assert!(
472            entry.is_some(),
473            "We shouldn't give out ids to entries that don't exist"
474        );
475        entry
476    }
477
478    /// Returns true if the last turn is awaiting tool authorization
479    pub fn waiting_for_tool_confirmation(&self) -> bool {
480        for entry in self.entries.iter().rev() {
481            match &entry.content {
482                AgentThreadEntryContent::ToolCall(call) => match call.status {
483                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
484                    ToolCallStatus::Allowed { .. } | ToolCallStatus::Rejected => continue,
485                },
486                AgentThreadEntryContent::Message(_) => {
487                    // Reached the beginning of the turn
488                    return false;
489                }
490            }
491        }
492        false
493    }
494
495    pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
496        let agent = self.server.clone();
497        let id = self.id.clone();
498        let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
499        let message = Message {
500            role: Role::User,
501            chunks: vec![chunk],
502        };
503        self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
504        let acp_message = message.into_acp(cx);
505        cx.spawn(async move |_, cx| {
506            agent.send_message(id, acp_message, cx).await?;
507            Ok(())
508        })
509    }
510}
511
512fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
513    match icon {
514        acp::Icon::FileSearch => IconName::FileSearch,
515        acp::Icon::Folder => IconName::Folder,
516        acp::Icon::Globe => IconName::Globe,
517        acp::Icon::Hammer => IconName::Hammer,
518        acp::Icon::LightBulb => IconName::LightBulb,
519        acp::Icon::Pencil => IconName::Pencil,
520        acp::Icon::Regex => IconName::Regex,
521        acp::Icon::Terminal => IconName::Terminal,
522    }
523}
524
525pub struct ToolCallRequest {
526    pub id: ToolCallId,
527    pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533    use futures::{FutureExt as _, channel::mpsc, select};
534    use gpui::{AsyncApp, TestAppContext};
535    use project::FakeFs;
536    use serde_json::json;
537    use settings::SettingsStore;
538    use smol::stream::StreamExt as _;
539    use std::{env, path::Path, process::Stdio, time::Duration};
540    use util::path;
541
542    fn init_test(cx: &mut TestAppContext) {
543        env_logger::try_init().ok();
544        cx.update(|cx| {
545            let settings_store = SettingsStore::test(cx);
546            cx.set_global(settings_store);
547            Project::init_settings(cx);
548            language::init(cx);
549        });
550    }
551
552    #[gpui::test]
553    async fn test_gemini_basic(cx: &mut TestAppContext) {
554        init_test(cx);
555
556        cx.executor().allow_parking();
557
558        let fs = FakeFs::new(cx.executor());
559        let project = Project::test(fs, [], cx).await;
560        let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
561        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
562        thread
563            .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
564            .await
565            .unwrap();
566
567        thread.read_with(cx, |thread, _| {
568            assert_eq!(thread.entries.len(), 2);
569            assert!(matches!(
570                thread.entries[0].content,
571                AgentThreadEntryContent::Message(Message {
572                    role: Role::User,
573                    ..
574                })
575            ));
576            assert!(matches!(
577                thread.entries[1].content,
578                AgentThreadEntryContent::Message(Message {
579                    role: Role::Assistant,
580                    ..
581                })
582            ));
583        });
584    }
585
586    #[gpui::test]
587    async fn test_gemini_tool_call(cx: &mut TestAppContext) {
588        init_test(cx);
589
590        cx.executor().allow_parking();
591
592        let fs = FakeFs::new(cx.executor());
593        fs.insert_tree(
594            path!("/private/tmp"),
595            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
596        )
597        .await;
598        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
599        let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
600        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
601        thread
602            .update(cx, |thread, cx| {
603                thread.send(
604                    "Read the '/private/tmp/foo' file and tell me what you see.",
605                    cx,
606                )
607            })
608            .await
609            .unwrap();
610        thread.read_with(cx, |thread, _cx| {
611            assert!(matches!(
612                &thread.entries()[1].content,
613                AgentThreadEntryContent::ToolCall(ToolCall {
614                    status: ToolCallStatus::Allowed { .. },
615                    ..
616                })
617            ));
618
619            assert!(matches!(
620                thread.entries[2].content,
621                AgentThreadEntryContent::Message(Message {
622                    role: Role::Assistant,
623                    ..
624                })
625            ));
626        });
627    }
628
629    #[gpui::test]
630    async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
631        init_test(cx);
632
633        cx.executor().allow_parking();
634
635        let fs = FakeFs::new(cx.executor());
636        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
637        let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
638        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
639        let full_turn = thread.update(cx, |thread, cx| {
640            thread.send(r#"Run `echo "Hello, world!"`"#, cx)
641        });
642
643        run_until_tool_call(&thread, cx).await;
644
645        let tool_call_id = thread.read_with(cx, |thread, _cx| {
646            let AgentThreadEntryContent::ToolCall(ToolCall {
647                id,
648                status:
649                    ToolCallStatus::WaitingForConfirmation {
650                        confirmation: acp::ToolCallConfirmation::Execute { root_command, .. },
651                        ..
652                    },
653                ..
654            }) = &thread.entries()[1].content
655            else {
656                panic!();
657            };
658
659            assert_eq!(root_command, "echo");
660
661            *id
662        });
663
664        thread.update(cx, |thread, cx| {
665            thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
666
667            assert!(matches!(
668                &thread.entries()[1].content,
669                AgentThreadEntryContent::ToolCall(ToolCall {
670                    status: ToolCallStatus::Allowed { .. },
671                    ..
672                })
673            ));
674        });
675
676        full_turn.await.unwrap();
677
678        thread.read_with(cx, |thread, cx| {
679            let AgentThreadEntryContent::ToolCall(ToolCall {
680                status:
681                    ToolCallStatus::Allowed {
682                        content: Some(ToolCallContent::Markdown { markdown }),
683                        ..
684                    },
685                ..
686            }) = &thread.entries()[1].content
687            else {
688                panic!();
689            };
690
691            markdown.read_with(cx, |md, _cx| {
692                assert!(
693                    md.source().contains("Hello, world!"),
694                    r#"Expected '{}' to contain "Hello, world!""#,
695                    md.source()
696                );
697            });
698        });
699    }
700
701    async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
702        let (mut tx, mut rx) = mpsc::channel::<()>(1);
703
704        let subscription = cx.update(|cx| {
705            cx.subscribe(thread, move |thread, _, cx| {
706                if thread
707                    .read(cx)
708                    .entries
709                    .iter()
710                    .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
711                {
712                    tx.try_send(()).unwrap();
713                }
714            })
715        });
716
717        select! {
718            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
719                panic!("Timeout waiting for tool call")
720            }
721            _ = rx.next().fuse() => {
722                drop(subscription);
723            }
724        }
725    }
726
727    pub fn gemini_acp_server(project: Entity<Project>, cx: AsyncApp) -> Result<Arc<AcpServer>> {
728        let cli_path =
729            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
730        let mut command = util::command::new_smol_command("node");
731        command
732            .arg(cli_path)
733            .arg("--acp")
734            .current_dir("/private/tmp")
735            .stdin(Stdio::piped())
736            .stdout(Stdio::piped())
737            .stderr(Stdio::inherit())
738            .kill_on_drop(true);
739
740        if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
741            command.env("GEMINI_API_KEY", gemini_key);
742        }
743
744        let child = command.spawn().unwrap();
745
746        cx.update(|cx| AcpServer::stdio(child, project, cx))
747    }
748}