1pub use acp::ToolCallId;
2use agent_servers::AgentServer;
3use agentic_coding_protocol::{self as acp, UserMessageChunk};
4use anyhow::{Context as _, Result, anyhow};
5use assistant_tool::ActionLog;
6use buffer_diff::BufferDiff;
7use editor::{MultiBuffer, PathKey};
8use futures::{FutureExt, channel::oneshot, future::BoxFuture};
9use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
10use itertools::Itertools;
11use language::{
12 Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
13 text_diff,
14};
15use markdown::Markdown;
16use project::{AgentLocation, Project};
17use std::collections::HashMap;
18use std::error::Error;
19use std::fmt::{Formatter, Write};
20use std::{
21 fmt::Display,
22 mem,
23 path::{Path, PathBuf},
24 sync::Arc,
25};
26use ui::{App, IconName};
27use util::ResultExt;
28
29#[derive(Clone, Debug, Eq, PartialEq)]
30pub struct UserMessage {
31 pub content: Entity<Markdown>,
32}
33
34impl UserMessage {
35 pub fn from_acp(
36 message: &acp::SendUserMessageParams,
37 language_registry: Arc<LanguageRegistry>,
38 cx: &mut App,
39 ) -> Self {
40 let mut md_source = String::new();
41
42 for chunk in &message.chunks {
43 match chunk {
44 UserMessageChunk::Text { text } => md_source.push_str(&text),
45 UserMessageChunk::Path { path } => {
46 write!(&mut md_source, "{}", MentionPath(&path)).unwrap()
47 }
48 }
49 }
50
51 Self {
52 content: cx
53 .new(|cx| Markdown::new(md_source.into(), Some(language_registry), None, cx)),
54 }
55 }
56
57 fn to_markdown(&self, cx: &App) -> String {
58 format!("## User\n\n{}\n\n", self.content.read(cx).source())
59 }
60}
61
62#[derive(Debug)]
63pub struct MentionPath<'a>(&'a Path);
64
65impl<'a> MentionPath<'a> {
66 const PREFIX: &'static str = "@file:";
67
68 pub fn new(path: &'a Path) -> Self {
69 MentionPath(path)
70 }
71
72 pub fn try_parse(url: &'a str) -> Option<Self> {
73 let path = url.strip_prefix(Self::PREFIX)?;
74 Some(MentionPath(Path::new(path)))
75 }
76
77 pub fn path(&self) -> &Path {
78 self.0
79 }
80}
81
82impl Display for MentionPath<'_> {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 write!(
85 f,
86 "[@{}]({}{})",
87 self.0.file_name().unwrap_or_default().display(),
88 Self::PREFIX,
89 self.0.display()
90 )
91 }
92}
93
94#[derive(Clone, Debug, Eq, PartialEq)]
95pub struct AssistantMessage {
96 pub chunks: Vec<AssistantMessageChunk>,
97}
98
99impl AssistantMessage {
100 fn to_markdown(&self, cx: &App) -> String {
101 format!(
102 "## Assistant\n\n{}\n\n",
103 self.chunks
104 .iter()
105 .map(|chunk| chunk.to_markdown(cx))
106 .join("\n\n")
107 )
108 }
109}
110
111#[derive(Clone, Debug, Eq, PartialEq)]
112pub enum AssistantMessageChunk {
113 Text { chunk: Entity<Markdown> },
114 Thought { chunk: Entity<Markdown> },
115}
116
117impl AssistantMessageChunk {
118 pub fn from_acp(
119 chunk: acp::AssistantMessageChunk,
120 language_registry: Arc<LanguageRegistry>,
121 cx: &mut App,
122 ) -> Self {
123 match chunk {
124 acp::AssistantMessageChunk::Text { text } => Self::Text {
125 chunk: cx.new(|cx| Markdown::new(text.into(), Some(language_registry), None, cx)),
126 },
127 acp::AssistantMessageChunk::Thought { thought } => Self::Thought {
128 chunk: cx
129 .new(|cx| Markdown::new(thought.into(), Some(language_registry), None, cx)),
130 },
131 }
132 }
133
134 pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
135 Self::Text {
136 chunk: cx.new(|cx| {
137 Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
138 }),
139 }
140 }
141
142 fn to_markdown(&self, cx: &App) -> String {
143 match self {
144 Self::Text { chunk } => chunk.read(cx).source().to_string(),
145 Self::Thought { chunk } => {
146 format!("<thinking>\n{}\n</thinking>", chunk.read(cx).source())
147 }
148 }
149 }
150}
151
152#[derive(Debug)]
153pub enum AgentThreadEntry {
154 UserMessage(UserMessage),
155 AssistantMessage(AssistantMessage),
156 ToolCall(ToolCall),
157}
158
159impl AgentThreadEntry {
160 fn to_markdown(&self, cx: &App) -> String {
161 match self {
162 Self::UserMessage(message) => message.to_markdown(cx),
163 Self::AssistantMessage(message) => message.to_markdown(cx),
164 Self::ToolCall(too_call) => too_call.to_markdown(cx),
165 }
166 }
167
168 pub fn diff(&self) -> Option<&Diff> {
169 if let AgentThreadEntry::ToolCall(ToolCall {
170 content: Some(ToolCallContent::Diff { diff }),
171 ..
172 }) = self
173 {
174 Some(&diff)
175 } else {
176 None
177 }
178 }
179
180 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
181 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
182 Some(locations)
183 } else {
184 None
185 }
186 }
187}
188
189#[derive(Debug)]
190pub struct ToolCall {
191 pub id: acp::ToolCallId,
192 pub label: Entity<Markdown>,
193 pub icon: IconName,
194 pub content: Option<ToolCallContent>,
195 pub status: ToolCallStatus,
196 pub locations: Vec<acp::ToolCallLocation>,
197}
198
199impl ToolCall {
200 fn to_markdown(&self, cx: &App) -> String {
201 let mut markdown = format!(
202 "**Tool Call: {}**\nStatus: {}\n\n",
203 self.label.read(cx).source(),
204 self.status
205 );
206 if let Some(content) = &self.content {
207 markdown.push_str(content.to_markdown(cx).as_str());
208 markdown.push_str("\n\n");
209 }
210 markdown
211 }
212}
213
214#[derive(Debug)]
215pub enum ToolCallStatus {
216 WaitingForConfirmation {
217 confirmation: ToolCallConfirmation,
218 respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
219 },
220 Allowed {
221 status: acp::ToolCallStatus,
222 },
223 Rejected,
224 Canceled,
225}
226
227impl Display for ToolCallStatus {
228 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
229 write!(
230 f,
231 "{}",
232 match self {
233 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
234 ToolCallStatus::Allowed { status } => match status {
235 acp::ToolCallStatus::Running => "Running",
236 acp::ToolCallStatus::Finished => "Finished",
237 acp::ToolCallStatus::Error => "Error",
238 },
239 ToolCallStatus::Rejected => "Rejected",
240 ToolCallStatus::Canceled => "Canceled",
241 }
242 )
243 }
244}
245
246#[derive(Debug)]
247pub enum ToolCallConfirmation {
248 Edit {
249 description: Option<Entity<Markdown>>,
250 },
251 Execute {
252 command: String,
253 root_command: String,
254 description: Option<Entity<Markdown>>,
255 },
256 Mcp {
257 server_name: String,
258 tool_name: String,
259 tool_display_name: String,
260 description: Option<Entity<Markdown>>,
261 },
262 Fetch {
263 urls: Vec<SharedString>,
264 description: Option<Entity<Markdown>>,
265 },
266 Other {
267 description: Entity<Markdown>,
268 },
269}
270
271impl ToolCallConfirmation {
272 pub fn from_acp(
273 confirmation: acp::ToolCallConfirmation,
274 language_registry: Arc<LanguageRegistry>,
275 cx: &mut App,
276 ) -> Self {
277 let to_md = |description: String, cx: &mut App| -> Entity<Markdown> {
278 cx.new(|cx| {
279 Markdown::new(
280 description.into(),
281 Some(language_registry.clone()),
282 None,
283 cx,
284 )
285 })
286 };
287
288 match confirmation {
289 acp::ToolCallConfirmation::Edit { description } => Self::Edit {
290 description: description.map(|description| to_md(description, cx)),
291 },
292 acp::ToolCallConfirmation::Execute {
293 command,
294 root_command,
295 description,
296 } => Self::Execute {
297 command,
298 root_command,
299 description: description.map(|description| to_md(description, cx)),
300 },
301 acp::ToolCallConfirmation::Mcp {
302 server_name,
303 tool_name,
304 tool_display_name,
305 description,
306 } => Self::Mcp {
307 server_name,
308 tool_name,
309 tool_display_name,
310 description: description.map(|description| to_md(description, cx)),
311 },
312 acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch {
313 urls: urls.iter().map(|url| url.into()).collect(),
314 description: description.map(|description| to_md(description, cx)),
315 },
316 acp::ToolCallConfirmation::Other { description } => Self::Other {
317 description: to_md(description, cx),
318 },
319 }
320 }
321}
322
323#[derive(Debug)]
324pub enum ToolCallContent {
325 Markdown { markdown: Entity<Markdown> },
326 Diff { diff: Diff },
327}
328
329impl ToolCallContent {
330 pub fn from_acp(
331 content: acp::ToolCallContent,
332 language_registry: Arc<LanguageRegistry>,
333 cx: &mut App,
334 ) -> Self {
335 match content {
336 acp::ToolCallContent::Markdown { markdown } => Self::Markdown {
337 markdown: cx.new(|cx| Markdown::new_text(markdown.into(), cx)),
338 },
339 acp::ToolCallContent::Diff { diff } => Self::Diff {
340 diff: Diff::from_acp(diff, language_registry, cx),
341 },
342 }
343 }
344
345 fn to_markdown(&self, cx: &App) -> String {
346 match self {
347 Self::Markdown { markdown } => markdown.read(cx).source().to_string(),
348 Self::Diff { diff } => diff.to_markdown(cx),
349 }
350 }
351}
352
353#[derive(Debug)]
354pub struct Diff {
355 pub multibuffer: Entity<MultiBuffer>,
356 pub path: PathBuf,
357 pub new_buffer: Entity<Buffer>,
358 pub old_buffer: Entity<Buffer>,
359 _task: Task<Result<()>>,
360}
361
362impl Diff {
363 pub fn from_acp(
364 diff: acp::Diff,
365 language_registry: Arc<LanguageRegistry>,
366 cx: &mut App,
367 ) -> Self {
368 let acp::Diff {
369 path,
370 old_text,
371 new_text,
372 } = diff;
373
374 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
375
376 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
377 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
378 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
379 let old_buffer_snapshot = old_buffer.read(cx).snapshot();
380 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
381 let diff_task = buffer_diff.update(cx, |diff, cx| {
382 diff.set_base_text(
383 old_buffer_snapshot,
384 Some(language_registry.clone()),
385 new_buffer_snapshot,
386 cx,
387 )
388 });
389
390 let task = cx.spawn({
391 let multibuffer = multibuffer.clone();
392 let path = path.clone();
393 let new_buffer = new_buffer.clone();
394 async move |cx| {
395 diff_task.await?;
396
397 multibuffer
398 .update(cx, |multibuffer, cx| {
399 let hunk_ranges = {
400 let buffer = new_buffer.read(cx);
401 let diff = buffer_diff.read(cx);
402 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
403 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
404 .collect::<Vec<_>>()
405 };
406
407 multibuffer.set_excerpts_for_path(
408 PathKey::for_buffer(&new_buffer, cx),
409 new_buffer.clone(),
410 hunk_ranges,
411 editor::DEFAULT_MULTIBUFFER_CONTEXT,
412 cx,
413 );
414 multibuffer.add_diff(buffer_diff.clone(), cx);
415 })
416 .log_err();
417
418 if let Some(language) = language_registry
419 .language_for_file_path(&path)
420 .await
421 .log_err()
422 {
423 new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
424 }
425
426 anyhow::Ok(())
427 }
428 });
429
430 Self {
431 multibuffer,
432 path,
433 new_buffer,
434 old_buffer,
435 _task: task,
436 }
437 }
438
439 fn to_markdown(&self, cx: &App) -> String {
440 let buffer_text = self
441 .multibuffer
442 .read(cx)
443 .all_buffers()
444 .iter()
445 .map(|buffer| buffer.read(cx).text())
446 .join("\n");
447 format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
448 }
449}
450
451pub struct AcpThread {
452 entries: Vec<AgentThreadEntry>,
453 title: SharedString,
454 project: Entity<Project>,
455 action_log: Entity<ActionLog>,
456 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
457 send_task: Option<Task<()>>,
458 connection: Arc<acp::AgentConnection>,
459 child_status: Option<Task<Result<()>>>,
460 _io_task: Task<()>,
461}
462
463pub enum AcpThreadEvent {
464 NewEntry,
465 EntryUpdated(usize),
466}
467
468impl EventEmitter<AcpThreadEvent> for AcpThread {}
469
470#[derive(PartialEq, Eq)]
471pub enum ThreadStatus {
472 Idle,
473 WaitingForToolConfirmation,
474 Generating,
475}
476
477#[derive(Debug, Clone)]
478pub enum LoadError {
479 Unsupported { current_version: SharedString },
480 Exited(i32),
481 Other(SharedString),
482}
483
484impl Display for LoadError {
485 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
486 match self {
487 LoadError::Unsupported { current_version } => {
488 write!(
489 f,
490 "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).",
491 current_version
492 )
493 }
494 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
495 LoadError::Other(msg) => write!(f, "{}", msg),
496 }
497 }
498}
499
500impl Error for LoadError {}
501
502impl AcpThread {
503 pub async fn spawn(
504 server: impl AgentServer + 'static,
505 root_dir: &Path,
506 project: Entity<Project>,
507 cx: &mut AsyncApp,
508 ) -> Result<Entity<Self>> {
509 let command = match server.command(&project, cx).await {
510 Ok(command) => command,
511 Err(e) => return Err(anyhow!(LoadError::Other(format!("{e}").into()))),
512 };
513
514 let mut child = util::command::new_smol_command(&command.path)
515 .args(command.args.iter())
516 .current_dir(root_dir)
517 .stdin(std::process::Stdio::piped())
518 .stdout(std::process::Stdio::piped())
519 .stderr(std::process::Stdio::inherit())
520 .kill_on_drop(true)
521 .spawn()?;
522
523 let stdin = child.stdin.take().unwrap();
524 let stdout = child.stdout.take().unwrap();
525
526 cx.new(|cx| {
527 let foreground_executor = cx.foreground_executor().clone();
528
529 let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
530 AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
531 stdin,
532 stdout,
533 move |fut| foreground_executor.spawn(fut).detach(),
534 );
535
536 let io_task = cx.background_spawn(async move {
537 io_fut.await.log_err();
538 });
539
540 let child_status = cx.background_spawn(async move {
541 match child.status().await {
542 Err(e) => Err(anyhow!(e)),
543 Ok(result) if result.success() => Ok(()),
544 Ok(result) => {
545 if let Some(version) = server.version(&command).await.log_err()
546 && !version.supported
547 {
548 Err(anyhow!(LoadError::Unsupported {
549 current_version: version.current_version
550 }))
551 } else {
552 Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
553 }
554 }
555 }
556 });
557
558 let action_log = cx.new(|_| ActionLog::new(project.clone()));
559
560 Self {
561 action_log,
562 shared_buffers: Default::default(),
563 entries: Default::default(),
564 title: "ACP Thread".into(),
565 project,
566 send_task: None,
567 connection: Arc::new(connection),
568 child_status: Some(child_status),
569 _io_task: io_task,
570 }
571 })
572 }
573
574 pub fn action_log(&self) -> &Entity<ActionLog> {
575 &self.action_log
576 }
577
578 pub fn project(&self) -> &Entity<Project> {
579 &self.project
580 }
581
582 #[cfg(test)]
583 pub fn fake(
584 stdin: async_pipe::PipeWriter,
585 stdout: async_pipe::PipeReader,
586 project: Entity<Project>,
587 cx: &mut Context<Self>,
588 ) -> Self {
589 let foreground_executor = cx.foreground_executor().clone();
590
591 let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
592 AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
593 stdin,
594 stdout,
595 move |fut| {
596 foreground_executor.spawn(fut).detach();
597 },
598 );
599
600 let io_task = cx.background_spawn({
601 async move {
602 io_fut.await.log_err();
603 }
604 });
605
606 let action_log = cx.new(|_| ActionLog::new(project.clone()));
607
608 Self {
609 action_log,
610 shared_buffers: Default::default(),
611 entries: Default::default(),
612 title: "ACP Thread".into(),
613 project,
614 send_task: None,
615 connection: Arc::new(connection),
616 child_status: None,
617 _io_task: io_task,
618 }
619 }
620
621 pub fn title(&self) -> SharedString {
622 self.title.clone()
623 }
624
625 pub fn entries(&self) -> &[AgentThreadEntry] {
626 &self.entries
627 }
628
629 pub fn status(&self) -> ThreadStatus {
630 if self.send_task.is_some() {
631 if self.waiting_for_tool_confirmation() {
632 ThreadStatus::WaitingForToolConfirmation
633 } else {
634 ThreadStatus::Generating
635 }
636 } else {
637 ThreadStatus::Idle
638 }
639 }
640
641 pub fn has_pending_edit_tool_calls(&self) -> bool {
642 for entry in self.entries.iter().rev() {
643 match entry {
644 AgentThreadEntry::UserMessage(_) => return false,
645 AgentThreadEntry::ToolCall(ToolCall {
646 status:
647 ToolCallStatus::Allowed {
648 status: acp::ToolCallStatus::Running,
649 ..
650 },
651 content: Some(ToolCallContent::Diff { .. }),
652 ..
653 }) => return true,
654 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
655 }
656 }
657
658 false
659 }
660
661 pub fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
662 self.entries.push(entry);
663 cx.emit(AcpThreadEvent::NewEntry);
664 }
665
666 pub fn push_assistant_chunk(
667 &mut self,
668 chunk: acp::AssistantMessageChunk,
669 cx: &mut Context<Self>,
670 ) {
671 let entries_len = self.entries.len();
672 if let Some(last_entry) = self.entries.last_mut()
673 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
674 {
675 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
676
677 match (chunks.last_mut(), &chunk) {
678 (
679 Some(AssistantMessageChunk::Text { chunk: old_chunk }),
680 acp::AssistantMessageChunk::Text { text: new_chunk },
681 )
682 | (
683 Some(AssistantMessageChunk::Thought { chunk: old_chunk }),
684 acp::AssistantMessageChunk::Thought { thought: new_chunk },
685 ) => {
686 old_chunk.update(cx, |old_chunk, cx| {
687 old_chunk.append(&new_chunk, cx);
688 });
689 }
690 _ => {
691 chunks.push(AssistantMessageChunk::from_acp(
692 chunk,
693 self.project.read(cx).languages().clone(),
694 cx,
695 ));
696 }
697 }
698 } else {
699 let chunk = AssistantMessageChunk::from_acp(
700 chunk,
701 self.project.read(cx).languages().clone(),
702 cx,
703 );
704
705 self.push_entry(
706 AgentThreadEntry::AssistantMessage(AssistantMessage {
707 chunks: vec![chunk],
708 }),
709 cx,
710 );
711 }
712 }
713
714 pub fn request_tool_call(
715 &mut self,
716 tool_call: acp::RequestToolCallConfirmationParams,
717 cx: &mut Context<Self>,
718 ) -> ToolCallRequest {
719 let (tx, rx) = oneshot::channel();
720
721 let status = ToolCallStatus::WaitingForConfirmation {
722 confirmation: ToolCallConfirmation::from_acp(
723 tool_call.confirmation,
724 self.project.read(cx).languages().clone(),
725 cx,
726 ),
727 respond_tx: tx,
728 };
729
730 let id = self.insert_tool_call(tool_call.tool_call, status, cx);
731 ToolCallRequest { id, outcome: rx }
732 }
733
734 pub fn push_tool_call(
735 &mut self,
736 request: acp::PushToolCallParams,
737 cx: &mut Context<Self>,
738 ) -> acp::ToolCallId {
739 let status = ToolCallStatus::Allowed {
740 status: acp::ToolCallStatus::Running,
741 };
742
743 self.insert_tool_call(request, status, cx)
744 }
745
746 fn insert_tool_call(
747 &mut self,
748 tool_call: acp::PushToolCallParams,
749 status: ToolCallStatus,
750 cx: &mut Context<Self>,
751 ) -> acp::ToolCallId {
752 let language_registry = self.project.read(cx).languages().clone();
753 let id = acp::ToolCallId(self.entries.len() as u64);
754 let call = ToolCall {
755 id,
756 label: cx.new(|cx| {
757 Markdown::new(
758 tool_call.label.into(),
759 Some(language_registry.clone()),
760 None,
761 cx,
762 )
763 }),
764 icon: acp_icon_to_ui_icon(tool_call.icon),
765 content: tool_call
766 .content
767 .map(|content| ToolCallContent::from_acp(content, language_registry, cx)),
768 locations: tool_call.locations,
769 status,
770 };
771
772 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
773
774 id
775 }
776
777 pub fn authorize_tool_call(
778 &mut self,
779 id: acp::ToolCallId,
780 outcome: acp::ToolCallConfirmationOutcome,
781 cx: &mut Context<Self>,
782 ) {
783 let Some((ix, call)) = self.tool_call_mut(id) else {
784 return;
785 };
786
787 let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
788 ToolCallStatus::Rejected
789 } else {
790 ToolCallStatus::Allowed {
791 status: acp::ToolCallStatus::Running,
792 }
793 };
794
795 let curr_status = mem::replace(&mut call.status, new_status);
796
797 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
798 respond_tx.send(outcome).log_err();
799 } else if cfg!(debug_assertions) {
800 panic!("tried to authorize an already authorized tool call");
801 }
802
803 cx.emit(AcpThreadEvent::EntryUpdated(ix));
804 }
805
806 pub fn update_tool_call(
807 &mut self,
808 id: acp::ToolCallId,
809 new_status: acp::ToolCallStatus,
810 new_content: Option<acp::ToolCallContent>,
811 cx: &mut Context<Self>,
812 ) -> Result<()> {
813 let language_registry = self.project.read(cx).languages().clone();
814 let (ix, call) = self.tool_call_mut(id).context("Entry not found")?;
815
816 call.content = new_content
817 .map(|new_content| ToolCallContent::from_acp(new_content, language_registry, cx));
818
819 match &mut call.status {
820 ToolCallStatus::Allowed { status } => {
821 *status = new_status;
822 }
823 ToolCallStatus::WaitingForConfirmation { .. } => {
824 anyhow::bail!("Tool call hasn't been authorized yet")
825 }
826 ToolCallStatus::Rejected => {
827 anyhow::bail!("Tool call was rejected and therefore can't be updated")
828 }
829 ToolCallStatus::Canceled => {
830 call.status = ToolCallStatus::Allowed { status: new_status };
831 }
832 }
833
834 cx.emit(AcpThreadEvent::EntryUpdated(ix));
835 Ok(())
836 }
837
838 fn tool_call_mut(&mut self, id: acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
839 let entry = self.entries.get_mut(id.0 as usize);
840 debug_assert!(
841 entry.is_some(),
842 "We shouldn't give out ids to entries that don't exist"
843 );
844 match entry {
845 Some(AgentThreadEntry::ToolCall(call)) if call.id == id => Some((id.0 as usize, call)),
846 _ => {
847 if cfg!(debug_assertions) {
848 panic!("entry is not a tool call");
849 }
850 None
851 }
852 }
853 }
854
855 /// Returns true if the last turn is awaiting tool authorization
856 pub fn waiting_for_tool_confirmation(&self) -> bool {
857 for entry in self.entries.iter().rev() {
858 match &entry {
859 AgentThreadEntry::ToolCall(call) => match call.status {
860 ToolCallStatus::WaitingForConfirmation { .. } => return true,
861 ToolCallStatus::Allowed { .. }
862 | ToolCallStatus::Rejected
863 | ToolCallStatus::Canceled => continue,
864 },
865 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
866 // Reached the beginning of the turn
867 return false;
868 }
869 }
870 }
871 false
872 }
873
874 pub fn initialize(
875 &self,
876 ) -> impl use<> + Future<Output = Result<acp::InitializeResponse, acp::Error>> {
877 let connection = self.connection.clone();
878 async move { connection.initialize().await }
879 }
880
881 pub fn authenticate(&self) -> impl use<> + Future<Output = Result<(), acp::Error>> {
882 let connection = self.connection.clone();
883 async move { connection.request(acp::AuthenticateParams).await }
884 }
885
886 #[cfg(test)]
887 pub fn send_raw(
888 &mut self,
889 message: &str,
890 cx: &mut Context<Self>,
891 ) -> BoxFuture<'static, Result<(), acp::Error>> {
892 self.send(
893 acp::SendUserMessageParams {
894 chunks: vec![acp::UserMessageChunk::Text {
895 text: message.to_string(),
896 }],
897 },
898 cx,
899 )
900 }
901
902 pub fn send(
903 &mut self,
904 message: acp::SendUserMessageParams,
905 cx: &mut Context<Self>,
906 ) -> BoxFuture<'static, Result<(), acp::Error>> {
907 let agent = self.connection.clone();
908 self.push_entry(
909 AgentThreadEntry::UserMessage(UserMessage::from_acp(
910 &message,
911 self.project.read(cx).languages().clone(),
912 cx,
913 )),
914 cx,
915 );
916
917 let (tx, rx) = oneshot::channel();
918 let cancel = self.cancel(cx);
919
920 self.send_task = Some(cx.spawn(async move |this, cx| {
921 cancel.await.log_err();
922
923 let result = agent.request(message).await;
924 tx.send(result).log_err();
925 this.update(cx, |this, _cx| this.send_task.take()).log_err();
926 }));
927
928 async move {
929 match rx.await {
930 Ok(Err(e)) => Err(e)?,
931 _ => Ok(()),
932 }
933 }
934 .boxed()
935 }
936
937 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<(), acp::Error>> {
938 let agent = self.connection.clone();
939
940 if self.send_task.take().is_some() {
941 cx.spawn(async move |this, cx| {
942 agent.request(acp::CancelSendMessageParams).await?;
943
944 this.update(cx, |this, _cx| {
945 for entry in this.entries.iter_mut() {
946 if let AgentThreadEntry::ToolCall(call) = entry {
947 let cancel = matches!(
948 call.status,
949 ToolCallStatus::WaitingForConfirmation { .. }
950 | ToolCallStatus::Allowed {
951 status: acp::ToolCallStatus::Running
952 }
953 );
954
955 if cancel {
956 let curr_status =
957 mem::replace(&mut call.status, ToolCallStatus::Canceled);
958
959 if let ToolCallStatus::WaitingForConfirmation {
960 respond_tx, ..
961 } = curr_status
962 {
963 respond_tx
964 .send(acp::ToolCallConfirmationOutcome::Cancel)
965 .ok();
966 }
967 }
968 }
969 }
970 })?;
971 Ok(())
972 })
973 } else {
974 Task::ready(Ok(()))
975 }
976 }
977
978 pub fn read_text_file(
979 &self,
980 request: acp::ReadTextFileParams,
981 cx: &mut Context<Self>,
982 ) -> Task<Result<String>> {
983 let project = self.project.clone();
984 let action_log = self.action_log.clone();
985 cx.spawn(async move |this, cx| {
986 let load = project.update(cx, |project, cx| {
987 let path = project
988 .project_path_for_absolute_path(&request.path, cx)
989 .context("invalid path")?;
990 anyhow::Ok(project.open_buffer(path, cx))
991 });
992 let buffer = load??.await?;
993
994 action_log.update(cx, |action_log, cx| {
995 action_log.buffer_read(buffer.clone(), cx);
996 })?;
997 project.update(cx, |project, cx| {
998 let position = buffer
999 .read(cx)
1000 .snapshot()
1001 .anchor_before(Point::new(request.line.unwrap_or_default(), 0));
1002 project.set_agent_location(
1003 Some(AgentLocation {
1004 buffer: buffer.downgrade(),
1005 position,
1006 }),
1007 cx,
1008 );
1009 })?;
1010 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
1011 this.update(cx, |this, _| {
1012 let text = snapshot.text();
1013 this.shared_buffers.insert(buffer.clone(), snapshot);
1014 text
1015 })
1016 })
1017 }
1018
1019 pub fn write_text_file(
1020 &self,
1021 path: PathBuf,
1022 content: String,
1023 cx: &mut Context<Self>,
1024 ) -> Task<Result<()>> {
1025 let project = self.project.clone();
1026 let action_log = self.action_log.clone();
1027 cx.spawn(async move |this, cx| {
1028 let load = project.update(cx, |project, cx| {
1029 let path = project
1030 .project_path_for_absolute_path(&path, cx)
1031 .context("invalid path")?;
1032 anyhow::Ok(project.open_buffer(path, cx))
1033 });
1034 let buffer = load??.await?;
1035 let snapshot = this.update(cx, |this, cx| {
1036 this.shared_buffers
1037 .get(&buffer)
1038 .cloned()
1039 .unwrap_or_else(|| buffer.read(cx).snapshot())
1040 })?;
1041 let edits = cx
1042 .background_executor()
1043 .spawn(async move {
1044 let old_text = snapshot.text();
1045 text_diff(old_text.as_str(), &content)
1046 .into_iter()
1047 .map(|(range, replacement)| {
1048 (
1049 snapshot.anchor_after(range.start)
1050 ..snapshot.anchor_before(range.end),
1051 replacement,
1052 )
1053 })
1054 .collect::<Vec<_>>()
1055 })
1056 .await;
1057 cx.update(|cx| {
1058 project.update(cx, |project, cx| {
1059 project.set_agent_location(
1060 Some(AgentLocation {
1061 buffer: buffer.downgrade(),
1062 position: edits
1063 .last()
1064 .map(|(range, _)| range.end)
1065 .unwrap_or(Anchor::MIN),
1066 }),
1067 cx,
1068 );
1069 });
1070
1071 action_log.update(cx, |action_log, cx| {
1072 action_log.buffer_read(buffer.clone(), cx);
1073 });
1074 buffer.update(cx, |buffer, cx| {
1075 buffer.edit(edits, None, cx);
1076 });
1077 action_log.update(cx, |action_log, cx| {
1078 action_log.buffer_edited(buffer.clone(), cx);
1079 });
1080 })?;
1081 project
1082 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1083 .await
1084 })
1085 }
1086
1087 pub fn child_status(&mut self) -> Option<Task<Result<()>>> {
1088 self.child_status.take()
1089 }
1090
1091 pub fn to_markdown(&self, cx: &App) -> String {
1092 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1093 }
1094}
1095
1096struct AcpClientDelegate {
1097 thread: WeakEntity<AcpThread>,
1098 cx: AsyncApp,
1099 // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
1100}
1101
1102impl AcpClientDelegate {
1103 fn new(thread: WeakEntity<AcpThread>, cx: AsyncApp) -> Self {
1104 Self { thread, cx }
1105 }
1106}
1107
1108impl acp::Client for AcpClientDelegate {
1109 async fn stream_assistant_message_chunk(
1110 &self,
1111 params: acp::StreamAssistantMessageChunkParams,
1112 ) -> Result<(), acp::Error> {
1113 let cx = &mut self.cx.clone();
1114
1115 cx.update(|cx| {
1116 self.thread
1117 .update(cx, |thread, cx| {
1118 thread.push_assistant_chunk(params.chunk, cx)
1119 })
1120 .ok();
1121 })?;
1122
1123 Ok(())
1124 }
1125
1126 async fn request_tool_call_confirmation(
1127 &self,
1128 request: acp::RequestToolCallConfirmationParams,
1129 ) -> Result<acp::RequestToolCallConfirmationResponse, acp::Error> {
1130 let cx = &mut self.cx.clone();
1131 let ToolCallRequest { id, outcome } = cx
1132 .update(|cx| {
1133 self.thread
1134 .update(cx, |thread, cx| thread.request_tool_call(request, cx))
1135 })?
1136 .context("Failed to update thread")?;
1137
1138 Ok(acp::RequestToolCallConfirmationResponse {
1139 id,
1140 outcome: outcome.await.map_err(acp::Error::into_internal_error)?,
1141 })
1142 }
1143
1144 async fn push_tool_call(
1145 &self,
1146 request: acp::PushToolCallParams,
1147 ) -> Result<acp::PushToolCallResponse, acp::Error> {
1148 let cx = &mut self.cx.clone();
1149 let id = cx
1150 .update(|cx| {
1151 self.thread
1152 .update(cx, |thread, cx| thread.push_tool_call(request, cx))
1153 })?
1154 .context("Failed to update thread")?;
1155
1156 Ok(acp::PushToolCallResponse { id })
1157 }
1158
1159 async fn update_tool_call(&self, request: acp::UpdateToolCallParams) -> Result<(), acp::Error> {
1160 let cx = &mut self.cx.clone();
1161
1162 cx.update(|cx| {
1163 self.thread.update(cx, |thread, cx| {
1164 thread.update_tool_call(request.tool_call_id, request.status, request.content, cx)
1165 })
1166 })?
1167 .context("Failed to update thread")??;
1168
1169 Ok(())
1170 }
1171
1172 async fn read_text_file(
1173 &self,
1174 request: acp::ReadTextFileParams,
1175 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
1176 let content = self
1177 .cx
1178 .update(|cx| {
1179 self.thread
1180 .update(cx, |thread, cx| thread.read_text_file(request, cx))
1181 })?
1182 .context("Failed to update thread")?
1183 .await?;
1184 Ok(acp::ReadTextFileResponse { content })
1185 }
1186
1187 async fn write_text_file(&self, request: acp::WriteTextFileParams) -> Result<(), acp::Error> {
1188 self.cx
1189 .update(|cx| {
1190 self.thread.update(cx, |thread, cx| {
1191 thread.write_text_file(request.path, request.content, cx)
1192 })
1193 })?
1194 .context("Failed to update thread")?
1195 .await?;
1196
1197 Ok(())
1198 }
1199}
1200
1201fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
1202 match icon {
1203 acp::Icon::FileSearch => IconName::ToolSearch,
1204 acp::Icon::Folder => IconName::ToolFolder,
1205 acp::Icon::Globe => IconName::ToolWeb,
1206 acp::Icon::Hammer => IconName::ToolHammer,
1207 acp::Icon::LightBulb => IconName::ToolBulb,
1208 acp::Icon::Pencil => IconName::ToolPencil,
1209 acp::Icon::Regex => IconName::ToolRegex,
1210 acp::Icon::Terminal => IconName::ToolTerminal,
1211 }
1212}
1213
1214pub struct ToolCallRequest {
1215 pub id: acp::ToolCallId,
1216 pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
1217}
1218
1219#[cfg(test)]
1220mod tests {
1221 use super::*;
1222 use agent_servers::{AgentServerCommand, AgentServerVersion};
1223 use async_pipe::{PipeReader, PipeWriter};
1224 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1225 use gpui::{AsyncApp, TestAppContext};
1226 use indoc::indoc;
1227 use project::FakeFs;
1228 use serde_json::json;
1229 use settings::SettingsStore;
1230 use smol::{future::BoxedLocal, stream::StreamExt as _};
1231 use std::{cell::RefCell, env, path::Path, rc::Rc, time::Duration};
1232 use util::path;
1233
1234 fn init_test(cx: &mut TestAppContext) {
1235 env_logger::try_init().ok();
1236 cx.update(|cx| {
1237 let settings_store = SettingsStore::test(cx);
1238 cx.set_global(settings_store);
1239 Project::init_settings(cx);
1240 language::init(cx);
1241 });
1242 }
1243
1244 #[gpui::test]
1245 async fn test_thinking_concatenation(cx: &mut TestAppContext) {
1246 init_test(cx);
1247
1248 let fs = FakeFs::new(cx.executor());
1249 let project = Project::test(fs, [], cx).await;
1250 let (thread, fake_server) = fake_acp_thread(project, cx);
1251
1252 fake_server.update(cx, |fake_server, _| {
1253 fake_server.on_user_message(move |_, server, mut cx| async move {
1254 server
1255 .update(&mut cx, |server, _| {
1256 server.send_to_zed(acp::StreamAssistantMessageChunkParams {
1257 chunk: acp::AssistantMessageChunk::Thought {
1258 thought: "Thinking ".into(),
1259 },
1260 })
1261 })?
1262 .await
1263 .unwrap();
1264 server
1265 .update(&mut cx, |server, _| {
1266 server.send_to_zed(acp::StreamAssistantMessageChunkParams {
1267 chunk: acp::AssistantMessageChunk::Thought {
1268 thought: "hard!".into(),
1269 },
1270 })
1271 })?
1272 .await
1273 .unwrap();
1274
1275 Ok(())
1276 })
1277 });
1278
1279 thread
1280 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1281 .await
1282 .unwrap();
1283
1284 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1285 assert_eq!(
1286 output,
1287 indoc! {r#"
1288 ## User
1289
1290 Hello from Zed!
1291
1292 ## Assistant
1293
1294 <thinking>
1295 Thinking hard!
1296 </thinking>
1297
1298 "#}
1299 );
1300 }
1301
1302 #[gpui::test]
1303 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1304 init_test(cx);
1305
1306 let fs = FakeFs::new(cx.executor());
1307 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1308 .await;
1309 let project = Project::test(fs.clone(), [], cx).await;
1310 let (thread, fake_server) = fake_acp_thread(project.clone(), cx);
1311 let (worktree, pathbuf) = project
1312 .update(cx, |project, cx| {
1313 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1314 })
1315 .await
1316 .unwrap();
1317 let buffer = project
1318 .update(cx, |project, cx| {
1319 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1320 })
1321 .await
1322 .unwrap();
1323
1324 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1325 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1326
1327 fake_server.update(cx, |fake_server, _| {
1328 fake_server.on_user_message(move |_, server, mut cx| {
1329 let read_file_tx = read_file_tx.clone();
1330 async move {
1331 let content = server
1332 .update(&mut cx, |server, _| {
1333 server.send_to_zed(acp::ReadTextFileParams {
1334 path: path!("/tmp/foo").into(),
1335 line: None,
1336 limit: None,
1337 })
1338 })?
1339 .await
1340 .unwrap();
1341 assert_eq!(content.content, "one\ntwo\nthree\n");
1342 read_file_tx.take().unwrap().send(()).unwrap();
1343 server
1344 .update(&mut cx, |server, _| {
1345 server.send_to_zed(acp::WriteTextFileParams {
1346 path: path!("/tmp/foo").into(),
1347 content: "one\ntwo\nthree\nfour\nfive\n".to_string(),
1348 })
1349 })?
1350 .await
1351 .unwrap();
1352 Ok(())
1353 }
1354 })
1355 });
1356
1357 let request = thread.update(cx, |thread, cx| {
1358 thread.send_raw("Extend the count in /tmp/foo", cx)
1359 });
1360 read_file_rx.await.ok();
1361 buffer.update(cx, |buffer, cx| {
1362 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1363 });
1364 cx.run_until_parked();
1365 assert_eq!(
1366 buffer.read_with(cx, |buffer, _| buffer.text()),
1367 "zero\none\ntwo\nthree\nfour\nfive\n"
1368 );
1369 assert_eq!(
1370 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1371 "zero\none\ntwo\nthree\nfour\nfive\n"
1372 );
1373 request.await.unwrap();
1374 }
1375
1376 #[gpui::test]
1377 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1378 init_test(cx);
1379
1380 let fs = FakeFs::new(cx.executor());
1381 let project = Project::test(fs, [], cx).await;
1382 let (thread, fake_server) = fake_acp_thread(project, cx);
1383
1384 let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>();
1385
1386 let tool_call_id = Rc::new(RefCell::new(None));
1387 let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx)));
1388 fake_server.update(cx, |fake_server, _| {
1389 let tool_call_id = tool_call_id.clone();
1390 fake_server.on_user_message(move |_, server, mut cx| {
1391 let end_turn_rx = end_turn_rx.clone();
1392 let tool_call_id = tool_call_id.clone();
1393 async move {
1394 let tool_call_result = server
1395 .update(&mut cx, |server, _| {
1396 server.send_to_zed(acp::PushToolCallParams {
1397 label: "Fetch".to_string(),
1398 icon: acp::Icon::Globe,
1399 content: None,
1400 locations: vec![],
1401 })
1402 })?
1403 .await
1404 .unwrap();
1405 *tool_call_id.clone().borrow_mut() = Some(tool_call_result.id);
1406 end_turn_rx.take().unwrap().await.ok();
1407
1408 Ok(())
1409 }
1410 })
1411 });
1412
1413 let request = thread.update(cx, |thread, cx| {
1414 thread.send_raw("Fetch https://example.com", cx)
1415 });
1416
1417 run_until_first_tool_call(&thread, cx).await;
1418
1419 thread.read_with(cx, |thread, _| {
1420 assert!(matches!(
1421 thread.entries[1],
1422 AgentThreadEntry::ToolCall(ToolCall {
1423 status: ToolCallStatus::Allowed {
1424 status: acp::ToolCallStatus::Running,
1425 ..
1426 },
1427 ..
1428 })
1429 ));
1430 });
1431
1432 cx.run_until_parked();
1433
1434 thread
1435 .update(cx, |thread, cx| thread.cancel(cx))
1436 .await
1437 .unwrap();
1438
1439 thread.read_with(cx, |thread, _| {
1440 assert!(matches!(
1441 &thread.entries[1],
1442 AgentThreadEntry::ToolCall(ToolCall {
1443 status: ToolCallStatus::Canceled,
1444 ..
1445 })
1446 ));
1447 });
1448
1449 fake_server
1450 .update(cx, |fake_server, _| {
1451 fake_server.send_to_zed(acp::UpdateToolCallParams {
1452 tool_call_id: tool_call_id.borrow().unwrap(),
1453 status: acp::ToolCallStatus::Finished,
1454 content: None,
1455 })
1456 })
1457 .await
1458 .unwrap();
1459
1460 drop(end_turn_tx);
1461 request.await.unwrap();
1462
1463 thread.read_with(cx, |thread, _| {
1464 assert!(matches!(
1465 thread.entries[1],
1466 AgentThreadEntry::ToolCall(ToolCall {
1467 status: ToolCallStatus::Allowed {
1468 status: acp::ToolCallStatus::Finished,
1469 ..
1470 },
1471 ..
1472 })
1473 ));
1474 });
1475 }
1476
1477 #[gpui::test]
1478 #[cfg_attr(not(feature = "gemini"), ignore)]
1479 async fn test_gemini_basic(cx: &mut TestAppContext) {
1480 init_test(cx);
1481
1482 cx.executor().allow_parking();
1483
1484 let fs = FakeFs::new(cx.executor());
1485 let project = Project::test(fs, [], cx).await;
1486 let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
1487 thread
1488 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1489 .await
1490 .unwrap();
1491
1492 thread.read_with(cx, |thread, _| {
1493 assert_eq!(thread.entries.len(), 2);
1494 assert!(matches!(
1495 thread.entries[0],
1496 AgentThreadEntry::UserMessage(_)
1497 ));
1498 assert!(matches!(
1499 thread.entries[1],
1500 AgentThreadEntry::AssistantMessage(_)
1501 ));
1502 });
1503 }
1504
1505 #[gpui::test]
1506 #[cfg_attr(not(feature = "gemini"), ignore)]
1507 async fn test_gemini_path_mentions(cx: &mut TestAppContext) {
1508 init_test(cx);
1509
1510 cx.executor().allow_parking();
1511 let tempdir = tempfile::tempdir().unwrap();
1512 std::fs::write(
1513 tempdir.path().join("foo.rs"),
1514 indoc! {"
1515 fn main() {
1516 println!(\"Hello, world!\");
1517 }
1518 "},
1519 )
1520 .expect("failed to write file");
1521 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
1522 let thread = gemini_acp_thread(project.clone(), tempdir.path(), cx).await;
1523 thread
1524 .update(cx, |thread, cx| {
1525 thread.send(
1526 acp::SendUserMessageParams {
1527 chunks: vec![
1528 acp::UserMessageChunk::Text {
1529 text: "Read the file ".into(),
1530 },
1531 acp::UserMessageChunk::Path {
1532 path: Path::new("foo.rs").into(),
1533 },
1534 acp::UserMessageChunk::Text {
1535 text: " and tell me what the content of the println! is".into(),
1536 },
1537 ],
1538 },
1539 cx,
1540 )
1541 })
1542 .await
1543 .unwrap();
1544
1545 thread.read_with(cx, |thread, cx| {
1546 assert_eq!(thread.entries.len(), 3);
1547 assert!(matches!(
1548 thread.entries[0],
1549 AgentThreadEntry::UserMessage(_)
1550 ));
1551 assert!(matches!(thread.entries[1], AgentThreadEntry::ToolCall(_)));
1552 let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries[2] else {
1553 panic!("Expected AssistantMessage")
1554 };
1555 assert!(
1556 assistant_message.to_markdown(cx).contains("Hello, world!"),
1557 "unexpected assistant message: {:?}",
1558 assistant_message.to_markdown(cx)
1559 );
1560 });
1561 }
1562
1563 #[gpui::test]
1564 #[cfg_attr(not(feature = "gemini"), ignore)]
1565 async fn test_gemini_tool_call(cx: &mut TestAppContext) {
1566 init_test(cx);
1567
1568 cx.executor().allow_parking();
1569
1570 let fs = FakeFs::new(cx.executor());
1571 fs.insert_tree(
1572 path!("/private/tmp"),
1573 json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
1574 )
1575 .await;
1576 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
1577 let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
1578 thread
1579 .update(cx, |thread, cx| {
1580 thread.send_raw(
1581 "Read the '/private/tmp/foo' file and tell me what you see.",
1582 cx,
1583 )
1584 })
1585 .await
1586 .unwrap();
1587 thread.read_with(cx, |thread, _cx| {
1588 assert!(matches!(
1589 &thread.entries()[2],
1590 AgentThreadEntry::ToolCall(ToolCall {
1591 status: ToolCallStatus::Allowed { .. },
1592 ..
1593 })
1594 ));
1595
1596 assert!(matches!(
1597 thread.entries[3],
1598 AgentThreadEntry::AssistantMessage(_)
1599 ));
1600 });
1601 }
1602
1603 #[gpui::test]
1604 #[cfg_attr(not(feature = "gemini"), ignore)]
1605 async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
1606 init_test(cx);
1607
1608 cx.executor().allow_parking();
1609
1610 let fs = FakeFs::new(cx.executor());
1611 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
1612 let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
1613 let full_turn = thread.update(cx, |thread, cx| {
1614 thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
1615 });
1616
1617 run_until_first_tool_call(&thread, cx).await;
1618
1619 let tool_call_id = thread.read_with(cx, |thread, _cx| {
1620 let AgentThreadEntry::ToolCall(ToolCall {
1621 id,
1622 status:
1623 ToolCallStatus::WaitingForConfirmation {
1624 confirmation: ToolCallConfirmation::Execute { root_command, .. },
1625 ..
1626 },
1627 ..
1628 }) = &thread.entries()[2]
1629 else {
1630 panic!();
1631 };
1632
1633 assert_eq!(root_command, "echo");
1634
1635 *id
1636 });
1637
1638 thread.update(cx, |thread, cx| {
1639 thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
1640
1641 assert!(matches!(
1642 &thread.entries()[2],
1643 AgentThreadEntry::ToolCall(ToolCall {
1644 status: ToolCallStatus::Allowed { .. },
1645 ..
1646 })
1647 ));
1648 });
1649
1650 full_turn.await.unwrap();
1651
1652 thread.read_with(cx, |thread, cx| {
1653 let AgentThreadEntry::ToolCall(ToolCall {
1654 content: Some(ToolCallContent::Markdown { markdown }),
1655 status: ToolCallStatus::Allowed { .. },
1656 ..
1657 }) = &thread.entries()[2]
1658 else {
1659 panic!();
1660 };
1661
1662 markdown.read_with(cx, |md, _cx| {
1663 assert!(
1664 md.source().contains("Hello, world!"),
1665 r#"Expected '{}' to contain "Hello, world!""#,
1666 md.source()
1667 );
1668 });
1669 });
1670 }
1671
1672 #[gpui::test]
1673 #[cfg_attr(not(feature = "gemini"), ignore)]
1674 async fn test_gemini_cancel(cx: &mut TestAppContext) {
1675 init_test(cx);
1676
1677 cx.executor().allow_parking();
1678
1679 let fs = FakeFs::new(cx.executor());
1680 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
1681 let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
1682 let full_turn = thread.update(cx, |thread, cx| {
1683 thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
1684 });
1685
1686 let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
1687
1688 thread.read_with(cx, |thread, _cx| {
1689 let AgentThreadEntry::ToolCall(ToolCall {
1690 id,
1691 status:
1692 ToolCallStatus::WaitingForConfirmation {
1693 confirmation: ToolCallConfirmation::Execute { root_command, .. },
1694 ..
1695 },
1696 ..
1697 }) = &thread.entries()[first_tool_call_ix]
1698 else {
1699 panic!("{:?}", thread.entries()[1]);
1700 };
1701
1702 assert_eq!(root_command, "echo");
1703
1704 *id
1705 });
1706
1707 thread
1708 .update(cx, |thread, cx| thread.cancel(cx))
1709 .await
1710 .unwrap();
1711 full_turn.await.unwrap();
1712 thread.read_with(cx, |thread, _| {
1713 let AgentThreadEntry::ToolCall(ToolCall {
1714 status: ToolCallStatus::Canceled,
1715 ..
1716 }) = &thread.entries()[first_tool_call_ix]
1717 else {
1718 panic!();
1719 };
1720 });
1721
1722 thread
1723 .update(cx, |thread, cx| {
1724 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
1725 })
1726 .await
1727 .unwrap();
1728 thread.read_with(cx, |thread, _| {
1729 assert!(matches!(
1730 &thread.entries().last().unwrap(),
1731 AgentThreadEntry::AssistantMessage(..),
1732 ))
1733 });
1734 }
1735
1736 async fn run_until_first_tool_call(
1737 thread: &Entity<AcpThread>,
1738 cx: &mut TestAppContext,
1739 ) -> usize {
1740 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1741
1742 let subscription = cx.update(|cx| {
1743 cx.subscribe(thread, move |thread, _, cx| {
1744 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1745 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1746 return tx.try_send(ix).unwrap();
1747 }
1748 }
1749 })
1750 });
1751
1752 select! {
1753 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1754 panic!("Timeout waiting for tool call")
1755 }
1756 ix = rx.next().fuse() => {
1757 drop(subscription);
1758 ix.unwrap()
1759 }
1760 }
1761 }
1762
1763 pub async fn gemini_acp_thread(
1764 project: Entity<Project>,
1765 current_dir: impl AsRef<Path>,
1766 cx: &mut TestAppContext,
1767 ) -> Entity<AcpThread> {
1768 struct DevGemini;
1769
1770 impl agent_servers::AgentServer for DevGemini {
1771 async fn command(
1772 &self,
1773 _project: &Entity<Project>,
1774 _cx: &mut AsyncApp,
1775 ) -> Result<agent_servers::AgentServerCommand> {
1776 let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
1777 .join("../../../gemini-cli/packages/cli")
1778 .to_string_lossy()
1779 .to_string();
1780
1781 Ok(AgentServerCommand {
1782 path: "node".into(),
1783 args: vec![cli_path, "--acp".into()],
1784 env: None,
1785 })
1786 }
1787
1788 async fn version(
1789 &self,
1790 _command: &agent_servers::AgentServerCommand,
1791 ) -> Result<AgentServerVersion> {
1792 Ok(AgentServerVersion {
1793 current_version: "0.1.0".into(),
1794 supported: true,
1795 })
1796 }
1797 }
1798
1799 let thread = AcpThread::spawn(DevGemini, current_dir.as_ref(), project, &mut cx.to_async())
1800 .await
1801 .unwrap();
1802
1803 thread
1804 .update(cx, |thread, _| thread.initialize())
1805 .await
1806 .unwrap();
1807 thread
1808 }
1809
1810 pub fn fake_acp_thread(
1811 project: Entity<Project>,
1812 cx: &mut TestAppContext,
1813 ) -> (Entity<AcpThread>, Entity<FakeAcpServer>) {
1814 let (stdin_tx, stdin_rx) = async_pipe::pipe();
1815 let (stdout_tx, stdout_rx) = async_pipe::pipe();
1816 let thread = cx.update(|cx| cx.new(|cx| AcpThread::fake(stdin_tx, stdout_rx, project, cx)));
1817 let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
1818 (thread, agent)
1819 }
1820
1821 pub struct FakeAcpServer {
1822 connection: acp::ClientConnection,
1823 _io_task: Task<()>,
1824 on_user_message: Option<
1825 Rc<
1826 dyn Fn(
1827 acp::SendUserMessageParams,
1828 Entity<FakeAcpServer>,
1829 AsyncApp,
1830 ) -> LocalBoxFuture<'static, Result<(), acp::Error>>,
1831 >,
1832 >,
1833 }
1834
1835 #[derive(Clone)]
1836 struct FakeAgent {
1837 server: Entity<FakeAcpServer>,
1838 cx: AsyncApp,
1839 }
1840
1841 impl acp::Agent for FakeAgent {
1842 async fn initialize(&self) -> Result<acp::InitializeResponse, acp::Error> {
1843 Ok(acp::InitializeResponse {
1844 is_authenticated: true,
1845 })
1846 }
1847
1848 async fn authenticate(&self) -> Result<(), acp::Error> {
1849 Ok(())
1850 }
1851
1852 async fn cancel_send_message(&self) -> Result<(), acp::Error> {
1853 Ok(())
1854 }
1855
1856 async fn send_user_message(
1857 &self,
1858 request: acp::SendUserMessageParams,
1859 ) -> Result<(), acp::Error> {
1860 let mut cx = self.cx.clone();
1861 let handler = self
1862 .server
1863 .update(&mut cx, |server, _| server.on_user_message.clone())
1864 .ok()
1865 .flatten();
1866 if let Some(handler) = handler {
1867 handler(request, self.server.clone(), self.cx.clone()).await
1868 } else {
1869 Err(anyhow::anyhow!("No handler for on_user_message").into())
1870 }
1871 }
1872 }
1873
1874 impl FakeAcpServer {
1875 fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
1876 let agent = FakeAgent {
1877 server: cx.entity(),
1878 cx: cx.to_async(),
1879 };
1880 let foreground_executor = cx.foreground_executor().clone();
1881
1882 let (connection, io_fut) = acp::ClientConnection::connect_to_client(
1883 agent.clone(),
1884 stdout,
1885 stdin,
1886 move |fut| {
1887 foreground_executor.spawn(fut).detach();
1888 },
1889 );
1890 FakeAcpServer {
1891 connection: connection,
1892 on_user_message: None,
1893 _io_task: cx.background_spawn(async move {
1894 io_fut.await.log_err();
1895 }),
1896 }
1897 }
1898
1899 fn on_user_message<F>(
1900 &mut self,
1901 handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity<FakeAcpServer>, AsyncApp) -> F
1902 + 'static,
1903 ) where
1904 F: Future<Output = Result<(), acp::Error>> + 'static,
1905 {
1906 self.on_user_message
1907 .replace(Rc::new(move |request, server, cx| {
1908 handler(request, server, cx).boxed_local()
1909 }));
1910 }
1911
1912 fn send_to_zed<T: acp::ClientRequest + 'static>(
1913 &self,
1914 message: T,
1915 ) -> BoxedLocal<Result<T::Response>> {
1916 self.connection
1917 .request(message)
1918 .map(|f| f.map_err(|err| anyhow!(err)))
1919 .boxed_local()
1920 }
1921 }
1922}