acp.rs

  1mod server;
  2mod thread_view;
  3
  4use agentic_coding_protocol::{self as acp, Role};
  5use anyhow::{Context as _, Result};
  6use chrono::{DateTime, Utc};
  7use futures::channel::oneshot;
  8use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  9use language::LanguageRegistry;
 10use markdown::Markdown;
 11use project::Project;
 12use std::{mem, ops::Range, path::PathBuf, sync::Arc};
 13use ui::{App, IconName};
 14use util::{ResultExt, debug_panic};
 15
 16pub use server::AcpServer;
 17pub use thread_view::AcpThreadView;
 18
 19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 20pub struct ThreadId(SharedString);
 21
 22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 23pub struct FileVersion(u64);
 24
 25#[derive(Debug)]
 26pub struct AgentThreadSummary {
 27    pub id: ThreadId,
 28    pub title: String,
 29    pub created_at: DateTime<Utc>,
 30}
 31
 32#[derive(Clone, Debug, PartialEq, Eq)]
 33pub struct FileContent {
 34    pub path: PathBuf,
 35    pub version: FileVersion,
 36    pub content: SharedString,
 37}
 38
 39#[derive(Clone, Debug, Eq, PartialEq)]
 40pub struct Message {
 41    pub role: acp::Role,
 42    pub chunks: Vec<MessageChunk>,
 43}
 44
 45impl Message {
 46    fn into_acp(self, cx: &App) -> acp::Message {
 47        acp::Message {
 48            role: self.role,
 49            chunks: self
 50                .chunks
 51                .into_iter()
 52                .map(|chunk| chunk.into_acp(cx))
 53                .collect(),
 54        }
 55    }
 56}
 57
 58#[derive(Clone, Debug, Eq, PartialEq)]
 59pub enum MessageChunk {
 60    Text {
 61        chunk: Entity<Markdown>,
 62    },
 63    File {
 64        content: FileContent,
 65    },
 66    Directory {
 67        path: PathBuf,
 68        contents: Vec<FileContent>,
 69    },
 70    Symbol {
 71        path: PathBuf,
 72        range: Range<u64>,
 73        version: FileVersion,
 74        name: SharedString,
 75        content: SharedString,
 76    },
 77    Fetch {
 78        url: SharedString,
 79        content: SharedString,
 80    },
 81}
 82
 83impl MessageChunk {
 84    pub fn from_acp(
 85        chunk: acp::MessageChunk,
 86        language_registry: Arc<LanguageRegistry>,
 87        cx: &mut App,
 88    ) -> Self {
 89        match chunk {
 90            acp::MessageChunk::Text { chunk } => MessageChunk::Text {
 91                chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
 92            },
 93        }
 94    }
 95
 96    pub fn into_acp(self, cx: &App) -> acp::MessageChunk {
 97        match self {
 98            MessageChunk::Text { chunk } => acp::MessageChunk::Text {
 99                chunk: chunk.read(cx).source().to_string(),
100            },
101            MessageChunk::File { .. } => todo!(),
102            MessageChunk::Directory { .. } => todo!(),
103            MessageChunk::Symbol { .. } => todo!(),
104            MessageChunk::Fetch { .. } => todo!(),
105        }
106    }
107
108    pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
109        MessageChunk::Text {
110            chunk: cx.new(|cx| {
111                Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
112            }),
113        }
114    }
115}
116
117#[derive(Debug)]
118pub enum AgentThreadEntryContent {
119    Message(Message),
120    ToolCall(ToolCall),
121}
122
123#[derive(Debug)]
124pub struct ToolCall {
125    id: ToolCallId,
126    label: Entity<Markdown>,
127    icon: IconName,
128    status: ToolCallStatus,
129}
130
131#[derive(Debug)]
132pub enum ToolCallStatus {
133    WaitingForConfirmation {
134        confirmation: acp::ToolCallConfirmation,
135        respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
136    },
137    // todo! Running?
138    Allowed {
139        // todo! should this be variants in crate::ToolCallStatus instead?
140        status: acp::ToolCallStatus,
141        content: Option<Entity<Markdown>>,
142    },
143    Rejected,
144}
145
146/// A `ThreadEntryId` that is known to be a ToolCall
147#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
148pub struct ToolCallId(ThreadEntryId);
149
150impl ToolCallId {
151    pub fn as_u64(&self) -> u64 {
152        self.0.0
153    }
154}
155
156#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
157pub struct ThreadEntryId(pub u64);
158
159impl ThreadEntryId {
160    pub fn post_inc(&mut self) -> Self {
161        let id = *self;
162        self.0 += 1;
163        id
164    }
165}
166
167#[derive(Debug)]
168pub struct ThreadEntry {
169    pub id: ThreadEntryId,
170    pub content: AgentThreadEntryContent,
171}
172
173pub struct AcpThread {
174    id: ThreadId,
175    next_entry_id: ThreadEntryId,
176    entries: Vec<ThreadEntry>,
177    server: Arc<AcpServer>,
178    title: SharedString,
179    project: Entity<Project>,
180}
181
182enum AcpThreadEvent {
183    NewEntry,
184    EntryUpdated(usize),
185}
186
187impl EventEmitter<AcpThreadEvent> for AcpThread {}
188
189impl AcpThread {
190    pub fn new(
191        server: Arc<AcpServer>,
192        thread_id: ThreadId,
193        entries: Vec<AgentThreadEntryContent>,
194        project: Entity<Project>,
195        _: &mut Context<Self>,
196    ) -> Self {
197        let mut next_entry_id = ThreadEntryId(0);
198        Self {
199            title: "A new agent2 thread".into(),
200            entries: entries
201                .into_iter()
202                .map(|entry| ThreadEntry {
203                    id: next_entry_id.post_inc(),
204                    content: entry,
205                })
206                .collect(),
207            server,
208            id: thread_id,
209            next_entry_id,
210            project,
211        }
212    }
213
214    pub fn title(&self) -> SharedString {
215        self.title.clone()
216    }
217
218    pub fn entries(&self) -> &[ThreadEntry] {
219        &self.entries
220    }
221
222    pub fn push_entry(
223        &mut self,
224        entry: AgentThreadEntryContent,
225        cx: &mut Context<Self>,
226    ) -> ThreadEntryId {
227        let id = self.next_entry_id.post_inc();
228        self.entries.push(ThreadEntry { id, content: entry });
229        cx.emit(AcpThreadEvent::NewEntry);
230        id
231    }
232
233    pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
234        let entries_len = self.entries.len();
235        if let Some(last_entry) = self.entries.last_mut()
236            && let AgentThreadEntryContent::Message(Message {
237                ref mut chunks,
238                role: Role::Assistant,
239            }) = last_entry.content
240        {
241            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
242
243            if let (
244                Some(MessageChunk::Text { chunk: old_chunk }),
245                acp::MessageChunk::Text { chunk: new_chunk },
246            ) = (chunks.last_mut(), &chunk)
247            {
248                old_chunk.update(cx, |old_chunk, cx| {
249                    old_chunk.append(&new_chunk, cx);
250                });
251            } else {
252                chunks.push(MessageChunk::from_acp(
253                    chunk,
254                    self.project.read(cx).languages().clone(),
255                    cx,
256                ));
257            }
258
259            return;
260        }
261
262        let chunk = MessageChunk::from_acp(chunk, self.project.read(cx).languages().clone(), cx);
263
264        self.push_entry(
265            AgentThreadEntryContent::Message(Message {
266                role: Role::Assistant,
267                chunks: vec![chunk],
268            }),
269            cx,
270        );
271    }
272
273    pub fn request_tool_call(
274        &mut self,
275        label: String,
276        icon: acp::Icon,
277        confirmation: acp::ToolCallConfirmation,
278        cx: &mut Context<Self>,
279    ) -> ToolCallRequest {
280        let (tx, rx) = oneshot::channel();
281
282        let status = ToolCallStatus::WaitingForConfirmation {
283            confirmation,
284            respond_tx: tx,
285        };
286
287        let id = self.insert_tool_call(label, status, icon, cx);
288        ToolCallRequest { id, outcome: rx }
289    }
290
291    pub fn push_tool_call(
292        &mut self,
293        label: String,
294        icon: acp::Icon,
295        cx: &mut Context<Self>,
296    ) -> ToolCallId {
297        let status = ToolCallStatus::Allowed {
298            status: acp::ToolCallStatus::Running,
299            content: None,
300        };
301
302        self.insert_tool_call(label, status, icon, cx)
303    }
304
305    fn insert_tool_call(
306        &mut self,
307        label: String,
308        status: ToolCallStatus,
309        icon: acp::Icon,
310        cx: &mut Context<Self>,
311    ) -> ToolCallId {
312        let language_registry = self.project.read(cx).languages().clone();
313
314        let entry_id = self.push_entry(
315            AgentThreadEntryContent::ToolCall(ToolCall {
316                // todo! clean up id creation
317                id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
318                label: cx.new(|cx| {
319                    Markdown::new(label.into(), Some(language_registry.clone()), None, cx)
320                }),
321                icon: acp_icon_to_ui_icon(icon),
322                status,
323            }),
324            cx,
325        );
326
327        ToolCallId(entry_id)
328    }
329
330    pub fn authorize_tool_call(
331        &mut self,
332        id: ToolCallId,
333        outcome: acp::ToolCallConfirmationOutcome,
334        cx: &mut Context<Self>,
335    ) {
336        let Some(entry) = self.entry_mut(id.0) else {
337            return;
338        };
339
340        let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
341            debug_panic!("expected ToolCall");
342            return;
343        };
344
345        let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
346            ToolCallStatus::Rejected
347        } else {
348            ToolCallStatus::Allowed {
349                status: acp::ToolCallStatus::Running,
350                content: None,
351            }
352        };
353
354        let curr_status = mem::replace(&mut call.status, new_status);
355
356        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
357            respond_tx.send(outcome).log_err();
358        } else {
359            debug_panic!("tried to authorize an already authorized tool call");
360        }
361
362        cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
363    }
364
365    pub fn update_tool_call(
366        &mut self,
367        id: ToolCallId,
368        new_status: acp::ToolCallStatus,
369        new_content: Option<acp::ToolCallContent>,
370        cx: &mut Context<Self>,
371    ) -> Result<()> {
372        let language_registry = self.project.read(cx).languages().clone();
373        let entry = self.entry_mut(id.0).context("Entry not found")?;
374
375        match &mut entry.content {
376            AgentThreadEntryContent::ToolCall(call) => match &mut call.status {
377                ToolCallStatus::Allowed { content, status } => {
378                    *content = new_content.map(|new_content| {
379                        let acp::ToolCallContent::Markdown { markdown } = new_content;
380
381                        cx.new(|cx| {
382                            Markdown::new(markdown.into(), Some(language_registry), None, cx)
383                        })
384                    });
385
386                    *status = new_status;
387                }
388                ToolCallStatus::WaitingForConfirmation { .. } => {
389                    anyhow::bail!("Tool call hasn't been authorized yet")
390                }
391                ToolCallStatus::Rejected => {
392                    anyhow::bail!("Tool call was rejected and therefore can't be updated")
393                }
394            },
395            _ => anyhow::bail!("Entry is not a tool call"),
396        }
397
398        cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
399        Ok(())
400    }
401
402    fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
403        let entry = self.entries.get_mut(id.0 as usize);
404        debug_assert!(
405            entry.is_some(),
406            "We shouldn't give out ids to entries that don't exist"
407        );
408        entry
409    }
410
411    /// Returns true if the last turn is awaiting tool authorization
412    pub fn waiting_for_tool_confirmation(&self) -> bool {
413        for entry in self.entries.iter().rev() {
414            match &entry.content {
415                AgentThreadEntryContent::ToolCall(call) => match call.status {
416                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
417                    ToolCallStatus::Allowed { .. } | ToolCallStatus::Rejected => continue,
418                },
419                AgentThreadEntryContent::Message(_) => {
420                    // Reached the beginning of the turn
421                    return false;
422                }
423            }
424        }
425        false
426    }
427
428    pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
429        let agent = self.server.clone();
430        let id = self.id.clone();
431        let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
432        let message = Message {
433            role: Role::User,
434            chunks: vec![chunk],
435        };
436        self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
437        let acp_message = message.into_acp(cx);
438        cx.spawn(async move |_, cx| {
439            agent.send_message(id, acp_message, cx).await?;
440            Ok(())
441        })
442    }
443}
444
445fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
446    match icon {
447        acp::Icon::FileSearch => IconName::FileSearch,
448        acp::Icon::Folder => IconName::Folder,
449        acp::Icon::Globe => IconName::Globe,
450        acp::Icon::Hammer => IconName::Hammer,
451        acp::Icon::LightBulb => IconName::LightBulb,
452        acp::Icon::Pencil => IconName::Pencil,
453        acp::Icon::Regex => IconName::Regex,
454        acp::Icon::Terminal => IconName::Terminal,
455    }
456}
457
458pub struct ToolCallRequest {
459    pub id: ToolCallId,
460    pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466    use futures::{FutureExt as _, channel::mpsc, select};
467    use gpui::{AsyncApp, TestAppContext};
468    use project::FakeFs;
469    use serde_json::json;
470    use settings::SettingsStore;
471    use smol::stream::StreamExt as _;
472    use std::{env, path::Path, process::Stdio, time::Duration};
473    use util::path;
474
475    fn init_test(cx: &mut TestAppContext) {
476        env_logger::try_init().ok();
477        cx.update(|cx| {
478            let settings_store = SettingsStore::test(cx);
479            cx.set_global(settings_store);
480            Project::init_settings(cx);
481            language::init(cx);
482        });
483    }
484
485    #[gpui::test]
486    async fn test_gemini_basic(cx: &mut TestAppContext) {
487        init_test(cx);
488
489        cx.executor().allow_parking();
490
491        let fs = FakeFs::new(cx.executor());
492        let project = Project::test(fs, [], cx).await;
493        let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
494        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
495        thread
496            .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
497            .await
498            .unwrap();
499
500        thread.read_with(cx, |thread, _| {
501            assert_eq!(thread.entries.len(), 2);
502            assert!(matches!(
503                thread.entries[0].content,
504                AgentThreadEntryContent::Message(Message {
505                    role: Role::User,
506                    ..
507                })
508            ));
509            assert!(matches!(
510                thread.entries[1].content,
511                AgentThreadEntryContent::Message(Message {
512                    role: Role::Assistant,
513                    ..
514                })
515            ));
516        });
517    }
518
519    #[gpui::test]
520    async fn test_gemini_tool_call(cx: &mut TestAppContext) {
521        init_test(cx);
522
523        cx.executor().allow_parking();
524
525        let fs = FakeFs::new(cx.executor());
526        fs.insert_tree(
527            path!("/private/tmp"),
528            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
529        )
530        .await;
531        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
532        let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
533        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
534        thread
535            .update(cx, |thread, cx| {
536                thread.send(
537                    "Read the '/private/tmp/foo' file and tell me what you see.",
538                    cx,
539                )
540            })
541            .await
542            .unwrap();
543        thread.read_with(cx, |thread, _cx| {
544            assert!(matches!(
545                &thread.entries()[1].content,
546                AgentThreadEntryContent::ToolCall(ToolCall {
547                    status: ToolCallStatus::Allowed { .. },
548                    ..
549                })
550            ));
551
552            assert!(matches!(
553                thread.entries[2].content,
554                AgentThreadEntryContent::Message(Message {
555                    role: Role::Assistant,
556                    ..
557                })
558            ));
559        });
560    }
561
562    #[gpui::test]
563    async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
564        init_test(cx);
565
566        cx.executor().allow_parking();
567
568        let fs = FakeFs::new(cx.executor());
569        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
570        let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
571        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
572        let full_turn = thread.update(cx, |thread, cx| {
573            thread.send(r#"Run `echo "Hello, world!"`"#, cx)
574        });
575
576        run_until_tool_call(&thread, cx).await;
577
578        let tool_call_id = thread.read_with(cx, |thread, _cx| {
579            let AgentThreadEntryContent::ToolCall(ToolCall {
580                id,
581                status:
582                    ToolCallStatus::WaitingForConfirmation {
583                        confirmation: acp::ToolCallConfirmation::Execute { root_command, .. },
584                        ..
585                    },
586                ..
587            }) = &thread.entries()[1].content
588            else {
589                panic!();
590            };
591
592            assert_eq!(root_command, "echo");
593
594            *id
595        });
596
597        thread.update(cx, |thread, cx| {
598            thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
599
600            assert!(matches!(
601                &thread.entries()[1].content,
602                AgentThreadEntryContent::ToolCall(ToolCall {
603                    status: ToolCallStatus::Allowed { .. },
604                    ..
605                })
606            ));
607        });
608
609        full_turn.await.unwrap();
610
611        thread.read_with(cx, |thread, cx| {
612            let AgentThreadEntryContent::ToolCall(ToolCall {
613                status: ToolCallStatus::Allowed { content, .. },
614                ..
615            }) = &thread.entries()[1].content
616            else {
617                panic!();
618            };
619
620            content.as_ref().unwrap().read_with(cx, |md, _cx| {
621                assert!(
622                    md.source().contains("Hello, world!"),
623                    r#"Expected '{}' to contain "Hello, world!""#,
624                    md.source()
625                );
626            });
627        });
628    }
629
630    async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
631        let (mut tx, mut rx) = mpsc::channel::<()>(1);
632
633        let subscription = cx.update(|cx| {
634            cx.subscribe(thread, move |thread, _, cx| {
635                if thread
636                    .read(cx)
637                    .entries
638                    .iter()
639                    .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
640                {
641                    tx.try_send(()).unwrap();
642                }
643            })
644        });
645
646        select! {
647            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
648                panic!("Timeout waiting for tool call")
649            }
650            _ = rx.next().fuse() => {
651                drop(subscription);
652            }
653        }
654    }
655
656    pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {
657        let cli_path =
658            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
659        let mut command = util::command::new_smol_command("node");
660        command
661            .arg(cli_path)
662            .arg("--acp")
663            .current_dir("/private/tmp")
664            .stdin(Stdio::piped())
665            .stdout(Stdio::piped())
666            .stderr(Stdio::inherit())
667            .kill_on_drop(true);
668
669        if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
670            command.env("GEMINI_API_KEY", gemini_key);
671        }
672
673        let child = command.spawn().unwrap();
674
675        Ok(AcpServer::stdio(child, project, &mut cx))
676    }
677}