1mod connection;
2pub use connection::*;
3
4pub use acp::ToolCallId;
5use agentic_coding_protocol::{
6 self as acp, AgentRequest, ProtocolVersion, ToolCallConfirmationOutcome, ToolCallLocation,
7 UserMessageChunk,
8};
9use anyhow::{Context as _, Result};
10use assistant_tool::ActionLog;
11use buffer_diff::BufferDiff;
12use editor::{Bias, MultiBuffer, PathKey};
13use futures::{FutureExt, channel::oneshot, future::BoxFuture};
14use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
15use itertools::Itertools;
16use language::{
17 Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
18 text_diff,
19};
20use markdown::Markdown;
21use project::{AgentLocation, Project};
22use std::collections::HashMap;
23use std::error::Error;
24use std::fmt::{Formatter, Write};
25use std::{
26 fmt::Display,
27 mem,
28 path::{Path, PathBuf},
29 sync::Arc,
30};
31use ui::{App, IconName};
32use util::ResultExt;
33
34#[derive(Clone, Debug, Eq, PartialEq)]
35pub struct UserMessage {
36 pub content: Entity<Markdown>,
37}
38
39impl UserMessage {
40 pub fn from_acp(
41 message: &acp::SendUserMessageParams,
42 language_registry: Arc<LanguageRegistry>,
43 cx: &mut App,
44 ) -> Self {
45 let mut md_source = String::new();
46
47 for chunk in &message.chunks {
48 match chunk {
49 UserMessageChunk::Text { text } => md_source.push_str(&text),
50 UserMessageChunk::Path { path } => {
51 write!(&mut md_source, "{}", MentionPath(&path)).unwrap()
52 }
53 }
54 }
55
56 Self {
57 content: cx
58 .new(|cx| Markdown::new(md_source.into(), Some(language_registry), None, cx)),
59 }
60 }
61
62 fn to_markdown(&self, cx: &App) -> String {
63 format!("## User\n\n{}\n\n", self.content.read(cx).source())
64 }
65}
66
67#[derive(Debug)]
68pub struct MentionPath<'a>(&'a Path);
69
70impl<'a> MentionPath<'a> {
71 const PREFIX: &'static str = "@file:";
72
73 pub fn new(path: &'a Path) -> Self {
74 MentionPath(path)
75 }
76
77 pub fn try_parse(url: &'a str) -> Option<Self> {
78 let path = url.strip_prefix(Self::PREFIX)?;
79 Some(MentionPath(Path::new(path)))
80 }
81
82 pub fn path(&self) -> &Path {
83 self.0
84 }
85}
86
87impl Display for MentionPath<'_> {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 write!(
90 f,
91 "[@{}]({}{})",
92 self.0.file_name().unwrap_or_default().display(),
93 Self::PREFIX,
94 self.0.display()
95 )
96 }
97}
98
99#[derive(Clone, Debug, Eq, PartialEq)]
100pub struct AssistantMessage {
101 pub chunks: Vec<AssistantMessageChunk>,
102}
103
104impl AssistantMessage {
105 pub fn to_markdown(&self, cx: &App) -> String {
106 format!(
107 "## Assistant\n\n{}\n\n",
108 self.chunks
109 .iter()
110 .map(|chunk| chunk.to_markdown(cx))
111 .join("\n\n")
112 )
113 }
114}
115
116#[derive(Clone, Debug, Eq, PartialEq)]
117pub enum AssistantMessageChunk {
118 Text { chunk: Entity<Markdown> },
119 Thought { chunk: Entity<Markdown> },
120}
121
122impl AssistantMessageChunk {
123 pub fn from_acp(
124 chunk: acp::AssistantMessageChunk,
125 language_registry: Arc<LanguageRegistry>,
126 cx: &mut App,
127 ) -> Self {
128 match chunk {
129 acp::AssistantMessageChunk::Text { text } => Self::Text {
130 chunk: cx.new(|cx| Markdown::new(text.into(), Some(language_registry), None, cx)),
131 },
132 acp::AssistantMessageChunk::Thought { thought } => Self::Thought {
133 chunk: cx
134 .new(|cx| Markdown::new(thought.into(), Some(language_registry), None, cx)),
135 },
136 }
137 }
138
139 pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
140 Self::Text {
141 chunk: cx.new(|cx| {
142 Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
143 }),
144 }
145 }
146
147 fn to_markdown(&self, cx: &App) -> String {
148 match self {
149 Self::Text { chunk } => chunk.read(cx).source().to_string(),
150 Self::Thought { chunk } => {
151 format!("<thinking>\n{}\n</thinking>", chunk.read(cx).source())
152 }
153 }
154 }
155}
156
157#[derive(Debug)]
158pub enum AgentThreadEntry {
159 UserMessage(UserMessage),
160 AssistantMessage(AssistantMessage),
161 ToolCall(ToolCall),
162}
163
164impl AgentThreadEntry {
165 fn to_markdown(&self, cx: &App) -> String {
166 match self {
167 Self::UserMessage(message) => message.to_markdown(cx),
168 Self::AssistantMessage(message) => message.to_markdown(cx),
169 Self::ToolCall(too_call) => too_call.to_markdown(cx),
170 }
171 }
172
173 pub fn diff(&self) -> Option<&Diff> {
174 if let AgentThreadEntry::ToolCall(ToolCall {
175 content: Some(ToolCallContent::Diff { diff }),
176 ..
177 }) = self
178 {
179 Some(&diff)
180 } else {
181 None
182 }
183 }
184
185 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
186 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
187 Some(locations)
188 } else {
189 None
190 }
191 }
192}
193
194#[derive(Debug)]
195pub struct ToolCall {
196 pub id: acp::ToolCallId,
197 pub label: Entity<Markdown>,
198 pub icon: IconName,
199 pub content: Option<ToolCallContent>,
200 pub status: ToolCallStatus,
201 pub locations: Vec<acp::ToolCallLocation>,
202}
203
204impl ToolCall {
205 fn to_markdown(&self, cx: &App) -> String {
206 let mut markdown = format!(
207 "**Tool Call: {}**\nStatus: {}\n\n",
208 self.label.read(cx).source(),
209 self.status
210 );
211 if let Some(content) = &self.content {
212 markdown.push_str(content.to_markdown(cx).as_str());
213 markdown.push_str("\n\n");
214 }
215 markdown
216 }
217}
218
219#[derive(Debug)]
220pub enum ToolCallStatus {
221 WaitingForConfirmation {
222 confirmation: ToolCallConfirmation,
223 respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
224 },
225 Allowed {
226 status: acp::ToolCallStatus,
227 },
228 Rejected,
229 Canceled,
230}
231
232impl Display for ToolCallStatus {
233 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
234 write!(
235 f,
236 "{}",
237 match self {
238 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
239 ToolCallStatus::Allowed { status } => match status {
240 acp::ToolCallStatus::Running => "Running",
241 acp::ToolCallStatus::Finished => "Finished",
242 acp::ToolCallStatus::Error => "Error",
243 },
244 ToolCallStatus::Rejected => "Rejected",
245 ToolCallStatus::Canceled => "Canceled",
246 }
247 )
248 }
249}
250
251#[derive(Debug)]
252pub enum ToolCallConfirmation {
253 Edit {
254 description: Option<Entity<Markdown>>,
255 },
256 Execute {
257 command: String,
258 root_command: String,
259 description: Option<Entity<Markdown>>,
260 },
261 Mcp {
262 server_name: String,
263 tool_name: String,
264 tool_display_name: String,
265 description: Option<Entity<Markdown>>,
266 },
267 Fetch {
268 urls: Vec<SharedString>,
269 description: Option<Entity<Markdown>>,
270 },
271 Other {
272 description: Entity<Markdown>,
273 },
274}
275
276impl ToolCallConfirmation {
277 pub fn from_acp(
278 confirmation: acp::ToolCallConfirmation,
279 language_registry: Arc<LanguageRegistry>,
280 cx: &mut App,
281 ) -> Self {
282 let to_md = |description: String, cx: &mut App| -> Entity<Markdown> {
283 cx.new(|cx| {
284 Markdown::new(
285 description.into(),
286 Some(language_registry.clone()),
287 None,
288 cx,
289 )
290 })
291 };
292
293 match confirmation {
294 acp::ToolCallConfirmation::Edit { description } => Self::Edit {
295 description: description.map(|description| to_md(description, cx)),
296 },
297 acp::ToolCallConfirmation::Execute {
298 command,
299 root_command,
300 description,
301 } => Self::Execute {
302 command,
303 root_command,
304 description: description.map(|description| to_md(description, cx)),
305 },
306 acp::ToolCallConfirmation::Mcp {
307 server_name,
308 tool_name,
309 tool_display_name,
310 description,
311 } => Self::Mcp {
312 server_name,
313 tool_name,
314 tool_display_name,
315 description: description.map(|description| to_md(description, cx)),
316 },
317 acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch {
318 urls: urls.iter().map(|url| url.into()).collect(),
319 description: description.map(|description| to_md(description, cx)),
320 },
321 acp::ToolCallConfirmation::Other { description } => Self::Other {
322 description: to_md(description, cx),
323 },
324 }
325 }
326}
327
328#[derive(Debug)]
329pub enum ToolCallContent {
330 Markdown { markdown: Entity<Markdown> },
331 Diff { diff: Diff },
332}
333
334impl ToolCallContent {
335 pub fn from_acp(
336 content: acp::ToolCallContent,
337 language_registry: Arc<LanguageRegistry>,
338 cx: &mut App,
339 ) -> Self {
340 match content {
341 acp::ToolCallContent::Markdown { markdown } => Self::Markdown {
342 markdown: cx.new(|cx| Markdown::new_text(markdown.into(), cx)),
343 },
344 acp::ToolCallContent::Diff { diff } => Self::Diff {
345 diff: Diff::from_acp(diff, language_registry, cx),
346 },
347 }
348 }
349
350 fn to_markdown(&self, cx: &App) -> String {
351 match self {
352 Self::Markdown { markdown } => markdown.read(cx).source().to_string(),
353 Self::Diff { diff } => diff.to_markdown(cx),
354 }
355 }
356}
357
358#[derive(Debug)]
359pub struct Diff {
360 pub multibuffer: Entity<MultiBuffer>,
361 pub path: PathBuf,
362 pub new_buffer: Entity<Buffer>,
363 pub old_buffer: Entity<Buffer>,
364 _task: Task<Result<()>>,
365}
366
367impl Diff {
368 pub fn from_acp(
369 diff: acp::Diff,
370 language_registry: Arc<LanguageRegistry>,
371 cx: &mut App,
372 ) -> Self {
373 let acp::Diff {
374 path,
375 old_text,
376 new_text,
377 } = diff;
378
379 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
380
381 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
382 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
383 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
384 let old_buffer_snapshot = old_buffer.read(cx).snapshot();
385 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
386 let diff_task = buffer_diff.update(cx, |diff, cx| {
387 diff.set_base_text(
388 old_buffer_snapshot,
389 Some(language_registry.clone()),
390 new_buffer_snapshot,
391 cx,
392 )
393 });
394
395 let task = cx.spawn({
396 let multibuffer = multibuffer.clone();
397 let path = path.clone();
398 let new_buffer = new_buffer.clone();
399 async move |cx| {
400 diff_task.await?;
401
402 multibuffer
403 .update(cx, |multibuffer, cx| {
404 let hunk_ranges = {
405 let buffer = new_buffer.read(cx);
406 let diff = buffer_diff.read(cx);
407 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
408 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
409 .collect::<Vec<_>>()
410 };
411
412 multibuffer.set_excerpts_for_path(
413 PathKey::for_buffer(&new_buffer, cx),
414 new_buffer.clone(),
415 hunk_ranges,
416 editor::DEFAULT_MULTIBUFFER_CONTEXT,
417 cx,
418 );
419 multibuffer.add_diff(buffer_diff.clone(), cx);
420 })
421 .log_err();
422
423 if let Some(language) = language_registry
424 .language_for_file_path(&path)
425 .await
426 .log_err()
427 {
428 new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
429 }
430
431 anyhow::Ok(())
432 }
433 });
434
435 Self {
436 multibuffer,
437 path,
438 new_buffer,
439 old_buffer,
440 _task: task,
441 }
442 }
443
444 fn to_markdown(&self, cx: &App) -> String {
445 let buffer_text = self
446 .multibuffer
447 .read(cx)
448 .all_buffers()
449 .iter()
450 .map(|buffer| buffer.read(cx).text())
451 .join("\n");
452 format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
453 }
454}
455
456#[derive(Debug, Default)]
457pub struct Plan {
458 pub entries: Vec<PlanEntry>,
459}
460
461#[derive(Debug)]
462pub struct PlanStats<'a> {
463 pub in_progress_entry: Option<&'a PlanEntry>,
464 pub pending: u32,
465 pub completed: u32,
466}
467
468impl Plan {
469 pub fn is_empty(&self) -> bool {
470 self.entries.is_empty()
471 }
472
473 pub fn stats(&self) -> PlanStats<'_> {
474 let mut stats = PlanStats {
475 in_progress_entry: None,
476 pending: 0,
477 completed: 0,
478 };
479
480 for entry in &self.entries {
481 match &entry.status {
482 acp::PlanEntryStatus::Pending => {
483 stats.pending += 1;
484 }
485 acp::PlanEntryStatus::InProgress => {
486 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
487 }
488 acp::PlanEntryStatus::Completed => {
489 stats.completed += 1;
490 }
491 }
492 }
493
494 stats
495 }
496}
497
498#[derive(Debug)]
499pub struct PlanEntry {
500 pub content: Entity<Markdown>,
501 pub priority: acp::PlanEntryPriority,
502 pub status: acp::PlanEntryStatus,
503}
504
505impl PlanEntry {
506 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
507 Self {
508 content: cx.new(|cx| Markdown::new_text(entry.content.into(), cx)),
509 priority: entry.priority,
510 status: entry.status,
511 }
512 }
513}
514
515pub struct AcpThread {
516 title: SharedString,
517 entries: Vec<AgentThreadEntry>,
518 plan: Plan,
519 project: Entity<Project>,
520 action_log: Entity<ActionLog>,
521 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
522 send_task: Option<Task<()>>,
523 connection: Arc<dyn AgentConnection>,
524 child_status: Option<Task<Result<()>>>,
525}
526
527pub enum AcpThreadEvent {
528 NewEntry,
529 EntryUpdated(usize),
530}
531
532impl EventEmitter<AcpThreadEvent> for AcpThread {}
533
534#[derive(PartialEq, Eq)]
535pub enum ThreadStatus {
536 Idle,
537 WaitingForToolConfirmation,
538 Generating,
539}
540
541#[derive(Debug, Clone)]
542pub enum LoadError {
543 Unsupported {
544 error_message: SharedString,
545 upgrade_message: SharedString,
546 upgrade_command: String,
547 },
548 Exited(i32),
549 Other(SharedString),
550}
551
552impl Display for LoadError {
553 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
554 match self {
555 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
556 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
557 LoadError::Other(msg) => write!(f, "{}", msg),
558 }
559 }
560}
561
562impl Error for LoadError {}
563
564impl AcpThread {
565 pub fn new(
566 connection: impl AgentConnection + 'static,
567 title: SharedString,
568 child_status: Option<Task<Result<()>>>,
569 project: Entity<Project>,
570 cx: &mut Context<Self>,
571 ) -> Self {
572 let action_log = cx.new(|_| ActionLog::new(project.clone()));
573
574 Self {
575 action_log,
576 shared_buffers: Default::default(),
577 entries: Default::default(),
578 plan: Default::default(),
579 title,
580 project,
581 send_task: None,
582 connection: Arc::new(connection),
583 child_status,
584 }
585 }
586
587 /// Send a request to the agent and wait for a response.
588 pub fn request<R: AgentRequest + 'static>(
589 &self,
590 params: R,
591 ) -> impl use<R> + Future<Output = Result<R::Response>> {
592 let params = params.into_any();
593 let result = self.connection.request_any(params);
594 async move {
595 let result = result.await?;
596 Ok(R::response_from_any(result)?)
597 }
598 }
599
600 pub fn action_log(&self) -> &Entity<ActionLog> {
601 &self.action_log
602 }
603
604 pub fn project(&self) -> &Entity<Project> {
605 &self.project
606 }
607
608 pub fn title(&self) -> SharedString {
609 self.title.clone()
610 }
611
612 pub fn entries(&self) -> &[AgentThreadEntry] {
613 &self.entries
614 }
615
616 pub fn status(&self) -> ThreadStatus {
617 if self.send_task.is_some() {
618 if self.waiting_for_tool_confirmation() {
619 ThreadStatus::WaitingForToolConfirmation
620 } else {
621 ThreadStatus::Generating
622 }
623 } else {
624 ThreadStatus::Idle
625 }
626 }
627
628 pub fn has_pending_edit_tool_calls(&self) -> bool {
629 for entry in self.entries.iter().rev() {
630 match entry {
631 AgentThreadEntry::UserMessage(_) => return false,
632 AgentThreadEntry::ToolCall(ToolCall {
633 status:
634 ToolCallStatus::Allowed {
635 status: acp::ToolCallStatus::Running,
636 ..
637 },
638 content: Some(ToolCallContent::Diff { .. }),
639 ..
640 }) => return true,
641 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
642 }
643 }
644
645 false
646 }
647
648 pub fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
649 self.entries.push(entry);
650 cx.emit(AcpThreadEvent::NewEntry);
651 }
652
653 pub fn push_assistant_chunk(
654 &mut self,
655 chunk: acp::AssistantMessageChunk,
656 cx: &mut Context<Self>,
657 ) {
658 let entries_len = self.entries.len();
659 if let Some(last_entry) = self.entries.last_mut()
660 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
661 {
662 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
663
664 match (chunks.last_mut(), &chunk) {
665 (
666 Some(AssistantMessageChunk::Text { chunk: old_chunk }),
667 acp::AssistantMessageChunk::Text { text: new_chunk },
668 )
669 | (
670 Some(AssistantMessageChunk::Thought { chunk: old_chunk }),
671 acp::AssistantMessageChunk::Thought { thought: new_chunk },
672 ) => {
673 old_chunk.update(cx, |old_chunk, cx| {
674 old_chunk.append(&new_chunk, cx);
675 });
676 }
677 _ => {
678 chunks.push(AssistantMessageChunk::from_acp(
679 chunk,
680 self.project.read(cx).languages().clone(),
681 cx,
682 ));
683 }
684 }
685 } else {
686 let chunk = AssistantMessageChunk::from_acp(
687 chunk,
688 self.project.read(cx).languages().clone(),
689 cx,
690 );
691
692 self.push_entry(
693 AgentThreadEntry::AssistantMessage(AssistantMessage {
694 chunks: vec![chunk],
695 }),
696 cx,
697 );
698 }
699 }
700
701 pub fn request_new_tool_call(
702 &mut self,
703 tool_call: acp::RequestToolCallConfirmationParams,
704 cx: &mut Context<Self>,
705 ) -> ToolCallRequest {
706 let (tx, rx) = oneshot::channel();
707
708 let status = ToolCallStatus::WaitingForConfirmation {
709 confirmation: ToolCallConfirmation::from_acp(
710 tool_call.confirmation,
711 self.project.read(cx).languages().clone(),
712 cx,
713 ),
714 respond_tx: tx,
715 };
716
717 let id = self.insert_tool_call(tool_call.tool_call, status, cx);
718 ToolCallRequest { id, outcome: rx }
719 }
720
721 pub fn request_tool_call_confirmation(
722 &mut self,
723 tool_call_id: ToolCallId,
724 confirmation: acp::ToolCallConfirmation,
725 cx: &mut Context<Self>,
726 ) -> Result<ToolCallRequest> {
727 let project = self.project.read(cx).languages().clone();
728 let Some((idx, call)) = self.tool_call_mut(tool_call_id) else {
729 anyhow::bail!("Tool call not found");
730 };
731
732 let (tx, rx) = oneshot::channel();
733
734 call.status = ToolCallStatus::WaitingForConfirmation {
735 confirmation: ToolCallConfirmation::from_acp(confirmation, project, cx),
736 respond_tx: tx,
737 };
738
739 cx.emit(AcpThreadEvent::EntryUpdated(idx));
740
741 Ok(ToolCallRequest {
742 id: tool_call_id,
743 outcome: rx,
744 })
745 }
746
747 pub fn push_tool_call(
748 &mut self,
749 request: acp::PushToolCallParams,
750 cx: &mut Context<Self>,
751 ) -> acp::ToolCallId {
752 let status = ToolCallStatus::Allowed {
753 status: acp::ToolCallStatus::Running,
754 };
755
756 self.insert_tool_call(request, status, cx)
757 }
758
759 fn insert_tool_call(
760 &mut self,
761 tool_call: acp::PushToolCallParams,
762 status: ToolCallStatus,
763 cx: &mut Context<Self>,
764 ) -> acp::ToolCallId {
765 let language_registry = self.project.read(cx).languages().clone();
766 let id = acp::ToolCallId(self.entries.len() as u64);
767 let call = ToolCall {
768 id,
769 label: cx.new(|cx| {
770 Markdown::new(
771 tool_call.label.into(),
772 Some(language_registry.clone()),
773 None,
774 cx,
775 )
776 }),
777 icon: acp_icon_to_ui_icon(tool_call.icon),
778 content: tool_call
779 .content
780 .map(|content| ToolCallContent::from_acp(content, language_registry, cx)),
781 locations: tool_call.locations,
782 status,
783 };
784
785 let location = call.locations.last().cloned();
786 if let Some(location) = location {
787 self.set_project_location(location, cx)
788 }
789
790 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
791
792 id
793 }
794
795 pub fn authorize_tool_call(
796 &mut self,
797 id: acp::ToolCallId,
798 outcome: acp::ToolCallConfirmationOutcome,
799 cx: &mut Context<Self>,
800 ) {
801 let Some((ix, call)) = self.tool_call_mut(id) else {
802 return;
803 };
804
805 let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
806 ToolCallStatus::Rejected
807 } else {
808 ToolCallStatus::Allowed {
809 status: acp::ToolCallStatus::Running,
810 }
811 };
812
813 let curr_status = mem::replace(&mut call.status, new_status);
814
815 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
816 respond_tx.send(outcome).log_err();
817 } else if cfg!(debug_assertions) {
818 panic!("tried to authorize an already authorized tool call");
819 }
820
821 cx.emit(AcpThreadEvent::EntryUpdated(ix));
822 }
823
824 pub fn update_tool_call(
825 &mut self,
826 id: acp::ToolCallId,
827 new_status: acp::ToolCallStatus,
828 new_content: Option<acp::ToolCallContent>,
829 cx: &mut Context<Self>,
830 ) -> Result<()> {
831 let language_registry = self.project.read(cx).languages().clone();
832 let (ix, call) = self.tool_call_mut(id).context("Entry not found")?;
833
834 if let Some(new_content) = new_content {
835 call.content = Some(ToolCallContent::from_acp(
836 new_content,
837 language_registry,
838 cx,
839 ));
840 }
841
842 match &mut call.status {
843 ToolCallStatus::Allowed { status } => {
844 *status = new_status;
845 }
846 ToolCallStatus::WaitingForConfirmation { .. } => {
847 anyhow::bail!("Tool call hasn't been authorized yet")
848 }
849 ToolCallStatus::Rejected => {
850 anyhow::bail!("Tool call was rejected and therefore can't be updated")
851 }
852 ToolCallStatus::Canceled => {
853 call.status = ToolCallStatus::Allowed { status: new_status };
854 }
855 }
856
857 let location = call.locations.last().cloned();
858 if let Some(location) = location {
859 self.set_project_location(location, cx)
860 }
861
862 cx.emit(AcpThreadEvent::EntryUpdated(ix));
863 Ok(())
864 }
865
866 fn tool_call_mut(&mut self, id: acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
867 let entry = self.entries.get_mut(id.0 as usize);
868 debug_assert!(
869 entry.is_some(),
870 "We shouldn't give out ids to entries that don't exist"
871 );
872 match entry {
873 Some(AgentThreadEntry::ToolCall(call)) if call.id == id => Some((id.0 as usize, call)),
874 _ => {
875 if cfg!(debug_assertions) {
876 panic!("entry is not a tool call");
877 }
878 None
879 }
880 }
881 }
882
883 pub fn plan(&self) -> &Plan {
884 &self.plan
885 }
886
887 pub fn update_plan(&mut self, request: acp::UpdatePlanParams, cx: &mut Context<Self>) {
888 self.plan = Plan {
889 entries: request
890 .entries
891 .into_iter()
892 .map(|entry| PlanEntry::from_acp(entry, cx))
893 .collect(),
894 };
895
896 cx.notify();
897 }
898
899 pub fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
900 self.plan
901 .entries
902 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
903 cx.notify();
904 }
905
906 pub fn set_project_location(&self, location: ToolCallLocation, cx: &mut Context<Self>) {
907 self.project.update(cx, |project, cx| {
908 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
909 return;
910 };
911 let buffer = project.open_buffer(path, cx);
912 cx.spawn(async move |project, cx| {
913 let buffer = buffer.await?;
914
915 project.update(cx, |project, cx| {
916 let position = if let Some(line) = location.line {
917 let snapshot = buffer.read(cx).snapshot();
918 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
919 snapshot.anchor_before(point)
920 } else {
921 Anchor::MIN
922 };
923
924 project.set_agent_location(
925 Some(AgentLocation {
926 buffer: buffer.downgrade(),
927 position,
928 }),
929 cx,
930 );
931 })
932 })
933 .detach_and_log_err(cx);
934 });
935 }
936
937 /// Returns true if the last turn is awaiting tool authorization
938 pub fn waiting_for_tool_confirmation(&self) -> bool {
939 for entry in self.entries.iter().rev() {
940 match &entry {
941 AgentThreadEntry::ToolCall(call) => match call.status {
942 ToolCallStatus::WaitingForConfirmation { .. } => return true,
943 ToolCallStatus::Allowed { .. }
944 | ToolCallStatus::Rejected
945 | ToolCallStatus::Canceled => continue,
946 },
947 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
948 // Reached the beginning of the turn
949 return false;
950 }
951 }
952 }
953 false
954 }
955
956 pub fn initialize(&self) -> impl use<> + Future<Output = Result<acp::InitializeResponse>> {
957 self.request(acp::InitializeParams {
958 protocol_version: ProtocolVersion::latest(),
959 })
960 }
961
962 pub fn authenticate(&self) -> impl use<> + Future<Output = Result<()>> {
963 self.request(acp::AuthenticateParams)
964 }
965
966 #[cfg(any(test, feature = "test-support"))]
967 pub fn send_raw(
968 &mut self,
969 message: &str,
970 cx: &mut Context<Self>,
971 ) -> BoxFuture<'static, Result<(), acp::Error>> {
972 self.send(
973 acp::SendUserMessageParams {
974 chunks: vec![acp::UserMessageChunk::Text {
975 text: message.to_string(),
976 }],
977 },
978 cx,
979 )
980 }
981
982 pub fn send(
983 &mut self,
984 message: acp::SendUserMessageParams,
985 cx: &mut Context<Self>,
986 ) -> BoxFuture<'static, Result<(), acp::Error>> {
987 self.push_entry(
988 AgentThreadEntry::UserMessage(UserMessage::from_acp(
989 &message,
990 self.project.read(cx).languages().clone(),
991 cx,
992 )),
993 cx,
994 );
995
996 let (tx, rx) = oneshot::channel();
997 let cancel = self.cancel(cx);
998
999 self.send_task = Some(cx.spawn(async move |this, cx| {
1000 async {
1001 cancel.await.log_err();
1002
1003 let result = this.update(cx, |this, _| this.request(message))?.await;
1004 tx.send(result).log_err();
1005 this.update(cx, |this, _cx| this.send_task.take())?;
1006 anyhow::Ok(())
1007 }
1008 .await
1009 .log_err();
1010 }));
1011
1012 async move {
1013 match rx.await {
1014 Ok(Err(e)) => Err(e)?,
1015 _ => Ok(()),
1016 }
1017 }
1018 .boxed()
1019 }
1020
1021 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<(), acp::Error>> {
1022 if self.send_task.take().is_some() {
1023 let request = self.request(acp::CancelSendMessageParams);
1024 cx.spawn(async move |this, cx| {
1025 request.await?;
1026 this.update(cx, |this, _cx| {
1027 for entry in this.entries.iter_mut() {
1028 if let AgentThreadEntry::ToolCall(call) = entry {
1029 let cancel = matches!(
1030 call.status,
1031 ToolCallStatus::WaitingForConfirmation { .. }
1032 | ToolCallStatus::Allowed {
1033 status: acp::ToolCallStatus::Running
1034 }
1035 );
1036
1037 if cancel {
1038 let curr_status =
1039 mem::replace(&mut call.status, ToolCallStatus::Canceled);
1040
1041 if let ToolCallStatus::WaitingForConfirmation {
1042 respond_tx, ..
1043 } = curr_status
1044 {
1045 respond_tx
1046 .send(acp::ToolCallConfirmationOutcome::Cancel)
1047 .ok();
1048 }
1049 }
1050 }
1051 }
1052 })?;
1053 Ok(())
1054 })
1055 } else {
1056 Task::ready(Ok(()))
1057 }
1058 }
1059
1060 pub fn read_text_file(
1061 &self,
1062 request: acp::ReadTextFileParams,
1063 reuse_shared_snapshot: bool,
1064 cx: &mut Context<Self>,
1065 ) -> Task<Result<String>> {
1066 let project = self.project.clone();
1067 let action_log = self.action_log.clone();
1068 cx.spawn(async move |this, cx| {
1069 let load = project.update(cx, |project, cx| {
1070 let path = project
1071 .project_path_for_absolute_path(&request.path, cx)
1072 .context("invalid path")?;
1073 anyhow::Ok(project.open_buffer(path, cx))
1074 });
1075 let buffer = load??.await?;
1076
1077 let snapshot = if reuse_shared_snapshot {
1078 this.read_with(cx, |this, _| {
1079 this.shared_buffers.get(&buffer.clone()).cloned()
1080 })
1081 .log_err()
1082 .flatten()
1083 } else {
1084 None
1085 };
1086
1087 let snapshot = if let Some(snapshot) = snapshot {
1088 snapshot
1089 } else {
1090 action_log.update(cx, |action_log, cx| {
1091 action_log.buffer_read(buffer.clone(), cx);
1092 })?;
1093 project.update(cx, |project, cx| {
1094 let position = buffer
1095 .read(cx)
1096 .snapshot()
1097 .anchor_before(Point::new(request.line.unwrap_or_default(), 0));
1098 project.set_agent_location(
1099 Some(AgentLocation {
1100 buffer: buffer.downgrade(),
1101 position,
1102 }),
1103 cx,
1104 );
1105 })?;
1106
1107 buffer.update(cx, |buffer, _| buffer.snapshot())?
1108 };
1109
1110 this.update(cx, |this, _| {
1111 let text = snapshot.text();
1112 this.shared_buffers.insert(buffer.clone(), snapshot);
1113 if request.line.is_none() && request.limit.is_none() {
1114 return Ok(text);
1115 }
1116 let limit = request.limit.unwrap_or(u32::MAX) as usize;
1117 let Some(line) = request.line else {
1118 return Ok(text.lines().take(limit).collect::<String>());
1119 };
1120
1121 let count = text.lines().count();
1122 if count < line as usize {
1123 anyhow::bail!("There are only {} lines", count);
1124 }
1125 Ok(text
1126 .lines()
1127 .skip(line as usize + 1)
1128 .take(limit)
1129 .collect::<String>())
1130 })?
1131 })
1132 }
1133
1134 pub fn write_text_file(
1135 &self,
1136 path: PathBuf,
1137 content: String,
1138 cx: &mut Context<Self>,
1139 ) -> Task<Result<()>> {
1140 let project = self.project.clone();
1141 let action_log = self.action_log.clone();
1142 cx.spawn(async move |this, cx| {
1143 let load = project.update(cx, |project, cx| {
1144 let path = project
1145 .project_path_for_absolute_path(&path, cx)
1146 .context("invalid path")?;
1147 anyhow::Ok(project.open_buffer(path, cx))
1148 });
1149 let buffer = load??.await?;
1150 let snapshot = this.update(cx, |this, cx| {
1151 this.shared_buffers
1152 .get(&buffer)
1153 .cloned()
1154 .unwrap_or_else(|| buffer.read(cx).snapshot())
1155 })?;
1156 let edits = cx
1157 .background_executor()
1158 .spawn(async move {
1159 let old_text = snapshot.text();
1160 text_diff(old_text.as_str(), &content)
1161 .into_iter()
1162 .map(|(range, replacement)| {
1163 (
1164 snapshot.anchor_after(range.start)
1165 ..snapshot.anchor_before(range.end),
1166 replacement,
1167 )
1168 })
1169 .collect::<Vec<_>>()
1170 })
1171 .await;
1172 cx.update(|cx| {
1173 project.update(cx, |project, cx| {
1174 project.set_agent_location(
1175 Some(AgentLocation {
1176 buffer: buffer.downgrade(),
1177 position: edits
1178 .last()
1179 .map(|(range, _)| range.end)
1180 .unwrap_or(Anchor::MIN),
1181 }),
1182 cx,
1183 );
1184 });
1185
1186 action_log.update(cx, |action_log, cx| {
1187 action_log.buffer_read(buffer.clone(), cx);
1188 });
1189 buffer.update(cx, |buffer, cx| {
1190 buffer.edit(edits, None, cx);
1191 });
1192 action_log.update(cx, |action_log, cx| {
1193 action_log.buffer_edited(buffer.clone(), cx);
1194 });
1195 })?;
1196 project
1197 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1198 .await
1199 })
1200 }
1201
1202 pub fn child_status(&mut self) -> Option<Task<Result<()>>> {
1203 self.child_status.take()
1204 }
1205
1206 pub fn to_markdown(&self, cx: &App) -> String {
1207 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1208 }
1209}
1210
1211#[derive(Clone)]
1212pub struct AcpClientDelegate {
1213 thread: WeakEntity<AcpThread>,
1214 cx: AsyncApp,
1215 // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
1216}
1217
1218impl AcpClientDelegate {
1219 pub fn new(thread: WeakEntity<AcpThread>, cx: AsyncApp) -> Self {
1220 Self { thread, cx }
1221 }
1222
1223 pub async fn clear_completed_plan_entries(&self) -> Result<()> {
1224 let cx = &mut self.cx.clone();
1225 cx.update(|cx| {
1226 self.thread
1227 .update(cx, |thread, cx| thread.clear_completed_plan_entries(cx))
1228 })?
1229 .context("Failed to update thread")?;
1230
1231 Ok(())
1232 }
1233
1234 pub async fn request_existing_tool_call_confirmation(
1235 &self,
1236 tool_call_id: ToolCallId,
1237 confirmation: acp::ToolCallConfirmation,
1238 ) -> Result<ToolCallConfirmationOutcome> {
1239 let cx = &mut self.cx.clone();
1240 let ToolCallRequest { outcome, .. } = cx
1241 .update(|cx| {
1242 self.thread.update(cx, |thread, cx| {
1243 thread.request_tool_call_confirmation(tool_call_id, confirmation, cx)
1244 })
1245 })?
1246 .context("Failed to update thread")??;
1247
1248 Ok(outcome.await?)
1249 }
1250
1251 pub async fn read_text_file_reusing_snapshot(
1252 &self,
1253 request: acp::ReadTextFileParams,
1254 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
1255 let content = self
1256 .cx
1257 .update(|cx| {
1258 self.thread
1259 .update(cx, |thread, cx| thread.read_text_file(request, true, cx))
1260 })?
1261 .context("Failed to update thread")?
1262 .await?;
1263 Ok(acp::ReadTextFileResponse { content })
1264 }
1265}
1266
1267impl acp::Client for AcpClientDelegate {
1268 async fn stream_assistant_message_chunk(
1269 &self,
1270 params: acp::StreamAssistantMessageChunkParams,
1271 ) -> Result<(), acp::Error> {
1272 let cx = &mut self.cx.clone();
1273
1274 cx.update(|cx| {
1275 self.thread
1276 .update(cx, |thread, cx| {
1277 thread.push_assistant_chunk(params.chunk, cx)
1278 })
1279 .ok();
1280 })?;
1281
1282 Ok(())
1283 }
1284
1285 async fn request_tool_call_confirmation(
1286 &self,
1287 request: acp::RequestToolCallConfirmationParams,
1288 ) -> Result<acp::RequestToolCallConfirmationResponse, acp::Error> {
1289 let cx = &mut self.cx.clone();
1290 let ToolCallRequest { id, outcome } = cx
1291 .update(|cx| {
1292 self.thread
1293 .update(cx, |thread, cx| thread.request_new_tool_call(request, cx))
1294 })?
1295 .context("Failed to update thread")?;
1296
1297 Ok(acp::RequestToolCallConfirmationResponse {
1298 id,
1299 outcome: outcome.await.map_err(acp::Error::into_internal_error)?,
1300 })
1301 }
1302
1303 async fn push_tool_call(
1304 &self,
1305 request: acp::PushToolCallParams,
1306 ) -> Result<acp::PushToolCallResponse, acp::Error> {
1307 let cx = &mut self.cx.clone();
1308 let id = cx
1309 .update(|cx| {
1310 self.thread
1311 .update(cx, |thread, cx| thread.push_tool_call(request, cx))
1312 })?
1313 .context("Failed to update thread")?;
1314
1315 Ok(acp::PushToolCallResponse { id })
1316 }
1317
1318 async fn update_tool_call(&self, request: acp::UpdateToolCallParams) -> Result<(), acp::Error> {
1319 let cx = &mut self.cx.clone();
1320
1321 cx.update(|cx| {
1322 self.thread.update(cx, |thread, cx| {
1323 thread.update_tool_call(request.tool_call_id, request.status, request.content, cx)
1324 })
1325 })?
1326 .context("Failed to update thread")??;
1327
1328 Ok(())
1329 }
1330
1331 async fn update_plan(&self, request: acp::UpdatePlanParams) -> Result<(), acp::Error> {
1332 let cx = &mut self.cx.clone();
1333
1334 cx.update(|cx| {
1335 self.thread
1336 .update(cx, |thread, cx| thread.update_plan(request, cx))
1337 })?
1338 .context("Failed to update thread")?;
1339
1340 Ok(())
1341 }
1342
1343 async fn read_text_file(
1344 &self,
1345 request: acp::ReadTextFileParams,
1346 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
1347 let content = self
1348 .cx
1349 .update(|cx| {
1350 self.thread
1351 .update(cx, |thread, cx| thread.read_text_file(request, false, cx))
1352 })?
1353 .context("Failed to update thread")?
1354 .await?;
1355 Ok(acp::ReadTextFileResponse { content })
1356 }
1357
1358 async fn write_text_file(&self, request: acp::WriteTextFileParams) -> Result<(), acp::Error> {
1359 self.cx
1360 .update(|cx| {
1361 self.thread.update(cx, |thread, cx| {
1362 thread.write_text_file(request.path, request.content, cx)
1363 })
1364 })?
1365 .context("Failed to update thread")?
1366 .await?;
1367
1368 Ok(())
1369 }
1370}
1371
1372fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
1373 match icon {
1374 acp::Icon::FileSearch => IconName::ToolSearch,
1375 acp::Icon::Folder => IconName::ToolFolder,
1376 acp::Icon::Globe => IconName::ToolWeb,
1377 acp::Icon::Hammer => IconName::ToolHammer,
1378 acp::Icon::LightBulb => IconName::ToolBulb,
1379 acp::Icon::Pencil => IconName::ToolPencil,
1380 acp::Icon::Regex => IconName::ToolRegex,
1381 acp::Icon::Terminal => IconName::ToolTerminal,
1382 }
1383}
1384
1385pub struct ToolCallRequest {
1386 pub id: acp::ToolCallId,
1387 pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
1388}
1389
1390#[cfg(test)]
1391mod tests {
1392 use super::*;
1393 use anyhow::anyhow;
1394 use async_pipe::{PipeReader, PipeWriter};
1395 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1396 use gpui::{AsyncApp, TestAppContext};
1397 use indoc::indoc;
1398 use project::FakeFs;
1399 use serde_json::json;
1400 use settings::SettingsStore;
1401 use smol::{future::BoxedLocal, stream::StreamExt as _};
1402 use std::{cell::RefCell, rc::Rc, time::Duration};
1403 use util::path;
1404
1405 fn init_test(cx: &mut TestAppContext) {
1406 env_logger::try_init().ok();
1407 cx.update(|cx| {
1408 let settings_store = SettingsStore::test(cx);
1409 cx.set_global(settings_store);
1410 Project::init_settings(cx);
1411 language::init(cx);
1412 });
1413 }
1414
1415 #[gpui::test]
1416 async fn test_thinking_concatenation(cx: &mut TestAppContext) {
1417 init_test(cx);
1418
1419 let fs = FakeFs::new(cx.executor());
1420 let project = Project::test(fs, [], cx).await;
1421 let (thread, fake_server) = fake_acp_thread(project, cx);
1422
1423 fake_server.update(cx, |fake_server, _| {
1424 fake_server.on_user_message(move |_, server, mut cx| async move {
1425 server
1426 .update(&mut cx, |server, _| {
1427 server.send_to_zed(acp::StreamAssistantMessageChunkParams {
1428 chunk: acp::AssistantMessageChunk::Thought {
1429 thought: "Thinking ".into(),
1430 },
1431 })
1432 })?
1433 .await
1434 .unwrap();
1435 server
1436 .update(&mut cx, |server, _| {
1437 server.send_to_zed(acp::StreamAssistantMessageChunkParams {
1438 chunk: acp::AssistantMessageChunk::Thought {
1439 thought: "hard!".into(),
1440 },
1441 })
1442 })?
1443 .await
1444 .unwrap();
1445
1446 Ok(())
1447 })
1448 });
1449
1450 thread
1451 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1452 .await
1453 .unwrap();
1454
1455 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1456 assert_eq!(
1457 output,
1458 indoc! {r#"
1459 ## User
1460
1461 Hello from Zed!
1462
1463 ## Assistant
1464
1465 <thinking>
1466 Thinking hard!
1467 </thinking>
1468
1469 "#}
1470 );
1471 }
1472
1473 #[gpui::test]
1474 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1475 init_test(cx);
1476
1477 let fs = FakeFs::new(cx.executor());
1478 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1479 .await;
1480 let project = Project::test(fs.clone(), [], cx).await;
1481 let (thread, fake_server) = fake_acp_thread(project.clone(), cx);
1482 let (worktree, pathbuf) = project
1483 .update(cx, |project, cx| {
1484 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1485 })
1486 .await
1487 .unwrap();
1488 let buffer = project
1489 .update(cx, |project, cx| {
1490 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1491 })
1492 .await
1493 .unwrap();
1494
1495 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1496 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1497
1498 fake_server.update(cx, |fake_server, _| {
1499 fake_server.on_user_message(move |_, server, mut cx| {
1500 let read_file_tx = read_file_tx.clone();
1501 async move {
1502 let content = server
1503 .update(&mut cx, |server, _| {
1504 server.send_to_zed(acp::ReadTextFileParams {
1505 path: path!("/tmp/foo").into(),
1506 line: None,
1507 limit: None,
1508 })
1509 })?
1510 .await
1511 .unwrap();
1512 assert_eq!(content.content, "one\ntwo\nthree\n");
1513 read_file_tx.take().unwrap().send(()).unwrap();
1514 server
1515 .update(&mut cx, |server, _| {
1516 server.send_to_zed(acp::WriteTextFileParams {
1517 path: path!("/tmp/foo").into(),
1518 content: "one\ntwo\nthree\nfour\nfive\n".to_string(),
1519 })
1520 })?
1521 .await
1522 .unwrap();
1523 Ok(())
1524 }
1525 })
1526 });
1527
1528 let request = thread.update(cx, |thread, cx| {
1529 thread.send_raw("Extend the count in /tmp/foo", cx)
1530 });
1531 read_file_rx.await.ok();
1532 buffer.update(cx, |buffer, cx| {
1533 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1534 });
1535 cx.run_until_parked();
1536 assert_eq!(
1537 buffer.read_with(cx, |buffer, _| buffer.text()),
1538 "zero\none\ntwo\nthree\nfour\nfive\n"
1539 );
1540 assert_eq!(
1541 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1542 "zero\none\ntwo\nthree\nfour\nfive\n"
1543 );
1544 request.await.unwrap();
1545 }
1546
1547 #[gpui::test]
1548 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1549 init_test(cx);
1550
1551 let fs = FakeFs::new(cx.executor());
1552 let project = Project::test(fs, [], cx).await;
1553 let (thread, fake_server) = fake_acp_thread(project, cx);
1554
1555 let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>();
1556
1557 let tool_call_id = Rc::new(RefCell::new(None));
1558 let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx)));
1559 fake_server.update(cx, |fake_server, _| {
1560 let tool_call_id = tool_call_id.clone();
1561 fake_server.on_user_message(move |_, server, mut cx| {
1562 let end_turn_rx = end_turn_rx.clone();
1563 let tool_call_id = tool_call_id.clone();
1564 async move {
1565 let tool_call_result = server
1566 .update(&mut cx, |server, _| {
1567 server.send_to_zed(acp::PushToolCallParams {
1568 label: "Fetch".to_string(),
1569 icon: acp::Icon::Globe,
1570 content: None,
1571 locations: vec![],
1572 })
1573 })?
1574 .await
1575 .unwrap();
1576 *tool_call_id.clone().borrow_mut() = Some(tool_call_result.id);
1577 end_turn_rx.take().unwrap().await.ok();
1578
1579 Ok(())
1580 }
1581 })
1582 });
1583
1584 let request = thread.update(cx, |thread, cx| {
1585 thread.send_raw("Fetch https://example.com", cx)
1586 });
1587
1588 run_until_first_tool_call(&thread, cx).await;
1589
1590 thread.read_with(cx, |thread, _| {
1591 assert!(matches!(
1592 thread.entries[1],
1593 AgentThreadEntry::ToolCall(ToolCall {
1594 status: ToolCallStatus::Allowed {
1595 status: acp::ToolCallStatus::Running,
1596 ..
1597 },
1598 ..
1599 })
1600 ));
1601 });
1602
1603 cx.run_until_parked();
1604
1605 thread
1606 .update(cx, |thread, cx| thread.cancel(cx))
1607 .await
1608 .unwrap();
1609
1610 thread.read_with(cx, |thread, _| {
1611 assert!(matches!(
1612 &thread.entries[1],
1613 AgentThreadEntry::ToolCall(ToolCall {
1614 status: ToolCallStatus::Canceled,
1615 ..
1616 })
1617 ));
1618 });
1619
1620 fake_server
1621 .update(cx, |fake_server, _| {
1622 fake_server.send_to_zed(acp::UpdateToolCallParams {
1623 tool_call_id: tool_call_id.borrow().unwrap(),
1624 status: acp::ToolCallStatus::Finished,
1625 content: None,
1626 })
1627 })
1628 .await
1629 .unwrap();
1630
1631 drop(end_turn_tx);
1632 request.await.unwrap();
1633
1634 thread.read_with(cx, |thread, _| {
1635 assert!(matches!(
1636 thread.entries[1],
1637 AgentThreadEntry::ToolCall(ToolCall {
1638 status: ToolCallStatus::Allowed {
1639 status: acp::ToolCallStatus::Finished,
1640 ..
1641 },
1642 ..
1643 })
1644 ));
1645 });
1646 }
1647
1648 async fn run_until_first_tool_call(
1649 thread: &Entity<AcpThread>,
1650 cx: &mut TestAppContext,
1651 ) -> usize {
1652 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1653
1654 let subscription = cx.update(|cx| {
1655 cx.subscribe(thread, move |thread, _, cx| {
1656 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1657 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1658 return tx.try_send(ix).unwrap();
1659 }
1660 }
1661 })
1662 });
1663
1664 select! {
1665 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1666 panic!("Timeout waiting for tool call")
1667 }
1668 ix = rx.next().fuse() => {
1669 drop(subscription);
1670 ix.unwrap()
1671 }
1672 }
1673 }
1674
1675 pub fn fake_acp_thread(
1676 project: Entity<Project>,
1677 cx: &mut TestAppContext,
1678 ) -> (Entity<AcpThread>, Entity<FakeAcpServer>) {
1679 let (stdin_tx, stdin_rx) = async_pipe::pipe();
1680 let (stdout_tx, stdout_rx) = async_pipe::pipe();
1681
1682 let thread = cx.new(|cx| {
1683 let foreground_executor = cx.foreground_executor().clone();
1684 let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
1685 AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
1686 stdin_tx,
1687 stdout_rx,
1688 move |fut| {
1689 foreground_executor.spawn(fut).detach();
1690 },
1691 );
1692
1693 let io_task = cx.background_spawn({
1694 async move {
1695 io_fut.await.log_err();
1696 Ok(())
1697 }
1698 });
1699 AcpThread::new(connection, "Test".into(), Some(io_task), project, cx)
1700 });
1701 let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
1702 (thread, agent)
1703 }
1704
1705 pub struct FakeAcpServer {
1706 connection: acp::ClientConnection,
1707
1708 _io_task: Task<()>,
1709 on_user_message: Option<
1710 Rc<
1711 dyn Fn(
1712 acp::SendUserMessageParams,
1713 Entity<FakeAcpServer>,
1714 AsyncApp,
1715 ) -> LocalBoxFuture<'static, Result<(), acp::Error>>,
1716 >,
1717 >,
1718 }
1719
1720 #[derive(Clone)]
1721 struct FakeAgent {
1722 server: Entity<FakeAcpServer>,
1723 cx: AsyncApp,
1724 }
1725
1726 impl acp::Agent for FakeAgent {
1727 async fn initialize(
1728 &self,
1729 params: acp::InitializeParams,
1730 ) -> Result<acp::InitializeResponse, acp::Error> {
1731 Ok(acp::InitializeResponse {
1732 protocol_version: params.protocol_version,
1733 is_authenticated: true,
1734 })
1735 }
1736
1737 async fn authenticate(&self) -> Result<(), acp::Error> {
1738 Ok(())
1739 }
1740
1741 async fn cancel_send_message(&self) -> Result<(), acp::Error> {
1742 Ok(())
1743 }
1744
1745 async fn send_user_message(
1746 &self,
1747 request: acp::SendUserMessageParams,
1748 ) -> Result<(), acp::Error> {
1749 let mut cx = self.cx.clone();
1750 let handler = self
1751 .server
1752 .update(&mut cx, |server, _| server.on_user_message.clone())
1753 .ok()
1754 .flatten();
1755 if let Some(handler) = handler {
1756 handler(request, self.server.clone(), self.cx.clone()).await
1757 } else {
1758 Err(anyhow::anyhow!("No handler for on_user_message").into())
1759 }
1760 }
1761 }
1762
1763 impl FakeAcpServer {
1764 fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
1765 let agent = FakeAgent {
1766 server: cx.entity(),
1767 cx: cx.to_async(),
1768 };
1769 let foreground_executor = cx.foreground_executor().clone();
1770
1771 let (connection, io_fut) = acp::ClientConnection::connect_to_client(
1772 agent.clone(),
1773 stdout,
1774 stdin,
1775 move |fut| {
1776 foreground_executor.spawn(fut).detach();
1777 },
1778 );
1779 FakeAcpServer {
1780 connection: connection,
1781 on_user_message: None,
1782 _io_task: cx.background_spawn(async move {
1783 io_fut.await.log_err();
1784 }),
1785 }
1786 }
1787
1788 fn on_user_message<F>(
1789 &mut self,
1790 handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity<FakeAcpServer>, AsyncApp) -> F
1791 + 'static,
1792 ) where
1793 F: Future<Output = Result<(), acp::Error>> + 'static,
1794 {
1795 self.on_user_message
1796 .replace(Rc::new(move |request, server, cx| {
1797 handler(request, server, cx).boxed_local()
1798 }));
1799 }
1800
1801 fn send_to_zed<T: acp::ClientRequest + 'static>(
1802 &self,
1803 message: T,
1804 ) -> BoxedLocal<Result<T::Response>> {
1805 self.connection
1806 .request(message)
1807 .map(|f| f.map_err(|err| anyhow!(err)))
1808 .boxed_local()
1809 }
1810 }
1811}