1pub use acp::ToolCallId;
2use agent_servers::AgentServer;
3use agentic_coding_protocol::{self as acp, UserMessageChunk};
4use anyhow::{Context as _, Result, anyhow};
5use assistant_tool::ActionLog;
6use buffer_diff::BufferDiff;
7use editor::{MultiBuffer, PathKey};
8use futures::{FutureExt, channel::oneshot, future::BoxFuture};
9use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
10use itertools::Itertools;
11use language::{
12 Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
13 text_diff,
14};
15use markdown::Markdown;
16use project::{AgentLocation, Project};
17use std::collections::HashMap;
18use std::error::Error;
19use std::fmt::{Formatter, Write};
20use std::{
21 fmt::Display,
22 mem,
23 path::{Path, PathBuf},
24 sync::Arc,
25};
26use ui::{App, IconName};
27use util::ResultExt;
28
29#[derive(Clone, Debug, Eq, PartialEq)]
30pub struct UserMessage {
31 pub content: Entity<Markdown>,
32}
33
34impl UserMessage {
35 pub fn from_acp(
36 message: &acp::SendUserMessageParams,
37 language_registry: Arc<LanguageRegistry>,
38 cx: &mut App,
39 ) -> Self {
40 let mut md_source = String::new();
41
42 for chunk in &message.chunks {
43 match chunk {
44 UserMessageChunk::Text { text } => md_source.push_str(&text),
45 UserMessageChunk::Path { path } => {
46 write!(&mut md_source, "{}", MentionPath(&path)).unwrap()
47 }
48 }
49 }
50
51 Self {
52 content: cx
53 .new(|cx| Markdown::new(md_source.into(), Some(language_registry), None, cx)),
54 }
55 }
56
57 fn to_markdown(&self, cx: &App) -> String {
58 format!("## User\n\n{}\n\n", self.content.read(cx).source())
59 }
60}
61
62#[derive(Debug)]
63pub struct MentionPath<'a>(&'a Path);
64
65impl<'a> MentionPath<'a> {
66 const PREFIX: &'static str = "@file:";
67
68 pub fn new(path: &'a Path) -> Self {
69 MentionPath(path)
70 }
71
72 pub fn try_parse(url: &'a str) -> Option<Self> {
73 let path = url.strip_prefix(Self::PREFIX)?;
74 Some(MentionPath(Path::new(path)))
75 }
76
77 pub fn path(&self) -> &Path {
78 self.0
79 }
80}
81
82impl Display for MentionPath<'_> {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 write!(
85 f,
86 "[@{}]({}{})",
87 self.0.file_name().unwrap_or_default().display(),
88 Self::PREFIX,
89 self.0.display()
90 )
91 }
92}
93
94#[derive(Clone, Debug, Eq, PartialEq)]
95pub struct AssistantMessage {
96 pub chunks: Vec<AssistantMessageChunk>,
97}
98
99impl AssistantMessage {
100 fn to_markdown(&self, cx: &App) -> String {
101 format!(
102 "## Assistant\n\n{}\n\n",
103 self.chunks
104 .iter()
105 .map(|chunk| chunk.to_markdown(cx))
106 .join("\n\n")
107 )
108 }
109}
110
111#[derive(Clone, Debug, Eq, PartialEq)]
112pub enum AssistantMessageChunk {
113 Text { chunk: Entity<Markdown> },
114 Thought { chunk: Entity<Markdown> },
115}
116
117impl AssistantMessageChunk {
118 pub fn from_acp(
119 chunk: acp::AssistantMessageChunk,
120 language_registry: Arc<LanguageRegistry>,
121 cx: &mut App,
122 ) -> Self {
123 match chunk {
124 acp::AssistantMessageChunk::Text { text } => Self::Text {
125 chunk: cx.new(|cx| Markdown::new(text.into(), Some(language_registry), None, cx)),
126 },
127 acp::AssistantMessageChunk::Thought { thought } => Self::Thought {
128 chunk: cx
129 .new(|cx| Markdown::new(thought.into(), Some(language_registry), None, cx)),
130 },
131 }
132 }
133
134 pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
135 Self::Text {
136 chunk: cx.new(|cx| {
137 Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
138 }),
139 }
140 }
141
142 fn to_markdown(&self, cx: &App) -> String {
143 match self {
144 Self::Text { chunk } => chunk.read(cx).source().to_string(),
145 Self::Thought { chunk } => {
146 format!("<thinking>\n{}\n</thinking>", chunk.read(cx).source())
147 }
148 }
149 }
150}
151
152#[derive(Debug)]
153pub enum AgentThreadEntry {
154 UserMessage(UserMessage),
155 AssistantMessage(AssistantMessage),
156 ToolCall(ToolCall),
157}
158
159impl AgentThreadEntry {
160 fn to_markdown(&self, cx: &App) -> String {
161 match self {
162 Self::UserMessage(message) => message.to_markdown(cx),
163 Self::AssistantMessage(message) => message.to_markdown(cx),
164 Self::ToolCall(too_call) => too_call.to_markdown(cx),
165 }
166 }
167
168 pub fn diff(&self) -> Option<&Diff> {
169 if let AgentThreadEntry::ToolCall(ToolCall {
170 content: Some(ToolCallContent::Diff { diff }),
171 ..
172 }) = self
173 {
174 Some(&diff)
175 } else {
176 None
177 }
178 }
179}
180
181#[derive(Debug)]
182pub struct ToolCall {
183 pub id: acp::ToolCallId,
184 pub label: Entity<Markdown>,
185 pub icon: IconName,
186 pub content: Option<ToolCallContent>,
187 pub status: ToolCallStatus,
188 pub locations: Vec<acp::ToolCallLocation>,
189}
190
191impl ToolCall {
192 fn to_markdown(&self, cx: &App) -> String {
193 let mut markdown = format!(
194 "**Tool Call: {}**\nStatus: {}\n\n",
195 self.label.read(cx).source(),
196 self.status
197 );
198 if let Some(content) = &self.content {
199 markdown.push_str(content.to_markdown(cx).as_str());
200 markdown.push_str("\n\n");
201 }
202 markdown
203 }
204}
205
206#[derive(Debug)]
207pub enum ToolCallStatus {
208 WaitingForConfirmation {
209 confirmation: ToolCallConfirmation,
210 respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
211 },
212 Allowed {
213 status: acp::ToolCallStatus,
214 },
215 Rejected,
216 Canceled,
217}
218
219impl Display for ToolCallStatus {
220 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
221 write!(
222 f,
223 "{}",
224 match self {
225 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
226 ToolCallStatus::Allowed { status } => match status {
227 acp::ToolCallStatus::Running => "Running",
228 acp::ToolCallStatus::Finished => "Finished",
229 acp::ToolCallStatus::Error => "Error",
230 },
231 ToolCallStatus::Rejected => "Rejected",
232 ToolCallStatus::Canceled => "Canceled",
233 }
234 )
235 }
236}
237
238#[derive(Debug)]
239pub enum ToolCallConfirmation {
240 Edit {
241 description: Option<Entity<Markdown>>,
242 },
243 Execute {
244 command: String,
245 root_command: String,
246 description: Option<Entity<Markdown>>,
247 },
248 Mcp {
249 server_name: String,
250 tool_name: String,
251 tool_display_name: String,
252 description: Option<Entity<Markdown>>,
253 },
254 Fetch {
255 urls: Vec<SharedString>,
256 description: Option<Entity<Markdown>>,
257 },
258 Other {
259 description: Entity<Markdown>,
260 },
261}
262
263impl ToolCallConfirmation {
264 pub fn from_acp(
265 confirmation: acp::ToolCallConfirmation,
266 language_registry: Arc<LanguageRegistry>,
267 cx: &mut App,
268 ) -> Self {
269 let to_md = |description: String, cx: &mut App| -> Entity<Markdown> {
270 cx.new(|cx| {
271 Markdown::new(
272 description.into(),
273 Some(language_registry.clone()),
274 None,
275 cx,
276 )
277 })
278 };
279
280 match confirmation {
281 acp::ToolCallConfirmation::Edit { description } => Self::Edit {
282 description: description.map(|description| to_md(description, cx)),
283 },
284 acp::ToolCallConfirmation::Execute {
285 command,
286 root_command,
287 description,
288 } => Self::Execute {
289 command,
290 root_command,
291 description: description.map(|description| to_md(description, cx)),
292 },
293 acp::ToolCallConfirmation::Mcp {
294 server_name,
295 tool_name,
296 tool_display_name,
297 description,
298 } => Self::Mcp {
299 server_name,
300 tool_name,
301 tool_display_name,
302 description: description.map(|description| to_md(description, cx)),
303 },
304 acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch {
305 urls: urls.iter().map(|url| url.into()).collect(),
306 description: description.map(|description| to_md(description, cx)),
307 },
308 acp::ToolCallConfirmation::Other { description } => Self::Other {
309 description: to_md(description, cx),
310 },
311 }
312 }
313}
314
315#[derive(Debug)]
316pub enum ToolCallContent {
317 Markdown { markdown: Entity<Markdown> },
318 Diff { diff: Diff },
319}
320
321impl ToolCallContent {
322 pub fn from_acp(
323 content: acp::ToolCallContent,
324 language_registry: Arc<LanguageRegistry>,
325 cx: &mut App,
326 ) -> Self {
327 match content {
328 acp::ToolCallContent::Markdown { markdown } => Self::Markdown {
329 markdown: cx.new(|cx| Markdown::new_text(markdown.into(), cx)),
330 },
331 acp::ToolCallContent::Diff { diff } => Self::Diff {
332 diff: Diff::from_acp(diff, language_registry, cx),
333 },
334 }
335 }
336
337 fn to_markdown(&self, cx: &App) -> String {
338 match self {
339 Self::Markdown { markdown } => markdown.read(cx).source().to_string(),
340 Self::Diff { diff } => diff.to_markdown(cx),
341 }
342 }
343}
344
345#[derive(Debug)]
346pub struct Diff {
347 pub multibuffer: Entity<MultiBuffer>,
348 pub path: PathBuf,
349 pub new_buffer: Entity<Buffer>,
350 pub old_buffer: Entity<Buffer>,
351 _task: Task<Result<()>>,
352}
353
354impl Diff {
355 pub fn from_acp(
356 diff: acp::Diff,
357 language_registry: Arc<LanguageRegistry>,
358 cx: &mut App,
359 ) -> Self {
360 let acp::Diff {
361 path,
362 old_text,
363 new_text,
364 } = diff;
365
366 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
367
368 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
369 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
370 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
371 let old_buffer_snapshot = old_buffer.read(cx).snapshot();
372 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
373 let diff_task = buffer_diff.update(cx, |diff, cx| {
374 diff.set_base_text(
375 old_buffer_snapshot,
376 Some(language_registry.clone()),
377 new_buffer_snapshot,
378 cx,
379 )
380 });
381
382 let task = cx.spawn({
383 let multibuffer = multibuffer.clone();
384 let path = path.clone();
385 let new_buffer = new_buffer.clone();
386 async move |cx| {
387 diff_task.await?;
388
389 multibuffer
390 .update(cx, |multibuffer, cx| {
391 let hunk_ranges = {
392 let buffer = new_buffer.read(cx);
393 let diff = buffer_diff.read(cx);
394 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
395 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
396 .collect::<Vec<_>>()
397 };
398
399 multibuffer.set_excerpts_for_path(
400 PathKey::for_buffer(&new_buffer, cx),
401 new_buffer.clone(),
402 hunk_ranges,
403 editor::DEFAULT_MULTIBUFFER_CONTEXT,
404 cx,
405 );
406 multibuffer.add_diff(buffer_diff.clone(), cx);
407 })
408 .log_err();
409
410 if let Some(language) = language_registry
411 .language_for_file_path(&path)
412 .await
413 .log_err()
414 {
415 new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
416 }
417
418 anyhow::Ok(())
419 }
420 });
421
422 Self {
423 multibuffer,
424 path,
425 new_buffer,
426 old_buffer,
427 _task: task,
428 }
429 }
430
431 fn to_markdown(&self, cx: &App) -> String {
432 let buffer_text = self
433 .multibuffer
434 .read(cx)
435 .all_buffers()
436 .iter()
437 .map(|buffer| buffer.read(cx).text())
438 .join("\n");
439 format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
440 }
441}
442
443pub struct AcpThread {
444 entries: Vec<AgentThreadEntry>,
445 title: SharedString,
446 project: Entity<Project>,
447 action_log: Entity<ActionLog>,
448 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
449 send_task: Option<Task<()>>,
450 connection: Arc<acp::AgentConnection>,
451 child_status: Option<Task<Result<()>>>,
452 _io_task: Task<()>,
453}
454
455pub enum AcpThreadEvent {
456 NewEntry,
457 EntryUpdated(usize),
458}
459
460impl EventEmitter<AcpThreadEvent> for AcpThread {}
461
462#[derive(PartialEq, Eq)]
463pub enum ThreadStatus {
464 Idle,
465 WaitingForToolConfirmation,
466 Generating,
467}
468
469#[derive(Debug, Clone)]
470pub enum LoadError {
471 Unsupported { current_version: SharedString },
472 Exited(i32),
473 Other(SharedString),
474}
475
476impl Display for LoadError {
477 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
478 match self {
479 LoadError::Unsupported { current_version } => {
480 write!(
481 f,
482 "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).",
483 current_version
484 )
485 }
486 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
487 LoadError::Other(msg) => write!(f, "{}", msg),
488 }
489 }
490}
491
492impl Error for LoadError {}
493
494impl AcpThread {
495 pub async fn spawn(
496 server: impl AgentServer + 'static,
497 root_dir: &Path,
498 project: Entity<Project>,
499 cx: &mut AsyncApp,
500 ) -> Result<Entity<Self>> {
501 let command = match server.command(&project, cx).await {
502 Ok(command) => command,
503 Err(e) => return Err(anyhow!(LoadError::Other(format!("{e}").into()))),
504 };
505
506 let mut child = util::command::new_smol_command(&command.path)
507 .args(command.args.iter())
508 .current_dir(root_dir)
509 .stdin(std::process::Stdio::piped())
510 .stdout(std::process::Stdio::piped())
511 .stderr(std::process::Stdio::inherit())
512 .kill_on_drop(true)
513 .spawn()?;
514
515 let stdin = child.stdin.take().unwrap();
516 let stdout = child.stdout.take().unwrap();
517
518 cx.new(|cx| {
519 let foreground_executor = cx.foreground_executor().clone();
520
521 let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
522 AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
523 stdin,
524 stdout,
525 move |fut| foreground_executor.spawn(fut).detach(),
526 );
527
528 let io_task = cx.background_spawn(async move {
529 io_fut.await.log_err();
530 });
531
532 let child_status = cx.background_spawn(async move {
533 match child.status().await {
534 Err(e) => Err(anyhow!(e)),
535 Ok(result) if result.success() => Ok(()),
536 Ok(result) => {
537 if let Some(version) = server.version(&command).await.log_err()
538 && !version.supported
539 {
540 Err(anyhow!(LoadError::Unsupported {
541 current_version: version.current_version
542 }))
543 } else {
544 Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
545 }
546 }
547 }
548 });
549
550 let action_log = cx.new(|_| ActionLog::new(project.clone()));
551
552 Self {
553 action_log,
554 shared_buffers: Default::default(),
555 entries: Default::default(),
556 title: "ACP Thread".into(),
557 project,
558 send_task: None,
559 connection: Arc::new(connection),
560 child_status: Some(child_status),
561 _io_task: io_task,
562 }
563 })
564 }
565
566 pub fn action_log(&self) -> &Entity<ActionLog> {
567 &self.action_log
568 }
569
570 pub fn project(&self) -> &Entity<Project> {
571 &self.project
572 }
573
574 #[cfg(test)]
575 pub fn fake(
576 stdin: async_pipe::PipeWriter,
577 stdout: async_pipe::PipeReader,
578 project: Entity<Project>,
579 cx: &mut Context<Self>,
580 ) -> Self {
581 let foreground_executor = cx.foreground_executor().clone();
582
583 let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
584 AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
585 stdin,
586 stdout,
587 move |fut| {
588 foreground_executor.spawn(fut).detach();
589 },
590 );
591
592 let io_task = cx.background_spawn({
593 async move {
594 io_fut.await.log_err();
595 }
596 });
597
598 let action_log = cx.new(|_| ActionLog::new(project.clone()));
599
600 Self {
601 action_log,
602 shared_buffers: Default::default(),
603 entries: Default::default(),
604 title: "ACP Thread".into(),
605 project,
606 send_task: None,
607 connection: Arc::new(connection),
608 child_status: None,
609 _io_task: io_task,
610 }
611 }
612
613 pub fn title(&self) -> SharedString {
614 self.title.clone()
615 }
616
617 pub fn entries(&self) -> &[AgentThreadEntry] {
618 &self.entries
619 }
620
621 pub fn status(&self) -> ThreadStatus {
622 if self.send_task.is_some() {
623 if self.waiting_for_tool_confirmation() {
624 ThreadStatus::WaitingForToolConfirmation
625 } else {
626 ThreadStatus::Generating
627 }
628 } else {
629 ThreadStatus::Idle
630 }
631 }
632
633 pub fn has_pending_edit_tool_calls(&self) -> bool {
634 for entry in self.entries.iter().rev() {
635 match entry {
636 AgentThreadEntry::UserMessage(_) => return false,
637 AgentThreadEntry::ToolCall(ToolCall {
638 status:
639 ToolCallStatus::Allowed {
640 status: acp::ToolCallStatus::Running,
641 ..
642 },
643 content: Some(ToolCallContent::Diff { .. }),
644 ..
645 }) => return true,
646 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
647 }
648 }
649
650 false
651 }
652
653 pub fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
654 self.entries.push(entry);
655 cx.emit(AcpThreadEvent::NewEntry);
656 }
657
658 pub fn push_assistant_chunk(
659 &mut self,
660 chunk: acp::AssistantMessageChunk,
661 cx: &mut Context<Self>,
662 ) {
663 let entries_len = self.entries.len();
664 if let Some(last_entry) = self.entries.last_mut()
665 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
666 {
667 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
668
669 match (chunks.last_mut(), &chunk) {
670 (
671 Some(AssistantMessageChunk::Text { chunk: old_chunk }),
672 acp::AssistantMessageChunk::Text { text: new_chunk },
673 )
674 | (
675 Some(AssistantMessageChunk::Thought { chunk: old_chunk }),
676 acp::AssistantMessageChunk::Thought { thought: new_chunk },
677 ) => {
678 old_chunk.update(cx, |old_chunk, cx| {
679 old_chunk.append(&new_chunk, cx);
680 });
681 }
682 _ => {
683 chunks.push(AssistantMessageChunk::from_acp(
684 chunk,
685 self.project.read(cx).languages().clone(),
686 cx,
687 ));
688 }
689 }
690 } else {
691 let chunk = AssistantMessageChunk::from_acp(
692 chunk,
693 self.project.read(cx).languages().clone(),
694 cx,
695 );
696
697 self.push_entry(
698 AgentThreadEntry::AssistantMessage(AssistantMessage {
699 chunks: vec![chunk],
700 }),
701 cx,
702 );
703 }
704 }
705
706 pub fn request_tool_call(
707 &mut self,
708 tool_call: acp::RequestToolCallConfirmationParams,
709 cx: &mut Context<Self>,
710 ) -> ToolCallRequest {
711 let (tx, rx) = oneshot::channel();
712
713 let status = ToolCallStatus::WaitingForConfirmation {
714 confirmation: ToolCallConfirmation::from_acp(
715 tool_call.confirmation,
716 self.project.read(cx).languages().clone(),
717 cx,
718 ),
719 respond_tx: tx,
720 };
721
722 let id = self.insert_tool_call(tool_call.tool_call, status, cx);
723 ToolCallRequest { id, outcome: rx }
724 }
725
726 pub fn push_tool_call(
727 &mut self,
728 request: acp::PushToolCallParams,
729 cx: &mut Context<Self>,
730 ) -> acp::ToolCallId {
731 let status = ToolCallStatus::Allowed {
732 status: acp::ToolCallStatus::Running,
733 };
734
735 self.insert_tool_call(request, status, cx)
736 }
737
738 fn insert_tool_call(
739 &mut self,
740 tool_call: acp::PushToolCallParams,
741 status: ToolCallStatus,
742 cx: &mut Context<Self>,
743 ) -> acp::ToolCallId {
744 let language_registry = self.project.read(cx).languages().clone();
745 let id = acp::ToolCallId(self.entries.len() as u64);
746 let call = ToolCall {
747 id,
748 label: cx.new(|cx| {
749 Markdown::new(
750 tool_call.label.into(),
751 Some(language_registry.clone()),
752 None,
753 cx,
754 )
755 }),
756 icon: acp_icon_to_ui_icon(tool_call.icon),
757 content: tool_call
758 .content
759 .map(|content| ToolCallContent::from_acp(content, language_registry, cx)),
760 locations: tool_call.locations,
761 status,
762 };
763
764 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
765
766 id
767 }
768
769 pub fn authorize_tool_call(
770 &mut self,
771 id: acp::ToolCallId,
772 outcome: acp::ToolCallConfirmationOutcome,
773 cx: &mut Context<Self>,
774 ) {
775 let Some((ix, call)) = self.tool_call_mut(id) else {
776 return;
777 };
778
779 let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
780 ToolCallStatus::Rejected
781 } else {
782 ToolCallStatus::Allowed {
783 status: acp::ToolCallStatus::Running,
784 }
785 };
786
787 let curr_status = mem::replace(&mut call.status, new_status);
788
789 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
790 respond_tx.send(outcome).log_err();
791 } else if cfg!(debug_assertions) {
792 panic!("tried to authorize an already authorized tool call");
793 }
794
795 cx.emit(AcpThreadEvent::EntryUpdated(ix));
796 }
797
798 pub fn update_tool_call(
799 &mut self,
800 id: acp::ToolCallId,
801 new_status: acp::ToolCallStatus,
802 new_content: Option<acp::ToolCallContent>,
803 cx: &mut Context<Self>,
804 ) -> Result<()> {
805 let language_registry = self.project.read(cx).languages().clone();
806 let (ix, call) = self.tool_call_mut(id).context("Entry not found")?;
807
808 call.content = new_content
809 .map(|new_content| ToolCallContent::from_acp(new_content, language_registry, cx));
810
811 match &mut call.status {
812 ToolCallStatus::Allowed { status } => {
813 *status = new_status;
814 }
815 ToolCallStatus::WaitingForConfirmation { .. } => {
816 anyhow::bail!("Tool call hasn't been authorized yet")
817 }
818 ToolCallStatus::Rejected => {
819 anyhow::bail!("Tool call was rejected and therefore can't be updated")
820 }
821 ToolCallStatus::Canceled => {
822 call.status = ToolCallStatus::Allowed { status: new_status };
823 }
824 }
825
826 cx.emit(AcpThreadEvent::EntryUpdated(ix));
827 Ok(())
828 }
829
830 fn tool_call_mut(&mut self, id: acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
831 let entry = self.entries.get_mut(id.0 as usize);
832 debug_assert!(
833 entry.is_some(),
834 "We shouldn't give out ids to entries that don't exist"
835 );
836 match entry {
837 Some(AgentThreadEntry::ToolCall(call)) if call.id == id => Some((id.0 as usize, call)),
838 _ => {
839 if cfg!(debug_assertions) {
840 panic!("entry is not a tool call");
841 }
842 None
843 }
844 }
845 }
846
847 /// Returns true if the last turn is awaiting tool authorization
848 pub fn waiting_for_tool_confirmation(&self) -> bool {
849 for entry in self.entries.iter().rev() {
850 match &entry {
851 AgentThreadEntry::ToolCall(call) => match call.status {
852 ToolCallStatus::WaitingForConfirmation { .. } => return true,
853 ToolCallStatus::Allowed { .. }
854 | ToolCallStatus::Rejected
855 | ToolCallStatus::Canceled => continue,
856 },
857 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
858 // Reached the beginning of the turn
859 return false;
860 }
861 }
862 }
863 false
864 }
865
866 pub fn initialize(
867 &self,
868 ) -> impl use<> + Future<Output = Result<acp::InitializeResponse, acp::Error>> {
869 let connection = self.connection.clone();
870 async move { connection.request(acp::InitializeParams).await }
871 }
872
873 pub fn authenticate(&self) -> impl use<> + Future<Output = Result<(), acp::Error>> {
874 let connection = self.connection.clone();
875 async move { connection.request(acp::AuthenticateParams).await }
876 }
877
878 #[cfg(test)]
879 pub fn send_raw(
880 &mut self,
881 message: &str,
882 cx: &mut Context<Self>,
883 ) -> BoxFuture<'static, Result<(), acp::Error>> {
884 self.send(
885 acp::SendUserMessageParams {
886 chunks: vec![acp::UserMessageChunk::Text {
887 text: message.to_string(),
888 }],
889 },
890 cx,
891 )
892 }
893
894 pub fn send(
895 &mut self,
896 message: acp::SendUserMessageParams,
897 cx: &mut Context<Self>,
898 ) -> BoxFuture<'static, Result<(), acp::Error>> {
899 let agent = self.connection.clone();
900 self.push_entry(
901 AgentThreadEntry::UserMessage(UserMessage::from_acp(
902 &message,
903 self.project.read(cx).languages().clone(),
904 cx,
905 )),
906 cx,
907 );
908
909 let (tx, rx) = oneshot::channel();
910 let cancel = self.cancel(cx);
911
912 self.send_task = Some(cx.spawn(async move |this, cx| {
913 cancel.await.log_err();
914
915 let result = agent.request(message).await;
916 tx.send(result).log_err();
917 this.update(cx, |this, _cx| this.send_task.take()).log_err();
918 }));
919
920 async move {
921 match rx.await {
922 Ok(Err(e)) => Err(e)?,
923 _ => Ok(()),
924 }
925 }
926 .boxed()
927 }
928
929 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<(), acp::Error>> {
930 let agent = self.connection.clone();
931
932 if self.send_task.take().is_some() {
933 cx.spawn(async move |this, cx| {
934 agent.request(acp::CancelSendMessageParams).await?;
935
936 this.update(cx, |this, _cx| {
937 for entry in this.entries.iter_mut() {
938 if let AgentThreadEntry::ToolCall(call) = entry {
939 let cancel = matches!(
940 call.status,
941 ToolCallStatus::WaitingForConfirmation { .. }
942 | ToolCallStatus::Allowed {
943 status: acp::ToolCallStatus::Running
944 }
945 );
946
947 if cancel {
948 let curr_status =
949 mem::replace(&mut call.status, ToolCallStatus::Canceled);
950
951 if let ToolCallStatus::WaitingForConfirmation {
952 respond_tx, ..
953 } = curr_status
954 {
955 respond_tx
956 .send(acp::ToolCallConfirmationOutcome::Cancel)
957 .ok();
958 }
959 }
960 }
961 }
962 })?;
963 Ok(())
964 })
965 } else {
966 Task::ready(Ok(()))
967 }
968 }
969
970 pub fn read_text_file(
971 &self,
972 request: acp::ReadTextFileParams,
973 cx: &mut Context<Self>,
974 ) -> Task<Result<String>> {
975 let project = self.project.clone();
976 let action_log = self.action_log.clone();
977 cx.spawn(async move |this, cx| {
978 let load = project.update(cx, |project, cx| {
979 let path = project
980 .project_path_for_absolute_path(&request.path, cx)
981 .context("invalid path")?;
982 anyhow::Ok(project.open_buffer(path, cx))
983 });
984 let buffer = load??.await?;
985
986 action_log.update(cx, |action_log, cx| {
987 action_log.buffer_read(buffer.clone(), cx);
988 })?;
989 project.update(cx, |project, cx| {
990 let position = buffer
991 .read(cx)
992 .snapshot()
993 .anchor_before(Point::new(request.line.unwrap_or_default(), 0));
994 project.set_agent_location(
995 Some(AgentLocation {
996 buffer: buffer.downgrade(),
997 position,
998 }),
999 cx,
1000 );
1001 })?;
1002 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
1003 this.update(cx, |this, _| {
1004 let text = snapshot.text();
1005 this.shared_buffers.insert(buffer.clone(), snapshot);
1006 text
1007 })
1008 })
1009 }
1010
1011 pub fn write_text_file(
1012 &self,
1013 path: PathBuf,
1014 content: String,
1015 cx: &mut Context<Self>,
1016 ) -> Task<Result<()>> {
1017 let project = self.project.clone();
1018 let action_log = self.action_log.clone();
1019 cx.spawn(async move |this, cx| {
1020 let load = project.update(cx, |project, cx| {
1021 let path = project
1022 .project_path_for_absolute_path(&path, cx)
1023 .context("invalid path")?;
1024 anyhow::Ok(project.open_buffer(path, cx))
1025 });
1026 let buffer = load??.await?;
1027 let snapshot = this.update(cx, |this, cx| {
1028 this.shared_buffers
1029 .get(&buffer)
1030 .cloned()
1031 .unwrap_or_else(|| buffer.read(cx).snapshot())
1032 })?;
1033 let edits = cx
1034 .background_executor()
1035 .spawn(async move {
1036 let old_text = snapshot.text();
1037 text_diff(old_text.as_str(), &content)
1038 .into_iter()
1039 .map(|(range, replacement)| {
1040 (
1041 snapshot.anchor_after(range.start)
1042 ..snapshot.anchor_before(range.end),
1043 replacement,
1044 )
1045 })
1046 .collect::<Vec<_>>()
1047 })
1048 .await;
1049 cx.update(|cx| {
1050 project.update(cx, |project, cx| {
1051 project.set_agent_location(
1052 Some(AgentLocation {
1053 buffer: buffer.downgrade(),
1054 position: edits
1055 .last()
1056 .map(|(range, _)| range.end)
1057 .unwrap_or(Anchor::MIN),
1058 }),
1059 cx,
1060 );
1061 });
1062
1063 action_log.update(cx, |action_log, cx| {
1064 action_log.buffer_read(buffer.clone(), cx);
1065 });
1066 buffer.update(cx, |buffer, cx| {
1067 buffer.edit(edits, None, cx);
1068 });
1069 action_log.update(cx, |action_log, cx| {
1070 action_log.buffer_edited(buffer.clone(), cx);
1071 });
1072 })?;
1073 project
1074 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1075 .await
1076 })
1077 }
1078
1079 pub fn child_status(&mut self) -> Option<Task<Result<()>>> {
1080 self.child_status.take()
1081 }
1082
1083 pub fn to_markdown(&self, cx: &App) -> String {
1084 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1085 }
1086}
1087
1088struct AcpClientDelegate {
1089 thread: WeakEntity<AcpThread>,
1090 cx: AsyncApp,
1091 // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
1092}
1093
1094impl AcpClientDelegate {
1095 fn new(thread: WeakEntity<AcpThread>, cx: AsyncApp) -> Self {
1096 Self { thread, cx }
1097 }
1098}
1099
1100impl acp::Client for AcpClientDelegate {
1101 async fn stream_assistant_message_chunk(
1102 &self,
1103 params: acp::StreamAssistantMessageChunkParams,
1104 ) -> Result<(), acp::Error> {
1105 let cx = &mut self.cx.clone();
1106
1107 cx.update(|cx| {
1108 self.thread
1109 .update(cx, |thread, cx| {
1110 thread.push_assistant_chunk(params.chunk, cx)
1111 })
1112 .ok();
1113 })?;
1114
1115 Ok(())
1116 }
1117
1118 async fn request_tool_call_confirmation(
1119 &self,
1120 request: acp::RequestToolCallConfirmationParams,
1121 ) -> Result<acp::RequestToolCallConfirmationResponse, acp::Error> {
1122 let cx = &mut self.cx.clone();
1123 let ToolCallRequest { id, outcome } = cx
1124 .update(|cx| {
1125 self.thread
1126 .update(cx, |thread, cx| thread.request_tool_call(request, cx))
1127 })?
1128 .context("Failed to update thread")?;
1129
1130 Ok(acp::RequestToolCallConfirmationResponse {
1131 id,
1132 outcome: outcome.await.map_err(acp::Error::into_internal_error)?,
1133 })
1134 }
1135
1136 async fn push_tool_call(
1137 &self,
1138 request: acp::PushToolCallParams,
1139 ) -> Result<acp::PushToolCallResponse, acp::Error> {
1140 let cx = &mut self.cx.clone();
1141 let id = cx
1142 .update(|cx| {
1143 self.thread
1144 .update(cx, |thread, cx| thread.push_tool_call(request, cx))
1145 })?
1146 .context("Failed to update thread")?;
1147
1148 Ok(acp::PushToolCallResponse { id })
1149 }
1150
1151 async fn update_tool_call(&self, request: acp::UpdateToolCallParams) -> Result<(), acp::Error> {
1152 let cx = &mut self.cx.clone();
1153
1154 cx.update(|cx| {
1155 self.thread.update(cx, |thread, cx| {
1156 thread.update_tool_call(request.tool_call_id, request.status, request.content, cx)
1157 })
1158 })?
1159 .context("Failed to update thread")??;
1160
1161 Ok(())
1162 }
1163
1164 async fn read_text_file(
1165 &self,
1166 request: acp::ReadTextFileParams,
1167 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
1168 let content = self
1169 .cx
1170 .update(|cx| {
1171 self.thread
1172 .update(cx, |thread, cx| thread.read_text_file(request, cx))
1173 })?
1174 .context("Failed to update thread")?
1175 .await?;
1176 Ok(acp::ReadTextFileResponse { content })
1177 }
1178
1179 async fn write_text_file(&self, request: acp::WriteTextFileParams) -> Result<(), acp::Error> {
1180 self.cx
1181 .update(|cx| {
1182 self.thread.update(cx, |thread, cx| {
1183 thread.write_text_file(request.path, request.content, cx)
1184 })
1185 })?
1186 .context("Failed to update thread")?
1187 .await?;
1188
1189 Ok(())
1190 }
1191}
1192
1193fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
1194 match icon {
1195 acp::Icon::FileSearch => IconName::ToolSearch,
1196 acp::Icon::Folder => IconName::ToolFolder,
1197 acp::Icon::Globe => IconName::ToolWeb,
1198 acp::Icon::Hammer => IconName::ToolHammer,
1199 acp::Icon::LightBulb => IconName::ToolBulb,
1200 acp::Icon::Pencil => IconName::ToolPencil,
1201 acp::Icon::Regex => IconName::ToolRegex,
1202 acp::Icon::Terminal => IconName::ToolTerminal,
1203 }
1204}
1205
1206pub struct ToolCallRequest {
1207 pub id: acp::ToolCallId,
1208 pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
1209}
1210
1211#[cfg(test)]
1212mod tests {
1213 use super::*;
1214 use agent_servers::{AgentServerCommand, AgentServerVersion};
1215 use async_pipe::{PipeReader, PipeWriter};
1216 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1217 use gpui::{AsyncApp, TestAppContext};
1218 use indoc::indoc;
1219 use project::FakeFs;
1220 use serde_json::json;
1221 use settings::SettingsStore;
1222 use smol::{future::BoxedLocal, stream::StreamExt as _};
1223 use std::{cell::RefCell, env, path::Path, rc::Rc, time::Duration};
1224 use util::path;
1225
1226 fn init_test(cx: &mut TestAppContext) {
1227 env_logger::try_init().ok();
1228 cx.update(|cx| {
1229 let settings_store = SettingsStore::test(cx);
1230 cx.set_global(settings_store);
1231 Project::init_settings(cx);
1232 language::init(cx);
1233 });
1234 }
1235
1236 #[gpui::test]
1237 async fn test_thinking_concatenation(cx: &mut TestAppContext) {
1238 init_test(cx);
1239
1240 let fs = FakeFs::new(cx.executor());
1241 let project = Project::test(fs, [], cx).await;
1242 let (thread, fake_server) = fake_acp_thread(project, cx);
1243
1244 fake_server.update(cx, |fake_server, _| {
1245 fake_server.on_user_message(move |_, server, mut cx| async move {
1246 server
1247 .update(&mut cx, |server, _| {
1248 server.send_to_zed(acp::StreamAssistantMessageChunkParams {
1249 chunk: acp::AssistantMessageChunk::Thought {
1250 thought: "Thinking ".into(),
1251 },
1252 })
1253 })?
1254 .await
1255 .unwrap();
1256 server
1257 .update(&mut cx, |server, _| {
1258 server.send_to_zed(acp::StreamAssistantMessageChunkParams {
1259 chunk: acp::AssistantMessageChunk::Thought {
1260 thought: "hard!".into(),
1261 },
1262 })
1263 })?
1264 .await
1265 .unwrap();
1266
1267 Ok(())
1268 })
1269 });
1270
1271 thread
1272 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1273 .await
1274 .unwrap();
1275
1276 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1277 assert_eq!(
1278 output,
1279 indoc! {r#"
1280 ## User
1281
1282 Hello from Zed!
1283
1284 ## Assistant
1285
1286 <thinking>
1287 Thinking hard!
1288 </thinking>
1289
1290 "#}
1291 );
1292 }
1293
1294 #[gpui::test]
1295 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1296 init_test(cx);
1297
1298 let fs = FakeFs::new(cx.executor());
1299 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1300 .await;
1301 let project = Project::test(fs.clone(), [], cx).await;
1302 let (thread, fake_server) = fake_acp_thread(project.clone(), cx);
1303 let (worktree, pathbuf) = project
1304 .update(cx, |project, cx| {
1305 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1306 })
1307 .await
1308 .unwrap();
1309 let buffer = project
1310 .update(cx, |project, cx| {
1311 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1312 })
1313 .await
1314 .unwrap();
1315
1316 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1317 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1318
1319 fake_server.update(cx, |fake_server, _| {
1320 fake_server.on_user_message(move |_, server, mut cx| {
1321 let read_file_tx = read_file_tx.clone();
1322 async move {
1323 let content = server
1324 .update(&mut cx, |server, _| {
1325 server.send_to_zed(acp::ReadTextFileParams {
1326 path: path!("/tmp/foo").into(),
1327 line: None,
1328 limit: None,
1329 })
1330 })?
1331 .await
1332 .unwrap();
1333 assert_eq!(content.content, "one\ntwo\nthree\n");
1334 read_file_tx.take().unwrap().send(()).unwrap();
1335 server
1336 .update(&mut cx, |server, _| {
1337 server.send_to_zed(acp::WriteTextFileParams {
1338 path: path!("/tmp/foo").into(),
1339 content: "one\ntwo\nthree\nfour\nfive\n".to_string(),
1340 })
1341 })?
1342 .await
1343 .unwrap();
1344 Ok(())
1345 }
1346 })
1347 });
1348
1349 let request = thread.update(cx, |thread, cx| {
1350 thread.send_raw("Extend the count in /tmp/foo", cx)
1351 });
1352 read_file_rx.await.ok();
1353 buffer.update(cx, |buffer, cx| {
1354 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1355 });
1356 cx.run_until_parked();
1357 assert_eq!(
1358 buffer.read_with(cx, |buffer, _| buffer.text()),
1359 "zero\none\ntwo\nthree\nfour\nfive\n"
1360 );
1361 assert_eq!(
1362 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1363 "zero\none\ntwo\nthree\nfour\nfive\n"
1364 );
1365 request.await.unwrap();
1366 }
1367
1368 #[gpui::test]
1369 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1370 init_test(cx);
1371
1372 let fs = FakeFs::new(cx.executor());
1373 let project = Project::test(fs, [], cx).await;
1374 let (thread, fake_server) = fake_acp_thread(project, cx);
1375
1376 let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>();
1377
1378 let tool_call_id = Rc::new(RefCell::new(None));
1379 let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx)));
1380 fake_server.update(cx, |fake_server, _| {
1381 let tool_call_id = tool_call_id.clone();
1382 fake_server.on_user_message(move |_, server, mut cx| {
1383 let end_turn_rx = end_turn_rx.clone();
1384 let tool_call_id = tool_call_id.clone();
1385 async move {
1386 let tool_call_result = server
1387 .update(&mut cx, |server, _| {
1388 server.send_to_zed(acp::PushToolCallParams {
1389 label: "Fetch".to_string(),
1390 icon: acp::Icon::Globe,
1391 content: None,
1392 locations: vec![],
1393 })
1394 })?
1395 .await
1396 .unwrap();
1397 *tool_call_id.clone().borrow_mut() = Some(tool_call_result.id);
1398 end_turn_rx.take().unwrap().await.ok();
1399
1400 Ok(())
1401 }
1402 })
1403 });
1404
1405 let request = thread.update(cx, |thread, cx| {
1406 thread.send_raw("Fetch https://example.com", cx)
1407 });
1408
1409 run_until_first_tool_call(&thread, cx).await;
1410
1411 thread.read_with(cx, |thread, _| {
1412 assert!(matches!(
1413 thread.entries[1],
1414 AgentThreadEntry::ToolCall(ToolCall {
1415 status: ToolCallStatus::Allowed {
1416 status: acp::ToolCallStatus::Running,
1417 ..
1418 },
1419 ..
1420 })
1421 ));
1422 });
1423
1424 cx.run_until_parked();
1425
1426 thread
1427 .update(cx, |thread, cx| thread.cancel(cx))
1428 .await
1429 .unwrap();
1430
1431 thread.read_with(cx, |thread, _| {
1432 assert!(matches!(
1433 &thread.entries[1],
1434 AgentThreadEntry::ToolCall(ToolCall {
1435 status: ToolCallStatus::Canceled,
1436 ..
1437 })
1438 ));
1439 });
1440
1441 fake_server
1442 .update(cx, |fake_server, _| {
1443 fake_server.send_to_zed(acp::UpdateToolCallParams {
1444 tool_call_id: tool_call_id.borrow().unwrap(),
1445 status: acp::ToolCallStatus::Finished,
1446 content: None,
1447 })
1448 })
1449 .await
1450 .unwrap();
1451
1452 drop(end_turn_tx);
1453 request.await.unwrap();
1454
1455 thread.read_with(cx, |thread, _| {
1456 assert!(matches!(
1457 thread.entries[1],
1458 AgentThreadEntry::ToolCall(ToolCall {
1459 status: ToolCallStatus::Allowed {
1460 status: acp::ToolCallStatus::Finished,
1461 ..
1462 },
1463 ..
1464 })
1465 ));
1466 });
1467 }
1468
1469 #[gpui::test]
1470 #[cfg_attr(not(feature = "gemini"), ignore)]
1471 async fn test_gemini_basic(cx: &mut TestAppContext) {
1472 init_test(cx);
1473
1474 cx.executor().allow_parking();
1475
1476 let fs = FakeFs::new(cx.executor());
1477 let project = Project::test(fs, [], cx).await;
1478 let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
1479 thread
1480 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1481 .await
1482 .unwrap();
1483
1484 thread.read_with(cx, |thread, _| {
1485 assert_eq!(thread.entries.len(), 2);
1486 assert!(matches!(
1487 thread.entries[0],
1488 AgentThreadEntry::UserMessage(_)
1489 ));
1490 assert!(matches!(
1491 thread.entries[1],
1492 AgentThreadEntry::AssistantMessage(_)
1493 ));
1494 });
1495 }
1496
1497 #[gpui::test]
1498 #[cfg_attr(not(feature = "gemini"), ignore)]
1499 async fn test_gemini_path_mentions(cx: &mut TestAppContext) {
1500 init_test(cx);
1501
1502 cx.executor().allow_parking();
1503 let tempdir = tempfile::tempdir().unwrap();
1504 std::fs::write(
1505 tempdir.path().join("foo.rs"),
1506 indoc! {"
1507 fn main() {
1508 println!(\"Hello, world!\");
1509 }
1510 "},
1511 )
1512 .expect("failed to write file");
1513 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
1514 let thread = gemini_acp_thread(project.clone(), tempdir.path(), cx).await;
1515 thread
1516 .update(cx, |thread, cx| {
1517 thread.send(
1518 acp::SendUserMessageParams {
1519 chunks: vec![
1520 acp::UserMessageChunk::Text {
1521 text: "Read the file ".into(),
1522 },
1523 acp::UserMessageChunk::Path {
1524 path: Path::new("foo.rs").into(),
1525 },
1526 acp::UserMessageChunk::Text {
1527 text: " and tell me what the content of the println! is".into(),
1528 },
1529 ],
1530 },
1531 cx,
1532 )
1533 })
1534 .await
1535 .unwrap();
1536
1537 thread.read_with(cx, |thread, cx| {
1538 assert_eq!(thread.entries.len(), 3);
1539 assert!(matches!(
1540 thread.entries[0],
1541 AgentThreadEntry::UserMessage(_)
1542 ));
1543 assert!(matches!(thread.entries[1], AgentThreadEntry::ToolCall(_)));
1544 let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries[2] else {
1545 panic!("Expected AssistantMessage")
1546 };
1547 assert!(
1548 assistant_message.to_markdown(cx).contains("Hello, world!"),
1549 "unexpected assistant message: {:?}",
1550 assistant_message.to_markdown(cx)
1551 );
1552 });
1553 }
1554
1555 #[gpui::test]
1556 #[cfg_attr(not(feature = "gemini"), ignore)]
1557 async fn test_gemini_tool_call(cx: &mut TestAppContext) {
1558 init_test(cx);
1559
1560 cx.executor().allow_parking();
1561
1562 let fs = FakeFs::new(cx.executor());
1563 fs.insert_tree(
1564 path!("/private/tmp"),
1565 json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
1566 )
1567 .await;
1568 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
1569 let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
1570 thread
1571 .update(cx, |thread, cx| {
1572 thread.send_raw(
1573 "Read the '/private/tmp/foo' file and tell me what you see.",
1574 cx,
1575 )
1576 })
1577 .await
1578 .unwrap();
1579 thread.read_with(cx, |thread, _cx| {
1580 assert!(matches!(
1581 &thread.entries()[2],
1582 AgentThreadEntry::ToolCall(ToolCall {
1583 status: ToolCallStatus::Allowed { .. },
1584 ..
1585 })
1586 ));
1587
1588 assert!(matches!(
1589 thread.entries[3],
1590 AgentThreadEntry::AssistantMessage(_)
1591 ));
1592 });
1593 }
1594
1595 #[gpui::test]
1596 #[cfg_attr(not(feature = "gemini"), ignore)]
1597 async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
1598 init_test(cx);
1599
1600 cx.executor().allow_parking();
1601
1602 let fs = FakeFs::new(cx.executor());
1603 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
1604 let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
1605 let full_turn = thread.update(cx, |thread, cx| {
1606 thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
1607 });
1608
1609 run_until_first_tool_call(&thread, cx).await;
1610
1611 let tool_call_id = thread.read_with(cx, |thread, _cx| {
1612 let AgentThreadEntry::ToolCall(ToolCall {
1613 id,
1614 status:
1615 ToolCallStatus::WaitingForConfirmation {
1616 confirmation: ToolCallConfirmation::Execute { root_command, .. },
1617 ..
1618 },
1619 ..
1620 }) = &thread.entries()[2]
1621 else {
1622 panic!();
1623 };
1624
1625 assert_eq!(root_command, "echo");
1626
1627 *id
1628 });
1629
1630 thread.update(cx, |thread, cx| {
1631 thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
1632
1633 assert!(matches!(
1634 &thread.entries()[2],
1635 AgentThreadEntry::ToolCall(ToolCall {
1636 status: ToolCallStatus::Allowed { .. },
1637 ..
1638 })
1639 ));
1640 });
1641
1642 full_turn.await.unwrap();
1643
1644 thread.read_with(cx, |thread, cx| {
1645 let AgentThreadEntry::ToolCall(ToolCall {
1646 content: Some(ToolCallContent::Markdown { markdown }),
1647 status: ToolCallStatus::Allowed { .. },
1648 ..
1649 }) = &thread.entries()[2]
1650 else {
1651 panic!();
1652 };
1653
1654 markdown.read_with(cx, |md, _cx| {
1655 assert!(
1656 md.source().contains("Hello, world!"),
1657 r#"Expected '{}' to contain "Hello, world!""#,
1658 md.source()
1659 );
1660 });
1661 });
1662 }
1663
1664 #[gpui::test]
1665 #[cfg_attr(not(feature = "gemini"), ignore)]
1666 async fn test_gemini_cancel(cx: &mut TestAppContext) {
1667 init_test(cx);
1668
1669 cx.executor().allow_parking();
1670
1671 let fs = FakeFs::new(cx.executor());
1672 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
1673 let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
1674 let full_turn = thread.update(cx, |thread, cx| {
1675 thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
1676 });
1677
1678 let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
1679
1680 thread.read_with(cx, |thread, _cx| {
1681 let AgentThreadEntry::ToolCall(ToolCall {
1682 id,
1683 status:
1684 ToolCallStatus::WaitingForConfirmation {
1685 confirmation: ToolCallConfirmation::Execute { root_command, .. },
1686 ..
1687 },
1688 ..
1689 }) = &thread.entries()[first_tool_call_ix]
1690 else {
1691 panic!("{:?}", thread.entries()[1]);
1692 };
1693
1694 assert_eq!(root_command, "echo");
1695
1696 *id
1697 });
1698
1699 thread
1700 .update(cx, |thread, cx| thread.cancel(cx))
1701 .await
1702 .unwrap();
1703 full_turn.await.unwrap();
1704 thread.read_with(cx, |thread, _| {
1705 let AgentThreadEntry::ToolCall(ToolCall {
1706 status: ToolCallStatus::Canceled,
1707 ..
1708 }) = &thread.entries()[first_tool_call_ix]
1709 else {
1710 panic!();
1711 };
1712 });
1713
1714 thread
1715 .update(cx, |thread, cx| {
1716 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
1717 })
1718 .await
1719 .unwrap();
1720 thread.read_with(cx, |thread, _| {
1721 assert!(matches!(
1722 &thread.entries().last().unwrap(),
1723 AgentThreadEntry::AssistantMessage(..),
1724 ))
1725 });
1726 }
1727
1728 async fn run_until_first_tool_call(
1729 thread: &Entity<AcpThread>,
1730 cx: &mut TestAppContext,
1731 ) -> usize {
1732 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1733
1734 let subscription = cx.update(|cx| {
1735 cx.subscribe(thread, move |thread, _, cx| {
1736 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1737 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1738 return tx.try_send(ix).unwrap();
1739 }
1740 }
1741 })
1742 });
1743
1744 select! {
1745 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1746 panic!("Timeout waiting for tool call")
1747 }
1748 ix = rx.next().fuse() => {
1749 drop(subscription);
1750 ix.unwrap()
1751 }
1752 }
1753 }
1754
1755 pub async fn gemini_acp_thread(
1756 project: Entity<Project>,
1757 current_dir: impl AsRef<Path>,
1758 cx: &mut TestAppContext,
1759 ) -> Entity<AcpThread> {
1760 struct DevGemini;
1761
1762 impl agent_servers::AgentServer for DevGemini {
1763 async fn command(
1764 &self,
1765 _project: &Entity<Project>,
1766 _cx: &mut AsyncApp,
1767 ) -> Result<agent_servers::AgentServerCommand> {
1768 let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
1769 .join("../../../gemini-cli/packages/cli")
1770 .to_string_lossy()
1771 .to_string();
1772
1773 Ok(AgentServerCommand {
1774 path: "node".into(),
1775 args: vec![cli_path, "--acp".into()],
1776 env: None,
1777 })
1778 }
1779
1780 async fn version(
1781 &self,
1782 _command: &agent_servers::AgentServerCommand,
1783 ) -> Result<AgentServerVersion> {
1784 Ok(AgentServerVersion {
1785 current_version: "0.1.0".into(),
1786 supported: true,
1787 })
1788 }
1789 }
1790
1791 let thread = AcpThread::spawn(DevGemini, current_dir.as_ref(), project, &mut cx.to_async())
1792 .await
1793 .unwrap();
1794
1795 thread
1796 .update(cx, |thread, _| thread.initialize())
1797 .await
1798 .unwrap();
1799 thread
1800 }
1801
1802 pub fn fake_acp_thread(
1803 project: Entity<Project>,
1804 cx: &mut TestAppContext,
1805 ) -> (Entity<AcpThread>, Entity<FakeAcpServer>) {
1806 let (stdin_tx, stdin_rx) = async_pipe::pipe();
1807 let (stdout_tx, stdout_rx) = async_pipe::pipe();
1808 let thread = cx.update(|cx| cx.new(|cx| AcpThread::fake(stdin_tx, stdout_rx, project, cx)));
1809 let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
1810 (thread, agent)
1811 }
1812
1813 pub struct FakeAcpServer {
1814 connection: acp::ClientConnection,
1815 _io_task: Task<()>,
1816 on_user_message: Option<
1817 Rc<
1818 dyn Fn(
1819 acp::SendUserMessageParams,
1820 Entity<FakeAcpServer>,
1821 AsyncApp,
1822 ) -> LocalBoxFuture<'static, Result<(), acp::Error>>,
1823 >,
1824 >,
1825 }
1826
1827 #[derive(Clone)]
1828 struct FakeAgent {
1829 server: Entity<FakeAcpServer>,
1830 cx: AsyncApp,
1831 }
1832
1833 impl acp::Agent for FakeAgent {
1834 async fn initialize(&self) -> Result<acp::InitializeResponse, acp::Error> {
1835 Ok(acp::InitializeResponse {
1836 is_authenticated: true,
1837 })
1838 }
1839
1840 async fn authenticate(&self) -> Result<(), acp::Error> {
1841 Ok(())
1842 }
1843
1844 async fn cancel_send_message(&self) -> Result<(), acp::Error> {
1845 Ok(())
1846 }
1847
1848 async fn send_user_message(
1849 &self,
1850 request: acp::SendUserMessageParams,
1851 ) -> Result<(), acp::Error> {
1852 let mut cx = self.cx.clone();
1853 let handler = self
1854 .server
1855 .update(&mut cx, |server, _| server.on_user_message.clone())
1856 .ok()
1857 .flatten();
1858 if let Some(handler) = handler {
1859 handler(request, self.server.clone(), self.cx.clone()).await
1860 } else {
1861 Err(anyhow::anyhow!("No handler for on_user_message").into())
1862 }
1863 }
1864 }
1865
1866 impl FakeAcpServer {
1867 fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
1868 let agent = FakeAgent {
1869 server: cx.entity(),
1870 cx: cx.to_async(),
1871 };
1872 let foreground_executor = cx.foreground_executor().clone();
1873
1874 let (connection, io_fut) = acp::ClientConnection::connect_to_client(
1875 agent.clone(),
1876 stdout,
1877 stdin,
1878 move |fut| {
1879 foreground_executor.spawn(fut).detach();
1880 },
1881 );
1882 FakeAcpServer {
1883 connection: connection,
1884 on_user_message: None,
1885 _io_task: cx.background_spawn(async move {
1886 io_fut.await.log_err();
1887 }),
1888 }
1889 }
1890
1891 fn on_user_message<F>(
1892 &mut self,
1893 handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity<FakeAcpServer>, AsyncApp) -> F
1894 + 'static,
1895 ) where
1896 F: Future<Output = Result<(), acp::Error>> + 'static,
1897 {
1898 self.on_user_message
1899 .replace(Rc::new(move |request, server, cx| {
1900 handler(request, server, cx).boxed_local()
1901 }));
1902 }
1903
1904 fn send_to_zed<T: acp::ClientRequest + 'static>(
1905 &self,
1906 message: T,
1907 ) -> BoxedLocal<Result<T::Response>> {
1908 self.connection
1909 .request(message)
1910 .map(|f| f.map_err(|err| anyhow!(err)))
1911 .boxed_local()
1912 }
1913 }
1914}