acp.rs

  1mod server;
  2mod thread_view;
  3
  4use agentic_coding_protocol::{self as acp, Role};
  5use anyhow::{Context as _, Result};
  6use buffer_diff::BufferDiff;
  7use chrono::{DateTime, Utc};
  8use editor::MultiBuffer;
  9use futures::channel::oneshot;
 10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
 11use language::{Buffer, LanguageRegistry};
 12use markdown::Markdown;
 13use project::Project;
 14use std::{mem, ops::Range, path::PathBuf, sync::Arc};
 15use ui::{App, IconName};
 16use util::{ResultExt, debug_panic};
 17
 18pub use server::AcpServer;
 19pub use thread_view::AcpThreadView;
 20
 21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 22pub struct ThreadId(SharedString);
 23
 24#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 25pub struct FileVersion(u64);
 26
 27#[derive(Debug)]
 28pub struct AgentThreadSummary {
 29    pub id: ThreadId,
 30    pub title: String,
 31    pub created_at: DateTime<Utc>,
 32}
 33
 34#[derive(Clone, Debug, PartialEq, Eq)]
 35pub struct FileContent {
 36    pub path: PathBuf,
 37    pub version: FileVersion,
 38    pub content: SharedString,
 39}
 40
 41#[derive(Clone, Debug, Eq, PartialEq)]
 42pub struct Message {
 43    pub role: acp::Role,
 44    pub chunks: Vec<MessageChunk>,
 45}
 46
 47impl Message {
 48    fn into_acp(self, cx: &App) -> acp::Message {
 49        acp::Message {
 50            role: self.role,
 51            chunks: self
 52                .chunks
 53                .into_iter()
 54                .map(|chunk| chunk.into_acp(cx))
 55                .collect(),
 56        }
 57    }
 58}
 59
 60#[derive(Clone, Debug, Eq, PartialEq)]
 61pub enum MessageChunk {
 62    Text {
 63        chunk: Entity<Markdown>,
 64    },
 65    File {
 66        content: FileContent,
 67    },
 68    Directory {
 69        path: PathBuf,
 70        contents: Vec<FileContent>,
 71    },
 72    Symbol {
 73        path: PathBuf,
 74        range: Range<u64>,
 75        version: FileVersion,
 76        name: SharedString,
 77        content: SharedString,
 78    },
 79    Fetch {
 80        url: SharedString,
 81        content: SharedString,
 82    },
 83}
 84
 85impl MessageChunk {
 86    pub fn from_acp(
 87        chunk: acp::MessageChunk,
 88        language_registry: Arc<LanguageRegistry>,
 89        cx: &mut App,
 90    ) -> Self {
 91        match chunk {
 92            acp::MessageChunk::Text { chunk } => MessageChunk::Text {
 93                chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
 94            },
 95        }
 96    }
 97
 98    pub fn into_acp(self, cx: &App) -> acp::MessageChunk {
 99        match self {
100            MessageChunk::Text { chunk } => acp::MessageChunk::Text {
101                chunk: chunk.read(cx).source().to_string(),
102            },
103            MessageChunk::File { .. } => todo!(),
104            MessageChunk::Directory { .. } => todo!(),
105            MessageChunk::Symbol { .. } => todo!(),
106            MessageChunk::Fetch { .. } => todo!(),
107        }
108    }
109
110    pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
111        MessageChunk::Text {
112            chunk: cx.new(|cx| {
113                Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
114            }),
115        }
116    }
117}
118
119#[derive(Debug)]
120pub enum AgentThreadEntryContent {
121    Message(Message),
122    ToolCall(ToolCall),
123}
124
125#[derive(Debug)]
126pub struct ToolCall {
127    id: ToolCallId,
128    label: Entity<Markdown>,
129    icon: IconName,
130    status: ToolCallStatus,
131}
132
133#[derive(Debug)]
134pub enum ToolCallStatus {
135    WaitingForConfirmation {
136        confirmation: acp::ToolCallConfirmation,
137        respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
138    },
139    // todo! Running?
140    Allowed {
141        // todo! should this be variants in crate::ToolCallStatus instead?
142        status: acp::ToolCallStatus,
143        content: Option<ToolCallContent>,
144    },
145    Rejected,
146}
147
148#[derive(Debug)]
149pub enum ToolCallContent {
150    Markdown {
151        markdown: Entity<Markdown>,
152    },
153    Diff {
154        path: PathBuf,
155        diff: Entity<BufferDiff>,
156        buffer: Entity<MultiBuffer>,
157        _task: Task<Result<()>>,
158    },
159}
160
161/// A `ThreadEntryId` that is known to be a ToolCall
162#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
163pub struct ToolCallId(ThreadEntryId);
164
165impl ToolCallId {
166    pub fn as_u64(&self) -> u64 {
167        self.0.0
168    }
169}
170
171#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
172pub struct ThreadEntryId(pub u64);
173
174impl ThreadEntryId {
175    pub fn post_inc(&mut self) -> Self {
176        let id = *self;
177        self.0 += 1;
178        id
179    }
180}
181
182#[derive(Debug)]
183pub struct ThreadEntry {
184    pub id: ThreadEntryId,
185    pub content: AgentThreadEntryContent,
186}
187
188pub struct AcpThread {
189    id: ThreadId,
190    next_entry_id: ThreadEntryId,
191    entries: Vec<ThreadEntry>,
192    server: Arc<AcpServer>,
193    title: SharedString,
194    project: Entity<Project>,
195}
196
197enum AcpThreadEvent {
198    NewEntry,
199    EntryUpdated(usize),
200}
201
202impl EventEmitter<AcpThreadEvent> for AcpThread {}
203
204impl AcpThread {
205    pub fn new(
206        server: Arc<AcpServer>,
207        thread_id: ThreadId,
208        entries: Vec<AgentThreadEntryContent>,
209        project: Entity<Project>,
210        _: &mut Context<Self>,
211    ) -> Self {
212        let mut next_entry_id = ThreadEntryId(0);
213        Self {
214            title: "A new agent2 thread".into(),
215            entries: entries
216                .into_iter()
217                .map(|entry| ThreadEntry {
218                    id: next_entry_id.post_inc(),
219                    content: entry,
220                })
221                .collect(),
222            server,
223            id: thread_id,
224            next_entry_id,
225            project,
226        }
227    }
228
229    pub fn title(&self) -> SharedString {
230        self.title.clone()
231    }
232
233    pub fn entries(&self) -> &[ThreadEntry] {
234        &self.entries
235    }
236
237    pub fn push_entry(
238        &mut self,
239        entry: AgentThreadEntryContent,
240        cx: &mut Context<Self>,
241    ) -> ThreadEntryId {
242        let id = self.next_entry_id.post_inc();
243        self.entries.push(ThreadEntry { id, content: entry });
244        cx.emit(AcpThreadEvent::NewEntry);
245        id
246    }
247
248    pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
249        let entries_len = self.entries.len();
250        if let Some(last_entry) = self.entries.last_mut()
251            && let AgentThreadEntryContent::Message(Message {
252                ref mut chunks,
253                role: Role::Assistant,
254            }) = last_entry.content
255        {
256            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
257
258            if let (
259                Some(MessageChunk::Text { chunk: old_chunk }),
260                acp::MessageChunk::Text { chunk: new_chunk },
261            ) = (chunks.last_mut(), &chunk)
262            {
263                old_chunk.update(cx, |old_chunk, cx| {
264                    old_chunk.append(&new_chunk, cx);
265                });
266            } else {
267                chunks.push(MessageChunk::from_acp(
268                    chunk,
269                    self.project.read(cx).languages().clone(),
270                    cx,
271                ));
272            }
273
274            return;
275        }
276
277        let chunk = MessageChunk::from_acp(chunk, self.project.read(cx).languages().clone(), cx);
278
279        self.push_entry(
280            AgentThreadEntryContent::Message(Message {
281                role: Role::Assistant,
282                chunks: vec![chunk],
283            }),
284            cx,
285        );
286    }
287
288    pub fn request_tool_call(
289        &mut self,
290        label: String,
291        icon: acp::Icon,
292        confirmation: acp::ToolCallConfirmation,
293        cx: &mut Context<Self>,
294    ) -> ToolCallRequest {
295        let (tx, rx) = oneshot::channel();
296
297        let status = ToolCallStatus::WaitingForConfirmation {
298            confirmation,
299            respond_tx: tx,
300        };
301
302        let id = self.insert_tool_call(label, status, icon, cx);
303        ToolCallRequest { id, outcome: rx }
304    }
305
306    pub fn push_tool_call(
307        &mut self,
308        label: String,
309        icon: acp::Icon,
310        cx: &mut Context<Self>,
311    ) -> ToolCallId {
312        let status = ToolCallStatus::Allowed {
313            status: acp::ToolCallStatus::Running,
314            content: None,
315        };
316
317        self.insert_tool_call(label, status, icon, cx)
318    }
319
320    fn insert_tool_call(
321        &mut self,
322        label: String,
323        status: ToolCallStatus,
324        icon: acp::Icon,
325        cx: &mut Context<Self>,
326    ) -> ToolCallId {
327        let language_registry = self.project.read(cx).languages().clone();
328
329        let entry_id = self.push_entry(
330            AgentThreadEntryContent::ToolCall(ToolCall {
331                // todo! clean up id creation
332                id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
333                label: cx.new(|cx| {
334                    Markdown::new(label.into(), Some(language_registry.clone()), None, cx)
335                }),
336                icon: acp_icon_to_ui_icon(icon),
337                status,
338            }),
339            cx,
340        );
341
342        ToolCallId(entry_id)
343    }
344
345    pub fn authorize_tool_call(
346        &mut self,
347        id: ToolCallId,
348        outcome: acp::ToolCallConfirmationOutcome,
349        cx: &mut Context<Self>,
350    ) {
351        let Some(entry) = self.entry_mut(id.0) else {
352            return;
353        };
354
355        let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
356            debug_panic!("expected ToolCall");
357            return;
358        };
359
360        let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
361            ToolCallStatus::Rejected
362        } else {
363            ToolCallStatus::Allowed {
364                status: acp::ToolCallStatus::Running,
365                content: None,
366            }
367        };
368
369        let curr_status = mem::replace(&mut call.status, new_status);
370
371        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
372            respond_tx.send(outcome).log_err();
373        } else {
374            debug_panic!("tried to authorize an already authorized tool call");
375        }
376
377        cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
378    }
379
380    pub fn update_tool_call(
381        &mut self,
382        id: ToolCallId,
383        new_status: acp::ToolCallStatus,
384        new_content: Option<acp::ToolCallContent>,
385        cx: &mut Context<Self>,
386    ) -> Result<()> {
387        let language_registry = self.project.read(cx).languages().clone();
388        let entry = self.entry_mut(id.0).context("Entry not found")?;
389
390        match &mut entry.content {
391            AgentThreadEntryContent::ToolCall(call) => match &mut call.status {
392                ToolCallStatus::Allowed { content, status } => {
393                    *content = new_content.map(|new_content| match new_content {
394                        acp::ToolCallContent::Markdown { markdown } => ToolCallContent::Markdown {
395                            markdown: cx.new(|cx| {
396                                Markdown::new(
397                                    markdown.into(),
398                                    Some(language_registry.clone()),
399                                    None,
400                                    cx,
401                                )
402                            }),
403                        },
404                        acp::ToolCallContent::Diff {
405                            path,
406                            old_text,
407                            new_text,
408                        } => {
409                            let buffer = cx.new(|cx| Buffer::local(new_text, cx));
410                            let text_snapshot = buffer.read(cx).text_snapshot();
411                            let buffer_diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx));
412
413                            let multibuffer = cx.new(|cx| {
414                                let mut multibuffer = MultiBuffer::singleton(buffer.clone(), cx);
415                                multibuffer.add_diff(buffer_diff.clone(), cx);
416                                multibuffer
417                            });
418
419                            ToolCallContent::Diff {
420                                path: path.clone(),
421                                diff: buffer_diff.clone(),
422                                buffer: multibuffer,
423                                _task: cx.spawn(async move |_this, cx| {
424                                    let diff_snapshot = BufferDiff::update_diff(
425                                        buffer_diff.clone(),
426                                        text_snapshot.clone(),
427                                        old_text.map(|o| o.into()),
428                                        true,
429                                        true,
430                                        None,
431                                        Some(language_registry.clone()),
432                                        cx,
433                                    )
434                                    .await?;
435
436                                    buffer_diff.update(cx, |diff, cx| {
437                                        diff.set_snapshot(diff_snapshot, &text_snapshot, cx)
438                                    })?;
439
440                                    if let Some(language) = language_registry
441                                        .language_for_file_path(&path)
442                                        .await
443                                        .log_err()
444                                    {
445                                        buffer.update(cx, |buffer, cx| {
446                                            buffer.set_language(Some(language), cx)
447                                        })?;
448                                    }
449
450                                    anyhow::Ok(())
451                                }),
452                            }
453                        }
454                    });
455                    *status = new_status;
456                }
457                ToolCallStatus::WaitingForConfirmation { .. } => {
458                    anyhow::bail!("Tool call hasn't been authorized yet")
459                }
460                ToolCallStatus::Rejected => {
461                    anyhow::bail!("Tool call was rejected and therefore can't be updated")
462                }
463            },
464            _ => anyhow::bail!("Entry is not a tool call"),
465        }
466
467        cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
468        Ok(())
469    }
470
471    fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
472        let entry = self.entries.get_mut(id.0 as usize);
473        debug_assert!(
474            entry.is_some(),
475            "We shouldn't give out ids to entries that don't exist"
476        );
477        entry
478    }
479
480    /// Returns true if the last turn is awaiting tool authorization
481    pub fn waiting_for_tool_confirmation(&self) -> bool {
482        for entry in self.entries.iter().rev() {
483            match &entry.content {
484                AgentThreadEntryContent::ToolCall(call) => match call.status {
485                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
486                    ToolCallStatus::Allowed { .. } | ToolCallStatus::Rejected => continue,
487                },
488                AgentThreadEntryContent::Message(_) => {
489                    // Reached the beginning of the turn
490                    return false;
491                }
492            }
493        }
494        false
495    }
496
497    pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
498        let agent = self.server.clone();
499        let id = self.id.clone();
500        let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
501        let message = Message {
502            role: Role::User,
503            chunks: vec![chunk],
504        };
505        self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
506        let acp_message = message.into_acp(cx);
507        cx.spawn(async move |_, cx| {
508            agent.send_message(id, acp_message, cx).await?;
509            Ok(())
510        })
511    }
512}
513
514fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
515    match icon {
516        acp::Icon::FileSearch => IconName::FileSearch,
517        acp::Icon::Folder => IconName::Folder,
518        acp::Icon::Globe => IconName::Globe,
519        acp::Icon::Hammer => IconName::Hammer,
520        acp::Icon::LightBulb => IconName::LightBulb,
521        acp::Icon::Pencil => IconName::Pencil,
522        acp::Icon::Regex => IconName::Regex,
523        acp::Icon::Terminal => IconName::Terminal,
524    }
525}
526
527pub struct ToolCallRequest {
528    pub id: ToolCallId,
529    pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
530}
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535    use futures::{FutureExt as _, channel::mpsc, select};
536    use gpui::{AsyncApp, TestAppContext};
537    use project::FakeFs;
538    use serde_json::json;
539    use settings::SettingsStore;
540    use smol::stream::StreamExt as _;
541    use std::{env, path::Path, process::Stdio, time::Duration};
542    use util::path;
543
544    fn init_test(cx: &mut TestAppContext) {
545        env_logger::try_init().ok();
546        cx.update(|cx| {
547            let settings_store = SettingsStore::test(cx);
548            cx.set_global(settings_store);
549            Project::init_settings(cx);
550            language::init(cx);
551        });
552    }
553
554    #[gpui::test]
555    async fn test_gemini_basic(cx: &mut TestAppContext) {
556        init_test(cx);
557
558        cx.executor().allow_parking();
559
560        let fs = FakeFs::new(cx.executor());
561        let project = Project::test(fs, [], cx).await;
562        let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
563        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
564        thread
565            .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
566            .await
567            .unwrap();
568
569        thread.read_with(cx, |thread, _| {
570            assert_eq!(thread.entries.len(), 2);
571            assert!(matches!(
572                thread.entries[0].content,
573                AgentThreadEntryContent::Message(Message {
574                    role: Role::User,
575                    ..
576                })
577            ));
578            assert!(matches!(
579                thread.entries[1].content,
580                AgentThreadEntryContent::Message(Message {
581                    role: Role::Assistant,
582                    ..
583                })
584            ));
585        });
586    }
587
588    #[gpui::test]
589    async fn test_gemini_tool_call(cx: &mut TestAppContext) {
590        init_test(cx);
591
592        cx.executor().allow_parking();
593
594        let fs = FakeFs::new(cx.executor());
595        fs.insert_tree(
596            path!("/private/tmp"),
597            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
598        )
599        .await;
600        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
601        let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
602        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
603        thread
604            .update(cx, |thread, cx| {
605                thread.send(
606                    "Read the '/private/tmp/foo' file and tell me what you see.",
607                    cx,
608                )
609            })
610            .await
611            .unwrap();
612        thread.read_with(cx, |thread, _cx| {
613            assert!(matches!(
614                &thread.entries()[1].content,
615                AgentThreadEntryContent::ToolCall(ToolCall {
616                    status: ToolCallStatus::Allowed { .. },
617                    ..
618                })
619            ));
620
621            assert!(matches!(
622                thread.entries[2].content,
623                AgentThreadEntryContent::Message(Message {
624                    role: Role::Assistant,
625                    ..
626                })
627            ));
628        });
629    }
630
631    #[gpui::test]
632    async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
633        init_test(cx);
634
635        cx.executor().allow_parking();
636
637        let fs = FakeFs::new(cx.executor());
638        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
639        let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
640        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
641        let full_turn = thread.update(cx, |thread, cx| {
642            thread.send(r#"Run `echo "Hello, world!"`"#, cx)
643        });
644
645        run_until_tool_call(&thread, cx).await;
646
647        let tool_call_id = thread.read_with(cx, |thread, _cx| {
648            let AgentThreadEntryContent::ToolCall(ToolCall {
649                id,
650                status:
651                    ToolCallStatus::WaitingForConfirmation {
652                        confirmation: acp::ToolCallConfirmation::Execute { root_command, .. },
653                        ..
654                    },
655                ..
656            }) = &thread.entries()[1].content
657            else {
658                panic!();
659            };
660
661            assert_eq!(root_command, "echo");
662
663            *id
664        });
665
666        thread.update(cx, |thread, cx| {
667            thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
668
669            assert!(matches!(
670                &thread.entries()[1].content,
671                AgentThreadEntryContent::ToolCall(ToolCall {
672                    status: ToolCallStatus::Allowed { .. },
673                    ..
674                })
675            ));
676        });
677
678        full_turn.await.unwrap();
679
680        thread.read_with(cx, |thread, cx| {
681            let AgentThreadEntryContent::ToolCall(ToolCall {
682                status:
683                    ToolCallStatus::Allowed {
684                        content: Some(ToolCallContent::Markdown { markdown }),
685                        ..
686                    },
687                ..
688            }) = &thread.entries()[1].content
689            else {
690                panic!();
691            };
692
693            markdown.read_with(cx, |md, _cx| {
694                assert!(
695                    md.source().contains("Hello, world!"),
696                    r#"Expected '{}' to contain "Hello, world!""#,
697                    md.source()
698                );
699            });
700        });
701    }
702
703    async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
704        let (mut tx, mut rx) = mpsc::channel::<()>(1);
705
706        let subscription = cx.update(|cx| {
707            cx.subscribe(thread, move |thread, _, cx| {
708                if thread
709                    .read(cx)
710                    .entries
711                    .iter()
712                    .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
713                {
714                    tx.try_send(()).unwrap();
715                }
716            })
717        });
718
719        select! {
720            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
721                panic!("Timeout waiting for tool call")
722            }
723            _ = rx.next().fuse() => {
724                drop(subscription);
725            }
726        }
727    }
728
729    pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {
730        let cli_path =
731            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
732        let mut command = util::command::new_smol_command("node");
733        command
734            .arg(cli_path)
735            .arg("--acp")
736            .current_dir("/private/tmp")
737            .stdin(Stdio::piped())
738            .stdout(Stdio::piped())
739            .stderr(Stdio::inherit())
740            .kill_on_drop(true);
741
742        if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
743            command.env("GEMINI_API_KEY", gemini_key);
744        }
745
746        let child = command.spawn().unwrap();
747
748        Ok(AcpServer::stdio(child, project, &mut cx))
749    }
750}