1mod connection;
2pub use connection::*;
3
4pub use acp::ToolCallId;
5use agentic_coding_protocol::{
6 self as acp, AgentRequest, ProtocolVersion, ToolCallConfirmationOutcome, ToolCallLocation,
7 UserMessageChunk,
8};
9use anyhow::{Context as _, Result};
10use assistant_tool::ActionLog;
11use buffer_diff::BufferDiff;
12use editor::{Bias, MultiBuffer, PathKey};
13use futures::{FutureExt, channel::oneshot, future::BoxFuture};
14use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
15use itertools::Itertools;
16use language::{
17 Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
18 text_diff,
19};
20use markdown::Markdown;
21use project::{AgentLocation, Project};
22use std::collections::HashMap;
23use std::error::Error;
24use std::fmt::{Formatter, Write};
25use std::{
26 fmt::Display,
27 mem,
28 path::{Path, PathBuf},
29 sync::Arc,
30};
31use ui::{App, IconName};
32use util::ResultExt;
33
34#[derive(Clone, Debug, Eq, PartialEq)]
35pub struct UserMessage {
36 pub content: Entity<Markdown>,
37}
38
39impl UserMessage {
40 pub fn from_acp(
41 message: &acp::SendUserMessageParams,
42 language_registry: Arc<LanguageRegistry>,
43 cx: &mut App,
44 ) -> Self {
45 let mut md_source = String::new();
46
47 for chunk in &message.chunks {
48 match chunk {
49 UserMessageChunk::Text { text } => md_source.push_str(&text),
50 UserMessageChunk::Path { path } => {
51 write!(&mut md_source, "{}", MentionPath(&path)).unwrap()
52 }
53 }
54 }
55
56 Self {
57 content: cx
58 .new(|cx| Markdown::new(md_source.into(), Some(language_registry), None, cx)),
59 }
60 }
61
62 fn to_markdown(&self, cx: &App) -> String {
63 format!("## User\n\n{}\n\n", self.content.read(cx).source())
64 }
65}
66
67#[derive(Debug)]
68pub struct MentionPath<'a>(&'a Path);
69
70impl<'a> MentionPath<'a> {
71 const PREFIX: &'static str = "@file:";
72
73 pub fn new(path: &'a Path) -> Self {
74 MentionPath(path)
75 }
76
77 pub fn try_parse(url: &'a str) -> Option<Self> {
78 let path = url.strip_prefix(Self::PREFIX)?;
79 Some(MentionPath(Path::new(path)))
80 }
81
82 pub fn path(&self) -> &Path {
83 self.0
84 }
85}
86
87impl Display for MentionPath<'_> {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 write!(
90 f,
91 "[@{}]({}{})",
92 self.0.file_name().unwrap_or_default().display(),
93 Self::PREFIX,
94 self.0.display()
95 )
96 }
97}
98
99#[derive(Clone, Debug, Eq, PartialEq)]
100pub struct AssistantMessage {
101 pub chunks: Vec<AssistantMessageChunk>,
102}
103
104impl AssistantMessage {
105 pub fn to_markdown(&self, cx: &App) -> String {
106 format!(
107 "## Assistant\n\n{}\n\n",
108 self.chunks
109 .iter()
110 .map(|chunk| chunk.to_markdown(cx))
111 .join("\n\n")
112 )
113 }
114}
115
116#[derive(Clone, Debug, Eq, PartialEq)]
117pub enum AssistantMessageChunk {
118 Text { chunk: Entity<Markdown> },
119 Thought { chunk: Entity<Markdown> },
120}
121
122impl AssistantMessageChunk {
123 pub fn from_acp(
124 chunk: acp::AssistantMessageChunk,
125 language_registry: Arc<LanguageRegistry>,
126 cx: &mut App,
127 ) -> Self {
128 match chunk {
129 acp::AssistantMessageChunk::Text { text } => Self::Text {
130 chunk: cx.new(|cx| Markdown::new(text.into(), Some(language_registry), None, cx)),
131 },
132 acp::AssistantMessageChunk::Thought { thought } => Self::Thought {
133 chunk: cx
134 .new(|cx| Markdown::new(thought.into(), Some(language_registry), None, cx)),
135 },
136 }
137 }
138
139 pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
140 Self::Text {
141 chunk: cx.new(|cx| {
142 Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
143 }),
144 }
145 }
146
147 fn to_markdown(&self, cx: &App) -> String {
148 match self {
149 Self::Text { chunk } => chunk.read(cx).source().to_string(),
150 Self::Thought { chunk } => {
151 format!("<thinking>\n{}\n</thinking>", chunk.read(cx).source())
152 }
153 }
154 }
155}
156
157#[derive(Debug)]
158pub enum AgentThreadEntry {
159 UserMessage(UserMessage),
160 AssistantMessage(AssistantMessage),
161 ToolCall(ToolCall),
162}
163
164impl AgentThreadEntry {
165 fn to_markdown(&self, cx: &App) -> String {
166 match self {
167 Self::UserMessage(message) => message.to_markdown(cx),
168 Self::AssistantMessage(message) => message.to_markdown(cx),
169 Self::ToolCall(too_call) => too_call.to_markdown(cx),
170 }
171 }
172
173 pub fn diff(&self) -> Option<&Diff> {
174 if let AgentThreadEntry::ToolCall(ToolCall {
175 content: Some(ToolCallContent::Diff { diff }),
176 ..
177 }) = self
178 {
179 Some(&diff)
180 } else {
181 None
182 }
183 }
184
185 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
186 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
187 Some(locations)
188 } else {
189 None
190 }
191 }
192}
193
194#[derive(Debug)]
195pub struct ToolCall {
196 pub id: acp::ToolCallId,
197 pub label: Entity<Markdown>,
198 pub icon: IconName,
199 pub content: Option<ToolCallContent>,
200 pub status: ToolCallStatus,
201 pub locations: Vec<acp::ToolCallLocation>,
202}
203
204impl ToolCall {
205 fn to_markdown(&self, cx: &App) -> String {
206 let mut markdown = format!(
207 "**Tool Call: {}**\nStatus: {}\n\n",
208 self.label.read(cx).source(),
209 self.status
210 );
211 if let Some(content) = &self.content {
212 markdown.push_str(content.to_markdown(cx).as_str());
213 markdown.push_str("\n\n");
214 }
215 markdown
216 }
217}
218
219#[derive(Debug)]
220pub enum ToolCallStatus {
221 WaitingForConfirmation {
222 confirmation: ToolCallConfirmation,
223 respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
224 },
225 Allowed {
226 status: acp::ToolCallStatus,
227 },
228 Rejected,
229 Canceled,
230}
231
232impl Display for ToolCallStatus {
233 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
234 write!(
235 f,
236 "{}",
237 match self {
238 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
239 ToolCallStatus::Allowed { status } => match status {
240 acp::ToolCallStatus::Running => "Running",
241 acp::ToolCallStatus::Finished => "Finished",
242 acp::ToolCallStatus::Error => "Error",
243 },
244 ToolCallStatus::Rejected => "Rejected",
245 ToolCallStatus::Canceled => "Canceled",
246 }
247 )
248 }
249}
250
251#[derive(Debug)]
252pub enum ToolCallConfirmation {
253 Edit {
254 description: Option<Entity<Markdown>>,
255 },
256 Execute {
257 command: String,
258 root_command: String,
259 description: Option<Entity<Markdown>>,
260 },
261 Mcp {
262 server_name: String,
263 tool_name: String,
264 tool_display_name: String,
265 description: Option<Entity<Markdown>>,
266 },
267 Fetch {
268 urls: Vec<SharedString>,
269 description: Option<Entity<Markdown>>,
270 },
271 Other {
272 description: Entity<Markdown>,
273 },
274}
275
276impl ToolCallConfirmation {
277 pub fn from_acp(
278 confirmation: acp::ToolCallConfirmation,
279 language_registry: Arc<LanguageRegistry>,
280 cx: &mut App,
281 ) -> Self {
282 let to_md = |description: String, cx: &mut App| -> Entity<Markdown> {
283 cx.new(|cx| {
284 Markdown::new(
285 description.into(),
286 Some(language_registry.clone()),
287 None,
288 cx,
289 )
290 })
291 };
292
293 match confirmation {
294 acp::ToolCallConfirmation::Edit { description } => Self::Edit {
295 description: description.map(|description| to_md(description, cx)),
296 },
297 acp::ToolCallConfirmation::Execute {
298 command,
299 root_command,
300 description,
301 } => Self::Execute {
302 command,
303 root_command,
304 description: description.map(|description| to_md(description, cx)),
305 },
306 acp::ToolCallConfirmation::Mcp {
307 server_name,
308 tool_name,
309 tool_display_name,
310 description,
311 } => Self::Mcp {
312 server_name,
313 tool_name,
314 tool_display_name,
315 description: description.map(|description| to_md(description, cx)),
316 },
317 acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch {
318 urls: urls.iter().map(|url| url.into()).collect(),
319 description: description.map(|description| to_md(description, cx)),
320 },
321 acp::ToolCallConfirmation::Other { description } => Self::Other {
322 description: to_md(description, cx),
323 },
324 }
325 }
326}
327
328#[derive(Debug)]
329pub enum ToolCallContent {
330 Markdown { markdown: Entity<Markdown> },
331 Diff { diff: Diff },
332}
333
334impl ToolCallContent {
335 pub fn from_acp(
336 content: acp::ToolCallContent,
337 language_registry: Arc<LanguageRegistry>,
338 cx: &mut App,
339 ) -> Self {
340 match content {
341 acp::ToolCallContent::Markdown { markdown } => Self::Markdown {
342 markdown: cx.new(|cx| Markdown::new_text(markdown.into(), cx)),
343 },
344 acp::ToolCallContent::Diff { diff } => Self::Diff {
345 diff: Diff::from_acp(diff, language_registry, cx),
346 },
347 }
348 }
349
350 fn to_markdown(&self, cx: &App) -> String {
351 match self {
352 Self::Markdown { markdown } => markdown.read(cx).source().to_string(),
353 Self::Diff { diff } => diff.to_markdown(cx),
354 }
355 }
356}
357
358#[derive(Debug)]
359pub struct Diff {
360 pub multibuffer: Entity<MultiBuffer>,
361 pub path: PathBuf,
362 pub new_buffer: Entity<Buffer>,
363 pub old_buffer: Entity<Buffer>,
364 _task: Task<Result<()>>,
365}
366
367impl Diff {
368 pub fn from_acp(
369 diff: acp::Diff,
370 language_registry: Arc<LanguageRegistry>,
371 cx: &mut App,
372 ) -> Self {
373 let acp::Diff {
374 path,
375 old_text,
376 new_text,
377 } = diff;
378
379 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
380
381 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
382 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
383 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
384 let old_buffer_snapshot = old_buffer.read(cx).snapshot();
385 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
386 let diff_task = buffer_diff.update(cx, |diff, cx| {
387 diff.set_base_text(
388 old_buffer_snapshot,
389 Some(language_registry.clone()),
390 new_buffer_snapshot,
391 cx,
392 )
393 });
394
395 let task = cx.spawn({
396 let multibuffer = multibuffer.clone();
397 let path = path.clone();
398 let new_buffer = new_buffer.clone();
399 async move |cx| {
400 diff_task.await?;
401
402 multibuffer
403 .update(cx, |multibuffer, cx| {
404 let hunk_ranges = {
405 let buffer = new_buffer.read(cx);
406 let diff = buffer_diff.read(cx);
407 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
408 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
409 .collect::<Vec<_>>()
410 };
411
412 multibuffer.set_excerpts_for_path(
413 PathKey::for_buffer(&new_buffer, cx),
414 new_buffer.clone(),
415 hunk_ranges,
416 editor::DEFAULT_MULTIBUFFER_CONTEXT,
417 cx,
418 );
419 multibuffer.add_diff(buffer_diff.clone(), cx);
420 })
421 .log_err();
422
423 if let Some(language) = language_registry
424 .language_for_file_path(&path)
425 .await
426 .log_err()
427 {
428 new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
429 }
430
431 anyhow::Ok(())
432 }
433 });
434
435 Self {
436 multibuffer,
437 path,
438 new_buffer,
439 old_buffer,
440 _task: task,
441 }
442 }
443
444 fn to_markdown(&self, cx: &App) -> String {
445 let buffer_text = self
446 .multibuffer
447 .read(cx)
448 .all_buffers()
449 .iter()
450 .map(|buffer| buffer.read(cx).text())
451 .join("\n");
452 format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
453 }
454}
455
456pub struct AcpThread {
457 entries: Vec<AgentThreadEntry>,
458 title: SharedString,
459 project: Entity<Project>,
460 action_log: Entity<ActionLog>,
461 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
462 send_task: Option<Task<()>>,
463 connection: Arc<dyn AgentConnection>,
464 child_status: Option<Task<Result<()>>>,
465}
466
467pub enum AcpThreadEvent {
468 NewEntry,
469 EntryUpdated(usize),
470}
471
472impl EventEmitter<AcpThreadEvent> for AcpThread {}
473
474#[derive(PartialEq, Eq)]
475pub enum ThreadStatus {
476 Idle,
477 WaitingForToolConfirmation,
478 Generating,
479}
480
481#[derive(Debug, Clone)]
482pub enum LoadError {
483 Unsupported {
484 error_message: SharedString,
485 upgrade_message: SharedString,
486 upgrade_command: String,
487 },
488 Exited(i32),
489 Other(SharedString),
490}
491
492impl Display for LoadError {
493 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
494 match self {
495 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
496 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
497 LoadError::Other(msg) => write!(f, "{}", msg),
498 }
499 }
500}
501
502impl Error for LoadError {}
503
504impl AcpThread {
505 pub fn new(
506 connection: impl AgentConnection + 'static,
507 title: SharedString,
508 child_status: Option<Task<Result<()>>>,
509 project: Entity<Project>,
510 cx: &mut Context<Self>,
511 ) -> Self {
512 let action_log = cx.new(|_| ActionLog::new(project.clone()));
513
514 Self {
515 action_log,
516 shared_buffers: Default::default(),
517 entries: Default::default(),
518 title,
519 project,
520 send_task: None,
521 connection: Arc::new(connection),
522 child_status,
523 }
524 }
525
526 /// Send a request to the agent and wait for a response.
527 pub fn request<R: AgentRequest + 'static>(
528 &self,
529 params: R,
530 ) -> impl use<R> + Future<Output = Result<R::Response>> {
531 let params = params.into_any();
532 let result = self.connection.request_any(params);
533 async move {
534 let result = result.await?;
535 Ok(R::response_from_any(result)?)
536 }
537 }
538
539 pub fn action_log(&self) -> &Entity<ActionLog> {
540 &self.action_log
541 }
542
543 pub fn project(&self) -> &Entity<Project> {
544 &self.project
545 }
546
547 pub fn title(&self) -> SharedString {
548 self.title.clone()
549 }
550
551 pub fn entries(&self) -> &[AgentThreadEntry] {
552 &self.entries
553 }
554
555 pub fn status(&self) -> ThreadStatus {
556 if self.send_task.is_some() {
557 if self.waiting_for_tool_confirmation() {
558 ThreadStatus::WaitingForToolConfirmation
559 } else {
560 ThreadStatus::Generating
561 }
562 } else {
563 ThreadStatus::Idle
564 }
565 }
566
567 pub fn has_pending_edit_tool_calls(&self) -> bool {
568 for entry in self.entries.iter().rev() {
569 match entry {
570 AgentThreadEntry::UserMessage(_) => return false,
571 AgentThreadEntry::ToolCall(ToolCall {
572 status:
573 ToolCallStatus::Allowed {
574 status: acp::ToolCallStatus::Running,
575 ..
576 },
577 content: Some(ToolCallContent::Diff { .. }),
578 ..
579 }) => return true,
580 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
581 }
582 }
583
584 false
585 }
586
587 pub fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
588 self.entries.push(entry);
589 cx.emit(AcpThreadEvent::NewEntry);
590 }
591
592 pub fn push_assistant_chunk(
593 &mut self,
594 chunk: acp::AssistantMessageChunk,
595 cx: &mut Context<Self>,
596 ) {
597 let entries_len = self.entries.len();
598 if let Some(last_entry) = self.entries.last_mut()
599 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
600 {
601 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
602
603 match (chunks.last_mut(), &chunk) {
604 (
605 Some(AssistantMessageChunk::Text { chunk: old_chunk }),
606 acp::AssistantMessageChunk::Text { text: new_chunk },
607 )
608 | (
609 Some(AssistantMessageChunk::Thought { chunk: old_chunk }),
610 acp::AssistantMessageChunk::Thought { thought: new_chunk },
611 ) => {
612 old_chunk.update(cx, |old_chunk, cx| {
613 old_chunk.append(&new_chunk, cx);
614 });
615 }
616 _ => {
617 chunks.push(AssistantMessageChunk::from_acp(
618 chunk,
619 self.project.read(cx).languages().clone(),
620 cx,
621 ));
622 }
623 }
624 } else {
625 let chunk = AssistantMessageChunk::from_acp(
626 chunk,
627 self.project.read(cx).languages().clone(),
628 cx,
629 );
630
631 self.push_entry(
632 AgentThreadEntry::AssistantMessage(AssistantMessage {
633 chunks: vec![chunk],
634 }),
635 cx,
636 );
637 }
638 }
639
640 pub fn request_new_tool_call(
641 &mut self,
642 tool_call: acp::RequestToolCallConfirmationParams,
643 cx: &mut Context<Self>,
644 ) -> ToolCallRequest {
645 let (tx, rx) = oneshot::channel();
646
647 let status = ToolCallStatus::WaitingForConfirmation {
648 confirmation: ToolCallConfirmation::from_acp(
649 tool_call.confirmation,
650 self.project.read(cx).languages().clone(),
651 cx,
652 ),
653 respond_tx: tx,
654 };
655
656 let id = self.insert_tool_call(tool_call.tool_call, status, cx);
657 ToolCallRequest { id, outcome: rx }
658 }
659
660 pub fn request_tool_call_confirmation(
661 &mut self,
662 tool_call_id: ToolCallId,
663 confirmation: acp::ToolCallConfirmation,
664 cx: &mut Context<Self>,
665 ) -> Result<ToolCallRequest> {
666 let project = self.project.read(cx).languages().clone();
667 let Some((idx, call)) = self.tool_call_mut(tool_call_id) else {
668 anyhow::bail!("Tool call not found");
669 };
670
671 let (tx, rx) = oneshot::channel();
672
673 call.status = ToolCallStatus::WaitingForConfirmation {
674 confirmation: ToolCallConfirmation::from_acp(confirmation, project, cx),
675 respond_tx: tx,
676 };
677
678 cx.emit(AcpThreadEvent::EntryUpdated(idx));
679
680 Ok(ToolCallRequest {
681 id: tool_call_id,
682 outcome: rx,
683 })
684 }
685
686 pub fn push_tool_call(
687 &mut self,
688 request: acp::PushToolCallParams,
689 cx: &mut Context<Self>,
690 ) -> acp::ToolCallId {
691 let status = ToolCallStatus::Allowed {
692 status: acp::ToolCallStatus::Running,
693 };
694
695 self.insert_tool_call(request, status, cx)
696 }
697
698 fn insert_tool_call(
699 &mut self,
700 tool_call: acp::PushToolCallParams,
701 status: ToolCallStatus,
702 cx: &mut Context<Self>,
703 ) -> acp::ToolCallId {
704 let language_registry = self.project.read(cx).languages().clone();
705 let id = acp::ToolCallId(self.entries.len() as u64);
706 let call = ToolCall {
707 id,
708 label: cx.new(|cx| {
709 Markdown::new(
710 tool_call.label.into(),
711 Some(language_registry.clone()),
712 None,
713 cx,
714 )
715 }),
716 icon: acp_icon_to_ui_icon(tool_call.icon),
717 content: tool_call
718 .content
719 .map(|content| ToolCallContent::from_acp(content, language_registry, cx)),
720 locations: tool_call.locations,
721 status,
722 };
723
724 let location = call.locations.last().cloned();
725 if let Some(location) = location {
726 self.set_project_location(location, cx)
727 }
728
729 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
730
731 id
732 }
733
734 pub fn authorize_tool_call(
735 &mut self,
736 id: acp::ToolCallId,
737 outcome: acp::ToolCallConfirmationOutcome,
738 cx: &mut Context<Self>,
739 ) {
740 let Some((ix, call)) = self.tool_call_mut(id) else {
741 return;
742 };
743
744 let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
745 ToolCallStatus::Rejected
746 } else {
747 ToolCallStatus::Allowed {
748 status: acp::ToolCallStatus::Running,
749 }
750 };
751
752 let curr_status = mem::replace(&mut call.status, new_status);
753
754 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
755 respond_tx.send(outcome).log_err();
756 } else if cfg!(debug_assertions) {
757 panic!("tried to authorize an already authorized tool call");
758 }
759
760 cx.emit(AcpThreadEvent::EntryUpdated(ix));
761 }
762
763 pub fn update_tool_call(
764 &mut self,
765 id: acp::ToolCallId,
766 new_status: acp::ToolCallStatus,
767 new_content: Option<acp::ToolCallContent>,
768 cx: &mut Context<Self>,
769 ) -> Result<()> {
770 let language_registry = self.project.read(cx).languages().clone();
771 let (ix, call) = self.tool_call_mut(id).context("Entry not found")?;
772
773 if let Some(new_content) = new_content {
774 call.content = Some(ToolCallContent::from_acp(
775 new_content,
776 language_registry,
777 cx,
778 ));
779 }
780
781 match &mut call.status {
782 ToolCallStatus::Allowed { status } => {
783 *status = new_status;
784 }
785 ToolCallStatus::WaitingForConfirmation { .. } => {
786 anyhow::bail!("Tool call hasn't been authorized yet")
787 }
788 ToolCallStatus::Rejected => {
789 anyhow::bail!("Tool call was rejected and therefore can't be updated")
790 }
791 ToolCallStatus::Canceled => {
792 call.status = ToolCallStatus::Allowed { status: new_status };
793 }
794 }
795
796 let location = call.locations.last().cloned();
797 if let Some(location) = location {
798 self.set_project_location(location, cx)
799 }
800
801 cx.emit(AcpThreadEvent::EntryUpdated(ix));
802 Ok(())
803 }
804
805 fn tool_call_mut(&mut self, id: acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
806 let entry = self.entries.get_mut(id.0 as usize);
807 debug_assert!(
808 entry.is_some(),
809 "We shouldn't give out ids to entries that don't exist"
810 );
811 match entry {
812 Some(AgentThreadEntry::ToolCall(call)) if call.id == id => Some((id.0 as usize, call)),
813 _ => {
814 if cfg!(debug_assertions) {
815 panic!("entry is not a tool call");
816 }
817 None
818 }
819 }
820 }
821
822 pub fn set_project_location(&self, location: ToolCallLocation, cx: &mut Context<Self>) {
823 self.project.update(cx, |project, cx| {
824 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
825 return;
826 };
827 let buffer = project.open_buffer(path, cx);
828 cx.spawn(async move |project, cx| {
829 let buffer = buffer.await?;
830
831 project.update(cx, |project, cx| {
832 let position = if let Some(line) = location.line {
833 let snapshot = buffer.read(cx).snapshot();
834 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
835 snapshot.anchor_before(point)
836 } else {
837 Anchor::MIN
838 };
839
840 project.set_agent_location(
841 Some(AgentLocation {
842 buffer: buffer.downgrade(),
843 position,
844 }),
845 cx,
846 );
847 })
848 })
849 .detach_and_log_err(cx);
850 });
851 }
852
853 /// Returns true if the last turn is awaiting tool authorization
854 pub fn waiting_for_tool_confirmation(&self) -> bool {
855 for entry in self.entries.iter().rev() {
856 match &entry {
857 AgentThreadEntry::ToolCall(call) => match call.status {
858 ToolCallStatus::WaitingForConfirmation { .. } => return true,
859 ToolCallStatus::Allowed { .. }
860 | ToolCallStatus::Rejected
861 | ToolCallStatus::Canceled => continue,
862 },
863 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
864 // Reached the beginning of the turn
865 return false;
866 }
867 }
868 }
869 false
870 }
871
872 pub fn initialize(&self) -> impl use<> + Future<Output = Result<acp::InitializeResponse>> {
873 self.request(acp::InitializeParams {
874 protocol_version: ProtocolVersion::latest(),
875 })
876 }
877
878 pub fn authenticate(&self) -> impl use<> + Future<Output = Result<()>> {
879 self.request(acp::AuthenticateParams)
880 }
881
882 #[cfg(any(test, feature = "test-support"))]
883 pub fn send_raw(
884 &mut self,
885 message: &str,
886 cx: &mut Context<Self>,
887 ) -> BoxFuture<'static, Result<(), acp::Error>> {
888 self.send(
889 acp::SendUserMessageParams {
890 chunks: vec![acp::UserMessageChunk::Text {
891 text: message.to_string(),
892 }],
893 },
894 cx,
895 )
896 }
897
898 pub fn send(
899 &mut self,
900 message: acp::SendUserMessageParams,
901 cx: &mut Context<Self>,
902 ) -> BoxFuture<'static, Result<(), acp::Error>> {
903 self.push_entry(
904 AgentThreadEntry::UserMessage(UserMessage::from_acp(
905 &message,
906 self.project.read(cx).languages().clone(),
907 cx,
908 )),
909 cx,
910 );
911
912 let (tx, rx) = oneshot::channel();
913 let cancel = self.cancel(cx);
914
915 self.send_task = Some(cx.spawn(async move |this, cx| {
916 async {
917 cancel.await.log_err();
918
919 let result = this.update(cx, |this, _| this.request(message))?.await;
920 tx.send(result).log_err();
921 this.update(cx, |this, _cx| this.send_task.take())?;
922 anyhow::Ok(())
923 }
924 .await
925 .log_err();
926 }));
927
928 async move {
929 match rx.await {
930 Ok(Err(e)) => Err(e)?,
931 _ => Ok(()),
932 }
933 }
934 .boxed()
935 }
936
937 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<(), acp::Error>> {
938 if self.send_task.take().is_some() {
939 let request = self.request(acp::CancelSendMessageParams);
940 cx.spawn(async move |this, cx| {
941 request.await?;
942 this.update(cx, |this, _cx| {
943 for entry in this.entries.iter_mut() {
944 if let AgentThreadEntry::ToolCall(call) = entry {
945 let cancel = matches!(
946 call.status,
947 ToolCallStatus::WaitingForConfirmation { .. }
948 | ToolCallStatus::Allowed {
949 status: acp::ToolCallStatus::Running
950 }
951 );
952
953 if cancel {
954 let curr_status =
955 mem::replace(&mut call.status, ToolCallStatus::Canceled);
956
957 if let ToolCallStatus::WaitingForConfirmation {
958 respond_tx, ..
959 } = curr_status
960 {
961 respond_tx
962 .send(acp::ToolCallConfirmationOutcome::Cancel)
963 .ok();
964 }
965 }
966 }
967 }
968 })?;
969 Ok(())
970 })
971 } else {
972 Task::ready(Ok(()))
973 }
974 }
975
976 pub fn read_text_file(
977 &self,
978 request: acp::ReadTextFileParams,
979 reuse_shared_snapshot: bool,
980 cx: &mut Context<Self>,
981 ) -> Task<Result<String>> {
982 let project = self.project.clone();
983 let action_log = self.action_log.clone();
984 cx.spawn(async move |this, cx| {
985 let load = project.update(cx, |project, cx| {
986 let path = project
987 .project_path_for_absolute_path(&request.path, cx)
988 .context("invalid path")?;
989 anyhow::Ok(project.open_buffer(path, cx))
990 });
991 let buffer = load??.await?;
992
993 let snapshot = if reuse_shared_snapshot {
994 this.read_with(cx, |this, _| {
995 this.shared_buffers.get(&buffer.clone()).cloned()
996 })
997 .log_err()
998 .flatten()
999 } else {
1000 None
1001 };
1002
1003 let snapshot = if let Some(snapshot) = snapshot {
1004 snapshot
1005 } else {
1006 action_log.update(cx, |action_log, cx| {
1007 action_log.buffer_read(buffer.clone(), cx);
1008 })?;
1009 project.update(cx, |project, cx| {
1010 let position = buffer
1011 .read(cx)
1012 .snapshot()
1013 .anchor_before(Point::new(request.line.unwrap_or_default(), 0));
1014 project.set_agent_location(
1015 Some(AgentLocation {
1016 buffer: buffer.downgrade(),
1017 position,
1018 }),
1019 cx,
1020 );
1021 })?;
1022
1023 buffer.update(cx, |buffer, _| buffer.snapshot())?
1024 };
1025
1026 this.update(cx, |this, _| {
1027 let text = snapshot.text();
1028 this.shared_buffers.insert(buffer.clone(), snapshot);
1029 if request.line.is_none() && request.limit.is_none() {
1030 return Ok(text);
1031 }
1032 let limit = request.limit.unwrap_or(u32::MAX) as usize;
1033 let Some(line) = request.line else {
1034 return Ok(text.lines().take(limit).collect::<String>());
1035 };
1036
1037 let count = text.lines().count();
1038 if count < line as usize {
1039 anyhow::bail!("There are only {} lines", count);
1040 }
1041 Ok(text
1042 .lines()
1043 .skip(line as usize + 1)
1044 .take(limit)
1045 .collect::<String>())
1046 })?
1047 })
1048 }
1049
1050 pub fn write_text_file(
1051 &self,
1052 path: PathBuf,
1053 content: String,
1054 cx: &mut Context<Self>,
1055 ) -> Task<Result<()>> {
1056 let project = self.project.clone();
1057 let action_log = self.action_log.clone();
1058 cx.spawn(async move |this, cx| {
1059 let load = project.update(cx, |project, cx| {
1060 let path = project
1061 .project_path_for_absolute_path(&path, cx)
1062 .context("invalid path")?;
1063 anyhow::Ok(project.open_buffer(path, cx))
1064 });
1065 let buffer = load??.await?;
1066 let snapshot = this.update(cx, |this, cx| {
1067 this.shared_buffers
1068 .get(&buffer)
1069 .cloned()
1070 .unwrap_or_else(|| buffer.read(cx).snapshot())
1071 })?;
1072 let edits = cx
1073 .background_executor()
1074 .spawn(async move {
1075 let old_text = snapshot.text();
1076 text_diff(old_text.as_str(), &content)
1077 .into_iter()
1078 .map(|(range, replacement)| {
1079 (
1080 snapshot.anchor_after(range.start)
1081 ..snapshot.anchor_before(range.end),
1082 replacement,
1083 )
1084 })
1085 .collect::<Vec<_>>()
1086 })
1087 .await;
1088 cx.update(|cx| {
1089 project.update(cx, |project, cx| {
1090 project.set_agent_location(
1091 Some(AgentLocation {
1092 buffer: buffer.downgrade(),
1093 position: edits
1094 .last()
1095 .map(|(range, _)| range.end)
1096 .unwrap_or(Anchor::MIN),
1097 }),
1098 cx,
1099 );
1100 });
1101
1102 action_log.update(cx, |action_log, cx| {
1103 action_log.buffer_read(buffer.clone(), cx);
1104 });
1105 buffer.update(cx, |buffer, cx| {
1106 buffer.edit(edits, None, cx);
1107 });
1108 action_log.update(cx, |action_log, cx| {
1109 action_log.buffer_edited(buffer.clone(), cx);
1110 });
1111 })?;
1112 project
1113 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1114 .await
1115 })
1116 }
1117
1118 pub fn child_status(&mut self) -> Option<Task<Result<()>>> {
1119 self.child_status.take()
1120 }
1121
1122 pub fn to_markdown(&self, cx: &App) -> String {
1123 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1124 }
1125}
1126
1127#[derive(Clone)]
1128pub struct AcpClientDelegate {
1129 thread: WeakEntity<AcpThread>,
1130 cx: AsyncApp,
1131 // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
1132}
1133
1134impl AcpClientDelegate {
1135 pub fn new(thread: WeakEntity<AcpThread>, cx: AsyncApp) -> Self {
1136 Self { thread, cx }
1137 }
1138
1139 pub async fn request_existing_tool_call_confirmation(
1140 &self,
1141 tool_call_id: ToolCallId,
1142 confirmation: acp::ToolCallConfirmation,
1143 ) -> Result<ToolCallConfirmationOutcome> {
1144 let cx = &mut self.cx.clone();
1145 let ToolCallRequest { outcome, .. } = cx
1146 .update(|cx| {
1147 self.thread.update(cx, |thread, cx| {
1148 thread.request_tool_call_confirmation(tool_call_id, confirmation, cx)
1149 })
1150 })?
1151 .context("Failed to update thread")??;
1152
1153 Ok(outcome.await?)
1154 }
1155
1156 pub async fn read_text_file_reusing_snapshot(
1157 &self,
1158 request: acp::ReadTextFileParams,
1159 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
1160 let content = self
1161 .cx
1162 .update(|cx| {
1163 self.thread
1164 .update(cx, |thread, cx| thread.read_text_file(request, true, cx))
1165 })?
1166 .context("Failed to update thread")?
1167 .await?;
1168 Ok(acp::ReadTextFileResponse { content })
1169 }
1170}
1171
1172impl acp::Client for AcpClientDelegate {
1173 async fn stream_assistant_message_chunk(
1174 &self,
1175 params: acp::StreamAssistantMessageChunkParams,
1176 ) -> Result<(), acp::Error> {
1177 let cx = &mut self.cx.clone();
1178
1179 cx.update(|cx| {
1180 self.thread
1181 .update(cx, |thread, cx| {
1182 thread.push_assistant_chunk(params.chunk, cx)
1183 })
1184 .ok();
1185 })?;
1186
1187 Ok(())
1188 }
1189
1190 async fn request_tool_call_confirmation(
1191 &self,
1192 request: acp::RequestToolCallConfirmationParams,
1193 ) -> Result<acp::RequestToolCallConfirmationResponse, acp::Error> {
1194 let cx = &mut self.cx.clone();
1195 let ToolCallRequest { id, outcome } = cx
1196 .update(|cx| {
1197 self.thread
1198 .update(cx, |thread, cx| thread.request_new_tool_call(request, cx))
1199 })?
1200 .context("Failed to update thread")?;
1201
1202 Ok(acp::RequestToolCallConfirmationResponse {
1203 id,
1204 outcome: outcome.await.map_err(acp::Error::into_internal_error)?,
1205 })
1206 }
1207
1208 async fn push_tool_call(
1209 &self,
1210 request: acp::PushToolCallParams,
1211 ) -> Result<acp::PushToolCallResponse, acp::Error> {
1212 let cx = &mut self.cx.clone();
1213 let id = cx
1214 .update(|cx| {
1215 self.thread
1216 .update(cx, |thread, cx| thread.push_tool_call(request, cx))
1217 })?
1218 .context("Failed to update thread")?;
1219
1220 Ok(acp::PushToolCallResponse { id })
1221 }
1222
1223 async fn update_tool_call(&self, request: acp::UpdateToolCallParams) -> Result<(), acp::Error> {
1224 let cx = &mut self.cx.clone();
1225
1226 cx.update(|cx| {
1227 self.thread.update(cx, |thread, cx| {
1228 thread.update_tool_call(request.tool_call_id, request.status, request.content, cx)
1229 })
1230 })?
1231 .context("Failed to update thread")??;
1232
1233 Ok(())
1234 }
1235
1236 async fn read_text_file(
1237 &self,
1238 request: acp::ReadTextFileParams,
1239 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
1240 let content = self
1241 .cx
1242 .update(|cx| {
1243 self.thread
1244 .update(cx, |thread, cx| thread.read_text_file(request, false, cx))
1245 })?
1246 .context("Failed to update thread")?
1247 .await?;
1248 Ok(acp::ReadTextFileResponse { content })
1249 }
1250
1251 async fn write_text_file(&self, request: acp::WriteTextFileParams) -> Result<(), acp::Error> {
1252 self.cx
1253 .update(|cx| {
1254 self.thread.update(cx, |thread, cx| {
1255 thread.write_text_file(request.path, request.content, cx)
1256 })
1257 })?
1258 .context("Failed to update thread")?
1259 .await?;
1260
1261 Ok(())
1262 }
1263}
1264
1265fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
1266 match icon {
1267 acp::Icon::FileSearch => IconName::ToolSearch,
1268 acp::Icon::Folder => IconName::ToolFolder,
1269 acp::Icon::Globe => IconName::ToolWeb,
1270 acp::Icon::Hammer => IconName::ToolHammer,
1271 acp::Icon::LightBulb => IconName::ToolBulb,
1272 acp::Icon::Pencil => IconName::ToolPencil,
1273 acp::Icon::Regex => IconName::ToolRegex,
1274 acp::Icon::Terminal => IconName::ToolTerminal,
1275 }
1276}
1277
1278pub struct ToolCallRequest {
1279 pub id: acp::ToolCallId,
1280 pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
1281}
1282
1283#[cfg(test)]
1284mod tests {
1285 use super::*;
1286 use anyhow::anyhow;
1287 use async_pipe::{PipeReader, PipeWriter};
1288 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1289 use gpui::{AsyncApp, TestAppContext};
1290 use indoc::indoc;
1291 use project::FakeFs;
1292 use serde_json::json;
1293 use settings::SettingsStore;
1294 use smol::{future::BoxedLocal, stream::StreamExt as _};
1295 use std::{cell::RefCell, rc::Rc, time::Duration};
1296 use util::path;
1297
1298 fn init_test(cx: &mut TestAppContext) {
1299 env_logger::try_init().ok();
1300 cx.update(|cx| {
1301 let settings_store = SettingsStore::test(cx);
1302 cx.set_global(settings_store);
1303 Project::init_settings(cx);
1304 language::init(cx);
1305 });
1306 }
1307
1308 #[gpui::test]
1309 async fn test_thinking_concatenation(cx: &mut TestAppContext) {
1310 init_test(cx);
1311
1312 let fs = FakeFs::new(cx.executor());
1313 let project = Project::test(fs, [], cx).await;
1314 let (thread, fake_server) = fake_acp_thread(project, cx);
1315
1316 fake_server.update(cx, |fake_server, _| {
1317 fake_server.on_user_message(move |_, server, mut cx| async move {
1318 server
1319 .update(&mut cx, |server, _| {
1320 server.send_to_zed(acp::StreamAssistantMessageChunkParams {
1321 chunk: acp::AssistantMessageChunk::Thought {
1322 thought: "Thinking ".into(),
1323 },
1324 })
1325 })?
1326 .await
1327 .unwrap();
1328 server
1329 .update(&mut cx, |server, _| {
1330 server.send_to_zed(acp::StreamAssistantMessageChunkParams {
1331 chunk: acp::AssistantMessageChunk::Thought {
1332 thought: "hard!".into(),
1333 },
1334 })
1335 })?
1336 .await
1337 .unwrap();
1338
1339 Ok(())
1340 })
1341 });
1342
1343 thread
1344 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1345 .await
1346 .unwrap();
1347
1348 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1349 assert_eq!(
1350 output,
1351 indoc! {r#"
1352 ## User
1353
1354 Hello from Zed!
1355
1356 ## Assistant
1357
1358 <thinking>
1359 Thinking hard!
1360 </thinking>
1361
1362 "#}
1363 );
1364 }
1365
1366 #[gpui::test]
1367 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1368 init_test(cx);
1369
1370 let fs = FakeFs::new(cx.executor());
1371 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1372 .await;
1373 let project = Project::test(fs.clone(), [], cx).await;
1374 let (thread, fake_server) = fake_acp_thread(project.clone(), cx);
1375 let (worktree, pathbuf) = project
1376 .update(cx, |project, cx| {
1377 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1378 })
1379 .await
1380 .unwrap();
1381 let buffer = project
1382 .update(cx, |project, cx| {
1383 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1384 })
1385 .await
1386 .unwrap();
1387
1388 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1389 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1390
1391 fake_server.update(cx, |fake_server, _| {
1392 fake_server.on_user_message(move |_, server, mut cx| {
1393 let read_file_tx = read_file_tx.clone();
1394 async move {
1395 let content = server
1396 .update(&mut cx, |server, _| {
1397 server.send_to_zed(acp::ReadTextFileParams {
1398 path: path!("/tmp/foo").into(),
1399 line: None,
1400 limit: None,
1401 })
1402 })?
1403 .await
1404 .unwrap();
1405 assert_eq!(content.content, "one\ntwo\nthree\n");
1406 read_file_tx.take().unwrap().send(()).unwrap();
1407 server
1408 .update(&mut cx, |server, _| {
1409 server.send_to_zed(acp::WriteTextFileParams {
1410 path: path!("/tmp/foo").into(),
1411 content: "one\ntwo\nthree\nfour\nfive\n".to_string(),
1412 })
1413 })?
1414 .await
1415 .unwrap();
1416 Ok(())
1417 }
1418 })
1419 });
1420
1421 let request = thread.update(cx, |thread, cx| {
1422 thread.send_raw("Extend the count in /tmp/foo", cx)
1423 });
1424 read_file_rx.await.ok();
1425 buffer.update(cx, |buffer, cx| {
1426 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1427 });
1428 cx.run_until_parked();
1429 assert_eq!(
1430 buffer.read_with(cx, |buffer, _| buffer.text()),
1431 "zero\none\ntwo\nthree\nfour\nfive\n"
1432 );
1433 assert_eq!(
1434 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1435 "zero\none\ntwo\nthree\nfour\nfive\n"
1436 );
1437 request.await.unwrap();
1438 }
1439
1440 #[gpui::test]
1441 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1442 init_test(cx);
1443
1444 let fs = FakeFs::new(cx.executor());
1445 let project = Project::test(fs, [], cx).await;
1446 let (thread, fake_server) = fake_acp_thread(project, cx);
1447
1448 let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>();
1449
1450 let tool_call_id = Rc::new(RefCell::new(None));
1451 let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx)));
1452 fake_server.update(cx, |fake_server, _| {
1453 let tool_call_id = tool_call_id.clone();
1454 fake_server.on_user_message(move |_, server, mut cx| {
1455 let end_turn_rx = end_turn_rx.clone();
1456 let tool_call_id = tool_call_id.clone();
1457 async move {
1458 let tool_call_result = server
1459 .update(&mut cx, |server, _| {
1460 server.send_to_zed(acp::PushToolCallParams {
1461 label: "Fetch".to_string(),
1462 icon: acp::Icon::Globe,
1463 content: None,
1464 locations: vec![],
1465 })
1466 })?
1467 .await
1468 .unwrap();
1469 *tool_call_id.clone().borrow_mut() = Some(tool_call_result.id);
1470 end_turn_rx.take().unwrap().await.ok();
1471
1472 Ok(())
1473 }
1474 })
1475 });
1476
1477 let request = thread.update(cx, |thread, cx| {
1478 thread.send_raw("Fetch https://example.com", cx)
1479 });
1480
1481 run_until_first_tool_call(&thread, cx).await;
1482
1483 thread.read_with(cx, |thread, _| {
1484 assert!(matches!(
1485 thread.entries[1],
1486 AgentThreadEntry::ToolCall(ToolCall {
1487 status: ToolCallStatus::Allowed {
1488 status: acp::ToolCallStatus::Running,
1489 ..
1490 },
1491 ..
1492 })
1493 ));
1494 });
1495
1496 cx.run_until_parked();
1497
1498 thread
1499 .update(cx, |thread, cx| thread.cancel(cx))
1500 .await
1501 .unwrap();
1502
1503 thread.read_with(cx, |thread, _| {
1504 assert!(matches!(
1505 &thread.entries[1],
1506 AgentThreadEntry::ToolCall(ToolCall {
1507 status: ToolCallStatus::Canceled,
1508 ..
1509 })
1510 ));
1511 });
1512
1513 fake_server
1514 .update(cx, |fake_server, _| {
1515 fake_server.send_to_zed(acp::UpdateToolCallParams {
1516 tool_call_id: tool_call_id.borrow().unwrap(),
1517 status: acp::ToolCallStatus::Finished,
1518 content: None,
1519 })
1520 })
1521 .await
1522 .unwrap();
1523
1524 drop(end_turn_tx);
1525 request.await.unwrap();
1526
1527 thread.read_with(cx, |thread, _| {
1528 assert!(matches!(
1529 thread.entries[1],
1530 AgentThreadEntry::ToolCall(ToolCall {
1531 status: ToolCallStatus::Allowed {
1532 status: acp::ToolCallStatus::Finished,
1533 ..
1534 },
1535 ..
1536 })
1537 ));
1538 });
1539 }
1540
1541 async fn run_until_first_tool_call(
1542 thread: &Entity<AcpThread>,
1543 cx: &mut TestAppContext,
1544 ) -> usize {
1545 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1546
1547 let subscription = cx.update(|cx| {
1548 cx.subscribe(thread, move |thread, _, cx| {
1549 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1550 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1551 return tx.try_send(ix).unwrap();
1552 }
1553 }
1554 })
1555 });
1556
1557 select! {
1558 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1559 panic!("Timeout waiting for tool call")
1560 }
1561 ix = rx.next().fuse() => {
1562 drop(subscription);
1563 ix.unwrap()
1564 }
1565 }
1566 }
1567
1568 pub fn fake_acp_thread(
1569 project: Entity<Project>,
1570 cx: &mut TestAppContext,
1571 ) -> (Entity<AcpThread>, Entity<FakeAcpServer>) {
1572 let (stdin_tx, stdin_rx) = async_pipe::pipe();
1573 let (stdout_tx, stdout_rx) = async_pipe::pipe();
1574
1575 let thread = cx.new(|cx| {
1576 let foreground_executor = cx.foreground_executor().clone();
1577 let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
1578 AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
1579 stdin_tx,
1580 stdout_rx,
1581 move |fut| {
1582 foreground_executor.spawn(fut).detach();
1583 },
1584 );
1585
1586 let io_task = cx.background_spawn({
1587 async move {
1588 io_fut.await.log_err();
1589 Ok(())
1590 }
1591 });
1592 AcpThread::new(connection, "Test".into(), Some(io_task), project, cx)
1593 });
1594 let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
1595 (thread, agent)
1596 }
1597
1598 pub struct FakeAcpServer {
1599 connection: acp::ClientConnection,
1600
1601 _io_task: Task<()>,
1602 on_user_message: Option<
1603 Rc<
1604 dyn Fn(
1605 acp::SendUserMessageParams,
1606 Entity<FakeAcpServer>,
1607 AsyncApp,
1608 ) -> LocalBoxFuture<'static, Result<(), acp::Error>>,
1609 >,
1610 >,
1611 }
1612
1613 #[derive(Clone)]
1614 struct FakeAgent {
1615 server: Entity<FakeAcpServer>,
1616 cx: AsyncApp,
1617 }
1618
1619 impl acp::Agent for FakeAgent {
1620 async fn initialize(
1621 &self,
1622 params: acp::InitializeParams,
1623 ) -> Result<acp::InitializeResponse, acp::Error> {
1624 Ok(acp::InitializeResponse {
1625 protocol_version: params.protocol_version,
1626 is_authenticated: true,
1627 })
1628 }
1629
1630 async fn authenticate(&self) -> Result<(), acp::Error> {
1631 Ok(())
1632 }
1633
1634 async fn cancel_send_message(&self) -> Result<(), acp::Error> {
1635 Ok(())
1636 }
1637
1638 async fn send_user_message(
1639 &self,
1640 request: acp::SendUserMessageParams,
1641 ) -> Result<(), acp::Error> {
1642 let mut cx = self.cx.clone();
1643 let handler = self
1644 .server
1645 .update(&mut cx, |server, _| server.on_user_message.clone())
1646 .ok()
1647 .flatten();
1648 if let Some(handler) = handler {
1649 handler(request, self.server.clone(), self.cx.clone()).await
1650 } else {
1651 Err(anyhow::anyhow!("No handler for on_user_message").into())
1652 }
1653 }
1654 }
1655
1656 impl FakeAcpServer {
1657 fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
1658 let agent = FakeAgent {
1659 server: cx.entity(),
1660 cx: cx.to_async(),
1661 };
1662 let foreground_executor = cx.foreground_executor().clone();
1663
1664 let (connection, io_fut) = acp::ClientConnection::connect_to_client(
1665 agent.clone(),
1666 stdout,
1667 stdin,
1668 move |fut| {
1669 foreground_executor.spawn(fut).detach();
1670 },
1671 );
1672 FakeAcpServer {
1673 connection: connection,
1674 on_user_message: None,
1675 _io_task: cx.background_spawn(async move {
1676 io_fut.await.log_err();
1677 }),
1678 }
1679 }
1680
1681 fn on_user_message<F>(
1682 &mut self,
1683 handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity<FakeAcpServer>, AsyncApp) -> F
1684 + 'static,
1685 ) where
1686 F: Future<Output = Result<(), acp::Error>> + 'static,
1687 {
1688 self.on_user_message
1689 .replace(Rc::new(move |request, server, cx| {
1690 handler(request, server, cx).boxed_local()
1691 }));
1692 }
1693
1694 fn send_to_zed<T: acp::ClientRequest + 'static>(
1695 &self,
1696 message: T,
1697 ) -> BoxedLocal<Result<T::Response>> {
1698 self.connection
1699 .request(message)
1700 .map(|f| f.map_err(|err| anyhow!(err)))
1701 .boxed_local()
1702 }
1703 }
1704}