1mod connection;
2pub use connection::*;
3
4pub use acp::ToolCallId;
5use agentic_coding_protocol::{
6 self as acp, AgentRequest, ProtocolVersion, ToolCallConfirmationOutcome, ToolCallLocation,
7 UserMessageChunk,
8};
9use anyhow::{Context as _, Result};
10use assistant_tool::ActionLog;
11use buffer_diff::BufferDiff;
12use editor::{Bias, MultiBuffer, PathKey};
13use futures::{FutureExt, channel::oneshot, future::BoxFuture};
14use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
15use itertools::Itertools;
16use language::{
17 Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
18 text_diff,
19};
20use markdown::Markdown;
21use project::{AgentLocation, Project};
22use std::collections::HashMap;
23use std::error::Error;
24use std::fmt::{Formatter, Write};
25use std::{
26 fmt::Display,
27 mem,
28 path::{Path, PathBuf},
29 sync::Arc,
30};
31use ui::{App, IconName};
32use util::ResultExt;
33
34#[derive(Clone, Debug, Eq, PartialEq)]
35pub struct UserMessage {
36 pub content: Entity<Markdown>,
37}
38
39impl UserMessage {
40 pub fn from_acp(
41 message: &acp::SendUserMessageParams,
42 language_registry: Arc<LanguageRegistry>,
43 cx: &mut App,
44 ) -> Self {
45 let mut md_source = String::new();
46
47 for chunk in &message.chunks {
48 match chunk {
49 UserMessageChunk::Text { text } => md_source.push_str(&text),
50 UserMessageChunk::Path { path } => {
51 write!(&mut md_source, "{}", MentionPath(&path)).unwrap()
52 }
53 }
54 }
55
56 Self {
57 content: cx
58 .new(|cx| Markdown::new(md_source.into(), Some(language_registry), None, cx)),
59 }
60 }
61
62 fn to_markdown(&self, cx: &App) -> String {
63 format!("## User\n\n{}\n\n", self.content.read(cx).source())
64 }
65}
66
67#[derive(Debug)]
68pub struct MentionPath<'a>(&'a Path);
69
70impl<'a> MentionPath<'a> {
71 const PREFIX: &'static str = "@file:";
72
73 pub fn new(path: &'a Path) -> Self {
74 MentionPath(path)
75 }
76
77 pub fn try_parse(url: &'a str) -> Option<Self> {
78 let path = url.strip_prefix(Self::PREFIX)?;
79 Some(MentionPath(Path::new(path)))
80 }
81
82 pub fn path(&self) -> &Path {
83 self.0
84 }
85}
86
87impl Display for MentionPath<'_> {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 write!(
90 f,
91 "[@{}]({}{})",
92 self.0.file_name().unwrap_or_default().display(),
93 Self::PREFIX,
94 self.0.display()
95 )
96 }
97}
98
99#[derive(Clone, Debug, Eq, PartialEq)]
100pub struct AssistantMessage {
101 pub chunks: Vec<AssistantMessageChunk>,
102}
103
104impl AssistantMessage {
105 pub fn to_markdown(&self, cx: &App) -> String {
106 format!(
107 "## Assistant\n\n{}\n\n",
108 self.chunks
109 .iter()
110 .map(|chunk| chunk.to_markdown(cx))
111 .join("\n\n")
112 )
113 }
114}
115
116#[derive(Clone, Debug, Eq, PartialEq)]
117pub enum AssistantMessageChunk {
118 Text { chunk: Entity<Markdown> },
119 Thought { chunk: Entity<Markdown> },
120}
121
122impl AssistantMessageChunk {
123 pub fn from_acp(
124 chunk: acp::AssistantMessageChunk,
125 language_registry: Arc<LanguageRegistry>,
126 cx: &mut App,
127 ) -> Self {
128 match chunk {
129 acp::AssistantMessageChunk::Text { text } => Self::Text {
130 chunk: cx.new(|cx| Markdown::new(text.into(), Some(language_registry), None, cx)),
131 },
132 acp::AssistantMessageChunk::Thought { thought } => Self::Thought {
133 chunk: cx
134 .new(|cx| Markdown::new(thought.into(), Some(language_registry), None, cx)),
135 },
136 }
137 }
138
139 pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
140 Self::Text {
141 chunk: cx.new(|cx| {
142 Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
143 }),
144 }
145 }
146
147 fn to_markdown(&self, cx: &App) -> String {
148 match self {
149 Self::Text { chunk } => chunk.read(cx).source().to_string(),
150 Self::Thought { chunk } => {
151 format!("<thinking>\n{}\n</thinking>", chunk.read(cx).source())
152 }
153 }
154 }
155}
156
157#[derive(Debug)]
158pub enum AgentThreadEntry {
159 UserMessage(UserMessage),
160 AssistantMessage(AssistantMessage),
161 ToolCall(ToolCall),
162}
163
164impl AgentThreadEntry {
165 fn to_markdown(&self, cx: &App) -> String {
166 match self {
167 Self::UserMessage(message) => message.to_markdown(cx),
168 Self::AssistantMessage(message) => message.to_markdown(cx),
169 Self::ToolCall(too_call) => too_call.to_markdown(cx),
170 }
171 }
172
173 pub fn diff(&self) -> Option<&Diff> {
174 if let AgentThreadEntry::ToolCall(ToolCall {
175 content: Some(ToolCallContent::Diff { diff }),
176 ..
177 }) = self
178 {
179 Some(&diff)
180 } else {
181 None
182 }
183 }
184
185 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
186 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
187 Some(locations)
188 } else {
189 None
190 }
191 }
192}
193
194#[derive(Debug)]
195pub struct ToolCall {
196 pub id: acp::ToolCallId,
197 pub label: Entity<Markdown>,
198 pub icon: IconName,
199 pub content: Option<ToolCallContent>,
200 pub status: ToolCallStatus,
201 pub locations: Vec<acp::ToolCallLocation>,
202}
203
204impl ToolCall {
205 fn to_markdown(&self, cx: &App) -> String {
206 let mut markdown = format!(
207 "**Tool Call: {}**\nStatus: {}\n\n",
208 self.label.read(cx).source(),
209 self.status
210 );
211 if let Some(content) = &self.content {
212 markdown.push_str(content.to_markdown(cx).as_str());
213 markdown.push_str("\n\n");
214 }
215 markdown
216 }
217}
218
219#[derive(Debug)]
220pub enum ToolCallStatus {
221 WaitingForConfirmation {
222 confirmation: ToolCallConfirmation,
223 respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
224 },
225 Allowed {
226 status: acp::ToolCallStatus,
227 },
228 Rejected,
229 Canceled,
230}
231
232impl Display for ToolCallStatus {
233 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
234 write!(
235 f,
236 "{}",
237 match self {
238 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
239 ToolCallStatus::Allowed { status } => match status {
240 acp::ToolCallStatus::Running => "Running",
241 acp::ToolCallStatus::Finished => "Finished",
242 acp::ToolCallStatus::Error => "Error",
243 },
244 ToolCallStatus::Rejected => "Rejected",
245 ToolCallStatus::Canceled => "Canceled",
246 }
247 )
248 }
249}
250
251#[derive(Debug)]
252pub enum ToolCallConfirmation {
253 Edit {
254 description: Option<Entity<Markdown>>,
255 },
256 Execute {
257 command: String,
258 root_command: String,
259 description: Option<Entity<Markdown>>,
260 },
261 Mcp {
262 server_name: String,
263 tool_name: String,
264 tool_display_name: String,
265 description: Option<Entity<Markdown>>,
266 },
267 Fetch {
268 urls: Vec<SharedString>,
269 description: Option<Entity<Markdown>>,
270 },
271 Other {
272 description: Entity<Markdown>,
273 },
274}
275
276impl ToolCallConfirmation {
277 pub fn from_acp(
278 confirmation: acp::ToolCallConfirmation,
279 language_registry: Arc<LanguageRegistry>,
280 cx: &mut App,
281 ) -> Self {
282 let to_md = |description: String, cx: &mut App| -> Entity<Markdown> {
283 cx.new(|cx| {
284 Markdown::new(
285 description.into(),
286 Some(language_registry.clone()),
287 None,
288 cx,
289 )
290 })
291 };
292
293 match confirmation {
294 acp::ToolCallConfirmation::Edit { description } => Self::Edit {
295 description: description.map(|description| to_md(description, cx)),
296 },
297 acp::ToolCallConfirmation::Execute {
298 command,
299 root_command,
300 description,
301 } => Self::Execute {
302 command,
303 root_command,
304 description: description.map(|description| to_md(description, cx)),
305 },
306 acp::ToolCallConfirmation::Mcp {
307 server_name,
308 tool_name,
309 tool_display_name,
310 description,
311 } => Self::Mcp {
312 server_name,
313 tool_name,
314 tool_display_name,
315 description: description.map(|description| to_md(description, cx)),
316 },
317 acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch {
318 urls: urls.iter().map(|url| url.into()).collect(),
319 description: description.map(|description| to_md(description, cx)),
320 },
321 acp::ToolCallConfirmation::Other { description } => Self::Other {
322 description: to_md(description, cx),
323 },
324 }
325 }
326}
327
328#[derive(Debug)]
329pub enum ToolCallContent {
330 Markdown { markdown: Entity<Markdown> },
331 Diff { diff: Diff },
332}
333
334impl ToolCallContent {
335 pub fn from_acp(
336 content: acp::ToolCallContent,
337 language_registry: Arc<LanguageRegistry>,
338 cx: &mut App,
339 ) -> Self {
340 match content {
341 acp::ToolCallContent::Markdown { markdown } => Self::Markdown {
342 markdown: cx.new(|cx| Markdown::new_text(markdown.into(), cx)),
343 },
344 acp::ToolCallContent::Diff { diff } => Self::Diff {
345 diff: Diff::from_acp(diff, language_registry, cx),
346 },
347 }
348 }
349
350 fn to_markdown(&self, cx: &App) -> String {
351 match self {
352 Self::Markdown { markdown } => markdown.read(cx).source().to_string(),
353 Self::Diff { diff } => diff.to_markdown(cx),
354 }
355 }
356}
357
358#[derive(Debug)]
359pub struct Diff {
360 pub multibuffer: Entity<MultiBuffer>,
361 pub path: PathBuf,
362 pub new_buffer: Entity<Buffer>,
363 pub old_buffer: Entity<Buffer>,
364 _task: Task<Result<()>>,
365}
366
367impl Diff {
368 pub fn from_acp(
369 diff: acp::Diff,
370 language_registry: Arc<LanguageRegistry>,
371 cx: &mut App,
372 ) -> Self {
373 let acp::Diff {
374 path,
375 old_text,
376 new_text,
377 } = diff;
378
379 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
380
381 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
382 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
383 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
384 let old_buffer_snapshot = old_buffer.read(cx).snapshot();
385 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
386 let diff_task = buffer_diff.update(cx, |diff, cx| {
387 diff.set_base_text(
388 old_buffer_snapshot,
389 Some(language_registry.clone()),
390 new_buffer_snapshot,
391 cx,
392 )
393 });
394
395 let task = cx.spawn({
396 let multibuffer = multibuffer.clone();
397 let path = path.clone();
398 let new_buffer = new_buffer.clone();
399 async move |cx| {
400 diff_task.await?;
401
402 multibuffer
403 .update(cx, |multibuffer, cx| {
404 let hunk_ranges = {
405 let buffer = new_buffer.read(cx);
406 let diff = buffer_diff.read(cx);
407 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
408 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
409 .collect::<Vec<_>>()
410 };
411
412 multibuffer.set_excerpts_for_path(
413 PathKey::for_buffer(&new_buffer, cx),
414 new_buffer.clone(),
415 hunk_ranges,
416 editor::DEFAULT_MULTIBUFFER_CONTEXT,
417 cx,
418 );
419 multibuffer.add_diff(buffer_diff.clone(), cx);
420 })
421 .log_err();
422
423 if let Some(language) = language_registry
424 .language_for_file_path(&path)
425 .await
426 .log_err()
427 {
428 new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
429 }
430
431 anyhow::Ok(())
432 }
433 });
434
435 Self {
436 multibuffer,
437 path,
438 new_buffer,
439 old_buffer,
440 _task: task,
441 }
442 }
443
444 fn to_markdown(&self, cx: &App) -> String {
445 let buffer_text = self
446 .multibuffer
447 .read(cx)
448 .all_buffers()
449 .iter()
450 .map(|buffer| buffer.read(cx).text())
451 .join("\n");
452 format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
453 }
454}
455
456pub struct AcpThread {
457 entries: Vec<AgentThreadEntry>,
458 title: SharedString,
459 project: Entity<Project>,
460 action_log: Entity<ActionLog>,
461 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
462 send_task: Option<Task<()>>,
463 connection: Arc<dyn AgentConnection>,
464 child_status: Option<Task<Result<()>>>,
465}
466
467pub enum AcpThreadEvent {
468 NewEntry,
469 EntryUpdated(usize),
470}
471
472impl EventEmitter<AcpThreadEvent> for AcpThread {}
473
474#[derive(PartialEq, Eq)]
475pub enum ThreadStatus {
476 Idle,
477 WaitingForToolConfirmation,
478 Generating,
479}
480
481#[derive(Debug, Clone)]
482pub enum LoadError {
483 Unsupported {
484 error_message: SharedString,
485 upgrade_message: SharedString,
486 upgrade_command: String,
487 },
488 Exited(i32),
489 Other(SharedString),
490}
491
492impl Display for LoadError {
493 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
494 match self {
495 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
496 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
497 LoadError::Other(msg) => write!(f, "{}", msg),
498 }
499 }
500}
501
502impl Error for LoadError {}
503
504impl AcpThread {
505 pub fn new(
506 connection: impl AgentConnection + 'static,
507 title: SharedString,
508 child_status: Option<Task<Result<()>>>,
509 project: Entity<Project>,
510 cx: &mut Context<Self>,
511 ) -> Self {
512 let action_log = cx.new(|_| ActionLog::new(project.clone()));
513
514 Self {
515 action_log,
516 shared_buffers: Default::default(),
517 entries: Default::default(),
518 title,
519 project,
520 send_task: None,
521 connection: Arc::new(connection),
522 child_status,
523 }
524 }
525
526 /// Send a request to the agent and wait for a response.
527 pub fn request<R: AgentRequest + 'static>(
528 &self,
529 params: R,
530 ) -> impl use<R> + Future<Output = Result<R::Response>> {
531 let params = params.into_any();
532 let result = self.connection.request_any(params);
533 async move {
534 let result = result.await?;
535 Ok(R::response_from_any(result)?)
536 }
537 }
538
539 pub fn action_log(&self) -> &Entity<ActionLog> {
540 &self.action_log
541 }
542
543 pub fn project(&self) -> &Entity<Project> {
544 &self.project
545 }
546
547 pub fn title(&self) -> SharedString {
548 self.title.clone()
549 }
550
551 pub fn entries(&self) -> &[AgentThreadEntry] {
552 &self.entries
553 }
554
555 pub fn status(&self) -> ThreadStatus {
556 if self.send_task.is_some() {
557 if self.waiting_for_tool_confirmation() {
558 ThreadStatus::WaitingForToolConfirmation
559 } else {
560 ThreadStatus::Generating
561 }
562 } else {
563 ThreadStatus::Idle
564 }
565 }
566
567 pub fn has_pending_edit_tool_calls(&self) -> bool {
568 for entry in self.entries.iter().rev() {
569 match entry {
570 AgentThreadEntry::UserMessage(_) => return false,
571 AgentThreadEntry::ToolCall(ToolCall {
572 status:
573 ToolCallStatus::Allowed {
574 status: acp::ToolCallStatus::Running,
575 ..
576 },
577 content: Some(ToolCallContent::Diff { .. }),
578 ..
579 }) => return true,
580 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
581 }
582 }
583
584 false
585 }
586
587 pub fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
588 self.entries.push(entry);
589 cx.emit(AcpThreadEvent::NewEntry);
590 }
591
592 pub fn push_assistant_chunk(
593 &mut self,
594 chunk: acp::AssistantMessageChunk,
595 cx: &mut Context<Self>,
596 ) {
597 let entries_len = self.entries.len();
598 if let Some(last_entry) = self.entries.last_mut()
599 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
600 {
601 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
602
603 match (chunks.last_mut(), &chunk) {
604 (
605 Some(AssistantMessageChunk::Text { chunk: old_chunk }),
606 acp::AssistantMessageChunk::Text { text: new_chunk },
607 )
608 | (
609 Some(AssistantMessageChunk::Thought { chunk: old_chunk }),
610 acp::AssistantMessageChunk::Thought { thought: new_chunk },
611 ) => {
612 old_chunk.update(cx, |old_chunk, cx| {
613 old_chunk.append(&new_chunk, cx);
614 });
615 }
616 _ => {
617 chunks.push(AssistantMessageChunk::from_acp(
618 chunk,
619 self.project.read(cx).languages().clone(),
620 cx,
621 ));
622 }
623 }
624 } else {
625 let chunk = AssistantMessageChunk::from_acp(
626 chunk,
627 self.project.read(cx).languages().clone(),
628 cx,
629 );
630
631 self.push_entry(
632 AgentThreadEntry::AssistantMessage(AssistantMessage {
633 chunks: vec![chunk],
634 }),
635 cx,
636 );
637 }
638 }
639
640 pub fn request_new_tool_call(
641 &mut self,
642 tool_call: acp::RequestToolCallConfirmationParams,
643 cx: &mut Context<Self>,
644 ) -> ToolCallRequest {
645 let (tx, rx) = oneshot::channel();
646
647 let status = ToolCallStatus::WaitingForConfirmation {
648 confirmation: ToolCallConfirmation::from_acp(
649 tool_call.confirmation,
650 self.project.read(cx).languages().clone(),
651 cx,
652 ),
653 respond_tx: tx,
654 };
655
656 let id = self.insert_tool_call(tool_call.tool_call, status, cx);
657 ToolCallRequest { id, outcome: rx }
658 }
659
660 pub fn request_tool_call_confirmation(
661 &mut self,
662 tool_call_id: ToolCallId,
663 confirmation: acp::ToolCallConfirmation,
664 cx: &mut Context<Self>,
665 ) -> Result<ToolCallRequest> {
666 let project = self.project.read(cx).languages().clone();
667 let Some((_, call)) = self.tool_call_mut(tool_call_id) else {
668 anyhow::bail!("Tool call not found");
669 };
670
671 let (tx, rx) = oneshot::channel();
672
673 call.status = ToolCallStatus::WaitingForConfirmation {
674 confirmation: ToolCallConfirmation::from_acp(confirmation, project, cx),
675 respond_tx: tx,
676 };
677
678 Ok(ToolCallRequest {
679 id: tool_call_id,
680 outcome: rx,
681 })
682 }
683
684 pub fn push_tool_call(
685 &mut self,
686 request: acp::PushToolCallParams,
687 cx: &mut Context<Self>,
688 ) -> acp::ToolCallId {
689 let status = ToolCallStatus::Allowed {
690 status: acp::ToolCallStatus::Running,
691 };
692
693 self.insert_tool_call(request, status, cx)
694 }
695
696 fn insert_tool_call(
697 &mut self,
698 tool_call: acp::PushToolCallParams,
699 status: ToolCallStatus,
700 cx: &mut Context<Self>,
701 ) -> acp::ToolCallId {
702 let language_registry = self.project.read(cx).languages().clone();
703 let id = acp::ToolCallId(self.entries.len() as u64);
704 let call = ToolCall {
705 id,
706 label: cx.new(|cx| {
707 Markdown::new(
708 tool_call.label.into(),
709 Some(language_registry.clone()),
710 None,
711 cx,
712 )
713 }),
714 icon: acp_icon_to_ui_icon(tool_call.icon),
715 content: tool_call
716 .content
717 .map(|content| ToolCallContent::from_acp(content, language_registry, cx)),
718 locations: tool_call.locations,
719 status,
720 };
721
722 let location = call.locations.last().cloned();
723 if let Some(location) = location {
724 self.set_project_location(location, cx)
725 }
726
727 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
728
729 id
730 }
731
732 pub fn authorize_tool_call(
733 &mut self,
734 id: acp::ToolCallId,
735 outcome: acp::ToolCallConfirmationOutcome,
736 cx: &mut Context<Self>,
737 ) {
738 let Some((ix, call)) = self.tool_call_mut(id) else {
739 return;
740 };
741
742 let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
743 ToolCallStatus::Rejected
744 } else {
745 ToolCallStatus::Allowed {
746 status: acp::ToolCallStatus::Running,
747 }
748 };
749
750 let curr_status = mem::replace(&mut call.status, new_status);
751
752 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
753 respond_tx.send(outcome).log_err();
754 } else if cfg!(debug_assertions) {
755 panic!("tried to authorize an already authorized tool call");
756 }
757
758 cx.emit(AcpThreadEvent::EntryUpdated(ix));
759 }
760
761 pub fn update_tool_call(
762 &mut self,
763 id: acp::ToolCallId,
764 new_status: acp::ToolCallStatus,
765 new_content: Option<acp::ToolCallContent>,
766 cx: &mut Context<Self>,
767 ) -> Result<()> {
768 let language_registry = self.project.read(cx).languages().clone();
769 let (ix, call) = self.tool_call_mut(id).context("Entry not found")?;
770
771 call.content = new_content
772 .map(|new_content| ToolCallContent::from_acp(new_content, language_registry, cx));
773
774 match &mut call.status {
775 ToolCallStatus::Allowed { status } => {
776 *status = new_status;
777 }
778 ToolCallStatus::WaitingForConfirmation { .. } => {
779 anyhow::bail!("Tool call hasn't been authorized yet")
780 }
781 ToolCallStatus::Rejected => {
782 anyhow::bail!("Tool call was rejected and therefore can't be updated")
783 }
784 ToolCallStatus::Canceled => {
785 call.status = ToolCallStatus::Allowed { status: new_status };
786 }
787 }
788
789 let location = call.locations.last().cloned();
790 if let Some(location) = location {
791 self.set_project_location(location, cx)
792 }
793
794 cx.emit(AcpThreadEvent::EntryUpdated(ix));
795 Ok(())
796 }
797
798 fn tool_call_mut(&mut self, id: acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
799 let entry = self.entries.get_mut(id.0 as usize);
800 debug_assert!(
801 entry.is_some(),
802 "We shouldn't give out ids to entries that don't exist"
803 );
804 match entry {
805 Some(AgentThreadEntry::ToolCall(call)) if call.id == id => Some((id.0 as usize, call)),
806 _ => {
807 if cfg!(debug_assertions) {
808 panic!("entry is not a tool call");
809 }
810 None
811 }
812 }
813 }
814
815 pub fn set_project_location(&self, location: ToolCallLocation, cx: &mut Context<Self>) {
816 self.project.update(cx, |project, cx| {
817 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
818 return;
819 };
820 let buffer = project.open_buffer(path, cx);
821 cx.spawn(async move |project, cx| {
822 let buffer = buffer.await?;
823
824 project.update(cx, |project, cx| {
825 let position = if let Some(line) = location.line {
826 let snapshot = buffer.read(cx).snapshot();
827 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
828 snapshot.anchor_before(point)
829 } else {
830 Anchor::MIN
831 };
832
833 project.set_agent_location(
834 Some(AgentLocation {
835 buffer: buffer.downgrade(),
836 position,
837 }),
838 cx,
839 );
840 })
841 })
842 .detach_and_log_err(cx);
843 });
844 }
845
846 /// Returns true if the last turn is awaiting tool authorization
847 pub fn waiting_for_tool_confirmation(&self) -> bool {
848 for entry in self.entries.iter().rev() {
849 match &entry {
850 AgentThreadEntry::ToolCall(call) => match call.status {
851 ToolCallStatus::WaitingForConfirmation { .. } => return true,
852 ToolCallStatus::Allowed { .. }
853 | ToolCallStatus::Rejected
854 | ToolCallStatus::Canceled => continue,
855 },
856 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
857 // Reached the beginning of the turn
858 return false;
859 }
860 }
861 }
862 false
863 }
864
865 pub fn initialize(&self) -> impl use<> + Future<Output = Result<acp::InitializeResponse>> {
866 self.request(acp::InitializeParams {
867 protocol_version: ProtocolVersion::latest(),
868 })
869 }
870
871 pub fn authenticate(&self) -> impl use<> + Future<Output = Result<()>> {
872 self.request(acp::AuthenticateParams)
873 }
874
875 #[cfg(any(test, feature = "test-support"))]
876 pub fn send_raw(
877 &mut self,
878 message: &str,
879 cx: &mut Context<Self>,
880 ) -> BoxFuture<'static, Result<(), acp::Error>> {
881 self.send(
882 acp::SendUserMessageParams {
883 chunks: vec![acp::UserMessageChunk::Text {
884 text: message.to_string(),
885 }],
886 },
887 cx,
888 )
889 }
890
891 pub fn send(
892 &mut self,
893 message: acp::SendUserMessageParams,
894 cx: &mut Context<Self>,
895 ) -> BoxFuture<'static, Result<(), acp::Error>> {
896 self.push_entry(
897 AgentThreadEntry::UserMessage(UserMessage::from_acp(
898 &message,
899 self.project.read(cx).languages().clone(),
900 cx,
901 )),
902 cx,
903 );
904
905 let (tx, rx) = oneshot::channel();
906 let cancel = self.cancel(cx);
907
908 self.send_task = Some(cx.spawn(async move |this, cx| {
909 async {
910 cancel.await.log_err();
911
912 let result = this.update(cx, |this, _| this.request(message))?.await;
913 tx.send(result).log_err();
914 this.update(cx, |this, _cx| this.send_task.take())?;
915 anyhow::Ok(())
916 }
917 .await
918 .log_err();
919 }));
920
921 async move {
922 match rx.await {
923 Ok(Err(e)) => Err(e)?,
924 _ => Ok(()),
925 }
926 }
927 .boxed()
928 }
929
930 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<(), acp::Error>> {
931 if self.send_task.take().is_some() {
932 let request = self.request(acp::CancelSendMessageParams);
933 cx.spawn(async move |this, cx| {
934 request.await?;
935 this.update(cx, |this, _cx| {
936 for entry in this.entries.iter_mut() {
937 if let AgentThreadEntry::ToolCall(call) = entry {
938 let cancel = matches!(
939 call.status,
940 ToolCallStatus::WaitingForConfirmation { .. }
941 | ToolCallStatus::Allowed {
942 status: acp::ToolCallStatus::Running
943 }
944 );
945
946 if cancel {
947 let curr_status =
948 mem::replace(&mut call.status, ToolCallStatus::Canceled);
949
950 if let ToolCallStatus::WaitingForConfirmation {
951 respond_tx, ..
952 } = curr_status
953 {
954 respond_tx
955 .send(acp::ToolCallConfirmationOutcome::Cancel)
956 .ok();
957 }
958 }
959 }
960 }
961 })?;
962 Ok(())
963 })
964 } else {
965 Task::ready(Ok(()))
966 }
967 }
968
969 pub fn read_text_file(
970 &self,
971 request: acp::ReadTextFileParams,
972 reuse_shared_snapshot: bool,
973 cx: &mut Context<Self>,
974 ) -> Task<Result<String>> {
975 let project = self.project.clone();
976 let action_log = self.action_log.clone();
977 cx.spawn(async move |this, cx| {
978 let load = project.update(cx, |project, cx| {
979 let path = project
980 .project_path_for_absolute_path(&request.path, cx)
981 .context("invalid path")?;
982 anyhow::Ok(project.open_buffer(path, cx))
983 });
984 let buffer = load??.await?;
985
986 let snapshot = if reuse_shared_snapshot {
987 this.read_with(cx, |this, _| {
988 this.shared_buffers.get(&buffer.clone()).cloned()
989 })
990 .log_err()
991 .flatten()
992 } else {
993 None
994 };
995
996 let snapshot = if let Some(snapshot) = snapshot {
997 snapshot
998 } else {
999 action_log.update(cx, |action_log, cx| {
1000 action_log.buffer_read(buffer.clone(), cx);
1001 })?;
1002 project.update(cx, |project, cx| {
1003 let position = buffer
1004 .read(cx)
1005 .snapshot()
1006 .anchor_before(Point::new(request.line.unwrap_or_default(), 0));
1007 project.set_agent_location(
1008 Some(AgentLocation {
1009 buffer: buffer.downgrade(),
1010 position,
1011 }),
1012 cx,
1013 );
1014 })?;
1015
1016 buffer.update(cx, |buffer, _| buffer.snapshot())?
1017 };
1018
1019 this.update(cx, |this, _| {
1020 let text = snapshot.text();
1021 this.shared_buffers.insert(buffer.clone(), snapshot);
1022 if request.line.is_none() && request.limit.is_none() {
1023 return Ok(text);
1024 }
1025 let limit = request.limit.unwrap_or(u32::MAX) as usize;
1026 let Some(line) = request.line else {
1027 return Ok(text.lines().take(limit).collect::<String>());
1028 };
1029
1030 let count = text.lines().count();
1031 if count < line as usize {
1032 anyhow::bail!("There are only {} lines", count);
1033 }
1034 Ok(text
1035 .lines()
1036 .skip(line as usize + 1)
1037 .take(limit)
1038 .collect::<String>())
1039 })?
1040 })
1041 }
1042
1043 pub fn write_text_file(
1044 &self,
1045 path: PathBuf,
1046 content: String,
1047 cx: &mut Context<Self>,
1048 ) -> Task<Result<()>> {
1049 let project = self.project.clone();
1050 let action_log = self.action_log.clone();
1051 cx.spawn(async move |this, cx| {
1052 let load = project.update(cx, |project, cx| {
1053 let path = project
1054 .project_path_for_absolute_path(&path, cx)
1055 .context("invalid path")?;
1056 anyhow::Ok(project.open_buffer(path, cx))
1057 });
1058 let buffer = load??.await?;
1059 let snapshot = this.update(cx, |this, cx| {
1060 this.shared_buffers
1061 .get(&buffer)
1062 .cloned()
1063 .unwrap_or_else(|| buffer.read(cx).snapshot())
1064 })?;
1065 let edits = cx
1066 .background_executor()
1067 .spawn(async move {
1068 let old_text = snapshot.text();
1069 text_diff(old_text.as_str(), &content)
1070 .into_iter()
1071 .map(|(range, replacement)| {
1072 (
1073 snapshot.anchor_after(range.start)
1074 ..snapshot.anchor_before(range.end),
1075 replacement,
1076 )
1077 })
1078 .collect::<Vec<_>>()
1079 })
1080 .await;
1081 cx.update(|cx| {
1082 project.update(cx, |project, cx| {
1083 project.set_agent_location(
1084 Some(AgentLocation {
1085 buffer: buffer.downgrade(),
1086 position: edits
1087 .last()
1088 .map(|(range, _)| range.end)
1089 .unwrap_or(Anchor::MIN),
1090 }),
1091 cx,
1092 );
1093 });
1094
1095 action_log.update(cx, |action_log, cx| {
1096 action_log.buffer_read(buffer.clone(), cx);
1097 });
1098 buffer.update(cx, |buffer, cx| {
1099 buffer.edit(edits, None, cx);
1100 });
1101 action_log.update(cx, |action_log, cx| {
1102 action_log.buffer_edited(buffer.clone(), cx);
1103 });
1104 })?;
1105 project
1106 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1107 .await
1108 })
1109 }
1110
1111 pub fn child_status(&mut self) -> Option<Task<Result<()>>> {
1112 self.child_status.take()
1113 }
1114
1115 pub fn to_markdown(&self, cx: &App) -> String {
1116 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1117 }
1118}
1119
1120#[derive(Clone)]
1121pub struct AcpClientDelegate {
1122 thread: WeakEntity<AcpThread>,
1123 cx: AsyncApp,
1124 // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
1125}
1126
1127impl AcpClientDelegate {
1128 pub fn new(thread: WeakEntity<AcpThread>, cx: AsyncApp) -> Self {
1129 Self { thread, cx }
1130 }
1131
1132 pub async fn request_existing_tool_call_confirmation(
1133 &self,
1134 tool_call_id: ToolCallId,
1135 confirmation: acp::ToolCallConfirmation,
1136 ) -> Result<ToolCallConfirmationOutcome> {
1137 let cx = &mut self.cx.clone();
1138 let ToolCallRequest { outcome, .. } = cx
1139 .update(|cx| {
1140 self.thread.update(cx, |thread, cx| {
1141 thread.request_tool_call_confirmation(tool_call_id, confirmation, cx)
1142 })
1143 })?
1144 .context("Failed to update thread")??;
1145
1146 Ok(outcome.await?)
1147 }
1148
1149 pub async fn read_text_file_reusing_snapshot(
1150 &self,
1151 request: acp::ReadTextFileParams,
1152 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
1153 let content = self
1154 .cx
1155 .update(|cx| {
1156 self.thread
1157 .update(cx, |thread, cx| thread.read_text_file(request, true, cx))
1158 })?
1159 .context("Failed to update thread")?
1160 .await?;
1161 Ok(acp::ReadTextFileResponse { content })
1162 }
1163}
1164
1165impl acp::Client for AcpClientDelegate {
1166 async fn stream_assistant_message_chunk(
1167 &self,
1168 params: acp::StreamAssistantMessageChunkParams,
1169 ) -> Result<(), acp::Error> {
1170 let cx = &mut self.cx.clone();
1171
1172 cx.update(|cx| {
1173 self.thread
1174 .update(cx, |thread, cx| {
1175 thread.push_assistant_chunk(params.chunk, cx)
1176 })
1177 .ok();
1178 })?;
1179
1180 Ok(())
1181 }
1182
1183 async fn request_tool_call_confirmation(
1184 &self,
1185 request: acp::RequestToolCallConfirmationParams,
1186 ) -> Result<acp::RequestToolCallConfirmationResponse, acp::Error> {
1187 let cx = &mut self.cx.clone();
1188 let ToolCallRequest { id, outcome } = cx
1189 .update(|cx| {
1190 self.thread
1191 .update(cx, |thread, cx| thread.request_new_tool_call(request, cx))
1192 })?
1193 .context("Failed to update thread")?;
1194
1195 Ok(acp::RequestToolCallConfirmationResponse {
1196 id,
1197 outcome: outcome.await.map_err(acp::Error::into_internal_error)?,
1198 })
1199 }
1200
1201 async fn push_tool_call(
1202 &self,
1203 request: acp::PushToolCallParams,
1204 ) -> Result<acp::PushToolCallResponse, acp::Error> {
1205 let cx = &mut self.cx.clone();
1206 let id = cx
1207 .update(|cx| {
1208 self.thread
1209 .update(cx, |thread, cx| thread.push_tool_call(request, cx))
1210 })?
1211 .context("Failed to update thread")?;
1212
1213 Ok(acp::PushToolCallResponse { id })
1214 }
1215
1216 async fn update_tool_call(&self, request: acp::UpdateToolCallParams) -> Result<(), acp::Error> {
1217 let cx = &mut self.cx.clone();
1218
1219 cx.update(|cx| {
1220 self.thread.update(cx, |thread, cx| {
1221 thread.update_tool_call(request.tool_call_id, request.status, request.content, cx)
1222 })
1223 })?
1224 .context("Failed to update thread")??;
1225
1226 Ok(())
1227 }
1228
1229 async fn read_text_file(
1230 &self,
1231 request: acp::ReadTextFileParams,
1232 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
1233 let content = self
1234 .cx
1235 .update(|cx| {
1236 self.thread
1237 .update(cx, |thread, cx| thread.read_text_file(request, false, cx))
1238 })?
1239 .context("Failed to update thread")?
1240 .await?;
1241 Ok(acp::ReadTextFileResponse { content })
1242 }
1243
1244 async fn write_text_file(&self, request: acp::WriteTextFileParams) -> Result<(), acp::Error> {
1245 self.cx
1246 .update(|cx| {
1247 self.thread.update(cx, |thread, cx| {
1248 thread.write_text_file(request.path, request.content, cx)
1249 })
1250 })?
1251 .context("Failed to update thread")?
1252 .await?;
1253
1254 Ok(())
1255 }
1256}
1257
1258fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
1259 match icon {
1260 acp::Icon::FileSearch => IconName::ToolSearch,
1261 acp::Icon::Folder => IconName::ToolFolder,
1262 acp::Icon::Globe => IconName::ToolWeb,
1263 acp::Icon::Hammer => IconName::ToolHammer,
1264 acp::Icon::LightBulb => IconName::ToolBulb,
1265 acp::Icon::Pencil => IconName::ToolPencil,
1266 acp::Icon::Regex => IconName::ToolRegex,
1267 acp::Icon::Terminal => IconName::ToolTerminal,
1268 }
1269}
1270
1271pub struct ToolCallRequest {
1272 pub id: acp::ToolCallId,
1273 pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
1274}
1275
1276#[cfg(test)]
1277mod tests {
1278 use super::*;
1279 use anyhow::anyhow;
1280 use async_pipe::{PipeReader, PipeWriter};
1281 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1282 use gpui::{AsyncApp, TestAppContext};
1283 use indoc::indoc;
1284 use project::FakeFs;
1285 use serde_json::json;
1286 use settings::SettingsStore;
1287 use smol::{future::BoxedLocal, stream::StreamExt as _};
1288 use std::{cell::RefCell, rc::Rc, time::Duration};
1289 use util::path;
1290
1291 fn init_test(cx: &mut TestAppContext) {
1292 env_logger::try_init().ok();
1293 cx.update(|cx| {
1294 let settings_store = SettingsStore::test(cx);
1295 cx.set_global(settings_store);
1296 Project::init_settings(cx);
1297 language::init(cx);
1298 });
1299 }
1300
1301 #[gpui::test]
1302 async fn test_thinking_concatenation(cx: &mut TestAppContext) {
1303 init_test(cx);
1304
1305 let fs = FakeFs::new(cx.executor());
1306 let project = Project::test(fs, [], cx).await;
1307 let (thread, fake_server) = fake_acp_thread(project, cx);
1308
1309 fake_server.update(cx, |fake_server, _| {
1310 fake_server.on_user_message(move |_, server, mut cx| async move {
1311 server
1312 .update(&mut cx, |server, _| {
1313 server.send_to_zed(acp::StreamAssistantMessageChunkParams {
1314 chunk: acp::AssistantMessageChunk::Thought {
1315 thought: "Thinking ".into(),
1316 },
1317 })
1318 })?
1319 .await
1320 .unwrap();
1321 server
1322 .update(&mut cx, |server, _| {
1323 server.send_to_zed(acp::StreamAssistantMessageChunkParams {
1324 chunk: acp::AssistantMessageChunk::Thought {
1325 thought: "hard!".into(),
1326 },
1327 })
1328 })?
1329 .await
1330 .unwrap();
1331
1332 Ok(())
1333 })
1334 });
1335
1336 thread
1337 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1338 .await
1339 .unwrap();
1340
1341 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1342 assert_eq!(
1343 output,
1344 indoc! {r#"
1345 ## User
1346
1347 Hello from Zed!
1348
1349 ## Assistant
1350
1351 <thinking>
1352 Thinking hard!
1353 </thinking>
1354
1355 "#}
1356 );
1357 }
1358
1359 #[gpui::test]
1360 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1361 init_test(cx);
1362
1363 let fs = FakeFs::new(cx.executor());
1364 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1365 .await;
1366 let project = Project::test(fs.clone(), [], cx).await;
1367 let (thread, fake_server) = fake_acp_thread(project.clone(), cx);
1368 let (worktree, pathbuf) = project
1369 .update(cx, |project, cx| {
1370 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1371 })
1372 .await
1373 .unwrap();
1374 let buffer = project
1375 .update(cx, |project, cx| {
1376 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1377 })
1378 .await
1379 .unwrap();
1380
1381 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1382 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1383
1384 fake_server.update(cx, |fake_server, _| {
1385 fake_server.on_user_message(move |_, server, mut cx| {
1386 let read_file_tx = read_file_tx.clone();
1387 async move {
1388 let content = server
1389 .update(&mut cx, |server, _| {
1390 server.send_to_zed(acp::ReadTextFileParams {
1391 path: path!("/tmp/foo").into(),
1392 line: None,
1393 limit: None,
1394 })
1395 })?
1396 .await
1397 .unwrap();
1398 assert_eq!(content.content, "one\ntwo\nthree\n");
1399 read_file_tx.take().unwrap().send(()).unwrap();
1400 server
1401 .update(&mut cx, |server, _| {
1402 server.send_to_zed(acp::WriteTextFileParams {
1403 path: path!("/tmp/foo").into(),
1404 content: "one\ntwo\nthree\nfour\nfive\n".to_string(),
1405 })
1406 })?
1407 .await
1408 .unwrap();
1409 Ok(())
1410 }
1411 })
1412 });
1413
1414 let request = thread.update(cx, |thread, cx| {
1415 thread.send_raw("Extend the count in /tmp/foo", cx)
1416 });
1417 read_file_rx.await.ok();
1418 buffer.update(cx, |buffer, cx| {
1419 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1420 });
1421 cx.run_until_parked();
1422 assert_eq!(
1423 buffer.read_with(cx, |buffer, _| buffer.text()),
1424 "zero\none\ntwo\nthree\nfour\nfive\n"
1425 );
1426 assert_eq!(
1427 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1428 "zero\none\ntwo\nthree\nfour\nfive\n"
1429 );
1430 request.await.unwrap();
1431 }
1432
1433 #[gpui::test]
1434 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1435 init_test(cx);
1436
1437 let fs = FakeFs::new(cx.executor());
1438 let project = Project::test(fs, [], cx).await;
1439 let (thread, fake_server) = fake_acp_thread(project, cx);
1440
1441 let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>();
1442
1443 let tool_call_id = Rc::new(RefCell::new(None));
1444 let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx)));
1445 fake_server.update(cx, |fake_server, _| {
1446 let tool_call_id = tool_call_id.clone();
1447 fake_server.on_user_message(move |_, server, mut cx| {
1448 let end_turn_rx = end_turn_rx.clone();
1449 let tool_call_id = tool_call_id.clone();
1450 async move {
1451 let tool_call_result = server
1452 .update(&mut cx, |server, _| {
1453 server.send_to_zed(acp::PushToolCallParams {
1454 label: "Fetch".to_string(),
1455 icon: acp::Icon::Globe,
1456 content: None,
1457 locations: vec![],
1458 })
1459 })?
1460 .await
1461 .unwrap();
1462 *tool_call_id.clone().borrow_mut() = Some(tool_call_result.id);
1463 end_turn_rx.take().unwrap().await.ok();
1464
1465 Ok(())
1466 }
1467 })
1468 });
1469
1470 let request = thread.update(cx, |thread, cx| {
1471 thread.send_raw("Fetch https://example.com", cx)
1472 });
1473
1474 run_until_first_tool_call(&thread, cx).await;
1475
1476 thread.read_with(cx, |thread, _| {
1477 assert!(matches!(
1478 thread.entries[1],
1479 AgentThreadEntry::ToolCall(ToolCall {
1480 status: ToolCallStatus::Allowed {
1481 status: acp::ToolCallStatus::Running,
1482 ..
1483 },
1484 ..
1485 })
1486 ));
1487 });
1488
1489 cx.run_until_parked();
1490
1491 thread
1492 .update(cx, |thread, cx| thread.cancel(cx))
1493 .await
1494 .unwrap();
1495
1496 thread.read_with(cx, |thread, _| {
1497 assert!(matches!(
1498 &thread.entries[1],
1499 AgentThreadEntry::ToolCall(ToolCall {
1500 status: ToolCallStatus::Canceled,
1501 ..
1502 })
1503 ));
1504 });
1505
1506 fake_server
1507 .update(cx, |fake_server, _| {
1508 fake_server.send_to_zed(acp::UpdateToolCallParams {
1509 tool_call_id: tool_call_id.borrow().unwrap(),
1510 status: acp::ToolCallStatus::Finished,
1511 content: None,
1512 })
1513 })
1514 .await
1515 .unwrap();
1516
1517 drop(end_turn_tx);
1518 request.await.unwrap();
1519
1520 thread.read_with(cx, |thread, _| {
1521 assert!(matches!(
1522 thread.entries[1],
1523 AgentThreadEntry::ToolCall(ToolCall {
1524 status: ToolCallStatus::Allowed {
1525 status: acp::ToolCallStatus::Finished,
1526 ..
1527 },
1528 ..
1529 })
1530 ));
1531 });
1532 }
1533
1534 async fn run_until_first_tool_call(
1535 thread: &Entity<AcpThread>,
1536 cx: &mut TestAppContext,
1537 ) -> usize {
1538 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1539
1540 let subscription = cx.update(|cx| {
1541 cx.subscribe(thread, move |thread, _, cx| {
1542 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1543 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1544 return tx.try_send(ix).unwrap();
1545 }
1546 }
1547 })
1548 });
1549
1550 select! {
1551 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1552 panic!("Timeout waiting for tool call")
1553 }
1554 ix = rx.next().fuse() => {
1555 drop(subscription);
1556 ix.unwrap()
1557 }
1558 }
1559 }
1560
1561 pub fn fake_acp_thread(
1562 project: Entity<Project>,
1563 cx: &mut TestAppContext,
1564 ) -> (Entity<AcpThread>, Entity<FakeAcpServer>) {
1565 let (stdin_tx, stdin_rx) = async_pipe::pipe();
1566 let (stdout_tx, stdout_rx) = async_pipe::pipe();
1567
1568 let thread = cx.new(|cx| {
1569 let foreground_executor = cx.foreground_executor().clone();
1570 let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
1571 AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
1572 stdin_tx,
1573 stdout_rx,
1574 move |fut| {
1575 foreground_executor.spawn(fut).detach();
1576 },
1577 );
1578
1579 let io_task = cx.background_spawn({
1580 async move {
1581 io_fut.await.log_err();
1582 Ok(())
1583 }
1584 });
1585 AcpThread::new(connection, "Test".into(), Some(io_task), project, cx)
1586 });
1587 let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
1588 (thread, agent)
1589 }
1590
1591 pub struct FakeAcpServer {
1592 connection: acp::ClientConnection,
1593
1594 _io_task: Task<()>,
1595 on_user_message: Option<
1596 Rc<
1597 dyn Fn(
1598 acp::SendUserMessageParams,
1599 Entity<FakeAcpServer>,
1600 AsyncApp,
1601 ) -> LocalBoxFuture<'static, Result<(), acp::Error>>,
1602 >,
1603 >,
1604 }
1605
1606 #[derive(Clone)]
1607 struct FakeAgent {
1608 server: Entity<FakeAcpServer>,
1609 cx: AsyncApp,
1610 }
1611
1612 impl acp::Agent for FakeAgent {
1613 async fn initialize(
1614 &self,
1615 params: acp::InitializeParams,
1616 ) -> Result<acp::InitializeResponse, acp::Error> {
1617 Ok(acp::InitializeResponse {
1618 protocol_version: params.protocol_version,
1619 is_authenticated: true,
1620 })
1621 }
1622
1623 async fn authenticate(&self) -> Result<(), acp::Error> {
1624 Ok(())
1625 }
1626
1627 async fn cancel_send_message(&self) -> Result<(), acp::Error> {
1628 Ok(())
1629 }
1630
1631 async fn send_user_message(
1632 &self,
1633 request: acp::SendUserMessageParams,
1634 ) -> Result<(), acp::Error> {
1635 let mut cx = self.cx.clone();
1636 let handler = self
1637 .server
1638 .update(&mut cx, |server, _| server.on_user_message.clone())
1639 .ok()
1640 .flatten();
1641 if let Some(handler) = handler {
1642 handler(request, self.server.clone(), self.cx.clone()).await
1643 } else {
1644 Err(anyhow::anyhow!("No handler for on_user_message").into())
1645 }
1646 }
1647 }
1648
1649 impl FakeAcpServer {
1650 fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
1651 let agent = FakeAgent {
1652 server: cx.entity(),
1653 cx: cx.to_async(),
1654 };
1655 let foreground_executor = cx.foreground_executor().clone();
1656
1657 let (connection, io_fut) = acp::ClientConnection::connect_to_client(
1658 agent.clone(),
1659 stdout,
1660 stdin,
1661 move |fut| {
1662 foreground_executor.spawn(fut).detach();
1663 },
1664 );
1665 FakeAcpServer {
1666 connection: connection,
1667 on_user_message: None,
1668 _io_task: cx.background_spawn(async move {
1669 io_fut.await.log_err();
1670 }),
1671 }
1672 }
1673
1674 fn on_user_message<F>(
1675 &mut self,
1676 handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity<FakeAcpServer>, AsyncApp) -> F
1677 + 'static,
1678 ) where
1679 F: Future<Output = Result<(), acp::Error>> + 'static,
1680 {
1681 self.on_user_message
1682 .replace(Rc::new(move |request, server, cx| {
1683 handler(request, server, cx).boxed_local()
1684 }));
1685 }
1686
1687 fn send_to_zed<T: acp::ClientRequest + 'static>(
1688 &self,
1689 message: T,
1690 ) -> BoxedLocal<Result<T::Response>> {
1691 self.connection
1692 .request(message)
1693 .map(|f| f.map_err(|err| anyhow!(err)))
1694 .boxed_local()
1695 }
1696 }
1697}