1mod server;
2mod thread_view;
3
4use agentic_coding_protocol::{self as acp};
5use anyhow::{Context as _, Result};
6use buffer_diff::BufferDiff;
7use chrono::{DateTime, Utc};
8use editor::{MultiBuffer, PathKey};
9use futures::channel::oneshot;
10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
11use language::{Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _};
12use markdown::Markdown;
13use parking_lot::Mutex;
14use parking_lot::Mutex;
15use project::Project;
16use std::{mem, ops::Range, path::PathBuf, process::ExitStatus, sync::Arc};
17use ui::{App, IconName};
18use util::{ResultExt, debug_panic};
19
20pub use server::AcpServer;
21pub use thread_view::AcpThreadView;
22
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct ThreadId(SharedString);
25
26#[derive(Copy, Clone, Debug, PartialEq, Eq)]
27pub struct FileVersion(u64);
28
29#[derive(Debug)]
30pub struct AgentThreadSummary {
31 pub id: ThreadId,
32 pub title: String,
33 pub created_at: DateTime<Utc>,
34}
35
36#[derive(Clone, Debug, PartialEq, Eq)]
37pub struct FileContent {
38 pub path: PathBuf,
39 pub version: FileVersion,
40 pub content: SharedString,
41}
42
43#[derive(Clone, Debug, Eq, PartialEq)]
44pub struct UserMessage {
45 pub chunks: Vec<UserMessageChunk>,
46}
47
48impl UserMessage {
49 fn into_acp(self, cx: &App) -> acp::UserMessage {
50 acp::UserMessage {
51 chunks: self
52 .chunks
53 .into_iter()
54 .map(|chunk| chunk.into_acp(cx))
55 .collect(),
56 }
57 }
58}
59
60#[derive(Clone, Debug, Eq, PartialEq)]
61pub enum UserMessageChunk {
62 Text {
63 chunk: Entity<Markdown>,
64 },
65 File {
66 content: FileContent,
67 },
68 Directory {
69 path: PathBuf,
70 contents: Vec<FileContent>,
71 },
72 Symbol {
73 path: PathBuf,
74 range: Range<u64>,
75 version: FileVersion,
76 name: SharedString,
77 content: SharedString,
78 },
79 Fetch {
80 url: SharedString,
81 content: SharedString,
82 },
83}
84
85impl UserMessageChunk {
86 pub fn into_acp(self, cx: &App) -> acp::UserMessageChunk {
87 match self {
88 Self::Text { chunk } => acp::UserMessageChunk::Text {
89 chunk: chunk.read(cx).source().to_string(),
90 },
91 Self::File { .. } => todo!(),
92 Self::Directory { .. } => todo!(),
93 Self::Symbol { .. } => todo!(),
94 Self::Fetch { .. } => todo!(),
95 }
96 }
97
98 pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
99 Self::Text {
100 chunk: cx.new(|cx| {
101 Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
102 }),
103 }
104 }
105}
106
107#[derive(Clone, Debug, Eq, PartialEq)]
108pub struct AssistantMessage {
109 pub chunks: Vec<AssistantMessageChunk>,
110}
111
112#[derive(Clone, Debug, Eq, PartialEq)]
113pub enum AssistantMessageChunk {
114 Text { chunk: Entity<Markdown> },
115 Thought { chunk: Entity<Markdown> },
116}
117
118impl AssistantMessageChunk {
119 pub fn from_acp(
120 chunk: acp::AssistantMessageChunk,
121 language_registry: Arc<LanguageRegistry>,
122 cx: &mut App,
123 ) -> Self {
124 match chunk {
125 acp::AssistantMessageChunk::Text { chunk } => Self::Text {
126 chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
127 },
128 acp::AssistantMessageChunk::Thought { chunk } => Self::Thought {
129 chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
130 },
131 }
132 }
133
134 pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
135 Self::Text {
136 chunk: cx.new(|cx| {
137 Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
138 }),
139 }
140 }
141}
142
143#[derive(Debug)]
144pub enum AgentThreadEntryContent {
145 UserMessage(UserMessage),
146 AssistantMessage(AssistantMessage),
147 ToolCall(ToolCall),
148}
149
150#[derive(Debug)]
151pub struct ToolCall {
152 id: ToolCallId,
153 label: Entity<Markdown>,
154 icon: IconName,
155 content: Option<ToolCallContent>,
156 status: ToolCallStatus,
157}
158
159#[derive(Debug)]
160pub enum ToolCallStatus {
161 WaitingForConfirmation {
162 confirmation: ToolCallConfirmation,
163 respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
164 },
165 Allowed {
166 status: acp::ToolCallStatus,
167 },
168 Rejected,
169 Canceled,
170}
171
172#[derive(Debug)]
173pub enum ToolCallConfirmation {
174 Edit {
175 description: Option<Entity<Markdown>>,
176 },
177 Execute {
178 command: String,
179 root_command: String,
180 description: Option<Entity<Markdown>>,
181 },
182 Mcp {
183 server_name: String,
184 tool_name: String,
185 tool_display_name: String,
186 description: Option<Entity<Markdown>>,
187 },
188 Fetch {
189 urls: Vec<String>,
190 description: Option<Entity<Markdown>>,
191 },
192 Other {
193 description: Entity<Markdown>,
194 },
195}
196
197impl ToolCallConfirmation {
198 pub fn from_acp(
199 confirmation: acp::ToolCallConfirmation,
200 language_registry: Arc<LanguageRegistry>,
201 cx: &mut App,
202 ) -> Self {
203 let to_md = |description: String, cx: &mut App| -> Entity<Markdown> {
204 cx.new(|cx| {
205 Markdown::new(
206 description.into(),
207 Some(language_registry.clone()),
208 None,
209 cx,
210 )
211 })
212 };
213
214 match confirmation {
215 acp::ToolCallConfirmation::Edit { description } => Self::Edit {
216 description: description.map(|description| to_md(description, cx)),
217 },
218 acp::ToolCallConfirmation::Execute {
219 command,
220 root_command,
221 description,
222 } => Self::Execute {
223 command,
224 root_command,
225 description: description.map(|description| to_md(description, cx)),
226 },
227 acp::ToolCallConfirmation::Mcp {
228 server_name,
229 tool_name,
230 tool_display_name,
231 description,
232 } => Self::Mcp {
233 server_name,
234 tool_name,
235 tool_display_name,
236 description: description.map(|description| to_md(description, cx)),
237 },
238 acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch {
239 urls,
240 description: description.map(|description| to_md(description, cx)),
241 },
242 acp::ToolCallConfirmation::Other { description } => Self::Other {
243 description: to_md(description, cx),
244 },
245 }
246 }
247}
248
249#[derive(Debug)]
250pub enum ToolCallContent {
251 Markdown { markdown: Entity<Markdown> },
252 Diff { diff: Diff },
253}
254
255impl ToolCallContent {
256 pub fn from_acp(
257 content: acp::ToolCallContent,
258 language_registry: Arc<LanguageRegistry>,
259 cx: &mut App,
260 ) -> Self {
261 match content {
262 acp::ToolCallContent::Markdown { markdown } => Self::Markdown {
263 markdown: cx.new(|cx| Markdown::new_text(markdown.into(), cx)),
264 },
265 acp::ToolCallContent::Diff { diff } => Self::Diff {
266 diff: Diff::from_acp(diff, language_registry, cx),
267 },
268 }
269 }
270}
271
272#[derive(Debug)]
273pub struct Diff {
274 multibuffer: Entity<MultiBuffer>,
275 path: PathBuf,
276 _task: Task<Result<()>>,
277}
278
279impl Diff {
280 pub fn from_acp(
281 diff: acp::Diff,
282 language_registry: Arc<LanguageRegistry>,
283 cx: &mut App,
284 ) -> Self {
285 let acp::Diff {
286 path,
287 old_text,
288 new_text,
289 } = diff;
290
291 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
292
293 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
294 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
295 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
296 let old_buffer_snapshot = old_buffer.read(cx).snapshot();
297 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
298 let diff_task = buffer_diff.update(cx, |diff, cx| {
299 diff.set_base_text(
300 old_buffer_snapshot,
301 Some(language_registry.clone()),
302 new_buffer_snapshot,
303 cx,
304 )
305 });
306
307 let task = cx.spawn({
308 let multibuffer = multibuffer.clone();
309 let path = path.clone();
310 async move |cx| {
311 diff_task.await?;
312
313 multibuffer
314 .update(cx, |multibuffer, cx| {
315 let hunk_ranges = {
316 let buffer = new_buffer.read(cx);
317 let diff = buffer_diff.read(cx);
318 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
319 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
320 .collect::<Vec<_>>()
321 };
322
323 multibuffer.set_excerpts_for_path(
324 PathKey::for_buffer(&new_buffer, cx),
325 new_buffer.clone(),
326 hunk_ranges,
327 editor::DEFAULT_MULTIBUFFER_CONTEXT,
328 cx,
329 );
330 multibuffer.add_diff(buffer_diff.clone(), cx);
331 })
332 .log_err();
333
334 if let Some(language) = language_registry
335 .language_for_file_path(&path)
336 .await
337 .log_err()
338 {
339 new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
340 }
341
342 anyhow::Ok(())
343 }
344 });
345
346 Self {
347 multibuffer,
348 path,
349 _task: task,
350 }
351 }
352}
353
354/// A `ThreadEntryId` that is known to be a ToolCall
355#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
356pub struct ToolCallId(ThreadEntryId);
357
358impl ToolCallId {
359 pub fn as_u64(&self) -> u64 {
360 self.0.0
361 }
362}
363
364#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
365pub struct ThreadEntryId(pub u64);
366
367impl ThreadEntryId {
368 pub fn post_inc(&mut self) -> Self {
369 let id = *self;
370 self.0 += 1;
371 id
372 }
373}
374
375#[derive(Debug)]
376pub struct ThreadEntry {
377 pub id: ThreadEntryId,
378 pub content: AgentThreadEntryContent,
379}
380
381pub struct AcpThread {
382 next_entry_id: ThreadEntryId,
383 entries: Vec<ThreadEntry>,
384 server: Arc<AcpServer>,
385 title: SharedString,
386 project: Entity<Project>,
387 send_task: Option<Task<()>>,
388
389 connection: Arc<acp::AgentConnection>,
390 exit_status: Arc<Mutex<Option<ExitStatus>>>,
391 _handler_task: Task<()>,
392 _io_task: Task<()>,
393}
394
395enum AcpThreadEvent {
396 NewEntry,
397 EntryUpdated(usize),
398}
399
400#[derive(PartialEq, Eq)]
401pub enum ThreadStatus {
402 Idle,
403 WaitingForToolConfirmation,
404 Generating,
405}
406
407impl EventEmitter<AcpThreadEvent> for AcpThread {}
408
409impl AcpThread {
410 pub fn new(
411 server: Arc<AcpServer>,
412 entries: Vec<AgentThreadEntryContent>,
413 project: Entity<Project>,
414 _: &mut Context<Self>,
415 ) -> Self {
416 let mut next_entry_id = ThreadEntryId(0);
417 Self {
418 title: "ACP Thread".into(),
419 entries: entries
420 .into_iter()
421 .map(|entry| ThreadEntry {
422 id: next_entry_id.post_inc(),
423 content: entry,
424 })
425 .collect(),
426 server,
427 next_entry_id,
428 project,
429 send_task: None,
430 }
431 }
432
433 pub fn title(&self) -> SharedString {
434 self.title.clone()
435 }
436
437 pub fn entries(&self) -> &[ThreadEntry] {
438 &self.entries
439 }
440
441 pub fn status(&self) -> ThreadStatus {
442 if self.send_task.is_some() {
443 if self.waiting_for_tool_confirmation() {
444 ThreadStatus::WaitingForToolConfirmation
445 } else {
446 ThreadStatus::Generating
447 }
448 } else {
449 ThreadStatus::Idle
450 }
451 }
452
453 pub fn push_entry(
454 &mut self,
455 entry: AgentThreadEntryContent,
456 cx: &mut Context<Self>,
457 ) -> ThreadEntryId {
458 let id = self.next_entry_id.post_inc();
459 self.entries.push(ThreadEntry { id, content: entry });
460 cx.emit(AcpThreadEvent::NewEntry);
461 id
462 }
463
464 pub fn push_assistant_chunk(
465 &mut self,
466 chunk: acp::AssistantMessageChunk,
467 cx: &mut Context<Self>,
468 ) {
469 let entries_len = self.entries.len();
470 if let Some(last_entry) = self.entries.last_mut()
471 && let AgentThreadEntryContent::AssistantMessage(AssistantMessage { ref mut chunks }) =
472 last_entry.content
473 {
474 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
475
476 match (chunks.last_mut(), &chunk) {
477 (
478 Some(AssistantMessageChunk::Text { chunk: old_chunk }),
479 acp::AssistantMessageChunk::Text { chunk: new_chunk },
480 )
481 | (
482 Some(AssistantMessageChunk::Thought { chunk: old_chunk }),
483 acp::AssistantMessageChunk::Thought { chunk: new_chunk },
484 ) => {
485 old_chunk.update(cx, |old_chunk, cx| {
486 old_chunk.append(&new_chunk, cx);
487 });
488 }
489 _ => {
490 chunks.push(AssistantMessageChunk::from_acp(
491 chunk,
492 self.project.read(cx).languages().clone(),
493 cx,
494 ));
495 }
496 }
497 } else {
498 let chunk = AssistantMessageChunk::from_acp(
499 chunk,
500 self.project.read(cx).languages().clone(),
501 cx,
502 );
503
504 self.push_entry(
505 AgentThreadEntryContent::AssistantMessage(AssistantMessage {
506 chunks: vec![chunk],
507 }),
508 cx,
509 );
510 }
511 }
512
513 pub fn request_tool_call(
514 &mut self,
515 label: String,
516 icon: acp::Icon,
517 content: Option<acp::ToolCallContent>,
518 confirmation: acp::ToolCallConfirmation,
519 cx: &mut Context<Self>,
520 ) -> ToolCallRequest {
521 let (tx, rx) = oneshot::channel();
522
523 let status = ToolCallStatus::WaitingForConfirmation {
524 confirmation: ToolCallConfirmation::from_acp(
525 confirmation,
526 self.project.read(cx).languages().clone(),
527 cx,
528 ),
529 respond_tx: tx,
530 };
531
532 let id = self.insert_tool_call(label, status, icon, content, cx);
533 ToolCallRequest { id, outcome: rx }
534 }
535
536 pub fn push_tool_call(
537 &mut self,
538 label: String,
539 icon: acp::Icon,
540 content: Option<acp::ToolCallContent>,
541 cx: &mut Context<Self>,
542 ) -> ToolCallId {
543 let status = ToolCallStatus::Allowed {
544 status: acp::ToolCallStatus::Running,
545 };
546
547 self.insert_tool_call(label, status, icon, content, cx)
548 }
549
550 fn insert_tool_call(
551 &mut self,
552 label: String,
553 status: ToolCallStatus,
554 icon: acp::Icon,
555 content: Option<acp::ToolCallContent>,
556 cx: &mut Context<Self>,
557 ) -> ToolCallId {
558 let language_registry = self.project.read(cx).languages().clone();
559
560 let entry_id = self.push_entry(
561 AgentThreadEntryContent::ToolCall(ToolCall {
562 // todo! clean up id creation
563 id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
564 label: cx.new(|cx| {
565 Markdown::new(label.into(), Some(language_registry.clone()), None, cx)
566 }),
567 icon: acp_icon_to_ui_icon(icon),
568 content: content
569 .map(|content| ToolCallContent::from_acp(content, language_registry, cx)),
570 status,
571 }),
572 cx,
573 );
574
575 ToolCallId(entry_id)
576 }
577
578 pub fn authorize_tool_call(
579 &mut self,
580 id: ToolCallId,
581 outcome: acp::ToolCallConfirmationOutcome,
582 cx: &mut Context<Self>,
583 ) {
584 let Some(entry) = self.entry_mut(id.0) else {
585 return;
586 };
587
588 let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
589 debug_panic!("expected ToolCall");
590 return;
591 };
592
593 let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
594 ToolCallStatus::Rejected
595 } else {
596 ToolCallStatus::Allowed {
597 status: acp::ToolCallStatus::Running,
598 }
599 };
600
601 let curr_status = mem::replace(&mut call.status, new_status);
602
603 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
604 respond_tx.send(outcome).log_err();
605 } else {
606 debug_panic!("tried to authorize an already authorized tool call");
607 }
608
609 cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
610 }
611
612 pub fn update_tool_call(
613 &mut self,
614 id: ToolCallId,
615 new_status: acp::ToolCallStatus,
616 new_content: Option<acp::ToolCallContent>,
617 cx: &mut Context<Self>,
618 ) -> Result<()> {
619 let language_registry = self.project.read(cx).languages().clone();
620 let entry = self.entry_mut(id.0).context("Entry not found")?;
621
622 match &mut entry.content {
623 AgentThreadEntryContent::ToolCall(call) => {
624 call.content = new_content.map(|new_content| {
625 ToolCallContent::from_acp(new_content, language_registry, cx)
626 });
627
628 match &mut call.status {
629 ToolCallStatus::Allowed { status } => {
630 *status = new_status;
631 }
632 ToolCallStatus::WaitingForConfirmation { .. } => {
633 anyhow::bail!("Tool call hasn't been authorized yet")
634 }
635 ToolCallStatus::Rejected => {
636 anyhow::bail!("Tool call was rejected and therefore can't be updated")
637 }
638 ToolCallStatus::Canceled => {
639 // todo! test this case with fake server
640 call.status = ToolCallStatus::Allowed { status: new_status };
641 }
642 }
643 }
644 _ => anyhow::bail!("Entry is not a tool call"),
645 }
646
647 cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
648 Ok(())
649 }
650
651 fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
652 let entry = self.entries.get_mut(id.0 as usize);
653 debug_assert!(
654 entry.is_some(),
655 "We shouldn't give out ids to entries that don't exist"
656 );
657 entry
658 }
659
660 /// Returns true if the last turn is awaiting tool authorization
661 pub fn waiting_for_tool_confirmation(&self) -> bool {
662 // todo!("should we use a hashmap?")
663 for entry in self.entries.iter().rev() {
664 match &entry.content {
665 AgentThreadEntryContent::ToolCall(call) => match call.status {
666 ToolCallStatus::WaitingForConfirmation { .. } => return true,
667 ToolCallStatus::Allowed { .. }
668 | ToolCallStatus::Rejected
669 | ToolCallStatus::Canceled => continue,
670 },
671 AgentThreadEntryContent::UserMessage(_)
672 | AgentThreadEntryContent::AssistantMessage(_) => {
673 // Reached the beginning of the turn
674 return false;
675 }
676 }
677 }
678 false
679 }
680
681 pub fn send(
682 &mut self,
683 message: &str,
684 cx: &mut Context<Self>,
685 ) -> impl use<> + Future<Output = Result<()>> {
686 let agent = self.server.clone();
687 let chunk =
688 UserMessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
689 let message = UserMessage {
690 chunks: vec![chunk],
691 };
692 self.push_entry(AgentThreadEntryContent::UserMessage(message.clone()), cx);
693 let acp_message = message.into_acp(cx);
694
695 let (tx, rx) = oneshot::channel();
696 let cancel = self.cancel(cx);
697
698 self.send_task = Some(cx.spawn(async move |this, cx| {
699 cancel.await.log_err();
700
701 let result = agent.send_message(acp_message, cx).await;
702 tx.send(result).log_err();
703 this.update(cx, |this, _cx| this.send_task.take()).log_err();
704 }));
705
706 async move {
707 match rx.await {
708 Ok(result) => result,
709 Err(_) => Ok(()),
710 }
711 }
712 }
713
714 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
715 let agent = self.server.clone();
716
717 if self.send_task.take().is_some() {
718 cx.spawn(async move |this, cx| {
719 agent.cancel_send_message(cx).await?;
720
721 this.update(cx, |this, _cx| {
722 for entry in this.entries.iter_mut() {
723 if let AgentThreadEntryContent::ToolCall(call) = &mut entry.content {
724 let cancel = matches!(
725 call.status,
726 ToolCallStatus::WaitingForConfirmation { .. }
727 | ToolCallStatus::Allowed {
728 status: acp::ToolCallStatus::Running
729 }
730 );
731
732 if cancel {
733 let curr_status =
734 mem::replace(&mut call.status, ToolCallStatus::Canceled);
735
736 if let ToolCallStatus::WaitingForConfirmation {
737 respond_tx, ..
738 } = curr_status
739 {
740 respond_tx
741 .send(acp::ToolCallConfirmationOutcome::Cancel)
742 .ok();
743 }
744 }
745 }
746 }
747 })
748 })
749 } else {
750 Task::ready(Ok(()))
751 }
752 }
753
754 #[cfg(test)]
755 pub fn to_string(&self, cx: &App) -> String {
756 let mut result = String::new();
757 for entry in &self.entries {
758 match &entry.content {
759 AgentThreadEntryContent::UserMessage(user_message) => {
760 result.push_str("# User\n");
761 for chunk in &user_message.chunks {
762 match chunk {
763 UserMessageChunk::Text { chunk } => {
764 result.push_str(chunk.read(cx).source());
765 result.push('\n');
766 }
767 _ => unimplemented!(),
768 }
769 }
770 }
771 AgentThreadEntryContent::AssistantMessage(assistant_message) => {
772 result.push_str("# Assistant\n");
773 for chunk in &assistant_message.chunks {
774 match chunk {
775 AssistantMessageChunk::Text { chunk } => {
776 result.push_str(chunk.read(cx).source());
777 result.push('\n')
778 }
779 AssistantMessageChunk::Thought { chunk } => {
780 result.push_str("<thinking>\n");
781 result.push_str(chunk.read(cx).source());
782 result.push_str("\n</thinking>\n");
783 }
784 }
785 }
786 }
787 AgentThreadEntryContent::ToolCall(_tool_call) => unimplemented!(),
788 }
789 }
790 result
791 }
792}
793
794fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
795 match icon {
796 acp::Icon::FileSearch => IconName::FileSearch,
797 acp::Icon::Folder => IconName::Folder,
798 acp::Icon::Globe => IconName::Globe,
799 acp::Icon::Hammer => IconName::Hammer,
800 acp::Icon::LightBulb => IconName::LightBulb,
801 acp::Icon::Pencil => IconName::Pencil,
802 acp::Icon::Regex => IconName::Regex,
803 acp::Icon::Terminal => IconName::Terminal,
804 }
805}
806
807pub struct ToolCallRequest {
808 pub id: ToolCallId,
809 pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
810}
811
812#[cfg(test)]
813mod tests {
814 use super::*;
815 use async_pipe::{PipeReader, PipeWriter};
816 use async_trait::async_trait;
817 use futures::{FutureExt as _, channel::mpsc, future::LocalBoxFuture, select};
818 use gpui::{AsyncApp, TestAppContext};
819 use indoc::indoc;
820 use project::FakeFs;
821 use serde_json::json;
822 use settings::SettingsStore;
823 use smol::{future::BoxedLocal, stream::StreamExt as _};
824 use std::{env, path::Path, process::Stdio, rc::Rc, time::Duration};
825 use util::path;
826
827 fn init_test(cx: &mut TestAppContext) {
828 env_logger::try_init().ok();
829 cx.update(|cx| {
830 let settings_store = SettingsStore::test(cx);
831 cx.set_global(settings_store);
832 Project::init_settings(cx);
833 language::init(cx);
834 });
835 }
836
837 #[gpui::test]
838 async fn test_thinking_concatenation(cx: &mut TestAppContext) {
839 init_test(cx);
840
841 cx.executor().allow_parking();
842
843 let fs = FakeFs::new(cx.executor());
844 let project = Project::test(fs, [], cx).await;
845 let (server, fake_server) = fake_acp_server(project, cx);
846
847 server.initialize().await.unwrap();
848
849 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
850
851 fake_server.update(cx, |fake_server, _| {
852 fake_server.on_user_message(move |params, server, mut cx| async move {
853 server
854 .update(&mut cx, |server, _| {
855 server.send_to_zed(acp::StreamAssistantMessageChunkParams {
856 chunk: acp::AssistantMessageChunk::Thought {
857 chunk: "Thinking ".into(),
858 },
859 })
860 })?
861 .await
862 .unwrap();
863 server
864 .update(&mut cx, |server, _| {
865 server.send_to_zed(acp::StreamAssistantMessageChunkParams {
866 chunk: acp::AssistantMessageChunk::Thought {
867 chunk: "hard!".into(),
868 },
869 })
870 })?
871 .await
872 .unwrap();
873
874 Ok(acp::SendUserMessageResponse)
875 })
876 });
877
878 thread
879 .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
880 .await
881 .unwrap();
882
883 let output = thread.read_with(cx, |thread, cx| thread.to_string(cx));
884 assert_eq!(
885 output,
886 indoc! {r#"
887 # User
888 Hello from Zed!
889 # Assistant
890 <thinking>
891 Thinking hard!
892 </thinking>
893 "#}
894 );
895 }
896
897 #[gpui::test]
898 async fn test_gemini_basic(cx: &mut TestAppContext) {
899 init_test(cx);
900
901 cx.executor().allow_parking();
902
903 let fs = FakeFs::new(cx.executor());
904 let project = Project::test(fs, [], cx).await;
905 let server = gemini_acp_server(project.clone(), cx).await;
906 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
907 thread
908 .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
909 .await
910 .unwrap();
911
912 thread.read_with(cx, |thread, _| {
913 assert_eq!(thread.entries.len(), 2);
914 assert!(matches!(
915 thread.entries[0].content,
916 AgentThreadEntryContent::UserMessage(_)
917 ));
918 assert!(matches!(
919 thread.entries[1].content,
920 AgentThreadEntryContent::AssistantMessage(_)
921 ));
922 });
923 }
924
925 #[gpui::test]
926 async fn test_gemini_tool_call(cx: &mut TestAppContext) {
927 init_test(cx);
928
929 cx.executor().allow_parking();
930
931 let fs = FakeFs::new(cx.executor());
932 fs.insert_tree(
933 path!("/private/tmp"),
934 json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
935 )
936 .await;
937 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
938 let server = gemini_acp_server(project.clone(), cx).await;
939 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
940 thread
941 .update(cx, |thread, cx| {
942 thread.send(
943 "Read the '/private/tmp/foo' file and tell me what you see.",
944 cx,
945 )
946 })
947 .await
948 .unwrap();
949 thread.read_with(cx, |thread, _cx| {
950 assert!(matches!(
951 &thread.entries()[2].content,
952 AgentThreadEntryContent::ToolCall(ToolCall {
953 status: ToolCallStatus::Allowed { .. },
954 ..
955 })
956 ));
957
958 assert!(matches!(
959 thread.entries[3].content,
960 AgentThreadEntryContent::AssistantMessage(_)
961 ));
962 });
963 }
964
965 #[gpui::test]
966 async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
967 init_test(cx);
968
969 cx.executor().allow_parking();
970
971 let fs = FakeFs::new(cx.executor());
972 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
973 let server = gemini_acp_server(project.clone(), cx).await;
974 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
975 let full_turn = thread.update(cx, |thread, cx| {
976 thread.send(r#"Run `echo "Hello, world!"`"#, cx)
977 });
978
979 run_until_first_tool_call(&thread, cx).await;
980
981 let tool_call_id = thread.read_with(cx, |thread, _cx| {
982 let AgentThreadEntryContent::ToolCall(ToolCall {
983 id,
984 status:
985 ToolCallStatus::WaitingForConfirmation {
986 confirmation: ToolCallConfirmation::Execute { root_command, .. },
987 ..
988 },
989 ..
990 }) = &thread.entries()[2].content
991 else {
992 panic!();
993 };
994
995 assert_eq!(root_command, "echo");
996
997 *id
998 });
999
1000 thread.update(cx, |thread, cx| {
1001 thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
1002
1003 assert!(matches!(
1004 &thread.entries()[2].content,
1005 AgentThreadEntryContent::ToolCall(ToolCall {
1006 status: ToolCallStatus::Allowed { .. },
1007 ..
1008 })
1009 ));
1010 });
1011
1012 full_turn.await.unwrap();
1013
1014 thread.read_with(cx, |thread, cx| {
1015 let AgentThreadEntryContent::ToolCall(ToolCall {
1016 content: Some(ToolCallContent::Markdown { markdown }),
1017 status: ToolCallStatus::Allowed { .. },
1018 ..
1019 }) = &thread.entries()[2].content
1020 else {
1021 panic!();
1022 };
1023
1024 markdown.read_with(cx, |md, _cx| {
1025 assert!(
1026 md.source().contains("Hello, world!"),
1027 r#"Expected '{}' to contain "Hello, world!""#,
1028 md.source()
1029 );
1030 });
1031 });
1032 }
1033
1034 #[gpui::test]
1035 async fn test_gemini_cancel(cx: &mut TestAppContext) {
1036 init_test(cx);
1037
1038 cx.executor().allow_parking();
1039
1040 let fs = FakeFs::new(cx.executor());
1041 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
1042 let server = gemini_acp_server(project.clone(), cx).await;
1043 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
1044 let full_turn = thread.update(cx, |thread, cx| {
1045 thread.send(r#"Run `echo "Hello, world!"`"#, cx)
1046 });
1047
1048 let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
1049
1050 thread.read_with(cx, |thread, _cx| {
1051 let AgentThreadEntryContent::ToolCall(ToolCall {
1052 id,
1053 status:
1054 ToolCallStatus::WaitingForConfirmation {
1055 confirmation: ToolCallConfirmation::Execute { root_command, .. },
1056 ..
1057 },
1058 ..
1059 }) = &thread.entries()[first_tool_call_ix].content
1060 else {
1061 panic!("{:?}", thread.entries()[1].content);
1062 };
1063
1064 assert_eq!(root_command, "echo");
1065
1066 *id
1067 });
1068
1069 thread
1070 .update(cx, |thread, cx| thread.cancel(cx))
1071 .await
1072 .unwrap();
1073 full_turn.await.unwrap();
1074 thread.read_with(cx, |thread, _| {
1075 let AgentThreadEntryContent::ToolCall(ToolCall {
1076 status: ToolCallStatus::Canceled,
1077 ..
1078 }) = &thread.entries()[first_tool_call_ix].content
1079 else {
1080 panic!();
1081 };
1082 });
1083
1084 thread
1085 .update(cx, |thread, cx| {
1086 thread.send(r#"Stop running and say goodbye to me."#, cx)
1087 })
1088 .await
1089 .unwrap();
1090 thread.read_with(cx, |thread, _| {
1091 assert!(matches!(
1092 &thread.entries().last().unwrap().content,
1093 AgentThreadEntryContent::AssistantMessage(..),
1094 ))
1095 });
1096 }
1097
1098 async fn run_until_first_tool_call(
1099 thread: &Entity<AcpThread>,
1100 cx: &mut TestAppContext,
1101 ) -> usize {
1102 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1103
1104 let subscription = cx.update(|cx| {
1105 cx.subscribe(thread, move |thread, _, cx| {
1106 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1107 if matches!(entry.content, AgentThreadEntryContent::ToolCall(_)) {
1108 return tx.try_send(ix).unwrap();
1109 }
1110 }
1111 })
1112 });
1113
1114 select! {
1115 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1116 panic!("Timeout waiting for tool call")
1117 }
1118 ix = rx.next().fuse() => {
1119 drop(subscription);
1120 ix.unwrap()
1121 }
1122 }
1123 }
1124
1125 pub async fn gemini_acp_server(
1126 project: Entity<Project>,
1127 cx: &mut TestAppContext,
1128 ) -> Arc<AcpServer> {
1129 let cli_path =
1130 Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
1131 let mut command = util::command::new_smol_command("node");
1132 command
1133 .arg(cli_path)
1134 .arg("--acp")
1135 .current_dir("/private/tmp")
1136 .stdin(Stdio::piped())
1137 .stdout(Stdio::piped())
1138 .stderr(Stdio::inherit())
1139 .kill_on_drop(true);
1140
1141 if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
1142 command.env("GEMINI_API_KEY", gemini_key);
1143 }
1144
1145 let child = command.spawn().unwrap();
1146 let server = cx.update(|cx| AcpServer::stdio(child, project, cx));
1147 server.initialize().await.unwrap();
1148 server
1149 }
1150
1151 pub fn fake_acp_server(
1152 project: Entity<Project>,
1153 cx: &mut TestAppContext,
1154 ) -> (Entity<Thread>, Arc<AcpServer>, Entity<FakeAcpServer>) {
1155 let (stdin_tx, stdin_rx) = async_pipe::pipe();
1156 let (stdout_tx, stdout_rx) = async_pipe::pipe();
1157 let server = cx.update(|cx| AcpServer::fake(stdin_tx, stdout_rx, project, cx));
1158 let thread = server.thread.upgrade().unwrap();
1159 let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
1160 (server, agent)
1161 }
1162
1163 pub struct FakeAcpServer {
1164 connection: acp::ClientConnection,
1165 _handler_task: Task<()>,
1166 _io_task: Task<()>,
1167 on_user_message: Option<
1168 Rc<
1169 dyn Fn(
1170 acp::SendUserMessageParams,
1171 Entity<FakeAcpServer>,
1172 AsyncApp,
1173 )
1174 -> LocalBoxFuture<'static, Result<acp::SendUserMessageResponse>>,
1175 >,
1176 >,
1177 }
1178
1179 #[derive(Clone)]
1180 struct FakeAgent {
1181 server: Entity<FakeAcpServer>,
1182 cx: AsyncApp,
1183 }
1184
1185 #[async_trait(?Send)]
1186 impl acp::Agent for FakeAgent {
1187 async fn initialize(
1188 &self,
1189 _request: acp::InitializeParams,
1190 ) -> Result<acp::InitializeResponse> {
1191 Ok(acp::InitializeResponse {
1192 is_authenticated: true,
1193 })
1194 }
1195
1196 async fn authenticate(
1197 &self,
1198 _request: acp::AuthenticateParams,
1199 ) -> Result<acp::AuthenticateResponse> {
1200 Ok(acp::AuthenticateResponse)
1201 }
1202
1203 async fn send_user_message(
1204 &self,
1205 request: acp::SendUserMessageParams,
1206 ) -> Result<acp::SendUserMessageResponse> {
1207 let mut cx = self.cx.clone();
1208 let handler = self
1209 .server
1210 .update(&mut cx, |server, _| server.on_user_message.clone())
1211 .ok()
1212 .flatten();
1213 if let Some(handler) = handler {
1214 handler(request, self.server.clone(), self.cx.clone()).await
1215 } else {
1216 anyhow::bail!("No handler for on_user_message")
1217 }
1218 }
1219 }
1220
1221 impl FakeAcpServer {
1222 fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
1223 let agent = FakeAgent {
1224 server: cx.entity(),
1225 cx: cx.to_async(),
1226 };
1227
1228 let (connection, handler_fut, io_fut) =
1229 acp::ClientConnection::connect_to_client(agent.clone(), stdout, stdin);
1230 FakeAcpServer {
1231 connection: connection,
1232 on_user_message: None,
1233 _handler_task: cx.foreground_executor().spawn(handler_fut),
1234 _io_task: cx.background_spawn(async move {
1235 io_fut.await.log_err();
1236 }),
1237 }
1238 }
1239
1240 fn on_user_message<F>(
1241 &mut self,
1242 handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity<FakeAcpServer>, AsyncApp) -> F
1243 + 'static,
1244 ) where
1245 F: Future<Output = Result<acp::SendUserMessageResponse>> + 'static,
1246 {
1247 self.on_user_message
1248 .replace(Rc::new(move |request, server, cx| {
1249 handler(request, server, cx).boxed_local()
1250 }));
1251 }
1252
1253 fn send_to_zed<T: acp::ClientRequest>(
1254 &self,
1255 message: T,
1256 ) -> BoxedLocal<Result<T::Response, acp::Error>> {
1257 self.connection.request(message).boxed_local()
1258 }
1259 }
1260}