1pub use acp::ToolCallId;
2use agent_servers::AgentServer;
3use agentic_coding_protocol::{self as acp, ToolCallLocation, UserMessageChunk};
4use anyhow::{Context as _, Result, anyhow};
5use assistant_tool::ActionLog;
6use buffer_diff::BufferDiff;
7use editor::{Bias, MultiBuffer, PathKey};
8use futures::{FutureExt, channel::oneshot, future::BoxFuture};
9use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
10use itertools::Itertools;
11use language::{
12 Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
13 text_diff,
14};
15use markdown::Markdown;
16use project::{AgentLocation, Project};
17use std::collections::HashMap;
18use std::error::Error;
19use std::fmt::{Formatter, Write};
20use std::{
21 fmt::Display,
22 mem,
23 path::{Path, PathBuf},
24 sync::Arc,
25};
26use ui::{App, IconName};
27use util::ResultExt;
28
29#[derive(Clone, Debug, Eq, PartialEq)]
30pub struct UserMessage {
31 pub content: Entity<Markdown>,
32}
33
34impl UserMessage {
35 pub fn from_acp(
36 message: &acp::SendUserMessageParams,
37 language_registry: Arc<LanguageRegistry>,
38 cx: &mut App,
39 ) -> Self {
40 let mut md_source = String::new();
41
42 for chunk in &message.chunks {
43 match chunk {
44 UserMessageChunk::Text { text } => md_source.push_str(&text),
45 UserMessageChunk::Path { path } => {
46 write!(&mut md_source, "{}", MentionPath(&path)).unwrap()
47 }
48 }
49 }
50
51 Self {
52 content: cx
53 .new(|cx| Markdown::new(md_source.into(), Some(language_registry), None, cx)),
54 }
55 }
56
57 fn to_markdown(&self, cx: &App) -> String {
58 format!("## User\n\n{}\n\n", self.content.read(cx).source())
59 }
60}
61
62#[derive(Debug)]
63pub struct MentionPath<'a>(&'a Path);
64
65impl<'a> MentionPath<'a> {
66 const PREFIX: &'static str = "@file:";
67
68 pub fn new(path: &'a Path) -> Self {
69 MentionPath(path)
70 }
71
72 pub fn try_parse(url: &'a str) -> Option<Self> {
73 let path = url.strip_prefix(Self::PREFIX)?;
74 Some(MentionPath(Path::new(path)))
75 }
76
77 pub fn path(&self) -> &Path {
78 self.0
79 }
80}
81
82impl Display for MentionPath<'_> {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 write!(
85 f,
86 "[@{}]({}{})",
87 self.0.file_name().unwrap_or_default().display(),
88 Self::PREFIX,
89 self.0.display()
90 )
91 }
92}
93
94#[derive(Clone, Debug, Eq, PartialEq)]
95pub struct AssistantMessage {
96 pub chunks: Vec<AssistantMessageChunk>,
97}
98
99impl AssistantMessage {
100 fn to_markdown(&self, cx: &App) -> String {
101 format!(
102 "## Assistant\n\n{}\n\n",
103 self.chunks
104 .iter()
105 .map(|chunk| chunk.to_markdown(cx))
106 .join("\n\n")
107 )
108 }
109}
110
111#[derive(Clone, Debug, Eq, PartialEq)]
112pub enum AssistantMessageChunk {
113 Text { chunk: Entity<Markdown> },
114 Thought { chunk: Entity<Markdown> },
115}
116
117impl AssistantMessageChunk {
118 pub fn from_acp(
119 chunk: acp::AssistantMessageChunk,
120 language_registry: Arc<LanguageRegistry>,
121 cx: &mut App,
122 ) -> Self {
123 match chunk {
124 acp::AssistantMessageChunk::Text { text } => Self::Text {
125 chunk: cx.new(|cx| Markdown::new(text.into(), Some(language_registry), None, cx)),
126 },
127 acp::AssistantMessageChunk::Thought { thought } => Self::Thought {
128 chunk: cx
129 .new(|cx| Markdown::new(thought.into(), Some(language_registry), None, cx)),
130 },
131 }
132 }
133
134 pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
135 Self::Text {
136 chunk: cx.new(|cx| {
137 Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
138 }),
139 }
140 }
141
142 fn to_markdown(&self, cx: &App) -> String {
143 match self {
144 Self::Text { chunk } => chunk.read(cx).source().to_string(),
145 Self::Thought { chunk } => {
146 format!("<thinking>\n{}\n</thinking>", chunk.read(cx).source())
147 }
148 }
149 }
150}
151
152#[derive(Debug)]
153pub enum AgentThreadEntry {
154 UserMessage(UserMessage),
155 AssistantMessage(AssistantMessage),
156 ToolCall(ToolCall),
157}
158
159impl AgentThreadEntry {
160 fn to_markdown(&self, cx: &App) -> String {
161 match self {
162 Self::UserMessage(message) => message.to_markdown(cx),
163 Self::AssistantMessage(message) => message.to_markdown(cx),
164 Self::ToolCall(too_call) => too_call.to_markdown(cx),
165 }
166 }
167
168 pub fn diff(&self) -> Option<&Diff> {
169 if let AgentThreadEntry::ToolCall(ToolCall {
170 content: Some(ToolCallContent::Diff { diff }),
171 ..
172 }) = self
173 {
174 Some(&diff)
175 } else {
176 None
177 }
178 }
179
180 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
181 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
182 Some(locations)
183 } else {
184 None
185 }
186 }
187}
188
189#[derive(Debug)]
190pub struct ToolCall {
191 pub id: acp::ToolCallId,
192 pub label: Entity<Markdown>,
193 pub icon: IconName,
194 pub content: Option<ToolCallContent>,
195 pub status: ToolCallStatus,
196 pub locations: Vec<acp::ToolCallLocation>,
197}
198
199impl ToolCall {
200 fn to_markdown(&self, cx: &App) -> String {
201 let mut markdown = format!(
202 "**Tool Call: {}**\nStatus: {}\n\n",
203 self.label.read(cx).source(),
204 self.status
205 );
206 if let Some(content) = &self.content {
207 markdown.push_str(content.to_markdown(cx).as_str());
208 markdown.push_str("\n\n");
209 }
210 markdown
211 }
212}
213
214#[derive(Debug)]
215pub enum ToolCallStatus {
216 WaitingForConfirmation {
217 confirmation: ToolCallConfirmation,
218 respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
219 },
220 Allowed {
221 status: acp::ToolCallStatus,
222 },
223 Rejected,
224 Canceled,
225}
226
227impl Display for ToolCallStatus {
228 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
229 write!(
230 f,
231 "{}",
232 match self {
233 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
234 ToolCallStatus::Allowed { status } => match status {
235 acp::ToolCallStatus::Running => "Running",
236 acp::ToolCallStatus::Finished => "Finished",
237 acp::ToolCallStatus::Error => "Error",
238 },
239 ToolCallStatus::Rejected => "Rejected",
240 ToolCallStatus::Canceled => "Canceled",
241 }
242 )
243 }
244}
245
246#[derive(Debug)]
247pub enum ToolCallConfirmation {
248 Edit {
249 description: Option<Entity<Markdown>>,
250 },
251 Execute {
252 command: String,
253 root_command: String,
254 description: Option<Entity<Markdown>>,
255 },
256 Mcp {
257 server_name: String,
258 tool_name: String,
259 tool_display_name: String,
260 description: Option<Entity<Markdown>>,
261 },
262 Fetch {
263 urls: Vec<SharedString>,
264 description: Option<Entity<Markdown>>,
265 },
266 Other {
267 description: Entity<Markdown>,
268 },
269}
270
271impl ToolCallConfirmation {
272 pub fn from_acp(
273 confirmation: acp::ToolCallConfirmation,
274 language_registry: Arc<LanguageRegistry>,
275 cx: &mut App,
276 ) -> Self {
277 let to_md = |description: String, cx: &mut App| -> Entity<Markdown> {
278 cx.new(|cx| {
279 Markdown::new(
280 description.into(),
281 Some(language_registry.clone()),
282 None,
283 cx,
284 )
285 })
286 };
287
288 match confirmation {
289 acp::ToolCallConfirmation::Edit { description } => Self::Edit {
290 description: description.map(|description| to_md(description, cx)),
291 },
292 acp::ToolCallConfirmation::Execute {
293 command,
294 root_command,
295 description,
296 } => Self::Execute {
297 command,
298 root_command,
299 description: description.map(|description| to_md(description, cx)),
300 },
301 acp::ToolCallConfirmation::Mcp {
302 server_name,
303 tool_name,
304 tool_display_name,
305 description,
306 } => Self::Mcp {
307 server_name,
308 tool_name,
309 tool_display_name,
310 description: description.map(|description| to_md(description, cx)),
311 },
312 acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch {
313 urls: urls.iter().map(|url| url.into()).collect(),
314 description: description.map(|description| to_md(description, cx)),
315 },
316 acp::ToolCallConfirmation::Other { description } => Self::Other {
317 description: to_md(description, cx),
318 },
319 }
320 }
321}
322
323#[derive(Debug)]
324pub enum ToolCallContent {
325 Markdown { markdown: Entity<Markdown> },
326 Diff { diff: Diff },
327}
328
329impl ToolCallContent {
330 pub fn from_acp(
331 content: acp::ToolCallContent,
332 language_registry: Arc<LanguageRegistry>,
333 cx: &mut App,
334 ) -> Self {
335 match content {
336 acp::ToolCallContent::Markdown { markdown } => Self::Markdown {
337 markdown: cx.new(|cx| Markdown::new_text(markdown.into(), cx)),
338 },
339 acp::ToolCallContent::Diff { diff } => Self::Diff {
340 diff: Diff::from_acp(diff, language_registry, cx),
341 },
342 }
343 }
344
345 fn to_markdown(&self, cx: &App) -> String {
346 match self {
347 Self::Markdown { markdown } => markdown.read(cx).source().to_string(),
348 Self::Diff { diff } => diff.to_markdown(cx),
349 }
350 }
351}
352
353#[derive(Debug)]
354pub struct Diff {
355 pub multibuffer: Entity<MultiBuffer>,
356 pub path: PathBuf,
357 pub new_buffer: Entity<Buffer>,
358 pub old_buffer: Entity<Buffer>,
359 _task: Task<Result<()>>,
360}
361
362impl Diff {
363 pub fn from_acp(
364 diff: acp::Diff,
365 language_registry: Arc<LanguageRegistry>,
366 cx: &mut App,
367 ) -> Self {
368 let acp::Diff {
369 path,
370 old_text,
371 new_text,
372 } = diff;
373
374 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
375
376 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
377 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
378 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
379 let old_buffer_snapshot = old_buffer.read(cx).snapshot();
380 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
381 let diff_task = buffer_diff.update(cx, |diff, cx| {
382 diff.set_base_text(
383 old_buffer_snapshot,
384 Some(language_registry.clone()),
385 new_buffer_snapshot,
386 cx,
387 )
388 });
389
390 let task = cx.spawn({
391 let multibuffer = multibuffer.clone();
392 let path = path.clone();
393 let new_buffer = new_buffer.clone();
394 async move |cx| {
395 diff_task.await?;
396
397 multibuffer
398 .update(cx, |multibuffer, cx| {
399 let hunk_ranges = {
400 let buffer = new_buffer.read(cx);
401 let diff = buffer_diff.read(cx);
402 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
403 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
404 .collect::<Vec<_>>()
405 };
406
407 multibuffer.set_excerpts_for_path(
408 PathKey::for_buffer(&new_buffer, cx),
409 new_buffer.clone(),
410 hunk_ranges,
411 editor::DEFAULT_MULTIBUFFER_CONTEXT,
412 cx,
413 );
414 multibuffer.add_diff(buffer_diff.clone(), cx);
415 })
416 .log_err();
417
418 if let Some(language) = language_registry
419 .language_for_file_path(&path)
420 .await
421 .log_err()
422 {
423 new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
424 }
425
426 anyhow::Ok(())
427 }
428 });
429
430 Self {
431 multibuffer,
432 path,
433 new_buffer,
434 old_buffer,
435 _task: task,
436 }
437 }
438
439 fn to_markdown(&self, cx: &App) -> String {
440 let buffer_text = self
441 .multibuffer
442 .read(cx)
443 .all_buffers()
444 .iter()
445 .map(|buffer| buffer.read(cx).text())
446 .join("\n");
447 format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
448 }
449}
450
451pub struct AcpThread {
452 entries: Vec<AgentThreadEntry>,
453 title: SharedString,
454 project: Entity<Project>,
455 action_log: Entity<ActionLog>,
456 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
457 send_task: Option<Task<()>>,
458 connection: Arc<acp::AgentConnection>,
459 child_status: Option<Task<Result<()>>>,
460 _io_task: Task<()>,
461}
462
463pub enum AcpThreadEvent {
464 NewEntry,
465 EntryUpdated(usize),
466}
467
468impl EventEmitter<AcpThreadEvent> for AcpThread {}
469
470#[derive(PartialEq, Eq)]
471pub enum ThreadStatus {
472 Idle,
473 WaitingForToolConfirmation,
474 Generating,
475}
476
477#[derive(Debug, Clone)]
478pub enum LoadError {
479 Unsupported { current_version: SharedString },
480 Exited(i32),
481 Other(SharedString),
482}
483
484impl Display for LoadError {
485 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
486 match self {
487 LoadError::Unsupported { current_version } => {
488 write!(
489 f,
490 "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).",
491 current_version
492 )
493 }
494 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
495 LoadError::Other(msg) => write!(f, "{}", msg),
496 }
497 }
498}
499
500impl Error for LoadError {}
501
502impl AcpThread {
503 pub async fn spawn(
504 server: impl AgentServer + 'static,
505 root_dir: &Path,
506 project: Entity<Project>,
507 cx: &mut AsyncApp,
508 ) -> Result<Entity<Self>> {
509 let command = match server.command(&project, cx).await {
510 Ok(command) => command,
511 Err(e) => return Err(anyhow!(LoadError::Other(format!("{e}").into()))),
512 };
513
514 let mut child = util::command::new_smol_command(&command.path)
515 .args(command.args.iter())
516 .current_dir(root_dir)
517 .stdin(std::process::Stdio::piped())
518 .stdout(std::process::Stdio::piped())
519 .stderr(std::process::Stdio::inherit())
520 .kill_on_drop(true)
521 .spawn()?;
522
523 let stdin = child.stdin.take().unwrap();
524 let stdout = child.stdout.take().unwrap();
525
526 cx.new(|cx| {
527 let foreground_executor = cx.foreground_executor().clone();
528
529 let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
530 AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
531 stdin,
532 stdout,
533 move |fut| foreground_executor.spawn(fut).detach(),
534 );
535
536 let io_task = cx.background_spawn(async move {
537 io_fut.await.log_err();
538 });
539
540 let child_status = cx.background_spawn(async move {
541 match child.status().await {
542 Err(e) => Err(anyhow!(e)),
543 Ok(result) if result.success() => Ok(()),
544 Ok(result) => {
545 if let Some(version) = server.version(&command).await.log_err()
546 && !version.supported
547 {
548 Err(anyhow!(LoadError::Unsupported {
549 current_version: version.current_version
550 }))
551 } else {
552 Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
553 }
554 }
555 }
556 });
557
558 let action_log = cx.new(|_| ActionLog::new(project.clone()));
559
560 Self {
561 action_log,
562 shared_buffers: Default::default(),
563 entries: Default::default(),
564 title: "ACP Thread".into(),
565 project,
566 send_task: None,
567 connection: Arc::new(connection),
568 child_status: Some(child_status),
569 _io_task: io_task,
570 }
571 })
572 }
573
574 pub fn action_log(&self) -> &Entity<ActionLog> {
575 &self.action_log
576 }
577
578 pub fn project(&self) -> &Entity<Project> {
579 &self.project
580 }
581
582 #[cfg(test)]
583 pub fn fake(
584 stdin: async_pipe::PipeWriter,
585 stdout: async_pipe::PipeReader,
586 project: Entity<Project>,
587 cx: &mut Context<Self>,
588 ) -> Self {
589 let foreground_executor = cx.foreground_executor().clone();
590
591 let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
592 AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
593 stdin,
594 stdout,
595 move |fut| {
596 foreground_executor.spawn(fut).detach();
597 },
598 );
599
600 let io_task = cx.background_spawn({
601 async move {
602 io_fut.await.log_err();
603 }
604 });
605
606 let action_log = cx.new(|_| ActionLog::new(project.clone()));
607
608 Self {
609 action_log,
610 shared_buffers: Default::default(),
611 entries: Default::default(),
612 title: "ACP Thread".into(),
613 project,
614 send_task: None,
615 connection: Arc::new(connection),
616 child_status: None,
617 _io_task: io_task,
618 }
619 }
620
621 pub fn title(&self) -> SharedString {
622 self.title.clone()
623 }
624
625 pub fn entries(&self) -> &[AgentThreadEntry] {
626 &self.entries
627 }
628
629 pub fn status(&self) -> ThreadStatus {
630 if self.send_task.is_some() {
631 if self.waiting_for_tool_confirmation() {
632 ThreadStatus::WaitingForToolConfirmation
633 } else {
634 ThreadStatus::Generating
635 }
636 } else {
637 ThreadStatus::Idle
638 }
639 }
640
641 pub fn has_pending_edit_tool_calls(&self) -> bool {
642 for entry in self.entries.iter().rev() {
643 match entry {
644 AgentThreadEntry::UserMessage(_) => return false,
645 AgentThreadEntry::ToolCall(ToolCall {
646 status:
647 ToolCallStatus::Allowed {
648 status: acp::ToolCallStatus::Running,
649 ..
650 },
651 content: Some(ToolCallContent::Diff { .. }),
652 ..
653 }) => return true,
654 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
655 }
656 }
657
658 false
659 }
660
661 pub fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
662 self.entries.push(entry);
663 cx.emit(AcpThreadEvent::NewEntry);
664 }
665
666 pub fn push_assistant_chunk(
667 &mut self,
668 chunk: acp::AssistantMessageChunk,
669 cx: &mut Context<Self>,
670 ) {
671 let entries_len = self.entries.len();
672 if let Some(last_entry) = self.entries.last_mut()
673 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
674 {
675 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
676
677 match (chunks.last_mut(), &chunk) {
678 (
679 Some(AssistantMessageChunk::Text { chunk: old_chunk }),
680 acp::AssistantMessageChunk::Text { text: new_chunk },
681 )
682 | (
683 Some(AssistantMessageChunk::Thought { chunk: old_chunk }),
684 acp::AssistantMessageChunk::Thought { thought: new_chunk },
685 ) => {
686 old_chunk.update(cx, |old_chunk, cx| {
687 old_chunk.append(&new_chunk, cx);
688 });
689 }
690 _ => {
691 chunks.push(AssistantMessageChunk::from_acp(
692 chunk,
693 self.project.read(cx).languages().clone(),
694 cx,
695 ));
696 }
697 }
698 } else {
699 let chunk = AssistantMessageChunk::from_acp(
700 chunk,
701 self.project.read(cx).languages().clone(),
702 cx,
703 );
704
705 self.push_entry(
706 AgentThreadEntry::AssistantMessage(AssistantMessage {
707 chunks: vec![chunk],
708 }),
709 cx,
710 );
711 }
712 }
713
714 pub fn request_tool_call(
715 &mut self,
716 tool_call: acp::RequestToolCallConfirmationParams,
717 cx: &mut Context<Self>,
718 ) -> ToolCallRequest {
719 let (tx, rx) = oneshot::channel();
720
721 let status = ToolCallStatus::WaitingForConfirmation {
722 confirmation: ToolCallConfirmation::from_acp(
723 tool_call.confirmation,
724 self.project.read(cx).languages().clone(),
725 cx,
726 ),
727 respond_tx: tx,
728 };
729
730 let id = self.insert_tool_call(tool_call.tool_call, status, cx);
731 ToolCallRequest { id, outcome: rx }
732 }
733
734 pub fn push_tool_call(
735 &mut self,
736 request: acp::PushToolCallParams,
737 cx: &mut Context<Self>,
738 ) -> acp::ToolCallId {
739 let status = ToolCallStatus::Allowed {
740 status: acp::ToolCallStatus::Running,
741 };
742
743 self.insert_tool_call(request, status, cx)
744 }
745
746 fn insert_tool_call(
747 &mut self,
748 tool_call: acp::PushToolCallParams,
749 status: ToolCallStatus,
750 cx: &mut Context<Self>,
751 ) -> acp::ToolCallId {
752 let language_registry = self.project.read(cx).languages().clone();
753 let id = acp::ToolCallId(self.entries.len() as u64);
754 let call = ToolCall {
755 id,
756 label: cx.new(|cx| {
757 Markdown::new(
758 tool_call.label.into(),
759 Some(language_registry.clone()),
760 None,
761 cx,
762 )
763 }),
764 icon: acp_icon_to_ui_icon(tool_call.icon),
765 content: tool_call
766 .content
767 .map(|content| ToolCallContent::from_acp(content, language_registry, cx)),
768 locations: tool_call.locations,
769 status,
770 };
771
772 let location = call.locations.last().cloned();
773 if let Some(location) = location {
774 self.set_project_location(location, cx)
775 }
776
777 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
778
779 id
780 }
781
782 pub fn authorize_tool_call(
783 &mut self,
784 id: acp::ToolCallId,
785 outcome: acp::ToolCallConfirmationOutcome,
786 cx: &mut Context<Self>,
787 ) {
788 let Some((ix, call)) = self.tool_call_mut(id) else {
789 return;
790 };
791
792 let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
793 ToolCallStatus::Rejected
794 } else {
795 ToolCallStatus::Allowed {
796 status: acp::ToolCallStatus::Running,
797 }
798 };
799
800 let curr_status = mem::replace(&mut call.status, new_status);
801
802 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
803 respond_tx.send(outcome).log_err();
804 } else if cfg!(debug_assertions) {
805 panic!("tried to authorize an already authorized tool call");
806 }
807
808 cx.emit(AcpThreadEvent::EntryUpdated(ix));
809 }
810
811 pub fn update_tool_call(
812 &mut self,
813 id: acp::ToolCallId,
814 new_status: acp::ToolCallStatus,
815 new_content: Option<acp::ToolCallContent>,
816 cx: &mut Context<Self>,
817 ) -> Result<()> {
818 let language_registry = self.project.read(cx).languages().clone();
819 let (ix, call) = self.tool_call_mut(id).context("Entry not found")?;
820
821 call.content = new_content
822 .map(|new_content| ToolCallContent::from_acp(new_content, language_registry, cx));
823
824 match &mut call.status {
825 ToolCallStatus::Allowed { status } => {
826 *status = new_status;
827 }
828 ToolCallStatus::WaitingForConfirmation { .. } => {
829 anyhow::bail!("Tool call hasn't been authorized yet")
830 }
831 ToolCallStatus::Rejected => {
832 anyhow::bail!("Tool call was rejected and therefore can't be updated")
833 }
834 ToolCallStatus::Canceled => {
835 call.status = ToolCallStatus::Allowed { status: new_status };
836 }
837 }
838
839 let location = call.locations.last().cloned();
840 if let Some(location) = location {
841 self.set_project_location(location, cx)
842 }
843
844 cx.emit(AcpThreadEvent::EntryUpdated(ix));
845 Ok(())
846 }
847
848 fn tool_call_mut(&mut self, id: acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
849 let entry = self.entries.get_mut(id.0 as usize);
850 debug_assert!(
851 entry.is_some(),
852 "We shouldn't give out ids to entries that don't exist"
853 );
854 match entry {
855 Some(AgentThreadEntry::ToolCall(call)) if call.id == id => Some((id.0 as usize, call)),
856 _ => {
857 if cfg!(debug_assertions) {
858 panic!("entry is not a tool call");
859 }
860 None
861 }
862 }
863 }
864
865 pub fn set_project_location(&self, location: ToolCallLocation, cx: &mut Context<Self>) {
866 self.project.update(cx, |project, cx| {
867 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
868 return;
869 };
870 let buffer = project.open_buffer(path, cx);
871 cx.spawn(async move |project, cx| {
872 let buffer = buffer.await?;
873
874 project.update(cx, |project, cx| {
875 let position = if let Some(line) = location.line {
876 let snapshot = buffer.read(cx).snapshot();
877 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
878 snapshot.anchor_before(point)
879 } else {
880 Anchor::MIN
881 };
882
883 project.set_agent_location(
884 Some(AgentLocation {
885 buffer: buffer.downgrade(),
886 position,
887 }),
888 cx,
889 );
890 })
891 })
892 .detach_and_log_err(cx);
893 });
894 }
895
896 /// Returns true if the last turn is awaiting tool authorization
897 pub fn waiting_for_tool_confirmation(&self) -> bool {
898 for entry in self.entries.iter().rev() {
899 match &entry {
900 AgentThreadEntry::ToolCall(call) => match call.status {
901 ToolCallStatus::WaitingForConfirmation { .. } => return true,
902 ToolCallStatus::Allowed { .. }
903 | ToolCallStatus::Rejected
904 | ToolCallStatus::Canceled => continue,
905 },
906 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
907 // Reached the beginning of the turn
908 return false;
909 }
910 }
911 }
912 false
913 }
914
915 pub fn initialize(
916 &self,
917 ) -> impl use<> + Future<Output = Result<acp::InitializeResponse, acp::Error>> {
918 let connection = self.connection.clone();
919 async move { connection.initialize().await }
920 }
921
922 pub fn authenticate(&self) -> impl use<> + Future<Output = Result<(), acp::Error>> {
923 let connection = self.connection.clone();
924 async move { connection.request(acp::AuthenticateParams).await }
925 }
926
927 #[cfg(test)]
928 pub fn send_raw(
929 &mut self,
930 message: &str,
931 cx: &mut Context<Self>,
932 ) -> BoxFuture<'static, Result<(), acp::Error>> {
933 self.send(
934 acp::SendUserMessageParams {
935 chunks: vec![acp::UserMessageChunk::Text {
936 text: message.to_string(),
937 }],
938 },
939 cx,
940 )
941 }
942
943 pub fn send(
944 &mut self,
945 message: acp::SendUserMessageParams,
946 cx: &mut Context<Self>,
947 ) -> BoxFuture<'static, Result<(), acp::Error>> {
948 let agent = self.connection.clone();
949 self.push_entry(
950 AgentThreadEntry::UserMessage(UserMessage::from_acp(
951 &message,
952 self.project.read(cx).languages().clone(),
953 cx,
954 )),
955 cx,
956 );
957
958 let (tx, rx) = oneshot::channel();
959 let cancel = self.cancel(cx);
960
961 self.send_task = Some(cx.spawn(async move |this, cx| {
962 cancel.await.log_err();
963
964 let result = agent.request(message).await;
965 tx.send(result).log_err();
966 this.update(cx, |this, _cx| this.send_task.take()).log_err();
967 }));
968
969 async move {
970 match rx.await {
971 Ok(Err(e)) => Err(e)?,
972 _ => Ok(()),
973 }
974 }
975 .boxed()
976 }
977
978 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<(), acp::Error>> {
979 let agent = self.connection.clone();
980
981 if self.send_task.take().is_some() {
982 cx.spawn(async move |this, cx| {
983 agent.request(acp::CancelSendMessageParams).await?;
984
985 this.update(cx, |this, _cx| {
986 for entry in this.entries.iter_mut() {
987 if let AgentThreadEntry::ToolCall(call) = entry {
988 let cancel = matches!(
989 call.status,
990 ToolCallStatus::WaitingForConfirmation { .. }
991 | ToolCallStatus::Allowed {
992 status: acp::ToolCallStatus::Running
993 }
994 );
995
996 if cancel {
997 let curr_status =
998 mem::replace(&mut call.status, ToolCallStatus::Canceled);
999
1000 if let ToolCallStatus::WaitingForConfirmation {
1001 respond_tx, ..
1002 } = curr_status
1003 {
1004 respond_tx
1005 .send(acp::ToolCallConfirmationOutcome::Cancel)
1006 .ok();
1007 }
1008 }
1009 }
1010 }
1011 })?;
1012 Ok(())
1013 })
1014 } else {
1015 Task::ready(Ok(()))
1016 }
1017 }
1018
1019 pub fn read_text_file(
1020 &self,
1021 request: acp::ReadTextFileParams,
1022 cx: &mut Context<Self>,
1023 ) -> Task<Result<String>> {
1024 let project = self.project.clone();
1025 let action_log = self.action_log.clone();
1026 cx.spawn(async move |this, cx| {
1027 let load = project.update(cx, |project, cx| {
1028 let path = project
1029 .project_path_for_absolute_path(&request.path, cx)
1030 .context("invalid path")?;
1031 anyhow::Ok(project.open_buffer(path, cx))
1032 });
1033 let buffer = load??.await?;
1034
1035 action_log.update(cx, |action_log, cx| {
1036 action_log.buffer_read(buffer.clone(), cx);
1037 })?;
1038 project.update(cx, |project, cx| {
1039 let position = buffer
1040 .read(cx)
1041 .snapshot()
1042 .anchor_before(Point::new(request.line.unwrap_or_default(), 0));
1043 project.set_agent_location(
1044 Some(AgentLocation {
1045 buffer: buffer.downgrade(),
1046 position,
1047 }),
1048 cx,
1049 );
1050 })?;
1051 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
1052 this.update(cx, |this, _| {
1053 let text = snapshot.text();
1054 this.shared_buffers.insert(buffer.clone(), snapshot);
1055 text
1056 })
1057 })
1058 }
1059
1060 pub fn write_text_file(
1061 &self,
1062 path: PathBuf,
1063 content: String,
1064 cx: &mut Context<Self>,
1065 ) -> Task<Result<()>> {
1066 let project = self.project.clone();
1067 let action_log = self.action_log.clone();
1068 cx.spawn(async move |this, cx| {
1069 let load = project.update(cx, |project, cx| {
1070 let path = project
1071 .project_path_for_absolute_path(&path, cx)
1072 .context("invalid path")?;
1073 anyhow::Ok(project.open_buffer(path, cx))
1074 });
1075 let buffer = load??.await?;
1076 let snapshot = this.update(cx, |this, cx| {
1077 this.shared_buffers
1078 .get(&buffer)
1079 .cloned()
1080 .unwrap_or_else(|| buffer.read(cx).snapshot())
1081 })?;
1082 let edits = cx
1083 .background_executor()
1084 .spawn(async move {
1085 let old_text = snapshot.text();
1086 text_diff(old_text.as_str(), &content)
1087 .into_iter()
1088 .map(|(range, replacement)| {
1089 (
1090 snapshot.anchor_after(range.start)
1091 ..snapshot.anchor_before(range.end),
1092 replacement,
1093 )
1094 })
1095 .collect::<Vec<_>>()
1096 })
1097 .await;
1098 cx.update(|cx| {
1099 project.update(cx, |project, cx| {
1100 project.set_agent_location(
1101 Some(AgentLocation {
1102 buffer: buffer.downgrade(),
1103 position: edits
1104 .last()
1105 .map(|(range, _)| range.end)
1106 .unwrap_or(Anchor::MIN),
1107 }),
1108 cx,
1109 );
1110 });
1111
1112 action_log.update(cx, |action_log, cx| {
1113 action_log.buffer_read(buffer.clone(), cx);
1114 });
1115 buffer.update(cx, |buffer, cx| {
1116 buffer.edit(edits, None, cx);
1117 });
1118 action_log.update(cx, |action_log, cx| {
1119 action_log.buffer_edited(buffer.clone(), cx);
1120 });
1121 })?;
1122 project
1123 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1124 .await
1125 })
1126 }
1127
1128 pub fn child_status(&mut self) -> Option<Task<Result<()>>> {
1129 self.child_status.take()
1130 }
1131
1132 pub fn to_markdown(&self, cx: &App) -> String {
1133 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1134 }
1135}
1136
1137struct AcpClientDelegate {
1138 thread: WeakEntity<AcpThread>,
1139 cx: AsyncApp,
1140 // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
1141}
1142
1143impl AcpClientDelegate {
1144 fn new(thread: WeakEntity<AcpThread>, cx: AsyncApp) -> Self {
1145 Self { thread, cx }
1146 }
1147}
1148
1149impl acp::Client for AcpClientDelegate {
1150 async fn stream_assistant_message_chunk(
1151 &self,
1152 params: acp::StreamAssistantMessageChunkParams,
1153 ) -> Result<(), acp::Error> {
1154 let cx = &mut self.cx.clone();
1155
1156 cx.update(|cx| {
1157 self.thread
1158 .update(cx, |thread, cx| {
1159 thread.push_assistant_chunk(params.chunk, cx)
1160 })
1161 .ok();
1162 })?;
1163
1164 Ok(())
1165 }
1166
1167 async fn request_tool_call_confirmation(
1168 &self,
1169 request: acp::RequestToolCallConfirmationParams,
1170 ) -> Result<acp::RequestToolCallConfirmationResponse, acp::Error> {
1171 let cx = &mut self.cx.clone();
1172 let ToolCallRequest { id, outcome } = cx
1173 .update(|cx| {
1174 self.thread
1175 .update(cx, |thread, cx| thread.request_tool_call(request, cx))
1176 })?
1177 .context("Failed to update thread")?;
1178
1179 Ok(acp::RequestToolCallConfirmationResponse {
1180 id,
1181 outcome: outcome.await.map_err(acp::Error::into_internal_error)?,
1182 })
1183 }
1184
1185 async fn push_tool_call(
1186 &self,
1187 request: acp::PushToolCallParams,
1188 ) -> Result<acp::PushToolCallResponse, acp::Error> {
1189 let cx = &mut self.cx.clone();
1190 let id = cx
1191 .update(|cx| {
1192 self.thread
1193 .update(cx, |thread, cx| thread.push_tool_call(request, cx))
1194 })?
1195 .context("Failed to update thread")?;
1196
1197 Ok(acp::PushToolCallResponse { id })
1198 }
1199
1200 async fn update_tool_call(&self, request: acp::UpdateToolCallParams) -> Result<(), acp::Error> {
1201 let cx = &mut self.cx.clone();
1202
1203 cx.update(|cx| {
1204 self.thread.update(cx, |thread, cx| {
1205 thread.update_tool_call(request.tool_call_id, request.status, request.content, cx)
1206 })
1207 })?
1208 .context("Failed to update thread")??;
1209
1210 Ok(())
1211 }
1212
1213 async fn read_text_file(
1214 &self,
1215 request: acp::ReadTextFileParams,
1216 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
1217 let content = self
1218 .cx
1219 .update(|cx| {
1220 self.thread
1221 .update(cx, |thread, cx| thread.read_text_file(request, cx))
1222 })?
1223 .context("Failed to update thread")?
1224 .await?;
1225 Ok(acp::ReadTextFileResponse { content })
1226 }
1227
1228 async fn write_text_file(&self, request: acp::WriteTextFileParams) -> Result<(), acp::Error> {
1229 self.cx
1230 .update(|cx| {
1231 self.thread.update(cx, |thread, cx| {
1232 thread.write_text_file(request.path, request.content, cx)
1233 })
1234 })?
1235 .context("Failed to update thread")?
1236 .await?;
1237
1238 Ok(())
1239 }
1240}
1241
1242fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
1243 match icon {
1244 acp::Icon::FileSearch => IconName::ToolSearch,
1245 acp::Icon::Folder => IconName::ToolFolder,
1246 acp::Icon::Globe => IconName::ToolWeb,
1247 acp::Icon::Hammer => IconName::ToolHammer,
1248 acp::Icon::LightBulb => IconName::ToolBulb,
1249 acp::Icon::Pencil => IconName::ToolPencil,
1250 acp::Icon::Regex => IconName::ToolRegex,
1251 acp::Icon::Terminal => IconName::ToolTerminal,
1252 }
1253}
1254
1255pub struct ToolCallRequest {
1256 pub id: acp::ToolCallId,
1257 pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
1258}
1259
1260#[cfg(test)]
1261mod tests {
1262 use super::*;
1263 use agent_servers::{AgentServerCommand, AgentServerVersion};
1264 use async_pipe::{PipeReader, PipeWriter};
1265 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1266 use gpui::{AsyncApp, TestAppContext};
1267 use indoc::indoc;
1268 use project::FakeFs;
1269 use serde_json::json;
1270 use settings::SettingsStore;
1271 use smol::{future::BoxedLocal, stream::StreamExt as _};
1272 use std::{cell::RefCell, env, path::Path, rc::Rc, time::Duration};
1273 use util::path;
1274
1275 fn init_test(cx: &mut TestAppContext) {
1276 env_logger::try_init().ok();
1277 cx.update(|cx| {
1278 let settings_store = SettingsStore::test(cx);
1279 cx.set_global(settings_store);
1280 Project::init_settings(cx);
1281 language::init(cx);
1282 });
1283 }
1284
1285 #[gpui::test]
1286 async fn test_thinking_concatenation(cx: &mut TestAppContext) {
1287 init_test(cx);
1288
1289 let fs = FakeFs::new(cx.executor());
1290 let project = Project::test(fs, [], cx).await;
1291 let (thread, fake_server) = fake_acp_thread(project, cx);
1292
1293 fake_server.update(cx, |fake_server, _| {
1294 fake_server.on_user_message(move |_, server, mut cx| async move {
1295 server
1296 .update(&mut cx, |server, _| {
1297 server.send_to_zed(acp::StreamAssistantMessageChunkParams {
1298 chunk: acp::AssistantMessageChunk::Thought {
1299 thought: "Thinking ".into(),
1300 },
1301 })
1302 })?
1303 .await
1304 .unwrap();
1305 server
1306 .update(&mut cx, |server, _| {
1307 server.send_to_zed(acp::StreamAssistantMessageChunkParams {
1308 chunk: acp::AssistantMessageChunk::Thought {
1309 thought: "hard!".into(),
1310 },
1311 })
1312 })?
1313 .await
1314 .unwrap();
1315
1316 Ok(())
1317 })
1318 });
1319
1320 thread
1321 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1322 .await
1323 .unwrap();
1324
1325 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1326 assert_eq!(
1327 output,
1328 indoc! {r#"
1329 ## User
1330
1331 Hello from Zed!
1332
1333 ## Assistant
1334
1335 <thinking>
1336 Thinking hard!
1337 </thinking>
1338
1339 "#}
1340 );
1341 }
1342
1343 #[gpui::test]
1344 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1345 init_test(cx);
1346
1347 let fs = FakeFs::new(cx.executor());
1348 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1349 .await;
1350 let project = Project::test(fs.clone(), [], cx).await;
1351 let (thread, fake_server) = fake_acp_thread(project.clone(), cx);
1352 let (worktree, pathbuf) = project
1353 .update(cx, |project, cx| {
1354 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1355 })
1356 .await
1357 .unwrap();
1358 let buffer = project
1359 .update(cx, |project, cx| {
1360 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1361 })
1362 .await
1363 .unwrap();
1364
1365 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1366 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1367
1368 fake_server.update(cx, |fake_server, _| {
1369 fake_server.on_user_message(move |_, server, mut cx| {
1370 let read_file_tx = read_file_tx.clone();
1371 async move {
1372 let content = server
1373 .update(&mut cx, |server, _| {
1374 server.send_to_zed(acp::ReadTextFileParams {
1375 path: path!("/tmp/foo").into(),
1376 line: None,
1377 limit: None,
1378 })
1379 })?
1380 .await
1381 .unwrap();
1382 assert_eq!(content.content, "one\ntwo\nthree\n");
1383 read_file_tx.take().unwrap().send(()).unwrap();
1384 server
1385 .update(&mut cx, |server, _| {
1386 server.send_to_zed(acp::WriteTextFileParams {
1387 path: path!("/tmp/foo").into(),
1388 content: "one\ntwo\nthree\nfour\nfive\n".to_string(),
1389 })
1390 })?
1391 .await
1392 .unwrap();
1393 Ok(())
1394 }
1395 })
1396 });
1397
1398 let request = thread.update(cx, |thread, cx| {
1399 thread.send_raw("Extend the count in /tmp/foo", cx)
1400 });
1401 read_file_rx.await.ok();
1402 buffer.update(cx, |buffer, cx| {
1403 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1404 });
1405 cx.run_until_parked();
1406 assert_eq!(
1407 buffer.read_with(cx, |buffer, _| buffer.text()),
1408 "zero\none\ntwo\nthree\nfour\nfive\n"
1409 );
1410 assert_eq!(
1411 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1412 "zero\none\ntwo\nthree\nfour\nfive\n"
1413 );
1414 request.await.unwrap();
1415 }
1416
1417 #[gpui::test]
1418 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1419 init_test(cx);
1420
1421 let fs = FakeFs::new(cx.executor());
1422 let project = Project::test(fs, [], cx).await;
1423 let (thread, fake_server) = fake_acp_thread(project, cx);
1424
1425 let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>();
1426
1427 let tool_call_id = Rc::new(RefCell::new(None));
1428 let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx)));
1429 fake_server.update(cx, |fake_server, _| {
1430 let tool_call_id = tool_call_id.clone();
1431 fake_server.on_user_message(move |_, server, mut cx| {
1432 let end_turn_rx = end_turn_rx.clone();
1433 let tool_call_id = tool_call_id.clone();
1434 async move {
1435 let tool_call_result = server
1436 .update(&mut cx, |server, _| {
1437 server.send_to_zed(acp::PushToolCallParams {
1438 label: "Fetch".to_string(),
1439 icon: acp::Icon::Globe,
1440 content: None,
1441 locations: vec![],
1442 })
1443 })?
1444 .await
1445 .unwrap();
1446 *tool_call_id.clone().borrow_mut() = Some(tool_call_result.id);
1447 end_turn_rx.take().unwrap().await.ok();
1448
1449 Ok(())
1450 }
1451 })
1452 });
1453
1454 let request = thread.update(cx, |thread, cx| {
1455 thread.send_raw("Fetch https://example.com", cx)
1456 });
1457
1458 run_until_first_tool_call(&thread, cx).await;
1459
1460 thread.read_with(cx, |thread, _| {
1461 assert!(matches!(
1462 thread.entries[1],
1463 AgentThreadEntry::ToolCall(ToolCall {
1464 status: ToolCallStatus::Allowed {
1465 status: acp::ToolCallStatus::Running,
1466 ..
1467 },
1468 ..
1469 })
1470 ));
1471 });
1472
1473 cx.run_until_parked();
1474
1475 thread
1476 .update(cx, |thread, cx| thread.cancel(cx))
1477 .await
1478 .unwrap();
1479
1480 thread.read_with(cx, |thread, _| {
1481 assert!(matches!(
1482 &thread.entries[1],
1483 AgentThreadEntry::ToolCall(ToolCall {
1484 status: ToolCallStatus::Canceled,
1485 ..
1486 })
1487 ));
1488 });
1489
1490 fake_server
1491 .update(cx, |fake_server, _| {
1492 fake_server.send_to_zed(acp::UpdateToolCallParams {
1493 tool_call_id: tool_call_id.borrow().unwrap(),
1494 status: acp::ToolCallStatus::Finished,
1495 content: None,
1496 })
1497 })
1498 .await
1499 .unwrap();
1500
1501 drop(end_turn_tx);
1502 request.await.unwrap();
1503
1504 thread.read_with(cx, |thread, _| {
1505 assert!(matches!(
1506 thread.entries[1],
1507 AgentThreadEntry::ToolCall(ToolCall {
1508 status: ToolCallStatus::Allowed {
1509 status: acp::ToolCallStatus::Finished,
1510 ..
1511 },
1512 ..
1513 })
1514 ));
1515 });
1516 }
1517
1518 #[gpui::test]
1519 #[cfg_attr(not(feature = "gemini"), ignore)]
1520 async fn test_gemini_basic(cx: &mut TestAppContext) {
1521 init_test(cx);
1522
1523 cx.executor().allow_parking();
1524
1525 let fs = FakeFs::new(cx.executor());
1526 let project = Project::test(fs, [], cx).await;
1527 let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
1528 thread
1529 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1530 .await
1531 .unwrap();
1532
1533 thread.read_with(cx, |thread, _| {
1534 assert_eq!(thread.entries.len(), 2);
1535 assert!(matches!(
1536 thread.entries[0],
1537 AgentThreadEntry::UserMessage(_)
1538 ));
1539 assert!(matches!(
1540 thread.entries[1],
1541 AgentThreadEntry::AssistantMessage(_)
1542 ));
1543 });
1544 }
1545
1546 #[gpui::test]
1547 #[cfg_attr(not(feature = "gemini"), ignore)]
1548 async fn test_gemini_path_mentions(cx: &mut TestAppContext) {
1549 init_test(cx);
1550
1551 cx.executor().allow_parking();
1552 let tempdir = tempfile::tempdir().unwrap();
1553 std::fs::write(
1554 tempdir.path().join("foo.rs"),
1555 indoc! {"
1556 fn main() {
1557 println!(\"Hello, world!\");
1558 }
1559 "},
1560 )
1561 .expect("failed to write file");
1562 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
1563 let thread = gemini_acp_thread(project.clone(), tempdir.path(), cx).await;
1564 thread
1565 .update(cx, |thread, cx| {
1566 thread.send(
1567 acp::SendUserMessageParams {
1568 chunks: vec![
1569 acp::UserMessageChunk::Text {
1570 text: "Read the file ".into(),
1571 },
1572 acp::UserMessageChunk::Path {
1573 path: Path::new("foo.rs").into(),
1574 },
1575 acp::UserMessageChunk::Text {
1576 text: " and tell me what the content of the println! is".into(),
1577 },
1578 ],
1579 },
1580 cx,
1581 )
1582 })
1583 .await
1584 .unwrap();
1585
1586 thread.read_with(cx, |thread, cx| {
1587 assert_eq!(thread.entries.len(), 3);
1588 assert!(matches!(
1589 thread.entries[0],
1590 AgentThreadEntry::UserMessage(_)
1591 ));
1592 assert!(matches!(thread.entries[1], AgentThreadEntry::ToolCall(_)));
1593 let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries[2] else {
1594 panic!("Expected AssistantMessage")
1595 };
1596 assert!(
1597 assistant_message.to_markdown(cx).contains("Hello, world!"),
1598 "unexpected assistant message: {:?}",
1599 assistant_message.to_markdown(cx)
1600 );
1601 });
1602 }
1603
1604 #[gpui::test]
1605 #[cfg_attr(not(feature = "gemini"), ignore)]
1606 async fn test_gemini_tool_call(cx: &mut TestAppContext) {
1607 init_test(cx);
1608
1609 cx.executor().allow_parking();
1610
1611 let fs = FakeFs::new(cx.executor());
1612 fs.insert_tree(
1613 path!("/private/tmp"),
1614 json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
1615 )
1616 .await;
1617 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
1618 let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
1619 thread
1620 .update(cx, |thread, cx| {
1621 thread.send_raw(
1622 "Read the '/private/tmp/foo' file and tell me what you see.",
1623 cx,
1624 )
1625 })
1626 .await
1627 .unwrap();
1628 thread.read_with(cx, |thread, _cx| {
1629 assert!(matches!(
1630 &thread.entries()[2],
1631 AgentThreadEntry::ToolCall(ToolCall {
1632 status: ToolCallStatus::Allowed { .. },
1633 ..
1634 })
1635 ));
1636
1637 assert!(matches!(
1638 thread.entries[3],
1639 AgentThreadEntry::AssistantMessage(_)
1640 ));
1641 });
1642 }
1643
1644 #[gpui::test]
1645 #[cfg_attr(not(feature = "gemini"), ignore)]
1646 async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
1647 init_test(cx);
1648
1649 cx.executor().allow_parking();
1650
1651 let fs = FakeFs::new(cx.executor());
1652 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
1653 let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
1654 let full_turn = thread.update(cx, |thread, cx| {
1655 thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
1656 });
1657
1658 run_until_first_tool_call(&thread, cx).await;
1659
1660 let tool_call_id = thread.read_with(cx, |thread, _cx| {
1661 let AgentThreadEntry::ToolCall(ToolCall {
1662 id,
1663 status:
1664 ToolCallStatus::WaitingForConfirmation {
1665 confirmation: ToolCallConfirmation::Execute { root_command, .. },
1666 ..
1667 },
1668 ..
1669 }) = &thread.entries()[2]
1670 else {
1671 panic!();
1672 };
1673
1674 assert_eq!(root_command, "echo");
1675
1676 *id
1677 });
1678
1679 thread.update(cx, |thread, cx| {
1680 thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
1681
1682 assert!(matches!(
1683 &thread.entries()[2],
1684 AgentThreadEntry::ToolCall(ToolCall {
1685 status: ToolCallStatus::Allowed { .. },
1686 ..
1687 })
1688 ));
1689 });
1690
1691 full_turn.await.unwrap();
1692
1693 thread.read_with(cx, |thread, cx| {
1694 let AgentThreadEntry::ToolCall(ToolCall {
1695 content: Some(ToolCallContent::Markdown { markdown }),
1696 status: ToolCallStatus::Allowed { .. },
1697 ..
1698 }) = &thread.entries()[2]
1699 else {
1700 panic!();
1701 };
1702
1703 markdown.read_with(cx, |md, _cx| {
1704 assert!(
1705 md.source().contains("Hello, world!"),
1706 r#"Expected '{}' to contain "Hello, world!""#,
1707 md.source()
1708 );
1709 });
1710 });
1711 }
1712
1713 #[gpui::test]
1714 #[cfg_attr(not(feature = "gemini"), ignore)]
1715 async fn test_gemini_cancel(cx: &mut TestAppContext) {
1716 init_test(cx);
1717
1718 cx.executor().allow_parking();
1719
1720 let fs = FakeFs::new(cx.executor());
1721 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
1722 let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
1723 let full_turn = thread.update(cx, |thread, cx| {
1724 thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
1725 });
1726
1727 let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
1728
1729 thread.read_with(cx, |thread, _cx| {
1730 let AgentThreadEntry::ToolCall(ToolCall {
1731 id,
1732 status:
1733 ToolCallStatus::WaitingForConfirmation {
1734 confirmation: ToolCallConfirmation::Execute { root_command, .. },
1735 ..
1736 },
1737 ..
1738 }) = &thread.entries()[first_tool_call_ix]
1739 else {
1740 panic!("{:?}", thread.entries()[1]);
1741 };
1742
1743 assert_eq!(root_command, "echo");
1744
1745 *id
1746 });
1747
1748 thread
1749 .update(cx, |thread, cx| thread.cancel(cx))
1750 .await
1751 .unwrap();
1752 full_turn.await.unwrap();
1753 thread.read_with(cx, |thread, _| {
1754 let AgentThreadEntry::ToolCall(ToolCall {
1755 status: ToolCallStatus::Canceled,
1756 ..
1757 }) = &thread.entries()[first_tool_call_ix]
1758 else {
1759 panic!();
1760 };
1761 });
1762
1763 thread
1764 .update(cx, |thread, cx| {
1765 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
1766 })
1767 .await
1768 .unwrap();
1769 thread.read_with(cx, |thread, _| {
1770 assert!(matches!(
1771 &thread.entries().last().unwrap(),
1772 AgentThreadEntry::AssistantMessage(..),
1773 ))
1774 });
1775 }
1776
1777 async fn run_until_first_tool_call(
1778 thread: &Entity<AcpThread>,
1779 cx: &mut TestAppContext,
1780 ) -> usize {
1781 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1782
1783 let subscription = cx.update(|cx| {
1784 cx.subscribe(thread, move |thread, _, cx| {
1785 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1786 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1787 return tx.try_send(ix).unwrap();
1788 }
1789 }
1790 })
1791 });
1792
1793 select! {
1794 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1795 panic!("Timeout waiting for tool call")
1796 }
1797 ix = rx.next().fuse() => {
1798 drop(subscription);
1799 ix.unwrap()
1800 }
1801 }
1802 }
1803
1804 pub async fn gemini_acp_thread(
1805 project: Entity<Project>,
1806 current_dir: impl AsRef<Path>,
1807 cx: &mut TestAppContext,
1808 ) -> Entity<AcpThread> {
1809 struct DevGemini;
1810
1811 impl agent_servers::AgentServer for DevGemini {
1812 async fn command(
1813 &self,
1814 _project: &Entity<Project>,
1815 _cx: &mut AsyncApp,
1816 ) -> Result<agent_servers::AgentServerCommand> {
1817 let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
1818 .join("../../../gemini-cli/packages/cli")
1819 .to_string_lossy()
1820 .to_string();
1821
1822 Ok(AgentServerCommand {
1823 path: "node".into(),
1824 args: vec![cli_path, "--experimental-acp".into()],
1825 env: None,
1826 })
1827 }
1828
1829 async fn version(
1830 &self,
1831 _command: &agent_servers::AgentServerCommand,
1832 ) -> Result<AgentServerVersion> {
1833 Ok(AgentServerVersion {
1834 current_version: "0.1.0".into(),
1835 supported: true,
1836 })
1837 }
1838 }
1839
1840 let thread = AcpThread::spawn(DevGemini, current_dir.as_ref(), project, &mut cx.to_async())
1841 .await
1842 .unwrap();
1843
1844 thread
1845 .update(cx, |thread, _| thread.initialize())
1846 .await
1847 .unwrap();
1848 thread
1849 }
1850
1851 pub fn fake_acp_thread(
1852 project: Entity<Project>,
1853 cx: &mut TestAppContext,
1854 ) -> (Entity<AcpThread>, Entity<FakeAcpServer>) {
1855 let (stdin_tx, stdin_rx) = async_pipe::pipe();
1856 let (stdout_tx, stdout_rx) = async_pipe::pipe();
1857 let thread = cx.update(|cx| cx.new(|cx| AcpThread::fake(stdin_tx, stdout_rx, project, cx)));
1858 let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
1859 (thread, agent)
1860 }
1861
1862 pub struct FakeAcpServer {
1863 connection: acp::ClientConnection,
1864 _io_task: Task<()>,
1865 on_user_message: Option<
1866 Rc<
1867 dyn Fn(
1868 acp::SendUserMessageParams,
1869 Entity<FakeAcpServer>,
1870 AsyncApp,
1871 ) -> LocalBoxFuture<'static, Result<(), acp::Error>>,
1872 >,
1873 >,
1874 }
1875
1876 #[derive(Clone)]
1877 struct FakeAgent {
1878 server: Entity<FakeAcpServer>,
1879 cx: AsyncApp,
1880 }
1881
1882 impl acp::Agent for FakeAgent {
1883 async fn initialize(
1884 &self,
1885 params: acp::InitializeParams,
1886 ) -> Result<acp::InitializeResponse, acp::Error> {
1887 Ok(acp::InitializeResponse {
1888 protocol_version: params.protocol_version,
1889 is_authenticated: true,
1890 })
1891 }
1892
1893 async fn authenticate(&self) -> Result<(), acp::Error> {
1894 Ok(())
1895 }
1896
1897 async fn cancel_send_message(&self) -> Result<(), acp::Error> {
1898 Ok(())
1899 }
1900
1901 async fn send_user_message(
1902 &self,
1903 request: acp::SendUserMessageParams,
1904 ) -> Result<(), acp::Error> {
1905 let mut cx = self.cx.clone();
1906 let handler = self
1907 .server
1908 .update(&mut cx, |server, _| server.on_user_message.clone())
1909 .ok()
1910 .flatten();
1911 if let Some(handler) = handler {
1912 handler(request, self.server.clone(), self.cx.clone()).await
1913 } else {
1914 Err(anyhow::anyhow!("No handler for on_user_message").into())
1915 }
1916 }
1917 }
1918
1919 impl FakeAcpServer {
1920 fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
1921 let agent = FakeAgent {
1922 server: cx.entity(),
1923 cx: cx.to_async(),
1924 };
1925 let foreground_executor = cx.foreground_executor().clone();
1926
1927 let (connection, io_fut) = acp::ClientConnection::connect_to_client(
1928 agent.clone(),
1929 stdout,
1930 stdin,
1931 move |fut| {
1932 foreground_executor.spawn(fut).detach();
1933 },
1934 );
1935 FakeAcpServer {
1936 connection: connection,
1937 on_user_message: None,
1938 _io_task: cx.background_spawn(async move {
1939 io_fut.await.log_err();
1940 }),
1941 }
1942 }
1943
1944 fn on_user_message<F>(
1945 &mut self,
1946 handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity<FakeAcpServer>, AsyncApp) -> F
1947 + 'static,
1948 ) where
1949 F: Future<Output = Result<(), acp::Error>> + 'static,
1950 {
1951 self.on_user_message
1952 .replace(Rc::new(move |request, server, cx| {
1953 handler(request, server, cx).boxed_local()
1954 }));
1955 }
1956
1957 fn send_to_zed<T: acp::ClientRequest + 'static>(
1958 &self,
1959 message: T,
1960 ) -> BoxedLocal<Result<T::Response>> {
1961 self.connection
1962 .request(message)
1963 .map(|f| f.map_err(|err| anyhow!(err)))
1964 .boxed_local()
1965 }
1966 }
1967}