1mod server;
2mod thread_view;
3
4use agentic_coding_protocol::{self as acp};
5use anyhow::{Context as _, Result};
6use buffer_diff::BufferDiff;
7use chrono::{DateTime, Utc};
8use editor::{MultiBuffer, PathKey};
9use futures::channel::oneshot;
10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
11use language::{Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _};
12use markdown::Markdown;
13use project::Project;
14use std::{mem, ops::Range, path::PathBuf, sync::Arc};
15use ui::{App, IconName};
16use util::{ResultExt, debug_panic};
17
18pub use server::AcpServer;
19pub use thread_view::AcpThreadView;
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct ThreadId(SharedString);
23
24#[derive(Copy, Clone, Debug, PartialEq, Eq)]
25pub struct FileVersion(u64);
26
27#[derive(Debug)]
28pub struct AgentThreadSummary {
29 pub id: ThreadId,
30 pub title: String,
31 pub created_at: DateTime<Utc>,
32}
33
34#[derive(Clone, Debug, PartialEq, Eq)]
35pub struct FileContent {
36 pub path: PathBuf,
37 pub version: FileVersion,
38 pub content: SharedString,
39}
40
41#[derive(Clone, Debug, Eq, PartialEq)]
42pub struct UserMessage {
43 pub chunks: Vec<UserMessageChunk>,
44}
45
46impl UserMessage {
47 fn into_acp(self, cx: &App) -> acp::UserMessage {
48 acp::UserMessage {
49 chunks: self
50 .chunks
51 .into_iter()
52 .map(|chunk| chunk.into_acp(cx))
53 .collect(),
54 }
55 }
56}
57
58#[derive(Clone, Debug, Eq, PartialEq)]
59pub enum UserMessageChunk {
60 Text {
61 chunk: Entity<Markdown>,
62 },
63 File {
64 content: FileContent,
65 },
66 Directory {
67 path: PathBuf,
68 contents: Vec<FileContent>,
69 },
70 Symbol {
71 path: PathBuf,
72 range: Range<u64>,
73 version: FileVersion,
74 name: SharedString,
75 content: SharedString,
76 },
77 Fetch {
78 url: SharedString,
79 content: SharedString,
80 },
81}
82
83impl UserMessageChunk {
84 pub fn into_acp(self, cx: &App) -> acp::UserMessageChunk {
85 match self {
86 Self::Text { chunk } => acp::UserMessageChunk::Text {
87 chunk: chunk.read(cx).source().to_string(),
88 },
89 Self::File { .. } => todo!(),
90 Self::Directory { .. } => todo!(),
91 Self::Symbol { .. } => todo!(),
92 Self::Fetch { .. } => todo!(),
93 }
94 }
95
96 pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
97 Self::Text {
98 chunk: cx.new(|cx| {
99 Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
100 }),
101 }
102 }
103}
104
105#[derive(Clone, Debug, Eq, PartialEq)]
106pub struct AssistantMessage {
107 pub chunks: Vec<AssistantMessageChunk>,
108}
109
110#[derive(Clone, Debug, Eq, PartialEq)]
111pub enum AssistantMessageChunk {
112 Text { chunk: Entity<Markdown> },
113 Thought { chunk: Entity<Markdown> },
114}
115
116impl AssistantMessageChunk {
117 pub fn from_acp(
118 chunk: acp::AssistantMessageChunk,
119 language_registry: Arc<LanguageRegistry>,
120 cx: &mut App,
121 ) -> Self {
122 match chunk {
123 acp::AssistantMessageChunk::Text { chunk } => Self::Text {
124 chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
125 },
126 acp::AssistantMessageChunk::Thought { chunk } => Self::Thought {
127 chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
128 },
129 }
130 }
131
132 pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
133 Self::Text {
134 chunk: cx.new(|cx| {
135 Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
136 }),
137 }
138 }
139}
140
141#[derive(Debug)]
142pub enum AgentThreadEntryContent {
143 UserMessage(UserMessage),
144 AssistantMessage(AssistantMessage),
145 ToolCall(ToolCall),
146}
147
148#[derive(Debug)]
149pub struct ToolCall {
150 id: ToolCallId,
151 label: Entity<Markdown>,
152 icon: IconName,
153 content: Option<ToolCallContent>,
154 status: ToolCallStatus,
155}
156
157#[derive(Debug)]
158pub enum ToolCallStatus {
159 WaitingForConfirmation {
160 confirmation: ToolCallConfirmation,
161 respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
162 },
163 Allowed {
164 status: acp::ToolCallStatus,
165 },
166 Rejected,
167 Canceled,
168}
169
170#[derive(Debug)]
171pub enum ToolCallConfirmation {
172 Edit {
173 description: Option<Entity<Markdown>>,
174 },
175 Execute {
176 command: String,
177 root_command: String,
178 description: Option<Entity<Markdown>>,
179 },
180 Mcp {
181 server_name: String,
182 tool_name: String,
183 tool_display_name: String,
184 description: Option<Entity<Markdown>>,
185 },
186 Fetch {
187 urls: Vec<String>,
188 description: Option<Entity<Markdown>>,
189 },
190 Other {
191 description: Entity<Markdown>,
192 },
193}
194
195impl ToolCallConfirmation {
196 pub fn from_acp(
197 confirmation: acp::ToolCallConfirmation,
198 language_registry: Arc<LanguageRegistry>,
199 cx: &mut App,
200 ) -> Self {
201 let to_md = |description: String, cx: &mut App| -> Entity<Markdown> {
202 cx.new(|cx| {
203 Markdown::new(
204 description.into(),
205 Some(language_registry.clone()),
206 None,
207 cx,
208 )
209 })
210 };
211
212 match confirmation {
213 acp::ToolCallConfirmation::Edit { description } => Self::Edit {
214 description: description.map(|description| to_md(description, cx)),
215 },
216 acp::ToolCallConfirmation::Execute {
217 command,
218 root_command,
219 description,
220 } => Self::Execute {
221 command,
222 root_command,
223 description: description.map(|description| to_md(description, cx)),
224 },
225 acp::ToolCallConfirmation::Mcp {
226 server_name,
227 tool_name,
228 tool_display_name,
229 description,
230 } => Self::Mcp {
231 server_name,
232 tool_name,
233 tool_display_name,
234 description: description.map(|description| to_md(description, cx)),
235 },
236 acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch {
237 urls,
238 description: description.map(|description| to_md(description, cx)),
239 },
240 acp::ToolCallConfirmation::Other { description } => Self::Other {
241 description: to_md(description, cx),
242 },
243 }
244 }
245}
246
247#[derive(Debug)]
248pub enum ToolCallContent {
249 Markdown { markdown: Entity<Markdown> },
250 Diff { diff: Diff },
251}
252
253impl ToolCallContent {
254 pub fn from_acp(
255 content: acp::ToolCallContent,
256 language_registry: Arc<LanguageRegistry>,
257 cx: &mut App,
258 ) -> Self {
259 match content {
260 acp::ToolCallContent::Markdown { markdown } => Self::Markdown {
261 markdown: cx.new(|cx| Markdown::new_text(markdown.into(), cx)),
262 },
263 acp::ToolCallContent::Diff { diff } => Self::Diff {
264 diff: Diff::from_acp(diff, language_registry, cx),
265 },
266 }
267 }
268}
269
270#[derive(Debug)]
271pub struct Diff {
272 multibuffer: Entity<MultiBuffer>,
273 path: PathBuf,
274 _task: Task<Result<()>>,
275}
276
277impl Diff {
278 pub fn from_acp(
279 diff: acp::Diff,
280 language_registry: Arc<LanguageRegistry>,
281 cx: &mut App,
282 ) -> Self {
283 let acp::Diff {
284 path,
285 old_text,
286 new_text,
287 } = diff;
288
289 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
290
291 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
292 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
293 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
294 let old_buffer_snapshot = old_buffer.read(cx).snapshot();
295 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
296 let diff_task = buffer_diff.update(cx, |diff, cx| {
297 diff.set_base_text(
298 old_buffer_snapshot,
299 Some(language_registry.clone()),
300 new_buffer_snapshot,
301 cx,
302 )
303 });
304
305 let task = cx.spawn({
306 let multibuffer = multibuffer.clone();
307 let path = path.clone();
308 async move |cx| {
309 diff_task.await?;
310
311 multibuffer
312 .update(cx, |multibuffer, cx| {
313 let hunk_ranges = {
314 let buffer = new_buffer.read(cx);
315 let diff = buffer_diff.read(cx);
316 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
317 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
318 .collect::<Vec<_>>()
319 };
320
321 multibuffer.set_excerpts_for_path(
322 PathKey::for_buffer(&new_buffer, cx),
323 new_buffer.clone(),
324 hunk_ranges,
325 editor::DEFAULT_MULTIBUFFER_CONTEXT,
326 cx,
327 );
328 multibuffer.add_diff(buffer_diff.clone(), cx);
329 })
330 .log_err();
331
332 if let Some(language) = language_registry
333 .language_for_file_path(&path)
334 .await
335 .log_err()
336 {
337 new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
338 }
339
340 anyhow::Ok(())
341 }
342 });
343
344 Self {
345 multibuffer,
346 path,
347 _task: task,
348 }
349 }
350}
351
352/// A `ThreadEntryId` that is known to be a ToolCall
353#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
354pub struct ToolCallId(ThreadEntryId);
355
356impl ToolCallId {
357 pub fn as_u64(&self) -> u64 {
358 self.0.0
359 }
360}
361
362#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
363pub struct ThreadEntryId(pub u64);
364
365impl ThreadEntryId {
366 pub fn post_inc(&mut self) -> Self {
367 let id = *self;
368 self.0 += 1;
369 id
370 }
371}
372
373#[derive(Debug)]
374pub struct ThreadEntry {
375 pub id: ThreadEntryId,
376 pub content: AgentThreadEntryContent,
377}
378
379pub struct AcpThread {
380 id: ThreadId,
381 next_entry_id: ThreadEntryId,
382 entries: Vec<ThreadEntry>,
383 server: Arc<AcpServer>,
384 title: SharedString,
385 project: Entity<Project>,
386 send_task: Option<Task<()>>,
387}
388
389enum AcpThreadEvent {
390 NewEntry,
391 EntryUpdated(usize),
392}
393
394#[derive(PartialEq, Eq)]
395pub enum ThreadStatus {
396 Idle,
397 WaitingForToolConfirmation,
398 Generating,
399}
400
401impl EventEmitter<AcpThreadEvent> for AcpThread {}
402
403impl AcpThread {
404 pub fn new(
405 server: Arc<AcpServer>,
406 thread_id: ThreadId,
407 entries: Vec<AgentThreadEntryContent>,
408 project: Entity<Project>,
409 _: &mut Context<Self>,
410 ) -> Self {
411 let mut next_entry_id = ThreadEntryId(0);
412 Self {
413 title: "ACP Thread".into(),
414 entries: entries
415 .into_iter()
416 .map(|entry| ThreadEntry {
417 id: next_entry_id.post_inc(),
418 content: entry,
419 })
420 .collect(),
421 server,
422 id: thread_id,
423 next_entry_id,
424 project,
425 send_task: None,
426 }
427 }
428
429 pub fn title(&self) -> SharedString {
430 self.title.clone()
431 }
432
433 pub fn entries(&self) -> &[ThreadEntry] {
434 &self.entries
435 }
436
437 pub fn status(&self) -> ThreadStatus {
438 if self.send_task.is_some() {
439 if self.waiting_for_tool_confirmation() {
440 ThreadStatus::WaitingForToolConfirmation
441 } else {
442 ThreadStatus::Generating
443 }
444 } else {
445 ThreadStatus::Idle
446 }
447 }
448
449 pub fn push_entry(
450 &mut self,
451 entry: AgentThreadEntryContent,
452 cx: &mut Context<Self>,
453 ) -> ThreadEntryId {
454 let id = self.next_entry_id.post_inc();
455 self.entries.push(ThreadEntry { id, content: entry });
456 cx.emit(AcpThreadEvent::NewEntry);
457 id
458 }
459
460 pub fn push_assistant_chunk(
461 &mut self,
462 chunk: acp::AssistantMessageChunk,
463 cx: &mut Context<Self>,
464 ) {
465 let entries_len = self.entries.len();
466 if let Some(last_entry) = self.entries.last_mut()
467 && let AgentThreadEntryContent::AssistantMessage(AssistantMessage { ref mut chunks }) =
468 last_entry.content
469 {
470 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
471
472 match (chunks.last_mut(), &chunk) {
473 (
474 Some(AssistantMessageChunk::Text { chunk: old_chunk }),
475 acp::AssistantMessageChunk::Text { chunk: new_chunk },
476 )
477 | (
478 Some(AssistantMessageChunk::Thought { chunk: old_chunk }),
479 acp::AssistantMessageChunk::Thought { chunk: new_chunk },
480 ) => {
481 old_chunk.update(cx, |old_chunk, cx| {
482 old_chunk.append(&new_chunk, cx);
483 });
484 }
485 _ => {
486 chunks.push(AssistantMessageChunk::from_acp(
487 chunk,
488 self.project.read(cx).languages().clone(),
489 cx,
490 ));
491 }
492 }
493 } else {
494 let chunk = AssistantMessageChunk::from_acp(
495 chunk,
496 self.project.read(cx).languages().clone(),
497 cx,
498 );
499
500 self.push_entry(
501 AgentThreadEntryContent::AssistantMessage(AssistantMessage {
502 chunks: vec![chunk],
503 }),
504 cx,
505 );
506 }
507 }
508
509 pub fn request_tool_call(
510 &mut self,
511 label: String,
512 icon: acp::Icon,
513 content: Option<acp::ToolCallContent>,
514 confirmation: acp::ToolCallConfirmation,
515 cx: &mut Context<Self>,
516 ) -> ToolCallRequest {
517 let (tx, rx) = oneshot::channel();
518
519 let status = ToolCallStatus::WaitingForConfirmation {
520 confirmation: ToolCallConfirmation::from_acp(
521 confirmation,
522 self.project.read(cx).languages().clone(),
523 cx,
524 ),
525 respond_tx: tx,
526 };
527
528 let id = self.insert_tool_call(label, status, icon, content, cx);
529 ToolCallRequest { id, outcome: rx }
530 }
531
532 pub fn push_tool_call(
533 &mut self,
534 label: String,
535 icon: acp::Icon,
536 content: Option<acp::ToolCallContent>,
537 cx: &mut Context<Self>,
538 ) -> ToolCallId {
539 let status = ToolCallStatus::Allowed {
540 status: acp::ToolCallStatus::Running,
541 };
542
543 self.insert_tool_call(label, status, icon, content, cx)
544 }
545
546 fn insert_tool_call(
547 &mut self,
548 label: String,
549 status: ToolCallStatus,
550 icon: acp::Icon,
551 content: Option<acp::ToolCallContent>,
552 cx: &mut Context<Self>,
553 ) -> ToolCallId {
554 let language_registry = self.project.read(cx).languages().clone();
555
556 let entry_id = self.push_entry(
557 AgentThreadEntryContent::ToolCall(ToolCall {
558 // todo! clean up id creation
559 id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
560 label: cx.new(|cx| {
561 Markdown::new(label.into(), Some(language_registry.clone()), None, cx)
562 }),
563 icon: acp_icon_to_ui_icon(icon),
564 content: content
565 .map(|content| ToolCallContent::from_acp(content, language_registry, cx)),
566 status,
567 }),
568 cx,
569 );
570
571 ToolCallId(entry_id)
572 }
573
574 pub fn authorize_tool_call(
575 &mut self,
576 id: ToolCallId,
577 outcome: acp::ToolCallConfirmationOutcome,
578 cx: &mut Context<Self>,
579 ) {
580 let Some(entry) = self.entry_mut(id.0) else {
581 return;
582 };
583
584 let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
585 debug_panic!("expected ToolCall");
586 return;
587 };
588
589 let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
590 ToolCallStatus::Rejected
591 } else {
592 ToolCallStatus::Allowed {
593 status: acp::ToolCallStatus::Running,
594 }
595 };
596
597 let curr_status = mem::replace(&mut call.status, new_status);
598
599 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
600 respond_tx.send(outcome).log_err();
601 } else {
602 debug_panic!("tried to authorize an already authorized tool call");
603 }
604
605 cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
606 }
607
608 pub fn update_tool_call(
609 &mut self,
610 id: ToolCallId,
611 new_status: acp::ToolCallStatus,
612 new_content: Option<acp::ToolCallContent>,
613 cx: &mut Context<Self>,
614 ) -> Result<()> {
615 let language_registry = self.project.read(cx).languages().clone();
616 let entry = self.entry_mut(id.0).context("Entry not found")?;
617
618 match &mut entry.content {
619 AgentThreadEntryContent::ToolCall(call) => {
620 call.content = new_content.map(|new_content| {
621 ToolCallContent::from_acp(new_content, language_registry, cx)
622 });
623
624 match &mut call.status {
625 ToolCallStatus::Allowed { status } => {
626 *status = new_status;
627 }
628 ToolCallStatus::WaitingForConfirmation { .. } => {
629 anyhow::bail!("Tool call hasn't been authorized yet")
630 }
631 ToolCallStatus::Rejected => {
632 anyhow::bail!("Tool call was rejected and therefore can't be updated")
633 }
634 ToolCallStatus::Canceled => {
635 // todo! test this case with fake server
636 call.status = ToolCallStatus::Allowed { status: new_status };
637 }
638 }
639 }
640 _ => anyhow::bail!("Entry is not a tool call"),
641 }
642
643 cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
644 Ok(())
645 }
646
647 fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
648 let entry = self.entries.get_mut(id.0 as usize);
649 debug_assert!(
650 entry.is_some(),
651 "We shouldn't give out ids to entries that don't exist"
652 );
653 entry
654 }
655
656 /// Returns true if the last turn is awaiting tool authorization
657 pub fn waiting_for_tool_confirmation(&self) -> bool {
658 // todo!("should we use a hashmap?")
659 for entry in self.entries.iter().rev() {
660 match &entry.content {
661 AgentThreadEntryContent::ToolCall(call) => match call.status {
662 ToolCallStatus::WaitingForConfirmation { .. } => return true,
663 ToolCallStatus::Allowed { .. }
664 | ToolCallStatus::Rejected
665 | ToolCallStatus::Canceled => continue,
666 },
667 AgentThreadEntryContent::UserMessage(_)
668 | AgentThreadEntryContent::AssistantMessage(_) => {
669 // Reached the beginning of the turn
670 return false;
671 }
672 }
673 }
674 false
675 }
676
677 pub fn send(
678 &mut self,
679 message: &str,
680 cx: &mut Context<Self>,
681 ) -> impl use<> + Future<Output = Result<()>> {
682 let agent = self.server.clone();
683 let id = self.id.clone();
684 let chunk =
685 UserMessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
686 let message = UserMessage {
687 chunks: vec![chunk],
688 };
689 self.push_entry(AgentThreadEntryContent::UserMessage(message.clone()), cx);
690 let acp_message = message.into_acp(cx);
691
692 let (tx, rx) = oneshot::channel();
693 let cancel = self.cancel(cx);
694
695 self.send_task = Some(cx.spawn(async move |this, cx| {
696 cancel.await.log_err();
697
698 let result = agent.send_message(id, acp_message, cx).await;
699 tx.send(result).log_err();
700 this.update(cx, |this, _cx| this.send_task.take()).log_err();
701 }));
702
703 async move {
704 match rx.await {
705 Ok(result) => result,
706 Err(_) => Ok(()),
707 }
708 }
709 }
710
711 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
712 let agent = self.server.clone();
713 let id = self.id.clone();
714
715 if self.send_task.take().is_some() {
716 cx.spawn(async move |this, cx| {
717 agent.cancel_send_message(id, cx).await?;
718
719 this.update(cx, |this, _cx| {
720 for entry in this.entries.iter_mut() {
721 if let AgentThreadEntryContent::ToolCall(call) = &mut entry.content {
722 let cancel = matches!(
723 call.status,
724 ToolCallStatus::WaitingForConfirmation { .. }
725 | ToolCallStatus::Allowed {
726 status: acp::ToolCallStatus::Running
727 }
728 );
729
730 if cancel {
731 let curr_status =
732 mem::replace(&mut call.status, ToolCallStatus::Canceled);
733
734 if let ToolCallStatus::WaitingForConfirmation {
735 respond_tx, ..
736 } = curr_status
737 {
738 respond_tx
739 .send(acp::ToolCallConfirmationOutcome::Cancel)
740 .ok();
741 }
742 }
743 }
744 }
745 })
746 })
747 } else {
748 Task::ready(Ok(()))
749 }
750 }
751}
752
753fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
754 match icon {
755 acp::Icon::FileSearch => IconName::FileSearch,
756 acp::Icon::Folder => IconName::Folder,
757 acp::Icon::Globe => IconName::Globe,
758 acp::Icon::Hammer => IconName::Hammer,
759 acp::Icon::LightBulb => IconName::LightBulb,
760 acp::Icon::Pencil => IconName::Pencil,
761 acp::Icon::Regex => IconName::Regex,
762 acp::Icon::Terminal => IconName::Terminal,
763 }
764}
765
766pub struct ToolCallRequest {
767 pub id: ToolCallId,
768 pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
769}
770
771#[cfg(test)]
772mod tests {
773 use super::*;
774 use async_pipe::{PipeReader, PipeWriter};
775 use async_trait::async_trait;
776 use futures::{FutureExt as _, channel::mpsc, future::LocalBoxFuture, select};
777 use gpui::{AsyncApp, TestAppContext};
778 use project::FakeFs;
779 use serde_json::json;
780 use settings::SettingsStore;
781 use smol::stream::StreamExt as _;
782 use std::{env, path::Path, process::Stdio, rc::Rc, time::Duration};
783 use util::path;
784
785 fn init_test(cx: &mut TestAppContext) {
786 env_logger::try_init().ok();
787 cx.update(|cx| {
788 let settings_store = SettingsStore::test(cx);
789 cx.set_global(settings_store);
790 Project::init_settings(cx);
791 language::init(cx);
792 });
793 }
794
795 #[gpui::test]
796 async fn test_message_receipt(cx: &mut TestAppContext) {
797 init_test(cx);
798
799 cx.executor().allow_parking();
800
801 let fs = FakeFs::new(cx.executor());
802 let project = Project::test(fs, [], cx).await;
803 let (server, fake_server) = fake_acp_server(project, cx);
804
805 server.initialize().await.unwrap();
806
807 fake_server.update(cx, |fake_server, _| {
808 fake_server.on_user_message(move |params, server, mut cx| async move {
809 server
810 .update(&mut cx, |server, cx| {
811 let future =
812 server
813 .connection
814 .request(acp::StreamAssistantMessageChunkParams {
815 thread_id: params.thread_id,
816 chunk: acp::AssistantMessageChunk::Thought {
817 chunk: "Thinking ".into(),
818 },
819 });
820
821 cx.spawn(async move |_, _| future.await)
822 })?
823 .await
824 .unwrap();
825
826 Ok(acp::SendUserMessageResponse)
827 })
828 })
829 }
830
831 #[gpui::test]
832 async fn test_gemini_basic(cx: &mut TestAppContext) {
833 init_test(cx);
834
835 cx.executor().allow_parking();
836
837 let fs = FakeFs::new(cx.executor());
838 let project = Project::test(fs, [], cx).await;
839 let server = gemini_acp_server(project.clone(), cx).await;
840 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
841 thread
842 .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
843 .await
844 .unwrap();
845
846 thread.read_with(cx, |thread, _| {
847 assert_eq!(thread.entries.len(), 2);
848 assert!(matches!(
849 thread.entries[0].content,
850 AgentThreadEntryContent::UserMessage(_)
851 ));
852 assert!(matches!(
853 thread.entries[1].content,
854 AgentThreadEntryContent::AssistantMessage(_)
855 ));
856 });
857 }
858
859 #[gpui::test]
860 async fn test_gemini_tool_call(cx: &mut TestAppContext) {
861 init_test(cx);
862
863 cx.executor().allow_parking();
864
865 let fs = FakeFs::new(cx.executor());
866 fs.insert_tree(
867 path!("/private/tmp"),
868 json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
869 )
870 .await;
871 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
872 let server = gemini_acp_server(project.clone(), cx).await;
873 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
874 thread
875 .update(cx, |thread, cx| {
876 thread.send(
877 "Read the '/private/tmp/foo' file and tell me what you see.",
878 cx,
879 )
880 })
881 .await
882 .unwrap();
883 thread.read_with(cx, |thread, _cx| {
884 assert!(matches!(
885 &thread.entries()[2].content,
886 AgentThreadEntryContent::ToolCall(ToolCall {
887 status: ToolCallStatus::Allowed { .. },
888 ..
889 })
890 ));
891
892 assert!(matches!(
893 thread.entries[3].content,
894 AgentThreadEntryContent::AssistantMessage(_)
895 ));
896 });
897 }
898
899 #[gpui::test]
900 async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
901 init_test(cx);
902
903 cx.executor().allow_parking();
904
905 let fs = FakeFs::new(cx.executor());
906 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
907 let server = gemini_acp_server(project.clone(), cx).await;
908 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
909 let full_turn = thread.update(cx, |thread, cx| {
910 thread.send(r#"Run `echo "Hello, world!"`"#, cx)
911 });
912
913 run_until_first_tool_call(&thread, cx).await;
914
915 let tool_call_id = thread.read_with(cx, |thread, _cx| {
916 let AgentThreadEntryContent::ToolCall(ToolCall {
917 id,
918 status:
919 ToolCallStatus::WaitingForConfirmation {
920 confirmation: ToolCallConfirmation::Execute { root_command, .. },
921 ..
922 },
923 ..
924 }) = &thread.entries()[2].content
925 else {
926 panic!();
927 };
928
929 assert_eq!(root_command, "echo");
930
931 *id
932 });
933
934 thread.update(cx, |thread, cx| {
935 thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
936
937 assert!(matches!(
938 &thread.entries()[2].content,
939 AgentThreadEntryContent::ToolCall(ToolCall {
940 status: ToolCallStatus::Allowed { .. },
941 ..
942 })
943 ));
944 });
945
946 full_turn.await.unwrap();
947
948 thread.read_with(cx, |thread, cx| {
949 let AgentThreadEntryContent::ToolCall(ToolCall {
950 content: Some(ToolCallContent::Markdown { markdown }),
951 status: ToolCallStatus::Allowed { .. },
952 ..
953 }) = &thread.entries()[2].content
954 else {
955 panic!();
956 };
957
958 markdown.read_with(cx, |md, _cx| {
959 assert!(
960 md.source().contains("Hello, world!"),
961 r#"Expected '{}' to contain "Hello, world!""#,
962 md.source()
963 );
964 });
965 });
966 }
967
968 #[gpui::test]
969 async fn test_gemini_cancel(cx: &mut TestAppContext) {
970 init_test(cx);
971
972 cx.executor().allow_parking();
973
974 let fs = FakeFs::new(cx.executor());
975 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
976 let server = gemini_acp_server(project.clone(), cx).await;
977 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
978 let full_turn = thread.update(cx, |thread, cx| {
979 thread.send(r#"Run `echo "Hello, world!"`"#, cx)
980 });
981
982 let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
983
984 thread.read_with(cx, |thread, _cx| {
985 let AgentThreadEntryContent::ToolCall(ToolCall {
986 id,
987 status:
988 ToolCallStatus::WaitingForConfirmation {
989 confirmation: ToolCallConfirmation::Execute { root_command, .. },
990 ..
991 },
992 ..
993 }) = &thread.entries()[first_tool_call_ix].content
994 else {
995 panic!("{:?}", thread.entries()[1].content);
996 };
997
998 assert_eq!(root_command, "echo");
999
1000 *id
1001 });
1002
1003 thread
1004 .update(cx, |thread, cx| thread.cancel(cx))
1005 .await
1006 .unwrap();
1007 full_turn.await.unwrap();
1008 thread.read_with(cx, |thread, _| {
1009 let AgentThreadEntryContent::ToolCall(ToolCall {
1010 status: ToolCallStatus::Canceled,
1011 ..
1012 }) = &thread.entries()[first_tool_call_ix].content
1013 else {
1014 panic!();
1015 };
1016 });
1017
1018 thread
1019 .update(cx, |thread, cx| {
1020 thread.send(r#"Stop running and say goodbye to me."#, cx)
1021 })
1022 .await
1023 .unwrap();
1024 thread.read_with(cx, |thread, _| {
1025 assert!(matches!(
1026 &thread.entries().last().unwrap().content,
1027 AgentThreadEntryContent::AssistantMessage(..),
1028 ))
1029 });
1030 }
1031
1032 async fn run_until_first_tool_call(
1033 thread: &Entity<AcpThread>,
1034 cx: &mut TestAppContext,
1035 ) -> usize {
1036 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1037
1038 let subscription = cx.update(|cx| {
1039 cx.subscribe(thread, move |thread, _, cx| {
1040 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1041 if matches!(entry.content, AgentThreadEntryContent::ToolCall(_)) {
1042 return tx.try_send(ix).unwrap();
1043 }
1044 }
1045 })
1046 });
1047
1048 select! {
1049 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1050 panic!("Timeout waiting for tool call")
1051 }
1052 ix = rx.next().fuse() => {
1053 drop(subscription);
1054 ix.unwrap()
1055 }
1056 }
1057 }
1058
1059 pub async fn gemini_acp_server(
1060 project: Entity<Project>,
1061 cx: &mut TestAppContext,
1062 ) -> Arc<AcpServer> {
1063 let cli_path =
1064 Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
1065 let mut command = util::command::new_smol_command("node");
1066 command
1067 .arg(cli_path)
1068 .arg("--acp")
1069 .current_dir("/private/tmp")
1070 .stdin(Stdio::piped())
1071 .stdout(Stdio::piped())
1072 .stderr(Stdio::inherit())
1073 .kill_on_drop(true);
1074
1075 if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
1076 command.env("GEMINI_API_KEY", gemini_key);
1077 }
1078
1079 let child = command.spawn().unwrap();
1080 let server = cx.update(|cx| AcpServer::stdio(child, project, cx));
1081 server.initialize().await.unwrap();
1082 server
1083 }
1084
1085 pub fn fake_acp_server(
1086 project: Entity<Project>,
1087 cx: &mut TestAppContext,
1088 ) -> (Arc<AcpServer>, Entity<FakeAcpServer>) {
1089 let (stdin_tx, stdin_rx) = async_pipe::pipe();
1090 let (stdout_tx, stdout_rx) = async_pipe::pipe();
1091 let server = cx.update(|cx| AcpServer::fake(stdin_tx, stdout_rx, project, cx));
1092 let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
1093 (server, agent)
1094 }
1095
1096 pub struct FakeAcpServer {
1097 connection: acp::ClientConnection,
1098 _handler_task: Task<()>,
1099 _io_task: Task<()>,
1100 on_user_message: Option<
1101 Rc<
1102 dyn Fn(
1103 acp::SendUserMessageParams,
1104 Entity<FakeAcpServer>,
1105 AsyncApp,
1106 )
1107 -> LocalBoxFuture<'static, Result<acp::SendUserMessageResponse>>,
1108 >,
1109 >,
1110 }
1111
1112 #[derive(Clone)]
1113 struct FakeAgent {
1114 server: Entity<FakeAcpServer>,
1115 cx: AsyncApp,
1116 }
1117
1118 #[async_trait(?Send)]
1119 impl acp::Agent for FakeAgent {
1120 async fn initialize(
1121 &self,
1122 _request: acp::InitializeParams,
1123 ) -> Result<acp::InitializeResponse> {
1124 Ok(acp::InitializeResponse {
1125 is_authenticated: true,
1126 })
1127 }
1128
1129 async fn authenticate(
1130 &self,
1131 _request: acp::AuthenticateParams,
1132 ) -> Result<acp::AuthenticateResponse> {
1133 Ok(acp::AuthenticateResponse)
1134 }
1135
1136 async fn create_thread(
1137 &self,
1138 _request: acp::CreateThreadParams,
1139 ) -> Result<acp::CreateThreadResponse> {
1140 Ok(acp::CreateThreadResponse {
1141 thread_id: acp::ThreadId("test-thread".into()),
1142 })
1143 }
1144
1145 async fn send_user_message(
1146 &self,
1147 request: acp::SendUserMessageParams,
1148 ) -> Result<acp::SendUserMessageResponse> {
1149 let mut cx = self.cx.clone();
1150 let handler = self
1151 .server
1152 .update(&mut cx, |server, _| server.on_user_message.clone())
1153 .ok()
1154 .flatten();
1155 if let Some(handler) = handler {
1156 handler(request, self.server.clone(), self.cx.clone()).await
1157 } else {
1158 anyhow::bail!("No handler for on_user_message")
1159 }
1160 }
1161 }
1162
1163 impl FakeAcpServer {
1164 fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
1165 let agent = FakeAgent {
1166 server: cx.entity(),
1167 cx: cx.to_async(),
1168 };
1169
1170 let (connection, handler_fut, io_fut) =
1171 acp::ClientConnection::connect_to_client(agent.clone(), stdout, stdin);
1172 FakeAcpServer {
1173 connection: connection,
1174 on_user_message: None,
1175 _handler_task: cx.foreground_executor().spawn(handler_fut),
1176 _io_task: cx.background_spawn(async move {
1177 io_fut.await.log_err();
1178 }),
1179 }
1180 }
1181
1182 fn on_user_message<F>(
1183 &mut self,
1184 handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity<FakeAcpServer>, AsyncApp) -> F
1185 + 'static,
1186 ) where
1187 F: Future<Output = Result<acp::SendUserMessageResponse>> + 'static,
1188 {
1189 self.on_user_message
1190 .replace(Rc::new(move |request, server, cx| {
1191 handler(request, server, cx).boxed_local()
1192 }));
1193 }
1194 }
1195}