1mod server;
2mod thread_view;
3
4use agentic_coding_protocol::{self as acp};
5use anyhow::{Context as _, Result};
6use buffer_diff::BufferDiff;
7use chrono::{DateTime, Utc};
8use editor::{MultiBuffer, PathKey};
9use futures::channel::oneshot;
10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
11use language::{Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _};
12use markdown::Markdown;
13use project::Project;
14use std::{mem, ops::Range, path::PathBuf, sync::Arc};
15use ui::{App, IconName};
16use util::{ResultExt, debug_panic};
17
18pub use server::AcpServer;
19pub use thread_view::AcpThreadView;
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct ThreadId(SharedString);
23
24#[derive(Copy, Clone, Debug, PartialEq, Eq)]
25pub struct FileVersion(u64);
26
27#[derive(Debug)]
28pub struct AgentThreadSummary {
29 pub id: ThreadId,
30 pub title: String,
31 pub created_at: DateTime<Utc>,
32}
33
34#[derive(Clone, Debug, PartialEq, Eq)]
35pub struct FileContent {
36 pub path: PathBuf,
37 pub version: FileVersion,
38 pub content: SharedString,
39}
40
41#[derive(Clone, Debug, Eq, PartialEq)]
42pub struct UserMessage {
43 pub chunks: Vec<UserMessageChunk>,
44}
45
46impl UserMessage {
47 fn into_acp(self, cx: &App) -> acp::UserMessage {
48 acp::UserMessage {
49 chunks: self
50 .chunks
51 .into_iter()
52 .map(|chunk| chunk.into_acp(cx))
53 .collect(),
54 }
55 }
56}
57
58#[derive(Clone, Debug, Eq, PartialEq)]
59pub enum UserMessageChunk {
60 Text {
61 chunk: Entity<Markdown>,
62 },
63 File {
64 content: FileContent,
65 },
66 Directory {
67 path: PathBuf,
68 contents: Vec<FileContent>,
69 },
70 Symbol {
71 path: PathBuf,
72 range: Range<u64>,
73 version: FileVersion,
74 name: SharedString,
75 content: SharedString,
76 },
77 Fetch {
78 url: SharedString,
79 content: SharedString,
80 },
81}
82
83impl UserMessageChunk {
84 pub fn into_acp(self, cx: &App) -> acp::UserMessageChunk {
85 match self {
86 Self::Text { chunk } => acp::UserMessageChunk::Text {
87 chunk: chunk.read(cx).source().to_string(),
88 },
89 Self::File { .. } => todo!(),
90 Self::Directory { .. } => todo!(),
91 Self::Symbol { .. } => todo!(),
92 Self::Fetch { .. } => todo!(),
93 }
94 }
95
96 pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
97 Self::Text {
98 chunk: cx.new(|cx| {
99 Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
100 }),
101 }
102 }
103}
104
105#[derive(Clone, Debug, Eq, PartialEq)]
106pub struct AssistantMessage {
107 pub chunks: Vec<AssistantMessageChunk>,
108}
109
110#[derive(Clone, Debug, Eq, PartialEq)]
111pub enum AssistantMessageChunk {
112 Text { chunk: Entity<Markdown> },
113 Thought { chunk: Entity<Markdown> },
114}
115
116impl AssistantMessageChunk {
117 pub fn from_acp(
118 chunk: acp::AssistantMessageChunk,
119 language_registry: Arc<LanguageRegistry>,
120 cx: &mut App,
121 ) -> Self {
122 match chunk {
123 acp::AssistantMessageChunk::Text { chunk } => Self::Text {
124 chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
125 },
126 acp::AssistantMessageChunk::Thought { chunk } => Self::Thought {
127 chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
128 },
129 }
130 }
131
132 pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
133 Self::Text {
134 chunk: cx.new(|cx| {
135 Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
136 }),
137 }
138 }
139}
140
141#[derive(Debug)]
142pub enum AgentThreadEntryContent {
143 UserMessage(UserMessage),
144 AssistantMessage(AssistantMessage),
145 ToolCall(ToolCall),
146}
147
148#[derive(Debug)]
149pub struct ToolCall {
150 id: ToolCallId,
151 label: Entity<Markdown>,
152 icon: IconName,
153 content: Option<ToolCallContent>,
154 status: ToolCallStatus,
155}
156
157#[derive(Debug)]
158pub enum ToolCallStatus {
159 WaitingForConfirmation {
160 confirmation: ToolCallConfirmation,
161 respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
162 },
163 Allowed {
164 status: acp::ToolCallStatus,
165 },
166 Rejected,
167 Canceled,
168}
169
170#[derive(Debug)]
171pub enum ToolCallConfirmation {
172 Edit {
173 description: Option<Entity<Markdown>>,
174 },
175 Execute {
176 command: String,
177 root_command: String,
178 description: Option<Entity<Markdown>>,
179 },
180 Mcp {
181 server_name: String,
182 tool_name: String,
183 tool_display_name: String,
184 description: Option<Entity<Markdown>>,
185 },
186 Fetch {
187 urls: Vec<String>,
188 description: Option<Entity<Markdown>>,
189 },
190 Other {
191 description: Entity<Markdown>,
192 },
193}
194
195impl ToolCallConfirmation {
196 pub fn from_acp(
197 confirmation: acp::ToolCallConfirmation,
198 language_registry: Arc<LanguageRegistry>,
199 cx: &mut App,
200 ) -> Self {
201 let to_md = |description: String, cx: &mut App| -> Entity<Markdown> {
202 cx.new(|cx| {
203 Markdown::new(
204 description.into(),
205 Some(language_registry.clone()),
206 None,
207 cx,
208 )
209 })
210 };
211
212 match confirmation {
213 acp::ToolCallConfirmation::Edit { description } => Self::Edit {
214 description: description.map(|description| to_md(description, cx)),
215 },
216 acp::ToolCallConfirmation::Execute {
217 command,
218 root_command,
219 description,
220 } => Self::Execute {
221 command,
222 root_command,
223 description: description.map(|description| to_md(description, cx)),
224 },
225 acp::ToolCallConfirmation::Mcp {
226 server_name,
227 tool_name,
228 tool_display_name,
229 description,
230 } => Self::Mcp {
231 server_name,
232 tool_name,
233 tool_display_name,
234 description: description.map(|description| to_md(description, cx)),
235 },
236 acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch {
237 urls,
238 description: description.map(|description| to_md(description, cx)),
239 },
240 acp::ToolCallConfirmation::Other { description } => Self::Other {
241 description: to_md(description, cx),
242 },
243 }
244 }
245}
246
247#[derive(Debug)]
248pub enum ToolCallContent {
249 Markdown { markdown: Entity<Markdown> },
250 Diff { diff: Diff },
251}
252
253impl ToolCallContent {
254 pub fn from_acp(
255 content: acp::ToolCallContent,
256 language_registry: Arc<LanguageRegistry>,
257 cx: &mut App,
258 ) -> Self {
259 match content {
260 acp::ToolCallContent::Markdown { markdown } => Self::Markdown {
261 markdown: cx.new(|cx| Markdown::new_text(markdown.into(), cx)),
262 },
263 acp::ToolCallContent::Diff { diff } => Self::Diff {
264 diff: Diff::from_acp(diff, language_registry, cx),
265 },
266 }
267 }
268}
269
270#[derive(Debug)]
271pub struct Diff {
272 multibuffer: Entity<MultiBuffer>,
273 path: PathBuf,
274 _task: Task<Result<()>>,
275}
276
277impl Diff {
278 pub fn from_acp(
279 diff: acp::Diff,
280 language_registry: Arc<LanguageRegistry>,
281 cx: &mut App,
282 ) -> Self {
283 let acp::Diff {
284 path,
285 old_text,
286 new_text,
287 } = diff;
288
289 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
290
291 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
292 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
293 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
294 let old_buffer_snapshot = old_buffer.read(cx).snapshot();
295 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
296 let diff_task = buffer_diff.update(cx, |diff, cx| {
297 diff.set_base_text(
298 old_buffer_snapshot,
299 Some(language_registry.clone()),
300 new_buffer_snapshot,
301 cx,
302 )
303 });
304
305 let task = cx.spawn({
306 let multibuffer = multibuffer.clone();
307 let path = path.clone();
308 async move |cx| {
309 diff_task.await?;
310
311 multibuffer
312 .update(cx, |multibuffer, cx| {
313 let hunk_ranges = {
314 let buffer = new_buffer.read(cx);
315 let diff = buffer_diff.read(cx);
316 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
317 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
318 .collect::<Vec<_>>()
319 };
320
321 multibuffer.set_excerpts_for_path(
322 PathKey::for_buffer(&new_buffer, cx),
323 new_buffer.clone(),
324 hunk_ranges,
325 editor::DEFAULT_MULTIBUFFER_CONTEXT,
326 cx,
327 );
328 multibuffer.add_diff(buffer_diff.clone(), cx);
329 })
330 .log_err();
331
332 if let Some(language) = language_registry
333 .language_for_file_path(&path)
334 .await
335 .log_err()
336 {
337 new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
338 }
339
340 anyhow::Ok(())
341 }
342 });
343
344 Self {
345 multibuffer,
346 path,
347 _task: task,
348 }
349 }
350}
351
352/// A `ThreadEntryId` that is known to be a ToolCall
353#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
354pub struct ToolCallId(ThreadEntryId);
355
356impl ToolCallId {
357 pub fn as_u64(&self) -> u64 {
358 self.0.0
359 }
360}
361
362#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
363pub struct ThreadEntryId(pub u64);
364
365impl ThreadEntryId {
366 pub fn post_inc(&mut self) -> Self {
367 let id = *self;
368 self.0 += 1;
369 id
370 }
371}
372
373#[derive(Debug)]
374pub struct ThreadEntry {
375 pub id: ThreadEntryId,
376 pub content: AgentThreadEntryContent,
377}
378
379pub struct AcpThread {
380 id: ThreadId,
381 next_entry_id: ThreadEntryId,
382 entries: Vec<ThreadEntry>,
383 server: Arc<AcpServer>,
384 title: SharedString,
385 project: Entity<Project>,
386 send_task: Option<Task<()>>,
387}
388
389enum AcpThreadEvent {
390 NewEntry,
391 EntryUpdated(usize),
392}
393
394#[derive(PartialEq, Eq)]
395pub enum ThreadStatus {
396 Idle,
397 WaitingForToolConfirmation,
398 Generating,
399}
400
401impl EventEmitter<AcpThreadEvent> for AcpThread {}
402
403impl AcpThread {
404 pub fn new(
405 server: Arc<AcpServer>,
406 thread_id: ThreadId,
407 entries: Vec<AgentThreadEntryContent>,
408 project: Entity<Project>,
409 _: &mut Context<Self>,
410 ) -> Self {
411 let mut next_entry_id = ThreadEntryId(0);
412 Self {
413 title: "ACP Thread".into(),
414 entries: entries
415 .into_iter()
416 .map(|entry| ThreadEntry {
417 id: next_entry_id.post_inc(),
418 content: entry,
419 })
420 .collect(),
421 server,
422 id: thread_id,
423 next_entry_id,
424 project,
425 send_task: None,
426 }
427 }
428
429 pub fn title(&self) -> SharedString {
430 self.title.clone()
431 }
432
433 pub fn entries(&self) -> &[ThreadEntry] {
434 &self.entries
435 }
436
437 pub fn status(&self) -> ThreadStatus {
438 if self.send_task.is_some() {
439 if self.waiting_for_tool_confirmation() {
440 ThreadStatus::WaitingForToolConfirmation
441 } else {
442 ThreadStatus::Generating
443 }
444 } else {
445 ThreadStatus::Idle
446 }
447 }
448
449 pub fn push_entry(
450 &mut self,
451 entry: AgentThreadEntryContent,
452 cx: &mut Context<Self>,
453 ) -> ThreadEntryId {
454 let id = self.next_entry_id.post_inc();
455 self.entries.push(ThreadEntry { id, content: entry });
456 cx.emit(AcpThreadEvent::NewEntry);
457 id
458 }
459
460 pub fn push_assistant_chunk(
461 &mut self,
462 chunk: acp::AssistantMessageChunk,
463 cx: &mut Context<Self>,
464 ) {
465 let entries_len = self.entries.len();
466 if let Some(last_entry) = self.entries.last_mut()
467 && let AgentThreadEntryContent::AssistantMessage(AssistantMessage { ref mut chunks }) =
468 last_entry.content
469 {
470 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
471
472 match (chunks.last_mut(), &chunk) {
473 (
474 Some(AssistantMessageChunk::Text { chunk: old_chunk }),
475 acp::AssistantMessageChunk::Text { chunk: new_chunk },
476 )
477 | (
478 Some(AssistantMessageChunk::Thought { chunk: old_chunk }),
479 acp::AssistantMessageChunk::Thought { chunk: new_chunk },
480 ) => {
481 old_chunk.update(cx, |old_chunk, cx| {
482 old_chunk.append(&new_chunk, cx);
483 });
484 }
485 _ => {
486 chunks.push(AssistantMessageChunk::from_acp(
487 chunk,
488 self.project.read(cx).languages().clone(),
489 cx,
490 ));
491 }
492 }
493 } else {
494 let chunk = AssistantMessageChunk::from_acp(
495 chunk,
496 self.project.read(cx).languages().clone(),
497 cx,
498 );
499
500 self.push_entry(
501 AgentThreadEntryContent::AssistantMessage(AssistantMessage {
502 chunks: vec![chunk],
503 }),
504 cx,
505 );
506 }
507 }
508
509 pub fn request_tool_call(
510 &mut self,
511 label: String,
512 icon: acp::Icon,
513 content: Option<acp::ToolCallContent>,
514 confirmation: acp::ToolCallConfirmation,
515 cx: &mut Context<Self>,
516 ) -> ToolCallRequest {
517 let (tx, rx) = oneshot::channel();
518
519 let status = ToolCallStatus::WaitingForConfirmation {
520 confirmation: ToolCallConfirmation::from_acp(
521 confirmation,
522 self.project.read(cx).languages().clone(),
523 cx,
524 ),
525 respond_tx: tx,
526 };
527
528 let id = self.insert_tool_call(label, status, icon, content, cx);
529 ToolCallRequest { id, outcome: rx }
530 }
531
532 pub fn push_tool_call(
533 &mut self,
534 label: String,
535 icon: acp::Icon,
536 content: Option<acp::ToolCallContent>,
537 cx: &mut Context<Self>,
538 ) -> ToolCallId {
539 let status = ToolCallStatus::Allowed {
540 status: acp::ToolCallStatus::Running,
541 };
542
543 self.insert_tool_call(label, status, icon, content, cx)
544 }
545
546 fn insert_tool_call(
547 &mut self,
548 label: String,
549 status: ToolCallStatus,
550 icon: acp::Icon,
551 content: Option<acp::ToolCallContent>,
552 cx: &mut Context<Self>,
553 ) -> ToolCallId {
554 let language_registry = self.project.read(cx).languages().clone();
555
556 let entry_id = self.push_entry(
557 AgentThreadEntryContent::ToolCall(ToolCall {
558 // todo! clean up id creation
559 id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
560 label: cx.new(|cx| {
561 Markdown::new(label.into(), Some(language_registry.clone()), None, cx)
562 }),
563 icon: acp_icon_to_ui_icon(icon),
564 content: content
565 .map(|content| ToolCallContent::from_acp(content, language_registry, cx)),
566 status,
567 }),
568 cx,
569 );
570
571 ToolCallId(entry_id)
572 }
573
574 pub fn authorize_tool_call(
575 &mut self,
576 id: ToolCallId,
577 outcome: acp::ToolCallConfirmationOutcome,
578 cx: &mut Context<Self>,
579 ) {
580 let Some(entry) = self.entry_mut(id.0) else {
581 return;
582 };
583
584 let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
585 debug_panic!("expected ToolCall");
586 return;
587 };
588
589 let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
590 ToolCallStatus::Rejected
591 } else {
592 ToolCallStatus::Allowed {
593 status: acp::ToolCallStatus::Running,
594 }
595 };
596
597 let curr_status = mem::replace(&mut call.status, new_status);
598
599 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
600 respond_tx.send(outcome).log_err();
601 } else {
602 debug_panic!("tried to authorize an already authorized tool call");
603 }
604
605 cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
606 }
607
608 pub fn update_tool_call(
609 &mut self,
610 id: ToolCallId,
611 new_status: acp::ToolCallStatus,
612 new_content: Option<acp::ToolCallContent>,
613 cx: &mut Context<Self>,
614 ) -> Result<()> {
615 let language_registry = self.project.read(cx).languages().clone();
616 let entry = self.entry_mut(id.0).context("Entry not found")?;
617
618 match &mut entry.content {
619 AgentThreadEntryContent::ToolCall(call) => {
620 call.content = new_content.map(|new_content| {
621 ToolCallContent::from_acp(new_content, language_registry, cx)
622 });
623
624 match &mut call.status {
625 ToolCallStatus::Allowed { status } => {
626 *status = new_status;
627 }
628 ToolCallStatus::WaitingForConfirmation { .. } => {
629 anyhow::bail!("Tool call hasn't been authorized yet")
630 }
631 ToolCallStatus::Rejected => {
632 anyhow::bail!("Tool call was rejected and therefore can't be updated")
633 }
634 ToolCallStatus::Canceled => {
635 // todo! test this case with fake server
636 call.status = ToolCallStatus::Allowed { status: new_status };
637 }
638 }
639 }
640 _ => anyhow::bail!("Entry is not a tool call"),
641 }
642
643 cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
644 Ok(())
645 }
646
647 fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
648 let entry = self.entries.get_mut(id.0 as usize);
649 debug_assert!(
650 entry.is_some(),
651 "We shouldn't give out ids to entries that don't exist"
652 );
653 entry
654 }
655
656 /// Returns true if the last turn is awaiting tool authorization
657 pub fn waiting_for_tool_confirmation(&self) -> bool {
658 // todo!("should we use a hashmap?")
659 for entry in self.entries.iter().rev() {
660 match &entry.content {
661 AgentThreadEntryContent::ToolCall(call) => match call.status {
662 ToolCallStatus::WaitingForConfirmation { .. } => return true,
663 ToolCallStatus::Allowed { .. }
664 | ToolCallStatus::Rejected
665 | ToolCallStatus::Canceled => continue,
666 },
667 AgentThreadEntryContent::UserMessage(_)
668 | AgentThreadEntryContent::AssistantMessage(_) => {
669 // Reached the beginning of the turn
670 return false;
671 }
672 }
673 }
674 false
675 }
676
677 pub fn send(
678 &mut self,
679 message: &str,
680 cx: &mut Context<Self>,
681 ) -> impl use<> + Future<Output = Result<()>> {
682 let agent = self.server.clone();
683 let id = self.id.clone();
684 let chunk =
685 UserMessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
686 let message = UserMessage {
687 chunks: vec![chunk],
688 };
689 self.push_entry(AgentThreadEntryContent::UserMessage(message.clone()), cx);
690 let acp_message = message.into_acp(cx);
691
692 let (tx, rx) = oneshot::channel();
693 let cancel = self.cancel(cx);
694
695 self.send_task = Some(cx.spawn(async move |this, cx| {
696 cancel.await.log_err();
697
698 let result = agent.send_message(id, acp_message, cx).await;
699 tx.send(result).log_err();
700 this.update(cx, |this, _cx| this.send_task.take()).log_err();
701 }));
702
703 async move {
704 match rx.await {
705 Ok(result) => result,
706 Err(_) => Ok(()),
707 }
708 }
709 }
710
711 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
712 let agent = self.server.clone();
713 let id = self.id.clone();
714
715 if self.send_task.take().is_some() {
716 cx.spawn(async move |this, cx| {
717 agent.cancel_send_message(id, cx).await?;
718
719 this.update(cx, |this, _cx| {
720 for entry in this.entries.iter_mut() {
721 if let AgentThreadEntryContent::ToolCall(call) = &mut entry.content {
722 let cancel = matches!(
723 call.status,
724 ToolCallStatus::WaitingForConfirmation { .. }
725 | ToolCallStatus::Allowed {
726 status: acp::ToolCallStatus::Running
727 }
728 );
729
730 if cancel {
731 let curr_status =
732 mem::replace(&mut call.status, ToolCallStatus::Canceled);
733
734 if let ToolCallStatus::WaitingForConfirmation {
735 respond_tx, ..
736 } = curr_status
737 {
738 respond_tx
739 .send(acp::ToolCallConfirmationOutcome::Cancel)
740 .ok();
741 }
742 }
743 }
744 }
745 })
746 })
747 } else {
748 Task::ready(Ok(()))
749 }
750 }
751}
752
753fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
754 match icon {
755 acp::Icon::FileSearch => IconName::FileSearch,
756 acp::Icon::Folder => IconName::Folder,
757 acp::Icon::Globe => IconName::Globe,
758 acp::Icon::Hammer => IconName::Hammer,
759 acp::Icon::LightBulb => IconName::LightBulb,
760 acp::Icon::Pencil => IconName::Pencil,
761 acp::Icon::Regex => IconName::Regex,
762 acp::Icon::Terminal => IconName::Terminal,
763 }
764}
765
766pub struct ToolCallRequest {
767 pub id: ToolCallId,
768 pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
769}
770
771#[cfg(test)]
772mod tests {
773 use super::*;
774 use async_pipe::{PipeReader, PipeWriter};
775 use async_trait::async_trait;
776 use futures::{FutureExt as _, channel::mpsc, future::LocalBoxFuture, select};
777 use gpui::{AsyncApp, TestAppContext};
778 use project::FakeFs;
779 use serde_json::json;
780 use settings::SettingsStore;
781 use smol::stream::StreamExt as _;
782 use std::{env, path::Path, process::Stdio, rc::Rc, time::Duration};
783 use util::path;
784
785 fn init_test(cx: &mut TestAppContext) {
786 env_logger::try_init().ok();
787 cx.update(|cx| {
788 let settings_store = SettingsStore::test(cx);
789 cx.set_global(settings_store);
790 Project::init_settings(cx);
791 language::init(cx);
792 });
793 }
794
795 #[gpui::test]
796 async fn test_message_receipt(cx: &mut TestAppContext) {
797 init_test(cx);
798
799 cx.executor().allow_parking();
800
801 let fs = FakeFs::new(cx.executor());
802 let project = Project::test(fs, [], cx).await;
803 let (server, fake_server) = fake_acp_server(project, cx);
804
805 server.initialize().await.unwrap();
806
807 fake_server.update(cx, |fake_server, _| {
808 fake_server.on_user_message(move |params, server, mut cx| async move {
809 server
810 .update(&mut cx, |server, cx| {
811 server.send_to_zed(
812 acp::StreamAssistantMessageChunkParams {
813 thread_id: params.thread_id.clone(),
814 chunk: acp::AssistantMessageChunk::Thought {
815 chunk: "Thinking ".into(),
816 },
817 },
818 cx,
819 )
820 })?
821 .await
822 .unwrap();
823 server
824 .update(&mut cx, |server, cx| {
825 server.send_to_zed(
826 acp::StreamAssistantMessageChunkParams {
827 thread_id: params.thread_id,
828 chunk: acp::AssistantMessageChunk::Thought {
829 chunk: "hard!".into(),
830 },
831 },
832 cx,
833 )
834 })?
835 .await
836 .unwrap();
837
838 Ok(acp::SendUserMessageResponse)
839 })
840 })
841 }
842
843 #[gpui::test]
844 async fn test_gemini_basic(cx: &mut TestAppContext) {
845 init_test(cx);
846
847 cx.executor().allow_parking();
848
849 let fs = FakeFs::new(cx.executor());
850 let project = Project::test(fs, [], cx).await;
851 let server = gemini_acp_server(project.clone(), cx).await;
852 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
853 thread
854 .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
855 .await
856 .unwrap();
857
858 thread.read_with(cx, |thread, _| {
859 assert_eq!(thread.entries.len(), 2);
860 assert!(matches!(
861 thread.entries[0].content,
862 AgentThreadEntryContent::UserMessage(_)
863 ));
864 assert!(matches!(
865 thread.entries[1].content,
866 AgentThreadEntryContent::AssistantMessage(_)
867 ));
868 });
869 }
870
871 #[gpui::test]
872 async fn test_gemini_tool_call(cx: &mut TestAppContext) {
873 init_test(cx);
874
875 cx.executor().allow_parking();
876
877 let fs = FakeFs::new(cx.executor());
878 fs.insert_tree(
879 path!("/private/tmp"),
880 json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
881 )
882 .await;
883 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
884 let server = gemini_acp_server(project.clone(), cx).await;
885 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
886 thread
887 .update(cx, |thread, cx| {
888 thread.send(
889 "Read the '/private/tmp/foo' file and tell me what you see.",
890 cx,
891 )
892 })
893 .await
894 .unwrap();
895 thread.read_with(cx, |thread, _cx| {
896 assert!(matches!(
897 &thread.entries()[2].content,
898 AgentThreadEntryContent::ToolCall(ToolCall {
899 status: ToolCallStatus::Allowed { .. },
900 ..
901 })
902 ));
903
904 assert!(matches!(
905 thread.entries[3].content,
906 AgentThreadEntryContent::AssistantMessage(_)
907 ));
908 });
909 }
910
911 #[gpui::test]
912 async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
913 init_test(cx);
914
915 cx.executor().allow_parking();
916
917 let fs = FakeFs::new(cx.executor());
918 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
919 let server = gemini_acp_server(project.clone(), cx).await;
920 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
921 let full_turn = thread.update(cx, |thread, cx| {
922 thread.send(r#"Run `echo "Hello, world!"`"#, cx)
923 });
924
925 run_until_first_tool_call(&thread, cx).await;
926
927 let tool_call_id = thread.read_with(cx, |thread, _cx| {
928 let AgentThreadEntryContent::ToolCall(ToolCall {
929 id,
930 status:
931 ToolCallStatus::WaitingForConfirmation {
932 confirmation: ToolCallConfirmation::Execute { root_command, .. },
933 ..
934 },
935 ..
936 }) = &thread.entries()[2].content
937 else {
938 panic!();
939 };
940
941 assert_eq!(root_command, "echo");
942
943 *id
944 });
945
946 thread.update(cx, |thread, cx| {
947 thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
948
949 assert!(matches!(
950 &thread.entries()[2].content,
951 AgentThreadEntryContent::ToolCall(ToolCall {
952 status: ToolCallStatus::Allowed { .. },
953 ..
954 })
955 ));
956 });
957
958 full_turn.await.unwrap();
959
960 thread.read_with(cx, |thread, cx| {
961 let AgentThreadEntryContent::ToolCall(ToolCall {
962 content: Some(ToolCallContent::Markdown { markdown }),
963 status: ToolCallStatus::Allowed { .. },
964 ..
965 }) = &thread.entries()[2].content
966 else {
967 panic!();
968 };
969
970 markdown.read_with(cx, |md, _cx| {
971 assert!(
972 md.source().contains("Hello, world!"),
973 r#"Expected '{}' to contain "Hello, world!""#,
974 md.source()
975 );
976 });
977 });
978 }
979
980 #[gpui::test]
981 async fn test_gemini_cancel(cx: &mut TestAppContext) {
982 init_test(cx);
983
984 cx.executor().allow_parking();
985
986 let fs = FakeFs::new(cx.executor());
987 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
988 let server = gemini_acp_server(project.clone(), cx).await;
989 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
990 let full_turn = thread.update(cx, |thread, cx| {
991 thread.send(r#"Run `echo "Hello, world!"`"#, cx)
992 });
993
994 let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
995
996 thread.read_with(cx, |thread, _cx| {
997 let AgentThreadEntryContent::ToolCall(ToolCall {
998 id,
999 status:
1000 ToolCallStatus::WaitingForConfirmation {
1001 confirmation: ToolCallConfirmation::Execute { root_command, .. },
1002 ..
1003 },
1004 ..
1005 }) = &thread.entries()[first_tool_call_ix].content
1006 else {
1007 panic!("{:?}", thread.entries()[1].content);
1008 };
1009
1010 assert_eq!(root_command, "echo");
1011
1012 *id
1013 });
1014
1015 thread
1016 .update(cx, |thread, cx| thread.cancel(cx))
1017 .await
1018 .unwrap();
1019 full_turn.await.unwrap();
1020 thread.read_with(cx, |thread, _| {
1021 let AgentThreadEntryContent::ToolCall(ToolCall {
1022 status: ToolCallStatus::Canceled,
1023 ..
1024 }) = &thread.entries()[first_tool_call_ix].content
1025 else {
1026 panic!();
1027 };
1028 });
1029
1030 thread
1031 .update(cx, |thread, cx| {
1032 thread.send(r#"Stop running and say goodbye to me."#, cx)
1033 })
1034 .await
1035 .unwrap();
1036 thread.read_with(cx, |thread, _| {
1037 assert!(matches!(
1038 &thread.entries().last().unwrap().content,
1039 AgentThreadEntryContent::AssistantMessage(..),
1040 ))
1041 });
1042 }
1043
1044 async fn run_until_first_tool_call(
1045 thread: &Entity<AcpThread>,
1046 cx: &mut TestAppContext,
1047 ) -> usize {
1048 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1049
1050 let subscription = cx.update(|cx| {
1051 cx.subscribe(thread, move |thread, _, cx| {
1052 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1053 if matches!(entry.content, AgentThreadEntryContent::ToolCall(_)) {
1054 return tx.try_send(ix).unwrap();
1055 }
1056 }
1057 })
1058 });
1059
1060 select! {
1061 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1062 panic!("Timeout waiting for tool call")
1063 }
1064 ix = rx.next().fuse() => {
1065 drop(subscription);
1066 ix.unwrap()
1067 }
1068 }
1069 }
1070
1071 pub async fn gemini_acp_server(
1072 project: Entity<Project>,
1073 cx: &mut TestAppContext,
1074 ) -> Arc<AcpServer> {
1075 let cli_path =
1076 Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
1077 let mut command = util::command::new_smol_command("node");
1078 command
1079 .arg(cli_path)
1080 .arg("--acp")
1081 .current_dir("/private/tmp")
1082 .stdin(Stdio::piped())
1083 .stdout(Stdio::piped())
1084 .stderr(Stdio::inherit())
1085 .kill_on_drop(true);
1086
1087 if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
1088 command.env("GEMINI_API_KEY", gemini_key);
1089 }
1090
1091 let child = command.spawn().unwrap();
1092 let server = cx.update(|cx| AcpServer::stdio(child, project, cx));
1093 server.initialize().await.unwrap();
1094 server
1095 }
1096
1097 pub fn fake_acp_server(
1098 project: Entity<Project>,
1099 cx: &mut TestAppContext,
1100 ) -> (Arc<AcpServer>, Entity<FakeAcpServer>) {
1101 let (stdin_tx, stdin_rx) = async_pipe::pipe();
1102 let (stdout_tx, stdout_rx) = async_pipe::pipe();
1103 let server = cx.update(|cx| AcpServer::fake(stdin_tx, stdout_rx, project, cx));
1104 let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
1105 (server, agent)
1106 }
1107
1108 pub struct FakeAcpServer {
1109 connection: acp::ClientConnection,
1110 _handler_task: Task<()>,
1111 _io_task: Task<()>,
1112 on_user_message: Option<
1113 Rc<
1114 dyn Fn(
1115 acp::SendUserMessageParams,
1116 Entity<FakeAcpServer>,
1117 AsyncApp,
1118 )
1119 -> LocalBoxFuture<'static, Result<acp::SendUserMessageResponse>>,
1120 >,
1121 >,
1122 }
1123
1124 #[derive(Clone)]
1125 struct FakeAgent {
1126 server: Entity<FakeAcpServer>,
1127 cx: AsyncApp,
1128 }
1129
1130 #[async_trait(?Send)]
1131 impl acp::Agent for FakeAgent {
1132 async fn initialize(
1133 &self,
1134 _request: acp::InitializeParams,
1135 ) -> Result<acp::InitializeResponse> {
1136 Ok(acp::InitializeResponse {
1137 is_authenticated: true,
1138 })
1139 }
1140
1141 async fn authenticate(
1142 &self,
1143 _request: acp::AuthenticateParams,
1144 ) -> Result<acp::AuthenticateResponse> {
1145 Ok(acp::AuthenticateResponse)
1146 }
1147
1148 async fn create_thread(
1149 &self,
1150 _request: acp::CreateThreadParams,
1151 ) -> Result<acp::CreateThreadResponse> {
1152 Ok(acp::CreateThreadResponse {
1153 thread_id: acp::ThreadId("test-thread".into()),
1154 })
1155 }
1156
1157 async fn send_user_message(
1158 &self,
1159 request: acp::SendUserMessageParams,
1160 ) -> Result<acp::SendUserMessageResponse> {
1161 let mut cx = self.cx.clone();
1162 let handler = self
1163 .server
1164 .update(&mut cx, |server, _| server.on_user_message.clone())
1165 .ok()
1166 .flatten();
1167 if let Some(handler) = handler {
1168 handler(request, self.server.clone(), self.cx.clone()).await
1169 } else {
1170 anyhow::bail!("No handler for on_user_message")
1171 }
1172 }
1173 }
1174
1175 impl FakeAcpServer {
1176 fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
1177 let agent = FakeAgent {
1178 server: cx.entity(),
1179 cx: cx.to_async(),
1180 };
1181
1182 let (connection, handler_fut, io_fut) =
1183 acp::ClientConnection::connect_to_client(agent.clone(), stdout, stdin);
1184 FakeAcpServer {
1185 connection: connection,
1186 on_user_message: None,
1187 _handler_task: cx.foreground_executor().spawn(handler_fut),
1188 _io_task: cx.background_spawn(async move {
1189 io_fut.await.log_err();
1190 }),
1191 }
1192 }
1193
1194 fn on_user_message<F>(
1195 &mut self,
1196 handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity<FakeAcpServer>, AsyncApp) -> F
1197 + 'static,
1198 ) where
1199 F: Future<Output = Result<acp::SendUserMessageResponse>> + 'static,
1200 {
1201 self.on_user_message
1202 .replace(Rc::new(move |request, server, cx| {
1203 handler(request, server, cx).boxed_local()
1204 }));
1205 }
1206
1207 fn send_to_zed<T: acp::ClientRequest>(
1208 &self,
1209 message: T,
1210 cx: &Context<Self>,
1211 ) -> Task<Result<T::Response, acp::Error>> {
1212 let future = self.connection.request(message);
1213 cx.foreground_executor().spawn(future)
1214 }
1215 }
1216}